|
1
|
|
|
"""This modules defines the interface which must be implemented in order to |
|
2
|
|
|
utilize a pretrained model and its weights""" |
|
3
|
|
|
|
|
4
|
|
|
from abc import ABC, abstractmethod |
|
5
|
|
|
from typing import Tuple, Dict |
|
6
|
|
|
from numpy.typing import NDArray |
|
7
|
|
|
|
|
8
|
|
|
|
|
9
|
|
|
class PretrainedModelRoutines(ABC): |
|
10
|
|
|
"""Set of routines that are required in order to use a pretrained model for nst.""" |
|
11
|
|
|
|
|
12
|
|
|
@abstractmethod |
|
13
|
|
|
def load_layers(self, file_path: str) -> NDArray: |
|
14
|
|
|
"""Load a pretrained model from disk. |
|
15
|
|
|
|
|
16
|
|
|
Loads the model parameters, given the path to a file in the disk, that |
|
17
|
|
|
indicated where the pretrained model is. |
|
18
|
|
|
|
|
19
|
|
|
Args: |
|
20
|
|
|
file_path (str): the path corresponding to a file in the disk |
|
21
|
|
|
|
|
22
|
|
|
Returns: |
|
23
|
|
|
NDArray: the model parameters as a numpy array |
|
24
|
|
|
""" |
|
25
|
|
|
raise NotImplementedError |
|
26
|
|
|
|
|
27
|
|
|
@abstractmethod |
|
28
|
|
|
def get_id(self, layer: NDArray) -> str: |
|
29
|
|
|
"""Get the id of a model's network layer. |
|
30
|
|
|
|
|
31
|
|
|
The pretrained model being a neural network has a specific architecture |
|
32
|
|
|
and each layer should a unique string id that one can reference it. |
|
33
|
|
|
|
|
34
|
|
|
Args: |
|
35
|
|
|
layer (NDArray): the layer of a pretrained neural network model |
|
36
|
|
|
|
|
37
|
|
|
Returns: |
|
38
|
|
|
str: the layer id |
|
39
|
|
|
""" |
|
40
|
|
|
raise NotImplementedError |
|
41
|
|
|
|
|
42
|
|
|
@abstractmethod |
|
43
|
|
|
def get_layers_dict(self, layers: NDArray) -> Dict[str, NDArray]: |
|
44
|
|
|
"""Get a dict mapping strings to pretrained model layers. |
|
45
|
|
|
|
|
46
|
|
|
Args: |
|
47
|
|
|
layers (NDArray): the pretrained model layers |
|
48
|
|
|
|
|
49
|
|
|
Returns: |
|
50
|
|
|
Dict[str, NDArray]: the dictionary mapping strings to layers |
|
51
|
|
|
""" |
|
52
|
|
|
raise NotImplementedError |
|
53
|
|
|
|
|
54
|
|
|
@abstractmethod |
|
55
|
|
|
def get_weights(self, layer: NDArray) -> Tuple[NDArray, NDArray]: |
|
56
|
|
|
"""Get the values of the weights of a given network layer. |
|
57
|
|
|
|
|
58
|
|
|
Each pretrained model network layer has "learned" certain parameters in |
|
59
|
|
|
the form of "weights" (ie weight matrices A and b in equation Ax + b). |
|
60
|
|
|
|
|
61
|
|
|
Call this method to get a tuple of the A and b mathematical matrices. |
|
62
|
|
|
|
|
63
|
|
|
Args: |
|
64
|
|
|
layer (NDArray): the layer of a pretrained neural network model |
|
65
|
|
|
|
|
66
|
|
|
Returns: |
|
67
|
|
|
Tuple[NDArray, NDArray]: the weights in matrix A and b |
|
68
|
|
|
""" |
|
69
|
|
|
raise NotImplementedError |
|
70
|
|
|
|