|
1
|
|
|
from typing import Tuple, Dict |
|
2
|
|
|
from numpy.typing import NDArray |
|
3
|
|
|
import scipy.io |
|
4
|
|
|
|
|
5
|
|
|
|
|
6
|
|
|
from artificial_artwork.pretrained_model.model_routines import PretrainedModelRoutines |
|
7
|
|
|
from artificial_artwork.pretrained_model import ModelHandlerFacility, Modelhandler |
|
8
|
|
|
|
|
9
|
|
|
|
|
10
|
|
|
class VggModelRoutines(PretrainedModelRoutines): |
|
11
|
|
|
|
|
12
|
|
|
def load_layers(self, file_path: str) -> NDArray: |
|
13
|
|
|
return scipy.io.loadmat(file_path)['layers'][0] |
|
14
|
|
|
|
|
15
|
|
|
def get_id(self, layer: NDArray) -> str: |
|
16
|
|
|
return layer[0][0][0][0] |
|
17
|
|
|
|
|
18
|
|
|
def get_layers_dict(self, layers: NDArray) -> Dict[str, NDArray]: |
|
19
|
|
|
return {self.get_id(layer): layers[index] for index, layer in enumerate(layers)} |
|
20
|
|
|
|
|
21
|
|
|
def get_weights(self, layer: NDArray) -> Tuple[NDArray, NDArray]: |
|
22
|
|
|
return layer[0][0][2][0][0], layer[0][0][2][0][1] |
|
23
|
|
|
|
|
24
|
|
|
|
|
25
|
|
|
vgg_model_routines = VggModelRoutines() |
|
26
|
|
|
|
|
27
|
|
|
|
|
28
|
|
|
@ModelHandlerFacility.factory.register_as_subclass('vgg') |
|
29
|
|
|
class VggModelHandler(Modelhandler): |
|
30
|
|
|
|
|
31
|
|
|
@property |
|
32
|
|
|
def environment_variable(self) -> str: |
|
33
|
|
|
return 'AA_VGG_19' |
|
34
|
|
|
|
|
35
|
|
|
@property |
|
36
|
|
|
def model_routines(self) -> VggModelRoutines: |
|
37
|
|
|
return vgg_model_routines |
|
38
|
|
|
|
|
39
|
|
|
@property |
|
40
|
|
|
def model_load_exception_text(self) -> str: |
|
41
|
|
|
return 'No pretrained image model found. ' \ |
|
42
|
|
|
f'Please download it and set the {self.environment_variable} ' \ |
|
43
|
|
|
'environment variable with the path where you stored the model ' \ |
|
44
|
|
|
'(*.mat file), to instruct the program where to locate and load it' |
|
45
|
|
|
|