Passed
Pull Request — main (#713)
by Yunguan
01:55
created

deepreg.model.network.RegistrationModel.__init__()   A

Complexity

Conditions 1

Size

Total Lines 39
Code Lines 23

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 23
dl 0
loc 39
rs 9.328
c 0
b 0
f 0
cc 1
nop 9

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
import logging
2
import os
3
from abc import abstractmethod
4
from copy import deepcopy
5
from typing import Dict, Optional, Tuple
6
7
import tensorflow as tf
8
9
from deepreg.loss.label import compute_centroid_distance
10
from deepreg.model import layer
11
from deepreg.model.backbone import GlobalNet
12
from deepreg.registry import REGISTRY
13
14
15
def dict_without(d: dict, key) -> dict:
16
    """
17
    Return a copy of the given dict without a certain key.
18
19
    :param d: dict to be copied.
20
    :param key: key to be removed.
21
    :return: the copy without a key
22
    """
23
    copied = deepcopy(d)
24
    copied.pop(key)
25
    return copied
26
27
28
class RegistrationModel(tf.keras.Model):
29
    """Interface for registration model."""
30
31
    def __init__(
32
        self,
33
        moving_image_size: tuple,
34
        fixed_image_size: tuple,
35
        index_size: int,
36
        labeled: bool,
37
        batch_size: int,
38
        config: dict,
39
        num_devices: int = 1,
40
        name: str = "RegistrationModel",
41
    ):
42
        """
43
        Init.
44
45
        :param moving_image_size: (m_dim1, m_dim2, m_dim3)
46
        :param fixed_image_size: (f_dim1, f_dim2, f_dim3)
47
        :param index_size: number of indices for identify each sample
48
        :param labeled: if the data is labeled
49
        :param batch_size: size of mini-batch
50
        :param config: config for method, backbone, and loss.
51
        :param num_devices: number of GPU used,
52
            global_batch_size = batch_size*num_devices
53
        :param name: name of the model
54
        """
55
        super().__init__(name=name)
56
        self.moving_image_size = moving_image_size
57
        self.fixed_image_size = fixed_image_size
58
        self.index_size = index_size
59
        self.labeled = labeled
60
        self.batch_size = batch_size
61
        self.config = config
62
        self.num_devices = num_devices
63
        self.global_batch_size = num_devices * batch_size
64
65
        self._inputs = None  # save inputs of self._model as dict
66
        self._outputs = None  # save outputs of self._model as dict
67
68
        self._model: tf.keras.Model = self.build_model()
69
        self.build_loss()
70
71
    def get_config(self) -> dict:
72
        """Return the config dictionary for recreating this class."""
73
        return dict(
74
            moving_image_size=self.moving_image_size,
75
            fixed_image_size=self.fixed_image_size,
76
            index_size=self.index_size,
77
            labeled=self.labeled,
78
            batch_size=self.batch_size,
79
            config=self.config,
80
            num_devices=self.num_devices,
81
            name=self.name,
82
        )
83
84
    @abstractmethod
85
    def build_model(self):
86
        """Build the model to be saved as self._model."""
87
88
    def build_inputs(self) -> Dict[str, tf.keras.layers.Input]:
89
        """
90
        Build input tensors.
91
92
        :return: dict of inputs.
93
        """
94
        # (batch, m_dim1, m_dim2, m_dim3, 1)
95
        moving_image = tf.keras.Input(
96
            shape=self.moving_image_size,
97
            batch_size=self.batch_size,
98
            name="moving_image",
99
        )
100
        # (batch, f_dim1, f_dim2, f_dim3, 1)
101
        fixed_image = tf.keras.Input(
102
            shape=self.fixed_image_size,
103
            batch_size=self.batch_size,
104
            name="fixed_image",
105
        )
106
        # (batch, index_size)
107
        indices = tf.keras.Input(
108
            shape=(self.index_size,),
109
            batch_size=self.batch_size,
110
            name="indices",
111
        )
112
113
        if not self.labeled:
114
            return dict(
115
                moving_image=moving_image, fixed_image=fixed_image, indices=indices
116
            )
117
118
        # (batch, m_dim1, m_dim2, m_dim3, 1)
119
        moving_label = tf.keras.Input(
120
            shape=self.moving_image_size,
121
            batch_size=self.batch_size,
122
            name="moving_label",
123
        )
124
        # (batch, m_dim1, m_dim2, m_dim3, 1)
125
        fixed_label = tf.keras.Input(
126
            shape=self.fixed_image_size,
127
            batch_size=self.batch_size,
128
            name="fixed_label",
129
        )
130
        return dict(
131
            moving_image=moving_image,
132
            fixed_image=fixed_image,
133
            moving_label=moving_label,
134
            fixed_label=fixed_label,
135
            indices=indices,
136
        )
137
138
    def concat_images(
139
        self,
140
        moving_image: tf.Tensor,
141
        fixed_image: tf.Tensor,
142
        moving_label: Optional[tf.Tensor] = None,
143
    ) -> tf.Tensor:
144
        """
145
        Adjust image shape and concatenate them together.
146
147
        :param moving_image: registration source
148
        :param fixed_image: registration target
149
        :param moving_label: optional, only used for conditional model.
150
        :return:
151
        """
152
        images = []
153
154
        resize_layer = layer.Resize3d(shape=self.fixed_image_size)
155
156
        # (batch, m_dim1, m_dim2, m_dim3, 1)
157
        moving_image = tf.expand_dims(moving_image, axis=4)
158
        moving_image = resize_layer(moving_image)
159
        images.append(moving_image)
160
161
        # (batch, m_dim1, m_dim2, m_dim3, 1)
162
        fixed_image = tf.expand_dims(fixed_image, axis=4)
163
        images.append(fixed_image)
164
165
        # (batch, m_dim1, m_dim2, m_dim3, 1)
166
        if moving_label is not None:
167
            moving_label = tf.expand_dims(moving_label, axis=4)
168
            moving_label = resize_layer(moving_label)
169
            images.append(moving_label)
170
171
        # (batch, f_dim1, f_dim2, f_dim3, 2 or 3)
172
        images = tf.concat(images, axis=4)
173
        return images
174
175
    def _build_loss(self, name: str, inputs_dict: dict):
176
        """
177
        Build and add one weighted loss together with the metrics.
178
179
        :param name: name of loss, image / label / regularization.
180
        :param inputs_dict: inputs for loss function
181
        """
182
183
        if name not in self.config["loss"]:
184
            # loss config is not defined
185
            logging.warning(
186
                f"The configuration for loss {name} is not defined. "
187
                f"Therefore it is not used."
188
            )
189
            return
190
191
        loss_configs = self.config["loss"][name]
192
        if not isinstance(loss_configs, list):
193
            loss_configs = [loss_configs]
194
195
        for loss_config in loss_configs:
196
197
            if "weight" not in loss_config:
198
                # default loss weight 1
199
                logging.warning(
200
                    f"The weight for loss {name} is not defined."
201
                    f"Default weight = 1.0 is used."
202
                )
203
                loss_config["weight"] = 1.0
204
205
            # build loss
206
            weight = loss_config["weight"]
207
208
            if weight == 0:
209
                logging.warning(
210
                    f"The weight for loss {name} is zero." f"Loss is not used."
211
                )
212
                return
213
214
            loss_layer: tf.keras.layers.Layer = REGISTRY.build_loss(
215
                config=dict_without(d=loss_config, key="weight")
216
            )
217
            loss_value = loss_layer(**inputs_dict) / self.global_batch_size
218
            weighted_loss = loss_value * weight
219
220
            # add loss
221
            self._model.add_loss(weighted_loss)
222
223
            # add metric
224
            self._model.add_metric(
225
                loss_value, name=f"loss/{name}_{loss_layer.name}", aggregation="mean"
226
            )
227
            self._model.add_metric(
228
                weighted_loss,
229
                name=f"loss/{name}_{loss_layer.name}_weighted",
230
                aggregation="mean",
231
            )
232
233
    @abstractmethod
234
    def build_loss(self):
235
        """Build losses according to configs."""
236
237
    def call(
238
        self, inputs: Dict[str, tf.Tensor], training=None, mask=None
239
    ) -> Dict[str, tf.Tensor]:
240
        """
241
        Call the self._model.
242
243
        :param inputs: a dict of tensors.
244
        :param training: training or not.
245
        :param mask: maks for inputs.
246
        :return:
247
        """
248
        return self._model(inputs, training=training, mask=mask)  # pragma: no cover
249
250
    @abstractmethod
251
    def postprocess(
252
        self,
253
        inputs: Dict[str, tf.Tensor],
254
        outputs: Dict[str, tf.Tensor],
255
    ) -> Tuple[tf.Tensor, Dict]:
256
        """
257
        Return a dict used for saving inputs and outputs.
258
259
        :param inputs: dict of model inputs
260
        :param outputs: dict of model outputs
261
        :return: tuple, indices and a dict.
262
            In the dict, each value is (tensor, normalize, on_label), where
263
            - normalize = True if the tensor need to be normalized to [0, 1]
264
            - on_label = True if the tensor depends on label
265
        """
266
267
    def plot_model(self, output_dir: str):
268
        """
269
        Save model structure in png.
270
271
        :param output_dir: path to the output dir.
272
        """
273
        logging.info(self._model.summary())
274
        try:
275
            tf.keras.utils.plot_model(
276
                self._model,
277
                to_file=os.path.join(output_dir, f"{self.name}.png"),
278
                dpi=96,
279
                show_shapes=True,
280
                show_layer_names=True,
281
                expand_nested=False,
282
            )
283
        except ImportError as err:  # pragma: no cover
284
            logging.error(
285
                "Failed to plot model structure."
286
                "Please check if graphviz is installed.\n"
287
                "Error message is:"
288
                f"{err}"
289
            )
290
291
    def log_tensor_stats(self, tensor: tf.Tensor, name: str):
292
        """
293
        Log statistics of a given tensor.
294
295
        :param tensor: tensor to monitor.
296
        :param name: name of the tensor.
297
        """
298
        flatten = tf.reshape(tensor, shape=(self.batch_size, -1))
299
        self._model.add_metric(
300
            tf.reduce_mean(flatten, axis=1),
301
            name=f"metric/{name}_mean",
302
            aggregation="mean",
303
        )
304
        self._model.add_metric(
305
            tf.reduce_min(flatten, axis=1),
306
            name=f"metric/{name}_min",
307
            aggregation="min",
308
        )
309
        self._model.add_metric(
310
            tf.reduce_max(flatten, axis=1),
311
            name=f"metric/{name}_max",
312
            aggregation="max",
313
        )
314
315
316
@REGISTRY.register_model(name="ddf")
317
class DDFModel(RegistrationModel):
318
    """
319
    A registration model predicts DDF.
320
321
    When using global net as backbone,
322
    the model predicts an affine transformation parameters,
323
    and a DDF is calculated based on that.
324
    """
325
326
    name = "DDFModel"
327
328
    def _resize_interpolate(self, field, control_points):
329
        resize = layer.ResizeCPTransform(control_points)
330
        field = resize(field)
331
332
        interpolate = layer.BSplines3DTransform(control_points, self.fixed_image_size)
333
        field = interpolate(field)
334
335
        return field
336
337
    def build_model(self):
338
        """Build the model to be saved as self._model."""
339
        # build inputs
340
        self._inputs = self.build_inputs()
341
        moving_image = self._inputs["moving_image"]
342
        fixed_image = self._inputs["fixed_image"]
343
344
        # build ddf
345
        control_points = self.config["backbone"].pop("control_points", False)
346
        backbone_inputs = self.concat_images(moving_image, fixed_image)
347
        backbone = REGISTRY.build_backbone(
348
            config=self.config["backbone"],
349
            default_args=dict(
350
                image_size=self.fixed_image_size,
351
                out_channels=3,
352
                out_kernel_initializer="zeros",
353
                out_activation=None,
354
            ),
355
        )
356
357
        if isinstance(backbone, GlobalNet):
358
            # (f_dim1, f_dim2, f_dim3, 3), (4, 3)
359
            ddf, theta = backbone(inputs=backbone_inputs)
360
            self._outputs = dict(ddf=ddf, theta=theta)
361
        else:
362
            # (f_dim1, f_dim2, f_dim3, 3)
363
            ddf = backbone(inputs=backbone_inputs)
364
            ddf = (
365
                self._resize_interpolate(ddf, control_points) if control_points else ddf
366
            )
367
            self._outputs = dict(ddf=ddf)
368
369
        # build outputs
370
        self._warping = layer.Warping(fixed_image_size=self.fixed_image_size)
371
        # (f_dim1, f_dim2, f_dim3, 3)
372
        pred_fixed_image = self._warping(inputs=[ddf, moving_image])
373
        self._outputs["pred_fixed_image"] = pred_fixed_image
374
375
        if not self.labeled:
376
            return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
377
378
        # (f_dim1, f_dim2, f_dim3, 3)
379
        moving_label = self._inputs["moving_label"]
380
        pred_fixed_label = self._warping(inputs=[ddf, moving_label])
381
382
        self._outputs["pred_fixed_label"] = pred_fixed_label
383
        return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
384
385
    def build_loss(self):
386
        """Build losses according to configs."""
387
        fixed_image = self._inputs["fixed_image"]
388
        moving_image = self._inputs["moving_image"]
389
        ddf = self._outputs["ddf"]
390
        pred_fixed_image = self._outputs["pred_fixed_image"]
391
392
        # inputs
393
        self.log_tensor_stats(tensor=moving_image, name="moving_image")
394
        self.log_tensor_stats(tensor=fixed_image, name="fixed_image")
395
396
        # ddf
397
        self._build_loss(name="regularization", inputs_dict=dict(inputs=ddf))
398
        self.log_tensor_stats(tensor=ddf, name="ddf")
399
400
        # image
401
        self._build_loss(
402
            name="image", inputs_dict=dict(y_true=fixed_image, y_pred=pred_fixed_image)
403
        )
404
405
        # label
406
        if self.labeled:
407
            fixed_label = self._inputs["fixed_label"]
408
            moving_label = self._inputs["moving_label"]
409
            pred_fixed_label = self._outputs["pred_fixed_label"]
410
            self._build_loss(
411
                name="label",
412
                inputs_dict=dict(y_true=fixed_label, y_pred=pred_fixed_label),
413
            )
414
            tre = compute_centroid_distance(
415
                y_true=fixed_label, y_pred=pred_fixed_label, grid=self._warping.grid_ref
416
            )
417
            self._model.add_metric(tre, name="metric/TRE", aggregation="mean")
418
            self.log_tensor_stats(tensor=moving_label, name="moving_label")
419
            self.log_tensor_stats(tensor=fixed_label, name="fixed_label")
420
421
    def postprocess(
422
        self,
423
        inputs: Dict[str, tf.Tensor],
424
        outputs: Dict[str, tf.Tensor],
425
    ) -> Tuple[tf.Tensor, Dict]:
426
        """
427
        Return a dict used for saving inputs and outputs.
428
429
        :param inputs: dict of model inputs
430
        :param outputs: dict of model outputs
431
        :return: tuple, indices and a dict.
432
            In the dict, each value is (tensor, normalize, on_label), where
433
            - normalize = True if the tensor need to be normalized to [0, 1]
434
            - on_label = True if the tensor depends on label
435
        """
436
        indices = inputs["indices"]
437
        processed = dict(
438
            moving_image=(inputs["moving_image"], True, False),
439
            fixed_image=(inputs["fixed_image"], True, False),
440
            ddf=(outputs["ddf"], True, False),
441
            pred_fixed_image=(outputs["pred_fixed_image"], True, False),
442
        )
443
444
        # save theta for affine model
445
        if "theta" in outputs:
446
            processed["theta"] = (outputs["theta"], None, None)  # type: ignore
447
448
        if not self.labeled:
449
            return indices, processed
450
451
        processed = {
452
            **dict(
453
                moving_label=(inputs["moving_label"], False, True),
454
                fixed_label=(inputs["fixed_label"], False, True),
455
                pred_fixed_label=(outputs["pred_fixed_label"], False, True),
456
            ),
457
            **processed,
458
        }
459
460
        return indices, processed
461
462
463
@REGISTRY.register_model(name="dvf")
464
class DVFModel(DDFModel):
465
    """
466
    A registration model predicts DVF.
467
468
    DDF is calculated based on DVF.
469
    """
470
471
    name = "DVFModel"
472
473
    def build_model(self):
474
        """Build the model to be saved as self._model."""
475
        # build inputs
476
        self._inputs = self.build_inputs()
477
        moving_image = self._inputs["moving_image"]
478
        fixed_image = self._inputs["fixed_image"]
479
        control_points = self.config["backbone"].pop("control_points", False)
480
481
        # build ddf
482
        backbone_inputs = self.concat_images(moving_image, fixed_image)
483
        backbone = REGISTRY.build_backbone(
484
            config=self.config["backbone"],
485
            default_args=dict(
486
                image_size=self.fixed_image_size,
487
                out_channels=3,
488
                out_kernel_initializer="zeros",
489
                out_activation=None,
490
            ),
491
        )
492
        dvf = backbone(inputs=backbone_inputs)
493
        dvf = self._resize_interpolate(dvf, control_points) if control_points else dvf
494
        ddf = layer.IntDVF(fixed_image_size=self.fixed_image_size)(dvf)
495
496
        # build outputs
497
        self._warping = layer.Warping(fixed_image_size=self.fixed_image_size)
498
        # (f_dim1, f_dim2, f_dim3, 3)
499
        pred_fixed_image = self._warping(inputs=[ddf, moving_image])
500
501
        self._outputs = dict(dvf=dvf, ddf=ddf, pred_fixed_image=pred_fixed_image)
502
503
        if not self.labeled:
504
            return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
505
506
        # (f_dim1, f_dim2, f_dim3, 3)
507
        moving_label = self._inputs["moving_label"]
508
        pred_fixed_label = self._warping(inputs=[ddf, moving_label])
509
510
        self._outputs["pred_fixed_label"] = pred_fixed_label
511
        return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
512
513
    def postprocess(
514
        self,
515
        inputs: Dict[str, tf.Tensor],
516
        outputs: Dict[str, tf.Tensor],
517
    ) -> Tuple[tf.Tensor, Dict]:
518
        """
519
        Return a dict used for saving inputs and outputs.
520
521
        :param inputs: dict of model inputs
522
        :param outputs: dict of model outputs
523
        :return: tuple, indices and a dict.
524
            In the dict, each value is (tensor, normalize, on_label), where
525
            - normalize = True if the tensor need to be normalized to [0, 1]
526
            - on_label = True if the tensor depends on label
527
        """
528
        indices, processed = super().postprocess(inputs=inputs, outputs=outputs)
529
        processed["dvf"] = (outputs["dvf"], True, False)
530
        return indices, processed
531
532
533
@REGISTRY.register_model(name="conditional")
534
class ConditionalModel(RegistrationModel):
535
    """
536
    A registration model predicts fixed image label without DDF or DVF.
537
    """
538
539
    name = "ConditionalModel"
540
541
    def build_model(self):
542
        """Build the model to be saved as self._model."""
543
        assert self.labeled
544
545
        # build inputs
546
        self._inputs = self.build_inputs()
547
        moving_image = self._inputs["moving_image"]
548
        fixed_image = self._inputs["fixed_image"]
549
        moving_label = self._inputs["moving_label"]
550
551
        # build ddf
552
        backbone_inputs = self.concat_images(moving_image, fixed_image, moving_label)
553
        backbone = REGISTRY.build_backbone(
554
            config=self.config["backbone"],
555
            default_args=dict(
556
                image_size=self.fixed_image_size,
557
                out_channels=1,
558
                out_kernel_initializer="glorot_uniform",
559
                out_activation="sigmoid",
560
            ),
561
        )
562
        # (batch, f_dim1, f_dim2, f_dim3)
563
        pred_fixed_label = backbone(inputs=backbone_inputs)
564
        pred_fixed_label = tf.squeeze(pred_fixed_label, axis=4)
565
566
        self._outputs = dict(pred_fixed_label=pred_fixed_label)
567
        return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
568
569
    def build_loss(self):
570
        """Build losses according to configs."""
571
        fixed_label = self._inputs["fixed_label"]
572
        pred_fixed_label = self._outputs["pred_fixed_label"]
573
574
        self._build_loss(
575
            name="label",
576
            inputs_dict=dict(y_true=fixed_label, y_pred=pred_fixed_label),
577
        )
578
579
    def postprocess(
580
        self,
581
        inputs: Dict[str, tf.Tensor],
582
        outputs: Dict[str, tf.Tensor],
583
    ) -> Tuple[tf.Tensor, Dict]:
584
        """
585
        Return a dict used for saving inputs and outputs.
586
587
        :param inputs: dict of model inputs
588
        :param outputs: dict of model outputs
589
        :return: tuple, indices and a dict.
590
            In the dict, each value is (tensor, normalize, on_label), where
591
            - normalize = True if the tensor need to be normalized to [0, 1]
592
            - on_label = True if the tensor depends on label
593
        """
594
        indices = inputs["indices"]
595
        processed = dict(
596
            moving_image=(inputs["moving_image"], True, False),
597
            fixed_image=(inputs["fixed_image"], True, False),
598
            pred_fixed_label=(outputs["pred_fixed_label"], True, True),
599
            moving_label=(inputs["moving_label"], False, True),
600
            fixed_label=(inputs["fixed_label"], False, True),
601
        )
602
603
        return indices, processed
604