1
|
|
|
from copy import deepcopy |
2
|
|
|
from typing import Any, Callable, Dict, Optional |
3
|
|
|
|
4
|
|
|
BACKBONE_CLASS = "backbone_class" |
5
|
|
|
LOSS_CLASS = "loss_class" |
6
|
|
|
METRIC_CLASS = "metric_class" |
7
|
|
|
MODEL_CLASS = "model_class" |
8
|
|
|
DATA_AUGMENTATION_CLASS = "da_class" |
9
|
|
|
DATA_LOADER_CLASS = "data_loader_class" |
10
|
|
|
FILE_LOADER_CLASS = "file_loader_class" |
11
|
|
|
|
12
|
|
|
KNOWN_CATEGORIES = [ |
13
|
|
|
BACKBONE_CLASS, |
14
|
|
|
LOSS_CLASS, |
15
|
|
|
MODEL_CLASS, |
16
|
|
|
DATA_AUGMENTATION_CLASS, |
17
|
|
|
DATA_LOADER_CLASS, |
18
|
|
|
FILE_LOADER_CLASS, |
19
|
|
|
] |
20
|
|
|
|
21
|
|
|
|
22
|
|
|
class Registry: |
23
|
|
|
""" |
24
|
|
|
Registry maintains a dictionary which maps `(category, key)` to `value`. |
25
|
|
|
|
26
|
|
|
Multiple __init__.py files have been modified so that the classes are registered |
27
|
|
|
when executing: |
28
|
|
|
|
29
|
|
|
.. code-block:: python |
30
|
|
|
|
31
|
|
|
from deepreg.registry import REGISTRY |
32
|
|
|
|
33
|
|
|
References: |
34
|
|
|
|
35
|
|
|
- https://github.com/ray-project/ray/blob/00ef1179c012719a17c147a5c3b36d6bdbe97195/python/ray/tune/registry.py#L108 |
36
|
|
|
- https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/builder.py |
37
|
|
|
- https://github.com/open-mmlab/mmcv/blob/master/mmcv/utils/registry.py |
38
|
|
|
- https://towardsdatascience.com/whats-init-for-me-d70a312da583 |
39
|
|
|
""" |
40
|
|
|
|
41
|
|
|
def __init__(self): |
42
|
|
|
"""Init registry with empty dict.""" |
43
|
|
|
self._dict = {} |
44
|
|
|
|
45
|
|
|
def _register(self, category: str, key: str, value: Callable, force: bool): |
46
|
|
|
""" |
47
|
|
|
Registers the value with the registry. |
48
|
|
|
|
49
|
|
|
:param category: name of the class category |
50
|
|
|
:param key: unique identity |
51
|
|
|
:param value: class to be registered |
52
|
|
|
:param force: if True, overwrite the existing value |
53
|
|
|
in case the key has been registered. |
54
|
|
|
""" |
55
|
|
|
# sanity check |
56
|
|
|
if category not in KNOWN_CATEGORIES: |
57
|
|
|
raise ValueError( |
58
|
|
|
f"Unknown category {category} not among {KNOWN_CATEGORIES}" |
59
|
|
|
) |
60
|
|
|
if not force and self.contains(category=category, key=key): |
61
|
|
|
raise ValueError( |
62
|
|
|
f"Key {key} in category {category} has been registered" |
63
|
|
|
f" with {self._dict[(category, key)]}" |
64
|
|
|
) |
65
|
|
|
# register value |
66
|
|
|
self._dict[(category, key)] = value |
67
|
|
|
|
68
|
|
|
def contains(self, category: str, key: str) -> bool: |
69
|
|
|
""" |
70
|
|
|
Verify if the key has been registered for the category. |
71
|
|
|
|
72
|
|
|
:param category: category name. |
73
|
|
|
:param key: value name. |
74
|
|
|
:return: `True` if registered. |
75
|
|
|
""" |
76
|
|
|
return (category, key) in self._dict |
77
|
|
|
|
78
|
|
|
def get(self, category: str, key: str) -> Callable: |
79
|
|
|
""" |
80
|
|
|
Return the registered class. |
81
|
|
|
|
82
|
|
|
:param category: category name. |
83
|
|
|
:param key: value name. |
84
|
|
|
:return: registered value. |
85
|
|
|
""" |
86
|
|
|
if self.contains(category=category, key=key): |
87
|
|
|
return self._dict[(category, key)] |
88
|
|
|
raise ValueError(f"Key {key} in category {category} has not been registered.") |
89
|
|
|
|
90
|
|
|
def register( |
91
|
|
|
self, category: str, name: str, cls: Callable = None, force: bool = False |
92
|
|
|
) -> Callable: |
93
|
|
|
""" |
94
|
|
|
Register a py class. |
95
|
|
|
A record will be added to `self._dict`, whose key is the class |
96
|
|
|
name or the specified name, and value is the class itself. |
97
|
|
|
It can be used as a decorator or a normal function. |
98
|
|
|
|
99
|
|
|
:param category: The type of the category. |
100
|
|
|
:param name: The class name to be registered. |
101
|
|
|
If not specified, the class name will be used. |
102
|
|
|
:param force: Whether to override an existing class with the same name. |
103
|
|
|
:param cls: Class to be registered. |
104
|
|
|
:return: The given class or a decorator. |
105
|
|
|
""" |
106
|
|
|
# use it as a normal method: x.register_module(module=SomeClass) |
107
|
|
|
if cls is not None: |
108
|
|
|
self._register(category=category, key=name, value=cls, force=force) |
109
|
|
|
return cls |
110
|
|
|
|
111
|
|
|
# use it as a decorator: @x.register(name, category) |
112
|
|
|
def decorator(c: Callable) -> Callable: |
113
|
|
|
self._register(category=category, key=name, value=c, force=force) |
114
|
|
|
return c |
115
|
|
|
|
116
|
|
|
return decorator |
117
|
|
|
|
118
|
|
|
def build_from_config( |
119
|
|
|
self, category: str, config: Dict, default_args: Optional[dict] = None |
120
|
|
|
) -> Any: |
121
|
|
|
""" |
122
|
|
|
Build a class instance from config dict. |
123
|
|
|
|
124
|
|
|
:param category: category name. |
125
|
|
|
:param config: a dict which must contain the key "name". |
126
|
|
|
:param default_args: optionally some default arguments. |
127
|
|
|
:return: the instantiated class. |
128
|
|
|
""" |
129
|
|
|
if not isinstance(config, dict): |
130
|
|
|
raise ValueError(f"config must be a dict, but got {type(config)}") |
131
|
|
|
if "name" not in config: |
132
|
|
|
raise ValueError(f"`config` must contain the key `name`, but got {config}") |
133
|
|
|
args = deepcopy(config) |
134
|
|
|
|
135
|
|
|
# insert key, value pairs if key is not in args |
136
|
|
|
if default_args is not None: |
137
|
|
|
for name, value in default_args.items(): |
138
|
|
|
args.setdefault(name, value) |
139
|
|
|
|
140
|
|
|
name = args.pop("name") |
141
|
|
|
cls = self.get(category=category, key=name) |
142
|
|
|
try: |
143
|
|
|
return cls(**args) |
144
|
|
|
except TypeError as err: |
145
|
|
|
raise ValueError( |
146
|
|
|
f"Configuration is not compatible " |
147
|
|
|
f"for Class {cls} of category {category}.\n" |
148
|
|
|
f"Potentially an outdated configuration has been used.\n" |
149
|
|
|
f"Please check the latest documentation of the class" |
150
|
|
|
f"and arrange the required keys at the same level" |
151
|
|
|
f" as `name` in configuration file.\n" |
152
|
|
|
f"{err}" |
153
|
|
|
) |
154
|
|
|
|
155
|
|
|
def copy(self): |
156
|
|
|
"""Deep copy the registry.""" |
157
|
|
|
copied = Registry() |
158
|
|
|
copied._dict = deepcopy(self._dict) |
159
|
|
|
return copied |
160
|
|
|
|
161
|
|
|
def register_model( |
162
|
|
|
self, name: str, cls: Callable = None, force: bool = False |
163
|
|
|
) -> Callable: |
164
|
|
|
""" |
165
|
|
|
Register a model class. |
166
|
|
|
|
167
|
|
|
:param name: model name |
168
|
|
|
:param cls: model class |
169
|
|
|
:param force: whether overwrite if already registered |
170
|
|
|
:return: the registered class |
171
|
|
|
""" |
172
|
|
|
return self.register(category=MODEL_CLASS, name=name, cls=cls, force=force) |
173
|
|
|
|
174
|
|
|
def build_model(self, config: Dict, default_args: Optional[dict] = None) -> Any: |
175
|
|
|
""" |
176
|
|
|
Instantiate a registered model class. |
177
|
|
|
|
178
|
|
|
:param config: config having key `name`. |
179
|
|
|
:param default_args: optionally some default arguments. |
180
|
|
|
:return: a model instance |
181
|
|
|
""" |
182
|
|
|
return self.build_from_config( |
183
|
|
|
category=MODEL_CLASS, config=config, default_args=default_args |
184
|
|
|
) |
185
|
|
|
|
186
|
|
|
def register_backbone( |
187
|
|
|
self, name: str, cls: Callable = None, force: bool = False |
188
|
|
|
) -> Callable: |
189
|
|
|
""" |
190
|
|
|
Register a backbone class. |
191
|
|
|
|
192
|
|
|
:param name: backbone name |
193
|
|
|
:param cls: backbone class |
194
|
|
|
:param force: whether overwrite if already registered |
195
|
|
|
:return: the registered class |
196
|
|
|
""" |
197
|
|
|
return self.register(category=BACKBONE_CLASS, name=name, cls=cls, force=force) |
198
|
|
|
|
199
|
|
|
def build_backbone(self, config: Dict, default_args: Optional[dict] = None) -> Any: |
200
|
|
|
""" |
201
|
|
|
Instantiate a registered backbone class. |
202
|
|
|
|
203
|
|
|
:param config: config having key `name`. |
204
|
|
|
:param default_args: optionally some default arguments. |
205
|
|
|
:return: a backbone instance |
206
|
|
|
""" |
207
|
|
|
return self.build_from_config( |
208
|
|
|
category=BACKBONE_CLASS, config=config, default_args=default_args |
209
|
|
|
) |
210
|
|
|
|
211
|
|
|
def register_loss( |
212
|
|
|
self, name: str, cls: Callable = None, force: bool = False |
213
|
|
|
) -> Callable: |
214
|
|
|
""" |
215
|
|
|
Register a loss class. |
216
|
|
|
|
217
|
|
|
:param name: loss name |
218
|
|
|
:param cls: loss class |
219
|
|
|
:param force: whether overwrite if already registered |
220
|
|
|
:return: the registered class |
221
|
|
|
""" |
222
|
|
|
return self.register(category=LOSS_CLASS, name=name, cls=cls, force=force) |
223
|
|
|
|
224
|
|
|
def build_loss(self, config: Dict, default_args: Optional[dict] = None) -> Callable: |
225
|
|
|
""" |
226
|
|
|
Instantiate a registered loss class. |
227
|
|
|
|
228
|
|
|
:param config: config having key `name`. |
229
|
|
|
:param default_args: optionally some default arguments. |
230
|
|
|
:return: a loss instance |
231
|
|
|
""" |
232
|
|
|
return self.build_from_config( |
233
|
|
|
category=LOSS_CLASS, config=config, default_args=default_args |
234
|
|
|
) |
235
|
|
|
|
236
|
|
|
def register_data_loader( |
237
|
|
|
self, name: str, cls: Callable = None, force: bool = False |
238
|
|
|
) -> Callable: |
239
|
|
|
""" |
240
|
|
|
Register a data loader class. |
241
|
|
|
|
242
|
|
|
:param name: loss name |
243
|
|
|
:param cls: loss class |
244
|
|
|
:param force: whether overwrite if already registered |
245
|
|
|
:return: the registered class |
246
|
|
|
""" |
247
|
|
|
return self.register( |
248
|
|
|
category=DATA_LOADER_CLASS, name=name, cls=cls, force=force |
249
|
|
|
) |
250
|
|
|
|
251
|
|
|
def build_data_loader( |
252
|
|
|
self, config: Dict, default_args: Optional[dict] = None |
253
|
|
|
) -> Any: |
254
|
|
|
""" |
255
|
|
|
Instantiate a registered data loader class. |
256
|
|
|
|
257
|
|
|
:param config: config having key `name`. |
258
|
|
|
:param default_args: optionally some default arguments. |
259
|
|
|
:return: a loss instance |
260
|
|
|
""" |
261
|
|
|
return self.build_from_config( |
262
|
|
|
category=DATA_LOADER_CLASS, config=config, default_args=default_args |
263
|
|
|
) |
264
|
|
|
|
265
|
|
|
def register_data_augmentation( |
266
|
|
|
self, name: str, cls: Callable = None, force: bool = False |
267
|
|
|
) -> Callable: |
268
|
|
|
""" |
269
|
|
|
Register a data augmentation class. |
270
|
|
|
|
271
|
|
|
:param name: data augmentation name |
272
|
|
|
:param cls: data augmentation class |
273
|
|
|
:param force: whether overwrite if already registered |
274
|
|
|
:return: the registered class |
275
|
|
|
""" |
276
|
|
|
return self.register( |
277
|
|
|
category=DATA_AUGMENTATION_CLASS, name=name, cls=cls, force=force |
278
|
|
|
) |
279
|
|
|
|
280
|
|
|
def register_file_loader( |
281
|
|
|
self, name: str, cls: Callable = None, force: bool = False |
282
|
|
|
) -> Callable: |
283
|
|
|
""" |
284
|
|
|
Register a file loader class. |
285
|
|
|
|
286
|
|
|
:param name: loss name |
287
|
|
|
:param cls: loss class |
288
|
|
|
:param force: whether overwrite if already registered |
289
|
|
|
:return: the registered class |
290
|
|
|
""" |
291
|
|
|
return self.register( |
292
|
|
|
category=FILE_LOADER_CLASS, name=name, cls=cls, force=force |
293
|
|
|
) |
294
|
|
|
|
295
|
|
|
def build_data_augmentation( |
296
|
|
|
self, config: Dict, default_args: Optional[dict] = None |
297
|
|
|
) -> Callable: |
298
|
|
|
""" |
299
|
|
|
Instantiate a registered data augmentation class. |
300
|
|
|
|
301
|
|
|
:param config: config having key `name`. |
302
|
|
|
:param default_args: optionally some default arguments. |
303
|
|
|
:return: a data augmentation instance |
304
|
|
|
""" |
305
|
|
|
return self.build_from_config( |
306
|
|
|
category=DATA_AUGMENTATION_CLASS, config=config, default_args=default_args |
307
|
|
|
) |
308
|
|
|
|
309
|
|
|
|
310
|
|
|
REGISTRY = Registry() |
311
|
|
|
|