Passed
Pull Request — main (#656)
by Yunguan
02:34
created

deepreg.model.layer.Warping.get_config()   A

Complexity

Conditions 1

Size

Total Lines 5
Code Lines 4

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 4
dl 0
loc 5
rs 10
c 0
b 0
f 0
cc 1
nop 1
1
"""This module defines custom layers."""
2
import itertools
3
4
import numpy as np
5
import tensorflow as tf
6
import tensorflow.keras.layers as tfkl
7
8
import deepreg.model.layer_util as layer_util
9
10
11
class Deconv3d(tfkl.Layer):
12
    """
13
    Wrap Conv3DTranspose to allow dynamic output padding calculation.
14
    """
15
16
    def __init__(
0 ignored issues
show
introduced by
"name" missing in parameter type documentation
Loading history...
17
        self,
18
        filters: int,
19
        output_shape: (tuple, None) = None,
20
        kernel_size: int = 3,
21
        strides: int = 1,
22
        padding: str = "same",
23
        name="deconv3d",
24
        **kwargs,
25
    ):
26
        """
27
        Init.
28
29
        :param filters: number of channels of the output
30
        :param output_shape: (out_dim1, out_dim2, out_dim3)
31
        :param kernel_size: int or tuple of 3 ints, e.g. (3,3,3) or 3
32
        :param strides: int or tuple of 3 ints, e.g. (1,1,1) or 1
33
        :param padding: same or valid.
34
        :param name: name of the layer.
35
        :param kwargs: additional arguments for Conv3DTranspose.
36
        """
37
        super().__init__(name=name)
38
        # save arguments
39
        self._filters = filters
40
        self._output_shape = output_shape
41
        self._kernel_size = kernel_size
42
        self._strides = strides
43
        self._padding = padding
44
        self._kwargs = kwargs
45
        # init layer variables
46
        self._deconv3d = None
47
48
    def build(self, input_shape):
0 ignored issues
show
introduced by
"input_shape" missing in parameter type documentation
Loading history...
49
        # pylint: disable-next=line-too-long
50
        """
51
        Calculate output padding on the fly.
52
53
        https://github.com/tensorflow/tensorflow/blob/1cf0898dd4331baf93fe77205550f2c2e6c90ee5/tensorflow/python/keras/utils/conv_utils.py#L139-L185
54
55
        :param input_shape: shape of input
56
        """
57
        super().build(input_shape)
58
        kernel_size = self._kernel_size
59
        strides = self._strides
60
        if isinstance(kernel_size, int):
61
            kernel_size = [kernel_size] * 3
62
63
        if isinstance(strides, int):
64
            strides = [strides] * 3
65
66
        output_padding = None
67
        if self._output_shape is not None:
68
            assert self._padding == "same"
69
            output_padding = [
70
                self._output_shape[i]
71
                - (
72
                    (input_shape[1 + i] - 1) * strides[i]
73
                    + kernel_size[i]
74
                    - 2 * (kernel_size[i] // 2)
75
                )
76
                for i in range(3)
77
            ]
78
        self._deconv3d = tfkl.Conv3DTranspose(
79
            filters=self._filters,
80
            kernel_size=self._kernel_size,
81
            strides=self._strides,
82
            padding=self._padding,
83
            output_padding=output_padding,
84
            **self._kwargs,
85
        )
86
87
    def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor:
88
        """
89
        Forward.
90
91
        :param inputs: input tensor.
92
        :param kwargs: additional arguments
93
        :return:
94
        """
95
        return self._deconv3d(inputs=inputs)
96
97
    def get_config(self) -> dict:
98
        """Return the config dictionary for recreating this class."""
99
        config = super().get_config()
100
        config.update(
101
            dict(
102
                filters=self._filters,
103
                output_shape=self._output_shape,
104
                kernel_size=self._kernel_size,
105
                strides=self._strides,
106
                padding=self._padding,
107
            )
108
        )
109
        config.update(self._kwargs)
110
        return config
111
112
113
class NormBlock(tfkl.Layer):
114
    """
115
    A block with layer - norm - activation.
116
    """
117
118
    layer_cls_dict = dict(conv3d=tfkl.Conv3D, deconv3d=Deconv3d)
119
    norm_cls_dict = dict(batch=tfkl.BatchNormalization, layer=tfkl.LayerNormalization)
120
121
    def __init__(
122
        self,
123
        layer_name: str,
124
        norm_name: str = "batch",
125
        activation: str = "relu",
126
        name: str = "norm_block",
127
        **kwargs,
128
    ):
129
        """
130
        Init.
131
132
        :param layer_name: class of the layer to be wrapped.
133
        :param norm_name: class of the normalization layer.
134
        :param activation: name of activation.
135
        :param name: name of the block layer.
136
        :param kwargs: additional arguments.
137
        """
138
        super().__init__()
139
        self._config = dict(
140
            layer_name=layer_name,
141
            norm_name=norm_name,
142
            activation=activation,
143
            name=name,
144
            **kwargs,
145
        )
146
        self._layer = self.layer_cls_dict[layer_name](use_bias=False, **kwargs)
147
        self._norm = self.norm_cls_dict[norm_name]()
148
        self._act = tfkl.Activation(activation=activation)
149
150
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
151
        """
152
        Forward.
153
154
        :param inputs: inputs for the layer
155
        :param training: training flag for normalization layers (default: None)
156
        :param kwargs: additional arguments.
157
        :return:
158
        """
159
        output = self._layer(inputs=inputs)
160
        output = self._norm(inputs=output, training=training)
161
        output = self._act(output)
162
        return output
163
164
    def get_config(self) -> dict:
165
        """Return the config dictionary for recreating this class."""
166
        config = super().get_config()
167
        config.update(self._config)
168
        return config
169
170
171
class Conv3dBlock(NormBlock):
172
    """
173
    A conv3d block having conv3d - norm - activation.
174
    """
175
176
    def __init__(
177
        self,
178
        name: str = "conv3d_block",
179
        **kwargs,
180
    ):
181
        """
182
        Init.
183
184
        :param name: name of the layer
185
        :param kwargs: additional arguments.
186
        """
187
        super().__init__(layer_name="conv3d", name=name, **kwargs)
188
189
190
class Deconv3dBlock(NormBlock):
191
    """
192
    A deconv3d block having conv3d - norm - activation.
193
    """
194
195
    def __init__(
196
        self,
197
        name: str = "deconv3d_block",
198
        **kwargs,
199
    ):
200
        """
201
        Init.
202
203
        :param name: name of the layer
204
        :param kwargs: additional arguments.
205
        """
206
        super().__init__(layer_name="deconv3d", name=name, **kwargs)
207
208
209
class Resize3d(tfkl.Layer):
210
    """
211
    Resize image in two folds.
212
213
    - resize dim2 and dim3
214
    - resize dim1 and dim2
215
    """
216
217
    def __init__(
218
        self,
219
        shape: tuple,
220
        method: str = tf.image.ResizeMethod.BILINEAR,
221
        name: str = "resize3d",
222
    ):
223
        """
224
        Init, save arguments.
225
226
        :param shape: (dim1, dim2, dim3)
227
        :param method: tf.image.ResizeMethod
228
        :param name: name of the layer
229
        """
230
        super().__init__(name=name)
231
        assert len(shape) == 3
232
        self._shape = shape
233
        self._method = method
234
235
    def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"ValueError" not documented as being raised
Loading history...
236
        """
237
        Perform two fold resize.
238
239
        :param inputs: shape = (batch, dim1, dim2, dim3, channels)
240
                                     or (batch, dim1, dim2, dim3)
241
                                     or (dim1, dim2, dim3)
242
        :param kwargs: additional arguments
243
        :return: shape = (batch, out_dim1, out_dim2, out_dim3, channels)
244
                                or (batch, dim1, dim2, dim3)
245
                                or (dim1, dim2, dim3)
246
        """
247
        # sanity check
248
        image = inputs
249
        image_dim = len(image.shape)
250
251
        # init
252
        if image_dim == 5:
253
            has_channel = True
254
            has_batch = True
255
            input_image_shape = image.shape[1:4]
256
        elif image_dim == 4:
257
            has_channel = False
258
            has_batch = True
259
            input_image_shape = image.shape[1:4]
260
        elif image_dim == 3:
261
            has_channel = False
262
            has_batch = False
263
            input_image_shape = image.shape[0:3]
264
        else:
265
            raise ValueError(
266
                "Resize3d takes input image of dimension 3 or 4 or 5, "
267
                "corresponding to (dim1, dim2, dim3) "
268
                "or (batch, dim1, dim2, dim3) "
269
                "or (batch, dim1, dim2, dim3, channels), "
270
                "got image shape{}".format(image.shape)
271
            )
272
273
        # no need of resize
274
        if input_image_shape == tuple(self._shape):
275
            return image
276
277
        # expand to five dimensions
278
        if not has_batch:
279
            image = tf.expand_dims(image, axis=0)
280
        if not has_channel:
281
            image = tf.expand_dims(image, axis=-1)
282
        assert len(image.shape) == 5  # (batch, dim1, dim2, dim3, channels)
283
        image_shape = tf.shape(image)
284
285
        # merge axis 0 and 1
286
        output = tf.reshape(
287
            image, (-1, image_shape[2], image_shape[3], image_shape[4])
288
        )  # (batch * dim1, dim2, dim3, channels)
289
290
        # resize dim2 and dim3
291
        output = tf.image.resize(
292
            images=output, size=self._shape[1:3], method=self._method
293
        )  # (batch * dim1, out_dim2, out_dim3, channels)
294
295
        # split axis 0 and merge axis 3 and 4
296
        output = tf.reshape(
297
            output,
298
            shape=(-1, image_shape[1], self._shape[1], self._shape[2] * image_shape[4]),
299
        )  # (batch, dim1, out_dim2, out_dim3 * channels)
300
301
        # resize dim1 and dim2
302
        output = tf.image.resize(
303
            images=output, size=self._shape[:2], method=self._method
304
        )  # (batch, out_dim1, out_dim2, out_dim3 * channels)
305
306
        # reshape
307
        output = tf.reshape(
308
            output, shape=[-1, *self._shape, image_shape[4]]
309
        )  # (batch, out_dim1, out_dim2, out_dim3, channels)
310
311
        # squeeze to original dimension
312
        if not has_batch:
313
            output = tf.squeeze(output, axis=0)
314
        if not has_channel:
315
            output = tf.squeeze(output, axis=-1)
316
        return output
317
318
    def get_config(self) -> dict:
319
        """Return the config dictionary for recreating this class."""
320
        config = super().get_config()
321
        config["shape"] = self._shape
322
        config["method"] = self._method
323
        return config
324
325
326
class Warping(tfkl.Layer):
327
    """
328
    Warps an image with DDF.
329
330
    Reference:
331
332
    https://github.com/adalca/neurite/blob/legacy/neuron/utils.py
333
    where vol = image, loc_shift = ddf
334
    """
335
336
    def __init__(self, fixed_image_size: tuple, name: str = "warping", **kwargs):
337
        """
338
        Init.
339
340
        :param fixed_image_size: shape = (f_dim1, f_dim2, f_dim3)
341
             or (f_dim1, f_dim2, f_dim3, ch) with the last channel for features
342
        :param name: name of the layer
343
        :param kwargs: additional arguments.
344
        """
345
        super().__init__(name=name, **kwargs)
346
        self._fixed_image_size = fixed_image_size
347
        # shape = (1, f_dim1, f_dim2, f_dim3, 3)
348
        self.grid_ref = layer_util.get_reference_grid(grid_size=fixed_image_size)
349
        self.grid_ref = self.grid_ref[None, ...]
350
351
    def call(self, inputs, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"inputs" missing in parameter type documentation
Loading history...
352
        """
353
        :param inputs: (ddf, image)
354
355
          - ddf, shape = (batch, f_dim1, f_dim2, f_dim3, 3)
356
          - image, shape = (batch, m_dim1, m_dim2, m_dim3)
357
        :param kwargs: additional arguments.
358
        :return: shape = (batch, f_dim1, f_dim2, f_dim3)
359
        """
360
        ddf, image = inputs
361
        return layer_util.resample(vol=image, loc=self.grid_ref + ddf)
362
363
    def get_config(self) -> dict:
364
        """Return the config dictionary for recreating this class."""
365
        config = super().get_config()
366
        config["fixed_image_size"] = self._fixed_image_size
367
        return config
368
369
370
class ResidualBlock(tfkl.Layer):
371
    """
372
    A block with skip links and layer - norm - activation.
373
    """
374
375
    layer_cls_dict = dict(conv3d=tfkl.Conv3D, deconv3d=Deconv3d)
376
    norm_cls_dict = dict(batch=tfkl.BatchNormalization, layer=tfkl.LayerNormalization)
377
378
    def __init__(
379
        self,
380
        layer_name: str,
381
        num_layers: int = 2,
382
        norm_name: str = "batch",
383
        activation: str = "relu",
384
        name: str = "res_block",
385
        **kwargs,
386
    ):
387
        """
388
        Init.
389
390
        :param layer_name: class of the layer to be wrapped.
391
        :param num_layers: number of layers/blocks.
392
        :param norm_name: class of the normalization layer.
393
        :param activation: name of activation.
394
        :param name: name of the block layer.
395
        :param kwargs: additional arguments.
396
        """
397
        super().__init__()
398
        self._num_layers = num_layers
399
        self._config = dict(
400
            layer_name=layer_name,
401
            num_layers=num_layers,
402
            norm_name=norm_name,
403
            activation=activation,
404
            name=name,
405
            **kwargs,
406
        )
407
        self._layers = [
408
            self.layer_cls_dict[layer_name](use_bias=False, **kwargs)
409
            for _ in range(num_layers)
410
        ]
411
        self._norms = [self.norm_cls_dict[norm_name]() for _ in range(num_layers)]
412
        self._acts = [tfkl.Activation(activation=activation) for _ in range(num_layers)]
413
414
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
415
        """
416
        Forward.
417
418
        :param inputs: inputs for the layer
419
        :param training: training flag for normalization layers (default: None)
420
        :param kwargs: additional arguments.
421
        :return:
422
        """
423
424
        output = inputs
425
        for i in range(self._num_layers):
426
            output = self._layers[i](inputs=output)
427
            output = self._norms[i](inputs=output, training=training)
428
            if i == self._num_layers - 1:
429
                # last block
430
                output = output + inputs
431
            output = self._acts[i](output)
432
        return output
433
434
    def get_config(self) -> dict:
435
        """Return the config dictionary for recreating this class."""
436
        config = super().get_config()
437
        config.update(self._config)
438
        return config
439
440
441
class ResidualConv3dBlock(ResidualBlock):
442
    """
443
    A conv3d residual block
444
    """
445
446
    def __init__(
447
        self,
448
        name: str = "conv3d_res_block",
449
        **kwargs,
450
    ):
451
        """
452
        Init.
453
454
        :param name: name of the layer
455
        :param kwargs: additional arguments.
456
        """
457
        super().__init__(layer_name="conv3d", name=name, **kwargs)
458
459
460
class UpSampleResnetBlock(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
461
    def __init__(self, filters, kernel_size=3, concat=False, **kwargs):
0 ignored issues
show
introduced by
"concat, filters, kernel_size" missing in parameter type documentation
Loading history...
462
        """
463
        An up-sampling resnet conv3d block, with deconv3d.
464
465
        :param filters: number of channels of the output
466
        :param kernel_size: int or tuple of 3 ints, e.g. (3,3,3) or 3
467
        :param concat: bool,specify how to combine input and skip connection images.
468
            If True, use concatenation, otherwise use sum (default=False).
469
        :param kwargs: additional arguments.
470
        """
471
        super().__init__(**kwargs)
472
        # save parameters
473
        self._filters = filters
474
        self._concat = concat
475
        # init layer variables
476
        self._deconv3d_block = None
477
        self._conv3d_block = Conv3dBlock(
478
            filters=filters, kernel_size=kernel_size, padding="same"
479
        )
480
        self._residual_block = ResidualConv3dBlock(
481
            filters=filters, kernel_size=kernel_size, padding="same"
482
        )
483
484
    def build(self, input_shape):
0 ignored issues
show
introduced by
"input_shape" missing in parameter type documentation
Loading history...
485
        """
486
        :param input_shape: tuple, (downsampled_image_shape, skip_image_shape)
487
        """
488
        super().build(input_shape)
489
        skip_shape = input_shape[1][1:4]
490
        self._deconv3d_block = Deconv3dBlock(
491
            filters=self._filters,
492
            output_shape=skip_shape,
493
            kernel_size=3,
494
            strides=2,
495
            padding="same",
496
        )
497
498
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
499
        r"""
500
        :param inputs: tuple
501
502
          - down-sampled
503
          - skipped
504
505
        :param training: training flag for normalization layers (default: None)
506
        :param kwargs: additional arguments.
507
        :return: shape = (batch, \*skip_connection_image_shape, filters]
508
        """
509
        up_sampled, skip = inputs[0], inputs[1]
510
        up_sampled = self._deconv3d_block(
511
            inputs=up_sampled, training=training
512
        )  # up sample and change channel
513
        up_sampled = (
514
            tf.concat([up_sampled, skip], axis=4) if self._concat else up_sampled + skip
515
        )  # combine
516
        up_sampled = self._conv3d_block(
517
            inputs=up_sampled, training=training
518
        )  # adjust channel
519
        up_sampled = self._residual_block(inputs=up_sampled, training=training)  # conv
520
        return up_sampled
521
522
523
class IntDVF(tfkl.Layer):
524
    """
525
    Integrate DVF to get DDF.
526
527
    Reference:
528
529
    - integrate_vec of neuron
530
      https://github.com/adalca/neurite/blob/legacy/neuron/utils.py
531
    """
532
533
    def __init__(
0 ignored issues
show
introduced by
"name" missing in parameter documentation
Loading history...
534
        self,
535
        fixed_image_size: tuple,
536
        num_steps: int = 7,
537
        name: str = "int_dvf",
538
        **kwargs,
539
    ):
540
        """
541
        Init.
542
543
        :param fixed_image_size: tuple, (f_dim1, f_dim2, f_dim3)
544
        :param num_steps: int, number of steps for integration
545
        :param kwargs: additional arguments.
546
        """
547
        super().__init__(name=name, **kwargs)
548
        assert len(fixed_image_size) == 3
549
        self._fixed_image_size = fixed_image_size
550
        self._num_steps = num_steps
551
        self._warping = Warping(fixed_image_size=fixed_image_size)
552
553
    def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor:
554
        """
555
        :param inputs: dvf, shape = (batch, f_dim1, f_dim2, f_dim3, 3)
556
        :param kwargs: additional arguments.
557
        :return: ddf, shape = (batch, f_dim1, f_dim2, f_dim3, 3)
558
        """
559
        ddf = inputs / (2 ** self._num_steps)
560
        for _ in range(self._num_steps):
561
            ddf += self._warping(inputs=[ddf, ddf])
562
        return ddf
563
564
    def get_config(self) -> dict:
565
        """Return the config dictionary for recreating this class."""
566
        config = super().get_config()
567
        config["fixed_image_size"] = self._fixed_image_size
568
        config["num_steps"] = self._num_steps
569
        return config
570
571
572
class AdditiveUpSampling(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
573
    def __init__(self, output_shape: tuple, stride: (int, list) = 2, **kwargs):
574
        """
575
        Layer up-samples 3d tensor and reduce channels using split and sum.
576
577
        :param output_shape: (out_dim1, out_dim2, out_dim3)
578
        :param stride: int, 1-D Tensor or list
579
        :param kwargs: additional arguments.
580
        """
581
        super().__init__(**kwargs)
582
        # save parameters
583
        self._stride = stride
584
        self._resize = Resize3d(output_shape)
585
        self._output_shape = output_shape
586
587
    def call(self, inputs, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"ValueError" not documented as being raised
Loading history...
introduced by
"inputs" missing in parameter type documentation
Loading history...
588
        """
589
        :param inputs: shape = (batch, dim1, dim2, dim3, channels)
590
        :param kwargs: additional arguments.
591
        :return: shape = (batch, out_dim1, out_dim2, out_dim3, channels//stride]
592
        """
593
        if inputs.shape[4] % self._stride != 0:
594
            raise ValueError("The channel dimension can not be divided by the stride")
595
        output = self._resize(inputs)
596
        # a list of (batch, out_dim1, out_dim2, out_dim3, channels//stride)
597
        output = tf.split(output, num_or_size_splits=self._stride, axis=4)
0 ignored issues
show
Unused Code introduced by
Argument 'axis' passed by position and keyword in function call
Loading history...
598
        # (batch, out_dim1, out_dim2, out_dim3, channels//stride)
599
        output = tf.reduce_sum(tf.stack(output, axis=5), axis=5)
600
        return output
601
602
603
class LocalNetResidual3dBlock(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
604
    def __init__(
605
        self,
606
        filters: int,
607
        kernel_size: (int, tuple) = 3,
608
        strides: (int, tuple) = 1,
609
        activation: str = "relu",
610
        **kwargs,
611
    ):
612
        """
613
        A resnet conv3d block, simpler than ResidualConv3dBlock.
614
615
        1. conved = conv3d(inputs)
616
        2. out = act(norm(conved) + inputs)
617
618
        :param filters: number of channels of the output
619
        :param kernel_size: int or tuple of 3 ints, e.g. (3,3,3) or 3
620
        :param strides: int or tuple of 3 ints, e.g. (1,1,1) or 1
621
        :param activation: name of activation
622
        :param kwargs: additional arguments.
623
        """
624
        super().__init__(**kwargs)
625
        # init layer variables
626
        self._conv3d = tfkl.Conv3D(
627
            filters=filters,
628
            kernel_size=kernel_size,
629
            strides=strides,
630
            padding="same",
631
            use_bias=False,
632
        )
633
        self._norm = tfkl.BatchNormalization()
634
        self._act = tfkl.Activation(activation=activation)
635
636
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
637
        return self._act(
638
            self._norm(inputs=self._conv3d(inputs=inputs[0]), training=training)
639
            + inputs[1]
640
        )
641
642
643
class LocalNetUpSampleResnetBlock(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
644
    def __init__(self, filters: int, use_additive_upsampling: bool = True, **kwargs):
645
        """
646
        Layer up-samples tensor with two inputs (skipped and down-sampled).
647
648
        :param filters: int, number of output channels
649
        :param use_additive_upsampling: bool to used additive upsampling
650
        :param kwargs: additional arguments.
651
        """
652
        super().__init__(**kwargs)
653
        # save parameters
654
        self._filters = filters
655
        self._use_additive_upsampling = use_additive_upsampling
656
        # init layer variables
657
        self._deconv3d_block = None
658
        self._additive_upsampling = None
659
        self._conv3d_block = Conv3dBlock(filters=filters, kernel_size=3, padding="same")
660
        self._residual_block = LocalNetResidual3dBlock(filters=filters, strides=1)
661
662
    def build(self, input_shape):
0 ignored issues
show
introduced by
"input_shape" missing in parameter type documentation
Loading history...
663
        """
664
        :param input_shape: tuple (nonskip_tensor_shape, skip_tensor_shape)
665
        """
666
        super().build(input_shape)
667
668
        output_shape = input_shape[1][1:4]
669
        self._deconv3d_block = Deconv3dBlock(
670
            filters=self._filters,
671
            output_shape=output_shape,
672
            kernel_size=3,
673
            strides=2,
674
            padding="same",
675
        )
676
        if self._use_additive_upsampling:
677
            self._additive_upsampling = AdditiveUpSampling(output_shape=output_shape)
678
679
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
680
        """
681
        :param inputs: list = [inputs_nonskip, inputs_skip]
682
        :param training: training flag for normalization layers (default: None)
683
        :param kwargs: additional arguments.
684
        :return:
685
        """
686
        inputs_nonskip, inputs_skip = inputs[0], inputs[1]
687
        h0 = self._deconv3d_block(inputs=inputs_nonskip, training=training)
688
        if self._use_additive_upsampling:
689
            h0 += self._additive_upsampling(inputs=inputs_nonskip)
690
        r1 = h0 + inputs_skip
691
        r2 = self._conv3d_block(inputs=h0, training=training)
692
        h1 = self._residual_block(inputs=[r2, r1], training=training)
693
        return h1
694
695
696
class ResizeCPTransform(tfkl.Layer):
697
    """
698
    Layer for getting the control points from the output of a image-to-image network.
699
    It uses an anti-aliasing Gaussian filter before downsampling.
700
    """
701
702
    def __init__(self, control_point_spacing: (list, tuple, int), **kwargs):
703
        """
704
        :param control_point_spacing: list or int
705
        :param kwargs: additional arguments.
706
        """
707
        super().__init__(**kwargs)
708
709
        if isinstance(control_point_spacing, int):
710
            control_point_spacing = [control_point_spacing] * 3
711
712
        self.kernel_sigma = [
713
            0.44 * cp for cp in control_point_spacing
714
        ]  # 0.44 = ln(4)/pi
715
        self.cp_spacing = control_point_spacing
716
        self.kernel = None
717
        self._output_shape = None
718
        self._resize = None
719
720
    def build(self, input_shape):
721
        super().build(input_shape=input_shape)
722
723
        self.kernel = layer_util.gaussian_filter_3d(self.kernel_sigma)
724
        output_shape = [
725
            tf.cast(tf.math.ceil(v / c) + 3, tf.int32)
726
            for v, c in zip(input_shape[1:-1], self.cp_spacing)
727
        ]
728
        self._output_shape = output_shape
729
        self._resize = Resize3d(output_shape)
730
731
    def call(self, inputs, **kwargs) -> tf.Tensor:
732
        output = tf.nn.conv3d(
733
            inputs, self.kernel, strides=(1, 1, 1, 1, 1), padding="SAME"
734
        )
735
        output = self._resize(inputs=output)
736
        return output
737
738
739
class BSplines3DTransform(tfkl.Layer):
740
    """
741
     Layer for BSplines interpolation with precomputed cubic spline filters.
742
     It assumes a full sized image from which:
743
     1. it compute the contol points values by downsampling the initial image
744
     2. performs the interpolation
745
     3. crops the image around the valid values.
746
747
    :param cp_spacing: int or tuple of three ints specifying the spacing (in pixels)
748
        in each dimension. When a single int is used,
749
        the same spacing to all dimensions is used
750
    :param output_shape: (batch_size, dim0, dim1, dim2, 3) of the high resolution
751
        deformation fields.
752
    :param kwargs: additional arguments.
753
    """
754
755
    def __init__(self, cp_spacing: (int, tuple), output_shape: tuple, **kwargs):
756
757
        super().__init__(**kwargs)
758
759
        self.filters = []
760
        self._output_shape = output_shape
761
762
        if isinstance(cp_spacing, int):
763
            self.cp_spacing = (cp_spacing, cp_spacing, cp_spacing)
764
        else:
765
            self.cp_spacing = cp_spacing
766
767
    def build(self, input_shape: tuple):
0 ignored issues
show
introduced by
Redundant returns documentation
Loading history...
768
        """
769
        :param input_shape: tuple with the input shape
770
        :return: None
771
        """
772
773
        super().build(input_shape=input_shape)
774
775
        b = {
776
            0: lambda u: np.float64((1 - u) ** 3 / 6),
777
            1: lambda u: np.float64((3 * (u ** 3) - 6 * (u ** 2) + 4) / 6),
778
            2: lambda u: np.float64((-3 * (u ** 3) + 3 * (u ** 2) + 3 * u + 1) / 6),
779
            3: lambda u: np.float64(u ** 3 / 6),
780
        }
781
782
        filters = np.zeros(
783
            (
784
                4 * self.cp_spacing[0],
785
                4 * self.cp_spacing[1],
786
                4 * self.cp_spacing[2],
787
                3,
788
                3,
789
            ),
790
            dtype=np.float32,
791
        )
792
793
        u_arange = 1 - np.arange(
794
            1 / (2 * self.cp_spacing[0]), 1, 1 / self.cp_spacing[0]
795
        )
796
        v_arange = 1 - np.arange(
797
            1 / (2 * self.cp_spacing[1]), 1, 1 / self.cp_spacing[1]
798
        )
799
        w_arange = 1 - np.arange(
800
            1 / (2 * self.cp_spacing[2]), 1, 1 / self.cp_spacing[2]
801
        )
802
803
        filter_idx = [[0, 1, 2, 3] for _ in range(3)]
804
        filter_coord = list(itertools.product(*filter_idx))
805
806
        for f_idx in filter_coord:
807
            for it_dim in range(3):
808
                filters[
809
                    f_idx[0] * self.cp_spacing[0] : (f_idx[0] + 1) * self.cp_spacing[0],
810
                    f_idx[1] * self.cp_spacing[1] : (f_idx[1] + 1) * self.cp_spacing[1],
811
                    f_idx[2] * self.cp_spacing[2] : (f_idx[2] + 1) * self.cp_spacing[2],
812
                    it_dim,
813
                    it_dim,
814
                ] = (
815
                    b[f_idx[0]](u_arange)[:, None, None]
816
                    * b[f_idx[1]](v_arange)[None, :, None]
817
                    * b[f_idx[2]](w_arange)[None, None, :]
818
                )
819
820
        self.filter = tf.convert_to_tensor(filters)
821
822
    def interpolate(self, field) -> tf.Tensor:
0 ignored issues
show
introduced by
"field" missing in parameter type documentation
Loading history...
823
        """
824
        :param field: tf.Tensor with shape=number_of_control_points_per_dim
825
        :return: interpolated_field: tf.Tensor
826
        """
827
828
        image_shape = tuple(
829
            [(a - 1) * b + 4 * b for a, b in zip(field.shape[1:-1], self.cp_spacing)]
830
        )
831
832
        output_shape = (field.shape[0],) + image_shape + (3,)
833
        return tf.nn.conv3d_transpose(
834
            field,
835
            self.filter,
836
            output_shape=output_shape,
837
            strides=self.cp_spacing,
838
            padding="VALID",
839
        )
840
841
    def call(self, inputs, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"inputs" missing in parameter type documentation
Loading history...
842
        """
843
        :param inputs: tf.Tensor defining a low resolution free-form deformation field
844
        :param kwargs: additional arguments.
845
        :return: interpolated_field: tf.Tensor of shape=self.input_shape
846
        """
847
        high_res_field = self.interpolate(inputs)
848
849
        index = [int(3 * c) for c in self.cp_spacing]
850
        return high_res_field[
851
            :,
852
            index[0] : index[0] + self._output_shape[0],
853
            index[1] : index[1] + self._output_shape[1],
854
            index[2] : index[2] + self._output_shape[2],
855
        ]
856