ConditionalModel.postprocess()   A
last analyzed

Complexity

Conditions 1

Size

Total Lines 25
Code Lines 12

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 12
dl 0
loc 25
rs 9.8
c 0
b 0
f 0
cc 1
nop 3
1
import os
2
from abc import abstractmethod
3
from copy import deepcopy
4
from typing import Dict, Optional, Tuple
5
6
import tensorflow as tf
7
8
from deepreg import log
9
from deepreg.loss.label import compute_centroid_distance
10
from deepreg.model import layer, layer_util
11
from deepreg.model.backbone import GlobalNet
12
from deepreg.registry import REGISTRY
13
14
logger = log.get(__name__)
15
16
17
def dict_without(d: dict, key) -> dict:
18
    """
19
    Return a copy of the given dict without a certain key.
20
21
    :param d: dict to be copied.
22
    :param key: key to be removed.
23
    :return: the copy without a key
24
    """
25
    copied = deepcopy(d)
26
    copied.pop(key)
27
    return copied
28
29
30
class RegistrationModel(tf.keras.Model):
31
    """Interface for registration model."""
32
33
    def __init__(
34
        self,
35
        moving_image_size: Tuple,
36
        fixed_image_size: Tuple,
37
        index_size: int,
38
        labeled: bool,
39
        batch_size: int,
40
        config: dict,
41
        name: str = "RegistrationModel",
42
    ):
43
        """
44
        Init.
45
46
        :param moving_image_size: (m_dim1, m_dim2, m_dim3)
47
        :param fixed_image_size: (f_dim1, f_dim2, f_dim3)
48
        :param index_size: number of indices for identify each sample
49
        :param labeled: if the data is labeled
50
        :param batch_size: total number of samples consumed per step, over all devices.
51
            When using multiple devices, TensorFlow automatically split the tensors.
52
            Therefore, input shapes should be defined over batch_size.
53
        :param config: config for method, backbone, and loss.
54
        :param name: name of the model
55
        """
56
        super().__init__(name=name)
57
        self.moving_image_size = moving_image_size
58
        self.fixed_image_size = fixed_image_size
59
        self.index_size = index_size
60
        self.labeled = labeled
61
        self.config = config
62
        self.batch_size = batch_size
63
64
        self._inputs = None  # save inputs of self._model as dict
65
        self._outputs = None  # save outputs of self._model as dict
66
67
        self.grid_ref = layer_util.get_reference_grid(grid_size=fixed_image_size)[
68
            None, ...
69
        ]
70
        self._model: tf.keras.Model = self.build_model()
71
        self.build_loss()
72
73
    def get_config(self) -> dict:
74
        """Return the config dictionary for recreating this class."""
75
        return dict(
76
            moving_image_size=self.moving_image_size,
77
            fixed_image_size=self.fixed_image_size,
78
            index_size=self.index_size,
79
            labeled=self.labeled,
80
            batch_size=self.batch_size,
81
            config=self.config,
82
            name=self.name,
83
        )
84
85
    @abstractmethod
86
    def build_model(self):
87
        """Build the model to be saved as self._model."""
88
89
    def build_inputs(self) -> Dict[str, tf.keras.layers.Input]:
90
        """
91
        Build input tensors.
92
93
        :return: dict of inputs.
94
        """
95
        # (batch, m_dim1, m_dim2, m_dim3)
96
        moving_image = tf.keras.Input(
97
            shape=self.moving_image_size,
98
            batch_size=self.batch_size,
99
            name="moving_image",
100
        )
101
        # (batch, f_dim1, f_dim2, f_dim3)
102
        fixed_image = tf.keras.Input(
103
            shape=self.fixed_image_size,
104
            batch_size=self.batch_size,
105
            name="fixed_image",
106
        )
107
        # (batch, index_size)
108
        indices = tf.keras.Input(
109
            shape=(self.index_size,),
110
            batch_size=self.batch_size,
111
            name="indices",
112
        )
113
114
        if not self.labeled:
115
            return dict(
116
                moving_image=moving_image, fixed_image=fixed_image, indices=indices
117
            )
118
119
        # (batch, m_dim1, m_dim2, m_dim3)
120
        moving_label = tf.keras.Input(
121
            shape=self.moving_image_size,
122
            batch_size=self.batch_size,
123
            name="moving_label",
124
        )
125
        # (batch, m_dim1, m_dim2, m_dim3)
126
        fixed_label = tf.keras.Input(
127
            shape=self.fixed_image_size,
128
            batch_size=self.batch_size,
129
            name="fixed_label",
130
        )
