Passed
Pull Request — main (#662)
by
unknown
03:45
created

deepreg.model.network.RegistrationModel.__init__()   A

Complexity

Conditions 1

Size

Total Lines 39
Code Lines 23

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 23
dl 0
loc 39
rs 9.328
c 0
b 0
f 0
cc 1
nop 9

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
import logging
0 ignored issues
show
introduced by
Missing module docstring
Loading history...
2
from abc import abstractmethod
3
from copy import deepcopy
4
from typing import Dict, Optional
5
6
import tensorflow as tf
7
8
from deepreg.model import layer, layer_util
9
from deepreg.model.backbone import GlobalNet
10
from deepreg.registry import REGISTRY
11
12
13
def dict_without(d: dict, key) -> dict:
0 ignored issues
show
introduced by
"key" missing in parameter type documentation
Loading history...
14
    """
15
    Return a copy of the given dict without a certain key.
16
17
    :param d: dict to be copied.
18
    :param key: key to be removed.
19
    :return: the copy without a key
20
    """
21
    copied = deepcopy(d)
22
    copied.pop(key)
23
    return copied
24
25
26
class RegistrationModel(tf.keras.Model):
27
    """Interface for registration model."""
28
29
    def __init__(
30
        self,
31
        moving_image_size: tuple,
32
        fixed_image_size: tuple,
33
        index_size: int,
34
        labeled: bool,
35
        batch_size: int,
36
        config: dict,
37
        num_devices: int = 1,
38
        name: str = "RegistrationModel",
39
    ):
40
        """
41
        Init.
42
43
        :param moving_image_size: (m_dim1, m_dim2, m_dim3)
44
        :param fixed_image_size: (f_dim1, f_dim2, f_dim3)
45
        :param index_size: number of indices for identify each sample
46
        :param labeled: if the data is labeled
47
        :param batch_size: size of mini-batch
48
        :param config: config for method, backbone, and loss.
49
        :param num_devices: number of GPU used,
50
            global_batch_size = batch_size*num_devices
51
        :param name: name of the model
52
        """
53
        super().__init__(name=name)
54
        self.moving_image_size = moving_image_size
55
        self.fixed_image_size = fixed_image_size
56
        self.index_size = index_size
57
        self.labeled = labeled
58
        self.batch_size = batch_size
59
        self.config = config
60
        self.num_devices = num_devices
61
        self.global_batch_size = num_devices * batch_size
62
63
        self._inputs = None  # save inputs of self._model as dict
64
        self._outputs = None  # save outputs of self._model as dict
65
66
        self._model = self.build_model()
67
        self.build_loss()
68
69
    def get_config(self) -> dict:
70
        """Return the config dictionary for recreating this class."""
71
        return dict(
72
            moving_image_size=self.moving_image_size,
73
            fixed_image_size=self.fixed_image_size,
74
            index_size=self.index_size,
75
            labeled=self.labeled,
76
            batch_size=self.batch_size,
77
            config=self.config,
78
            num_devices=self.num_devices,
79
            name=self.name,
80
        )
81
82
    @abstractmethod
83
    def build_model(self):
84
        """Build the model to be saved as self._model."""
85
86
    def build_inputs(self) -> Dict[str, tf.keras.layers.Input]:
87
        """
88
        Build input tensors.
89
90
        :return: dict of inputs.
91
        """
92
        # (batch, m_dim1, m_dim2, m_dim3, 1)
93
        moving_image = tf.keras.Input(
94
            shape=self.moving_image_size,
95
            batch_size=self.batch_size,
96
            name="moving_image",
97
        )
98
        # (batch, f_dim1, f_dim2, f_dim3, 1)
99
        fixed_image = tf.keras.Input(
100
            shape=self.fixed_image_size,
101
            batch_size=self.batch_size,
102
            name="fixed_image",
103
        )
104
        # (batch, index_size)
105
        indices = tf.keras.Input(
106
            shape=(self.index_size,),
107
            batch_size=self.batch_size,
108
            name="indices",
109
        )
110
111
        if not self.labeled:
112
            return dict(
113
                moving_image=moving_image, fixed_image=fixed_image, indices=indices
114
            )
115
116
        # (batch, m_dim1, m_dim2, m_dim3, 1)
117
        moving_label = tf.keras.Input(
118
            shape=self.moving_image_size,
119
            batch_size=self.batch_size,
120
            name="moving_label",
121
        )
122
        # (batch, m_dim1, m_dim2, m_dim3, 1)
123
        fixed_label = tf.keras.Input(
124
            shape=self.fixed_image_size,
125
            batch_size=self.batch_size,
126
            name="fixed_label",
127
        )
128
        return dict(
129
            moving_image=moving_image,
130
            fixed_image=fixed_image,
131
            moving_label=moving_label,
132
            fixed_label=fixed_label,
133
            indices=indices,
134
        )
135
136
    def concat_images(
137
        self,
138
        moving_image: tf.Tensor,
139
        fixed_image: tf.Tensor,
140
        moving_label: Optional[tf.Tensor] = None,
141
    ) -> tf.Tensor:
142
        """
143
        Adjust image shape and concatenate them together.
144
145
        :param moving_image: registration source
146
        :param fixed_image: registration target
147
        :param moving_label: optional, only used for conditional model.
148
        :return:
149
        """
150
        images = []
151
152
        # (batch, m_dim1, m_dim2, m_dim3, 1)
153
        moving_image = tf.expand_dims(moving_image, axis=4)
154
        moving_image = layer_util.resize3d(
155
            image=moving_image, size=self.fixed_image_size
156
        )
157
        images.append(moving_image)
158
159
        # (batch, m_dim1, m_dim2, m_dim3, 1)
160
        fixed_image = tf.expand_dims(fixed_image, axis=4)
161
        images.append(fixed_image)
162
163
        # (batch, m_dim1, m_dim2, m_dim3, 1)
164
        if moving_label is not None:
165
            moving_label = tf.expand_dims(moving_label, axis=4)
166
            moving_label = layer_util.resize3d(
167
                image=moving_label, size=self.fixed_image_size
168
            )
169
            images.append(moving_label)
170
171
        # (batch, f_dim1, f_dim2, f_dim3, 2 or 3)
172
        images = tf.concat(images, axis=4)
173
        return images
174
175
    def _build_loss(self, name: str, inputs_dict: dict):
176
        """
177
        Build and add one weighted loss together with the metrics.
178
179
        :param name: name of loss
180
        :param inputs_dict: inputs for loss function
181
        """
182
183
        if name not in self.config["loss"]:
184
            # loss config is not defined
185
            logging.warning(
0 ignored issues
show
introduced by
Use lazy % formatting in logging functions
Loading history...
186
                f"The configuration for loss {name} is not defined."
187
                f"Loss is not used."
188
            )
189
            return
190
191
        loss_list = self.config["loss"][name]
192
        if not isinstance(loss_list, list):
193
            loss_list = [loss_list]
194
195
        for loss_config in loss_list:
196
197
            if "weight" not in loss_config:
198
                # default loss weight 1
199
                logging.warning(
0 ignored issues
show
introduced by
Use lazy % formatting in logging functions
Loading history...
200
                    f"The weight for loss {name} is not defined."
201
                    f"Default weight = 1.0 is used."
202
                )
203
                loss_config["weight"] = 1.0
204
205
            # build loss
206
            weight = loss_config["weight"]
207
208
            if weight == 0:
209
                logging.warning(
0 ignored issues
show
introduced by
Use lazy % formatting in logging functions
Loading history...
210
                    f"The weight for loss {name} is zero." f"Loss is not used."
211
                )
212
                return
213
214
            loss_cls = REGISTRY.build_loss(
215
                config=dict_without(d=loss_config, key="weight")
216
            )
217
            loss = loss_cls(**inputs_dict) / self.global_batch_size
218
            weighted_loss = loss * weight
219
220
            # add loss
221
            self._model.add_loss(weighted_loss)
222
223
            # add metric
224
            self._model.add_metric(
225
                loss, name=f"loss/{name}_{loss_cls.name}", aggregation="mean"
226
            )
227
            self._model.add_metric(
228
                weighted_loss,
229
                name=f"loss/{name}_{loss_cls.name}_weighted",
230
                aggregation="mean",
231
            )
232
233
    @abstractmethod
234
    def build_loss(self):
235
        """Build losses according to configs."""
236
237
    def call(
0 ignored issues
show
introduced by
"mask, training" missing in parameter type documentation
Loading history...
238
        self, inputs: Dict[str, tf.Tensor], training=None, mask=None
239
    ) -> Dict[str, tf.Tensor]:
240
        """
241
        Call the self._model.
242
243
        :param inputs: a dict of tensors.
244
        :param training: training or not.
245
        :param mask: maks for inputs.
246
        :return:
247
        """
248
        return self._model(inputs, training=training, mask=mask)  # pragma: no cover
249
250
    @abstractmethod
251
    def postprocess(
252
        self,
253
        inputs: Dict[str, tf.Tensor],
254
        outputs: Dict[str, tf.Tensor],
255
    ) -> (tf.Tensor, Dict):
256
        """
257
        Return a dict used for saving inputs and outputs.
258
259
        :param inputs: dict of model inputs
260
        :param outputs: dict of model outputs
261
        :return: tuple, indices and a dict.
262
            In the dict, each value is (tensor, normalize, on_label), where
263
            - normalize = True if the tensor need to be normalized to [0, 1]
264
            - on_label = True if the tensor depends on label
265
        """
266
267
268
@REGISTRY.register_model(name="ddf")
269
class DDFModel(RegistrationModel):
270
    """
271
    A registration model predicts DDF.
272
273
    When using global net as backbone,
274
    the model predicts an affine transformation parameters,
275
    and a DDF is calculated based on that.
276
    """
277
278
    def build_model(self):
279
        """Build the model to be saved as self._model."""
280
        # build inputs
281
        self._inputs = self.build_inputs()
282
        moving_image = self._inputs["moving_image"]
283
        fixed_image = self._inputs["fixed_image"]
284
285
        # build ddf
286
        backbone_inputs = self.concat_images(moving_image, fixed_image)
287
        backbone = REGISTRY.build_backbone(
288
            config=self.config["backbone"],
289
            default_args=dict(
290
                image_size=self.fixed_image_size,
291
                out_channels=3,
292
                out_kernel_initializer="zeros",
293
                out_activation=None,
294
            ),
295
        )
296
297
        if isinstance(backbone, GlobalNet):
298
            # (f_dim1, f_dim2, f_dim3, 3), (4, 3)
299
            ddf, theta = backbone(inputs=backbone_inputs)
300
            self._outputs = dict(ddf=ddf, theta=theta)
301
        else:
302
            # (f_dim1, f_dim2, f_dim3, 3)
303
            ddf = backbone(inputs=backbone_inputs)
304
            self._outputs = dict(ddf=ddf)
305
306
        # build outputs
307
        warping = layer.Warping(fixed_image_size=self.fixed_image_size)
308
        # (f_dim1, f_dim2, f_dim3, 3)
309
        pred_fixed_image = warping(inputs=[ddf, moving_image])
310
        self._outputs["pred_fixed_image"] = pred_fixed_image
311
312
        if not self.labeled:
313
            return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
314
315
        # (f_dim1, f_dim2, f_dim3, 3)
316
        moving_label = self._inputs["moving_label"]
317
        pred_fixed_label = warping(inputs=[ddf, moving_label])
318
319
        self._outputs["pred_fixed_label"] = pred_fixed_label
320
        return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
321
322
    def build_loss(self):
323
        """Build losses according to configs."""
324
        fixed_image = self._inputs["fixed_image"]
325
        ddf = self._outputs["ddf"]
326
        pred_fixed_image = self._outputs["pred_fixed_image"]
327
328
        # ddf
329
        self._build_loss(name="regularization", inputs_dict=dict(inputs=ddf))
330
331
        # image
332
        self._build_loss(
333
            name="image", inputs_dict=dict(y_true=fixed_image, y_pred=pred_fixed_image)
334
        )
335
336
        # label
337
        if self.labeled:
338
            fixed_label = self._inputs["fixed_label"]
339
            pred_fixed_label = self._outputs["pred_fixed_label"]
340
            self._build_loss(
341
                name="label",
342
                inputs_dict=dict(y_true=fixed_label, y_pred=pred_fixed_label),
343
            )
344
345
    def postprocess(
346
        self,
347
        inputs: Dict[str, tf.Tensor],
348
        outputs: Dict[str, tf.Tensor],
349
    ) -> (tf.Tensor, Dict):
350
        """
351
        Return a dict used for saving inputs and outputs.
352
353
        :param inputs: dict of model inputs
354
        :param outputs: dict of model outputs
355
        :return: tuple, indices and a dict.
356
            In the dict, each value is (tensor, normalize, on_label), where
357
            - normalize = True if the tensor need to be normalized to [0, 1]
358
            - on_label = True if the tensor depends on label
359
        """
360
        indices = inputs["indices"]
361
        processed = dict(
362
            moving_image=(inputs["moving_image"], True, False),
363
            fixed_image=(inputs["fixed_image"], True, False),
364
            ddf=(outputs["ddf"], True, False),
365
            pred_fixed_image=(outputs["pred_fixed_image"], True, False),
366
        )
367
368
        # save theta for affine model
369
        if "theta" in outputs:
370
            processed["theta"] = (outputs["theta"], None, None)
371
372
        if not self.labeled:
373
            return indices, processed
374
375
        processed = {
376
            **dict(
377
                moving_label=(inputs["moving_label"], False, True),
378
                fixed_label=(inputs["fixed_label"], False, True),
379
                pred_fixed_label=(outputs["pred_fixed_label"], False, True),
380
            ),
381
            **processed,
382
        }
383
384
        return indices, processed
385
386
387
@REGISTRY.register_model(name="dvf")
388
class DVFModel(DDFModel):
389
    """
390
    A registration model predicts DVF.
391
392
    DDF is calculated based on DVF.
393
    """
394
395
    def build_model(self):
396
        """Build the model to be saved as self._model."""
397
        # build inputs
398
        self._inputs = self.build_inputs()
399
        moving_image = self._inputs["moving_image"]
400
        fixed_image = self._inputs["fixed_image"]
401
402
        # build ddf
403
        backbone_inputs = self.concat_images(moving_image, fixed_image)
404
        backbone = REGISTRY.build_backbone(
405
            config=self.config["backbone"],
406
            default_args=dict(
407
                image_size=self.fixed_image_size,
408
                out_channels=3,
409
                out_kernel_initializer="zeros",
410
                out_activation=None,
411
            ),
412
        )
413
        dvf = backbone(inputs=backbone_inputs)
414
        ddf = layer.IntDVF(fixed_image_size=self.fixed_image_size)(dvf)
415
416
        # build outputs
417
        warping = layer.Warping(fixed_image_size=self.fixed_image_size)
418
        # (f_dim1, f_dim2, f_dim3, 3)
419
        pred_fixed_image = warping(inputs=[ddf, moving_image])
420
421
        self._outputs = dict(dvf=dvf, ddf=ddf, pred_fixed_image=pred_fixed_image)
422
423
        if not self.labeled:
424
            return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
425
426
        # (f_dim1, f_dim2, f_dim3, 3)
427
        moving_label = self._inputs["moving_label"]
428
        pred_fixed_label = warping(inputs=[ddf, moving_label])
429
430
        self._outputs["pred_fixed_label"] = pred_fixed_label
431
        return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
432
433
    def postprocess(
434
        self,
435
        inputs: Dict[str, tf.Tensor],
436
        outputs: Dict[str, tf.Tensor],
437
    ) -> (tf.Tensor, Dict):
438
        """
439
        Return a dict used for saving inputs and outputs.
440
441
        :param inputs: dict of model inputs
442
        :param outputs: dict of model outputs
443
        :return: tuple, indices and a dict.
444
            In the dict, each value is (tensor, normalize, on_label), where
445
            - normalize = True if the tensor need to be normalized to [0, 1]
446
            - on_label = True if the tensor depends on label
447
        """
448
        indices, processed = super().postprocess(inputs=inputs, outputs=outputs)
449
        processed["dvf"] = (outputs["dvf"], True, False)
450
        return indices, processed
451
452
453
@REGISTRY.register_model(name="conditional")
454
class ConditionalModel(RegistrationModel):
455
    """
456
    A registration model predicts fixed image label without DDF or DVF.
457
    """
458
459
    def build_model(self):
460
        """Build the model to be saved as self._model."""
461
        assert self.labeled
462
463
        # build inputs
464
        self._inputs = self.build_inputs()
465
        moving_image = self._inputs["moving_image"]
466
        fixed_image = self._inputs["fixed_image"]
467
        moving_label = self._inputs["moving_label"]
468
469
        # build ddf
470
        backbone_inputs = self.concat_images(moving_image, fixed_image, moving_label)
471
        backbone = REGISTRY.build_backbone(
472
            config=self.config["backbone"],
473
            default_args=dict(
474
                image_size=self.fixed_image_size,
475
                out_channels=1,
476
                out_kernel_initializer="glorot_uniform",
477
                out_activation="sigmoid",
478
            ),
479
        )
480
        # (batch, f_dim1, f_dim2, f_dim3)
481
        pred_fixed_label = backbone(inputs=backbone_inputs)
482
        pred_fixed_label = tf.squeeze(pred_fixed_label, axis=4)
483
484
        self._outputs = dict(pred_fixed_label=pred_fixed_label)
485
        return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
486
487
    def build_loss(self):
488
        """Build losses according to configs."""
489
        fixed_label = self._inputs["fixed_label"]
490
        pred_fixed_label = self._outputs["pred_fixed_label"]
491
492
        self._build_loss(
493
            name="label",
494
            inputs_dict=dict(y_true=fixed_label, y_pred=pred_fixed_label),
495
        )
496
497
    def postprocess(
498
        self,
499
        inputs: Dict[str, tf.Tensor],
500
        outputs: Dict[str, tf.Tensor],
501
    ) -> (tf.Tensor, Dict):
502
        """
503
        Return a dict used for saving inputs and outputs.
504
505
        :param inputs: dict of model inputs
506
        :param outputs: dict of model outputs
507
        :return: tuple, indices and a dict.
508
            In the dict, each value is (tensor, normalize, on_label), where
509
            - normalize = True if the tensor need to be normalized to [0, 1]
510
            - on_label = True if the tensor depends on label
511
        """
512
        indices = inputs["indices"]
513
        processed = dict(
514
            moving_image=(inputs["moving_image"], True, False),
515
            fixed_image=(inputs["fixed_image"], True, False),
516
            pred_fixed_label=(outputs["pred_fixed_label"], True, True),
517
            moving_label=(inputs["moving_label"], False, True),
518
            fixed_label=(inputs["fixed_label"], False, True),
519
        )
520
521
        return indices, processed
522