Passed
Pull Request — main (#656)
by Yunguan
11:15 queued 50s
created

deepreg.model.backbone.local_net.LocalNet.call()   A

Complexity

Conditions 3

Size

Total Lines 50
Code Lines 24

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 24
dl 0
loc 50
rs 9.304
c 0
b 0
f 0
cc 3
nop 4
1
# coding=utf-8
0 ignored issues
show
introduced by
Missing module docstring
Loading history...
2
3
from typing import List, Tuple, Union
4
5
import tensorflow as tf
6
import tensorflow.keras.layers as tfkl
7
8
from deepreg.model import layer
9
from deepreg.model.backbone.u_net import AbstractUNet
10
from deepreg.registry import REGISTRY
11
12
13
class AdditiveUpsampling(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
14
    def __init__(
15
        self,
16
        filters: int,
17
        output_padding: int,
18
        kernel_size: int,
19
        padding: str,
20
        strides: int,
21
        output_shape: tuple,
22
        name: str = "AdditiveUpsampling",
23
    ):
24
        """
25
        Addictive up-sampling layer.
26
27
        :param filters: number of channels for output
28
        :param output_padding: padding for output
29
        :param kernel_size: arg for deconv3d
30
        :param padding: arg for deconv3d
31
        :param strides: arg for deconv3d
32
        :param output_shape: shape of the output tensor
33
        :param name: name of the layer.
34
        """
35
        super().__init__(name=name)
36
        self.deconv3d = layer.Deconv3dBlock(
37
            filters=filters,
38
            output_padding=output_padding,
39
            kernel_size=kernel_size,
40
            strides=strides,
41
            padding=padding,
42
        )
43
        self.resize = layer.Resize3d(shape=output_shape)
44
45
    def call(self, inputs, **kwargs):
46
        deconved = self.deconv3d(inputs)
47
        resized = self.resize(inputs)
48
        resized = tf.add_n(tf.split(resized, num_or_size_splits=2, axis=4))
0 ignored issues
show
Unused Code introduced by
Argument 'axis' passed by position and keyword in function call
Loading history...
49
        return deconved + resized
50
51
52
class Extraction(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
53
    def __init__(
54
        self,
55
        image_size: Tuple[int],
56
        extract_levels: List[int],
57
        out_channels: int,
58
        out_kernel_initializer: str,
59
        out_activation: str,
60
        name: str = "Extraction",
61
    ):
62
        """
63
        :param image_size: such as (dim1, dim2, dim3)
64
        :param extract_levels: number of extraction levels.
65
        :param out_channels: number of channels for the extractions
66
        :param out_kernel_initializer: initializer to use for kernels.
67
        :param out_activation: activation to use at end layer.
68
        :param name: name of the layer
69
        """
70
        super().__init__(name=name)
71
        self.extract_levels = extract_levels
72
        self.max_level = max(extract_levels)
73
        self.layers = [
74
            tf.keras.Sequential(
75
                [
76
                    tfkl.Conv3D(
77
                        filters=out_channels,
78
                        kernel_size=3,
79
                        strides=1,
80
                        padding="same",
81
                        kernel_initializer=out_kernel_initializer,
82
                        activation=out_activation,
83
                    ),
84
                    layer.Resize3d(shape=image_size),
85
                ]
86
            )
87
            for _ in extract_levels
88
        ]
89
90
    def call(self, inputs: List[tf.Tensor], **kwargs) -> tf.Tensor:
91
        """
92
93
        :param inputs: a list of tensors
94
        :param kwargs:
95
        :return:
96
        """
97
98
        return tf.add_n(
99
            [
100
                self.layers[idx](inputs=inputs[self.max_level - level])
101
                for idx, level in enumerate(self.extract_levels)
102
            ]
103
        ) / len(self.extract_levels)
104
105
106
@REGISTRY.register_backbone(name="local")
107
class LocalNet(AbstractUNet):
108
    """
109
    Build LocalNet for image registration.
110
111
    Reference:
112
113
    - Hu, Yipeng, et al.
114
      "Weakly-supervised convolutional neural networks
115
      for multimodal image registration."
116
      Medical image analysis 49 (2018): 1-13.
117
      https://doi.org/10.1016/j.media.2018.07.002
118
119
    - Hu, Yipeng, et al.
120
      "Label-driven weakly-supervised learning
121
      for multimodal deformable image registration,"
122
      https://arxiv.org/abs/1711.01666
123
    """
124
125
    def __init__(
126
        self,
127
        image_size: tuple,
128
        num_channel_initial: int,
129
        extract_levels: List[int],
130
        out_kernel_initializer: str,
131
        out_activation: str,
132
        out_channels: int,
133
        use_additive_upsampling: bool = True,
134
        name: str = "LocalNet",
135
        **kwargs,
136
    ):
137
        """
138
        Init.
139
140
        Image is encoded gradually, i from level 0 to D,
141
        then it is decoded gradually, j from level D to 0.
142
        Some of the decoded levels are used for generating extractions.
143
144
        So, extract_levels are between [0, D].
145
146
        :param image_size: such as (dim1, dim2, dim3)
147
        :param num_channel_initial: number of initial channels.
148
        :param extract_levels: from which depths the output will be built.
149
        :param out_kernel_initializer: initializer to use for kernels.
150
        :param out_activation: activation to use at end layer.
151
        :param out_channels: number of channels for the extractions
152
        :param use_additive_upsampling: whether use additive up-sampling.
153
        :param name: name of the backbone.
154
        :param kwargs: additional arguments.
155
        """
156
        super().__init__(
157
            image_size=image_size,
158
            num_channel_initial=num_channel_initial,
159
            depth=max(extract_levels),
160
            extract_levels=extract_levels,
161
            out_kernel_initializer=out_kernel_initializer,
162
            out_activation=out_activation,
163
            out_channels=out_channels,
164
            name=name,
165
            **kwargs,
166
        )
167
168
        # save extra parameters
169
        self._use_additive_upsampling = use_additive_upsampling
170
171
        # build layers
172
        self.build_layers(
173
            image_size=image_size,
174
            num_channel_initial=num_channel_initial,
175
            depth=self._depth,
176
            extract_levels=self._extract_levels,
177
            downsample_kernel_sizes=[7] + [3] * self._depth,
178
            upsample_kernel_sizes=3,
179
            strides=2,
180
            padding="same",
181
            out_kernel_initializer=out_kernel_initializer,
182
            out_activation=out_activation,
183
            out_channels=out_channels,
184
        )
185
186
    def build_conv_block(
187
        self, filters: int, kernel_size: int, padding: str
188
    ) -> Union[tf.keras.Model, tfkl.Layer]:
189
        """
190
        Build a conv block for down-sampling or up-sampling.
191
192
        This block do not change the tensor shape (width, height, depth),
193
        it only changes the number of channels.
194
195
        :param filters: number of channels for output
196
        :param kernel_size: arg for conv3d
197
        :param padding: arg for conv3d
198
        :return: a block consists of one or multiple layers
199
        """
200
        return tf.keras.Sequential(
201
            [
202
                layer.Conv3dBlock(
203
                    filters=filters,
204
                    kernel_size=kernel_size,
205
                    padding=padding,
206
                ),
207
                layer.ResidualConv3dBlock(
208
                    filters=filters,
209
                    kernel_size=kernel_size,
210
                    padding=padding,
211
                ),
212
            ]
213
        )
214
215
    def build_down_sampling_block(
216
        self, kernel_size: int, padding: str, strides: int
217
    ) -> Union[tf.keras.Model, tfkl.Layer]:
218
        """
219
        Build a block for down-sampling.
220
221
        This block changes the tensor shape (width, height, depth),
222
        but it does not changes the number of channels.
223
224
        :param kernel_size: arg for pool3d
225
        :param padding: arg for pool3d
226
        :param strides: arg for pool3d
227
        :return: a block consists of one or multiple layers
228
        """
229
        return tfkl.MaxPool3D(pool_size=kernel_size, strides=strides, padding=padding)
230
231
    def build_bottom_block(
232
        self, filters: int, kernel_size: int, padding: str
233
    ) -> Union[tf.keras.Model, tfkl.Layer]:
234
        """
235
        Build a block for bottom layer.
236
237
        This block do not change the tensor shape (width, height, depth),
238
        it only changes the number of channels.
239
240
        :param filters: number of channels for output
241
        :param kernel_size: arg for conv3d
242
        :param padding: arg for conv3d
243
        :return: a block consists of one or multiple layers
244
        """
245
        return layer.Conv3dBlock(
246
            filters=filters, kernel_size=kernel_size, padding=padding
247
        )
248
249
    def build_up_sampling_block(
250
        self,
251
        filters: int,
252
        output_padding: int,
253
        kernel_size: int,
254
        padding: str,
255
        strides: int,
256
        output_shape: tuple,
257
    ) -> Union[tf.keras.Model, tfkl.Layer]:
258
        """
259
        Build a block for up-sampling.
260
261
        This block changes the tensor shape (width, height, depth),
262
        but it does not changes the number of channels.
263
264
        :param filters: number of channels for output
265
        :param output_padding: padding for output
266
        :param kernel_size: arg for deconv3d
267
        :param padding: arg for deconv3d
268
        :param strides: arg for deconv3d
269
        :param output_shape: shape of the output tensor
270
        :return: a block consists of one or multiple layers
271
        """
272
273
        if self._use_additive_upsampling:
274
            return AdditiveUpsampling(
275
                filters=filters,
276
                output_padding=output_padding,
277
                kernel_size=kernel_size,
278
                strides=strides,
279
                padding=padding,
280
                output_shape=output_shape,
281
            )
282
283
        return layer.Deconv3dBlock(
284
            filters=filters,
285
            output_padding=output_padding,
286
            kernel_size=kernel_size,
287
            strides=strides,
288
            padding=padding,
289
        )
290
291
    def build_skip_block(self) -> Union[tf.keras.Model, tfkl.Layer]:
292
        """
293
        Build a block for combining skipped tensor and up-sampled one.
294
295
        This block do not change the tensor shape (width, height, depth),
296
        it only changes the number of channels.
297
298
        The input to this block is a list of tensors.
299
300
        :return: a block consists of one or multiple layers
301
        """
302
        return tfkl.Add()
303
304
    def build_output_block(
305
        self,
306
        image_size: Tuple[int],
307
        extract_levels: List[int],
308
        out_channels: int,
309
        out_kernel_initializer: str,
310
        out_activation: str,
311
    ) -> Union[tf.keras.Model, tfkl.Layer]:
312
        """
313
        Build a block for output.
314
315
        The input to this block is a list of tensors.
316
317
        :param image_size: such as (dim1, dim2, dim3)
318
        :param extract_levels: number of extraction levels.
319
        :param out_channels: number of channels for the extractions
320
        :param out_kernel_initializer: initializer to use for kernels.
321
        :param out_activation: activation to use at end layer.
322
        :return: a block consists of one or multiple layers
323
        """
324
        return Extraction(
325
            image_size=image_size,
326
            extract_levels=extract_levels,
327
            out_channels=out_channels,
328
            out_kernel_initializer=out_kernel_initializer,
329
            out_activation=out_activation,
330
        )
331