Passed
Push — master ( 338098...c3d045 )
by Konstantinos
02:29 queued 01:15
created

VggModelHandler.model_routines()   A

Complexity

Conditions 1

Size

Total Lines 3
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 3
dl 0
loc 3
rs 10
c 0
b 0
f 0
cc 1
nop 1
1
from typing import Tuple, Dict
2
from numpy.typing import NDArray
3
import scipy.io
4
5
6
from artificial_artwork.pretrained_model.model_routines import PretrainedModelRoutines
7
from artificial_artwork.pretrained_model import ModelHandlerFacility, Modelhandler
8
9
10
class VggModelRoutines(PretrainedModelRoutines):
11
12
    def load_layers(self, file_path: str) -> NDArray:
13
        return scipy.io.loadmat(file_path)['layers'][0]
14
15
    def get_id(self, layer: NDArray) -> str:
16
        return layer[0][0][0][0]
17
18
    def get_layers_dict(self, layers: NDArray) -> Dict[str, NDArray]:
19
        return {self.get_id(layer): layers[index] for index, layer in enumerate(layers)}
20
21
    def get_weights(self, layer: NDArray) -> Tuple[NDArray, NDArray]:
22
        return layer[0][0][2][0][0], layer[0][0][2][0][1]
23
24
25
vgg_model_routines = VggModelRoutines()
26
27
28
@ModelHandlerFacility.factory.register_as_subclass('vgg')
29
class VggModelHandler(Modelhandler):
30
31
    @property
32
    def environment_variable(self) -> str:
33
        return 'AA_VGG_19'
34
35
    @property
36
    def model_routines(self) -> VggModelRoutines:
37
        return vgg_model_routines
38
39
    @property
40
    def model_load_exception_text(self) -> str:
41
        return 'No pretrained image model found. ' \
42
            f'Please download it and set the {self.environment_variable} ' \
43
            'environment variable with the path where you stored the model ' \
44
            '(*.mat file), to instruct the program where to locate and load it'
45