deepreg.model.network.RegistrationModel.__init__()   A
last analyzed

Complexity

Conditions 1

Size

Total Lines 39
Code Lines 22

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 22
dl 0
loc 39
rs 9.352
c 0
b 0
f 0
cc 1
nop 8

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
import os
2
from abc import abstractmethod
3
from copy import deepcopy
4
from typing import Dict, Optional, Tuple
5
6
import tensorflow as tf
7
8
from deepreg import log
9
from deepreg.loss.label import compute_centroid_distance
10
from deepreg.model import layer, layer_util
11
from deepreg.model.backbone import GlobalNet
12
from deepreg.registry import REGISTRY
13
14
logger = log.get(__name__)
15
16
17
def dict_without(d: dict, key) -> dict:
18
    """
19
    Return a copy of the given dict without a certain key.
20
21
    :param d: dict to be copied.
22
    :param key: key to be removed.
23
    :return: the copy without a key
24
    """
25
    copied = deepcopy(d)
26
    copied.pop(key)
27
    return copied
28
29
30
class RegistrationModel(tf.keras.Model):
31
    """Interface for registration model."""
32
33
    def __init__(
34
        self,
35
        moving_image_size: Tuple,
36
        fixed_image_size: Tuple,
37
        index_size: int,
38
        labeled: bool,
39
        batch_size: int,
40
        config: dict,
41
        name: str = "RegistrationModel",
42
    ):
43
        """
44
        Init.
45
46
        :param moving_image_size: (m_dim1, m_dim2, m_dim3)
47
        :param fixed_image_size: (f_dim1, f_dim2, f_dim3)
48
        :param index_size: number of indices for identify each sample
49
        :param labeled: if the data is labeled
50
        :param batch_size: total number of samples consumed per step, over all devices.
51
            When using multiple devices, TensorFlow automatically split the tensors.
52
            Therefore, input shapes should be defined over batch_size.
53
        :param config: config for method, backbone, and loss.
54
        :param name: name of the model
55
        """
56
        super().__init__(name=name)
57
        self.moving_image_size = moving_image_size
58
        self.fixed_image_size = fixed_image_size
59
        self.index_size = index_size
60
        self.labeled = labeled
61
        self.config = config
62
        self.batch_size = batch_size
63
64
        self._inputs = None  # save inputs of self._model as dict
65
        self._outputs = None  # save outputs of self._model as dict
66
67
        self.grid_ref = layer_util.get_reference_grid(grid_size=fixed_image_size)[
68
            None, ...
69
        ]
70
        self._model: tf.keras.Model = self.build_model()
71
        self.build_loss()
72
73
    def get_config(self) -> dict:
74
        """Return the config dictionary for recreating this class."""
75
        return dict(
76
            moving_image_size=self.moving_image_size,
77
            fixed_image_size=self.fixed_image_size,
78
            index_size=self.index_size,
79
            labeled=self.labeled,
80
            batch_size=self.batch_size,
81
            config=self.config,
82
            name=self.name,
83
        )
84
85
    @abstractmethod
86
    def build_model(self):
87
        """Build the model to be saved as self._model."""
88
89
    def build_inputs(self) -> Dict[str, tf.keras.layers.Input]:
90
        """
91
        Build input tensors.
92
93
        :return: dict of inputs.
94
        """
95
        # (batch, m_dim1, m_dim2, m_dim3)
96
        moving_image = tf.keras.Input(
97
            shape=self.moving_image_size,
98
            batch_size=self.batch_size,
99
            name="moving_image",
100
        )
101
        # (batch, f_dim1, f_dim2, f_dim3)
102
        fixed_image = tf.keras.Input(
103
            shape=self.fixed_image_size,
104
            batch_size=self.batch_size,
105
            name="fixed_image",
106
        )
107
        # (batch, index_size)
108
        indices = tf.keras.Input(
109
            shape=(self.index_size,),
110
            batch_size=self.batch_size,
111
            name="indices",
112
        )
113
114
        if not self.labeled:
115
            return dict(
116
                moving_image=moving_image, fixed_image=fixed_image, indices=indices
117
            )
118
119
        # (batch, m_dim1, m_dim2, m_dim3)
120
        moving_label = tf.keras.Input(
121
            shape=self.moving_image_size,
122
            batch_size=self.batch_size,
123
            name="moving_label",
124
        )
125
        # (batch, m_dim1, m_dim2, m_dim3)
