Passed
Pull Request — master (#6)
by Konstantinos
03:56
created

artificial_artwork.cli.cli()   A

Complexity

Conditions 2

Size

Total Lines 48
Code Lines 31

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 31
dl 0
loc 48
rs 9.1359
c 0
b 0
f 0
cc 2
nop 4
1
import sys
2
import click
3
import numpy as np
4
5
from .disk_operations import Disk
6
from .styling_observer import StylingObserver
7
from .algorithm import NSTAlgorithm, AlogirthmParameters
8
from .nst_tf_algorithm import NSTAlgorithmRunner
9
from .termination_condition_adapter_factory import TerminationConditionAdapterFactory
10
from .nst_image import ImageManager, noisy, convert_to_uint8
11
from .production_networks import NetworkDesign
12
from .pretrained_model import ModelHandlerFacility
13
14
15
def load_pretrained_model_functions():
16
    # future work: discover dynamically the modules inside the pre_trained_model
17
    # package
18
    from .pre_trained_models import vgg
19
    return vgg
20
21
22
def read_images(content, style):
23
    # todo dynamically find means
24
    means = np.array([123.68, 116.779, 103.939]).reshape((1,1,1,3))  # means
25
26
    image_manager = ImageManager.default(means)
27
28
    # probably load each image in separate thread and then join
29
    image_manager.load_from_disk(content, 'content')
30
    image_manager.load_from_disk(style, 'style')
31
32
    if not image_manager.images_compatible:
33
        print("Given CONTENT image '{content_image}' has 'height' x 'width' x "
34
        f"'color_channels': {image_manager.content_image.matrix.shape}")
35
        print("Given STYLE image '{style_image}' has 'height' x 'width' x "
36
        f"'color_channels': {image_manager.style_image.matrix.shape}")
37
        print('Expected to receive images (matrices) of identical shape')
38
        print('Exiting..')
39
        sys.exit(1)
40
41
    return image_manager.content_image, image_manager.style_image
42
43
44
@click.command()
45
@click.argument('content_image')
46
@click.argument('style_image')
47
@click.option('--iterations', '-it', type=int, default=100, show_default=True)
48
@click.option('--location', '-l', type=str, default='.')
49
def cli(content_image, style_image, iterations, location):
50
51
    termination_condition = 'max-iterations'
52
53
    content_image, style_image = read_images(content_image, style_image)
54
55
    load_pretrained_model_functions()
56
    model_design = type('ModelDesign', (), {
57
        'pretrained_model': ModelHandlerFacility.create('vgg'),
58
        'network_design': NetworkDesign.from_default_vgg()
59
    })
60
    model_design.pretrained_model.load_model_layers()
61
62
    termination_condition = TerminationConditionAdapterFactory.create(
63
        termination_condition,
64
        iterations,
65
    )
66
67
    print(f' -- Termination Condition: {termination_condition.termination_condition}')
68
69
    algorithm_parameters = AlogirthmParameters(
70
        content_image,
71
        style_image,
72
        termination_condition,
73
        location,
74
    )
75
76
    algorithm = NSTAlgorithm(algorithm_parameters)
77
78
    noisy_ratio = 0.6  # ratio
79
80
    algorithm_runner = NSTAlgorithmRunner.default(
81
        lambda matrix: noisy(matrix, noisy_ratio),
82
    )
83
84
    algorithm_runner.progress_subject.add(
85
        termination_condition,
86
    )
87
    algorithm_runner.persistance_subject.add(
88
        StylingObserver(Disk.save_image, convert_to_uint8)
89
    )
90
91
    algorithm_runner.run(algorithm, model_design)
92