Passed
Push — master ( c3d045...3525ae )
by Konstantinos
01:55 queued 43s
created

conftest.toy_nst_algorithm()   A

Complexity

Conditions 2

Size

Total Lines 63
Code Lines 34

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 34
nop 3
dl 0
loc 63
rs 9.064
c 0
b 0
f 0

How to fix   Long Method   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

1
import typing as t
2
import pytest
3
4
from artificial_artwork.pretrained_model import model_handler
5
6
7
@pytest.fixture
8
def test_suite():
9
    """Path of the test suite directory."""
10
    import os
11
    return os.path.dirname(os.path.realpath(__file__))
12
13
14
@pytest.fixture
15
def test_image(test_suite):
16
    import os
17
    def get_image_file_path(file_name):
18
        return os.path.join(test_suite, 'data', file_name)
19
    return get_image_file_path
20
21
22
@pytest.fixture
23
def disk():
24
    from artificial_artwork.disk_operations import Disk
25
    return Disk
26
27
28
@pytest.fixture
29
def session():
30
    """Tensorflow v1 Session, with seed defined at runtime.
31
32
    >>> import tensorflow as tf
33
    >>> with session(2) as test:
34
    ...  a_C = tf.compat.v1.random_normal([1, 4, 4, 3], mean=1, stddev=4)
35
    ...  a_G = tf.compat.v1.random_normal([1, 4, 4, 3], mean=1, stddev=4)
36
    ...  J_content = compute_cost(a_C, a_G)
37
    ...  assert abs(J_content.eval() - 7.0738883) < 1e-5
38
39
    Returns:
40
        (MySession): A tensorflow session with a set random seed
41
    """
42
    import tensorflow as tf
43
    class MySession():
44
        def __init__(self, seed):
45
            tf.compat.v1.reset_default_graph()
46
            self.tf_session = tf.compat.v1.Session()
47
            self.seed = seed
48
        def __enter__(self):
49
            entering_output = self.tf_session.__enter__()
50
            tf.compat.v1.set_random_seed(self.seed)
51
            return entering_output
52
            
53
        def __exit__(self, type, value, traceback):
54
            # Exception handling here
55
            self.tf_session.__exit__(type, value, traceback)
56
    return MySession  
57
58
59
@pytest.fixture
60
def image_factory():
61
    """Production Image Factory.
62
    
63
    Exposes the 'from_disk(file_path, preprocess=True)'.
64
65
    Returns:
66
        ImageFactory: an instance of the ImageFactory class
67
    """
68
    from artificial_artwork.image.image_factory import ImageFactory
69
    from artificial_artwork.disk_operations import Disk
70
    return ImageFactory(Disk.load_image)
71
72
73
@pytest.fixture
74
def termination_condition_module():
75
    from artificial_artwork.termination_condition.termination_condition import TerminationConditionFacility, \
76
        TerminationConditionInterface, MaxIterations, TimeLimit, Convergence
77
78
    # all tests require that the Facility already contains some implementations of TerminationCondition
79
    assert TerminationConditionFacility.class_registry.subclasses == {
80
        'max-iterations': MaxIterations,
81
        'time-limit': TimeLimit,
82
        'convergence': Convergence,
83
    }
84
    return type('M', (), {
85
        'facility': TerminationConditionFacility,
86
        'interface': TerminationConditionInterface,
87
    })
88
89
90
@pytest.fixture
91
def termination_condition(termination_condition_module):
92
    def create_termination_condition(term_cond_type: str, *args, **kwargs) -> termination_condition_module.interface:
93
        return termination_condition_module.facility.create(term_cond_type, *args, **kwargs)
94
    return create_termination_condition
95
 
96
97
@pytest.fixture
98
def subscribe():
99
    def _subscribe(broadcaster, listeners):
100
        broadcaster.add(*listeners)
101
    return _subscribe
102
103
104
105
@pytest.fixture
106
def broadcaster_class():
107
    class TestSubject:
108
        def __init__(self, subject, done_callback):
109
            self.subject = subject
110
            self.done = done_callback
111
112
        def iterate(self):
113
            i = 0
114
            while not self.done():
115
                # do something in the current iteration
116
                print('Iteration with index', i)
117
118
                # notify when we have completed i+1 iterations
