Passed
Pull Request — main (#656)
by Yunguan
02:36
created

deepreg.model.layer.NormBlock.get_config()   A

Complexity

Conditions 1

Size

Total Lines 5
Code Lines 4

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 4
dl 0
loc 5
rs 10
c 0
b 0
f 0
cc 1
nop 1
1
"""This module defines custom layers."""
2
import itertools
3
4
import numpy as np
5
import tensorflow as tf
6
import tensorflow.keras.layers as tfkl
7
8
import deepreg.model.layer_util as layer_util
9
10
11
class Deconv3d(tfkl.Layer):
12
    """
13
    Wrap Conv3DTranspose to allow dynamic output padding calculation.
14
    """
15
16
    def __init__(
0 ignored issues
show
introduced by
"name" missing in parameter type documentation
Loading history...
17
        self,
18
        filters: int,
19
        output_shape: (tuple, None) = None,
20
        kernel_size: int = 3,
21
        strides: int = 1,
22
        padding: str = "same",
23
        name="deconv3d",
24
        **kwargs,
25
    ):
26
        """
27
        Init.
28
29
        :param filters: number of channels of the output
30
        :param output_shape: (out_dim1, out_dim2, out_dim3)
31
        :param kernel_size: int or tuple of 3 ints, e.g. (3,3,3) or 3
32
        :param strides: int or tuple of 3 ints, e.g. (1,1,1) or 1
33
        :param padding: same or valid.
34
        :param name: name of the layer.
35
        :param kwargs: additional arguments for Conv3DTranspose.
36
        """
37
        super().__init__(name=name)
38
        # save arguments
39
        self._filters = filters
40
        self._output_shape = output_shape
41
        self._kernel_size = kernel_size
42
        self._strides = strides
43
        self._padding = padding
44
        self._kwargs = kwargs
45
        # init layer variables
46
        self._deconv3d = None
47
48
    def build(self, input_shape):
0 ignored issues
show
introduced by
"input_shape" missing in parameter type documentation
Loading history...
49
        # pylint: disable-next=line-too-long
50
        """
51
        Calculate output padding on the fly.
52
53
        https://github.com/tensorflow/tensorflow/blob/1cf0898dd4331baf93fe77205550f2c2e6c90ee5/tensorflow/python/keras/utils/conv_utils.py#L139-L185
54
55
        :param input_shape: shape of input
56
        """
57
        super().build(input_shape)
58
59
        if isinstance(self._kernel_size, int):
60
            kernel_size = [self._kernel_size] * 3
61
        else:
62
            kernel_size = self._kernel_size
63
64
        if isinstance(self._strides, int):
65
            strides = [self._strides] * 3
66
        else:
67
            strides = self._strides
68
69
        output_padding = None
70
        if self._output_shape is not None:
71
            assert self._padding == "same"
72
            output_padding = [
73
                self._output_shape[i]
74
                - (
75
                    (input_shape[1 + i] - 1) * strides[i]
76
                    + kernel_size[i]
77
                    - 2 * (kernel_size[i] // 2)
78
                )
79
                for i in range(3)
80
            ]
81
        self._deconv3d = tfkl.Conv3DTranspose(
82
            filters=self._filters,
83
            kernel_size=self._kernel_size,
84
            strides=self._strides,
85
            padding=self._padding,
86
            output_padding=output_padding,
87
            **self._kwargs,
88
        )
89
90
    def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor:
91
        """
92
        Forward.
93
94
        :param inputs: input tensor.
95
        :param kwargs: additional arguments
96
        :return:
97
        """
98
        return self._deconv3d(inputs=inputs)
99
100
    def get_config(self) -> dict:
101
        """Return the config dictionary for recreating this class."""
102
        config = super().get_config()
103
        config.update(
104
            dict(
105
                filters=self._filters,
106
                output_shape=self._output_shape,
107
                kernel_size=self._kernel_size,
108
                strides=self._strides,
109
                padding=self._padding,
110
            )
111
        )
112
        config.update(self._kwargs)
113
        return config
114
115
116
class NormBlock(tfkl.Layer):
117
    """
118
    A block with layer - norm - activation.
119
    """
120
121
    layer_cls_dict = dict(conv3d=tfkl.Conv3D, deconv3d=Deconv3d)
122
    norm_cls_dict = dict(batch=tfkl.BatchNormalization, layer=tfkl.LayerNormalization)
123
124
    def __init__(
125
        self,
126
        layer_name: str,
127
        norm_name: str = "batch",
128
        activation: str = "relu",
129
        name: str = "norm_block",
130
        **kwargs,
131
    ):
132
        """
133
        Init.
134
135
        :param layer_name: class of the layer to be wrapped.
136
        :param norm_name: class of the normalization layer.
137
        :param activation: name of activation.
138
        :param name: name of the block layer.
139
        :param kwargs: additional arguments.
140
        """
141
        super().__init__()
142
        self._config = dict(
143
            layer_name=layer_name,
144
            norm_name=norm_name,
145
            activation=activation,
146
            name=name,
147
            **kwargs,
148
        )
149
        self._layer = self.layer_cls_dict[layer_name](use_bias=False, **kwargs)
150
        self._norm = self.norm_cls_dict[norm_name]()
151
        self._act = tfkl.Activation(activation=activation)
152
153
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
154
        """
155
        Forward.
156
157
        :param inputs: inputs for the layer
158
        :param training: training flag for normalization layers (default: None)
159
        :param kwargs: additional arguments.
160
        :return:
161
        """
162
        output = self._layer(inputs=inputs)
163
        output = self._norm(inputs=output, training=training)
164
        output = self._act(output)
165
        return output
166
167
    def get_config(self) -> dict:
168
        """Return the config dictionary for recreating this class."""
169
        config = super().get_config()
170
        config.update(self._config)
171
        return config
172
173
174
class Conv3dBlock(NormBlock):
175
    """
176
    A conv3d block having conv3d - norm - activation.
177
    """
178
179
    def __init__(
180
        self,
181
        name: str = "conv3d_block",
182
        **kwargs,
183
    ):
184
        """
185
        Init.
186
187
        :param name: name of the layer
188
        :param kwargs: additional arguments.
189
        """
190
        super().__init__(layer_name="conv3d", name=name, **kwargs)
191
192
193
class Deconv3dBlock(NormBlock):
194
    """
195
    A deconv3d block having conv3d - norm - activation.
196
    """
197
198
    def __init__(
199
        self,
200
        name: str = "deconv3d_block",
201
        **kwargs,
202
    ):
203
        """
204
        Init.
205
206
        :param name: name of the layer
207
        :param kwargs: additional arguments.
208
        """
209
        super().__init__(layer_name="deconv3d", name=name, **kwargs)
210
211
212
class Residual3dBlock(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
213
    def __init__(
214
        self,
215
        filters: int,
216
        kernel_size: (int, tuple) = 3,
217
        strides: (int, tuple) = 1,
218
        activation: str = "relu",
219
        **kwargs,
220
    ):
221
        """
222
        A resnet conv3d block.
223
224
        1. conved = conv3d(conv3d_block(inputs))
225
        2. out = act(norm(conved) + inputs)
226
227
        :param filters: int, number of filters in the convolutional layers
228
        :param kernel_size: int or tuple of 3 ints, e.g. (3,3,3) or 3
229
        :param strides: int or tuple of 3 ints, e.g. (1,1,1) or 1
230
        :param activation: name of activation
231
        :param kwargs: additional arguments.
232
        """
233
        super().__init__(**kwargs)
234
        # init layer variables
235
        self._conv3d_block = Conv3dBlock(
236
            filters=filters,
237
            kernel_size=kernel_size,
238
            strides=strides,
239
            padding="same",
240
        )
241
        self._conv3d = tfkl.Conv3D(
242
            filters=filters,
243
            kernel_size=kernel_size,
244
            strides=strides,
245
            padding="same",
246
            use_bias=False,
247
        )
248
        self._norm = tfkl.BatchNormalization()
249
        self._act = tfkl.Activation(activation=activation)
250
251
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
introduced by
Missing return documentation
Loading history...
252
        """
253
        :param inputs: shape = (batch, in_dim1, in_dim2, in_dim3, channels)
254
        :param training: training flag for normalization layers (default: None)
255
        :param kwargs: additional arguments.
256
        :return output: shape = (batch, in_dim1, in_dim2, in_dim3, channels)
257
        """
258
        return self._act(
259
            self._norm(
260
                inputs=self._conv3d(inputs=self._conv3d_block(inputs)),
261
                training=training,
262
            )
263
            + inputs
264
        )
265
266
267
class DownSampleResnetBlock(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
268
    def __init__(
269
        self,
270
        filters: int,
271
        kernel_size: (int, tuple) = 3,
272
        pooling: bool = True,
273
        **kwargs,
274
    ):
275
        """
276
        A down-sampling resnet conv3d block, with max-pooling or conv3d.
277
278
        1. conved = conv3d_block(inputs)  # adjust channel
279
        2. skip = residual_block(conved)  # develop feature
280
        3. pooled = pool(skip) # down-sample
281
282
        :param filters: number of channels of the output
283
        :param kernel_size: int or tuple of 3 ints, e.g. (3,3,3) or 3
284
        :param pooling: if True, use max pooling to downsample, otherwise use conv.
285
        :param kwargs: additional arguments.
286
        """
287
        super().__init__(**kwargs)
288
        # save parameters
289
        self._pooling = pooling
290
        # init layer variables
291
        self._conv3d_block = Conv3dBlock(
292
            filters=filters, kernel_size=kernel_size, padding="same"
293
        )
294
        self._residual_block = Residual3dBlock(filters=filters, kernel_size=kernel_size)
295
        self._max_pool3d = (
296
            tfkl.MaxPool3D(pool_size=(2, 2, 2), strides=(2, 2, 2), padding="same")
297
            if pooling
298
            else None
299
        )
300
        self._conv3d_block3 = (
301
            None
302
            if pooling
303
            else NormBlock(
304
                layer_name="conv3d",
305
                filters=filters,
306
                kernel_size=kernel_size,
307
                strides=2,
308
                padding="same",
309
            )
310
        )
311
312
    def call(self, inputs, training=None, **kwargs) -> (tf.Tensor, tf.Tensor):
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
313
        """
314
        :param inputs: shape = (batch, in_dim1, in_dim2, in_dim3, channels)
315
        :param training: training flag for normalization layers (default: None)
316
        :param kwargs: additional arguments.
317
        :return: (pooled, skip)
318
319
          - downsampled, shape = (batch, in_dim1//2, in_dim2//2, in_dim3//2, channels)
320
          - skipped, shape = (batch, in_dim1, in_dim2, in_dim3, channels)
321
        """
322
        conved = self._conv3d_block(inputs=inputs, training=training)  # adjust channel
323
        skip = self._residual_block(inputs=conved, training=training)  # develop feature
324
        pooled = (
325
            self._max_pool3d(inputs=skip)
326
            if self._pooling
327
            else self._conv3d_block3(inputs=skip, training=training)
0 ignored issues
show
Bug introduced by
self._conv3d_block3 does not seem to be callable.
Loading history...
328
        )  # downsample
329
        return pooled, skip
330
331
332
class UpSampleResnetBlock(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
333
    def __init__(self, filters, kernel_size=3, concat=False, **kwargs):
0 ignored issues
show
introduced by
"concat, filters, kernel_size" missing in parameter type documentation
Loading history...
334
        """
335
        An up-sampling resnet conv3d block, with deconv3d.
336
337
        :param filters: number of channels of the output
338
        :param kernel_size: int or tuple of 3 ints, e.g. (3,3,3) or 3
339
        :param concat: bool,specify how to combine input and skip connection images.
340
            If True, use concatenation, otherwise use sum (default=False).
341
        :param kwargs: additional arguments.
342
        """
343
        super().__init__(**kwargs)
344
        # save parameters
345
        self._filters = filters
346
        self._concat = concat
347
        # init layer variables
348
        self._deconv3d_block = None
349
        self._conv3d_block = Conv3dBlock(
350
            filters=filters, kernel_size=kernel_size, padding="same"
351
        )
352
        self._residual_block = Residual3dBlock(filters=filters, kernel_size=kernel_size)
353
354
    def build(self, input_shape):
0 ignored issues
show
introduced by
"input_shape" missing in parameter type documentation
Loading history...
355
        """
356
        :param input_shape: tuple, (downsampled_image_shape, skip_image_shape)
357
        """
358
        super().build(input_shape)
359
        skip_shape = input_shape[1][1:4]
360
        self._deconv3d_block = Deconv3dBlock(
361
            filters=self._filters,
362
            output_shape=skip_shape,
363
            kernel_size=3,
364
            strides=2,
365
            padding="same",
366
        )
367
368
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
369
        r"""
370
        :param inputs: tuple
371
372
          - down-sampled
373
          - skipped
374
375
        :param training: training flag for normalization layers (default: None)
376
        :param kwargs: additional arguments.
377
        :return: shape = (batch, \*skip_connection_image_shape, filters]
378
        """
379
        up_sampled, skip = inputs[0], inputs[1]
380
        up_sampled = self._deconv3d_block(
381
            inputs=up_sampled, training=training
382
        )  # up sample and change channel
383
        up_sampled = (
384
            tf.concat([up_sampled, skip], axis=4) if self._concat else up_sampled + skip
385
        )  # combine
386
        up_sampled = self._conv3d_block(
387
            inputs=up_sampled, training=training
388
        )  # adjust channel
389
        up_sampled = self._residual_block(inputs=up_sampled, training=training)  # conv
390
        return up_sampled
391
392
393
class Conv3dWithResize(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
394
    def __init__(
395
        self,
396
        output_shape: tuple,
397
        filters: int,
398
        kernel_initializer: str = "glorot_uniform",
399
        activation: (str, None) = None,
400
        **kwargs,
401
    ):
402
        """
403
        A layer contains conv3d - resize3d.
404
405
        :param output_shape: tuple, (out_dim1, out_dim2, out_dim3)
406
        :param filters: int, number of channels of the output
407
        :param kernel_initializer: str, defines the initialization method
408
        :param activation: str, defines the activation function
409
        :param kwargs: additional arguments.
410
        """
411
        super().__init__(**kwargs)
412
        # save parameters
413
        self._output_shape = output_shape
414
        # init layer variables
415
        self._conv3d = tfkl.Conv3D(
416
            filters=filters,
417
            kernel_size=3,
418
            strides=1,
419
            padding="same",
420
            kernel_initializer=kernel_initializer,
421
            activation=activation,
422
        )  # if not zero, with init NN, ddf may be too large
423
424
    def call(self, inputs, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"inputs" missing in parameter type documentation
Loading history...
425
        """
426
        :param inputs: shape = (batch, dim1, dim2, dim3, channels)
427
        :param kwargs: additional arguments.
428
        :return: shape = (batch, out_dim1, out_dim2, out_dim3, channels)
429
        """
430
        output = self._conv3d(inputs=inputs)
431
        output = layer_util.resize3d(image=output, size=self._output_shape)
432
        return output
433
434
435
class Warping(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
436
    def __init__(self, fixed_image_size: tuple, **kwargs):
437
        """
438
        A layer warps an image using DDF.
439
440
        Reference:
441
442
        - transform of neuron
443
          https://github.com/adalca/neurite/blob/legacy/neuron/utils.py
444
445
          where vol = image, loc_shift = ddf
446
447
        :param fixed_image_size: shape = (f_dim1, f_dim2, f_dim3)
448
             or (f_dim1, f_dim2, f_dim3, ch) with the last channel for features
449
        :param kwargs: additional arguments.
450
        """
451
        super().__init__(**kwargs)
452
        self.grid_ref = layer_util.get_reference_grid(grid_size=fixed_image_size)[
453
            None, ...
454
        ]  # shape = (1, f_dim1, f_dim2, f_dim3, 3)
455
456
    def call(self, inputs, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"inputs" missing in parameter type documentation
Loading history...
457
        """
458
        :param inputs: (ddf, image)
459
460
          - ddf, shape = (batch, f_dim1, f_dim2, f_dim3, 3), dtype = float32
461
          - image, shape = (batch, m_dim1, m_dim2, m_dim3), dtype = float32
462
        :param kwargs: additional arguments.
463
        :return: shape = (batch, f_dim1, f_dim2, f_dim3)
464
        """
465
        return layer_util.resample(vol=inputs[1], loc=self.grid_ref + inputs[0])
466
467
468
class IntDVF(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
469
    def __init__(self, fixed_image_size: tuple, num_steps: int = 7, **kwargs):
470
        """
471
        Layer calculates DVF from DDF.
472
473
        Reference:
474
475
        - integrate_vec of neuron
476
          https://github.com/adalca/neurite/blob/legacy/neuron/utils.py
477
478
        :param fixed_image_size: tuple, (f_dim1, f_dim2, f_dim3)
479
        :param num_steps: int, number of steps for integration
480
        :param kwargs: additional arguments.
481
        """
482
        super().__init__(**kwargs)
483
        self._warping = Warping(fixed_image_size=fixed_image_size)
484
        self._num_steps = num_steps
485
486
    def call(self, inputs, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"inputs" missing in parameter type documentation
Loading history...
487
        """
488
        :param inputs: dvf, shape = (batch, f_dim1, f_dim2, f_dim3, 3), type = float32
489
        :param kwargs: additional arguments.
490
        :return: ddf, shape = (batch, f_dim1, f_dim2, f_dim3, 3)
491
        """
492
        ddf = inputs / (2 ** self._num_steps)
493
        for _ in range(self._num_steps):
494
            ddf += self._warping(inputs=[ddf, ddf])
495
        return ddf
496
497
498
class AdditiveUpSampling(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
499
    def __init__(self, output_shape: tuple, stride: (int, list) = 2, **kwargs):
500
        """
501
        Layer up-samples 3d tensor and reduce channels using split and sum.
502
503
        :param output_shape: (out_dim1, out_dim2, out_dim3)
504
        :param stride: int, 1-D Tensor or list
505
        :param kwargs: additional arguments.
506
        """
507
        super().__init__(**kwargs)
508
        # save parameters
509
        self._stride = stride
510
        self._output_shape = output_shape
511
512
    def call(self, inputs, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"inputs" missing in parameter type documentation
Loading history...
introduced by
"ValueError" not documented as being raised
Loading history...
513
        """
514
        :param inputs: shape = (batch, dim1, dim2, dim3, channels)
515
        :param kwargs: additional arguments.
516
        :return: shape = (batch, out_dim1, out_dim2, out_dim3, channels//stride]
517
        """
518
        if inputs.shape[4] % self._stride != 0:
519
            raise ValueError("The channel dimension can not be divided by the stride")
520
        output = layer_util.resize3d(image=inputs, size=self._output_shape)
521
        # a list of (batch, out_dim1, out_dim2, out_dim3, channels//stride)
522
        output = tf.split(output, num_or_size_splits=self._stride, axis=4)
0 ignored issues
show
Unused Code introduced by
Argument 'axis' passed by position and keyword in function call
Loading history...
523
        # (batch, out_dim1, out_dim2, out_dim3, channels//stride)
524
        output = tf.reduce_sum(tf.stack(output, axis=5), axis=5)
525
        return output
526
527
528
class LocalNetResidual3dBlock(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
529
    def __init__(
530
        self,
531
        filters: int,
532
        kernel_size: (int, tuple) = 3,
533
        strides: (int, tuple) = 1,
534
        activation: str = "relu",
535
        **kwargs,
536
    ):
537
        """
538
        A resnet conv3d block, simpler than Residual3dBlock.
539
540
        1. conved = conv3d(inputs)
541
        2. out = act(norm(conved) + inputs)
542
543
        :param filters: number of channels of the output
544
        :param kernel_size: int or tuple of 3 ints, e.g. (3,3,3) or 3
545
        :param strides: int or tuple of 3 ints, e.g. (1,1,1) or 1
546
        :param activation: name of activation
547
        :param kwargs: additional arguments.
548
        """
549
        super().__init__(**kwargs)
550
        # init layer variables
551
        self._conv3d = tfkl.Conv3D(
552
            filters=filters,
553
            kernel_size=kernel_size,
554
            strides=strides,
555
            padding="same",
556
            use_bias=False,
557
        )
558
        self._norm = tfkl.BatchNormalization()
559
        self._act = tfkl.Activation(activation=activation)
560
561
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
562
        return self._act(
563
            self._norm(inputs=self._conv3d(inputs=inputs[0]), training=training)
564
            + inputs[1]
565
        )
566
567
568
class LocalNetUpSampleResnetBlock(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
569
    def __init__(self, filters: int, use_additive_upsampling: bool = True, **kwargs):
570
        """
571
        Layer up-samples tensor with two inputs (skipped and down-sampled).
572
573
        :param filters: int, number of output channels
574
        :param use_additive_upsampling: bool to used additive upsampling
575
        :param kwargs: additional arguments.
576
        """
577
        super().__init__(**kwargs)
578
        # save parameters
579
        self._filters = filters
580
        self._use_additive_upsampling = use_additive_upsampling
581
        # init layer variables
582
        self._deconv3d_block = None
583
        self._additive_upsampling = None
584
        self._conv3d_block = Conv3dBlock(filters=filters, kernel_size=3, padding="same")
585
        self._residual_block = LocalNetResidual3dBlock(filters=filters, strides=1)
586
587
    def build(self, input_shape):
0 ignored issues
show
introduced by
"input_shape" missing in parameter type documentation
Loading history...
588
        """
589
        :param input_shape: tuple (nonskip_tensor_shape, skip_tensor_shape)
590
        """
591
        super().build(input_shape)
592
593
        output_shape = input_shape[1][1:4]
594
        self._deconv3d_block = Deconv3dBlock(
595
            filters=self._filters,
596
            output_shape=output_shape,
597
            kernel_size=3,
598
            strides=2,
599
            padding="same",
600
        )
601
        if self._use_additive_upsampling:
602
            self._additive_upsampling = AdditiveUpSampling(output_shape=output_shape)
603
604
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
605
        """
606
        :param inputs: list = [inputs_nonskip, inputs_skip]
607
        :param training: training flag for normalization layers (default: None)
608
        :param kwargs: additional arguments.
609
        :return:
610
        """
611
        inputs_nonskip, inputs_skip = inputs[0], inputs[1]
612
        h0 = self._deconv3d_block(inputs=inputs_nonskip, training=training)
613
        if self._use_additive_upsampling:
614
            h0 += self._additive_upsampling(inputs=inputs_nonskip)
615
        r1 = h0 + inputs_skip
616
        r2 = self._conv3d_block(inputs=h0, training=training)
617
        h1 = self._residual_block(inputs=[r2, r1], training=training)
618
        return h1
619
620
621
class ResizeCPTransform(tfkl.Layer):
622
    """
623
    Layer for getting the control points from the output of a image-to-image network.
624
    It uses an anti-aliasing Gaussian filter before downsampling.
625
    """
626
627
    def __init__(self, control_point_spacing: (list, tuple, int), **kwargs):
628
        """
629
        :param control_point_spacing: list or int
630
        :param kwargs: additional arguments.
631
        """
632
        super().__init__(**kwargs)
633
634
        if isinstance(control_point_spacing, int):
635
            control_point_spacing = [control_point_spacing] * 3
636
637
        self.kernel_sigma = [
638
            0.44 * cp for cp in control_point_spacing
639
        ]  # 0.44 = ln(4)/pi
640
        self.cp_spacing = control_point_spacing
641
        self.kernel = None
642
        self._output_shape = None
643
644
    def build(self, input_shape):
645
        super().build(input_shape=input_shape)
646
647
        self.kernel = layer_util.gaussian_filter_3d(self.kernel_sigma)
648
        output_shape = [
649
            tf.cast(tf.math.ceil(v / c) + 3, tf.int32)
650
            for v, c in zip(input_shape[1:-1], self.cp_spacing)
651
        ]
652
        self._output_shape = output_shape
653
654
    def call(self, inputs, **kwargs) -> tf.Tensor:
655
        output = tf.nn.conv3d(
656
            inputs, self.kernel, strides=(1, 1, 1, 1, 1), padding="SAME"
657
        )
658
        return layer_util.resize3d(image=output, size=self._output_shape)
659
660
661
class BSplines3DTransform(tfkl.Layer):
662
    """
663
     Layer for BSplines interpolation with precomputed cubic spline filters.
664
     It assumes a full sized image from which:
665
     1. it compute the contol points values by downsampling the initial image
666
     2. performs the interpolation
667
     3. crops the image around the valid values.
668
669
    :param cp_spacing: int or tuple of three ints specifying the spacing (in pixels)
670
        in each dimension. When a single int is used,
671
        the same spacing to all dimensions is used
672
    :param output_shape: (batch_size, dim0, dim1, dim2, 3) of the high resolution
673
        deformation fields.
674
    :param kwargs: additional arguments.
675
    """
676
677
    def __init__(self, cp_spacing: (int, tuple), output_shape: tuple, **kwargs):
678
679
        super().__init__(**kwargs)
680
681
        self.filters = []
682
        self._output_shape = output_shape
683
684
        if isinstance(cp_spacing, int):
685
            self.cp_spacing = (cp_spacing, cp_spacing, cp_spacing)
686
        else:
687
            self.cp_spacing = cp_spacing
688
689
    def build(self, input_shape: tuple):
0 ignored issues
show
introduced by
Redundant returns documentation
Loading history...
690
        """
691
        :param input_shape: tuple with the input shape
692
        :return: None
693
        """
694
695
        super().build(input_shape=input_shape)
696
697
        b = {
698
            0: lambda u: np.float64((1 - u) ** 3 / 6),
699
            1: lambda u: np.float64((3 * (u ** 3) - 6 * (u ** 2) + 4) / 6),
700
            2: lambda u: np.float64((-3 * (u ** 3) + 3 * (u ** 2) + 3 * u + 1) / 6),
701
            3: lambda u: np.float64(u ** 3 / 6),
702
        }
703
704
        filters = np.zeros(
705
            (
706
                4 * self.cp_spacing[0],
707
                4 * self.cp_spacing[1],
708
                4 * self.cp_spacing[2],
709
                3,
710
                3,
711
            ),
712
            dtype=np.float32,
713
        )
714
715
        u_arange = 1 - np.arange(
716
            1 / (2 * self.cp_spacing[0]), 1, 1 / self.cp_spacing[0]
717
        )
718
        v_arange = 1 - np.arange(
719
            1 / (2 * self.cp_spacing[1]), 1, 1 / self.cp_spacing[1]
720
        )
721
        w_arange = 1 - np.arange(
722
            1 / (2 * self.cp_spacing[2]), 1, 1 / self.cp_spacing[2]
723
        )
724
725
        filter_idx = [[0, 1, 2, 3] for _ in range(3)]
726
        filter_coord = list(itertools.product(*filter_idx))
727
728
        for f_idx in filter_coord:
729
            for it_dim in range(3):
730
                filters[
731
                    f_idx[0] * self.cp_spacing[0] : (f_idx[0] + 1) * self.cp_spacing[0],
732
                    f_idx[1] * self.cp_spacing[1] : (f_idx[1] + 1) * self.cp_spacing[1],
733
                    f_idx[2] * self.cp_spacing[2] : (f_idx[2] + 1) * self.cp_spacing[2],
734
                    it_dim,
735
                    it_dim,
736
                ] = (
737
                    b[f_idx[0]](u_arange)[:, None, None]
738
                    * b[f_idx[1]](v_arange)[None, :, None]
739
                    * b[f_idx[2]](w_arange)[None, None, :]
740
                )
741
742
        self.filter = tf.convert_to_tensor(filters)
743
744
    def interpolate(self, field) -> tf.Tensor:
0 ignored issues
show
introduced by
"field" missing in parameter type documentation
Loading history...
745
        """
746
        :param field: tf.Tensor with shape=number_of_control_points_per_dim
747
        :return: interpolated_field: tf.Tensor
748
        """
749
750
        image_shape = tuple(
751
            [(a - 1) * b + 4 * b for a, b in zip(field.shape[1:-1], self.cp_spacing)]
752
        )
753
754
        output_shape = (field.shape[0],) + image_shape + (3,)
755
        return tf.nn.conv3d_transpose(
756
            field,
757
            self.filter,
758
            output_shape=output_shape,
759
            strides=self.cp_spacing,
760
            padding="VALID",
761
        )
762
763
    def call(self, inputs, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"inputs" missing in parameter type documentation
Loading history...
764
        """
765
        :param inputs: tf.Tensor defining a low resolution free-form deformation field
766
        :param kwargs: additional arguments.
767
        :return: interpolated_field: tf.Tensor of shape=self.input_shape
768
        """
769
        high_res_field = self.interpolate(inputs)
770
771
        index = [int(3 * c) for c in self.cp_spacing]
772
        return high_res_field[
773
            :,
774
            index[0] : index[0] + self._output_shape[0],
775
            index[1] : index[1] + self._output_shape[1],
776
            index[2] : index[2] + self._output_shape[2],
777
        ]
778