131
        return dict(
132
            moving_image=moving_image,
133
            fixed_image=fixed_image,
134
            moving_label=moving_label,
135
            fixed_label=fixed_label,
136
            indices=indices,
137
        )
138
139
    def concat_images(
140
        self,
141
        moving_image: tf.Tensor,
142
        fixed_image: tf.Tensor,
143
        moving_label: Optional[tf.Tensor] = None,
144
    ) -> tf.Tensor:
145
        """
146
        Adjust image shape and concatenate them together.
147
148
        :param moving_image: registration source
149
        :param fixed_image: registration target
150
        :param moving_label: optional, only used for conditional model.
151
        :return:
152
        """
153
        images = []
154
155
        resize_layer = layer.Resize3d(shape=self.fixed_image_size)
156
157
        # (batch, m_dim1, m_dim2, m_dim3, 1)
158
        moving_image = tf.expand_dims(moving_image, axis=4)
159
        moving_image = resize_layer(moving_image)
160
        images.append(moving_image)
161
162
        # (batch, m_dim1, m_dim2, m_dim3, 1)
163
        fixed_image = tf.expand_dims(fixed_image, axis=4)
164
        images.append(fixed_image)
165
166
        # (batch, m_dim1, m_dim2, m_dim3, 1)
167
        if moving_label is not None:
168
            moving_label = tf.expand_dims(moving_label, axis=4)
169
            moving_label = resize_layer(moving_label)
170
            images.append(moving_label)
171
172
        # (batch, f_dim1, f_dim2, f_dim3, 2 or 3)
173
        images = tf.concat(images, axis=4)
174
        return images
175
176
    def _build_loss(self, name: str, inputs_dict: dict):
177
        """
178
        Build and add one weighted loss together with the metrics.
179
180
        :param name: name of loss, image / label / regularization.
181
        :param inputs_dict: inputs for loss function
182
        """
183
184
        if name not in self.config["loss"]:
185
            # loss config is not defined
186
            logger.warning(
187
                f"The configuration for loss {name} is not defined. "
188
                f"Therefore it is not used."
189
            )
190
            return
191
192
        loss_configs = self.config["loss"][name]
193
        if not isinstance(loss_configs, list):
194
            loss_configs = [loss_configs]
195
196
        for loss_config in loss_configs:
197
198
            if "weight" not in loss_config:
199
                # default loss weight 1
200
                logger.warning(
201
                    f"The weight for loss {name} is not defined."
202
                    f"Default weight = 1.0 is used."
203
                )
204
                loss_config["weight"] = 1.0
205
206
            # build loss
207
            weight = loss_config["weight"]
208
209
            if weight == 0:
210
                logger.warning(
211
                    f"The weight for loss {name} is zero." f"Loss is not used."
212
                )
213
                return
214
215
            # do not perform reduction over batch axis for supporting multi-device
216
            # training, model.fit() will average over global batch size automatically
217
            loss_layer: tf.keras.layers.Layer = REGISTRY.build_loss(
218
                config=dict_without(d=loss_config, key="weight"),
219
                default_args={"reduction": tf.keras.losses.Reduction.NONE},
220
            )
221
            loss_value = loss_layer(**inputs_dict)
222
            weighted_loss = loss_value * weight
223
224
            # add loss
225
            self._model.add_loss(weighted_loss)
226
227
            # add metric
228
            self._model.add_metric(
229
                loss_value, name=f"loss/{name}_{loss_layer.name}", aggregation="mean"
230
            )
231
            self._model.add_metric(
232
                weighted_loss,
233
                name=f"loss/{name}_{loss_layer.name}_weighted",
234
                aggregation="mean",
235
            )
236
237
    @abstractmethod
238
    def build_loss(self):
239
        """Build losses according to configs."""
240
241
        # input metrics
242
        fixed_image = self._inputs["fixed_image"]
243
        moving_image = self._inputs["moving_image"]
244
        self.log_tensor_stats(tensor=moving_image, name="moving_image")
245
        self.log_tensor_stats(tensor=fixed_image, name="fixed_image")
246
247
        # image loss, conditional model does not have this
248
        if "pred_fixed_image" in self._outputs:
249
            pred_fixed_image = self._outputs["pred_fixed_image"]
250
            self._build_loss(
251
                name="image",
252
                inputs_dict=dict(y_true=fixed_image, y_pred=pred_fixed_image),
253
            )
254
255
        if self.labeled:
