Passed
Pull Request — master (#6)
by Konstantinos
03:56
created

artificial_artwork.style_model.graph_builder   A

Complexity

Total Complexity 6

Size/Duplication

Total Lines 42
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 6
eloc 27
dl 0
loc 42
rs 10
c 0
b 0
f 0

6 Methods

Rating   Name   Duplication   Size   Complexity  
A GraphBuilder.relu_conv_2d() 0 10 1
A GraphBuilder.input() 0 4 1
A GraphBuilder._build_layer() 0 4 1
A GraphBuilder._conv_2d() 0 4 1
A GraphBuilder.avg_pool() 0 3 1
A GraphBuilder.__init__() 0 3 1
1
from typing import Tuple
2
import numpy as np
3
from numpy.typing import NDArray
4
import tensorflow as tf
5
6
7
class GraphBuilder:
8
9
    def __init__(self):
10
        self.graph = {}
11
        self._prev_layer = None
12
13
    def _build_layer(self, layer_id: str, layer):
14
        self.graph[layer_id] = layer
15
        self._prev_layer = layer
16
        return self
17
18
    def input(self, image_specs):
19
        self.graph = {}
20
        return self._build_layer('input', tf.Variable(np.zeros(
21
            (1, image_specs.height, image_specs.width, image_specs.color_channels)), dtype='float32'))
22
23
    def avg_pool(self, layer_id: str):
24
        return self._build_layer(layer_id,
25
            tf.nn.avg_pool(self._prev_layer, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME'))
26
27
    def relu_conv_2d(self, layer_id: str, layer_weights: Tuple[NDArray, NDArray]):
28
        """A Relu wrapped around a convolutional layer.
29
30
        Will use the layer_id to find weight (for W and b matrices) values in
31
        the pretrained model (layer).
32
33
        Also uses the layer_id to as dict key to the output graph.
34
        """
35
        W, b = layer_weights
36
        return self._build_layer(layer_id, tf.nn.relu(self._conv_2d(W, b)))
37
38
    def _conv_2d(self, W: NDArray, b: NDArray):
39
        W = tf.constant(W)
40
        b = tf.constant(np.reshape(b, (b.size)))
41
        return tf.compat.v1.nn.conv2d(self._prev_layer, filter=W, strides=[1, 1, 1, 1], padding='SAME') + b
42