119
                self.subject.state = type('Subject', (), {
120
                    'metrics': {'iterations': i + 1},  # we have completed i+1 iterations
121
                }) 
122
                self.subject.notify()
123
                i += 1
124
            return i
125
126
    return TestSubject
127
128
129
@pytest.fixture
130
def toy_network_design():
131
    # layers we pick to use for our Neural Network
132
    network_layers = ('conv1_1',)  # Toy Network has 1 Layer
133
    weight = 1.0 / len(network_layers)  # equally weight all Style Layers
134
    # for the Toy Network Design, select all Network Layers to be Style Layers
135
    style_layers = [(layer_id, weight) for layer_id in network_layers]
136
    return type('ModelDesign', (), {
137
        'network_layers': (
138
            'conv1_1',
139
        ),
140
        'style_layers': style_layers,
141
        'output_layer': 'conv1_1',
142
    })
143
144
145
@pytest.fixture
146
def image_manager_class():
147
    from artificial_artwork.nst_image import ImageManager
148
    return ImageManager
149
150
151
## Supported pretrained models and their expected layers
152
153
@pytest.fixture
154
def vgg_layers():
155
    """Production vgg image model Complete network's layer Architecture."""
156
    VGG_LAYERS = (
157
        (0, 'conv1_1') ,  # (3, 3, 3, 64)
158
        (1, 'relu1_1') ,
159
        (2, 'conv1_2') ,  # (3, 3, 64, 64)
160
        (3, 'relu1_2') ,
161
        (4, 'pool1')   ,  # maxpool
162
        (5, 'conv2_1') ,  # (3, 3, 64, 128)
163
        (6, 'relu2_1') ,
164
        (7, 'conv2_2') ,  # (3, 3, 128, 128)
165
        (8, 'relu2_2') ,
166
        (9, 'pool2')   ,
167
        (10, 'conv3_1'),  # (3, 3, 128, 256)
168
        (11, 'relu3_1'),
169
        (12, 'conv3_2'),  # (3, 3, 256, 256)
170
        (13, 'relu3_2'),
171
        (14, 'conv3_3'),  # (3, 3, 256, 256)
172
        (15, 'relu3_3'),
173
        (16, 'conv3_4'),  # (3, 3, 256, 256)
174
        (17, 'relu3_4'),
175
        (18, 'pool3')  ,
176
        (19, 'conv4_1'),  # (3, 3, 256, 512)
177
        (20, 'relu4_1'),
178
        (21, 'conv4_2'),  # (3, 3, 512, 512)
179
        (22, 'relu4_2'),
180
        (23, 'conv4_3'),  # (3, 3, 512, 512)
181
        (24, 'relu4_3'),
182
        (25, 'conv4_4'),  # (3, 3, 512, 512)
183
        (26, 'relu4_4'),
184
        (27, 'pool4')  ,
185
        (28, 'conv5_1'),  # (3, 3, 512, 512)
186
        (29, 'relu5_1'),
187
        (30, 'conv5_2'),  # (3, 3, 512, 512)
188
        (31, 'relu5_2'),
189
        (32, 'conv5_3'),  # (3, 3, 512, 512)
190
        (33, 'relu5_3'),
191
        (34, 'conv5_4'),  # (3, 3, 512, 512)
192
        (35, 'relu5_4'),
193
        (36, 'pool5'),
194
        (37, 'fc6'),  # fullyconnected (7, 7, 512, 4096)
195
        (38, 'relu6'),
196
        (39, 'fc7'),  # fullyconnected (1, 1, 4096, 4096)
197
        (40, 'relu7'),
198
        (41, 'fc8'),  # fullyconnected (1, 1, 4096, 1000)
199
        (42, 'prob'),  # softmax
200
    )
201
202
    return tuple((layer_id for _, layer_id in VGG_LAYERS))
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable layer_id does not seem to be defined.
Loading history...
203
204
205
import os
206
PRODUCTION_IMAGE_MODEL = os.environ.get('AA_VGG_19', 'PRETRAINED_MODEL_NOT_FOUND')
207
208
209
@pytest.fixture
210
def pre_trained_models_1(vgg_layers, toy_model_data, toy_network_design):
211
    import typing as t
212
    from numpy.typing import NDArray
213
    from artificial_artwork.production_networks import NetworkDesign
