Passed
Pull Request — master (#6)
by Konstantinos
03:56
created

conftest.vgg_layers()   B

Complexity

Conditions 1

Size

Total Lines 50
Code Lines 47

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 47
dl 0
loc 50
rs 8.7345
c 0
b 0
f 0
cc 1
nop 0
1
import pytest
2
3
from artificial_artwork.pretrained_model import model_handler
4
5
6
@pytest.fixture
7
def test_suite():
8
    """Path of the test suite directory."""
9
    import os
10
    return os.path.dirname(os.path.realpath(__file__))
11
12
13
@pytest.fixture
14
def test_image(test_suite):
15
    import os
16
    def get_image_file_path(file_name):
17
        return os.path.join(test_suite, 'data', file_name)
18
    return get_image_file_path
19
20
21
@pytest.fixture
22
def disk():
23
    from artificial_artwork.disk_operations import Disk
24
    return Disk
25
26
27
@pytest.fixture
28
def session():
29
    """Tensorflow v1 Session, with seed defined at runtime.
30
31
    >>> import tensorflow as tf
32
    >>> with session(2) as test:
33
    ...  a_C = tf.compat.v1.random_normal([1, 4, 4, 3], mean=1, stddev=4)
34
    ...  a_G = tf.compat.v1.random_normal([1, 4, 4, 3], mean=1, stddev=4)
35
    ...  J_content = compute_cost(a_C, a_G)
36
    ...  assert abs(J_content.eval() - 7.0738883) < 1e-5
37
38
    Returns:
39
        (MySession): A tensorflow session with a set random seed
40
    """
41
    import tensorflow as tf
42
    class MySession():
43
        def __init__(self, seed):
44
            tf.compat.v1.reset_default_graph()
45
            self.tf_session = tf.compat.v1.Session()
46
            self.seed = seed
47
        def __enter__(self):
48
            entering_output = self.tf_session.__enter__()
49
            tf.compat.v1.set_random_seed(self.seed)
50
            return entering_output
51
            
52
        def __exit__(self, type, value, traceback):
53
            # Exception handling here
54
            self.tf_session.__exit__(type, value, traceback)
55
    return MySession  
56
57
58
@pytest.fixture
59
def image_factory():
60
    """Production Image Factory.
61
    
62
    Exposes the 'from_disk(file_path, preprocess=True)'.
63
64
    Returns:
65
        ImageFactory: an instance of the ImageFactory class
66
    """
67
    from artificial_artwork.image.image_factory import ImageFactory
68
    from artificial_artwork.disk_operations import Disk
69
    return ImageFactory(Disk.load_image)
70
71
72
@pytest.fixture
73
def termination_condition_module():
74
    from artificial_artwork.termination_condition.termination_condition import TerminationConditionFacility, \
75
        TerminationConditionInterface, MaxIterations, TimeLimit, Convergence
76
77
    # all tests require that the Facility already contains some implementations of TerminationCondition
78
    assert TerminationConditionFacility.class_registry.subclasses == {
79
        'max-iterations': MaxIterations,
80
        'time-limit': TimeLimit,
81
        'convergence': Convergence,
82
    }
83
    return type('M', (), {
84
        'facility': TerminationConditionFacility,
85
        'interface': TerminationConditionInterface,
86
    })
87
88
89
@pytest.fixture
90
def termination_condition(termination_condition_module):
91
    def create_termination_condition(term_cond_type: str, *args, **kwargs) -> termination_condition_module.interface:
92
        return termination_condition_module.facility.create(term_cond_type, *args, **kwargs)
93
    return create_termination_condition
94
 
95
96
@pytest.fixture
97
def subscribe():
98
    def _subscribe(broadcaster, listeners):
99
        broadcaster.add(*listeners)
100
    return _subscribe
101
102
103
104
@pytest.fixture
105
def broadcaster_class():
106
    class TestSubject:
107
        def __init__(self, subject, done_callback):
108
            self.subject = subject
109
            self.done = done_callback
110
111
        def iterate(self):
112
            i = 0
113
            while not self.done():
114
                # do something in the current iteration
115
                print('Iteration with index', i)
116
117
                # notify when we have completed i+1 iterations
118
                self.subject.state = type('Subject', (), {
119
                    'metrics': {'iterations': i + 1},  # we have completed i+1 iterations
120
                }) 
121
                self.subject.notify()
122
                i += 1
123
            return i
124
125
    return TestSubject
126
127
128
@pytest.fixture
129
def toy_model_data():
130
    import numpy as np
131
    from artificial_artwork.pretrained_model import ModelHandlerFacility
132
    from artificial_artwork.pre_trained_models.vgg import VggModelRoutines, VggModelHandler
133
134
    from functools import reduce
135
    model_layers = (
136
        'conv1_1',
137
        'relu1',
138
        'maxpool1',
139
    )
140
    convo_w_weights_shape = (3, 3, 3, 4)
141
142
    class ToyModelRoutines(VggModelRoutines):
143
144
        def load_layers(self, file_path: str):
145
            return {
146
                'layers': [[
147
                    [[[[model_layers[0]], 'unused', [[
148
                        np.reshape(np.array([i for i in range(1, reduce(lambda i,j: i*j, convo_w_weights_shape)+1)], dtype=np.float32), convo_w_weights_shape),
149
                        np.array([5], dtype=np.float32)
150
                    ]]]]],
151
                    [[[[model_layers[1]], 'unused', [['W', 'b']]]]],
152
                    [[[[model_layers[2]], 'unused', [['W', 'b']]]]],
153
                ]]
154
            }
155
156
157
    toy_model_routines = ToyModelRoutines()
158
159
    @ModelHandlerFacility.factory.register_as_subclass('toy')
160
    class ToyModelHandler(VggModelHandler):
161
        def _load_model_layers(self):
162
            return toy_model_routines.load_layers('')['layers'][0]
163
164
        @property
165
        def model_routines(self):
166
            return toy_model_routines
167
168
    return type('TMD', (), {
169
        'expected_layers': model_layers,
170
    })
171
172
173
@pytest.fixture
174
def toy_network_design():
175
    # layers we pick to use for our Neural Network
176
    network_layers = ('conv1_1',)
177
    weight = 1.0 / len(network_layers)
178
    style_layers = [(layer_id, weight) for layer_id in network_layers]
179
    return type('ModelDesign', (), {
180
        'network_layers': (
181
            'conv1_1',
182
        ),
183
        'style_layers': style_layers,
184
        'output_layer': 'conv1_1',
185
    })
186
187
188
@pytest.fixture
189
def image_manager_class():
190
    from artificial_artwork.nst_image import ImageManager
191
    return ImageManager
192
193
194
## Supported pretrained models and their expected layers
195
196
@pytest.fixture
197
def vgg_layers():
198
    """The vgg image model network's layer structure."""
199
    VGG_LAYERS = (
200
        (0, 'conv1_1') ,  # (3, 3, 3, 64)
201
        (1, 'relu1_1') ,
202
        (2, 'conv1_2') ,  # (3, 3, 64, 64)
203
        (3, 'relu1_2') ,
204
        (4, 'pool1')   ,  # maxpool
205
        (5, 'conv2_1') ,  # (3, 3, 64, 128)
206
        (6, 'relu2_1') ,
207
        (7, 'conv2_2') ,  # (3, 3, 128, 128)
208
        (8, 'relu2_2') ,
209
        (9, 'pool2')   ,
210
        (10, 'conv3_1'),  # (3, 3, 128, 256)
211
        (11, 'relu3_1'),
212
        (12, 'conv3_2'),  # (3, 3, 256, 256)
213
        (13, 'relu3_2'),
214
        (14, 'conv3_3'),  # (3, 3, 256, 256)
215
        (15, 'relu3_3'),
216
        (16, 'conv3_4'),  # (3, 3, 256, 256)
217
        (17, 'relu3_4'),
218
        (18, 'pool3')  ,
219
        (19, 'conv4_1'),  # (3, 3, 256, 512)
220
        (20, 'relu4_1'),
221
        (21, 'conv4_2'),  # (3, 3, 512, 512)
222
        (22, 'relu4_2'),
223
        (23, 'conv4_3'),  # (3, 3, 512, 512)
224
        (24, 'relu4_3'),
225
        (25, 'conv4_4'),  # (3, 3, 512, 512)
226
        (26, 'relu4_4'),
227
        (27, 'pool4')  ,
228
        (28, 'conv5_1'),  # (3, 3, 512, 512)
229
        (29, 'relu5_1'),
230
        (30, 'conv5_2'),  # (3, 3, 512, 512)
231
        (31, 'relu5_2'),
232
        (32, 'conv5_3'),  # (3, 3, 512, 512)
233
        (33, 'relu5_3'),
234
        (34, 'conv5_4'),  # (3, 3, 512, 512)
235
        (35, 'relu5_4'),
236
        (36, 'pool5'),
237
        (37, 'fc6'),  # fullyconnected (7, 7, 512, 4096)
238
        (38, 'relu6'),
239
        (39, 'fc7'),  # fullyconnected (1, 1, 4096, 4096)
240
        (40, 'relu7'),
241
        (41, 'fc8'),  # fullyconnected (1, 1, 4096, 1000)
242
        (42, 'prob'),  # softmax
243
    )
244
245
    return tuple((layer_id for _, layer_id in VGG_LAYERS))
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable layer_id does not seem to be defined.
Loading history...
246
247
248
import os
249
PRODUCTION_IMAGE_MODEL = os.environ.get('AA_VGG_19', 'PRETRAINED_MODEL_NOT_FOUND')
250
251
252
@pytest.fixture
253
def pre_trained_models_1(vgg_layers, toy_model_data, toy_network_design):
254
    from artificial_artwork.production_networks import NetworkDesign
255
    from artificial_artwork.pretrained_model import ModelHandlerFacility
256
    return {
257
        'vgg': type('NSTModel', (), {
258
            'pretrained_model': type('PTM', (), {
259
                'expected_layers': vgg_layers,
260
                'id': 'vgg',
261
                'handler': ModelHandlerFacility.create('vgg'),
262
            }),
263
            'network_design': NetworkDesign.from_default_vgg()
264
        }),
265
        'toy': type('NSTModel', (), {
266
            'pretrained_model': type('PTM', (), {
267
                'expected_layers': toy_model_data.expected_layers,
268
                'id': 'toy',
269
                'handler': ModelHandlerFacility.create('toy'),
270
            }),
271
            'network_design': NetworkDesign(
272
                toy_network_design.network_layers,
273
                toy_network_design.style_layers,
274
                toy_network_design.output_layer,
275
            )
276
        }),
277
    }
278
279
@pytest.fixture
280
def model(pre_trained_models_1):
281
    import os
282
    return {
283
        True: pre_trained_models_1['vgg'],
284
        False: pre_trained_models_1['toy'],
285
    }[os.path.isfile(PRODUCTION_IMAGE_MODEL)]
286