Passed
Pull Request — main (#656)
by Yunguan
25:40
created

deepreg.model.backbone.local_net.LocalNet.call()   A

Complexity

Conditions 3

Size

Total Lines 50
Code Lines 24

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 24
dl 0
loc 50
rs 9.304
c 0
b 0
f 0
cc 3
nop 4
1
# coding=utf-8
0 ignored issues
show
introduced by
Missing module docstring
Loading history...
2
3
from typing import Optional, Tuple, Union
4
5
import tensorflow as tf
6
import tensorflow.keras.layers as tfkl
7
8
from deepreg.model import layer
9
from deepreg.model.backbone.u_net import UNet
10
from deepreg.model.layer import Extraction
11
from deepreg.registry import REGISTRY
12
13
14
class AdditiveUpsampling(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
15
    def __init__(
16
        self,
17
        filters: int,
18
        output_padding: int,
19
        kernel_size: int,
20
        padding: str,
21
        strides: int,
22
        output_shape: tuple,
23
        name: str = "AdditiveUpsampling",
24
    ):
25
        """
26
        Addictive up-sampling layer.
27
28
        :param filters: number of channels for output
29
        :param output_padding: padding for output
30
        :param kernel_size: arg for deconv3d
31
        :param padding: arg for deconv3d
32
        :param strides: arg for deconv3d
33
        :param output_shape: shape of the output tensor
34
        :param name: name of the layer.
35
        """
36
        super().__init__(name=name)
37
        self.deconv3d = layer.Deconv3dBlock(
38
            filters=filters,
39
            output_padding=output_padding,
40
            kernel_size=kernel_size,
41
            strides=strides,
42
            padding=padding,
43
        )
44
        self.resize = layer.Resize3d(shape=output_shape)
45
46
    def call(self, inputs, **kwargs):
47
        deconved = self.deconv3d(inputs)
48
        resized = self.resize(inputs)
49
        resized = tf.add_n(tf.split(resized, num_or_size_splits=2, axis=4))
0 ignored issues
show
Unused Code introduced by
Argument 'axis' passed by position and keyword in function call
Loading history...
50
        return deconved + resized
51
52
53
@REGISTRY.register_backbone(name="local")
54
class LocalNet(UNet):
55
    """
56
    Build LocalNet for image registration.
57
58
    Reference:
59
60
    - Hu, Yipeng, et al.
61
      "Weakly-supervised convolutional neural networks
62
      for multimodal image registration."
63
      Medical image analysis 49 (2018): 1-13.
64
      https://doi.org/10.1016/j.media.2018.07.002
65
66
    - Hu, Yipeng, et al.
67
      "Label-driven weakly-supervised learning
68
      for multimodal deformable image registration,"
69
      https://arxiv.org/abs/1711.01666
70
    """
71
72
    def __init__(
0 ignored issues
show
introduced by
"depth" missing in parameter documentation
Loading history...
73
        self,
74
        image_size: tuple,
75
        num_channel_initial: int,
76
        extract_levels: Tuple[int],
77
        out_kernel_initializer: str,
78
        out_activation: str,
79
        out_channels: int,
80
        depth: Optional[int] = None,
81
        use_additive_upsampling: bool = True,
82
        pooling: bool = True,
83
        concat_skip: bool = False,
84
        name: str = "LocalNet",
85
        **kwargs,
86
    ):
87
        """
88
        Init.
89
90
        Image is encoded gradually, i from level 0 to D,
91
        then it is decoded gradually, j from level D to 0.
92
        Some of the decoded levels are used for generating extractions.
93
94
        So, extract_levels are between [0, D].
95
96
        :param image_size: such as (dim1, dim2, dim3)
97
        :param num_channel_initial: number of initial channels.
98
        :param extract_levels: from which depths the output will be built.
99
        :param out_kernel_initializer: initializer to use for kernels.
100
        :param out_activation: activation to use at end layer.
101
        :param out_channels: number of channels for the extractions
102
        :param use_additive_upsampling: whether use additive up-sampling.
103
        :param pooling: for down-sampling, use non-parameterized
104
                        pooling if true, otherwise use conv3d
105
        :param concat_skip: when up-sampling, concatenate skipped
106
                            tensor if true, otherwise use addition
107
        :param name: name of the backbone.
108
        :param kwargs: additional arguments.
109
        """
110
        self._use_additive_upsampling = use_additive_upsampling
111
        if depth is None:
112
            depth = max(extract_levels)
113
        super().__init__(
114
            image_size=image_size,
115
            num_channel_initial=num_channel_initial,
116
            depth=depth,
117
            extract_levels=extract_levels,
118
            out_kernel_initializer=out_kernel_initializer,
119
            out_activation=out_activation,
120
            out_channels=out_channels,
121
            pooling=pooling,
122
            concat_skip=concat_skip,
123
            encode_kernel_sizes=[7] + [3] * depth,
124
            name=name,
125
            **kwargs,
126
        )
127
128
    def build_bottom_block(
129
        self, filters: int, kernel_size: int, padding: str
130
    ) -> Union[tf.keras.Model, tfkl.Layer]:
131
        """
132
        Build a block for bottom layer.
133
134
        This block do not change the tensor shape (width, height, depth),
135
        it only changes the number of channels.
136
137
        :param filters: number of channels for output
138
        :param kernel_size: arg for conv3d
139
        :param padding: arg for conv3d
140
        :return: a block consists of one or multiple layers
141
        """
142
        return layer.Conv3dBlock(
143
            filters=filters, kernel_size=kernel_size, padding=padding
144
        )
145
146
    def build_up_sampling_block(
147
        self,
148
        filters: int,
149
        output_padding: int,
150
        kernel_size: int,
151
        padding: str,
152
        strides: int,
153
        output_shape: tuple,
154
    ) -> Union[tf.keras.Model, tfkl.Layer]:
155
        """
156
        Build a block for up-sampling.
157
158
        This block changes the tensor shape (width, height, depth),
159
        but it does not changes the number of channels.
160
161
        :param filters: number of channels for output
162
        :param output_padding: padding for output
163
        :param kernel_size: arg for deconv3d
164
        :param padding: arg for deconv3d
165
        :param strides: arg for deconv3d
166
        :param output_shape: shape of the output tensor
167
        :return: a block consists of one or multiple layers
168
        """
169
170
        if self._use_additive_upsampling:
171
            return AdditiveUpsampling(
172
                filters=filters,
173
                output_padding=output_padding,
174
                kernel_size=kernel_size,
175
                strides=strides,
176
                padding=padding,
177
                output_shape=output_shape,
178
            )
179
180
        return layer.Deconv3dBlock(
181
            filters=filters,
182
            output_padding=output_padding,
183
            kernel_size=kernel_size,
184
            strides=strides,
185
            padding=padding,
186
        )
187
188
    def build_output_block(
189
        self,
190
        image_size: Tuple[int],
191
        extract_levels: Tuple[int],
192
        out_channels: int,
193
        out_kernel_initializer: str,
194
        out_activation: str,
195
    ) -> Union[tf.keras.Model, tfkl.Layer]:
196
        """
197
        Build a block for output.
198
199
        The input to this block is a list of tensors.
200
201
        :param image_size: such as (dim1, dim2, dim3)
202
        :param extract_levels: number of extraction levels.
203
        :param out_channels: number of channels for the extractions
204
        :param out_kernel_initializer: initializer to use for kernels.
205
        :param out_activation: activation to use at end layer.
206
        :return: a block consists of one or multiple layers
207
        """
208
        return Extraction(
209
            image_size=image_size,
210
            extract_levels=extract_levels,
211
            out_channels=out_channels,
212
            out_kernel_initializer=out_kernel_initializer,
213
            out_activation=out_activation,
214
        )
215