126
        fixed_label = tf.keras.Input(
127
            shape=self.fixed_image_size,
128
            batch_size=self.batch_size,
129
            name="fixed_label",
130
        )
131
        return dict(
132
            moving_image=moving_image,
133
            fixed_image=fixed_image,
134
            moving_label=moving_label,
135
            fixed_label=fixed_label,
136
            indices=indices,
137
        )
138
139
    def concat_images(
140
        self,
141
        moving_image: tf.Tensor,
142
        fixed_image: tf.Tensor,
143
        moving_label: Optional[tf.Tensor] = None,
144
    ) -> tf.Tensor:
145
        """
146
        Adjust image shape and concatenate them together.
147
148
        :param moving_image: registration source
149
        :param fixed_image: registration target
150
        :param moving_label: optional, only used for conditional model.
151
        :return:
152
        """
153
        images = []
154
155
        resize_layer = layer.Resize3d(shape=self.fixed_image_size)
156
157
        # (batch, m_dim1, m_dim2, m_dim3, 1)
158
        moving_image = tf.expand_dims(moving_image, axis=4)
159
        moving_image = resize_layer(moving_image)
160
        images.append(moving_image)
161
162
        # (batch, m_dim1, m_dim2, m_dim3, 1)
163
        fixed_image = tf.expand_dims(fixed_image, axis=4)
164
        images.append(fixed_image)
165
166
        # (batch, m_dim1, m_dim2, m_dim3, 1)
167
        if moving_label is not None:
168
            moving_label = tf.expand_dims(moving_label, axis=4)
169
            moving_label = resize_layer(moving_label)
170
            images.append(moving_label)
171
172
        # (batch, f_dim1, f_dim2, f_dim3, 2 or 3)
173
        images = tf.concat(images, axis=4)
174
        return images
175
176
    def _build_loss(self, name: str, inputs_dict: dict):
177
        """
178
        Build and add one weighted loss together with the metrics.
179
180
        :param name: name of loss, image / label / regularization.
181
        :param inputs_dict: inputs for loss function
182
        """
183
184
        if name not in self.config["loss"]:
185
            # loss config is not defined
186
            logger.warning(
187
                f"The configuration for loss {name} is not defined. "
188
                f"Therefore it is not used."
189
            )
190
            return
191
192
        loss_configs = self.config["loss"][name]
193
        if not isinstance(loss_configs, list):
194
            loss_configs = [loss_configs]
195
196
        for loss_config in loss_configs:
197
198
            if "weight" not in loss_config:
199
                # default loss weight 1
200
                logger.warning(
201
                    f"The weight for loss {name} is not defined."
202
                    f"Default weight = 1.0 is used."
203
                )
204
                loss_config["weight"] = 1.0
205
206
            # build loss
207
            weight = loss_config["weight"]
208
209
            if weight == 0:
210
                logger.warning(
211
                    f"The weight for loss {name} is zero." f"Loss is not used."
212
                )
213
                return
214
215
            # do not perform reduction over batch axis for supporting multi-device
216
            # training, model.fit() will average over global batch size automatically
217
            loss_layer: tf.keras.layers.Layer = REGISTRY.build_loss(
218
                config=dict_without(d=loss_config, key="weight"),
219
                default_args={"reduction": tf.keras.losses.Reduction.NONE},
220
            )
221
            loss_value = loss_layer(**inputs_dict)
222
            weighted_loss = loss_value * weight
223
224
            # add loss
225
            self._model.add_loss(weighted_loss)
226
227
            # add metric
228
            self._model.add_metric(
229
                loss_value, name=f"loss/{name}_{loss_layer.name}", aggregation="mean"
230
            )
231
            self._model.add_metric(
232
                weighted_loss,
233
                name=f"loss/{name}_{loss_layer.name}_weighted",
234
                aggregation="mean",
235
            )
236
237
    @abstractmethod
238
    def build_loss(self):
239
        """Build losses according to configs."""
240
241
        # input metrics
242
        fixed_image = self._inputs["fixed_image"]
243
        moving_image = self._inputs["moving_image"]
244
        self.log_tensor_stats(tensor=moving_image, name="moving_image")
245
        self.log_tensor_stats(tensor=fixed_image, name="fixed_image")
246
247
        # image loss, conditional model does not have this
248
        if "pred_fixed_image" in self._outputs:
249
            pred_fixed_image = self._outputs["pred_fixed_image"]
250
            self._build_loss(
251
                name="image",
252
                inputs_dict=dict(y_true=fixed_image, y_pred=pred_fixed_image),
253
            )
