deepreg.model.layer.NormBlock.__init__()   A
last analyzed

Complexity

Conditions 1

Size

Total Lines 28
Code Lines 17

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 17
dl 0
loc 28
rs 9.55
c 0
b 0
f 0
cc 1
nop 6
1
"""This module defines custom layers."""
2
import itertools
3
from typing import List, Tuple, Union
4
5
import numpy as np
6
import tensorflow as tf
7
import tensorflow.keras.layers as tfkl
8
9
from deepreg.model import layer_util
10
11
LAYER_DICT = dict(conv3d=tfkl.Conv3D, deconv3d=tfkl.Conv3DTranspose)
12
NORM_DICT = dict(batch=tfkl.BatchNormalization, layer=tfkl.LayerNormalization)
13
14
15
class NormBlock(tfkl.Layer):
16
    """
17
    A block with layer - norm - activation.
18
    """
19
20
    def __init__(
21
        self,
22
        layer_name: str,
23
        norm_name: str = "batch",
24
        activation: str = "relu",
25
        name: str = "norm_block",
26
        **kwargs,
27
    ):
28
        """
29
        Init.
30
31
        :param layer_name: class of the layer to be wrapped.
32
        :param norm_name: class of the normalization layer.
33
        :param activation: name of activation.
34
        :param name: name of the block layer.
35
        :param kwargs: additional arguments.
36
        """
37
        super().__init__()
38
        self._config = dict(
39
            layer_name=layer_name,
40
            norm_name=norm_name,
41
            activation=activation,
42
            name=name,
43
            **kwargs,
44
        )
45
        self._layer = LAYER_DICT[layer_name](use_bias=False, **kwargs)
46
        self._norm = NORM_DICT[norm_name]()
47
        self._act = tfkl.Activation(activation=activation)
48
49
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
50
        """
51
        Forward.
52
53
        :param inputs: inputs for the layer
54
        :param training: training flag for normalization layers (default: None)
55
        :param kwargs: additional arguments.
56
        :return:
57
        """
58
        output = self._layer(inputs=inputs)
59
        output = self._norm(inputs=output, training=training)
60
        output = self._act(output)
61
        return output
62
63
    def get_config(self) -> dict:
64
        """Return the config dictionary for recreating this class."""
65
        config = super().get_config()
66
        config.update(self._config)
67
        return config
68
69
70
class Conv3dBlock(NormBlock):
71
    """
72
    A conv3d block having conv3d - norm - activation.
73
    """
74
75
    def __init__(
76
        self,
77
        name: str = "conv3d_block",
78
        **kwargs,
79
    ):
80
        """
81
        Init.
82
83
        :param name: name of the layer
84
        :param kwargs: additional arguments.
85
        """
86
        super().__init__(layer_name="conv3d", name=name, **kwargs)
87
88
89
class Deconv3dBlock(NormBlock):
90
    """
91
    A deconv3d block having conv3d - norm - activation.
92
    """
93
94
    def __init__(
95
        self,
96
        name: str = "deconv3d_block",
97
        **kwargs,
98
    ):
99
        """
100
        Init.
101
102
        :param name: name of the layer
103
        :param kwargs: additional arguments.
104
        """
105
        super().__init__(layer_name="deconv3d", name=name, **kwargs)
106
107
108
class Resize3d(tfkl.Layer):
109
    """
110
    Resize image in two folds.
111
112
    - resize dim2 and dim3
113
    - resize dim1 and dim2
114
    """
115
116
    def __init__(
117
        self,
118
        shape: tuple,
119
        method: str = tf.image.ResizeMethod.BILINEAR,
120
        name: str = "resize3d",
121
    ):
122
        """
123
        Init, save arguments.
124
125
        :param shape: (dim1, dim2, dim3)
126
        :param method: tf.image.ResizeMethod
127
        :param name: name of the layer
128
        """
129
        super().__init__(name=name)
130
        assert len(shape) == 3
131
        self._shape = shape
132
        self._method = method
133
134
    def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor:
135
        """
136
        Perform two fold resize.
137
138
        :param inputs: shape = (batch, dim1, dim2, dim3, channels)
139
                                     or (batch, dim1, dim2, dim3)
140
                                     or (dim1, dim2, dim3)
141
        :param kwargs: additional arguments
142
        :return: shape = (batch, out_dim1, out_dim2, out_dim3, channels)
143
                                or (batch, dim1, dim2, dim3)
144
                                or (dim1, dim2, dim3)
145
        """
