Completed
Push — main ( d3edf2...e8f714 )
by Yunguan
21s queued 14s
created

deepreg.config.v011.parse_data()   A

Complexity

Conditions 4

Size

Total Lines 25
Code Lines 14

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 14
dl 0
loc 25
rs 9.7
c 0
b 0
f 0
cc 4
nop 1
1
"""Support backcompatibility for configs at v0.1.1."""
2
3
from copy import deepcopy
4
from typing import Dict
5
6
from deepreg.constant import KNOWN_DATA_SPLITS
7
8
9
def parse_v011(old_config: Dict) -> Dict:
10
    """
11
    Transform configuration from V0.1.1 format to the latest format.
12
13
    V0.1.1 to latest.
14
15
    :param old_config:
16
    :return: transformed config
17
    """
18
19
    new_config = deepcopy(old_config)
20
21
    new_config["dataset"] = parse_data(data_config=new_config["dataset"])
22
23
    model_config = new_config["train"].pop("model", None)
24
    if model_config is not None:
25
        model_config = parse_model(model_config=model_config)
26
        new_config["train"].update(model_config)
27
28
    new_config["train"]["loss"] = parse_loss(loss_config=new_config["train"]["loss"])
29
30
    new_config["train"]["preprocess"] = parse_preprocess(
31
        preprocess_config=new_config["train"]["preprocess"]
32
    )
33
34
    new_config["train"]["optimizer"] = parse_optimizer(
35
        opt_config=new_config["train"]["optimizer"]
36
    )
37
38
    return new_config
39
40
41
def parse_data(data_config: dict) -> Dict:
42
    """
43
    Parse the data configuration.
44
45
    :param data_config: potentially outdated config
46
    :return: latest config
47
    """
48
    if "format" not in data_config:
49
        # up-to-date
50
        return data_config
51
52
    # define format and labeled in each split
53
    dir = data_config.pop("dir")
54
    format = data_config.pop("format")
55
    labeled = data_config.pop("labeled")
56
    # get dir for each split
57
    for split in KNOWN_DATA_SPLITS:
58
        if split not in dir:
59
            continue
60
        data_config[split] = {
61
            "dir": dir[split],
62
            "format": format,
63
            "labeled": labeled,
64
        }
65
    return data_config
66
67
68
def parse_model(model_config: Dict) -> Dict:
69
    """
70
    Parse the model configuration.
71
72
    :param model_config: potentially outdated config
73
    :return: latest config
74
    """
75
    # remove model layer
76
    if "model" in model_config:
77
        model_config = model_config["model"]
78
79
    if isinstance(model_config["backbone"], dict):
80
        # up-to-date
81
        return model_config
82
83
    backbone_name = model_config["backbone"]
84
85
    # backbone_config is the backbone name
86
    backbone_config = {"name": backbone_name, **model_config[backbone_name]}
87
88
    if backbone_name == "global":
89
        # global net use depth instead of extract_levels
90
        if "extract_levels" in backbone_config:
91
            extract_levels = backbone_config.pop("extract_levels")
92
            backbone_config["depth"] = max(extract_levels)
93
94
    model_config = {"method": model_config["method"], "backbone": backbone_config}
95
    return model_config
96
97
98
def parse_image_loss(loss_config: Dict) -> Dict:
99
    """
100
    Parse the image loss part in loss configuration.
101
102
    :param loss_config: potentially outdated config
103
    :return: latest config
104
    """
105
    if "image" not in loss_config:
106
        # no image loss
107
        return loss_config
108
109
    if isinstance(loss_config["image"], list):
110
        # config up-to-date
111
        return loss_config
112
113
    image_loss_name = loss_config["image"]["name"]
114
115
    if image_loss_name not in loss_config["image"]:
116
        # config up-to-date
117
        return loss_config
118
119
    image_loss_config = {
120
        "name": image_loss_name,
121
        "weight": loss_config["image"].get("weight", 1.0),
122
    }
123
    image_loss_config.update(loss_config["image"][image_loss_name])
124
    loss_config["image"] = image_loss_config
125
    return loss_config
126
127
128
def parse_label_loss(loss_config: Dict) -> Dict:
129
    """
130
    Parse the label loss part in loss configuration.
131
132
    :param loss_config: potentially outdated config
133
    :return: latest config
134
    """
