Passed
Pull Request — master (#12)
by Konstantinos
01:12
created

conftest   A

Complexity

Total Complexity 19

Size/Duplication

Total Lines 472
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 272
dl 0
loc 472
rs 10
c 0
b 0
f 0
wmc 19

16 Functions

Rating   Name   Duplication   Size   Complexity  
A disk() 0 5 1
A termination_condition() 0 8 1
A broadcaster_class() 0 26 2
A termination_condition_module() 0 22 1
A subscribe() 0 6 1
A model() 0 8 1
B vgg_layers() 0 50 1
B toy_model_data() 0 78 2
A image_factory() 0 13 1
A toy_network_design() 0 14 1
A test_image() 0 8 1
A image_manager_class() 0 5 1
B pre_trained_models_1() 0 62 1
B toy_nst_algorithm() 0 72 2
A test_suite() 0 6 1
A session() 0 32 1
1
import os
2
import typing as t
3
4
import pytest
5
6
7
@pytest.fixture
8
def test_suite():
9
    """Path of the test suite directory."""
10
    import os
11
12
    return os.path.dirname(os.path.realpath(__file__))
13
14
15
@pytest.fixture
16
def test_image(test_suite):
17
    import os
18
19
    def get_image_file_path(file_name):
20
        return os.path.join(test_suite, "data", file_name)
21
22
    return get_image_file_path
23
24
25
@pytest.fixture
26
def disk():
27
    from artificial_artwork.disk_operations import Disk
28
29
    return Disk
30
31
32
@pytest.fixture
33
def session():
34
    """Tensorflow v1 Session, with seed defined at runtime.
35
36
    >>> import tensorflow as tf
37
    >>> with session(2) as test:
38
    ...  a_C = tf.compat.v1.random_normal([1, 4, 4, 3], mean=1, stddev=4)
39
    ...  a_G = tf.compat.v1.random_normal([1, 4, 4, 3], mean=1, stddev=4)
40
    ...  J_content = compute_cost(a_C, a_G)
41
    ...  assert abs(J_content.eval() - 7.0738883) < 1e-5
42
43
    Returns:
44
        (MySession): A tensorflow session with a set random seed
45
    """
46
    import tensorflow as tf
47
48
    class MySession:
49
        def __init__(self, seed):
50
            tf.compat.v1.reset_default_graph()
51
            self.tf_session = tf.compat.v1.Session()
52
            self.seed = seed
53
54
        def __enter__(self):
55
            entering_output = self.tf_session.__enter__()
56
            tf.compat.v1.set_random_seed(self.seed)
57
            return entering_output
58
59
        def __exit__(self, type, value, traceback):
60
            # Exception handling here
61
            self.tf_session.__exit__(type, value, traceback)
62
63
    return MySession
64
65
66
@pytest.fixture
67
def image_factory():
68
    """Production Image Factory.
69
70
    Exposes the 'from_disk(file_path, preprocess=True)'.
71
72
    Returns:
73
        ImageFactory: an instance of the ImageFactory class
74
    """
75
    from artificial_artwork.disk_operations import Disk
76
    from artificial_artwork.image.image_factory import ImageFactory
77
78
    return ImageFactory(Disk.load_image)
79
80
81
@pytest.fixture
82
def termination_condition_module():
83
    from artificial_artwork.termination_condition.termination_condition import (
84
        Convergence,
85
        MaxIterations,
86
        TerminationConditionFacility,
87
        TerminationConditionInterface,
88
        TimeLimit,
89
    )
90
91
    # all tests require that the Facility already contains some implementations of TerminationCondition
92
    assert TerminationConditionFacility.class_registry.subclasses == {
93
        "max-iterations": MaxIterations,
94
        "time-limit": TimeLimit,
95
        "convergence": Convergence,
96
    }
97
    return type(
98
        "M",
99
        (),
100
        {
101
            "facility": TerminationConditionFacility,
102
            "interface": TerminationConditionInterface,
103
        },
104
    )
105
106
107
@pytest.fixture
108
def termination_condition(termination_condition_module):
109
    def create_termination_condition(
110
        term_cond_type: str, *args, **kwargs
111
    ) -> termination_condition_module.interface:
112
        return termination_condition_module.facility.create(term_cond_type, *args, **kwargs)
113
114
    return create_termination_condition
115
116
117
@pytest.fixture
118
def subscribe():
119
    def _subscribe(broadcaster, listeners):
120
        broadcaster.add(*listeners)
121
122
    return _subscribe
123
124
125
@pytest.fixture
126
def broadcaster_class():
127
    class TestSubject:
128
        def __init__(self, subject, done_callback):
129
            self.subject = subject
130
            self.done = done_callback
131
132
        def iterate(self):
133
            i = 0
134
            while not self.done():
135
                # do something in the current iteration
136
                print("Iteration with index", i)
137
138
                # notify when we have completed i+1 iterations
139
                self.subject.state = type(
140
                    "Subject",
141
                    (),
142
                    {
143
                        "metrics": {"iterations": i + 1},  # we have completed i+1 iterations
144
                    },
145
                )
146
                self.subject.notify()
147
                i += 1
148
            return i
149
150
    return TestSubject
151
152
153
@pytest.fixture
154
def toy_network_design():
155
    # layers we pick to use for our Neural Network
156
    network_layers = ("conv1_1",)  # Toy Network has 1 Layer
157
    weight = 1.0 / len(network_layers)  # equally weight all Style Layers
158
    # for the Toy Network Design, select all Network Layers to be Style Layers
159
    style_layers = [(layer_id, weight) for layer_id in network_layers]
160
    return type(
161
        "ModelDesign",
162
        (),
163
        {
164
            "network_layers": ("conv1_1",),
165
            "style_layers": style_layers,
166
            "output_layer": "conv1_1",
167
        },
168
    )
169
170
171
@pytest.fixture
172
def image_manager_class():
173
    from artificial_artwork.nst_image import ImageManager
174
175
    return ImageManager
176
177
178
## Supported pretrained models and their expected layers
179
180
181
@pytest.fixture
182
def vgg_layers():
183
    """Production vgg image model Complete network's layer Architecture."""
184
    VGG_LAYERS = (
185
        (0, "conv1_1"),  # (3, 3, 3, 64)
186
        (1, "relu1_1"),
187
        (2, "conv1_2"),  # (3, 3, 64, 64)
188
        (3, "relu1_2"),
189
        (4, "pool1"),  # maxpool
190
        (5, "conv2_1"),  # (3, 3, 64, 128)
191
        (6, "relu2_1"),
192
        (7, "conv2_2"),  # (3, 3, 128, 128)
193
        (8, "relu2_2"),
194
        (9, "pool2"),
195
        (10, "conv3_1"),  # (3, 3, 128, 256)
196
        (11, "relu3_1"),
197
        (12, "conv3_2"),  # (3, 3, 256, 256)
198
        (13, "relu3_2"),
199
        (14, "conv3_3"),  # (3, 3, 256, 256)
200
        (15, "relu3_3"),
201
        (16, "conv3_4"),  # (3, 3, 256, 256)
202
        (17, "relu3_4"),
203
        (18, "pool3"),
204
        (19, "conv4_1"),  # (3, 3, 256, 512)
205
        (20, "relu4_1"),
206
        (21, "conv4_2"),  # (3, 3, 512, 512)
207
        (22, "relu4_2"),
208
        (23, "conv4_3"),  # (3, 3, 512, 512)
209
        (24, "relu4_3"),
210
        (25, "conv4_4"),  # (3, 3, 512, 512)
211
        (26, "relu4_4"),
212
        (27, "pool4"),
213
        (28, "conv5_1"),  # (3, 3, 512, 512)
214
        (29, "relu5_1"),
215
        (30, "conv5_2"),  # (3, 3, 512, 512)
216
        (31, "relu5_2"),
217
        (32, "conv5_3"),  # (3, 3, 512, 512)
218
        (33, "relu5_3"),
219
        (34, "conv5_4"),  # (3, 3, 512, 512)
220
        (35, "relu5_4"),
221
        (36, "pool5"),
222
        (37, "fc6"),  # fullyconnected (7, 7, 512, 4096)
223
        (38, "relu6"),
224
        (39, "fc7"),  # fullyconnected (1, 1, 4096, 4096)
225
        (40, "relu7"),
226
        (41, "fc8"),  # fullyconnected (1, 1, 4096, 1000)
227
        (42, "prob"),  # softmax
228
    )
229
230
    return tuple((layer_id for _, layer_id in VGG_LAYERS))
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable layer_id does not seem to be defined.
Loading history...
231
232
233
PRODUCTION_IMAGE_MODEL = os.environ.get("AA_VGG_19", "PRETRAINED_MODEL_NOT_FOUND")
234
235
236
@pytest.fixture
237
def pre_trained_models_1(vgg_layers, toy_model_data, toy_network_design):
238
    import typing as t
239
240
    from numpy.typing import NDArray
241
242
    from artificial_artwork.pretrained_model import ModelHandlerFacility
243
    from artificial_artwork.production_networks import NetworkDesign
244
245
    toy_layers_loader: t.Callable[..., NDArray] = toy_model_data[0]
246
    pretrained_toy_model_layers: t.List[str] = toy_model_data[1]
247
248
    # help implement the ModelHandler Interface for the toy handler
249
    # by using parts of production code, that we don't need to mock
250
    from artificial_artwork.pre_trained_models.vgg import (
251
        VggModelHandler,
252
        VggModelRoutines,
253
    )
254
255
    class ToyModelRoutines(VggModelRoutines):
256
        # override only critical operations integrating with Prod Pretrained Stored Layers/Weights
257
        def load_layers(self, file_path: str):
258
            return toy_layers_loader(file_path)
259
260
    toy_model_routines = ToyModelRoutines()
261
262
    @ModelHandlerFacility.factory.register_as_subclass("toy")
263
    class ToyModelHandler(VggModelHandler):
264
        def _load_model_layers(self):
265
            return toy_model_routines.load_layers("")["layers"][0]
266
267
        @property
268
        def model_routines(self):
269
            return toy_model_routines
270
271
    return {
272
        # 'vgg': type('NSTModel', (), {
273
        #     'pretrained_model': type('PTM', (), {
274
        #         'expected_layers': vgg_layers,
275
        #         'id': 'vgg',
276
        #         'handler': ModelHandlerFacility.create('vgg'),
277
        #     }),
278
        #     # Production Style Layers and Output (Content) Layer picked from vgg
279
        #     'network_design': NetworkDesign.from_default_vgg()
280
        # }),
281
        "toy": type(
282
            "NSTModel",
283
            (),
284
            {
285
                "pretrained_model": type(
286
                    "PTM",
287
                    (),
288
                    {
289
                        "expected_layers": pretrained_toy_model_layers,  # t.List[str]
290
                        "id": "toy",
291
                        "handler": ModelHandlerFacility.create("toy"),
292
                    },
293
                ),
294
                "network_design": NetworkDesign(
295
                    toy_network_design.network_layers,
296
                    toy_network_design.style_layers,
297
                    toy_network_design.output_layer,
298
                ),
299
            },
300
        ),
301
    }
302
303
304
@pytest.fixture
305
def model(pre_trained_models_1):
306
    import os
307
308
    print(f"\n -- PROD IM MODEL: {PRODUCTION_IMAGE_MODEL}")
309
    print(f"Selected Prod?: {os.path.isfile(PRODUCTION_IMAGE_MODEL)}")
310
311
    return pre_trained_models_1["toy"]
312
    # return {
313
    #     True: pre_trained_models_1['vgg'],
314
    #     False: pre_trained_models_1['toy'],
315
    # }[os.path.isfile(PRODUCTION_IMAGE_MODEL)]
316
317
318
# CONSTANT DATA Representing Layers Information (ie weight values) of Toy Network
319
@pytest.fixture
320
def toy_model_data():
321
    from functools import reduce
322
323
    import numpy as np
324
325
    # This data format emulates the format the production pretrained VGG layer
326
    # IDs are stored in
327
    model_layers = (
328
        "conv1_1",
329
        "relu1",
330
        "maxpool1",
331
    )
332
    convo_w_weights_shape = (3, 3, 3, 4)
333
334
    def load_layers(*args):
335
        """Load Layers of 3-layered Toy Neural Net, emulating prod VGG format.
336
337
        It emulates what the production implementation (scipy.io.loadmat) does,
338
        by returning an object following the same interface as the one returned
339
        by scipy.io.loadmat, when called on the file storing the production
340
        pretrained VGG model.
341
        """
342
        # here we use pytest to emit some text, leveraging pytest, so that the test code using this fixture
343
        # can somehow verify that the text appeared in the expected place (ie console, log or sth)
344
        print("VGG Mat Weights Mock Loader Called")
345
346
        return {
347
            "layers": [
348
                [
349
                    # 1st Layer: conv1_1
350
                    [
351
                        [
352
                            [
353
                                [model_layers[0]],
354
                                "unused",
355
                                [
356
                                    [
357
                                        # 'A' Matrix weights tensor with shape (3, 3, 3, 4) (total nb of values = 3*3*3*4 = 108)
358
                                        # for this toy Conv Layer we set the tensor values to be 1, 2, 3, ... 3 * 3 * 3 * 4 + 1 = 109
359
                                        np.reshape(
360
                                            np.array(
361
                                                [
362
                                                    i
363
                                                    for i in range(
364
                                                        1,
365
                                                        reduce(
366
                                                            lambda i, j: i * j,
367
                                                            convo_w_weights_shape,
368
                                                        )
369
                                                        + 1,
370
                                                    )
371
                                                ],
372
                                                dtype=np.float32,
373
                                            ),
374
                                            convo_w_weights_shape,
375
                                        ),
376
                                        # 'b' bias vector, which here is an array of shape (1,)
377
                                        # for this toy Conv Layer we set the bias value to be 5
378
                                        np.array([5], dtype=np.float32),
379
                                    ]
380
                                ],
381
                            ]
382
                        ]
383
                    ],
384
                    # 2nd Layer: relu1
385
                    [
386
                        [[[model_layers[1]], "unused", [["W", "b"]]]]
387
                    ],  # these layer weights are not expected to be used, because the layer is not a Conv layer
388
                    # 3rd Layer: maxpool1
389
                    [
390
                        [[[model_layers[2]], "unused", [["W", "b"]]]]
391
                    ],  # these layer weights are not expected to be used, because the layer is not a Conv layer
392
                ]
393
            ]
394
        }
395
396
    return load_layers, model_layers
397
398
399
# MONKEYPATH PROD NST ALGO at RUNTIME with Algo using Toy Network
400
@pytest.fixture
401
def toy_nst_algorithm(toy_model_data, toy_network_design, monkeypatch):
402
    from numpy.typing import NDArray
403
404
    toy_layers_loader: t.Callable[..., NDArray] = toy_model_data[0]
405
    # pretrained_toy_model_layer_ids: t.List[str] = toy_model_data[1]
406
407
    def _monkeypatch():
408
        return_toy_layers, _ = toy_model_data
409
        import scipy.io
410
411
        # equip Handler Facility Facory with the 'vgg' implementation
412
        from artificial_artwork.pre_trained_models import vgg
413
        from artificial_artwork.pretrained_model import ModelHandlerFacility
414
        from artificial_artwork.production_networks import NetworkDesign
415
416
        # if prod VGG Handler tries to load VGG Prod Weights, return Toy Weights instead
417
        # 1st we patch the scipy.io.loadmat, which is used by the production VGG Handler
418
        monkeypatch.setattr(scipy.io, "loadmat", return_toy_layers)  # Patch/replace-with-mock
419
420
        from artificial_artwork.pre_trained_models.vgg import (
421
            VggModelHandler,
422
            VggModelRoutines,
423
        )
424
425
        class ToyModelRoutines(VggModelRoutines):
426
            # override only critical operations integrating with Prod Pretrained Stored Layers/Weights
427
            def load_layers(self, file_path: str):
428
                return toy_layers_loader(file_path)
429
430
        toy_model_routines = ToyModelRoutines()
431
432
        class ToyModelHandler(VggModelHandler):
433
            def _load_model_layers(self):
434
                return toy_model_routines.load_layers("")["layers"][0]
435
436
            @property
437
            def model_routines(self):
438
                return toy_model_routines
439
440
        monkeypatch.setattr(vgg, "VggModelHandler", ToyModelHandler)  # Patch/replace-with-mock
441
442
        # 2nd we patch the AA_VGG_19 env var which the code strictly requires to find
443
        import os
444
445
        os.environ["AA_VGG_19"] = "unit-tests-toy-value"  # Patch/replace-with-mock
446
447
        # Prod Code uses the 'default' factory (classmetod) method of class
448
        # NetworkDesign, in order to instantiate a NetworkDesign object
449
        # according to the 'Original' NST Algorithm (which layers to pick for
450
        # creating ReLUs from their pretrained Conv A, b weights, or which is the Output Layer)
451
452
        # Monkey patching objects used in the 'default' factory method
453
        monkeypatch.setattr(
454
            NetworkDesign,
455
            "from_default_vgg",
456
            lambda: NetworkDesign(
457
                toy_network_design.network_layers,  # full list of layer IDs available in Pretrained Model
458
                toy_network_design.style_layers,  # list of tuples with layer IDs and coefficients governing their proportional contribution to the Style Cost/Loss formula
459
                toy_network_design.output_layer,  # layer ID to be used for Content Loss (ie last layer of Pretrained Model/Network)
460
            ),
461
        )
462
        # for convenience, construct here a ModelHanlder instance, equiped with
463
        # handling all operations (of ModelHandlerInterface) with mocked Toy operations
464
        # when needed and provide it to test code
465
        # TODO remove the need for that
466
        toy_model_handler = ModelHandlerFacility.create(
467
            "vgg"
468
        )  # handler instances are stateless, and lightweight
469
        return toy_model_handler
470
471
    return _monkeypatch
472