deepreg.registry   A
last analyzed

Complexity

Total Complexity 28

Size/Duplication

Total Lines 311
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 28
eloc 104
dl 0
loc 311
rs 10
c 0
b 0
f 0

18 Methods

Rating   Name   Duplication   Size   Complexity  
A Registry.register_data_loader() 0 13 1
A Registry.build_data_loader() 0 12 1
A Registry.build_data_augmentation() 0 12 1
A Registry.register_loss() 0 12 1
A Registry.register_file_loader() 0 13 1
A Registry.register_data_augmentation() 0 13 1
A Registry.build_loss() 0 10 1
B Registry.build_from_config() 0 29 6
A Registry.register() 0 27 2
A Registry._register() 0 22 4
A Registry.get() 0 11 2
A Registry.build_backbone() 0 10 1
A Registry.contains() 0 9 1
A Registry.__init__() 0 3 1
A Registry.build_model() 0 10 1
A Registry.register_backbone() 0 12 1
A Registry.register_model() 0 12 1
A Registry.copy() 0 5 1
1
from copy import deepcopy
2
from typing import Any, Callable, Dict, Optional
3
4
BACKBONE_CLASS = "backbone_class"
5
LOSS_CLASS = "loss_class"
6
METRIC_CLASS = "metric_class"
7
MODEL_CLASS = "model_class"
8
DATA_AUGMENTATION_CLASS = "da_class"
9
DATA_LOADER_CLASS = "data_loader_class"
10
FILE_LOADER_CLASS = "file_loader_class"
11
12
KNOWN_CATEGORIES = [
13
    BACKBONE_CLASS,
14
    LOSS_CLASS,
15
    MODEL_CLASS,
16
    DATA_AUGMENTATION_CLASS,
17
    DATA_LOADER_CLASS,
18
    FILE_LOADER_CLASS,
19
]
20
21
22
class Registry:
23
    """
24
    Registry maintains a dictionary which maps `(category, key)` to `value`.
25
26
    Multiple __init__.py files have been modified so that the classes are registered
27
    when executing:
28
29
    .. code-block:: python
30
31
        from deepreg.registry import REGISTRY
32
33
    References:
34
35
    - https://github.com/ray-project/ray/blob/00ef1179c012719a17c147a5c3b36d6bdbe97195/python/ray/tune/registry.py#L108
36
    - https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/builder.py
37
    - https://github.com/open-mmlab/mmcv/blob/master/mmcv/utils/registry.py
38
    - https://towardsdatascience.com/whats-init-for-me-d70a312da583
39
    """
40
41
    def __init__(self):
42
        """Init registry with empty dict."""
43
        self._dict = {}
44
45
    def _register(self, category: str, key: str, value: Callable, force: bool):
46
        """
47
        Registers the value with the registry.
48
49
        :param category: name of the class category
50
        :param key: unique identity
51
        :param value: class to be registered
52
        :param force: if True, overwrite the existing value
53
            in case the key has been registered.
54
        """
55
        # sanity check
56
        if category not in KNOWN_CATEGORIES:
57
            raise ValueError(
58
                f"Unknown category {category} not among {KNOWN_CATEGORIES}"
59
            )
60
        if not force and self.contains(category=category, key=key):
61
            raise ValueError(
62
                f"Key {key} in category {category} has been registered"
63
                f" with {self._dict[(category, key)]}"
64
            )
65
        # register value
66
        self._dict[(category, key)] = value
67
68
    def contains(self, category: str, key: str) -> bool:
69
        """
70
        Verify if the key has been registered for the category.
71
72
        :param category: category name.
73
        :param key: value name.
74
        :return: `True` if registered.
75
        """
76
        return (category, key) in self._dict
77
78
    def get(self, category: str, key: str) -> Callable:
79
        """
80
        Return the registered class.
81
82
        :param category: category name.
83
        :param key: value name.
84
        :return: registered value.
85
        """
86
        if self.contains(category=category, key=key):
87
            return self._dict[(category, key)]
88
        raise ValueError(f"Key {key} in category {category} has not been registered.")
89
90
    def register(
91
        self, category: str, name: str, cls: Callable = None, force: bool = False
92
    ) -> Callable:
93
        """
94
        Register a py class.
95
        A record will be added to `self._dict`, whose key is the class
96
        name or the specified name, and value is the class itself.
97
        It can be used as a decorator or a normal function.
98
99
        :param category: The type of the category.
100
        :param name: The class name to be registered.
101
            If not specified, the class name will be used.
102
        :param force: Whether to override an existing class with the same name.
103
        :param cls: Class to be registered.
104
        :return: The given class or a decorator.
105
        """
106
        # use it as a normal method: x.register_module(module=SomeClass)
107
        if cls is not None:
108
            self._register(category=category, key=name, value=cls, force=force)
109
            return cls
110
111
        # use it as a decorator: @x.register(name, category)
112
        def decorator(c: Callable) -> Callable:
113
            self._register(category=category, key=name, value=c, force=force)
114
            return c
115
116
        return decorator
117
118
    def build_from_config(
119
        self, category: str, config: Dict, default_args: Optional[dict] = None
120
    ) -> Any:
121
        """
122
        Build a class instance from config dict.
123
124
        :param category: category name.
125
        :param config: a dict which must contain the key "name".
126
        :param default_args: optionally some default arguments.
127
        :return: the instantiated class.
128
        """
129
        if not isinstance(config, dict):
130
            raise ValueError(f"config must be a dict, but got {type(config)}")
131
        if "name" not in config:
132
            raise ValueError(f"`config` must contain the key `name`, but got {config}")