256
            # input metrics
257
            fixed_label = self._inputs["fixed_label"]
258
            moving_label = self._inputs["moving_label"]
259
            self.log_tensor_stats(tensor=moving_label, name="moving_label")
260
            self.log_tensor_stats(tensor=fixed_label, name="fixed_label")
261
262
            # label loss
263
            pred_fixed_label = self._outputs["pred_fixed_label"]
264
            self._build_loss(
265
                name="label",
266
                inputs_dict=dict(y_true=fixed_label, y_pred=pred_fixed_label),
267
            )
268
269
            # additional label metrics
270
            tre = compute_centroid_distance(
271
                y_true=fixed_label, y_pred=pred_fixed_label, grid=self.grid_ref
272
            )
273
            self._model.add_metric(tre, name="metric/TRE", aggregation="mean")
274
275
    def call(
276
        self, inputs: Dict[str, tf.Tensor], training=None, mask=None
277
    ) -> Dict[str, tf.Tensor]:
278
        """
279
        Call the self._model.
280
281
        :param inputs: a dict of tensors.
282
        :param training: training or not.
283
        :param mask: maks for inputs.
284
        :return:
285
        """
286
        return self._model(inputs, training=training, mask=mask)  # pragma: no cover
287
288
    @abstractmethod
289
    def postprocess(
290
        self,
291
        inputs: Dict[str, tf.Tensor],
292
        outputs: Dict[str, tf.Tensor],
293
    ) -> Tuple[tf.Tensor, Dict]:
294
        """
295
        Return a dict used for saving inputs and outputs.
296
297
        :param inputs: dict of model inputs
298
        :param outputs: dict of model outputs
299
        :return: tuple, indices and a dict.
300
            In the dict, each value is (tensor, normalize, on_label), where
301
            - normalize = True if the tensor need to be normalized to [0, 1]
302
            - on_label = True if the tensor depends on label
303
        """
304
305
    def plot_model(self, output_dir: str):
306
        """
307
        Save model structure in png.
308
309
        :param output_dir: path to the output dir.
310
        """
311
        self._model.summary(print_fn=logger.debug)
312
        try:
313
            tf.keras.utils.plot_model(
314
                self._model,
315
                to_file=os.path.join(output_dir, f"{self.name}.png"),
316
                dpi=96,
317
                show_shapes=True,
318
                show_layer_names=True,
319
                expand_nested=False,
320
            )
321
        except ImportError as err:  # pragma: no cover
322
            logger.error(
323
                "Failed to plot model structure. "
324
                "Please check if graphviz is installed. "
325
                "Error message is: %s.",
326
                err,
327
            )
328
329
    def log_tensor_stats(self, tensor: tf.Tensor, name: str):
330
        """
331
        Log statistics of a given tensor.
332
333
        :param tensor: tensor to monitor.
334
        :param name: name of the tensor.
335
        """
336
        flatten = tf.reshape(tensor, shape=(self.batch_size, -1))
337
        self._model.add_metric(
338
            tf.reduce_mean(flatten, axis=1),
339
            name=f"metric/{name}_mean",
340
            aggregation="mean",
341
        )
342
        self._model.add_metric(
343
            tf.reduce_min(flatten, axis=1),
344
            name=f"metric/{name}_min",
345
            aggregation="min",
346
        )
347
        self._model.add_metric(
348
            tf.reduce_max(flatten, axis=1),
349
            name=f"metric/{name}_max",
350
            aggregation="max",
351
        )
352
353
354
@REGISTRY.register_model(name="ddf")
355
class DDFModel(RegistrationModel):
356
    """
357
    A registration model predicts DDF.
358
359
    When using global net as backbone,
360
    the model predicts an affine transformation parameters,
361
    and a DDF is calculated based on that.
362
    """
363
364
    name = "DDFModel"
365
366
    def _resize_interpolate(self, field, control_points):
367
        resize = layer.ResizeCPTransform(control_points)
368
        field = resize(field)
369
370
        interpolate = layer.BSplines3DTransform(control_points, self.fixed_image_size)
371
        field = interpolate(field)
372
373
        return field
374
375
    def build_model(self):
376
        """Build the model to be saved as self._model."""
377
        # build inputs
378
        self._inputs = self.build_inputs()
379
        moving_image = self._inputs["moving_image"]  # (batch, m_dim1, m_dim2, m_dim3)
380
        fixed_image = self._inputs["fixed_image"]  # (batch, f_dim1, f_dim2, f_dim3)
