Passed
Push — master ( 338098...c3d045 )
by Konstantinos
02:29 queued 01:15
created

artificial_artwork.style_model.graph_factory   A

Complexity

Total Complexity 6

Size/Duplication

Total Lines 78
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 6
eloc 42
dl 0
loc 78
rs 10
c 0
b 0
f 0

4 Methods

Rating   Name   Duplication   Size   Complexity  
A LayerMaker.make_layers() 0 3 2
A LayerMaker.layer() 0 6 2
A LayerMaker.relu() 0 2 1
A GraphFactory.create() 0 26 1
1
import re
2
from typing import Dict, Protocol, Any, Iterable
3
import attr
4
from numpy.typing import NDArray
5
6
from .graph_builder import GraphBuilder
7
8
9
ModelParameters = Dict[str, NDArray]
10
11
class ImageSpecs(Protocol):
12
    width: int
13
    height: int
14
    color_channels: int
15
16
17
class GraphFactory:
18
    builder = GraphBuilder()
19
20
    @classmethod
21
    def create(cls, config: ImageSpecs, model_design) -> Dict[str, Any]:
22
        """Create a model for the purpose of 'painting'/generating a picture.
23
24
        Creates a Deep Learning Neural Network with most layers having weights
25
        (aka model parameters) with values extracted from a pre-trained model
26
        (ie another neural network trained on an image dataset suitably).
27
28
        Args:
29
            config ([type]): [description]
30
            model_parameters ([type], optional): [description]. Defaults to None.
31
32
        Returns:
33
            Dict[str, Any]: [description]
34
        """
35
        # each relu_conv_2d uses pretrained model's layer weights for W and b matrices
36
        # each average pooling layer uses custom weight values
37
        # all weights are guaranteed to remain constant (see GraphBuilder._conv_2d method)
38
39
        cls.builder.input(config)
40
        LayerMaker(
41
            cls.builder,
42
            model_design.pretrained_model.reporter,
43
        ).make_layers(model_design.network_design.network_layers)
44
45
        return cls.builder.graph
46
47
48
@attr.s
49
class LayerMaker:
50
    graph_builder = attr.ib()
51
    reporter = attr.ib()
52
53
    layer_callbacks = attr.ib(init=False, default=attr.Factory(lambda self: {
54
            'conv': self.relu,
55
            'avgpool': self.graph_builder.avg_pool
56
        }, takes_self=True)
57
    )
58
    regex = attr.ib(init=False, default=re.compile(r'(\w+?)[\d_]*$'))
59
60
    def relu(self, layer_id: str):
61
        return self.graph_builder.relu_conv_2d(layer_id, self.reporter.get_weights(layer_id))
62
63
    def layer(self, layer_id: str):
64
        match_instance = self.regex.match(layer_id)
65
        if match_instance is not None:
66
            return self.layer_callbacks[match_instance.group(1)](layer_id)
67
        raise UnknownLayerError(
68
            f"Failed to construct layer '{layer_id}'. Supported layers are "
69
            f"[{', '.join((k for k in self.layer_callbacks))}] and regex"
70
            f"used to parse the layer is '{self.regex.pattern}'")
71
72
    def make_layers(self, layers: Iterable[str]):
73
        for layer_id in layers:
74
            self.layer(layer_id)
75
76
77
class UnknownLayerError(Exception): pass
78