Passed
Pull Request — main (#656)
by Yunguan
11:15 queued 50s
created

deepreg.model.layer.Deconv3d.build()   A

Complexity

Conditions 4

Size

Total Lines 36
Code Lines 22

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 22
dl 0
loc 36
rs 9.352
c 0
b 0
f 0
cc 4
nop 2
1
"""This module defines custom layers."""
2
import itertools
3
4
import numpy as np
5
import tensorflow as tf
6
import tensorflow.keras.layers as tfkl
7
8
import deepreg.model.layer_util as layer_util
9
10
LAYER_DICT = dict(conv3d=tfkl.Conv3D, deconv3d=tfkl.Conv3DTranspose)
11
NORM_DICT = dict(batch=tfkl.BatchNormalization, layer=tfkl.LayerNormalization)
12
13
14
class NormBlock(tfkl.Layer):
15
    """
16
    A block with layer - norm - activation.
17
    """
18
19
    def __init__(
20
        self,
21
        layer_name: str,
22
        norm_name: str = "batch",
23
        activation: str = "relu",
24
        name: str = "norm_block",
25
        **kwargs,
26
    ):
27
        """
28
        Init.
29
30
        :param layer_name: class of the layer to be wrapped.
31
        :param norm_name: class of the normalization layer.
32
        :param activation: name of activation.
33
        :param name: name of the block layer.
34
        :param kwargs: additional arguments.
35
        """
36
        super().__init__()
37
        self._config = dict(
38
            layer_name=layer_name,
39
            norm_name=norm_name,
40
            activation=activation,
41
            name=name,
42
            **kwargs,
43
        )
44
        self._layer = LAYER_DICT[layer_name](use_bias=False, **kwargs)
45
        self._norm = NORM_DICT[norm_name]()
46
        self._act = tfkl.Activation(activation=activation)
47
48
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
49
        """
50
        Forward.
51
52
        :param inputs: inputs for the layer
53
        :param training: training flag for normalization layers (default: None)
54
        :param kwargs: additional arguments.
55
        :return:
56
        """
57
        output = self._layer(inputs=inputs)
58
        output = self._norm(inputs=output, training=training)
59
        output = self._act(output)
60
        return output
61
62
    def get_config(self) -> dict:
63
        """Return the config dictionary for recreating this class."""
64
        config = super().get_config()
65
        config.update(self._config)
66
        return config
67
68
69
class Conv3dBlock(NormBlock):
70
    """
71
    A conv3d block having conv3d - norm - activation.
72
    """
73
74
    def __init__(
75
        self,
76
        name: str = "conv3d_block",
77
        **kwargs,
78
    ):
79
        """
80
        Init.
81
82
        :param name: name of the layer
83
        :param kwargs: additional arguments.
84
        """
85
        super().__init__(layer_name="conv3d", name=name, **kwargs)
86
87
88
class Deconv3dBlock(NormBlock):
89
    """
90
    A deconv3d block having conv3d - norm - activation.
91
    """
92
93
    def __init__(
94
        self,
95
        name: str = "deconv3d_block",
96
        **kwargs,
97
    ):
98
        """
99
        Init.
100
101
        :param name: name of the layer
102
        :param kwargs: additional arguments.
103
        """
104
        super().__init__(layer_name="deconv3d", name=name, **kwargs)
105
106
107
class Resize3d(tfkl.Layer):
108
    """
109
    Resize image in two folds.
110
111
    - resize dim2 and dim3
112
    - resize dim1 and dim2
113
    """
114
115
    def __init__(
116
        self,
117
        shape: tuple,
118
        method: str = tf.image.ResizeMethod.BILINEAR,
119
        name: str = "resize3d",
120
    ):
121
        """
122
        Init, save arguments.
123
124
        :param shape: (dim1, dim2, dim3)
125
        :param method: tf.image.ResizeMethod
126
        :param name: name of the layer
127
        """
128
        super().__init__(name=name)
129
        assert len(shape) == 3
130
        self._shape = shape
131
        self._method = method
132
133
    def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"ValueError" not documented as being raised
Loading history...
134
        """
135
        Perform two fold resize.
136
137
        :param inputs: shape = (batch, dim1, dim2, dim3, channels)
138
                                     or (batch, dim1, dim2, dim3)
139
                                     or (dim1, dim2, dim3)
140
        :param kwargs: additional arguments
141
        :return: shape = (batch, out_dim1, out_dim2, out_dim3, channels)
142
                                or (batch, dim1, dim2, dim3)
143
                                or (dim1, dim2, dim3)
144
        """
145
        # sanity check
146
        image = inputs
147
        image_dim = len(image.shape)
148
149
        # init
150
        if image_dim == 5:
151
            has_channel = True
152
            has_batch = True
153
            input_image_shape = image.shape[1:4]
154
        elif image_dim == 4:
155
            has_channel = False
156
            has_batch = True
157
            input_image_shape = image.shape[1:4]
158
        elif image_dim == 3:
159
            has_channel = False
160
            has_batch = False
161
            input_image_shape = image.shape[0:3]
162
        else:
163
            raise ValueError(
164
                "Resize3d takes input image of dimension 3 or 4 or 5, "
165
                "corresponding to (dim1, dim2, dim3) "
166
                "or (batch, dim1, dim2, dim3) "
167
                "or (batch, dim1, dim2, dim3, channels), "
168
                "got image shape{}".format(image.shape)
169
            )
170
171
        # no need of resize
172
        if input_image_shape == tuple(self._shape):
173
            return image
174
175
        # expand to five dimensions
176
        if not has_batch:
177
            image = tf.expand_dims(image, axis=0)
178
        if not has_channel:
179
            image = tf.expand_dims(image, axis=-1)
180
        assert len(image.shape) == 5  # (batch, dim1, dim2, dim3, channels)
181
        image_shape = tf.shape(image)
182
183
        # merge axis 0 and 1
184
        output = tf.reshape(
185
            image, (-1, image_shape[2], image_shape[3], image_shape[4])
186
        )  # (batch * dim1, dim2, dim3, channels)
187
188
        # resize dim2 and dim3
189
        output = tf.image.resize(
190
            images=output, size=self._shape[1:3], method=self._method
191
        )  # (batch * dim1, out_dim2, out_dim3, channels)
192
193
        # split axis 0 and merge axis 3 and 4
194
        output = tf.reshape(
195
            output,
196
            shape=(-1, image_shape[1], self._shape[1], self._shape[2] * image_shape[4]),
197
        )  # (batch, dim1, out_dim2, out_dim3 * channels)
198
199
        # resize dim1 and dim2
200
        output = tf.image.resize(
201
            images=output, size=self._shape[:2], method=self._method
202
        )  # (batch, out_dim1, out_dim2, out_dim3 * channels)
203
204
        # reshape
205
        output = tf.reshape(
206
            output, shape=[-1, *self._shape, image_shape[4]]
207
        )  # (batch, out_dim1, out_dim2, out_dim3, channels)
208
209
        # squeeze to original dimension
210
        if not has_batch:
211
            output = tf.squeeze(output, axis=0)
212
        if not has_channel:
213
            output = tf.squeeze(output, axis=-1)
214
        return output
215
216
    def get_config(self) -> dict:
217
        """Return the config dictionary for recreating this class."""
218
        config = super().get_config()
219
        config["shape"] = self._shape
220
        config["method"] = self._method
221
        return config
222
223
224
class Warping(tfkl.Layer):
225
    """
226
    Warps an image with DDF.
227
228
    Reference:
229
230
    https://github.com/adalca/neurite/blob/legacy/neuron/utils.py
231
    where vol = image, loc_shift = ddf
232
    """
233
234
    def __init__(self, fixed_image_size: tuple, name: str = "warping", **kwargs):
235
        """
236
        Init.
237
238
        :param fixed_image_size: shape = (f_dim1, f_dim2, f_dim3)
239
             or (f_dim1, f_dim2, f_dim3, ch) with the last channel for features
240
        :param name: name of the layer
241
        :param kwargs: additional arguments.
242
        """
243
        super().__init__(name=name, **kwargs)
244
        self._fixed_image_size = fixed_image_size
245
        # shape = (1, f_dim1, f_dim2, f_dim3, 3)
246
        self.grid_ref = layer_util.get_reference_grid(grid_size=fixed_image_size)
247
        self.grid_ref = self.grid_ref[None, ...]
248
249
    def call(self, inputs, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"inputs" missing in parameter type documentation
Loading history...
250
        """
251
        :param inputs: (ddf, image)
252
253
          - ddf, shape = (batch, f_dim1, f_dim2, f_dim3, 3)
254
          - image, shape = (batch, m_dim1, m_dim2, m_dim3)
255
        :param kwargs: additional arguments.
256
        :return: shape = (batch, f_dim1, f_dim2, f_dim3)
257
        """
258
        ddf, image = inputs
259
        return layer_util.resample(vol=image, loc=self.grid_ref + ddf)
260
261
    def get_config(self) -> dict:
262
        """Return the config dictionary for recreating this class."""
263
        config = super().get_config()
264
        config["fixed_image_size"] = self._fixed_image_size
265
        return config
266
267
268
class ResidualBlock(tfkl.Layer):
269
    """
270
    A block with skip links and layer - norm - activation.
271
    """
272
273
    def __init__(
274
        self,
275
        layer_name: str,
276
        num_layers: int = 2,
277
        norm_name: str = "batch",
278
        activation: str = "relu",
279
        name: str = "res_block",
280
        **kwargs,
281
    ):
282
        """
283
        Init.
284
285
        :param layer_name: class of the layer to be wrapped.
286
        :param num_layers: number of layers/blocks.
287
        :param norm_name: class of the normalization layer.
288
        :param activation: name of activation.
289
        :param name: name of the block layer.
290
        :param kwargs: additional arguments.
291
        """
292
        super().__init__()
293
        self._num_layers = num_layers
294
        self._config = dict(
295
            layer_name=layer_name,
296
            num_layers=num_layers,
297
            norm_name=norm_name,
298
            activation=activation,
299
            name=name,
300
            **kwargs,
301
        )
302
        self._layers = [
303
            LAYER_DICT[layer_name](use_bias=False, **kwargs) for _ in range(num_layers)
304
        ]
305
        self._norms = [NORM_DICT[norm_name]() for _ in range(num_layers)]
306
        self._acts = [tfkl.Activation(activation=activation) for _ in range(num_layers)]
307
308
    def call(self, inputs, training=None, **kwargs) -> tf.Tensor:
0 ignored issues
show
Bug introduced by
Parameters differ from overridden 'call' method
Loading history...
introduced by
"inputs, training" missing in parameter type documentation
Loading history...
309
        """
310
        Forward.
311
312
        :param inputs: inputs for the layer
313
        :param training: training flag for normalization layers (default: None)
314
        :param kwargs: additional arguments.
315
        :return:
316
        """
317
318
        output = inputs
319
        for i in range(self._num_layers):
320
            output = self._layers[i](inputs=output)
321
            output = self._norms[i](inputs=output, training=training)
322
            if i == self._num_layers - 1:
323
                # last block
324
                output = output + inputs
325
            output = self._acts[i](output)
326
        return output
327
328
    def get_config(self) -> dict:
329
        """Return the config dictionary for recreating this class."""
330
        config = super().get_config()
331
        config.update(self._config)
332
        return config
333
334
335
class ResidualConv3dBlock(ResidualBlock):
336
    """
337
    A conv3d residual block
338
    """
339
340
    def __init__(
341
        self,
342
        name: str = "conv3d_res_block",
343
        **kwargs,
344
    ):
345
        """
346
        Init.
347
348
        :param name: name of the layer
349
        :param kwargs: additional arguments.
350
        """
351
        super().__init__(layer_name="conv3d", name=name, **kwargs)
352
353
354
class IntDVF(tfkl.Layer):
355
    """
356
    Integrate DVF to get DDF.
357
358
    Reference:
359
360
    - integrate_vec of neuron
361
      https://github.com/adalca/neurite/blob/legacy/neuron/utils.py
362
    """
363
364
    def __init__(
365
        self,
366
        fixed_image_size: tuple,
367
        num_steps: int = 7,
368
        name: str = "int_dvf",
369
        **kwargs,
370
    ):
371
        """
372
        Init.
373
374
        :param fixed_image_size: tuple, (f_dim1, f_dim2, f_dim3)
375
        :param num_steps: int, number of steps for integration
376
        :param name: name of the layer
377
        :param kwargs: additional arguments.
378
        """
379
        super().__init__(name=name, **kwargs)
380
        assert len(fixed_image_size) == 3
381
        self._fixed_image_size = fixed_image_size
382
        self._num_steps = num_steps
383
        self._warping = Warping(fixed_image_size=fixed_image_size)
384
385
    def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor:
386
        """
387
        :param inputs: dvf, shape = (batch, f_dim1, f_dim2, f_dim3, 3)
388
        :param kwargs: additional arguments.
389
        :return: ddf, shape = (batch, f_dim1, f_dim2, f_dim3, 3)
390
        """
391
        ddf = inputs / (2 ** self._num_steps)
392
        for _ in range(self._num_steps):
393
            ddf += self._warping(inputs=[ddf, ddf])
394
        return ddf
395
396
    def get_config(self) -> dict:
397
        """Return the config dictionary for recreating this class."""
398
        config = super().get_config()
399
        config["fixed_image_size"] = self._fixed_image_size
400
        config["num_steps"] = self._num_steps
401
        return config
402
403
404
class ResizeCPTransform(tfkl.Layer):
405
    """
406
    Layer for getting the control points from the output of a image-to-image network.
407
    It uses an anti-aliasing Gaussian filter before downsampling.
408
    """
409
410
    def __init__(self, control_point_spacing: (list, tuple, int), **kwargs):
411
        """
412
        :param control_point_spacing: list or int
413
        :param kwargs: additional arguments.
414
        """
415
        super().__init__(**kwargs)
416
417
        if isinstance(control_point_spacing, int):
418
            control_point_spacing = [control_point_spacing] * 3
419
420
        self.kernel_sigma = [
421
            0.44 * cp for cp in control_point_spacing
422
        ]  # 0.44 = ln(4)/pi
423
        self.cp_spacing = control_point_spacing
424
        self.kernel = None
425
        self._output_shape = None
426
        self._resize = None
427
428
    def build(self, input_shape):
429
        super().build(input_shape=input_shape)
430
431
        self.kernel = layer_util.gaussian_filter_3d(self.kernel_sigma)
432
        output_shape = [
433
            tf.cast(tf.math.ceil(v / c) + 3, tf.int32)
434
            for v, c in zip(input_shape[1:-1], self.cp_spacing)
435
        ]
436
        self._output_shape = output_shape
437
        self._resize = Resize3d(output_shape)
438
439
    def call(self, inputs, **kwargs) -> tf.Tensor:
440
        output = tf.nn.conv3d(
441
            inputs, self.kernel, strides=(1, 1, 1, 1, 1), padding="SAME"
442
        )
443
        output = self._resize(inputs=output)
444
        return output
445
446
447
class BSplines3DTransform(tfkl.Layer):
448
    """
449
     Layer for BSplines interpolation with precomputed cubic spline kernel_size.
450
     It assumes a full sized image from which:
451
     1. it compute the contol points values by downsampling the initial image
452
     2. performs the interpolation
453
     3. crops the image around the valid values.
454
455
    :param cp_spacing: int or tuple of three ints specifying the spacing (in pixels)
456
        in each dimension. When a single int is used,
457
        the same spacing to all dimensions is used
458
    :param output_shape: (batch_size, dim0, dim1, dim2, 3) of the high resolution
459
        deformation fields.
460
    :param kwargs: additional arguments.
461
    """
462
463
    def __init__(self, cp_spacing: (int, tuple), output_shape: tuple, **kwargs):
464
465
        super().__init__(**kwargs)
466
467
        self.filters = []
468
        self._output_shape = output_shape
469
470
        if isinstance(cp_spacing, int):
471
            self.cp_spacing = (cp_spacing, cp_spacing, cp_spacing)
472
        else:
473
            self.cp_spacing = cp_spacing
474
475
    def build(self, input_shape: tuple):
0 ignored issues
show
introduced by
Redundant returns documentation
Loading history...
476
        """
477
        :param input_shape: tuple with the input shape
478
        :return: None
479
        """
480
481
        super().build(input_shape=input_shape)
482
483
        b = {
484
            0: lambda u: np.float64((1 - u) ** 3 / 6),
485
            1: lambda u: np.float64((3 * (u ** 3) - 6 * (u ** 2) + 4) / 6),
486
            2: lambda u: np.float64((-3 * (u ** 3) + 3 * (u ** 2) + 3 * u + 1) / 6),
487
            3: lambda u: np.float64(u ** 3 / 6),
488
        }
489
490
        filters = np.zeros(
491
            (
492
                4 * self.cp_spacing[0],
493
                4 * self.cp_spacing[1],
494
                4 * self.cp_spacing[2],
495
                3,
496
                3,
497
            ),
498
            dtype=np.float32,
499
        )
500
501
        u_arange = 1 - np.arange(
502
            1 / (2 * self.cp_spacing[0]), 1, 1 / self.cp_spacing[0]
503
        )
504
        v_arange = 1 - np.arange(
505
            1 / (2 * self.cp_spacing[1]), 1, 1 / self.cp_spacing[1]
506
        )
507
        w_arange = 1 - np.arange(
508
            1 / (2 * self.cp_spacing[2]), 1, 1 / self.cp_spacing[2]
509
        )
510
511
        filter_idx = [[0, 1, 2, 3] for _ in range(3)]
512
        filter_coord = list(itertools.product(*filter_idx))
513
514
        for f_idx in filter_coord:
515
            for it_dim in range(3):
516
                filters[
517
                    f_idx[0] * self.cp_spacing[0] : (f_idx[0] + 1) * self.cp_spacing[0],
518
                    f_idx[1] * self.cp_spacing[1] : (f_idx[1] + 1) * self.cp_spacing[1],
519
                    f_idx[2] * self.cp_spacing[2] : (f_idx[2] + 1) * self.cp_spacing[2],
520
                    it_dim,
521
                    it_dim,
522
                ] = (
523
                    b[f_idx[0]](u_arange)[:, None, None]
524
                    * b[f_idx[1]](v_arange)[None, :, None]
525
                    * b[f_idx[2]](w_arange)[None, None, :]
526
                )
527
528
        self.filter = tf.convert_to_tensor(filters)
529
530
    def interpolate(self, field) -> tf.Tensor:
0 ignored issues
show
introduced by
"field" missing in parameter type documentation
Loading history...
531
        """
532
        :param field: tf.Tensor with shape=number_of_control_points_per_dim
533
        :return: interpolated_field: tf.Tensor
534
        """
535
536
        image_shape = tuple(
537
            [(a - 1) * b + 4 * b for a, b in zip(field.shape[1:-1], self.cp_spacing)]
538
        )
539
540
        output_shape = (field.shape[0],) + image_shape + (3,)
541
        return tf.nn.conv3d_transpose(
542
            field,
543
            self.filter,
544
            output_shape=output_shape,
545
            strides=self.cp_spacing,
546
            padding="VALID",
547
        )
548
549
    def call(self, inputs, **kwargs) -> tf.Tensor:
0 ignored issues
show
introduced by
"inputs" missing in parameter type documentation
Loading history...
550
        """
551
        :param inputs: tf.Tensor defining a low resolution free-form deformation field
552
        :param kwargs: additional arguments.
553
        :return: interpolated_field: tf.Tensor of shape=self.input_shape
554
        """
555
        high_res_field = self.interpolate(inputs)
556
557
        index = [int(3 * c) for c in self.cp_spacing]
558
        return high_res_field[
559
            :,
560
            index[0] : index[0] + self._output_shape[0],
561
            index[1] : index[1] + self._output_shape[1],
562
            index[2] : index[2] + self._output_shape[2],
563
        ]
564