254
255
        if self.labeled:
256
            # input metrics
257
            fixed_label = self._inputs["fixed_label"]
258
            moving_label = self._inputs["moving_label"]
259
            self.log_tensor_stats(tensor=moving_label, name="moving_label")
260
            self.log_tensor_stats(tensor=fixed_label, name="fixed_label")
261
262
            # label loss
263
            pred_fixed_label = self._outputs["pred_fixed_label"]
264
            self._build_loss(
265
                name="label",
266
                inputs_dict=dict(y_true=fixed_label, y_pred=pred_fixed_label),
267
            )
268
269
            # additional label metrics
270
            tre = compute_centroid_distance(
271
                y_true=fixed_label, y_pred=pred_fixed_label, grid=self.grid_ref
272
            )
273
            self._model.add_metric(tre, name="metric/TRE", aggregation="mean")
274
275
    def call(
276
        self, inputs: Dict[str, tf.Tensor], training=None, mask=None
277
    ) -> Dict[str, tf.Tensor]:
278
        """
279
        Call the self._model.
280
281
        :param inputs: a dict of tensors.
282
        :param training: training or not.
283
        :param mask: maks for inputs.
284
        :return:
285
        """
286
        return self._model(inputs, training=training, mask=mask)  # pragma: no cover
287
288
    @abstractmethod
289
    def postprocess(
290
        self,
291
        inputs: Dict[str, tf.Tensor],
292
        outputs: Dict[str, tf.Tensor],
293
    ) -> Tuple[tf.Tensor, Dict]:
294
        """
295
        Return a dict used for saving inputs and outputs.
296
297
        :param inputs: dict of model inputs
298
        :param outputs: dict of model outputs
299
        :return: tuple, indices and a dict.
300
            In the dict, each value is (tensor, normalize, on_label), where
301
            - normalize = True if the tensor need to be normalized to [0, 1]
302
            - on_label = True if the tensor depends on label
303
        """
304
305
    def plot_model(self, output_dir: str):
306
        """
307
        Save model structure in png.
308
309
        :param output_dir: path to the output dir.
310
        """
311
        self._model.summary(print_fn=logger.debug)
312
        try:
313
            tf.keras.utils.plot_model(
314
                self._model,
315
                to_file=os.path.join(output_dir, f"{self.name}.png"),
316
                dpi=96,
317
                show_shapes=True,
318
                show_layer_names=True,
319
                expand_nested=False,
320
            )
321
        except ImportError as err:  # pragma: no cover
322
            logger.error(
323
                "Failed to plot model structure. "
324
                "Please check if graphviz is installed. "
325
                "Error message is: %s.",
326
                err,
327
            )
328
329
    def log_tensor_stats(self, tensor: tf.Tensor, name: str):
330
        """
331
        Log statistics of a given tensor.
332
333
        :param tensor: tensor to monitor.
334
        :param name: name of the tensor.
335
        """
336
        flatten = tf.reshape(tensor, shape=(self.batch_size, -1))
337
        self._model.add_metric(
338
            tf.reduce_mean(flatten, axis=1),
339
            name=f"metric/{name}_mean",
340
            aggregation="mean",
341
        )
342
        self._model.add_metric(
343
            tf.reduce_min(flatten, axis=1),
344
            name=f"metric/{name}_min",
345
            aggregation="min",
346
        )
347
        self._model.add_metric(
348
            tf.reduce_max(flatten, axis=1),
349
            name=f"metric/{name}_max",
350
            aggregation="max",
351
        )
352
353
354
@REGISTRY.register_model(name="ddf")
355
class DDFModel(RegistrationModel):
356
    """
357
    A registration model predicts DDF.
358
359
    When using global net as backbone,
360
    the model predicts an affine transformation parameters,
361
    and a DDF is calculated based on that.
362
    """
363
364
    name = "DDFModel"
365
366
    def _resize_interpolate(self, field, control_points):
367
        resize = layer.ResizeCPTransform(control_points)
368
        field = resize(field)
369
370
        interpolate = layer.BSplines3DTransform(control_points, self.fixed_image_size)
371
        field = interpolate(field)
372
373
        return field
374
375
    def build_model(self):
376
        """Build the model to be saved as self._model."""
377
        # build inputs
378
        self._inputs = self.build_inputs()
379
        moving_image = self._inputs["moving_image"]  # (batch, m_dim1, m_dim2, m_dim3)
380
        fixed_image = self._inputs["fixed_image"]  # (batch, f_dim1, f_dim2, f_dim3)
