Passed
Pull Request — main (#656)
by Yunguan
31:19
created

LocalNet.build_bottom_block()   A

Complexity

Conditions 1

Size

Total Lines 16
Code Lines 4

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 4
dl 0
loc 16
rs 10
c 0
b 0
f 0
cc 1
nop 4
1
# coding=utf-8
0 ignored issues
show
introduced by
Missing module docstring
Loading history...
2
3
from typing import List, Tuple, Union
4
5
import tensorflow as tf
6
import tensorflow.keras.layers as tfkl
7
8
from deepreg.model import layer, layer_util
9
from deepreg.model.backbone.interface import Backbone
10
from deepreg.registry import REGISTRY
11
12
13
class AdditiveUpsampling(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
14
    def __init__(
15
        self,
16
        filters: int,
17
        output_padding: int,
18
        kernel_size: int,
19
        padding: str,
20
        strides: int,
21
        output_shape: tuple,
22
        name: str = "AdditiveUpsampling",
23
    ):
24
        """
25
        Addictive up-sampling layer.
26
27
        :param filters: number of channels for output
28
        :param output_padding: padding for output
29
        :param kernel_size: arg for deconv3d
30
        :param padding: arg for deconv3d
31
        :param strides: arg for deconv3d
32
        :param output_shape: shape of the output tensor
33
        :param name: name of the layer.
34
        """
35
        super().__init__(name=name)
36
        self.deconv3d = layer.Deconv3dBlock(
37
            filters=filters,
38
            output_padding=output_padding,
39
            kernel_size=kernel_size,
40
            strides=strides,
41
            padding=padding,
42
        )
43
        self.resize = layer.Resize3d(shape=output_shape)
44
45
    def call(self, inputs, **kwargs):
46
        deconved = self.deconv3d(inputs)
47
        resized = self.resize(inputs)
48
        resized = tf.add_n(tf.split(resized, num_or_size_splits=2, axis=4))
0 ignored issues
show
Unused Code introduced by
Argument 'axis' passed by position and keyword in function call
Loading history...
49
        return deconved + resized
50
51
52
class Extraction(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
53
    def __init__(
54
        self,
55
        image_size: Tuple[int],
56
        extract_levels: List[int],
57
        out_channels: int,
58
        out_kernel_initializer: str,
59
        out_activation: str,
60
        name: str = "Extraction",
61
    ):
62
        """
63
        :param image_size: such as (dim1, dim2, dim3)
64
        :param extract_levels: number of extraction levels.
65
        :param out_channels: number of channels for the extractions
66
        :param out_kernel_initializer: initializer to use for kernels.
67
        :param out_activation: activation to use at end layer.
68
        :param name: name of the layer
69
        """
70
        super().__init__(name=name)
71
        self.extract_levels = extract_levels
72
        self.max_level = max(extract_levels)
73
        self.layers = [
74
            tf.keras.Sequential(
75
                [
76
                    tfkl.Conv3D(
77
                        filters=out_channels,
78
                        kernel_size=3,
79
                        strides=1,
80
                        padding="same",
81
                        kernel_initializer=out_kernel_initializer,
82
                        activation=out_activation,
83
                    ),
84
                    layer.Resize3d(shape=image_size),
85
                ]
86
            )
87
            for _ in extract_levels
88
        ]
89
90
    def call(self, inputs: List[tf.Tensor], **kwargs) -> tf.Tensor:
91
        """
92
93
        :param inputs: a list of tensors
94
        :param kwargs:
95
        :return:
96
        """
97
98
        return tf.add_n(
99
            [
100
                self.layers[idx](inputs=inputs[self.max_level - level])
101
                for idx, level in enumerate(self.extract_levels)
102
            ]
103
        ) / len(self.extract_levels)
104
105
106
@REGISTRY.register_backbone(name="local")
107
class LocalNet(Backbone):
108
    """
109
    Build LocalNet for image registration.
110
111
    Reference:
112
113
    - Hu, Yipeng, et al.
114
      "Weakly-supervised convolutional neural networks
115
      for multimodal image registration."
116
      Medical image analysis 49 (2018): 1-13.
117
      https://doi.org/10.1016/j.media.2018.07.002
118
119
    - Hu, Yipeng, et al.
120
      "Label-driven weakly-supervised learning
121
      for multimodal deformable image registration,"
122
      https://arxiv.org/abs/1711.01666
123
    """
124
125
    def __init__(
126
        self,
127
        image_size: tuple,
128
        out_channels: int,
129
        num_channel_initial: int,
130
        extract_levels: List[int],
131
        out_kernel_initializer: str,
132
        out_activation: str,
133
        use_additive_upsampling: bool = True,
134
        name: str = "LocalNet",
135
        **kwargs,
136
    ):
137
        """
138
        Image is encoded gradually, i from level 0 to E,
139
        then it is decoded gradually, j from level E to D.
140
        Some of the decoded levels are used for generating extractions.
141
142
        So, extract_levels are between [0, E] with E = max(extract_levels),
143
        and D = min(extract_levels).
144
145
        :param image_size: such as (dim1, dim2, dim3)
146
        :param out_channels: number of channels for the extractions
147
        :param num_channel_initial: number of initial channels.
148
        :param extract_levels: number of extraction levels.
149
        :param out_kernel_initializer: initializer to use for kernels.
150
        :param out_activation: activation to use at end layer.
151
        :param use_additive_upsampling: whether use additive up-sampling.
152
        :param name: name of the backbone.
153
        :param kwargs: additional arguments.
154
        """
155
        super().__init__(
156
            image_size=image_size,
157
            out_channels=out_channels,
158
            num_channel_initial=num_channel_initial,
159
            out_kernel_initializer=out_kernel_initializer,
160
            out_activation=out_activation,
161
            name=name,
162
            **kwargs,
163
        )
164
165
        # save parameters
166
        self._extract_levels = extract_levels
167
        self._use_additive_upsampling = use_additive_upsampling
168
        self._extract_max_level = max(self._extract_levels)  # E
169
        self._extract_min_level = min(self._extract_levels)  # D
170
171
        # init layer variables
172
        num_channels = [
173
            num_channel_initial * (2 ** level)
174
            for level in range(self._extract_max_level + 1)
175
        ]  # level 0 to E
176
        kernel_sizes = [
177
            7 if level == 0 else 3 for level in range(self._extract_max_level + 1)
178
        ]
179
        self._downsample_convs = []
180
        self._downsample_pools = []
181
        tensor_shape = image_size
182
        self._tensor_shapes = [tensor_shape]
183
        for i in range(self._extract_max_level):
184
            downsample_conv = self.build_conv_block(
185
                filters=num_channels[i], kernel_size=kernel_sizes[i], padding="same"
186
            )
187
            downsample_pool = self.build_down_sampling_block(
188
                kernel_size=2, strides=2, padding="same"
189
            )
190
            tensor_shape = tuple((x + 1) // 2 for x in tensor_shape)
191
            self._downsample_convs.append(downsample_conv)
192
            self._downsample_pools.append(downsample_pool)
193
            self._tensor_shapes.append(tensor_shape)
194
195
        self._bottom_block = self.build_bottom_block(
196
            filters=num_channels[-1], kernel_size=3, padding="same"
197
        )  # level E
198
199
        self._upsample_deconvs = []
200
        self._upsample_convs = []
201
        for level in range(
202
            self._extract_max_level - 1, self._extract_min_level - 1, -1
203
        ):  # level D to E-1
204
            padding = layer_util.deconv_output_padding(
205
                input_shape=self._tensor_shapes[level + 1],
206
                output_shape=self._tensor_shapes[level],
207
                kernel_size=kernel_sizes[level],
208
                stride=2,
209
                padding="same",
210
            )
211
            upsample_deconv = self.build_up_sampling_block(
212
                filters=num_channels[level],
213
                output_padding=padding,
214
                kernel_size=3,
215
                strides=2,
216
                padding="same",
217
                output_shape=self._tensor_shapes[level],
218
            )
219
            upsample_conv = self.build_conv_block(
220
                filters=num_channels[level], kernel_size=3, padding="same"
221
            )
222
            self._upsample_deconvs.append(upsample_deconv)
223
            self._upsample_convs.append(upsample_conv)
224
        self._output = self.build_output_block(
225
            image_size=image_size,
226
            extract_levels=extract_levels,
227
            out_channels=out_channels,
228
            out_kernel_initializer=out_kernel_initializer,
229
            out_activation=out_activation,
230
        )
231
232
    def build_conv_block(
233
        self, filters: int, kernel_size: int, padding: str
234
    ) -> Union[tf.keras.Model, tfkl.Layer]:
235
        """
236
        Build a conv block for down-sampling or up-sampling.
237
238
        This block do not change the tensor shape (width, height, depth),
239
        it only changes the number of channels.
240
241
        :param filters: number of channels for output
242
        :param kernel_size: arg for conv3d
243
        :param padding: arg for conv3d
244
        :return: a block consists of one or multiple layers
245
        """
246
        return tf.keras.Sequential(
247
            [
248
                layer.Conv3dBlock(
249
                    filters=filters,
250
                    kernel_size=kernel_size,
251
                    padding=padding,
252
                ),
253
                layer.ResidualConv3dBlock(
254
                    filters=filters,
255
                    kernel_size=kernel_size,
256
                    padding=padding,
257
                ),
258
            ]
259
        )
260
261
    def build_down_sampling_block(
262
        self, kernel_size: int, padding: str, strides: int
263
    ) -> Union[tf.keras.Model, tfkl.Layer]:
264
        """
265
        Build a block for down-sampling.
266
267
        This block changes the tensor shape (width, height, depth),
268
        but it does not changes the number of channels.
269
270
        :param kernel_size: arg for pool3d
271
        :param padding: arg for pool3d
272
        :param strides: arg for pool3d
273
        :return: a block consists of one or multiple layers
274
        """
275
        return tfkl.MaxPool3D(pool_size=kernel_size, strides=strides, padding=padding)
276
277
    def build_bottom_block(
278
        self, filters: int, kernel_size: int, padding: str
279
    ) -> Union[tf.keras.Model, tfkl.Layer]:
280
        """
281
        Build a block for bottom layer.
282
283
        This block do not change the tensor shape (width, height, depth),
284
        it only changes the number of channels.
285
286
        :param filters: number of channels for output
287
        :param kernel_size: arg for conv3d
288
        :param padding: arg for conv3d
289
        :return: a block consists of one or multiple layers
290
        """
291
        return layer.Conv3dBlock(
292
            filters=filters, kernel_size=kernel_size, padding=padding
293
        )
294
295
    def build_up_sampling_block(
296
        self,
297
        filters: int,
298
        output_padding: int,
299
        kernel_size: int,
300
        padding: str,
301
        strides: int,
302
        output_shape: tuple,
303
    ) -> Union[tf.keras.Model, tfkl.Layer]:
304
        """
305
        Build a block for up-sampling.
306
307
        This block changes the tensor shape (width, height, depth),
308
        but it does not changes the number of channels.
309
310
        :param filters: number of channels for output
311
        :param output_padding: padding for output
312
        :param kernel_size: arg for deconv3d
313
        :param padding: arg for deconv3d
314
        :param strides: arg for deconv3d
315
        :param output_shape: shape of the output tensor
316
        :return: a block consists of one or multiple layers
317
        """
318
319
        if self._use_additive_upsampling:
320
            return AdditiveUpsampling(
321
                filters=filters,
322
                output_padding=output_padding,
323
                kernel_size=kernel_size,
324
                strides=strides,
325
                padding=padding,
326
                output_shape=output_shape,
327
            )
328
329
        return layer.Deconv3dBlock(
330
            filters=filters,
331
            output_padding=output_padding,
332
            kernel_size=kernel_size,
333
            strides=strides,
334
            padding=padding,
335
        )
336
337
    def build_skip_block(self) -> Union[tf.keras.Model, tfkl.Layer]:
338
        """
339
        Build a block for combining skipped tensor and up-sampled one.
340
341
        This block do not change the tensor shape (width, height, depth),
342
        it only changes the number of channels.
343
344
        The input to this block is a list of tensors.
345
346
        :return: a block consists of one or multiple layers
347
        """
348
        return tfkl.Add()
349
350
    def build_output_block(
351
        self,
352
        image_size: Tuple[int],
353
        extract_levels: List[int],
354
        out_channels: int,
355
        out_kernel_initializer: str,
356
        out_activation: str,
357
    ) -> Union[tf.keras.Model, tfkl.Layer]:
358
        """
359
        Build a block for output.
360
361
        The input to this block is a list of tensors.
362
363
        :param image_size: such as (dim1, dim2, dim3)
364
        :param extract_levels: number of extraction levels.
365
        :param out_channels: number of channels for the extractions
366
        :param out_kernel_initializer: initializer to use for kernels.
367
        :param out_activation: activation to use at end layer.
368
        :return: a block consists of one or multiple layers
369
        """
370
        return Extraction(
371
            image_size=image_size,
372
            extract_levels=extract_levels,
373
            out_channels=out_channels,
374
            out_kernel_initializer=out_kernel_initializer,
375
            out_activation=out_activation,
376
        )
377
378
    def call(self, inputs: tf.Tensor, training=None, mask=None) -> tf.Tensor:
0 ignored issues
show
introduced by
"mask, training" missing in parameter type documentation
Loading history...
379
        """
380
        Build LocalNet graph based on built layers.
381
382
        :param inputs: image batch, shape = (batch, f_dim1, f_dim2, f_dim3, ch)
383
        :param training: None or bool.
384
        :param mask: None or tf.Tensor.
385
        :return: shape = (batch, f_dim1, f_dim2, f_dim3, out_channels)
386
        """
387
388
        # down sample from level 0 to E
389
        # outputs used for decoding, encoded[i] corresponds -> level i
390
        # stored only 0 to E-1
391
        skips = []
392
        down_sampled = inputs
393
        for level in range(self._extract_max_level):  # level 0 to E - 1
394
            skip = self._downsample_convs[level](inputs=down_sampled, training=training)
395
            down_sampled = self._downsample_pools[level](inputs=skip, training=training)
396
            skips.append(skip)
397
        up_sampled = self._bottom_block(
398
            inputs=down_sampled, training=training
399
        )  # level E of encoding/decoding
400
401
        # up sample from level E to D
402
        outs = [up_sampled]  # level E
403
        for idx, level in enumerate(
404
            range(self._extract_max_level - 1, self._extract_min_level - 1, -1)
405
        ):  # level E-1 to D
406
            up_sampled = self._upsample_deconvs[idx](
407
                inputs=up_sampled, training=training
408
            )
409
            up_sampled = self.build_skip_block()([up_sampled, skips[level]])
410
            up_sampled = self._upsample_convs[idx](inputs=up_sampled, training=training)
411
            outs.append(up_sampled)
412
413
        # output
414
        output = self._output(outs)
415
416
        return output
417