Passed
Pull Request — main (#656)
by Yunguan
03:01
created

deepreg.model.layer.Resize3d.call()   C

Complexity

Conditions 9

Size

Total Lines 82
Code Lines 42

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 42
dl 0
loc 82
rs 6.5386
c 0
b 0
f 0
cc 9
nop 3

How to fix   Long Method   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

1
"""This module defines custom layers."""
2
import itertools
3
4
import numpy as np
5
import tensorflow as tf
6
import tensorflow.keras.layers as tfkl
7
8
import deepreg.model.layer_util as layer_util
9
10
11
class Deconv3d(tfkl.Layer):
12
    """
13
    Wrap Conv3DTranspose to allow dynamic output padding calculation.
14
    """
15
16
    def __init__(
0 ignored issues
show
introduced by
"name" missing in parameter type documentation
Loading history...
17
        self,
18
        filters: int,
19
        output_shape: (tuple, None) = None,
20
        kernel_size: int = 3,
21
        strides: int = 1,
22
        padding: str = "same",
23
        name="deconv3d",
24
        **kwargs,
25
    ):
26
        """
27
        Init.
28
29
        :param filters: number of channels of the output
30
        :param output_shape: (out_dim1, out_dim2, out_dim3)
31
        :param kernel_size: int or tuple of 3 ints, e.g. (3,3,3) or 3
32
        :param strides: int or tuple of 3 ints, e.g. (1,1,1) or 1
33
        :param padding: same or valid.
34
        :param name: name of the layer.
35
        :param kwargs: additional arguments for Conv3DTranspose.
36
        """
37
        super().__init__(name=name)
38
        # save arguments
39
        self._filters = filters
40
        self._output_shape = output_shape
41
        self._kernel_size = kernel_size
42
        self._strides = strides
43
        self._padding = padding
44
        self._kwargs = kwargs
45
        # init layer variables
46
        self._deconv3d = None
47
48
    def build(self, input_shape):
0 ignored issues
show
introduced by
"input_shape" missing in parameter type documentation
Loading history...
49
        # pylint: disable-next=line-too-long
50
        """
51
        Calculate output padding on the fly.
52
53
        https://github.com/tensorflow/tensorflow/blob/1cf0898dd4331baf93fe77205550f2c2e6c90ee5/tensorflow/python/keras/utils/conv_utils.py#L139-L185
54
55
        :param input_shape: shape of input
56
        """
57
        super().build(input_shape)
58
        kernel_size = self._kernel_size
59
        strides = self._strides
60
        if isinstance(kernel_size, int):
61
            kernel_size = [kernel_size] * 3
62
63
        if isinstance(strides, int):
64
            strides = [strides] * 3
65
66
        output_padding = None
67
        if self._output_shape is not None:
68
            assert self._padding == "same"
69
            output_padding = [
70
                self._output_shape[i]
71
                - (
72
                    (input_shape[1 + i] - 1) * strides[i]
73
                    + kernel_size[i]
74
                    - 2 * (kernel_size[i] // 2)
75
                )
76
                for i in range(3)
77
            ]
78
        self._deconv3d = tfkl.Conv3DTranspose(
79
            filters=self._filters,
80
            kernel_size=self._kernel_size,
81
            strides=self._strides,
82
            padding=self._padding,
83
            output_padding=output_padding,
84
            **self._kwargs,
85
        )
86
87
    def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor:
88
        """
89
        Forward.
90
91
        :param inputs: input tensor.
92
        :param kwargs: additional arguments
93
        :return:
94
        """
95
        return self._deconv3d(inputs=inputs)
96
97
    def get_config(self) -> dict:
98
        """Return the config dictionary for recreating this class."""
99
        config = super().get_config()
100
        config.update(
101
            dict(
102
                filters=self._filters,
103
                output_shape=self._output_shape,
104
                kernel_size=self._kernel_size,
105
                strides=self._strides,
106
                padding=self._padding,
107
            )
108
        )
109
        config.update(self._kwargs)
110
        return config
111
112
113
class NormBlock(tfkl.Layer):
114
    """
115
    A block with layer - norm - activation.
116
    """
117
118
    layer_cls_dict = dict(conv3d=tfkl.Conv3D, deconv3d=Deconv3d)
119
    norm_cls_dict = dict(batch=tfkl.BatchNormalization, layer=tfkl.LayerNormalization)
120
121
    def __init__(
122
        self,
123
        layer_name: str,
124
        norm_name: str = "batch",
125
        activation: str = "relu",
126
        name: str = "norm_block",
127
        **kwargs,
128
    ):
129
        """
130
        Init.
131
132
        :param layer_name: class of the layer to be wrapped.
133
        :param norm_name: class of the normalization layer.
134
        :param activation: name of activation.
135
        :param name: name of the block layer.
136
        :param kwargs: additional arguments.
137
        """
138
        super().__init__()
139
        self._config = dict(
140
            layer_name=layer_name,
141
            norm_name=norm_name,
142
            activation=activation,
143
            name=name,
144
            **kwargs,
145
        )
146
        self._layer = self.layer_cls_dict[layer_name](use_bias=False, **kwargs)
147
        self._norm = self.norm_cls_dict[norm_name]()
148
        self._act = tfkl.Activation(activation=activation)
149
150
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
151
        """
152
        Forward.
153
154
        :param inputs: inputs for the layer
155
        :param training: training flag for normalization layers (default: None)
156
        :param kwargs: additional arguments.
157
        :return:
158
        """
159
        output = self._layer(inputs=inputs)
160
        output = self._norm(inputs=output, training=training)
161
        output = self._act(output)
162
        return output
163
164
    def get_config(self) -> dict:
165
        """Return the config dictionary for recreating this class."""
166
        config = super().get_config()
167
        config.update(self._config)
168
        return config
169
170
171
class Conv3dBlock(NormBlock):
172
    """
173
    A conv3d block having conv3d - norm - activation.
174
    """
175
176
    def __init__(
177
        self,
178
        name: str = "conv3d_block",
179
        **kwargs,
180
    ):
181
        """
182
        Init.
183
184
        :param name: name of the layer
185
        :param kwargs: additional arguments.
186
        """
187
        super().__init__(layer_name="conv3d", name=name, **kwargs)
188
189
190
class Deconv3dBlock(NormBlock):
191
    """
192
    A deconv3d block having conv3d - norm - activation.
193
    """
194
195
    def __init__(
196
        self,
197
        name: str = "deconv3d_block",
198
        **kwargs,
199
    ):
200
        """
201
        Init.
202
203
        :param name: name of the layer
204
        :param kwargs: additional arguments.
205
        """
206
        super().__init__(layer_name="deconv3d", name=name, **kwargs)
207
208
209
class Resize3d(tfkl.Layer):
210
    """
211
    Resize image in two folds.
212
213
    - resize dim2 and dim3
214
    - resize dim1 and dim2
215
    """
216
217
    def __init__(
218
        self,
219
        shape: tuple,
220
        method: str = tf.image.ResizeMethod.BILINEAR,
221
        name: str = "resize3d",
222
    ):
223
        """
224
        Init, save arguments.
225
226
        :param shape: (dim1, dim2, dim3)
227
        :param method: tf.image.ResizeMethod
228
        :param name: name of the layer
229
        """
230
        super().__init__(name=name)
231
        assert len(shape) == 3
232
        self._shape = shape
233
        self._method = method
234
235
    def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"ValueError" not documented as being raised
Loading history...
236
        """
237
        Perform two fold resize.
238
239
        :param inputs: shape = (batch, dim1, dim2, dim3, channels)
240
                                     or (batch, dim1, dim2, dim3)
241
                                     or (dim1, dim2, dim3)
242
        :param kwargs: additional arguments
243
        :return: shape = (batch, out_dim1, out_dim2, out_dim3, channels)
244
                                or (batch, dim1, dim2, dim3)
245
                                or (dim1, dim2, dim3)
246
        """
247
        # sanity check
248
        image = inputs
249
        image_dim = len(image.shape)
250
251
        # init
252
        if image_dim == 5:
253
            has_channel = True
254
            has_batch = True
255
            input_image_shape = image.shape[1:4]
256
        elif image_dim == 4:
257
            has_channel = False
258
            has_batch = True
259
            input_image_shape = image.shape[1:4]
260
        elif image_dim == 3:
261
            has_channel = False
262
            has_batch = False
263
            input_image_shape = image.shape[0:3]
264
        else:
265
            raise ValueError(
266
                "Resize3d takes input image of dimension 3 or 4 or 5, "
267
                "corresponding to (dim1, dim2, dim3) "
268
                "or (batch, dim1, dim2, dim3) "
269
                "or (batch, dim1, dim2, dim3, channels), "
270
                "got image shape{}".format(image.shape)
271
            )
272
273
        # no need of resize
274
        if input_image_shape == tuple(self._shape):
275
            return image
276
277
        # expand to five dimensions
278
        if not has_batch:
279
            image = tf.expand_dims(image, axis=0)
280
        if not has_channel:
281
            image = tf.expand_dims(image, axis=-1)
282
        assert len(image.shape) == 5  # (batch, dim1, dim2, dim3, channels)
283
        image_shape = tf.shape(image)
284
285
        # merge axis 0 and 1
286
        output = tf.reshape(
287
            image, (-1, image_shape[2], image_shape[3], image_shape[4])
288
        )  # (batch * dim1, dim2, dim3, channels)
289
290
        # resize dim2 and dim3
291
        output = tf.image.resize(
292
            images=output, size=self._shape[1:3], method=self._method
293
        )  # (batch * dim1, out_dim2, out_dim3, channels)
294
295
        # split axis 0 and merge axis 3 and 4
296
        output = tf.reshape(
297
            output,
298
            shape=(-1, image_shape[1], self._shape[1], self._shape[2] * image_shape[4]),
299
        )  # (batch, dim1, out_dim2, out_dim3 * channels)
300
301
        # resize dim1 and dim2
302
        output = tf.image.resize(
303
            images=output, size=self._shape[:2], method=self._method
304
        )  # (batch, out_dim1, out_dim2, out_dim3 * channels)
305
306
        # reshape
307
        output = tf.reshape(
308
            output, shape=[-1, *self._shape, image_shape[4]]
309
        )  # (batch, out_dim1, out_dim2, out_dim3, channels)
310
311
        # squeeze to original dimension
312
        if not has_batch:
313
            output = tf.squeeze(output, axis=0)
314
        if not has_channel:
315
            output = tf.squeeze(output, axis=-1)
316
        return output
317
318
    def get_config(self) -> dict:
319
        """Return the config dictionary for recreating this class."""
320
        config = super().get_config()
321
        config["shape"] = self._shape
322
        config["method"] = self._method
323
        return config
324
325
326
class Warping(tfkl.Layer):
327
    """
328
    Warps an image with DDF.
329
330
    Reference:
331
332
    https://github.com/adalca/neurite/blob/legacy/neuron/utils.py
333
    where vol = image, loc_shift = ddf
334
    """
335
336
    def __init__(self, fixed_image_size: tuple, name: str = "warping", **kwargs):
337
        """
338
        Init.
339
340
        :param fixed_image_size: shape = (f_dim1, f_dim2, f_dim3)
341
             or (f_dim1, f_dim2, f_dim3, ch) with the last channel for features
342
        :param name: name of the layer
343
        :param kwargs: additional arguments.
344
        """
345
        super().__init__(name=name, **kwargs)
346
        self._fixed_image_size = fixed_image_size
347
        # shape = (1, f_dim1, f_dim2, f_dim3, 3)
348
        self.grid_ref = layer_util.get_reference_grid(grid_size=fixed_image_size)
349
        self.grid_ref = self.grid_ref[None, ...]
350
351
    def call(self, inputs, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"inputs" missing in parameter type documentation
Loading history...
352
        """
353
        :param inputs: (ddf, image)
354
355
          - ddf, shape = (batch, f_dim1, f_dim2, f_dim3, 3)
356
          - image, shape = (batch, m_dim1, m_dim2, m_dim3)
357
        :param kwargs: additional arguments.
358
        :return: shape = (batch, f_dim1, f_dim2, f_dim3)
359
        """
360
        ddf, image = inputs
361
        return layer_util.resample(vol=image, loc=self.grid_ref + ddf)
362
363
    def get_config(self) -> dict:
364
        """Return the config dictionary for recreating this class."""
365
        config = super().get_config()
366
        config["fixed_image_size"] = self._fixed_image_size
367
        return config
368
369
370
class Residual3dBlock(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
371
    def __init__(
372
        self,
373
        filters: int,
374
        kernel_size: (int, tuple) = 3,
375
        strides: (int, tuple) = 1,
376
        activation: str = "relu",
377
        **kwargs,
378
    ):
379
        """
380
        A resnet conv3d block.
381
382
        1. conved = conv3d(conv3d_block(inputs))
383
        2. out = act(norm(conved) + inputs)
384
385
        :param filters: int, number of filters in the convolutional layers
386
        :param kernel_size: int or tuple of 3 ints, e.g. (3,3,3) or 3
387
        :param strides: int or tuple of 3 ints, e.g. (1,1,1) or 1
388
        :param activation: name of activation
389
        :param kwargs: additional arguments.
390
        """
391
        super().__init__(**kwargs)
392
        # init layer variables
393
        self._conv3d_block = Conv3dBlock(
394
            filters=filters,
395
            kernel_size=kernel_size,
396
            strides=strides,
397
            padding="same",
398
        )
399
        self._conv3d = tfkl.Conv3D(
400
            filters=filters,
401
            kernel_size=kernel_size,
402
            strides=strides,
403
            padding="same",
404
            use_bias=False,
405
        )
406
        self._norm = tfkl.BatchNormalization()
407
        self._act = tfkl.Activation(activation=activation)
408
409
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
introduced by
Missing return documentation
Loading history...
410
        """
411
        :param inputs: shape = (batch, in_dim1, in_dim2, in_dim3, channels)
412
        :param training: training flag for normalization layers (default: None)
413
        :param kwargs: additional arguments.
414
        :return output: shape = (batch, in_dim1, in_dim2, in_dim3, channels)
415
        """
416
        return self._act(
417
            self._norm(
418
                inputs=self._conv3d(inputs=self._conv3d_block(inputs)),
419
                training=training,
420
            )
421
            + inputs
422
        )
423
424
425
class DownSampleResnetBlock(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
426
    def __init__(
427
        self,
428
        filters: int,
429
        kernel_size: (int, tuple) = 3,
430
        pooling: bool = True,
431
        **kwargs,
432
    ):
433
        """
434
        A down-sampling resnet conv3d block, with max-pooling or conv3d.
435
436
        1. conved = conv3d_block(inputs)  # adjust channel
437
        2. skip = residual_block(conved)  # develop feature
438
        3. pooled = pool(skip) # down-sample
439
440
        :param filters: number of channels of the output
441
        :param kernel_size: int or tuple of 3 ints, e.g. (3,3,3) or 3
442
        :param pooling: if True, use max pooling to downsample, otherwise use conv.
443
        :param kwargs: additional arguments.
444
        """
445
        super().__init__(**kwargs)
446
        # save parameters
447
        self._pooling = pooling
448
        # init layer variables
449
        self._conv3d_block = Conv3dBlock(
450
            filters=filters, kernel_size=kernel_size, padding="same"
451
        )
452
        self._residual_block = Residual3dBlock(filters=filters, kernel_size=kernel_size)
453
        self._max_pool3d = (
454
            tfkl.MaxPool3D(pool_size=(2, 2, 2), strides=(2, 2, 2), padding="same")
455
            if pooling
456
            else None
457
        )
458
        self._conv3d_block3 = (
459
            None
460
            if pooling
461
            else NormBlock(
462
                layer_name="conv3d",
463
                filters=filters,
464
                kernel_size=kernel_size,
465
                strides=2,
466
                padding="same",
467
            )
468
        )
469
470
    def call(self, inputs, training=None, **kwargs) -> (tf.Tensor, tf.Tensor):
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
471
        """
472
        :param inputs: shape = (batch, in_dim1, in_dim2, in_dim3, channels)
473
        :param training: training flag for normalization layers (default: None)
474
        :param kwargs: additional arguments.
475
        :return: (pooled, skip)
476
477
          - downsampled, shape = (batch, in_dim1//2, in_dim2//2, in_dim3//2, channels)
478
          - skipped, shape = (batch, in_dim1, in_dim2, in_dim3, channels)
479
        """
480
        conved = self._conv3d_block(inputs=inputs, training=training)  # adjust channel
481
        skip = self._residual_block(inputs=conved, training=training)  # develop feature
482
        pooled = (
483
            self._max_pool3d(inputs=skip)
484
            if self._pooling
485
            else self._conv3d_block3(inputs=skip, training=training)
0 ignored issues
show
Bug introduced by
self._conv3d_block3 does not seem to be callable.
Loading history...
486
        )  # downsample
487
        return pooled, skip
488
489
490
class UpSampleResnetBlock(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
491
    def __init__(self, filters, kernel_size=3, concat=False, **kwargs):
0 ignored issues
show
introduced by
"concat, filters, kernel_size" missing in parameter type documentation
Loading history...
492
        """
493
        An up-sampling resnet conv3d block, with deconv3d.
494
495
        :param filters: number of channels of the output
496
        :param kernel_size: int or tuple of 3 ints, e.g. (3,3,3) or 3
497
        :param concat: bool,specify how to combine input and skip connection images.
498
            If True, use concatenation, otherwise use sum (default=False).
499
        :param kwargs: additional arguments.
500
        """
501
        super().__init__(**kwargs)
502
        # save parameters
503
        self._filters = filters
504
        self._concat = concat
505
        # init layer variables
506
        self._deconv3d_block = None
507
        self._conv3d_block = Conv3dBlock(
508
            filters=filters, kernel_size=kernel_size, padding="same"
509
        )
510
        self._residual_block = Residual3dBlock(filters=filters, kernel_size=kernel_size)
511
512
    def build(self, input_shape):
0 ignored issues
show
introduced by
"input_shape" missing in parameter type documentation
Loading history...
513
        """
514
        :param input_shape: tuple, (downsampled_image_shape, skip_image_shape)
515
        """
516
        super().build(input_shape)
517
        skip_shape = input_shape[1][1:4]
518
        self._deconv3d_block = Deconv3dBlock(
519
            filters=self._filters,
520
            output_shape=skip_shape,
521
            kernel_size=3,
522
            strides=2,
523
            padding="same",
524
        )
525
526
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
527
        r"""
528
        :param inputs: tuple
529
530
          - down-sampled
531
          - skipped
532
533
        :param training: training flag for normalization layers (default: None)
534
        :param kwargs: additional arguments.
535
        :return: shape = (batch, \*skip_connection_image_shape, filters]
536
        """
537
        up_sampled, skip = inputs[0], inputs[1]
538
        up_sampled = self._deconv3d_block(
539
            inputs=up_sampled, training=training
540
        )  # up sample and change channel
541
        up_sampled = (
542
            tf.concat([up_sampled, skip], axis=4) if self._concat else up_sampled + skip
543
        )  # combine
544
        up_sampled = self._conv3d_block(
545
            inputs=up_sampled, training=training
546
        )  # adjust channel
547
        up_sampled = self._residual_block(inputs=up_sampled, training=training)  # conv
548
        return up_sampled
549
550
551
class IntDVF(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
552
    def __init__(self, fixed_image_size: tuple, num_steps: int = 7, **kwargs):
553
        """
554
        Layer calculates DVF from DDF.
555
556
        Reference:
557
558
        - integrate_vec of neuron
559
          https://github.com/adalca/neurite/blob/legacy/neuron/utils.py
560
561
        :param fixed_image_size: tuple, (f_dim1, f_dim2, f_dim3)
562
        :param num_steps: int, number of steps for integration
563
        :param kwargs: additional arguments.
564
        """
565
        super().__init__(**kwargs)
566
        self._warping = Warping(fixed_image_size=fixed_image_size)
567
        self._num_steps = num_steps
568
569
    def call(self, inputs, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"inputs" missing in parameter type documentation
Loading history...
570
        """
571
        :param inputs: dvf, shape = (batch, f_dim1, f_dim2, f_dim3, 3), type = float32
572
        :param kwargs: additional arguments.
573
        :return: ddf, shape = (batch, f_dim1, f_dim2, f_dim3, 3)
574
        """
575
        ddf = inputs / (2 ** self._num_steps)
576
        for _ in range(self._num_steps):
577
            ddf += self._warping(inputs=[ddf, ddf])
578
        return ddf
579
580
581
class AdditiveUpSampling(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
582
    def __init__(self, output_shape: tuple, stride: (int, list) = 2, **kwargs):
583
        """
584
        Layer up-samples 3d tensor and reduce channels using split and sum.
585
586
        :param output_shape: (out_dim1, out_dim2, out_dim3)
587
        :param stride: int, 1-D Tensor or list
588
        :param kwargs: additional arguments.
589
        """
590
        super().__init__(**kwargs)
591
        # save parameters
592
        self._stride = stride
593
        self._resize = Resize3d(output_shape)
594
        self._output_shape = output_shape
595
596
    def call(self, inputs, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"inputs" missing in parameter type documentation
Loading history...
introduced by
"ValueError" not documented as being raised
Loading history...
597
        """
598
        :param inputs: shape = (batch, dim1, dim2, dim3, channels)
599
        :param kwargs: additional arguments.
600
        :return: shape = (batch, out_dim1, out_dim2, out_dim3, channels//stride]
601
        """
602
        if inputs.shape[4] % self._stride != 0:
603
            raise ValueError("The channel dimension can not be divided by the stride")
604
        output = self._resize(inputs)
605
        # a list of (batch, out_dim1, out_dim2, out_dim3, channels//stride)
606
        output = tf.split(output, num_or_size_splits=self._stride, axis=4)
0 ignored issues
show
Unused Code introduced by
Argument 'axis' passed by position and keyword in function call
Loading history...
607
        # (batch, out_dim1, out_dim2, out_dim3, channels//stride)
608
        output = tf.reduce_sum(tf.stack(output, axis=5), axis=5)
609
        return output
610
611
612
class LocalNetResidual3dBlock(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
613
    def __init__(
614
        self,
615
        filters: int,
616
        kernel_size: (int, tuple) = 3,
617
        strides: (int, tuple) = 1,
618
        activation: str = "relu",
619
        **kwargs,
620
    ):
621
        """
622
        A resnet conv3d block, simpler than Residual3dBlock.
623
624
        1. conved = conv3d(inputs)
625
        2. out = act(norm(conved) + inputs)
626
627
        :param filters: number of channels of the output
628
        :param kernel_size: int or tuple of 3 ints, e.g. (3,3,3) or 3
629
        :param strides: int or tuple of 3 ints, e.g. (1,1,1) or 1
630
        :param activation: name of activation
631
        :param kwargs: additional arguments.
632
        """
633
        super().__init__(**kwargs)
634
        # init layer variables
635
        self._conv3d = tfkl.Conv3D(
636
            filters=filters,
637
            kernel_size=kernel_size,
638
            strides=strides,
639
            padding="same",
640
            use_bias=False,
641
        )
642
        self._norm = tfkl.BatchNormalization()
643
        self._act = tfkl.Activation(activation=activation)
644
645
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
646
        return self._act(
647
            self._norm(inputs=self._conv3d(inputs=inputs[0]), training=training)
648
            + inputs[1]
649
        )
650
651
652
class LocalNetUpSampleResnetBlock(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
653
    def __init__(self, filters: int, use_additive_upsampling: bool = True, **kwargs):
654
        """
655
        Layer up-samples tensor with two inputs (skipped and down-sampled).
656
657
        :param filters: int, number of output channels
658
        :param use_additive_upsampling: bool to used additive upsampling
659
        :param kwargs: additional arguments.
660
        """
661
        super().__init__(**kwargs)
662
        # save parameters
663
        self._filters = filters
664
        self._use_additive_upsampling = use_additive_upsampling
665
        # init layer variables
666
        self._deconv3d_block = None
667
        self._additive_upsampling = None
668
        self._conv3d_block = Conv3dBlock(filters=filters, kernel_size=3, padding="same")
669
        self._residual_block = LocalNetResidual3dBlock(filters=filters, strides=1)
670
671
    def build(self, input_shape):
0 ignored issues
show
introduced by
"input_shape" missing in parameter type documentation
Loading history...
672
        """
673
        :param input_shape: tuple (nonskip_tensor_shape, skip_tensor_shape)
674
        """
675
        super().build(input_shape)
676
677
        output_shape = input_shape[1][1:4]
678
        self._deconv3d_block = Deconv3dBlock(
679
            filters=self._filters,
680
            output_shape=output_shape,
681
            kernel_size=3,
682
            strides=2,
683
            padding="same",
684
        )
685
        if self._use_additive_upsampling:
686
            self._additive_upsampling = AdditiveUpSampling(output_shape=output_shape)
687
688
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
689
        """
690
        :param inputs: list = [inputs_nonskip, inputs_skip]
691
        :param training: training flag for normalization layers (default: None)
692
        :param kwargs: additional arguments.
693
        :return:
694
        """
695
        inputs_nonskip, inputs_skip = inputs[0], inputs[1]
696
        h0 = self._deconv3d_block(inputs=inputs_nonskip, training=training)
697
        if self._use_additive_upsampling:
698
            h0 += self._additive_upsampling(inputs=inputs_nonskip)
699
        r1 = h0 + inputs_skip
700
        r2 = self._conv3d_block(inputs=h0, training=training)
701
        h1 = self._residual_block(inputs=[r2, r1], training=training)
702
        return h1
703
704
705
class ResizeCPTransform(tfkl.Layer):
706
    """
707
    Layer for getting the control points from the output of a image-to-image network.
708
    It uses an anti-aliasing Gaussian filter before downsampling.
709
    """
710
711
    def __init__(self, control_point_spacing: (list, tuple, int), **kwargs):
712
        """
713
        :param control_point_spacing: list or int
714
        :param kwargs: additional arguments.
715
        """
716
        super().__init__(**kwargs)
717
718
        if isinstance(control_point_spacing, int):
719
            control_point_spacing = [control_point_spacing] * 3
720
721
        self.kernel_sigma = [
722
            0.44 * cp for cp in control_point_spacing
723
        ]  # 0.44 = ln(4)/pi
724
        self.cp_spacing = control_point_spacing
725
        self.kernel = None
726
        self._output_shape = None
727
        self._resize = None
728
729
    def build(self, input_shape):
730
        super().build(input_shape=input_shape)
731
732
        self.kernel = layer_util.gaussian_filter_3d(self.kernel_sigma)
733
        output_shape = [
734
            tf.cast(tf.math.ceil(v / c) + 3, tf.int32)
735
            for v, c in zip(input_shape[1:-1], self.cp_spacing)
736
        ]
737
        self._output_shape = output_shape
738
        self._resize = Resize3d(output_shape)
739
740
    def call(self, inputs, **kwargs) -> tf.Tensor:
741
        output = tf.nn.conv3d(
742
            inputs, self.kernel, strides=(1, 1, 1, 1, 1), padding="SAME"
743
        )
744
        output = self._resize(inputs=output)
745
        return output
746
747
748
class BSplines3DTransform(tfkl.Layer):
749
    """
750
     Layer for BSplines interpolation with precomputed cubic spline filters.
751
     It assumes a full sized image from which:
752
     1. it compute the contol points values by downsampling the initial image
753
     2. performs the interpolation
754
     3. crops the image around the valid values.
755
756
    :param cp_spacing: int or tuple of three ints specifying the spacing (in pixels)
757
        in each dimension. When a single int is used,
758
        the same spacing to all dimensions is used
759
    :param output_shape: (batch_size, dim0, dim1, dim2, 3) of the high resolution
760
        deformation fields.
761
    :param kwargs: additional arguments.
762
    """
763
764
    def __init__(self, cp_spacing: (int, tuple), output_shape: tuple, **kwargs):
765
766
        super().__init__(**kwargs)
767
768
        self.filters = []
769
        self._output_shape = output_shape
770
771
        if isinstance(cp_spacing, int):
772
            self.cp_spacing = (cp_spacing, cp_spacing, cp_spacing)
773
        else:
774
            self.cp_spacing = cp_spacing
775
776
    def build(self, input_shape: tuple):
0 ignored issues
show
introduced by
Redundant returns documentation
Loading history...
777
        """
778
        :param input_shape: tuple with the input shape
779
        :return: None
780
        """
781
782
        super().build(input_shape=input_shape)
783
784
        b = {
785
            0: lambda u: np.float64((1 - u) ** 3 / 6),
786
            1: lambda u: np.float64((3 * (u ** 3) - 6 * (u ** 2) + 4) / 6),
787
            2: lambda u: np.float64((-3 * (u ** 3) + 3 * (u ** 2) + 3 * u + 1) / 6),
788
            3: lambda u: np.float64(u ** 3 / 6),
789
        }
790
791
        filters = np.zeros(
792
            (
793
                4 * self.cp_spacing[0],
794
                4 * self.cp_spacing[1],
795
                4 * self.cp_spacing[2],
796
                3,
797
                3,
798
            ),
799
            dtype=np.float32,
800
        )
801
802
        u_arange = 1 - np.arange(
803
            1 / (2 * self.cp_spacing[0]), 1, 1 / self.cp_spacing[0]
804
        )
805
        v_arange = 1 - np.arange(
806
            1 / (2 * self.cp_spacing[1]), 1, 1 / self.cp_spacing[1]
807
        )
808
        w_arange = 1 - np.arange(
809
            1 / (2 * self.cp_spacing[2]), 1, 1 / self.cp_spacing[2]
810
        )
811
812
        filter_idx = [[0, 1, 2, 3] for _ in range(3)]
813
        filter_coord = list(itertools.product(*filter_idx))
814
815
        for f_idx in filter_coord:
816
            for it_dim in range(3):
817
                filters[
818
                    f_idx[0] * self.cp_spacing[0] : (f_idx[0] + 1) * self.cp_spacing[0],
819
                    f_idx[1] * self.cp_spacing[1] : (f_idx[1] + 1) * self.cp_spacing[1],
820
                    f_idx[2] * self.cp_spacing[2] : (f_idx[2] + 1) * self.cp_spacing[2],
821
                    it_dim,
822
                    it_dim,
823
                ] = (
824
                    b[f_idx[0]](u_arange)[:, None, None]
825
                    * b[f_idx[1]](v_arange)[None, :, None]
826
                    * b[f_idx[2]](w_arange)[None, None, :]
827
                )
828
829
        self.filter = tf.convert_to_tensor(filters)
830
831
    def interpolate(self, field) -> tf.Tensor:
0 ignored issues
show
introduced by
"field" missing in parameter type documentation
Loading history...
832
        """
833
        :param field: tf.Tensor with shape=number_of_control_points_per_dim
834
        :return: interpolated_field: tf.Tensor
835
        """
836
837
        image_shape = tuple(
838
            [(a - 1) * b + 4 * b for a, b in zip(field.shape[1:-1], self.cp_spacing)]
839
        )
840
841
        output_shape = (field.shape[0],) + image_shape + (3,)
842
        return tf.nn.conv3d_transpose(
843
            field,
844
            self.filter,
845
            output_shape=output_shape,
846
            strides=self.cp_spacing,
847
            padding="VALID",
848
        )
849
850
    def call(self, inputs, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"inputs" missing in parameter type documentation
Loading history...
851
        """
852
        :param inputs: tf.Tensor defining a low resolution free-form deformation field
853
        :param kwargs: additional arguments.
854
        :return: interpolated_field: tf.Tensor of shape=self.input_shape
855
        """
856
        high_res_field = self.interpolate(inputs)
857
858
        index = [int(3 * c) for c in self.cp_spacing]
859
        return high_res_field[
860
            :,
861
            index[0] : index[0] + self._output_shape[0],
862
            index[1] : index[1] + self._output_shape[1],
863
            index[2] : index[2] + self._output_shape[2],
864
        ]
865