381
382
        # build ddf
383
        control_points = self.config["backbone"].pop("control_points", False)
384
        backbone_inputs = self.concat_images(moving_image, fixed_image)
385
        backbone = REGISTRY.build_backbone(
386
            config=self.config["backbone"],
387
            default_args=dict(
388
                image_size=self.fixed_image_size,
389
                out_channels=3,
390
                out_kernel_initializer="zeros",
391
                out_activation=None,
392
            ),
393
        )
394
395
        if isinstance(backbone, GlobalNet):
396
            # (f_dim1, f_dim2, f_dim3, 3), (4, 3)
397
            ddf, theta = backbone(inputs=backbone_inputs)
398
            self._outputs = dict(ddf=ddf, theta=theta)
399
        else:
400
            # (f_dim1, f_dim2, f_dim3, 3)
401
            ddf = backbone(inputs=backbone_inputs)
402
            ddf = (
403
                self._resize_interpolate(ddf, control_points) if control_points else ddf
404
            )
405
            self._outputs = dict(ddf=ddf)
406
407
        # build outputs
408
        warping = layer.Warping(fixed_image_size=self.fixed_image_size)
409
        # (f_dim1, f_dim2, f_dim3)
410
        pred_fixed_image = warping(inputs=[ddf, moving_image])
411
        self._outputs["pred_fixed_image"] = pred_fixed_image
412
413
        if not self.labeled:
414
            return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
415
416
        # (f_dim1, f_dim2, f_dim3)
417
        moving_label = self._inputs["moving_label"]
418
        pred_fixed_label = warping(inputs=[ddf, moving_label])
419
420
        self._outputs["pred_fixed_label"] = pred_fixed_label
421
        return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
422
423
    def build_loss(self):
424
        """Build losses according to configs."""
425
        super().build_loss()
426
427
        # ddf loss and metrics
428
        ddf = self._outputs["ddf"]
429
        self._build_loss(name="regularization", inputs_dict=dict(inputs=ddf))
430
        self.log_tensor_stats(tensor=ddf, name="ddf")
431
432
    def postprocess(
433
        self,
434
        inputs: Dict[str, tf.Tensor],
435
        outputs: Dict[str, tf.Tensor],
436
    ) -> Tuple[tf.Tensor, Dict]:
437
        """
438
        Return a dict used for saving inputs and outputs.
439
440
        :param inputs: dict of model inputs
441
        :param outputs: dict of model outputs
442
        :return: tuple, indices and a dict.
443
            In the dict, each value is (tensor, normalize, on_label), where
444
            - normalize = True if the tensor need to be normalized to [0, 1]
445
            - on_label = True if the tensor depends on label
446
        """
447
        indices = inputs["indices"]
448
        processed = dict(
449
            moving_image=(inputs["moving_image"], True, False),
450
            fixed_image=(inputs["fixed_image"], True, False),
451
            ddf=(outputs["ddf"], True, False),
452
            pred_fixed_image=(outputs["pred_fixed_image"], True, False),
453
        )
454
455
        # save theta for affine model
456
        if "theta" in outputs:
457
            processed["theta"] = (outputs["theta"], None, None)  # type: ignore
458
459
        if not self.labeled:
460
            return indices, processed
461
462
        processed = {
463
            **dict(
464
                moving_label=(inputs["moving_label"], False, True),
465
                fixed_label=(inputs["fixed_label"], False, True),
466
                pred_fixed_label=(outputs["pred_fixed_label"], False, True),
467
            ),
468
            **processed,
469
        }
470
471
        return indices, processed
472
473
474
@REGISTRY.register_model(name="dvf")
475
class DVFModel(DDFModel):
476
    """
477
    A registration model predicts DVF.
478
479
    DDF is calculated based on DVF.
480
    """
481
482
    name = "DVFModel"
483
484
    def build_model(self):
485
        """Build the model to be saved as self._model."""
486
        # build inputs
487
        self._inputs = self.build_inputs()
488
        moving_image = self._inputs["moving_image"]
489
        fixed_image = self._inputs["fixed_image"]
490
        control_points = self.config["backbone"].pop("control_points", False)
491
492
        # build ddf
493
        backbone_inputs = self.concat_images(moving_image, fixed_image)
494
        backbone = REGISTRY.build_backbone(
495
            config=self.config["backbone"],
496
            default_args=dict(
497
                image_size=self.fixed_image_size,
498
                out_channels=3,
499
                out_kernel_initializer="zeros",
500
                out_activation=None,
501
            ),
502
        )
