Completed
Push — main ( 3b5a47...3ab0c2 )
by
unknown
30s queued 13s
created

deepreg.model.network.RegistrationModel.__init__()   A

Complexity

Conditions 1

Size

Total Lines 42
Code Lines 25

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 25
dl 0
loc 42
rs 9.28
c 0
b 0
f 0
cc 1
nop 9

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
import logging
2
import os
3
from abc import abstractmethod
4
from copy import deepcopy
5
from typing import Dict, Optional, Tuple
6
7
import tensorflow as tf
8
9
from deepreg.loss.label import DiceScore, compute_centroid_distance
10
from deepreg.model import layer, layer_util
11
from deepreg.model.backbone import GlobalNet
12
from deepreg.registry import REGISTRY
13
14
15
def dict_without(d: dict, key) -> dict:
16
    """
17
    Return a copy of the given dict without a certain key.
18
19
    :param d: dict to be copied.
20
    :param key: key to be removed.
21
    :return: the copy without a key
22
    """
23
    copied = deepcopy(d)
24
    copied.pop(key)
25
    return copied
26
27
28
class RegistrationModel(tf.keras.Model):
29
    """Interface for registration model."""
30
31
    def __init__(
32
        self,
33
        moving_image_size: tuple,
34
        fixed_image_size: tuple,
35
        index_size: int,
36
        labeled: bool,
37
        batch_size: int,
38
        config: dict,
39
        num_devices: int = 1,
40
        name: str = "RegistrationModel",
41
    ):
42
        """
43
        Init.
44
45
        :param moving_image_size: (m_dim1, m_dim2, m_dim3)
46
        :param fixed_image_size: (f_dim1, f_dim2, f_dim3)
47
        :param index_size: number of indices for identify each sample
48
        :param labeled: if the data is labeled
49
        :param batch_size: size of mini-batch
50
        :param config: config for method, backbone, and loss.
51
        :param num_devices: number of GPU used,
52
            global_batch_size = batch_size*num_devices
53
        :param name: name of the model
54
        """
55
        super().__init__(name=name)
56
        self.moving_image_size = moving_image_size
57
        self.fixed_image_size = fixed_image_size
58
        self.index_size = index_size
59
        self.labeled = labeled
60
        self.batch_size = batch_size
61
        self.config = config
62
        self.num_devices = num_devices
63
        self.global_batch_size = num_devices * batch_size
64
65
        self._inputs = None  # save inputs of self._model as dict
66
        self._outputs = None  # save outputs of self._model as dict
67
68
        self.grid_ref = layer_util.get_reference_grid(grid_size=fixed_image_size)[
69
            None, ...
70
        ]
71
        self._model: tf.keras.Model = self.build_model()
72
        self.build_loss()
73
74
    def get_config(self) -> dict:
75
        """Return the config dictionary for recreating this class."""
76
        return dict(
77
            moving_image_size=self.moving_image_size,
78
            fixed_image_size=self.fixed_image_size,
79
            index_size=self.index_size,
80
            labeled=self.labeled,
81
            batch_size=self.batch_size,
82
            config=self.config,
83
            num_devices=self.num_devices,
84
            name=self.name,
85
        )
86
87
    @abstractmethod
88
    def build_model(self):
89
        """Build the model to be saved as self._model."""
90
91
    def build_inputs(self) -> Dict[str, tf.keras.layers.Input]:
92
        """
93
        Build input tensors.
94
95
        :return: dict of inputs.
96
        """
97
        # (batch, m_dim1, m_dim2, m_dim3, 1)
98
        moving_image = tf.keras.Input(
99
            shape=self.moving_image_size,
100
            batch_size=self.batch_size,
101
            name="moving_image",
102
        )
103
        # (batch, f_dim1, f_dim2, f_dim3, 1)
104
        fixed_image = tf.keras.Input(
105
            shape=self.fixed_image_size,
106
            batch_size=self.batch_size,
107
            name="fixed_image",
108
        )
109
        # (batch, index_size)
110
        indices = tf.keras.Input(
111
            shape=(self.index_size,),
112
            batch_size=self.batch_size,
113
            name="indices",
114
        )
115
116
        if not self.labeled:
117
            return dict(
118
                moving_image=moving_image, fixed_image=fixed_image, indices=indices
119
            )
120
121
        # (batch, m_dim1, m_dim2, m_dim3, 1)
122
        moving_label = tf.keras.Input(
123
            shape=self.moving_image_size,
124
            batch_size=self.batch_size,
125
            name="moving_label",
126
        )
