Passed
Pull Request — main (#675)
by Yunguan
03:26
created

deepreg.model.layer.ResizeCPTransform.__init__()   A

Complexity

Conditions 2

Size

Total Lines 19
Code Lines 11

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 11
dl 0
loc 19
rs 9.85
c 0
b 0
f 0
cc 2
nop 3
1
"""This module defines custom layers."""
2
import itertools
3
from typing import List, Tuple, Union
4
5
import numpy as np
6
import tensorflow as tf
7
import tensorflow.keras.layers as tfkl
8
9
import deepreg.model.layer_util as layer_util
10
11
LAYER_DICT = dict(conv3d=tfkl.Conv3D, deconv3d=tfkl.Conv3DTranspose)
12
NORM_DICT = dict(batch=tfkl.BatchNormalization, layer=tfkl.LayerNormalization)
13
14
15
class NormBlock(tfkl.Layer):
16
    """
17
    A block with layer - norm - activation.
18
    """
19
20
    def __init__(
21
        self,
22
        layer_name: str,
23
        norm_name: str = "batch",
24
        activation: str = "relu",
25
        name: str = "norm_block",
26
        **kwargs,
27
    ):
28
        """
29
        Init.
30
31
        :param layer_name: class of the layer to be wrapped.
32
        :param norm_name: class of the normalization layer.
33
        :param activation: name of activation.
34
        :param name: name of the block layer.
35
        :param kwargs: additional arguments.
36
        """
37
        super().__init__()
38
        self._config = dict(
39
            layer_name=layer_name,
40
            norm_name=norm_name,
41
            activation=activation,
42
            name=name,
43
            **kwargs,
44
        )
45
        self._layer = LAYER_DICT[layer_name](use_bias=False, **kwargs)
46
        self._norm = NORM_DICT[norm_name]()
47
        self._act = tfkl.Activation(activation=activation)
48
49
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
50
        """
51
        Forward.
52
53
        :param inputs: inputs for the layer
54
        :param training: training flag for normalization layers (default: None)
55
        :param kwargs: additional arguments.
56
        :return:
57
        """
58
        output = self._layer(inputs=inputs)
59
        output = self._norm(inputs=output, training=training)
60
        output = self._act(output)
61
        return output
62
63
    def get_config(self) -> dict:
64
        """Return the config dictionary for recreating this class."""
65
        config = super().get_config()
66
        config.update(self._config)
67
        return config
68
69
70
class Conv3dBlock(NormBlock):
71
    """
72
    A conv3d block having conv3d - norm - activation.
73
    """
74
75
    def __init__(
76
        self,
77
        name: str = "conv3d_block",
78
        **kwargs,
79
    ):
80
        """
81
        Init.
82
83
        :param name: name of the layer
84
        :param kwargs: additional arguments.
85
        """
86
        super().__init__(layer_name="conv3d", name=name, **kwargs)
87
88
89
class Deconv3dBlock(NormBlock):
90
    """
91
    A deconv3d block having conv3d - norm - activation.
92
    """
93
94
    def __init__(
95
        self,
96
        name: str = "deconv3d_block",
97
        **kwargs,
98
    ):
99
        """
100
        Init.
101
102
        :param name: name of the layer
103
        :param kwargs: additional arguments.
104
        """
105
        super().__init__(layer_name="deconv3d", name=name, **kwargs)
106
107
108
class Resize3d(tfkl.Layer):
109
    """
110
    Resize image in two folds.
111
112
    - resize dim2 and dim3
113
    - resize dim1 and dim2
114
    """
115
116
    def __init__(
117
        self,
118
        shape: tuple,
119
        method: str = tf.image.ResizeMethod.BILINEAR,
120
        name: str = "resize3d",
121
    ):
122
        """
123
        Init, save arguments.
124
125
        :param shape: (dim1, dim2, dim3)
126
        :param method: tf.image.ResizeMethod
127
        :param name: name of the layer
128
        """
129
        super().__init__(name=name)
130
        assert len(shape) == 3
131
        self._shape = shape
132
        self._method = method
133
134
    def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"ValueError" not documented as being raised
Loading history...
135
        """
136
        Perform two fold resize.
137
138
        :param inputs: shape = (batch, dim1, dim2, dim3, channels)
139
                                     or (batch, dim1, dim2, dim3)
140
                                     or (dim1, dim2, dim3)
141
        :param kwargs: additional arguments
142
        :return: shape = (batch, out_dim1, out_dim2, out_dim3, channels)
143
                                or (batch, dim1, dim2, dim3)
144
                                or (dim1, dim2, dim3)
145
        """
146
        # sanity check
147
        image = inputs
148
        image_dim = len(image.shape)
149
150
        # init
151
        if image_dim == 5:
152
            has_channel = True
153
            has_batch = True
154
            input_image_shape = image.shape[1:4]
155
        elif image_dim == 4:
156
            has_channel = False
157
            has_batch = True
158
            input_image_shape = image.shape[1:4]
159
        elif image_dim == 3:
160
            has_channel = False
161
            has_batch = False
162
            input_image_shape = image.shape[0:3]
163
        else:
164
            raise ValueError(
165
                "Resize3d takes input image of dimension 3 or 4 or 5, "
166
                "corresponding to (dim1, dim2, dim3) "
167
                "or (batch, dim1, dim2, dim3) "
168
                "or (batch, dim1, dim2, dim3, channels), "
169
                "got image shape{}".format(image.shape)
170
            )
171
172
        # no need of resize
173
        if input_image_shape == tuple(self._shape):
174
            return image
175
176
        # expand to five dimensions
177
        if not has_batch:
178
            image = tf.expand_dims(image, axis=0)
179
        if not has_channel:
180
            image = tf.expand_dims(image, axis=-1)
181
        assert len(image.shape) == 5  # (batch, dim1, dim2, dim3, channels)
182
        image_shape = tf.shape(image)
183
184
        # merge axis 0 and 1
185
        output = tf.reshape(
186
            image, (-1, image_shape[2], image_shape[3], image_shape[4])
187
        )  # (batch * dim1, dim2, dim3, channels)
188
189
        # resize dim2 and dim3
190
        output = tf.image.resize(
191
            images=output, size=self._shape[1:3], method=self._method
192
        )  # (batch * dim1, out_dim2, out_dim3, channels)
193
194
        # split axis 0 and merge axis 3 and 4
195
        output = tf.reshape(
196
            output,
197
            shape=(-1, image_shape[1], self._shape[1], self._shape[2] * image_shape[4]),
198
        )  # (batch, dim1, out_dim2, out_dim3 * channels)
199
200
        # resize dim1 and dim2
201
        output = tf.image.resize(
202
            images=output, size=self._shape[:2], method=self._method
203
        )  # (batch, out_dim1, out_dim2, out_dim3 * channels)
204
205
        # reshape
206
        output = tf.reshape(
207
            output, shape=[-1, *self._shape, image_shape[4]]
208
        )  # (batch, out_dim1, out_dim2, out_dim3, channels)
209
210
        # squeeze to original dimension
211
        if not has_batch:
212
            output = tf.squeeze(output, axis=0)
213
        if not has_channel:
214
            output = tf.squeeze(output, axis=-1)
215
        return output
216
217
    def get_config(self) -> dict:
218
        """Return the config dictionary for recreating this class."""
219
        config = super().get_config()
220
        config["shape"] = self._shape
221
        config["method"] = self._method
222
        return config
223
224
225
class Warping(tfkl.Layer):
226
    """
227
    Warps an image with DDF.
228
229
    Reference:
230
231
    https://github.com/adalca/neurite/blob/legacy/neuron/utils.py
232
    where vol = image, loc_shift = ddf
233
    """
234
235
    def __init__(self, fixed_image_size: tuple, name: str = "warping", **kwargs):
236
        """
237
        Init.
238
239
        :param fixed_image_size: shape = (f_dim1, f_dim2, f_dim3)
240
             or (f_dim1, f_dim2, f_dim3, ch) with the last channel for features
241
        :param name: name of the layer
242
        :param kwargs: additional arguments.
243
        """
244
        super().__init__(name=name, **kwargs)
245
        self._fixed_image_size = fixed_image_size
246
        # shape = (1, f_dim1, f_dim2, f_dim3, 3)
247
        self.grid_ref = layer_util.get_reference_grid(grid_size=fixed_image_size)
248
        self.grid_ref = self.grid_ref[None, ...]
249
250
    def call(self, inputs, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"inputs" missing in parameter type documentation
Loading history...
251
        """
252
        :param inputs: (ddf, image)
253
254
          - ddf, shape = (batch, f_dim1, f_dim2, f_dim3, 3)
255
          - image, shape = (batch, m_dim1, m_dim2, m_dim3)
256
        :param kwargs: additional arguments.
257
        :return: shape = (batch, f_dim1, f_dim2, f_dim3)
258
        """
259
        ddf, image = inputs
260
        return layer_util.resample(vol=image, loc=self.grid_ref + ddf)
261
262
    def get_config(self) -> dict:
263
        """Return the config dictionary for recreating this class."""
264
        config = super().get_config()
265
        config["fixed_image_size"] = self._fixed_image_size
266
        return config
267
268
269
class ResidualBlock(tfkl.Layer):
270
    """
271
    A block with skip links and layer - norm - activation.
272
    """
273
274
    def __init__(
275
        self,
276
        layer_name: str,
277
        num_layers: int = 2,
278
        norm_name: str = "batch",
279
        activation: str = "relu",
280
        name: str = "res_block",
281
        **kwargs,
282
    ):
283
        """
284
        Init.
285
286
        :param layer_name: class of the layer to be wrapped.
287
        :param num_layers: number of layers/blocks.
288
        :param norm_name: class of the normalization layer.
289
        :param activation: name of activation.
290
        :param name: name of the block layer.
291
        :param kwargs: additional arguments.
292
        """
293
        super().__init__()
294
        self._num_layers = num_layers
295
        self._config = dict(
296
            layer_name=layer_name,
297
            num_layers=num_layers,
298
            norm_name=norm_name,
299
            activation=activation,
300
            name=name,
301
            **kwargs,
302
        )
303
        self._layers = [
304
            LAYER_DICT[layer_name](use_bias=False, **kwargs) for _ in range(num_layers)
305
        ]
306
        self._norms = [NORM_DICT[norm_name]() for _ in range(num_layers)]
307
        self._acts = [tfkl.Activation(activation=activation) for _ in range(num_layers)]
308
309
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
310
        """
311
        Forward.
312
313
        :param inputs: inputs for the layer
314
        :param training: training flag for normalization layers (default: None)
315
        :param kwargs: additional arguments.
316
        :return:
317
        """
318
319
        output = inputs
320
        for i in range(self._num_layers):
321
            output = self._layers[i](inputs=output)
322
            output = self._norms[i](inputs=output, training=training)
323
            if i == self._num_layers - 1:
324
                # last block
325
                output = output + inputs
326
            output = self._acts[i](output)
327
        return output
328
329
    def get_config(self) -> dict:
330
        """Return the config dictionary for recreating this class."""
331
        config = super().get_config()
332
        config.update(self._config)
333
        return config
334
335
336
class ResidualConv3dBlock(ResidualBlock):
337
    """
338
    A conv3d residual block
339
    """
340
341
    def __init__(
342
        self,
343
        name: str = "conv3d_res_block",
344
        **kwargs,
345
    ):
346
        """
347
        Init.
348
349
        :param name: name of the layer
350
        :param kwargs: additional arguments.
351
        """
352
        super().__init__(layer_name="conv3d", name=name, **kwargs)
353
354
355
class IntDVF(tfkl.Layer):
356
    """
357
    Integrate DVF to get DDF.
358
359
    Reference:
360
361
    - integrate_vec of neuron
362
      https://github.com/adalca/neurite/blob/legacy/neuron/utils.py
363
    """
364
365
    def __init__(
366
        self,
367
        fixed_image_size: tuple,
368
        num_steps: int = 7,
369
        name: str = "int_dvf",
370
        **kwargs,
371
    ):
372
        """
373
        Init.
374
375
        :param fixed_image_size: tuple, (f_dim1, f_dim2, f_dim3)
376
        :param num_steps: int, number of steps for integration
377
        :param name: name of the layer
378
        :param kwargs: additional arguments.
379
        """
380
        super().__init__(name=name, **kwargs)
381
        assert len(fixed_image_size) == 3
382
        self._fixed_image_size = fixed_image_size
383
        self._num_steps = num_steps
384
        self._warping = Warping(fixed_image_size=fixed_image_size)
385
386
    def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor:
387
        """
388
        :param inputs: dvf, shape = (batch, f_dim1, f_dim2, f_dim3, 3)
389
        :param kwargs: additional arguments.
390
        :return: ddf, shape = (batch, f_dim1, f_dim2, f_dim3, 3)
391
        """
392
        ddf = inputs / (2 ** self._num_steps)
393
        for _ in range(self._num_steps):
394
            ddf += self._warping(inputs=[ddf, ddf])
395
        return ddf
396
397
    def get_config(self) -> dict:
398
        """Return the config dictionary for recreating this class."""
399
        config = super().get_config()
400
        config["fixed_image_size"] = self._fixed_image_size
401
        config["num_steps"] = self._num_steps
402
        return config
403
404
405
class ResizeCPTransform(tfkl.Layer):
406
    """
407
    Layer for getting the control points from the output of a image-to-image network.
408
    It uses an anti-aliasing Gaussian filter before down-sampling.
409
    """
410
411
    def __init__(
412
        self, control_point_spacing: Union[List[int], Tuple[int, ...], int], **kwargs
413
    ):
414
        """
415
        :param control_point_spacing: list or int
416
        :param kwargs: additional arguments.
417
        """
418
        super().__init__(**kwargs)
419
420
        if isinstance(control_point_spacing, int):
421
            control_point_spacing = [control_point_spacing] * 3
422
423
        self.kernel_sigma = [
424
            0.44 * cp for cp in control_point_spacing
425
        ]  # 0.44 = ln(4)/pi
426
        self.cp_spacing = control_point_spacing
427
        self.kernel = None
428
        self._output_shape = None
429
        self._resize = None
430
431
    def build(self, input_shape):
432
        super().build(input_shape=input_shape)
433
434
        self.kernel = layer_util.gaussian_filter_3d(self.kernel_sigma)
435
        output_shape = tuple(
436
            tf.cast(tf.math.ceil(v / c) + 3, tf.int32)
437
            for v, c in zip(input_shape[1:-1], self.cp_spacing)
438
        )
439
        self._output_shape = output_shape
440
        self._resize = Resize3d(output_shape)
441
442
    def call(self, inputs, **kwargs) -> tf.Tensor:
443
        output = tf.nn.conv3d(
444
            inputs, self.kernel, strides=(1, 1, 1, 1, 1), padding="SAME"
445
        )
446
        output = self._resize(inputs=output)  # type: ignore
447
        return output
448
449
450
class BSplines3DTransform(tfkl.Layer):
451
    """
452
    Layer for BSplines interpolation with precomputed cubic spline kernel_size.
453
    It assumes a full sized image from which:
454
    1. it compute the contol points values by down-sampling the initial image
455
    2. performs the interpolation
456
    3. crops the image around the valid values.
457
    """
458
459
    def __init__(
460
        self,
461
        cp_spacing: Union[Tuple[int, ...], int],
462
        output_shape: Tuple[int, ...],
463
        **kwargs,
464
    ):
465
        """
466
        Init.
467
468
        :param cp_spacing: int or tuple of three ints specifying the spacing (in pixels)
469
            in each dimension. When a single int is used,
470
            the same spacing to all dimensions is used
471
        :param output_shape: (batch_size, dim0, dim1, dim2, 3) of the high resolution
472
            deformation fields.
473
        :param kwargs: additional arguments.
474
        """
475
        super().__init__(**kwargs)
476
477
        self._output_shape = output_shape
478
        if isinstance(cp_spacing, int):
479
            cp_spacing = (cp_spacing, cp_spacing, cp_spacing)
480
        self.cp_spacing = cp_spacing
481
482
    def build(self, input_shape: tuple):
0 ignored issues
show
introduced by
Redundant returns documentation
Loading history...
483
        """
484
        :param input_shape: tuple with the input shape
485
        :return: None
486
        """
487
488
        super().build(input_shape=input_shape)
489
490
        b = {
491
            0: lambda u: np.float64((1 - u) ** 3 / 6),
492
            1: lambda u: np.float64((3 * (u ** 3) - 6 * (u ** 2) + 4) / 6),
493
            2: lambda u: np.float64((-3 * (u ** 3) + 3 * (u ** 2) + 3 * u + 1) / 6),
494
            3: lambda u: np.float64(u ** 3 / 6),
495
        }
496
497
        filters = np.zeros(
498
            (
499
                4 * self.cp_spacing[0],
500
                4 * self.cp_spacing[1],
501
                4 * self.cp_spacing[2],
502
                3,
503
                3,
504
            ),
505
            dtype=np.float32,
506
        )
507
508
        u_arange = 1 - np.arange(
509
            1 / (2 * self.cp_spacing[0]), 1, 1 / self.cp_spacing[0]
510
        )
511
        v_arange = 1 - np.arange(
512
            1 / (2 * self.cp_spacing[1]), 1, 1 / self.cp_spacing[1]
513
        )
514
        w_arange = 1 - np.arange(
515
            1 / (2 * self.cp_spacing[2]), 1, 1 / self.cp_spacing[2]
516
        )
517
518
        filter_idx = [[0, 1, 2, 3] for _ in range(3)]
519
        filter_coord = list(itertools.product(*filter_idx))
520
521
        for f_idx in filter_coord:
522
            for it_dim in range(3):
523
                filters[
524
                    f_idx[0] * self.cp_spacing[0] : (f_idx[0] + 1) * self.cp_spacing[0],
525
                    f_idx[1] * self.cp_spacing[1] : (f_idx[1] + 1) * self.cp_spacing[1],
526
                    f_idx[2] * self.cp_spacing[2] : (f_idx[2] + 1) * self.cp_spacing[2],
527
                    it_dim,
528
                    it_dim,
529
                ] = (
530
                    b[f_idx[0]](u_arange)[:, None, None]
531
                    * b[f_idx[1]](v_arange)[None, :, None]
532
                    * b[f_idx[2]](w_arange)[None, None, :]
533
                )
534
535
        self.filter = tf.convert_to_tensor(filters)
536
537
    def interpolate(self, field) -> tf.Tensor:
0 ignored issues
show
introduced by
"field" missing in parameter type documentation
Loading history...
538
        """
539
        :param field: tf.Tensor with shape=number_of_control_points_per_dim
540
        :return: interpolated_field: tf.Tensor
541
        """
542
543
        image_shape = tuple(
0 ignored issues
show
introduced by
Consider using a generator instead 'tuple((a - 1) * b + 4 * b for (a, b) in zip(field.shape[1:-1], self.cp_spacing))'
Loading history...
544
            [(a - 1) * b + 4 * b for a, b in zip(field.shape[1:-1], self.cp_spacing)]
545
        )
546
547
        output_shape = (field.shape[0],) + image_shape + (3,)
548
        return tf.nn.conv3d_transpose(
549
            field,
550
            self.filter,
551
            output_shape=output_shape,
552
            strides=self.cp_spacing,
553
            padding="VALID",
554
        )
555
556
    def call(self, inputs, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"inputs" missing in parameter type documentation
Loading history...
557
        """
558
        :param inputs: tf.Tensor defining a low resolution free-form deformation field
559
        :param kwargs: additional arguments.
560
        :return: interpolated_field: tf.Tensor of shape=self.input_shape
561
        """
562
        high_res_field = self.interpolate(inputs)
563
564
        index = [int(3 * c) for c in self.cp_spacing]
565
        return high_res_field[
566
            :,
567
            index[0] : index[0] + self._output_shape[0],
568
            index[1] : index[1] + self._output_shape[1],
569
            index[2] : index[2] + self._output_shape[2],
570
        ]
571
572
573
class Extraction(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
574
    def __init__(
575
        self,
576
        image_size: Tuple[int, ...],
577
        extract_levels: Tuple[int, ...],
578
        out_channels: int,
579
        out_kernel_initializer: str,
580
        out_activation: str,
581
        name: str = "Extraction",
582
    ):
583
        """
584
        :param image_size: such as (dim1, dim2, dim3)
585
        :param extract_levels: number of extraction levels.
586
        :param out_channels: number of channels for the extractions
587
        :param out_kernel_initializer: initializer to use for kernels.
588
        :param out_activation: activation to use at end layer.
589
        :param name: name of the layer
590
        """
591
        super().__init__(name=name)
592
        self.extract_levels = extract_levels
593
        self.max_level = max(extract_levels)
594
        self.layers = [
595
            tf.keras.Sequential(
596
                [
597
                    tfkl.Conv3D(
598
                        filters=out_channels,
599
                        kernel_size=3,
600
                        strides=1,
601
                        padding="same",
602
                        kernel_initializer=out_kernel_initializer,
603
                        activation=out_activation,
604
                    ),
605
                    Resize3d(shape=image_size),
606
                ]
607
            )
608
            for _ in extract_levels
609
        ]
610
611
    def call(self, inputs: List[tf.Tensor], **kwargs) -> tf.Tensor:
612
        """
613
614
        :param inputs: a list of tensors
615
        :param kwargs:
616
        :return:
617
        """
618
        outputs = [
619
            self.layers[idx](inputs=inputs[self.max_level - level])
620
            for idx, level in enumerate(self.extract_levels)
621
        ]
622
        if len(self.extract_levels) == 1:
623
            return outputs[0]
624
        return tf.add_n(outputs) / len(self.extract_levels)
625