Completed
Push — main ( 0bdbdf...808540 )
by
unknown
19s queued 12s
created

deepreg.model.network.RegistrationModel.__init__()   A

Complexity

Conditions 1

Size

Total Lines 39
Code Lines 23

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 23
dl 0
loc 39
rs 9.328
c 0
b 0
f 0
cc 1
nop 9

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
import logging
0 ignored issues
show
introduced by
Missing module docstring
Loading history...
2
from abc import abstractmethod
3
from copy import deepcopy
4
from typing import Dict, Optional
5
6
import tensorflow as tf
7
8
from deepreg.model import layer, layer_util
9
from deepreg.model.backbone import GlobalNet
10
from deepreg.registry import REGISTRY
11
12
13
def dict_without(d: dict, key) -> dict:
0 ignored issues
show
introduced by
"key" missing in parameter type documentation
Loading history...
14
    """
15
    Return a copy of the given dict without a certain key.
16
17
    :param d: dict to be copied.
18
    :param key: key to be removed.
19
    :return: the copy without a key
20
    """
21
    copied = deepcopy(d)
22
    copied.pop(key)
23
    return copied
24
25
26
class RegistrationModel(tf.keras.Model):
27
    """Interface for registration model."""
28
29
    def __init__(
30
        self,
31
        moving_image_size: tuple,
32
        fixed_image_size: tuple,
33
        index_size: int,
34
        labeled: bool,
35
        batch_size: int,
36
        config: dict,
37
        num_devices: int = 1,
38
        name: str = "RegistrationModel",
39
    ):
40
        """
41
        Init.
42
43
        :param moving_image_size: (m_dim1, m_dim2, m_dim3)
44
        :param fixed_image_size: (f_dim1, f_dim2, f_dim3)
45
        :param index_size: number of indices for identify each sample
46
        :param labeled: if the data is labeled
47
        :param batch_size: size of mini-batch
48
        :param config: config for method, backbone, and loss.
49
        :param num_devices: number of GPU used,
50
            global_batch_size = batch_size*num_devices
51
        :param name: name of the model
52
        """
53
        super().__init__(name=name)
54
        self.moving_image_size = moving_image_size
55
        self.fixed_image_size = fixed_image_size
56
        self.index_size = index_size
57
        self.labeled = labeled
58
        self.batch_size = batch_size
59
        self.config = config
60
        self.num_devices = num_devices
61
        self.global_batch_size = num_devices * batch_size
62
63
        self._inputs = None  # save inputs of self._model as dict
64
        self._outputs = None  # save outputs of self._model as dict
65
66
        self._model = self.build_model()
67
        self.build_loss()
68
69
    def get_config(self) -> dict:
70
        """Return the config dictionary for recreating this class."""
71
        return dict(
72
            moving_image_size=self.moving_image_size,
73
            fixed_image_size=self.fixed_image_size,
74
            index_size=self.index_size,
75
            labeled=self.labeled,
76
            batch_size=self.batch_size,
77
            config=self.config,
78
            num_devices=self.num_devices,
79
            name=self.name,
80
        )
81
82
    @abstractmethod
83
    def build_model(self):
84
        """Build the model to be saved as self._model."""
85
86
    def build_inputs(self) -> Dict[str, tf.keras.layers.Input]:
87
        """
88
        Build input tensors.
89
90
        :return: dict of inputs.
91
        """
92
        # (batch, m_dim1, m_dim2, m_dim3, 1)
93
        moving_image = tf.keras.Input(
94
            shape=self.moving_image_size,
95
            batch_size=self.batch_size,
96
            name="moving_image",
97
        )
98
        # (batch, f_dim1, f_dim2, f_dim3, 1)
99
        fixed_image = tf.keras.Input(
100
            shape=self.fixed_image_size,
101
            batch_size=self.batch_size,
102
            name="fixed_image",
103
        )
104
        # (batch, index_size)
105
        indices = tf.keras.Input(
106
            shape=(self.index_size,),
107
            batch_size=self.batch_size,
108
            name="indices",
109
        )
