Completed
Push — main ( 72b597...0bdbdf )
by Yunguan
23s queued 12s
created

deepreg.model.network.RegistrationModel.__init__()   A

Complexity

Conditions 1

Size

Total Lines 38
Code Lines 23

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 23
dl 0
loc 38
rs 9.328
c 0
b 0
f 0
cc 1
nop 9

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
import logging
0 ignored issues
show
introduced by
Missing module docstring
Loading history...
2
from abc import abstractmethod
3
from copy import deepcopy
4
from typing import Dict, Optional
5
6
import tensorflow as tf
7
8
from deepreg.model import layer, layer_util
9
from deepreg.model.backbone import GlobalNet
10
from deepreg.registry import REGISTRY
11
12
13
def dict_without(d: dict, key) -> dict:
0 ignored issues
show
introduced by
"key" missing in parameter type documentation
Loading history...
14
    """
15
    Return a copy of the given dict without a certain key.
16
17
    :param d: dict to be copied.
18
    :param key: key to be removed.
19
    :return: the copy without a key
20
    """
21
    copied = deepcopy(d)
22
    copied.pop(key)
23
    return copied
24
25
26
class RegistrationModel(tf.keras.Model):
27
    """Interface for registration model."""
28
29
    def __init__(
30
        self,
31
        moving_image_size: tuple,
32
        fixed_image_size: tuple,
33
        index_size: int,
34
        labeled: bool,
35
        batch_size: int,
36
        config: dict,
37
        num_devices: int = 1,
38
        name: str = "RegistrationModel",
39
    ):
40
        """
41
        Init.
42
43
        :param moving_image_size: (m_dim1, m_dim2, m_dim3)
44
        :param fixed_image_size: (f_dim1, f_dim2, f_dim3)
45
        :param index_size: number of indices for identify each sample
46
        :param labeled: if the data is labeled
47
        :param batch_size: size of mini-batch
48
        :param config: config for method, backbone, and loss.
49
        :param num_devices: number of GPU used,
50
            global_batch_size = batch_size*num_devices
51
        :param name: name of the model
52
        """
53
        super().__init__(name=name)
54
        self.moving_image_size = moving_image_size
55
        self.fixed_image_size = fixed_image_size
56
        self.index_size = index_size
57
        self.labeled = labeled
58
        self.batch_size = batch_size
59
        self.config = config
60
        self.num_devices = num_devices
61
        self.global_batch_size = num_devices * batch_size
62
63
        self._inputs = None  # save inputs of self._model as dict
64
        self._outputs = None  # save outputs of self._model as dict
65
        self._model = self.build_model()
66
        self.build_loss()
67
68
    def get_config(self) -> dict:
69
        """Return the config dictionary for recreating this class."""
70
        return dict(
71
            moving_image_size=self.moving_image_size,
72
            fixed_image_size=self.fixed_image_size,
73
            index_size=self.index_size,
74
            labeled=self.labeled,
75
            batch_size=self.batch_size,
76
            config=self.config,
77
            num_devices=self.num_devices,
78
            name=self.name,
79
        )
80
81
    @abstractmethod
82
    def build_model(self):
83
        """Build the model to be saved as self._model."""
84
85
    def build_inputs(self) -> Dict[str, tf.keras.layers.Input]:
86
        """
87
        Build input tensors.
88
89
        :return: dict of inputs.
90
        """
91
        # (batch, m_dim1, m_dim2, m_dim3, 1)
92
        moving_image = tf.keras.Input(
93
            shape=self.moving_image_size,
94
            batch_size=self.batch_size,
95
            name="moving_image",
96
        )
97
        # (batch, f_dim1, f_dim2, f_dim3, 1)
98
        fixed_image = tf.keras.Input(
99
            shape=self.fixed_image_size,
100
            batch_size=self.batch_size,
101
            name="fixed_image",
102
        )
103
        # (batch, index_size)