214
    from artificial_artwork.pretrained_model import ModelHandlerFacility
215
216
    toy_layers_loader: t.Callable[..., NDArray] = toy_model_data[0]
217
    pretrained_toy_model_layers: t.List[str] = toy_model_data[1]
218
219
    # equip ModelHandlerFacility with the 'vgg' class implementation
220
    from artificial_artwork.pre_trained_models import vgg
221
    # equip ModelHandlerFacility with the 'toy' class implementation
222
    from artificial_artwork.pre_trained_models.vgg import VggModelRoutines, VggModelHandler
223
    class ToyModelRoutines(VggModelRoutines):
224
        # override only critical operations integrating with Prod Pretrained Stored Layers/Weights
225
        def load_layers(self, file_path: str):
226
            return toy_layers_loader(file_path)
227
228
    toy_model_routines = ToyModelRoutines()
229
230
    @ModelHandlerFacility.factory.register_as_subclass('toy')
231
    class ToyModelHandler(VggModelHandler):
232
        def _load_model_layers(self):
233
            return toy_model_routines.load_layers('')['layers'][0]
234
235
        @property
236
        def model_routines(self):
237
            return toy_model_routines
238
239
    return {
240
        # 'vgg': type('NSTModel', (), {
241
        #     'pretrained_model': type('PTM', (), {
242
        #         'expected_layers': vgg_layers,
243
        #         'id': 'vgg',
244
        #         'handler': ModelHandlerFacility.create('vgg'),
245
        #     }),
246
        #     # Production Style Layers and Output (Content) Layer picked from vgg
247
        #     'network_design': NetworkDesign.from_default_vgg()
248
        # }),
249
        'toy': type('NSTModel', (), {
250
            'pretrained_model': type('PTM', (), {
251
                'expected_layers': pretrained_toy_model_layers,  # t.List[str]
252
                'id': 'toy',
253
                'handler': ModelHandlerFacility.create('toy'),
254
            }),
255
            'network_design': NetworkDesign(
256
                toy_network_design.network_layers,
257
                toy_network_design.style_layers,
258
                toy_network_design.output_layer,
259
            )
260
    }),}
261
@pytest.fixture
262
def model(pre_trained_models_1):
263
    import os
264
    print(f"\n -- PROD IM MODEL: {PRODUCTION_IMAGE_MODEL}")
265
    print(f"Selected Prod?: {os.path.isfile(PRODUCTION_IMAGE_MODEL)}")
266
267
    return pre_trained_models_1['toy']
268
    # return {
269
    #     True: pre_trained_models_1['vgg'],
270
    #     False: pre_trained_models_1['toy'],
271
    # }[os.path.isfile(PRODUCTION_IMAGE_MODEL)]
272
273
274
275
# CONSTANT DATA Representing Layers Information (ie weight values) of Toy Network
276
@pytest.fixture
277
def toy_model_data():
278
    import numpy as np
279
280
    from functools import reduce
281
    # This data format emulates the format the production pretrained VGG layer
282
    # IDs are stored in
283
    model_layers = (
284
        'conv1_1',
285
        'relu1',
286
        'maxpool1',
287
    )
288
    convo_w_weights_shape = (3, 3, 3, 4)
289
290
    def load_layers(*args):
291
        """Load Layers of 3-layered Toy Neural Net, emulating prod VGG format.
292
293
        It emulates what the production implementation (scipy.io.loadmat) does,
294
        by returning an object following the same interface as the one returned
295
        by scipy.io.loadmat, when called on the file storing the production
296
        pretrained VGG model.
297
        """
298
        # here we use pytest to emit some text, leveraging pytest, so that the test code using this fixture
299
        # can somehow verify that the text appeared in the expected place (ie console, log or sth)
300
        print(f"VGG Mat Weights Mock Loader Called")
