1
|
|
|
import re |
2
|
|
|
|
3
|
|
|
import pandas as pd |
4
|
|
|
import pytest |
5
|
|
|
|
6
|
|
|
from deepreg.registry import BACKBONE_CLASS, KNOWN_CATEGORIES, LOSS_CLASS, REGISTRY |
7
|
|
|
|
8
|
|
|
|
9
|
|
|
class TestRegistry: |
10
|
|
|
@pytest.fixture() |
11
|
|
|
def reg(self): |
12
|
|
|
return REGISTRY.copy() |
13
|
|
|
|
14
|
|
|
@pytest.mark.parametrize( |
15
|
|
|
"category,key,err_msg", |
16
|
|
|
[ |
17
|
|
|
("unknown_category", "key", "Unknown category"), |
18
|
|
|
(BACKBONE_CLASS, "unet", "has been registered"), |
19
|
|
|
], |
20
|
|
|
) |
21
|
|
|
def test_register_err(self, category, key, err_msg, reg): |
22
|
|
|
with pytest.raises(ValueError) as err_info: |
23
|
|
|
reg.register(category=category, name=key, cls=0) |
24
|
|
|
assert err_msg in str(err_info.value) |
25
|
|
|
|
26
|
|
|
@pytest.mark.parametrize( |
27
|
|
|
"category,key,force", |
28
|
|
|
[ |
29
|
|
|
(BACKBONE_CLASS, "unet", True), |
30
|
|
|
(BACKBONE_CLASS, "vnet", False), |
31
|
|
|
(LOSS_CLASS, "dice", True), |
32
|
|
|
(LOSS_CLASS, "Dice", False), |
33
|
|
|
], |
34
|
|
|
) |
35
|
|
|
def test_register(self, category, key, force, reg): |
36
|
|
|
value = 0 |
37
|
|
|
reg.register(category=category, name=key, cls=value, force=force) |
38
|
|
|
assert reg._dict[(category, key)] == value |
39
|
|
|
assert reg.get(category, key) == value |
40
|
|
|
|
41
|
|
|
@pytest.mark.parametrize( |
42
|
|
|
"category,key", |
43
|
|
|
[ |
44
|
|
|
(BACKBONE_CLASS, "unet"), |
45
|
|
|
(LOSS_CLASS, "dice"), |
46
|
|
|
], |
47
|
|
|
) |
48
|
|
|
def test_get(self, category, key, reg): |
49
|
|
|
# no error means the key has been registered |
50
|
|
|
_ = reg.get(category, key) |
51
|
|
|
|
52
|
|
|
def test_get_err(self, reg): |
53
|
|
|
with pytest.raises(ValueError) as err_info: |
54
|
|
|
reg.get(BACKBONE_CLASS, "wrong_key") |
55
|
|
|
assert "has not been registered" in str(err_info.value) |
56
|
|
|
|
57
|
|
|
@pytest.mark.parametrize( |
58
|
|
|
"category,config,err_msg", |
59
|
|
|
[ |
60
|
|
|
(BACKBONE_CLASS, [], "config must be a dict"), |
61
|
|
|
(BACKBONE_CLASS, {}, "`config` must contain the key `name`"), |
62
|
|
|
( |
63
|
|
|
BACKBONE_CLASS, |
64
|
|
|
{"name": "unet"}, |
65
|
|
|
"Configuration is not compatible for Class", |
66
|
|
|
), |
67
|
|
|
], |
68
|
|
|
) |
69
|
|
|
def test_build_from_config_err(self, category, config, err_msg, reg): |
70
|
|
|
with pytest.raises(ValueError) as err_info: |
71
|
|
|
reg.build_from_config(category=category, config=config) |
72
|
|
|
assert err_msg in str(err_info.value) |
73
|
|
|
|
74
|
|
|
@pytest.mark.parametrize( |
75
|
|
|
"category,config", |
76
|
|
|
[ |
77
|
|
|
( |
78
|
|
|
BACKBONE_CLASS, |
79
|
|
|
dict( |
80
|
|
|
name="unet", |
81
|
|
|
image_size=[1, 2, 3], |
82
|
|
|
out_channels=3, |
83
|
|
|
num_channel_initial=3, |
84
|
|
|
depth=5, |
85
|
|
|
out_kernel_initializer="he_normal", |
86
|
|
|
out_activation="softmax", |
87
|
|
|
), |
88
|
|
|
), |
89
|
|
|
(LOSS_CLASS, dict(name="dice")), |
90
|
|
|
], |
91
|
|
|
) |
92
|
|
|
def test_build_from_config(self, category, config, reg): |
93
|
|
|
_ = reg.build_from_config(category=category, config=config) |
94
|
|
|
|
95
|
|
|
def test_doc(self): |
96
|
|
|
"""Test the doc maintaining the list of registered classes are correct.""" |
97
|
|
|
|
98
|
|
|
filename = "docs/source/docs/registered_classes.md" |
99
|
|
|
|
100
|
|
|
# generate dataframe |
101
|
|
|
name_to_category = { |
102
|
|
|
"Backbone": "backbone_class", |
103
|
|
|
"Model": "model_class", |
104
|
|
|
"Loss": "loss_class", |
105
|
|
|
"Data Augmentation": "da_class", |
106
|
|
|
"Data Loader": "data_loader_class", |
107
|
|
|
"File Loader": "file_loader_class", |
108
|
|
|
} |
109
|
|
|
for category in KNOWN_CATEGORIES: |
110
|
|
|
assert category in name_to_category.values() |
111
|
|
|
|
112
|
|
|
df = dict(category=[], key=[], value=[]) |
113
|
|
|
for (category, key), value in REGISTRY._dict.items(): |
114
|
|
|
df["category"].append(category) |
115
|
|
|
df["key"].append(f'"{key}"') |
116
|
|
|
df["value"].append(f"`{value.__module__}.{value.__name__}`") |
117
|
|
|
df = pd.DataFrame(df) |
118
|
|
|
df = df.sort_values(["category", "key"]) |
119
|
|
|
|
120
|
|
|
# generate lines |
121
|
|
|
lines = ( |
122
|
|
|
"# Registered Classes\n\n" |
123
|
|
|
"> This file is generated automatically.\n\n" |
124
|
|
|
"The following tables contain all registered classes " |
125
|
|
|
"with their categories and keys." |
126
|
|
|
) |
127
|
|
|
|
128
|
|
|
for category_name, category in name_to_category.items(): |
129
|
|
|
df_cat = df[df.category == category] |
130
|
|
|
lines += f"\n\n## {category_name}\n\n" |
131
|
|
|
lines += ( |
132
|
|
|
f"The category is `{category}`. " |
133
|
|
|
f"Registered keys and values are as following.\n\n" |
134
|
|
|
) |
135
|
|
|
lines += df_cat[["key", "value"]].to_markdown(index=False) |
136
|
|
|
|
137
|
|
|
# check file content |
138
|
|
|
with open(filename, "r") as f: |
139
|
|
|
got = f.readlines() |
140
|
|
|
got = "".join(got) |
141
|
|
|
got = re.sub(r":-+", "", got) |
142
|
|
|
got = got.replace(" ", "") |
143
|
|
|
expected = re.sub(r":-+", "", lines) |
144
|
|
|
expected = expected.replace(" ", "") |
145
|
|
|
expected = expected + "\n" |
146
|
|
|
|
147
|
|
|
assert got == expected |
148
|
|
|
|
149
|
|
|
# rewrite the file |
150
|
|
|
# if test failed, only need to temporarily comment out the assert |
151
|
|
|
# then regenerate the file |
152
|
|
|
if got != expected: |
153
|
|
|
with open(filename, "w") as f: |
154
|
|
|
f.writelines(lines) |
155
|
|
|
|