Completed
Push — main ( 0bdbdf...808540 )
by
unknown
19s queued 12s
created

RegistrationModel._build_loss()   B

Complexity

Conditions 6

Size

Total Lines 56
Code Lines 29

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 29
dl 0
loc 56
rs 8.2506
c 0
b 0
f 0
cc 6
nop 3

How to fix   Long Method   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

1
import logging
0 ignored issues
show
introduced by
Missing module docstring
Loading history...
2
from abc import abstractmethod
3
from copy import deepcopy
4
from typing import Dict, Optional
5
6
import tensorflow as tf
7
8
from deepreg.model import layer, layer_util
9
from deepreg.model.backbone import GlobalNet
10
from deepreg.registry import REGISTRY
11
12
13
def dict_without(d: dict, key) -> dict:
0 ignored issues
show
introduced by
"key" missing in parameter type documentation
Loading history...
14
    """
15
    Return a copy of the given dict without a certain key.
16
17
    :param d: dict to be copied.
18
    :param key: key to be removed.
19
    :return: the copy without a key
20
    """
21
    copied = deepcopy(d)
22
    copied.pop(key)
23
    return copied
24
25
26
class RegistrationModel(tf.keras.Model):
27
    """Interface for registration model."""
28
29
    def __init__(
30
        self,
31
        moving_image_size: tuple,
32
        fixed_image_size: tuple,
33
        index_size: int,
34
        labeled: bool,
35
        batch_size: int,
36
        config: dict,
37
        num_devices: int = 1,
38
        name: str = "RegistrationModel",
39
    ):
40
        """
41
        Init.
42
43
        :param moving_image_size: (m_dim1, m_dim2, m_dim3)
44
        :param fixed_image_size: (f_dim1, f_dim2, f_dim3)
45
        :param index_size: number of indices for identify each sample
46
        :param labeled: if the data is labeled
47
        :param batch_size: size of mini-batch
48
        :param config: config for method, backbone, and loss.
49
        :param num_devices: number of GPU used,
50
            global_batch_size = batch_size*num_devices
51
        :param name: name of the model
52
        """
53
        super().__init__(name=name)
54
        self.moving_image_size = moving_image_size
55
        self.fixed_image_size = fixed_image_size
56
        self.index_size = index_size
57
        self.labeled = labeled
58
        self.batch_size = batch_size
59
        self.config = config
60
        self.num_devices = num_devices
61
        self.global_batch_size = num_devices * batch_size
62
63
        self._inputs = None  # save inputs of self._model as dict
64
        self._outputs = None  # save outputs of self._model as dict
65
66
        self._model = self.build_model()
67
        self.build_loss()
68
69
    def get_config(self) -> dict:
70
        """Return the config dictionary for recreating this class."""
71
        return dict(
72
            moving_image_size=self.moving_image_size,
73
            fixed_image_size=self.fixed_image_size,
74
            index_size=self.index_size,
75
            labeled=self.labeled,
76
            batch_size=self.batch_size,
77
            config=self.config,
78
            num_devices=self.num_devices,
79
            name=self.name,
80
        )
81
82
    @abstractmethod
83
    def build_model(self):
84
        """Build the model to be saved as self._model."""
85
86
    def build_inputs(self) -> Dict[str, tf.keras.layers.Input]:
87
        """
88
        Build input tensors.
89
90
        :return: dict of inputs.
91
        """
92
        # (batch, m_dim1, m_dim2, m_dim3, 1)
93
        moving_image = tf.keras.Input(
94
            shape=self.moving_image_size,
95
            batch_size=self.batch_size,
96
            name="moving_image",
97
        )
98
        # (batch, f_dim1, f_dim2, f_dim3, 1)
99
        fixed_image = tf.keras.Input(
100
            shape=self.fixed_image_size,
101
            batch_size=self.batch_size,
102
            name="fixed_image",
103
        )
104
        # (batch, index_size)
105
        indices = tf.keras.Input(
106
            shape=(self.index_size,),
107
            batch_size=self.batch_size,
108
            name="indices",
109
        )