503
        dvf = backbone(inputs=backbone_inputs)
504
        dvf = self._resize_interpolate(dvf, control_points) if control_points else dvf
505
        ddf = layer.IntDVF(fixed_image_size=self.fixed_image_size)(dvf)
506
507
        # build outputs
508
        self._warping = layer.Warping(fixed_image_size=self.fixed_image_size)
509
        # (f_dim1, f_dim2, f_dim3, 3)
510
        pred_fixed_image = self._warping(inputs=[ddf, moving_image])
511
512
        self._outputs = dict(dvf=dvf, ddf=ddf, pred_fixed_image=pred_fixed_image)
513
514
        if not self.labeled:
515
            return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
516
517
        # (f_dim1, f_dim2, f_dim3, 3)
518
        moving_label = self._inputs["moving_label"]
519
        pred_fixed_label = self._warping(inputs=[ddf, moving_label])
520
521
        self._outputs["pred_fixed_label"] = pred_fixed_label
522
        return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
523
524
    def build_loss(self):
525
        """Build losses according to configs."""
526
        super().build_loss()
527
528
        # dvf metrics
529
        dvf = self._outputs["dvf"]
530
        self.log_tensor_stats(tensor=dvf, name="dvf")
531
532
    def postprocess(
533
        self,
534
        inputs: Dict[str, tf.Tensor],
535
        outputs: Dict[str, tf.Tensor],
536
    ) -> Tuple[tf.Tensor, Dict]:
537
        """
538
        Return a dict used for saving inputs and outputs.
539
540
        :param inputs: dict of model inputs
541
        :param outputs: dict of model outputs
542
        :return: tuple, indices and a dict.
543
            In the dict, each value is (tensor, normalize, on_label), where
544
            - normalize = True if the tensor need to be normalized to [0, 1]
545
            - on_label = True if the tensor depends on label
546
        """
547
        indices, processed = super().postprocess(inputs=inputs, outputs=outputs)
548
        processed["dvf"] = (outputs["dvf"], True, False)
549
        return indices, processed
550
551
552
@REGISTRY.register_model(name="conditional")
553
class ConditionalModel(RegistrationModel):
554
    """
555
    A registration model predicts fixed image label without DDF or DVF.
556
    """
557
558
    name = "ConditionalModel"
559
560
    def build_model(self):
561
        """Build the model to be saved as self._model."""
562
        assert self.labeled
563
564
        # build inputs
565
        self._inputs = self.build_inputs()
566
        moving_image = self._inputs["moving_image"]
567
        fixed_image = self._inputs["fixed_image"]
568
        moving_label = self._inputs["moving_label"]
569
570
        # build ddf
571
        backbone_inputs = self.concat_images(moving_image, fixed_image, moving_label)
572
        backbone = REGISTRY.build_backbone(
573
            config=self.config["backbone"],
574
            default_args=dict(
575
                image_size=self.fixed_image_size,
576
                out_channels=1,
577
                out_kernel_initializer="glorot_uniform",
578
                out_activation="sigmoid",
579
            ),
580
        )
581
        # (batch, f_dim1, f_dim2, f_dim3)
582
        pred_fixed_label = backbone(inputs=backbone_inputs)
583
        pred_fixed_label = tf.squeeze(pred_fixed_label, axis=4)
584
585
        self._outputs = dict(pred_fixed_label=pred_fixed_label)
586
        return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
587
588
    def postprocess(
589
        self,
590
        inputs: Dict[str, tf.Tensor],
591
        outputs: Dict[str, tf.Tensor],
592
    ) -> Tuple[tf.Tensor, Dict]:
593
        """
594
        Return a dict used for saving inputs and outputs.
595
596
        :param inputs: dict of model inputs
597
        :param outputs: dict of model outputs
598
        :return: tuple, indices and a dict.
599
            In the dict, each value is (tensor, normalize, on_label), where
600
            - normalize = True if the tensor need to be normalized to [0, 1]
601
            - on_label = True if the tensor depends on label
602
        """
603
        indices = inputs["indices"]
604
        processed = dict(
605
            moving_image=(inputs["moving_image"], True, False),
606
            fixed_image=(inputs["fixed_image"], True, False),
607
            pred_fixed_label=(outputs["pred_fixed_label"], True, True),
608
            moving_label=(inputs["moving_label"], False, True),
609
            fixed_label=(inputs["fixed_label"], False, True),
610
        )
611
612
        return indices, processed
613