Completed
Push — main ( 3b5a47...3ab0c2 )
by
unknown
30s queued 13s
created

deepreg.model.network.DVFModel.build_loss()   A

Complexity

Conditions 1

Size

Total Lines 7
Code Lines 4

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 4
dl 0
loc 7
rs 10
c 0
b 0
f 0
cc 1
nop 1
1
import logging
2
import os
3
from abc import abstractmethod
4
from copy import deepcopy
5
from typing import Dict, Optional, Tuple
6
7
import tensorflow as tf
8
9
from deepreg.loss.label import DiceScore, compute_centroid_distance
10
from deepreg.model import layer, layer_util
11
from deepreg.model.backbone import GlobalNet
12
from deepreg.registry import REGISTRY
13
14
15
def dict_without(d: dict, key) -> dict:
16
    """
17
    Return a copy of the given dict without a certain key.
18
19
    :param d: dict to be copied.
20
    :param key: key to be removed.
21
    :return: the copy without a key
22
    """
23
    copied = deepcopy(d)
24
    copied.pop(key)
25
    return copied
26
27
28
class RegistrationModel(tf.keras.Model):
29
    """Interface for registration model."""
30
31
    def __init__(
32
        self,
33
        moving_image_size: tuple,
34
        fixed_image_size: tuple,
35
        index_size: int,
36
        labeled: bool,
37
        batch_size: int,
38
        config: dict,
39
        num_devices: int = 1,
40
        name: str = "RegistrationModel",
41
    ):
42
        """
43
        Init.
44
45
        :param moving_image_size: (m_dim1, m_dim2, m_dim3)
46
        :param fixed_image_size: (f_dim1, f_dim2, f_dim3)
47
        :param index_size: number of indices for identify each sample
48
        :param labeled: if the data is labeled
49
        :param batch_size: size of mini-batch
50
        :param config: config for method, backbone, and loss.
51
        :param num_devices: number of GPU used,
52
            global_batch_size = batch_size*num_devices
53
        :param name: name of the model
54
        """
55
        super().__init__(name=name)
56
        self.moving_image_size = moving_image_size
57
        self.fixed_image_size = fixed_image_size
58
        self.index_size = index_size
59
        self.labeled = labeled
60
        self.batch_size = batch_size
61
        self.config = config
62
        self.num_devices = num_devices
63
        self.global_batch_size = num_devices * batch_size
64
65
        self._inputs = None  # save inputs of self._model as dict
66
        self._outputs = None  # save outputs of self._model as dict
67
68
        self.grid_ref = layer_util.get_reference_grid(grid_size=fixed_image_size)[
69
            None, ...
70
        ]
71
        self._model: tf.keras.Model = self.build_model()
72
        self.build_loss()
73
74
    def get_config(self) -> dict:
75
        """Return the config dictionary for recreating this class."""
76
        return dict(
77
            moving_image_size=self.moving_image_size,
78
            fixed_image_size=self.fixed_image_size,
79
            index_size=self.index_size,
80
            labeled=self.labeled,
81
            batch_size=self.batch_size,
82
            config=self.config,
83
            num_devices=self.num_devices,
84
            name=self.name,
85
        )
86
87
    @abstractmethod
88
    def build_model(self):
89
        """Build the model to be saved as self._model."""
90
91
    def build_inputs(self) -> Dict[str, tf.keras.layers.Input]:
92
        """
93
        Build input tensors.
94
95
        :return: dict of inputs.
96
        """
97
        # (batch, m_dim1, m_dim2, m_dim3, 1)
98
        moving_image = tf.keras.Input(
99
            shape=self.moving_image_size,
100
            batch_size=self.batch_size,
101
            name="moving_image",
102
        )
103
        # (batch, f_dim1, f_dim2, f_dim3, 1)
104
        fixed_image = tf.keras.Input(
105
            shape=self.fixed_image_size,
106
            batch_size=self.batch_size,
107
            name="fixed_image",
108
        )
109
        # (batch, index_size)
110
        indices = tf.keras.Input(
111
            shape=(self.index_size,),
112
            batch_size=self.batch_size,
113
            name="indices",
114
        )
115
116
        if not self.labeled:
117
            return dict(
118
                moving_image=moving_image, fixed_image=fixed_image, indices=indices
119
            )
120
121
        # (batch, m_dim1, m_dim2, m_dim3, 1)
122
        moving_label = tf.keras.Input(
123
            shape=self.moving_image_size,
124
            batch_size=self.batch_size,
125
            name="moving_label",
126
        )
