Passed
Pull Request — main (#719)
by Yunguan
01:13
created

deepreg.model.network.DVFModel.build_model()   A

Complexity

Conditions 3

Size

Total Lines 39
Code Lines 25

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 25
dl 0
loc 39
rs 9.28
c 0
b 0
f 0
cc 3
nop 1
1
import logging
2
import os
3
from abc import abstractmethod
4
from copy import deepcopy
5
from typing import Dict, Optional, Tuple
6
7
import tensorflow as tf
8
9
from deepreg.loss.label import DiceScore, compute_centroid_distance
10
from deepreg.model import layer, layer_util
11
from deepreg.model.backbone import GlobalNet
12
from deepreg.registry import REGISTRY
13
14
15
def dict_without(d: dict, key) -> dict:
16
    """
17
    Return a copy of the given dict without a certain key.
18
19
    :param d: dict to be copied.
20
    :param key: key to be removed.
21
    :return: the copy without a key
22
    """
23
    copied = deepcopy(d)
24
    copied.pop(key)
25
    return copied
26
27
28
class RegistrationModel(tf.keras.Model):
29
    """Interface for registration model."""
30
31
    def __init__(
32
        self,
33
        moving_image_size: Tuple,
34
        fixed_image_size: Tuple,
35
        index_size: int,
36
        labeled: bool,
37
        batch_size: int,
38
        config: dict,
39
        num_devices: int = 1,
40
        name: str = "RegistrationModel",
41
    ):
42
        """
43
        Init.
44
45
        :param moving_image_size: (m_dim1, m_dim2, m_dim3)
46
        :param fixed_image_size: (f_dim1, f_dim2, f_dim3)
47
        :param index_size: number of indices for identify each sample
48
        :param labeled: if the data is labeled
49
        :param batch_size: size of mini-batch
50
        :param config: config for method, backbone, and loss.
51
        :param num_devices: number of GPU used,
52
            global_batch_size = batch_size*num_devices
53
        :param name: name of the model
54
        """
55
        super().__init__(name=name)
56
        self.moving_image_size = moving_image_size
57
        self.fixed_image_size = fixed_image_size
58
        self.index_size = index_size
59
        self.labeled = labeled
60
        self.batch_size = batch_size
61
        self.config = config
62
        self.num_devices = num_devices
63
        self.global_batch_size = num_devices * batch_size * 1.0
64
        assert self.global_batch_size > 0
65
66
        self._inputs = None  # save inputs of self._model as dict
67
        self._outputs = None  # save outputs of self._model as dict
68
69
        self.grid_ref = layer_util.get_reference_grid(grid_size=fixed_image_size)[
70
            None, ...
71
        ]
72
        self._model: tf.keras.Model = self.build_model()
73
        self.build_loss()
74
75
    def get_config(self) -> dict:
76
        """Return the config dictionary for recreating this class."""
77
        return dict(
78
            moving_image_size=self.moving_image_size,
79
            fixed_image_size=self.fixed_image_size,
80
            index_size=self.index_size,
81
            labeled=self.labeled,
82
            batch_size=self.batch_size,
83
            config=self.config,
84
            num_devices=self.num_devices,
85
            name=self.name,
86
        )
87
88
    @abstractmethod
89
    def build_model(self):
90
        """Build the model to be saved as self._model."""
91
92
    def build_inputs(self) -> Dict[str, tf.keras.layers.Input]:
93
        """
94
        Build input tensors.
95
96
        :return: dict of inputs.
97
        """
98
        # (batch, m_dim1, m_dim2, m_dim3, 1)
99
        moving_image = tf.keras.Input(
100
            shape=self.moving_image_size,
101
            batch_size=self.batch_size,
102
            name="moving_image",
103
        )
104
        # (batch, f_dim1, f_dim2, f_dim3, 1)
105
        fixed_image = tf.keras.Input(
106
            shape=self.fixed_image_size,
107
            batch_size=self.batch_size,
108
            name="fixed_image",
109
        )
110
        # (batch, index_size)
111
        indices = tf.keras.Input(
112
            shape=(self.index_size,),
113
            batch_size=self.batch_size,
114
            name="indices",
115
        )
116
117
        if not self.labeled:
118
            return dict(
119
                moving_image=moving_image, fixed_image=fixed_image, indices=indices
120
            )
121
122
        # (batch, m_dim1, m_dim2, m_dim3, 1)
123
        moving_label = tf.keras.Input(
124
            shape=self.moving_image_size,
125
            batch_size=self.batch_size,
126
            name="moving_label",
127
        )
128
        # (batch, m_dim1, m_dim2, m_dim3, 1)
129
        fixed_label = tf.keras.Input(
130
            shape=self.fixed_image_size,
131
            batch_size=self.batch_size,
132
            name="fixed_label",
133
        )
134
        return dict(
135
            moving_image=moving_image,
136
            fixed_image=fixed_image,
137
            moving_label=moving_label,
138
            fixed_label=fixed_label,
139
            indices=indices,
140
        )
141
142
    def concat_images(
143
        self,
144
        moving_image: tf.Tensor,
145
        fixed_image: tf.Tensor,
146
        moving_label: Optional[tf.Tensor] = None,
147
    ) -> tf.Tensor:
148
        """
149
        Adjust image shape and concatenate them together.
150
151
        :param moving_image: registration source
152
        :param fixed_image: registration target
153
        :param moving_label: optional, only used for conditional model.
154
        :return:
155
        """
156
        images = []
157
158
        resize_layer = layer.Resize3d(shape=self.fixed_image_size)
159
160
        # (batch, m_dim1, m_dim2, m_dim3, 1)
161
        moving_image = tf.expand_dims(moving_image, axis=4)
162
        moving_image = resize_layer(moving_image)
163
        images.append(moving_image)
164
165
        # (batch, m_dim1, m_dim2, m_dim3, 1)
166
        fixed_image = tf.expand_dims(fixed_image, axis=4)
167
        images.append(fixed_image)
168
169
        # (batch, m_dim1, m_dim2, m_dim3, 1)
170
        if moving_label is not None:
171
            moving_label = tf.expand_dims(moving_label, axis=4)
172
            moving_label = resize_layer(moving_label)
173
            images.append(moving_label)
174
175
        # (batch, f_dim1, f_dim2, f_dim3, 2 or 3)
176
        images = tf.concat(images, axis=4)
177
        return images
178
179
    def _build_loss(self, name: str, inputs_dict: dict):
180
        """
181
        Build and add one weighted loss together with the metrics.
182
183
        :param name: name of loss, image / label / regularization.
184
        :param inputs_dict: inputs for loss function
185
        """
186
187
        if name not in self.config["loss"]:
188
            # loss config is not defined
189
            logging.warning(
190
                f"The configuration for loss {name} is not defined. "
191
                f"Therefore it is not used."
192
            )
193
            return
194
195
        loss_configs = self.config["loss"][name]
196
        if not isinstance(loss_configs, list):
197
            loss_configs = [loss_configs]
198
199
        for loss_config in loss_configs:
200
201
            if "weight" not in loss_config:
202
                # default loss weight 1
203
                logging.warning(
204
                    f"The weight for loss {name} is not defined."
205
                    f"Default weight = 1.0 is used."
206
                )
207
                loss_config["weight"] = 1.0
208
209
            # build loss
210
            weight = loss_config["weight"]
211
212
            if weight == 0:
213
                logging.warning(
214
                    f"The weight for loss {name} is zero." f"Loss is not used."
215
                )
216
                return
217
218
            loss_layer: tf.keras.layers.Layer = REGISTRY.build_loss(
219
                config=dict_without(d=loss_config, key="weight")
220
            )
221
            loss_value = loss_layer(**inputs_dict) / self.global_batch_size
222
            weighted_loss = loss_value * weight if weight != 1 else loss_value
223
224
            # add loss
225
            self._model.add_loss(weighted_loss)
226
227
            loss_value = tf.debugging.check_numerics(
228
                loss_value,
229
                f"loss {name}_{loss_layer.name} inf/nan",
230
                name=f"op_loss_{name}_{loss_layer.name}",
231
            )
232
233
            # add metric
234
            self._model.add_metric(
235
                loss_value, name=f"loss/{name}_{loss_layer.name}", aggregation="mean"
236
            )
237
            self._model.add_metric(
238
                weighted_loss,
239
                name=f"loss/{name}_{loss_layer.name}_weighted",
240
                aggregation="mean",
241
            )
242
243
    @abstractmethod
244
    def build_loss(self):
245
        """Build losses according to configs."""
246
247
        # input metrics
248
        fixed_image = self._inputs["fixed_image"]
249
        moving_image = self._inputs["moving_image"]
250
        self.log_tensor_stats(tensor=moving_image, name="moving_image")
251
        self.log_tensor_stats(tensor=fixed_image, name="fixed_image")
252
253
        # image loss, conditional model does not have this
254
        if "pred_fixed_image" in self._outputs:
255
            pred_fixed_image = self._outputs["pred_fixed_image"]
256
            self._build_loss(
257
                name="image",
258
                inputs_dict=dict(y_true=fixed_image, y_pred=pred_fixed_image),
259
            )
260
261
        if self.labeled:
262
            # input metrics
263
            fixed_label = self._inputs["fixed_label"]
264
            moving_label = self._inputs["moving_label"]
265
            self.log_tensor_stats(tensor=moving_label, name="moving_label")
266
            self.log_tensor_stats(tensor=fixed_label, name="fixed_label")
267
268
            # label loss
269
            pred_fixed_label = self._outputs["pred_fixed_label"]
270
            self._build_loss(
271
                name="label",
272
                inputs_dict=dict(y_true=fixed_label, y_pred=pred_fixed_label),
273
            )
274
275
            # additional label metrics
276
            tre = compute_centroid_distance(
277
                y_true=fixed_label, y_pred=pred_fixed_label, grid=self.grid_ref
278
            )
279
            dice_binary = DiceScore(binary=True)(
280
                y_true=fixed_label, y_pred=pred_fixed_label
281
            )
282
            self._model.add_metric(tre, name="metric/TRE", aggregation="mean")
283
            self._model.add_metric(
284
                dice_binary, name="metric/BinaryDiceScore", aggregation="mean"
285
            )
286
287
    def call(
288
        self, inputs: Dict[str, tf.Tensor], training=None, mask=None
289
    ) -> Dict[str, tf.Tensor]:
290
        """
291
        Call the self._model.
292
293
        :param inputs: a dict of tensors.
294
        :param training: training or not.
295
        :param mask: maks for inputs.
296
        :return:
297
        """
298
        return self._model(inputs, training=training, mask=mask)  # pragma: no cover
299
300
    @abstractmethod
301
    def postprocess(
302
        self,
303
        inputs: Dict[str, tf.Tensor],
304
        outputs: Dict[str, tf.Tensor],
305
    ) -> Tuple[tf.Tensor, Dict]:
306
        """
307
        Return a dict used for saving inputs and outputs.
308
309
        :param inputs: dict of model inputs
310
        :param outputs: dict of model outputs
311
        :return: tuple, indices and a dict.
312
            In the dict, each value is (tensor, normalize, on_label), where
313
            - normalize = True if the tensor need to be normalized to [0, 1]
314
            - on_label = True if the tensor depends on label
315
        """
316
317
    def plot_model(self, output_dir: str):
318
        """
319
        Save model structure in png.
320
321
        :param output_dir: path to the output dir.
322
        """
323
        logging.info(self._model.summary())
324
        try:
325
            tf.keras.utils.plot_model(
326
                self._model,
327
                to_file=os.path.join(output_dir, f"{self.name}.png"),
328
                dpi=96,
329
                show_shapes=True,
330
                show_layer_names=True,
331
                expand_nested=False,
332
            )
333
        except ImportError as err:  # pragma: no cover
334
            logging.error(
335
                "Failed to plot model structure."
336
                "Please check if graphviz is installed.\n"
337
                "Error message is:"
338
                f"{err}"
339
            )
340
341
    def log_tensor_stats(self, tensor: tf.Tensor, name: str):
342
        """
343
        Log statistics of a given tensor.
344
345
        :param tensor: tensor to monitor.
346
        :param name: name of the tensor.
347
        """
348
        flatten = tf.reshape(tensor, shape=(self.batch_size, -1))
349
        self._model.add_metric(
350
            tf.reduce_mean(flatten, axis=1),
351
            name=f"metric/{name}_mean",
352
            aggregation="mean",
353
        )
354
        self._model.add_metric(
355
            tf.reduce_min(flatten, axis=1),
356
            name=f"metric/{name}_min",
357
            aggregation="min",
358
        )
359
        self._model.add_metric(
360
            tf.reduce_max(flatten, axis=1),
361
            name=f"metric/{name}_max",
362
            aggregation="max",
363
        )
364
365
366
@REGISTRY.register_model(name="ddf")
367
class DDFModel(RegistrationModel):
368
    """
369
    A registration model predicts DDF.
370
371
    When using global net as backbone,
372
    the model predicts an affine transformation parameters,
373
    and a DDF is calculated based on that.
374
    """
375
376
    name = "DDFModel"
377
378
    def _resize_interpolate(self, field, control_points):
379
        resize = layer.ResizeCPTransform(control_points)
380
        field = resize(field)
381
382
        interpolate = layer.BSplines3DTransform(control_points, self.fixed_image_size)
383
        field = interpolate(field)
384
385
        return field
386
387
    def build_model(self):
388
        """Build the model to be saved as self._model."""
389
        # build inputs
390
        self._inputs = self.build_inputs()
391
        moving_image = self._inputs["moving_image"]
392
        fixed_image = self._inputs["fixed_image"]
393
394
        # build ddf
395
        control_points = self.config["backbone"].pop("control_points", False)
396
        backbone_inputs = self.concat_images(moving_image, fixed_image)
397
        backbone = REGISTRY.build_backbone(
398
            config=self.config["backbone"],
399
            default_args=dict(
400
                image_size=self.fixed_image_size,
401
                out_channels=3,
402
                out_kernel_initializer="zeros",
403
                out_activation=None,
404
            ),
405
        )
406
407
        if isinstance(backbone, GlobalNet):
408
            # (f_dim1, f_dim2, f_dim3, 3), (4, 3)
409
            ddf, theta = backbone(inputs=backbone_inputs)
410
            self._outputs = dict(ddf=ddf, theta=theta)
411
        else:
412
            # (f_dim1, f_dim2, f_dim3, 3)
413
            ddf = backbone(inputs=backbone_inputs)
414
            ddf = (
415
                self._resize_interpolate(ddf, control_points) if control_points else ddf
416
            )
417
            self._outputs = dict(ddf=ddf)
418
419
        # build outputs
420
        warping = layer.Warping(fixed_image_size=self.fixed_image_size)
421
        # (f_dim1, f_dim2, f_dim3, 3)
422
        pred_fixed_image = warping(inputs=[ddf, moving_image])
423
        self._outputs["pred_fixed_image"] = pred_fixed_image
424
425
        if not self.labeled:
426
            return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
427
428
        # (f_dim1, f_dim2, f_dim3, 3)
429
        moving_label = self._inputs["moving_label"]
430
        pred_fixed_label = warping(inputs=[ddf, moving_label])
431
432
        self._outputs["pred_fixed_label"] = pred_fixed_label
433
        return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
434
435
    def build_loss(self):
436
        """Build losses according to configs."""
437
        super().build_loss()
438
439
        # ddf loss and metrics
440
        ddf = self._outputs["ddf"]
441
        self._build_loss(name="regularization", inputs_dict=dict(inputs=ddf))
442
        self.log_tensor_stats(tensor=ddf, name="ddf")
443
444
    def postprocess(
445
        self,
446
        inputs: Dict[str, tf.Tensor],
447
        outputs: Dict[str, tf.Tensor],
448
    ) -> Tuple[tf.Tensor, Dict]:
449
        """
450
        Return a dict used for saving inputs and outputs.
451
452
        :param inputs: dict of model inputs
453
        :param outputs: dict of model outputs
454
        :return: tuple, indices and a dict.
455
            In the dict, each value is (tensor, normalize, on_label), where
456
            - normalize = True if the tensor need to be normalized to [0, 1]
457
            - on_label = True if the tensor depends on label
458
        """
459
        indices = inputs["indices"]
460
        processed = dict(
461
            moving_image=(inputs["moving_image"], True, False),
462
            fixed_image=(inputs["fixed_image"], True, False),
463
            ddf=(outputs["ddf"], True, False),
464
            pred_fixed_image=(outputs["pred_fixed_image"], True, False),
465
        )
466
467
        # save theta for affine model
468
        if "theta" in outputs:
469
            processed["theta"] = (outputs["theta"], None, None)  # type: ignore
470
471
        if not self.labeled:
472
            return indices, processed
473
474
        processed = {
475
            **dict(
476
                moving_label=(inputs["moving_label"], False, True),
477
                fixed_label=(inputs["fixed_label"], False, True),
478
                pred_fixed_label=(outputs["pred_fixed_label"], False, True),
479
            ),
480
            **processed,
481
        }
482
483
        return indices, processed
484
485
486
@REGISTRY.register_model(name="dvf")
487
class DVFModel(DDFModel):
488
    """
489
    A registration model predicts DVF.
490
491
    DDF is calculated based on DVF.
492
    """
493
494
    name = "DVFModel"
495
496
    def build_model(self):
497
        """Build the model to be saved as self._model."""
498
        # build inputs
499
        self._inputs = self.build_inputs()
500
        moving_image = self._inputs["moving_image"]
501
        fixed_image = self._inputs["fixed_image"]
502
        control_points = self.config["backbone"].pop("control_points", False)
503
504
        # build ddf
505
        backbone_inputs = self.concat_images(moving_image, fixed_image)
506
        backbone = REGISTRY.build_backbone(
507
            config=self.config["backbone"],
508
            default_args=dict(
509
                image_size=self.fixed_image_size,
510
                out_channels=3,
511
                out_kernel_initializer="zeros",
512
                out_activation=None,
513
            ),
514
        )
515
        dvf = backbone(inputs=backbone_inputs)
516
        dvf = self._resize_interpolate(dvf, control_points) if control_points else dvf
517
        ddf = layer.IntDVF(fixed_image_size=self.fixed_image_size)(dvf)
518
519
        # build outputs
520
        self._warping = layer.Warping(fixed_image_size=self.fixed_image_size)
521
        # (f_dim1, f_dim2, f_dim3, 3)
522
        pred_fixed_image = self._warping(inputs=[ddf, moving_image])
523
524
        self._outputs = dict(dvf=dvf, ddf=ddf, pred_fixed_image=pred_fixed_image)
525
526
        if not self.labeled:
527
            return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
528
529
        # (f_dim1, f_dim2, f_dim3, 3)
530
        moving_label = self._inputs["moving_label"]
531
        pred_fixed_label = self._warping(inputs=[ddf, moving_label])
532
533
        self._outputs["pred_fixed_label"] = pred_fixed_label
534
        return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
535
536
    def build_loss(self):
537
        """Build losses according to configs."""
538
        super().build_loss()
539
540
        # dvf metrics
541
        dvf = self._outputs["dvf"]
542
        self.log_tensor_stats(tensor=dvf, name="dvf")
543
544
    def postprocess(
545
        self,
546
        inputs: Dict[str, tf.Tensor],
547
        outputs: Dict[str, tf.Tensor],
548
    ) -> Tuple[tf.Tensor, Dict]:
549
        """
550
        Return a dict used for saving inputs and outputs.
551
552
        :param inputs: dict of model inputs
553
        :param outputs: dict of model outputs
554
        :return: tuple, indices and a dict.
555
            In the dict, each value is (tensor, normalize, on_label), where
556
            - normalize = True if the tensor need to be normalized to [0, 1]
557
            - on_label = True if the tensor depends on label
558
        """
559
        indices, processed = super().postprocess(inputs=inputs, outputs=outputs)
560
        processed["dvf"] = (outputs["dvf"], True, False)
561
        return indices, processed
562
563
564
@REGISTRY.register_model(name="conditional")
565
class ConditionalModel(RegistrationModel):
566
    """
567
    A registration model predicts fixed image label without DDF or DVF.
568
    """
569
570
    name = "ConditionalModel"
571
572
    def build_model(self):
573
        """Build the model to be saved as self._model."""
574
        assert self.labeled
575
576
        # build inputs
577
        self._inputs = self.build_inputs()
578
        moving_image = self._inputs["moving_image"]
579
        fixed_image = self._inputs["fixed_image"]
580
        moving_label = self._inputs["moving_label"]
581
582
        # build ddf
583
        backbone_inputs = self.concat_images(moving_image, fixed_image, moving_label)
584
        backbone = REGISTRY.build_backbone(
585
            config=self.config["backbone"],
586
            default_args=dict(
587
                image_size=self.fixed_image_size,
588
                out_channels=1,
589
                out_kernel_initializer="glorot_uniform",
590
                out_activation="sigmoid",
591
            ),
592
        )
593
        # (batch, f_dim1, f_dim2, f_dim3)
594
        pred_fixed_label = backbone(inputs=backbone_inputs)
595
        pred_fixed_label = tf.squeeze(pred_fixed_label, axis=4)
596
597
        self._outputs = dict(pred_fixed_label=pred_fixed_label)
598
        return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
599
600
    def postprocess(
601
        self,
602
        inputs: Dict[str, tf.Tensor],
603
        outputs: Dict[str, tf.Tensor],
604
    ) -> Tuple[tf.Tensor, Dict]:
605
        """
606
        Return a dict used for saving inputs and outputs.
607
608
        :param inputs: dict of model inputs
609
        :param outputs: dict of model outputs
610
        :return: tuple, indices and a dict.
611
            In the dict, each value is (tensor, normalize, on_label), where
612
            - normalize = True if the tensor need to be normalized to [0, 1]
613
            - on_label = True if the tensor depends on label
614
        """
615
        indices = inputs["indices"]
616
        processed = dict(
617
            moving_image=(inputs["moving_image"], True, False),
618
            fixed_image=(inputs["fixed_image"], True, False),
619
            pred_fixed_label=(outputs["pred_fixed_label"], True, True),
620
            moving_label=(inputs["moving_label"], False, True),
621
            fixed_label=(inputs["fixed_label"], False, True),
622
        )
623
624
        return indices, processed
625