Passed
Pull Request — main (#713)
by Yunguan
01:55
created

RegistrationModel.log_tensor_stats()   A

Complexity

Conditions 1

Size

Total Lines 22
Code Lines 14

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 14
dl 0
loc 22
rs 9.7
c 0
b 0
f 0
cc 1
nop 3
1
import logging
2
import os
3
from abc import abstractmethod
4
from copy import deepcopy
5
from typing import Dict, Optional, Tuple
6
7
import tensorflow as tf
8
9
from deepreg.loss.label import compute_centroid_distance
10
from deepreg.model import layer
11
from deepreg.model.backbone import GlobalNet
12
from deepreg.registry import REGISTRY
13
14
15
def dict_without(d: dict, key) -> dict:
16
    """
17
    Return a copy of the given dict without a certain key.
18
19
    :param d: dict to be copied.
20
    :param key: key to be removed.
21
    :return: the copy without a key
22
    """
23
    copied = deepcopy(d)
24
    copied.pop(key)
25
    return copied
26
27
28
class RegistrationModel(tf.keras.Model):
29
    """Interface for registration model."""
30
31
    def __init__(
32
        self,
33
        moving_image_size: tuple,
34
        fixed_image_size: tuple,
35
        index_size: int,
36
        labeled: bool,
37
        batch_size: int,
38
        config: dict,
39
        num_devices: int = 1,
40
        name: str = "RegistrationModel",
41
    ):
42
        """
43
        Init.
44
45
        :param moving_image_size: (m_dim1, m_dim2, m_dim3)
46
        :param fixed_image_size: (f_dim1, f_dim2, f_dim3)
47
        :param index_size: number of indices for identify each sample
48
        :param labeled: if the data is labeled
49
        :param batch_size: size of mini-batch
50
        :param config: config for method, backbone, and loss.
51
        :param num_devices: number of GPU used,
52
            global_batch_size = batch_size*num_devices
53
        :param name: name of the model
54
        """
55
        super().__init__(name=name)
56
        self.moving_image_size = moving_image_size
57
        self.fixed_image_size = fixed_image_size
58
        self.index_size = index_size
59
        self.labeled = labeled
60
        self.batch_size = batch_size
61
        self.config = config
62
        self.num_devices = num_devices
63
        self.global_batch_size = num_devices * batch_size
64
65
        self._inputs = None  # save inputs of self._model as dict
66
        self._outputs = None  # save outputs of self._model as dict
67
68
        self._model: tf.keras.Model = self.build_model()
69
        self.build_loss()
70
71
    def get_config(self) -> dict:
72
        """Return the config dictionary for recreating this class."""
73
        return dict(
74
            moving_image_size=self.moving_image_size,
75
            fixed_image_size=self.fixed_image_size,
76
            index_size=self.index_size,
77
            labeled=self.labeled,
78
            batch_size=self.batch_size,
79
            config=self.config,
80
            num_devices=self.num_devices,
81
            name=self.name,
82
        )
83
84
    @abstractmethod
85
    def build_model(self):
86
        """Build the model to be saved as self._model."""
87
88
    def build_inputs(self) -> Dict[str, tf.keras.layers.Input]:
89
        """
90
        Build input tensors.
91
92
        :return: dict of inputs.
93
        """
94
        # (batch, m_dim1, m_dim2, m_dim3, 1)
95
        moving_image = tf.keras.Input(
96
            shape=self.moving_image_size,
97
            batch_size=self.batch_size,
98
            name="moving_image",
99
        )
100
        # (batch, f_dim1, f_dim2, f_dim3, 1)
101
        fixed_image = tf.keras.Input(
102
            shape=self.fixed_image_size,
103
            batch_size=self.batch_size,
104
            name="fixed_image",
105
        )
106
        # (batch, index_size)
107
        indices = tf.keras.Input(
108
            shape=(self.index_size,),
109
            batch_size=self.batch_size,
110
            name="indices",
111
        )
112
113
        if not self.labeled:
114
            return dict(
115
                moving_image=moving_image, fixed_image=fixed_image, indices=indices
116
            )
117
118
        # (batch, m_dim1, m_dim2, m_dim3, 1)
119
        moving_label = tf.keras.Input(
120
            shape=self.moving_image_size,
121
            batch_size=self.batch_size,
122
            name="moving_label",
123
        )
124
        # (batch, m_dim1, m_dim2, m_dim3, 1)
125
        fixed_label = tf.keras.Input(
126
            shape=self.fixed_image_size,
127
            batch_size=self.batch_size,
128
            name="fixed_label",
129
        )
130
        return dict(
131
            moving_image=moving_image,
132
            fixed_image=fixed_image,
133
            moving_label=moving_label,
134
            fixed_label=fixed_label,
135
            indices=indices,
136
        )
137
138
    def concat_images(
139
        self,
140
        moving_image: tf.Tensor,
141
        fixed_image: tf.Tensor,
142
        moving_label: Optional[tf.Tensor] = None,
143
    ) -> tf.Tensor:
144
        """
145
        Adjust image shape and concatenate them together.
146
147
        :param moving_image: registration source
148
        :param fixed_image: registration target
149
        :param moving_label: optional, only used for conditional model.
150
        :return:
151
        """
152
        images = []
153
154
        resize_layer = layer.Resize3d(shape=self.fixed_image_size)
155
156
        # (batch, m_dim1, m_dim2, m_dim3, 1)
157
        moving_image = tf.expand_dims(moving_image, axis=4)
158
        moving_image = resize_layer(moving_image)
159
        images.append(moving_image)
160
161
        # (batch, m_dim1, m_dim2, m_dim3, 1)
162
        fixed_image = tf.expand_dims(fixed_image, axis=4)
163
        images.append(fixed_image)
164
165
        # (batch, m_dim1, m_dim2, m_dim3, 1)
166
        if moving_label is not None:
167
            moving_label = tf.expand_dims(moving_label, axis=4)
168
            moving_label = resize_layer(moving_label)
169
            images.append(moving_label)
170
171
        # (batch, f_dim1, f_dim2, f_dim3, 2 or 3)
172
        images = tf.concat(images, axis=4)
173
        return images
174
175
    def _build_loss(self, name: str, inputs_dict: dict):
176
        """
177
        Build and add one weighted loss together with the metrics.
178
179
        :param name: name of loss, image / label / regularization.
180
        :param inputs_dict: inputs for loss function
181
        """
182
183
        if name not in self.config["loss"]:
184
            # loss config is not defined
185
            logging.warning(
186
                f"The configuration for loss {name} is not defined. "
187
                f"Therefore it is not used."
188
            )
189
            return
190
191
        loss_configs = self.config["loss"][name]
192
        if not isinstance(loss_configs, list):
193
            loss_configs = [loss_configs]
194
195
        for loss_config in loss_configs:
196
197
            if "weight" not in loss_config:
198
                # default loss weight 1
199
                logging.warning(
200
                    f"The weight for loss {name} is not defined."
201
                    f"Default weight = 1.0 is used."
202
                )
203
                loss_config["weight"] = 1.0
204
205
            # build loss
206
            weight = loss_config["weight"]
207
208
            if weight == 0:
209
                logging.warning(
210
                    f"The weight for loss {name} is zero." f"Loss is not used."
211
                )
212
                return
213
214
            loss_layer: tf.keras.layers.Layer = REGISTRY.build_loss(
215
                config=dict_without(d=loss_config, key="weight")
216
            )
217
            loss_value = loss_layer(**inputs_dict) / self.global_batch_size
218
            weighted_loss = loss_value * weight
219
220
            # add loss
221
            self._model.add_loss(weighted_loss)
222
223
            # add metric
224
            self._model.add_metric(
225
                loss_value, name=f"loss/{name}_{loss_layer.name}", aggregation="mean"
226
            )
227
            self._model.add_metric(
228
                weighted_loss,
229
                name=f"loss/{name}_{loss_layer.name}_weighted",
230
                aggregation="mean",
231
            )
232
233
    @abstractmethod
234
    def build_loss(self):
235
        """Build losses according to configs."""
236
237
    def call(
238
        self, inputs: Dict[str, tf.Tensor], training=None, mask=None
239
    ) -> Dict[str, tf.Tensor]:
240
        """
241
        Call the self._model.
242
243
        :param inputs: a dict of tensors.
244
        :param training: training or not.
245
        :param mask: maks for inputs.
246
        :return:
247
        """
248
        return self._model(inputs, training=training, mask=mask)  # pragma: no cover
249
250
    @abstractmethod
251
    def postprocess(
252
        self,
253
        inputs: Dict[str, tf.Tensor],
254
        outputs: Dict[str, tf.Tensor],
255
    ) -> Tuple[tf.Tensor, Dict]:
256
        """
257
        Return a dict used for saving inputs and outputs.
258
259
        :param inputs: dict of model inputs
260
        :param outputs: dict of model outputs
261
        :return: tuple, indices and a dict.
262
            In the dict, each value is (tensor, normalize, on_label), where
263
            - normalize = True if the tensor need to be normalized to [0, 1]
264
            - on_label = True if the tensor depends on label
265
        """
266
267
    def plot_model(self, output_dir: str):
268
        """
269
        Save model structure in png.
270
271
        :param output_dir: path to the output dir.
272
        """
273
        logging.info(self._model.summary())
274
        try:
275
            tf.keras.utils.plot_model(
276
                self._model,
277
                to_file=os.path.join(output_dir, f"{self.name}.png"),
278
                dpi=96,
279
                show_shapes=True,
280
                show_layer_names=True,
281
                expand_nested=False,
282
            )
283
        except ImportError as err:  # pragma: no cover
284
            logging.error(
285
                "Failed to plot model structure."
286
                "Please check if graphviz is installed.\n"
287
                "Error message is:"
288
                f"{err}"
289
            )
290
291
    def log_tensor_stats(self, tensor: tf.Tensor, name: str):
292
        """
293
        Log statistics of a given tensor.
294
295
        :param tensor: tensor to monitor.
296
        :param name: name of the tensor.
297
        """
298
        flatten = tf.reshape(tensor, shape=(self.batch_size, -1))
299
        self._model.add_metric(
300
            tf.reduce_mean(flatten, axis=1),
301
            name=f"metric/{name}_mean",
302
            aggregation="mean",
303
        )
304
        self._model.add_metric(
305
            tf.reduce_min(flatten, axis=1),
306
            name=f"metric/{name}_min",
307
            aggregation="min",
308
        )
309
        self._model.add_metric(
310
            tf.reduce_max(flatten, axis=1),
311
            name=f"metric/{name}_max",
312
            aggregation="max",
313
        )
314
315
316
@REGISTRY.register_model(name="ddf")
317
class DDFModel(RegistrationModel):
318
    """
319
    A registration model predicts DDF.
320
321
    When using global net as backbone,
322
    the model predicts an affine transformation parameters,
323
    and a DDF is calculated based on that.
324
    """
325
326
    name = "DDFModel"
327
328
    def _resize_interpolate(self, field, control_points):
329
        resize = layer.ResizeCPTransform(control_points)
330
        field = resize(field)
331
332
        interpolate = layer.BSplines3DTransform(control_points, self.fixed_image_size)
333
        field = interpolate(field)
334
335
        return field
336
337
    def build_model(self):
338
        """Build the model to be saved as self._model."""
339
        # build inputs
340
        self._inputs = self.build_inputs()
341
        moving_image = self._inputs["moving_image"]
342
        fixed_image = self._inputs["fixed_image"]
343
344
        # build ddf
345
        control_points = self.config["backbone"].pop("control_points", False)
346
        backbone_inputs = self.concat_images(moving_image, fixed_image)
347
        backbone = REGISTRY.build_backbone(
348
            config=self.config["backbone"],
349
            default_args=dict(
350
                image_size=self.fixed_image_size,
351
                out_channels=3,
352
                out_kernel_initializer="zeros",
353
                out_activation=None,
354
            ),
355
        )
356
357
        if isinstance(backbone, GlobalNet):
358
            # (f_dim1, f_dim2, f_dim3, 3), (4, 3)
359
            ddf, theta = backbone(inputs=backbone_inputs)
360
            self._outputs = dict(ddf=ddf, theta=theta)
361
        else:
362
            # (f_dim1, f_dim2, f_dim3, 3)
363
            ddf = backbone(inputs=backbone_inputs)
364
            ddf = (
365
                self._resize_interpolate(ddf, control_points) if control_points else ddf
366
            )
367
            self._outputs = dict(ddf=ddf)
368
369
        # build outputs
370
        self._warping = layer.Warping(fixed_image_size=self.fixed_image_size)
371
        # (f_dim1, f_dim2, f_dim3, 3)
372
        pred_fixed_image = self._warping(inputs=[ddf, moving_image])
373
        self._outputs["pred_fixed_image"] = pred_fixed_image
374
375
        if not self.labeled:
376
            return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
377
378
        # (f_dim1, f_dim2, f_dim3, 3)
379
        moving_label = self._inputs["moving_label"]
380
        pred_fixed_label = self._warping(inputs=[ddf, moving_label])
381
382
        self._outputs["pred_fixed_label"] = pred_fixed_label
383
        return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
384
385
    def build_loss(self):
386
        """Build losses according to configs."""
387
        fixed_image = self._inputs["fixed_image"]
388
        moving_image = self._inputs["moving_image"]
389
        ddf = self._outputs["ddf"]
390
        pred_fixed_image = self._outputs["pred_fixed_image"]
391
392
        # inputs
393
        self.log_tensor_stats(tensor=moving_image, name="moving_image")
394
        self.log_tensor_stats(tensor=fixed_image, name="fixed_image")
395
396
        # ddf
397
        self._build_loss(name="regularization", inputs_dict=dict(inputs=ddf))
398
        self.log_tensor_stats(tensor=ddf, name="ddf")
399
400
        # image
401
        self._build_loss(
402
            name="image", inputs_dict=dict(y_true=fixed_image, y_pred=pred_fixed_image)
403
        )
404
405
        # label
406
        if self.labeled:
407
            fixed_label = self._inputs["fixed_label"]
408
            moving_label = self._inputs["moving_label"]
409
            pred_fixed_label = self._outputs["pred_fixed_label"]
410
            self._build_loss(
411
                name="label",
412
                inputs_dict=dict(y_true=fixed_label, y_pred=pred_fixed_label),
413
            )
414
            tre = compute_centroid_distance(
415
                y_true=fixed_label, y_pred=pred_fixed_label, grid=self._warping.grid_ref
416
            )
417
            self._model.add_metric(tre, name="metric/TRE", aggregation="mean")
418
            self.log_tensor_stats(tensor=moving_label, name="moving_label")
419
            self.log_tensor_stats(tensor=fixed_label, name="fixed_label")
420
421
    def postprocess(
422
        self,
423
        inputs: Dict[str, tf.Tensor],
424
        outputs: Dict[str, tf.Tensor],
425
    ) -> Tuple[tf.Tensor, Dict]:
426
        """
427
        Return a dict used for saving inputs and outputs.
428
429
        :param inputs: dict of model inputs
430
        :param outputs: dict of model outputs
431
        :return: tuple, indices and a dict.
432
            In the dict, each value is (tensor, normalize, on_label), where
433
            - normalize = True if the tensor need to be normalized to [0, 1]
434
            - on_label = True if the tensor depends on label
435
        """
436
        indices = inputs["indices"]
437
        processed = dict(
438
            moving_image=(inputs["moving_image"], True, False),
439
            fixed_image=(inputs["fixed_image"], True, False),
440
            ddf=(outputs["ddf"], True, False),
441
            pred_fixed_image=(outputs["pred_fixed_image"], True, False),
442
        )
443
444
        # save theta for affine model
445
        if "theta" in outputs:
446
            processed["theta"] = (outputs["theta"], None, None)  # type: ignore
447
448
        if not self.labeled:
449
            return indices, processed
450
451
        processed = {
452
            **dict(
453
                moving_label=(inputs["moving_label"], False, True),
454
                fixed_label=(inputs["fixed_label"], False, True),
455
                pred_fixed_label=(outputs["pred_fixed_label"], False, True),
456
            ),
457
            **processed,
458
        }
459
460
        return indices, processed
461
462
463
@REGISTRY.register_model(name="dvf")
464
class DVFModel(DDFModel):
465
    """
466
    A registration model predicts DVF.
467
468
    DDF is calculated based on DVF.
469
    """
470
471
    name = "DVFModel"
472
473
    def build_model(self):
474
        """Build the model to be saved as self._model."""
475
        # build inputs
476
        self._inputs = self.build_inputs()
477
        moving_image = self._inputs["moving_image"]
478
        fixed_image = self._inputs["fixed_image"]
479
        control_points = self.config["backbone"].pop("control_points", False)
480
481
        # build ddf
482
        backbone_inputs = self.concat_images(moving_image, fixed_image)
483
        backbone = REGISTRY.build_backbone(
484
            config=self.config["backbone"],
485
            default_args=dict(
486
                image_size=self.fixed_image_size,
487
                out_channels=3,
488
                out_kernel_initializer="zeros",
489
                out_activation=None,
490
            ),
491
        )
492
        dvf = backbone(inputs=backbone_inputs)
493
        dvf = self._resize_interpolate(dvf, control_points) if control_points else dvf
494
        ddf = layer.IntDVF(fixed_image_size=self.fixed_image_size)(dvf)
495
496
        # build outputs
497
        self._warping = layer.Warping(fixed_image_size=self.fixed_image_size)
498
        # (f_dim1, f_dim2, f_dim3, 3)
499
        pred_fixed_image = self._warping(inputs=[ddf, moving_image])
500
501
        self._outputs = dict(dvf=dvf, ddf=ddf, pred_fixed_image=pred_fixed_image)
502
503
        if not self.labeled:
504
            return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
505
506
        # (f_dim1, f_dim2, f_dim3, 3)
507
        moving_label = self._inputs["moving_label"]
508
        pred_fixed_label = self._warping(inputs=[ddf, moving_label])
509
510
        self._outputs["pred_fixed_label"] = pred_fixed_label
511
        return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
512
513
    def postprocess(
514
        self,
515
        inputs: Dict[str, tf.Tensor],
516
        outputs: Dict[str, tf.Tensor],
517
    ) -> Tuple[tf.Tensor, Dict]:
518
        """
519
        Return a dict used for saving inputs and outputs.
520
521
        :param inputs: dict of model inputs
522
        :param outputs: dict of model outputs
523
        :return: tuple, indices and a dict.
524
            In the dict, each value is (tensor, normalize, on_label), where
525
            - normalize = True if the tensor need to be normalized to [0, 1]
526
            - on_label = True if the tensor depends on label
527
        """
528
        indices, processed = super().postprocess(inputs=inputs, outputs=outputs)
529
        processed["dvf"] = (outputs["dvf"], True, False)
530
        return indices, processed
531
532
533
@REGISTRY.register_model(name="conditional")
534
class ConditionalModel(RegistrationModel):
535
    """
536
    A registration model predicts fixed image label without DDF or DVF.
537
    """
538
539
    name = "ConditionalModel"
540
541
    def build_model(self):
542
        """Build the model to be saved as self._model."""
543
        assert self.labeled
544
545
        # build inputs
546
        self._inputs = self.build_inputs()
547
        moving_image = self._inputs["moving_image"]
548
        fixed_image = self._inputs["fixed_image"]
549
        moving_label = self._inputs["moving_label"]
550
551
        # build ddf
552
        backbone_inputs = self.concat_images(moving_image, fixed_image, moving_label)
553
        backbone = REGISTRY.build_backbone(
554
            config=self.config["backbone"],
555
            default_args=dict(
556
                image_size=self.fixed_image_size,
557
                out_channels=1,
558
                out_kernel_initializer="glorot_uniform",
559
                out_activation="sigmoid",
560
            ),
561
        )
562
        # (batch, f_dim1, f_dim2, f_dim3)
563
        pred_fixed_label = backbone(inputs=backbone_inputs)
564
        pred_fixed_label = tf.squeeze(pred_fixed_label, axis=4)
565
566
        self._outputs = dict(pred_fixed_label=pred_fixed_label)
567
        return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
568
569
    def build_loss(self):
570
        """Build losses according to configs."""
571
        fixed_label = self._inputs["fixed_label"]
572
        pred_fixed_label = self._outputs["pred_fixed_label"]
573
574
        self._build_loss(
575
            name="label",
576
            inputs_dict=dict(y_true=fixed_label, y_pred=pred_fixed_label),
577
        )
578
579
    def postprocess(
580
        self,
581
        inputs: Dict[str, tf.Tensor],
582
        outputs: Dict[str, tf.Tensor],
583
    ) -> Tuple[tf.Tensor, Dict]:
584
        """
585
        Return a dict used for saving inputs and outputs.
586
587
        :param inputs: dict of model inputs
588
        :param outputs: dict of model outputs
589
        :return: tuple, indices and a dict.
590
            In the dict, each value is (tensor, normalize, on_label), where
591
            - normalize = True if the tensor need to be normalized to [0, 1]
592
            - on_label = True if the tensor depends on label
593
        """
594
        indices = inputs["indices"]
595
        processed = dict(
596
            moving_image=(inputs["moving_image"], True, False),
597
            fixed_image=(inputs["fixed_image"], True, False),
598
            pred_fixed_label=(outputs["pred_fixed_label"], True, True),
599
            moving_label=(inputs["moving_label"], False, True),
600
            fixed_label=(inputs["fixed_label"], False, True),
601
        )
602
603
        return indices, processed
604