127
        # (batch, m_dim1, m_dim2, m_dim3, 1)
128
        fixed_label = tf.keras.Input(
129
            shape=self.fixed_image_size,
130
            batch_size=self.batch_size,
131
            name="fixed_label",
132
        )
133
        return dict(
134
            moving_image=moving_image,
135
            fixed_image=fixed_image,
136
            moving_label=moving_label,
137
            fixed_label=fixed_label,
138
            indices=indices,
139
        )
140
141
    def concat_images(
142
        self,
143
        moving_image: tf.Tensor,
144
        fixed_image: tf.Tensor,
145
        moving_label: Optional[tf.Tensor] = None,
146
    ) -> tf.Tensor:
147
        """
148
        Adjust image shape and concatenate them together.
149
150
        :param moving_image: registration source
151
        :param fixed_image: registration target
152
        :param moving_label: optional, only used for conditional model.
153
        :return:
154
        """
155
        images = []
156
157
        resize_layer = layer.Resize3d(shape=self.fixed_image_size)
158
159
        # (batch, m_dim1, m_dim2, m_dim3, 1)
160
        moving_image = tf.expand_dims(moving_image, axis=4)
161
        moving_image = resize_layer(moving_image)
162
        images.append(moving_image)
163
164
        # (batch, m_dim1, m_dim2, m_dim3, 1)
165
        fixed_image = tf.expand_dims(fixed_image, axis=4)
166
        images.append(fixed_image)
167
168
        # (batch, m_dim1, m_dim2, m_dim3, 1)
169
        if moving_label is not None:
170
            moving_label = tf.expand_dims(moving_label, axis=4)
171
            moving_label = resize_layer(moving_label)
172
            images.append(moving_label)
173
174
        # (batch, f_dim1, f_dim2, f_dim3, 2 or 3)
175
        images = tf.concat(images, axis=4)
176
        return images
177
178
    def _build_loss(self, name: str, inputs_dict: dict):
179
        """
180
        Build and add one weighted loss together with the metrics.
181
182
        :param name: name of loss, image / label / regularization.
183
        :param inputs_dict: inputs for loss function
184
        """
185
186
        if name not in self.config["loss"]:
187
            # loss config is not defined
188
            logging.warning(
189
                f"The configuration for loss {name} is not defined. "
190
                f"Therefore it is not used."
191
            )
192
            return
193
194
        loss_configs = self.config["loss"][name]
195
        if not isinstance(loss_configs, list):
196
            loss_configs = [loss_configs]
197
198
        for loss_config in loss_configs:
199
200
            if "weight" not in loss_config:
201
                # default loss weight 1
202
                logging.warning(
203
                    f"The weight for loss {name} is not defined."
204
                    f"Default weight = 1.0 is used."
205
                )
206
                loss_config["weight"] = 1.0
207
208
            # build loss
209
            weight = loss_config["weight"]
210
211
            if weight == 0:
212
                logging.warning(
213
                    f"The weight for loss {name} is zero." f"Loss is not used."
214
                )
215
                return
216
217
            loss_layer: tf.keras.layers.Layer = REGISTRY.build_loss(
218
                config=dict_without(d=loss_config, key="weight")
219
            )
220
            loss_value = loss_layer(**inputs_dict) / self.global_batch_size
221
            weighted_loss = loss_value * weight
222
223
            # add loss
224
            self._model.add_loss(weighted_loss)
225
226
            # add metric
227
            self._model.add_metric(
228
                loss_value, name=f"loss/{name}_{loss_layer.name}", aggregation="mean"
229
            )
230
            self._model.add_metric(
231
                weighted_loss,
232
                name=f"loss/{name}_{loss_layer.name}_weighted",
233
                aggregation="mean",
234
            )
235
236
    @abstractmethod
237
    def build_loss(self):
238
        """Build losses according to configs."""
239
240
        # input metrics
241
        fixed_image = self._inputs["fixed_image"]
242
        moving_image = self._inputs["moving_image"]
243
        self.log_tensor_stats(tensor=moving_image, name="moving_image")
244
        self.log_tensor_stats(tensor=fixed_image, name="fixed_image")
245
246
        # image loss, conditional model does not have this
247
        if "pred_fixed_image" in self._outputs:
248
            pred_fixed_image = self._outputs["pred_fixed_image"]
249
            self._build_loss(
250
                name="image",
251
                inputs_dict=dict(y_true=fixed_image, y_pred=pred_fixed_image),
252
            )
253
254
        if self.labeled:
255
            # input metrics
