Passed
Pull Request — main (#656)
by Yunguan
03:01
created

deepreg.model.backbone.local_net.LocalNet.call()   A

Complexity

Conditions 3

Size

Total Lines 50
Code Lines 24

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 24
dl 0
loc 50
rs 9.304
c 0
b 0
f 0
cc 3
nop 4
1
# coding=utf-8
0 ignored issues
show
introduced by
Missing module docstring
Loading history...
2
3
from typing import List, Optional, Tuple, Union
4
5
import tensorflow as tf
6
import tensorflow.keras.layers as tfkl
7
8
from deepreg.model import layer
9
from deepreg.model.backbone.u_net import UNet
10
from deepreg.model.layer import Extraction
11
from deepreg.registry import REGISTRY
12
13
14
class AdditiveUpsampling(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
15
    def __init__(
16
        self,
17
        filters: int,
18
        output_padding: Union[int, Tuple, List],
19
        kernel_size: Union[int, Tuple, List],
20
        padding: str,
21
        strides: Union[int, Tuple, List],
22
        output_shape: Tuple,
23
        name: str = "AdditiveUpsampling",
24
    ):
25
        """
26
        Addictive up-sampling layer.
27
28
        :param filters: number of channels for output
29
        :param output_padding: padding for output
30
        :param kernel_size: arg for deconv3d
31
        :param padding: arg for deconv3d
32
        :param strides: arg for deconv3d
33
        :param output_shape: shape of the output tensor
34
        :param name: name of the layer.
35
        """
36
        super().__init__(name=name)
37
        self.deconv3d = layer.Deconv3dBlock(
38
            filters=filters,
39
            output_padding=output_padding,
40
            kernel_size=kernel_size,
41
            strides=strides,
42
            padding=padding,
43
        )
44
        self.resize = layer.Resize3d(shape=output_shape)
45
46
    def call(self, inputs, **kwargs):
47
        deconved = self.deconv3d(inputs)
48
        resized = self.resize(inputs)
49
        resized = tf.add_n(tf.split(resized, num_or_size_splits=2, axis=4))
0 ignored issues
show
Unused Code introduced by
Argument 'axis' passed by position and keyword in function call
Loading history...
50
        return deconved + resized
51
52
    def get_config(self) -> dict:
53
        """Return the config dictionary for recreating this class."""
54
        config = super().get_config()
55
        deconv_config = self.deconv3d.get_config()
56
        config.update(
57
            filters=deconv_config["filters"],
58
            output_padding=deconv_config["output_padding"],
59
            kernel_size=deconv_config["kernel_size"],
60
            strides=deconv_config["strides"],
61
            padding=deconv_config["padding"],
62
        )
63
        config.update(output_shape=self.resize._shape)
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _shape was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
64
        return config
65
66
67
@REGISTRY.register_backbone(name="local")
68
class LocalNet(UNet):
69
    """
70
    Build LocalNet for image registration.
71
72
    Reference:
73
74
    - Hu, Yipeng, et al.
75
      "Weakly-supervised convolutional neural networks
76
      for multimodal image registration."
77
      Medical image analysis 49 (2018): 1-13.
78
      https://doi.org/10.1016/j.media.2018.07.002
79
80
    - Hu, Yipeng, et al.
81
      "Label-driven weakly-supervised learning
82
      for multimodal deformable image registration,"
83
      https://arxiv.org/abs/1711.01666
84
    """
85
86
    def __init__(
87
        self,
88
        image_size: tuple,
89
        num_channel_initial: int,
90
        extract_levels: Tuple[int],
91
        out_kernel_initializer: str,
92
        out_activation: str,
93
        out_channels: int,
94
        depth: Optional[int] = None,
95
        use_additive_upsampling: bool = True,
96
        pooling: bool = True,
97
        concat_skip: bool = False,
98
        name: str = "LocalNet",
99
        **kwargs,
100
    ):
101
        """
102
        Init.
103
104
        Image is encoded gradually, i from level 0 to D,
105
        then it is decoded gradually, j from level D to 0.
106
        Some of the decoded levels are used for generating extractions.
107
108
        So, extract_levels are between [0, D].
109
110
        :param image_size: such as (dim1, dim2, dim3)
111
        :param num_channel_initial: number of initial channels.
112
        :param extract_levels: from which depths the output will be built.
113
        :param out_kernel_initializer: initializer to use for kernels.
114
        :param out_activation: activation to use at end layer.
115
        :param out_channels: number of channels for the extractions
116
        :param depth: depth of the encoder.
117
            If depth is not given, depth = max(extract_levels) will be used.
118
        :param use_additive_upsampling: whether use additive up-sampling layer
119
            for decoding.
120
        :param pooling: for down-sampling, use non-parameterized
121
                        pooling if true, otherwise use conv3d
122
        :param concat_skip: when up-sampling, concatenate skipped
123
                            tensor if true, otherwise use addition
124
        :param name: name of the backbone.
125
        :param kwargs: additional arguments.
126
        """
127
        self._use_additive_upsampling = use_additive_upsampling
128
        if depth is None:
129
            depth = max(extract_levels)
130
        kwargs["encode_kernel_sizes"] = [7] + [3] * depth
131
        super().__init__(
132
            image_size=image_size,
133
            num_channel_initial=num_channel_initial,
134
            depth=depth,
135
            extract_levels=extract_levels,
136
            out_kernel_initializer=out_kernel_initializer,
137
            out_activation=out_activation,
138
            out_channels=out_channels,
139
            pooling=pooling,
140
            concat_skip=concat_skip,
141
            name=name,
142
            **kwargs,
143
        )
144
145
    def build_bottom_block(
146
        self, filters: int, kernel_size: int, padding: str
147
    ) -> Union[tf.keras.Model, tfkl.Layer]:
148
        """
149
        Build a block for bottom layer.
150
151
        This block do not change the tensor shape (width, height, depth),
152
        it only changes the number of channels.
153
154
        :param filters: number of channels for output
155
        :param kernel_size: arg for conv3d
156
        :param padding: arg for conv3d
157
        :return: a block consists of one or multiple layers
158
        """
159
        return layer.Conv3dBlock(
160
            filters=filters, kernel_size=kernel_size, padding=padding
161
        )
162
163
    def build_up_sampling_block(
164
        self,
165
        filters: int,
166
        output_padding: int,
167
        kernel_size: int,
168
        padding: str,
169
        strides: int,
170
        output_shape: tuple,
171
    ) -> Union[tf.keras.Model, tfkl.Layer]:
172
        """
173
        Build a block for up-sampling.
174
175
        This block changes the tensor shape (width, height, depth),
176
        but it does not changes the number of channels.
177
178
        :param filters: number of channels for output
179
        :param output_padding: padding for output
180
        :param kernel_size: arg for deconv3d
181
        :param padding: arg for deconv3d
182
        :param strides: arg for deconv3d
183
        :param output_shape: shape of the output tensor
184
        :return: a block consists of one or multiple layers
185
        """
186
187
        if self._use_additive_upsampling:
188
            return AdditiveUpsampling(
189
                filters=filters,
190
                output_padding=output_padding,
191
                kernel_size=kernel_size,
192
                strides=strides,
193
                padding=padding,
194
                output_shape=output_shape,
195
            )
196
197
        return layer.Deconv3dBlock(
198
            filters=filters,
199
            output_padding=output_padding,
200
            kernel_size=kernel_size,
201
            strides=strides,
202
            padding=padding,
203
        )
204
205
    def build_output_block(
206
        self,
207
        image_size: Tuple[int],
208
        extract_levels: Tuple[int],
209
        out_channels: int,
210
        out_kernel_initializer: str,
211
        out_activation: str,
212
    ) -> Union[tf.keras.Model, tfkl.Layer]:
213
        """
214
        Build a block for output.
215
216
        The input to this block is a list of tensors.
217
218
        :param image_size: such as (dim1, dim2, dim3)
219
        :param extract_levels: number of extraction levels.
220
        :param out_channels: number of channels for the extractions
221
        :param out_kernel_initializer: initializer to use for kernels.
222
        :param out_activation: activation to use at end layer.
223
        :return: a block consists of one or multiple layers
224
        """
225
        return Extraction(
226
            image_size=image_size,
227
            extract_levels=extract_levels,
228
            out_channels=out_channels,
229
            out_kernel_initializer=out_kernel_initializer,
230
            out_activation=out_activation,
231
        )
232
233
    def get_config(self) -> dict:
234
        """Return the config dictionary for recreating this class."""
235
        config = super().get_config()
236
        config.update(use_additive_upsampling=self._use_additive_upsampling)
237
        return config
238