Passed
Pull Request — main (#656)
by Yunguan
03:03
created

deepreg.model.layer.Warping.get_config()   A

Complexity

Conditions 1

Size

Total Lines 5
Code Lines 4

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 4
dl 0
loc 5
rs 10
c 0
b 0
f 0
cc 1
nop 1
1
"""This module defines custom layers."""
2
import itertools
3
4
import numpy as np
5
import tensorflow as tf
6
import tensorflow.keras.layers as tfkl
7
8
import deepreg.model.layer_util as layer_util
9
10
11
class Deconv3d(tfkl.Layer):
12
    """
13
    Wrap Conv3DTranspose to allow dynamic output padding calculation.
14
    """
15
16
    def __init__(
0 ignored issues
show
introduced by
"name" missing in parameter type documentation
Loading history...
17
        self,
18
        filters: int,
19
        output_shape: (tuple, None) = None,
20
        kernel_size: int = 3,
21
        strides: int = 1,
22
        padding: str = "same",
23
        name="deconv3d",
24
        **kwargs,
25
    ):
26
        """
27
        Init.
28
29
        :param filters: number of channels of the output
30
        :param output_shape: (out_dim1, out_dim2, out_dim3)
31
        :param kernel_size: int or tuple of 3 ints, e.g. (3,3,3) or 3
32
        :param strides: int or tuple of 3 ints, e.g. (1,1,1) or 1
33
        :param padding: same or valid.
34
        :param name: name of the layer.
35
        :param kwargs: additional arguments for Conv3DTranspose.
36
        """
37
        super().__init__(name=name)
38
        # save arguments
39
        self._filters = filters
40
        self._output_shape = output_shape
41
        self._kernel_size = kernel_size
42
        self._strides = strides
43
        self._padding = padding
44
        self._kwargs = kwargs
45
        # init layer variables
46
        self._deconv3d = None
47
48
    def build(self, input_shape):
0 ignored issues
show
introduced by
"input_shape" missing in parameter type documentation
Loading history...
49
        # pylint: disable-next=line-too-long
50
        """
51
        Calculate output padding on the fly.
52
53
        https://github.com/tensorflow/tensorflow/blob/1cf0898dd4331baf93fe77205550f2c2e6c90ee5/tensorflow/python/keras/utils/conv_utils.py#L139-L185
54
55
        :param input_shape: shape of input
56
        """
57
        super().build(input_shape)
58
        kernel_size = self._kernel_size
59
        strides = self._strides
60
        if isinstance(kernel_size, int):
61
            kernel_size = [kernel_size] * 3
62
63
        if isinstance(strides, int):
64
            strides = [strides] * 3
65
66
        output_padding = None
67
        if self._output_shape is not None:
68
            assert self._padding == "same"
69
            output_padding = [
70
                self._output_shape[i]
71
                - (
72
                    (input_shape[1 + i] - 1) * strides[i]
73
                    + kernel_size[i]
74
                    - 2 * (kernel_size[i] // 2)
75
                )
76
                for i in range(3)
77
            ]
78
        self._deconv3d = tfkl.Conv3DTranspose(
79
            filters=self._filters,
80
            kernel_size=self._kernel_size,
81
            strides=self._strides,
82
            padding=self._padding,
83
            output_padding=output_padding,
84
            **self._kwargs,
85
        )
86
87
    def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor:
88
        """
89
        Forward.
90
91
        :param inputs: input tensor.
92
        :param kwargs: additional arguments
93
        :return:
94
        """
95
        return self._deconv3d(inputs=inputs)
96
97
    def get_config(self) -> dict:
98
        """Return the config dictionary for recreating this class."""
99
        config = super().get_config()
100
        config.update(
101
            dict(
102
                filters=self._filters,
103
                output_shape=self._output_shape,
104
                kernel_size=self._kernel_size,
105
                strides=self._strides,
106
                padding=self._padding,
107
            )
108
        )
109
        config.update(self._kwargs)
110
        return config
111
112
113
class NormBlock(tfkl.Layer):
114
    """
115
    A block with layer - norm - activation.
116
    """
117
118
    layer_cls_dict = dict(conv3d=tfkl.Conv3D, deconv3d=Deconv3d)
119
    norm_cls_dict = dict(batch=tfkl.BatchNormalization, layer=tfkl.LayerNormalization)
120
121
    def __init__(
122
        self,
123
        layer_name: str,
124
        norm_name: str = "batch",
125
        activation: str = "relu",
126
        name: str = "norm_block",
127
        **kwargs,
128
    ):
129
        """
130
        Init.
131
132
        :param layer_name: class of the layer to be wrapped.
133
        :param norm_name: class of the normalization layer.
134
        :param activation: name of activation.
135
        :param name: name of the block layer.
136
        :param kwargs: additional arguments.
137
        """
138
        super().__init__()
139
        self._config = dict(
140
            layer_name=layer_name,
141
            norm_name=norm_name,
142
            activation=activation,
143
            name=name,
144
            **kwargs,
145
        )
146
        self._layer = self.layer_cls_dict[layer_name](use_bias=False, **kwargs)
147
        self._norm = self.norm_cls_dict[norm_name]()
148
        self._act = tfkl.Activation(activation=activation)
149
150
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
151
        """
152
        Forward.
153
154
        :param inputs: inputs for the layer
155
        :param training: training flag for normalization layers (default: None)
156
        :param kwargs: additional arguments.
157
        :return:
158
        """
159
        output = self._layer(inputs=inputs)
160
        output = self._norm(inputs=output, training=training)
161
        output = self._act(output)
162
        return output
163
164
    def get_config(self) -> dict:
165
        """Return the config dictionary for recreating this class."""
166
        config = super().get_config()
167
        config.update(self._config)
168
        return config
169
170
171
class Conv3dBlock(NormBlock):
172
    """
173
    A conv3d block having conv3d - norm - activation.
174
    """
175
176
    def __init__(
177
        self,
178
        name: str = "conv3d_block",
179
        **kwargs,
180
    ):
181
        """
182
        Init.
183
184
        :param name: name of the layer
185
        :param kwargs: additional arguments.
186
        """
187
        super().__init__(layer_name="conv3d", name=name, **kwargs)
188
189
190
class Deconv3dBlock(NormBlock):
191
    """
192
    A deconv3d block having conv3d - norm - activation.
193
    """
194
195
    def __init__(
196
        self,
197
        name: str = "deconv3d_block",
198
        **kwargs,
199
    ):
200
        """
201
        Init.
202
203
        :param name: name of the layer
204
        :param kwargs: additional arguments.
205
        """
206
        super().__init__(layer_name="deconv3d", name=name, **kwargs)
207
208
209
class Resize3d(tfkl.Layer):
210
    """
211
    Resize image in two folds.
212
213
    - resize dim2 and dim3
214
    - resize dim1 and dim2
215
    """
216
217
    def __init__(
218
        self,
219
        shape: tuple,
220
        method: str = tf.image.ResizeMethod.BILINEAR,
221
        name: str = "resize3d",
222
    ):
223
        """
224
        Init, save arguments.
225
226
        :param shape: (dim1, dim2, dim3)
227
        :param method: tf.image.ResizeMethod
228
        :param name: name of the layer
229
        """
230
        super().__init__(name=name)
231
        assert len(shape) == 3
232
        self._shape = shape
233
        self._method = method
234
235
    def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"ValueError" not documented as being raised
Loading history...
236
        """
237
        Perform two fold resize.
238
239
        :param inputs: shape = (batch, dim1, dim2, dim3, channels)
240
                                     or (batch, dim1, dim2, dim3)
241
                                     or (dim1, dim2, dim3)
242
        :param kwargs: additional arguments
243
        :return: shape = (batch, out_dim1, out_dim2, out_dim3, channels)
244
                                or (batch, dim1, dim2, dim3)
245
                                or (dim1, dim2, dim3)
246
        """
247
        # sanity check
248
        image = inputs
249
        image_dim = len(image.shape)
250
251
        # init
252
        if image_dim == 5:
253
            has_channel = True
254
            has_batch = True
255
            input_image_shape = image.shape[1:4]
256
        elif image_dim == 4:
257
            has_channel = False
258
            has_batch = True
259
            input_image_shape = image.shape[1:4]
260
        elif image_dim == 3:
261
            has_channel = False
262
            has_batch = False
263
            input_image_shape = image.shape[0:3]
264
        else:
265
            raise ValueError(
266
                "Resize3d takes input image of dimension 3 or 4 or 5, "
267
                "corresponding to (dim1, dim2, dim3) "
268
                "or (batch, dim1, dim2, dim3) "
269
                "or (batch, dim1, dim2, dim3, channels), "
270
                "got image shape{}".format(image.shape)
271
            )
272
273
        # no need of resize
274
        if input_image_shape == tuple(self._shape):
275
            return image
276
277
        # expand to five dimensions
278
        if not has_batch:
279
            image = tf.expand_dims(image, axis=0)
280
        if not has_channel:
281
            image = tf.expand_dims(image, axis=-1)
282
        assert len(image.shape) == 5  # (batch, dim1, dim2, dim3, channels)
283
        image_shape = tf.shape(image)
284
285
        # merge axis 0 and 1
286
        output = tf.reshape(
287
            image, (-1, image_shape[2], image_shape[3], image_shape[4])
288
        )  # (batch * dim1, dim2, dim3, channels)
289
290
        # resize dim2 and dim3
291
        output = tf.image.resize(
292
            images=output, size=self._shape[1:3], method=self._method
293
        )  # (batch * dim1, out_dim2, out_dim3, channels)
294
295
        # split axis 0 and merge axis 3 and 4
296
        output = tf.reshape(
297
            output,
298
            shape=(-1, image_shape[1], self._shape[1], self._shape[2] * image_shape[4]),
299
        )  # (batch, dim1, out_dim2, out_dim3 * channels)
300
301
        # resize dim1 and dim2
302
        output = tf.image.resize(
303
            images=output, size=self._shape[:2], method=self._method
304
        )  # (batch, out_dim1, out_dim2, out_dim3 * channels)
305
306
        # reshape
307
        output = tf.reshape(
308
            output, shape=[-1, *self._shape, image_shape[4]]
309
        )  # (batch, out_dim1, out_dim2, out_dim3, channels)
310
311
        # squeeze to original dimension
312
        if not has_batch:
313
            output = tf.squeeze(output, axis=0)
314
        if not has_channel:
315
            output = tf.squeeze(output, axis=-1)
316
        return output
317
318
    def get_config(self) -> dict:
319
        """Return the config dictionary for recreating this class."""
320
        config = super().get_config()
321
        config["shape"] = self._shape
322
        config["method"] = self._method
323
        return config
324
325
326
class Warping(tfkl.Layer):
327
    """
328
    Warps an image with DDF.
329
330
    Reference:
331
332
    https://github.com/adalca/neurite/blob/legacy/neuron/utils.py
333
    where vol = image, loc_shift = ddf
334
    """
335
336
    def __init__(self, fixed_image_size: tuple, name: str = "warping", **kwargs):
337
        """
338
        Init.
339
340
        :param fixed_image_size: shape = (f_dim1, f_dim2, f_dim3)
341
             or (f_dim1, f_dim2, f_dim3, ch) with the last channel for features
342
        :param name: name of the layer
343
        :param kwargs: additional arguments.
344
        """
345
        super().__init__(name=name, **kwargs)
346
        self._fixed_image_size = fixed_image_size
347
        # shape = (1, f_dim1, f_dim2, f_dim3, 3)
348
        self.grid_ref = layer_util.get_reference_grid(grid_size=fixed_image_size)
349
        self.grid_ref = self.grid_ref[None, ...]
350
351
    def call(self, inputs, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"inputs" missing in parameter type documentation
Loading history...
352
        """
353
        :param inputs: (ddf, image)
354
355
          - ddf, shape = (batch, f_dim1, f_dim2, f_dim3, 3)
356
          - image, shape = (batch, m_dim1, m_dim2, m_dim3)
357
        :param kwargs: additional arguments.
358
        :return: shape = (batch, f_dim1, f_dim2, f_dim3)
359
        """
360
        ddf, image = inputs
361
        return layer_util.resample(vol=image, loc=self.grid_ref + ddf)
362
363
    def get_config(self) -> dict:
364
        """Return the config dictionary for recreating this class."""
365
        config = super().get_config()
366
        config["fixed_image_size"] = self._fixed_image_size
367
        return config
368
369
370
class ResidualBlock(tfkl.Layer):
371
    """
372
    A block with skip links and layer - norm - activation.
373
    """
374
375
    layer_cls_dict = dict(conv3d=tfkl.Conv3D, deconv3d=Deconv3d)
376
    norm_cls_dict = dict(batch=tfkl.BatchNormalization, layer=tfkl.LayerNormalization)
377
378
    def __init__(
379
        self,
380
        layer_name: str,
381
        num_layers: int = 2,
382
        norm_name: str = "batch",
383
        activation: str = "relu",
384
        name: str = "res_block",
385
        **kwargs,
386
    ):
387
        """
388
        Init.
389
390
        :param layer_name: class of the layer to be wrapped.
391
        :param num_layers: number of layers/blocks.
392
        :param norm_name: class of the normalization layer.
393
        :param activation: name of activation.
394
        :param name: name of the block layer.
395
        :param kwargs: additional arguments.
396
        """
397
        super().__init__()
398
        self._num_layers = num_layers
399
        self._config = dict(
400
            layer_name=layer_name,
401
            num_layers=num_layers,
402
            norm_name=norm_name,
403
            activation=activation,
404
            name=name,
405
            **kwargs,
406
        )
407
        self._layers = [
408
            self.layer_cls_dict[layer_name](use_bias=False, **kwargs)
409
            for _ in range(num_layers)
410
        ]
411
        self._norms = [self.norm_cls_dict[norm_name]() for _ in range(num_layers)]
412
        self._acts = [tfkl.Activation(activation=activation) for _ in range(num_layers)]
413
414
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
415
        """
416
        Forward.
417
418
        :param inputs: inputs for the layer
419
        :param training: training flag for normalization layers (default: None)
420
        :param kwargs: additional arguments.
421
        :return:
422
        """
423
424
        output = inputs
425
        for i in range(self._num_layers):
426
            output = self._layers[i](inputs=output)
427
            output = self._norms[i](inputs=output, training=training)
428
            if i == self._num_layers - 1:
429
                # last block
430
                output = output + inputs
431
            output = self._acts[i](output)
432
        return output
433
434
    def get_config(self) -> dict:
435
        """Return the config dictionary for recreating this class."""
436
        config = super().get_config()
437
        config.update(self._config)
438
        return config
439
440
441
class ResidualConv3dBlock(ResidualBlock):
442
    """
443
    A conv3d residual block
444
    """
445
446
    def __init__(
447
        self,
448
        name: str = "conv3d_res_block",
449
        **kwargs,
450
    ):
451
        """
452
        Init.
453
454
        :param name: name of the layer
455
        :param kwargs: additional arguments.
456
        """
457
        super().__init__(layer_name="conv3d", name=name, **kwargs)
458
459
460
class DownSampleResnetBlock(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
461
    def __init__(
462
        self,
463
        filters: int,
464
        kernel_size: (int, tuple) = 3,
465
        pooling: bool = True,
466
        **kwargs,
467
    ):
468
        """
469
        A down-sampling resnet conv3d block, with max-pooling or conv3d.
470
471
        1. conved = conv3d_block(inputs)  # adjust channel
472
        2. skip = residual_block(conved)  # develop feature
473
        3. pooled = pool(skip) # down-sample
474
475
        :param filters: number of channels of the output
476
        :param kernel_size: int or tuple of 3 ints, e.g. (3,3,3) or 3
477
        :param pooling: if True, use max pooling to downsample, otherwise use conv.
478
        :param kwargs: additional arguments.
479
        """
480
        super().__init__(**kwargs)
481
        # save parameters
482
        self._pooling = pooling
483
        # init layer variables
484
        self._conv3d_block = Conv3dBlock(
485
            filters=filters, kernel_size=kernel_size, padding="same"
486
        )
487
        self._residual_block = ResidualConv3dBlock(
488
            filters=filters, kernel_size=kernel_size, padding="same"
489
        )
490
        self._max_pool3d = (
491
            tfkl.MaxPool3D(pool_size=(2, 2, 2), strides=(2, 2, 2), padding="same")
492
            if pooling
493
            else None
494
        )
495
        self._conv3d_block3 = (
496
            None
497
            if pooling
498
            else NormBlock(
499
                layer_name="conv3d",
500
                filters=filters,
501
                kernel_size=kernel_size,
502
                strides=2,
503
                padding="same",
504
            )
505
        )
506
507
    def call(self, inputs, training=None, **kwargs) -> (tf.Tensor, tf.Tensor):
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
508
        """
509
        :param inputs: shape = (batch, in_dim1, in_dim2, in_dim3, channels)
510
        :param training: training flag for normalization layers (default: None)
511
        :param kwargs: additional arguments.
512
        :return: (pooled, skip)
513
514
          - downsampled, shape = (batch, in_dim1//2, in_dim2//2, in_dim3//2, channels)
515
          - skipped, shape = (batch, in_dim1, in_dim2, in_dim3, channels)
516
        """
517
        conved = self._conv3d_block(inputs=inputs, training=training)  # adjust channel
518
        skip = self._residual_block(inputs=conved, training=training)  # develop feature
519
        pooled = (
520
            self._max_pool3d(inputs=skip)
521
            if self._pooling
522
            else self._conv3d_block3(inputs=skip, training=training)
0 ignored issues
show
Bug introduced by
self._conv3d_block3 does not seem to be callable.
Loading history...
523
        )  # downsample
524
        return pooled, skip
525
526
527
class UpSampleResnetBlock(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
528
    def __init__(self, filters, kernel_size=3, concat=False, **kwargs):
0 ignored issues
show
introduced by
"concat, filters, kernel_size" missing in parameter type documentation
Loading history...
529
        """
530
        An up-sampling resnet conv3d block, with deconv3d.
531
532
        :param filters: number of channels of the output
533
        :param kernel_size: int or tuple of 3 ints, e.g. (3,3,3) or 3
534
        :param concat: bool,specify how to combine input and skip connection images.
535
            If True, use concatenation, otherwise use sum (default=False).
536
        :param kwargs: additional arguments.
537
        """
538
        super().__init__(**kwargs)
539
        # save parameters
540
        self._filters = filters
541
        self._concat = concat
542
        # init layer variables
543
        self._deconv3d_block = None
544
        self._conv3d_block = Conv3dBlock(
545
            filters=filters, kernel_size=kernel_size, padding="same"
546
        )
547
        self._residual_block = ResidualConv3dBlock(
548
            filters=filters, kernel_size=kernel_size, padding="same"
549
        )
550
551
    def build(self, input_shape):
0 ignored issues
show
introduced by
"input_shape" missing in parameter type documentation
Loading history...
552
        """
553
        :param input_shape: tuple, (downsampled_image_shape, skip_image_shape)
554
        """
555
        super().build(input_shape)
556
        skip_shape = input_shape[1][1:4]
557
        self._deconv3d_block = Deconv3dBlock(
558
            filters=self._filters,
559
            output_shape=skip_shape,
560
            kernel_size=3,
561
            strides=2,
562
            padding="same",
563
        )
564
565
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
566
        r"""
567
        :param inputs: tuple
568
569
          - down-sampled
570
          - skipped
571
572
        :param training: training flag for normalization layers (default: None)
573
        :param kwargs: additional arguments.
574
        :return: shape = (batch, \*skip_connection_image_shape, filters]
575
        """
576
        up_sampled, skip = inputs[0], inputs[1]
577
        up_sampled = self._deconv3d_block(
578
            inputs=up_sampled, training=training
579
        )  # up sample and change channel
580
        up_sampled = (
581
            tf.concat([up_sampled, skip], axis=4) if self._concat else up_sampled + skip
582
        )  # combine
583
        up_sampled = self._conv3d_block(
584
            inputs=up_sampled, training=training
585
        )  # adjust channel
586
        up_sampled = self._residual_block(inputs=up_sampled, training=training)  # conv
587
        return up_sampled
588
589
590
class IntDVF(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
591
    def __init__(self, fixed_image_size: tuple, num_steps: int = 7, **kwargs):
592
        """
593
        Layer calculates DVF from DDF.
594
595
        Reference:
596
597
        - integrate_vec of neuron
598
          https://github.com/adalca/neurite/blob/legacy/neuron/utils.py
599
600
        :param fixed_image_size: tuple, (f_dim1, f_dim2, f_dim3)
601
        :param num_steps: int, number of steps for integration
602
        :param kwargs: additional arguments.
603
        """
604
        super().__init__(**kwargs)
605
        self._warping = Warping(fixed_image_size=fixed_image_size)
606
        self._num_steps = num_steps
607
608
    def call(self, inputs, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"inputs" missing in parameter type documentation
Loading history...
609
        """
610
        :param inputs: dvf, shape = (batch, f_dim1, f_dim2, f_dim3, 3), type = float32
611
        :param kwargs: additional arguments.
612
        :return: ddf, shape = (batch, f_dim1, f_dim2, f_dim3, 3)
613
        """
614
        ddf = inputs / (2 ** self._num_steps)
615
        for _ in range(self._num_steps):
616
            ddf += self._warping(inputs=[ddf, ddf])
617
        return ddf
618
619
620
class AdditiveUpSampling(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
621
    def __init__(self, output_shape: tuple, stride: (int, list) = 2, **kwargs):
622
        """
623
        Layer up-samples 3d tensor and reduce channels using split and sum.
624
625
        :param output_shape: (out_dim1, out_dim2, out_dim3)
626
        :param stride: int, 1-D Tensor or list
627
        :param kwargs: additional arguments.
628
        """
629
        super().__init__(**kwargs)
630
        # save parameters
631
        self._stride = stride
632
        self._resize = Resize3d(output_shape)
633
        self._output_shape = output_shape
634
635
    def call(self, inputs, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"inputs" missing in parameter type documentation
Loading history...
introduced by
"ValueError" not documented as being raised
Loading history...
636
        """
637
        :param inputs: shape = (batch, dim1, dim2, dim3, channels)
638
        :param kwargs: additional arguments.
639
        :return: shape = (batch, out_dim1, out_dim2, out_dim3, channels//stride]
640
        """
641
        if inputs.shape[4] % self._stride != 0:
642
            raise ValueError("The channel dimension can not be divided by the stride")
643
        output = self._resize(inputs)
644
        # a list of (batch, out_dim1, out_dim2, out_dim3, channels//stride)
645
        output = tf.split(output, num_or_size_splits=self._stride, axis=4)
0 ignored issues
show
Unused Code introduced by
Argument 'axis' passed by position and keyword in function call
Loading history...
646
        # (batch, out_dim1, out_dim2, out_dim3, channels//stride)
647
        output = tf.reduce_sum(tf.stack(output, axis=5), axis=5)
648
        return output
649
650
651
class LocalNetResidual3dBlock(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
652
    def __init__(
653
        self,
654
        filters: int,
655
        kernel_size: (int, tuple) = 3,
656
        strides: (int, tuple) = 1,
657
        activation: str = "relu",
658
        **kwargs,
659
    ):
660
        """
661
        A resnet conv3d block, simpler than ResidualConv3dBlock.
662
663
        1. conved = conv3d(inputs)
664
        2. out = act(norm(conved) + inputs)
665
666
        :param filters: number of channels of the output
667
        :param kernel_size: int or tuple of 3 ints, e.g. (3,3,3) or 3
668
        :param strides: int or tuple of 3 ints, e.g. (1,1,1) or 1
669
        :param activation: name of activation
670
        :param kwargs: additional arguments.
671
        """
672
        super().__init__(**kwargs)
673
        # init layer variables
674
        self._conv3d = tfkl.Conv3D(
675
            filters=filters,
676
            kernel_size=kernel_size,
677
            strides=strides,
678
            padding="same",
679
            use_bias=False,
680
        )
681
        self._norm = tfkl.BatchNormalization()
682
        self._act = tfkl.Activation(activation=activation)
683
684
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
685
        return self._act(
686
            self._norm(inputs=self._conv3d(inputs=inputs[0]), training=training)
687
            + inputs[1]
688
        )
689
690
691
class LocalNetUpSampleResnetBlock(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
692
    def __init__(self, filters: int, use_additive_upsampling: bool = True, **kwargs):
693
        """
694
        Layer up-samples tensor with two inputs (skipped and down-sampled).
695
696
        :param filters: int, number of output channels
697
        :param use_additive_upsampling: bool to used additive upsampling
698
        :param kwargs: additional arguments.
699
        """
700
        super().__init__(**kwargs)
701
        # save parameters
702
        self._filters = filters
703
        self._use_additive_upsampling = use_additive_upsampling
704
        # init layer variables
705
        self._deconv3d_block = None
706
        self._additive_upsampling = None
707
        self._conv3d_block = Conv3dBlock(filters=filters, kernel_size=3, padding="same")
708
        self._residual_block = LocalNetResidual3dBlock(filters=filters, strides=1)
709
710
    def build(self, input_shape):
0 ignored issues
show
introduced by
"input_shape" missing in parameter type documentation
Loading history...
711
        """
712
        :param input_shape: tuple (nonskip_tensor_shape, skip_tensor_shape)
713
        """
714
        super().build(input_shape)
715
716
        output_shape = input_shape[1][1:4]
717
        self._deconv3d_block = Deconv3dBlock(
718
            filters=self._filters,
719
            output_shape=output_shape,
720
            kernel_size=3,
721
            strides=2,
722
            padding="same",
723
        )
724
        if self._use_additive_upsampling:
725
            self._additive_upsampling = AdditiveUpSampling(output_shape=output_shape)
726
727
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
728
        """
729
        :param inputs: list = [inputs_nonskip, inputs_skip]
730
        :param training: training flag for normalization layers (default: None)
731
        :param kwargs: additional arguments.
732
        :return:
733
        """
734
        inputs_nonskip, inputs_skip = inputs[0], inputs[1]
735
        h0 = self._deconv3d_block(inputs=inputs_nonskip, training=training)
736
        if self._use_additive_upsampling:
737
            h0 += self._additive_upsampling(inputs=inputs_nonskip)
738
        r1 = h0 + inputs_skip
739
        r2 = self._conv3d_block(inputs=h0, training=training)
740
        h1 = self._residual_block(inputs=[r2, r1], training=training)
741
        return h1
742
743
744
class ResizeCPTransform(tfkl.Layer):
745
    """
746
    Layer for getting the control points from the output of a image-to-image network.
747
    It uses an anti-aliasing Gaussian filter before downsampling.
748
    """
749
750
    def __init__(self, control_point_spacing: (list, tuple, int), **kwargs):
751
        """
752
        :param control_point_spacing: list or int
753
        :param kwargs: additional arguments.
754
        """
755
        super().__init__(**kwargs)
756
757
        if isinstance(control_point_spacing, int):
758
            control_point_spacing = [control_point_spacing] * 3
759
760
        self.kernel_sigma = [
761
            0.44 * cp for cp in control_point_spacing
762
        ]  # 0.44 = ln(4)/pi
763
        self.cp_spacing = control_point_spacing
764
        self.kernel = None
765
        self._output_shape = None
766
        self._resize = None
767
768
    def build(self, input_shape):
769
        super().build(input_shape=input_shape)
770
771
        self.kernel = layer_util.gaussian_filter_3d(self.kernel_sigma)
772
        output_shape = [
773
            tf.cast(tf.math.ceil(v / c) + 3, tf.int32)
774
            for v, c in zip(input_shape[1:-1], self.cp_spacing)
775
        ]
776
        self._output_shape = output_shape
777
        self._resize = Resize3d(output_shape)
778
779
    def call(self, inputs, **kwargs) -> tf.Tensor:
780
        output = tf.nn.conv3d(
781
            inputs, self.kernel, strides=(1, 1, 1, 1, 1), padding="SAME"
782
        )
783
        output = self._resize(inputs=output)
784
        return output
785
786
787
class BSplines3DTransform(tfkl.Layer):
788
    """
789
     Layer for BSplines interpolation with precomputed cubic spline filters.
790
     It assumes a full sized image from which:
791
     1. it compute the contol points values by downsampling the initial image
792
     2. performs the interpolation
793
     3. crops the image around the valid values.
794
795
    :param cp_spacing: int or tuple of three ints specifying the spacing (in pixels)
796
        in each dimension. When a single int is used,
797
        the same spacing to all dimensions is used
798
    :param output_shape: (batch_size, dim0, dim1, dim2, 3) of the high resolution
799
        deformation fields.
800
    :param kwargs: additional arguments.
801
    """
802
803
    def __init__(self, cp_spacing: (int, tuple), output_shape: tuple, **kwargs):
804
805
        super().__init__(**kwargs)
806
807
        self.filters = []
808
        self._output_shape = output_shape
809
810
        if isinstance(cp_spacing, int):
811
            self.cp_spacing = (cp_spacing, cp_spacing, cp_spacing)
812
        else:
813
            self.cp_spacing = cp_spacing
814
815
    def build(self, input_shape: tuple):
0 ignored issues
show
introduced by
Redundant returns documentation
Loading history...
816
        """
817
        :param input_shape: tuple with the input shape
818
        :return: None
819
        """
820
821
        super().build(input_shape=input_shape)
822
823
        b = {
824
            0: lambda u: np.float64((1 - u) ** 3 / 6),
825
            1: lambda u: np.float64((3 * (u ** 3) - 6 * (u ** 2) + 4) / 6),
826
            2: lambda u: np.float64((-3 * (u ** 3) + 3 * (u ** 2) + 3 * u + 1) / 6),
827
            3: lambda u: np.float64(u ** 3 / 6),
828
        }
829
830
        filters = np.zeros(
831
            (
832
                4 * self.cp_spacing[0],
833
                4 * self.cp_spacing[1],
834
                4 * self.cp_spacing[2],
835
                3,
836
                3,
837
            ),
838
            dtype=np.float32,
839
        )
840
841
        u_arange = 1 - np.arange(
842
            1 / (2 * self.cp_spacing[0]), 1, 1 / self.cp_spacing[0]
843
        )
844
        v_arange = 1 - np.arange(
845
            1 / (2 * self.cp_spacing[1]), 1, 1 / self.cp_spacing[1]
846
        )
847
        w_arange = 1 - np.arange(
848
            1 / (2 * self.cp_spacing[2]), 1, 1 / self.cp_spacing[2]
849
        )
850
851
        filter_idx = [[0, 1, 2, 3] for _ in range(3)]
852
        filter_coord = list(itertools.product(*filter_idx))
853
854
        for f_idx in filter_coord:
855
            for it_dim in range(3):
856
                filters[
857
                    f_idx[0] * self.cp_spacing[0] : (f_idx[0] + 1) * self.cp_spacing[0],
858
                    f_idx[1] * self.cp_spacing[1] : (f_idx[1] + 1) * self.cp_spacing[1],
859
                    f_idx[2] * self.cp_spacing[2] : (f_idx[2] + 1) * self.cp_spacing[2],
860
                    it_dim,
861
                    it_dim,
862
                ] = (
863
                    b[f_idx[0]](u_arange)[:, None, None]
864
                    * b[f_idx[1]](v_arange)[None, :, None]
865
                    * b[f_idx[2]](w_arange)[None, None, :]
866
                )
867
868
        self.filter = tf.convert_to_tensor(filters)
869
870
    def interpolate(self, field) -> tf.Tensor:
0 ignored issues
show
introduced by
"field" missing in parameter type documentation
Loading history...
871
        """
872
        :param field: tf.Tensor with shape=number_of_control_points_per_dim
873
        :return: interpolated_field: tf.Tensor
874
        """
875
876
        image_shape = tuple(
877
            [(a - 1) * b + 4 * b for a, b in zip(field.shape[1:-1], self.cp_spacing)]
878
        )
879
880
        output_shape = (field.shape[0],) + image_shape + (3,)
881
        return tf.nn.conv3d_transpose(
882
            field,
883
            self.filter,
884
            output_shape=output_shape,
885
            strides=self.cp_spacing,
886
            padding="VALID",
887
        )
888
889
    def call(self, inputs, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"inputs" missing in parameter type documentation
Loading history...
890
        """
891
        :param inputs: tf.Tensor defining a low resolution free-form deformation field
892
        :param kwargs: additional arguments.
893
        :return: interpolated_field: tf.Tensor of shape=self.input_shape
894
        """
895
        high_res_field = self.interpolate(inputs)
896
897
        index = [int(3 * c) for c in self.cp_spacing]
898
        return high_res_field[
899
            :,
900
            index[0] : index[0] + self._output_shape[0],
901
            index[1] : index[1] + self._output_shape[1],
902
            index[2] : index[2] + self._output_shape[2],
903
        ]
904