104
        indices = tf.keras.Input(
105
            shape=(self.index_size,),
106
            batch_size=self.batch_size,
107
            name="indices",
108
        )
109
110
        if not self.labeled:
111
            return dict(
112
                moving_image=moving_image, fixed_image=fixed_image, indices=indices
113
            )
114
115
        # (batch, m_dim1, m_dim2, m_dim3, 1)
116
        moving_label = tf.keras.Input(
117
            shape=self.moving_image_size,
118
            batch_size=self.batch_size,
119
            name="moving_label",
120
        )
121
        # (batch, m_dim1, m_dim2, m_dim3, 1)
122
        fixed_label = tf.keras.Input(
123
            shape=self.fixed_image_size,
124
            batch_size=self.batch_size,
125
            name="fixed_label",
126
        )
127
        return dict(
128
            moving_image=moving_image,
129
            fixed_image=fixed_image,
130
            moving_label=moving_label,
131
            fixed_label=fixed_label,
132
            indices=indices,
133
        )
134
135
    def concat_images(
136
        self,
137
        moving_image: tf.Tensor,
138
        fixed_image: tf.Tensor,
139
        moving_label: Optional[tf.Tensor] = None,
140
    ) -> tf.Tensor:
141
        """
142
        Adjust image shape and concatenate them together.
143
144
        :param moving_image: registration source
145
        :param fixed_image: registration target
146
        :param moving_label: optional, only used for conditional model.
147
        :return:
148
        """
149
        images = []
150
151
        # (batch, m_dim1, m_dim2, m_dim3, 1)
152
        moving_image = tf.expand_dims(moving_image, axis=4)
153
        moving_image = layer_util.resize3d(
154
            image=moving_image, size=self.fixed_image_size
155
        )
156
        images.append(moving_image)
157
158
        # (batch, m_dim1, m_dim2, m_dim3, 1)
159
        fixed_image = tf.expand_dims(fixed_image, axis=4)
160
        images.append(fixed_image)
161
162
        # (batch, m_dim1, m_dim2, m_dim3, 1)
163
        if moving_label is not None:
164
            moving_label = tf.expand_dims(moving_label, axis=4)
165
            moving_label = layer_util.resize3d(
166
                image=moving_label, size=self.fixed_image_size
167
            )
168
            images.append(moving_label)
169
170
        # (batch, f_dim1, f_dim2, f_dim3, 2 or 3)
171
        images = tf.concat(images, axis=4)
172
        return images
173
174
    def _build_loss(self, name: str, inputs_dict: dict):
175
        """
176
        Build and add one weighted loss together with the metrics.
177
178
        :param name: name of loss
179
        :param inputs_dict: inputs for loss function
180
        """
181
        if name not in self.config["loss"]:
182
            # loss config is not defined
183
            logging.warning(
0 ignored issues
show
introduced by
Use lazy % formatting in logging functions
Loading history...
184
                f"The configuration for loss {name} is not defined."
185
                f"Loss is not used."
186
            )
187
            return
188
189
        loss_config = self.config["loss"][name]
190
191
        if "weight" not in loss_config:
192
            # default loss weight 1
193
            logging.warning(
0 ignored issues
show
introduced by
Use lazy % formatting in logging functions
Loading history...
194
                f"The weight for loss {name} is not defined."
195
                f"Default weight = 1.0 is used."
196
            )
197
            loss_config["weight"] = 1.0
198
199
        # build loss
200
        weight = loss_config["weight"]
201
202
        if weight == 0:
203
            logging.warning(f"The weight for loss {name} is zero." f"Loss is not used.")
0 ignored issues
show
introduced by
Use lazy % formatting in logging functions
Loading history...
204
            return
205
206
        loss_cls = REGISTRY.build_loss(config=dict_without(d=loss_config, key="weight"))
207
        loss = loss_cls(**inputs_dict) / self.global_batch_size
208
        weighted_loss = loss * weight
209
210
        # add loss
211
        self._model.add_loss(weighted_loss)
212
213
        # add metric