256
            fixed_label = self._inputs["fixed_label"]
257
            moving_label = self._inputs["moving_label"]
258
            self.log_tensor_stats(tensor=moving_label, name="moving_label")
259
            self.log_tensor_stats(tensor=fixed_label, name="fixed_label")
260
261
            # label loss
262
            pred_fixed_label = self._outputs["pred_fixed_label"]
263
            self._build_loss(
264
                name="label",
265
                inputs_dict=dict(y_true=fixed_label, y_pred=pred_fixed_label),
266
            )
267
268
            # additional label metrics
269
            tre = compute_centroid_distance(
270
                y_true=fixed_label, y_pred=pred_fixed_label, grid=self.grid_ref
271
            )
272
            dice_binary = DiceScore(binary=True)(
273
                y_true=fixed_label, y_pred=pred_fixed_label
274
            )
275
            self._model.add_metric(tre, name="metric/TRE", aggregation="mean")
276
            self._model.add_metric(
277
                dice_binary, name="metric/BinaryDiceScore", aggregation="mean"
278
            )
279
280
    def call(
281
        self, inputs: Dict[str, tf.Tensor], training=None, mask=None
282
    ) -> Dict[str, tf.Tensor]:
283
        """
284
        Call the self._model.
285
286
        :param inputs: a dict of tensors.
287
        :param training: training or not.
288
        :param mask: maks for inputs.
289
        :return:
290
        """
291
        return self._model(inputs, training=training, mask=mask)  # pragma: no cover
292
293
    @abstractmethod
294
    def postprocess(
295
        self,
296
        inputs: Dict[str, tf.Tensor],
297
        outputs: Dict[str, tf.Tensor],
298
    ) -> Tuple[tf.Tensor, Dict]:
299
        """
300
        Return a dict used for saving inputs and outputs.
301
302
        :param inputs: dict of model inputs
303
        :param outputs: dict of model outputs
304
        :return: tuple, indices and a dict.
305
            In the dict, each value is (tensor, normalize, on_label), where
306
            - normalize = True if the tensor need to be normalized to [0, 1]
307
            - on_label = True if the tensor depends on label
308
        """
309
310
    def plot_model(self, output_dir: str):
311
        """
312
        Save model structure in png.
313
314
        :param output_dir: path to the output dir.
315
        """
316
        logging.info(self._model.summary())
317
        try:
318
            tf.keras.utils.plot_model(
319
                self._model,
320
                to_file=os.path.join(output_dir, f"{self.name}.png"),
321
                dpi=96,
322
                show_shapes=True,
323
                show_layer_names=True,
324
                expand_nested=False,
325
            )
326
        except ImportError as err:  # pragma: no cover
327
            logging.error(
328
                "Failed to plot model structure."
329
                "Please check if graphviz is installed.\n"
330
                "Error message is:"
331
                f"{err}"
332
            )
333
334
    def log_tensor_stats(self, tensor: tf.Tensor, name: str):
335
        """
336
        Log statistics of a given tensor.
337
338
        :param tensor: tensor to monitor.
339
        :param name: name of the tensor.
340
        """
341
        flatten = tf.reshape(tensor, shape=(self.batch_size, -1))
342
        self._model.add_metric(
343
            tf.reduce_mean(flatten, axis=1),
344
            name=f"metric/{name}_mean",
345
            aggregation="mean",
346
        )
347
        self._model.add_metric(
348
            tf.reduce_min(flatten, axis=1),
349
            name=f"metric/{name}_min",
350
            aggregation="min",
351
        )
352
        self._model.add_metric(
353
            tf.reduce_max(flatten, axis=1),
354
            name=f"metric/{name}_max",
355
            aggregation="max",
356
        )
357
358
359
@REGISTRY.register_model(name="ddf")
360
class DDFModel(RegistrationModel):
361
    """
362
    A registration model predicts DDF.
363
364
    When using global net as backbone,
365
    the model predicts an affine transformation parameters,
366
    and a DDF is calculated based on that.
367
    """
368
369
    name = "DDFModel"
370
371
    def _resize_interpolate(self, field, control_points):
372
        resize = layer.ResizeCPTransform(control_points)
373
        field = resize(field)
374
375
        interpolate = layer.BSplines3DTransform(control_points, self.fixed_image_size)
376
        field = interpolate(field)
377
378
        return field
379
380
    def build_model(self):
381
        """Build the model to be saved as self._model."""
382
        # build inputs
383
        self._inputs = self.build_inputs()
384
        moving_image = self._inputs["moving_image"]