110
111
        if not self.labeled:
112
            return dict(
113
                moving_image=moving_image, fixed_image=fixed_image, indices=indices
114
            )
115
116
        # (batch, m_dim1, m_dim2, m_dim3, 1)
117
        moving_label = tf.keras.Input(
118
            shape=self.moving_image_size,
119
            batch_size=self.batch_size,
120
            name="moving_label",
121
        )
122
        # (batch, m_dim1, m_dim2, m_dim3, 1)
123
        fixed_label = tf.keras.Input(
124
            shape=self.fixed_image_size,
125
            batch_size=self.batch_size,
126
            name="fixed_label",
127
        )
128
        return dict(
129
            moving_image=moving_image,
130
            fixed_image=fixed_image,
131
            moving_label=moving_label,
132
            fixed_label=fixed_label,
133
            indices=indices,
134
        )
135
136
    def concat_images(
137
        self,
138
        moving_image: tf.Tensor,
139
        fixed_image: tf.Tensor,
140
        moving_label: Optional[tf.Tensor] = None,
141
    ) -> tf.Tensor:
142
        """
143
        Adjust image shape and concatenate them together.
144
145
        :param moving_image: registration source
146
        :param fixed_image: registration target
147
        :param moving_label: optional, only used for conditional model.
148
        :return:
149
        """
150
        images = []
151
152
        # (batch, m_dim1, m_dim2, m_dim3, 1)
153
        moving_image = tf.expand_dims(moving_image, axis=4)
154
        moving_image = layer_util.resize3d(
155
            image=moving_image, size=self.fixed_image_size
156
        )
157
        images.append(moving_image)
158
159
        # (batch, m_dim1, m_dim2, m_dim3, 1)
160
        fixed_image = tf.expand_dims(fixed_image, axis=4)
161
        images.append(fixed_image)
162
163
        # (batch, m_dim1, m_dim2, m_dim3, 1)
164
        if moving_label is not None:
165
            moving_label = tf.expand_dims(moving_label, axis=4)
166
            moving_label = layer_util.resize3d(
167
                image=moving_label, size=self.fixed_image_size
168
            )
169
            images.append(moving_label)
170
171
        # (batch, f_dim1, f_dim2, f_dim3, 2 or 3)
172
        images = tf.concat(images, axis=4)
173
        return images
174
175
    def _build_loss(self, name: str, inputs_dict: dict):
176
        """
177
        Build and add one weighted loss together with the metrics.
178
179
        :param name: name of loss
180
        :param inputs_dict: inputs for loss function
181
        """
182
183
        if name not in self.config["loss"]:
184
            # loss config is not defined
185
            logging.warning(
0 ignored issues
show
introduced by
Use lazy % formatting in logging functions
Loading history...
186
                f"The configuration for loss {name} is not defined."
187
                f"Loss is not used."
188
            )
189
            return
190
191
        loss_configs = self.config["loss"][name]
192
        if not isinstance(loss_configs, list):
193
            loss_configs = [loss_configs]
194
195
        for loss_config in loss_configs:
196
197
            if "weight" not in loss_config:
198
                # default loss weight 1
199
                logging.warning(
0 ignored issues
show
introduced by
Use lazy % formatting in logging functions
Loading history...
200
                    f"The weight for loss {name} is not defined."
201
                    f"Default weight = 1.0 is used."
202
                )
203
                loss_config["weight"] = 1.0
204
205
            # build loss
206
            weight = loss_config["weight"]
207
208
            if weight == 0:
209
                logging.warning(
0 ignored issues
show
introduced by
Use lazy % formatting in logging functions
Loading history...
210
                    f"The weight for loss {name} is zero." f"Loss is not used."
211
                )
212
                return
213
214
            loss_cls = REGISTRY.build_loss(
215
                config=dict_without(d=loss_config, key="weight")
216
            )
217
            loss = loss_cls(**inputs_dict) / self.global_batch_size
218
            weighted_loss = loss * weight