214
        self._model.add_metric(
215
            loss, name=f"loss/{name}_{loss_cls.name}", aggregation="mean"
216
        )
217
        self._model.add_metric(
218
            weighted_loss,
219
            name=f"loss/{name}_{loss_cls.name}_weighted",
220
            aggregation="mean",
221
        )
222
223
    @abstractmethod
224
    def build_loss(self):
225
        """Build losses according to configs."""
226
227
    def call(
0 ignored issues
show
introduced by
"mask, training" missing in parameter type documentation
Loading history...
228
        self, inputs: Dict[str, tf.Tensor], training=None, mask=None
229
    ) -> Dict[str, tf.Tensor]:
230
        """
231
        Call the self._model.
232
233
        :param inputs: a dict of tensors.
234
        :param training: training or not.
235
        :param mask: maks for inputs.
236
        :return:
237
        """
238
        return self._model(inputs, training=training, mask=mask)  # pragma: no cover
239
240
    @abstractmethod
241
    def postprocess(
242
        self,
243
        inputs: Dict[str, tf.Tensor],
244
        outputs: Dict[str, tf.Tensor],
245
    ) -> (tf.Tensor, Dict):
246
        """
247
        Return a dict used for saving inputs and outputs.
248
249
        :param inputs: dict of model inputs
250
        :param outputs: dict of model outputs
251
        :return: tuple, indices and a dict.
252
            In the dict, each value is (tensor, normalize, on_label), where
253
            - normalize = True if the tensor need to be normalized to [0, 1]
254
            - on_label = True if the tensor depends on label
255
        """
256
257
258
@REGISTRY.register_model(name="ddf")
259
class DDFModel(RegistrationModel):
260
    """
261
    A registration model predicts DDF.
262
263
    When using global net as backbone,
264
    the model predicts an affine transformation parameters,
265
    and a DDF is calculated based on that.
266
    """
267
268
    def _resize_interpolate(self, field, control_points):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
269
        resize = layer.ResizeCPTransform(control_points)
270
        field = resize(field)
271
272
        interpolate = layer.BSplines3DTransform(control_points, self.fixed_image_size)
273
        field = interpolate(field)
274
275
        return field
276
277
    def build_model(self):
278
        """Build the model to be saved as self._model."""
279
        # build inputs
280
        self._inputs = self.build_inputs()
281
        moving_image = self._inputs["moving_image"]
282
        fixed_image = self._inputs["fixed_image"]
283
284
        # build ddf
285
        control_points = self.config["backbone"].pop("control_points", False)
286
        backbone_inputs = self.concat_images(moving_image, fixed_image)
287
        backbone = REGISTRY.build_backbone(
288
            config=self.config["backbone"],
289
            default_args=dict(
290
                image_size=self.fixed_image_size,
291
                out_channels=3,
292
                out_kernel_initializer="zeros",
293
                out_activation=None,
294
            ),
295
        )
296
297
        if isinstance(backbone, GlobalNet):
298
            # (f_dim1, f_dim2, f_dim3, 3), (4, 3)
299
            ddf, theta = backbone(inputs=backbone_inputs)
300
            self._outputs = dict(ddf=ddf, theta=theta)
301
        else:
302
            # (f_dim1, f_dim2, f_dim3, 3)
303
            ddf = backbone(inputs=backbone_inputs)
304
            ddf = (
305
                self._resize_interpolate(ddf, control_points) if control_points else ddf
306
            )
307
            self._outputs = dict(ddf=ddf)
308
309
        # build outputs
310
        warping = layer.Warping(fixed_image_size=self.fixed_image_size)
311
        # (f_dim1, f_dim2, f_dim3, 3)
312
        pred_fixed_image = warping(inputs=[ddf, moving_image])
313
        self._outputs["pred_fixed_image"] = pred_fixed_image
314
315
        if not self.labeled:
316
            return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
317
318
        # (f_dim1, f_dim2, f_dim3, 3)
319
        moving_label = self._inputs["moving_label"]
