Passed
Pull Request — master (#6)
by Konstantinos
03:56
created

test_algorithm   A

Complexity

Total Complexity 13

Size/Duplication

Total Lines 144
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 13
eloc 100
dl 0
loc 144
rs 10
c 0
b 0
f 0

10 Functions

Rating   Name   Duplication   Size   Complexity  
A get_algorithm_runner() 0 8 1
A image_file_names() 0 5 1
A algorithm_parameters_class() 0 4 1
A create_production_algorithm_runner() 0 21 2
A get_model_design() 0 7 1
A image_manager() 0 9 3
A max_iterations_adapter_factory_method() 0 6 1
A create_algorithm() 0 10 1
A algorithm() 0 6 1
A test_nst_runner() 0 38 1
1
import pytest
2
3
from click.testing import CliRunner
4
from artificial_artwork.cli import cli
5
from unittest.mock import patch
6
7
8
runner = CliRunner()
9
10
@pytest.fixture
11
def image_file_names():
12
    return type('Images', (), {
13
        'content': 'canoe_water_w300-h225.jpg',
14
        'style': 'blue-red_w300-h225.jpg'
15
    })
16
17
18
@pytest.fixture
19
def image_manager(image_manager_class):
20
    """Production ImageManager instance."""
21
    import numpy as np
22
    from artificial_artwork.image.image_operations import reshape_image, subtract
23
    means = np.array([123.68, 116.779, 103.939]).reshape((1,1,1,3))
24
    return image_manager_class([
25
        lambda matrix: reshape_image(matrix, ((1,) + matrix.shape)),
26
        lambda matrix: subtract(matrix, means),  # input image must have 3 channels!
27
    ])
28
29
30
@pytest.fixture
31
def max_iterations_adapter_factory_method():
32
    from artificial_artwork.termination_condition_adapter_factory import TerminationConditionAdapterFactory
33
    def create_max_iterations_termination_condition_adapter(iterations):
34
        return TerminationConditionAdapterFactory.create('max-iterations', iterations)
35
    return create_max_iterations_termination_condition_adapter
36
37
38
@pytest.fixture
39
def algorithm_parameters_class():
40
    from artificial_artwork.algorithm import AlogirthmParameters
41
    return AlogirthmParameters
42
43
44
@pytest.fixture
45
def algorithm(algorithm_parameters_class):
46
    from artificial_artwork.algorithm import NSTAlgorithm
47
    def _create_algorithm(*parameters):
48
        return NSTAlgorithm(algorithm_parameters_class(*parameters))
49
    return _create_algorithm
50
51
52
@pytest.fixture
53
def create_algorithm(algorithm, tmpdir):
54
    def _create_algorithm(image_manager, termination_condition_adapter):
55
        return algorithm(
56
            image_manager.content_image,
57
            image_manager.style_image,
58
            termination_condition_adapter,
59
            tmpdir
60
        )
61
    return _create_algorithm
62
63
64
@pytest.fixture
65
def create_production_algorithm_runner():
66
    from artificial_artwork.nst_tf_algorithm import NSTAlgorithmRunner
67
    from artificial_artwork.image.image_operations import noisy, convert_to_uint8
68
    from artificial_artwork.styling_observer import StylingObserver
69
    from artificial_artwork.disk_operations import Disk
70
    
71
    noisy_ratio = 0.6
72
    def _create_production_algorithm_runner(termination_condition_adapter):
73
        algorithm_runner = NSTAlgorithmRunner.default(
74
            lambda matrix: noisy(matrix, noisy_ratio),
75
        )
76
77
        algorithm_runner.progress_subject.add(
78
            termination_condition_adapter,
79
        )
80
        algorithm_runner.persistance_subject.add(
81
            StylingObserver(Disk.save_image, convert_to_uint8)
82
        )
83
        return algorithm_runner
84
    return _create_production_algorithm_runner
85
86
87
@pytest.fixture
88
def get_algorithm_runner(create_production_algorithm_runner):
89
    def _get_algorithm_runner(termination_condition_adapter):
90
        algorithm_runner = create_production_algorithm_runner(
91
            termination_condition_adapter,
92
        )
93
        return algorithm_runner
94
    return _get_algorithm_runner
95
96
97
@pytest.fixture
98
def get_model_design():
99
    def _get_model_design(handler, network_design):
100
        return type('ModelDesign', (), {
101
            'pretrained_model': handler,
102
            'network_design': network_design})
103
    return _get_model_design
104
105
106
def test_nst_runner(
107
    get_algorithm_runner,
108
    create_algorithm,
109
    image_file_names,
110
    get_model_design,
111
    max_iterations_adapter_factory_method,
112
    image_manager,
113
    test_image,
114
    model,
115
    tmpdir):
116
    """Test nst algorithm runner.
117
118
    Runs a simple 'smoke test' by iterating only 3 times.
119
    """
120
    import os
121
    ITERATIONS = 3
122
123
    image_manager.load_from_disk(test_image(image_file_names.content), 'content')
124
    image_manager.load_from_disk(test_image(image_file_names.style), 'style')
125
126
    assert image_manager.images_compatible == True
127
128
    termination_condition_adapter = max_iterations_adapter_factory_method(ITERATIONS)
129
130
    algorithm_runner = get_algorithm_runner(termination_condition_adapter)
131
132
    algorithm = create_algorithm(image_manager, termination_condition_adapter)
133
134
    model_design = get_model_design(
135
        model.pretrained_model.handler,
136
        model.network_design,
137
    )
138
    model_design.pretrained_model.load_model_layers()
139
    algorithm_runner.run(algorithm, model_design)
140
141
    template_string = image_file_names.content + '+' + image_file_names.style + '-' + '{}' + '.png'
142
    assert os.path.isfile(os.path.join(tmpdir, template_string.format(1)))
143
    assert os.path.isfile(os.path.join(tmpdir, template_string.format(ITERATIONS)))
144