Passed
Pull Request — master (#6)
by Konstantinos
03:56
created

StylingObserver.update()   B

Complexity

Conditions 6

Size

Total Lines 28
Code Lines 18

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 18
dl 0
loc 28
rs 8.5666
c 0
b 0
f 0
cc 6
nop 3
1
import os
2
from typing import Callable
3
from attr import define
4
import numpy as np
5
import numpy.typing as npt
6
7
from .utils import Observer
8
9
10
@define
11
class StylingObserver(Observer):
12
    save_on_disk_callback: Callable[[str, npt.NDArray], None]
13
    convert_to_unit8: Callable[[npt.NDArray], npt.NDArray]
14
    """Store a snapshot of the image under construction.
15
16
    Args:
17
        Observer ([type]): [description]
18
    """
19
    def update(self, *args, **kwargs):
20
        output_dir = args[0].state.output_path
21
        content_image_path = args[0].state.content_image_path
22
        style_image_path = args[0].state.style_image_path
23
        iterations_completed = args[0].state.metrics['iterations']
24
        matrix = args[0].state.matrix
25
26
        # Future work: Impelement handling of the "request to persist" with a
27
        # chain of responsibility design pattern. It suits this case  since we
28
        # do not know how many checks and/or image transformation will be
29
        # required before saving on disk
30
31
        output_file_path = os.path.join(
32
            output_dir,
33
            f'{os.path.basename(content_image_path)}+{os.path.basename(style_image_path)}-{iterations_completed}.png'
34
        )
35
        # if we have shape of form (1, Width, Height, Number_of_Color_Channels)
36
        if matrix.ndim == 4 and matrix.shape[0] == 1:
37
            # reshape to (Width, Height, Number_of_Color_Channels)
38
            matrix = np.reshape(matrix, tuple(matrix.shape[1:]))
39
40
        if str(matrix.dtype) != 'uint8':
41
            matrix = self.convert_to_unit8(matrix)
42
        if np.nanmin(matrix) < 0:
43
            raise ImageDataValueError('Generated Image has pixel(s) with negative values.')
44
        if np.nanmax(matrix) >= np.power(2.0, 8):
45
            raise ImageDataValueError('Generated Image has pixel(s) with value >= 255.')
46
        self.save_on_disk_callback(matrix, output_file_path, save_format='png')
47
48
49
class ImageDataValueError(Exception): pass
50