Passed
Pull Request — main (#656)
by Yunguan
02:58
created

deepreg.model.layer.BSplines3DTransform.__init__()   A

Complexity

Conditions 2

Size

Total Lines 11
Code Lines 7

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 7
dl 0
loc 11
rs 10
c 0
b 0
f 0
cc 2
nop 4
1
"""This module defines custom layers."""
2
import itertools
3
4
import numpy as np
5
import tensorflow as tf
6
import tensorflow.keras.layers as tfkl
7
8
import deepreg.model.layer_util as layer_util
9
10
11
class Deconv3d(tfkl.Layer):
12
    """
13
    Wrap Conv3DTranspose to allow dynamic output padding calculation.
14
    """
15
16
    def __init__(
0 ignored issues
show
introduced by
"name" missing in parameter type documentation
Loading history...
17
        self,
18
        filters: int,
19
        output_shape: (tuple, None) = None,
20
        kernel_size: int = 3,
21
        strides: int = 1,
22
        padding: str = "same",
23
        name="deconv3d",
24
        **kwargs,
25
    ):
26
        """
27
        Init.
28
29
        :param filters: number of channels of the output
30
        :param output_shape: (out_dim1, out_dim2, out_dim3)
31
        :param kernel_size: int or tuple of 3 ints, e.g. (3,3,3) or 3
32
        :param strides: int or tuple of 3 ints, e.g. (1,1,1) or 1
33
        :param padding: same or valid.
34
        :param name: name of the layer.
35
        :param kwargs: additional arguments for Conv3DTranspose.
36
        """
37
        super().__init__(name=name)
38
        # save arguments
39
        self._filters = filters
40
        self._output_shape = output_shape
41
        self._kernel_size = kernel_size
42
        self._strides = strides
43
        self._padding = padding
44
        self._kwargs = kwargs
45
        # init layer variables
46
        self._deconv3d = None
47
48
    def build(self, input_shape):
0 ignored issues
show
introduced by
"input_shape" missing in parameter type documentation
Loading history...
49
        # pylint: disable-next=line-too-long
50
        """
51
        Calculate output padding on the fly.
52
53
        https://github.com/tensorflow/tensorflow/blob/1cf0898dd4331baf93fe77205550f2c2e6c90ee5/tensorflow/python/keras/utils/conv_utils.py#L139-L185
54
55
        :param input_shape: shape of input
56
        """
57
        super().build(input_shape)
58
59
        if isinstance(self._kernel_size, int):
60
            kernel_size = [self._kernel_size] * 3
61
        else:
62
            kernel_size = self._kernel_size
63
64
        if isinstance(self._strides, int):
65
            strides = [self._strides] * 3
66
        else:
67
            strides = self._strides
68
69
        output_padding = None
70
        if self._output_shape is not None:
71
            assert self._padding == "same"
72
            output_padding = [
73
                self._output_shape[i]
74
                - (
75
                    (input_shape[1 + i] - 1) * strides[i]
76
                    + kernel_size[i]
77
                    - 2 * (kernel_size[i] // 2)
78
                )
79
                for i in range(3)
80
            ]
81
        self._deconv3d = tfkl.Conv3DTranspose(
82
            filters=self._filters,
83
            kernel_size=self._kernel_size,
84
            strides=self._strides,
85
            padding=self._padding,
86
            output_padding=output_padding,
87
            **self._kwargs,
88
        )
89
90
    def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor:
91
        """
92
        Forward.
93
94
        :param inputs: input tensor.
95
        :param kwargs: additional arguments
96
        :return:
97
        """
98
        return self._deconv3d(inputs=inputs)
99
100
    def get_config(self) -> dict:
101
        """Return the config dictionary for recreating this class."""
102
        config = super().get_config()
103
        config.update(
104
            dict(
105
                filters=self._filters,
106
                output_shape=self._output_shape,
107
                kernel_size=self._kernel_size,
108
                strides=self._strides,
109
                padding=self._padding,
110
            )
111
        )
112
        config.update(self._kwargs)
113
        return config
114
115
116
class NormBlock(tfkl.Layer):
117
    """
118
    A block with layer - norm - activation.
119
    """
120
121
    layer_cls_dict = dict(conv3d=tfkl.Conv3D, deconv3d=Deconv3d)
122
    norm_cls_dict = dict(batch=tfkl.BatchNormalization, layer=tfkl.LayerNormalization)
123
124
    def __init__(
125
        self,
126
        layer_name: str,
127
        norm_name: str = "batch",
128
        activation: str = "relu",
129
        name: str = "norm_block",
130
        **kwargs,
131
    ):
132
        """
133
        Init.
134
135
        :param layer_name: class of the layer to be wrapped.
136
        :param norm_name: class of the normalization layer.
137
        :param activation: name of activation.
138
        :param name: name of the block layer.
139
        :param kwargs: additional arguments.
140
        """
141
        super().__init__()
142
        self._config = dict(
143
            layer_name=layer_name,
144
            norm_name=norm_name,
145
            activation=activation,
146
            name=name,
147
            **kwargs,
148
        )
149
        self._layer = self.layer_cls_dict[layer_name](use_bias=False, **kwargs)
150
        self._norm = self.norm_cls_dict[norm_name]()
151
        self._act = tfkl.Activation(activation=activation)
152
153
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
154
        """
155
        Forward.
156
157
        :param inputs: inputs for the layer
158
        :param training: training flag for normalization layers (default: None)
159
        :param kwargs: additional arguments.
160
        :return:
161
        """
162
        output = self._layer(inputs=inputs)
163
        output = self._norm(inputs=output, training=training)
164
        output = self._act(output)
165
        return output
166
167
    def get_config(self) -> dict:
168
        """Return the config dictionary for recreating this class."""
169
        config = super().get_config()
170
        config.update(self._config)
171
        return config
172
173
174
class Conv3dBlock(NormBlock):
175
    """
176
    A conv3d block having conv3d - norm - activation.
177
    """
178
179
    def __init__(
180
        self,
181
        name: str = "conv3d_block",
182
        **kwargs,
183
    ):
184
        """
185
        Init.
186
187
        :param name: name of the layer
188
        :param kwargs: additional arguments.
189
        """
190
        super().__init__(layer_name="conv3d", name=name, **kwargs)
191
192
193
class Deconv3dBlock(NormBlock):
194
    """
195
    A deconv3d block having conv3d - norm - activation.
196
    """
197
198
    def __init__(
199
        self,
200
        name: str = "deconv3d_block",
201
        **kwargs,
202
    ):
203
        """
204
        Init.
205
206
        :param name: name of the layer
207
        :param kwargs: additional arguments.
208
        """
209
        super().__init__(layer_name="deconv3d", name=name, **kwargs)
210
211
212
class Resize3d(tfkl.Layer):
213
    """
214
    Resize image in two folds.
215
216
    - resize dim2 and dim3
217
    - resize dim1 and dim2
218
    """
219
220
    def __init__(
221
        self,
222
        shape: tuple,
223
        method: str = tf.image.ResizeMethod.BILINEAR,
224
        name: str = "resize3d",
225
    ):
226
        """
227
        Init, save arguments.
228
229
        :param shape: (dim1, dim2, dim3)
230
        :param method: tf.image.ResizeMethod
231
        :param name: name of the layer
232
        """
233
        super().__init__(name=name)
234
        assert len(shape) == 3
235
        self._shape = shape
236
        self._method = method
237
238
    def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"ValueError" not documented as being raised
Loading history...
239
        """
240
        Perform two fold resize.
241
242
        :param inputs: shape = (batch, dim1, dim2, dim3, channels)
243
                                     or (batch, dim1, dim2, dim3)
244
                                     or (dim1, dim2, dim3)
245
        :param kwargs: additional arguments
246
        :return: shape = (batch, out_dim1, out_dim2, out_dim3, channels)
247
                                or (batch, dim1, dim2, dim3)
248
                                or (dim1, dim2, dim3)
249
        """
250
        # sanity check
251
        image = inputs
252
        image_dim = len(image.shape)
253
254
        # init
255
        if image_dim == 5:
256
            has_channel = True
257
            has_batch = True
258
            input_image_shape = image.shape[1:4]
259
        elif image_dim == 4:
260
            has_channel = False
261
            has_batch = True
262
            input_image_shape = image.shape[1:4]
263
        elif image_dim == 3:
264
            has_channel = False
265
            has_batch = False
266
            input_image_shape = image.shape[0:3]
267
        else:
268
            raise ValueError(
269
                "Resize3d takes input image of dimension 3 or 4 or 5, "
270
                "corresponding to (dim1, dim2, dim3) "
271
                "or (batch, dim1, dim2, dim3) "
272
                "or (batch, dim1, dim2, dim3, channels), "
273
                "got image shape{}".format(image.shape)
274
            )
275
276
        # no need of resize
277
        if input_image_shape == tuple(self._shape):
278
            return image
279
280
        # expand to five dimensions
281
        if not has_batch:
282
            image = tf.expand_dims(image, axis=0)
283
        if not has_channel:
284
            image = tf.expand_dims(image, axis=-1)
285
        assert len(image.shape) == 5  # (batch, dim1, dim2, dim3, channels)
286
        image_shape = tf.shape(image)
287
288
        # merge axis 0 and 1
289
        output = tf.reshape(
290
            image, (-1, image_shape[2], image_shape[3], image_shape[4])
291
        )  # (batch * dim1, dim2, dim3, channels)
292
293
        # resize dim2 and dim3
294
        output = tf.image.resize(
295
            images=output, size=self._shape[1:3], method=self._method
296
        )  # (batch * dim1, out_dim2, out_dim3, channels)
297
298
        # split axis 0 and merge axis 3 and 4
299
        output = tf.reshape(
300
            output,
301
            shape=(-1, image_shape[1], self._shape[1], self._shape[2] * image_shape[4]),
302
        )  # (batch, dim1, out_dim2, out_dim3 * channels)
303
304
        # resize dim1 and dim2
305
        output = tf.image.resize(
306
            images=output, size=self._shape[:2], method=self._method
307
        )  # (batch, out_dim1, out_dim2, out_dim3 * channels)
308
309
        # reshape
310
        output = tf.reshape(
311
            output, shape=[-1, *self._shape, image_shape[4]]
312
        )  # (batch, out_dim1, out_dim2, out_dim3, channels)
313
314
        # squeeze to original dimension
315
        if not has_batch:
316
            output = tf.squeeze(output, axis=0)
317
        if not has_channel:
318
            output = tf.squeeze(output, axis=-1)
319
        return output
320
321
    def get_config(self):
322
        config = super().get_config()
323
        config["shape"] = self._shape
324
        config["method"] = self._method
325
        return config
326
327
328
class Residual3dBlock(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
329
    def __init__(
330
        self,
331
        filters: int,
332
        kernel_size: (int, tuple) = 3,
333
        strides: (int, tuple) = 1,
334
        activation: str = "relu",
335
        **kwargs,
336
    ):
337
        """
338
        A resnet conv3d block.
339
340
        1. conved = conv3d(conv3d_block(inputs))
341
        2. out = act(norm(conved) + inputs)
342
343
        :param filters: int, number of filters in the convolutional layers
344
        :param kernel_size: int or tuple of 3 ints, e.g. (3,3,3) or 3
345
        :param strides: int or tuple of 3 ints, e.g. (1,1,1) or 1
346
        :param activation: name of activation
347
        :param kwargs: additional arguments.
348
        """
349
        super().__init__(**kwargs)
350
        # init layer variables
351
        self._conv3d_block = Conv3dBlock(
352
            filters=filters,
353
            kernel_size=kernel_size,
354
            strides=strides,
355
            padding="same",
356
        )
357
        self._conv3d = tfkl.Conv3D(
358
            filters=filters,
359
            kernel_size=kernel_size,
360
            strides=strides,
361
            padding="same",
362
            use_bias=False,
363
        )
364
        self._norm = tfkl.BatchNormalization()
365
        self._act = tfkl.Activation(activation=activation)
366
367
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
introduced by
Missing return documentation
Loading history...
368
        """
369
        :param inputs: shape = (batch, in_dim1, in_dim2, in_dim3, channels)
370
        :param training: training flag for normalization layers (default: None)
371
        :param kwargs: additional arguments.
372
        :return output: shape = (batch, in_dim1, in_dim2, in_dim3, channels)
373
        """
374
        return self._act(
375
            self._norm(
376
                inputs=self._conv3d(inputs=self._conv3d_block(inputs)),
377
                training=training,
378
            )
379
            + inputs
380
        )
381
382
383
class DownSampleResnetBlock(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
384
    def __init__(
385
        self,
386
        filters: int,
387
        kernel_size: (int, tuple) = 3,
388
        pooling: bool = True,
389
        **kwargs,
390
    ):
391
        """
392
        A down-sampling resnet conv3d block, with max-pooling or conv3d.
393
394
        1. conved = conv3d_block(inputs)  # adjust channel
395
        2. skip = residual_block(conved)  # develop feature
396
        3. pooled = pool(skip) # down-sample
397
398
        :param filters: number of channels of the output
399
        :param kernel_size: int or tuple of 3 ints, e.g. (3,3,3) or 3
400
        :param pooling: if True, use max pooling to downsample, otherwise use conv.
401
        :param kwargs: additional arguments.
402
        """
403
        super().__init__(**kwargs)
404
        # save parameters
405
        self._pooling = pooling
406
        # init layer variables
407
        self._conv3d_block = Conv3dBlock(
408
            filters=filters, kernel_size=kernel_size, padding="same"
409
        )
410
        self._residual_block = Residual3dBlock(filters=filters, kernel_size=kernel_size)
411
        self._max_pool3d = (
412
            tfkl.MaxPool3D(pool_size=(2, 2, 2), strides=(2, 2, 2), padding="same")
413
            if pooling
414
            else None
415
        )
416
        self._conv3d_block3 = (
417
            None
418
            if pooling
419
            else NormBlock(
420
                layer_name="conv3d",
421
                filters=filters,
422
                kernel_size=kernel_size,
423
                strides=2,
424
                padding="same",
425
            )
426
        )
427
428
    def call(self, inputs, training=None, **kwargs) -> (tf.Tensor, tf.Tensor):
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
429
        """
430
        :param inputs: shape = (batch, in_dim1, in_dim2, in_dim3, channels)
431
        :param training: training flag for normalization layers (default: None)
432
        :param kwargs: additional arguments.
433
        :return: (pooled, skip)
434
435
          - downsampled, shape = (batch, in_dim1//2, in_dim2//2, in_dim3//2, channels)
436
          - skipped, shape = (batch, in_dim1, in_dim2, in_dim3, channels)
437
        """
438
        conved = self._conv3d_block(inputs=inputs, training=training)  # adjust channel
439
        skip = self._residual_block(inputs=conved, training=training)  # develop feature
440
        pooled = (
441
            self._max_pool3d(inputs=skip)
442
            if self._pooling
443
            else self._conv3d_block3(inputs=skip, training=training)
0 ignored issues
show
Bug introduced by
self._conv3d_block3 does not seem to be callable.
Loading history...
444
        )  # downsample
445
        return pooled, skip
446
447
448
class UpSampleResnetBlock(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
449
    def __init__(self, filters, kernel_size=3, concat=False, **kwargs):
0 ignored issues
show
introduced by
"concat, filters, kernel_size" missing in parameter type documentation
Loading history...
450
        """
451
        An up-sampling resnet conv3d block, with deconv3d.
452
453
        :param filters: number of channels of the output
454
        :param kernel_size: int or tuple of 3 ints, e.g. (3,3,3) or 3
455
        :param concat: bool,specify how to combine input and skip connection images.
456
            If True, use concatenation, otherwise use sum (default=False).
457
        :param kwargs: additional arguments.
458
        """
459
        super().__init__(**kwargs)
460
        # save parameters
461
        self._filters = filters
462
        self._concat = concat
463
        # init layer variables
464
        self._deconv3d_block = None
465
        self._conv3d_block = Conv3dBlock(
466
            filters=filters, kernel_size=kernel_size, padding="same"
467
        )
468
        self._residual_block = Residual3dBlock(filters=filters, kernel_size=kernel_size)
469
470
    def build(self, input_shape):
0 ignored issues
show
introduced by
"input_shape" missing in parameter type documentation
Loading history...
471
        """
472
        :param input_shape: tuple, (downsampled_image_shape, skip_image_shape)
473
        """
474
        super().build(input_shape)
475
        skip_shape = input_shape[1][1:4]
476
        self._deconv3d_block = Deconv3dBlock(
477
            filters=self._filters,
478
            output_shape=skip_shape,
479
            kernel_size=3,
480
            strides=2,
481
            padding="same",
482
        )
483
484
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
485
        r"""
486
        :param inputs: tuple
487
488
          - down-sampled
489
          - skipped
490
491
        :param training: training flag for normalization layers (default: None)
492
        :param kwargs: additional arguments.
493
        :return: shape = (batch, \*skip_connection_image_shape, filters]
494
        """
495
        up_sampled, skip = inputs[0], inputs[1]
496
        up_sampled = self._deconv3d_block(
497
            inputs=up_sampled, training=training
498
        )  # up sample and change channel
499
        up_sampled = (
500
            tf.concat([up_sampled, skip], axis=4) if self._concat else up_sampled + skip
501
        )  # combine
502
        up_sampled = self._conv3d_block(
503
            inputs=up_sampled, training=training
504
        )  # adjust channel
505
        up_sampled = self._residual_block(inputs=up_sampled, training=training)  # conv
506
        return up_sampled
507
508
509
class Warping(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
510
    def __init__(self, fixed_image_size: tuple, **kwargs):
511
        """
512
        A layer warps an image using DDF.
513
514
        Reference:
515
516
        - transform of neuron
517
          https://github.com/adalca/neurite/blob/legacy/neuron/utils.py
518
519
          where vol = image, loc_shift = ddf
520
521
        :param fixed_image_size: shape = (f_dim1, f_dim2, f_dim3)
522
             or (f_dim1, f_dim2, f_dim3, ch) with the last channel for features
523
        :param kwargs: additional arguments.
524
        """
525
        super().__init__(**kwargs)
526
        self.grid_ref = layer_util.get_reference_grid(grid_size=fixed_image_size)[
527
            None, ...
528
        ]  # shape = (1, f_dim1, f_dim2, f_dim3, 3)
529
530
    def call(self, inputs, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"inputs" missing in parameter type documentation
Loading history...
531
        """
532
        :param inputs: (ddf, image)
533
534
          - ddf, shape = (batch, f_dim1, f_dim2, f_dim3, 3), dtype = float32
535
          - image, shape = (batch, m_dim1, m_dim2, m_dim3), dtype = float32
536
        :param kwargs: additional arguments.
537
        :return: shape = (batch, f_dim1, f_dim2, f_dim3)
538
        """
539
        return layer_util.resample(vol=inputs[1], loc=self.grid_ref + inputs[0])
540
541
542
class IntDVF(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
543
    def __init__(self, fixed_image_size: tuple, num_steps: int = 7, **kwargs):
544
        """
545
        Layer calculates DVF from DDF.
546
547
        Reference:
548
549
        - integrate_vec of neuron
550
          https://github.com/adalca/neurite/blob/legacy/neuron/utils.py
551
552
        :param fixed_image_size: tuple, (f_dim1, f_dim2, f_dim3)
553
        :param num_steps: int, number of steps for integration
554
        :param kwargs: additional arguments.
555
        """
556
        super().__init__(**kwargs)
557
        self._warping = Warping(fixed_image_size=fixed_image_size)
558
        self._num_steps = num_steps
559
560
    def call(self, inputs, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"inputs" missing in parameter type documentation
Loading history...
561
        """
562
        :param inputs: dvf, shape = (batch, f_dim1, f_dim2, f_dim3, 3), type = float32
563
        :param kwargs: additional arguments.
564
        :return: ddf, shape = (batch, f_dim1, f_dim2, f_dim3, 3)
565
        """
566
        ddf = inputs / (2 ** self._num_steps)
567
        for _ in range(self._num_steps):
568
            ddf += self._warping(inputs=[ddf, ddf])
569
        return ddf
570
571
572
class AdditiveUpSampling(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
573
    def __init__(self, output_shape: tuple, stride: (int, list) = 2, **kwargs):
574
        """
575
        Layer up-samples 3d tensor and reduce channels using split and sum.
576
577
        :param output_shape: (out_dim1, out_dim2, out_dim3)
578
        :param stride: int, 1-D Tensor or list
579
        :param kwargs: additional arguments.
580
        """
581
        super().__init__(**kwargs)
582
        # save parameters
583
        self._stride = stride
584
        self._resize = Resize3d(output_shape)
585
        self._output_shape = output_shape
586
587
    def call(self, inputs, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"inputs" missing in parameter type documentation
Loading history...
introduced by
"ValueError" not documented as being raised
Loading history...
588
        """
589
        :param inputs: shape = (batch, dim1, dim2, dim3, channels)
590
        :param kwargs: additional arguments.
591
        :return: shape = (batch, out_dim1, out_dim2, out_dim3, channels//stride]
592
        """
593
        if inputs.shape[4] % self._stride != 0:
594
            raise ValueError("The channel dimension can not be divided by the stride")
595
        output = self._resize(inputs)
596
        # a list of (batch, out_dim1, out_dim2, out_dim3, channels//stride)
597
        output = tf.split(output, num_or_size_splits=self._stride, axis=4)
0 ignored issues
show
Unused Code introduced by
Argument 'axis' passed by position and keyword in function call
Loading history...
598
        # (batch, out_dim1, out_dim2, out_dim3, channels//stride)
599
        output = tf.reduce_sum(tf.stack(output, axis=5), axis=5)
600
        return output
601
602
603
class LocalNetResidual3dBlock(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
604
    def __init__(
605
        self,
606
        filters: int,
607
        kernel_size: (int, tuple) = 3,
608
        strides: (int, tuple) = 1,
609
        activation: str = "relu",
610
        **kwargs,
611
    ):
612
        """
613
        A resnet conv3d block, simpler than Residual3dBlock.
614
615
        1. conved = conv3d(inputs)
616
        2. out = act(norm(conved) + inputs)
617
618
        :param filters: number of channels of the output
619
        :param kernel_size: int or tuple of 3 ints, e.g. (3,3,3) or 3
620
        :param strides: int or tuple of 3 ints, e.g. (1,1,1) or 1
621
        :param activation: name of activation
622
        :param kwargs: additional arguments.
623
        """
624
        super().__init__(**kwargs)
625
        # init layer variables
626
        self._conv3d = tfkl.Conv3D(
627
            filters=filters,
628
            kernel_size=kernel_size,
629
            strides=strides,
630
            padding="same",
631
            use_bias=False,
632
        )
633
        self._norm = tfkl.BatchNormalization()
634
        self._act = tfkl.Activation(activation=activation)
635
636
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
637
        return self._act(
638
            self._norm(inputs=self._conv3d(inputs=inputs[0]), training=training)
639
            + inputs[1]
640
        )
641
642
643
class LocalNetUpSampleResnetBlock(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
644
    def __init__(self, filters: int, use_additive_upsampling: bool = True, **kwargs):
645
        """
646
        Layer up-samples tensor with two inputs (skipped and down-sampled).
647
648
        :param filters: int, number of output channels
649
        :param use_additive_upsampling: bool to used additive upsampling
650
        :param kwargs: additional arguments.
651
        """
652
        super().__init__(**kwargs)
653
        # save parameters
654
        self._filters = filters
655
        self._use_additive_upsampling = use_additive_upsampling
656
        # init layer variables
657
        self._deconv3d_block = None
658
        self._additive_upsampling = None
659
        self._conv3d_block = Conv3dBlock(filters=filters, kernel_size=3, padding="same")
660
        self._residual_block = LocalNetResidual3dBlock(filters=filters, strides=1)
661
662
    def build(self, input_shape):
0 ignored issues
show
introduced by
"input_shape" missing in parameter type documentation
Loading history...
663
        """
664
        :param input_shape: tuple (nonskip_tensor_shape, skip_tensor_shape)
665
        """
666
        super().build(input_shape)
667
668
        output_shape = input_shape[1][1:4]
669
        self._deconv3d_block = Deconv3dBlock(
670
            filters=self._filters,
671
            output_shape=output_shape,
672
            kernel_size=3,
673
            strides=2,
674
            padding="same",
675
        )
676
        if self._use_additive_upsampling:
677
            self._additive_upsampling = AdditiveUpSampling(output_shape=output_shape)
678
679
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
680
        """
681
        :param inputs: list = [inputs_nonskip, inputs_skip]
682
        :param training: training flag for normalization layers (default: None)
683
        :param kwargs: additional arguments.
684
        :return:
685
        """
686
        inputs_nonskip, inputs_skip = inputs[0], inputs[1]
687
        h0 = self._deconv3d_block(inputs=inputs_nonskip, training=training)
688
        if self._use_additive_upsampling:
689
            h0 += self._additive_upsampling(inputs=inputs_nonskip)
690
        r1 = h0 + inputs_skip
691
        r2 = self._conv3d_block(inputs=h0, training=training)
692
        h1 = self._residual_block(inputs=[r2, r1], training=training)
693
        return h1
694
695
696
class ResizeCPTransform(tfkl.Layer):
697
    """
698
    Layer for getting the control points from the output of a image-to-image network.
699
    It uses an anti-aliasing Gaussian filter before downsampling.
700
    """
701
702
    def __init__(self, control_point_spacing: (list, tuple, int), **kwargs):
703
        """
704
        :param control_point_spacing: list or int
705
        :param kwargs: additional arguments.
706
        """
707
        super().__init__(**kwargs)
708
709
        if isinstance(control_point_spacing, int):
710
            control_point_spacing = [control_point_spacing] * 3
711
712
        self.kernel_sigma = [
713
            0.44 * cp for cp in control_point_spacing
714
        ]  # 0.44 = ln(4)/pi
715
        self.cp_spacing = control_point_spacing
716
        self.kernel = None
717
        self._output_shape = None
718
        self._resize = None
719
720
    def build(self, input_shape):
721
        super().build(input_shape=input_shape)
722
723
        self.kernel = layer_util.gaussian_filter_3d(self.kernel_sigma)
724
        output_shape = [
725
            tf.cast(tf.math.ceil(v / c) + 3, tf.int32)
726
            for v, c in zip(input_shape[1:-1], self.cp_spacing)
727
        ]
728
        self._output_shape = output_shape
729
        self._resize = Resize3d(output_shape)
730
731
    def call(self, inputs, **kwargs) -> tf.Tensor:
732
        output = tf.nn.conv3d(
733
            inputs, self.kernel, strides=(1, 1, 1, 1, 1), padding="SAME"
734
        )
735
        output = self._resize(inputs=output)
736
        return output
737
738
739
class BSplines3DTransform(tfkl.Layer):
740
    """
741
     Layer for BSplines interpolation with precomputed cubic spline filters.
742
     It assumes a full sized image from which:
743
     1. it compute the contol points values by downsampling the initial image
744
     2. performs the interpolation
745
     3. crops the image around the valid values.
746
747
    :param cp_spacing: int or tuple of three ints specifying the spacing (in pixels)
748
        in each dimension. When a single int is used,
749
        the same spacing to all dimensions is used
750
    :param output_shape: (batch_size, dim0, dim1, dim2, 3) of the high resolution
751
        deformation fields.
752
    :param kwargs: additional arguments.
753
    """
754
755
    def __init__(self, cp_spacing: (int, tuple), output_shape: tuple, **kwargs):
756
757
        super().__init__(**kwargs)
758
759
        self.filters = []
760
        self._output_shape = output_shape
761
762
        if isinstance(cp_spacing, int):
763
            self.cp_spacing = (cp_spacing, cp_spacing, cp_spacing)
764
        else:
765
            self.cp_spacing = cp_spacing
766
767
    def build(self, input_shape: tuple):
0 ignored issues
show
introduced by
Redundant returns documentation
Loading history...
768
        """
769
        :param input_shape: tuple with the input shape
770
        :return: None
771
        """
772
773
        super().build(input_shape=input_shape)
774
775
        b = {
776
            0: lambda u: np.float64((1 - u) ** 3 / 6),
777
            1: lambda u: np.float64((3 * (u ** 3) - 6 * (u ** 2) + 4) / 6),
778
            2: lambda u: np.float64((-3 * (u ** 3) + 3 * (u ** 2) + 3 * u + 1) / 6),
779
            3: lambda u: np.float64(u ** 3 / 6),
780
        }
781
782
        filters = np.zeros(
783
            (
784
                4 * self.cp_spacing[0],
785
                4 * self.cp_spacing[1],
786
                4 * self.cp_spacing[2],
787
                3,
788
                3,
789
            ),
790
            dtype=np.float32,
791
        )
792
793
        u_arange = 1 - np.arange(
794
            1 / (2 * self.cp_spacing[0]), 1, 1 / self.cp_spacing[0]
795
        )
796
        v_arange = 1 - np.arange(
797
            1 / (2 * self.cp_spacing[1]), 1, 1 / self.cp_spacing[1]
798
        )
799
        w_arange = 1 - np.arange(
800
            1 / (2 * self.cp_spacing[2]), 1, 1 / self.cp_spacing[2]
801
        )
802
803
        filter_idx = [[0, 1, 2, 3] for _ in range(3)]
804
        filter_coord = list(itertools.product(*filter_idx))
805
806
        for f_idx in filter_coord:
807
            for it_dim in range(3):
808
                filters[
809
                    f_idx[0] * self.cp_spacing[0] : (f_idx[0] + 1) * self.cp_spacing[0],
810
                    f_idx[1] * self.cp_spacing[1] : (f_idx[1] + 1) * self.cp_spacing[1],
811
                    f_idx[2] * self.cp_spacing[2] : (f_idx[2] + 1) * self.cp_spacing[2],
812
                    it_dim,
813
                    it_dim,
814
                ] = (
815
                    b[f_idx[0]](u_arange)[:, None, None]
816
                    * b[f_idx[1]](v_arange)[None, :, None]
817
                    * b[f_idx[2]](w_arange)[None, None, :]
818
                )
819
820
        self.filter = tf.convert_to_tensor(filters)
821
822
    def interpolate(self, field) -> tf.Tensor:
0 ignored issues
show
introduced by
"field" missing in parameter type documentation
Loading history...
823
        """
824
        :param field: tf.Tensor with shape=number_of_control_points_per_dim
825
        :return: interpolated_field: tf.Tensor
826
        """
827
828
        image_shape = tuple(
829
            [(a - 1) * b + 4 * b for a, b in zip(field.shape[1:-1], self.cp_spacing)]
830
        )
831
832
        output_shape = (field.shape[0],) + image_shape + (3,)
833
        return tf.nn.conv3d_transpose(
834
            field,
835
            self.filter,
836
            output_shape=output_shape,
837
            strides=self.cp_spacing,
838
            padding="VALID",
839
        )
840
841
    def call(self, inputs, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"inputs" missing in parameter type documentation
Loading history...
842
        """
843
        :param inputs: tf.Tensor defining a low resolution free-form deformation field
844
        :param kwargs: additional arguments.
845
        :return: interpolated_field: tf.Tensor of shape=self.input_shape
846
        """
847
        high_res_field = self.interpolate(inputs)
848
849
        index = [int(3 * c) for c in self.cp_spacing]
850
        return high_res_field[
851
            :,
852
            index[0] : index[0] + self._output_shape[0],
853
            index[1] : index[1] + self._output_shape[1],
854
            index[2] : index[2] + self._output_shape[2],
855
        ]
856