Passed
Pull Request — master (#6)
by Konstantinos
03:56
created

test_algorithm.test_nst_runner()   A

Complexity

Conditions 1

Size

Total Lines 38
Code Lines 26

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 26
dl 0
loc 38
rs 9.256
c 0
b 0
f 0
cc 1
nop 9

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
import pytest
2
3
from click.testing import CliRunner
4
from artificial_artwork.cli import cli
5
from unittest.mock import patch
6
7
8
runner = CliRunner()
9
10
@pytest.fixture
11
def image_file_names():
12
    return type('Images', (), {
13
        'content': 'canoe_water_w300-h225.jpg',
14
        'style': 'blue-red_w300-h225.jpg'
15
    })
16
17
18
@pytest.fixture
19
def image_manager(image_manager_class):
20
    """Production ImageManager instance."""
21
    import numpy as np
22
    from artificial_artwork.image.image_operations import reshape_image, subtract
23
    means = np.array([123.68, 116.779, 103.939]).reshape((1,1,1,3))
24
    return image_manager_class([
25
        lambda matrix: reshape_image(matrix, ((1,) + matrix.shape)),
26
        lambda matrix: subtract(matrix, means),  # input image must have 3 channels!
27
    ])
28
29
30
@pytest.fixture
31
def max_iterations_adapter_factory_method():
32
    from artificial_artwork.termination_condition_adapter_factory import TerminationConditionAdapterFactory
33
    def create_max_iterations_termination_condition_adapter(iterations):
34
        return TerminationConditionAdapterFactory.create('max-iterations', iterations)
35
    return create_max_iterations_termination_condition_adapter
36
37
38
@pytest.fixture
39
def algorithm_parameters_class():
40
    from artificial_artwork.algorithm import AlogirthmParameters
41
    return AlogirthmParameters
42
43
44
@pytest.fixture
45
def algorithm(algorithm_parameters_class):
46
    from artificial_artwork.algorithm import NSTAlgorithm
47
    def _create_algorithm(*parameters):
48
        return NSTAlgorithm(algorithm_parameters_class(*parameters))
49
    return _create_algorithm
50
51
52
@pytest.fixture
53
def create_algorithm(algorithm, tmpdir):
54
    def _create_algorithm(image_manager, termination_condition_adapter):
55
        return algorithm(
56
            image_manager.content_image,
57
            image_manager.style_image,
58
            termination_condition_adapter,
59
            tmpdir
60
        )
61
    return _create_algorithm
62
63
64
@pytest.fixture
65
def create_production_algorithm_runner():
66
    from artificial_artwork.nst_tf_algorithm import NSTAlgorithmRunner
67
    from artificial_artwork.image.image_operations import noisy, convert_to_uint8
68
    from artificial_artwork.styling_observer import StylingObserver
69
    from artificial_artwork.disk_operations import Disk
70
    
71
    noisy_ratio = 0.6
72
    def _create_production_algorithm_runner(termination_condition_adapter):
73
        algorithm_runner = NSTAlgorithmRunner.default(
74
            lambda matrix: noisy(matrix, noisy_ratio),
75
        )
76
77
        algorithm_runner.progress_subject.add(
78
            termination_condition_adapter,
79
        )
80
        algorithm_runner.persistance_subject.add(
81
            StylingObserver(Disk.save_image, convert_to_uint8)
82
        )
83
        return algorithm_runner
84
    return _create_production_algorithm_runner
85
86
87
@pytest.fixture
88
def get_algorithm_runner(create_production_algorithm_runner):
89
    def _get_algorithm_runner(termination_condition_adapter):
90
        algorithm_runner = create_production_algorithm_runner(
91
            termination_condition_adapter,
92
        )
93
        return algorithm_runner
94
    return _get_algorithm_runner
95
96
97
@pytest.fixture
98
def get_model_design():
99
    def _get_model_design(handler, network_design):
100
        return type('ModelDesign', (), {
101
            'pretrained_model': handler,
102
            'network_design': network_design})
103
    return _get_model_design
104
105
106
def test_nst_runner(
107
    get_algorithm_runner,
108
    create_algorithm,
109
    image_file_names,
110
    get_model_design,
111
    max_iterations_adapter_factory_method,
112
    image_manager,
113
    test_image,
114
    model,
115
    tmpdir):
116
    """Test nst algorithm runner.
117
118
    Runs a simple 'smoke test' by iterating only 3 times.
119
    """
120
    import os
121
    ITERATIONS = 3
122
123
    image_manager.load_from_disk(test_image(image_file_names.content), 'content')
124
    image_manager.load_from_disk(test_image(image_file_names.style), 'style')
125
126
    assert image_manager.images_compatible == True
127
128
    termination_condition_adapter = max_iterations_adapter_factory_method(ITERATIONS)
129
130
    algorithm_runner = get_algorithm_runner(termination_condition_adapter)
131
132
    algorithm = create_algorithm(image_manager, termination_condition_adapter)
133
134
    model_design = get_model_design(
135
        model.pretrained_model.handler,
136
        model.network_design,
137
    )
138
    model_design.pretrained_model.load_model_layers()
139
    algorithm_runner.run(algorithm, model_design)
140
141
    template_string = image_file_names.content + '+' + image_file_names.style + '-' + '{}' + '.png'
142
    assert os.path.isfile(os.path.join(tmpdir, template_string.format(1)))
143
    assert os.path.isfile(os.path.join(tmpdir, template_string.format(ITERATIONS)))
144