Passed
Pull Request — main (#656)
by Yunguan
02:51
created

deepreg.model.backbone.local_net.LocalNet.call()   A

Complexity

Conditions 3

Size

Total Lines 50
Code Lines 24

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 24
dl 0
loc 50
rs 9.304
c 0
b 0
f 0
cc 3
nop 4
1
# coding=utf-8
0 ignored issues
show
introduced by
Missing module docstring
Loading history...
2
3
from typing import Optional, Tuple, Union
4
5
import tensorflow as tf
6
import tensorflow.keras.layers as tfkl
7
8
from deepreg.model import layer
9
from deepreg.model.backbone.u_net import UNet
10
from deepreg.model.layer import Extraction
11
from deepreg.registry import REGISTRY
12
13
14
class AdditiveUpsampling(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
15
    def __init__(
16
        self,
17
        filters: int,
18
        output_padding: int,
19
        kernel_size: int,
20
        padding: str,
21
        strides: int,
22
        output_shape: tuple,
23
        name: str = "AdditiveUpsampling",
24
    ):
25
        """
26
        Addictive up-sampling layer.
27
28
        :param filters: number of channels for output
29
        :param output_padding: padding for output
30
        :param kernel_size: arg for deconv3d
31
        :param padding: arg for deconv3d
32
        :param strides: arg for deconv3d
33
        :param output_shape: shape of the output tensor
34
        :param name: name of the layer.
35
        """
36
        super().__init__(name=name)
37
        self.deconv3d = layer.Deconv3dBlock(
38
            filters=filters,
39
            output_padding=output_padding,
40
            kernel_size=kernel_size,
41
            strides=strides,
42
            padding=padding,
43
        )
44
        self.resize = layer.Resize3d(shape=output_shape)
45
46
    def call(self, inputs, **kwargs):
47
        deconved = self.deconv3d(inputs)
48
        resized = self.resize(inputs)
49
        resized = tf.add_n(tf.split(resized, num_or_size_splits=2, axis=4))
0 ignored issues
show
Unused Code introduced by
Argument 'axis' passed by position and keyword in function call
Loading history...
50
        return deconved + resized
51
52
53
@REGISTRY.register_backbone(name="local")
54
class LocalNet(UNet):
55
    """
56
    Build LocalNet for image registration.
57
58
    Reference:
59
60
    - Hu, Yipeng, et al.
61
      "Weakly-supervised convolutional neural networks
62
      for multimodal image registration."
63
      Medical image analysis 49 (2018): 1-13.
64
      https://doi.org/10.1016/j.media.2018.07.002
65
66
    - Hu, Yipeng, et al.
67
      "Label-driven weakly-supervised learning
68
      for multimodal deformable image registration,"
69
      https://arxiv.org/abs/1711.01666
70
    """
71
72
    def __init__(
73
        self,
74
        image_size: tuple,
75
        num_channel_initial: int,
76
        extract_levels: Tuple[int],
77
        out_kernel_initializer: str,
78
        out_activation: str,
79
        out_channels: int,
80
        depth: Optional[int] = None,
81
        use_additive_upsampling: bool = True,
82
        pooling: bool = True,
83
        concat_skip: bool = False,
84
        name: str = "LocalNet",
85
        **kwargs,
86
    ):
87
        """
88
        Init.
89
90
        Image is encoded gradually, i from level 0 to D,
91
        then it is decoded gradually, j from level D to 0.
92
        Some of the decoded levels are used for generating extractions.
93
94
        So, extract_levels are between [0, D].
95
96
        :param image_size: such as (dim1, dim2, dim3)
97
        :param num_channel_initial: number of initial channels.
98
        :param extract_levels: from which depths the output will be built.
99
        :param out_kernel_initializer: initializer to use for kernels.
100
        :param out_activation: activation to use at end layer.
101
        :param out_channels: number of channels for the extractions
102
        :param depth: depth of the encoder.
103
        :param use_additive_upsampling: whether use additive up-sampling.
104
        :param pooling: for down-sampling, use non-parameterized
105
                        pooling if true, otherwise use conv3d
106
        :param concat_skip: when up-sampling, concatenate skipped
107
                            tensor if true, otherwise use addition
108
        :param name: name of the backbone.
109
        :param kwargs: additional arguments.
110
        """
111
        self._use_additive_upsampling = use_additive_upsampling
112
        if depth is None:
113
            depth = max(extract_levels)
114
        super().__init__(
115
            image_size=image_size,
116
            num_channel_initial=num_channel_initial,
117
            depth=depth,
118
            extract_levels=extract_levels,
119
            out_kernel_initializer=out_kernel_initializer,
120
            out_activation=out_activation,
121
            out_channels=out_channels,
122
            pooling=pooling,
123
            concat_skip=concat_skip,
124
            encode_kernel_sizes=[7] + [3] * depth,
125
            name=name,
126
            **kwargs,
127
        )
128
129
    def build_bottom_block(
130
        self, filters: int, kernel_size: int, padding: str
131
    ) -> Union[tf.keras.Model, tfkl.Layer]:
132
        """
133
        Build a block for bottom layer.
134
135
        This block do not change the tensor shape (width, height, depth),
136
        it only changes the number of channels.
137
138
        :param filters: number of channels for output
139
        :param kernel_size: arg for conv3d
140
        :param padding: arg for conv3d
141
        :return: a block consists of one or multiple layers
142
        """
143
        return layer.Conv3dBlock(
144
            filters=filters, kernel_size=kernel_size, padding=padding
145
        )
146
147
    def build_up_sampling_block(
148
        self,
149
        filters: int,
150
        output_padding: int,
151
        kernel_size: int,
152
        padding: str,
153
        strides: int,
154
        output_shape: tuple,
155
    ) -> Union[tf.keras.Model, tfkl.Layer]:
156
        """
157
        Build a block for up-sampling.
158
159
        This block changes the tensor shape (width, height, depth),
160
        but it does not changes the number of channels.
161
162
        :param filters: number of channels for output
163
        :param output_padding: padding for output
164
        :param kernel_size: arg for deconv3d
165
        :param padding: arg for deconv3d
166
        :param strides: arg for deconv3d
167
        :param output_shape: shape of the output tensor
168
        :return: a block consists of one or multiple layers
169
        """
170
171
        if self._use_additive_upsampling:
172
            return AdditiveUpsampling(
173
                filters=filters,
174
                output_padding=output_padding,
175
                kernel_size=kernel_size,
176
                strides=strides,
177
                padding=padding,
178
                output_shape=output_shape,
179
            )
180
181
        return layer.Deconv3dBlock(
182
            filters=filters,
183
            output_padding=output_padding,
184
            kernel_size=kernel_size,
185
            strides=strides,
186
            padding=padding,
187
        )
188
189
    def build_output_block(
190
        self,
191
        image_size: Tuple[int],
192
        extract_levels: Tuple[int],
193
        out_channels: int,
194
        out_kernel_initializer: str,
195
        out_activation: str,
196
    ) -> Union[tf.keras.Model, tfkl.Layer]:
197
        """
198
        Build a block for output.
199
200
        The input to this block is a list of tensors.
201
202
        :param image_size: such as (dim1, dim2, dim3)
203
        :param extract_levels: number of extraction levels.
204
        :param out_channels: number of channels for the extractions
205
        :param out_kernel_initializer: initializer to use for kernels.
206
        :param out_activation: activation to use at end layer.
207
        :return: a block consists of one or multiple layers
208
        """
209
        return Extraction(
210
            image_size=image_size,
211
            extract_levels=extract_levels,
212
            out_channels=out_channels,
213
            out_kernel_initializer=out_kernel_initializer,
214
            out_activation=out_activation,
215
        )
216