Passed
Pull Request — main (#662)
by
unknown
05:24 queued 01:54
created

RegistrationModel._build_loss()   B

Complexity

Conditions 6

Size

Total Lines 56
Code Lines 29

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 29
dl 0
loc 56
rs 8.2506
c 0
b 0
f 0
cc 6
nop 3

How to fix   Long Method   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

1
import logging
0 ignored issues
show
introduced by
Missing module docstring
Loading history...
2
from abc import abstractmethod
3
from copy import deepcopy
4
from typing import Dict, Optional
5
6
import tensorflow as tf
7
8
from deepreg.model import layer, layer_util
9
from deepreg.model.backbone import GlobalNet
10
from deepreg.registry import REGISTRY
11
12
13
def dict_without(d: dict, key) -> dict:
0 ignored issues
show
introduced by
"key" missing in parameter type documentation
Loading history...
14
    """
15
    Return a copy of the given dict without a certain key.
16
17
    :param d: dict to be copied.
18
    :param key: key to be removed.
19
    :return: the copy without a key
20
    """
21
    copied = deepcopy(d)
22
    copied.pop(key)
23
    return copied
24
25
26
class RegistrationModel(tf.keras.Model):
27
    """Interface for registration model."""
28
29
    def __init__(
30
        self,
31
        moving_image_size: tuple,
32
        fixed_image_size: tuple,
33
        index_size: int,
34
        labeled: bool,
35
        batch_size: int,
36
        config: dict,
37
        num_devices: int = 1,
38
        name: str = "RegistrationModel",
39
    ):
40
        """
41
        Init.
42
43
        :param moving_image_size: (m_dim1, m_dim2, m_dim3)
44
        :param fixed_image_size: (f_dim1, f_dim2, f_dim3)
45
        :param index_size: number of indices for identify each sample
46
        :param labeled: if the data is labeled
47
        :param batch_size: size of mini-batch
48
        :param config: config for method, backbone, and loss.
49
        :param num_devices: number of GPU used,
50
            global_batch_size = batch_size*num_devices
51
        :param name: name of the model
52
        """
53
        super().__init__(name=name)
54
        self.moving_image_size = moving_image_size
55
        self.fixed_image_size = fixed_image_size
56
        self.index_size = index_size
57
        self.labeled = labeled
58
        self.batch_size = batch_size
59
        self.config = config
60
        self.num_devices = num_devices
61
        self.global_batch_size = num_devices * batch_size
62
63
        self._inputs = None  # save inputs of self._model as dict
64
        self._outputs = None  # save outputs of self._model as dict
65
66
        self._model = self.build_model()
67
        self.build_loss()
68
69
    def get_config(self) -> dict:
70
        """Return the config dictionary for recreating this class."""
71
        return dict(
72
            moving_image_size=self.moving_image_size,
73
            fixed_image_size=self.fixed_image_size,
74
            index_size=self.index_size,
75
            labeled=self.labeled,
76
            batch_size=self.batch_size,
77
            config=self.config,
78
            num_devices=self.num_devices,
79
            name=self.name,
80
        )
81
82
    @abstractmethod
83
    def build_model(self):
84
        """Build the model to be saved as self._model."""
85
86
    def build_inputs(self) -> Dict[str, tf.keras.layers.Input]:
87
        """
88
        Build input tensors.
89
90
        :return: dict of inputs.
91
        """
92
        # (batch, m_dim1, m_dim2, m_dim3, 1)
93
        moving_image = tf.keras.Input(
94
            shape=self.moving_image_size,
95
            batch_size=self.batch_size,
96
            name="moving_image",
97
        )
98
        # (batch, f_dim1, f_dim2, f_dim3, 1)
99
        fixed_image = tf.keras.Input(
100
            shape=self.fixed_image_size,
101
            batch_size=self.batch_size,
102
            name="fixed_image",
103
        )
104
        # (batch, index_size)
105
        indices = tf.keras.Input(
106
            shape=(self.index_size,),
107
            batch_size=self.batch_size,
108
            name="indices",
109
        )
110
111
        if not self.labeled:
112
            return dict(
113
                moving_image=moving_image, fixed_image=fixed_image, indices=indices
114
            )
115
116
        # (batch, m_dim1, m_dim2, m_dim3, 1)
117
        moving_label = tf.keras.Input(
118
            shape=self.moving_image_size,
119
            batch_size=self.batch_size,
120
            name="moving_label",
121
        )
122
        # (batch, m_dim1, m_dim2, m_dim3, 1)
123
        fixed_label = tf.keras.Input(
124
            shape=self.fixed_image_size,
125
            batch_size=self.batch_size,
126
            name="fixed_label",
127
        )
128
        return dict(
129
            moving_image=moving_image,
130
            fixed_image=fixed_image,
131
            moving_label=moving_label,
132
            fixed_label=fixed_label,
133
            indices=indices,
134
        )
135
136
    def concat_images(
137
        self,
138
        moving_image: tf.Tensor,
139
        fixed_image: tf.Tensor,
140
        moving_label: Optional[tf.Tensor] = None,
141
    ) -> tf.Tensor:
142
        """
143
        Adjust image shape and concatenate them together.
144
145
        :param moving_image: registration source
146
        :param fixed_image: registration target
147
        :param moving_label: optional, only used for conditional model.
148
        :return:
149
        """
150
        images = []
151
152
        # (batch, m_dim1, m_dim2, m_dim3, 1)
153
        moving_image = tf.expand_dims(moving_image, axis=4)
154
        moving_image = layer_util.resize3d(
155
            image=moving_image, size=self.fixed_image_size
156
        )
157
        images.append(moving_image)
158
159
        # (batch, m_dim1, m_dim2, m_dim3, 1)
160
        fixed_image = tf.expand_dims(fixed_image, axis=4)
161
        images.append(fixed_image)
162
163
        # (batch, m_dim1, m_dim2, m_dim3, 1)
164
        if moving_label is not None:
165
            moving_label = tf.expand_dims(moving_label, axis=4)
166
            moving_label = layer_util.resize3d(
167
                image=moving_label, size=self.fixed_image_size
168
            )
169
            images.append(moving_label)
170
171
        # (batch, f_dim1, f_dim2, f_dim3, 2 or 3)
172
        images = tf.concat(images, axis=4)
173
        return images
174
175
    def _build_loss(self, name: str, inputs_dict: dict):
176
        """
177
        Build and add one weighted loss together with the metrics.
178
179
        :param name: name of loss
180
        :param inputs_dict: inputs for loss function
181
        """
182
183
        if name not in self.config["loss"]:
184
            # loss config is not defined
185
            logging.warning(
0 ignored issues
show
introduced by
Use lazy % formatting in logging functions
Loading history...
186
                f"The configuration for loss {name} is not defined."
187
                f"Loss is not used."
188
            )
189
            return
190
191
        loss_list = self.config["loss"][name]
192
        if not isinstance(loss_list, list):
193
            loss_list = [loss_list]
194
195
        for loss_config in loss_list:
196
197
            if "weight" not in loss_config:
198
                # default loss weight 1
199
                logging.warning(
0 ignored issues
show
introduced by
Use lazy % formatting in logging functions
Loading history...
200
                    f"The weight for loss {name} is not defined."
201
                    f"Default weight = 1.0 is used."
202
                )
203
                loss_config["weight"] = 1.0
204
205
            # build loss
206
            weight = loss_config["weight"]
207
208
            if weight == 0:
209
                logging.warning(
0 ignored issues
show
introduced by
Use lazy % formatting in logging functions
Loading history...
210
                    f"The weight for loss {name} is zero." f"Loss is not used."
211
                )
212
                return
213
214
            loss_cls = REGISTRY.build_loss(
215
                config=dict_without(d=loss_config, key="weight")
216
            )
217
            loss = loss_cls(**inputs_dict) / self.global_batch_size
218
            weighted_loss = loss * weight
219
220
            # add loss
221
            self._model.add_loss(weighted_loss)
222
223
            # add metric
224
            self._model.add_metric(
225
                loss, name=f"loss/{name}_{loss_cls.name}", aggregation="mean"
226
            )
227
            self._model.add_metric(
228
                weighted_loss,
229
                name=f"loss/{name}_{loss_cls.name}_weighted",
230
                aggregation="mean",
231
            )
232
233
    @abstractmethod
234
    def build_loss(self):
235
        """Build losses according to configs."""
236
237
    def call(
0 ignored issues
show
introduced by
"mask, training" missing in parameter type documentation
Loading history...
238
        self, inputs: Dict[str, tf.Tensor], training=None, mask=None
239
    ) -> Dict[str, tf.Tensor]:
240
        """
241
        Call the self._model.
242
243
        :param inputs: a dict of tensors.
244
        :param training: training or not.
245
        :param mask: maks for inputs.
246
        :return:
247
        """
248
        return self._model(inputs, training=training, mask=mask)  # pragma: no cover
249
250
    @abstractmethod
251
    def postprocess(
252
        self,
253
        inputs: Dict[str, tf.Tensor],
254
        outputs: Dict[str, tf.Tensor],
255
    ) -> (tf.Tensor, Dict):
256
        """
257
        Return a dict used for saving inputs and outputs.
258
259
        :param inputs: dict of model inputs
260
        :param outputs: dict of model outputs
261
        :return: tuple, indices and a dict.
262
            In the dict, each value is (tensor, normalize, on_label), where
263
            - normalize = True if the tensor need to be normalized to [0, 1]
264
            - on_label = True if the tensor depends on label
265
        """
266
267
268
@REGISTRY.register_model(name="ddf")
269
class DDFModel(RegistrationModel):
270
    """
271
    A registration model predicts DDF.
272
273
    When using global net as backbone,
274
    the model predicts an affine transformation parameters,
275
    and a DDF is calculated based on that.
276
    """
277
278
    def build_model(self):
279
        """Build the model to be saved as self._model."""
280
        # build inputs
281
        self._inputs = self.build_inputs()
282
        moving_image = self._inputs["moving_image"]
283
        fixed_image = self._inputs["fixed_image"]
284
285
        # build ddf
286
        backbone_inputs = self.concat_images(moving_image, fixed_image)
287
        backbone = REGISTRY.build_backbone(
288
            config=self.config["backbone"],
289
            default_args=dict(
290
                image_size=self.fixed_image_size,
291
                out_channels=3,
292
                out_kernel_initializer="zeros",
293
                out_activation=None,
294
            ),
295
        )
296
297
        if isinstance(backbone, GlobalNet):
298
            # (f_dim1, f_dim2, f_dim3, 3), (4, 3)
299
            ddf, theta = backbone(inputs=backbone_inputs)
300
            self._outputs = dict(ddf=ddf, theta=theta)
301
        else:
302
            # (f_dim1, f_dim2, f_dim3, 3)
303
            ddf = backbone(inputs=backbone_inputs)
304
            self._outputs = dict(ddf=ddf)
305
306
        # build outputs
307
        warping = layer.Warping(fixed_image_size=self.fixed_image_size)
308
        # (f_dim1, f_dim2, f_dim3, 3)
309
        pred_fixed_image = warping(inputs=[ddf, moving_image])
310
        self._outputs["pred_fixed_image"] = pred_fixed_image
311
312
        if not self.labeled:
313
            return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
314
315
        # (f_dim1, f_dim2, f_dim3, 3)
316
        moving_label = self._inputs["moving_label"]
317
        pred_fixed_label = warping(inputs=[ddf, moving_label])
318
319
        self._outputs["pred_fixed_label"] = pred_fixed_label
320
        return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
321
322
    def build_loss(self):
323
        """Build losses according to configs."""
324
        fixed_image = self._inputs["fixed_image"]
325
        ddf = self._outputs["ddf"]
326
        pred_fixed_image = self._outputs["pred_fixed_image"]
327
328
        # ddf
329
        self._build_loss(name="regularization", inputs_dict=dict(inputs=ddf))
330
331
        # image
332
        self._build_loss(
333
            name="image", inputs_dict=dict(y_true=fixed_image, y_pred=pred_fixed_image)
334
        )
335
336
        # label
337
        if self.labeled:
338
            fixed_label = self._inputs["fixed_label"]
339
            pred_fixed_label = self._outputs["pred_fixed_label"]
340
            self._build_loss(
341
                name="label",
342
                inputs_dict=dict(y_true=fixed_label, y_pred=pred_fixed_label),
343
            )
344
345
    def postprocess(
346
        self,
347
        inputs: Dict[str, tf.Tensor],
348
        outputs: Dict[str, tf.Tensor],
349
    ) -> (tf.Tensor, Dict):
350
        """
351
        Return a dict used for saving inputs and outputs.
352
353
        :param inputs: dict of model inputs
354
        :param outputs: dict of model outputs
355
        :return: tuple, indices and a dict.
356
            In the dict, each value is (tensor, normalize, on_label), where
357
            - normalize = True if the tensor need to be normalized to [0, 1]
358
            - on_label = True if the tensor depends on label
359
        """
360
        indices = inputs["indices"]
361
        processed = dict(
362
            moving_image=(inputs["moving_image"], True, False),
363
            fixed_image=(inputs["fixed_image"], True, False),
364
            ddf=(outputs["ddf"], True, False),
365
            pred_fixed_image=(outputs["pred_fixed_image"], True, False),
366
        )
367
368
        # save theta for affine model
369
        if "theta" in outputs:
370
            processed["theta"] = (outputs["theta"], None, None)
371
372
        if not self.labeled:
373
            return indices, processed
374
375
        processed = {
376
            **dict(
377
                moving_label=(inputs["moving_label"], False, True),
378
                fixed_label=(inputs["fixed_label"], False, True),
379
                pred_fixed_label=(outputs["pred_fixed_label"], False, True),
380
            ),
381
            **processed,
382
        }
383
384
        return indices, processed
385
386
387
@REGISTRY.register_model(name="dvf")
388
class DVFModel(DDFModel):
389
    """
390
    A registration model predicts DVF.
391
392
    DDF is calculated based on DVF.
393
    """
394
395
    def build_model(self):
396
        """Build the model to be saved as self._model."""
397
        # build inputs
398
        self._inputs = self.build_inputs()
399
        moving_image = self._inputs["moving_image"]
400
        fixed_image = self._inputs["fixed_image"]
401
402
        # build ddf
403
        backbone_inputs = self.concat_images(moving_image, fixed_image)
404
        backbone = REGISTRY.build_backbone(
405
            config=self.config["backbone"],
406
            default_args=dict(
407
                image_size=self.fixed_image_size,
408
                out_channels=3,
409
                out_kernel_initializer="zeros",
410
                out_activation=None,
411
            ),
412
        )
413
        dvf = backbone(inputs=backbone_inputs)
414
        ddf = layer.IntDVF(fixed_image_size=self.fixed_image_size)(dvf)
415
416
        # build outputs
417
        warping = layer.Warping(fixed_image_size=self.fixed_image_size)
418
        # (f_dim1, f_dim2, f_dim3, 3)
419
        pred_fixed_image = warping(inputs=[ddf, moving_image])
420
421
        self._outputs = dict(dvf=dvf, ddf=ddf, pred_fixed_image=pred_fixed_image)
422
423
        if not self.labeled:
424
            return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
425
426
        # (f_dim1, f_dim2, f_dim3, 3)
427
        moving_label = self._inputs["moving_label"]
428
        pred_fixed_label = warping(inputs=[ddf, moving_label])
429
430
        self._outputs["pred_fixed_label"] = pred_fixed_label
431
        return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
432
433
    def postprocess(
434
        self,
435
        inputs: Dict[str, tf.Tensor],
436
        outputs: Dict[str, tf.Tensor],
437
    ) -> (tf.Tensor, Dict):
438
        """
439
        Return a dict used for saving inputs and outputs.
440
441
        :param inputs: dict of model inputs
442
        :param outputs: dict of model outputs
443
        :return: tuple, indices and a dict.
444
            In the dict, each value is (tensor, normalize, on_label), where
445
            - normalize = True if the tensor need to be normalized to [0, 1]
446
            - on_label = True if the tensor depends on label
447
        """
448
        indices, processed = super().postprocess(inputs=inputs, outputs=outputs)
449
        processed["dvf"] = (outputs["dvf"], True, False)
450
        return indices, processed
451
452
453
@REGISTRY.register_model(name="conditional")
454
class ConditionalModel(RegistrationModel):
455
    """
456
    A registration model predicts fixed image label without DDF or DVF.
457
    """
458
459
    def build_model(self):
460
        """Build the model to be saved as self._model."""
461
        assert self.labeled
462
463
        # build inputs
464
        self._inputs = self.build_inputs()
465
        moving_image = self._inputs["moving_image"]
466
        fixed_image = self._inputs["fixed_image"]
467
        moving_label = self._inputs["moving_label"]
468
469
        # build ddf
470
        backbone_inputs = self.concat_images(moving_image, fixed_image, moving_label)
471
        backbone = REGISTRY.build_backbone(
472
            config=self.config["backbone"],
473
            default_args=dict(
474
                image_size=self.fixed_image_size,
475
                out_channels=1,
476
                out_kernel_initializer="glorot_uniform",
477
                out_activation="sigmoid",
478
            ),
479
        )
480
        # (batch, f_dim1, f_dim2, f_dim3)
481
        pred_fixed_label = backbone(inputs=backbone_inputs)
482
        pred_fixed_label = tf.squeeze(pred_fixed_label, axis=4)
483
484
        self._outputs = dict(pred_fixed_label=pred_fixed_label)
485
        return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
486
487
    def build_loss(self):
488
        """Build losses according to configs."""
489
        fixed_label = self._inputs["fixed_label"]
490
        pred_fixed_label = self._outputs["pred_fixed_label"]
491
492
        self._build_loss(
493
            name="label",
494
            inputs_dict=dict(y_true=fixed_label, y_pred=pred_fixed_label),
495
        )
496
497
    def postprocess(
498
        self,
499
        inputs: Dict[str, tf.Tensor],
500
        outputs: Dict[str, tf.Tensor],
501
    ) -> (tf.Tensor, Dict):
502
        """
503
        Return a dict used for saving inputs and outputs.
504
505
        :param inputs: dict of model inputs
506
        :param outputs: dict of model outputs
507
        :return: tuple, indices and a dict.
508
            In the dict, each value is (tensor, normalize, on_label), where
509
            - normalize = True if the tensor need to be normalized to [0, 1]
510
            - on_label = True if the tensor depends on label
511
        """
512
        indices = inputs["indices"]
513
        processed = dict(
514
            moving_image=(inputs["moving_image"], True, False),
515
            fixed_image=(inputs["fixed_image"], True, False),
516
            pred_fixed_label=(outputs["pred_fixed_label"], True, True),
517
            moving_label=(inputs["moving_label"], False, True),
518
            fixed_label=(inputs["fixed_label"], False, True),
519
        )
520
521
        return indices, processed
522