Completed
Push — main ( 72b597...0bdbdf )
by Yunguan
23s queued 12s
created

DDFModel._resize_interpolate()   A

Complexity

Conditions 1

Size

Total Lines 8
Code Lines 6

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 6
dl 0
loc 8
rs 10
c 0
b 0
f 0
cc 1
nop 3
1
import logging
0 ignored issues
show
introduced by
Missing module docstring
Loading history...
2
from abc import abstractmethod
3
from copy import deepcopy
4
from typing import Dict, Optional
5
6
import tensorflow as tf
7
8
from deepreg.model import layer, layer_util
9
from deepreg.model.backbone import GlobalNet
10
from deepreg.registry import REGISTRY
11
12
13
def dict_without(d: dict, key) -> dict:
0 ignored issues
show
introduced by
"key" missing in parameter type documentation
Loading history...
14
    """
15
    Return a copy of the given dict without a certain key.
16
17
    :param d: dict to be copied.
18
    :param key: key to be removed.
19
    :return: the copy without a key
20
    """
21
    copied = deepcopy(d)
22
    copied.pop(key)
23
    return copied
24
25
26
class RegistrationModel(tf.keras.Model):
27
    """Interface for registration model."""
28
29
    def __init__(
30
        self,
31
        moving_image_size: tuple,
32
        fixed_image_size: tuple,
33
        index_size: int,
34
        labeled: bool,
35
        batch_size: int,
36
        config: dict,
37
        num_devices: int = 1,
38
        name: str = "RegistrationModel",
39
    ):
40
        """
41
        Init.
42
43
        :param moving_image_size: (m_dim1, m_dim2, m_dim3)
44
        :param fixed_image_size: (f_dim1, f_dim2, f_dim3)
45
        :param index_size: number of indices for identify each sample
46
        :param labeled: if the data is labeled
47
        :param batch_size: size of mini-batch
48
        :param config: config for method, backbone, and loss.
49
        :param num_devices: number of GPU used,
50
            global_batch_size = batch_size*num_devices
51
        :param name: name of the model
52
        """
53
        super().__init__(name=name)
54
        self.moving_image_size = moving_image_size
55
        self.fixed_image_size = fixed_image_size
56
        self.index_size = index_size
57
        self.labeled = labeled
58
        self.batch_size = batch_size
59
        self.config = config
60
        self.num_devices = num_devices
61
        self.global_batch_size = num_devices * batch_size
62
63
        self._inputs = None  # save inputs of self._model as dict
64
        self._outputs = None  # save outputs of self._model as dict
65
        self._model = self.build_model()
66
        self.build_loss()
67
68
    def get_config(self) -> dict:
69
        """Return the config dictionary for recreating this class."""
70
        return dict(
71
            moving_image_size=self.moving_image_size,
72
            fixed_image_size=self.fixed_image_size,
73
            index_size=self.index_size,
74
            labeled=self.labeled,
75
            batch_size=self.batch_size,
76
            config=self.config,
77
            num_devices=self.num_devices,
78
            name=self.name,
79
        )
80
81
    @abstractmethod
82
    def build_model(self):
83
        """Build the model to be saved as self._model."""
84
85
    def build_inputs(self) -> Dict[str, tf.keras.layers.Input]:
86
        """
87
        Build input tensors.
88
89
        :return: dict of inputs.
90
        """
91
        # (batch, m_dim1, m_dim2, m_dim3, 1)
92
        moving_image = tf.keras.Input(
93
            shape=self.moving_image_size,
94
            batch_size=self.batch_size,
95
            name="moving_image",
96
        )
97
        # (batch, f_dim1, f_dim2, f_dim3, 1)
98
        fixed_image = tf.keras.Input(
99
            shape=self.fixed_image_size,
100
            batch_size=self.batch_size,
101
            name="fixed_image",
102
        )
103
        # (batch, index_size)
104
        indices = tf.keras.Input(
105
            shape=(self.index_size,),
106
            batch_size=self.batch_size,
107
            name="indices",
108
        )
