Passed
Push — master ( 338098...c3d045 )
by Konstantinos
02:29 queued 01:15
created

artificial_artwork.pretrained_model.model_handler   A

Complexity

Total Complexity 9

Size/Duplication

Total Lines 63
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 9
eloc 42
dl 0
loc 63
rs 10
c 0
b 0
f 0

7 Methods

Rating   Name   Duplication   Size   Complexity  
A Modelhandler.load_model_layers() 0 4 1
A ReporterProtocol.get_weights() 0 1 1
A Modelhandler._load_model_layers() 0 6 2
A ModelHandlerFacility.create() 0 3 1
A Modelhandler._create_reporter() 0 4 1
A Modelhandler.reporter() 0 3 1
A Modelhandler.__init__() 0 2 1
1
import os
2
from typing import Tuple, Protocol
3
from numpy.typing import NDArray
4
5
from artificial_artwork.utils import SubclassRegistry
6
from .model_handler_interface import ModelHandlerInterface
7
from .layers_getter import ModelReporter
8
9
10
class ReporterProtocol(Protocol):
11
    def get_weights(self, layer_id: str) -> Tuple[NDArray, NDArray]: ...
12
13
14
class Modelhandler(ModelHandlerInterface):
15
    _reporter: ReporterProtocol
16
    def __init__(self):
17
        self._reporter = None
18
19
    @property
20
    def reporter(self) -> ReporterProtocol:
21
        return self._reporter
22
23
    @reporter.setter
24
    def reporter(self, layers) -> None:
25
        self._reporter = self._create_reporter(layers)
26
27
    def _create_reporter(self, layers: NDArray) -> ReporterProtocol:
28
        return ModelReporter(
29
            self.model_routines.get_layers_dict(layers),
30
            self.model_routines.get_weights
31
        )
32
33
    def load_model_layers(self) -> NDArray:
34
        layers = self._load_model_layers()
35
        self._reporter = self._create_reporter(layers)
36
        return layers
37
38
    def _load_model_layers(self) -> NDArray:
39
        try:
40
            return self.model_routines.load_layers(os.environ[self.environment_variable])
41
        except KeyError as variable_not_found:
42
            raise NoImageModelSpesifiedError(self.model_load_exception_text) \
43
                from variable_not_found
44
45
46
class NoImageModelSpesifiedError(Exception): pass
47
48
49
class ModelHandlerFactoryMeta(SubclassRegistry[Modelhandler]): pass
50
51
52
class ModelHandlerFactory(metaclass=ModelHandlerFactoryMeta): pass
53
54
55
class ModelHandlerFacility:
56
    # routines_interface: type = PretrainedModelRoutines
57
    handler_class: type = Modelhandler
58
    factory = ModelHandlerFactory
59
60
    @classmethod
61
    def create(cls, handler_type, *args, **kwargs) -> Modelhandler:
62
        return cls.factory.create(handler_type, *args, **kwargs)
63