110
111
        if not self.labeled:
112
            return dict(
113
                moving_image=moving_image, fixed_image=fixed_image, indices=indices
114
            )
115
116
        # (batch, m_dim1, m_dim2, m_dim3, 1)
117
        moving_label = tf.keras.Input(
118
            shape=self.moving_image_size,
119
            batch_size=self.batch_size,
120
            name="moving_label",
121
        )
122
        # (batch, m_dim1, m_dim2, m_dim3, 1)
123
        fixed_label = tf.keras.Input(
124
            shape=self.fixed_image_size,
125
            batch_size=self.batch_size,
126
            name="fixed_label",
127
        )
128
        return dict(
129
            moving_image=moving_image,
130
            fixed_image=fixed_image,
131
            moving_label=moving_label,
132
            fixed_label=fixed_label,
133
            indices=indices,
134
        )
135
136
    def concat_images(
137
        self,
138
        moving_image: tf.Tensor,
139
        fixed_image: tf.Tensor,
140
        moving_label: Optional[tf.Tensor] = None,
141
    ) -> tf.Tensor:
142
        """
143
        Adjust image shape and concatenate them together.
144
145
        :param moving_image: registration source
146
        :param fixed_image: registration target
147
        :param moving_label: optional, only used for conditional model.
148
        :return:
149
        """
150
        images = []
151
152
        # (batch, m_dim1, m_dim2, m_dim3, 1)
153
        moving_image = tf.expand_dims(moving_image, axis=4)
154
        moving_image = layer_util.resize3d(
155
            image=moving_image, size=self.fixed_image_size
156
        )
157
        images.append(moving_image)
158
159
        # (batch, m_dim1, m_dim2, m_dim3, 1)
160
        fixed_image = tf.expand_dims(fixed_image, axis=4)
161
        images.append(fixed_image)
162
163
        # (batch, m_dim1, m_dim2, m_dim3, 1)
164
        if moving_label is not None:
165
            moving_label = tf.expand_dims(moving_label, axis=4)
166
            moving_label = layer_util.resize3d(
167
                image=moving_label, size=self.fixed_image_size
168
            )
169
            images.append(moving_label)
170
171
        # (batch, f_dim1, f_dim2, f_dim3, 2 or 3)
172
        images = tf.concat(images, axis=4)
173
        return images
174
175
    def _build_loss(self, name: str, inputs_dict: dict):
176
        """
177
        Build and add one weighted loss together with the metrics.
178
179
        :param name: name of loss
180
        :param inputs_dict: inputs for loss function
181
        """
182
183
        if name not in self.config["loss"]:
184
            # loss config is not defined
185
            logging.warning(
0 ignored issues
show
introduced by
Use lazy % formatting in logging functions
Loading history...
186
                f"The configuration for loss {name} is not defined."
187
                f"Loss is not used."
188
            )
189
            return
190
191
        loss_configs = self.config["loss"][name]
192
        if not isinstance(loss_configs, list):
193
            loss_configs = [loss_configs]
194
195
        for loss_config in loss_configs:
196
197
            if "weight" not in loss_config:
198
                # default loss weight 1
199
                logging.warning(
0 ignored issues
show
introduced by
Use lazy % formatting in logging functions
Loading history...
200
                    f"The weight for loss {name} is not defined."
201
                    f"Default weight = 1.0 is used."
202
                )
203
                loss_config["weight"] = 1.0
204
205
            # build loss
206
            weight = loss_config["weight"]
207
208
            if weight == 0:
209
                logging.warning(
0 ignored issues
show
introduced by
Use lazy % formatting in logging functions
Loading history...
210
                    f"The weight for loss {name} is zero." f"Loss is not used."
211
                )
212
                return
213
214
            loss_cls = REGISTRY.build_loss(
215
                config=dict_without(d=loss_config, key="weight")
216
            )
217
            loss = loss_cls(**inputs_dict) / self.global_batch_size
218
            weighted_loss = loss * weight