381
382
        # build ddf
383
        control_points = self.config["backbone"].pop("control_points", False)
384
        backbone_inputs = self.concat_images(moving_image, fixed_image)
385
        backbone = REGISTRY.build_backbone(
386
            config=self.config["backbone"],
387
            default_args=dict(
388
                image_size=self.fixed_image_size,
389
                out_channels=3,
390
                out_kernel_initializer="zeros",
391
                out_activation=None,
392
            ),
393
        )
394
395
        if isinstance(backbone, GlobalNet):
396
            # (f_dim1, f_dim2, f_dim3, 3), (4, 3)
397
            ddf, theta = backbone(inputs=backbone_inputs)
398
            self._outputs = dict(ddf=ddf, theta=theta)
399
        else:
400
            # (f_dim1, f_dim2, f_dim3, 3)
401
            ddf = backbone(inputs=backbone_inputs)
402
            ddf = (
403
                self._resize_interpolate(ddf, control_points) if control_points else ddf
404
            )
405
            self._outputs = dict(ddf=ddf)
406
407
        # build outputs
408
        warping = layer.Warping(fixed_image_size=self.fixed_image_size)
409
        # (f_dim1, f_dim2, f_dim3)
410
        pred_fixed_image = warping(inputs=[ddf, moving_image])
411
        self._outputs["pred_fixed_image"] = pred_fixed_image
412
413
        if not self.labeled:
414
            return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
415
416
        # (f_dim1, f_dim2, f_dim3)
417
        moving_label = self._inputs["moving_label"]
418
        pred_fixed_label = warping(inputs=[ddf, moving_label])
419
420
        self._outputs["pred_fixed_label"] = pred_fixed_label
421
        return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
422
423
    def build_loss(self):
424
        """Build losses according to configs."""
425
        super().build_loss()
426
427
        # ddf loss and metrics
428
        ddf = self._outputs["ddf"]
429
        self._build_loss(name="regularization", inputs_dict=dict(inputs=ddf))
430
        self.log_tensor_stats(tensor=ddf, name="ddf")
431
432
    def postprocess(
433
        self,
434
        inputs: Dict[str, tf.Tensor],
435
        outputs: Dict[str, tf.Tensor],
436
    ) -> Tuple[tf.Tensor, Dict]:
437
        """
438
        Return a dict used for saving inputs and outputs.
439
440
        :param inputs: dict of model inputs
441
        :param outputs: dict of model outputs
442
        :return: tuple, indices and a dict.
443
            In the dict, each value is (tensor, normalize, on_label), where
444
            - normalize = True if the tensor need to be normalized to [0, 1]
445
            - on_label = True if the tensor depends on label
446
        """
447
        indices = inputs["indices"]
448
        processed = dict(
449
            moving_image=(inputs["moving_image"], True, False),
450
            fixed_image=(inputs["fixed_image"], True, False),
451
            ddf=(outputs["ddf"], True, False),
452
            pred_fixed_image=(outputs["pred_fixed_image"], True, False),
453
        )
454
455
        # save theta for affine model
456
        if "theta" in outputs:
457
            processed["theta"] = (outputs["theta"], None, None)  # type: ignore
458
459
        if not self.labeled:
460
            return indices, processed
461
462
        processed = {
463
            **dict(
464
                moving_label=(inputs["moving_label"], False, True),
465
                fixed_label=(inputs["fixed_label"], False, True),
466
                pred_fixed_label=(outputs["pred_fixed_label"], False, True),
467
            ),
468
            **processed,
469
        }
470
471
        return indices, processed
472
473
474
@REGISTRY.register_model(name="dvf")
475
class DVFModel(DDFModel):
476
    """
477
    A registration model predicts DVF.
478
479
    DDF is calculated based on DVF.
480
    """
481
482
    name = "DVFModel"
483
484
    def build_model(self):
485
        """Build the model to be saved as self._model."""
486
        # build inputs
487
        self._inputs = self.build_inputs()
488
        moving_image = self._inputs["moving_image"]
489
        fixed_image = self._inputs["fixed_image"]
490
        control_points = self.config["backbone"].pop("control_points", False)
491
492
        # build ddf
493
        backbone_inputs = self.concat_images(moving_image, fixed_image)
494
        backbone = REGISTRY.build_backbone(
495
            config=self.config["backbone"],
496
            default_args=dict(
497
                image_size=self.fixed_image_size,
498
                out_channels=3,
499
                out_kernel_initializer="zeros",
500
                out_activation=None,
501
            ),
502
        )
