Passed
Pull Request — main (#656)
by Yunguan
03:01
created

deepreg.model.layer.MaxPool3d.__init__()   A

Complexity

Conditions 1

Size

Total Lines 18
Code Lines 9

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 9
dl 0
loc 18
rs 9.95
c 0
b 0
f 0
cc 1
nop 5
1
"""This module defines custom layers."""
2
import itertools
3
4
import numpy as np
5
import tensorflow as tf
6
import tensorflow.keras.layers as tfkl
7
8
import deepreg.model.layer_util as layer_util
9
10
11
class Deconv3d(tfkl.Layer):
12
    """
13
    Wrap Conv3DTranspose to allow dynamic output padding calculation.
14
    """
15
16
    def __init__(
0 ignored issues
show
introduced by
"name" missing in parameter type documentation
Loading history...
17
        self,
18
        filters: int,
19
        output_shape: (tuple, None) = None,
20
        kernel_size: int = 3,
21
        strides: int = 1,
22
        padding: str = "same",
23
        name="deconv3d",
24
        **kwargs,
25
    ):
26
        """
27
        Init.
28
29
        :param filters: number of channels of the output
30
        :param output_shape: (out_dim1, out_dim2, out_dim3)
31
        :param kernel_size: int or tuple of 3 ints, e.g. (3,3,3) or 3
32
        :param strides: int or tuple of 3 ints, e.g. (1,1,1) or 1
33
        :param padding: same or valid.
34
        :param name: name of the layer.
35
        :param kwargs: additional arguments for Conv3DTranspose.
36
        """
37
        super().__init__(name=name)
38
        # save arguments
39
        self._filters = filters
40
        self._output_shape = output_shape
41
        self._kernel_size = kernel_size
42
        self._strides = strides
43
        self._padding = padding
44
        self._kwargs = kwargs
45
        # init layer variables
46
        self._deconv3d = None
47
48
    def build(self, input_shape):
0 ignored issues
show
introduced by
"input_shape" missing in parameter type documentation
Loading history...
49
        # pylint: disable-next=line-too-long
50
        """
51
        Calculate output padding on the fly.
52
53
        https://github.com/tensorflow/tensorflow/blob/1cf0898dd4331baf93fe77205550f2c2e6c90ee5/tensorflow/python/keras/utils/conv_utils.py#L139-L185
54
55
        :param input_shape: shape of input
56
        """
57
        super().build(input_shape)
58
59
        if isinstance(self._kernel_size, int):
60
            kernel_size = [self._kernel_size] * 3
61
        else:
62
            kernel_size = self._kernel_size
63
64
        if isinstance(self._strides, int):
65
            strides = [self._strides] * 3
66
        else:
67
            strides = self._strides
68
69
        output_padding = None
70
        if self._output_shape is not None:
71
            assert self._padding == "same"
72
            output_padding = [
73
                self._output_shape[i]
74
                - (
75
                    (input_shape[1 + i] - 1) * strides[i]
76
                    + kernel_size[i]
77
                    - 2 * (kernel_size[i] // 2)
78
                )
79
                for i in range(3)
80
            ]
81
        self._deconv3d = tfkl.Conv3DTranspose(
82
            filters=self._filters,
83
            kernel_size=self._kernel_size,
84
            strides=self._strides,
85
            padding=self._padding,
86
            output_padding=output_padding,
87
            **self._kwargs,
88
        )
89
90
    def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor:
91
        """
92
        Forward.
93
94
        :param inputs: input tensor.
95
        :param kwargs: additional arguments
96
        :return:
97
        """
98
        return self._deconv3d(inputs=inputs)
99
100
    def get_config(self) -> dict:
101
        """Return the config dictionary for recreating this class."""
102
        config = super().get_config()
103
        config.update(
104
            dict(
105
                filters=self._filters,
106
                output_shape=self._output_shape,
107
                kernel_size=self._kernel_size,
108
                strides=self._strides,
109
                padding=self._padding,
110
            )
111
        )
112
        config.update(self._kwargs)
113
        return config
114
115
116
class NormBlock(tfkl.Layer):
117
    """
118
    A block with layer - norm - activation.
119
    """
120
121
    layer_cls_dict = dict(conv3d=tfkl.Conv3D, deconv3d=Deconv3d)
122
    norm_cls_dict = dict(batch=tfkl.BatchNormalization, layer=tfkl.LayerNormalization)
123
124
    def __init__(
125
        self,
126
        layer_name: str,
127
        norm_name: str = "batch",
128
        activation: str = "relu",
129
        name: str = "norm_block",
130
        **kwargs,
131
    ):
132
        """
133
        Init.
134
135
        :param layer_name: class of the layer to be wrapped.
136
        :param norm_name: class of the normalization layer.
137
        :param activation: name of activation.
138
        :param name: name of the block layer.
139
        :param kwargs: additional arguments.
140
        """
141
        super().__init__()
142
        self._config = dict(
143
            layer_name=layer_name,
144
            norm_name=norm_name,
145
            activation=activation,
146
            name=name,
147
            **kwargs,
148
        )
149
        self._layer = self.layer_cls_dict[layer_name](use_bias=False, **kwargs)
150
        self._norm = self.norm_cls_dict[norm_name]()
151
        self._act = tfkl.Activation(activation=activation)
152
153
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
154
        """
155
        Forward.
156
157
        :param inputs: inputs for the layer
158
        :param training: training flag for normalization layers (default: None)
159
        :param kwargs: additional arguments.
160
        :return:
161
        """
162
        output = self._layer(inputs=inputs)
163
        output = self._norm(inputs=output, training=training)
164
        output = self._act(output)
165
        return output
166
167
    def get_config(self) -> dict:
168
        """Return the config dictionary for recreating this class."""
169
        config = super().get_config()
170
        config.update(self._config)
171
        return config
172
173
174
class Conv3dBlock(NormBlock):
175
    """
176
    A conv3d block having conv3d - norm - activation.
177
    """
178
179
    def __init__(
180
        self,
181
        name: str = "conv3d_block",
182
        **kwargs,
183
    ):
184
        """
185
        Init.
186
187
        :param name: name of the layer
188
        :param kwargs: additional arguments.
189
        """
190
        super().__init__(layer_name="conv3d", name=name, **kwargs)
191
192
193
class Deconv3dBlock(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
194
    def __init__(
195
        self,
196
        filters: int,
197
        output_shape: (tuple, None) = None,
198
        kernel_size: (int, tuple) = 3,
199
        strides: (int, tuple) = 1,
200
        padding: str = "same",
201
        activation: str = "relu",
202
        **kwargs,
203
    ):
204
        """
205
        A deconv3d block having deconv3d - norm - activation.
206
207
        :param filters: number of channels of the output
208
        :param output_shape: (out_dim1, out_dim2, out_dim3)
209
        :param kernel_size: int or tuple of 3 ints, e.g. (3,3,3) or 3
210
        :param strides: int or tuple of 3 ints, e.g. (1,1,1) or 1
211
        :param padding: str, same or valid
212
        :param activation: name of activation
213
        :param kwargs: additional arguments.
214
        """
215
        super().__init__(**kwargs)
216
        # init layer variables
217
        self._deconv3d = Deconv3d(
218
            filters=filters,
219
            output_shape=output_shape,
220
            kernel_size=kernel_size,
221
            strides=strides,
222
            padding=padding,
223
            use_bias=False,
224
        )
225
        self._norm = tfkl.BatchNormalization()
226
        self._act = tfkl.Activation(activation=activation)
227
228
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
introduced by
Missing return documentation
Loading history...
229
        """
230
        :param inputs: shape = (batch, in_dim1, in_dim2, in_dim3, channels)
231
        :param training: training flag for normalization layers (default: None)
232
        :param kwargs: additional arguments.
233
        :return output: shape = (batch, in_dim1, in_dim2, in_dim3, channels)
234
        """
235
        output = self._deconv3d(inputs=inputs)
236
        output = self._norm(inputs=output, training=training)
237
        output = self._act(output)
238
        return output
239
240
241
class Residual3dBlock(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
242
    def __init__(
243
        self,
244
        filters: int,
245
        kernel_size: (int, tuple) = 3,
246
        strides: (int, tuple) = 1,
247
        activation: str = "relu",
248
        **kwargs,
249
    ):
250
        """
251
        A resnet conv3d block.
252
253
        1. conved = conv3d(conv3d_block(inputs))
254
        2. out = act(norm(conved) + inputs)
255
256
        :param filters: int, number of filters in the convolutional layers
257
        :param kernel_size: int or tuple of 3 ints, e.g. (3,3,3) or 3
258
        :param strides: int or tuple of 3 ints, e.g. (1,1,1) or 1
259
        :param activation: name of activation
260
        :param kwargs: additional arguments.
261
        """
262
        super().__init__(**kwargs)
263
        # init layer variables
264
        self._conv3d_block = Conv3dBlock(
265
            filters=filters,
266
            kernel_size=kernel_size,
267
            strides=strides,
268
            padding="same",
269
        )
270
        self._conv3d = tfkl.Conv3D(
271
            filters=filters,
272
            kernel_size=kernel_size,
273
            strides=strides,
274
            padding="same",
275
            use_bias=False,
276
        )
277
        self._norm = tfkl.BatchNormalization()
278
        self._act = tfkl.Activation(activation=activation)
279
280
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
introduced by
Missing return documentation
Loading history...
281
        """
282
        :param inputs: shape = (batch, in_dim1, in_dim2, in_dim3, channels)
283
        :param training: training flag for normalization layers (default: None)
284
        :param kwargs: additional arguments.
285
        :return output: shape = (batch, in_dim1, in_dim2, in_dim3, channels)
286
        """
287
        return self._act(
288
            self._norm(
289
                inputs=self._conv3d(inputs=self._conv3d_block(inputs)),
290
                training=training,
291
            )
292
            + inputs
293
        )
294
295
296
class DownSampleResnetBlock(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
297
    def __init__(
298
        self,
299
        filters: int,
300
        kernel_size: (int, tuple) = 3,
301
        pooling: bool = True,
302
        **kwargs,
303
    ):
304
        """
305
        A down-sampling resnet conv3d block, with max-pooling or conv3d.
306
307
        1. conved = conv3d_block(inputs)  # adjust channel
308
        2. skip = residual_block(conved)  # develop feature
309
        3. pooled = pool(skip) # down-sample
310
311
        :param filters: number of channels of the output
312
        :param kernel_size: int or tuple of 3 ints, e.g. (3,3,3) or 3
313
        :param pooling: if True, use max pooling to downsample, otherwise use conv.
314
        :param kwargs: additional arguments.
315
        """
316
        super().__init__(**kwargs)
317
        # save parameters
318
        self._pooling = pooling
319
        # init layer variables
320
        self._conv3d_block = Conv3dBlock(
321
            filters=filters, kernel_size=kernel_size, padding="same"
322
        )
323
        self._residual_block = Residual3dBlock(filters=filters, kernel_size=kernel_size)
324
        self._max_pool3d = (
325
            tfkl.MaxPool3D(pool_size=(2, 2, 2), strides=(2, 2, 2), padding="same")
326
            if pooling
327
            else None
328
        )
329
        self._conv3d_block3 = (
330
            None
331
            if pooling
332
            else NormBlock(
333
                layer_name="conv3d",
334
                filters=filters,
335
                kernel_size=kernel_size,
336
                strides=2,
337
                padding="same",
338
            )
339
        )
340
341
    def call(self, inputs, training=None, **kwargs) -> (tf.Tensor, tf.Tensor):
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
342
        """
343
        :param inputs: shape = (batch, in_dim1, in_dim2, in_dim3, channels)
344
        :param training: training flag for normalization layers (default: None)
345
        :param kwargs: additional arguments.
346
        :return: (pooled, skip)
347
348
          - downsampled, shape = (batch, in_dim1//2, in_dim2//2, in_dim3//2, channels)
349
          - skipped, shape = (batch, in_dim1, in_dim2, in_dim3, channels)
350
        """
351
        conved = self._conv3d_block(inputs=inputs, training=training)  # adjust channel
352
        skip = self._residual_block(inputs=conved, training=training)  # develop feature
353
        pooled = (
354
            self._max_pool3d(inputs=skip)
355
            if self._pooling
356
            else self._conv3d_block3(inputs=skip, training=training)
0 ignored issues
show
Bug introduced by
self._conv3d_block3 does not seem to be callable.
Loading history...
357
        )  # downsample
358
        return pooled, skip
359
360
361
class UpSampleResnetBlock(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
362
    def __init__(self, filters, kernel_size=3, concat=False, **kwargs):
0 ignored issues
show
introduced by
"concat, filters, kernel_size" missing in parameter type documentation
Loading history...
363
        """
364
        An up-sampling resnet conv3d block, with deconv3d.
365
366
        :param filters: number of channels of the output
367
        :param kernel_size: int or tuple of 3 ints, e.g. (3,3,3) or 3
368
        :param concat: bool,specify how to combine input and skip connection images.
369
            If True, use concatenation, otherwise use sum (default=False).
370
        :param kwargs: additional arguments.
371
        """
372
        super().__init__(**kwargs)
373
        # save parameters
374
        self._filters = filters
375
        self._concat = concat
376
        # init layer variables
377
        self._deconv3d_block = None
378
        self._conv3d_block = Conv3dBlock(
379
            filters=filters, kernel_size=kernel_size, padding="same"
380
        )
381
        self._residual_block = Residual3dBlock(filters=filters, kernel_size=kernel_size)
382
383
    def build(self, input_shape):
0 ignored issues
show
introduced by
"input_shape" missing in parameter type documentation
Loading history...
384
        """
385
        :param input_shape: tuple, (downsampled_image_shape, skip_image_shape)
386
        """
387
        super().build(input_shape)
388
        skip_shape = input_shape[1][1:4]
389
        self._deconv3d_block = Deconv3dBlock(
390
            filters=self._filters, output_shape=skip_shape, strides=2
391
        )
392
393
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
394
        r"""
395
        :param inputs: tuple
396
397
          - down-sampled
398
          - skipped
399
400
        :param training: training flag for normalization layers (default: None)
401
        :param kwargs: additional arguments.
402
        :return: shape = (batch, \*skip_connection_image_shape, filters]
403
        """
404
        up_sampled, skip = inputs[0], inputs[1]
405
        up_sampled = self._deconv3d_block(
406
            inputs=up_sampled, training=training
407
        )  # up sample and change channel
408
        up_sampled = (
409
            tf.concat([up_sampled, skip], axis=4) if self._concat else up_sampled + skip
410
        )  # combine
411
        up_sampled = self._conv3d_block(
412
            inputs=up_sampled, training=training
413
        )  # adjust channel
414
        up_sampled = self._residual_block(inputs=up_sampled, training=training)  # conv
415
        return up_sampled
416
417
418
class Conv3dWithResize(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
419
    def __init__(
420
        self,
421
        output_shape: tuple,
422
        filters: int,
423
        kernel_initializer: str = "glorot_uniform",
424
        activation: (str, None) = None,
425
        **kwargs,
426
    ):
427
        """
428
        A layer contains conv3d - resize3d.
429
430
        :param output_shape: tuple, (out_dim1, out_dim2, out_dim3)
431
        :param filters: int, number of channels of the output
432
        :param kernel_initializer: str, defines the initialization method
433
        :param activation: str, defines the activation function
434
        :param kwargs: additional arguments.
435
        """
436
        super().__init__(**kwargs)
437
        # save parameters
438
        self._output_shape = output_shape
439
        # init layer variables
440
        self._conv3d = tfkl.Conv3D(
441
            filters=filters,
442
            kernel_size=3,
443
            strides=1,
444
            padding="same",
445
            kernel_initializer=kernel_initializer,
446
            activation=activation,
447
        )  # if not zero, with init NN, ddf may be too large
448
449
    def call(self, inputs, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"inputs" missing in parameter type documentation
Loading history...
450
        """
451
        :param inputs: shape = (batch, dim1, dim2, dim3, channels)
452
        :param kwargs: additional arguments.
453
        :return: shape = (batch, out_dim1, out_dim2, out_dim3, channels)
454
        """
455
        output = self._conv3d(inputs=inputs)
456
        output = layer_util.resize3d(image=output, size=self._output_shape)
457
        return output
458
459
460
class Warping(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
461
    def __init__(self, fixed_image_size: tuple, **kwargs):
462
        """
463
        A layer warps an image using DDF.
464
465
        Reference:
466
467
        - transform of neuron
468
          https://github.com/adalca/neurite/blob/legacy/neuron/utils.py
469
470
          where vol = image, loc_shift = ddf
471
472
        :param fixed_image_size: shape = (f_dim1, f_dim2, f_dim3)
473
             or (f_dim1, f_dim2, f_dim3, ch) with the last channel for features
474
        :param kwargs: additional arguments.
475
        """
476
        super().__init__(**kwargs)
477
        self.grid_ref = layer_util.get_reference_grid(grid_size=fixed_image_size)[
478
            None, ...
479
        ]  # shape = (1, f_dim1, f_dim2, f_dim3, 3)
480
481
    def call(self, inputs, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"inputs" missing in parameter type documentation
Loading history...
482
        """
483
        :param inputs: (ddf, image)
484
485
          - ddf, shape = (batch, f_dim1, f_dim2, f_dim3, 3), dtype = float32
486
          - image, shape = (batch, m_dim1, m_dim2, m_dim3), dtype = float32
487
        :param kwargs: additional arguments.
488
        :return: shape = (batch, f_dim1, f_dim2, f_dim3)
489
        """
490
        return layer_util.resample(vol=inputs[1], loc=self.grid_ref + inputs[0])
491
492
493
class IntDVF(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
494
    def __init__(self, fixed_image_size: tuple, num_steps: int = 7, **kwargs):
495
        """
496
        Layer calculates DVF from DDF.
497
498
        Reference:
499
500
        - integrate_vec of neuron
501
          https://github.com/adalca/neurite/blob/legacy/neuron/utils.py
502
503
        :param fixed_image_size: tuple, (f_dim1, f_dim2, f_dim3)
504
        :param num_steps: int, number of steps for integration
505
        :param kwargs: additional arguments.
506
        """
507
        super().__init__(**kwargs)
508
        self._warping = Warping(fixed_image_size=fixed_image_size)
509
        self._num_steps = num_steps
510
511
    def call(self, inputs, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"inputs" missing in parameter type documentation
Loading history...
512
        """
513
        :param inputs: dvf, shape = (batch, f_dim1, f_dim2, f_dim3, 3), type = float32
514
        :param kwargs: additional arguments.
515
        :return: ddf, shape = (batch, f_dim1, f_dim2, f_dim3, 3)
516
        """
517
        ddf = inputs / (2 ** self._num_steps)
518
        for _ in range(self._num_steps):
519
            ddf += self._warping(inputs=[ddf, ddf])
520
        return ddf
521
522
523
class AdditiveUpSampling(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
524
    def __init__(self, output_shape: tuple, stride: (int, list) = 2, **kwargs):
525
        """
526
        Layer up-samples 3d tensor and reduce channels using split and sum.
527
528
        :param output_shape: (out_dim1, out_dim2, out_dim3)
529
        :param stride: int, 1-D Tensor or list
530
        :param kwargs: additional arguments.
531
        """
532
        super().__init__(**kwargs)
533
        # save parameters
534
        self._stride = stride
535
        self._output_shape = output_shape
536
537
    def call(self, inputs, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"inputs" missing in parameter type documentation
Loading history...
introduced by
"ValueError" not documented as being raised
Loading history...
538
        """
539
        :param inputs: shape = (batch, dim1, dim2, dim3, channels)
540
        :param kwargs: additional arguments.
541
        :return: shape = (batch, out_dim1, out_dim2, out_dim3, channels//stride]
542
        """
543
        if inputs.shape[4] % self._stride != 0:
544
            raise ValueError("The channel dimension can not be divided by the stride")
545
        output = layer_util.resize3d(image=inputs, size=self._output_shape)
546
        # a list of (batch, out_dim1, out_dim2, out_dim3, channels//stride)
547
        output = tf.split(output, num_or_size_splits=self._stride, axis=4)
0 ignored issues
show
Unused Code introduced by
Argument 'axis' passed by position and keyword in function call
Loading history...
548
        # (batch, out_dim1, out_dim2, out_dim3, channels//stride)
549
        output = tf.reduce_sum(tf.stack(output, axis=5), axis=5)
550
        return output
551
552
553
class LocalNetResidual3dBlock(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
554
    def __init__(
555
        self,
556
        filters: int,
557
        kernel_size: (int, tuple) = 3,
558
        strides: (int, tuple) = 1,
559
        activation: str = "relu",
560
        **kwargs,
561
    ):
562
        """
563
        A resnet conv3d block, simpler than Residual3dBlock.
564
565
        1. conved = conv3d(inputs)
566
        2. out = act(norm(conved) + inputs)
567
568
        :param filters: number of channels of the output
569
        :param kernel_size: int or tuple of 3 ints, e.g. (3,3,3) or 3
570
        :param strides: int or tuple of 3 ints, e.g. (1,1,1) or 1
571
        :param activation: name of activation
572
        :param kwargs: additional arguments.
573
        """
574
        super().__init__(**kwargs)
575
        # init layer variables
576
        self._conv3d = tfkl.Conv3D(
577
            filters=filters,
578
            kernel_size=kernel_size,
579
            strides=strides,
580
            padding="same",
581
            use_bias=False,
582
        )
583
        self._norm = tfkl.BatchNormalization()
584
        self._act = tfkl.Activation(activation=activation)
585
586
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
587
        return self._act(
588
            self._norm(inputs=self._conv3d(inputs=inputs[0]), training=training)
589
            + inputs[1]
590
        )
591
592
593
class LocalNetUpSampleResnetBlock(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
594
    def __init__(self, filters: int, use_additive_upsampling: bool = True, **kwargs):
595
        """
596
        Layer up-samples tensor with two inputs (skipped and down-sampled).
597
598
        :param filters: int, number of output channels
599
        :param use_additive_upsampling: bool to used additive upsampling
600
        :param kwargs: additional arguments.
601
        """
602
        super().__init__(**kwargs)
603
        # save parameters
604
        self._filters = filters
605
        self._use_additive_upsampling = use_additive_upsampling
606
        # init layer variables
607
        self._deconv3d_block = None
608
        self._additive_upsampling = None
609
        self._conv3d_block = Conv3dBlock(filters=filters, kernel_size=3, padding="same")
610
        self._residual_block = LocalNetResidual3dBlock(filters=filters, strides=1)
611
612
    def build(self, input_shape):
0 ignored issues
show
introduced by
"input_shape" missing in parameter type documentation
Loading history...
613
        """
614
        :param input_shape: tuple (nonskip_tensor_shape, skip_tensor_shape)
615
        """
616
        super().build(input_shape)
617
618
        output_shape = input_shape[1][1:4]
619
        self._deconv3d_block = Deconv3dBlock(
620
            filters=self._filters, output_shape=output_shape, strides=2
621
        )
622
        if self._use_additive_upsampling:
623
            self._additive_upsampling = AdditiveUpSampling(output_shape=output_shape)
624
625
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
626
        """
627
        :param inputs: list = [inputs_nonskip, inputs_skip]
628
        :param training: training flag for normalization layers (default: None)
629
        :param kwargs: additional arguments.
630
        :return:
631
        """
632
        inputs_nonskip, inputs_skip = inputs[0], inputs[1]
633
        h0 = self._deconv3d_block(inputs=inputs_nonskip, training=training)
634
        if self._use_additive_upsampling:
635
            h0 += self._additive_upsampling(inputs=inputs_nonskip)
636
        r1 = h0 + inputs_skip
637
        r2 = self._conv3d_block(inputs=h0, training=training)
638
        h1 = self._residual_block(inputs=[r2, r1], training=training)
639
        return h1
640
641
642
class ResizeCPTransform(tfkl.Layer):
643
    """
644
    Layer for getting the control points from the output of a image-to-image network.
645
    It uses an anti-aliasing Gaussian filter before downsampling.
646
    """
647
648
    def __init__(self, control_point_spacing: (list, tuple, int), **kwargs):
649
        """
650
        :param control_point_spacing: list or int
651
        :param kwargs: additional arguments.
652
        """
653
        super().__init__(**kwargs)
654
655
        if isinstance(control_point_spacing, int):
656
            control_point_spacing = [control_point_spacing] * 3
657
658
        self.kernel_sigma = [
659
            0.44 * cp for cp in control_point_spacing
660
        ]  # 0.44 = ln(4)/pi
661
        self.cp_spacing = control_point_spacing
662
        self.kernel = None
663
        self._output_shape = None
664
665
    def build(self, input_shape):
666
        super().build(input_shape=input_shape)
667
668
        self.kernel = layer_util.gaussian_filter_3d(self.kernel_sigma)
669
        output_shape = [
670
            tf.cast(tf.math.ceil(v / c) + 3, tf.int32)
671
            for v, c in zip(input_shape[1:-1], self.cp_spacing)
672
        ]
673
        self._output_shape = output_shape
674
675
    def call(self, inputs, **kwargs) -> tf.Tensor:
676
        output = tf.nn.conv3d(
677
            inputs, self.kernel, strides=(1, 1, 1, 1, 1), padding="SAME"
678
        )
679
        return layer_util.resize3d(image=output, size=self._output_shape)
680
681
682
class BSplines3DTransform(tfkl.Layer):
683
    """
684
     Layer for BSplines interpolation with precomputed cubic spline filters.
685
     It assumes a full sized image from which:
686
     1. it compute the contol points values by downsampling the initial image
687
     2. performs the interpolation
688
     3. crops the image around the valid values.
689
690
    :param cp_spacing: int or tuple of three ints specifying the spacing (in pixels)
691
        in each dimension. When a single int is used,
692
        the same spacing to all dimensions is used
693
    :param output_shape: (batch_size, dim0, dim1, dim2, 3) of the high resolution
694
        deformation fields.
695
    :param kwargs: additional arguments.
696
    """
697
698
    def __init__(self, cp_spacing: (int, tuple), output_shape: tuple, **kwargs):
699
700
        super().__init__(**kwargs)
701
702
        self.filters = []
703
        self._output_shape = output_shape
704
705
        if isinstance(cp_spacing, int):
706
            self.cp_spacing = (cp_spacing, cp_spacing, cp_spacing)
707
        else:
708
            self.cp_spacing = cp_spacing
709
710
    def build(self, input_shape: tuple):
0 ignored issues
show
introduced by
Redundant returns documentation
Loading history...
711
        """
712
        :param input_shape: tuple with the input shape
713
        :return: None
714
        """
715
716
        super().build(input_shape=input_shape)
717
718
        b = {
719
            0: lambda u: np.float64((1 - u) ** 3 / 6),
720
            1: lambda u: np.float64((3 * (u ** 3) - 6 * (u ** 2) + 4) / 6),
721
            2: lambda u: np.float64((-3 * (u ** 3) + 3 * (u ** 2) + 3 * u + 1) / 6),
722
            3: lambda u: np.float64(u ** 3 / 6),
723
        }
724
725
        filters = np.zeros(
726
            (
727
                4 * self.cp_spacing[0],
728
                4 * self.cp_spacing[1],
729
                4 * self.cp_spacing[2],
730
                3,
731
                3,
732
            ),
733
            dtype=np.float32,
734
        )
735
736
        u_arange = 1 - np.arange(
737
            1 / (2 * self.cp_spacing[0]), 1, 1 / self.cp_spacing[0]
738
        )
739
        v_arange = 1 - np.arange(
740
            1 / (2 * self.cp_spacing[1]), 1, 1 / self.cp_spacing[1]
741
        )
742
        w_arange = 1 - np.arange(
743
            1 / (2 * self.cp_spacing[2]), 1, 1 / self.cp_spacing[2]
744
        )
745
746
        filter_idx = [[0, 1, 2, 3] for _ in range(3)]
747
        filter_coord = list(itertools.product(*filter_idx))
748
749
        for f_idx in filter_coord:
750
            for it_dim in range(3):
751
                filters[
752
                    f_idx[0] * self.cp_spacing[0] : (f_idx[0] + 1) * self.cp_spacing[0],
753
                    f_idx[1] * self.cp_spacing[1] : (f_idx[1] + 1) * self.cp_spacing[1],
754
                    f_idx[2] * self.cp_spacing[2] : (f_idx[2] + 1) * self.cp_spacing[2],
755
                    it_dim,
756
                    it_dim,
757
                ] = (
758
                    b[f_idx[0]](u_arange)[:, None, None]
759
                    * b[f_idx[1]](v_arange)[None, :, None]
760
                    * b[f_idx[2]](w_arange)[None, None, :]
761
                )
762
763
        self.filter = tf.convert_to_tensor(filters)
764
765
    def interpolate(self, field) -> tf.Tensor:
0 ignored issues
show
introduced by
"field" missing in parameter type documentation
Loading history...
766
        """
767
        :param field: tf.Tensor with shape=number_of_control_points_per_dim
768
        :return: interpolated_field: tf.Tensor
769
        """
770
771
        image_shape = tuple(
772
            [(a - 1) * b + 4 * b for a, b in zip(field.shape[1:-1], self.cp_spacing)]
773
        )
774
775
        output_shape = (field.shape[0],) + image_shape + (3,)
776
        return tf.nn.conv3d_transpose(
777
            field,
778
            self.filter,
779
            output_shape=output_shape,
780
            strides=self.cp_spacing,
781
            padding="VALID",
782
        )
783
784
    def call(self, inputs, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"inputs" missing in parameter type documentation
Loading history...
785
        """
786
        :param inputs: tf.Tensor defining a low resolution free-form deformation field
787
        :param kwargs: additional arguments.
788
        :return: interpolated_field: tf.Tensor of shape=self.input_shape
789
        """
790
        high_res_field = self.interpolate(inputs)
791
792
        index = [int(3 * c) for c in self.cp_spacing]
793
        return high_res_field[
794
            :,
795
            index[0] : index[0] + self._output_shape[0],
796
            index[1] : index[1] + self._output_shape[1],
797
            index[2] : index[2] + self._output_shape[2],
798
        ]
799