Passed
Pull Request — main (#656)
by Yunguan
02:46
created

deepreg.model.layer.ResidualBlock.call()   A

Complexity

Conditions 3

Size

Total Lines 19
Code Lines 9

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 9
dl 0
loc 19
rs 9.95
c 0
b 0
f 0
cc 3
nop 4
1
"""This module defines custom layers."""
2
import itertools
3
from typing import Tuple, Union
4
5
import numpy as np
6
import tensorflow as tf
7
import tensorflow.keras.layers as tfkl
8
9
import deepreg.model.layer_util as layer_util
10
11
LAYER_DICT = dict(conv3d=tfkl.Conv3D, deconv3d=tfkl.Conv3DTranspose)
12
NORM_DICT = dict(batch=tfkl.BatchNormalization, layer=tfkl.LayerNormalization)
13
14
15
def _deconv_output_padding(
0 ignored issues
show
introduced by
"ValueError" not documented as being raised
Loading history...
16
    input_shape: int, output_shape: int, kernel_size: int, stride: int, padding: str
17
) -> int:
18
    """
19
    Calculate output padding for Conv3DTranspose in 1D.
20
21
    - output_shape = (input_shape - 1)*stride + kernel_size - 2*pad + output_padding
22
    - output_padding = output_shape - ((input_shape - 1)*stride + kernel_size - 2*pad)
23
24
    Reference:
25
26
    - https://github.com/tensorflow/tensorflow/blob/r2.3/tensorflow/python/keras/utils/conv_utils.py#L140
27
28
    :param input_shape: shape of input tensor, without batch or channel
29
    :param output_shape: shape of out tensor, without batch or channel
30
    :param kernel_size: kernel size of Conv3DTranspose layer
31
    :param stride: stride of Conv3DTranspose layer
32
    :param padding: padding of Conv3DTranspose layer
33
    :return: output_padding
34
    """
35
    if padding == "same":
36
        pad = kernel_size // 2
37
    elif padding == "valid":
38
        pad = 0
39
    elif padding == "full":
40
        pad = kernel_size - 1
41
    else:
42
        raise ValueError(f"Unknown padding {padding} in deconv_output_padding")
43
    return output_shape - ((input_shape - 1) * stride + kernel_size - 2 * pad)
44
45
46
def deconv_output_padding(
47
    input_shape: Union[Tuple[int], int],
48
    output_shape: Union[Tuple[int], int],
49
    kernel_size: Union[Tuple[int], int],
50
    stride: Union[Tuple[int], int],
51
    padding: str,
52
) -> Union[Tuple[int], int]:
53
    """
54
    Calculate output padding for Conv3DTranspose in any dimension.
55
56
    :param input_shape: shape of input tensor, without batch or channel
57
    :param output_shape: shape of out tensor, without batch or channel
58
    :param kernel_size: kernel size of Conv3DTranspose layer
59
    :param stride: stride of Conv3DTranspose layer
60
    :param padding: padding of Conv3DTranspose layer
61
    :return: output_padding
62
    """
63
    if isinstance(input_shape, int):
64
        return _deconv_output_padding(
65
            input_shape=input_shape,
66
            output_shape=output_shape,
67
            kernel_size=kernel_size,
68
            stride=stride,
69
            padding=padding,
70
        )
71
    assert len(input_shape) == len(output_shape)
72
    dim = len(input_shape)
73
    if isinstance(kernel_size, int):
74
        kernel_size = [kernel_size] * dim
75
    if isinstance(stride, int):
76
        stride = [stride] * dim
77
    return tuple(
78
        [
79
            _deconv_output_padding(
80
                input_shape=input_shape[d],
81
                output_shape=output_shape[d],
82
                kernel_size=kernel_size[d],
83
                stride=stride[d],
84
                padding=padding,
85
            )
86
            for d in range(dim)
87
        ]
88
    )
89
90
91
class NormBlock(tfkl.Layer):
92
    """
93
    A block with layer - norm - activation.
94
    """
95
96
    def __init__(
97
        self,
98
        layer_name: str,
99
        norm_name: str = "batch",
100
        activation: str = "relu",
101
        name: str = "norm_block",
102
        **kwargs,
103
    ):
104
        """
105
        Init.
106
107
        :param layer_name: class of the layer to be wrapped.
108
        :param norm_name: class of the normalization layer.
109
        :param activation: name of activation.
110
        :param name: name of the block layer.
111
        :param kwargs: additional arguments.
112
        """
113
        super().__init__()
114
        self._config = dict(
115
            layer_name=layer_name,
116
            norm_name=norm_name,
117
            activation=activation,
118
            name=name,
119
            **kwargs,
120
        )
121
        self._layer = LAYER_DICT[layer_name](use_bias=False, **kwargs)
122
        self._norm = NORM_DICT[norm_name]()
123
        self._act = tfkl.Activation(activation=activation)
124
125
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
126
        """
127
        Forward.
128
129
        :param inputs: inputs for the layer
130
        :param training: training flag for normalization layers (default: None)
131
        :param kwargs: additional arguments.
132
        :return:
133
        """
134
        output = self._layer(inputs=inputs)
135
        output = self._norm(inputs=output, training=training)
136
        output = self._act(output)
137
        return output
138
139
    def get_config(self) -> dict:
140
        """Return the config dictionary for recreating this class."""
141
        config = super().get_config()
142
        config.update(self._config)
143
        return config
144
145
146
class Conv3dBlock(NormBlock):
147
    """
148
    A conv3d block having conv3d - norm - activation.
149
    """
150
151
    def __init__(
152
        self,
153
        name: str = "conv3d_block",
154
        **kwargs,
155
    ):
156
        """
157
        Init.
158
159
        :param name: name of the layer
160
        :param kwargs: additional arguments.
161
        """
162
        super().__init__(layer_name="conv3d", name=name, **kwargs)
163
164
165
class Deconv3dBlock(NormBlock):
166
    """
167
    A deconv3d block having conv3d - norm - activation.
168
    """
169
170
    def __init__(
171
        self,
172
        name: str = "deconv3d_block",
173
        **kwargs,
174
    ):
175
        """
176
        Init.
177
178
        :param name: name of the layer
179
        :param kwargs: additional arguments.
180
        """
181
        super().__init__(layer_name="deconv3d", name=name, **kwargs)
182
183
184
class Resize3d(tfkl.Layer):
185
    """
186
    Resize image in two folds.
187
188
    - resize dim2 and dim3
189
    - resize dim1 and dim2
190
    """
191
192
    def __init__(
193
        self,
194
        shape: tuple,
195
        method: str = tf.image.ResizeMethod.BILINEAR,
196
        name: str = "resize3d",
197
    ):
198
        """
199
        Init, save arguments.
200
201
        :param shape: (dim1, dim2, dim3)
202
        :param method: tf.image.ResizeMethod
203
        :param name: name of the layer
204
        """
205
        super().__init__(name=name)
206
        assert len(shape) == 3
207
        self._shape = shape
208
        self._method = method
209
210
    def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"ValueError" not documented as being raised
Loading history...
211
        """
212
        Perform two fold resize.
213
214
        :param inputs: shape = (batch, dim1, dim2, dim3, channels)
215
                                     or (batch, dim1, dim2, dim3)
216
                                     or (dim1, dim2, dim3)
217
        :param kwargs: additional arguments
218
        :return: shape = (batch, out_dim1, out_dim2, out_dim3, channels)
219
                                or (batch, dim1, dim2, dim3)
220
                                or (dim1, dim2, dim3)
221
        """
222
        # sanity check
223
        image = inputs
224
        image_dim = len(image.shape)
225
226
        # init
227
        if image_dim == 5:
228
            has_channel = True
229
            has_batch = True
230
            input_image_shape = image.shape[1:4]
231
        elif image_dim == 4:
232
            has_channel = False
233
            has_batch = True
234
            input_image_shape = image.shape[1:4]
235
        elif image_dim == 3:
236
            has_channel = False
237
            has_batch = False
238
            input_image_shape = image.shape[0:3]
239
        else:
240
            raise ValueError(
241
                "Resize3d takes input image of dimension 3 or 4 or 5, "
242
                "corresponding to (dim1, dim2, dim3) "
243
                "or (batch, dim1, dim2, dim3) "
244
                "or (batch, dim1, dim2, dim3, channels), "
245
                "got image shape{}".format(image.shape)
246
            )
247
248
        # no need of resize
249
        if input_image_shape == tuple(self._shape):
250
            return image
251
252
        # expand to five dimensions
253
        if not has_batch:
254
            image = tf.expand_dims(image, axis=0)
255
        if not has_channel:
256
            image = tf.expand_dims(image, axis=-1)
257
        assert len(image.shape) == 5  # (batch, dim1, dim2, dim3, channels)
258
        image_shape = tf.shape(image)
259
260
        # merge axis 0 and 1
261
        output = tf.reshape(
262
            image, (-1, image_shape[2], image_shape[3], image_shape[4])
263
        )  # (batch * dim1, dim2, dim3, channels)
264
265
        # resize dim2 and dim3
266
        output = tf.image.resize(
267
            images=output, size=self._shape[1:3], method=self._method
268
        )  # (batch * dim1, out_dim2, out_dim3, channels)
269
270
        # split axis 0 and merge axis 3 and 4
271
        output = tf.reshape(
272
            output,
273
            shape=(-1, image_shape[1], self._shape[1], self._shape[2] * image_shape[4]),
274
        )  # (batch, dim1, out_dim2, out_dim3 * channels)
275
276
        # resize dim1 and dim2
277
        output = tf.image.resize(
278
            images=output, size=self._shape[:2], method=self._method
279
        )  # (batch, out_dim1, out_dim2, out_dim3 * channels)
280
281
        # reshape
282
        output = tf.reshape(
283
            output, shape=[-1, *self._shape, image_shape[4]]
284
        )  # (batch, out_dim1, out_dim2, out_dim3, channels)
285
286
        # squeeze to original dimension
287
        if not has_batch:
288
            output = tf.squeeze(output, axis=0)
289
        if not has_channel:
290
            output = tf.squeeze(output, axis=-1)
291
        return output
292
293
    def get_config(self) -> dict:
294
        """Return the config dictionary for recreating this class."""
295
        config = super().get_config()
296
        config["shape"] = self._shape
297
        config["method"] = self._method
298
        return config
299
300
301
class Warping(tfkl.Layer):
302
    """
303
    Warps an image with DDF.
304
305
    Reference:
306
307
    https://github.com/adalca/neurite/blob/legacy/neuron/utils.py
308
    where vol = image, loc_shift = ddf
309
    """
310
311
    def __init__(self, fixed_image_size: tuple, name: str = "warping", **kwargs):
312
        """
313
        Init.
314
315
        :param fixed_image_size: shape = (f_dim1, f_dim2, f_dim3)
316
             or (f_dim1, f_dim2, f_dim3, ch) with the last channel for features
317
        :param name: name of the layer
318
        :param kwargs: additional arguments.
319
        """
320
        super().__init__(name=name, **kwargs)
321
        self._fixed_image_size = fixed_image_size
322
        # shape = (1, f_dim1, f_dim2, f_dim3, 3)
323
        self.grid_ref = layer_util.get_reference_grid(grid_size=fixed_image_size)
324
        self.grid_ref = self.grid_ref[None, ...]
325
326
    def call(self, inputs, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"inputs" missing in parameter type documentation
Loading history...
327
        """
328
        :param inputs: (ddf, image)
329
330
          - ddf, shape = (batch, f_dim1, f_dim2, f_dim3, 3)
331
          - image, shape = (batch, m_dim1, m_dim2, m_dim3)
332
        :param kwargs: additional arguments.
333
        :return: shape = (batch, f_dim1, f_dim2, f_dim3)
334
        """
335
        ddf, image = inputs
336
        return layer_util.resample(vol=image, loc=self.grid_ref + ddf)
337
338
    def get_config(self) -> dict:
339
        """Return the config dictionary for recreating this class."""
340
        config = super().get_config()
341
        config["fixed_image_size"] = self._fixed_image_size
342
        return config
343
344
345
class ResidualBlock(tfkl.Layer):
346
    """
347
    A block with skip links and layer - norm - activation.
348
    """
349
350
    def __init__(
351
        self,
352
        layer_name: str,
353
        num_layers: int = 2,
354
        norm_name: str = "batch",
355
        activation: str = "relu",
356
        name: str = "res_block",
357
        **kwargs,
358
    ):
359
        """
360
        Init.
361
362
        :param layer_name: class of the layer to be wrapped.
363
        :param num_layers: number of layers/blocks.
364
        :param norm_name: class of the normalization layer.
365
        :param activation: name of activation.
366
        :param name: name of the block layer.
367
        :param kwargs: additional arguments.
368
        """
369
        super().__init__()
370
        self._num_layers = num_layers
371
        self._config = dict(
372
            layer_name=layer_name,
373
            num_layers=num_layers,
374
            norm_name=norm_name,
375
            activation=activation,
376
            name=name,
377
            **kwargs,
378
        )
379
        self._layers = [
380
            LAYER_DICT[layer_name](use_bias=False, **kwargs) for _ in range(num_layers)
381
        ]
382
        self._norms = [NORM_DICT[norm_name]() for _ in range(num_layers)]
383
        self._acts = [tfkl.Activation(activation=activation) for _ in range(num_layers)]
384
385
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
386
        """
387
        Forward.
388
389
        :param inputs: inputs for the layer
390
        :param training: training flag for normalization layers (default: None)
391
        :param kwargs: additional arguments.
392
        :return:
393
        """
394
395
        output = inputs
396
        for i in range(self._num_layers):
397
            output = self._layers[i](inputs=output)
398
            output = self._norms[i](inputs=output, training=training)
399
            if i == self._num_layers - 1:
400
                # last block
401
                output = output + inputs
402
            output = self._acts[i](output)
403
        return output
404
405
    def get_config(self) -> dict:
406
        """Return the config dictionary for recreating this class."""
407
        config = super().get_config()
408
        config.update(self._config)
409
        return config
410
411
412
class ResidualConv3dBlock(ResidualBlock):
413
    """
414
    A conv3d residual block
415
    """
416
417
    def __init__(
418
        self,
419
        name: str = "conv3d_res_block",
420
        **kwargs,
421
    ):
422
        """
423
        Init.
424
425
        :param name: name of the layer
426
        :param kwargs: additional arguments.
427
        """
428
        super().__init__(layer_name="conv3d", name=name, **kwargs)
429
430
431
class UpSampleResnetBlock(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
432
    def __init__(
433
        self,
434
        filters: int,
435
        output_padding: tuple,
436
        kernel_size: int = 3,
437
        concat: bool = False,
438
        **kwargs,
439
    ):
440
        """
441
        An up-sampling resnet conv3d block, with deconv3d.
442
443
        :param filters: number of channels of the output
444
        :param output_padding: output padding for deconv block
445
        :param kernel_size: int or tuple of 3 ints, e.g. (3,3,3) or 3
446
        :param concat: bool,specify how to combine input and skip connection images.
447
            If True, use concatenation, otherwise use sum (default=False).
448
        :param kwargs: additional arguments.
449
        """
450
        super().__init__(**kwargs)
451
        # save parameters
452
        self._concat = concat
453
        # init layer variables
454
        self._deconv3d_block = Deconv3dBlock(
455
            filters=filters,
456
            output_padding=output_padding,
457
            kernel_size=3,
458
            strides=2,
459
            padding="same",
460
        )
461
        self._conv3d_block = Conv3dBlock(
462
            filters=filters, kernel_size=kernel_size, padding="same"
463
        )
464
        self._residual_block = ResidualConv3dBlock(
465
            filters=filters, kernel_size=kernel_size, padding="same"
466
        )
467
468
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
469
        r"""
470
        :param inputs: tuple
471
472
          - down-sampled
473
          - skipped
474
475
        :param training: training flag for normalization layers (default: None)
476
        :param kwargs: additional arguments.
477
        :return: shape = (batch, \*skip_connection_image_shape, kernel_size]
478
        """
479
        up_sampled, skip = inputs[0], inputs[1]
480
        up_sampled = self._deconv3d_block(
481
            inputs=up_sampled, training=training
482
        )  # up sample and change channel
483
        up_sampled = (
484
            tf.concat([up_sampled, skip], axis=4) if self._concat else up_sampled + skip
485
        )  # combine
486
        up_sampled = self._conv3d_block(
487
            inputs=up_sampled, training=training
488
        )  # adjust channel
489
        up_sampled = self._residual_block(inputs=up_sampled, training=training)  # conv
490
        return up_sampled
491
492
493
class IntDVF(tfkl.Layer):
494
    """
495
    Integrate DVF to get DDF.
496
497
    Reference:
498
499
    - integrate_vec of neuron
500
      https://github.com/adalca/neurite/blob/legacy/neuron/utils.py
501
    """
502
503
    def __init__(
504
        self,
505
        fixed_image_size: tuple,
506
        num_steps: int = 7,
507
        name: str = "int_dvf",
508
        **kwargs,
509
    ):
510
        """
511
        Init.
512
513
        :param fixed_image_size: tuple, (f_dim1, f_dim2, f_dim3)
514
        :param num_steps: int, number of steps for integration
515
        :param name: name of the layer
516
        :param kwargs: additional arguments.
517
        """
518
        super().__init__(name=name, **kwargs)
519
        assert len(fixed_image_size) == 3
520
        self._fixed_image_size = fixed_image_size
521
        self._num_steps = num_steps
522
        self._warping = Warping(fixed_image_size=fixed_image_size)
523
524
    def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor:
525
        """
526
        :param inputs: dvf, shape = (batch, f_dim1, f_dim2, f_dim3, 3)
527
        :param kwargs: additional arguments.
528
        :return: ddf, shape = (batch, f_dim1, f_dim2, f_dim3, 3)
529
        """
530
        ddf = inputs / (2 ** self._num_steps)
531
        for _ in range(self._num_steps):
532
            ddf += self._warping(inputs=[ddf, ddf])
533
        return ddf
534
535
    def get_config(self) -> dict:
536
        """Return the config dictionary for recreating this class."""
537
        config = super().get_config()
538
        config["fixed_image_size"] = self._fixed_image_size
539
        config["num_steps"] = self._num_steps
540
        return config
541
542
543
class LocalNetResidual3dBlock(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
544
    def __init__(
545
        self,
546
        filters: int,
547
        kernel_size: (int, tuple) = 3,
548
        strides: (int, tuple) = 1,
549
        activation: str = "relu",
550
        **kwargs,
551
    ):
552
        """
553
        A resnet conv3d block, simpler than ResidualConv3dBlock.
554
555
        1. conved = conv3d(inputs)
556
        2. out = act(norm(conved) + inputs)
557
558
        :param filters: number of channels of the output
559
        :param kernel_size: int or tuple of 3 ints, e.g. (3,3,3) or 3
560
        :param strides: int or tuple of 3 ints, e.g. (1,1,1) or 1
561
        :param activation: name of activation
562
        :param kwargs: additional arguments.
563
        """
564
        super().__init__(**kwargs)
565
        # init layer variables
566
        self._conv3d = tfkl.Conv3D(
567
            filters=filters,
568
            kernel_size=kernel_size,
569
            strides=strides,
570
            padding="same",
571
            use_bias=False,
572
        )
573
        self._norm = tfkl.BatchNormalization()
574
        self._act = tfkl.Activation(activation=activation)
575
576
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
577
        return self._act(
578
            self._norm(inputs=self._conv3d(inputs=inputs[0]), training=training)
579
            + inputs[1]
580
        )
581
582
583
class LocalNetUpSampleResnetBlock(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
584
    def __init__(
585
        self,
586
        filters: int,
587
        output_padding: tuple,
588
        output_shape: tuple,
589
        use_additive_upsampling: bool = True,
590
        **kwargs,
591
    ):
592
        """
593
        Layer up-samples tensor with two inputs (skipped and down-sampled).
594
595
        :param filters: int, number of output channels
596
        :param output_padding: output padding for deconv block
597
        :param output_shape: shape of the output
598
        :param use_additive_upsampling: bool to used additive upsampling
599
        :param kwargs: additional arguments.
600
        """
601
        super().__init__(**kwargs)
602
        # save parameters
603
        self._use_additive_upsampling = use_additive_upsampling
604
        # init layer variables
605
        self._deconv3d_block = Deconv3dBlock(
606
            filters=filters,
607
            output_padding=output_padding,
608
            kernel_size=3,
609
            strides=2,
610
            padding="same",
611
        )
612
        if self._use_additive_upsampling:
613
            self._resize = Resize3d(shape=output_shape)
614
        self._conv3d_block = Conv3dBlock(filters=filters, kernel_size=3, padding="same")
615
        self._residual_block = LocalNetResidual3dBlock(filters=filters, strides=1)
616
617
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
618
        """
619
        :param inputs: list = [inputs_nonskip, inputs_skip]
620
        :param training: training flag for normalization layers (default: None)
621
        :param kwargs: additional arguments.
622
        :return:
623
        """
624
        inputs_nonskip, inputs_skip = inputs[0], inputs[1]
625
        h0 = self._deconv3d_block(inputs=inputs_nonskip, training=training)
626
        if self._use_additive_upsampling:
627
            upsampled = self._resize(inputs=inputs_nonskip)
628
            upsampled = tf.split(upsampled, num_or_size_splits=2, axis=4)
0 ignored issues
show
Unused Code introduced by
Argument 'axis' passed by position and keyword in function call
Loading history...
629
            upsampled = tf.add_n(upsampled)
630
            h0 = h0 + upsampled
631
        r1 = h0 + inputs_skip
632
        r2 = self._conv3d_block(inputs=h0, training=training)
633
        h1 = self._residual_block(inputs=[r2, r1], training=training)
634
        return h1
635
636
637
class ResizeCPTransform(tfkl.Layer):
638
    """
639
    Layer for getting the control points from the output of a image-to-image network.
640
    It uses an anti-aliasing Gaussian filter before downsampling.
641
    """
642
643
    def __init__(self, control_point_spacing: (list, tuple, int), **kwargs):
644
        """
645
        :param control_point_spacing: list or int
646
        :param kwargs: additional arguments.
647
        """
648
        super().__init__(**kwargs)
649
650
        if isinstance(control_point_spacing, int):
651
            control_point_spacing = [control_point_spacing] * 3
652
653
        self.kernel_sigma = [
654
            0.44 * cp for cp in control_point_spacing
655
        ]  # 0.44 = ln(4)/pi
656
        self.cp_spacing = control_point_spacing
657
        self.kernel = None
658
        self._output_shape = None
659
        self._resize = None
660
661
    def build(self, input_shape):
662
        super().build(input_shape=input_shape)
663
664
        self.kernel = layer_util.gaussian_filter_3d(self.kernel_sigma)
665
        output_shape = [
666
            tf.cast(tf.math.ceil(v / c) + 3, tf.int32)
667
            for v, c in zip(input_shape[1:-1], self.cp_spacing)
668
        ]
669
        self._output_shape = output_shape
670
        self._resize = Resize3d(output_shape)
671
672
    def call(self, inputs, **kwargs) -> tf.Tensor:
673
        output = tf.nn.conv3d(
674
            inputs, self.kernel, strides=(1, 1, 1, 1, 1), padding="SAME"
675
        )
676
        output = self._resize(inputs=output)
677
        return output
678
679
680
class BSplines3DTransform(tfkl.Layer):
681
    """
682
     Layer for BSplines interpolation with precomputed cubic spline kernel_size.
683
     It assumes a full sized image from which:
684
     1. it compute the contol points values by downsampling the initial image
685
     2. performs the interpolation
686
     3. crops the image around the valid values.
687
688
    :param cp_spacing: int or tuple of three ints specifying the spacing (in pixels)
689
        in each dimension. When a single int is used,
690
        the same spacing to all dimensions is used
691
    :param output_shape: (batch_size, dim0, dim1, dim2, 3) of the high resolution
692
        deformation fields.
693
    :param kwargs: additional arguments.
694
    """
695
696
    def __init__(self, cp_spacing: (int, tuple), output_shape: tuple, **kwargs):
697
698
        super().__init__(**kwargs)
699
700
        self.filters = []
701
        self._output_shape = output_shape
702
703
        if isinstance(cp_spacing, int):
704
            self.cp_spacing = (cp_spacing, cp_spacing, cp_spacing)
705
        else:
706
            self.cp_spacing = cp_spacing
707
708
    def build(self, input_shape: tuple):
0 ignored issues
show
introduced by
Redundant returns documentation
Loading history...
709
        """
710
        :param input_shape: tuple with the input shape
711
        :return: None
712
        """
713
714
        super().build(input_shape=input_shape)
715
716
        b = {
717
            0: lambda u: np.float64((1 - u) ** 3 / 6),
718
            1: lambda u: np.float64((3 * (u ** 3) - 6 * (u ** 2) + 4) / 6),
719
            2: lambda u: np.float64((-3 * (u ** 3) + 3 * (u ** 2) + 3 * u + 1) / 6),
720
            3: lambda u: np.float64(u ** 3 / 6),
721
        }
722
723
        filters = np.zeros(
724
            (
725
                4 * self.cp_spacing[0],
726
                4 * self.cp_spacing[1],
727
                4 * self.cp_spacing[2],
728
                3,
729
                3,
730
            ),
731
            dtype=np.float32,
732
        )
733
734
        u_arange = 1 - np.arange(
735
            1 / (2 * self.cp_spacing[0]), 1, 1 / self.cp_spacing[0]
736
        )
737
        v_arange = 1 - np.arange(
738
            1 / (2 * self.cp_spacing[1]), 1, 1 / self.cp_spacing[1]
739
        )
740
        w_arange = 1 - np.arange(
741
            1 / (2 * self.cp_spacing[2]), 1, 1 / self.cp_spacing[2]
742
        )
743
744
        filter_idx = [[0, 1, 2, 3] for _ in range(3)]
745
        filter_coord = list(itertools.product(*filter_idx))
746
747
        for f_idx in filter_coord:
748
            for it_dim in range(3):
749
                filters[
750
                    f_idx[0] * self.cp_spacing[0] : (f_idx[0] + 1) * self.cp_spacing[0],
751
                    f_idx[1] * self.cp_spacing[1] : (f_idx[1] + 1) * self.cp_spacing[1],
752
                    f_idx[2] * self.cp_spacing[2] : (f_idx[2] + 1) * self.cp_spacing[2],
753
                    it_dim,
754
                    it_dim,
755
                ] = (
756
                    b[f_idx[0]](u_arange)[:, None, None]
757
                    * b[f_idx[1]](v_arange)[None, :, None]
758
                    * b[f_idx[2]](w_arange)[None, None, :]
759
                )
760
761
        self.filter = tf.convert_to_tensor(filters)
762
763
    def interpolate(self, field) -> tf.Tensor:
0 ignored issues
show
introduced by
"field" missing in parameter type documentation
Loading history...
764
        """
765
        :param field: tf.Tensor with shape=number_of_control_points_per_dim
766
        :return: interpolated_field: tf.Tensor
767
        """
768
769
        image_shape = tuple(
770
            [(a - 1) * b + 4 * b for a, b in zip(field.shape[1:-1], self.cp_spacing)]
771
        )
772
773
        output_shape = (field.shape[0],) + image_shape + (3,)
774
        return tf.nn.conv3d_transpose(
775
            field,
776
            self.filter,
777
            output_shape=output_shape,
778
            strides=self.cp_spacing,
779
            padding="VALID",
780
        )
781
782
    def call(self, inputs, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"inputs" missing in parameter type documentation
Loading history...
783
        """
784
        :param inputs: tf.Tensor defining a low resolution free-form deformation field
785
        :param kwargs: additional arguments.
786
        :return: interpolated_field: tf.Tensor of shape=self.input_shape
787
        """
788
        high_res_field = self.interpolate(inputs)
789
790
        index = [int(3 * c) for c in self.cp_spacing]
791
        return high_res_field[
792
            :,
793
            index[0] : index[0] + self._output_shape[0],
794
            index[1] : index[1] + self._output_shape[1],
795
            index[2] : index[2] + self._output_shape[2],
796
        ]
797