503
        dvf = backbone(inputs=backbone_inputs)
504
        dvf = self._resize_interpolate(dvf, control_points) if control_points else dvf
505
        ddf = layer.IntDVF(fixed_image_size=self.fixed_image_size)(dvf)
506
507
        # build outputs
508
        self._warping = layer.Warping(fixed_image_size=self.fixed_image_size)
509
        # (f_dim1, f_dim2, f_dim3, 3)
510
        pred_fixed_image = self._warping(inputs=[ddf, moving_image])
511
512
        self._outputs = dict(dvf=dvf, ddf=ddf, pred_fixed_image=pred_fixed_image)
513
514
        if not self.labeled:
515
            return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
516
517
        # (f_dim1, f_dim2, f_dim3, 3)
518
        moving_label = self._inputs["moving_label"]
519
        pred_fixed_label = self._warping(inputs=[ddf, moving_label])
520
521
        self._outputs["pred_fixed_label"] = pred_fixed_label
522
        return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
523
524
    def build_loss(self):
525
        """Build losses according to configs."""
526
        super().build_loss()
527
528
        # dvf metrics
529
        dvf = self._outputs["dvf"]
530
        self.log_tensor_stats(tensor=dvf, name="dvf")
531
532
    def postprocess(
533
        self,
534
        inputs: Dict[str, tf.Tensor],
535
        outputs: Dict[str, tf.Tensor],
536
    ) -> Tuple[tf.Tensor, Dict]:
537
        """
538
        Return a dict used for saving inputs and outputs.
539
540
        :param inputs: dict of model inputs
541
        :param outputs: dict of model outputs
542
        :return: tuple, indices and a dict.
543
            In the dict, each value is (tensor, normalize, on_label), where
544
            - normalize = True if the tensor need to be normalized to [0, 1]
545
            - on_label = True if the tensor depends on label
546
        """
547
        indices, processed = super().postprocess(inputs=inputs, outputs=outputs)
548
        processed["dvf"] = (outputs["dvf"], True, False)
549
        return indices, processed
550
551
552
@REGISTRY.register_model(name="conditional")
553
class ConditionalModel(RegistrationModel):
554
    """
555
    A registration model predicts fixed image label without DDF or DVF.
556
    """
557
558
    name = "ConditionalModel"
559
560
    def build_model(self):
561
        """Build the model to be saved as self._model."""
562
        assert self.labeled
563
564
        # build inputs
565
        self._inputs = self.build_inputs()
566
        moving_image = self._inputs["moving_image"]
567
        fixed_image = self._inputs["fixed_image"]
568
        moving_label = self._inputs["moving_label"]
569
570
        # build ddf
571
        backbone_inputs = self.concat_images(moving_image, fixed_image, moving_label)
572
        backbone = REGISTRY.build_backbone(
573
            config=self.config["backbone"],
574
            default_args=dict(
575
                image_size=self.fixed_image_size,
576
                out_channels=1,
577
                out_kernel_initializer="glorot_uniform",
578
                out_activation="sigmoid",
579
            ),
580
        )
581
        # (batch, f_dim1, f_dim2, f_dim3)
582
        pred_fixed_label = backbone(inputs=backbone_inputs)
583
        pred_fixed_label = tf.squeeze(pred_fixed_label, axis=4)
584
585
        self._outputs = dict(pred_fixed_label=pred_fixed_label)
586
        return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
587
588
    def postprocess(
589
        self,
590
        inputs: Dict[str, tf.Tensor],
591
        outputs: Dict[str, tf.Tensor],
592
    ) -> Tuple[tf.Tensor, Dict]:
593
        """
594
        Return a dict used for saving inputs and outputs.
595
596
        :param inputs: dict of model inputs
597
        :param outputs: dict of model outputs
598
        :return: tuple, indices and a dict.
599
            In the dict, each value is (tensor, normalize, on_label), where
600
            - normalize = True if the tensor need to be normalized to [0, 1]
601
            - on_label = True if the tensor depends on label
602
        """
603
        indices = inputs["indices"]
604
        processed = dict(
605
            moving_image=(inputs["moving_image"], True, False),
606
            fixed_image=(inputs["fixed_image"], True, False),
607
            pred_fixed_label=(outputs["pred_fixed_label"], True, True),
608
            moving_label=(inputs["moving_label"], False, True),
609
            fixed_label=(inputs["fixed_label"], False, True),
610
        )
611
612
        return indices, processed
613