test.unit.test_registry   A
last analyzed

Complexity

Total Complexity 17

Size/Duplication

Total Lines 155
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 17
eloc 109
dl 0
loc 155
rs 10
c 0
b 0
f 0

8 Methods

Rating   Name   Duplication   Size   Complexity  
A TestRegistry.test_build_from_config() 0 20 1
A TestRegistry.test_register() 0 14 1
A TestRegistry.reg() 0 3 1
A TestRegistry.test_get() 0 10 1
A TestRegistry.test_register_err() 0 11 2
A TestRegistry.test_get_err() 0 4 2
A TestRegistry.test_build_from_config_err() 0 16 2
B TestRegistry.test_doc() 0 60 7
1
import re
2
3
import pandas as pd
4
import pytest
5
6
from deepreg.registry import BACKBONE_CLASS, KNOWN_CATEGORIES, LOSS_CLASS, REGISTRY
7
8
9
class TestRegistry:
10
    @pytest.fixture()
11
    def reg(self):
12
        return REGISTRY.copy()
13
14
    @pytest.mark.parametrize(
15
        "category,key,err_msg",
16
        [
17
            ("unknown_category", "key", "Unknown category"),
18
            (BACKBONE_CLASS, "unet", "has been registered"),
19
        ],
20
    )
21
    def test_register_err(self, category, key, err_msg, reg):
22
        with pytest.raises(ValueError) as err_info:
23
            reg.register(category=category, name=key, cls=0)
24
        assert err_msg in str(err_info.value)
25
26
    @pytest.mark.parametrize(
27
        "category,key,force",
28
        [
29
            (BACKBONE_CLASS, "unet", True),
30
            (BACKBONE_CLASS, "vnet", False),
31
            (LOSS_CLASS, "dice", True),
32
            (LOSS_CLASS, "Dice", False),
33
        ],
34
    )
35
    def test_register(self, category, key, force, reg):
36
        value = 0
37
        reg.register(category=category, name=key, cls=value, force=force)
38
        assert reg._dict[(category, key)] == value
39
        assert reg.get(category, key) == value
40
41
    @pytest.mark.parametrize(
42
        "category,key",
43
        [
44
            (BACKBONE_CLASS, "unet"),
45
            (LOSS_CLASS, "dice"),
46
        ],
47
    )
48
    def test_get(self, category, key, reg):
49
        # no error means the key has been registered
50
        _ = reg.get(category, key)
51
52
    def test_get_err(self, reg):
53
        with pytest.raises(ValueError) as err_info:
54
            reg.get(BACKBONE_CLASS, "wrong_key")
55
        assert "has not been registered" in str(err_info.value)
56
57
    @pytest.mark.parametrize(
58
        "category,config,err_msg",
59
        [
60
            (BACKBONE_CLASS, [], "config must be a dict"),
61
            (BACKBONE_CLASS, {}, "`config` must contain the key `name`"),
62
            (
63
                BACKBONE_CLASS,
64
                {"name": "unet"},
65
                "Configuration is not compatible for Class",
66
            ),
67
        ],
68
    )
69
    def test_build_from_config_err(self, category, config, err_msg, reg):
70
        with pytest.raises(ValueError) as err_info:
71
            reg.build_from_config(category=category, config=config)
72
        assert err_msg in str(err_info.value)
73
74
    @pytest.mark.parametrize(
75
        "category,config",
76
        [
77
            (
78
                BACKBONE_CLASS,
79
                dict(
80
                    name="unet",
81
                    image_size=[1, 2, 3],
82
                    out_channels=3,
83
                    num_channel_initial=3,
84
                    depth=5,
85
                    out_kernel_initializer="he_normal",
86
                    out_activation="softmax",
87
                ),
88
            ),
89
            (LOSS_CLASS, dict(name="dice")),
90
        ],
91
    )
92
    def test_build_from_config(self, category, config, reg):
93
        _ = reg.build_from_config(category=category, config=config)
94
95
    def test_doc(self):
96
        """Test the doc maintaining the list of registered classes are correct."""
97
98
        filename = "docs/source/docs/registered_classes.md"
99
100
        # generate dataframe
101
        name_to_category = {
102
            "Backbone": "backbone_class",
103
            "Model": "model_class",
104
            "Loss": "loss_class",
105
            "Data Augmentation": "da_class",
106
            "Data Loader": "data_loader_class",
107
            "File Loader": "file_loader_class",
108
        }
109
        for category in KNOWN_CATEGORIES:
110
            assert category in name_to_category.values()
111
112
        df = dict(category=[], key=[], value=[])
113
        for (category, key), value in REGISTRY._dict.items():
114
            df["category"].append(category)
115
            df["key"].append(f'"{key}"')
116
            df["value"].append(f"`{value.__module__}.{value.__name__}`")
117
        df = pd.DataFrame(df)
118
        df = df.sort_values(["category", "key"])
119
120
        # generate lines
121
        lines = (
122
            "# Registered Classes\n\n"
123
            "> This file is generated automatically.\n\n"
124
            "The following tables contain all registered classes "
125
            "with their categories and keys."
126
        )
127
128
        for category_name, category in name_to_category.items():
129
            df_cat = df[df.category == category]
130
            lines += f"\n\n## {category_name}\n\n"
131
            lines += (
132
                f"The category is `{category}`. "
133
                f"Registered keys and values are as following.\n\n"
134
            )
135
            lines += df_cat[["key", "value"]].to_markdown(index=False)
136
137
        # check file content
138
        with open(filename, "r") as f:
139
            got = f.readlines()
140
            got = "".join(got)
141
            got = re.sub(r":-+", "", got)
142
            got = got.replace(" ", "")
143
            expected = re.sub(r":-+", "", lines)
144
            expected = expected.replace(" ", "")
145
            expected = expected + "\n"
146
147
            assert got == expected
148
149
        # rewrite the file
150
        # if test failed, only need to temporarily comment out the assert
151
        # then regenerate the file
152
        if got != expected:
153
            with open(filename, "w") as f:
154
                f.writelines(lines)
155