146
        # sanity check
147
        image = inputs
148
        image_dim = len(image.shape)
149
150
        # init
151
        if image_dim == 5:
152
            has_channel = True
153
            has_batch = True
154
            input_image_shape = image.shape[1:4]
155
        elif image_dim == 4:
156
            has_channel = False
157
            has_batch = True
158
            input_image_shape = image.shape[1:4]
159
        elif image_dim == 3:
160
            has_channel = False
161
            has_batch = False
162
            input_image_shape = image.shape[0:3]
163
        else:
164
            raise ValueError(
165
                "Resize3d takes input image of dimension 3 or 4 or 5, "
166
                "corresponding to (dim1, dim2, dim3) "
167
                "or (batch, dim1, dim2, dim3) "
168
                "or (batch, dim1, dim2, dim3, channels), "
169
                "got image shape{}".format(image.shape)
170
            )
171
172
        # no need of resize
173
        if input_image_shape == tuple(self._shape):
174
            return image
175
176
        # expand to five dimensions
177
        if not has_batch:
178
            image = tf.expand_dims(image, axis=0)
179
        if not has_channel:
180
            image = tf.expand_dims(image, axis=-1)
181
        assert len(image.shape) == 5  # (batch, dim1, dim2, dim3, channels)
182
        image_shape = tf.shape(image)
183
184
        # merge axis 0 and 1
185
        output = tf.reshape(
186
            image, (-1, image_shape[2], image_shape[3], image_shape[4])
187
        )  # (batch * dim1, dim2, dim3, channels)
188
189
        # resize dim2 and dim3
190
        output = tf.image.resize(
191
            images=output, size=self._shape[1:3], method=self._method
192
        )  # (batch * dim1, out_dim2, out_dim3, channels)
193
194
        # split axis 0 and merge axis 3 and 4
195
        output = tf.reshape(
196
            output,
197
            shape=(-1, image_shape[1], self._shape[1], self._shape[2] * image_shape[4]),
198
        )  # (batch, dim1, out_dim2, out_dim3 * channels)
199
200
        # resize dim1 and dim2
201
        output = tf.image.resize(
202
            images=output, size=self._shape[:2], method=self._method
203
        )  # (batch, out_dim1, out_dim2, out_dim3 * channels)
204
205
        # reshape
206
        output = tf.reshape(
207
            output, shape=[-1, *self._shape, image_shape[4]]
208
        )  # (batch, out_dim1, out_dim2, out_dim3, channels)
209
210
        # squeeze to original dimension
211
        if not has_batch:
212
            output = tf.squeeze(output, axis=0)
213
        if not has_channel:
214
            output = tf.squeeze(output, axis=-1)
215
        return output
216
217
    def get_config(self) -> dict:
218
        """Return the config dictionary for recreating this class."""
219
        config = super().get_config()
220
        config["shape"] = self._shape
221
        config["method"] = self._method
222
        return config
223
224
225
class Warping(tfkl.Layer):
226
    """
227
    Warps an image with DDF.
228
229
    Reference:
230
231
    https://github.com/adalca/neurite/blob/legacy/neuron/utils.py
232
    where vol = image, loc_shift = ddf
233
    """
234
235
    def __init__(self, fixed_image_size: tuple, name: str = "warping", **kwargs):
236
        """
237
        Init.
238
239
        :param fixed_image_size: shape = (f_dim1, f_dim2, f_dim3)
240
             or (f_dim1, f_dim2, f_dim3, ch) with the last channel for features
241
        :param name: name of the layer
242
        :param kwargs: additional arguments.
243
        """
244
        super().__init__(name=name, **kwargs)
245
        self._fixed_image_size = fixed_image_size
246
        # shape = (1, f_dim1, f_dim2, f_dim3, 3)
247
        self.grid_ref = layer_util.get_reference_grid(grid_size=fixed_image_size)[
248
            None, ...
249
        ]
250
251
    def call(self, inputs, **kwargs) -> tf.Tensor:
252
        """
253
        :param inputs: (ddf, image)
254
255
          - ddf, shape = (batch, f_dim1, f_dim2, f_dim3, 3)
256
          - image, shape = (batch, m_dim1, m_dim2, m_dim3)
257
        :param kwargs: additional arguments.
258
        :return: shape = (batch, f_dim1, f_dim2, f_dim3)
259
        """