320
        pred_fixed_label = warping(inputs=[ddf, moving_label])
321
322
        self._outputs["pred_fixed_label"] = pred_fixed_label
323
        return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
324
325
    def build_loss(self):
326
        """Build losses according to configs."""
327
        fixed_image = self._inputs["fixed_image"]
328
        ddf = self._outputs["ddf"]
329
        pred_fixed_image = self._outputs["pred_fixed_image"]
330
331
        # ddf
332
        self._build_loss(name="regularization", inputs_dict=dict(inputs=ddf))
333
334
        # image
335
        self._build_loss(
336
            name="image", inputs_dict=dict(y_true=fixed_image, y_pred=pred_fixed_image)
337
        )
338
339
        # label
340
        if self.labeled:
341
            fixed_label = self._inputs["fixed_label"]
342
            pred_fixed_label = self._outputs["pred_fixed_label"]
343
            self._build_loss(
344
                name="label",
345
                inputs_dict=dict(y_true=fixed_label, y_pred=pred_fixed_label),
346
            )
347
348
    def postprocess(
349
        self,
350
        inputs: Dict[str, tf.Tensor],
351
        outputs: Dict[str, tf.Tensor],
352
    ) -> (tf.Tensor, Dict):
353
        """
354
        Return a dict used for saving inputs and outputs.
355
356
        :param inputs: dict of model inputs
357
        :param outputs: dict of model outputs
358
        :return: tuple, indices and a dict.
359
            In the dict, each value is (tensor, normalize, on_label), where
360
            - normalize = True if the tensor need to be normalized to [0, 1]
361
            - on_label = True if the tensor depends on label
362
        """
363
        indices = inputs["indices"]
364
        processed = dict(
365
            moving_image=(inputs["moving_image"], True, False),
366
            fixed_image=(inputs["fixed_image"], True, False),
367
            ddf=(outputs["ddf"], True, False),
368
            pred_fixed_image=(outputs["pred_fixed_image"], True, False),
369
        )
370
371
        # save theta for affine model
372
        if "theta" in outputs:
373
            processed["theta"] = (outputs["theta"], None, None)
374
375
        if not self.labeled:
376
            return indices, processed
377
378
        processed = {
379
            **dict(
380
                moving_label=(inputs["moving_label"], False, True),
381
                fixed_label=(inputs["fixed_label"], False, True),
382
                pred_fixed_label=(outputs["pred_fixed_label"], False, True),
383
            ),
384
            **processed,
385
        }
386
387
        return indices, processed
388
389
390
@REGISTRY.register_model(name="dvf")
391
class DVFModel(DDFModel):
392
    """
393
    A registration model predicts DVF.
394
395
    DDF is calculated based on DVF.
396
    """
397
398
    def build_model(self):
399
        """Build the model to be saved as self._model."""
400
        # build inputs
401
        self._inputs = self.build_inputs()
402
        moving_image = self._inputs["moving_image"]
403
        fixed_image = self._inputs["fixed_image"]
404
        control_points = self.config["backbone"].pop("control_points", False)
405
406
        # build ddf
407
        backbone_inputs = self.concat_images(moving_image, fixed_image)
408
        backbone = REGISTRY.build_backbone(
409
            config=self.config["backbone"],
410
            default_args=dict(
411
                image_size=self.fixed_image_size,
412
                out_channels=3,
413
                out_kernel_initializer="zeros",
414
                out_activation=None,
415
            ),
416
        )
417
        dvf = backbone(inputs=backbone_inputs)
418
        dvf = self._resize_interpolate(dvf, control_points) if control_points else dvf
419
        ddf = layer.IntDVF(fixed_image_size=self.fixed_image_size)(dvf)
420
421
        # build outputs
422
        warping = layer.Warping(fixed_image_size=self.fixed_image_size)
423
        # (f_dim1, f_dim2, f_dim3, 3)
424
        pred_fixed_image = warping(inputs=[ddf, moving_image])
