1
|
|
|
"""Support backcompatibility for configs at v0.1.1.""" |
2
|
|
|
|
3
|
|
|
from copy import deepcopy |
4
|
|
|
from typing import Dict |
5
|
|
|
|
6
|
|
|
from deepreg.constant import KNOWN_DATA_SPLITS |
7
|
|
|
|
8
|
|
|
|
9
|
|
|
def parse_v011(old_config: Dict) -> Dict: |
10
|
|
|
""" |
11
|
|
|
Transform configuration from V0.1.1 format to the latest format. |
12
|
|
|
|
13
|
|
|
V0.1.1 to latest. |
14
|
|
|
|
15
|
|
|
:param old_config: |
16
|
|
|
:return: transformed config |
17
|
|
|
""" |
18
|
|
|
|
19
|
|
|
new_config = deepcopy(old_config) |
20
|
|
|
|
21
|
|
|
new_config["dataset"] = parse_data(data_config=new_config["dataset"]) |
22
|
|
|
|
23
|
|
|
model_config = new_config["train"].pop("model", None) |
24
|
|
|
if model_config is not None: |
25
|
|
|
model_config = parse_model(model_config=model_config) |
26
|
|
|
new_config["train"].update(model_config) |
27
|
|
|
|
28
|
|
|
new_config["train"]["loss"] = parse_loss(loss_config=new_config["train"]["loss"]) |
29
|
|
|
|
30
|
|
|
new_config["train"]["preprocess"] = parse_preprocess( |
31
|
|
|
preprocess_config=new_config["train"]["preprocess"] |
32
|
|
|
) |
33
|
|
|
|
34
|
|
|
new_config["train"]["optimizer"] = parse_optimizer( |
35
|
|
|
opt_config=new_config["train"]["optimizer"] |
36
|
|
|
) |
37
|
|
|
|
38
|
|
|
return new_config |
39
|
|
|
|
40
|
|
|
|
41
|
|
|
def parse_data(data_config: dict) -> Dict: |
42
|
|
|
""" |
43
|
|
|
Parse the data configuration. |
44
|
|
|
|
45
|
|
|
:param data_config: potentially outdated config |
46
|
|
|
:return: latest config |
47
|
|
|
""" |
48
|
|
|
if "format" not in data_config: |
49
|
|
|
# up-to-date |
50
|
|
|
return data_config |
51
|
|
|
|
52
|
|
|
# define format and labeled in each split |
53
|
|
|
dir = data_config.pop("dir") |
54
|
|
|
format = data_config.pop("format") |
55
|
|
|
labeled = data_config.pop("labeled") |
56
|
|
|
# get dir for each split |
57
|
|
|
for split in KNOWN_DATA_SPLITS: |
58
|
|
|
if split not in dir: |
59
|
|
|
continue |
60
|
|
|
data_config[split] = { |
61
|
|
|
"dir": dir[split], |
62
|
|
|
"format": format, |
63
|
|
|
"labeled": labeled, |
64
|
|
|
} |
65
|
|
|
return data_config |
66
|
|
|
|
67
|
|
|
|
68
|
|
|
def parse_model(model_config: Dict) -> Dict: |
69
|
|
|
""" |
70
|
|
|
Parse the model configuration. |
71
|
|
|
|
72
|
|
|
:param model_config: potentially outdated config |
73
|
|
|
:return: latest config |
74
|
|
|
""" |
75
|
|
|
# remove model layer |
76
|
|
|
if "model" in model_config: |
77
|
|
|
model_config = model_config["model"] |
78
|
|
|
|
79
|
|
|
if isinstance(model_config["backbone"], dict): |
80
|
|
|
# up-to-date |
81
|
|
|
return model_config |
82
|
|
|
|
83
|
|
|
backbone_name = model_config["backbone"] |
84
|
|
|
|
85
|
|
|
# backbone_config is the backbone name |
86
|
|
|
backbone_config = {"name": backbone_name, **model_config[backbone_name]} |
87
|
|
|
|
88
|
|
|
if backbone_name == "global": |
89
|
|
|
# global net use depth instead of extract_levels |
90
|
|
|
if "extract_levels" in backbone_config: |
91
|
|
|
extract_levels = backbone_config.pop("extract_levels") |
92
|
|
|
backbone_config["depth"] = max(extract_levels) |
93
|
|
|
|
94
|
|
|
model_config = {"method": model_config["method"], "backbone": backbone_config} |
95
|
|
|
return model_config |
96
|
|
|
|
97
|
|
|
|
98
|
|
|
def parse_image_loss(loss_config: Dict) -> Dict: |
99
|
|
|
""" |
100
|
|
|
Parse the image loss part in loss configuration. |
101
|
|
|
|
102
|
|
|
:param loss_config: potentially outdated config |
103
|
|
|
:return: latest config |
104
|
|
|
""" |
105
|
|
|
if "image" not in loss_config: |
106
|
|
|
# no image loss |
107
|
|
|
return loss_config |
108
|
|
|
|
109
|
|
|
if isinstance(loss_config["image"], list): |
110
|
|
|
# config up-to-date |
111
|
|
|
return loss_config |
112
|
|
|
|
113
|
|
|
image_loss_name = loss_config["image"]["name"] |
114
|
|
|
|
115
|
|
|
if image_loss_name not in loss_config["image"]: |
116
|
|
|
# config up-to-date |
117
|
|
|
return loss_config |
118
|
|
|
|
119
|
|
|
image_loss_config = { |
120
|
|
|
"name": image_loss_name, |
121
|
|
|
"weight": loss_config["image"].get("weight", 1.0), |
122
|
|
|
} |
123
|
|
|
image_loss_config.update(loss_config["image"][image_loss_name]) |
124
|
|
|
loss_config["image"] = image_loss_config |
125
|
|
|
return loss_config |
126
|
|
|
|
127
|
|
|
|
128
|
|
|
def parse_label_loss(loss_config: Dict) -> Dict: |
129
|
|
|
""" |
130
|
|
|
Parse the label loss part in loss configuration. |
131
|
|
|
|
132
|
|
|
:param loss_config: potentially outdated config |
133
|
|
|
:return: latest config |
134
|
|
|
""" |
135
|
|
|
if "label" not in loss_config: |
136
|
|
|
# no label loss |
137
|
|
|
return loss_config |
138
|
|
|
|
139
|
|
|
if isinstance(loss_config["label"], list): |
140
|
|
|
# config up-to-date |
141
|
|
|
return loss_config |
142
|
|
|
|
143
|
|
|
label_loss_name = loss_config["label"]["name"] |
144
|
|
|
if label_loss_name == "single_scale": |
145
|
|
|
loss_config["label"] = { |
146
|
|
|
"name": loss_config["label"]["single_scale"]["loss_type"], |
147
|
|
|
"weight": loss_config["label"].get("weight", 1.0), |
148
|
|
|
} |
149
|
|
|
elif label_loss_name == "multi_scale": |
150
|
|
|
loss_config["label"] = { |
151
|
|
|
"name": loss_config["label"]["multi_scale"]["loss_type"], |
152
|
|
|
"weight": loss_config["label"].get("weight", 1.0), |
153
|
|
|
"scales": loss_config["label"]["multi_scale"]["loss_scales"], |
154
|
|
|
} |
155
|
|
|
|
156
|
|
|
# mean-squared renamed to ssd |
157
|
|
|
if loss_config["label"]["name"] == "mean-squared": |
158
|
|
|
loss_config["label"]["name"] = "ssd" |
159
|
|
|
|
160
|
|
|
# dice_generalized merged into dice |
161
|
|
|
if loss_config["label"]["name"] == "dice_generalized": |
162
|
|
|
loss_config["label"]["name"] = "dice" |
163
|
|
|
|
164
|
|
|
# rename neg_weight to background_weight |
165
|
|
|
if "neg_weight" in loss_config["label"]: |
166
|
|
|
background_weight = loss_config["label"].pop("neg_weight") |
167
|
|
|
loss_config["label"]["background_weight"] = background_weight |
168
|
|
|
|
169
|
|
|
return loss_config |
170
|
|
|
|
171
|
|
|
|
172
|
|
|
def parse_reg_loss(loss_config: Dict) -> Dict: |
173
|
|
|
""" |
174
|
|
|
Parse the regularization loss part in loss configuration. |
175
|
|
|
|
176
|
|
|
:param loss_config: potentially outdated config |
177
|
|
|
:return: latest config |
178
|
|
|
""" |
179
|
|
|
if "regularization" not in loss_config: |
180
|
|
|
# no regularization loss |
181
|
|
|
return loss_config |
182
|
|
|
|
183
|
|
|
if isinstance(loss_config["regularization"], list): |
184
|
|
|
# config up-to-date |
185
|
|
|
return loss_config |
186
|
|
|
|
187
|
|
|
if "energy_type" not in loss_config["regularization"]: |
188
|
|
|
# up-to-date |
189
|
|
|
return loss_config |
190
|
|
|
|
191
|
|
|
energy_type = loss_config["regularization"]["energy_type"] |
192
|
|
|
reg_config = {"weight": loss_config["regularization"].get("weight", 1.0)} |
193
|
|
|
if energy_type == "bending": |
194
|
|
|
reg_config["name"] = "bending" |
195
|
|
|
elif energy_type == "gradient-l2": |
196
|
|
|
reg_config["name"] = "gradient" |
197
|
|
|
reg_config["l1"] = False |
198
|
|
|
elif energy_type == "gradient-l1": |
199
|
|
|
reg_config["name"] = "gradient" |
200
|
|
|
reg_config["l1"] = True |
201
|
|
|
loss_config["regularization"] = reg_config |
202
|
|
|
|
203
|
|
|
return loss_config |
204
|
|
|
|
205
|
|
|
|
206
|
|
|
def parse_loss(loss_config: Dict) -> Dict: |
207
|
|
|
""" |
208
|
|
|
Parse the loss configuration. |
209
|
|
|
|
210
|
|
|
:param loss_config: potentially outdated config |
211
|
|
|
:return: latest config |
212
|
|
|
""" |
213
|
|
|
# remove dissimilarity layer |
214
|
|
|
if "dissimilarity" in loss_config: |
215
|
|
|
dissim_config = loss_config.pop("dissimilarity") |
216
|
|
|
loss_config.update(dissim_config) |
217
|
|
|
|
218
|
|
|
loss_config = parse_image_loss(loss_config=loss_config) |
219
|
|
|
loss_config = parse_label_loss(loss_config=loss_config) |
220
|
|
|
loss_config = parse_reg_loss(loss_config=loss_config) |
221
|
|
|
|
222
|
|
|
return loss_config |
223
|
|
|
|
224
|
|
|
|
225
|
|
|
def parse_preprocess(preprocess_config: Dict) -> Dict: |
226
|
|
|
""" |
227
|
|
|
Parse the preprocess configuration. |
228
|
|
|
|
229
|
|
|
:param preprocess_config: potentially outdated config |
230
|
|
|
:return: latest config |
231
|
|
|
""" |
232
|
|
|
if "data_augmentation" not in preprocess_config: |
233
|
|
|
preprocess_config["data_augmentation"] = {"name": "affine"} |
234
|
|
|
return preprocess_config |
235
|
|
|
|
236
|
|
|
|
237
|
|
|
def parse_optimizer(opt_config: Dict) -> Dict: |
238
|
|
|
""" |
239
|
|
|
Parse the optimizer configuration. |
240
|
|
|
|
241
|
|
|
:param opt_config: potentially outdated config |
242
|
|
|
:return: latest config |
243
|
|
|
""" |
244
|
|
|
name = opt_config["name"] |
245
|
|
|
if name not in opt_config: |
246
|
|
|
# up-to-date |
247
|
|
|
return opt_config |
248
|
|
|
|
249
|
|
|
name_dict = dict( |
250
|
|
|
adam="Adam", |
251
|
|
|
sgd="SGD", |
252
|
|
|
rms="RMSprop", |
253
|
|
|
) |
254
|
|
|
new_name = name_dict[name] |
255
|
|
|
return {"name": new_name, **opt_config[name]} |
256
|
|
|
|