Passed
Push — master ( 338098...c3d045 )
by Konstantinos
02:29 queued 01:15
created

test_cv_model.graph_factory()   A

Complexity

Conditions 1

Size

Total Lines 4
Code Lines 4

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 4
dl 0
loc 4
rs 10
c 0
b 0
f 0
cc 1
nop 0
1
import os
2
import pytest
3
4
5
my_dir = os.path.dirname(os.path.realpath(__file__))
6
7
8
@pytest.fixture
9
def graph_factory():
10
    from artificial_artwork.style_model import graph_factory
11
    return graph_factory
12
13
14
def test_pretrained_model(model, graph_factory):
15
    layers = model.pretrained_model.handler.load_model_layers()
16
17
    image_specs = type('ImageSpecs', (), {
18
        'width': 400,
19
        'height': 300,
20
        'color_channels': 3
21
    })()
22
23
    assert len(layers) == len(model.pretrained_model.expected_layers)
24
    for i, name in enumerate(model.pretrained_model.expected_layers):
25
        assert layers[i][0][0][0][0] == name
26
27
    model.pretrained_model.handler.reporter = layers
28
    model_design = type('ModelDesign', (), {
29
        'pretrained_model': model.pretrained_model.handler,
30
        'network_design': model.network_design
31
    })
32
    graph = graph_factory.create(image_specs, model_design)
33
    assert set(graph.keys()) == set(['input'] + list(model.network_design.network_layers))
34