260
        ddf, image = inputs
261
        return layer_util.resample(vol=image, loc=self.grid_ref + ddf)
262
263
    def get_config(self) -> dict:
264
        """Return the config dictionary for recreating this class."""
265
        config = super().get_config()
266
        config["fixed_image_size"] = self._fixed_image_size
267
        return config
268
269
270
class ResidualBlock(tfkl.Layer):
271
    """
272
    A block with skip links and layer - norm - activation.
273
    """
274
275
    def __init__(
276
        self,
277
        layer_name: str,
278
        num_layers: int = 2,
279
        norm_name: str = "batch",
280
        activation: str = "relu",
281
        name: str = "res_block",
282
        **kwargs,
283
    ):
284
        """
285
        Init.
286
287
        :param layer_name: class of the layer to be wrapped.
288
        :param num_layers: number of layers/blocks.
289
        :param norm_name: class of the normalization layer.
290
        :param activation: name of activation.
291
        :param name: name of the block layer.
292
        :param kwargs: additional arguments.
293
        """
294
        super().__init__()
295
        self._num_layers = num_layers
296
        self._config = dict(
297
            layer_name=layer_name,
298
            num_layers=num_layers,
299
            norm_name=norm_name,
300
            activation=activation,
301
            name=name,
302
            **kwargs,
303
        )
304
        self._layers = [
305
            LAYER_DICT[layer_name](use_bias=False, **kwargs) for _ in range(num_layers)
306
        ]
307
        self._norms = [NORM_DICT[norm_name]() for _ in range(num_layers)]
308
        self._acts = [tfkl.Activation(activation=activation) for _ in range(num_layers)]
309
310
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
311
        """
312
        Forward.
313
314
        :param inputs: inputs for the layer
315
        :param training: training flag for normalization layers (default: None)
316
        :param kwargs: additional arguments.
317
        :return:
318
        """
319
320
        output = inputs
321
        for i in range(self._num_layers):
322
            output = self._layers[i](inputs=output)
323
            output = self._norms[i](inputs=output, training=training)
324
            if i == self._num_layers - 1:
325
                # last block
326
                output = output + inputs
327
            output = self._acts[i](output)
328
        return output
329
330
    def get_config(self) -> dict:
331
        """Return the config dictionary for recreating this class."""
332
        config = super().get_config()
333
        config.update(self._config)
334
        return config
335
336
337
class ResidualConv3dBlock(ResidualBlock):
338
    """
339
    A conv3d residual block
340
    """
341
342
    def __init__(
343
        self,
344
        name: str = "conv3d_res_block",
345
        **kwargs,
346
    ):
347
        """
348
        Init.
349
350
        :param name: name of the layer
351
        :param kwargs: additional arguments.
352
        """
353
        super().__init__(layer_name="conv3d", name=name, **kwargs)
354
355
356
class IntDVF(tfkl.Layer):
357
    """
358
    Integrate DVF to get DDF.
359
360
    Reference:
361
362
    - integrate_vec of neuron
363
      https://github.com/adalca/neurite/blob/legacy/neuron/utils.py
364
    """
365
366
    def __init__(
367
        self,
368
        fixed_image_size: tuple,
369
        num_steps: int = 7,
370
        name: str = "int_dvf",
371
        **kwargs,
372
    ):
373
        """
374
        Init.
375
376
        :param fixed_image_size: tuple, (f_dim1, f_dim2, f_dim3)
377
        :param num_steps: int, number of steps for integration
378
        :param name: name of the layer
379
        :param kwargs: additional arguments.
380
        """
381
        super().__init__(name=name, **kwargs)
382
        assert len(fixed_image_size) == 3
383
        self._fixed_image_size = fixed_image_size
384
        self._num_steps = num_steps
385
        self._warping = Warping(fixed_image_size=fixed_image_size)
386
387
    def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor:
388
        """
389
        :param inputs: dvf, shape = (batch, f_dim1, f_dim2, f_dim3, 3)
390
        :param kwargs: additional arguments.
391
        :return: ddf, shape = (batch, f_dim1, f_dim2, f_dim3, 3)
392
        """
393
        ddf = inputs / (2 ** self._num_steps)
394
        for _ in range(self._num_steps):
395
            ddf += self._warping(inputs=[ddf, ddf])
