Passed
Pull Request — main (#719)
by Yunguan
01:13
created

deepreg.model.network.RegistrationModel.__init__()   A

Complexity

Conditions 1

Size

Total Lines 43
Code Lines 26

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 26
dl 0
loc 43
rs 9.256
c 0
b 0
f 0
cc 1
nop 9

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
import logging
2
import os
3
from abc import abstractmethod
4
from copy import deepcopy
5
from typing import Dict, Optional, Tuple
6
7
import tensorflow as tf
8
9
from deepreg.loss.label import DiceScore, compute_centroid_distance
10
from deepreg.model import layer, layer_util
11
from deepreg.model.backbone import GlobalNet
12
from deepreg.registry import REGISTRY
13
14
15
def dict_without(d: dict, key) -> dict:
16
    """
17
    Return a copy of the given dict without a certain key.
18
19
    :param d: dict to be copied.
20
    :param key: key to be removed.
21
    :return: the copy without a key
22
    """
23
    copied = deepcopy(d)
24
    copied.pop(key)
25
    return copied
26
27
28
class RegistrationModel(tf.keras.Model):
29
    """Interface for registration model."""
30
31
    def __init__(
32
        self,
33
        moving_image_size: Tuple,
34
        fixed_image_size: Tuple,
35
        index_size: int,
36
        labeled: bool,
37
        batch_size: int,
38
        config: dict,
39
        num_devices: int = 1,
40
        name: str = "RegistrationModel",
41
    ):
42
        """
43
        Init.
44
45
        :param moving_image_size: (m_dim1, m_dim2, m_dim3)
46
        :param fixed_image_size: (f_dim1, f_dim2, f_dim3)
47
        :param index_size: number of indices for identify each sample
48
        :param labeled: if the data is labeled
49
        :param batch_size: size of mini-batch
50
        :param config: config for method, backbone, and loss.
51
        :param num_devices: number of GPU used,
52
            global_batch_size = batch_size*num_devices
53
        :param name: name of the model
54
        """
55
        super().__init__(name=name)
56
        self.moving_image_size = moving_image_size
57
        self.fixed_image_size = fixed_image_size
58
        self.index_size = index_size
59
        self.labeled = labeled
60
        self.batch_size = batch_size
61
        self.config = config
62
        self.num_devices = num_devices
63
        self.global_batch_size = num_devices * batch_size * 1.0
64
        assert self.global_batch_size > 0
65
66
        self._inputs = None  # save inputs of self._model as dict
67
        self._outputs = None  # save outputs of self._model as dict
68
69
        self.grid_ref = layer_util.get_reference_grid(grid_size=fixed_image_size)[
70
            None, ...
71
        ]
72
        self._model: tf.keras.Model = self.build_model()
73
        self.build_loss()
74
75
    def get_config(self) -> dict:
76
        """Return the config dictionary for recreating this class."""
77
        return dict(
78
            moving_image_size=self.moving_image_size,
79
            fixed_image_size=self.fixed_image_size,
80
            index_size=self.index_size,
81
            labeled=self.labeled,
82
            batch_size=self.batch_size,
83
            config=self.config,
84
            num_devices=self.num_devices,
85
            name=self.name,
86
        )
87
88
    @abstractmethod
89
    def build_model(self):
90
        """Build the model to be saved as self._model."""
91
92
    def build_inputs(self) -> Dict[str, tf.keras.layers.Input]:
93
        """
94
        Build input tensors.
95
96
        :return: dict of inputs.
97
        """
98
        # (batch, m_dim1, m_dim2, m_dim3, 1)
99
        moving_image = tf.keras.Input(
100
            shape=self.moving_image_size,
101
            batch_size=self.batch_size,
102
            name="moving_image",
103
        )
104
        # (batch, f_dim1, f_dim2, f_dim3, 1)
105
        fixed_image = tf.keras.Input(
106
            shape=self.fixed_image_size,
107
            batch_size=self.batch_size,
108
            name="fixed_image",
109
        )
110
        # (batch, index_size)
111
        indices = tf.keras.Input(
112
            shape=(self.index_size,),
113
            batch_size=self.batch_size,
114
            name="indices",
115
        )
116
117
        if not self.labeled:
118
            return dict(
119
                moving_image=moving_image, fixed_image=fixed_image, indices=indices
120
            )
121
122
        # (batch, m_dim1, m_dim2, m_dim3, 1)
123
        moving_label = tf.keras.Input(
124
            shape=self.moving_image_size,
125
            batch_size=self.batch_size,
126
            name="moving_label",
127
        )
128
        # (batch, m_dim1, m_dim2, m_dim3, 1)
129
        fixed_label = tf.keras.Input(
130
            shape=self.fixed_image_size,
131
            batch_size=self.batch_size,
132
            name="fixed_label",
133
        )
134
        return dict(
135
            moving_image=moving_image,
136
            fixed_image=fixed_image,
137
            moving_label=moving_label,
138
            fixed_label=fixed_label,
139
            indices=indices,
140
        )
141
142
    def concat_images(
143
        self,
144
        moving_image: tf.Tensor,
145
        fixed_image: tf.Tensor,
146
        moving_label: Optional[tf.Tensor] = None,
147
    ) -> tf.Tensor:
148
        """
149
        Adjust image shape and concatenate them together.
150
151
        :param moving_image: registration source
152
        :param fixed_image: registration target
153
        :param moving_label: optional, only used for conditional model.
154
        :return:
155
        """
156
        images = []
157
158
        resize_layer = layer.Resize3d(shape=self.fixed_image_size)
159
160
        # (batch, m_dim1, m_dim2, m_dim3, 1)
161
        moving_image = tf.expand_dims(moving_image, axis=4)
162
        moving_image = resize_layer(moving_image)
163
        images.append(moving_image)
164
165
        # (batch, m_dim1, m_dim2, m_dim3, 1)
166
        fixed_image = tf.expand_dims(fixed_image, axis=4)
167
        images.append(fixed_image)
168
169
        # (batch, m_dim1, m_dim2, m_dim3, 1)
170
        if moving_label is not None:
171
            moving_label = tf.expand_dims(moving_label, axis=4)
172
            moving_label = resize_layer(moving_label)
173
            images.append(moving_label)
174
175
        # (batch, f_dim1, f_dim2, f_dim3, 2 or 3)
176
        images = tf.concat(images, axis=4)
177
        return images
178
179
    def _build_loss(self, name: str, inputs_dict: dict):
180
        """
181
        Build and add one weighted loss together with the metrics.
182
183
        :param name: name of loss, image / label / regularization.
184
        :param inputs_dict: inputs for loss function
185
        """
186
187
        if name not in self.config["loss"]:
188
            # loss config is not defined
189
            logging.warning(
190
                f"The configuration for loss {name} is not defined. "
191
                f"Therefore it is not used."
192
            )
193
            return
194
195
        loss_configs = self.config["loss"][name]
196
        if not isinstance(loss_configs, list):
197
            loss_configs = [loss_configs]
198
199
        for loss_config in loss_configs:
200
201
            if "weight" not in loss_config:
202
                # default loss weight 1
203
                logging.warning(
204
                    f"The weight for loss {name} is not defined."
205
                    f"Default weight = 1.0 is used."
206
                )
207
                loss_config["weight"] = 1.0
208
209
            # build loss
210
            weight = loss_config["weight"]
211
212
            if weight == 0:
213
                logging.warning(
214
                    f"The weight for loss {name} is zero." f"Loss is not used."
215
                )
216
                return
217
218
            loss_layer: tf.keras.layers.Layer = REGISTRY.build_loss(
219
                config=dict_without(d=loss_config, key="weight")
220
            )
221
            loss_value = loss_layer(**inputs_dict) / self.global_batch_size
222
            weighted_loss = loss_value * weight if weight != 1 else loss_value
223
224
            # add loss
225
            self._model.add_loss(weighted_loss)
226
227
            loss_value = tf.debugging.check_numerics(
228
                loss_value,
229
                f"loss {name}_{loss_layer.name} inf/nan",
230
                name=f"op_loss_{name}_{loss_layer.name}",
231
            )
232
233
            # add metric
234
            self._model.add_metric(
235
                loss_value, name=f"loss/{name}_{loss_layer.name}", aggregation="mean"
236
            )
237
            self._model.add_metric(
238
                weighted_loss,
239
                name=f"loss/{name}_{loss_layer.name}_weighted",
240
                aggregation="mean",
241
            )
242
243
    @abstractmethod
244
    def build_loss(self):
245
        """Build losses according to configs."""
246
247
        # input metrics
248
        fixed_image = self._inputs["fixed_image"]
249
        moving_image = self._inputs["moving_image"]
250
        self.log_tensor_stats(tensor=moving_image, name="moving_image")
251
        self.log_tensor_stats(tensor=fixed_image, name="fixed_image")
252
253
        # image loss, conditional model does not have this
254
        if "pred_fixed_image" in self._outputs:
255
            pred_fixed_image = self._outputs["pred_fixed_image"]
256
            self._build_loss(
257
                name="image",
258
                inputs_dict=dict(y_true=fixed_image, y_pred=pred_fixed_image),
259
            )
260
261
        if self.labeled:
262
            # input metrics
263
            fixed_label = self._inputs["fixed_label"]
264
            moving_label = self._inputs["moving_label"]
265
            self.log_tensor_stats(tensor=moving_label, name="moving_label")
266
            self.log_tensor_stats(tensor=fixed_label, name="fixed_label")
267
268
            # label loss
269
            pred_fixed_label = self._outputs["pred_fixed_label"]
270
            self._build_loss(
271
                name="label",
272
                inputs_dict=dict(y_true=fixed_label, y_pred=pred_fixed_label),
273
            )
274
275
            # additional label metrics
276
            tre = compute_centroid_distance(
277
                y_true=fixed_label, y_pred=pred_fixed_label, grid=self.grid_ref
278
            )
279
            dice_binary = DiceScore(binary=True)(
280
                y_true=fixed_label, y_pred=pred_fixed_label
281
            )
282
            self._model.add_metric(tre, name="metric/TRE", aggregation="mean")
283
            self._model.add_metric(
284
                dice_binary, name="metric/BinaryDiceScore", aggregation="mean"
285
            )
286
287
    def call(
288
        self, inputs: Dict[str, tf.Tensor], training=None, mask=None
289
    ) -> Dict[str, tf.Tensor]:
290
        """
291
        Call the self._model.
292
293
        :param inputs: a dict of tensors.
294
        :param training: training or not.
295
        :param mask: maks for inputs.
296
        :return:
297
        """
298
        return self._model(inputs, training=training, mask=mask)  # pragma: no cover
299
300
    @abstractmethod
301
    def postprocess(
302
        self,
303
        inputs: Dict[str, tf.Tensor],
304
        outputs: Dict[str, tf.Tensor],
305
    ) -> Tuple[tf.Tensor, Dict]:
306
        """
307
        Return a dict used for saving inputs and outputs.
308
309
        :param inputs: dict of model inputs
310
        :param outputs: dict of model outputs
311
        :return: tuple, indices and a dict.
312
            In the dict, each value is (tensor, normalize, on_label), where
313
            - normalize = True if the tensor need to be normalized to [0, 1]
314
            - on_label = True if the tensor depends on label
315
        """
316
317
    def plot_model(self, output_dir: str):
318
        """
319
        Save model structure in png.
320
321
        :param output_dir: path to the output dir.
322
        """
323
        logging.info(self._model.summary())
324
        try:
325
            tf.keras.utils.plot_model(
326
                self._model,
327
                to_file=os.path.join(output_dir, f"{self.name}.png"),
328
                dpi=96,
329
                show_shapes=True,
330
                show_layer_names=True,
331
                expand_nested=False,
332
            )
333
        except ImportError as err:  # pragma: no cover
334
            logging.error(
335
                "Failed to plot model structure."
336
                "Please check if graphviz is installed.\n"
337
                "Error message is:"
338
                f"{err}"
339
            )
340
341
    def log_tensor_stats(self, tensor: tf.Tensor, name: str):
342
        """
343
        Log statistics of a given tensor.
344
345
        :param tensor: tensor to monitor.
346
        :param name: name of the tensor.
347
        """
348
        flatten = tf.reshape(tensor, shape=(self.batch_size, -1))
349
        self._model.add_metric(
350
            tf.reduce_mean(flatten, axis=1),
351
            name=f"metric/{name}_mean",
352
            aggregation="mean",
353
        )
354
        self._model.add_metric(
355
            tf.reduce_min(flatten, axis=1),
356
            name=f"metric/{name}_min",
357
            aggregation="min",
358
        )
359
        self._model.add_metric(
360
            tf.reduce_max(flatten, axis=1),
361
            name=f"metric/{name}_max",
362
            aggregation="max",
363
        )
364
365
366
@REGISTRY.register_model(name="ddf")
367
class DDFModel(RegistrationModel):
368
    """
369
    A registration model predicts DDF.
370
371
    When using global net as backbone,
372
    the model predicts an affine transformation parameters,
373
    and a DDF is calculated based on that.
374
    """
375
376
    name = "DDFModel"
377
378
    def _resize_interpolate(self, field, control_points):
379
        resize = layer.ResizeCPTransform(control_points)
380
        field = resize(field)
381
382
        interpolate = layer.BSplines3DTransform(control_points, self.fixed_image_size)
383
        field = interpolate(field)
384
385
        return field
386
387
    def build_model(self):
388
        """Build the model to be saved as self._model."""
389
        # build inputs
390
        self._inputs = self.build_inputs()
391
        moving_image = self._inputs["moving_image"]
392
        fixed_image = self._inputs["fixed_image"]
393
394
        # build ddf
395
        control_points = self.config["backbone"].pop("control_points", False)
396
        backbone_inputs = self.concat_images(moving_image, fixed_image)
397
        backbone = REGISTRY.build_backbone(
398
            config=self.config["backbone"],
399
            default_args=dict(
400
                image_size=self.fixed_image_size,
401
                out_channels=3,
402
                out_kernel_initializer="zeros",
403
                out_activation=None,
404
            ),
405
        )
406
407
        if isinstance(backbone, GlobalNet):
408
            # (f_dim1, f_dim2, f_dim3, 3), (4, 3)
409
            ddf, theta = backbone(inputs=backbone_inputs)
410
            self._outputs = dict(ddf=ddf, theta=theta)
411
        else:
412
            # (f_dim1, f_dim2, f_dim3, 3)
413
            ddf = backbone(inputs=backbone_inputs)
414
            ddf = (
415
                self._resize_interpolate(ddf, control_points) if control_points else ddf
416
            )
417
            self._outputs = dict(ddf=ddf)
418
419
        # build outputs
420
        warping = layer.Warping(fixed_image_size=self.fixed_image_size)
421
        # (f_dim1, f_dim2, f_dim3, 3)
422
        pred_fixed_image = warping(inputs=[ddf, moving_image])
423
        self._outputs["pred_fixed_image"] = pred_fixed_image
424
425
        if not self.labeled:
426
            return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
427
428
        # (f_dim1, f_dim2, f_dim3, 3)
429
        moving_label = self._inputs["moving_label"]
430
        pred_fixed_label = warping(inputs=[ddf, moving_label])
431
432
        self._outputs["pred_fixed_label"] = pred_fixed_label
433
        return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
434
435
    def build_loss(self):
436
        """Build losses according to configs."""
437
        super().build_loss()
438
439
        # ddf loss and metrics
440
        ddf = self._outputs["ddf"]
441
        self._build_loss(name="regularization", inputs_dict=dict(inputs=ddf))
442
        self.log_tensor_stats(tensor=ddf, name="ddf")
443
444
    def postprocess(
445
        self,
446
        inputs: Dict[str, tf.Tensor],
447
        outputs: Dict[str, tf.Tensor],
448
    ) -> Tuple[tf.Tensor, Dict]:
449
        """
450
        Return a dict used for saving inputs and outputs.
451
452
        :param inputs: dict of model inputs
453
        :param outputs: dict of model outputs
454
        :return: tuple, indices and a dict.
455
            In the dict, each value is (tensor, normalize, on_label), where
456
            - normalize = True if the tensor need to be normalized to [0, 1]
457
            - on_label = True if the tensor depends on label
458
        """
459
        indices = inputs["indices"]
460
        processed = dict(
461
            moving_image=(inputs["moving_image"], True, False),
462
            fixed_image=(inputs["fixed_image"], True, False),
463
            ddf=(outputs["ddf"], True, False),
464
            pred_fixed_image=(outputs["pred_fixed_image"], True, False),
465
        )
466
467
        # save theta for affine model
468
        if "theta" in outputs:
469
            processed["theta"] = (outputs["theta"], None, None)  # type: ignore
470
471
        if not self.labeled:
472
            return indices, processed
473
474
        processed = {
475
            **dict(
476
                moving_label=(inputs["moving_label"], False, True),
477
                fixed_label=(inputs["fixed_label"], False, True),
478
                pred_fixed_label=(outputs["pred_fixed_label"], False, True),
479
            ),
480
            **processed,
481
        }
482
483
        return indices, processed
484
485
486
@REGISTRY.register_model(name="dvf")
487
class DVFModel(DDFModel):
488
    """
489
    A registration model predicts DVF.
490
491
    DDF is calculated based on DVF.
492
    """
493
494
    name = "DVFModel"
495
496
    def build_model(self):
497
        """Build the model to be saved as self._model."""
498
        # build inputs
499
        self._inputs = self.build_inputs()
500
        moving_image = self._inputs["moving_image"]
501
        fixed_image = self._inputs["fixed_image"]
502
        control_points = self.config["backbone"].pop("control_points", False)
503
504
        # build ddf
505
        backbone_inputs = self.concat_images(moving_image, fixed_image)
506
        backbone = REGISTRY.build_backbone(
507
            config=self.config["backbone"],
508
            default_args=dict(
509
                image_size=self.fixed_image_size,
510
                out_channels=3,
511
                out_kernel_initializer="zeros",
512
                out_activation=None,
513
            ),
514
        )
515
        dvf = backbone(inputs=backbone_inputs)
516
        dvf = self._resize_interpolate(dvf, control_points) if control_points else dvf
517
        ddf = layer.IntDVF(fixed_image_size=self.fixed_image_size)(dvf)
518
519
        # build outputs
520
        self._warping = layer.Warping(fixed_image_size=self.fixed_image_size)
521
        # (f_dim1, f_dim2, f_dim3, 3)
522
        pred_fixed_image = self._warping(inputs=[ddf, moving_image])
523
524
        self._outputs = dict(dvf=dvf, ddf=ddf, pred_fixed_image=pred_fixed_image)
525
526
        if not self.labeled:
527
            return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
528
529
        # (f_dim1, f_dim2, f_dim3, 3)
530
        moving_label = self._inputs["moving_label"]
531
        pred_fixed_label = self._warping(inputs=[ddf, moving_label])
532
533
        self._outputs["pred_fixed_label"] = pred_fixed_label
534
        return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
535
536
    def build_loss(self):
537
        """Build losses according to configs."""
538
        super().build_loss()
539
540
        # dvf metrics
541
        dvf = self._outputs["dvf"]
542
        self.log_tensor_stats(tensor=dvf, name="dvf")
543
544
    def postprocess(
545
        self,
546
        inputs: Dict[str, tf.Tensor],
547
        outputs: Dict[str, tf.Tensor],
548
    ) -> Tuple[tf.Tensor, Dict]:
549
        """
550
        Return a dict used for saving inputs and outputs.
551
552
        :param inputs: dict of model inputs
553
        :param outputs: dict of model outputs
554
        :return: tuple, indices and a dict.
555
            In the dict, each value is (tensor, normalize, on_label), where
556
            - normalize = True if the tensor need to be normalized to [0, 1]
557
            - on_label = True if the tensor depends on label
558
        """
559
        indices, processed = super().postprocess(inputs=inputs, outputs=outputs)
560
        processed["dvf"] = (outputs["dvf"], True, False)
561
        return indices, processed
562
563
564
@REGISTRY.register_model(name="conditional")
565
class ConditionalModel(RegistrationModel):
566
    """
567
    A registration model predicts fixed image label without DDF or DVF.
568
    """
569
570
    name = "ConditionalModel"
571
572
    def build_model(self):
573
        """Build the model to be saved as self._model."""
574
        assert self.labeled
575
576
        # build inputs
577
        self._inputs = self.build_inputs()
578
        moving_image = self._inputs["moving_image"]
579
        fixed_image = self._inputs["fixed_image"]
580
        moving_label = self._inputs["moving_label"]
581
582
        # build ddf
583
        backbone_inputs = self.concat_images(moving_image, fixed_image, moving_label)
584
        backbone = REGISTRY.build_backbone(
585
            config=self.config["backbone"],
586
            default_args=dict(
587
                image_size=self.fixed_image_size,
588
                out_channels=1,
589
                out_kernel_initializer="glorot_uniform",
590
                out_activation="sigmoid",
591
            ),
592
        )
593
        # (batch, f_dim1, f_dim2, f_dim3)
594
        pred_fixed_label = backbone(inputs=backbone_inputs)
595
        pred_fixed_label = tf.squeeze(pred_fixed_label, axis=4)
596
597
        self._outputs = dict(pred_fixed_label=pred_fixed_label)
598
        return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
599
600
    def postprocess(
601
        self,
602
        inputs: Dict[str, tf.Tensor],
603
        outputs: Dict[str, tf.Tensor],
604
    ) -> Tuple[tf.Tensor, Dict]:
605
        """
606
        Return a dict used for saving inputs and outputs.
607
608
        :param inputs: dict of model inputs
609
        :param outputs: dict of model outputs
610
        :return: tuple, indices and a dict.
611
            In the dict, each value is (tensor, normalize, on_label), where
612
            - normalize = True if the tensor need to be normalized to [0, 1]
613
            - on_label = True if the tensor depends on label
614
        """
615
        indices = inputs["indices"]
616
        processed = dict(
617
            moving_image=(inputs["moving_image"], True, False),
618
            fixed_image=(inputs["fixed_image"], True, False),
619
            pred_fixed_label=(outputs["pred_fixed_label"], True, True),
620
            moving_label=(inputs["moving_label"], False, True),
621
            fixed_label=(inputs["fixed_label"], False, True),
622
        )
623
624
        return indices, processed
625