127
        # (batch, m_dim1, m_dim2, m_dim3, 1)
128
        fixed_label = tf.keras.Input(
129
            shape=self.fixed_image_size,
130
            batch_size=self.batch_size,
131
            name="fixed_label",
132
        )
133
        return dict(
134
            moving_image=moving_image,
135
            fixed_image=fixed_image,
136
            moving_label=moving_label,
137
            fixed_label=fixed_label,
138
            indices=indices,
139
        )
140
141
    def concat_images(
142
        self,
143
        moving_image: tf.Tensor,
144
        fixed_image: tf.Tensor,
145
        moving_label: Optional[tf.Tensor] = None,
146
    ) -> tf.Tensor:
147
        """
148
        Adjust image shape and concatenate them together.
149
150
        :param moving_image: registration source
151
        :param fixed_image: registration target
152
        :param moving_label: optional, only used for conditional model.
153
        :return:
154
        """
155
        images = []
156
157
        resize_layer = layer.Resize3d(shape=self.fixed_image_size)
158
159
        # (batch, m_dim1, m_dim2, m_dim3, 1)
160
        moving_image = tf.expand_dims(moving_image, axis=4)
161
        moving_image = resize_layer(moving_image)
162
        images.append(moving_image)
163
164
        # (batch, m_dim1, m_dim2, m_dim3, 1)
165
        fixed_image = tf.expand_dims(fixed_image, axis=4)
166
        images.append(fixed_image)
167
168
        # (batch, m_dim1, m_dim2, m_dim3, 1)
169
        if moving_label is not None:
170
            moving_label = tf.expand_dims(moving_label, axis=4)
171
            moving_label = resize_layer(moving_label)
172
            images.append(moving_label)
173
174
        # (batch, f_dim1, f_dim2, f_dim3, 2 or 3)
175
        images = tf.concat(images, axis=4)
176
        return images
177
178
    def _build_loss(self, name: str, inputs_dict: dict):
179
        """
180
        Build and add one weighted loss together with the metrics.
181
182
        :param name: name of loss, image / label / regularization.
183
        :param inputs_dict: inputs for loss function
184
        """
185
186
        if name not in self.config["loss"]:
187
            # loss config is not defined
188
            logging.warning(
189
                f"The configuration for loss {name} is not defined. "
190
                f"Therefore it is not used."
191
            )
192
            return
193
194
        loss_configs = self.config["loss"][name]
195
        if not isinstance(loss_configs, list):
196
            loss_configs = [loss_configs]
197
198
        for loss_config in loss_configs:
199
200
            if "weight" not in loss_config:
201
                # default loss weight 1
202
                logging.warning(
203
                    f"The weight for loss {name} is not defined."
204
                    f"Default weight = 1.0 is used."
205
                )
206
                loss_config["weight"] = 1.0
207
208
            # build loss
209
            weight = loss_config["weight"]
210
211
            if weight == 0:
212
                logging.warning(
213
                    f"The weight for loss {name} is zero." f"Loss is not used."
214
                )
215
                return
216
217
            loss_layer: tf.keras.layers.Layer = REGISTRY.build_loss(
218
                config=dict_without(d=loss_config, key="weight")
219
            )
220
            loss_value = loss_layer(**inputs_dict) / self.global_batch_size
221
            weighted_loss = loss_value * weight
222
223
            # add loss
224
            self._model.add_loss(weighted_loss)
225
226
            # add metric
227
            self._model.add_metric(
228
                loss_value, name=f"loss/{name}_{loss_layer.name}", aggregation="mean"
229
            )
230
            self._model.add_metric(
231
                weighted_loss,
232
                name=f"loss/{name}_{loss_layer.name}_weighted",
233
                aggregation="mean",
234
            )
235
236
    @abstractmethod
237
    def build_loss(self):
238
        """Build losses according to configs."""
239
240
        # input metrics
241
        fixed_image = self._inputs["fixed_image"]
242
        moving_image = self._inputs["moving_image"]
243
        self.log_tensor_stats(tensor=moving_image, name="moving_image")
244
        self.log_tensor_stats(tensor=fixed_image, name="fixed_image")
245
246
        # image loss, conditional model does not have this
247
        if "pred_fixed_image" in self._outputs:
248
            pred_fixed_image = self._outputs["pred_fixed_image"]
249
            self._build_loss(
250
                name="image",
251
                inputs_dict=dict(y_true=fixed_image, y_pred=pred_fixed_image),
252
            )
253
254
        if self.labeled:
255
            # input metrics