425
426
        self._outputs = dict(dvf=dvf, ddf=ddf, pred_fixed_image=pred_fixed_image)
427
428
        if not self.labeled:
429
            return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
430
431
        # (f_dim1, f_dim2, f_dim3, 3)
432
        moving_label = self._inputs["moving_label"]
433
        pred_fixed_label = warping(inputs=[ddf, moving_label])
434
435
        self._outputs["pred_fixed_label"] = pred_fixed_label
436
        return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
437
438
    def postprocess(
439
        self,
440
        inputs: Dict[str, tf.Tensor],
441
        outputs: Dict[str, tf.Tensor],
442
    ) -> (tf.Tensor, Dict):
443
        """
444
        Return a dict used for saving inputs and outputs.
445
446
        :param inputs: dict of model inputs
447
        :param outputs: dict of model outputs
448
        :return: tuple, indices and a dict.
449
            In the dict, each value is (tensor, normalize, on_label), where
450
            - normalize = True if the tensor need to be normalized to [0, 1]
451
            - on_label = True if the tensor depends on label
452
        """
453
        indices, processed = super().postprocess(inputs=inputs, outputs=outputs)
454
        processed["dvf"] = (outputs["dvf"], True, False)
455
        return indices, processed
456
457
458
@REGISTRY.register_model(name="conditional")
459
class ConditionalModel(RegistrationModel):
460
    """
461
    A registration model predicts fixed image label without DDF or DVF.
462
    """
463
464
    def build_model(self):
465
        """Build the model to be saved as self._model."""
466
        assert self.labeled
467
468
        # build inputs
469
        self._inputs = self.build_inputs()
470
        moving_image = self._inputs["moving_image"]
471
        fixed_image = self._inputs["fixed_image"]
472
        moving_label = self._inputs["moving_label"]
473
474
        # build ddf
475
        backbone_inputs = self.concat_images(moving_image, fixed_image, moving_label)
476
        backbone = REGISTRY.build_backbone(
477
            config=self.config["backbone"],
478
            default_args=dict(
479
                image_size=self.fixed_image_size,
480
                out_channels=1,
481
                out_kernel_initializer="glorot_uniform",
482
                out_activation="sigmoid",
483
            ),
484
        )
485
        # (batch, f_dim1, f_dim2, f_dim3)
486
        pred_fixed_label = backbone(inputs=backbone_inputs)
487
        pred_fixed_label = tf.squeeze(pred_fixed_label, axis=4)
488
489
        self._outputs = dict(pred_fixed_label=pred_fixed_label)
490
        return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
491
492
    def build_loss(self):
493
        """Build losses according to configs."""
494
        fixed_label = self._inputs["fixed_label"]
495
        pred_fixed_label = self._outputs["pred_fixed_label"]
496
497
        self._build_loss(
498
            name="label",
499
            inputs_dict=dict(y_true=fixed_label, y_pred=pred_fixed_label),
500
        )
501
502
    def postprocess(
503
        self,
504
        inputs: Dict[str, tf.Tensor],
505
        outputs: Dict[str, tf.Tensor],
506
    ) -> (tf.Tensor, Dict):
507
        """
508
        Return a dict used for saving inputs and outputs.
509
510
        :param inputs: dict of model inputs
511
        :param outputs: dict of model outputs
512
        :return: tuple, indices and a dict.
513
            In the dict, each value is (tensor, normalize, on_label), where
514
            - normalize = True if the tensor need to be normalized to [0, 1]
515
            - on_label = True if the tensor depends on label
516
        """
517
        indices = inputs["indices"]
518
        processed = dict(
519
            moving_image=(inputs["moving_image"], True, False),
520
            fixed_image=(inputs["fixed_image"], True, False),
521
            pred_fixed_label=(outputs["pred_fixed_label"], True, True),
522
            moving_label=(inputs["moving_label"], False, True),
523
            fixed_label=(inputs["fixed_label"], False, True),
524
        )
525
526
        return indices, processed
527