Passed
Pull Request — main (#656)
by Yunguan
02:51
created

deepreg.model.layer.ResidualBlock.call()   A

Complexity

Conditions 3

Size

Total Lines 19
Code Lines 9

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 9
dl 0
loc 19
rs 9.95
c 0
b 0
f 0
cc 3
nop 4
1
"""This module defines custom layers."""
2
import itertools
3
from typing import List, Tuple
4
5
import numpy as np
6
import tensorflow as tf
7
import tensorflow.keras.layers as tfkl
8
9
import deepreg.model.layer_util as layer_util
10
11
LAYER_DICT = dict(conv3d=tfkl.Conv3D, deconv3d=tfkl.Conv3DTranspose)
12
NORM_DICT = dict(batch=tfkl.BatchNormalization, layer=tfkl.LayerNormalization)
13
14
15
class NormBlock(tfkl.Layer):
16
    """
17
    A block with layer - norm - activation.
18
    """
19
20
    def __init__(
21
        self,
22
        layer_name: str,
23
        norm_name: str = "batch",
24
        activation: str = "relu",
25
        name: str = "norm_block",
26
        **kwargs,
27
    ):
28
        """
29
        Init.
30
31
        :param layer_name: class of the layer to be wrapped.
32
        :param norm_name: class of the normalization layer.
33
        :param activation: name of activation.
34
        :param name: name of the block layer.
35
        :param kwargs: additional arguments.
36
        """
37
        super().__init__()
38
        self._config = dict(
39
            layer_name=layer_name,
40
            norm_name=norm_name,
41
            activation=activation,
42
            name=name,
43
            **kwargs,
44
        )
45
        self._layer = LAYER_DICT[layer_name](use_bias=False, **kwargs)
46
        self._norm = NORM_DICT[norm_name]()
47
        self._act = tfkl.Activation(activation=activation)
48
49
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
50
        """
51
        Forward.
52
53
        :param inputs: inputs for the layer
54
        :param training: training flag for normalization layers (default: None)
55
        :param kwargs: additional arguments.
56
        :return:
57
        """
58
        output = self._layer(inputs=inputs)
59
        output = self._norm(inputs=output, training=training)
60
        output = self._act(output)
61
        return output
62
63
    def get_config(self) -> dict:
64
        """Return the config dictionary for recreating this class."""
65
        config = super().get_config()
66
        config.update(self._config)
67
        return config
68
69
70
class Conv3dBlock(NormBlock):
71
    """
72
    A conv3d block having conv3d - norm - activation.
73
    """
74
75
    def __init__(
76
        self,
77
        name: str = "conv3d_block",
78
        **kwargs,
79
    ):
80
        """
81
        Init.
82
83
        :param name: name of the layer
84
        :param kwargs: additional arguments.
85
        """
86
        super().__init__(layer_name="conv3d", name=name, **kwargs)
87
88
89
class Deconv3dBlock(NormBlock):
90
    """
91
    A deconv3d block having conv3d - norm - activation.
92
    """
93
94
    def __init__(
95
        self,
96
        name: str = "deconv3d_block",
97
        **kwargs,
98
    ):
99
        """
100
        Init.
101
102
        :param name: name of the layer
103
        :param kwargs: additional arguments.
104
        """
105
        super().__init__(layer_name="deconv3d", name=name, **kwargs)
106
107
108
class Resize3d(tfkl.Layer):
109
    """
110
    Resize image in two folds.
111
112
    - resize dim2 and dim3
113
    - resize dim1 and dim2
114
    """
115
116
    def __init__(
117
        self,
118
        shape: tuple,
119
        method: str = tf.image.ResizeMethod.BILINEAR,
120
        name: str = "resize3d",
121
    ):
122
        """
123
        Init, save arguments.
124
125
        :param shape: (dim1, dim2, dim3)
126
        :param method: tf.image.ResizeMethod
127
        :param name: name of the layer
128
        """
129
        super().__init__(name=name)
130
        assert len(shape) == 3
131
        self._shape = shape
132
        self._method = method
133
134
    def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"ValueError" not documented as being raised
Loading history...
135
        """
136
        Perform two fold resize.
137
138
        :param inputs: shape = (batch, dim1, dim2, dim3, channels)
139
                                     or (batch, dim1, dim2, dim3)
140
                                     or (dim1, dim2, dim3)
141
        :param kwargs: additional arguments
142
        :return: shape = (batch, out_dim1, out_dim2, out_dim3, channels)
143
                                or (batch, dim1, dim2, dim3)
144
                                or (dim1, dim2, dim3)
145
        """
146
        # sanity check
147
        image = inputs
148
        image_dim = len(image.shape)
149
150
        # init
151
        if image_dim == 5:
152
            has_channel = True
153
            has_batch = True
154
            input_image_shape = image.shape[1:4]
155
        elif image_dim == 4:
156
            has_channel = False
157
            has_batch = True
158
            input_image_shape = image.shape[1:4]
159
        elif image_dim == 3:
160
            has_channel = False
161
            has_batch = False
162
            input_image_shape = image.shape[0:3]
163
        else:
164
            raise ValueError(
165
                "Resize3d takes input image of dimension 3 or 4 or 5, "
166
                "corresponding to (dim1, dim2, dim3) "
167
                "or (batch, dim1, dim2, dim3) "
168
                "or (batch, dim1, dim2, dim3, channels), "
169
                "got image shape{}".format(image.shape)
170
            )
171
172
        # no need of resize
173
        if input_image_shape == tuple(self._shape):
174
            return image
175
176
        # expand to five dimensions
177
        if not has_batch:
178
            image = tf.expand_dims(image, axis=0)
179
        if not has_channel:
180
            image = tf.expand_dims(image, axis=-1)
181
        assert len(image.shape) == 5  # (batch, dim1, dim2, dim3, channels)
182
        image_shape = tf.shape(image)
183
184
        # merge axis 0 and 1
185
        output = tf.reshape(
186
            image, (-1, image_shape[2], image_shape[3], image_shape[4])
187
        )  # (batch * dim1, dim2, dim3, channels)
188
189
        # resize dim2 and dim3
190
        output = tf.image.resize(
191
            images=output, size=self._shape[1:3], method=self._method
192
        )  # (batch * dim1, out_dim2, out_dim3, channels)
193
194
        # split axis 0 and merge axis 3 and 4
195
        output = tf.reshape(
196
            output,
197
            shape=(-1, image_shape[1], self._shape[1], self._shape[2] * image_shape[4]),
198
        )  # (batch, dim1, out_dim2, out_dim3 * channels)
199
200
        # resize dim1 and dim2
201
        output = tf.image.resize(
202
            images=output, size=self._shape[:2], method=self._method
203
        )  # (batch, out_dim1, out_dim2, out_dim3 * channels)
204
205
        # reshape
206
        output = tf.reshape(
207
            output, shape=[-1, *self._shape, image_shape[4]]
208
        )  # (batch, out_dim1, out_dim2, out_dim3, channels)
209
210
        # squeeze to original dimension
211
        if not has_batch:
212
            output = tf.squeeze(output, axis=0)
213
        if not has_channel:
214
            output = tf.squeeze(output, axis=-1)
215
        return output
216
217
    def get_config(self) -> dict:
218
        """Return the config dictionary for recreating this class."""
219
        config = super().get_config()
220
        config["shape"] = self._shape
221
        config["method"] = self._method
222
        return config
223
224
225
class Warping(tfkl.Layer):
226
    """
227
    Warps an image with DDF.
228
229
    Reference:
230
231
    https://github.com/adalca/neurite/blob/legacy/neuron/utils.py
232
    where vol = image, loc_shift = ddf
233
    """
234
235
    def __init__(self, fixed_image_size: tuple, name: str = "warping", **kwargs):
236
        """
237
        Init.
238
239
        :param fixed_image_size: shape = (f_dim1, f_dim2, f_dim3)
240
             or (f_dim1, f_dim2, f_dim3, ch) with the last channel for features
241
        :param name: name of the layer
242
        :param kwargs: additional arguments.
243
        """
244
        super().__init__(name=name, **kwargs)
245
        self._fixed_image_size = fixed_image_size
246
        # shape = (1, f_dim1, f_dim2, f_dim3, 3)
247
        self.grid_ref = layer_util.get_reference_grid(grid_size=fixed_image_size)
248
        self.grid_ref = self.grid_ref[None, ...]
249
250
    def call(self, inputs, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"inputs" missing in parameter type documentation
Loading history...
251
        """
252
        :param inputs: (ddf, image)
253
254
          - ddf, shape = (batch, f_dim1, f_dim2, f_dim3, 3)
255
          - image, shape = (batch, m_dim1, m_dim2, m_dim3)
256
        :param kwargs: additional arguments.
257
        :return: shape = (batch, f_dim1, f_dim2, f_dim3)
258
        """
259
        ddf, image = inputs
260
        return layer_util.resample(vol=image, loc=self.grid_ref + ddf)
261
262
    def get_config(self) -> dict:
263
        """Return the config dictionary for recreating this class."""
264
        config = super().get_config()
265
        config["fixed_image_size"] = self._fixed_image_size
266
        return config
267
268
269
class ResidualBlock(tfkl.Layer):
270
    """
271
    A block with skip links and layer - norm - activation.
272
    """
273
274
    def __init__(
275
        self,
276
        layer_name: str,
277
        num_layers: int = 2,
278
        norm_name: str = "batch",
279
        activation: str = "relu",
280
        name: str = "res_block",
281
        **kwargs,
282
    ):
283
        """
284
        Init.
285
286
        :param layer_name: class of the layer to be wrapped.
287
        :param num_layers: number of layers/blocks.
288
        :param norm_name: class of the normalization layer.
289
        :param activation: name of activation.
290
        :param name: name of the block layer.
291
        :param kwargs: additional arguments.
292
        """
293
        super().__init__()
294
        self._num_layers = num_layers
295
        self._config = dict(
296
            layer_name=layer_name,
297
            num_layers=num_layers,
298
            norm_name=norm_name,
299
            activation=activation,
300
            name=name,
301
            **kwargs,
302
        )
303
        self._layers = [
304
            LAYER_DICT[layer_name](use_bias=False, **kwargs) for _ in range(num_layers)
305
        ]
306
        self._norms = [NORM_DICT[norm_name]() for _ in range(num_layers)]
307
        self._acts = [tfkl.Activation(activation=activation) for _ in range(num_layers)]
308
309
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
310
        """
311
        Forward.
312
313
        :param inputs: inputs for the layer
314
        :param training: training flag for normalization layers (default: None)
315
        :param kwargs: additional arguments.
316
        :return:
317
        """
318
319
        output = inputs
320
        for i in range(self._num_layers):
321
            output = self._layers[i](inputs=output)
322
            output = self._norms[i](inputs=output, training=training)
323
            if i == self._num_layers - 1:
324
                # last block
325
                output = output + inputs
326
            output = self._acts[i](output)
327
        return output
328
329
    def get_config(self) -> dict:
330
        """Return the config dictionary for recreating this class."""
331
        config = super().get_config()
332
        config.update(self._config)
333
        return config
334
335
336
class ResidualConv3dBlock(ResidualBlock):
337
    """
338
    A conv3d residual block
339
    """
340
341
    def __init__(
342
        self,
343
        name: str = "conv3d_res_block",
344
        **kwargs,
345
    ):
346
        """
347
        Init.
348
349
        :param name: name of the layer
350
        :param kwargs: additional arguments.
351
        """
352
        super().__init__(layer_name="conv3d", name=name, **kwargs)
353
354
355
class IntDVF(tfkl.Layer):
356
    """
357
    Integrate DVF to get DDF.
358
359
    Reference:
360
361
    - integrate_vec of neuron
362
      https://github.com/adalca/neurite/blob/legacy/neuron/utils.py
363
    """
364
365
    def __init__(
366
        self,
367
        fixed_image_size: tuple,
368
        num_steps: int = 7,
369
        name: str = "int_dvf",
370
        **kwargs,
371
    ):
372
        """
373
        Init.
374
375
        :param fixed_image_size: tuple, (f_dim1, f_dim2, f_dim3)
376
        :param num_steps: int, number of steps for integration
377
        :param name: name of the layer
378
        :param kwargs: additional arguments.
379
        """
380
        super().__init__(name=name, **kwargs)
381
        assert len(fixed_image_size) == 3
382
        self._fixed_image_size = fixed_image_size
383
        self._num_steps = num_steps
384
        self._warping = Warping(fixed_image_size=fixed_image_size)
385
386
    def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor:
387
        """
388
        :param inputs: dvf, shape = (batch, f_dim1, f_dim2, f_dim3, 3)
389
        :param kwargs: additional arguments.
390
        :return: ddf, shape = (batch, f_dim1, f_dim2, f_dim3, 3)
391
        """
392
        ddf = inputs / (2 ** self._num_steps)
393
        for _ in range(self._num_steps):
394
            ddf += self._warping(inputs=[ddf, ddf])
395
        return ddf
396
397
    def get_config(self) -> dict:
398
        """Return the config dictionary for recreating this class."""
399
        config = super().get_config()
400
        config["fixed_image_size"] = self._fixed_image_size
401
        config["num_steps"] = self._num_steps
402
        return config
403
404
405
class ResizeCPTransform(tfkl.Layer):
406
    """
407
    Layer for getting the control points from the output of a image-to-image network.
408
    It uses an anti-aliasing Gaussian filter before down-sampling.
409
    """
410
411
    def __init__(self, control_point_spacing: (list, tuple, int), **kwargs):
412
        """
413
        :param control_point_spacing: list or int
414
        :param kwargs: additional arguments.
415
        """
416
        super().__init__(**kwargs)
417
418
        if isinstance(control_point_spacing, int):
419
            control_point_spacing = [control_point_spacing] * 3
420
421
        self.kernel_sigma = [
422
            0.44 * cp for cp in control_point_spacing
423
        ]  # 0.44 = ln(4)/pi
424
        self.cp_spacing = control_point_spacing
425
        self.kernel = None
426
        self._output_shape = None
427
        self._resize = None
428
429
    def build(self, input_shape):
430
        super().build(input_shape=input_shape)
431
432
        self.kernel = layer_util.gaussian_filter_3d(self.kernel_sigma)
433
        output_shape = [
434
            tf.cast(tf.math.ceil(v / c) + 3, tf.int32)
435
            for v, c in zip(input_shape[1:-1], self.cp_spacing)
436
        ]
437
        self._output_shape = output_shape
438
        self._resize = Resize3d(output_shape)
439
440
    def call(self, inputs, **kwargs) -> tf.Tensor:
441
        output = tf.nn.conv3d(
442
            inputs, self.kernel, strides=(1, 1, 1, 1, 1), padding="SAME"
443
        )
444
        output = self._resize(inputs=output)
445
        return output
446
447
448
class BSplines3DTransform(tfkl.Layer):
449
    """
450
     Layer for BSplines interpolation with precomputed cubic spline kernel_size.
451
     It assumes a full sized image from which:
452
     1. it compute the contol points values by down-sampling the initial image
453
     2. performs the interpolation
454
     3. crops the image around the valid values.
455
456
    :param cp_spacing: int or tuple of three ints specifying the spacing (in pixels)
457
        in each dimension. When a single int is used,
458
        the same spacing to all dimensions is used
459
    :param output_shape: (batch_size, dim0, dim1, dim2, 3) of the high resolution
460
        deformation fields.
461
    :param kwargs: additional arguments.
462
    """
463
464
    def __init__(self, cp_spacing: (int, tuple), output_shape: tuple, **kwargs):
465
466
        super().__init__(**kwargs)
467
468
        self.filters = []
469
        self._output_shape = output_shape
470
471
        if isinstance(cp_spacing, int):
472
            self.cp_spacing = (cp_spacing, cp_spacing, cp_spacing)
473
        else:
474
            self.cp_spacing = cp_spacing
475
476
    def build(self, input_shape: tuple):
0 ignored issues
show
introduced by
Redundant returns documentation
Loading history...
477
        """
478
        :param input_shape: tuple with the input shape
479
        :return: None
480
        """
481
482
        super().build(input_shape=input_shape)
483
484
        b = {
485
            0: lambda u: np.float64((1 - u) ** 3 / 6),
486
            1: lambda u: np.float64((3 * (u ** 3) - 6 * (u ** 2) + 4) / 6),
487
            2: lambda u: np.float64((-3 * (u ** 3) + 3 * (u ** 2) + 3 * u + 1) / 6),
488
            3: lambda u: np.float64(u ** 3 / 6),
489
        }
490
491
        filters = np.zeros(
492
            (
493
                4 * self.cp_spacing[0],
494
                4 * self.cp_spacing[1],
495
                4 * self.cp_spacing[2],
496
                3,
497
                3,
498
            ),
499
            dtype=np.float32,
500
        )
501
502
        u_arange = 1 - np.arange(
503
            1 / (2 * self.cp_spacing[0]), 1, 1 / self.cp_spacing[0]
504
        )
505
        v_arange = 1 - np.arange(
506
            1 / (2 * self.cp_spacing[1]), 1, 1 / self.cp_spacing[1]
507
        )
508
        w_arange = 1 - np.arange(
509
            1 / (2 * self.cp_spacing[2]), 1, 1 / self.cp_spacing[2]
510
        )
511
512
        filter_idx = [[0, 1, 2, 3] for _ in range(3)]
513
        filter_coord = list(itertools.product(*filter_idx))
514
515
        for f_idx in filter_coord:
516
            for it_dim in range(3):
517
                filters[
518
                    f_idx[0] * self.cp_spacing[0] : (f_idx[0] + 1) * self.cp_spacing[0],
519
                    f_idx[1] * self.cp_spacing[1] : (f_idx[1] + 1) * self.cp_spacing[1],
520
                    f_idx[2] * self.cp_spacing[2] : (f_idx[2] + 1) * self.cp_spacing[2],
521
                    it_dim,
522
                    it_dim,
523
                ] = (
524
                    b[f_idx[0]](u_arange)[:, None, None]
525
                    * b[f_idx[1]](v_arange)[None, :, None]
526
                    * b[f_idx[2]](w_arange)[None, None, :]
527
                )
528
529
        self.filter = tf.convert_to_tensor(filters)
530
531
    def interpolate(self, field) -> tf.Tensor:
0 ignored issues
show
introduced by
"field" missing in parameter type documentation
Loading history...
532
        """
533
        :param field: tf.Tensor with shape=number_of_control_points_per_dim
534
        :return: interpolated_field: tf.Tensor
535
        """
536
537
        image_shape = tuple(
538
            [(a - 1) * b + 4 * b for a, b in zip(field.shape[1:-1], self.cp_spacing)]
539
        )
540
541
        output_shape = (field.shape[0],) + image_shape + (3,)
542
        return tf.nn.conv3d_transpose(
543
            field,
544
            self.filter,
545
            output_shape=output_shape,
546
            strides=self.cp_spacing,
547
            padding="VALID",
548
        )
549
550
    def call(self, inputs, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"inputs" missing in parameter type documentation
Loading history...
551
        """
552
        :param inputs: tf.Tensor defining a low resolution free-form deformation field
553
        :param kwargs: additional arguments.
554
        :return: interpolated_field: tf.Tensor of shape=self.input_shape
555
        """
556
        high_res_field = self.interpolate(inputs)
557
558
        index = [int(3 * c) for c in self.cp_spacing]
559
        return high_res_field[
560
            :,
561
            index[0] : index[0] + self._output_shape[0],
562
            index[1] : index[1] + self._output_shape[1],
563
            index[2] : index[2] + self._output_shape[2],
564
        ]
565
566
567
class Extraction(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
568
    def __init__(
569
        self,
570
        image_size: Tuple[int],
571
        extract_levels: Tuple[int],
572
        out_channels: int,
573
        out_kernel_initializer: str,
574
        out_activation: str,
575
        name: str = "Extraction",
576
    ):
577
        """
578
        :param image_size: such as (dim1, dim2, dim3)
579
        :param extract_levels: number of extraction levels.
580
        :param out_channels: number of channels for the extractions
581
        :param out_kernel_initializer: initializer to use for kernels.
582
        :param out_activation: activation to use at end layer.
583
        :param name: name of the layer
584
        """
585
        super().__init__(name=name)
586
        self.extract_levels = extract_levels
587
        self.max_level = max(extract_levels)
588
        self.layers = [
589
            tf.keras.Sequential(
590
                [
591
                    tfkl.Conv3D(
592
                        filters=out_channels,
593
                        kernel_size=3,
594
                        strides=1,
595
                        padding="same",
596
                        kernel_initializer=out_kernel_initializer,
597
                        activation=out_activation,
598
                    ),
599
                    Resize3d(shape=image_size),
600
                ]
601
            )
602
            for _ in extract_levels
603
        ]
604
605
    def call(self, inputs: List[tf.Tensor], **kwargs) -> tf.Tensor:
606
        """
607
608
        :param inputs: a list of tensors
609
        :param kwargs:
610
        :return:
611
        """
612
        outputs = [
613
            self.layers[idx](inputs=inputs[self.max_level - level])
614
            for idx, level in enumerate(self.extract_levels)
615
        ]
616
        if len(self.extract_levels) == 1:
617
            return outputs[0]
618
        return tf.add_n(outputs) / len(self.extract_levels)
619