Passed
Pull Request — main (#656)
by Yunguan
02:46
created

deepreg.model.layer.IntDVF.__init__()   A

Complexity

Conditions 1

Size

Total Lines 20
Code Lines 11

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 11
dl 0
loc 20
rs 9.85
c 0
b 0
f 0
cc 1
nop 5
1
"""This module defines custom layers."""
2
import itertools
3
from typing import Tuple, Union
4
5
import numpy as np
6
import tensorflow as tf
7
import tensorflow.keras.layers as tfkl
8
9
import deepreg.model.layer_util as layer_util
10
11
LAYER_DICT = dict(conv3d=tfkl.Conv3D, deconv3d=tfkl.Conv3DTranspose)
12
NORM_DICT = dict(batch=tfkl.BatchNormalization, layer=tfkl.LayerNormalization)
13
14
15
def _deconv_output_padding(
0 ignored issues
show
introduced by
"ValueError" not documented as being raised
Loading history...
16
    input_shape: int, output_shape: int, kernel_size: int, stride: int, padding: str
17
) -> int:
18
    """
19
    Calculate output padding for Conv3DTranspose in 1D.
20
21
    - output_shape = (input_shape - 1)*stride + kernel_size - 2*pad + output_padding
22
    - output_padding = output_shape - ((input_shape - 1)*stride + kernel_size - 2*pad)
23
24
    Reference:
25
26
    - https://github.com/tensorflow/tensorflow/blob/r2.3/tensorflow/python/keras/utils/conv_utils.py#L140
27
28
    :param input_shape: shape of input tensor, without batch or channel
29
    :param output_shape: shape of out tensor, without batch or channel
30
    :param kernel_size: kernel size of Conv3DTranspose layer
31
    :param stride: stride of Conv3DTranspose layer
32
    :param padding: padding of Conv3DTranspose layer
33
    :return: output_padding
34
    """
35
    if padding == "same":
36
        pad = kernel_size // 2
37
    elif padding == "valid":
38
        pad = 0
39
    elif padding == "full":
40
        pad = kernel_size - 1
41
    else:
42
        raise ValueError(f"Unknown padding {padding} in deconv_output_padding")
43
    return output_shape - ((input_shape - 1) * stride + kernel_size - 2 * pad)
44
45
46
def deconv_output_padding(
47
    input_shape: Union[Tuple[int], int],
48
    output_shape: Union[Tuple[int], int],
49
    kernel_size: Union[Tuple[int], int],
50
    stride: Union[Tuple[int], int],
51
    padding: str,
52
) -> Union[Tuple[int], int]:
53
    """
54
    Calculate output padding for Conv3DTranspose in any dimension.
55
56
    :param input_shape: shape of input tensor, without batch or channel
57
    :param output_shape: shape of out tensor, without batch or channel
58
    :param kernel_size: kernel size of Conv3DTranspose layer
59
    :param stride: stride of Conv3DTranspose layer
60
    :param padding: padding of Conv3DTranspose layer
61
    :return: output_padding
62
    """
63
    if isinstance(input_shape, int):
64
        return _deconv_output_padding(
65
            input_shape=input_shape,
66
            output_shape=output_shape,
67
            kernel_size=kernel_size,
68
            stride=stride,
69
            padding=padding,
70
        )
71
    assert len(input_shape) == len(output_shape)
72
    dim = len(input_shape)
73
    if isinstance(kernel_size, int):
74
        kernel_size = [kernel_size] * dim
75
    if isinstance(stride, int):
76
        stride = [stride] * dim
77
    return tuple(
78
        [
79
            _deconv_output_padding(
80
                input_shape=input_shape[d],
81
                output_shape=output_shape[d],
82
                kernel_size=kernel_size[d],
83
                stride=stride[d],
84
                padding=padding,
85
            )
86
            for d in range(dim)
87
        ]
88
    )
89
90
91
class NormBlock(tfkl.Layer):
92
    """
93
    A block with layer - norm - activation.
94
    """
95
96
    def __init__(
97
        self,
98
        layer_name: str,
99
        norm_name: str = "batch",
100
        activation: str = "relu",
101
        name: str = "norm_block",
102
        **kwargs,
103
    ):
104
        """
105
        Init.
106
107
        :param layer_name: class of the layer to be wrapped.
108
        :param norm_name: class of the normalization layer.
109
        :param activation: name of activation.
110
        :param name: name of the block layer.
111
        :param kwargs: additional arguments.
112
        """
113
        super().__init__()
114
        self._config = dict(
115
            layer_name=layer_name,
116
            norm_name=norm_name,
117
            activation=activation,
118
            name=name,
119
            **kwargs,
120
        )
121
        self._layer = LAYER_DICT[layer_name](use_bias=False, **kwargs)
122
        self._norm = NORM_DICT[norm_name]()
123
        self._act = tfkl.Activation(activation=activation)
124
125
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
126
        """
127
        Forward.
128
129
        :param inputs: inputs for the layer
130
        :param training: training flag for normalization layers (default: None)
131
        :param kwargs: additional arguments.
132
        :return:
133
        """
134
        output = self._layer(inputs=inputs)
135
        output = self._norm(inputs=output, training=training)
136
        output = self._act(output)
137
        return output
138
139
    def get_config(self) -> dict:
140
        """Return the config dictionary for recreating this class."""
141
        config = super().get_config()
142
        config.update(self._config)
143
        return config
144
145
146
class Conv3dBlock(NormBlock):
147
    """
148
    A conv3d block having conv3d - norm - activation.
149
    """
150
151
    def __init__(
152
        self,
153
        name: str = "conv3d_block",
154
        **kwargs,
155
    ):
156
        """
157
        Init.
158
159
        :param name: name of the layer
160
        :param kwargs: additional arguments.
161
        """
162
        super().__init__(layer_name="conv3d", name=name, **kwargs)
163
164
165
class Deconv3dBlock(NormBlock):
166
    """
167
    A deconv3d block having conv3d - norm - activation.
168
    """
169
170
    def __init__(
171
        self,
172
        name: str = "deconv3d_block",
173
        **kwargs,
174
    ):
175
        """
176
        Init.
177
178
        :param name: name of the layer
179
        :param kwargs: additional arguments.
180
        """
181
        super().__init__(layer_name="deconv3d", name=name, **kwargs)
182
183
184
class Resize3d(tfkl.Layer):
185
    """
186
    Resize image in two folds.
187
188
    - resize dim2 and dim3
189
    - resize dim1 and dim2
190
    """
191
192
    def __init__(
193
        self,
194
        shape: tuple,
195
        method: str = tf.image.ResizeMethod.BILINEAR,
196
        name: str = "resize3d",
197
    ):
198
        """
199
        Init, save arguments.
200
201
        :param shape: (dim1, dim2, dim3)
202
        :param method: tf.image.ResizeMethod
203
        :param name: name of the layer
204
        """
205
        super().__init__(name=name)
206
        assert len(shape) == 3
207
        self._shape = shape
208
        self._method = method
209
210
    def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"ValueError" not documented as being raised
Loading history...
211
        """
212
        Perform two fold resize.
213
214
        :param inputs: shape = (batch, dim1, dim2, dim3, channels)
215
                                     or (batch, dim1, dim2, dim3)
216
                                     or (dim1, dim2, dim3)
217
        :param kwargs: additional arguments
218
        :return: shape = (batch, out_dim1, out_dim2, out_dim3, channels)
219
                                or (batch, dim1, dim2, dim3)
220
                                or (dim1, dim2, dim3)
221
        """
222
        # sanity check
223
        image = inputs
224
        image_dim = len(image.shape)
225
226
        # init
227
        if image_dim == 5:
228
            has_channel = True
229
            has_batch = True
230
            input_image_shape = image.shape[1:4]
231
        elif image_dim == 4:
232
            has_channel = False
233
            has_batch = True
234
            input_image_shape = image.shape[1:4]
235
        elif image_dim == 3:
236
            has_channel = False
237
            has_batch = False
238
            input_image_shape = image.shape[0:3]
239
        else:
240
            raise ValueError(
241
                "Resize3d takes input image of dimension 3 or 4 or 5, "
242
                "corresponding to (dim1, dim2, dim3) "
243
                "or (batch, dim1, dim2, dim3) "
244
                "or (batch, dim1, dim2, dim3, channels), "
245
                "got image shape{}".format(image.shape)
246
            )
247
248
        # no need of resize
249
        if input_image_shape == tuple(self._shape):
250
            return image
251
252
        # expand to five dimensions
253
        if not has_batch:
254
            image = tf.expand_dims(image, axis=0)
255
        if not has_channel:
256
            image = tf.expand_dims(image, axis=-1)
257
        assert len(image.shape) == 5  # (batch, dim1, dim2, dim3, channels)
258
        image_shape = tf.shape(image)
259
260
        # merge axis 0 and 1
261
        output = tf.reshape(
262
            image, (-1, image_shape[2], image_shape[3], image_shape[4])
263
        )  # (batch * dim1, dim2, dim3, channels)
264
265
        # resize dim2 and dim3
266
        output = tf.image.resize(
267
            images=output, size=self._shape[1:3], method=self._method
268
        )  # (batch * dim1, out_dim2, out_dim3, channels)
269
270
        # split axis 0 and merge axis 3 and 4
271
        output = tf.reshape(
272
            output,
273
            shape=(-1, image_shape[1], self._shape[1], self._shape[2] * image_shape[4]),
274
        )  # (batch, dim1, out_dim2, out_dim3 * channels)
275
276
        # resize dim1 and dim2
277
        output = tf.image.resize(
278
            images=output, size=self._shape[:2], method=self._method
279
        )  # (batch, out_dim1, out_dim2, out_dim3 * channels)
280
281
        # reshape
282
        output = tf.reshape(
283
            output, shape=[-1, *self._shape, image_shape[4]]
284
        )  # (batch, out_dim1, out_dim2, out_dim3, channels)
285
286
        # squeeze to original dimension
287
        if not has_batch:
288
            output = tf.squeeze(output, axis=0)
289
        if not has_channel:
290
            output = tf.squeeze(output, axis=-1)
291
        return output
292
293
    def get_config(self) -> dict:
294
        """Return the config dictionary for recreating this class."""
295
        config = super().get_config()
296
        config["shape"] = self._shape
297
        config["method"] = self._method
298
        return config
299
300
301
class Warping(tfkl.Layer):
302
    """
303
    Warps an image with DDF.
304
305
    Reference:
306
307
    https://github.com/adalca/neurite/blob/legacy/neuron/utils.py
308
    where vol = image, loc_shift = ddf
309
    """
310
311
    def __init__(self, fixed_image_size: tuple, name: str = "warping", **kwargs):
312
        """
313
        Init.
314
315
        :param fixed_image_size: shape = (f_dim1, f_dim2, f_dim3)
316
             or (f_dim1, f_dim2, f_dim3, ch) with the last channel for features
317
        :param name: name of the layer
318
        :param kwargs: additional arguments.
319
        """
320
        super().__init__(name=name, **kwargs)
321
        self._fixed_image_size = fixed_image_size
322
        # shape = (1, f_dim1, f_dim2, f_dim3, 3)
323
        self.grid_ref = layer_util.get_reference_grid(grid_size=fixed_image_size)
324
        self.grid_ref = self.grid_ref[None, ...]
325
326
    def call(self, inputs, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"inputs" missing in parameter type documentation
Loading history...
327
        """
328
        :param inputs: (ddf, image)
329
330
          - ddf, shape = (batch, f_dim1, f_dim2, f_dim3, 3)
331
          - image, shape = (batch, m_dim1, m_dim2, m_dim3)
332
        :param kwargs: additional arguments.
333
        :return: shape = (batch, f_dim1, f_dim2, f_dim3)
334
        """
335
        ddf, image = inputs
336
        return layer_util.resample(vol=image, loc=self.grid_ref + ddf)
337
338
    def get_config(self) -> dict:
339
        """Return the config dictionary for recreating this class."""
340
        config = super().get_config()
341
        config["fixed_image_size"] = self._fixed_image_size
342
        return config
343
344
345
class ResidualBlock(tfkl.Layer):
346
    """
347
    A block with skip links and layer - norm - activation.
348
    """
349
350
    def __init__(
351
        self,
352
        layer_name: str,
353
        num_layers: int = 2,
354
        norm_name: str = "batch",
355
        activation: str = "relu",
356
        name: str = "res_block",
357
        **kwargs,
358
    ):
359
        """
360
        Init.
361
362
        :param layer_name: class of the layer to be wrapped.
363
        :param num_layers: number of layers/blocks.
364
        :param norm_name: class of the normalization layer.
365
        :param activation: name of activation.
366
        :param name: name of the block layer.
367
        :param kwargs: additional arguments.
368
        """
369
        super().__init__()
370
        self._num_layers = num_layers
371
        self._config = dict(
372
            layer_name=layer_name,
373
            num_layers=num_layers,
374
            norm_name=norm_name,
375
            activation=activation,
376
            name=name,
377
            **kwargs,
378
        )
379
        self._layers = [
380
            LAYER_DICT[layer_name](use_bias=False, **kwargs) for _ in range(num_layers)
381
        ]
382
        self._norms = [NORM_DICT[norm_name]() for _ in range(num_layers)]
383
        self._acts = [tfkl.Activation(activation=activation) for _ in range(num_layers)]
384
385
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
386
        """
387
        Forward.
388
389
        :param inputs: inputs for the layer
390
        :param training: training flag for normalization layers (default: None)
391
        :param kwargs: additional arguments.
392
        :return:
393
        """
394
395
        output = inputs
396
        for i in range(self._num_layers):
397
            output = self._layers[i](inputs=output)
398
            output = self._norms[i](inputs=output, training=training)
399
            if i == self._num_layers - 1:
400
                # last block
401
                output = output + inputs
402
            output = self._acts[i](output)
403
        return output
404
405
    def get_config(self) -> dict:
406
        """Return the config dictionary for recreating this class."""
407
        config = super().get_config()
408
        config.update(self._config)
409
        return config
410
411
412
class ResidualConv3dBlock(ResidualBlock):
413
    """
414
    A conv3d residual block
415
    """
416
417
    def __init__(
418
        self,
419
        name: str = "conv3d_res_block",
420
        **kwargs,
421
    ):
422
        """
423
        Init.
424
425
        :param name: name of the layer
426
        :param kwargs: additional arguments.
427
        """
428
        super().__init__(layer_name="conv3d", name=name, **kwargs)
429
430
431
class IntDVF(tfkl.Layer):
432
    """
433
    Integrate DVF to get DDF.
434
435
    Reference:
436
437
    - integrate_vec of neuron
438
      https://github.com/adalca/neurite/blob/legacy/neuron/utils.py
439
    """
440
441
    def __init__(
442
        self,
443
        fixed_image_size: tuple,
444
        num_steps: int = 7,
445
        name: str = "int_dvf",
446
        **kwargs,
447
    ):
448
        """
449
        Init.
450
451
        :param fixed_image_size: tuple, (f_dim1, f_dim2, f_dim3)
452
        :param num_steps: int, number of steps for integration
453
        :param name: name of the layer
454
        :param kwargs: additional arguments.
455
        """
456
        super().__init__(name=name, **kwargs)
457
        assert len(fixed_image_size) == 3
458
        self._fixed_image_size = fixed_image_size
459
        self._num_steps = num_steps
460
        self._warping = Warping(fixed_image_size=fixed_image_size)
461
462
    def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor:
463
        """
464
        :param inputs: dvf, shape = (batch, f_dim1, f_dim2, f_dim3, 3)
465
        :param kwargs: additional arguments.
466
        :return: ddf, shape = (batch, f_dim1, f_dim2, f_dim3, 3)
467
        """
468
        ddf = inputs / (2 ** self._num_steps)
469
        for _ in range(self._num_steps):
470
            ddf += self._warping(inputs=[ddf, ddf])
471
        return ddf
472
473
    def get_config(self) -> dict:
474
        """Return the config dictionary for recreating this class."""
475
        config = super().get_config()
476
        config["fixed_image_size"] = self._fixed_image_size
477
        config["num_steps"] = self._num_steps
478
        return config
479
480
481
class LocalNetUpSampleResnetBlock(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
482
    def __init__(
483
        self,
484
        filters: int,
485
        output_padding: tuple,
486
        output_shape: tuple,
487
        use_additive_upsampling: bool = True,
488
        **kwargs,
489
    ):
490
        """
491
        Layer up-samples tensor with two inputs (skipped and down-sampled).
492
493
        :param filters: int, number of output channels
494
        :param output_padding: output padding for deconv block
495
        :param output_shape: shape of the output
496
        :param use_additive_upsampling: bool to used additive upsampling
497
        :param kwargs: additional arguments.
498
        """
499
        super().__init__(**kwargs)
500
        # save parameters
501
        self._use_additive_upsampling = use_additive_upsampling
502
        # init layer variables
503
        self._deconv3d_block = Deconv3dBlock(
504
            filters=filters,
505
            output_padding=output_padding,
506
            kernel_size=3,
507
            strides=2,
508
            padding="same",
509
        )
510
        if self._use_additive_upsampling:
511
            self._resize = Resize3d(shape=output_shape)
512
        self._conv3d_block = Conv3dBlock(filters=filters, kernel_size=3, padding="same")
513
        self._residual_block = ResidualConv3dBlock(
514
            filters=filters, kernel_size=3, padding="same"
515
        )
516
517
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
518
        """
519
        :param inputs: list = [inputs_nonskip, inputs_skip]
520
        :param training: training flag for normalization layers (default: None)
521
        :param kwargs: additional arguments.
522
        :return:
523
        """
524
        inputs_nonskip, inputs_skip = inputs[0], inputs[1]
525
        h0 = self._deconv3d_block(inputs=inputs_nonskip, training=training)
526
        if self._use_additive_upsampling:
527
            upsampled = self._resize(inputs=inputs_nonskip)
528
            upsampled = tf.split(upsampled, num_or_size_splits=2, axis=4)
0 ignored issues
show
Unused Code introduced by
Argument 'axis' passed by position and keyword in function call
Loading history...
529
            upsampled = tf.add_n(upsampled)
530
            h0 = h0 + upsampled
531
        r1 = h0 + inputs_skip
532
        h1 = self._residual_block(inputs=r1, training=training)
533
        return h1
534
535
536
class ResizeCPTransform(tfkl.Layer):
537
    """
538
    Layer for getting the control points from the output of a image-to-image network.
539
    It uses an anti-aliasing Gaussian filter before downsampling.
540
    """
541
542
    def __init__(self, control_point_spacing: (list, tuple, int), **kwargs):
543
        """
544
        :param control_point_spacing: list or int
545
        :param kwargs: additional arguments.
546
        """
547
        super().__init__(**kwargs)
548
549
        if isinstance(control_point_spacing, int):
550
            control_point_spacing = [control_point_spacing] * 3
551
552
        self.kernel_sigma = [
553
            0.44 * cp for cp in control_point_spacing
554
        ]  # 0.44 = ln(4)/pi
555
        self.cp_spacing = control_point_spacing
556
        self.kernel = None
557
        self._output_shape = None
558
        self._resize = None
559
560
    def build(self, input_shape):
561
        super().build(input_shape=input_shape)
562
563
        self.kernel = layer_util.gaussian_filter_3d(self.kernel_sigma)
564
        output_shape = [
565
            tf.cast(tf.math.ceil(v / c) + 3, tf.int32)
566
            for v, c in zip(input_shape[1:-1], self.cp_spacing)
567
        ]
568
        self._output_shape = output_shape
569
        self._resize = Resize3d(output_shape)
570
571
    def call(self, inputs, **kwargs) -> tf.Tensor:
572
        output = tf.nn.conv3d(
573
            inputs, self.kernel, strides=(1, 1, 1, 1, 1), padding="SAME"
574
        )
575
        output = self._resize(inputs=output)
576
        return output
577
578
579
class BSplines3DTransform(tfkl.Layer):
580
    """
581
     Layer for BSplines interpolation with precomputed cubic spline kernel_size.
582
     It assumes a full sized image from which:
583
     1. it compute the contol points values by downsampling the initial image
584
     2. performs the interpolation
585
     3. crops the image around the valid values.
586
587
    :param cp_spacing: int or tuple of three ints specifying the spacing (in pixels)
588
        in each dimension. When a single int is used,
589
        the same spacing to all dimensions is used
590
    :param output_shape: (batch_size, dim0, dim1, dim2, 3) of the high resolution
591
        deformation fields.
592
    :param kwargs: additional arguments.
593
    """
594
595
    def __init__(self, cp_spacing: (int, tuple), output_shape: tuple, **kwargs):
596
597
        super().__init__(**kwargs)
598
599
        self.filters = []
600
        self._output_shape = output_shape
601
602
        if isinstance(cp_spacing, int):
603
            self.cp_spacing = (cp_spacing, cp_spacing, cp_spacing)
604
        else:
605
            self.cp_spacing = cp_spacing
606
607
    def build(self, input_shape: tuple):
0 ignored issues
show
introduced by
Redundant returns documentation
Loading history...
608
        """
609
        :param input_shape: tuple with the input shape
610
        :return: None
611
        """
612
613
        super().build(input_shape=input_shape)
614
615
        b = {
616
            0: lambda u: np.float64((1 - u) ** 3 / 6),
617
            1: lambda u: np.float64((3 * (u ** 3) - 6 * (u ** 2) + 4) / 6),
618
            2: lambda u: np.float64((-3 * (u ** 3) + 3 * (u ** 2) + 3 * u + 1) / 6),
619
            3: lambda u: np.float64(u ** 3 / 6),
620
        }
621
622
        filters = np.zeros(
623
            (
624
                4 * self.cp_spacing[0],
625
                4 * self.cp_spacing[1],
626
                4 * self.cp_spacing[2],
627
                3,
628
                3,
629
            ),
630
            dtype=np.float32,
631
        )
632
633
        u_arange = 1 - np.arange(
634
            1 / (2 * self.cp_spacing[0]), 1, 1 / self.cp_spacing[0]
635
        )
636
        v_arange = 1 - np.arange(
637
            1 / (2 * self.cp_spacing[1]), 1, 1 / self.cp_spacing[1]
638
        )
639
        w_arange = 1 - np.arange(
640
            1 / (2 * self.cp_spacing[2]), 1, 1 / self.cp_spacing[2]
641
        )
642
643
        filter_idx = [[0, 1, 2, 3] for _ in range(3)]
644
        filter_coord = list(itertools.product(*filter_idx))
645
646
        for f_idx in filter_coord:
647
            for it_dim in range(3):
648
                filters[
649
                    f_idx[0] * self.cp_spacing[0] : (f_idx[0] + 1) * self.cp_spacing[0],
650
                    f_idx[1] * self.cp_spacing[1] : (f_idx[1] + 1) * self.cp_spacing[1],
651
                    f_idx[2] * self.cp_spacing[2] : (f_idx[2] + 1) * self.cp_spacing[2],
652
                    it_dim,
653
                    it_dim,
654
                ] = (
655
                    b[f_idx[0]](u_arange)[:, None, None]
656
                    * b[f_idx[1]](v_arange)[None, :, None]
657
                    * b[f_idx[2]](w_arange)[None, None, :]
658
                )
659
660
        self.filter = tf.convert_to_tensor(filters)
661
662
    def interpolate(self, field) -> tf.Tensor:
0 ignored issues
show
introduced by
"field" missing in parameter type documentation
Loading history...
663
        """
664
        :param field: tf.Tensor with shape=number_of_control_points_per_dim
665
        :return: interpolated_field: tf.Tensor
666
        """
667
668
        image_shape = tuple(
669
            [(a - 1) * b + 4 * b for a, b in zip(field.shape[1:-1], self.cp_spacing)]
670
        )
671
672
        output_shape = (field.shape[0],) + image_shape + (3,)
673
        return tf.nn.conv3d_transpose(
674
            field,
675
            self.filter,
676
            output_shape=output_shape,
677
            strides=self.cp_spacing,
678
            padding="VALID",
679
        )
680
681
    def call(self, inputs, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"inputs" missing in parameter type documentation
Loading history...
682
        """
683
        :param inputs: tf.Tensor defining a low resolution free-form deformation field
684
        :param kwargs: additional arguments.
685
        :return: interpolated_field: tf.Tensor of shape=self.input_shape
686
        """
687
        high_res_field = self.interpolate(inputs)
688
689
        index = [int(3 * c) for c in self.cp_spacing]
690
        return high_res_field[
691
            :,
692
            index[0] : index[0] + self._output_shape[0],
693
            index[1] : index[1] + self._output_shape[1],
694
            index[2] : index[2] + self._output_shape[2],
695
        ]
696