Passed
Pull Request — main (#656)
by Yunguan
03:08
created

deepreg.model.layer.Deconv3d.__init__()   A

Complexity

Conditions 1

Size

Total Lines 33
Code Lines 19

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 19
dl 0
loc 33
rs 9.45
c 0
b 0
f 0
cc 1
nop 8

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
import itertools
0 ignored issues
show
introduced by
Missing module docstring
Loading history...
2
3
import numpy as np
4
import tensorflow as tf
5
import tensorflow.keras.layers as tfkl
6
7
import deepreg.model.layer_util as layer_util
8
9
10
class Deconv3d(tfkl.Layer):
11
    """
12
    Wrap Conv3DTranspose to allow dynamic output padding calculation.
13
    """
14
15
    def __init__(
16
        self,
17
        filters: int,
18
        output_shape: (tuple, None) = None,
19
        kernel_size: int = 3,
20
        strides: int = 1,
21
        padding: str = "same",
22
        use_bias: bool = True,
23
        **kwargs,
24
    ):
25
        """
26
        Init.
27
28
        :param filters: number of channels of the output
29
        :param output_shape: (out_dim1, out_dim2, out_dim3)
30
        :param kernel_size: int or tuple of 3 ints, e.g. (3,3,3) or 3
31
        :param strides: int or tuple of 3 ints, e.g. (1,1,1) or 1
32
        :param padding: same or valid.
33
        :param use_bias: use bias for Conv3DTranspose or not.
34
        :param kwargs: additional arguments.
35
        """
36
        super().__init__()
37
        # save arguments
38
        self._filters = filters
39
        self._output_shape = output_shape
40
        self._kernel_size = kernel_size
41
        self._strides = strides
42
        self._padding = padding
43
        self._use_bias = use_bias
44
        self._kwargs = kwargs
45
        # init layer variables
46
        self._output_padding = None
47
        self._deconv3d = None
48
49
    def build(self, input_shape):
0 ignored issues
show
introduced by
"input_shape" missing in parameter documentation
Loading history...
introduced by
"input_shape" missing in parameter type documentation
Loading history...
50
        # pylint: disable-next=line-too-long
51
        """
52
        Calculate output padding on the fly.
53
54
        https://github.com/tensorflow/tensorflow/blob/1cf0898dd4331baf93fe77205550f2c2e6c90ee5/tensorflow/python/keras/utils/conv_utils.py#L139-L185
55
        When the output shape is defined, the padding should be calculated manually
56
        if padding == 'same':
57
            pad = filter_size // 2
58
            length = ((input_length - 1) * stride + filter_size
59
                     - 2 * pad + output_padding)
60
        """
61
        super().build(input_shape)
62
63
        if isinstance(self._kernel_size, int):
64
            self._kernel_size = [self._kernel_size] * 3
65
        if isinstance(self._strides, int):
66
            self._strides = [self._strides] * 3
67
68
        if self._output_shape is not None:
69
            self._padding = "same"
70
            self._output_padding = [
71
                self._output_shape[i]
72
                - (
73
                    (input_shape[1 + i] - 1) * self._strides[i]
74
                    + self._kernel_size[i]
75
                    - 2 * (self._kernel_size[i] // 2)
76
                )
77
                for i in range(3)
78
            ]
79
        self._deconv3d = tfkl.Conv3DTranspose(
80
            filters=self._filters,
81
            kernel_size=self._kernel_size,
82
            strides=self._strides,
83
            padding=self._padding,
84
            output_padding=self._output_padding,
85
            use_bias=self._use_bias,
86
            **self._kwargs,
87
        )
88
89
    def call(self, inputs: tf.Tensor, **kwargs):
0 ignored issues
show
introduced by
Missing return type documentation
Loading history...
90
        """
91
        Forward.
92
93
        :param inputs: input tensor.
94
        :param kwargs: additional arguments
95
        :return:
96
        """
97
        return self._deconv3d(inputs=inputs)
98
99
    def get_config(self) -> dict:
100
        """Return the config dictionary for recreating this class."""
101
        config = super().get_config
102
        config.update(
103
            dict(
104
                filters=self._filters,
105
                output_shape=self._output_shape,
106
                kernel_size=self._kernel_size,
107
                strides=self._strides,
108
                padding=self._padding,
109
                use_bias=self._use_bias,
110
            )
111
        )
112
        config.update(self._kwargs)
113
114
115
class Conv3dBlock(tfkl.Layer):
116
    """
117
    A conv3d block having conv3d - norm - activation.
118
    """
119
120
    def __init__(
121
        self,
122
        filters: int,
123
        kernel_size: (int, tuple) = 3,
124
        strides: (int, tuple) = 1,
125
        padding: str = "same",
126
        activation: str = "relu",
127
        **kwargs,
128
    ):
129
        """
130
        Init.
131
132
        :param filters: number of channels of the output
133
        :param kernel_size: int or tuple of 3 ints, e.g. (3,3,3) or 3
134
        :param strides: int or tuple of 3 ints, e.g. (1,1,1) or 1
135
        :param padding: str, same or valid
136
        :param activation: name of activation
137
        :param kwargs: additional arguments.
138
        """
139
        super().__init__(**kwargs)
140
        # save arguments
141
        self._filters = filters
142
        self._kernel_size = kernel_size
143
        self._strides = strides
144
        self._padding = padding
145
146
        # init layer variables
147
        self._conv3d = tfkl.Conv3D(
148
            filters=filters,
149
            kernel_size=kernel_size,
150
            strides=strides,
151
            padding=padding,
152
            use_bias=False,
153
        )
154
        self._norm = tfkl.BatchNormalization()
155
        self._act = tfkl.Activation(activation=activation)
156
157
    def call(self, inputs: tf.Tensor, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"training" missing in parameter type documentation
Loading history...
158
        """
159
        Forward.
160
161
        :param inputs: shape = (batch, in_dim1, in_dim2, in_dim3, channels)
162
        :param training: training flag for normalization layers (default: None)
163
        :param kwargs: additional arguments.
164
        :return: shape = (batch, in_dim1, in_dim2, in_dim3, channels)
165
        """
166
        output = self._conv3d(inputs=inputs)
167
        output = self._norm(inputs=output, training=training)
168
        output = self._act(output)
169
        return output
170
171
    def get_config(self) -> dict:
172
        """Return the config dictionary for recreating this class."""
173
        config = super().get_config
174
        config.update(
175
            dict(
176
                filters=self._filters,
177
                kernel_size=self._kernel_size,
178
                strides=self._strides,
179
                padding=self._padding,
180
                use_bias=self._use_bias,
181
            )
182
        )
183
184
185
class Deconv3dBlock(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
186
    def __init__(
187
        self,
188
        filters: int,
189
        output_shape: (tuple, None) = None,
190
        kernel_size: (int, tuple) = 3,
191
        strides: (int, tuple) = 1,
192
        padding: str = "same",
193
        activation: str = "relu",
194
        **kwargs,
195
    ):
196
        """
197
        A deconv3d block having deconv3d - norm - activation.
198
199
        :param filters: number of channels of the output
200
        :param output_shape: (out_dim1, out_dim2, out_dim3)
201
        :param kernel_size: int or tuple of 3 ints, e.g. (3,3,3) or 3
202
        :param strides: int or tuple of 3 ints, e.g. (1,1,1) or 1
203
        :param padding: str, same or valid
204
        :param activation: name of activation
205
        :param kwargs: additional arguments.
206
        """
207
        super().__init__(**kwargs)
208
        # init layer variables
209
        self._deconv3d = Deconv3d(
210
            filters=filters,
211
            output_shape=output_shape,
212
            kernel_size=kernel_size,
213
            strides=strides,
214
            padding=padding,
215
            use_bias=False,
216
        )
217
        self._norm = tfkl.BatchNormalization()
218
        self._act = tfkl.Activation(activation=activation)
219
220
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
introduced by
Missing return documentation
Loading history...
221
        """
222
        :param inputs: shape = (batch, in_dim1, in_dim2, in_dim3, channels)
223
        :param training: training flag for normalization layers (default: None)
224
        :param kwargs: additional arguments.
225
        :return output: shape = (batch, in_dim1, in_dim2, in_dim3, channels)
226
        """
227
        output = self._deconv3d(inputs=inputs)
228
        output = self._norm(inputs=output, training=training)
229
        output = self._act(output)
230
        return output
231
232
233
class Residual3dBlock(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
234
    def __init__(
235
        self,
236
        filters: int,
237
        kernel_size: (int, tuple) = 3,
238
        strides: (int, tuple) = 1,
239
        activation: str = "relu",
240
        **kwargs,
241
    ):
242
        """
243
        A resnet conv3d block.
244
245
        1. conved = conv3d(conv3d_block(inputs))
246
        2. out = act(norm(conved) + inputs)
247
248
        :param filters: int, number of filters in the convolutional layers
249
        :param kernel_size: int or tuple of 3 ints, e.g. (3,3,3) or 3
250
        :param strides: int or tuple of 3 ints, e.g. (1,1,1) or 1
251
        :param activation: name of activation
252
        :param kwargs: additional arguments.
253
        """
254
        super().__init__(**kwargs)
255
        # init layer variables
256
        self._conv3d_block = Conv3dBlock(
257
            filters=filters, kernel_size=kernel_size, strides=strides
258
        )
259
        self._conv3d = tfkl.Conv3D(
260
            filters=filters,
261
            kernel_size=kernel_size,
262
            strides=strides,
263
            padding="same",
264
            use_bias=False,
265
        )
266
        self._norm = tfkl.BatchNormalization()
267
        self._act = tfkl.Activation(activation=activation)
268
269
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
introduced by
Missing return documentation
Loading history...
270
        """
271
        :param inputs: shape = (batch, in_dim1, in_dim2, in_dim3, channels)
272
        :param training: training flag for normalization layers (default: None)
273
        :param kwargs: additional arguments.
274
        :return output: shape = (batch, in_dim1, in_dim2, in_dim3, channels)
275
        """
276
        return self._act(
277
            self._norm(
278
                inputs=self._conv3d(inputs=self._conv3d_block(inputs)),
279
                training=training,
280
            )
281
            + inputs
282
        )
283
284
285
class DownSampleResnetBlock(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
286
    def __init__(
287
        self,
288
        filters: int,
289
        kernel_size: (int, tuple) = 3,
290
        pooling: bool = True,
291
        **kwargs,
292
    ):
293
        """
294
        A down-sampling resnet conv3d block, with max-pooling or conv3d.
295
296
        1. conved = conv3d_block(inputs)  # adjust channel
297
        2. skip = residual_block(conved)  # develop feature
298
        3. pooled = pool(skip) # down-sample
299
300
        :param filters: number of channels of the output
301
        :param kernel_size: int or tuple of 3 ints, e.g. (3,3,3) or 3
302
        :param pooling: if True, use max pooling to downsample, otherwise use conv.
303
        :param kwargs: additional arguments.
304
        """
305
        super().__init__(**kwargs)
306
        # save parameters
307
        self._pooling = pooling
308
        # init layer variables
309
        self._conv3d_block = Conv3dBlock(filters=filters, kernel_size=kernel_size)
310
        self._residual_block = Residual3dBlock(filters=filters, kernel_size=kernel_size)
311
        self._max_pool3d = (
312
            tfkl.MaxPool3D(pool_size=(2, 2, 2), strides=(2, 2, 2), padding="same")
313
            if pooling
314
            else None
315
        )
316
        self._conv3d_block3 = (
317
            None
318
            if pooling
319
            else Conv3dBlock(filters=filters, kernel_size=kernel_size, strides=2)
320
        )
321
322
    def call(self, inputs, training=None, **kwargs) -> (tf.Tensor, tf.Tensor):
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
323
        """
324
        :param inputs: shape = (batch, in_dim1, in_dim2, in_dim3, channels)
325
        :param training: training flag for normalization layers (default: None)
326
        :param kwargs: additional arguments.
327
        :return: (pooled, skip)
328
329
          - downsampled, shape = (batch, in_dim1//2, in_dim2//2, in_dim3//2, channels)
330
          - skipped, shape = (batch, in_dim1, in_dim2, in_dim3, channels)
331
        """
332
        conved = self._conv3d_block(inputs=inputs, training=training)  # adjust channel
333
        skip = self._residual_block(inputs=conved, training=training)  # develop feature
334
        pooled = (
335
            self._max_pool3d(inputs=skip)
336
            if self._pooling
337
            else self._conv3d_block3(inputs=skip, training=training)
0 ignored issues
show
Bug introduced by
self._conv3d_block3 does not seem to be callable.
Loading history...
338
        )  # downsample
339
        return pooled, skip
340
341
342
class UpSampleResnetBlock(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
343
    def __init__(self, filters, kernel_size=3, concat=False, **kwargs):
0 ignored issues
show
introduced by
"concat, filters, kernel_size" missing in parameter type documentation
Loading history...
344
        """
345
        An up-sampling resnet conv3d block, with deconv3d.
346
347
        :param filters: number of channels of the output
348
        :param kernel_size: int or tuple of 3 ints, e.g. (3,3,3) or 3
349
        :param concat: bool,specify how to combine input and skip connection images.
350
            If True, use concatenation, otherwise use sum (default=False).
351
        :param kwargs: additional arguments.
352
        """
353
        super().__init__(**kwargs)
354
        # save parameters
355
        self._filters = filters
356
        self._concat = concat
357
        # init layer variables
358
        self._deconv3d_block = None
359
        self._conv3d_block = Conv3dBlock(filters=filters, kernel_size=kernel_size)
360
        self._residual_block = Residual3dBlock(filters=filters, kernel_size=kernel_size)
361
362
    def build(self, input_shape):
0 ignored issues
show
introduced by
"input_shape" missing in parameter type documentation
Loading history...
363
        """
364
        :param input_shape: tuple, (downsampled_image_shape, skip_image_shape)
365
        """
366
        super().build(input_shape)
367
        skip_shape = input_shape[1][1:4]
368
        self._deconv3d_block = Deconv3dBlock(
369
            filters=self._filters, output_shape=skip_shape, strides=2
370
        )
371
372
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
373
        r"""
374
        :param inputs: tuple
375
376
          - down-sampled
377
          - skipped
378
379
        :param training: training flag for normalization layers (default: None)
380
        :param kwargs: additional arguments.
381
        :return: shape = (batch, \*skip_connection_image_shape, filters]
382
        """
383
        up_sampled, skip = inputs[0], inputs[1]
384
        up_sampled = self._deconv3d_block(
385
            inputs=up_sampled, training=training
386
        )  # up sample and change channel
387
        up_sampled = (
388
            tf.concat([up_sampled, skip], axis=4) if self._concat else up_sampled + skip
389
        )  # combine
390
        up_sampled = self._conv3d_block(
391
            inputs=up_sampled, training=training
392
        )  # adjust channel
393
        up_sampled = self._residual_block(inputs=up_sampled, training=training)  # conv
394
        return up_sampled
395
396
397
class Conv3dWithResize(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
398
    def __init__(
399
        self,
400
        output_shape: tuple,
401
        filters: int,
402
        kernel_initializer: str = "glorot_uniform",
403
        activation: (str, None) = None,
404
        **kwargs,
405
    ):
406
        """
407
        A layer contains conv3d - resize3d.
408
409
        :param output_shape: tuple, (out_dim1, out_dim2, out_dim3)
410
        :param filters: int, number of channels of the output
411
        :param kernel_initializer: str, defines the initialization method
412
        :param activation: str, defines the activation function
413
        :param kwargs: additional arguments.
414
        """
415
        super().__init__(**kwargs)
416
        # save parameters
417
        self._output_shape = output_shape
418
        # init layer variables
419
        self._conv3d = tfkl.Conv3D(
420
            filters=filters,
421
            kernel_size=3,
422
            strides=1,
423
            padding="same",
424
            kernel_initializer=kernel_initializer,
425
            activation=activation,
426
        )  # if not zero, with init NN, ddf may be too large
427
428
    def call(self, inputs, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"inputs" missing in parameter type documentation
Loading history...
429
        """
430
        :param inputs: shape = (batch, dim1, dim2, dim3, channels)
431
        :param kwargs: additional arguments.
432
        :return: shape = (batch, out_dim1, out_dim2, out_dim3, channels)
433
        """
434
        output = self._conv3d(inputs=inputs)
435
        output = layer_util.resize3d(image=output, size=self._output_shape)
436
        return output
437
438
439
class Warping(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
440
    def __init__(self, fixed_image_size: tuple, **kwargs):
441
        """
442
        A layer warps an image using DDF.
443
444
        Reference:
445
446
        - transform of neuron
447
          https://github.com/adalca/neurite/blob/legacy/neuron/utils.py
448
449
          where vol = image, loc_shift = ddf
450
451
        :param fixed_image_size: shape = (f_dim1, f_dim2, f_dim3)
452
             or (f_dim1, f_dim2, f_dim3, ch) with the last channel for features
453
        :param kwargs: additional arguments.
454
        """
455
        super().__init__(**kwargs)
456
        self.grid_ref = layer_util.get_reference_grid(grid_size=fixed_image_size)[
457
            None, ...
458
        ]  # shape = (1, f_dim1, f_dim2, f_dim3, 3)
459
460
    def call(self, inputs, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"inputs" missing in parameter type documentation
Loading history...
461
        """
462
        :param inputs: (ddf, image)
463
464
          - ddf, shape = (batch, f_dim1, f_dim2, f_dim3, 3), dtype = float32
465
          - image, shape = (batch, m_dim1, m_dim2, m_dim3), dtype = float32
466
        :param kwargs: additional arguments.
467
        :return: shape = (batch, f_dim1, f_dim2, f_dim3)
468
        """
469
        return layer_util.resample(vol=inputs[1], loc=self.grid_ref + inputs[0])
470
471
472
class IntDVF(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
473
    def __init__(self, fixed_image_size: tuple, num_steps: int = 7, **kwargs):
474
        """
475
        Layer calculates DVF from DDF.
476
477
        Reference:
478
479
        - integrate_vec of neuron
480
          https://github.com/adalca/neurite/blob/legacy/neuron/utils.py
481
482
        :param fixed_image_size: tuple, (f_dim1, f_dim2, f_dim3)
483
        :param num_steps: int, number of steps for integration
484
        :param kwargs: additional arguments.
485
        """
486
        super().__init__(**kwargs)
487
        self._warping = Warping(fixed_image_size=fixed_image_size)
488
        self._num_steps = num_steps
489
490
    def call(self, inputs, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"inputs" missing in parameter type documentation
Loading history...
491
        """
492
        :param inputs: dvf, shape = (batch, f_dim1, f_dim2, f_dim3, 3), type = float32
493
        :param kwargs: additional arguments.
494
        :return: ddf, shape = (batch, f_dim1, f_dim2, f_dim3, 3)
495
        """
496
        ddf = inputs / (2 ** self._num_steps)
497
        for _ in range(self._num_steps):
498
            ddf += self._warping(inputs=[ddf, ddf])
499
        return ddf
500
501
502
class AdditiveUpSampling(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
503
    def __init__(self, output_shape: tuple, stride: (int, list) = 2, **kwargs):
504
        """
505
        Layer up-samples 3d tensor and reduce channels using split and sum.
506
507
        :param output_shape: (out_dim1, out_dim2, out_dim3)
508
        :param stride: int, 1-D Tensor or list
509
        :param kwargs: additional arguments.
510
        """
511
        super().__init__(**kwargs)
512
        # save parameters
513
        self._stride = stride
514
        self._output_shape = output_shape
515
516
    def call(self, inputs, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"inputs" missing in parameter type documentation
Loading history...
introduced by
"ValueError" not documented as being raised
Loading history...
517
        """
518
        :param inputs: shape = (batch, dim1, dim2, dim3, channels)
519
        :param kwargs: additional arguments.
520
        :return: shape = (batch, out_dim1, out_dim2, out_dim3, channels//stride]
521
        """
522
        if inputs.shape[4] % self._stride != 0:
523
            raise ValueError("The channel dimension can not be divided by the stride")
524
        output = layer_util.resize3d(image=inputs, size=self._output_shape)
525
        # a list of (batch, out_dim1, out_dim2, out_dim3, channels//stride)
526
        output = tf.split(output, num_or_size_splits=self._stride, axis=4)
0 ignored issues
show
Unused Code introduced by
Argument 'axis' passed by position and keyword in function call
Loading history...
527
        # (batch, out_dim1, out_dim2, out_dim3, channels//stride)
528
        output = tf.reduce_sum(tf.stack(output, axis=5), axis=5)
529
        return output
530
531
532
class LocalNetResidual3dBlock(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
533
    def __init__(
534
        self,
535
        filters: int,
536
        kernel_size: (int, tuple) = 3,
537
        strides: (int, tuple) = 1,
538
        activation: str = "relu",
539
        **kwargs,
540
    ):
541
        """
542
        A resnet conv3d block, simpler than Residual3dBlock.
543
544
        1. conved = conv3d(inputs)
545
        2. out = act(norm(conved) + inputs)
546
547
        :param filters: number of channels of the output
548
        :param kernel_size: int or tuple of 3 ints, e.g. (3,3,3) or 3
549
        :param strides: int or tuple of 3 ints, e.g. (1,1,1) or 1
550
        :param activation: name of activation
551
        :param kwargs: additional arguments.
552
        """
553
        super().__init__(**kwargs)
554
        # init layer variables
555
        self._conv3d = tfkl.Conv3D(
556
            filters=filters,
557
            kernel_size=kernel_size,
558
            strides=strides,
559
            padding="same",
560
            use_bias=False,
561
        )
562
        self._norm = tfkl.BatchNormalization()
563
        self._act = tfkl.Activation(activation=activation)
564
565
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
566
        return self._act(
567
            self._norm(inputs=self._conv3d(inputs=inputs[0]), training=training)
568
            + inputs[1]
569
        )
570
571
572
class LocalNetUpSampleResnetBlock(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
573
    def __init__(self, filters: int, use_additive_upsampling: bool = True, **kwargs):
574
        """
575
        Layer up-samples tensor with two inputs (skipped and down-sampled).
576
577
        :param filters: int, number of output channels
578
        :param use_additive_upsampling: bool to used additive upsampling
579
        :param kwargs: additional arguments.
580
        """
581
        super().__init__(**kwargs)
582
        # save parameters
583
        self._filters = filters
584
        self._use_additive_upsampling = use_additive_upsampling
585
        # init layer variables
586
        self._deconv3d_block = None
587
        self._additive_upsampling = None
588
        self._conv3d_block = Conv3dBlock(filters=filters)
589
        self._residual_block = LocalNetResidual3dBlock(filters=filters, strides=1)
590
591
    def build(self, input_shape):
0 ignored issues
show
introduced by
"input_shape" missing in parameter type documentation
Loading history...
592
        """
593
        :param input_shape: tuple (nonskip_tensor_shape, skip_tensor_shape)
594
        """
595
        super().build(input_shape)
596
597
        output_shape = input_shape[1][1:4]
598
        self._deconv3d_block = Deconv3dBlock(
599
            filters=self._filters, output_shape=output_shape, strides=2
600
        )
601
        if self._use_additive_upsampling:
602
            self._additive_upsampling = AdditiveUpSampling(output_shape=output_shape)
603
604
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
605
        """
606
        :param inputs: list = [inputs_nonskip, inputs_skip]
607
        :param training: training flag for normalization layers (default: None)
608
        :param kwargs: additional arguments.
609
        :return:
610
        """
611
        inputs_nonskip, inputs_skip = inputs[0], inputs[1]
612
        h0 = self._deconv3d_block(inputs=inputs_nonskip, training=training)
613
        if self._use_additive_upsampling:
614
            h0 += self._additive_upsampling(inputs=inputs_nonskip)
615
        r1 = h0 + inputs_skip
616
        r2 = self._conv3d_block(inputs=h0, training=training)
617
        h1 = self._residual_block(inputs=[r2, r1], training=training)
618
        return h1
619
620
621
class ResizeCPTransform(tfkl.Layer):
622
    """
623
    Layer for getting the control points from the output of a image-to-image network.
624
    It uses an anti-aliasing Gaussian filter before downsampling.
625
    """
626
627
    def __init__(self, control_point_spacing: (list, tuple, int), **kwargs):
628
        """
629
        :param control_point_spacing: list or int
630
        :param kwargs: additional arguments.
631
        """
632
        super().__init__(**kwargs)
633
634
        if isinstance(control_point_spacing, int):
635
            control_point_spacing = [control_point_spacing] * 3
636
637
        self.kernel_sigma = [
638
            0.44 * cp for cp in control_point_spacing
639
        ]  # 0.44 = ln(4)/pi
640
        self.cp_spacing = control_point_spacing
641
        self.kernel = None
642
        self._output_shape = None
643
644
    def build(self, input_shape):
645
        super().build(input_shape=input_shape)
646
647
        self.kernel = layer_util.gaussian_filter_3d(self.kernel_sigma)
648
        output_shape = [
649
            tf.cast(tf.math.ceil(v / c) + 3, tf.int32)
650
            for v, c in zip(input_shape[1:-1], self.cp_spacing)
651
        ]
652
        self._output_shape = output_shape
653
654
    def call(self, inputs, **kwargs) -> tf.Tensor:
655
        output = tf.nn.conv3d(
656
            inputs, self.kernel, strides=(1, 1, 1, 1, 1), padding="SAME"
657
        )
658
        return layer_util.resize3d(image=output, size=self._output_shape)
659
660
661
class BSplines3DTransform(tfkl.Layer):
662
    """
663
     Layer for BSplines interpolation with precomputed cubic spline filters.
664
     It assumes a full sized image from which:
665
     1. it compute the contol points values by downsampling the initial image
666
     2. performs the interpolation
667
     3. crops the image around the valid values.
668
669
    :param cp_spacing: int or tuple of three ints specifying the spacing (in pixels)
670
        in each dimension. When a single int is used,
671
        the same spacing to all dimensions is used
672
    :param output_shape: (batch_size, dim0, dim1, dim2, 3) of the high resolution
673
        deformation fields.
674
    :param kwargs: additional arguments.
675
    """
676
677
    def __init__(self, cp_spacing: (int, tuple), output_shape: tuple, **kwargs):
678
679
        super().__init__(**kwargs)
680
681
        self.filters = []
682
        self._output_shape = output_shape
683
684
        if isinstance(cp_spacing, int):
685
            self.cp_spacing = (cp_spacing, cp_spacing, cp_spacing)
686
        else:
687
            self.cp_spacing = cp_spacing
688
689
    def build(self, input_shape: tuple):
0 ignored issues
show
introduced by
Redundant returns documentation
Loading history...
690
        """
691
        :param input_shape: tuple with the input shape
692
        :return: None
693
        """
694
695
        super().build(input_shape=input_shape)
696
697
        b = {
698
            0: lambda u: np.float64((1 - u) ** 3 / 6),
699
            1: lambda u: np.float64((3 * (u ** 3) - 6 * (u ** 2) + 4) / 6),
700
            2: lambda u: np.float64((-3 * (u ** 3) + 3 * (u ** 2) + 3 * u + 1) / 6),
701
            3: lambda u: np.float64(u ** 3 / 6),
702
        }
703
704
        filters = np.zeros(
705
            (
706
                4 * self.cp_spacing[0],
707
                4 * self.cp_spacing[1],
708
                4 * self.cp_spacing[2],
709
                3,
710
                3,
711
            ),
712
            dtype=np.float32,
713
        )
714
715
        u_arange = 1 - np.arange(
716
            1 / (2 * self.cp_spacing[0]), 1, 1 / self.cp_spacing[0]
717
        )
718
        v_arange = 1 - np.arange(
719
            1 / (2 * self.cp_spacing[1]), 1, 1 / self.cp_spacing[1]
720
        )
721
        w_arange = 1 - np.arange(
722
            1 / (2 * self.cp_spacing[2]), 1, 1 / self.cp_spacing[2]
723
        )
724
725
        filter_idx = [[0, 1, 2, 3] for _ in range(3)]
726
        filter_coord = list(itertools.product(*filter_idx))
727
728
        for f_idx in filter_coord:
729
            for it_dim in range(3):
730
                filters[
731
                    f_idx[0] * self.cp_spacing[0] : (f_idx[0] + 1) * self.cp_spacing[0],
732
                    f_idx[1] * self.cp_spacing[1] : (f_idx[1] + 1) * self.cp_spacing[1],
733
                    f_idx[2] * self.cp_spacing[2] : (f_idx[2] + 1) * self.cp_spacing[2],
734
                    it_dim,
735
                    it_dim,
736
                ] = (
737
                    b[f_idx[0]](u_arange)[:, None, None]
738
                    * b[f_idx[1]](v_arange)[None, :, None]
739
                    * b[f_idx[2]](w_arange)[None, None, :]
740
                )
741
742
        self.filter = tf.convert_to_tensor(filters)
743
744
    def interpolate(self, field) -> tf.Tensor:
0 ignored issues
show
introduced by
"field" missing in parameter type documentation
Loading history...
745
        """
746
        :param field: tf.Tensor with shape=number_of_control_points_per_dim
747
        :return: interpolated_field: tf.Tensor
748
        """
749
750
        image_shape = tuple(
751
            [(a - 1) * b + 4 * b for a, b in zip(field.shape[1:-1], self.cp_spacing)]
752
        )
753
754
        output_shape = (field.shape[0],) + image_shape + (3,)
755
        return tf.nn.conv3d_transpose(
756
            field,
757
            self.filter,
758
            output_shape=output_shape,
759
            strides=self.cp_spacing,
760
            padding="VALID",
761
        )
762
763
    def call(self, inputs, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"inputs" missing in parameter type documentation
Loading history...
764
        """
765
        :param inputs: tf.Tensor defining a low resolution free-form deformation field
766
        :param kwargs: additional arguments.
767
        :return: interpolated_field: tf.Tensor of shape=self.input_shape
768
        """
769
        high_res_field = self.interpolate(inputs)
770
771
        index = [int(3 * c) for c in self.cp_spacing]
772
        return high_res_field[
773
            :,
774
            index[0] : index[0] + self._output_shape[0],
775
            index[1] : index[1] + self._output_shape[1],
776
            index[2] : index[2] + self._output_shape[2],
777
        ]
778