219
220
            # add loss
221
            self._model.add_loss(weighted_loss)
222
223
            # add metric
224
            self._model.add_metric(
225
                loss, name=f"loss/{name}_{loss_cls.name}", aggregation="mean"
226
            )
227
            self._model.add_metric(
228
                weighted_loss,
229
                name=f"loss/{name}_{loss_cls.name}_weighted",
230
                aggregation="mean",
231
            )
232
233
    @abstractmethod
234
    def build_loss(self):
235
        """Build losses according to configs."""
236
237
    def call(
0 ignored issues
show
introduced by
"mask, training" missing in parameter type documentation
Loading history...
238
        self, inputs: Dict[str, tf.Tensor], training=None, mask=None
239
    ) -> Dict[str, tf.Tensor]:
240
        """
241
        Call the self._model.
242
243
        :param inputs: a dict of tensors.
244
        :param training: training or not.
245
        :param mask: maks for inputs.
246
        :return:
247
        """
248
        return self._model(inputs, training=training, mask=mask)  # pragma: no cover
249
250
    @abstractmethod
251
    def postprocess(
252
        self,
253
        inputs: Dict[str, tf.Tensor],
254
        outputs: Dict[str, tf.Tensor],
255
    ) -> (tf.Tensor, Dict):
256
        """
257
        Return a dict used for saving inputs and outputs.
258
259
        :param inputs: dict of model inputs
260
        :param outputs: dict of model outputs
261
        :return: tuple, indices and a dict.
262
            In the dict, each value is (tensor, normalize, on_label), where
263
            - normalize = True if the tensor need to be normalized to [0, 1]
264
            - on_label = True if the tensor depends on label
265
        """
266
267
268
@REGISTRY.register_model(name="ddf")
269
class DDFModel(RegistrationModel):
270
    """
271
    A registration model predicts DDF.
272
273
    When using global net as backbone,
274
    the model predicts an affine transformation parameters,
275
    and a DDF is calculated based on that.
276
    """
277
278
    def _resize_interpolate(self, field, control_points):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
279
        resize = layer.ResizeCPTransform(control_points)
280
        field = resize(field)
281
282
        interpolate = layer.BSplines3DTransform(control_points, self.fixed_image_size)
283
        field = interpolate(field)
284
285
        return field
286
287
    def build_model(self):
288
        """Build the model to be saved as self._model."""
289
        # build inputs
290
        self._inputs = self.build_inputs()
291
        moving_image = self._inputs["moving_image"]
292
        fixed_image = self._inputs["fixed_image"]
293
294
        # build ddf
295
        control_points = self.config["backbone"].pop("control_points", False)
296
        backbone_inputs = self.concat_images(moving_image, fixed_image)
297
        backbone = REGISTRY.build_backbone(
298
            config=self.config["backbone"],
299
            default_args=dict(
300
                image_size=self.fixed_image_size,
301
                out_channels=3,
302
                out_kernel_initializer="zeros",
303
                out_activation=None,
304
            ),
305
        )
306
307
        if isinstance(backbone, GlobalNet):
308
            # (f_dim1, f_dim2, f_dim3, 3), (4, 3)
309
            ddf, theta = backbone(inputs=backbone_inputs)
310
            self._outputs = dict(ddf=ddf, theta=theta)
311
        else:
312
            # (f_dim1, f_dim2, f_dim3, 3)
313
            ddf = backbone(inputs=backbone_inputs)
314
            ddf = (
315
                self._resize_interpolate(ddf, control_points) if control_points else ddf
316
            )
317
            self._outputs = dict(ddf=ddf)
318
319
        # build outputs
320
        warping = layer.Warping(fixed_image_size=self.fixed_image_size)
321
        # (f_dim1, f_dim2, f_dim3, 3)
322
        pred_fixed_image = warping(inputs=[ddf, moving_image])
323
        self._outputs["pred_fixed_image"] = pred_fixed_image
324
325
        if not self.labeled:
326
            return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
