Passed
Push — master ( ee1e78...515b92 )
by Konstantinos
01:14
created

artificial_artwork.styling_observer   A

Complexity

Total Complexity 8

Size/Duplication

Total Lines 60
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 42
dl 0
loc 60
rs 10
c 0
b 0
f 0
wmc 8

2 Methods

Rating   Name   Duplication   Size   Complexity  
A StylingObserver._convert_to_uint8() 0 20 4
A StylingObserver.update() 0 20 4
1
from attr import define 
2
from .utils.notification import Observer
3
from typing import Callable
4
import numpy.typing as npt
5
import numpy as np
6
import os
7
8
9
@define
10
class StylingObserver(Observer):
11
    save_on_disk_callback: Callable[[str, npt.NDArray], None]
12
    """Store a snapshot of the image under construction.
13
14
    Args:
15
        Observer ([type]): [description]
16
    """
17
    def update(self, *args, **kwargs):
18
        output_dir = args[0].state.output_path
19
        content_image_path = args[0].state.content_image_path
20
        style_image_path = args[0].state.style_image_path
21
        itererations_completed = args[0].state.metrics['iterations']
22
        matrix = args[0].state.matrix
23
24
        output_file_path = os.path.join(
25
            output_dir,
26
            f'{os.path.basename(content_image_path)}+{os.path.basename(style_image_path)}-{itererations_completed}.png'
27
        )
28
        
29
        if matrix.ndim == 4 and matrix.shape[0] == 1:
30
            # we have shape of form (1, Width, Height, Number_of_Color_Channels)
31
            matrix = np.reshape(matrix, tuple(matrix.shape[1:]))
32
33
        if str(matrix.dtype) != 'uint8':
34
            matrix = self._convert_to_uint8(matrix)
35
36
        self.save_on_disk_callback(matrix, output_file_path, format='png')
37
38
    bit_2_data_type = {8: np.uint8}
39
40
    def _convert_to_uint8(self, im):
41
        bitdepth = 8
42
        out_type = type(self).bit_2_data_type[bitdepth]
43
        mi = np.nanmin(im)
44
        ma = np.nanmax(im)
45
        if not np.isfinite(mi):
46
            raise ValueError("Minimum image value is not finite")
47
        if not np.isfinite(ma):
48
            raise ValueError("Maximum image value is not finite")
49
        if ma == mi:
50
            return im.astype(out_type)
51
52
        # Make float copy before we scale
53
        im = im.astype("float64")
54
        # Scale the values between 0 and 1 then multiply by the max value
55
        im = (im - mi) / (ma - mi) * (np.power(2.0, bitdepth) - 1) + 0.499999999
56
        assert np.nanmin(im) >= 0
57
        assert np.nanmax(im) < np.power(2.0, bitdepth)
58
        im = im.astype(out_type)
59
        return im
60