135
    if "label" not in loss_config:
136
        # no label loss
137
        return loss_config
138
139
    if isinstance(loss_config["label"], list):
140
        # config up-to-date
141
        return loss_config
142
143
    label_loss_name = loss_config["label"]["name"]
144
    if label_loss_name == "single_scale":
145
        loss_config["label"] = {
146
            "name": loss_config["label"]["single_scale"]["loss_type"],
147
            "weight": loss_config["label"].get("weight", 1.0),
148
        }
149
    elif label_loss_name == "multi_scale":
150
        loss_config["label"] = {
151
            "name": loss_config["label"]["multi_scale"]["loss_type"],
152
            "weight": loss_config["label"].get("weight", 1.0),
153
            "scales": loss_config["label"]["multi_scale"]["loss_scales"],
154
        }
155
156
    # mean-squared renamed to ssd
157
    if loss_config["label"]["name"] == "mean-squared":
158
        loss_config["label"]["name"] = "ssd"
159
160
    # dice_generalized merged into dice
161
    if loss_config["label"]["name"] == "dice_generalized":
162
        loss_config["label"]["name"] = "dice"
163
164
    # rename neg_weight to background_weight
165
    if "neg_weight" in loss_config["label"]:
166
        background_weight = loss_config["label"].pop("neg_weight")
167
        loss_config["label"]["background_weight"] = background_weight
168
169
    return loss_config
170
171
172
def parse_reg_loss(loss_config: Dict) -> Dict:
173
    """
174
    Parse the regularization loss part in loss configuration.
175
176
    :param loss_config: potentially outdated config
177
    :return: latest config
178
    """
179
    if "regularization" not in loss_config:
180
        # no regularization loss
181
        return loss_config
182
183
    if isinstance(loss_config["regularization"], list):
184
        # config up-to-date
185
        return loss_config
186
187
    if "energy_type" not in loss_config["regularization"]:
188
        # up-to-date
189
        return loss_config
190
191
    energy_type = loss_config["regularization"]["energy_type"]
192
    reg_config = {"weight": loss_config["regularization"].get("weight", 1.0)}
193
    if energy_type == "bending":
194
        reg_config["name"] = "bending"
195
    elif energy_type == "gradient-l2":
196
        reg_config["name"] = "gradient"
197
        reg_config["l1"] = False
198
    elif energy_type == "gradient-l1":
199
        reg_config["name"] = "gradient"
200
        reg_config["l1"] = True
201
    loss_config["regularization"] = reg_config
202
203
    return loss_config
204
205
206
def parse_loss(loss_config: Dict) -> Dict:
207
    """
208
    Parse the loss configuration.
209
210
    :param loss_config: potentially outdated config
211
    :return: latest config
212
    """
213
    # remove dissimilarity layer
214
    if "dissimilarity" in loss_config:
215
        dissim_config = loss_config.pop("dissimilarity")
216
        loss_config.update(dissim_config)
217
218
    loss_config = parse_image_loss(loss_config=loss_config)
219
    loss_config = parse_label_loss(loss_config=loss_config)
220
    loss_config = parse_reg_loss(loss_config=loss_config)
221
222
    return loss_config
223
224
225
def parse_preprocess(preprocess_config: Dict) -> Dict:
226
    """
227
    Parse the preprocess configuration.
228
229
    :param preprocess_config: potentially outdated config
230
    :return: latest config
231
    """
232
    if "data_augmentation" not in preprocess_config:
233
        preprocess_config["data_augmentation"] = {"name": "affine"}
234
    return preprocess_config
235
236
237
def parse_optimizer(opt_config: Dict) -> Dict:
238
    """
239
    Parse the optimizer configuration.
240
241
    :param opt_config: potentially outdated config
242
    :return: latest config
243
    """
244
    name = opt_config["name"]
245
    if name not in opt_config:
246
        # up-to-date
247
        return opt_config
248
249
    name_dict = dict(
250
        adam="Adam",
251
        sgd="SGD",
252
        rms="RMSprop",
253
    )
254
    new_name = name_dict[name]
255
    return {"name": new_name, **opt_config[name]}
256