Passed
Pull Request — master (#6)
by Konstantinos
03:56
created

GraphBuilder.avg_pool()   A

Complexity

Conditions 1

Size

Total Lines 3
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 3
dl 0
loc 3
rs 10
c 0
b 0
f 0
cc 1
nop 2
1
from typing import Tuple
2
import numpy as np
3
from numpy.typing import NDArray
4
import tensorflow as tf
5
6
7
class GraphBuilder:
8
9
    def __init__(self):
10
        self.graph = {}
11
        self._prev_layer = None
12
13
    def _build_layer(self, layer_id: str, layer):
14
        self.graph[layer_id] = layer
15
        self._prev_layer = layer
16
        return self
17
18
    def input(self, image_specs):
19
        self.graph = {}
20
        return self._build_layer('input', tf.Variable(np.zeros(
21
            (1, image_specs.height, image_specs.width, image_specs.color_channels)), dtype='float32'))
22
23
    def avg_pool(self, layer_id: str):
24
        return self._build_layer(layer_id,
25
            tf.nn.avg_pool(self._prev_layer, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME'))
26
27
    def relu_conv_2d(self, layer_id: str, layer_weights: Tuple[NDArray, NDArray]):
28
        """A Relu wrapped around a convolutional layer.
29
30
        Will use the layer_id to find weight (for W and b matrices) values in
31
        the pretrained model (layer).
32
33
        Also uses the layer_id to as dict key to the output graph.
34
        """
35
        W, b = layer_weights
36
        return self._build_layer(layer_id, tf.nn.relu(self._conv_2d(W, b)))
37
38
    def _conv_2d(self, W: NDArray, b: NDArray):
39
        W = tf.constant(W)
40
        b = tf.constant(np.reshape(b, (b.size)))
41
        return tf.compat.v1.nn.conv2d(self._prev_layer, filter=W, strides=[1, 1, 1, 1], padding='SAME') + b
42