256
            fixed_label = self._inputs["fixed_label"]
257
            moving_label = self._inputs["moving_label"]
258
            self.log_tensor_stats(tensor=moving_label, name="moving_label")
259
            self.log_tensor_stats(tensor=fixed_label, name="fixed_label")
260
261
            # label loss
262
            pred_fixed_label = self._outputs["pred_fixed_label"]
263
            self._build_loss(
264
                name="label",
265
                inputs_dict=dict(y_true=fixed_label, y_pred=pred_fixed_label),
266
            )
267
268
            # additional label metrics
269
            tre = compute_centroid_distance(
270
                y_true=fixed_label, y_pred=pred_fixed_label, grid=self.grid_ref
271
            )
272
            dice_binary = DiceScore(binary=True)(
273
                y_true=fixed_label, y_pred=pred_fixed_label
274
            )
275
            self._model.add_metric(tre, name="metric/TRE", aggregation="mean")
276
            self._model.add_metric(
277
                dice_binary, name="metric/BinaryDiceScore", aggregation="mean"
278
            )
279
280
    def call(
281
        self, inputs: Dict[str, tf.Tensor], training=None, mask=None
282
    ) -> Dict[str, tf.Tensor]:
283
        """
284
        Call the self._model.
285
286
        :param inputs: a dict of tensors.
287
        :param training: training or not.
288
        :param mask: maks for inputs.
289
        :return:
290
        """
291
        return self._model(inputs, training=training, mask=mask)  # pragma: no cover
292
293
    @abstractmethod
294
    def postprocess(
295
        self,
296
        inputs: Dict[str, tf.Tensor],
297
        outputs: Dict[str, tf.Tensor],
298
    ) -> Tuple[tf.Tensor, Dict]:
299
        """
300
        Return a dict used for saving inputs and outputs.
301
302
        :param inputs: dict of model inputs
303
        :param outputs: dict of model outputs
304
        :return: tuple, indices and a dict.
305
            In the dict, each value is (tensor, normalize, on_label), where
306
            - normalize = True if the tensor need to be normalized to [0, 1]
307
            - on_label = True if the tensor depends on label
308
        """
309
310
    def plot_model(self, output_dir: str):
311
        """
312
        Save model structure in png.
313
314
        :param output_dir: path to the output dir.
315
        """
316
        logging.info(self._model.summary())
317
        try:
318
            tf.keras.utils.plot_model(
319
                self._model,
320
                to_file=os.path.join(output_dir, f"{self.name}.png"),
321
                dpi=96,
322
                show_shapes=True,
323
                show_layer_names=True,
324
                expand_nested=False,
325
            )
326
        except ImportError as err:  # pragma: no cover
327
            logging.error(
328
                "Failed to plot model structure."
329
                "Please check if graphviz is installed.\n"
330
                "Error message is:"
331
                f"{err}"
332
            )
333
334
    def log_tensor_stats(self, tensor: tf.Tensor, name: str):
335
        """
336
        Log statistics of a given tensor.
337
338
        :param tensor: tensor to monitor.
339
        :param name: name of the tensor.
340
        """
341
        flatten = tf.reshape(tensor, shape=(self.batch_size, -1))
342
        self._model.add_metric(
343
            tf.reduce_mean(flatten, axis=1),
344
            name=f"metric/{name}_mean",
345
            aggregation="mean",
346
        )
347
        self._model.add_metric(
348
            tf.reduce_min(flatten, axis=1),
349
            name=f"metric/{name}_min",
350
            aggregation="min",
351
        )
352
        self._model.add_metric(
353
            tf.reduce_max(flatten, axis=1),
354
            name=f"metric/{name}_max",
355
            aggregation="max",
356
        )
357
358
359
@REGISTRY.register_model(name="ddf")
360
class DDFModel(RegistrationModel):
361
    """
362
    A registration model predicts DDF.
363
364
    When using global net as backbone,
365
    the model predicts an affine transformation parameters,
366
    and a DDF is calculated based on that.
367
    """
368
369
    name = "DDFModel"
370
371
    def _resize_interpolate(self, field, control_points):
372
        resize = layer.ResizeCPTransform(control_points)
373
        field = resize(field)
374
375
        interpolate = layer.BSplines3DTransform(control_points, self.fixed_image_size)
376
        field = interpolate(field)
377
378
        return field
379
380
    def build_model(self):
381
        """Build the model to be saved as self._model."""
382
        # build inputs
383
        self._inputs = self.build_inputs()
384
        moving_image = self._inputs["moving_image"]
