Passed
Push — master ( 338098...c3d045 )
by Konstantinos
02:29 queued 01:15
created

PretrainedModelRoutines.load_layers()   A

Complexity

Conditions 1

Size

Total Lines 14
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 3
dl 0
loc 14
rs 10
c 0
b 0
f 0
cc 1
nop 2
1
"""This modules defines the interface which must be implemented in order to
2
utilize a pretrained model and its weights"""
3
4
from abc import ABC, abstractmethod
5
from typing import Tuple, Dict
6
from numpy.typing import NDArray
7
8
9
class PretrainedModelRoutines(ABC):
10
    """Set of routines that are required in order to use a pretrained model for nst."""
11
12
    @abstractmethod
13
    def load_layers(self, file_path: str) -> NDArray:
14
        """Load a pretrained model from disk.
15
        
16
        Loads the model parameters, given the path to a file in the disk, that
17
        indicated where the pretrained model is.
18
19
        Args:
20
            file_path (str): the path corresponding to a file in the disk
21
22
        Returns:
23
            NDArray: the model parameters as a numpy array
24
        """
25
        raise NotImplementedError
26
27
    @abstractmethod
28
    def get_id(self, layer: NDArray) -> str:
29
        """Get the id of a model's network layer.
30
31
        The pretrained model being a neural network has a specific architecture
32
        and each layer should a unique string id that one can reference it.
33
34
        Args:
35
            layer (NDArray): the layer of a pretrained neural network model
36
37
        Returns:
38
            str: the layer id
39
        """
40
        raise NotImplementedError
41
    
42
    @abstractmethod
43
    def get_layers_dict(self, layers: NDArray) -> Dict[str, NDArray]:
44
        """Get a dict mapping strings to pretrained model layers.
45
46
        Args:
47
            layers (NDArray): the pretrained model layers
48
49
        Returns:
50
            Dict[str, NDArray]: the dictionary mapping strings to layers
51
        """
52
        raise NotImplementedError
53
54
    @abstractmethod
55
    def get_weights(self, layer: NDArray) -> Tuple[NDArray, NDArray]:
56
        """Get the values of the weights of a given network layer.
57
58
        Each pretrained model network layer has "learned" certain parameters in
59
        the form of "weights" (ie weight matrices A and b in equation Ax + b).
60
61
        Call this method to get a tuple of the A and b mathematical matrices.
62
63
        Args:
64
            layer (NDArray): the layer of a pretrained neural network model
65
66
        Returns:
67
            Tuple[NDArray, NDArray]: the weights in matrix A and b
68
        """
69
        raise NotImplementedError
70