327
328
        # (f_dim1, f_dim2, f_dim3, 3)
329
        moving_label = self._inputs["moving_label"]
330
        pred_fixed_label = warping(inputs=[ddf, moving_label])
331
332
        self._outputs["pred_fixed_label"] = pred_fixed_label
333
        return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
334
335
    def build_loss(self):
336
        """Build losses according to configs."""
337
        fixed_image = self._inputs["fixed_image"]
338
        ddf = self._outputs["ddf"]
339
        pred_fixed_image = self._outputs["pred_fixed_image"]
340
341
        # ddf
342
        self._build_loss(name="regularization", inputs_dict=dict(inputs=ddf))
343
344
        # image
345
        self._build_loss(
346
            name="image", inputs_dict=dict(y_true=fixed_image, y_pred=pred_fixed_image)
347
        )
348
349
        # label
350
        if self.labeled:
351
            fixed_label = self._inputs["fixed_label"]
352
            pred_fixed_label = self._outputs["pred_fixed_label"]
353
            self._build_loss(
354
                name="label",
355
                inputs_dict=dict(y_true=fixed_label, y_pred=pred_fixed_label),
356
            )
357
358
    def postprocess(
359
        self,
360
        inputs: Dict[str, tf.Tensor],
361
        outputs: Dict[str, tf.Tensor],
362
    ) -> (tf.Tensor, Dict):
363
        """
364
        Return a dict used for saving inputs and outputs.
365
366
        :param inputs: dict of model inputs
367
        :param outputs: dict of model outputs
368
        :return: tuple, indices and a dict.
369
            In the dict, each value is (tensor, normalize, on_label), where
370
            - normalize = True if the tensor need to be normalized to [0, 1]
371
            - on_label = True if the tensor depends on label
372
        """
373
        indices = inputs["indices"]
374
        processed = dict(
375
            moving_image=(inputs["moving_image"], True, False),
376
            fixed_image=(inputs["fixed_image"], True, False),
377
            ddf=(outputs["ddf"], True, False),
378
            pred_fixed_image=(outputs["pred_fixed_image"], True, False),
379
        )
380
381
        # save theta for affine model
382
        if "theta" in outputs:
383
            processed["theta"] = (outputs["theta"], None, None)
384
385
        if not self.labeled:
386
            return indices, processed
387
388
        processed = {
389
            **dict(
390
                moving_label=(inputs["moving_label"], False, True),
391
                fixed_label=(inputs["fixed_label"], False, True),
392
                pred_fixed_label=(outputs["pred_fixed_label"], False, True),
393
            ),
394
            **processed,
395
        }
396
397
        return indices, processed
398
399
400
@REGISTRY.register_model(name="dvf")
401
class DVFModel(DDFModel):
402
    """
403
    A registration model predicts DVF.
404
405
    DDF is calculated based on DVF.
406
    """
407
408
    def build_model(self):
409
        """Build the model to be saved as self._model."""
410
        # build inputs
411
        self._inputs = self.build_inputs()
412
        moving_image = self._inputs["moving_image"]
413
        fixed_image = self._inputs["fixed_image"]
414
        control_points = self.config["backbone"].pop("control_points", False)
415
416
        # build ddf
417
        backbone_inputs = self.concat_images(moving_image, fixed_image)
418
        backbone = REGISTRY.build_backbone(
419
            config=self.config["backbone"],
420
            default_args=dict(
421
                image_size=self.fixed_image_size,
422
                out_channels=3,
423
                out_kernel_initializer="zeros",
424
                out_activation=None,
425
            ),
426
        )
427
        dvf = backbone(inputs=backbone_inputs)
428
        dvf = self._resize_interpolate(dvf, control_points) if control_points else dvf
429
        ddf = layer.IntDVF(fixed_image_size=self.fixed_image_size)(dvf)
430
431
        # build outputs
432
        warping = layer.Warping(fixed_image_size=self.fixed_image_size)
433
        # (f_dim1, f_dim2, f_dim3, 3)
434
        pred_fixed_image = warping(inputs=[ddf, moving_image])
