Passed
Pull Request — main (#651)
by Yunguan
03:48
created

deepreg.config.v011.parse_optimizer()   A

Complexity

Conditions 2

Size

Total Lines 19
Code Lines 10

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 10
dl 0
loc 19
rs 9.9
c 0
b 0
f 0
cc 2
nop 1
1
"""Support backcompatibility for configs at v0.1.1."""
2
3
from copy import deepcopy
4
5
6
def parse_v011(old_config: dict) -> dict:
7
    """
8
    Transform configuration from V0.1.1 format to the latest format.
9
10
    V0.1.1 to latest.
11
12
    :param old_config:
13
    :return: transformed config
14
    """
15
16
    new_config = deepcopy(old_config)
17
18
    model_config = new_config["train"].pop("model", None)
19
    if model_config is not None:
20
        model_config = parse_model(model_config=model_config)
21
        new_config["train"].update(model_config)
22
23
    new_config["train"]["loss"] = parse_loss(loss_config=new_config["train"]["loss"])
24
25
    new_config["train"]["preprocess"] = parse_preprocess(
26
        preprocess_config=new_config["train"]["preprocess"]
27
    )
28
29
    new_config["train"]["optimizer"] = parse_optimizer(
30
        opt_config=new_config["train"]["optimizer"]
31
    )
32
33
    return new_config
34
35
36
def parse_model(model_config: dict) -> dict:
37
    """
38
    Parse the model configuration.
39
40
    :param model_config: potentially outdated config
41
    :return: latest config
42
    """
43
    # remove model layer
44
    if "model" in model_config:
45
        model_config = model_config["model"]
46
47
    if isinstance(model_config["backbone"], dict):
48
        # up-to-date
49
        return model_config
50
51
    backbone_name = model_config["backbone"]
52
    # backbone_config is the backbone name
53
    backbone_config = {"name": backbone_name, **model_config[backbone_name]}
54
    model_config = {"method": model_config["method"], "backbone": backbone_config}
55
    return model_config
56
57
58
def parse_image_loss(loss_config: dict) -> dict:
59
    """
60
    Parse the image loss part in loss configuration.
61
62
    :param loss_config: potentially outdated config
63
    :return: latest config
64
    """
65
    if "image" not in loss_config:
66
        # no image loss
67
        return loss_config
68
69
    image_loss_name = loss_config["image"]["name"]
70
71
    if image_loss_name not in loss_config["image"]:
72
        # config up-to-date
73
        return loss_config
74
75
    image_loss_config = {
76
        "name": image_loss_name,
77
        "weight": loss_config["image"].get("weight", 1.0),
78
    }
79
    image_loss_config.update(loss_config["image"][image_loss_name])
80
    loss_config["image"] = image_loss_config
81
    return loss_config
82
83
84
def parse_label_loss(loss_config: dict) -> dict:
85
    """
86
    Parse the label loss part in loss configuration.
87
88
    :param loss_config: potentially outdated config
89
    :return: latest config
90
    """
91
    if "label" not in loss_config:
92
        # no label loss
93
        return loss_config
94
95
    label_loss_name = loss_config["label"]["name"]
96
    if label_loss_name == "single_scale":
97
        loss_config["label"] = {
98
            "name": loss_config["label"]["single_scale"]["loss_type"],
99
            "weight": loss_config["label"].get("weight", 1.0),
100
        }
101
    elif label_loss_name == "multi_scale":
102
        loss_config["label"] = {
103
            "name": loss_config["label"]["multi_scale"]["loss_type"],
104
            "weight": loss_config["label"].get("weight", 1.0),
105
            "scales": loss_config["label"]["multi_scale"]["loss_scales"],
106
        }
107
108
    # mean-squared renamed to ssd
109
    if loss_config["label"]["name"] == "mean-squared":
110
        loss_config["label"]["name"] = "ssd"
111
112
    # dice_generalized merged into dice
113
    if loss_config["label"]["name"] == "dice_generalized":
114
        loss_config["label"]["name"] = "dice"
115
116
    return loss_config
117
118
119
def parse_reg_loss(loss_config: dict) -> dict:
120
    """
121
    Parse the regularization loss part in loss configuration.
122
123
    :param loss_config: potentially outdated config
124
    :return: latest config
125
    """
126
    if "regularization" not in loss_config:
127
        # no regularization loss
128
        return loss_config
129
130
    if "energy_type" not in loss_config["regularization"]:
131
        # up-to-date
132
        return loss_config
133
134
    energy_type = loss_config["regularization"]["energy_type"]
135
    reg_config = {"weight": loss_config["regularization"].get("weight", 1.0)}
136
    if energy_type == "bending":
137
        reg_config["name"] = "bending"
138
    elif energy_type == "gradient-l2":
139
        reg_config["name"] = "gradient"
140
        reg_config["l1"] = False
141
    elif energy_type == "gradient-l1":
142
        reg_config["name"] = "gradient"
143
        reg_config["l1"] = True
144
    loss_config["regularization"] = reg_config
145
146
    return loss_config
147
148
149
def parse_loss(loss_config: dict) -> dict:
150
    """
151
    Parse the loss configuration.
152
153
    :param loss_config: potentially outdated config
154
    :return: latest config
155
    """
156
    # remove dissimilarity layer
157
    if "dissimilarity" in loss_config:
158
        dissim_config = loss_config.pop("dissimilarity")
159
        loss_config.update(dissim_config)
160
161
    loss_config = parse_image_loss(loss_config=loss_config)
162
    loss_config = parse_label_loss(loss_config=loss_config)
163
    loss_config = parse_reg_loss(loss_config=loss_config)
164
165
    return loss_config
166
167
168
def parse_preprocess(preprocess_config: dict) -> dict:
169
    """
170
    Parse the preprocess configuration.
171
172
    :param preprocess_config: potentially outdated config
173
    :return: latest config
174
    """
175
    if "data_augmentation" not in preprocess_config:
176
        preprocess_config["data_augmentation"] = {"name": "affine"}
177
    return preprocess_config
178
179
180
def parse_optimizer(opt_config: dict) -> dict:
181
    """
182
    Parse the optimizer configuration.
183
184
    :param opt_config: potentially outdated config
185
    :return: latest config
186
    """
187
    name = opt_config["name"]
188
    if name not in opt_config:
189
        # up-to-date
190
        return opt_config
191
192
    name_dict = dict(
193
        adam="Adam",
194
        sgd="SGD",
195
        rms="RMSprop",
196
    )
197
    new_name = name_dict[name]
198
    return {"name": new_name, **opt_config[name]}
199