109
110
        if not self.labeled:
111
            return dict(
112
                moving_image=moving_image, fixed_image=fixed_image, indices=indices
113
            )
114
115
        # (batch, m_dim1, m_dim2, m_dim3, 1)
116
        moving_label = tf.keras.Input(
117
            shape=self.moving_image_size,
118
            batch_size=self.batch_size,
119
            name="moving_label",
120
        )
121
        # (batch, m_dim1, m_dim2, m_dim3, 1)
122
        fixed_label = tf.keras.Input(
123
            shape=self.fixed_image_size,
124
            batch_size=self.batch_size,
125
            name="fixed_label",
126
        )
127
        return dict(
128
            moving_image=moving_image,
129
            fixed_image=fixed_image,
130
            moving_label=moving_label,
131
            fixed_label=fixed_label,
132
            indices=indices,
133
        )
134
135
    def concat_images(
136
        self,
137
        moving_image: tf.Tensor,
138
        fixed_image: tf.Tensor,
139
        moving_label: Optional[tf.Tensor] = None,
140
    ) -> tf.Tensor:
141
        """
142
        Adjust image shape and concatenate them together.
143
144
        :param moving_image: registration source
145
        :param fixed_image: registration target
146
        :param moving_label: optional, only used for conditional model.
147
        :return:
148
        """
149
        images = []
150
151
        # (batch, m_dim1, m_dim2, m_dim3, 1)
152
        moving_image = tf.expand_dims(moving_image, axis=4)
153
        moving_image = layer_util.resize3d(
154
            image=moving_image, size=self.fixed_image_size
155
        )
156
        images.append(moving_image)
157
158
        # (batch, m_dim1, m_dim2, m_dim3, 1)
159
        fixed_image = tf.expand_dims(fixed_image, axis=4)
160
        images.append(fixed_image)
161
162
        # (batch, m_dim1, m_dim2, m_dim3, 1)
163
        if moving_label is not None:
164
            moving_label = tf.expand_dims(moving_label, axis=4)
165
            moving_label = layer_util.resize3d(
166
                image=moving_label, size=self.fixed_image_size
167
            )
168
            images.append(moving_label)
169
170
        # (batch, f_dim1, f_dim2, f_dim3, 2 or 3)
171
        images = tf.concat(images, axis=4)
172
        return images
173
174
    def _build_loss(self, name: str, inputs_dict: dict):
175
        """
176
        Build and add one weighted loss together with the metrics.
177
178
        :param name: name of loss
179
        :param inputs_dict: inputs for loss function
180
        """
181
        if name not in self.config["loss"]:
182
            # loss config is not defined
183
            logging.warning(
0 ignored issues
show
introduced by
Use lazy % formatting in logging functions
Loading history...
184
                f"The configuration for loss {name} is not defined."
185
                f"Loss is not used."
186
            )
187
            return
188
189
        loss_config = self.config["loss"][name]
190
191
        if "weight" not in loss_config:
192
            # default loss weight 1
193
            logging.warning(
0 ignored issues
show
introduced by
Use lazy % formatting in logging functions
Loading history...
194
                f"The weight for loss {name} is not defined."
195
                f"Default weight = 1.0 is used."
196
            )
197
            loss_config["weight"] = 1.0
198
199
        # build loss
200
        weight = loss_config["weight"]
201
202
        if weight == 0:
203
            logging.warning(f"The weight for loss {name} is zero." f"Loss is not used.")
0 ignored issues
show
introduced by
Use lazy % formatting in logging functions
Loading history...
204
            return
205
206
        loss_cls = REGISTRY.build_loss(config=dict_without(d=loss_config, key="weight"))
207
        loss = loss_cls(**inputs_dict) / self.global_batch_size
208
        weighted_loss = loss * weight
209
210
        # add loss
211
        self._model.add_loss(weighted_loss)
212
213
        # add metric
