Passed
Pull Request — main (#656)
by Yunguan
18:04 queued 32s
created

deepreg.model.backbone.local_net.LocalNet.call()   A

Complexity

Conditions 3

Size

Total Lines 33
Code Lines 16

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 16
dl 0
loc 33
rs 9.6
c 0
b 0
f 0
cc 3
nop 4
1
# coding=utf-8
0 ignored issues
show
introduced by
Missing module docstring
Loading history...
2
3
from typing import List, Tuple, Union
4
5
import tensorflow as tf
6
import tensorflow.keras.layers as tfkl
7
from tensorflow.python.keras.utils import conv_utils
8
9
from deepreg.model import layer, layer_util
10
from deepreg.model.backbone.interface import Backbone
11
from deepreg.registry import REGISTRY
12
13
14
class AdditiveUpsampling(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
15
    def __init__(
16
        self,
17
        filters: int,
18
        output_padding: int,
19
        kernel_size: int,
20
        padding: str,
21
        strides: int,
22
        output_shape: tuple,
23
        name: str = "AdditiveUpsampling",
24
    ):
25
        """
26
        Addictive up-sampling layer.
27
28
        :param filters: number of channels for output
29
        :param output_padding: padding for output
30
        :param kernel_size: arg for deconv3d
31
        :param padding: arg for deconv3d
32
        :param strides: arg for deconv3d
33
        :param output_shape: shape of the output tensor
34
        :param name: name of the layer.
35
        """
36
        super().__init__(name=name)
37
        self.deconv3d = layer.Deconv3dBlock(
38
            filters=filters,
39
            output_padding=output_padding,
40
            kernel_size=kernel_size,
41
            strides=strides,
42
            padding=padding,
43
        )
44
        self.resize = layer.Resize3d(shape=output_shape)
45
46
    def call(self, inputs, **kwargs):
47
        deconved = self.deconv3d(inputs)
48
        resized = self.resize(inputs)
49
        resized = tf.add_n(tf.split(resized, num_or_size_splits=2, axis=4))
0 ignored issues
show
Unused Code introduced by
Argument 'axis' passed by position and keyword in function call
Loading history...
50
        return deconved + resized
51
52
53
class Extraction(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
54
    def __init__(
55
        self,
56
        image_size: Tuple[int],
57
        extract_levels: List[int],
58
        out_channels: int,
59
        out_kernel_initializer: str,
60
        out_activation: str,
61
        name: str = "Extraction",
62
    ):
63
        """
64
        :param image_size: such as (dim1, dim2, dim3)
65
        :param extract_levels: number of extraction levels.
66
        :param out_channels: number of channels for the extractions
67
        :param out_kernel_initializer: initializer to use for kernels.
68
        :param out_activation: activation to use at end layer.
69
        :param name: name of the layer
70
        """
71
        super().__init__(name=name)
72
        self.extract_levels = extract_levels
73
        self.max_level = max(extract_levels)
74
        self.layers = [
75
            tf.keras.Sequential(
76
                [
77
                    tfkl.Conv3D(
78
                        filters=out_channels,
79
                        kernel_size=3,
80
                        strides=1,
81
                        padding="same",
82
                        kernel_initializer=out_kernel_initializer,
83
                        activation=out_activation,
84
                    ),
85
                    layer.Resize3d(shape=image_size),
86
                ]
87
            )
88
            for _ in extract_levels
89
        ]
90
91
    def call(self, inputs: List[tf.Tensor], **kwargs) -> tf.Tensor:
92
        """
93
94
        :param inputs: a list of tensors
95
        :param kwargs:
96
        :return:
97
        """
98
99
        return tf.add_n(
100
            [
101
                self.layers[idx](inputs=inputs[self.max_level - level])
102
                for idx, level in enumerate(self.extract_levels)
103
            ]
104
        ) / len(self.extract_levels)
105
106
107
@REGISTRY.register_backbone(name="local")
108
class LocalNet(Backbone):
109
    """
110
    Build LocalNet for image registration.
111
112
    Reference:
113
114
    - Hu, Yipeng, et al.
115
      "Weakly-supervised convolutional neural networks
116
      for multimodal image registration."
117
      Medical image analysis 49 (2018): 1-13.
118
      https://doi.org/10.1016/j.media.2018.07.002
119
120
    - Hu, Yipeng, et al.
121
      "Label-driven weakly-supervised learning
122
      for multimodal deformable image registration,"
123
      https://arxiv.org/abs/1711.01666
124
    """
125
126
    def __init__(
127
        self,
128
        image_size: tuple,
129
        num_channel_initial: int,
130
        extract_levels: List[int],
131
        out_kernel_initializer: str,
132
        out_activation: str,
133
        out_channels: int,
134
        use_additive_upsampling: bool = True,
135
        name: str = "LocalNet",
136
        **kwargs,
137
    ):
138
        """
139
        Init.
140
141
        Image is encoded gradually, i from level 0 to D,
142
        then it is decoded gradually, j from level D to 0.
143
        Some of the decoded levels are used for generating extractions.
144
145
        So, extract_levels are between [0, D].
146
147
        :param image_size: such as (dim1, dim2, dim3)
148
        :param out_channels: number of channels for the extractions
149
        :param num_channel_initial: number of initial channels.
150
        :param extract_levels: number of extraction levels.
151
        :param out_kernel_initializer: initializer to use for kernels.
152
        :param out_activation: activation to use at end layer.
153
        :param use_additive_upsampling: whether use additive up-sampling.
154
        :param name: name of the backbone.
155
        :param kwargs: additional arguments.
156
        """
157
        super().__init__(
158
            image_size=image_size,
159
            out_channels=out_channels,
160
            num_channel_initial=num_channel_initial,
161
            out_kernel_initializer=out_kernel_initializer,
162
            out_activation=out_activation,
163
            name=name,
164
            **kwargs,
165
        )
166
167
        # save parameters
168
        self._extract_levels = extract_levels
169
        self._use_additive_upsampling = use_additive_upsampling
170
        self._depth = max(self._extract_levels)  # D
171
172
        # init layers
173
        # all lists start with d = 0
174
        self._downsample_convs = None
175
        self._downsample_pools = None
176
        self._bottom_block = None
177
        self._upsample_deconvs = None
178
        self._upsample_convs = None
179
        self._output_block = None
180
181
        # build layers
182
        self.build_layers(
183
            image_size=image_size,
184
            num_channel_initial=num_channel_initial,
185
            depth=self._depth,
186
            extract_levels=self._extract_levels,
187
            downsample_kernel_sizes=[7] + [3] * self._depth,
188
            upsample_kernel_sizes=3,
189
            strides=2,
190
            padding="same",
191
            out_kernel_initializer=out_kernel_initializer,
192
            out_activation=out_activation,
193
            out_channels=out_channels,
194
        )
195
196
    def build_layers(
197
        self,
198
        image_size: tuple,
199
        num_channel_initial: int,
200
        depth: int,
201
        extract_levels: List[int],
202
        downsample_kernel_sizes: Union[int, List[int]],
203
        upsample_kernel_sizes: Union[int, List[int]],
204
        strides: int,
205
        padding: str,
206
        out_kernel_initializer: str,
207
        out_activation: str,
208
        out_channels: int,
209
    ):
210
        """
211
        Build layers that will be used in call.
212
213
        :param image_size: (dim1, dim2, dim3).
214
        :param num_channel_initial: number of initial channels.
215
        :param depth: network starts with d = 0, and the bottom has d = depth.
216
        :param extract_levels: from which depths the output will be built.
217
        :param downsample_kernel_sizes: kernel size for down-sampling
218
        :param upsample_kernel_sizes: kernel size for up-sampling
219
        :param strides: strides for down-sampling
220
        :param padding: padding mode for all conv layers
221
        :param out_kernel_initializer: initializer to use for kernels.
222
        :param out_activation: activation to use at end layer.
223
        :param out_channels: number of channels for the extractions
224
        """
225
        # init params
226
        min_extract_level = min(extract_levels)
227
        num_channels = [num_channel_initial * (2 ** d) for d in range(depth + 1)]
228
        if isinstance(downsample_kernel_sizes, int):
229
            downsample_kernel_sizes = [downsample_kernel_sizes] * (depth + 1)
230
        assert len(downsample_kernel_sizes) == depth + 1
231
        if isinstance(upsample_kernel_sizes, int):
232
            upsample_kernel_sizes = [upsample_kernel_sizes] * depth
233
        assert len(upsample_kernel_sizes) == depth
234
235
        # down-sampling
236
        self._downsample_convs = []
237
        self._downsample_pools = []
238
        tensor_shape = image_size
239
        tensor_shapes = [tensor_shape]
240
        for d in range(depth):
241
            downsample_conv = self.build_conv_block(
242
                filters=num_channels[d],
243
                kernel_size=downsample_kernel_sizes[d],
244
                padding=padding,
245
            )
246
            downsample_pool = self.build_down_sampling_block(
247
                kernel_size=strides, strides=strides, padding=padding
248
            )
249
            tensor_shape = tuple(
250
                conv_utils.conv_output_length(
251
                    input_length=x,
252
                    filter_size=strides,
253
                    padding=padding,
254
                    stride=strides,
255
                    dilation=1,
256
                )
257
                for x in tensor_shape
258
            )
259
            self._downsample_convs.append(downsample_conv)
260
            self._downsample_pools.append(downsample_pool)
261
            tensor_shapes.append(tensor_shape)
262
263
        # bottom layer
264
        self._bottom_block = self.build_bottom_block(
265
            filters=num_channels[depth],
266
            kernel_size=downsample_kernel_sizes[depth],
267
            padding=padding,
268
        )
269
270
        # up-sampling
271
        self._upsample_deconvs = []
272
        self._upsample_convs = []
273
        for d in range(depth - 1, min_extract_level - 1, -1):
274
            kernel_size = upsample_kernel_sizes[d]
275
            output_padding = layer_util.deconv_output_padding(
276
                input_shape=tensor_shapes[d + 1],
277
                output_shape=tensor_shapes[d],
278
                kernel_size=kernel_size,
279
                stride=strides,
280
                padding=padding,
281
            )
282
            upsample_deconv = self.build_up_sampling_block(
283
                filters=num_channels[d],
284
                output_padding=output_padding,
285
                kernel_size=kernel_size,
286
                strides=strides,
287
                padding=padding,
288
                output_shape=tensor_shapes[d],
289
            )
290
            upsample_conv = self.build_conv_block(
291
                filters=num_channels[d], kernel_size=kernel_size, padding=padding
292
            )
293
            self._upsample_deconvs = [upsample_deconv] + self._upsample_deconvs
294
            self._upsample_convs = [upsample_conv] + self._upsample_convs
295
        if min_extract_level > 0:
296
            # add Nones to make lists have length depth - 1
297
            self._upsample_deconvs = [None] * min_extract_level + self._upsample_deconvs
298
            self._upsample_convs = [None] * min_extract_level + self._upsample_convs
299
300
        # extraction
301
        self._output_block = self.build_output_block(
302
            image_size=image_size,
303
            extract_levels=extract_levels,
304
            out_channels=out_channels,
305
            out_kernel_initializer=out_kernel_initializer,
306
            out_activation=out_activation,
307
        )
308
309
    def build_conv_block(
310
        self, filters: int, kernel_size: int, padding: str
311
    ) -> Union[tf.keras.Model, tfkl.Layer]:
312
        """
313
        Build a conv block for down-sampling or up-sampling.
314
315
        This block do not change the tensor shape (width, height, depth),
316
        it only changes the number of channels.
317
318
        :param filters: number of channels for output
319
        :param kernel_size: arg for conv3d
320
        :param padding: arg for conv3d
321
        :return: a block consists of one or multiple layers
322
        """
323
        return tf.keras.Sequential(
324
            [
325
                layer.Conv3dBlock(
326
                    filters=filters,
327
                    kernel_size=kernel_size,
328
                    padding=padding,
329
                ),
330
                layer.ResidualConv3dBlock(
331
                    filters=filters,
332
                    kernel_size=kernel_size,
333
                    padding=padding,
334
                ),
335
            ]
336
        )
337
338
    def build_down_sampling_block(
339
        self, kernel_size: int, padding: str, strides: int
340
    ) -> Union[tf.keras.Model, tfkl.Layer]:
341
        """
342
        Build a block for down-sampling.
343
344
        This block changes the tensor shape (width, height, depth),
345
        but it does not changes the number of channels.
346
347
        :param kernel_size: arg for pool3d
348
        :param padding: arg for pool3d
349
        :param strides: arg for pool3d
350
        :return: a block consists of one or multiple layers
351
        """
352
        return tfkl.MaxPool3D(pool_size=kernel_size, strides=strides, padding=padding)
353
354
    def build_bottom_block(
355
        self, filters: int, kernel_size: int, padding: str
356
    ) -> Union[tf.keras.Model, tfkl.Layer]:
357
        """
358
        Build a block for bottom layer.
359
360
        This block do not change the tensor shape (width, height, depth),
361
        it only changes the number of channels.
362
363
        :param filters: number of channels for output
364
        :param kernel_size: arg for conv3d
365
        :param padding: arg for conv3d
366
        :return: a block consists of one or multiple layers
367
        """
368
        return layer.Conv3dBlock(
369
            filters=filters, kernel_size=kernel_size, padding=padding
370
        )
371
372
    def build_up_sampling_block(
373
        self,
374
        filters: int,
375
        output_padding: int,
376
        kernel_size: int,
377
        padding: str,
378
        strides: int,
379
        output_shape: tuple,
380
    ) -> Union[tf.keras.Model, tfkl.Layer]:
381
        """
382
        Build a block for up-sampling.
383
384
        This block changes the tensor shape (width, height, depth),
385
        but it does not changes the number of channels.
386
387
        :param filters: number of channels for output
388
        :param output_padding: padding for output
389
        :param kernel_size: arg for deconv3d
390
        :param padding: arg for deconv3d
391
        :param strides: arg for deconv3d
392
        :param output_shape: shape of the output tensor
393
        :return: a block consists of one or multiple layers
394
        """
395
396
        if self._use_additive_upsampling:
397
            return AdditiveUpsampling(
398
                filters=filters,
399
                output_padding=output_padding,
400
                kernel_size=kernel_size,
401
                strides=strides,
402
                padding=padding,
403
                output_shape=output_shape,
404
            )
405
406
        return layer.Deconv3dBlock(
407
            filters=filters,
408
            output_padding=output_padding,
409
            kernel_size=kernel_size,
410
            strides=strides,
411
            padding=padding,
412
        )
413
414
    def build_skip_block(self) -> Union[tf.keras.Model, tfkl.Layer]:
415
        """
416
        Build a block for combining skipped tensor and up-sampled one.
417
418
        This block do not change the tensor shape (width, height, depth),
419
        it only changes the number of channels.
420
421
        The input to this block is a list of tensors.
422
423
        :return: a block consists of one or multiple layers
424
        """
425
        return tfkl.Add()
426
427
    def build_output_block(
428
        self,
429
        image_size: Tuple[int],
430
        extract_levels: List[int],
431
        out_channels: int,
432
        out_kernel_initializer: str,
433
        out_activation: str,
434
    ) -> Union[tf.keras.Model, tfkl.Layer]:
435
        """
436
        Build a block for output.
437
438
        The input to this block is a list of tensors.
439
440
        :param image_size: such as (dim1, dim2, dim3)
441
        :param extract_levels: number of extraction levels.
442
        :param out_channels: number of channels for the extractions
443
        :param out_kernel_initializer: initializer to use for kernels.
444
        :param out_activation: activation to use at end layer.
445
        :return: a block consists of one or multiple layers
446
        """
447
        return Extraction(
448
            image_size=image_size,
449
            extract_levels=extract_levels,
450
            out_channels=out_channels,
451
            out_kernel_initializer=out_kernel_initializer,
452
            out_activation=out_activation,
453
        )
454
455
    def call(self, inputs: tf.Tensor, training=None, mask=None) -> tf.Tensor:
0 ignored issues
show
introduced by
"mask, training" missing in parameter type documentation
Loading history...
456
        """
457
        Build LocalNet graph based on built layers.
458
459
        :param inputs: image batch, shape = (batch, f_dim1, f_dim2, f_dim3, ch)
460
        :param training: None or bool.
461
        :param mask: None or tf.Tensor.
462
        :return: shape = (batch, f_dim1, f_dim2, f_dim3, out_channels)
463
        """
464
465
        # down-sampling
466
        skips = []
467
        down_sampled = inputs
468
        for d in range(self._depth):
469
            skip = self._downsample_convs[d](inputs=down_sampled, training=training)
470
            down_sampled = self._downsample_pools[d](inputs=skip, training=training)
471
            skips.append(skip)
472
473
        # bottom
474
        up_sampled = self._bottom_block(inputs=down_sampled, training=training)
475
476
        # up-sampling
477
        outs = [up_sampled]
478
        for d in range(self._depth - 1, min(self._extract_levels) - 1, -1):
479
            up_sampled = self._upsample_deconvs[d](inputs=up_sampled, training=training)
480
            up_sampled = self.build_skip_block()([up_sampled, skips[d]])
481
            up_sampled = self._upsample_convs[d](inputs=up_sampled, training=training)
482
            outs.append(up_sampled)
483
484
        # output
485
        output = self._output_block(outs)
486
487
        return output
488