396
        return ddf
397
398
    def get_config(self) -> dict:
399
        """Return the config dictionary for recreating this class."""
400
        config = super().get_config()
401
        config["fixed_image_size"] = self._fixed_image_size
402
        config["num_steps"] = self._num_steps
403
        return config
404
405
406
class ResizeCPTransform(tfkl.Layer):
407
    """
408
    Layer for getting the control points from the output of a image-to-image network.
409
    It uses an anti-aliasing Gaussian filter before down-sampling.
410
    """
411
412
    def __init__(
413
        self, control_point_spacing: Union[List[int], Tuple[int, ...], int], **kwargs
414
    ):
415
        """
416
        :param control_point_spacing: list or int
417
        :param kwargs: additional arguments.
418
        """
419
        super().__init__(**kwargs)
420
421
        if isinstance(control_point_spacing, int):
422
            control_point_spacing = [control_point_spacing] * 3
423
424
        self.kernel_sigma = [
425
            0.44 * cp for cp in control_point_spacing
426
        ]  # 0.44 = ln(4)/pi
427
        self.cp_spacing = control_point_spacing
428
        self.kernel = None
429
        self._output_shape = None
430
        self._resize = None
431
432
    def build(self, input_shape):
433
        super().build(input_shape=input_shape)
434
435
        self.kernel = layer_util.gaussian_filter_3d(self.kernel_sigma)
436
        output_shape = tuple(
437
            tf.cast(tf.math.ceil(v / c) + 3, tf.int32)
438
            for v, c in zip(input_shape[1:-1], self.cp_spacing)
439
        )
440
        self._output_shape = output_shape
441
        self._resize = Resize3d(output_shape)
442
443
    def call(self, inputs, **kwargs) -> tf.Tensor:
444
        output = tf.nn.conv3d(
445
            inputs, self.kernel, strides=(1, 1, 1, 1, 1), padding="SAME"
446
        )
447
        output = self._resize(inputs=output)  # type: ignore
448
        return output
449
450
451
class BSplines3DTransform(tfkl.Layer):
452
    """
453
    Layer for BSplines interpolation with precomputed cubic spline kernel_size.
454
    It assumes a full sized image from which:
455
    1. it compute the contol points values by down-sampling the initial image
456
    2. performs the interpolation
457
    3. crops the image around the valid values.
458
    """
459
460
    def __init__(
461
        self,
462
        cp_spacing: Union[Tuple[int, ...], int],
463
        output_shape: Tuple[int, ...],
464
        **kwargs,
465
    ):
466
        """
467
        Init.
468
469
        :param cp_spacing: int or tuple of three ints specifying the spacing (in pixels)
470
            in each dimension. When a single int is used,
471
            the same spacing to all dimensions is used
472
        :param output_shape: (batch_size, dim0, dim1, dim2, 3) of the high resolution
473
            deformation fields.
474
        :param kwargs: additional arguments.
475
        """
476
        super().__init__(**kwargs)
477
478
        self._output_shape = output_shape
479
        if isinstance(cp_spacing, int):
480
            cp_spacing = (cp_spacing, cp_spacing, cp_spacing)
481
        self.cp_spacing = cp_spacing
482
483
    def build(self, input_shape: tuple):
484
        """
485
        :param input_shape: tuple with the input shape
486
        :return: None
487
        """
488
489
        super().build(input_shape=input_shape)
490
491
        b = {
492
            0: lambda u: np.float64((1 - u) ** 3 / 6),
493
            1: lambda u: np.float64((3 * (u ** 3) - 6 * (u ** 2) + 4) / 6),
494
            2: lambda u: np.float64((-3 * (u ** 3) + 3 * (u ** 2) + 3 * u + 1) / 6),
495
            3: lambda u: np.float64(u ** 3 / 6),
496
        }
497
498
        filters = np.zeros(
499
            (
500
                4 * self.cp_spacing[0],
501
                4 * self.cp_spacing[1],
502
                4 * self.cp_spacing[2],
503
                3,
504
                3,
505
            ),
506
            dtype=np.float32,
507
        )
508
509
        u_arange = 1 - np.arange(
510
            1 / (2 * self.cp_spacing[0]), 1, 1 / self.cp_spacing[0]
511
        )
512
        v_arange = 1 - np.arange(
513
            1 / (2 * self.cp_spacing[1]), 1, 1 / self.cp_spacing[1]
514
        )
