Passed
Push — master ( 338098...c3d045 )
by Konstantinos
02:29 queued 01:15
created

artificial_artwork.pre_trained_models.vgg   A

Complexity

Total Complexity 7

Size/Duplication

Total Lines 42
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 7
eloc 28
dl 0
loc 42
rs 10
c 0
b 0
f 0

7 Methods

Rating   Name   Duplication   Size   Complexity  
A VggModelRoutines.get_weights() 0 2 1
A VggModelHandler.environment_variable() 0 3 1
A VggModelHandler.model_load_exception_text() 0 3 1
A VggModelHandler.model_routines() 0 3 1
A VggModelRoutines.get_id() 0 2 1
A VggModelRoutines.get_layers_dict() 0 2 1
A VggModelRoutines.load_layers() 0 2 1
1
from typing import Tuple, Dict
2
from numpy.typing import NDArray
3
import scipy.io
4
5
6
from artificial_artwork.pretrained_model.model_routines import PretrainedModelRoutines
7
from artificial_artwork.pretrained_model import ModelHandlerFacility, Modelhandler
8
9
10
class VggModelRoutines(PretrainedModelRoutines):
11
12
    def load_layers(self, file_path: str) -> NDArray:
13
        return scipy.io.loadmat(file_path)['layers'][0]
14
15
    def get_id(self, layer: NDArray) -> str:
16
        return layer[0][0][0][0]
17
18
    def get_layers_dict(self, layers: NDArray) -> Dict[str, NDArray]:
19
        return {self.get_id(layer): layers[index] for index, layer in enumerate(layers)}
20
21
    def get_weights(self, layer: NDArray) -> Tuple[NDArray, NDArray]:
22
        return layer[0][0][2][0][0], layer[0][0][2][0][1]
23
24
25
vgg_model_routines = VggModelRoutines()
26
27
28
@ModelHandlerFacility.factory.register_as_subclass('vgg')
29
class VggModelHandler(Modelhandler):
30
31
    @property
32
    def environment_variable(self) -> str:
33
        return 'AA_VGG_19'
34
35
    @property
36
    def model_routines(self) -> VggModelRoutines:
37
        return vgg_model_routines
38
39
    @property
40
    def model_load_exception_text(self) -> str:
41
        return 'No pretrained image model found. ' \
42
            f'Please download it and set the {self.environment_variable} ' \
43
            'environment variable with the path where you stored the model ' \
44
            '(*.mat file), to instruct the program where to locate and load it'
45