Passed
Pull Request — master (#1)
by Konstantinos
59s
created

conftest.image_factory()   A

Complexity

Conditions 1

Size

Total Lines 12
Code Lines 5

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 5
dl 0
loc 12
rs 10
c 0
b 0
f 0
cc 1
nop 0
1
import pytest
2
3
4
@pytest.fixture
5
def test_suite():
6
    """Path of the test suite directory."""
7
    import os
8
    return os.path.dirname(os.path.realpath(__file__))
9
10
11
@pytest.fixture
12
def test_image(test_suite):
13
    import os
14
    def get_image_file_path(file_name):
15
        return os.path.join(test_suite, 'data', file_name)
16
    return get_image_file_path
17
18
19
@pytest.fixture
20
def disk():
21
    from artificial_artwork.disk_operations import Disk
22
    return Disk
23
24
25
@pytest.fixture
26
def session():
27
    """Tensorflow v1 Session, with seed defined at runtime.
28
29
    >>> import tensorflow as tf
30
    >>> with session(2) as test:
31
    ...  a_C = tf.compat.v1.random_normal([1, 4, 4, 3], mean=1, stddev=4)
32
    ...  a_G = tf.compat.v1.random_normal([1, 4, 4, 3], mean=1, stddev=4)
33
    ...  J_content = compute_cost(a_C, a_G)
34
    ...  assert abs(J_content.eval() - 7.0738883) < 1e-5
35
36
    Returns:
37
        (MySession): A tensorflow session with a set random seed
38
    """
39
    import tensorflow as tf
40
    class MySession():
41
        def __init__(self, seed):
42
            tf.compat.v1.reset_default_graph()
43
            self.tf_session = tf.compat.v1.Session()
44
            self.seed = seed
45
        def __enter__(self):
46
            entering_output = self.tf_session.__enter__()
47
            tf.compat.v1.set_random_seed(self.seed)
48
            return entering_output
49
            
50
        def __exit__(self, type, value, traceback):
51
            # Exception handling here
52
            self.tf_session.__exit__(type, value, traceback)
53
    return MySession  
54
55
56
# @pytest.fixture
57
# def default_image_processing_config():
58
#     from artificial_artwork.image import ImageProcessingConfig
59
#     return ImageProcessingConfig.from_image_dimensions()
60
61
62
@pytest.fixture
63
def image_factory():
64
    """Production Image Factory.
65
    
66
    Exposes the 'from_disk(file_path, preprocess=True)'.
67
68
    Returns:
69
        ImageFactory: an instance of the ImageFactory class
70
    """
71
    from artificial_artwork.image.image_factory import ImageFactory
72
    from artificial_artwork.disk_operations import Disk
73
    return ImageFactory(Disk.load_image)
74
75
76
@pytest.fixture
77
def termination_condition_module():
78
    from artificial_artwork.termination_condition.termination_condition import TerminationConditionFacility, \
79
        TerminationConditionInterface, MaxIterations, TimeLimit, Convergence
80
81
    # all tests require that the Facility already contains some implementations of TerminationCondition
82
    assert TerminationConditionFacility.class_registry.subclasses == {
83
        'max-iterations': MaxIterations,
84
        'time-limit': TimeLimit,
85
        'convergence': Convergence,
86
    }
87
    return type('M', (), {
88
        'facility': TerminationConditionFacility,
89
        'interface': TerminationConditionInterface,
90
    })
91
92
93
@pytest.fixture
94
def termination_condition(termination_condition_module):
95
    def create_termination_condition(term_cond_type: str, *args, **kwargs) -> termination_condition_module.interface:
96
        return termination_condition_module.facility.create(term_cond_type, *args, **kwargs)
97
    return create_termination_condition
98
 
99
100
@pytest.fixture
101
def subscribe():
102
    def _subscribe(broadcaster, listeners):
103
        broadcaster.add(*listeners)
104
    return _subscribe
105
106
107
108
@pytest.fixture
109
def broadcaster_class():
110
    class TestSubject:
111
        def __init__(self, subject, done_callback):
112
            self.subject = subject
113
            self.done = done_callback
114
115
        def iterate(self):
116
            i = 0
117
            while not self.done():
118
                # do something in the current iteration
119
                print('Iteration with index', i)
120
121
                # notify when we have completed i+1 iterations
122
                self.subject.state = type('Subject', (), {
123
                    'metrics': {'iterations': i + 1},  # we have completed i+1 iterations
124
                }) 
125
                self.subject.notify()
126
                i += 1
127
            return i
128
129
    return TestSubject
130