219
220
            # add loss
221
            self._model.add_loss(weighted_loss)
222
223
            # add metric
224
            self._model.add_metric(
225
                loss, name=f"loss/{name}_{loss_cls.name}", aggregation="mean"
226
            )
227
            self._model.add_metric(
228
                weighted_loss,
229
                name=f"loss/{name}_{loss_cls.name}_weighted",
230
                aggregation="mean",
231
            )
232
233
    @abstractmethod
234
    def build_loss(self):
235
        """Build losses according to configs."""
236
237
    def call(
0 ignored issues
show
introduced by
"mask, training" missing in parameter type documentation
Loading history...
238
        self, inputs: Dict[str, tf.Tensor], training=None, mask=None
239
    ) -> Dict[str, tf.Tensor]:
240
        """
241
        Call the self._model.
242
243
        :param inputs: a dict of tensors.
244
        :param training: training or not.
245
        :param mask: maks for inputs.
246
        :return:
247
        """
248
        return self._model(inputs, training=training, mask=mask)  # pragma: no cover
249
250
    @abstractmethod
251
    def postprocess(
252
        self,
253
        inputs: Dict[str, tf.Tensor],
254
        outputs: Dict[str, tf.Tensor],
255
    ) -> (tf.Tensor, Dict):
256
        """
257
        Return a dict used for saving inputs and outputs.
258
259
        :param inputs: dict of model inputs
260
        :param outputs: dict of model outputs
261
        :return: tuple, indices and a dict.
262
            In the dict, each value is (tensor, normalize, on_label), where
263
            - normalize = True if the tensor need to be normalized to [0, 1]
264
            - on_label = True if the tensor depends on label
265
        """
266
267
268
@REGISTRY.register_model(name="ddf")
269
class DDFModel(RegistrationModel):
270
    """
271
    A registration model predicts DDF.
272
273
    When using global net as backbone,
274
    the model predicts an affine transformation parameters,
275
    and a DDF is calculated based on that.
276
    """
277
278
    def _resize_interpolate(self, field, control_points):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
279
        resize = layer.ResizeCPTransform(control_points)
280
        field = resize(field)
281
282
        interpolate = layer.BSplines3DTransform(control_points, self.fixed_image_size)
283
        field = interpolate(field)
284
285
        return field
286
287
    def build_model(self):
288
        """Build the model to be saved as self._model."""
289
        # build inputs
290
        self._inputs = self.build_inputs()
291
        moving_image = self._inputs["moving_image"]
292
        fixed_image = self._inputs["fixed_image"]
293
294
        # build ddf
295
        control_points = self.config["backbone"].pop("control_points", False)
296
        backbone_inputs = self.concat_images(moving_image, fixed_image)
297
        backbone = REGISTRY.build_backbone(
298
            config=self.config["backbone"],
299
            default_args=dict(
300
                image_size=self.fixed_image_size,
301
                out_channels=3,
302
                out_kernel_initializer="zeros",
303
                out_activation=None,
304
            ),
305
        )
306
307
        if isinstance(backbone, GlobalNet):
308
            # (f_dim1, f_dim2, f_dim3, 3), (4, 3)
309
            ddf, theta = backbone(inputs=backbone_inputs)
310
            self._outputs = dict(ddf=ddf, theta=theta)
311
        else:
312
            # (f_dim1, f_dim2, f_dim3, 3)
313
            ddf = backbone(inputs=backbone_inputs)
314
            ddf = (
315
                self._resize_interpolate(ddf, control_points) if control_points else ddf
316
            )
317
            self._outputs = dict(ddf=ddf)
318
319
        # build outputs
320
        warping = layer.Warping(fixed_image_size=self.fixed_image_size)
321
        # (f_dim1, f_dim2, f_dim3, 3)
322
        pred_fixed_image = warping(inputs=[ddf, moving_image])
323
        self._outputs["pred_fixed_image"] = pred_fixed_image
324
325
        if not self.labeled:
326
            return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
327
328
        # (f_dim1, f_dim2, f_dim3, 3)
