Passed
Pull Request — main (#656)
by Yunguan
02:46
created

deepreg.model.layer   C

Complexity

Total Complexity 57

Size/Duplication

Total Lines 748
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 57
eloc 363
dl 0
loc 748
rs 5.04
c 0
b 0
f 0

36 Methods

Rating   Name   Duplication   Size   Complexity  
A AdditiveUpSampling.call() 0 14 2
A DownSampleResnetBlock.__init__() 0 34 3
A UpSampleResnetBlock.call() 0 23 2
A Deconv3dBlock.__init__() 0 33 1
A Residual3dBlock.__init__() 0 34 1
A UpSampleResnetBlock.build() 0 8 1
A LocalNetResidual3dBlock.__init__() 0 31 1
B BSplines3DTransform.build() 0 54 7
A Conv3dBlock.call() 0 11 1
A LocalNetUpSampleResnetBlock.__init__() 0 17 1
A BSplines3DTransform.__init__() 0 11 2
A Norm.call() 0 2 1
A Deconv3d.call() 0 2 1
A ResizeCPTransform.__init__() 0 16 2
A AdditiveUpSampling.__init__() 0 12 1
A UpSampleResnetBlock.__init__() 0 18 1
A Deconv3dBlock.call() 0 11 1
A LocalNetUpSampleResnetBlock.call() 0 15 2
A Norm.__init__() 0 15 3
A ResizeCPTransform.call() 0 5 1
A BSplines3DTransform.interpolate() 0 17 1
A Conv3dBlock.__init__() 0 30 1
A Deconv3d.__init__() 0 34 1
A Deconv3d.build() 0 36 4
A DownSampleResnetBlock.call() 0 18 2
A Warping.__init__() 0 18 1
A ResizeCPTransform.build() 0 9 1
A IntDVF.__init__() 0 16 1
A Conv3dWithResize.call() 0 9 1
A IntDVF.call() 0 10 2
A BSplines3DTransform.call() 0 14 1
A Conv3dWithResize.__init__() 0 28 1
A LocalNetUpSampleResnetBlock.build() 0 12 2
A LocalNetResidual3dBlock.call() 0 4 1
A Warping.call() 0 11 1
A Residual3dBlock.call() 0 13 1

How to fix   Complexity   

Complexity

Complex classes like deepreg.model.layer often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
import itertools
0 ignored issues
show
introduced by
Missing module docstring
Loading history...
2
3
import numpy as np
4
import tensorflow as tf
5
import tensorflow.keras.layers as tfkl
6
7
import deepreg.model.layer_util as layer_util
8
9
10
class Norm(tf.keras.layers.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
11
    def __init__(self, name: str = "batch_norm", axis: int = -1, **kwargs):
0 ignored issues
show
introduced by
"ValueError" not documented as being raised
Loading history...
12
        """
13
        Class merges batch norm and layer norm.
14
15
        :param name: str, batch_norm or layer_norm
16
        :param axis: int
17
        :param kwargs: additional arguments.
18
        """
19
        super().__init__(**kwargs)
20
        if name == "batch_norm":
21
            self._norm = tf.keras.layers.BatchNormalization(axis=axis, **kwargs)
22
        elif name == "layer_norm":
23
            self._norm = tf.keras.layers.LayerNormalization(axis=axis)
24
        else:
25
            raise ValueError("Unknown normalization type")
26
27
    def call(self, inputs, training=None, **kwargs):
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
28
        return self._norm(inputs=inputs, training=training)
29
30
31
class Deconv3d(tf.keras.layers.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
32
    def __init__(
33
        self,
34
        filters: int,
35
        output_shape: (tuple, None) = None,
36
        kernel_size: int = 3,
37
        strides: int = 1,
38
        padding: str = "same",
39
        use_bias: bool = True,
40
        **kwargs,
41
    ):
42
        """
43
        Layer wraps tf.keras.layers.Conv3DTranspose
44
        and does not requires input shape when initializing.
45
46
        :param filters: number of channels of the output
47
        :param output_shape: (out_dim1, out_dim2, out_dim3)
48
        :param kernel_size: int or tuple of 3 ints, e.g. (3,3,3) or 3
49
        :param strides: int or tuple of 3 ints, e.g. (1,1,1) or 1
50
        :param padding: same or valid.
51
        :param use_bias: use bias for Conv3DTranspose or not.
52
        :param kwargs: additional arguments.
53
        """
54
        super().__init__(**kwargs)
55
        # save parameters
56
        self._filters = filters
57
        self._output_shape = output_shape
58
        self._kernel_size = kernel_size
59
        self._strides = strides
60
        self._padding = padding
61
        self._use_bias = use_bias
62
        self._kwargs = kwargs
63
        # init layer variables
64
        self._output_padding = None
65
        self._deconv3d = None
66
67
    def build(self, input_shape):
68
        super().build(input_shape)
69
70
        if isinstance(self._kernel_size, int):
71
            self._kernel_size = [self._kernel_size] * 3
72
        if isinstance(self._strides, int):
73
            self._strides = [self._strides] * 3
74
75
        if self._output_shape is not None:
76
            # pylint: disable-next=line-too-long
77
            """
78
            https://github.com/tensorflow/tensorflow/blob/1cf0898dd4331baf93fe77205550f2c2e6c90ee5/tensorflow/python/keras/utils/conv_utils.py#L139-L185
79
            When the output shape is defined, the padding should be calculated manually
80
            if padding == 'same':
81
                pad = filter_size // 2
82
                length = ((input_length - 1) * stride + filter_size
83
                         - 2 * pad + output_padding)
84
            """
0 ignored issues
show
Unused Code introduced by
This string statement has no effect and could be removed.
Loading history...
85
            self._padding = "same"
86
            self._output_padding = [
87
                self._output_shape[i]
88
                - (
89
                    (input_shape[1 + i] - 1) * self._strides[i]
90
                    + self._kernel_size[i]
91
                    - 2 * (self._kernel_size[i] // 2)
92
                )
93
                for i in range(3)
94
            ]
95
        self._deconv3d = tf.keras.layers.Conv3DTranspose(
96
            filters=self._filters,
97
            kernel_size=self._kernel_size,
98
            strides=self._strides,
99
            padding=self._padding,
100
            output_padding=self._output_padding,
101
            use_bias=self._use_bias,
102
            **self._kwargs,
103
        )
104
105
    def call(self, inputs, **kwargs):
106
        return self._deconv3d(inputs=inputs)
107
108
109
class Conv3dBlock(tf.keras.layers.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
110
    def __init__(
111
        self,
112
        filters: int,
113
        kernel_size: (int, tuple) = 3,
114
        strides: (int, tuple) = 1,
115
        padding: str = "same",
116
        activation: str = "relu",
117
        **kwargs,
118
    ):
119
        """
120
        A conv3d block having conv3d - norm - activation.
121
122
        :param filters: number of channels of the output
123
        :param kernel_size: int or tuple of 3 ints, e.g. (3,3,3) or 3
124
        :param strides: int or tuple of 3 ints, e.g. (1,1,1) or 1
125
        :param padding: str, same or valid
126
        :param activation: name of activation
127
        :param kwargs: additional arguments.
128
        """
129
        super().__init__(**kwargs)
130
        # init layer variables
131
        self._conv3d = tfkl.Conv3D(
132
            filters=filters,
133
            kernel_size=kernel_size,
134
            strides=strides,
135
            padding=padding,
136
            use_bias=False,
137
        )
138
        self._norm = Norm()
139
        self._act = tfkl.Activation(activation=activation)
140
141
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
142
        """
143
        :param inputs: shape = (batch, in_dim1, in_dim2, in_dim3, channels)
144
        :param training: training flag for normalization layers (default: None)
145
        :param kwargs: additional arguments.
146
        :return: shape = (batch, in_dim1, in_dim2, in_dim3, channels)
147
        """
148
        output = self._conv3d(inputs=inputs)
149
        output = self._norm(inputs=output, training=training)
150
        output = self._act(output)
151
        return output
152
153
154
class Deconv3dBlock(tf.keras.layers.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
155
    def __init__(
156
        self,
157
        filters: int,
158
        output_shape: (tuple, None) = None,
159
        kernel_size: (int, tuple) = 3,
160
        strides: (int, tuple) = 1,
161
        padding: str = "same",
162
        activation: str = "relu",
163
        **kwargs,
164
    ):
165
        """
166
        A deconv3d block having deconv3d - norm - activation.
167
168
        :param filters: number of channels of the output
169
        :param output_shape: (out_dim1, out_dim2, out_dim3)
170
        :param kernel_size: int or tuple of 3 ints, e.g. (3,3,3) or 3
171
        :param strides: int or tuple of 3 ints, e.g. (1,1,1) or 1
172
        :param padding: str, same or valid
173
        :param activation: name of activation
174
        :param kwargs: additional arguments.
175
        """
176
        super().__init__(**kwargs)
177
        # init layer variables
178
        self._deconv3d = Deconv3d(
179
            filters=filters,
180
            output_shape=output_shape,
181
            kernel_size=kernel_size,
182
            strides=strides,
183
            padding=padding,
184
            use_bias=False,
185
        )
186
        self._norm = Norm()
187
        self._act = tfkl.Activation(activation=activation)
188
189
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
introduced by
Missing return documentation
Loading history...
190
        """
191
        :param inputs: shape = (batch, in_dim1, in_dim2, in_dim3, channels)
192
        :param training: training flag for normalization layers (default: None)
193
        :param kwargs: additional arguments.
194
        :return output: shape = (batch, in_dim1, in_dim2, in_dim3, channels)
195
        """
196
        output = self._deconv3d(inputs=inputs)
197
        output = self._norm(inputs=output, training=training)
198
        output = self._act(output)
199
        return output
200
201
202
class Residual3dBlock(tf.keras.layers.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
203
    def __init__(
204
        self,
205
        filters: int,
206
        kernel_size: (int, tuple) = 3,
207
        strides: (int, tuple) = 1,
208
        activation: str = "relu",
209
        **kwargs,
210
    ):
211
        """
212
        A resnet conv3d block.
213
214
        1. conved = conv3d(conv3d_block(inputs))
215
        2. out = act(norm(conved) + inputs)
216
217
        :param filters: int, number of filters in the convolutional layers
218
        :param kernel_size: int or tuple of 3 ints, e.g. (3,3,3) or 3
219
        :param strides: int or tuple of 3 ints, e.g. (1,1,1) or 1
220
        :param activation: name of activation
221
        :param kwargs: additional arguments.
222
        """
223
        super().__init__(**kwargs)
224
        # init layer variables
225
        self._conv3d_block = Conv3dBlock(
226
            filters=filters, kernel_size=kernel_size, strides=strides
227
        )
228
        self._conv3d = tfkl.Conv3D(
229
            filters=filters,
230
            kernel_size=kernel_size,
231
            strides=strides,
232
            padding="same",
233
            use_bias=False,
234
        )
235
        self._norm = Norm()
236
        self._act = tfkl.Activation(activation=activation)
237
238
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
introduced by
Missing return documentation
Loading history...
239
        """
240
        :param inputs: shape = (batch, in_dim1, in_dim2, in_dim3, channels)
241
        :param training: training flag for normalization layers (default: None)
242
        :param kwargs: additional arguments.
243
        :return output: shape = (batch, in_dim1, in_dim2, in_dim3, channels)
244
        """
245
        return self._act(
246
            self._norm(
247
                inputs=self._conv3d(inputs=self._conv3d_block(inputs)),
248
                training=training,
249
            )
250
            + inputs
251
        )
252
253
254
class DownSampleResnetBlock(tf.keras.layers.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
255
    def __init__(
256
        self,
257
        filters: int,
258
        kernel_size: (int, tuple) = 3,
259
        pooling: bool = True,
260
        **kwargs,
261
    ):
262
        """
263
        A down-sampling resnet conv3d block, with max-pooling or conv3d.
264
265
        1. conved = conv3d_block(inputs)  # adjust channel
266
        2. skip = residual_block(conved)  # develop feature
267
        3. pooled = pool(skip) # down-sample
268
269
        :param filters: number of channels of the output
270
        :param kernel_size: int or tuple of 3 ints, e.g. (3,3,3) or 3
271
        :param pooling: if True, use max pooling to downsample, otherwise use conv.
272
        :param kwargs: additional arguments.
273
        """
274
        super().__init__(**kwargs)
275
        # save parameters
276
        self._pooling = pooling
277
        # init layer variables
278
        self._conv3d_block = Conv3dBlock(filters=filters, kernel_size=kernel_size)
279
        self._residual_block = Residual3dBlock(filters=filters, kernel_size=kernel_size)
280
        self._max_pool3d = (
281
            tfkl.MaxPool3D(pool_size=(2, 2, 2), strides=(2, 2, 2), padding="same")
282
            if pooling
283
            else None
284
        )
285
        self._conv3d_block3 = (
286
            None
287
            if pooling
288
            else Conv3dBlock(filters=filters, kernel_size=kernel_size, strides=2)
289
        )
290
291
    def call(self, inputs, training=None, **kwargs) -> (tf.Tensor, tf.Tensor):
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
292
        """
293
        :param inputs: shape = (batch, in_dim1, in_dim2, in_dim3, channels)
294
        :param training: training flag for normalization layers (default: None)
295
        :param kwargs: additional arguments.
296
        :return: (pooled, skip)
297
298
          - downsampled, shape = (batch, in_dim1//2, in_dim2//2, in_dim3//2, channels)
299
          - skipped, shape = (batch, in_dim1, in_dim2, in_dim3, channels)
300
        """
301
        conved = self._conv3d_block(inputs=inputs, training=training)  # adjust channel
302
        skip = self._residual_block(inputs=conved, training=training)  # develop feature
303
        pooled = (
304
            self._max_pool3d(inputs=skip)
305
            if self._pooling
306
            else self._conv3d_block3(inputs=skip, training=training)
0 ignored issues
show
Bug introduced by
self._conv3d_block3 does not seem to be callable.
Loading history...
307
        )  # downsample
308
        return pooled, skip
309
310
311
class UpSampleResnetBlock(tf.keras.layers.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
312
    def __init__(self, filters, kernel_size=3, concat=False, **kwargs):
0 ignored issues
show
introduced by
"concat, filters, kernel_size" missing in parameter type documentation
Loading history...
313
        """
314
        An up-sampling resnet conv3d block, with deconv3d.
315
316
        :param filters: number of channels of the output
317
        :param kernel_size: int or tuple of 3 ints, e.g. (3,3,3) or 3
318
        :param concat: bool,specify how to combine input and skip connection images.
319
            If True, use concatenation, otherwise use sum (default=False).
320
        :param kwargs: additional arguments.
321
        """
322
        super().__init__(**kwargs)
323
        # save parameters
324
        self._filters = filters
325
        self._concat = concat
326
        # init layer variables
327
        self._deconv3d_block = None
328
        self._conv3d_block = Conv3dBlock(filters=filters, kernel_size=kernel_size)
329
        self._residual_block = Residual3dBlock(filters=filters, kernel_size=kernel_size)
330
331
    def build(self, input_shape):
0 ignored issues
show
introduced by
"input_shape" missing in parameter type documentation
Loading history...
332
        """
333
        :param input_shape: tuple, (downsampled_image_shape, skip_image_shape)
334
        """
335
        super().build(input_shape)
336
        skip_shape = input_shape[1][1:4]
337
        self._deconv3d_block = Deconv3dBlock(
338
            filters=self._filters, output_shape=skip_shape, strides=2
339
        )
340
341
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
342
        r"""
343
        :param inputs: tuple
344
345
          - down-sampled
346
          - skipped
347
348
        :param training: training flag for normalization layers (default: None)
349
        :param kwargs: additional arguments.
350
        :return: shape = (batch, \*skip_connection_image_shape, filters]
351
        """
352
        up_sampled, skip = inputs[0], inputs[1]
353
        up_sampled = self._deconv3d_block(
354
            inputs=up_sampled, training=training
355
        )  # up sample and change channel
356
        up_sampled = (
357
            tf.concat([up_sampled, skip], axis=4) if self._concat else up_sampled + skip
358
        )  # combine
359
        up_sampled = self._conv3d_block(
360
            inputs=up_sampled, training=training
361
        )  # adjust channel
362
        up_sampled = self._residual_block(inputs=up_sampled, training=training)  # conv
363
        return up_sampled
364
365
366
class Conv3dWithResize(tf.keras.layers.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
367
    def __init__(
368
        self,
369
        output_shape: tuple,
370
        filters: int,
371
        kernel_initializer: str = "glorot_uniform",
372
        activation: (str, None) = None,
373
        **kwargs,
374
    ):
375
        """
376
        A layer contains conv3d - resize3d.
377
378
        :param output_shape: tuple, (out_dim1, out_dim2, out_dim3)
379
        :param filters: int, number of channels of the output
380
        :param kernel_initializer: str, defines the initialization method
381
        :param activation: str, defines the activation function
382
        :param kwargs: additional arguments.
383
        """
384
        super().__init__(**kwargs)
385
        # save parameters
386
        self._output_shape = output_shape
387
        # init layer variables
388
        self._conv3d = tfkl.Conv3D(
389
            filters=filters,
390
            kernel_size=3,
391
            strides=1,
392
            padding="same",
393
            kernel_initializer=kernel_initializer,
394
            activation=activation,
395
        )  # if not zero, with init NN, ddf may be too large
396
397
    def call(self, inputs, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"inputs" missing in parameter type documentation
Loading history...
398
        """
399
        :param inputs: shape = (batch, dim1, dim2, dim3, channels)
400
        :param kwargs: additional arguments.
401
        :return: shape = (batch, out_dim1, out_dim2, out_dim3, channels)
402
        """
403
        output = self._conv3d(inputs=inputs)
404
        output = layer_util.resize3d(image=output, size=self._output_shape)
405
        return output
406
407
408
class Warping(tf.keras.layers.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
409
    def __init__(self, fixed_image_size: tuple, **kwargs):
410
        """
411
        A layer warps an image using DDF.
412
413
        Reference:
414
415
        - transform of neuron
416
          https://github.com/adalca/neurite/blob/legacy/neuron/utils.py
417
418
          where vol = image, loc_shift = ddf
419
420
        :param fixed_image_size: shape = (f_dim1, f_dim2, f_dim3)
421
             or (f_dim1, f_dim2, f_dim3, ch) with the last channel for features
422
        :param kwargs: additional arguments.
423
        """
424
        super().__init__(**kwargs)
425
        self.grid_ref = tf.expand_dims(
426
            layer_util.get_reference_grid(grid_size=fixed_image_size), axis=0
427
        )  # shape = (1, f_dim1, f_dim2, f_dim3, 3)
428
429
    def call(self, inputs, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"inputs" missing in parameter type documentation
Loading history...
430
        """
431
        :param inputs: (ddf, image)
432
433
          - ddf, shape = (batch, f_dim1, f_dim2, f_dim3, 3), dtype = float32
434
          - image, shape = (batch, m_dim1, m_dim2, m_dim3), dtype = float32
435
        :param kwargs: additional arguments.
436
        :return: shape = (batch, f_dim1, f_dim2, f_dim3)
437
        """
438
        return layer_util.warp_image_ddf(
439
            image=inputs[1], ddf=inputs[0], grid_ref=self.grid_ref
440
        )
441
442
443
class IntDVF(tf.keras.layers.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
444
    def __init__(self, fixed_image_size: tuple, num_steps: int = 7, **kwargs):
445
        """
446
        Layer calculates DVF from DDF.
447
448
        Reference:
449
450
        - integrate_vec of neuron
451
          https://github.com/adalca/neurite/blob/legacy/neuron/utils.py
452
453
        :param fixed_image_size: tuple, (f_dim1, f_dim2, f_dim3)
454
        :param num_steps: int, number of steps for integration
455
        :param kwargs: additional arguments.
456
        """
457
        super().__init__(**kwargs)
458
        self._warping = Warping(fixed_image_size=fixed_image_size)
459
        self._num_steps = num_steps
460
461
    def call(self, inputs, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"inputs" missing in parameter type documentation
Loading history...
462
        """
463
        :param inputs: dvf, shape = (batch, f_dim1, f_dim2, f_dim3, 3), type = float32
464
        :param kwargs: additional arguments.
465
        :return: ddf, shape = (batch, f_dim1, f_dim2, f_dim3, 3)
466
        """
467
        ddf = inputs / (2 ** self._num_steps)
468
        for _ in range(self._num_steps):
469
            ddf += self._warping(inputs=[ddf, ddf])
470
        return ddf
471
472
473
class AdditiveUpSampling(tf.keras.layers.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
474
    def __init__(self, output_shape: tuple, stride: (int, list) = 2, **kwargs):
475
        """
476
        Layer up-samples 3d tensor and reduce channels using split and sum.
477
478
        :param output_shape: (out_dim1, out_dim2, out_dim3)
479
        :param stride: int, 1-D Tensor or list
480
        :param kwargs: additional arguments.
481
        """
482
        super().__init__(**kwargs)
483
        # save parameters
484
        self._stride = stride
485
        self._output_shape = output_shape
486
487
    def call(self, inputs, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"inputs" missing in parameter type documentation
Loading history...
introduced by
"ValueError" not documented as being raised
Loading history...
488
        """
489
        :param inputs: shape = (batch, dim1, dim2, dim3, channels)
490
        :param kwargs: additional arguments.
491
        :return: shape = (batch, out_dim1, out_dim2, out_dim3, channels//stride]
492
        """
493
        if inputs.shape[4] % self._stride != 0:
494
            raise ValueError("The channel dimension can not be divided by the stride")
495
        output = layer_util.resize3d(image=inputs, size=self._output_shape)
496
        # a list of (batch, out_dim1, out_dim2, out_dim3, channels//stride)
497
        output = tf.split(output, num_or_size_splits=self._stride, axis=4)
0 ignored issues
show
Unused Code introduced by
Argument 'axis' passed by position and keyword in function call
Loading history...
498
        # (batch, out_dim1, out_dim2, out_dim3, channels//stride)
499
        output = tf.reduce_sum(tf.stack(output, axis=5), axis=5)
500
        return output
501
502
503
class LocalNetResidual3dBlock(tf.keras.layers.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
504
    def __init__(
505
        self,
506
        filters: int,
507
        kernel_size: (int, tuple) = 3,
508
        strides: (int, tuple) = 1,
509
        activation: str = "relu",
510
        **kwargs,
511
    ):
512
        """
513
        A resnet conv3d block, simpler than Residual3dBlock.
514
515
        1. conved = conv3d(inputs)
516
        2. out = act(norm(conved) + inputs)
517
518
        :param filters: number of channels of the output
519
        :param kernel_size: int or tuple of 3 ints, e.g. (3,3,3) or 3
520
        :param strides: int or tuple of 3 ints, e.g. (1,1,1) or 1
521
        :param activation: name of activation
522
        :param kwargs: additional arguments.
523
        """
524
        super().__init__(**kwargs)
525
        # init layer variables
526
        self._conv3d = tfkl.Conv3D(
527
            filters=filters,
528
            kernel_size=kernel_size,
529
            strides=strides,
530
            padding="same",
531
            use_bias=False,
532
        )
533
        self._norm = Norm()
534
        self._act = tfkl.Activation(activation=activation)
535
536
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
537
        return self._act(
538
            self._norm(inputs=self._conv3d(inputs=inputs[0]), training=training)
539
            + inputs[1]
540
        )
541
542
543
class LocalNetUpSampleResnetBlock(tf.keras.layers.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
544
    def __init__(self, filters: int, use_additive_upsampling: bool = True, **kwargs):
545
        """
546
        Layer up-samples tensor with two inputs (skipped and down-sampled).
547
548
        :param filters: int, number of output channels
549
        :param use_additive_upsampling: bool to used additive upsampling
550
        :param kwargs: additional arguments.
551
        """
552
        super().__init__(**kwargs)
553
        # save parameters
554
        self._filters = filters
555
        self._use_additive_upsampling = use_additive_upsampling
556
        # init layer variables
557
        self._deconv3d_block = None
558
        self._additive_upsampling = None
559
        self._conv3d_block = Conv3dBlock(filters=filters)
560
        self._residual_block = LocalNetResidual3dBlock(filters=filters, strides=1)
561
562
    def build(self, input_shape):
0 ignored issues
show
introduced by
"input_shape" missing in parameter type documentation
Loading history...
563
        """
564
        :param input_shape: tuple (nonskip_tensor_shape, skip_tensor_shape)
565
        """
566
        super().build(input_shape)
567
568
        output_shape = input_shape[1][1:4]
569
        self._deconv3d_block = Deconv3dBlock(
570
            filters=self._filters, output_shape=output_shape, strides=2
571
        )
572
        if self._use_additive_upsampling:
573
            self._additive_upsampling = AdditiveUpSampling(output_shape=output_shape)
574
575
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
576
        """
577
        :param inputs: list = [inputs_nonskip, inputs_skip]
578
        :param training: training flag for normalization layers (default: None)
579
        :param kwargs: additional arguments.
580
        :return:
581
        """
582
        inputs_nonskip, inputs_skip = inputs[0], inputs[1]
583
        h0 = self._deconv3d_block(inputs=inputs_nonskip, training=training)
584
        if self._use_additive_upsampling:
585
            h0 += self._additive_upsampling(inputs=inputs_nonskip)
586
        r1 = h0 + inputs_skip
587
        r2 = self._conv3d_block(inputs=h0, training=training)
588
        h1 = self._residual_block(inputs=[r2, r1], training=training)
589
        return h1
590
591
592
class ResizeCPTransform(tf.keras.layers.Layer):
593
    """
594
    Layer for getting the control points from the output of a image-to-image network.
595
    It uses an anti-aliasing Gaussian filter before downsampling.
596
    """
597
598
    def __init__(self, control_point_spacing: (list, tuple, int), **kwargs):
599
        """
600
        :param control_point_spacing: list or int
601
        :param kwargs: additional arguments.
602
        """
603
        super().__init__(**kwargs)
604
605
        if isinstance(control_point_spacing, int):
606
            control_point_spacing = [control_point_spacing] * 3
607
608
        self.kernel_sigma = [
609
            0.44 * cp for cp in control_point_spacing
610
        ]  # 0.44 = ln(4)/pi
611
        self.cp_spacing = control_point_spacing
612
        self.kernel = None
613
        self._output_shape = None
614
615
    def build(self, input_shape):
616
        super().build(input_shape=input_shape)
617
618
        self.kernel = layer_util.gaussian_filter_3d(self.kernel_sigma)
619
        output_shape = [
620
            tf.cast(tf.math.ceil(v / c) + 3, tf.int32)
621
            for v, c in zip(input_shape[1:-1], self.cp_spacing)
622
        ]
623
        self._output_shape = output_shape
624
625
    def call(self, inputs, **kwargs) -> tf.Tensor:
626
        output = tf.nn.conv3d(
627
            inputs, self.kernel, strides=(1, 1, 1, 1, 1), padding="SAME"
628
        )
629
        return layer_util.resize3d(image=output, size=self._output_shape)
630
631
632
class BSplines3DTransform(tf.keras.layers.Layer):
633
    """
634
     Layer for BSplines interpolation with precomputed cubic spline filters.
635
     It assumes a full sized image from which:
636
     1. it compute the contol points values by downsampling the initial image
637
     2. performs the interpolation
638
     3. crops the image around the valid values.
639
640
    :param cp_spacing: int or tuple of three ints specifying the spacing (in pixels)
641
        in each dimension. When a single int is used,
642
        the same spacing to all dimensions is used
643
    :param output_shape: (batch_size, dim0, dim1, dim2, 3) of the high resolution
644
        deformation fields.
645
    :param kwargs: additional arguments.
646
    """
647
648
    def __init__(self, cp_spacing: (int, tuple), output_shape: tuple, **kwargs):
649
650
        super().__init__(**kwargs)
651
652
        self.filters = []
653
        self._output_shape = output_shape
654
655
        if isinstance(cp_spacing, int):
656
            self.cp_spacing = (cp_spacing, cp_spacing, cp_spacing)
657
        else:
658
            self.cp_spacing = cp_spacing
659
660
    def build(self, input_shape: tuple):
0 ignored issues
show
introduced by
Redundant returns documentation
Loading history...
661
        """
662
        :param input_shape: tuple with the input shape
663
        :return: None
664
        """
665
666
        super().build(input_shape=input_shape)
667
668
        b = {
669
            0: lambda u: np.float64((1 - u) ** 3 / 6),
670
            1: lambda u: np.float64((3 * (u ** 3) - 6 * (u ** 2) + 4) / 6),
671
            2: lambda u: np.float64((-3 * (u ** 3) + 3 * (u ** 2) + 3 * u + 1) / 6),
672
            3: lambda u: np.float64(u ** 3 / 6),
673
        }
674
675
        filters = np.zeros(
676
            (
677
                4 * self.cp_spacing[0],
678
                4 * self.cp_spacing[1],
679
                4 * self.cp_spacing[2],
680
                3,
681
                3,
682
            ),
683
            dtype=np.float32,
684
        )
685
686
        u_arange = 1 - np.arange(
687
            1 / (2 * self.cp_spacing[0]), 1, 1 / self.cp_spacing[0]
688
        )
689
        v_arange = 1 - np.arange(
690
            1 / (2 * self.cp_spacing[1]), 1, 1 / self.cp_spacing[1]
691
        )
692
        w_arange = 1 - np.arange(
693
            1 / (2 * self.cp_spacing[2]), 1, 1 / self.cp_spacing[2]
694
        )
695
696
        filter_idx = [[0, 1, 2, 3] for _ in range(3)]
697
        filter_coord = list(itertools.product(*filter_idx))
698
699
        for f_idx in filter_coord:
700
            for it_dim in range(3):
701
                filters[
702
                    f_idx[0] * self.cp_spacing[0] : (f_idx[0] + 1) * self.cp_spacing[0],
703
                    f_idx[1] * self.cp_spacing[1] : (f_idx[1] + 1) * self.cp_spacing[1],
704
                    f_idx[2] * self.cp_spacing[2] : (f_idx[2] + 1) * self.cp_spacing[2],
705
                    it_dim,
706
                    it_dim,
707
                ] = (
708
                    b[f_idx[0]](u_arange)[:, None, None]
709
                    * b[f_idx[1]](v_arange)[None, :, None]
710
                    * b[f_idx[2]](w_arange)[None, None, :]
711
                )
712
713
        self.filter = tf.convert_to_tensor(filters)
714
715
    def interpolate(self, field) -> tf.Tensor:
0 ignored issues
show
introduced by
"field" missing in parameter type documentation
Loading history...
716
        """
717
        :param field: tf.Tensor with shape=number_of_control_points_per_dim
718
        :return: interpolated_field: tf.Tensor
719
        """
720
721
        image_shape = tuple(
722
            [(a - 1) * b + 4 * b for a, b in zip(field.shape[1:-1], self.cp_spacing)]
723
        )
724
725
        output_shape = (field.shape[0],) + image_shape + (3,)
726
        return tf.nn.conv3d_transpose(
727
            field,
728
            self.filter,
729
            output_shape=output_shape,
730
            strides=self.cp_spacing,
731
            padding="VALID",
732
        )
733
734
    def call(self, inputs, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"inputs" missing in parameter type documentation
Loading history...
735
        """
736
        :param inputs: tf.Tensor defining a low resolution free-form deformation field
737
        :param kwargs: additional arguments.
738
        :return: interpolated_field: tf.Tensor of shape=self.input_shape
739
        """
740
        high_res_field = self.interpolate(inputs)
741
742
        index = [int(3 * c) for c in self.cp_spacing]
743
        return high_res_field[
744
            :,
745
            index[0] : index[0] + self._output_shape[0],
746
            index[1] : index[1] + self._output_shape[1],
747
            index[2] : index[2] + self._output_shape[2],
748
        ]
749