Passed
Pull Request — master (#1)
by Konstantinos
59s
created

GraphFactory.create()   A

Complexity

Conditions 2

Size

Total Lines 46
Code Lines 17

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 17
dl 0
loc 46
rs 9.55
c 0
b 0
f 0
cc 2
nop 3
1
### Part of this code is due to the MatConvNet team and is used to load the parameters of the pretrained VGG19 model in the notebook ###
2
3
import os
4
import re
5
from typing import Dict, Tuple, Any, Protocol
6
7
import attr
8
from numpy.typing import NDArray
9
import numpy as np
10
import scipy.io
11
import tensorflow as tf
12
13
from .layers_getter import VggLayersGetter
14
from .image_model import LAYERS as NETWORK_DESIGN
15
16
17
class ImageSpecs(Protocol):
18
    width: int
19
    height: int
20
    color_channels: int
21
22
23
def load_vgg_model_parameters(path: str) -> Dict[str, NDArray]:
24
    return scipy.io.loadmat(path)
25
26
27
class NoImageModelSpesifiedError(Exception): pass
28
29
30
def get_vgg_19_model_path():
31
    try:
32
        return os.environ['AA_VGG_19']
33
    except KeyError as variable_not_found:
34
        raise NoImageModelSpesifiedError('No pretrained image model found. '
35
            'Please download it and set the AA_VGG_19 environment variable with the'
36
            'path where ou stored the model (*.mat file), to indicate to wher to '
37
            'locate and load it') from variable_not_found
38
39
40
def load_default_model_parameters():
41
    path = get_vgg_19_model_path()
42
    return load_vgg_model_parameters(path)
43
44
45
def get_layers(model_parameters: Dict[str, NDArray]) -> NDArray:
46
    return model_parameters['layers'][0]
47
48
49
class GraphBuilder:
50
    
51
    def __init__(self):
52
        self.graph = {}
53
        self._prev_layer = None
54
    
55
    def _build_layer(self, layer_id: str, layer):
56
        self.graph[layer_id] = layer
57
        self._prev_layer = layer
58
        return self
59
60
    def input(self, width: int, height: int, nb_channels=3, dtype='float32', layer_id='input'):
61
        self.graph = {}
62
        return self._build_layer(layer_id, tf.Variable(np.zeros((1, height, width, nb_channels)), dtype=dtype))
63
64
    def avg_pool(self, layer_id: str):
65
        return self._build_layer(layer_id, tf.nn.avg_pool(self._prev_layer, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME'))
66
67
    def relu_conv_2d(self, layer_id: str, layer_weights):
68
        """A Relu wrapped around a convolutional layer.
69
        
70
        Will use the layer_id to find weight (for W and b matrices) values in
71
        the pretrained model (layer).
72
        
73
        Also uses the layer_id to as dict key to the output graph.
74
        """
75
        W, b = layer_weights
76
        return self._build_layer(layer_id, tf.nn.relu(self._conv_2d(W, b)))
77
78
    def _conv_2d(self, W: NDArray, b: NDArray):
79
        W = tf.constant(W)
80
        b = tf.constant(np.reshape(b, (b.size)))
81
        return tf.compat.v1.nn.conv2d(self._prev_layer, filter=W, strides=[1, 1, 1, 1], padding='SAME') + b
82
83
84
@attr.s
85
class ModelParameters:
86
    params = attr.ib(default=attr.Factory(load_default_model_parameters))
87
88
89
class GraphFactory:
90
    builder = GraphBuilder()
91
92
    @classmethod
93
    def weights(cls, layer: NDArray) -> Tuple[NDArray, NDArray]:
94
        """Get the weights and bias for a given layer of the VGG model."""
95
            # wb = vgg_layers[0][layer][0][0][2]
96
        wb = layer[0][0][2]
97
        W = wb[0][0]
98
        b = wb[0][1]
99
        return W, b
100
101
    @classmethod
102
    def create(cls, config: ImageSpecs, model_parameters=None) -> Dict[str, Any]:
103
        """Create a model for the purpose of 'painting'/generating a picture.
104
105
        Creates a Deep Learning Neural Network with most layers having weights
106
        (aka model parameters) with values extracted from a pre-trained model
107
        (ie another neural network trained on an image dataset suitably).
108
109
        Args:
110
            config ([type]): [description]
111
            model_parameters ([type], optional): [description]. Defaults to None.
112
113
        Returns:
114
            Dict[str, Any]: [description]
115
        """
116
117
        vgg_model_parameters = ModelParameters(*list(filter(None, [model_parameters])))
118
119
        vgg_layers = get_layers(vgg_model_parameters.params)
120
121
        layer_getter = VggLayersGetter(vgg_layers)
122
123
        def relu(layer_id: str):
124
            return cls.builder.relu_conv_2d(layer_id, cls.weights(layer_getter.id_2_layer[layer_id]))
125
126
        layer_callbacks = {
127
            'conv': relu,
128
            'avgpool': cls.builder.avg_pool
129
        }
130
131
        def layer(layer_id: str):
132
            matched_string = re.match(r'(\w+?)[\d_]*$', layer_id).group(1)
133
            return layer_callbacks[matched_string](layer_id)
134
135
        ## Build Graph
136
137
        # each relu_conv_2d uses pretrained model's layer weights for W and b matrices
138
        # each average pooling layer uses custom weight values
139
        # all weights are guaranteed to remain constant (see GraphBuilder._conv_2d method)
140
141
        # cls.builder.input(config.image_width, config.image_height, nb_channels=config.color_channels)
142
        cls.builder.input(config.width, config.height, nb_channels=config.color_channels)
143
        for layer_id in NETWORK_DESIGN:
144
            layer(layer_id)
145
146
        return cls.builder.graph
147