301
302
        return {
303
            'layers': [[
304
                # 1st Layer: conv1_1
305
                [[[[model_layers[0]], 'unused', [[
306
                    # 'A' Matrix weights tensor with shape (3, 3, 3, 4) (total nb of values = 3*3*3*4 = 108)
307
                    # for this toy Conv Layer we set the tensor values to be 1, 2, 3, ... 3 * 3 * 3 * 4 + 1 = 109
308
                    np.reshape(np.array([i for i in range(1, reduce(lambda i,j: i*j, convo_w_weights_shape)+1)], dtype=np.float32), convo_w_weights_shape),
309
                    # 'b' bias vector, which here is an array of shape (1,)
310
                    # for this toy Conv Layer we set the bias value to be 5
311
                    np.array([5], dtype=np.float32)
312
                ]]]]],
313
                # 2nd Layer: relu1
314
                [[[[model_layers[1]], 'unused', [['W', 'b']]]]],  # these layer weights are not expected to be used, because the layer is not a Conv layer
315
                # 3rd Layer: maxpool1
316
                [[[[model_layers[2]], 'unused', [['W', 'b']]]]],  # these layer weights are not expected to be used, because the layer is not a Conv layer
317
            ]]
318
        }
319
320
    return load_layers, model_layers
321
322
323
# MONKEYPATH PROD NST ALGO at RUNTIME with Algo using Toy Network
324
@pytest.fixture
325
def toy_nst_algorithm(toy_model_data, toy_network_design, monkeypatch):
326
    from numpy.typing import NDArray
327
328
    toy_layers_loader: t.Callable[..., NDArray] = toy_model_data[0]
329
    # pretrained_toy_model_layer_ids: t.List[str] = toy_model_data[1]
330
331
    def _monkeypatch():
332
333
        return_toy_layers, _ = toy_model_data
334
        from artificial_artwork.production_networks import NetworkDesign
335
        from artificial_artwork.pretrained_model import ModelHandlerFacility
336
        # equip Handler Facility Facory with the 'vgg' implementation
337
        from artificial_artwork.pre_trained_models import vgg
338
        import scipy.io
339
340
        # if prod VGG Handler tries to load VGG Prod Weights, return Toy Weights instead
341
        # 1st we patch the scipy.io.loadmat, which is used by the production VGG Handler
342
        monkeypatch.setattr(scipy.io, 'loadmat', return_toy_layers)  # Patch/replace-with-mock
343
344
        from artificial_artwork.pre_trained_models.vgg import VggModelRoutines, VggModelHandler
345
        class ToyModelRoutines(VggModelRoutines):
346
            # override only critical operations integrating with Prod Pretrained Stored Layers/Weights
347
            def load_layers(self, file_path: str):
348
                return toy_layers_loader(file_path)
349
350
        toy_model_routines = ToyModelRoutines()
351
352
        class ToyModelHandler(VggModelHandler):
353
            def _load_model_layers(self):
354
                return toy_model_routines.load_layers('')['layers'][0]
355
356
            @property
357
            def model_routines(self):
358
                return toy_model_routines
359
360
        monkeypatch.setattr(
361
            vgg, 'VggModelHandler', ToyModelHandler)  # Patch/replace-with-mock
362
363
        # 2nd we patch the AA_VGG_19 env var which the code strictly requires to find
364
        import os
365
        os.environ['AA_VGG_19'] = 'unit-tests-toy-value'  # Patch/replace-with-mock
366
367
        # Prod Code uses the 'default' factory (classmetod) method of class
368
        # NetworkDesign, in order to instantiate a NetworkDesign object
369
        # according to the 'Original' NST Algorithm (which layers to pick for
370
        # creating ReLUs from their pretrained Conv A, b weights, or which is the Output Layer)
371
        
372
        # Monkey patching objects used in the 'default' factory method
373
        monkeypatch.setattr(NetworkDesign, 'from_default_vgg',
374
            lambda: NetworkDesign(
375
                toy_network_design.network_layers,  # full list of layer IDs available in Pretrained Model
376
                toy_network_design.style_layers,  # list of tuples with layer IDs and coefficients governing their proportional contribution to the Style Cost/Loss formula
377
                toy_network_design.output_layer,  # layer ID to be used for Content Loss (ie last layer of Pretrained Model/Network)
378
            )
379
        )
380
        # for convenience, construct here a ModelHanlder instance, equiped with
381
        # handling all operations (of ModelHandlerInterface) with mocked Toy operations
382
        # when needed and provide it to test code
383
        # TODO remove the need for that
384
        toy_model_handler = ModelHandlerFacility.create('vgg')  # handler instances are stateless, and lightweight
385
        return toy_model_handler
386
    return _monkeypatch
387