Passed
Pull Request — main (#719)
by Yunguan
01:41
created

deepreg.model.network.RegistrationModel.__init__()   A

Complexity

Conditions 1

Size

Total Lines 43
Code Lines 26

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 26
dl 0
loc 43
rs 9.256
c 0
b 0
f 0
cc 1
nop 9

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
import logging
2
import os
3
from abc import abstractmethod
4
from copy import deepcopy
5
from typing import Dict, Optional, Tuple
6
7
import tensorflow as tf
8
9
from deepreg.loss.image import LocalNormalizedCrossCorrelation
10
from deepreg.loss.label import DiceScore, compute_centroid_distance
11
from deepreg.model import layer, layer_util
12
from deepreg.model.backbone import GlobalNet
13
from deepreg.registry import REGISTRY
14
15
16
def dict_without(d: dict, key) -> dict:
17
    """
18
    Return a copy of the given dict without a certain key.
19
20
    :param d: dict to be copied.
21
    :param key: key to be removed.
22
    :return: the copy without a key
23
    """
24
    copied = deepcopy(d)
25
    copied.pop(key)
26
    return copied
27
28
29
class RegistrationModel(tf.keras.Model):
30
    """Interface for registration model."""
31
32
    def __init__(
33
        self,
34
        moving_image_size: tuple,
35
        fixed_image_size: tuple,
36
        index_size: int,
37
        labeled: bool,
38
        batch_size: int,
39
        config: dict,
40
        num_devices: int = 1,
41
        name: str = "RegistrationModel",
42
    ):
43
        """
44
        Init.
45
46
        :param moving_image_size: (m_dim1, m_dim2, m_dim3)
47
        :param fixed_image_size: (f_dim1, f_dim2, f_dim3)
48
        :param index_size: number of indices for identify each sample
49
        :param labeled: if the data is labeled
50
        :param batch_size: size of mini-batch
51
        :param config: config for method, backbone, and loss.
52
        :param num_devices: number of GPU used,
53
            global_batch_size = batch_size*num_devices
54
        :param name: name of the model
55
        """
56
        super().__init__(name=name)
57
        self.moving_image_size = moving_image_size
58
        self.fixed_image_size = fixed_image_size
59
        self.index_size = index_size
60
        self.labeled = labeled
61
        self.batch_size = batch_size
62
        self.config = config
63
        self.num_devices = num_devices
64
        self.global_batch_size = num_devices * batch_size * 1.0
65
        assert self.global_batch_size > 0
66
67
        self._inputs = None  # save inputs of self._model as dict
68
        self._outputs = None  # save outputs of self._model as dict
69
70
        self.grid_ref = layer_util.get_reference_grid(grid_size=fixed_image_size)[
71
            None, ...
72
        ]
73
        self._model: tf.keras.Model = self.build_model()
74
        self.build_loss()
75
76
    def get_config(self) -> dict:
77
        """Return the config dictionary for recreating this class."""
78
        return dict(
79
            moving_image_size=self.moving_image_size,
80
            fixed_image_size=self.fixed_image_size,
81
            index_size=self.index_size,
82
            labeled=self.labeled,
83
            batch_size=self.batch_size,
84
            config=self.config,
85
            num_devices=self.num_devices,
86
            name=self.name,
87
        )
88
89
    @abstractmethod
90
    def build_model(self):
91
        """Build the model to be saved as self._model."""
92
93
    def build_inputs(self) -> Dict[str, tf.keras.layers.Input]:
94
        """
95
        Build input tensors.
96
97
        :return: dict of inputs.
98
        """
99
        # (batch, m_dim1, m_dim2, m_dim3, 1)
100
        moving_image = tf.keras.Input(
101
            shape=self.moving_image_size,
102
            batch_size=self.batch_size,
103
            name="moving_image",
104
        )
105
        # (batch, f_dim1, f_dim2, f_dim3, 1)
106
        fixed_image = tf.keras.Input(
107
            shape=self.fixed_image_size,
108
            batch_size=self.batch_size,
109
            name="fixed_image",
110
        )
111
        # (batch, index_size)
112
        indices = tf.keras.Input(
113
            shape=(self.index_size,),
114
            batch_size=self.batch_size,
115
            name="indices",
116
        )
117
118
        if not self.labeled:
119
            return dict(
120
                moving_image=moving_image, fixed_image=fixed_image, indices=indices
121
            )
122
123
        # (batch, m_dim1, m_dim2, m_dim3, 1)
124
        moving_label = tf.keras.Input(
125
            shape=self.moving_image_size,
126
            batch_size=self.batch_size,
127
            name="moving_label",
128
        )
129
        # (batch, m_dim1, m_dim2, m_dim3, 1)
130
        fixed_label = tf.keras.Input(
131
            shape=self.fixed_image_size,
132
            batch_size=self.batch_size,
133
            name="fixed_label",
134
        )
135
        return dict(
136
            moving_image=moving_image,
137
            fixed_image=fixed_image,
138
            moving_label=moving_label,
139
            fixed_label=fixed_label,
140
            indices=indices,
141
        )
142
143
    def concat_images(
144
        self,
145
        moving_image: tf.Tensor,
146
        fixed_image: tf.Tensor,
147
        moving_label: Optional[tf.Tensor] = None,
148
    ) -> tf.Tensor:
149
        """
150
        Adjust image shape and concatenate them together.
151
152
        :param moving_image: registration source
153
        :param fixed_image: registration target
154
        :param moving_label: optional, only used for conditional model.
155
        :return:
156
        """
157
        images = []
158
159
        resize_layer = layer.Resize3d(shape=self.fixed_image_size)
160
161
        # (batch, m_dim1, m_dim2, m_dim3, 1)
162
        moving_image = tf.expand_dims(moving_image, axis=4)
163
        moving_image = resize_layer(moving_image)
164
        images.append(moving_image)
165
166
        # (batch, m_dim1, m_dim2, m_dim3, 1)
167
        fixed_image = tf.expand_dims(fixed_image, axis=4)
168
        images.append(fixed_image)
169
170
        # (batch, m_dim1, m_dim2, m_dim3, 1)
171
        if moving_label is not None:
172
            moving_label = tf.expand_dims(moving_label, axis=4)
173
            moving_label = resize_layer(moving_label)
174
            images.append(moving_label)
175
176
        # (batch, f_dim1, f_dim2, f_dim3, 2 or 3)
177
        images = tf.concat(images, axis=4)
178
        return images
179
180
    def _build_loss(self, name: str, inputs_dict: dict):
181
        """
182
        Build and add one weighted loss together with the metrics.
183
184
        :param name: name of loss, image / label / regularization.
185
        :param inputs_dict: inputs for loss function
186
        """
187
188
        if name not in self.config["loss"]:
189
            # loss config is not defined
190
            logging.warning(
191
                f"The configuration for loss {name} is not defined. "
192
                f"Therefore it is not used."
193
            )
194
            return
195
196
        loss_configs = self.config["loss"][name]
197
        if not isinstance(loss_configs, list):
198
            loss_configs = [loss_configs]
199
200
        for loss_config in loss_configs:
201
202
            if "weight" not in loss_config:
203
                # default loss weight 1
204
                logging.warning(
205
                    f"The weight for loss {name} is not defined."
206
                    f"Default weight = 1.0 is used."
207
                )
208
                loss_config["weight"] = 1.0
209
210
            # build loss
211
            weight = loss_config["weight"]
212
213
            if weight == 0:
214
                logging.warning(
215
                    f"The weight for loss {name} is zero." f"Loss is not used."
216
                )
217
                return
218
219
            loss_layer: tf.keras.layers.Layer = REGISTRY.build_loss(
220
                config=dict_without(d=loss_config, key="weight")
221
            )
222
            loss_value = loss_layer(**inputs_dict) / self.global_batch_size
223
            weighted_loss = loss_value * weight if weight != 1 else loss_value
224
225
            # add loss
226
            self._model.add_loss(weighted_loss)
227
228
            loss_value = tf.debugging.check_numerics(
229
                loss_value,
230
                f"loss {name}_{loss_layer.name} inf/nan",
231
                name=f"op_loss_{name}_{loss_layer.name}",
232
            )
233
234
            # add metric
235
            self._model.add_metric(
236
                loss_value, name=f"loss/{name}_{loss_layer.name}", aggregation="mean"
237
            )
238
            self._model.add_metric(
239
                weighted_loss,
240
                name=f"loss/{name}_{loss_layer.name}_weighted",
241
                aggregation="mean",
242
            )
243
244
    @abstractmethod
245
    def build_loss(self):
246
        """Build losses according to configs."""
247
248
        # input metrics
249
        fixed_image = self._inputs["fixed_image"]
250
        moving_image = self._inputs["moving_image"]
251
        self.log_tensor_stats(tensor=moving_image, name="moving_image")
252
        self.log_tensor_stats(tensor=fixed_image, name="fixed_image")
253
254
        # image loss, conditional model does not have this
255
        if "pred_fixed_image" in self._outputs:
256
            pred_fixed_image = self._outputs["pred_fixed_image"]
257
            num, denom = LocalNormalizedCrossCorrelation()._call(
258
                y_true=fixed_image, y_pred=pred_fixed_image
259
            )
260
            self.log_tensor_stats(num, name="debug-lncc-num")
261
            self.log_tensor_stats(denom, name="debug-lncc-denom")
262
            self._build_loss(
263
                name="image",
264
                inputs_dict=dict(y_true=fixed_image, y_pred=pred_fixed_image),
265
            )
266
267
        if self.labeled:
268
            # input metrics
269
            fixed_label = self._inputs["fixed_label"]
270
            moving_label = self._inputs["moving_label"]
271
            self.log_tensor_stats(tensor=moving_label, name="moving_label")
272
            self.log_tensor_stats(tensor=fixed_label, name="fixed_label")
273
274
            # label loss
275
            pred_fixed_label = self._outputs["pred_fixed_label"]
276
            self._build_loss(
277
                name="label",
278
                inputs_dict=dict(y_true=fixed_label, y_pred=pred_fixed_label),
279
            )
280
281
            # additional label metrics
282
            tre = compute_centroid_distance(
283
                y_true=fixed_label, y_pred=pred_fixed_label, grid=self.grid_ref
284
            )
285
            dice_binary = DiceScore(binary=True)(
286
                y_true=fixed_label, y_pred=pred_fixed_label
287
            )
288
            self._model.add_metric(tre, name="metric/TRE", aggregation="mean")
289
            self._model.add_metric(
290
                dice_binary, name="metric/BinaryDiceScore", aggregation="mean"
291
            )
292
293
    def call(
294
        self, inputs: Dict[str, tf.Tensor], training=None, mask=None
295
    ) -> Dict[str, tf.Tensor]:
296
        """
297
        Call the self._model.
298
299
        :param inputs: a dict of tensors.
300
        :param training: training or not.
301
        :param mask: maks for inputs.
302
        :return:
303
        """
304
        return self._model(inputs, training=training, mask=mask)  # pragma: no cover
305
306
    @abstractmethod
307
    def postprocess(
308
        self,
309
        inputs: Dict[str, tf.Tensor],
310
        outputs: Dict[str, tf.Tensor],
311
    ) -> Tuple[tf.Tensor, Dict]:
312
        """
313
        Return a dict used for saving inputs and outputs.
314
315
        :param inputs: dict of model inputs
316
        :param outputs: dict of model outputs
317
        :return: tuple, indices and a dict.
318
            In the dict, each value is (tensor, normalize, on_label), where
319
            - normalize = True if the tensor need to be normalized to [0, 1]
320
            - on_label = True if the tensor depends on label
321
        """
322
323
    def plot_model(self, output_dir: str):
324
        """
325
        Save model structure in png.
326
327
        :param output_dir: path to the output dir.
328
        """
329
        logging.info(self._model.summary())
330
        try:
331
            tf.keras.utils.plot_model(
332
                self._model,
333
                to_file=os.path.join(output_dir, f"{self.name}.png"),
334
                dpi=96,
335
                show_shapes=True,
336
                show_layer_names=True,
337
                expand_nested=False,
338
            )
339
        except ImportError as err:  # pragma: no cover
340
            logging.error(
341
                "Failed to plot model structure."
342
                "Please check if graphviz is installed.\n"
343
                "Error message is:"
344
                f"{err}"
345
            )
346
347
    def log_tensor_stats(self, tensor: tf.Tensor, name: str):
348
        """
349
        Log statistics of a given tensor.
350
351
        :param tensor: tensor to monitor.
352
        :param name: name of the tensor.
353
        """
354
        flatten = tf.reshape(tensor, shape=(self.batch_size, -1))
355
        self._model.add_metric(
356
            tf.reduce_mean(flatten, axis=1),
357
            name=f"metric/{name}_mean",
358
            aggregation="mean",
359
        )
360
        self._model.add_metric(
361
            tf.reduce_min(flatten, axis=1),
362
            name=f"metric/{name}_min",
363
            aggregation="min",
364
        )
365
        self._model.add_metric(
366
            tf.reduce_max(flatten, axis=1),
367
            name=f"metric/{name}_max",
368
            aggregation="max",
369
        )
370
371
372
@REGISTRY.register_model(name="ddf")
373
class DDFModel(RegistrationModel):
374
    """
375
    A registration model predicts DDF.
376
377
    When using global net as backbone,
378
    the model predicts an affine transformation parameters,
379
    and a DDF is calculated based on that.
380
    """
381
382
    name = "DDFModel"
383
384
    def _resize_interpolate(self, field, control_points):
385
        resize = layer.ResizeCPTransform(control_points)
386
        field = resize(field)
387
388
        interpolate = layer.BSplines3DTransform(control_points, self.fixed_image_size)
389
        field = interpolate(field)
390
391
        return field
392
393
    def build_model(self):
394
        """Build the model to be saved as self._model."""
395
        # build inputs
396
        self._inputs = self.build_inputs()
397
        moving_image = self._inputs["moving_image"]
398
        fixed_image = self._inputs["fixed_image"]
399
400
        # build ddf
401
        control_points = self.config["backbone"].pop("control_points", False)
402
        backbone_inputs = self.concat_images(moving_image, fixed_image)
403
        backbone = REGISTRY.build_backbone(
404
            config=self.config["backbone"],
405
            default_args=dict(
406
                image_size=self.fixed_image_size,
407
                out_channels=3,
408
                out_kernel_initializer="zeros",
409
                out_activation=None,
410
            ),
411
        )
412
413
        if isinstance(backbone, GlobalNet):
414
            # (f_dim1, f_dim2, f_dim3, 3), (4, 3)
415
            ddf, theta = backbone(inputs=backbone_inputs)
416
            self._outputs = dict(ddf=ddf, theta=theta)
417
        else:
418
            # (f_dim1, f_dim2, f_dim3, 3)
419
            ddf = backbone(inputs=backbone_inputs)
420
            ddf = (
421
                self._resize_interpolate(ddf, control_points) if control_points else ddf
422
            )
423
            self._outputs = dict(ddf=ddf)
424
425
        # build outputs
426
        warping = layer.Warping(fixed_image_size=self.fixed_image_size)
427
        # (f_dim1, f_dim2, f_dim3, 3)
428
        pred_fixed_image = warping(inputs=[ddf, moving_image])
429
        self._outputs["pred_fixed_image"] = pred_fixed_image
430
431
        if not self.labeled:
432
            return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
433
434
        # (f_dim1, f_dim2, f_dim3, 3)
435
        moving_label = self._inputs["moving_label"]
436
        pred_fixed_label = warping(inputs=[ddf, moving_label])
437
438
        self._outputs["pred_fixed_label"] = pred_fixed_label
439
        return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
440
441
    def build_loss(self):
442
        """Build losses according to configs."""
443
        super().build_loss()
444
445
        # ddf loss and metrics
446
        ddf = self._outputs["ddf"]
447
        self._build_loss(name="regularization", inputs_dict=dict(inputs=ddf))
448
        self.log_tensor_stats(tensor=ddf, name="ddf")
449
450
    def postprocess(
451
        self,
452
        inputs: Dict[str, tf.Tensor],
453
        outputs: Dict[str, tf.Tensor],
454
    ) -> Tuple[tf.Tensor, Dict]:
455
        """
456
        Return a dict used for saving inputs and outputs.
457
458
        :param inputs: dict of model inputs
459
        :param outputs: dict of model outputs
460
        :return: tuple, indices and a dict.
461
            In the dict, each value is (tensor, normalize, on_label), where
462
            - normalize = True if the tensor need to be normalized to [0, 1]
463
            - on_label = True if the tensor depends on label
464
        """
465
        indices = inputs["indices"]
466
        processed = dict(
467
            moving_image=(inputs["moving_image"], True, False),
468
            fixed_image=(inputs["fixed_image"], True, False),
469
            ddf=(outputs["ddf"], True, False),
470
            pred_fixed_image=(outputs["pred_fixed_image"], True, False),
471
        )
472
473
        # save theta for affine model
474
        if "theta" in outputs:
475
            processed["theta"] = (outputs["theta"], None, None)  # type: ignore
476
477
        if not self.labeled:
478
            return indices, processed
479
480
        processed = {
481
            **dict(
482
                moving_label=(inputs["moving_label"], False, True),
483
                fixed_label=(inputs["fixed_label"], False, True),
484
                pred_fixed_label=(outputs["pred_fixed_label"], False, True),
485
            ),
486
            **processed,
487
        }
488
489
        return indices, processed
490
491
492
@REGISTRY.register_model(name="dvf")
493
class DVFModel(DDFModel):
494
    """
495
    A registration model predicts DVF.
496
497
    DDF is calculated based on DVF.
498
    """
499
500
    name = "DVFModel"
501
502
    def build_model(self):
503
        """Build the model to be saved as self._model."""
504
        # build inputs
505
        self._inputs = self.build_inputs()
506
        moving_image = self._inputs["moving_image"]
507
        fixed_image = self._inputs["fixed_image"]
508
        control_points = self.config["backbone"].pop("control_points", False)
509
510
        # build ddf
511
        backbone_inputs = self.concat_images(moving_image, fixed_image)
512
        backbone = REGISTRY.build_backbone(
513
            config=self.config["backbone"],
514
            default_args=dict(
515
                image_size=self.fixed_image_size,
516
                out_channels=3,
517
                out_kernel_initializer="zeros",
518
                out_activation=None,
519
            ),
520
        )
521
        dvf = backbone(inputs=backbone_inputs)
522
        dvf = self._resize_interpolate(dvf, control_points) if control_points else dvf
523
        ddf = layer.IntDVF(fixed_image_size=self.fixed_image_size)(dvf)
524
525
        # build outputs
526
        self._warping = layer.Warping(fixed_image_size=self.fixed_image_size)
527
        # (f_dim1, f_dim2, f_dim3, 3)
528
        pred_fixed_image = self._warping(inputs=[ddf, moving_image])
529
530
        self._outputs = dict(dvf=dvf, ddf=ddf, pred_fixed_image=pred_fixed_image)
531
532
        if not self.labeled:
533
            return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
534
535
        # (f_dim1, f_dim2, f_dim3, 3)
536
        moving_label = self._inputs["moving_label"]
537
        pred_fixed_label = self._warping(inputs=[ddf, moving_label])
538
539
        self._outputs["pred_fixed_label"] = pred_fixed_label
540
        return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
541
542
    def build_loss(self):
543
        """Build losses according to configs."""
544
        super().build_loss()
545
546
        # dvf metrics
547
        dvf = self._outputs["dvf"]
548
        self.log_tensor_stats(tensor=dvf, name="dvf")
549
550
    def postprocess(
551
        self,
552
        inputs: Dict[str, tf.Tensor],
553
        outputs: Dict[str, tf.Tensor],
554
    ) -> Tuple[tf.Tensor, Dict]:
555
        """
556
        Return a dict used for saving inputs and outputs.
557
558
        :param inputs: dict of model inputs
559
        :param outputs: dict of model outputs
560
        :return: tuple, indices and a dict.
561
            In the dict, each value is (tensor, normalize, on_label), where
562
            - normalize = True if the tensor need to be normalized to [0, 1]
563
            - on_label = True if the tensor depends on label
564
        """
565
        indices, processed = super().postprocess(inputs=inputs, outputs=outputs)
566
        processed["dvf"] = (outputs["dvf"], True, False)
567
        return indices, processed
568
569
570
@REGISTRY.register_model(name="conditional")
571
class ConditionalModel(RegistrationModel):
572
    """
573
    A registration model predicts fixed image label without DDF or DVF.
574
    """
575
576
    name = "ConditionalModel"
577
578
    def build_model(self):
579
        """Build the model to be saved as self._model."""
580
        assert self.labeled
581
582
        # build inputs
583
        self._inputs = self.build_inputs()
584
        moving_image = self._inputs["moving_image"]
585
        fixed_image = self._inputs["fixed_image"]
586
        moving_label = self._inputs["moving_label"]
587
588
        # build ddf
589
        backbone_inputs = self.concat_images(moving_image, fixed_image, moving_label)
590
        backbone = REGISTRY.build_backbone(
591
            config=self.config["backbone"],
592
            default_args=dict(
593
                image_size=self.fixed_image_size,
594
                out_channels=1,
595
                out_kernel_initializer="glorot_uniform",
596
                out_activation="sigmoid",
597
            ),
598
        )
599
        # (batch, f_dim1, f_dim2, f_dim3)
600
        pred_fixed_label = backbone(inputs=backbone_inputs)
601
        pred_fixed_label = tf.squeeze(pred_fixed_label, axis=4)
602
603
        self._outputs = dict(pred_fixed_label=pred_fixed_label)
604
        return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
605
606
    def postprocess(
607
        self,
608
        inputs: Dict[str, tf.Tensor],
609
        outputs: Dict[str, tf.Tensor],
610
    ) -> Tuple[tf.Tensor, Dict]:
611
        """
612
        Return a dict used for saving inputs and outputs.
613
614
        :param inputs: dict of model inputs
615
        :param outputs: dict of model outputs
616
        :return: tuple, indices and a dict.
617
            In the dict, each value is (tensor, normalize, on_label), where
618
            - normalize = True if the tensor need to be normalized to [0, 1]
619
            - on_label = True if the tensor depends on label
620
        """
621
        indices = inputs["indices"]
622
        processed = dict(
623
            moving_image=(inputs["moving_image"], True, False),
624
            fixed_image=(inputs["fixed_image"], True, False),
625
            pred_fixed_label=(outputs["pred_fixed_label"], True, True),
626
            moving_label=(inputs["moving_label"], False, True),
627
            fixed_label=(inputs["fixed_label"], False, True),
628
        )
629
630
        return indices, processed
631