385
        fixed_image = self._inputs["fixed_image"]
386
387
        # build ddf
388
        control_points = self.config["backbone"].pop("control_points", False)
389
        backbone_inputs = self.concat_images(moving_image, fixed_image)
390
        backbone = REGISTRY.build_backbone(
391
            config=self.config["backbone"],
392
            default_args=dict(
393
                image_size=self.fixed_image_size,
394
                out_channels=3,
395
                out_kernel_initializer="zeros",
396
                out_activation=None,
397
            ),
398
        )
399
400
        if isinstance(backbone, GlobalNet):
401
            # (f_dim1, f_dim2, f_dim3, 3), (4, 3)
402
            ddf, theta = backbone(inputs=backbone_inputs)
403
            self._outputs = dict(ddf=ddf, theta=theta)
404
        else:
405
            # (f_dim1, f_dim2, f_dim3, 3)
406
            ddf = backbone(inputs=backbone_inputs)
407
            ddf = (
408
                self._resize_interpolate(ddf, control_points) if control_points else ddf
409
            )
410
            self._outputs = dict(ddf=ddf)
411
412
        # build outputs
413
        warping = layer.Warping(fixed_image_size=self.fixed_image_size)
414
        # (f_dim1, f_dim2, f_dim3, 3)
415
        pred_fixed_image = warping(inputs=[ddf, moving_image])
416
        self._outputs["pred_fixed_image"] = pred_fixed_image
417
418
        if not self.labeled:
419
            return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
420
421
        # (f_dim1, f_dim2, f_dim3, 3)
422
        moving_label = self._inputs["moving_label"]
423
        pred_fixed_label = warping(inputs=[ddf, moving_label])
424
425
        self._outputs["pred_fixed_label"] = pred_fixed_label
426
        return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
427
428
    def build_loss(self):
429
        """Build losses according to configs."""
430
        super().build_loss()
431
432
        # ddf loss and metrics
433
        ddf = self._outputs["ddf"]
434
        self._build_loss(name="regularization", inputs_dict=dict(inputs=ddf))
435
        self.log_tensor_stats(tensor=ddf, name="ddf")
436
437
    def postprocess(
438
        self,
439
        inputs: Dict[str, tf.Tensor],
440
        outputs: Dict[str, tf.Tensor],
441
    ) -> Tuple[tf.Tensor, Dict]:
442
        """
443
        Return a dict used for saving inputs and outputs.
444
445
        :param inputs: dict of model inputs
446
        :param outputs: dict of model outputs
447
        :return: tuple, indices and a dict.
448
            In the dict, each value is (tensor, normalize, on_label), where
449
            - normalize = True if the tensor need to be normalized to [0, 1]
450
            - on_label = True if the tensor depends on label
451
        """
452
        indices = inputs["indices"]
453
        processed = dict(
454
            moving_image=(inputs["moving_image"], True, False),
455
            fixed_image=(inputs["fixed_image"], True, False),
456
            ddf=(outputs["ddf"], True, False),
457
            pred_fixed_image=(outputs["pred_fixed_image"], True, False),
458
        )
459
460
        # save theta for affine model
461
        if "theta" in outputs:
462
            processed["theta"] = (outputs["theta"], None, None)  # type: ignore
463
464
        if not self.labeled:
465
            return indices, processed
466
467
        processed = {
468
            **dict(
469
                moving_label=(inputs["moving_label"], False, True),
470
                fixed_label=(inputs["fixed_label"], False, True),
471
                pred_fixed_label=(outputs["pred_fixed_label"], False, True),
472
            ),
473
            **processed,
474
        }
475
476
        return indices, processed
477
478
479
@REGISTRY.register_model(name="dvf")
480
class DVFModel(DDFModel):
481
    """
482
    A registration model predicts DVF.
483
484
    DDF is calculated based on DVF.
485
    """
486
487
    name = "DVFModel"
488
489
    def build_model(self):
490
        """Build the model to be saved as self._model."""
491
        # build inputs
492
        self._inputs = self.build_inputs()
493
        moving_image = self._inputs["moving_image"]
494
        fixed_image = self._inputs["fixed_image"]
495
        control_points = self.config["backbone"].pop("control_points", False)
496
497
        # build ddf
498
        backbone_inputs = self.concat_images(moving_image, fixed_image)
499
        backbone = REGISTRY.build_backbone(
500
            config=self.config["backbone"],
501
            default_args=dict(
502
                image_size=self.fixed_image_size,
503
                out_channels=3,
504
                out_kernel_initializer="zeros",
505
                out_activation=None,
506
            ),
507
        )