329
        moving_label = self._inputs["moving_label"]
330
        pred_fixed_label = warping(inputs=[ddf, moving_label])
331
332
        self._outputs["pred_fixed_label"] = pred_fixed_label
333
        return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
334
335
    def build_loss(self):
336
        """Build losses according to configs."""
337
        fixed_image = self._inputs["fixed_image"]
338
        ddf = self._outputs["ddf"]
339
        pred_fixed_image = self._outputs["pred_fixed_image"]
340
341
        # ddf
342
        self._build_loss(name="regularization", inputs_dict=dict(inputs=ddf))
343
344
        # image
345
        self._build_loss(
346
            name="image", inputs_dict=dict(y_true=fixed_image, y_pred=pred_fixed_image)
347
        )
348
349
        # label
350
        if self.labeled:
351
            fixed_label = self._inputs["fixed_label"]
352
            pred_fixed_label = self._outputs["pred_fixed_label"]
353
            self._build_loss(
354
                name="label",
355
                inputs_dict=dict(y_true=fixed_label, y_pred=pred_fixed_label),
356
            )
357
358
    def postprocess(
359
        self,
360
        inputs: Dict[str, tf.Tensor],
361
        outputs: Dict[str, tf.Tensor],
362
    ) -> (tf.Tensor, Dict):
363
        """
364
        Return a dict used for saving inputs and outputs.
365
366
        :param inputs: dict of model inputs
367
        :param outputs: dict of model outputs
368
        :return: tuple, indices and a dict.
369
            In the dict, each value is (tensor, normalize, on_label), where
370
            - normalize = True if the tensor need to be normalized to [0, 1]
371
            - on_label = True if the tensor depends on label
372
        """
373
        indices = inputs["indices"]
374
        processed = dict(
375
            moving_image=(inputs["moving_image"], True, False),
376
            fixed_image=(inputs["fixed_image"], True, False),
377
            ddf=(outputs["ddf"], True, False),
378
            pred_fixed_image=(outputs["pred_fixed_image"], True, False),
379
        )
380
381
        # save theta for affine model
382
        if "theta" in outputs:
383
            processed["theta"] = (outputs["theta"], None, None)
384
385
        if not self.labeled:
386
            return indices, processed
387
388
        processed = {
389
            **dict(
390
                moving_label=(inputs["moving_label"], False, True),
391
                fixed_label=(inputs["fixed_label"], False, True),
392
                pred_fixed_label=(outputs["pred_fixed_label"], False, True),
393
            ),
394
            **processed,
395
        }
396
397
        return indices, processed
398
399
400
@REGISTRY.register_model(name="dvf")
401
class DVFModel(DDFModel):
402
    """
403
    A registration model predicts DVF.
404
405
    DDF is calculated based on DVF.
406
    """
407
408
    def build_model(self):
409
        """Build the model to be saved as self._model."""
410
        # build inputs
411
        self._inputs = self.build_inputs()
412
        moving_image = self._inputs["moving_image"]
413
        fixed_image = self._inputs["fixed_image"]
414
        control_points = self.config["backbone"].pop("control_points", False)
415
416
        # build ddf
417
        backbone_inputs = self.concat_images(moving_image, fixed_image)
418
        backbone = REGISTRY.build_backbone(
419
            config=self.config["backbone"],
420
            default_args=dict(
421
                image_size=self.fixed_image_size,
422
                out_channels=3,
423
                out_kernel_initializer="zeros",
424
                out_activation=None,
425
            ),
426
        )
427
        dvf = backbone(inputs=backbone_inputs)
428
        dvf = self._resize_interpolate(dvf, control_points) if control_points else dvf
429
        ddf = layer.IntDVF(fixed_image_size=self.fixed_image_size)(dvf)
430
431
        # build outputs
432
        warping = layer.Warping(fixed_image_size=self.fixed_image_size)
433
        # (f_dim1, f_dim2, f_dim3, 3)
434
        pred_fixed_image = warping(inputs=[ddf, moving_image])
