Passed
Pull Request — main (#656)
by Yunguan
02:47
created

deepreg.model.layer.Warping.__init__()   A

Complexity

Conditions 1

Size

Total Lines 14
Code Lines 5

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 5
dl 0
loc 14
rs 10
c 0
b 0
f 0
cc 1
nop 4
1
"""This module defines custom layers."""
2
import itertools
3
4
import numpy as np
5
import tensorflow as tf
6
import tensorflow.keras.layers as tfkl
7
8
import deepreg.model.layer_util as layer_util
9
10
11
class Deconv3d(tfkl.Layer):
12
    """
13
    Wrap Conv3DTranspose to allow dynamic output padding calculation.
14
    """
15
16
    def __init__(
0 ignored issues
show
introduced by
"name" missing in parameter type documentation
Loading history...
17
        self,
18
        filters: int,
19
        output_shape: (tuple, None) = None,
20
        kernel_size: int = 3,
21
        strides: int = 1,
22
        padding: str = "same",
23
        name="deconv3d",
24
        **kwargs,
25
    ):
26
        """
27
        Init.
28
29
        :param filters: number of channels of the output
30
        :param output_shape: (out_dim1, out_dim2, out_dim3)
31
        :param kernel_size: int or tuple of 3 ints, e.g. (3,3,3) or 3
32
        :param strides: int or tuple of 3 ints, e.g. (1,1,1) or 1
33
        :param padding: same or valid.
34
        :param name: name of the layer.
35
        :param kwargs: additional arguments for Conv3DTranspose.
36
        """
37
        super().__init__(name=name)
38
        # save arguments
39
        self._filters = filters
40
        self._output_shape = output_shape
41
        self._kernel_size = kernel_size
42
        self._strides = strides
43
        self._padding = padding
44
        self._kwargs = kwargs
45
        # init layer variables
46
        self._deconv3d = None
47
48
    def build(self, input_shape):
0 ignored issues
show
introduced by
"input_shape" missing in parameter type documentation
Loading history...
49
        # pylint: disable-next=line-too-long
50
        """
51
        Calculate output padding on the fly.
52
53
        https://github.com/tensorflow/tensorflow/blob/1cf0898dd4331baf93fe77205550f2c2e6c90ee5/tensorflow/python/keras/utils/conv_utils.py#L139-L185
54
55
        :param input_shape: shape of input
56
        """
57
        super().build(input_shape)
58
        kernel_size = self._kernel_size
59
        strides = self._strides
60
        if isinstance(kernel_size, int):
61
            kernel_size = [kernel_size] * 3
62
63
        if isinstance(strides, int):
64
            strides = [strides] * 3
65
66
        output_padding = None
67
        if self._output_shape is not None:
68
            assert self._padding == "same"
69
            output_padding = [
70
                self._output_shape[i]
71
                - (
72
                    (input_shape[1 + i] - 1) * strides[i]
73
                    + kernel_size[i]
74
                    - 2 * (kernel_size[i] // 2)
75
                )
76
                for i in range(3)
77
            ]
78
        self._deconv3d = tfkl.Conv3DTranspose(
79
            filters=self._filters,
80
            kernel_size=self._kernel_size,
81
            strides=self._strides,
82
            padding=self._padding,
83
            output_padding=output_padding,
84
            **self._kwargs,
85
        )
86
87
    def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor:
88
        """
89
        Forward.
90
91
        :param inputs: input tensor.
92
        :param kwargs: additional arguments
93
        :return:
94
        """
95
        return self._deconv3d(inputs=inputs)
96
97
    def get_config(self) -> dict:
98
        """Return the config dictionary for recreating this class."""
99
        config = super().get_config()
100
        config.update(
101
            dict(
102
                filters=self._filters,
103
                output_shape=self._output_shape,
104
                kernel_size=self._kernel_size,
105
                strides=self._strides,
106
                padding=self._padding,
107
            )
108
        )
109
        config.update(self._kwargs)
110
        return config
111
112
113
class NormBlock(tfkl.Layer):
114
    """
115
    A block with layer - norm - activation.
116
    """
117
118
    layer_cls_dict = dict(conv3d=tfkl.Conv3D, deconv3d=Deconv3d)
119
    norm_cls_dict = dict(batch=tfkl.BatchNormalization, layer=tfkl.LayerNormalization)
120
121
    def __init__(
122
        self,
123
        layer_name: str,
124
        norm_name: str = "batch",
125
        activation: str = "relu",
126
        name: str = "norm_block",
127
        **kwargs,
128
    ):
129
        """
130
        Init.
131
132
        :param layer_name: class of the layer to be wrapped.
133
        :param norm_name: class of the normalization layer.
134
        :param activation: name of activation.
135
        :param name: name of the block layer.
136
        :param kwargs: additional arguments.
137
        """
138
        super().__init__()
139
        self._config = dict(
140
            layer_name=layer_name,
141
            norm_name=norm_name,
142
            activation=activation,
143
            name=name,
144
            **kwargs,
145
        )
146
        self._layer = self.layer_cls_dict[layer_name](use_bias=False, **kwargs)
147
        self._norm = self.norm_cls_dict[norm_name]()
148
        self._act = tfkl.Activation(activation=activation)
149
150
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
151
        """
152
        Forward.
153
154
        :param inputs: inputs for the layer
155
        :param training: training flag for normalization layers (default: None)
156
        :param kwargs: additional arguments.
157
        :return:
158
        """
159
        output = self._layer(inputs=inputs)
160
        output = self._norm(inputs=output, training=training)
161
        output = self._act(output)
162
        return output
163
164
    def get_config(self) -> dict:
165
        """Return the config dictionary for recreating this class."""
166
        config = super().get_config()
167
        config.update(self._config)
168
        return config
169
170
171
class Conv3dBlock(NormBlock):
172
    """
173
    A conv3d block having conv3d - norm - activation.
174
    """
175
176
    def __init__(
177
        self,
178
        name: str = "conv3d_block",
179
        **kwargs,
180
    ):
181
        """
182
        Init.
183
184
        :param name: name of the layer
185
        :param kwargs: additional arguments.
186
        """
187
        super().__init__(layer_name="conv3d", name=name, **kwargs)
188
189
190
class Deconv3dBlock(NormBlock):
191
    """
192
    A deconv3d block having conv3d - norm - activation.
193
    """
194
195
    def __init__(
196
        self,
197
        name: str = "deconv3d_block",
198
        **kwargs,
199
    ):
200
        """
201
        Init.
202
203
        :param name: name of the layer
204
        :param kwargs: additional arguments.
205
        """
206
        super().__init__(layer_name="deconv3d", name=name, **kwargs)
207
208
209
class Resize3d(tfkl.Layer):
210
    """
211
    Resize image in two folds.
212
213
    - resize dim2 and dim3
214
    - resize dim1 and dim2
215
    """
216
217
    def __init__(
218
        self,
219
        shape: tuple,
220
        method: str = tf.image.ResizeMethod.BILINEAR,
221
        name: str = "resize3d",
222
    ):
223
        """
224
        Init, save arguments.
225
226
        :param shape: (dim1, dim2, dim3)
227
        :param method: tf.image.ResizeMethod
228
        :param name: name of the layer
229
        """
230
        super().__init__(name=name)
231
        assert len(shape) == 3
232
        self._shape = shape
233
        self._method = method
234
235
    def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"ValueError" not documented as being raised
Loading history...
236
        """
237
        Perform two fold resize.
238
239
        :param inputs: shape = (batch, dim1, dim2, dim3, channels)
240
                                     or (batch, dim1, dim2, dim3)
241
                                     or (dim1, dim2, dim3)
242
        :param kwargs: additional arguments
243
        :return: shape = (batch, out_dim1, out_dim2, out_dim3, channels)
244
                                or (batch, dim1, dim2, dim3)
245
                                or (dim1, dim2, dim3)
246
        """
247
        # sanity check
248
        image = inputs
249
        image_dim = len(image.shape)
250
251
        # init
252
        if image_dim == 5:
253
            has_channel = True
254
            has_batch = True
255
            input_image_shape = image.shape[1:4]
256
        elif image_dim == 4:
257
            has_channel = False
258
            has_batch = True
259
            input_image_shape = image.shape[1:4]
260
        elif image_dim == 3:
261
            has_channel = False
262
            has_batch = False
263
            input_image_shape = image.shape[0:3]
264
        else:
265
            raise ValueError(
266
                "Resize3d takes input image of dimension 3 or 4 or 5, "
267
                "corresponding to (dim1, dim2, dim3) "
268
                "or (batch, dim1, dim2, dim3) "
269
                "or (batch, dim1, dim2, dim3, channels), "
270
                "got image shape{}".format(image.shape)
271
            )
272
273
        # no need of resize
274
        if input_image_shape == tuple(self._shape):
275
            return image
276
277
        # expand to five dimensions
278
        if not has_batch:
279
            image = tf.expand_dims(image, axis=0)
280
        if not has_channel:
281
            image = tf.expand_dims(image, axis=-1)
282
        assert len(image.shape) == 5  # (batch, dim1, dim2, dim3, channels)
283
        image_shape = tf.shape(image)
284
285
        # merge axis 0 and 1
286
        output = tf.reshape(
287
            image, (-1, image_shape[2], image_shape[3], image_shape[4])
288
        )  # (batch * dim1, dim2, dim3, channels)
289
290
        # resize dim2 and dim3
291
        output = tf.image.resize(
292
            images=output, size=self._shape[1:3], method=self._method
293
        )  # (batch * dim1, out_dim2, out_dim3, channels)
294
295
        # split axis 0 and merge axis 3 and 4
296
        output = tf.reshape(
297
            output,
298
            shape=(-1, image_shape[1], self._shape[1], self._shape[2] * image_shape[4]),
299
        )  # (batch, dim1, out_dim2, out_dim3 * channels)
300
301
        # resize dim1 and dim2
302
        output = tf.image.resize(
303
            images=output, size=self._shape[:2], method=self._method
304
        )  # (batch, out_dim1, out_dim2, out_dim3 * channels)
305
306
        # reshape
307
        output = tf.reshape(
308
            output, shape=[-1, *self._shape, image_shape[4]]
309
        )  # (batch, out_dim1, out_dim2, out_dim3, channels)
310
311
        # squeeze to original dimension
312
        if not has_batch:
313
            output = tf.squeeze(output, axis=0)
314
        if not has_channel:
315
            output = tf.squeeze(output, axis=-1)
316
        return output
317
318
    def get_config(self) -> dict:
319
        """Return the config dictionary for recreating this class."""
320
        config = super().get_config()
321
        config["shape"] = self._shape
322
        config["method"] = self._method
323
        return config
324
325
326
class Warping(tfkl.Layer):
327
    """
328
    Warps an image with DDF.
329
330
    Reference:
331
332
    https://github.com/adalca/neurite/blob/legacy/neuron/utils.py
333
    where vol = image, loc_shift = ddf
334
    """
335
336
    def __init__(self, fixed_image_size: tuple, name: str = "warping", **kwargs):
337
        """
338
        Init.
339
340
        :param fixed_image_size: shape = (f_dim1, f_dim2, f_dim3)
341
             or (f_dim1, f_dim2, f_dim3, ch) with the last channel for features
342
        :param name: name of the layer
343
        :param kwargs: additional arguments.
344
        """
345
        super().__init__(name=name, **kwargs)
346
        self._fixed_image_size = fixed_image_size
347
        # shape = (1, f_dim1, f_dim2, f_dim3, 3)
348
        self.grid_ref = layer_util.get_reference_grid(grid_size=fixed_image_size)
349
        self.grid_ref = self.grid_ref[None, ...]
350
351
    def call(self, inputs, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"inputs" missing in parameter type documentation
Loading history...
352
        """
353
        :param inputs: (ddf, image)
354
355
          - ddf, shape = (batch, f_dim1, f_dim2, f_dim3, 3)
356
          - image, shape = (batch, m_dim1, m_dim2, m_dim3)
357
        :param kwargs: additional arguments.
358
        :return: shape = (batch, f_dim1, f_dim2, f_dim3)
359
        """
360
        ddf, image = inputs
361
        return layer_util.resample(vol=image, loc=self.grid_ref + ddf)
362
363
    def get_config(self) -> dict:
364
        """Return the config dictionary for recreating this class."""
365
        config = super().get_config()
366
        config["fixed_image_size"] = self._fixed_image_size
367
        return config
368
369
370
class ResidualBlock(tfkl.Layer):
371
    """
372
    A block with skip links and layer - norm - activation.
373
    """
374
375
    layer_cls_dict = dict(conv3d=tfkl.Conv3D, deconv3d=Deconv3d)
376
    norm_cls_dict = dict(batch=tfkl.BatchNormalization, layer=tfkl.LayerNormalization)
377
378
    def __init__(
379
        self,
380
        layer_name: str,
381
        num_layers: int = 2,
382
        norm_name: str = "batch",
383
        activation: str = "relu",
384
        name: str = "res_block",
385
        **kwargs,
386
    ):
387
        """
388
        Init.
389
390
        :param layer_name: class of the layer to be wrapped.
391
        :param num_layers: number of layers/blocks.
392
        :param norm_name: class of the normalization layer.
393
        :param activation: name of activation.
394
        :param name: name of the block layer.
395
        :param kwargs: additional arguments.
396
        """
397
        super().__init__()
398
        self._num_layers = num_layers
399
        self._config = dict(
400
            layer_name=layer_name,
401
            num_layers=num_layers,
402
            norm_name=norm_name,
403
            activation=activation,
404
            name=name,
405
            **kwargs,
406
        )
407
        self._layers = [
408
            self.layer_cls_dict[layer_name](use_bias=False, **kwargs)
409
            for _ in range(num_layers)
410
        ]
411
        self._norms = [self.norm_cls_dict[norm_name]() for _ in range(num_layers)]
412
        self._acts = [tfkl.Activation(activation=activation) for _ in range(num_layers)]
413
414
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
415
        """
416
        Forward.
417
418
        :param inputs: inputs for the layer
419
        :param training: training flag for normalization layers (default: None)
420
        :param kwargs: additional arguments.
421
        :return:
422
        """
423
424
        output = inputs
425
        for i in range(self._num_layers):
426
            output = self._layers[i](inputs=output)
427
            output = self._norms[i](inputs=output, training=training)
428
            if i == self._num_layers - 1:
429
                # last block
430
                output = output + inputs
431
            output = self._acts[i](output)
432
        return output
433
434
    def get_config(self) -> dict:
435
        """Return the config dictionary for recreating this class."""
436
        config = super().get_config()
437
        config.update(self._config)
438
        return config
439
440
441
class ResidualConv3dBlock(ResidualBlock):
442
    """
443
    A conv3d residual block
444
    """
445
446
    def __init__(
447
        self,
448
        name: str = "conv3d_res_block",
449
        **kwargs,
450
    ):
451
        """
452
        Init.
453
454
        :param name: name of the layer
455
        :param kwargs: additional arguments.
456
        """
457
        super().__init__(layer_name="conv3d", name=name, **kwargs)
458
459
460
class UpSampleResnetBlock(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
461
    def __init__(self, filters, kernel_size=3, concat=False, **kwargs):
0 ignored issues
show
introduced by
"concat, filters, kernel_size" missing in parameter type documentation
Loading history...
462
        """
463
        An up-sampling resnet conv3d block, with deconv3d.
464
465
        :param filters: number of channels of the output
466
        :param kernel_size: int or tuple of 3 ints, e.g. (3,3,3) or 3
467
        :param concat: bool,specify how to combine input and skip connection images.
468
            If True, use concatenation, otherwise use sum (default=False).
469
        :param kwargs: additional arguments.
470
        """
471
        super().__init__(**kwargs)
472
        # save parameters
473
        self._filters = filters
474
        self._concat = concat
475
        # init layer variables
476
        self._deconv3d_block = None
477
        self._conv3d_block = Conv3dBlock(
478
            filters=filters, kernel_size=kernel_size, padding="same"
479
        )
480
        self._residual_block = ResidualConv3dBlock(
481
            filters=filters, kernel_size=kernel_size, padding="same"
482
        )
483
484
    def build(self, input_shape):
0 ignored issues
show
introduced by
"input_shape" missing in parameter type documentation
Loading history...
485
        """
486
        :param input_shape: tuple, (downsampled_image_shape, skip_image_shape)
487
        """
488
        super().build(input_shape)
489
        skip_shape = input_shape[1][1:4]
490
        self._deconv3d_block = Deconv3dBlock(
491
            filters=self._filters,
492
            output_shape=skip_shape,
493
            kernel_size=3,
494
            strides=2,
495
            padding="same",
496
        )
497
498
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
499
        r"""
500
        :param inputs: tuple
501
502
          - down-sampled
503
          - skipped
504
505
        :param training: training flag for normalization layers (default: None)
506
        :param kwargs: additional arguments.
507
        :return: shape = (batch, \*skip_connection_image_shape, filters]
508
        """
509
        up_sampled, skip = inputs[0], inputs[1]
510
        up_sampled = self._deconv3d_block(
511
            inputs=up_sampled, training=training
512
        )  # up sample and change channel
513
        up_sampled = (
514
            tf.concat([up_sampled, skip], axis=4) if self._concat else up_sampled + skip
515
        )  # combine
516
        up_sampled = self._conv3d_block(
517
            inputs=up_sampled, training=training
518
        )  # adjust channel
519
        up_sampled = self._residual_block(inputs=up_sampled, training=training)  # conv
520
        return up_sampled
521
522
523
class IntDVF(tfkl.Layer):
524
    """
525
    Integrate DVF to get DDF.
526
527
    Reference:
528
529
    - integrate_vec of neuron
530
      https://github.com/adalca/neurite/blob/legacy/neuron/utils.py
531
    """
532
533
    def __init__(
534
        self,
535
        fixed_image_size: tuple,
536
        num_steps: int = 7,
537
        name: str = "int_dvf",
538
        **kwargs,
539
    ):
540
        """
541
        Init.
542
543
        :param fixed_image_size: tuple, (f_dim1, f_dim2, f_dim3)
544
        :param num_steps: int, number of steps for integration
545
        :param name: name of the layer
546
        :param kwargs: additional arguments.
547
        """
548
        super().__init__(name=name, **kwargs)
549
        assert len(fixed_image_size) == 3
550
        self._fixed_image_size = fixed_image_size
551
        self._num_steps = num_steps
552
        self._warping = Warping(fixed_image_size=fixed_image_size)
553
554
    def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor:
555
        """
556
        :param inputs: dvf, shape = (batch, f_dim1, f_dim2, f_dim3, 3)
557
        :param kwargs: additional arguments.
558
        :return: ddf, shape = (batch, f_dim1, f_dim2, f_dim3, 3)
559
        """
560
        ddf = inputs / (2 ** self._num_steps)
561
        for _ in range(self._num_steps):
562
            ddf += self._warping(inputs=[ddf, ddf])
563
        return ddf
564
565
    def get_config(self) -> dict:
566
        """Return the config dictionary for recreating this class."""
567
        config = super().get_config()
568
        config["fixed_image_size"] = self._fixed_image_size
569
        config["num_steps"] = self._num_steps
570
        return config
571
572
573
class LocalNetResidual3dBlock(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
574
    def __init__(
575
        self,
576
        filters: int,
577
        kernel_size: (int, tuple) = 3,
578
        strides: (int, tuple) = 1,
579
        activation: str = "relu",
580
        **kwargs,
581
    ):
582
        """
583
        A resnet conv3d block, simpler than ResidualConv3dBlock.
584
585
        1. conved = conv3d(inputs)
586
        2. out = act(norm(conved) + inputs)
587
588
        :param filters: number of channels of the output
589
        :param kernel_size: int or tuple of 3 ints, e.g. (3,3,3) or 3
590
        :param strides: int or tuple of 3 ints, e.g. (1,1,1) or 1
591
        :param activation: name of activation
592
        :param kwargs: additional arguments.
593
        """
594
        super().__init__(**kwargs)
595
        # init layer variables
596
        self._conv3d = tfkl.Conv3D(
597
            filters=filters,
598
            kernel_size=kernel_size,
599
            strides=strides,
600
            padding="same",
601
            use_bias=False,
602
        )
603
        self._norm = tfkl.BatchNormalization()
604
        self._act = tfkl.Activation(activation=activation)
605
606
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
607
        return self._act(
608
            self._norm(inputs=self._conv3d(inputs=inputs[0]), training=training)
609
            + inputs[1]
610
        )
611
612
613
class LocalNetUpSampleResnetBlock(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
614
    def __init__(self, filters: int, use_additive_upsampling: bool = True, **kwargs):
615
        """
616
        Layer up-samples tensor with two inputs (skipped and down-sampled).
617
618
        :param filters: int, number of output channels
619
        :param use_additive_upsampling: bool to used additive upsampling
620
        :param kwargs: additional arguments.
621
        """
622
        super().__init__(**kwargs)
623
        # save parameters
624
        self._filters = filters
625
        self._use_additive_upsampling = use_additive_upsampling
626
        # init layer variables
627
        self._deconv3d_block = None
628
        self._additive_upsampling = None
629
        self._conv3d_block = Conv3dBlock(filters=filters, kernel_size=3, padding="same")
630
        self._residual_block = LocalNetResidual3dBlock(filters=filters, strides=1)
631
632
    def build(self, input_shape):
0 ignored issues
show
introduced by
"input_shape" missing in parameter type documentation
Loading history...
633
        """
634
        :param input_shape: tuple (nonskip_tensor_shape, skip_tensor_shape)
635
        """
636
        super().build(input_shape)
637
638
        output_shape = input_shape[1][1:4]
639
        self._deconv3d_block = Deconv3dBlock(
640
            filters=self._filters,
641
            output_shape=output_shape,
642
            kernel_size=3,
643
            strides=2,
644
            padding="same",
645
        )
646
        if self._use_additive_upsampling:
647
            self._resize = Resize3d(shape=output_shape)
648
649
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
650
        """
651
        :param inputs: list = [inputs_nonskip, inputs_skip]
652
        :param training: training flag for normalization layers (default: None)
653
        :param kwargs: additional arguments.
654
        :return:
655
        """
656
        inputs_nonskip, inputs_skip = inputs[0], inputs[1]
657
        h0 = self._deconv3d_block(inputs=inputs_nonskip, training=training)
658
        if self._use_additive_upsampling:
659
            upsampled = self._resize(inputs=inputs_nonskip)
660
            upsampled = tf.split(upsampled, num_or_size_splits=2, axis=4)
0 ignored issues
show
Unused Code introduced by
Argument 'axis' passed by position and keyword in function call
Loading history...
661
            upsampled = tf.add_n(upsampled)
662
            h0 = h0 + upsampled
663
        r1 = h0 + inputs_skip
664
        r2 = self._conv3d_block(inputs=h0, training=training)
665
        h1 = self._residual_block(inputs=[r2, r1], training=training)
666
        return h1
667
668
669
class ResizeCPTransform(tfkl.Layer):
670
    """
671
    Layer for getting the control points from the output of a image-to-image network.
672
    It uses an anti-aliasing Gaussian filter before downsampling.
673
    """
674
675
    def __init__(self, control_point_spacing: (list, tuple, int), **kwargs):
676
        """
677
        :param control_point_spacing: list or int
678
        :param kwargs: additional arguments.
679
        """
680
        super().__init__(**kwargs)
681
682
        if isinstance(control_point_spacing, int):
683
            control_point_spacing = [control_point_spacing] * 3
684
685
        self.kernel_sigma = [
686
            0.44 * cp for cp in control_point_spacing
687
        ]  # 0.44 = ln(4)/pi
688
        self.cp_spacing = control_point_spacing
689
        self.kernel = None
690
        self._output_shape = None
691
        self._resize = None
692
693
    def build(self, input_shape):
694
        super().build(input_shape=input_shape)
695
696
        self.kernel = layer_util.gaussian_filter_3d(self.kernel_sigma)
697
        output_shape = [
698
            tf.cast(tf.math.ceil(v / c) + 3, tf.int32)
699
            for v, c in zip(input_shape[1:-1], self.cp_spacing)
700
        ]
701
        self._output_shape = output_shape
702
        self._resize = Resize3d(output_shape)
703
704
    def call(self, inputs, **kwargs) -> tf.Tensor:
705
        output = tf.nn.conv3d(
706
            inputs, self.kernel, strides=(1, 1, 1, 1, 1), padding="SAME"
707
        )
708
        output = self._resize(inputs=output)
709
        return output
710
711
712
class BSplines3DTransform(tfkl.Layer):
713
    """
714
     Layer for BSplines interpolation with precomputed cubic spline filters.
715
     It assumes a full sized image from which:
716
     1. it compute the contol points values by downsampling the initial image
717
     2. performs the interpolation
718
     3. crops the image around the valid values.
719
720
    :param cp_spacing: int or tuple of three ints specifying the spacing (in pixels)
721
        in each dimension. When a single int is used,
722
        the same spacing to all dimensions is used
723
    :param output_shape: (batch_size, dim0, dim1, dim2, 3) of the high resolution
724
        deformation fields.
725
    :param kwargs: additional arguments.
726
    """
727
728
    def __init__(self, cp_spacing: (int, tuple), output_shape: tuple, **kwargs):
729
730
        super().__init__(**kwargs)
731
732
        self.filters = []
733
        self._output_shape = output_shape
734
735
        if isinstance(cp_spacing, int):
736
            self.cp_spacing = (cp_spacing, cp_spacing, cp_spacing)
737
        else:
738
            self.cp_spacing = cp_spacing
739
740
    def build(self, input_shape: tuple):
0 ignored issues
show
introduced by
Redundant returns documentation
Loading history...
741
        """
742
        :param input_shape: tuple with the input shape
743
        :return: None
744
        """
745
746
        super().build(input_shape=input_shape)
747
748
        b = {
749
            0: lambda u: np.float64((1 - u) ** 3 / 6),
750
            1: lambda u: np.float64((3 * (u ** 3) - 6 * (u ** 2) + 4) / 6),
751
            2: lambda u: np.float64((-3 * (u ** 3) + 3 * (u ** 2) + 3 * u + 1) / 6),
752
            3: lambda u: np.float64(u ** 3 / 6),
753
        }
754
755
        filters = np.zeros(
756
            (
757
                4 * self.cp_spacing[0],
758
                4 * self.cp_spacing[1],
759
                4 * self.cp_spacing[2],
760
                3,
761
                3,
762
            ),
763
            dtype=np.float32,
764
        )
765
766
        u_arange = 1 - np.arange(
767
            1 / (2 * self.cp_spacing[0]), 1, 1 / self.cp_spacing[0]
768
        )
769
        v_arange = 1 - np.arange(
770
            1 / (2 * self.cp_spacing[1]), 1, 1 / self.cp_spacing[1]
771
        )
772
        w_arange = 1 - np.arange(
773
            1 / (2 * self.cp_spacing[2]), 1, 1 / self.cp_spacing[2]
774
        )
775
776
        filter_idx = [[0, 1, 2, 3] for _ in range(3)]
777
        filter_coord = list(itertools.product(*filter_idx))
778
779
        for f_idx in filter_coord:
780
            for it_dim in range(3):
781
                filters[
782
                    f_idx[0] * self.cp_spacing[0] : (f_idx[0] + 1) * self.cp_spacing[0],
783
                    f_idx[1] * self.cp_spacing[1] : (f_idx[1] + 1) * self.cp_spacing[1],
784
                    f_idx[2] * self.cp_spacing[2] : (f_idx[2] + 1) * self.cp_spacing[2],
785
                    it_dim,
786
                    it_dim,
787
                ] = (
788
                    b[f_idx[0]](u_arange)[:, None, None]
789
                    * b[f_idx[1]](v_arange)[None, :, None]
790
                    * b[f_idx[2]](w_arange)[None, None, :]
791
                )
792
793
        self.filter = tf.convert_to_tensor(filters)
794
795
    def interpolate(self, field) -> tf.Tensor:
0 ignored issues
show
introduced by
"field" missing in parameter type documentation
Loading history...
796
        """
797
        :param field: tf.Tensor with shape=number_of_control_points_per_dim
798
        :return: interpolated_field: tf.Tensor
799
        """
800
801
        image_shape = tuple(
802
            [(a - 1) * b + 4 * b for a, b in zip(field.shape[1:-1], self.cp_spacing)]
803
        )
804
805
        output_shape = (field.shape[0],) + image_shape + (3,)
806
        return tf.nn.conv3d_transpose(
807
            field,
808
            self.filter,
809
            output_shape=output_shape,
810
            strides=self.cp_spacing,
811
            padding="VALID",
812
        )
813
814
    def call(self, inputs, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"inputs" missing in parameter type documentation
Loading history...
815
        """
816
        :param inputs: tf.Tensor defining a low resolution free-form deformation field
817
        :param kwargs: additional arguments.
818
        :return: interpolated_field: tf.Tensor of shape=self.input_shape
819
        """
820
        high_res_field = self.interpolate(inputs)
821
822
        index = [int(3 * c) for c in self.cp_spacing]
823
        return high_res_field[
824
            :,
825
            index[0] : index[0] + self._output_shape[0],
826
            index[1] : index[1] + self._output_shape[1],
827
            index[2] : index[2] + self._output_shape[2],
828
        ]
829