214
        self._model.add_metric(
215
            loss, name=f"loss/{name}_{loss_cls.name}", aggregation="mean"
216
        )
217
        self._model.add_metric(
218
            weighted_loss,
219
            name=f"loss/{name}_{loss_cls.name}_weighted",
220
            aggregation="mean",
221
        )
222
223
    @abstractmethod
224
    def build_loss(self):
225
        """Build losses according to configs."""
226
227
    def call(
0 ignored issues
show
introduced by
"mask, training" missing in parameter type documentation
Loading history...
228
        self, inputs: Dict[str, tf.Tensor], training=None, mask=None
229
    ) -> Dict[str, tf.Tensor]:
230
        """
231
        Call the self._model.
232
233
        :param inputs: a dict of tensors.
234
        :param training: training or not.
235
        :param mask: maks for inputs.
236
        :return:
237
        """
238
        return self._model(inputs, training=training, mask=mask)  # pragma: no cover
239
240
    @abstractmethod
241
    def postprocess(
242
        self,
243
        inputs: Dict[str, tf.Tensor],
244
        outputs: Dict[str, tf.Tensor],
245
    ) -> (tf.Tensor, Dict):
246
        """
247
        Return a dict used for saving inputs and outputs.
248
249
        :param inputs: dict of model inputs
250
        :param outputs: dict of model outputs
251
        :return: tuple, indices and a dict.
252
            In the dict, each value is (tensor, normalize, on_label), where
253
            - normalize = True if the tensor need to be normalized to [0, 1]
254
            - on_label = True if the tensor depends on label
255
        """
256
257
258
@REGISTRY.register_model(name="ddf")
259
class DDFModel(RegistrationModel):
260
    """
261
    A registration model predicts DDF.
262
263
    When using global net as backbone,
264
    the model predicts an affine transformation parameters,
265
    and a DDF is calculated based on that.
266
    """
267
268
    def _resize_interpolate(self, field, control_points):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
269
        resize = layer.ResizeCPTransform(control_points)
270
        field = resize(field)
271
272
        interpolate = layer.BSplines3DTransform(control_points, self.fixed_image_size)
273
        field = interpolate(field)
274
275
        return field
276
277
    def build_model(self):
278
        """Build the model to be saved as self._model."""
279
        # build inputs
280
        self._inputs = self.build_inputs()
281
        moving_image = self._inputs["moving_image"]
282
        fixed_image = self._inputs["fixed_image"]
283
284
        # build ddf
285
        control_points = self.config["backbone"].pop("control_points", False)
286
        backbone_inputs = self.concat_images(moving_image, fixed_image)
287
        backbone = REGISTRY.build_backbone(
288
            config=self.config["backbone"],
289
            default_args=dict(
290
                image_size=self.fixed_image_size,
291
                out_channels=3,
292
                out_kernel_initializer="zeros",
293
                out_activation=None,
294
            ),
295
        )
296
297
        if isinstance(backbone, GlobalNet):
298
            # (f_dim1, f_dim2, f_dim3, 3), (4, 3)
299
            ddf, theta = backbone(inputs=backbone_inputs)
300
            self._outputs = dict(ddf=ddf, theta=theta)
301
        else:
302
            # (f_dim1, f_dim2, f_dim3, 3)
303
            ddf = backbone(inputs=backbone_inputs)
304
            ddf = (
305
                self._resize_interpolate(ddf, control_points) if control_points else ddf
306
            )
307
            self._outputs = dict(ddf=ddf)
308
309
        # build outputs
310
        warping = layer.Warping(fixed_image_size=self.fixed_image_size)
311
        # (f_dim1, f_dim2, f_dim3, 3)
312
        pred_fixed_image = warping(inputs=[ddf, moving_image])
313
        self._outputs["pred_fixed_image"] = pred_fixed_image
314
315
        if not self.labeled:
316
            return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
317
318
        # (f_dim1, f_dim2, f_dim3, 3)
319
        moving_label = self._inputs["moving_label"]