385
        fixed_image = self._inputs["fixed_image"]
386
387
        # build ddf
388
        control_points = self.config["backbone"].pop("control_points", False)
389
        backbone_inputs = self.concat_images(moving_image, fixed_image)
390
        backbone = REGISTRY.build_backbone(
391
            config=self.config["backbone"],
392
            default_args=dict(
393
                image_size=self.fixed_image_size,
394
                out_channels=3,
395
                out_kernel_initializer="zeros",
396
                out_activation=None,
397
            ),
398
        )
399
400
        if isinstance(backbone, GlobalNet):
401
            # (f_dim1, f_dim2, f_dim3, 3), (4, 3)
402
            ddf, theta = backbone(inputs=backbone_inputs)
403
            self._outputs = dict(ddf=ddf, theta=theta)
404
        else:
405
            # (f_dim1, f_dim2, f_dim3, 3)
406
            ddf = backbone(inputs=backbone_inputs)
407
            ddf = (
408
                self._resize_interpolate(ddf, control_points) if control_points else ddf
409
            )
410
            self._outputs = dict(ddf=ddf)
411
412
        # build outputs
413
        warping = layer.Warping(fixed_image_size=self.fixed_image_size)
414
        # (f_dim1, f_dim2, f_dim3, 3)
415
        pred_fixed_image = warping(inputs=[ddf, moving_image])
416
        self._outputs["pred_fixed_image"] = pred_fixed_image
417
418
        if not self.labeled:
419
            return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
420
421
        # (f_dim1, f_dim2, f_dim3, 3)
422
        moving_label = self._inputs["moving_label"]
423
        pred_fixed_label = warping(inputs=[ddf, moving_label])
424
425
        self._outputs["pred_fixed_label"] = pred_fixed_label
426
        return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
427
428
    def build_loss(self):
429
        """Build losses according to configs."""
430
        super().build_loss()
431
432
        # ddf loss and metrics
433
        ddf = self._outputs["ddf"]
434
        self._build_loss(name="regularization", inputs_dict=dict(inputs=ddf))
435
        self.log_tensor_stats(tensor=ddf, name="ddf")
436
437
    def postprocess(
438
        self,
439
        inputs: Dict[str, tf.Tensor],
440
        outputs: Dict[str, tf.Tensor],
441
    ) -> Tuple[tf.Tensor, Dict]:
442
        """
443
        Return a dict used for saving inputs and outputs.
444
445
        :param inputs: dict of model inputs
446
        :param outputs: dict of model outputs
447
        :return: tuple, indices and a dict.
448
            In the dict, each value is (tensor, normalize, on_label), where
449
            - normalize = True if the tensor need to be normalized to [0, 1]
450
            - on_label = True if the tensor depends on label
451
        """
452
        indices = inputs["indices"]
453
        processed = dict(
454
            moving_image=(inputs["moving_image"], True, False),
455
            fixed_image=(inputs["fixed_image"], True, False),
456
            ddf=(outputs["ddf"], True, False),
457
            pred_fixed_image=(outputs["pred_fixed_image"], True, False),
458
        )
459
460
        # save theta for affine model
461
        if "theta" in outputs:
462
            processed["theta"] = (outputs["theta"], None, None)  # type: ignore
463
464
        if not self.labeled:
465
            return indices, processed
466
467
        processed = {
468
            **dict(
469
                moving_label=(inputs["moving_label"], False, True),
470
                fixed_label=(inputs["fixed_label"], False, True),
471
                pred_fixed_label=(outputs["pred_fixed_label"], False, True),
472
            ),
473
            **processed,
474
        }
475
476
        return indices, processed
477
478
479
@REGISTRY.register_model(name="dvf")
480
class DVFModel(DDFModel):
481
    """
482
    A registration model predicts DVF.
483
484
    DDF is calculated based on DVF.
485
    """
486
487
    name = "DVFModel"
488
489
    def build_model(self):
490
        """Build the model to be saved as self._model."""
491
        # build inputs
492
        self._inputs = self.build_inputs()
493
        moving_image = self._inputs["moving_image"]
494
        fixed_image = self._inputs["fixed_image"]
495
        control_points = self.config["backbone"].pop("control_points", False)
496
497
        # build ddf
498
        backbone_inputs = self.concat_images(moving_image, fixed_image)
499
        backbone = REGISTRY.build_backbone(
500
            config=self.config["backbone"],
501
            default_args=dict(
502
                image_size=self.fixed_image_size,
503
                out_channels=3,
504
                out_kernel_initializer="zeros",
505
                out_activation=None,
506
            ),
507
        )