133
        args = deepcopy(config)
134
135
        # insert key, value pairs if key is not in args
136
        if default_args is not None:
137
            for name, value in default_args.items():
138
                args.setdefault(name, value)
139
140
        name = args.pop("name")
141
        cls = self.get(category=category, key=name)
142
        try:
143
            return cls(**args)
144
        except TypeError as err:
145
            raise ValueError(
146
                f"Configuration is not compatible "
147
                f"for Class {cls} of category {category}.\n"
148
                f"Potentially an outdated configuration has been used.\n"
149
                f"Please check the latest documentation of the class"
150
                f"and arrange the required keys at the same level"
151
                f" as `name` in configuration file.\n"
152
                f"{err}"
153
            )
154
155
    def copy(self):
156
        """Deep copy the registry."""
157
        copied = Registry()
158
        copied._dict = deepcopy(self._dict)
159
        return copied
160
161
    def register_model(
162
        self, name: str, cls: Callable = None, force: bool = False
163
    ) -> Callable:
164
        """
165
        Register a model class.
166
167
        :param name: model name
168
        :param cls: model class
169
        :param force: whether overwrite if already registered
170
        :return: the registered class
171
        """
172
        return self.register(category=MODEL_CLASS, name=name, cls=cls, force=force)
173
174
    def build_model(self, config: Dict, default_args: Optional[dict] = None) -> Any:
175
        """
176
        Instantiate a registered model class.
177
178
        :param config: config having key `name`.
179
        :param default_args: optionally some default arguments.
180
        :return: a model instance
181
        """
182
        return self.build_from_config(
183
            category=MODEL_CLASS, config=config, default_args=default_args
184
        )
185
186
    def register_backbone(
187
        self, name: str, cls: Callable = None, force: bool = False
188
    ) -> Callable:
189
        """
190
        Register a backbone class.
191
192
        :param name: backbone name
193
        :param cls: backbone class
194
        :param force: whether overwrite if already registered
195
        :return: the registered class
196
        """
197
        return self.register(category=BACKBONE_CLASS, name=name, cls=cls, force=force)
198
199
    def build_backbone(self, config: Dict, default_args: Optional[dict] = None) -> Any:
200
        """
201
        Instantiate a registered backbone class.
202
203
        :param config: config having key `name`.
204
        :param default_args: optionally some default arguments.
205
        :return: a backbone instance
206
        """
207
        return self.build_from_config(
208
            category=BACKBONE_CLASS, config=config, default_args=default_args
209
        )
210
211
    def register_loss(
212
        self, name: str, cls: Callable = None, force: bool = False
213
    ) -> Callable:
214
        """
215
        Register a loss class.
216
217
        :param name: loss name
218
        :param cls: loss class
219
        :param force: whether overwrite if already registered
220
        :return: the registered class
221
        """
222
        return self.register(category=LOSS_CLASS, name=name, cls=cls, force=force)
223
224
    def build_loss(self, config: Dict, default_args: Optional[dict] = None) -> Callable:
225
        """
226
        Instantiate a registered loss class.
227
228
        :param config: config having key `name`.
229
        :param default_args: optionally some default arguments.
230
        :return: a loss instance
231
        """
232
        return self.build_from_config(
233
            category=LOSS_CLASS, config=config, default_args=default_args
234
        )
235
236
    def register_data_loader(
237
        self, name: str, cls: Callable = None, force: bool = False
238
    ) -> Callable:
239
        """
240
        Register a data loader class.
241
242
        :param name: loss name
243
        :param cls: loss class
244
        :param force: whether overwrite if already registered
245
        :return: the registered class
246
        """
247
        return self.register(
248
            category=DATA_LOADER_CLASS, name=name, cls=cls, force=force
249
        )
250
251
    def build_data_loader(
252
        self, config: Dict, default_args: Optional[dict] = None
253
    ) -> Any:
254
        """
255
        Instantiate a registered data loader class.
256
257
        :param config: config having key `name`.
258
        :param default_args: optionally some default arguments.
259
        :return: a loss instance
260
        """
261
        return self.build_from_config(
262
            category=DATA_LOADER_CLASS, config=config, default_args=default_args
263
        )
264
265
    def register_data_augmentation(
266
        self, name: str, cls: Callable = None, force: bool = False
267
    ) -> Callable:
268
        """
269
        Register a data augmentation class.
270
271
        :param name: data augmentation name
272
        :param cls: data augmentation class
273
        :param force: whether overwrite if already registered
274
        :return: the registered class
275
        """
276
        return self.register(
277
            category=DATA_AUGMENTATION_CLASS, name=name, cls=cls, force=force
278
        )
279
280
    def register_file_loader(
281
        self, name: str, cls: Callable = None, force: bool = False
282
    ) -> Callable:
283
        """
284
        Register a file loader class.
285
286
        :param name: loss name
287
        :param cls: loss class
288
        :param force: whether overwrite if already registered
289
        :return: the registered class
290
        """
291
        return self.register(
292
            category=FILE_LOADER_CLASS, name=name, cls=cls, force=force
293
        )
294
295
    def build_data_augmentation(
296
        self, config: Dict, default_args: Optional[dict] = None
297
    ) -> Callable:
298
        """
299
        Instantiate a registered data augmentation class.
300
301
        :param config: config having key `name`.
302
        :param default_args: optionally some default arguments.
303
        :return: a data augmentation instance
304
        """
305
        return self.build_from_config(
306
            category=DATA_AUGMENTATION_CLASS, config=config, default_args=default_args
307
        )
308
309
310
REGISTRY = Registry()
311