435
436
        self._outputs = dict(dvf=dvf, ddf=ddf, pred_fixed_image=pred_fixed_image)
437
438
        if not self.labeled:
439
            return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
440
441
        # (f_dim1, f_dim2, f_dim3, 3)
442
        moving_label = self._inputs["moving_label"]
443
        pred_fixed_label = warping(inputs=[ddf, moving_label])
444
445
        self._outputs["pred_fixed_label"] = pred_fixed_label
446
        return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
447
448
    def postprocess(
449
        self,
450
        inputs: Dict[str, tf.Tensor],
451
        outputs: Dict[str, tf.Tensor],
452
    ) -> (tf.Tensor, Dict):
453
        """
454
        Return a dict used for saving inputs and outputs.
455
456
        :param inputs: dict of model inputs
457
        :param outputs: dict of model outputs
458
        :return: tuple, indices and a dict.
459
            In the dict, each value is (tensor, normalize, on_label), where
460
            - normalize = True if the tensor need to be normalized to [0, 1]
461
            - on_label = True if the tensor depends on label
462
        """
463
        indices, processed = super().postprocess(inputs=inputs, outputs=outputs)
464
        processed["dvf"] = (outputs["dvf"], True, False)
465
        return indices, processed
466
467
468
@REGISTRY.register_model(name="conditional")
469
class ConditionalModel(RegistrationModel):
470
    """
471
    A registration model predicts fixed image label without DDF or DVF.
472
    """
473
474
    def build_model(self):
475
        """Build the model to be saved as self._model."""
476
        assert self.labeled
477
478
        # build inputs
479
        self._inputs = self.build_inputs()
480
        moving_image = self._inputs["moving_image"]
481
        fixed_image = self._inputs["fixed_image"]
482
        moving_label = self._inputs["moving_label"]
483
484
        # build ddf
485
        backbone_inputs = self.concat_images(moving_image, fixed_image, moving_label)
486
        backbone = REGISTRY.build_backbone(
487
            config=self.config["backbone"],
488
            default_args=dict(
489
                image_size=self.fixed_image_size,
490
                out_channels=1,
491
                out_kernel_initializer="glorot_uniform",
492
                out_activation="sigmoid",
493
            ),
494
        )
495
        # (batch, f_dim1, f_dim2, f_dim3)
496
        pred_fixed_label = backbone(inputs=backbone_inputs)
497
        pred_fixed_label = tf.squeeze(pred_fixed_label, axis=4)
498
499
        self._outputs = dict(pred_fixed_label=pred_fixed_label)
500
        return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
501
502
    def build_loss(self):
503
        """Build losses according to configs."""
504
        fixed_label = self._inputs["fixed_label"]
505
        pred_fixed_label = self._outputs["pred_fixed_label"]
506
507
        self._build_loss(
508
            name="label",
509
            inputs_dict=dict(y_true=fixed_label, y_pred=pred_fixed_label),
510
        )
511
512
    def postprocess(
513
        self,
514
        inputs: Dict[str, tf.Tensor],
515
        outputs: Dict[str, tf.Tensor],
516
    ) -> (tf.Tensor, Dict):
517
        """
518
        Return a dict used for saving inputs and outputs.
519
520
        :param inputs: dict of model inputs
521
        :param outputs: dict of model outputs
522
        :return: tuple, indices and a dict.
523
            In the dict, each value is (tensor, normalize, on_label), where
524
            - normalize = True if the tensor need to be normalized to [0, 1]
525
            - on_label = True if the tensor depends on label
526
        """
527
        indices = inputs["indices"]
528
        processed = dict(
529
            moving_image=(inputs["moving_image"], True, False),
530
            fixed_image=(inputs["fixed_image"], True, False),
531
            pred_fixed_label=(outputs["pred_fixed_label"], True, True),
532
            moving_label=(inputs["moving_label"], False, True),
533
            fixed_label=(inputs["fixed_label"], False, True),
534
        )
535
536
        return indices, processed
537