Passed
Pull Request — main (#656)
by Yunguan
02:56
created

deepreg.model.layer.Warping.get_config()   A

Complexity

Conditions 1

Size

Total Lines 5
Code Lines 4

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 4
dl 0
loc 5
rs 10
c 0
b 0
f 0
cc 1
nop 1
1
"""This module defines custom layers."""
2
import itertools
3
4
import numpy as np
5
import tensorflow as tf
6
import tensorflow.keras.layers as tfkl
7
8
import deepreg.model.layer_util as layer_util
9
10
11
class Deconv3d(tfkl.Layer):
12
    """
13
    Wrap Conv3DTranspose to allow dynamic output padding calculation.
14
    """
15
16
    def __init__(
0 ignored issues
show
introduced by
"name" missing in parameter type documentation
Loading history...
17
        self,
18
        filters: int,
19
        output_shape: (tuple, None) = None,
20
        kernel_size: int = 3,
21
        strides: int = 1,
22
        padding: str = "same",
23
        name="deconv3d",
24
        **kwargs,
25
    ):
26
        """
27
        Init.
28
29
        :param filters: number of channels of the output
30
        :param output_shape: (out_dim1, out_dim2, out_dim3)
31
        :param kernel_size: int or tuple of 3 ints, e.g. (3,3,3) or 3
32
        :param strides: int or tuple of 3 ints, e.g. (1,1,1) or 1
33
        :param padding: same or valid.
34
        :param name: name of the layer.
35
        :param kwargs: additional arguments for Conv3DTranspose.
36
        """
37
        super().__init__(name=name)
38
        # save arguments
39
        self._filters = filters
40
        self._output_shape = output_shape
41
        self._kernel_size = kernel_size
42
        self._strides = strides
43
        self._padding = padding
44
        self._kwargs = kwargs
45
        # init layer variables
46
        self._deconv3d = None
47
48
    def build(self, input_shape):
0 ignored issues
show
introduced by
"input_shape" missing in parameter type documentation
Loading history...
49
        # pylint: disable-next=line-too-long
50
        """
51
        Calculate output padding on the fly.
52
53
        https://github.com/tensorflow/tensorflow/blob/1cf0898dd4331baf93fe77205550f2c2e6c90ee5/tensorflow/python/keras/utils/conv_utils.py#L139-L185
54
55
        :param input_shape: shape of input
56
        """
57
        super().build(input_shape)
58
        kernel_size = self._kernel_size
59
        strides = self._strides
60
        if isinstance(kernel_size, int):
61
            kernel_size = [kernel_size] * 3
62
63
        if isinstance(strides, int):
64
            strides = [strides] * 3
65
66
        output_padding = None
67
        if self._output_shape is not None:
68
            assert self._padding == "same"
69
            output_padding = [
70
                self._output_shape[i]
71
                - (
72
                    (input_shape[1 + i] - 1) * strides[i]
73
                    + kernel_size[i]
74
                    - 2 * (kernel_size[i] // 2)
75
                )
76
                for i in range(3)
77
            ]
78
        self._deconv3d = tfkl.Conv3DTranspose(
79
            filters=self._filters,
80
            kernel_size=self._kernel_size,
81
            strides=self._strides,
82
            padding=self._padding,
83
            output_padding=output_padding,
84
            **self._kwargs,
85
        )
86
87
    def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor:
88
        """
89
        Forward.
90
91
        :param inputs: input tensor.
92
        :param kwargs: additional arguments
93
        :return:
94
        """
95
        return self._deconv3d(inputs=inputs)
96
97
    def get_config(self) -> dict:
98
        """Return the config dictionary for recreating this class."""
99
        config = super().get_config()
100
        config.update(
101
            dict(
102
                filters=self._filters,
103
                output_shape=self._output_shape,
104
                kernel_size=self._kernel_size,
105
                strides=self._strides,
106
                padding=self._padding,
107
            )
108
        )
109
        config.update(self._kwargs)
110
        return config
111
112
113
class NormBlock(tfkl.Layer):
114
    """
115
    A block with layer - norm - activation.
116
    """
117
118
    layer_cls_dict = dict(conv3d=tfkl.Conv3D, deconv3d=Deconv3d)
119
    norm_cls_dict = dict(batch=tfkl.BatchNormalization, layer=tfkl.LayerNormalization)
120
121
    def __init__(
122
        self,
123
        layer_name: str,
124
        norm_name: str = "batch",
125
        activation: str = "relu",
126
        name: str = "norm_block",
127
        **kwargs,
128
    ):
129
        """
130
        Init.
131
132
        :param layer_name: class of the layer to be wrapped.
133
        :param norm_name: class of the normalization layer.
134
        :param activation: name of activation.
135
        :param name: name of the block layer.
136
        :param kwargs: additional arguments.
137
        """
138
        super().__init__()
139
        self._config = dict(
140
            layer_name=layer_name,
141
            norm_name=norm_name,
142
            activation=activation,
143
            name=name,
144
            **kwargs,
145
        )
146
        self._layer = self.layer_cls_dict[layer_name](use_bias=False, **kwargs)
147
        self._norm = self.norm_cls_dict[norm_name]()
148
        self._act = tfkl.Activation(activation=activation)
149
150
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
151
        """
152
        Forward.
153
154
        :param inputs: inputs for the layer
155
        :param training: training flag for normalization layers (default: None)
156
        :param kwargs: additional arguments.
157
        :return:
158
        """
159
        output = self._layer(inputs=inputs)
160
        output = self._norm(inputs=output, training=training)
161
        output = self._act(output)
162
        return output
163
164
    def get_config(self) -> dict:
165
        """Return the config dictionary for recreating this class."""
166
        config = super().get_config()
167
        config.update(self._config)
168
        return config
169
170
171
class Conv3dBlock(NormBlock):
172
    """
173
    A conv3d block having conv3d - norm - activation.
174
    """
175
176
    def __init__(
177
        self,
178
        name: str = "conv3d_block",
179
        **kwargs,
180
    ):
181
        """
182
        Init.
183
184
        :param name: name of the layer
185
        :param kwargs: additional arguments.
186
        """
187
        super().__init__(layer_name="conv3d", name=name, **kwargs)
188
189
190
class Deconv3dBlock(NormBlock):
191
    """
192
    A deconv3d block having conv3d - norm - activation.
193
    """
194
195
    def __init__(
196
        self,
197
        name: str = "deconv3d_block",
198
        **kwargs,
199
    ):
200
        """
201
        Init.
202
203
        :param name: name of the layer
204
        :param kwargs: additional arguments.
205
        """
206
        super().__init__(layer_name="deconv3d", name=name, **kwargs)
207
208
209
class Resize3d(tfkl.Layer):
210
    """
211
    Resize image in two folds.
212
213
    - resize dim2 and dim3
214
    - resize dim1 and dim2
215
    """
216
217
    def __init__(
218
        self,
219
        shape: tuple,
220
        method: str = tf.image.ResizeMethod.BILINEAR,
221
        name: str = "resize3d",
222
    ):
223
        """
224
        Init, save arguments.
225
226
        :param shape: (dim1, dim2, dim3)
227
        :param method: tf.image.ResizeMethod
228
        :param name: name of the layer
229
        """
230
        super().__init__(name=name)
231
        assert len(shape) == 3
232
        self._shape = shape
233
        self._method = method
234
235
    def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"ValueError" not documented as being raised
Loading history...
236
        """
237
        Perform two fold resize.
238
239
        :param inputs: shape = (batch, dim1, dim2, dim3, channels)
240
                                     or (batch, dim1, dim2, dim3)
241
                                     or (dim1, dim2, dim3)
242
        :param kwargs: additional arguments
243
        :return: shape = (batch, out_dim1, out_dim2, out_dim3, channels)
244
                                or (batch, dim1, dim2, dim3)
245
                                or (dim1, dim2, dim3)
246
        """
247
        # sanity check
248
        image = inputs
249
        image_dim = len(image.shape)
250
251
        # init
252
        if image_dim == 5:
253
            has_channel = True
254
            has_batch = True
255
            input_image_shape = image.shape[1:4]
256
        elif image_dim == 4:
257
            has_channel = False
258
            has_batch = True
259
            input_image_shape = image.shape[1:4]
260
        elif image_dim == 3:
261
            has_channel = False
262
            has_batch = False
263
            input_image_shape = image.shape[0:3]
264
        else:
265
            raise ValueError(
266
                "Resize3d takes input image of dimension 3 or 4 or 5, "
267
                "corresponding to (dim1, dim2, dim3) "
268
                "or (batch, dim1, dim2, dim3) "
269
                "or (batch, dim1, dim2, dim3, channels), "
270
                "got image shape{}".format(image.shape)
271
            )
272
273
        # no need of resize
274
        if input_image_shape == tuple(self._shape):
275
            return image
276
277
        # expand to five dimensions
278
        if not has_batch:
279
            image = tf.expand_dims(image, axis=0)
280
        if not has_channel:
281
            image = tf.expand_dims(image, axis=-1)
282
        assert len(image.shape) == 5  # (batch, dim1, dim2, dim3, channels)
283
        image_shape = tf.shape(image)
284
285
        # merge axis 0 and 1
286
        output = tf.reshape(
287
            image, (-1, image_shape[2], image_shape[3], image_shape[4])
288
        )  # (batch * dim1, dim2, dim3, channels)
289
290
        # resize dim2 and dim3
291
        output = tf.image.resize(
292
            images=output, size=self._shape[1:3], method=self._method
293
        )  # (batch * dim1, out_dim2, out_dim3, channels)
294
295
        # split axis 0 and merge axis 3 and 4
296
        output = tf.reshape(
297
            output,
298
            shape=(-1, image_shape[1], self._shape[1], self._shape[2] * image_shape[4]),
299
        )  # (batch, dim1, out_dim2, out_dim3 * channels)
300
301
        # resize dim1 and dim2
302
        output = tf.image.resize(
303
            images=output, size=self._shape[:2], method=self._method
304
        )  # (batch, out_dim1, out_dim2, out_dim3 * channels)
305
306
        # reshape
307
        output = tf.reshape(
308
            output, shape=[-1, *self._shape, image_shape[4]]
309
        )  # (batch, out_dim1, out_dim2, out_dim3, channels)
310
311
        # squeeze to original dimension
312
        if not has_batch:
313
            output = tf.squeeze(output, axis=0)
314
        if not has_channel:
315
            output = tf.squeeze(output, axis=-1)
316
        return output
317
318
    def get_config(self) -> dict:
319
        """Return the config dictionary for recreating this class."""
320
        config = super().get_config()
321
        config["shape"] = self._shape
322
        config["method"] = self._method
323
        return config
324
325
326
class Warping(tfkl.Layer):
327
    """
328
    Warps an image with DDF.
329
330
    Reference:
331
332
    https://github.com/adalca/neurite/blob/legacy/neuron/utils.py
333
    where vol = image, loc_shift = ddf
334
    """
335
336
    def __init__(self, fixed_image_size: tuple, name: str = "warping", **kwargs):
337
        """
338
        Init.
339
340
        :param fixed_image_size: shape = (f_dim1, f_dim2, f_dim3)
341
             or (f_dim1, f_dim2, f_dim3, ch) with the last channel for features
342
        :param name: name of the layer
343
        :param kwargs: additional arguments.
344
        """
345
        super().__init__(name=name, **kwargs)
346
        self._fixed_image_size = fixed_image_size
347
        # shape = (1, f_dim1, f_dim2, f_dim3, 3)
348
        self.grid_ref = layer_util.get_reference_grid(grid_size=fixed_image_size)
349
        self.grid_ref = self.grid_ref[None, ...]
350
351
    def call(self, inputs, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"inputs" missing in parameter type documentation
Loading history...
352
        """
353
        :param inputs: (ddf, image)
354
355
          - ddf, shape = (batch, f_dim1, f_dim2, f_dim3, 3)
356
          - image, shape = (batch, m_dim1, m_dim2, m_dim3)
357
        :param kwargs: additional arguments.
358
        :return: shape = (batch, f_dim1, f_dim2, f_dim3)
359
        """
360
        ddf, image = inputs
361
        return layer_util.resample(vol=image, loc=self.grid_ref + ddf)
362
363
    def get_config(self) -> dict:
364
        """Return the config dictionary for recreating this class."""
365
        config = super().get_config()
366
        config["fixed_image_size"] = self._fixed_image_size
367
        return config
368
369
370
class ResidualBlock(tfkl.Layer):
371
    """
372
    A block with skip links and layer - norm - activation.
373
    """
374
375
    layer_cls_dict = dict(conv3d=tfkl.Conv3D, deconv3d=Deconv3d)
376
    norm_cls_dict = dict(batch=tfkl.BatchNormalization, layer=tfkl.LayerNormalization)
377
378
    def __init__(
379
        self,
380
        layer_name: str,
381
        num_layers: int = 2,
382
        norm_name: str = "batch",
383
        activation: str = "relu",
384
        name: str = "res_block",
385
        **kwargs,
386
    ):
387
        """
388
        Init.
389
390
        :param layer_name: class of the layer to be wrapped.
391
        :param num_layers: number of layers/blocks.
392
        :param norm_name: class of the normalization layer.
393
        :param activation: name of activation.
394
        :param name: name of the block layer.
395
        :param kwargs: additional arguments.
396
        """
397
        super().__init__()
398
        self._num_layers = num_layers
399
        self._config = dict(
400
            layer_name=layer_name,
401
            num_layers=num_layers,
402
            norm_name=norm_name,
403
            activation=activation,
404
            name=name,
405
            **kwargs,
406
        )
407
        self._layers = [
408
            self.layer_cls_dict[layer_name](use_bias=False, **kwargs)
409
            for _ in range(num_layers)
410
        ]
411
        self._norms = [self.norm_cls_dict[norm_name]() for _ in range(num_layers)]
412
        self._acts = [tfkl.Activation(activation=activation) for _ in range(num_layers)]
413
414
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
415
        """
416
        Forward.
417
418
        :param inputs: inputs for the layer
419
        :param training: training flag for normalization layers (default: None)
420
        :param kwargs: additional arguments.
421
        :return:
422
        """
423
424
        output = inputs
425
        for i in range(self._num_layers):
426
            output = self._layers[i](inputs=output)
427
            output = self._norms[i](inputs=output, training=training)
428
            if i == self._num_layers - 1:
429
                # last block
430
                output = output + inputs
431
            output = self._acts[i](output)
432
        return output
433
434
    def get_config(self) -> dict:
435
        """Return the config dictionary for recreating this class."""
436
        config = super().get_config()
437
        config.update(self._config)
438
        return config
439
440
441
class ResidualConv3dBlock(ResidualBlock):
442
    """
443
    A conv3d residual block
444
    """
445
446
    def __init__(
447
        self,
448
        name: str = "conv3d_res_block",
449
        **kwargs,
450
    ):
451
        """
452
        Init.
453
454
        :param name: name of the layer
455
        :param kwargs: additional arguments.
456
        """
457
        super().__init__(layer_name="conv3d", name=name, **kwargs)
458
459
460
class UpSampleResnetBlock(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
461
    def __init__(self, filters, kernel_size=3, concat=False, **kwargs):
0 ignored issues
show
introduced by
"concat, filters, kernel_size" missing in parameter type documentation
Loading history...
462
        """
463
        An up-sampling resnet conv3d block, with deconv3d.
464
465
        :param filters: number of channels of the output
466
        :param kernel_size: int or tuple of 3 ints, e.g. (3,3,3) or 3
467
        :param concat: bool,specify how to combine input and skip connection images.
468
            If True, use concatenation, otherwise use sum (default=False).
469
        :param kwargs: additional arguments.
470
        """
471
        super().__init__(**kwargs)
472
        # save parameters
473
        self._filters = filters
474
        self._concat = concat
475
        # init layer variables
476
        self._deconv3d_block = None
477
        self._conv3d_block = Conv3dBlock(
478
            filters=filters, kernel_size=kernel_size, padding="same"
479
        )
480
        self._residual_block = ResidualConv3dBlock(
481
            filters=filters, kernel_size=kernel_size, padding="same"
482
        )
483
484
    def build(self, input_shape):
0 ignored issues
show
introduced by
"input_shape" missing in parameter type documentation
Loading history...
485
        """
486
        :param input_shape: tuple, (downsampled_image_shape, skip_image_shape)
487
        """
488
        super().build(input_shape)
489
        skip_shape = input_shape[1][1:4]
490
        self._deconv3d_block = Deconv3dBlock(
491
            filters=self._filters,
492
            output_shape=skip_shape,
493
            kernel_size=3,
494
            strides=2,
495
            padding="same",
496
        )
497
498
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
499
        r"""
500
        :param inputs: tuple
501
502
          - down-sampled
503
          - skipped
504
505
        :param training: training flag for normalization layers (default: None)
506
        :param kwargs: additional arguments.
507
        :return: shape = (batch, \*skip_connection_image_shape, filters]
508
        """
509
        up_sampled, skip = inputs[0], inputs[1]
510
        up_sampled = self._deconv3d_block(
511
            inputs=up_sampled, training=training
512
        )  # up sample and change channel
513
        up_sampled = (
514
            tf.concat([up_sampled, skip], axis=4) if self._concat else up_sampled + skip
515
        )  # combine
516
        up_sampled = self._conv3d_block(
517
            inputs=up_sampled, training=training
518
        )  # adjust channel
519
        up_sampled = self._residual_block(inputs=up_sampled, training=training)  # conv
520
        return up_sampled
521
522
523
class IntDVF(tfkl.Layer):
524
    """
525
    Integrate DVF to get DDF.
526
527
    Reference:
528
529
    - integrate_vec of neuron
530
      https://github.com/adalca/neurite/blob/legacy/neuron/utils.py
531
    """
532
533
    def __init__(
0 ignored issues
show
introduced by
"name" missing in parameter documentation
Loading history...
534
        self,
535
        fixed_image_size: tuple,
536
        num_steps: int = 7,
537
        name: str = "int_dvf",
538
        **kwargs,
539
    ):
540
        """
541
        Init.
542
543
        :param fixed_image_size: tuple, (f_dim1, f_dim2, f_dim3)
544
        :param num_steps: int, number of steps for integration
545
        :param kwargs: additional arguments.
546
        """
547
        super().__init__(name=name, **kwargs)
548
        assert len(fixed_image_size) == 3
549
        self._fixed_image_size = fixed_image_size
550
        self._num_steps = num_steps
551
        self._warping = Warping(fixed_image_size=fixed_image_size)
552
553
    def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor:
554
        """
555
        :param inputs: dvf, shape = (batch, f_dim1, f_dim2, f_dim3, 3)
556
        :param kwargs: additional arguments.
557
        :return: ddf, shape = (batch, f_dim1, f_dim2, f_dim3, 3)
558
        """
559
        ddf = inputs / (2 ** self._num_steps)
560
        for _ in range(self._num_steps):
561
            ddf += self._warping(inputs=[ddf, ddf])
562
        return ddf
563
564
    def get_config(self) -> dict:
565
        """Return the config dictionary for recreating this class."""
566
        config = super().get_config()
567
        config["fixed_image_size"] = self._fixed_image_size
568
        config["num_steps"] = self._num_steps
569
        return config
570
571
572
class LocalNetResidual3dBlock(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
573
    def __init__(
574
        self,
575
        filters: int,
576
        kernel_size: (int, tuple) = 3,
577
        strides: (int, tuple) = 1,
578
        activation: str = "relu",
579
        **kwargs,
580
    ):
581
        """
582
        A resnet conv3d block, simpler than ResidualConv3dBlock.
583
584
        1. conved = conv3d(inputs)
585
        2. out = act(norm(conved) + inputs)
586
587
        :param filters: number of channels of the output
588
        :param kernel_size: int or tuple of 3 ints, e.g. (3,3,3) or 3
589
        :param strides: int or tuple of 3 ints, e.g. (1,1,1) or 1
590
        :param activation: name of activation
591
        :param kwargs: additional arguments.
592
        """
593
        super().__init__(**kwargs)
594
        # init layer variables
595
        self._conv3d = tfkl.Conv3D(
596
            filters=filters,
597
            kernel_size=kernel_size,
598
            strides=strides,
599
            padding="same",
600
            use_bias=False,
601
        )
602
        self._norm = tfkl.BatchNormalization()
603
        self._act = tfkl.Activation(activation=activation)
604
605
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
606
        return self._act(
607
            self._norm(inputs=self._conv3d(inputs=inputs[0]), training=training)
608
            + inputs[1]
609
        )
610
611
612
class LocalNetUpSampleResnetBlock(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
613
    def __init__(self, filters: int, use_additive_upsampling: bool = True, **kwargs):
614
        """
615
        Layer up-samples tensor with two inputs (skipped and down-sampled).
616
617
        :param filters: int, number of output channels
618
        :param use_additive_upsampling: bool to used additive upsampling
619
        :param kwargs: additional arguments.
620
        """
621
        super().__init__(**kwargs)
622
        # save parameters
623
        self._filters = filters
624
        self._use_additive_upsampling = use_additive_upsampling
625
        # init layer variables
626
        self._deconv3d_block = None
627
        self._additive_upsampling = None
628
        self._conv3d_block = Conv3dBlock(filters=filters, kernel_size=3, padding="same")
629
        self._residual_block = LocalNetResidual3dBlock(filters=filters, strides=1)
630
631
    def build(self, input_shape):
0 ignored issues
show
introduced by
"input_shape" missing in parameter type documentation
Loading history...
632
        """
633
        :param input_shape: tuple (nonskip_tensor_shape, skip_tensor_shape)
634
        """
635
        super().build(input_shape)
636
637
        output_shape = input_shape[1][1:4]
638
        self._deconv3d_block = Deconv3dBlock(
639
            filters=self._filters,
640
            output_shape=output_shape,
641
            kernel_size=3,
642
            strides=2,
643
            padding="same",
644
        )
645
        if self._use_additive_upsampling:
646
            self._resize = Resize3d(shape=output_shape)
647
648
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
649
        """
650
        :param inputs: list = [inputs_nonskip, inputs_skip]
651
        :param training: training flag for normalization layers (default: None)
652
        :param kwargs: additional arguments.
653
        :return:
654
        """
655
        inputs_nonskip, inputs_skip = inputs[0], inputs[1]
656
        h0 = self._deconv3d_block(inputs=inputs_nonskip, training=training)
657
        if self._use_additive_upsampling:
658
            upsampled = self._resize(inputs=inputs_nonskip)
659
            upsampled = tf.split(upsampled, num_or_size_splits=2, axis=4)
0 ignored issues
show
Unused Code introduced by
Argument 'axis' passed by position and keyword in function call
Loading history...
660
            upsampled = tf.add_n(upsampled)
661
            h0 = h0 + upsampled
662
        r1 = h0 + inputs_skip
663
        r2 = self._conv3d_block(inputs=h0, training=training)
664
        h1 = self._residual_block(inputs=[r2, r1], training=training)
665
        return h1
666
667
668
class ResizeCPTransform(tfkl.Layer):
669
    """
670
    Layer for getting the control points from the output of a image-to-image network.
671
    It uses an anti-aliasing Gaussian filter before downsampling.
672
    """
673
674
    def __init__(self, control_point_spacing: (list, tuple, int), **kwargs):
675
        """
676
        :param control_point_spacing: list or int
677
        :param kwargs: additional arguments.
678
        """
679
        super().__init__(**kwargs)
680
681
        if isinstance(control_point_spacing, int):
682
            control_point_spacing = [control_point_spacing] * 3
683
684
        self.kernel_sigma = [
685
            0.44 * cp for cp in control_point_spacing
686
        ]  # 0.44 = ln(4)/pi
687
        self.cp_spacing = control_point_spacing
688
        self.kernel = None
689
        self._output_shape = None
690
        self._resize = None
691
692
    def build(self, input_shape):
693
        super().build(input_shape=input_shape)
694
695
        self.kernel = layer_util.gaussian_filter_3d(self.kernel_sigma)
696
        output_shape = [
697
            tf.cast(tf.math.ceil(v / c) + 3, tf.int32)
698
            for v, c in zip(input_shape[1:-1], self.cp_spacing)
699
        ]
700
        self._output_shape = output_shape
701
        self._resize = Resize3d(output_shape)
702
703
    def call(self, inputs, **kwargs) -> tf.Tensor:
704
        output = tf.nn.conv3d(
705
            inputs, self.kernel, strides=(1, 1, 1, 1, 1), padding="SAME"
706
        )
707
        output = self._resize(inputs=output)
708
        return output
709
710
711
class BSplines3DTransform(tfkl.Layer):
712
    """
713
     Layer for BSplines interpolation with precomputed cubic spline filters.
714
     It assumes a full sized image from which:
715
     1. it compute the contol points values by downsampling the initial image
716
     2. performs the interpolation
717
     3. crops the image around the valid values.
718
719
    :param cp_spacing: int or tuple of three ints specifying the spacing (in pixels)
720
        in each dimension. When a single int is used,
721
        the same spacing to all dimensions is used
722
    :param output_shape: (batch_size, dim0, dim1, dim2, 3) of the high resolution
723
        deformation fields.
724
    :param kwargs: additional arguments.
725
    """
726
727
    def __init__(self, cp_spacing: (int, tuple), output_shape: tuple, **kwargs):
728
729
        super().__init__(**kwargs)
730
731
        self.filters = []
732
        self._output_shape = output_shape
733
734
        if isinstance(cp_spacing, int):
735
            self.cp_spacing = (cp_spacing, cp_spacing, cp_spacing)
736
        else:
737
            self.cp_spacing = cp_spacing
738
739
    def build(self, input_shape: tuple):
0 ignored issues
show
introduced by
Redundant returns documentation
Loading history...
740
        """
741
        :param input_shape: tuple with the input shape
742
        :return: None
743
        """
744
745
        super().build(input_shape=input_shape)
746
747
        b = {
748
            0: lambda u: np.float64((1 - u) ** 3 / 6),
749
            1: lambda u: np.float64((3 * (u ** 3) - 6 * (u ** 2) + 4) / 6),
750
            2: lambda u: np.float64((-3 * (u ** 3) + 3 * (u ** 2) + 3 * u + 1) / 6),
751
            3: lambda u: np.float64(u ** 3 / 6),
752
        }
753
754
        filters = np.zeros(
755
            (
756
                4 * self.cp_spacing[0],
757
                4 * self.cp_spacing[1],
758
                4 * self.cp_spacing[2],
759
                3,
760
                3,
761
            ),
762
            dtype=np.float32,
763
        )
764
765
        u_arange = 1 - np.arange(
766
            1 / (2 * self.cp_spacing[0]), 1, 1 / self.cp_spacing[0]
767
        )
768
        v_arange = 1 - np.arange(
769
            1 / (2 * self.cp_spacing[1]), 1, 1 / self.cp_spacing[1]
770
        )
771
        w_arange = 1 - np.arange(
772
            1 / (2 * self.cp_spacing[2]), 1, 1 / self.cp_spacing[2]
773
        )
774
775
        filter_idx = [[0, 1, 2, 3] for _ in range(3)]
776
        filter_coord = list(itertools.product(*filter_idx))
777
778
        for f_idx in filter_coord:
779
            for it_dim in range(3):
780
                filters[
781
                    f_idx[0] * self.cp_spacing[0] : (f_idx[0] + 1) * self.cp_spacing[0],
782
                    f_idx[1] * self.cp_spacing[1] : (f_idx[1] + 1) * self.cp_spacing[1],
783
                    f_idx[2] * self.cp_spacing[2] : (f_idx[2] + 1) * self.cp_spacing[2],
784
                    it_dim,
785
                    it_dim,
786
                ] = (
787
                    b[f_idx[0]](u_arange)[:, None, None]
788
                    * b[f_idx[1]](v_arange)[None, :, None]
789
                    * b[f_idx[2]](w_arange)[None, None, :]
790
                )
791
792
        self.filter = tf.convert_to_tensor(filters)
793
794
    def interpolate(self, field) -> tf.Tensor:
0 ignored issues
show
introduced by
"field" missing in parameter type documentation
Loading history...
795
        """
796
        :param field: tf.Tensor with shape=number_of_control_points_per_dim
797
        :return: interpolated_field: tf.Tensor
798
        """
799
800
        image_shape = tuple(
801
            [(a - 1) * b + 4 * b for a, b in zip(field.shape[1:-1], self.cp_spacing)]
802
        )
803
804
        output_shape = (field.shape[0],) + image_shape + (3,)
805
        return tf.nn.conv3d_transpose(
806
            field,
807
            self.filter,
808
            output_shape=output_shape,
809
            strides=self.cp_spacing,
810
            padding="VALID",
811
        )
812
813
    def call(self, inputs, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"inputs" missing in parameter type documentation
Loading history...
814
        """
815
        :param inputs: tf.Tensor defining a low resolution free-form deformation field
816
        :param kwargs: additional arguments.
817
        :return: interpolated_field: tf.Tensor of shape=self.input_shape
818
        """
819
        high_res_field = self.interpolate(inputs)
820
821
        index = [int(3 * c) for c in self.cp_spacing]
822
        return high_res_field[
823
            :,
824
            index[0] : index[0] + self._output_shape[0],
825
            index[1] : index[1] + self._output_shape[1],
826
            index[2] : index[2] + self._output_shape[2],
827
        ]
828