Passed
Pull Request — main (#719)
by Yunguan
01:41
created

deepreg.model.network.DVFModel.build_loss()   A

Complexity

Conditions 1

Size

Total Lines 7
Code Lines 4

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 4
dl 0
loc 7
rs 10
c 0
b 0
f 0
cc 1
nop 1
1
import logging
2
import os
3
from abc import abstractmethod
4
from copy import deepcopy
5
from typing import Dict, Optional, Tuple
6
7
import tensorflow as tf
8
9
from deepreg.loss.image import LocalNormalizedCrossCorrelation
10
from deepreg.loss.label import DiceScore, compute_centroid_distance
11
from deepreg.model import layer, layer_util
12
from deepreg.model.backbone import GlobalNet
13
from deepreg.registry import REGISTRY
14
15
16
def dict_without(d: dict, key) -> dict:
17
    """
18
    Return a copy of the given dict without a certain key.
19
20
    :param d: dict to be copied.
21
    :param key: key to be removed.
22
    :return: the copy without a key
23
    """
24
    copied = deepcopy(d)
25
    copied.pop(key)
26
    return copied
27
28
29
class RegistrationModel(tf.keras.Model):
30
    """Interface for registration model."""
31
32
    def __init__(
33
        self,
34
        moving_image_size: tuple,
35
        fixed_image_size: tuple,
36
        index_size: int,
37
        labeled: bool,
38
        batch_size: int,
39
        config: dict,
40
        num_devices: int = 1,
41
        name: str = "RegistrationModel",
42
    ):
43
        """
44
        Init.
45
46
        :param moving_image_size: (m_dim1, m_dim2, m_dim3)
47
        :param fixed_image_size: (f_dim1, f_dim2, f_dim3)
48
        :param index_size: number of indices for identify each sample
49
        :param labeled: if the data is labeled
50
        :param batch_size: size of mini-batch
51
        :param config: config for method, backbone, and loss.
52
        :param num_devices: number of GPU used,
53
            global_batch_size = batch_size*num_devices
54
        :param name: name of the model
55
        """
56
        super().__init__(name=name)
57
        self.moving_image_size = moving_image_size
58
        self.fixed_image_size = fixed_image_size
59
        self.index_size = index_size
60
        self.labeled = labeled
61
        self.batch_size = batch_size
62
        self.config = config
63
        self.num_devices = num_devices
64
        self.global_batch_size = num_devices * batch_size * 1.0
65
        assert self.global_batch_size > 0
66
67
        self._inputs = None  # save inputs of self._model as dict
68
        self._outputs = None  # save outputs of self._model as dict
69
70
        self.grid_ref = layer_util.get_reference_grid(grid_size=fixed_image_size)[
71
            None, ...
72
        ]
73
        self._model: tf.keras.Model = self.build_model()
74
        self.build_loss()
75
76
    def get_config(self) -> dict:
77
        """Return the config dictionary for recreating this class."""
78
        return dict(
79
            moving_image_size=self.moving_image_size,
80
            fixed_image_size=self.fixed_image_size,
81
            index_size=self.index_size,
82
            labeled=self.labeled,
83
            batch_size=self.batch_size,
84
            config=self.config,
85
            num_devices=self.num_devices,
86
            name=self.name,
87
        )
88
89
    @abstractmethod
90
    def build_model(self):
91
        """Build the model to be saved as self._model."""
92
93
    def build_inputs(self) -> Dict[str, tf.keras.layers.Input]:
94
        """
95
        Build input tensors.
96
97
        :return: dict of inputs.
98
        """
99
        # (batch, m_dim1, m_dim2, m_dim3, 1)
100
        moving_image = tf.keras.Input(
101
            shape=self.moving_image_size,
102
            batch_size=self.batch_size,
103
            name="moving_image",
104
        )
105
        # (batch, f_dim1, f_dim2, f_dim3, 1)
106
        fixed_image = tf.keras.Input(
107
            shape=self.fixed_image_size,
108
            batch_size=self.batch_size,
109
            name="fixed_image",
110
        )
111
        # (batch, index_size)
112
        indices = tf.keras.Input(
113
            shape=(self.index_size,),
114
            batch_size=self.batch_size,
115
            name="indices",
116
        )
117
118
        if not self.labeled:
119
            return dict(
120
                moving_image=moving_image, fixed_image=fixed_image, indices=indices
121
            )
122
123
        # (batch, m_dim1, m_dim2, m_dim3, 1)
124
        moving_label = tf.keras.Input(
125
            shape=self.moving_image_size,
126
            batch_size=self.batch_size,
127
            name="moving_label",
128
        )
129
        # (batch, m_dim1, m_dim2, m_dim3, 1)
130
        fixed_label = tf.keras.Input(
131
            shape=self.fixed_image_size,
132
            batch_size=self.batch_size,
133
            name="fixed_label",
134
        )
135
        return dict(
136
            moving_image=moving_image,
137
            fixed_image=fixed_image,
138
            moving_label=moving_label,
139
            fixed_label=fixed_label,
140
            indices=indices,
141
        )
142
143
    def concat_images(
144
        self,
145
        moving_image: tf.Tensor,
146
        fixed_image: tf.Tensor,
147
        moving_label: Optional[tf.Tensor] = None,
148
    ) -> tf.Tensor:
149
        """
150
        Adjust image shape and concatenate them together.
151
152
        :param moving_image: registration source
153
        :param fixed_image: registration target
154
        :param moving_label: optional, only used for conditional model.
155
        :return:
156
        """
157
        images = []
158
159
        resize_layer = layer.Resize3d(shape=self.fixed_image_size)
160
161
        # (batch, m_dim1, m_dim2, m_dim3, 1)
162
        moving_image = tf.expand_dims(moving_image, axis=4)
163
        moving_image = resize_layer(moving_image)
164
        images.append(moving_image)
165
166
        # (batch, m_dim1, m_dim2, m_dim3, 1)
167
        fixed_image = tf.expand_dims(fixed_image, axis=4)
168
        images.append(fixed_image)
169
170
        # (batch, m_dim1, m_dim2, m_dim3, 1)
171
        if moving_label is not None:
172
            moving_label = tf.expand_dims(moving_label, axis=4)
173
            moving_label = resize_layer(moving_label)
174
            images.append(moving_label)
175
176
        # (batch, f_dim1, f_dim2, f_dim3, 2 or 3)
177
        images = tf.concat(images, axis=4)
178
        return images
179
180
    def _build_loss(self, name: str, inputs_dict: dict):
181
        """
182
        Build and add one weighted loss together with the metrics.
183
184
        :param name: name of loss, image / label / regularization.
185
        :param inputs_dict: inputs for loss function
186
        """
187
188
        if name not in self.config["loss"]:
189
            # loss config is not defined
190
            logging.warning(
191
                f"The configuration for loss {name} is not defined. "
192
                f"Therefore it is not used."
193
            )
194
            return
195
196
        loss_configs = self.config["loss"][name]
197
        if not isinstance(loss_configs, list):
198
            loss_configs = [loss_configs]
199
200
        for loss_config in loss_configs:
201
202
            if "weight" not in loss_config:
203
                # default loss weight 1
204
                logging.warning(
205
                    f"The weight for loss {name} is not defined."
206
                    f"Default weight = 1.0 is used."
207
                )
208
                loss_config["weight"] = 1.0
209
210
            # build loss
211
            weight = loss_config["weight"]
212
213
            if weight == 0:
214
                logging.warning(
215
                    f"The weight for loss {name} is zero." f"Loss is not used."
216
                )
217
                return
218
219
            loss_layer: tf.keras.layers.Layer = REGISTRY.build_loss(
220
                config=dict_without(d=loss_config, key="weight")
221
            )
222
            loss_value = loss_layer(**inputs_dict) / self.global_batch_size
223
            weighted_loss = loss_value * weight if weight != 1 else loss_value
224
225
            # add loss
226
            self._model.add_loss(weighted_loss)
227
228
            loss_value = tf.debugging.check_numerics(
229
                loss_value,
230
                f"loss {name}_{loss_layer.name} inf/nan",
231
                name=f"op_loss_{name}_{loss_layer.name}",
232
            )
233
234
            # add metric
235
            self._model.add_metric(
236
                loss_value, name=f"loss/{name}_{loss_layer.name}", aggregation="mean"
237
            )
238
            self._model.add_metric(
239
                weighted_loss,
240
                name=f"loss/{name}_{loss_layer.name}_weighted",
241
                aggregation="mean",
242
            )
243
244
    @abstractmethod
245
    def build_loss(self):
246
        """Build losses according to configs."""
247
248
        # input metrics
249
        fixed_image = self._inputs["fixed_image"]
250
        moving_image = self._inputs["moving_image"]
251
        self.log_tensor_stats(tensor=moving_image, name="moving_image")
252
        self.log_tensor_stats(tensor=fixed_image, name="fixed_image")
253
254
        # image loss, conditional model does not have this
255
        if "pred_fixed_image" in self._outputs:
256
            pred_fixed_image = self._outputs["pred_fixed_image"]
257
            num, denom = LocalNormalizedCrossCorrelation()._call(
258
                y_true=fixed_image, y_pred=pred_fixed_image
259
            )
260
            self.log_tensor_stats(num, name="debug-lncc-num")
261
            self.log_tensor_stats(denom, name="debug-lncc-denom")
262
            self._build_loss(
263
                name="image",
264
                inputs_dict=dict(y_true=fixed_image, y_pred=pred_fixed_image),
265
            )
266
267
        if self.labeled:
268
            # input metrics
269
            fixed_label = self._inputs["fixed_label"]
270
            moving_label = self._inputs["moving_label"]
271
            self.log_tensor_stats(tensor=moving_label, name="moving_label")
272
            self.log_tensor_stats(tensor=fixed_label, name="fixed_label")
273
274
            # label loss
275
            pred_fixed_label = self._outputs["pred_fixed_label"]
276
            self._build_loss(
277
                name="label",
278
                inputs_dict=dict(y_true=fixed_label, y_pred=pred_fixed_label),
279
            )
280
281
            # additional label metrics
282
            tre = compute_centroid_distance(
283
                y_true=fixed_label, y_pred=pred_fixed_label, grid=self.grid_ref
284
            )
285
            dice_binary = DiceScore(binary=True)(
286
                y_true=fixed_label, y_pred=pred_fixed_label
287
            )
288
            self._model.add_metric(tre, name="metric/TRE", aggregation="mean")
289
            self._model.add_metric(
290
                dice_binary, name="metric/BinaryDiceScore", aggregation="mean"
291
            )
292
293
    def call(
294
        self, inputs: Dict[str, tf.Tensor], training=None, mask=None
295
    ) -> Dict[str, tf.Tensor]:
296
        """
297
        Call the self._model.
298
299
        :param inputs: a dict of tensors.
300
        :param training: training or not.
301
        :param mask: maks for inputs.
302
        :return:
303
        """
304
        return self._model(inputs, training=training, mask=mask)  # pragma: no cover
305
306
    @abstractmethod
307
    def postprocess(
308
        self,
309
        inputs: Dict[str, tf.Tensor],
310
        outputs: Dict[str, tf.Tensor],
311
    ) -> Tuple[tf.Tensor, Dict]:
312
        """
313
        Return a dict used for saving inputs and outputs.
314
315
        :param inputs: dict of model inputs
316
        :param outputs: dict of model outputs
317
        :return: tuple, indices and a dict.
318
            In the dict, each value is (tensor, normalize, on_label), where
319
            - normalize = True if the tensor need to be normalized to [0, 1]
320
            - on_label = True if the tensor depends on label
321
        """
322
323
    def plot_model(self, output_dir: str):
324
        """
325
        Save model structure in png.
326
327
        :param output_dir: path to the output dir.
328
        """
329
        logging.info(self._model.summary())
330
        try:
331
            tf.keras.utils.plot_model(
332
                self._model,
333
                to_file=os.path.join(output_dir, f"{self.name}.png"),
334
                dpi=96,
335
                show_shapes=True,
336
                show_layer_names=True,
337
                expand_nested=False,
338
            )
339
        except ImportError as err:  # pragma: no cover
340
            logging.error(
341
                "Failed to plot model structure."
342
                "Please check if graphviz is installed.\n"
343
                "Error message is:"
344
                f"{err}"
345
            )
346
347
    def log_tensor_stats(self, tensor: tf.Tensor, name: str):
348
        """
349
        Log statistics of a given tensor.
350
351
        :param tensor: tensor to monitor.
352
        :param name: name of the tensor.
353
        """
354
        flatten = tf.reshape(tensor, shape=(self.batch_size, -1))
355
        self._model.add_metric(
356
            tf.reduce_mean(flatten, axis=1),
357
            name=f"metric/{name}_mean",
358
            aggregation="mean",
359
        )
360
        self._model.add_metric(
361
            tf.reduce_min(flatten, axis=1),
362
            name=f"metric/{name}_min",
363
            aggregation="min",
364
        )
365
        self._model.add_metric(
366
            tf.reduce_max(flatten, axis=1),
367
            name=f"metric/{name}_max",
368
            aggregation="max",
369
        )
370
371
372
@REGISTRY.register_model(name="ddf")
373
class DDFModel(RegistrationModel):
374
    """
375
    A registration model predicts DDF.
376
377
    When using global net as backbone,
378
    the model predicts an affine transformation parameters,
379
    and a DDF is calculated based on that.
380
    """
381
382
    name = "DDFModel"
383
384
    def _resize_interpolate(self, field, control_points):
385
        resize = layer.ResizeCPTransform(control_points)
386
        field = resize(field)
387
388
        interpolate = layer.BSplines3DTransform(control_points, self.fixed_image_size)
389
        field = interpolate(field)
390
391
        return field
392
393
    def build_model(self):
394
        """Build the model to be saved as self._model."""
395
        # build inputs
396
        self._inputs = self.build_inputs()
397
        moving_image = self._inputs["moving_image"]
398
        fixed_image = self._inputs["fixed_image"]
399
400
        # build ddf
401
        control_points = self.config["backbone"].pop("control_points", False)
402
        backbone_inputs = self.concat_images(moving_image, fixed_image)
403
        backbone = REGISTRY.build_backbone(
404
            config=self.config["backbone"],
405
            default_args=dict(
406
                image_size=self.fixed_image_size,
407
                out_channels=3,
408
                out_kernel_initializer="zeros",
409
                out_activation=None,
410
            ),
411
        )
412
413
        if isinstance(backbone, GlobalNet):
414
            # (f_dim1, f_dim2, f_dim3, 3), (4, 3)
415
            ddf, theta = backbone(inputs=backbone_inputs)
416
            self._outputs = dict(ddf=ddf, theta=theta)
417
        else:
418
            # (f_dim1, f_dim2, f_dim3, 3)
419
            ddf = backbone(inputs=backbone_inputs)
420
            ddf = (
421
                self._resize_interpolate(ddf, control_points) if control_points else ddf
422
            )
423
            self._outputs = dict(ddf=ddf)
424
425
        # build outputs
426
        warping = layer.Warping(fixed_image_size=self.fixed_image_size)
427
        # (f_dim1, f_dim2, f_dim3, 3)
428
        pred_fixed_image = warping(inputs=[ddf, moving_image])
429
        self._outputs["pred_fixed_image"] = pred_fixed_image
430
431
        if not self.labeled:
432
            return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
433
434
        # (f_dim1, f_dim2, f_dim3, 3)
435
        moving_label = self._inputs["moving_label"]
436
        pred_fixed_label = warping(inputs=[ddf, moving_label])
437
438
        self._outputs["pred_fixed_label"] = pred_fixed_label
439
        return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
440
441
    def build_loss(self):
442
        """Build losses according to configs."""
443
        super().build_loss()
444
445
        # ddf loss and metrics
446
        ddf = self._outputs["ddf"]
447
        self._build_loss(name="regularization", inputs_dict=dict(inputs=ddf))
448
        self.log_tensor_stats(tensor=ddf, name="ddf")
449
450
    def postprocess(
451
        self,
452
        inputs: Dict[str, tf.Tensor],
453
        outputs: Dict[str, tf.Tensor],
454
    ) -> Tuple[tf.Tensor, Dict]:
455
        """
456
        Return a dict used for saving inputs and outputs.
457
458
        :param inputs: dict of model inputs
459
        :param outputs: dict of model outputs
460
        :return: tuple, indices and a dict.
461
            In the dict, each value is (tensor, normalize, on_label), where
462
            - normalize = True if the tensor need to be normalized to [0, 1]
463
            - on_label = True if the tensor depends on label
464
        """
465
        indices = inputs["indices"]
466
        processed = dict(
467
            moving_image=(inputs["moving_image"], True, False),
468
            fixed_image=(inputs["fixed_image"], True, False),
469
            ddf=(outputs["ddf"], True, False),
470
            pred_fixed_image=(outputs["pred_fixed_image"], True, False),
471
        )
472
473
        # save theta for affine model
474
        if "theta" in outputs:
475
            processed["theta"] = (outputs["theta"], None, None)  # type: ignore
476
477
        if not self.labeled:
478
            return indices, processed
479
480
        processed = {
481
            **dict(
482
                moving_label=(inputs["moving_label"], False, True),
483
                fixed_label=(inputs["fixed_label"], False, True),
484
                pred_fixed_label=(outputs["pred_fixed_label"], False, True),
485
            ),
486
            **processed,
487
        }
488
489
        return indices, processed
490
491
492
@REGISTRY.register_model(name="dvf")
493
class DVFModel(DDFModel):
494
    """
495
    A registration model predicts DVF.
496
497
    DDF is calculated based on DVF.
498
    """
499
500
    name = "DVFModel"
501
502
    def build_model(self):
503
        """Build the model to be saved as self._model."""
504
        # build inputs
505
        self._inputs = self.build_inputs()
506
        moving_image = self._inputs["moving_image"]
507
        fixed_image = self._inputs["fixed_image"]
508
        control_points = self.config["backbone"].pop("control_points", False)
509
510
        # build ddf
511
        backbone_inputs = self.concat_images(moving_image, fixed_image)
512
        backbone = REGISTRY.build_backbone(
513
            config=self.config["backbone"],
514
            default_args=dict(
515
                image_size=self.fixed_image_size,
516
                out_channels=3,
517
                out_kernel_initializer="zeros",
518
                out_activation=None,
519
            ),
520
        )
521
        dvf = backbone(inputs=backbone_inputs)
522
        dvf = self._resize_interpolate(dvf, control_points) if control_points else dvf
523
        ddf = layer.IntDVF(fixed_image_size=self.fixed_image_size)(dvf)
524
525
        # build outputs
526
        self._warping = layer.Warping(fixed_image_size=self.fixed_image_size)
527
        # (f_dim1, f_dim2, f_dim3, 3)
528
        pred_fixed_image = self._warping(inputs=[ddf, moving_image])
529
530
        self._outputs = dict(dvf=dvf, ddf=ddf, pred_fixed_image=pred_fixed_image)
531
532
        if not self.labeled:
533
            return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
534
535
        # (f_dim1, f_dim2, f_dim3, 3)
536
        moving_label = self._inputs["moving_label"]
537
        pred_fixed_label = self._warping(inputs=[ddf, moving_label])
538
539
        self._outputs["pred_fixed_label"] = pred_fixed_label
540
        return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
541
542
    def build_loss(self):
543
        """Build losses according to configs."""
544
        super().build_loss()
545
546
        # dvf metrics
547
        dvf = self._outputs["dvf"]
548
        self.log_tensor_stats(tensor=dvf, name="dvf")
549
550
    def postprocess(
551
        self,
552
        inputs: Dict[str, tf.Tensor],
553
        outputs: Dict[str, tf.Tensor],
554
    ) -> Tuple[tf.Tensor, Dict]:
555
        """
556
        Return a dict used for saving inputs and outputs.
557
558
        :param inputs: dict of model inputs
559
        :param outputs: dict of model outputs
560
        :return: tuple, indices and a dict.
561
            In the dict, each value is (tensor, normalize, on_label), where
562
            - normalize = True if the tensor need to be normalized to [0, 1]
563
            - on_label = True if the tensor depends on label
564
        """
565
        indices, processed = super().postprocess(inputs=inputs, outputs=outputs)
566
        processed["dvf"] = (outputs["dvf"], True, False)
567
        return indices, processed
568
569
570
@REGISTRY.register_model(name="conditional")
571
class ConditionalModel(RegistrationModel):
572
    """
573
    A registration model predicts fixed image label without DDF or DVF.
574
    """
575
576
    name = "ConditionalModel"
577
578
    def build_model(self):
579
        """Build the model to be saved as self._model."""
580
        assert self.labeled
581
582
        # build inputs
583
        self._inputs = self.build_inputs()
584
        moving_image = self._inputs["moving_image"]
585
        fixed_image = self._inputs["fixed_image"]
586
        moving_label = self._inputs["moving_label"]
587
588
        # build ddf
589
        backbone_inputs = self.concat_images(moving_image, fixed_image, moving_label)
590
        backbone = REGISTRY.build_backbone(
591
            config=self.config["backbone"],
592
            default_args=dict(
593
                image_size=self.fixed_image_size,
594
                out_channels=1,
595
                out_kernel_initializer="glorot_uniform",
596
                out_activation="sigmoid",
597
            ),
598
        )
599
        # (batch, f_dim1, f_dim2, f_dim3)
600
        pred_fixed_label = backbone(inputs=backbone_inputs)
601
        pred_fixed_label = tf.squeeze(pred_fixed_label, axis=4)
602
603
        self._outputs = dict(pred_fixed_label=pred_fixed_label)
604
        return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
605
606
    def postprocess(
607
        self,
608
        inputs: Dict[str, tf.Tensor],
609
        outputs: Dict[str, tf.Tensor],
610
    ) -> Tuple[tf.Tensor, Dict]:
611
        """
612
        Return a dict used for saving inputs and outputs.
613
614
        :param inputs: dict of model inputs
615
        :param outputs: dict of model outputs
616
        :return: tuple, indices and a dict.
617
            In the dict, each value is (tensor, normalize, on_label), where
618
            - normalize = True if the tensor need to be normalized to [0, 1]
619
            - on_label = True if the tensor depends on label
620
        """
621
        indices = inputs["indices"]
622
        processed = dict(
623
            moving_image=(inputs["moving_image"], True, False),
624
            fixed_image=(inputs["fixed_image"], True, False),
625
            pred_fixed_label=(outputs["pred_fixed_label"], True, True),
626
            moving_label=(inputs["moving_label"], False, True),
627
            fixed_label=(inputs["fixed_label"], False, True),
628
        )
629
630
        return indices, processed
631