Passed
Pull Request — main (#656)
by Yunguan
02:47
created

deepreg.model.layer   C

Complexity

Total Complexity 53

Size/Duplication

Total Lines 727
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 53
eloc 353
dl 0
loc 727
rs 6.96
c 0
b 0
f 0

34 Methods

Rating   Name   Duplication   Size   Complexity  
A AdditiveUpSampling.call() 0 14 2
A DownSampleResnetBlock.__init__() 0 34 3
A Deconv3dBlock.__init__() 0 33 1
A UpSampleResnetBlock.call() 0 23 2
A Residual3dBlock.__init__() 0 34 1
A UpSampleResnetBlock.build() 0 8 1
A LocalNetResidual3dBlock.__init__() 0 31 1
B BSplines3DTransform.build() 0 54 7
A Conv3dBlock.call() 0 11 1
A LocalNetUpSampleResnetBlock.__init__() 0 17 1
A BSplines3DTransform.__init__() 0 11 2
A Deconv3d.call() 0 2 1
A ResizeCPTransform.__init__() 0 16 2
A AdditiveUpSampling.__init__() 0 12 1
A UpSampleResnetBlock.__init__() 0 18 1
A Deconv3dBlock.call() 0 11 1
A LocalNetUpSampleResnetBlock.call() 0 15 2
A ResizeCPTransform.call() 0 5 1
A BSplines3DTransform.interpolate() 0 17 1
A Conv3dBlock.__init__() 0 30 1
A Deconv3d.__init__() 0 34 1
A Deconv3d.build() 0 36 4
A DownSampleResnetBlock.call() 0 18 2
A Warping.__init__() 0 18 1
A ResizeCPTransform.build() 0 9 1
A IntDVF.__init__() 0 16 1
A Conv3dWithResize.call() 0 9 1
A IntDVF.call() 0 10 2
A BSplines3DTransform.call() 0 14 1
A Conv3dWithResize.__init__() 0 28 1
A LocalNetUpSampleResnetBlock.build() 0 12 2
A LocalNetResidual3dBlock.call() 0 4 1
A Warping.call() 0 11 1
A Residual3dBlock.call() 0 13 1

How to fix   Complexity   

Complexity

Complex classes like deepreg.model.layer often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
import itertools
0 ignored issues
show
introduced by
Missing module docstring
Loading history...
2
3
import numpy as np
4
import tensorflow as tf
5
import tensorflow.keras.layers as tfkl
6
7
import deepreg.model.layer_util as layer_util
8
9
10
class Deconv3d(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
11
    def __init__(
12
        self,
13
        filters: int,
14
        output_shape: (tuple, None) = None,
15
        kernel_size: int = 3,
16
        strides: int = 1,
17
        padding: str = "same",
18
        use_bias: bool = True,
19
        **kwargs,
20
    ):
21
        """
22
        Layer wraps tfkl.Conv3DTranspose
23
        and does not requires input shape when initializing.
24
25
        :param filters: number of channels of the output
26
        :param output_shape: (out_dim1, out_dim2, out_dim3)
27
        :param kernel_size: int or tuple of 3 ints, e.g. (3,3,3) or 3
28
        :param strides: int or tuple of 3 ints, e.g. (1,1,1) or 1
29
        :param padding: same or valid.
30
        :param use_bias: use bias for Conv3DTranspose or not.
31
        :param kwargs: additional arguments.
32
        """
33
        super().__init__(**kwargs)
34
        # save parameters
35
        self._filters = filters
36
        self._output_shape = output_shape
37
        self._kernel_size = kernel_size
38
        self._strides = strides
39
        self._padding = padding
40
        self._use_bias = use_bias
41
        self._kwargs = kwargs
42
        # init layer variables
43
        self._output_padding = None
44
        self._deconv3d = None
45
46
    def build(self, input_shape):
47
        super().build(input_shape)
48
49
        if isinstance(self._kernel_size, int):
50
            self._kernel_size = [self._kernel_size] * 3
51
        if isinstance(self._strides, int):
52
            self._strides = [self._strides] * 3
53
54
        if self._output_shape is not None:
55
            # pylint: disable-next=line-too-long
56
            """
57
            https://github.com/tensorflow/tensorflow/blob/1cf0898dd4331baf93fe77205550f2c2e6c90ee5/tensorflow/python/keras/utils/conv_utils.py#L139-L185
58
            When the output shape is defined, the padding should be calculated manually
59
            if padding == 'same':
60
                pad = filter_size // 2
61
                length = ((input_length - 1) * stride + filter_size
62
                         - 2 * pad + output_padding)
63
            """
0 ignored issues
show
Unused Code introduced by
This string statement has no effect and could be removed.
Loading history...
64
            self._padding = "same"
65
            self._output_padding = [
66
                self._output_shape[i]
67
                - (
68
                    (input_shape[1 + i] - 1) * self._strides[i]
69
                    + self._kernel_size[i]
70
                    - 2 * (self._kernel_size[i] // 2)
71
                )
72
                for i in range(3)
73
            ]
74
        self._deconv3d = tfkl.Conv3DTranspose(
75
            filters=self._filters,
76
            kernel_size=self._kernel_size,
77
            strides=self._strides,
78
            padding=self._padding,
79
            output_padding=self._output_padding,
80
            use_bias=self._use_bias,
81
            **self._kwargs,
82
        )
83
84
    def call(self, inputs, **kwargs):
85
        return self._deconv3d(inputs=inputs)
86
87
88
class Conv3dBlock(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
89
    def __init__(
90
        self,
91
        filters: int,
92
        kernel_size: (int, tuple) = 3,
93
        strides: (int, tuple) = 1,
94
        padding: str = "same",
95
        activation: str = "relu",
96
        **kwargs,
97
    ):
98
        """
99
        A conv3d block having conv3d - norm - activation.
100
101
        :param filters: number of channels of the output
102
        :param kernel_size: int or tuple of 3 ints, e.g. (3,3,3) or 3
103
        :param strides: int or tuple of 3 ints, e.g. (1,1,1) or 1
104
        :param padding: str, same or valid
105
        :param activation: name of activation
106
        :param kwargs: additional arguments.
107
        """
108
        super().__init__(**kwargs)
109
        # init layer variables
110
        self._conv3d = tfkl.Conv3D(
111
            filters=filters,
112
            kernel_size=kernel_size,
113
            strides=strides,
114
            padding=padding,
115
            use_bias=False,
116
        )
117
        self._norm = tfkl.BatchNormalization()
118
        self._act = tfkl.Activation(activation=activation)
119
120
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
121
        """
122
        :param inputs: shape = (batch, in_dim1, in_dim2, in_dim3, channels)
123
        :param training: training flag for normalization layers (default: None)
124
        :param kwargs: additional arguments.
125
        :return: shape = (batch, in_dim1, in_dim2, in_dim3, channels)
126
        """
127
        output = self._conv3d(inputs=inputs)
128
        output = self._norm(inputs=output, training=training)
129
        output = self._act(output)
130
        return output
131
132
133
class Deconv3dBlock(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
134
    def __init__(
135
        self,
136
        filters: int,
137
        output_shape: (tuple, None) = None,
138
        kernel_size: (int, tuple) = 3,
139
        strides: (int, tuple) = 1,
140
        padding: str = "same",
141
        activation: str = "relu",
142
        **kwargs,
143
    ):
144
        """
145
        A deconv3d block having deconv3d - norm - activation.
146
147
        :param filters: number of channels of the output
148
        :param output_shape: (out_dim1, out_dim2, out_dim3)
149
        :param kernel_size: int or tuple of 3 ints, e.g. (3,3,3) or 3
150
        :param strides: int or tuple of 3 ints, e.g. (1,1,1) or 1
151
        :param padding: str, same or valid
152
        :param activation: name of activation
153
        :param kwargs: additional arguments.
154
        """
155
        super().__init__(**kwargs)
156
        # init layer variables
157
        self._deconv3d = Deconv3d(
158
            filters=filters,
159
            output_shape=output_shape,
160
            kernel_size=kernel_size,
161
            strides=strides,
162
            padding=padding,
163
            use_bias=False,
164
        )
165
        self._norm = tfkl.BatchNormalization()
166
        self._act = tfkl.Activation(activation=activation)
167
168
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
introduced by
Missing return documentation
Loading history...
169
        """
170
        :param inputs: shape = (batch, in_dim1, in_dim2, in_dim3, channels)
171
        :param training: training flag for normalization layers (default: None)
172
        :param kwargs: additional arguments.
173
        :return output: shape = (batch, in_dim1, in_dim2, in_dim3, channels)
174
        """
175
        output = self._deconv3d(inputs=inputs)
176
        output = self._norm(inputs=output, training=training)
177
        output = self._act(output)
178
        return output
179
180
181
class Residual3dBlock(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
182
    def __init__(
183
        self,
184
        filters: int,
185
        kernel_size: (int, tuple) = 3,
186
        strides: (int, tuple) = 1,
187
        activation: str = "relu",
188
        **kwargs,
189
    ):
190
        """
191
        A resnet conv3d block.
192
193
        1. conved = conv3d(conv3d_block(inputs))
194
        2. out = act(norm(conved) + inputs)
195
196
        :param filters: int, number of filters in the convolutional layers
197
        :param kernel_size: int or tuple of 3 ints, e.g. (3,3,3) or 3
198
        :param strides: int or tuple of 3 ints, e.g. (1,1,1) or 1
199
        :param activation: name of activation
200
        :param kwargs: additional arguments.
201
        """
202
        super().__init__(**kwargs)
203
        # init layer variables
204
        self._conv3d_block = Conv3dBlock(
205
            filters=filters, kernel_size=kernel_size, strides=strides
206
        )
207
        self._conv3d = tfkl.Conv3D(
208
            filters=filters,
209
            kernel_size=kernel_size,
210
            strides=strides,
211
            padding="same",
212
            use_bias=False,
213
        )
214
        self._norm = tfkl.BatchNormalization()
215
        self._act = tfkl.Activation(activation=activation)
216
217
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
introduced by
Missing return documentation
Loading history...
218
        """
219
        :param inputs: shape = (batch, in_dim1, in_dim2, in_dim3, channels)
220
        :param training: training flag for normalization layers (default: None)
221
        :param kwargs: additional arguments.
222
        :return output: shape = (batch, in_dim1, in_dim2, in_dim3, channels)
223
        """
224
        return self._act(
225
            self._norm(
226
                inputs=self._conv3d(inputs=self._conv3d_block(inputs)),
227
                training=training,
228
            )
229
            + inputs
230
        )
231
232
233
class DownSampleResnetBlock(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
234
    def __init__(
235
        self,
236
        filters: int,
237
        kernel_size: (int, tuple) = 3,
238
        pooling: bool = True,
239
        **kwargs,
240
    ):
241
        """
242
        A down-sampling resnet conv3d block, with max-pooling or conv3d.
243
244
        1. conved = conv3d_block(inputs)  # adjust channel
245
        2. skip = residual_block(conved)  # develop feature
246
        3. pooled = pool(skip) # down-sample
247
248
        :param filters: number of channels of the output
249
        :param kernel_size: int or tuple of 3 ints, e.g. (3,3,3) or 3
250
        :param pooling: if True, use max pooling to downsample, otherwise use conv.
251
        :param kwargs: additional arguments.
252
        """
253
        super().__init__(**kwargs)
254
        # save parameters
255
        self._pooling = pooling
256
        # init layer variables
257
        self._conv3d_block = Conv3dBlock(filters=filters, kernel_size=kernel_size)
258
        self._residual_block = Residual3dBlock(filters=filters, kernel_size=kernel_size)
259
        self._max_pool3d = (
260
            tfkl.MaxPool3D(pool_size=(2, 2, 2), strides=(2, 2, 2), padding="same")
261
            if pooling
262
            else None
263
        )
264
        self._conv3d_block3 = (
265
            None
266
            if pooling
267
            else Conv3dBlock(filters=filters, kernel_size=kernel_size, strides=2)
268
        )
269
270
    def call(self, inputs, training=None, **kwargs) -> (tf.Tensor, tf.Tensor):
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
271
        """
272
        :param inputs: shape = (batch, in_dim1, in_dim2, in_dim3, channels)
273
        :param training: training flag for normalization layers (default: None)
274
        :param kwargs: additional arguments.
275
        :return: (pooled, skip)
276
277
          - downsampled, shape = (batch, in_dim1//2, in_dim2//2, in_dim3//2, channels)
278
          - skipped, shape = (batch, in_dim1, in_dim2, in_dim3, channels)
279
        """
280
        conved = self._conv3d_block(inputs=inputs, training=training)  # adjust channel
281
        skip = self._residual_block(inputs=conved, training=training)  # develop feature
282
        pooled = (
283
            self._max_pool3d(inputs=skip)
284
            if self._pooling
285
            else self._conv3d_block3(inputs=skip, training=training)
0 ignored issues
show
Bug introduced by
self._conv3d_block3 does not seem to be callable.
Loading history...
286
        )  # downsample
287
        return pooled, skip
288
289
290
class UpSampleResnetBlock(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
291
    def __init__(self, filters, kernel_size=3, concat=False, **kwargs):
0 ignored issues
show
introduced by
"concat, filters, kernel_size" missing in parameter type documentation
Loading history...
292
        """
293
        An up-sampling resnet conv3d block, with deconv3d.
294
295
        :param filters: number of channels of the output
296
        :param kernel_size: int or tuple of 3 ints, e.g. (3,3,3) or 3
297
        :param concat: bool,specify how to combine input and skip connection images.
298
            If True, use concatenation, otherwise use sum (default=False).
299
        :param kwargs: additional arguments.
300
        """
301
        super().__init__(**kwargs)
302
        # save parameters
303
        self._filters = filters
304
        self._concat = concat
305
        # init layer variables
306
        self._deconv3d_block = None
307
        self._conv3d_block = Conv3dBlock(filters=filters, kernel_size=kernel_size)
308
        self._residual_block = Residual3dBlock(filters=filters, kernel_size=kernel_size)
309
310
    def build(self, input_shape):
0 ignored issues
show
introduced by
"input_shape" missing in parameter type documentation
Loading history...
311
        """
312
        :param input_shape: tuple, (downsampled_image_shape, skip_image_shape)
313
        """
314
        super().build(input_shape)
315
        skip_shape = input_shape[1][1:4]
316
        self._deconv3d_block = Deconv3dBlock(
317
            filters=self._filters, output_shape=skip_shape, strides=2
318
        )
319
320
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
321
        r"""
322
        :param inputs: tuple
323
324
          - down-sampled
325
          - skipped
326
327
        :param training: training flag for normalization layers (default: None)
328
        :param kwargs: additional arguments.
329
        :return: shape = (batch, \*skip_connection_image_shape, filters]
330
        """
331
        up_sampled, skip = inputs[0], inputs[1]
332
        up_sampled = self._deconv3d_block(
333
            inputs=up_sampled, training=training
334
        )  # up sample and change channel
335
        up_sampled = (
336
            tf.concat([up_sampled, skip], axis=4) if self._concat else up_sampled + skip
337
        )  # combine
338
        up_sampled = self._conv3d_block(
339
            inputs=up_sampled, training=training
340
        )  # adjust channel
341
        up_sampled = self._residual_block(inputs=up_sampled, training=training)  # conv
342
        return up_sampled
343
344
345
class Conv3dWithResize(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
346
    def __init__(
347
        self,
348
        output_shape: tuple,
349
        filters: int,
350
        kernel_initializer: str = "glorot_uniform",
351
        activation: (str, None) = None,
352
        **kwargs,
353
    ):
354
        """
355
        A layer contains conv3d - resize3d.
356
357
        :param output_shape: tuple, (out_dim1, out_dim2, out_dim3)
358
        :param filters: int, number of channels of the output
359
        :param kernel_initializer: str, defines the initialization method
360
        :param activation: str, defines the activation function
361
        :param kwargs: additional arguments.
362
        """
363
        super().__init__(**kwargs)
364
        # save parameters
365
        self._output_shape = output_shape
366
        # init layer variables
367
        self._conv3d = tfkl.Conv3D(
368
            filters=filters,
369
            kernel_size=3,
370
            strides=1,
371
            padding="same",
372
            kernel_initializer=kernel_initializer,
373
            activation=activation,
374
        )  # if not zero, with init NN, ddf may be too large
375
376
    def call(self, inputs, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"inputs" missing in parameter type documentation
Loading history...
377
        """
378
        :param inputs: shape = (batch, dim1, dim2, dim3, channels)
379
        :param kwargs: additional arguments.
380
        :return: shape = (batch, out_dim1, out_dim2, out_dim3, channels)
381
        """
382
        output = self._conv3d(inputs=inputs)
383
        output = layer_util.resize3d(image=output, size=self._output_shape)
384
        return output
385
386
387
class Warping(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
388
    def __init__(self, fixed_image_size: tuple, **kwargs):
389
        """
390
        A layer warps an image using DDF.
391
392
        Reference:
393
394
        - transform of neuron
395
          https://github.com/adalca/neurite/blob/legacy/neuron/utils.py
396
397
          where vol = image, loc_shift = ddf
398
399
        :param fixed_image_size: shape = (f_dim1, f_dim2, f_dim3)
400
             or (f_dim1, f_dim2, f_dim3, ch) with the last channel for features
401
        :param kwargs: additional arguments.
402
        """
403
        super().__init__(**kwargs)
404
        self.grid_ref = tf.expand_dims(
405
            layer_util.get_reference_grid(grid_size=fixed_image_size), axis=0
406
        )  # shape = (1, f_dim1, f_dim2, f_dim3, 3)
407
408
    def call(self, inputs, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"inputs" missing in parameter type documentation
Loading history...
409
        """
410
        :param inputs: (ddf, image)
411
412
          - ddf, shape = (batch, f_dim1, f_dim2, f_dim3, 3), dtype = float32
413
          - image, shape = (batch, m_dim1, m_dim2, m_dim3), dtype = float32
414
        :param kwargs: additional arguments.
415
        :return: shape = (batch, f_dim1, f_dim2, f_dim3)
416
        """
417
        return layer_util.warp_image_ddf(
418
            image=inputs[1], ddf=inputs[0], grid_ref=self.grid_ref
419
        )
420
421
422
class IntDVF(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
423
    def __init__(self, fixed_image_size: tuple, num_steps: int = 7, **kwargs):
424
        """
425
        Layer calculates DVF from DDF.
426
427
        Reference:
428
429
        - integrate_vec of neuron
430
          https://github.com/adalca/neurite/blob/legacy/neuron/utils.py
431
432
        :param fixed_image_size: tuple, (f_dim1, f_dim2, f_dim3)
433
        :param num_steps: int, number of steps for integration
434
        :param kwargs: additional arguments.
435
        """
436
        super().__init__(**kwargs)
437
        self._warping = Warping(fixed_image_size=fixed_image_size)
438
        self._num_steps = num_steps
439
440
    def call(self, inputs, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"inputs" missing in parameter type documentation
Loading history...
441
        """
442
        :param inputs: dvf, shape = (batch, f_dim1, f_dim2, f_dim3, 3), type = float32
443
        :param kwargs: additional arguments.
444
        :return: ddf, shape = (batch, f_dim1, f_dim2, f_dim3, 3)
445
        """
446
        ddf = inputs / (2 ** self._num_steps)
447
        for _ in range(self._num_steps):
448
            ddf += self._warping(inputs=[ddf, ddf])
449
        return ddf
450
451
452
class AdditiveUpSampling(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
453
    def __init__(self, output_shape: tuple, stride: (int, list) = 2, **kwargs):
454
        """
455
        Layer up-samples 3d tensor and reduce channels using split and sum.
456
457
        :param output_shape: (out_dim1, out_dim2, out_dim3)
458
        :param stride: int, 1-D Tensor or list
459
        :param kwargs: additional arguments.
460
        """
461
        super().__init__(**kwargs)
462
        # save parameters
463
        self._stride = stride
464
        self._output_shape = output_shape
465
466
    def call(self, inputs, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"inputs" missing in parameter type documentation
Loading history...
introduced by
"ValueError" not documented as being raised
Loading history...
467
        """
468
        :param inputs: shape = (batch, dim1, dim2, dim3, channels)
469
        :param kwargs: additional arguments.
470
        :return: shape = (batch, out_dim1, out_dim2, out_dim3, channels//stride]
471
        """
472
        if inputs.shape[4] % self._stride != 0:
473
            raise ValueError("The channel dimension can not be divided by the stride")
474
        output = layer_util.resize3d(image=inputs, size=self._output_shape)
475
        # a list of (batch, out_dim1, out_dim2, out_dim3, channels//stride)
476
        output = tf.split(output, num_or_size_splits=self._stride, axis=4)
0 ignored issues
show
Unused Code introduced by
Argument 'axis' passed by position and keyword in function call
Loading history...
477
        # (batch, out_dim1, out_dim2, out_dim3, channels//stride)
478
        output = tf.reduce_sum(tf.stack(output, axis=5), axis=5)
479
        return output
480
481
482
class LocalNetResidual3dBlock(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
483
    def __init__(
484
        self,
485
        filters: int,
486
        kernel_size: (int, tuple) = 3,
487
        strides: (int, tuple) = 1,
488
        activation: str = "relu",
489
        **kwargs,
490
    ):
491
        """
492
        A resnet conv3d block, simpler than Residual3dBlock.
493
494
        1. conved = conv3d(inputs)
495
        2. out = act(norm(conved) + inputs)
496
497
        :param filters: number of channels of the output
498
        :param kernel_size: int or tuple of 3 ints, e.g. (3,3,3) or 3
499
        :param strides: int or tuple of 3 ints, e.g. (1,1,1) or 1
500
        :param activation: name of activation
501
        :param kwargs: additional arguments.
502
        """
503
        super().__init__(**kwargs)
504
        # init layer variables
505
        self._conv3d = tfkl.Conv3D(
506
            filters=filters,
507
            kernel_size=kernel_size,
508
            strides=strides,
509
            padding="same",
510
            use_bias=False,
511
        )
512
        self._norm = tfkl.BatchNormalization()
513
        self._act = tfkl.Activation(activation=activation)
514
515
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
516
        return self._act(
517
            self._norm(inputs=self._conv3d(inputs=inputs[0]), training=training)
518
            + inputs[1]
519
        )
520
521
522
class LocalNetUpSampleResnetBlock(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
523
    def __init__(self, filters: int, use_additive_upsampling: bool = True, **kwargs):
524
        """
525
        Layer up-samples tensor with two inputs (skipped and down-sampled).
526
527
        :param filters: int, number of output channels
528
        :param use_additive_upsampling: bool to used additive upsampling
529
        :param kwargs: additional arguments.
530
        """
531
        super().__init__(**kwargs)
532
        # save parameters
533
        self._filters = filters
534
        self._use_additive_upsampling = use_additive_upsampling
535
        # init layer variables
536
        self._deconv3d_block = None
537
        self._additive_upsampling = None
538
        self._conv3d_block = Conv3dBlock(filters=filters)
539
        self._residual_block = LocalNetResidual3dBlock(filters=filters, strides=1)
540
541
    def build(self, input_shape):
0 ignored issues
show
introduced by
"input_shape" missing in parameter type documentation
Loading history...
542
        """
543
        :param input_shape: tuple (nonskip_tensor_shape, skip_tensor_shape)
544
        """
545
        super().build(input_shape)
546
547
        output_shape = input_shape[1][1:4]
548
        self._deconv3d_block = Deconv3dBlock(
549
            filters=self._filters, output_shape=output_shape, strides=2
550
        )
551
        if self._use_additive_upsampling:
552
            self._additive_upsampling = AdditiveUpSampling(output_shape=output_shape)
553
554
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
555
        """
556
        :param inputs: list = [inputs_nonskip, inputs_skip]
557
        :param training: training flag for normalization layers (default: None)
558
        :param kwargs: additional arguments.
559
        :return:
560
        """
561
        inputs_nonskip, inputs_skip = inputs[0], inputs[1]
562
        h0 = self._deconv3d_block(inputs=inputs_nonskip, training=training)
563
        if self._use_additive_upsampling:
564
            h0 += self._additive_upsampling(inputs=inputs_nonskip)
565
        r1 = h0 + inputs_skip
566
        r2 = self._conv3d_block(inputs=h0, training=training)
567
        h1 = self._residual_block(inputs=[r2, r1], training=training)
568
        return h1
569
570
571
class ResizeCPTransform(tfkl.Layer):
572
    """
573
    Layer for getting the control points from the output of a image-to-image network.
574
    It uses an anti-aliasing Gaussian filter before downsampling.
575
    """
576
577
    def __init__(self, control_point_spacing: (list, tuple, int), **kwargs):
578
        """
579
        :param control_point_spacing: list or int
580
        :param kwargs: additional arguments.
581
        """
582
        super().__init__(**kwargs)
583
584
        if isinstance(control_point_spacing, int):
585
            control_point_spacing = [control_point_spacing] * 3
586
587
        self.kernel_sigma = [
588
            0.44 * cp for cp in control_point_spacing
589
        ]  # 0.44 = ln(4)/pi
590
        self.cp_spacing = control_point_spacing
591
        self.kernel = None
592
        self._output_shape = None
593
594
    def build(self, input_shape):
595
        super().build(input_shape=input_shape)
596
597
        self.kernel = layer_util.gaussian_filter_3d(self.kernel_sigma)
598
        output_shape = [
599
            tf.cast(tf.math.ceil(v / c) + 3, tf.int32)
600
            for v, c in zip(input_shape[1:-1], self.cp_spacing)
601
        ]
602
        self._output_shape = output_shape
603
604
    def call(self, inputs, **kwargs) -> tf.Tensor:
605
        output = tf.nn.conv3d(
606
            inputs, self.kernel, strides=(1, 1, 1, 1, 1), padding="SAME"
607
        )
608
        return layer_util.resize3d(image=output, size=self._output_shape)
609
610
611
class BSplines3DTransform(tfkl.Layer):
612
    """
613
     Layer for BSplines interpolation with precomputed cubic spline filters.
614
     It assumes a full sized image from which:
615
     1. it compute the contol points values by downsampling the initial image
616
     2. performs the interpolation
617
     3. crops the image around the valid values.
618
619
    :param cp_spacing: int or tuple of three ints specifying the spacing (in pixels)
620
        in each dimension. When a single int is used,
621
        the same spacing to all dimensions is used
622
    :param output_shape: (batch_size, dim0, dim1, dim2, 3) of the high resolution
623
        deformation fields.
624
    :param kwargs: additional arguments.
625
    """
626
627
    def __init__(self, cp_spacing: (int, tuple), output_shape: tuple, **kwargs):
628
629
        super().__init__(**kwargs)
630
631
        self.filters = []
632
        self._output_shape = output_shape
633
634
        if isinstance(cp_spacing, int):
635
            self.cp_spacing = (cp_spacing, cp_spacing, cp_spacing)
636
        else:
637
            self.cp_spacing = cp_spacing
638
639
    def build(self, input_shape: tuple):
0 ignored issues
show
introduced by
Redundant returns documentation
Loading history...
640
        """
641
        :param input_shape: tuple with the input shape
642
        :return: None
643
        """
644
645
        super().build(input_shape=input_shape)
646
647
        b = {
648
            0: lambda u: np.float64((1 - u) ** 3 / 6),
649
            1: lambda u: np.float64((3 * (u ** 3) - 6 * (u ** 2) + 4) / 6),
650
            2: lambda u: np.float64((-3 * (u ** 3) + 3 * (u ** 2) + 3 * u + 1) / 6),
651
            3: lambda u: np.float64(u ** 3 / 6),
652
        }
653
654
        filters = np.zeros(
655
            (
656
                4 * self.cp_spacing[0],
657
                4 * self.cp_spacing[1],
658
                4 * self.cp_spacing[2],
659
                3,
660
                3,
661
            ),
662
            dtype=np.float32,
663
        )
664
665
        u_arange = 1 - np.arange(
666
            1 / (2 * self.cp_spacing[0]), 1, 1 / self.cp_spacing[0]
667
        )
668
        v_arange = 1 - np.arange(
669
            1 / (2 * self.cp_spacing[1]), 1, 1 / self.cp_spacing[1]
670
        )
671
        w_arange = 1 - np.arange(
672
            1 / (2 * self.cp_spacing[2]), 1, 1 / self.cp_spacing[2]
673
        )
674
675
        filter_idx = [[0, 1, 2, 3] for _ in range(3)]
676
        filter_coord = list(itertools.product(*filter_idx))
677
678
        for f_idx in filter_coord:
679
            for it_dim in range(3):
680
                filters[
681
                    f_idx[0] * self.cp_spacing[0] : (f_idx[0] + 1) * self.cp_spacing[0],
682
                    f_idx[1] * self.cp_spacing[1] : (f_idx[1] + 1) * self.cp_spacing[1],
683
                    f_idx[2] * self.cp_spacing[2] : (f_idx[2] + 1) * self.cp_spacing[2],
684
                    it_dim,
685
                    it_dim,
686
                ] = (
687
                    b[f_idx[0]](u_arange)[:, None, None]
688
                    * b[f_idx[1]](v_arange)[None, :, None]
689
                    * b[f_idx[2]](w_arange)[None, None, :]
690
                )
691
692
        self.filter = tf.convert_to_tensor(filters)
693
694
    def interpolate(self, field) -> tf.Tensor:
0 ignored issues
show
introduced by
"field" missing in parameter type documentation
Loading history...
695
        """
696
        :param field: tf.Tensor with shape=number_of_control_points_per_dim
697
        :return: interpolated_field: tf.Tensor
698
        """
699
700
        image_shape = tuple(
701
            [(a - 1) * b + 4 * b for a, b in zip(field.shape[1:-1], self.cp_spacing)]
702
        )
703
704
        output_shape = (field.shape[0],) + image_shape + (3,)
705
        return tf.nn.conv3d_transpose(
706
            field,
707
            self.filter,
708
            output_shape=output_shape,
709
            strides=self.cp_spacing,
710
            padding="VALID",
711
        )
712
713
    def call(self, inputs, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"inputs" missing in parameter type documentation
Loading history...
714
        """
715
        :param inputs: tf.Tensor defining a low resolution free-form deformation field
716
        :param kwargs: additional arguments.
717
        :return: interpolated_field: tf.Tensor of shape=self.input_shape
718
        """
719
        high_res_field = self.interpolate(inputs)
720
721
        index = [int(3 * c) for c in self.cp_spacing]
722
        return high_res_field[
723
            :,
724
            index[0] : index[0] + self._output_shape[0],
725
            index[1] : index[1] + self._output_shape[1],
726
            index[2] : index[2] + self._output_shape[2],
727
        ]
728