508
        dvf = backbone(inputs=backbone_inputs)
509
        dvf = self._resize_interpolate(dvf, control_points) if control_points else dvf
510
        ddf = layer.IntDVF(fixed_image_size=self.fixed_image_size)(dvf)
511
512
        # build outputs
513
        self._warping = layer.Warping(fixed_image_size=self.fixed_image_size)
514
        # (f_dim1, f_dim2, f_dim3, 3)
515
        pred_fixed_image = self._warping(inputs=[ddf, moving_image])
516
517
        self._outputs = dict(dvf=dvf, ddf=ddf, pred_fixed_image=pred_fixed_image)
518
519
        if not self.labeled:
520
            return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
521
522
        # (f_dim1, f_dim2, f_dim3, 3)
523
        moving_label = self._inputs["moving_label"]
524
        pred_fixed_label = self._warping(inputs=[ddf, moving_label])
525
526
        self._outputs["pred_fixed_label"] = pred_fixed_label
527
        return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
528
529
    def build_loss(self):
530
        """Build losses according to configs."""
531
        super().build_loss()
532
533
        # dvf metrics
534
        dvf = self._outputs["dvf"]
535
        self.log_tensor_stats(tensor=dvf, name="dvf")
536
537
    def postprocess(
538
        self,
539
        inputs: Dict[str, tf.Tensor],
540
        outputs: Dict[str, tf.Tensor],
541
    ) -> Tuple[tf.Tensor, Dict]:
542
        """
543
        Return a dict used for saving inputs and outputs.
544
545
        :param inputs: dict of model inputs
546
        :param outputs: dict of model outputs
547
        :return: tuple, indices and a dict.
548
            In the dict, each value is (tensor, normalize, on_label), where
549
            - normalize = True if the tensor need to be normalized to [0, 1]
550
            - on_label = True if the tensor depends on label
551
        """
552
        indices, processed = super().postprocess(inputs=inputs, outputs=outputs)
553
        processed["dvf"] = (outputs["dvf"], True, False)
554
        return indices, processed
555
556
557
@REGISTRY.register_model(name="conditional")
558
class ConditionalModel(RegistrationModel):
559
    """
560
    A registration model predicts fixed image label without DDF or DVF.
561
    """
562
563
    name = "ConditionalModel"
564
565
    def build_model(self):
566
        """Build the model to be saved as self._model."""
567
        assert self.labeled
568
569
        # build inputs
570
        self._inputs = self.build_inputs()
571
        moving_image = self._inputs["moving_image"]
572
        fixed_image = self._inputs["fixed_image"]
573
        moving_label = self._inputs["moving_label"]
574
575
        # build ddf
576
        backbone_inputs = self.concat_images(moving_image, fixed_image, moving_label)
577
        backbone = REGISTRY.build_backbone(
578
            config=self.config["backbone"],
579
            default_args=dict(
580
                image_size=self.fixed_image_size,
581
                out_channels=1,
582
                out_kernel_initializer="glorot_uniform",
583
                out_activation="sigmoid",
584
            ),
585
        )
586
        # (batch, f_dim1, f_dim2, f_dim3)
587
        pred_fixed_label = backbone(inputs=backbone_inputs)
588
        pred_fixed_label = tf.squeeze(pred_fixed_label, axis=4)
589
590
        self._outputs = dict(pred_fixed_label=pred_fixed_label)
591
        return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
592
593
    def postprocess(
594
        self,
595
        inputs: Dict[str, tf.Tensor],
596
        outputs: Dict[str, tf.Tensor],
597
    ) -> Tuple[tf.Tensor, Dict]:
598
        """
599
        Return a dict used for saving inputs and outputs.
600
601
        :param inputs: dict of model inputs
602
        :param outputs: dict of model outputs
603
        :return: tuple, indices and a dict.
604
            In the dict, each value is (tensor, normalize, on_label), where
605
            - normalize = True if the tensor need to be normalized to [0, 1]
606
            - on_label = True if the tensor depends on label
607
        """
608
        indices = inputs["indices"]
609
        processed = dict(
610
            moving_image=(inputs["moving_image"], True, False),
611
            fixed_image=(inputs["fixed_image"], True, False),
612
            pred_fixed_label=(outputs["pred_fixed_label"], True, True),
613
            moving_label=(inputs["moving_label"], False, True),
614
            fixed_label=(inputs["fixed_label"], False, True),
615
        )
616
617
        return indices, processed
618