320
        pred_fixed_label = warping(inputs=[ddf, moving_label])
321
322
        self._outputs["pred_fixed_label"] = pred_fixed_label
323
        return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
324
325
    def build_loss(self):
326
        """Build losses according to configs."""
327
        fixed_image = self._inputs["fixed_image"]
328
        ddf = self._outputs["ddf"]
329
        pred_fixed_image = self._outputs["pred_fixed_image"]
330
331
        # ddf
332
        self._build_loss(name="regularization", inputs_dict=dict(inputs=ddf))
333
334
        # image
335
        self._build_loss(
336
            name="image", inputs_dict=dict(y_true=fixed_image, y_pred=pred_fixed_image)
337
        )
338
339
        # label
340
        if self.labeled:
341
            fixed_label = self._inputs["fixed_label"]
342
            pred_fixed_label = self._outputs["pred_fixed_label"]
343
            self._build_loss(
344
                name="label",
345
                inputs_dict=dict(y_true=fixed_label, y_pred=pred_fixed_label),
346
            )
347
348
    def postprocess(
349
        self,
350
        inputs: Dict[str, tf.Tensor],
351
        outputs: Dict[str, tf.Tensor],
352
    ) -> (tf.Tensor, Dict):
353
        """
354
        Return a dict used for saving inputs and outputs.
355
356
        :param inputs: dict of model inputs
357
        :param outputs: dict of model outputs
358
        :return: tuple, indices and a dict.
359
            In the dict, each value is (tensor, normalize, on_label), where
360
            - normalize = True if the tensor need to be normalized to [0, 1]
361
            - on_label = True if the tensor depends on label
362
        """
363
        indices = inputs["indices"]
364
        processed = dict(
365
            moving_image=(inputs["moving_image"], True, False),
366
            fixed_image=(inputs["fixed_image"], True, False),
367
            ddf=(outputs["ddf"], True, False),
368
            pred_fixed_image=(outputs["pred_fixed_image"], True, False),
369
        )
370
371
        # save theta for affine model
372
        if "theta" in outputs:
373
            processed["theta"] = (outputs["theta"], None, None)
374
375
        if not self.labeled:
376
            return indices, processed
377
378
        processed = {
379
            **dict(
380
                moving_label=(inputs["moving_label"], False, True),
381
                fixed_label=(inputs["fixed_label"], False, True),
382
                pred_fixed_label=(outputs["pred_fixed_label"], False, True),
383
            ),
384
            **processed,
385
        }
386
387
        return indices, processed
388
389
390
@REGISTRY.register_model(name="dvf")
391
class DVFModel(DDFModel):
392
    """
393
    A registration model predicts DVF.
394
395
    DDF is calculated based on DVF.
396
    """
397
398
    def build_model(self):
399
        """Build the model to be saved as self._model."""
400
        # build inputs
401
        self._inputs = self.build_inputs()
402
        moving_image = self._inputs["moving_image"]
403
        fixed_image = self._inputs["fixed_image"]
404
        control_points = self.config["backbone"].pop("control_points", False)
405
406
        # build ddf
407
        backbone_inputs = self.concat_images(moving_image, fixed_image)
408
        backbone = REGISTRY.build_backbone(
409
            config=self.config["backbone"],
410
            default_args=dict(
411
                image_size=self.fixed_image_size,
412
                out_channels=3,
413
                out_kernel_initializer="zeros",
414
                out_activation=None,
415
            ),
416
        )
417
        dvf = backbone(inputs=backbone_inputs)
418
        dvf = self._resize_interpolate(dvf, control_points) if control_points else dvf
419
        ddf = layer.IntDVF(fixed_image_size=self.fixed_image_size)(dvf)
420
421
        # build outputs
422
        warping = layer.Warping(fixed_image_size=self.fixed_image_size)
423
        # (f_dim1, f_dim2, f_dim3, 3)
424
        pred_fixed_image = warping(inputs=[ddf, moving_image])