508
        dvf = backbone(inputs=backbone_inputs)
509
        dvf = self._resize_interpolate(dvf, control_points) if control_points else dvf
510
        ddf = layer.IntDVF(fixed_image_size=self.fixed_image_size)(dvf)
511
512
        # build outputs
513
        self._warping = layer.Warping(fixed_image_size=self.fixed_image_size)
514
        # (f_dim1, f_dim2, f_dim3, 3)
515
        pred_fixed_image = self._warping(inputs=[ddf, moving_image])
516
517
        self._outputs = dict(dvf=dvf, ddf=ddf, pred_fixed_image=pred_fixed_image)
518
519
        if not self.labeled:
520
            return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
521
522
        # (f_dim1, f_dim2, f_dim3, 3)
523
        moving_label = self._inputs["moving_label"]
524
        pred_fixed_label = self._warping(inputs=[ddf, moving_label])
525
526
        self._outputs["pred_fixed_label"] = pred_fixed_label
527
        return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
528
529
    def build_loss(self):
530
        """Build losses according to configs."""
531
        super().build_loss()
532
533
        # dvf metrics
534
        dvf = self._outputs["dvf"]
535
        self.log_tensor_stats(tensor=dvf, name="dvf")
536
537
    def postprocess(
538
        self,
539
        inputs: Dict[str, tf.Tensor],
540
        outputs: Dict[str, tf.Tensor],
541
    ) -> Tuple[tf.Tensor, Dict]:
542
        """
543
        Return a dict used for saving inputs and outputs.
544
545
        :param inputs: dict of model inputs
546
        :param outputs: dict of model outputs
547
        :return: tuple, indices and a dict.
548
            In the dict, each value is (tensor, normalize, on_label), where
549
            - normalize = True if the tensor need to be normalized to [0, 1]
550
            - on_label = True if the tensor depends on label
551
        """
552
        indices, processed = super().postprocess(inputs=inputs, outputs=outputs)
553
        processed["dvf"] = (outputs["dvf"], True, False)
554
        return indices, processed
555
556
557
@REGISTRY.register_model(name="conditional")
558
class ConditionalModel(RegistrationModel):
559
    """
560
    A registration model predicts fixed image label without DDF or DVF.
561
    """
562
563
    name = "ConditionalModel"
564
565
    def build_model(self):
566
        """Build the model to be saved as self._model."""
567
        assert self.labeled
568
569
        # build inputs
570
        self._inputs = self.build_inputs()
571
        moving_image = self._inputs["moving_image"]
572
        fixed_image = self._inputs["fixed_image"]
573
        moving_label = self._inputs["moving_label"]
574
575
        # build ddf
576
        backbone_inputs = self.concat_images(moving_image, fixed_image, moving_label)
577
        backbone = REGISTRY.build_backbone(
578
            config=self.config["backbone"],
579
            default_args=dict(
580
                image_size=self.fixed_image_size,
581
                out_channels=1,
582
                out_kernel_initializer="glorot_uniform",
583
                out_activation="sigmoid",
584
            ),
585
        )
586
        # (batch, f_dim1, f_dim2, f_dim3)
587
        pred_fixed_label = backbone(inputs=backbone_inputs)
588
        pred_fixed_label = tf.squeeze(pred_fixed_label, axis=4)
589
590
        self._outputs = dict(pred_fixed_label=pred_fixed_label)
591
        return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
592
593
    def postprocess(
594
        self,
595
        inputs: Dict[str, tf.Tensor],
596
        outputs: Dict[str, tf.Tensor],
597
    ) -> Tuple[tf.Tensor, Dict]:
598
        """
599
        Return a dict used for saving inputs and outputs.
600
601
        :param inputs: dict of model inputs
602
        :param outputs: dict of model outputs
603
        :return: tuple, indices and a dict.
604
            In the dict, each value is (tensor, normalize, on_label), where
605
            - normalize = True if the tensor need to be normalized to [0, 1]
606
            - on_label = True if the tensor depends on label
607
        """
608
        indices = inputs["indices"]
609
        processed = dict(
610
            moving_image=(inputs["moving_image"], True, False),
611
            fixed_image=(inputs["fixed_image"], True, False),
612
            pred_fixed_label=(outputs["pred_fixed_label"], True, True),
613
            moving_label=(inputs["moving_label"], False, True),
614
            fixed_label=(inputs["fixed_label"], False, True),
615
        )
616
617
        return indices, processed
618