Passed
Pull Request — main (#656)
by Yunguan
02:37
created

deepreg.model.layer   C

Complexity

Total Complexity 53

Size/Duplication

Total Lines 725
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 53
eloc 352
dl 0
loc 725
rs 6.96
c 0
b 0
f 0

34 Methods

Rating   Name   Duplication   Size   Complexity  
A DownSampleResnetBlock.__init__() 0 34 3
A Deconv3dBlock.__init__() 0 33 1
A UpSampleResnetBlock.call() 0 23 2
A Residual3dBlock.__init__() 0 34 1
A UpSampleResnetBlock.build() 0 8 1
A Conv3dBlock.call() 0 11 1
A Deconv3d.call() 0 2 1
A UpSampleResnetBlock.__init__() 0 18 1
A Deconv3dBlock.call() 0 11 1
A Conv3dBlock.__init__() 0 30 1
A Deconv3d.__init__() 0 34 1
A Deconv3d.build() 0 36 4
A DownSampleResnetBlock.call() 0 18 2
A Warping.__init__() 0 18 1
A Conv3dWithResize.call() 0 9 1
A Conv3dWithResize.__init__() 0 28 1
A Residual3dBlock.call() 0 13 1
A AdditiveUpSampling.call() 0 14 2
A LocalNetResidual3dBlock.__init__() 0 31 1
B BSplines3DTransform.build() 0 54 7
A LocalNetUpSampleResnetBlock.__init__() 0 17 1
A BSplines3DTransform.__init__() 0 11 2
A ResizeCPTransform.__init__() 0 16 2
A AdditiveUpSampling.__init__() 0 12 1
A LocalNetUpSampleResnetBlock.call() 0 15 2
A ResizeCPTransform.call() 0 5 1
A BSplines3DTransform.interpolate() 0 17 1
A ResizeCPTransform.build() 0 9 1
A IntDVF.__init__() 0 16 1
A IntDVF.call() 0 10 2
A BSplines3DTransform.call() 0 14 1
A LocalNetUpSampleResnetBlock.build() 0 12 2
A LocalNetResidual3dBlock.call() 0 4 1
A Warping.call() 0 10 1

How to fix   Complexity   

Complexity

Complex classes like deepreg.model.layer often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
import itertools
0 ignored issues
show
introduced by
Missing module docstring
Loading history...
2
3
import numpy as np
4
import tensorflow as tf
5
import tensorflow.keras.layers as tfkl
6
7
import deepreg.model.layer_util as layer_util
8
9
10
class Deconv3d(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
11
    def __init__(
12
        self,
13
        filters: int,
14
        output_shape: (tuple, None) = None,
15
        kernel_size: int = 3,
16
        strides: int = 1,
17
        padding: str = "same",
18
        use_bias: bool = True,
19
        **kwargs,
20
    ):
21
        """
22
        Layer wraps tfkl.Conv3DTranspose
23
        and does not requires input shape when initializing.
24
25
        :param filters: number of channels of the output
26
        :param output_shape: (out_dim1, out_dim2, out_dim3)
27
        :param kernel_size: int or tuple of 3 ints, e.g. (3,3,3) or 3
28
        :param strides: int or tuple of 3 ints, e.g. (1,1,1) or 1
29
        :param padding: same or valid.
30
        :param use_bias: use bias for Conv3DTranspose or not.
31
        :param kwargs: additional arguments.
32
        """
33
        super().__init__(**kwargs)
34
        # save parameters
35
        self._filters = filters
36
        self._output_shape = output_shape
37
        self._kernel_size = kernel_size
38
        self._strides = strides
39
        self._padding = padding
40
        self._use_bias = use_bias
41
        self._kwargs = kwargs
42
        # init layer variables
43
        self._output_padding = None
44
        self._deconv3d = None
45
46
    def build(self, input_shape):
47
        super().build(input_shape)
48
49
        if isinstance(self._kernel_size, int):
50
            self._kernel_size = [self._kernel_size] * 3
51
        if isinstance(self._strides, int):
52
            self._strides = [self._strides] * 3
53
54
        if self._output_shape is not None:
55
            # pylint: disable-next=line-too-long
56
            """
57
            https://github.com/tensorflow/tensorflow/blob/1cf0898dd4331baf93fe77205550f2c2e6c90ee5/tensorflow/python/keras/utils/conv_utils.py#L139-L185
58
            When the output shape is defined, the padding should be calculated manually
59
            if padding == 'same':
60
                pad = filter_size // 2
61
                length = ((input_length - 1) * stride + filter_size
62
                         - 2 * pad + output_padding)
63
            """
0 ignored issues
show
Unused Code introduced by
This string statement has no effect and could be removed.
Loading history...
64
            self._padding = "same"
65
            self._output_padding = [
66
                self._output_shape[i]
67
                - (
68
                    (input_shape[1 + i] - 1) * self._strides[i]
69
                    + self._kernel_size[i]
70
                    - 2 * (self._kernel_size[i] // 2)
71
                )
72
                for i in range(3)
73
            ]
74
        self._deconv3d = tfkl.Conv3DTranspose(
75
            filters=self._filters,
76
            kernel_size=self._kernel_size,
77
            strides=self._strides,
78
            padding=self._padding,
79
            output_padding=self._output_padding,
80
            use_bias=self._use_bias,
81
            **self._kwargs,
82
        )
83
84
    def call(self, inputs, **kwargs):
85
        return self._deconv3d(inputs=inputs)
86
87
88
class Conv3dBlock(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
89
    def __init__(
90
        self,
91
        filters: int,
92
        kernel_size: (int, tuple) = 3,
93
        strides: (int, tuple) = 1,
94
        padding: str = "same",
95
        activation: str = "relu",
96
        **kwargs,
97
    ):
98
        """
99
        A conv3d block having conv3d - norm - activation.
100
101
        :param filters: number of channels of the output
102
        :param kernel_size: int or tuple of 3 ints, e.g. (3,3,3) or 3
103
        :param strides: int or tuple of 3 ints, e.g. (1,1,1) or 1
104
        :param padding: str, same or valid
105
        :param activation: name of activation
106
        :param kwargs: additional arguments.
107
        """
108
        super().__init__(**kwargs)
109
        # init layer variables
110
        self._conv3d = tfkl.Conv3D(
111
            filters=filters,
112
            kernel_size=kernel_size,
113
            strides=strides,
114
            padding=padding,
115
            use_bias=False,
116
        )
117
        self._norm = tfkl.BatchNormalization()
118
        self._act = tfkl.Activation(activation=activation)
119
120
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
121
        """
122
        :param inputs: shape = (batch, in_dim1, in_dim2, in_dim3, channels)
123
        :param training: training flag for normalization layers (default: None)
124
        :param kwargs: additional arguments.
125
        :return: shape = (batch, in_dim1, in_dim2, in_dim3, channels)
126
        """
127
        output = self._conv3d(inputs=inputs)
128
        output = self._norm(inputs=output, training=training)
129
        output = self._act(output)
130
        return output
131
132
133
class Deconv3dBlock(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
134
    def __init__(
135
        self,
136
        filters: int,
137
        output_shape: (tuple, None) = None,
138
        kernel_size: (int, tuple) = 3,
139
        strides: (int, tuple) = 1,
140
        padding: str = "same",
141
        activation: str = "relu",
142
        **kwargs,
143
    ):
144
        """
145
        A deconv3d block having deconv3d - norm - activation.
146
147
        :param filters: number of channels of the output
148
        :param output_shape: (out_dim1, out_dim2, out_dim3)
149
        :param kernel_size: int or tuple of 3 ints, e.g. (3,3,3) or 3
150
        :param strides: int or tuple of 3 ints, e.g. (1,1,1) or 1
151
        :param padding: str, same or valid
152
        :param activation: name of activation
153
        :param kwargs: additional arguments.
154
        """
155
        super().__init__(**kwargs)
156
        # init layer variables
157
        self._deconv3d = Deconv3d(
158
            filters=filters,
159
            output_shape=output_shape,
160
            kernel_size=kernel_size,
161
            strides=strides,
162
            padding=padding,
163
            use_bias=False,
164
        )
165
        self._norm = tfkl.BatchNormalization()
166
        self._act = tfkl.Activation(activation=activation)
167
168
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
introduced by
Missing return documentation
Loading history...
169
        """
170
        :param inputs: shape = (batch, in_dim1, in_dim2, in_dim3, channels)
171
        :param training: training flag for normalization layers (default: None)
172
        :param kwargs: additional arguments.
173
        :return output: shape = (batch, in_dim1, in_dim2, in_dim3, channels)
174
        """
175
        output = self._deconv3d(inputs=inputs)
176
        output = self._norm(inputs=output, training=training)
177
        output = self._act(output)
178
        return output
179
180
181
class Residual3dBlock(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
182
    def __init__(
183
        self,
184
        filters: int,
185
        kernel_size: (int, tuple) = 3,
186
        strides: (int, tuple) = 1,
187
        activation: str = "relu",
188
        **kwargs,
189
    ):
190
        """
191
        A resnet conv3d block.
192
193
        1. conved = conv3d(conv3d_block(inputs))
194
        2. out = act(norm(conved) + inputs)
195
196
        :param filters: int, number of filters in the convolutional layers
197
        :param kernel_size: int or tuple of 3 ints, e.g. (3,3,3) or 3
198
        :param strides: int or tuple of 3 ints, e.g. (1,1,1) or 1
199
        :param activation: name of activation
200
        :param kwargs: additional arguments.
201
        """
202
        super().__init__(**kwargs)
203
        # init layer variables
204
        self._conv3d_block = Conv3dBlock(
205
            filters=filters, kernel_size=kernel_size, strides=strides
206
        )
207
        self._conv3d = tfkl.Conv3D(
208
            filters=filters,
209
            kernel_size=kernel_size,
210
            strides=strides,
211
            padding="same",
212
            use_bias=False,
213
        )
214
        self._norm = tfkl.BatchNormalization()
215
        self._act = tfkl.Activation(activation=activation)
216
217
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
introduced by
Missing return documentation
Loading history...
218
        """
219
        :param inputs: shape = (batch, in_dim1, in_dim2, in_dim3, channels)
220
        :param training: training flag for normalization layers (default: None)
221
        :param kwargs: additional arguments.
222
        :return output: shape = (batch, in_dim1, in_dim2, in_dim3, channels)
223
        """
224
        return self._act(
225
            self._norm(
226
                inputs=self._conv3d(inputs=self._conv3d_block(inputs)),
227
                training=training,
228
            )
229
            + inputs
230
        )
231
232
233
class DownSampleResnetBlock(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
234
    def __init__(
235
        self,
236
        filters: int,
237
        kernel_size: (int, tuple) = 3,
238
        pooling: bool = True,
239
        **kwargs,
240
    ):
241
        """
242
        A down-sampling resnet conv3d block, with max-pooling or conv3d.
243
244
        1. conved = conv3d_block(inputs)  # adjust channel
245
        2. skip = residual_block(conved)  # develop feature
246
        3. pooled = pool(skip) # down-sample
247
248
        :param filters: number of channels of the output
249
        :param kernel_size: int or tuple of 3 ints, e.g. (3,3,3) or 3
250
        :param pooling: if True, use max pooling to downsample, otherwise use conv.
251
        :param kwargs: additional arguments.
252
        """
253
        super().__init__(**kwargs)
254
        # save parameters
255
        self._pooling = pooling
256
        # init layer variables
257
        self._conv3d_block = Conv3dBlock(filters=filters, kernel_size=kernel_size)
258
        self._residual_block = Residual3dBlock(filters=filters, kernel_size=kernel_size)
259
        self._max_pool3d = (
260
            tfkl.MaxPool3D(pool_size=(2, 2, 2), strides=(2, 2, 2), padding="same")
261
            if pooling
262
            else None
263
        )
264
        self._conv3d_block3 = (
265
            None
266
            if pooling
267
            else Conv3dBlock(filters=filters, kernel_size=kernel_size, strides=2)
268
        )
269
270
    def call(self, inputs, training=None, **kwargs) -> (tf.Tensor, tf.Tensor):
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
271
        """
272
        :param inputs: shape = (batch, in_dim1, in_dim2, in_dim3, channels)
273
        :param training: training flag for normalization layers (default: None)
274
        :param kwargs: additional arguments.
275
        :return: (pooled, skip)
276
277
          - downsampled, shape = (batch, in_dim1//2, in_dim2//2, in_dim3//2, channels)
278
          - skipped, shape = (batch, in_dim1, in_dim2, in_dim3, channels)
279
        """
280
        conved = self._conv3d_block(inputs=inputs, training=training)  # adjust channel
281
        skip = self._residual_block(inputs=conved, training=training)  # develop feature
282
        pooled = (
283
            self._max_pool3d(inputs=skip)
284
            if self._pooling
285
            else self._conv3d_block3(inputs=skip, training=training)
0 ignored issues
show
Bug introduced by
self._conv3d_block3 does not seem to be callable.
Loading history...
286
        )  # downsample
287
        return pooled, skip
288
289
290
class UpSampleResnetBlock(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
291
    def __init__(self, filters, kernel_size=3, concat=False, **kwargs):
0 ignored issues
show
introduced by
"concat, filters, kernel_size" missing in parameter type documentation
Loading history...
292
        """
293
        An up-sampling resnet conv3d block, with deconv3d.
294
295
        :param filters: number of channels of the output
296
        :param kernel_size: int or tuple of 3 ints, e.g. (3,3,3) or 3
297
        :param concat: bool,specify how to combine input and skip connection images.
298
            If True, use concatenation, otherwise use sum (default=False).
299
        :param kwargs: additional arguments.
300
        """
301
        super().__init__(**kwargs)
302
        # save parameters
303
        self._filters = filters
304
        self._concat = concat
305
        # init layer variables
306
        self._deconv3d_block = None
307
        self._conv3d_block = Conv3dBlock(filters=filters, kernel_size=kernel_size)
308
        self._residual_block = Residual3dBlock(filters=filters, kernel_size=kernel_size)
309
310
    def build(self, input_shape):
0 ignored issues
show
introduced by
"input_shape" missing in parameter type documentation
Loading history...
311
        """
312
        :param input_shape: tuple, (downsampled_image_shape, skip_image_shape)
313
        """
314
        super().build(input_shape)
315
        skip_shape = input_shape[1][1:4]
316
        self._deconv3d_block = Deconv3dBlock(
317
            filters=self._filters, output_shape=skip_shape, strides=2
318
        )
319
320
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
321
        r"""
322
        :param inputs: tuple
323
324
          - down-sampled
325
          - skipped
326
327
        :param training: training flag for normalization layers (default: None)
328
        :param kwargs: additional arguments.
329
        :return: shape = (batch, \*skip_connection_image_shape, filters]
330
        """
331
        up_sampled, skip = inputs[0], inputs[1]
332
        up_sampled = self._deconv3d_block(
333
            inputs=up_sampled, training=training
334
        )  # up sample and change channel
335
        up_sampled = (
336
            tf.concat([up_sampled, skip], axis=4) if self._concat else up_sampled + skip
337
        )  # combine
338
        up_sampled = self._conv3d_block(
339
            inputs=up_sampled, training=training
340
        )  # adjust channel
341
        up_sampled = self._residual_block(inputs=up_sampled, training=training)  # conv
342
        return up_sampled
343
344
345
class Conv3dWithResize(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
346
    def __init__(
347
        self,
348
        output_shape: tuple,
349
        filters: int,
350
        kernel_initializer: str = "glorot_uniform",
351
        activation: (str, None) = None,
352
        **kwargs,
353
    ):
354
        """
355
        A layer contains conv3d - resize3d.
356
357
        :param output_shape: tuple, (out_dim1, out_dim2, out_dim3)
358
        :param filters: int, number of channels of the output
359
        :param kernel_initializer: str, defines the initialization method
360
        :param activation: str, defines the activation function
361
        :param kwargs: additional arguments.
362
        """
363
        super().__init__(**kwargs)
364
        # save parameters
365
        self._output_shape = output_shape
366
        # init layer variables
367
        self._conv3d = tfkl.Conv3D(
368
            filters=filters,
369
            kernel_size=3,
370
            strides=1,
371
            padding="same",
372
            kernel_initializer=kernel_initializer,
373
            activation=activation,
374
        )  # if not zero, with init NN, ddf may be too large
375
376
    def call(self, inputs, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"inputs" missing in parameter type documentation
Loading history...
377
        """
378
        :param inputs: shape = (batch, dim1, dim2, dim3, channels)
379
        :param kwargs: additional arguments.
380
        :return: shape = (batch, out_dim1, out_dim2, out_dim3, channels)
381
        """
382
        output = self._conv3d(inputs=inputs)
383
        output = layer_util.resize3d(image=output, size=self._output_shape)
384
        return output
385
386
387
class Warping(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
388
    def __init__(self, fixed_image_size: tuple, **kwargs):
389
        """
390
        A layer warps an image using DDF.
391
392
        Reference:
393
394
        - transform of neuron
395
          https://github.com/adalca/neurite/blob/legacy/neuron/utils.py
396
397
          where vol = image, loc_shift = ddf
398
399
        :param fixed_image_size: shape = (f_dim1, f_dim2, f_dim3)
400
             or (f_dim1, f_dim2, f_dim3, ch) with the last channel for features
401
        :param kwargs: additional arguments.
402
        """
403
        super().__init__(**kwargs)
404
        self.grid_ref = layer_util.get_reference_grid(grid_size=fixed_image_size)[
405
            None, ...
406
        ]  # shape = (1, f_dim1, f_dim2, f_dim3, 3)
407
408
    def call(self, inputs, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"inputs" missing in parameter type documentation
Loading history...
409
        """
410
        :param inputs: (ddf, image)
411
412
          - ddf, shape = (batch, f_dim1, f_dim2, f_dim3, 3), dtype = float32
413
          - image, shape = (batch, m_dim1, m_dim2, m_dim3), dtype = float32
414
        :param kwargs: additional arguments.
415
        :return: shape = (batch, f_dim1, f_dim2, f_dim3)
416
        """
417
        return layer_util.resample(vol=inputs[1], loc=self.grid_ref + inputs[0])
418
419
420
class IntDVF(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
421
    def __init__(self, fixed_image_size: tuple, num_steps: int = 7, **kwargs):
422
        """
423
        Layer calculates DVF from DDF.
424
425
        Reference:
426
427
        - integrate_vec of neuron
428
          https://github.com/adalca/neurite/blob/legacy/neuron/utils.py
429
430
        :param fixed_image_size: tuple, (f_dim1, f_dim2, f_dim3)
431
        :param num_steps: int, number of steps for integration
432
        :param kwargs: additional arguments.
433
        """
434
        super().__init__(**kwargs)
435
        self._warping = Warping(fixed_image_size=fixed_image_size)
436
        self._num_steps = num_steps
437
438
    def call(self, inputs, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"inputs" missing in parameter type documentation
Loading history...
439
        """
440
        :param inputs: dvf, shape = (batch, f_dim1, f_dim2, f_dim3, 3), type = float32
441
        :param kwargs: additional arguments.
442
        :return: ddf, shape = (batch, f_dim1, f_dim2, f_dim3, 3)
443
        """
444
        ddf = inputs / (2 ** self._num_steps)
445
        for _ in range(self._num_steps):
446
            ddf += self._warping(inputs=[ddf, ddf])
447
        return ddf
448
449
450
class AdditiveUpSampling(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
451
    def __init__(self, output_shape: tuple, stride: (int, list) = 2, **kwargs):
452
        """
453
        Layer up-samples 3d tensor and reduce channels using split and sum.
454
455
        :param output_shape: (out_dim1, out_dim2, out_dim3)
456
        :param stride: int, 1-D Tensor or list
457
        :param kwargs: additional arguments.
458
        """
459
        super().__init__(**kwargs)
460
        # save parameters
461
        self._stride = stride
462
        self._output_shape = output_shape
463
464
    def call(self, inputs, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"inputs" missing in parameter type documentation
Loading history...
introduced by
"ValueError" not documented as being raised
Loading history...
465
        """
466
        :param inputs: shape = (batch, dim1, dim2, dim3, channels)
467
        :param kwargs: additional arguments.
468
        :return: shape = (batch, out_dim1, out_dim2, out_dim3, channels//stride]
469
        """
470
        if inputs.shape[4] % self._stride != 0:
471
            raise ValueError("The channel dimension can not be divided by the stride")
472
        output = layer_util.resize3d(image=inputs, size=self._output_shape)
473
        # a list of (batch, out_dim1, out_dim2, out_dim3, channels//stride)
474
        output = tf.split(output, num_or_size_splits=self._stride, axis=4)
0 ignored issues
show
Unused Code introduced by
Argument 'axis' passed by position and keyword in function call
Loading history...
475
        # (batch, out_dim1, out_dim2, out_dim3, channels//stride)
476
        output = tf.reduce_sum(tf.stack(output, axis=5), axis=5)
477
        return output
478
479
480
class LocalNetResidual3dBlock(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
481
    def __init__(
482
        self,
483
        filters: int,
484
        kernel_size: (int, tuple) = 3,
485
        strides: (int, tuple) = 1,
486
        activation: str = "relu",
487
        **kwargs,
488
    ):
489
        """
490
        A resnet conv3d block, simpler than Residual3dBlock.
491
492
        1. conved = conv3d(inputs)
493
        2. out = act(norm(conved) + inputs)
494
495
        :param filters: number of channels of the output
496
        :param kernel_size: int or tuple of 3 ints, e.g. (3,3,3) or 3
497
        :param strides: int or tuple of 3 ints, e.g. (1,1,1) or 1
498
        :param activation: name of activation
499
        :param kwargs: additional arguments.
500
        """
501
        super().__init__(**kwargs)
502
        # init layer variables
503
        self._conv3d = tfkl.Conv3D(
504
            filters=filters,
505
            kernel_size=kernel_size,
506
            strides=strides,
507
            padding="same",
508
            use_bias=False,
509
        )
510
        self._norm = tfkl.BatchNormalization()
511
        self._act = tfkl.Activation(activation=activation)
512
513
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
514
        return self._act(
515
            self._norm(inputs=self._conv3d(inputs=inputs[0]), training=training)
516
            + inputs[1]
517
        )
518
519
520
class LocalNetUpSampleResnetBlock(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
521
    def __init__(self, filters: int, use_additive_upsampling: bool = True, **kwargs):
522
        """
523
        Layer up-samples tensor with two inputs (skipped and down-sampled).
524
525
        :param filters: int, number of output channels
526
        :param use_additive_upsampling: bool to used additive upsampling
527
        :param kwargs: additional arguments.
528
        """
529
        super().__init__(**kwargs)
530
        # save parameters
531
        self._filters = filters
532
        self._use_additive_upsampling = use_additive_upsampling
533
        # init layer variables
534
        self._deconv3d_block = None
535
        self._additive_upsampling = None
536
        self._conv3d_block = Conv3dBlock(filters=filters)
537
        self._residual_block = LocalNetResidual3dBlock(filters=filters, strides=1)
538
539
    def build(self, input_shape):
0 ignored issues
show
introduced by
"input_shape" missing in parameter type documentation
Loading history...
540
        """
541
        :param input_shape: tuple (nonskip_tensor_shape, skip_tensor_shape)
542
        """
543
        super().build(input_shape)
544
545
        output_shape = input_shape[1][1:4]
546
        self._deconv3d_block = Deconv3dBlock(
547
            filters=self._filters, output_shape=output_shape, strides=2
548
        )
549
        if self._use_additive_upsampling:
550
            self._additive_upsampling = AdditiveUpSampling(output_shape=output_shape)
551
552
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
553
        """
554
        :param inputs: list = [inputs_nonskip, inputs_skip]
555
        :param training: training flag for normalization layers (default: None)
556
        :param kwargs: additional arguments.
557
        :return:
558
        """
559
        inputs_nonskip, inputs_skip = inputs[0], inputs[1]
560
        h0 = self._deconv3d_block(inputs=inputs_nonskip, training=training)
561
        if self._use_additive_upsampling:
562
            h0 += self._additive_upsampling(inputs=inputs_nonskip)
563
        r1 = h0 + inputs_skip
564
        r2 = self._conv3d_block(inputs=h0, training=training)
565
        h1 = self._residual_block(inputs=[r2, r1], training=training)
566
        return h1
567
568
569
class ResizeCPTransform(tfkl.Layer):
570
    """
571
    Layer for getting the control points from the output of a image-to-image network.
572
    It uses an anti-aliasing Gaussian filter before downsampling.
573
    """
574
575
    def __init__(self, control_point_spacing: (list, tuple, int), **kwargs):
576
        """
577
        :param control_point_spacing: list or int
578
        :param kwargs: additional arguments.
579
        """
580
        super().__init__(**kwargs)
581
582
        if isinstance(control_point_spacing, int):
583
            control_point_spacing = [control_point_spacing] * 3
584
585
        self.kernel_sigma = [
586
            0.44 * cp for cp in control_point_spacing
587
        ]  # 0.44 = ln(4)/pi
588
        self.cp_spacing = control_point_spacing
589
        self.kernel = None
590
        self._output_shape = None
591
592
    def build(self, input_shape):
593
        super().build(input_shape=input_shape)
594
595
        self.kernel = layer_util.gaussian_filter_3d(self.kernel_sigma)
596
        output_shape = [
597
            tf.cast(tf.math.ceil(v / c) + 3, tf.int32)
598
            for v, c in zip(input_shape[1:-1], self.cp_spacing)
599
        ]
600
        self._output_shape = output_shape
601
602
    def call(self, inputs, **kwargs) -> tf.Tensor:
603
        output = tf.nn.conv3d(
604
            inputs, self.kernel, strides=(1, 1, 1, 1, 1), padding="SAME"
605
        )
606
        return layer_util.resize3d(image=output, size=self._output_shape)
607
608
609
class BSplines3DTransform(tfkl.Layer):
610
    """
611
     Layer for BSplines interpolation with precomputed cubic spline filters.
612
     It assumes a full sized image from which:
613
     1. it compute the contol points values by downsampling the initial image
614
     2. performs the interpolation
615
     3. crops the image around the valid values.
616
617
    :param cp_spacing: int or tuple of three ints specifying the spacing (in pixels)
618
        in each dimension. When a single int is used,
619
        the same spacing to all dimensions is used
620
    :param output_shape: (batch_size, dim0, dim1, dim2, 3) of the high resolution
621
        deformation fields.
622
    :param kwargs: additional arguments.
623
    """
624
625
    def __init__(self, cp_spacing: (int, tuple), output_shape: tuple, **kwargs):
626
627
        super().__init__(**kwargs)
628
629
        self.filters = []
630
        self._output_shape = output_shape
631
632
        if isinstance(cp_spacing, int):
633
            self.cp_spacing = (cp_spacing, cp_spacing, cp_spacing)
634
        else:
635
            self.cp_spacing = cp_spacing
636
637
    def build(self, input_shape: tuple):
0 ignored issues
show
introduced by
Redundant returns documentation
Loading history...
638
        """
639
        :param input_shape: tuple with the input shape
640
        :return: None
641
        """
642
643
        super().build(input_shape=input_shape)
644
645
        b = {
646
            0: lambda u: np.float64((1 - u) ** 3 / 6),
647
            1: lambda u: np.float64((3 * (u ** 3) - 6 * (u ** 2) + 4) / 6),
648
            2: lambda u: np.float64((-3 * (u ** 3) + 3 * (u ** 2) + 3 * u + 1) / 6),
649
            3: lambda u: np.float64(u ** 3 / 6),
650
        }
651
652
        filters = np.zeros(
653
            (
654
                4 * self.cp_spacing[0],
655
                4 * self.cp_spacing[1],
656
                4 * self.cp_spacing[2],
657
                3,
658
                3,
659
            ),
660
            dtype=np.float32,
661
        )
662
663
        u_arange = 1 - np.arange(
664
            1 / (2 * self.cp_spacing[0]), 1, 1 / self.cp_spacing[0]
665
        )
666
        v_arange = 1 - np.arange(
667
            1 / (2 * self.cp_spacing[1]), 1, 1 / self.cp_spacing[1]
668
        )
669
        w_arange = 1 - np.arange(
670
            1 / (2 * self.cp_spacing[2]), 1, 1 / self.cp_spacing[2]
671
        )
672
673
        filter_idx = [[0, 1, 2, 3] for _ in range(3)]
674
        filter_coord = list(itertools.product(*filter_idx))
675
676
        for f_idx in filter_coord:
677
            for it_dim in range(3):
678
                filters[
679
                    f_idx[0] * self.cp_spacing[0] : (f_idx[0] + 1) * self.cp_spacing[0],
680
                    f_idx[1] * self.cp_spacing[1] : (f_idx[1] + 1) * self.cp_spacing[1],
681
                    f_idx[2] * self.cp_spacing[2] : (f_idx[2] + 1) * self.cp_spacing[2],
682
                    it_dim,
683
                    it_dim,
684
                ] = (
685
                    b[f_idx[0]](u_arange)[:, None, None]
686
                    * b[f_idx[1]](v_arange)[None, :, None]
687
                    * b[f_idx[2]](w_arange)[None, None, :]
688
                )
689
690
        self.filter = tf.convert_to_tensor(filters)
691
692
    def interpolate(self, field) -> tf.Tensor:
0 ignored issues
show
introduced by
"field" missing in parameter type documentation
Loading history...
693
        """
694
        :param field: tf.Tensor with shape=number_of_control_points_per_dim
695
        :return: interpolated_field: tf.Tensor
696
        """
697
698
        image_shape = tuple(
699
            [(a - 1) * b + 4 * b for a, b in zip(field.shape[1:-1], self.cp_spacing)]
700
        )
701
702
        output_shape = (field.shape[0],) + image_shape + (3,)
703
        return tf.nn.conv3d_transpose(
704
            field,
705
            self.filter,
706
            output_shape=output_shape,
707
            strides=self.cp_spacing,
708
            padding="VALID",
709
        )
710
711
    def call(self, inputs, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"inputs" missing in parameter type documentation
Loading history...
712
        """
713
        :param inputs: tf.Tensor defining a low resolution free-form deformation field
714
        :param kwargs: additional arguments.
715
        :return: interpolated_field: tf.Tensor of shape=self.input_shape
716
        """
717
        high_res_field = self.interpolate(inputs)
718
719
        index = [int(3 * c) for c in self.cp_spacing]
720
        return high_res_field[
721
            :,
722
            index[0] : index[0] + self._output_shape[0],
723
            index[1] : index[1] + self._output_shape[1],
724
            index[2] : index[2] + self._output_shape[2],
725
        ]
726