| Total Complexity | 3 |
| Total Lines | 34 |
| Duplicated Lines | 0 % |
| Changes | 0 | ||
| 1 | import os |
||
| 2 | import pytest |
||
| 3 | |||
| 4 | |||
| 5 | my_dir = os.path.dirname(os.path.realpath(__file__)) |
||
| 6 | |||
| 7 | |||
| 8 | @pytest.fixture |
||
| 9 | def graph_factory(): |
||
| 10 | from artificial_artwork.style_model import graph_factory |
||
| 11 | return graph_factory |
||
| 12 | |||
| 13 | |||
| 14 | def test_pretrained_model(model, graph_factory): |
||
| 15 | layers = model.pretrained_model.handler.load_model_layers() |
||
| 16 | |||
| 17 | image_specs = type('ImageSpecs', (), { |
||
| 18 | 'width': 400, |
||
| 19 | 'height': 300, |
||
| 20 | 'color_channels': 3 |
||
| 21 | })() |
||
| 22 | |||
| 23 | assert len(layers) == len(model.pretrained_model.expected_layers) |
||
| 24 | for i, name in enumerate(model.pretrained_model.expected_layers): |
||
| 25 | assert layers[i][0][0][0][0] == name |
||
| 26 | |||
| 27 | model.pretrained_model.handler.reporter = layers |
||
| 28 | model_design = type('ModelDesign', (), { |
||
| 29 | 'pretrained_model': model.pretrained_model.handler, |
||
| 30 | 'network_design': model.network_design |
||
| 31 | }) |
||
| 32 | graph = graph_factory.create(image_specs, model_design) |
||
| 33 | assert set(graph.keys()) == set(['input'] + list(model.network_design.network_layers)) |
||
| 34 |