435
436
        self._outputs = dict(dvf=dvf, ddf=ddf, pred_fixed_image=pred_fixed_image)
437
438
        if not self.labeled:
439
            return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
440
441
        # (f_dim1, f_dim2, f_dim3, 3)
442
        moving_label = self._inputs["moving_label"]
443
        pred_fixed_label = warping(inputs=[ddf, moving_label])
444
445
        self._outputs["pred_fixed_label"] = pred_fixed_label
446
        return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
447
448
    def postprocess(
449
        self,
450
        inputs: Dict[str, tf.Tensor],
451
        outputs: Dict[str, tf.Tensor],
452
    ) -> (tf.Tensor, Dict):
453
        """
454
        Return a dict used for saving inputs and outputs.
455
456
        :param inputs: dict of model inputs
457
        :param outputs: dict of model outputs
458
        :return: tuple, indices and a dict.
459
            In the dict, each value is (tensor, normalize, on_label), where
460
            - normalize = True if the tensor need to be normalized to [0, 1]
461
            - on_label = True if the tensor depends on label
462
        """
463
        indices, processed = super().postprocess(inputs=inputs, outputs=outputs)
464
        processed["dvf"] = (outputs["dvf"], True, False)
465
        return indices, processed
466
467
468
@REGISTRY.register_model(name="conditional")
469
class ConditionalModel(RegistrationModel):
470
    """
471
    A registration model predicts fixed image label without DDF or DVF.
472
    """
473
474
    def build_model(self):
475
        """Build the model to be saved as self._model."""
476
        assert self.labeled
477
478
        # build inputs
479
        self._inputs = self.build_inputs()
480
        moving_image = self._inputs["moving_image"]
481
        fixed_image = self._inputs["fixed_image"]
482
        moving_label = self._inputs["moving_label"]
483
484
        # build ddf
485
        backbone_inputs = self.concat_images(moving_image, fixed_image, moving_label)
486
        backbone = REGISTRY.build_backbone(
487
            config=self.config["backbone"],
488
            default_args=dict(
489
                image_size=self.fixed_image_size,
490
                out_channels=1,
491
                out_kernel_initializer="glorot_uniform",
492
                out_activation="sigmoid",
493
            ),
494
        )
495
        # (batch, f_dim1, f_dim2, f_dim3)
496
        pred_fixed_label = backbone(inputs=backbone_inputs)
497
        pred_fixed_label = tf.squeeze(pred_fixed_label, axis=4)
498
499
        self._outputs = dict(pred_fixed_label=pred_fixed_label)
500
        return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
501
502
    def build_loss(self):
503
        """Build losses according to configs."""
504
        fixed_label = self._inputs["fixed_label"]
505
        pred_fixed_label = self._outputs["pred_fixed_label"]
506
507
        self._build_loss(
508
            name="label",
509
            inputs_dict=dict(y_true=fixed_label, y_pred=pred_fixed_label),
510
        )
511
512
    def postprocess(
513
        self,
514
        inputs: Dict[str, tf.Tensor],
515
        outputs: Dict[str, tf.Tensor],
516
    ) -> (tf.Tensor, Dict):
517
        """
518
        Return a dict used for saving inputs and outputs.
519
520
        :param inputs: dict of model inputs
521
        :param outputs: dict of model outputs
522
        :return: tuple, indices and a dict.
523
            In the dict, each value is (tensor, normalize, on_label), where
524
            - normalize = True if the tensor need to be normalized to [0, 1]
525
            - on_label = True if the tensor depends on label
526
        """
527
        indices = inputs["indices"]
528
        processed = dict(
529
            moving_image=(inputs["moving_image"], True, False),
530
            fixed_image=(inputs["fixed_image"], True, False),
531
            pred_fixed_label=(outputs["pred_fixed_label"], True, True),
532
            moving_label=(inputs["moving_label"], False, True),
533
            fixed_label=(inputs["fixed_label"], False, True),
534
        )
535
536
        return indices, processed
537