425
426
        self._outputs = dict(dvf=dvf, ddf=ddf, pred_fixed_image=pred_fixed_image)
427
428
        if not self.labeled:
429
            return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
430
431
        # (f_dim1, f_dim2, f_dim3, 3)
432
        moving_label = self._inputs["moving_label"]
433
        pred_fixed_label = warping(inputs=[ddf, moving_label])
434
435
        self._outputs["pred_fixed_label"] = pred_fixed_label
436
        return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
437
438
    def postprocess(
439
        self,
440
        inputs: Dict[str, tf.Tensor],
441
        outputs: Dict[str, tf.Tensor],
442
    ) -> (tf.Tensor, Dict):
443
        """
444
        Return a dict used for saving inputs and outputs.
445
446
        :param inputs: dict of model inputs
447
        :param outputs: dict of model outputs
448
        :return: tuple, indices and a dict.
449
            In the dict, each value is (tensor, normalize, on_label), where
450
            - normalize = True if the tensor need to be normalized to [0, 1]
451
            - on_label = True if the tensor depends on label
452
        """
453
        indices, processed = super().postprocess(inputs=inputs, outputs=outputs)
454
        processed["dvf"] = (outputs["dvf"], True, False)
455
        return indices, processed
456
457
458
@REGISTRY.register_model(name="conditional")
459
class ConditionalModel(RegistrationModel):
460
    """
461
    A registration model predicts fixed image label without DDF or DVF.
462
    """
463
464
    def build_model(self):
465
        """Build the model to be saved as self._model."""
466
        assert self.labeled
467
468
        # build inputs
469
        self._inputs = self.build_inputs()
470
        moving_image = self._inputs["moving_image"]
471
        fixed_image = self._inputs["fixed_image"]
472
        moving_label = self._inputs["moving_label"]
473
474
        # build ddf
475
        backbone_inputs = self.concat_images(moving_image, fixed_image, moving_label)
476
        backbone = REGISTRY.build_backbone(
477
            config=self.config["backbone"],
478
            default_args=dict(
479
                image_size=self.fixed_image_size,
480
                out_channels=1,
481
                out_kernel_initializer="glorot_uniform",
482
                out_activation="sigmoid",
483
            ),
484
        )
485
        # (batch, f_dim1, f_dim2, f_dim3)
486
        pred_fixed_label = backbone(inputs=backbone_inputs)
487
        pred_fixed_label = tf.squeeze(pred_fixed_label, axis=4)
488
489
        self._outputs = dict(pred_fixed_label=pred_fixed_label)
490
        return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
491
492
    def build_loss(self):
493
        """Build losses according to configs."""
494
        fixed_label = self._inputs["fixed_label"]
495
        pred_fixed_label = self._outputs["pred_fixed_label"]
496
497
        self._build_loss(
498
            name="label",
499
            inputs_dict=dict(y_true=fixed_label, y_pred=pred_fixed_label),
500
        )
501
502
    def postprocess(
503
        self,
504
        inputs: Dict[str, tf.Tensor],
505
        outputs: Dict[str, tf.Tensor],
506
    ) -> (tf.Tensor, Dict):
507
        """
508
        Return a dict used for saving inputs and outputs.
509
510
        :param inputs: dict of model inputs
511
        :param outputs: dict of model outputs
512
        :return: tuple, indices and a dict.
513
            In the dict, each value is (tensor, normalize, on_label), where
514
            - normalize = True if the tensor need to be normalized to [0, 1]
515
            - on_label = True if the tensor depends on label
516
        """
517
        indices = inputs["indices"]
518
        processed = dict(
519
            moving_image=(inputs["moving_image"], True, False),
520
            fixed_image=(inputs["fixed_image"], True, False),
521
            pred_fixed_label=(outputs["pred_fixed_label"], True, True),
522
            moving_label=(inputs["moving_label"], False, True),
523
            fixed_label=(inputs["fixed_label"], False, True),
524
        )
525
526
        return indices, processed
527