Passed
Push — master ( c3d045...3525ae )
by Konstantinos
01:55 queued 43s
created

artificial_artwork._main   A

Complexity

Total Complexity 7

Size/Duplication

Total Lines 127
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 69
dl 0
loc 127
rs 10
c 0
b 0
f 0
wmc 7

5 Functions

Rating   Name   Duplication   Size   Complexity  
A _read_algorithm_input() 0 17 1
A _create_termination_condition() 0 8 1
A create_algo_runner() 0 20 2
A _create_algo_runner() 0 42 2
A _load_algorithm_architecture() 0 8 1
1
import os
2
import sys
3
4
from .disk_operations import Disk
5
from .styling_observer import StylingObserver
6
from .algorithm import NSTAlgorithm, AlogirthmParameters
7
from .nst_tf_algorithm import NSTAlgorithmRunner
8
from .termination_condition_adapter_factory import TerminationConditionAdapterFactory
9
from .nst_image import noisy, convert_to_uint8
10
from .production_networks import NetworkDesign
11
from .pretrained_model import ModelHandlerFacility
12
from .utils import load_pretrained_model_functions, read_images
13
14
from artificial_artwork import __version__
15
16
this_file_location = os.path.dirname(os.path.realpath(os.path.abspath(__file__)))
17
18
19
__all__ = ['create_algo_runner']
20
21
22
def create_algo_runner(
23
    iterations=100,
24
    output_folder='gui-output-folder',
25
    noisy_ratio=0.6  # ratio
26
):
27
    termination_condition = _create_termination_condition(iterations)
28
29
    algorithm_runner =_create_algo_runner(termination_condition, noisy_ratio=noisy_ratio)
30
31
    def run(content_image, style_image):
32
        algorithm = _read_algorithm_input(
33
            content_image, style_image, termination_condition, output_folder
34
        )
35
        model_design = _load_algorithm_architecture()
36
37
        algorithm_runner.run(algorithm, model_design)
38
39
    return {
40
        'run': run,
41
        'subscribe': lambda observer: algorithm_runner.progress_subject.add(observer),
42
    }
43
44
45
def _create_algo_runner(termination_condition, noisy_ratio=0.6):
46
    import tensorflow as tf
47
    from artificial_artwork.tf_session_runner import (
48
        TensorflowSessionRunnerSubject,
49
        TensorflowSessionRunner,
50
    )
51
    from artificial_artwork.image import (
52
        noisy
53
    )
54
    tf.compat.v1.reset_default_graph()
55
    tf.compat.v1.disable_eager_execution()
56
    tf_session_wrapper = TensorflowSessionRunner(TensorflowSessionRunnerSubject(
57
        tf.compat.v1.InteractiveSession()
58
    ))
59
    # session_runner = TensorflowSessionRunner.with_default_graph_reset()
60
    algorithm_runner = NSTAlgorithmRunner(
61
        tf_session_wrapper,
62
        lambda matrix: noisy(matrix, noisy_ratio),
63
    )
64
    # algorithm_runner = NSTAlgorithmRunner.default(
65
    #     lambda matrix: noisy(matrix, noisy_ratio),
66
    # )
67
    # Subscribe the termination_condition object so that ir receives updates
68
    # whenever the runner broadcasts updates.
69
    # The NST Algorithm Runner broadcasts updates on a steady frequency during
70
    # the run. It always broadcats on First and Last Iteration. For example,
71
    # if the run is 100 iterations, it will broadcast on iterations
72
    # 0, 20, 40, 60, 80, 100
73
    # Each broadcast is an event 'carrying' a progress object, which is a python
74
    # Dict
75
    # For more on the expected keys and values of the progress Dict see the
76
    # '_progress' instance method defined in the
77
    # artificial_artwork.nst_tf_algorithm.py > NSTAlgorithmRunner class
78
79
    algorithm_runner.progress_subject.add(
80
        termination_condition,
81
    )
82
    # Subscribe Persistance so that we keep snaphosts of the generated images in the disk
83
    algorithm_runner.persistance_subject.add(
84
        StylingObserver(Disk.save_image, convert_to_uint8, termination_condition.termination_condition.max_iterations)
85
    )
86
    return algorithm_runner
87
88
DEFAULT_TERMINATION_CONDITION = 'max-iterations'
89
90
def _create_termination_condition(nb_iterations_to_perform):
91
92
    _ = TerminationConditionAdapterFactory.create(
93
        DEFAULT_TERMINATION_CONDITION,
94
        nb_iterations_to_perform,
95
    )
96
    print(f' -- Termination Condition: {_.termination_condition}')
97
    return _
98
99
100
def _load_algorithm_architecture():
101
    load_pretrained_model_functions()
102
    model_design = type('ModelDesign', (), {
103
        'pretrained_model': ModelHandlerFacility.create('vgg'),
104
        'network_design': NetworkDesign.from_default_vgg()
105
    })
106
    model_design.pretrained_model.load_model_layers()
107
    return model_design
108
109
110
def _read_algorithm_input(content_image, style_image, termination_condition, location):
111
    # Read Images given their file paths in the disk (filesystem)
112
    content_image, style_image = read_images(content_image, style_image)
113
    
114
    # Compute Termination Condition, given input number of iterations to perform
115
    # The number of iterations is the number the image will pass through the
116
    # network. The more iterations the more the Style is applied.
117
    # 
118
    # The number of iterations is not the number of times the network
119
    # will be trained. The network is trained only once, and the image is
120
    # passed through it multiple times. 
121
122
    return NSTAlgorithm(AlogirthmParameters(
123
        content_image,
124
        style_image,
125
        termination_condition,
126
        location,
127
    ))
128