515
        w_arange = 1 - np.arange(
516
            1 / (2 * self.cp_spacing[2]), 1, 1 / self.cp_spacing[2]
517
        )
518
519
        filter_idx = [[0, 1, 2, 3] for _ in range(3)]
520
        filter_coord = list(itertools.product(*filter_idx))
521
522
        for f_idx in filter_coord:
523
            for it_dim in range(3):
524
                filters[
525
                    f_idx[0] * self.cp_spacing[0] : (f_idx[0] + 1) * self.cp_spacing[0],
526
                    f_idx[1] * self.cp_spacing[1] : (f_idx[1] + 1) * self.cp_spacing[1],
527
                    f_idx[2] * self.cp_spacing[2] : (f_idx[2] + 1) * self.cp_spacing[2],
528
                    it_dim,
529
                    it_dim,
530
                ] = (
531
                    b[f_idx[0]](u_arange)[:, None, None]
532
                    * b[f_idx[1]](v_arange)[None, :, None]
533
                    * b[f_idx[2]](w_arange)[None, None, :]
534
                )
535
536
        self.filter = tf.convert_to_tensor(filters)
537
538
    def interpolate(self, field) -> tf.Tensor:
539
        """
540
        :param field: tf.Tensor with shape=number_of_control_points_per_dim
541
        :return: interpolated_field: tf.Tensor
542
        """
543
544
        image_shape = tuple(
545
            [(a - 1) * b + 4 * b for a, b in zip(field.shape[1:-1], self.cp_spacing)]
546
        )
547
548
        output_shape = (field.shape[0],) + image_shape + (3,)
549
        return tf.nn.conv3d_transpose(
550
            field,
551
            self.filter,
552
            output_shape=output_shape,
553
            strides=self.cp_spacing,
554
            padding="VALID",
555
        )
556
557
    def call(self, inputs, **kwargs) -> tf.Tensor:
558
        """
559
        :param inputs: tf.Tensor defining a low resolution free-form deformation field
560
        :param kwargs: additional arguments.
561
        :return: interpolated_field: tf.Tensor of shape=self.input_shape
562
        """
563
        high_res_field = self.interpolate(inputs)
564
565
        index = [int(3 * c) for c in self.cp_spacing]
566
        return high_res_field[
567
            :,
568
            index[0] : index[0] + self._output_shape[0],
569
            index[1] : index[1] + self._output_shape[1],
570
            index[2] : index[2] + self._output_shape[2],
571
        ]
572
573
574
class Extraction(tfkl.Layer):
575
    def __init__(
576
        self,
577
        image_size: Tuple[int, ...],
578
        extract_levels: Tuple[int, ...],
579
        out_channels: int,
580
        out_kernel_initializer: str,
581
        out_activation: str,
582
        name: str = "Extraction",
583
    ):
584
        """
585
        :param image_size: such as (dim1, dim2, dim3)
586
        :param extract_levels: number of extraction levels.
587
        :param out_channels: number of channels for the extractions
588
        :param out_kernel_initializer: initializer to use for kernels.
589
        :param out_activation: activation to use at end layer.
590
        :param name: name of the layer
591
        """
592
        super().__init__(name=name)
593
        self.extract_levels = extract_levels
594
        self.max_level = max(extract_levels)
595
        self.layers = [
596
            tf.keras.Sequential(
597
                [
598
                    tfkl.Conv3D(
599
                        filters=out_channels,
600
                        kernel_size=3,
601
                        strides=1,
602
                        padding="same",
603
                        kernel_initializer=out_kernel_initializer,
604
                        activation=out_activation,
605
                    ),
606
                    Resize3d(shape=image_size),
607
                ]
608
            )
609
            for _ in extract_levels
610
        ]
611
612
    def call(self, inputs: List[tf.Tensor], **kwargs) -> tf.Tensor:
613
        """
614
        Calculate the mean over some selected inputs.
615
616
        :param inputs: a list of tensors
617
        :param kwargs:
618
        :return:
619
        """
620
        outputs = [
621
            self.layers[idx](inputs=inputs[self.max_level - level])
622
            for idx, level in enumerate(self.extract_levels)
623
        ]
624
        if len(self.extract_levels) == 1:
625
            return outputs[0]
626
        return tf.add_n(outputs) / len(self.extract_levels)
627