Passed
Pull Request — main (#656)
by Yunguan
03:16
created

deepreg.model.backbone.local_net   A

Complexity

Total Complexity 10

Size/Duplication

Total Lines 237
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 10
eloc 117
dl 0
loc 237
rs 10
c 0
b 0
f 0

8 Methods

Rating   Name   Duplication   Size   Complexity  
A AdditiveUpsampling.call() 0 5 1
A AdditiveUpsampling.__init__() 0 30 1
A LocalNet.build_output_block() 0 26 1
A LocalNet.__init__() 0 56 2
A LocalNet.build_up_sampling_block() 0 40 2
A LocalNet.build_bottom_block() 0 16 1
A AdditiveUpsampling.get_config() 0 13 1
A LocalNet.get_config() 0 5 1
1
# coding=utf-8
0 ignored issues
show
introduced by
Missing module docstring
Loading history...
2
3
from typing import List, Optional, Tuple, Union
4
5
import tensorflow as tf
6
import tensorflow.keras.layers as tfkl
7
8
from deepreg.model import layer
9
from deepreg.model.backbone.u_net import UNet
10
from deepreg.model.layer import Extraction
11
from deepreg.registry import REGISTRY
12
13
14
class AdditiveUpsampling(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
15
    def __init__(
16
        self,
17
        filters: int,
18
        output_padding: Union[int, Tuple, List],
19
        kernel_size: Union[int, Tuple, List],
20
        padding: str,
21
        strides: Union[int, Tuple, List],
22
        output_shape: Tuple,
23
        name: str = "AdditiveUpsampling",
24
    ):
25
        """
26
        Addictive up-sampling layer.
27
28
        :param filters: number of channels for output
29
        :param output_padding: padding for output
30
        :param kernel_size: arg for deconv3d
31
        :param padding: arg for deconv3d
32
        :param strides: arg for deconv3d
33
        :param output_shape: shape of the output tensor
34
        :param name: name of the layer.
35
        """
36
        super().__init__(name=name)
37
        self.deconv3d = layer.Deconv3dBlock(
38
            filters=filters,
39
            output_padding=output_padding,
40
            kernel_size=kernel_size,
41
            strides=strides,
42
            padding=padding,
43
        )
44
        self.resize = layer.Resize3d(shape=output_shape)
45
46
    def call(self, inputs, **kwargs):
47
        deconved = self.deconv3d(inputs)
48
        resized = self.resize(inputs)
49
        resized = tf.add_n(tf.split(resized, num_or_size_splits=2, axis=4))
0 ignored issues
show
Unused Code introduced by
Argument 'axis' passed by position and keyword in function call
Loading history...
50
        return deconved + resized
51
52
    def get_config(self) -> dict:
53
        """Return the config dictionary for recreating this class."""
54
        config = super().get_config()
55
        deconv_config = self.deconv3d.get_config()
56
        config.update(
57
            filters=deconv_config["filters"],
58
            output_padding=deconv_config["output_padding"],
59
            kernel_size=deconv_config["kernel_size"],
60
            strides=deconv_config["strides"],
61
            padding=deconv_config["padding"],
62
        )
63
        config.update(output_shape=self.resize._shape)
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _shape was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
64
        return config
65
66
67
@REGISTRY.register_backbone(name="local")
68
class LocalNet(UNet):
69
    """
70
    Build LocalNet for image registration.
71
72
    Reference:
73
74
    - Hu, Yipeng, et al.
75
      "Weakly-supervised convolutional neural networks
76
      for multimodal image registration."
77
      Medical image analysis 49 (2018): 1-13.
78
      https://doi.org/10.1016/j.media.2018.07.002
79
80
    - Hu, Yipeng, et al.
81
      "Label-driven weakly-supervised learning
82
      for multimodal deformable image registration,"
83
      https://arxiv.org/abs/1711.01666
84
    """
85
86
    def __init__(
87
        self,
88
        image_size: tuple,
89
        num_channel_initial: int,
90
        extract_levels: Tuple[int],
91
        out_kernel_initializer: str,
92
        out_activation: str,
93
        out_channels: int,
94
        depth: Optional[int] = None,
95
        use_additive_upsampling: bool = True,
96
        pooling: bool = True,
97
        concat_skip: bool = False,
98
        name: str = "LocalNet",
99
        **kwargs,
100
    ):
101
        """
102
        Init.
103
104
        Image is encoded gradually, i from level 0 to D,
105
        then it is decoded gradually, j from level D to 0.
106
        Some of the decoded levels are used for generating extractions.
107
108
        So, extract_levels are between [0, D].
109
110
        :param image_size: such as (dim1, dim2, dim3)
111
        :param num_channel_initial: number of initial channels.
112
        :param extract_levels: from which depths the output will be built.
113
        :param out_kernel_initializer: initializer to use for kernels.
114
        :param out_activation: activation to use at end layer.
115
        :param out_channels: number of channels for the extractions
116
        :param depth: depth of the encoder.
117
        :param use_additive_upsampling: whether use additive up-sampling layer
118
            for decoding.
119
        :param pooling: for down-sampling, use non-parameterized
120
                        pooling if true, otherwise use conv3d
121
        :param concat_skip: when up-sampling, concatenate skipped
122
                            tensor if true, otherwise use addition
123
        :param name: name of the backbone.
124
        :param kwargs: additional arguments.
125
        """
126
        self._use_additive_upsampling = use_additive_upsampling
127
        if depth is None:
128
            depth = max(extract_levels)
129
        kwargs["encode_kernel_sizes"] = [7] + [3] * depth
130
        super().__init__(
131
            image_size=image_size,
132
            num_channel_initial=num_channel_initial,
133
            depth=depth,
134
            extract_levels=extract_levels,
135
            out_kernel_initializer=out_kernel_initializer,
136
            out_activation=out_activation,
137
            out_channels=out_channels,
138
            pooling=pooling,
139
            concat_skip=concat_skip,
140
            name=name,
141
            **kwargs,
142
        )
143
144
    def build_bottom_block(
145
        self, filters: int, kernel_size: int, padding: str
146
    ) -> Union[tf.keras.Model, tfkl.Layer]:
147
        """
148
        Build a block for bottom layer.
149
150
        This block do not change the tensor shape (width, height, depth),
151
        it only changes the number of channels.
152
153
        :param filters: number of channels for output
154
        :param kernel_size: arg for conv3d
155
        :param padding: arg for conv3d
156
        :return: a block consists of one or multiple layers
157
        """
158
        return layer.Conv3dBlock(
159
            filters=filters, kernel_size=kernel_size, padding=padding
160
        )
161
162
    def build_up_sampling_block(
163
        self,
164
        filters: int,
165
        output_padding: int,
166
        kernel_size: int,
167
        padding: str,
168
        strides: int,
169
        output_shape: tuple,
170
    ) -> Union[tf.keras.Model, tfkl.Layer]:
171
        """
172
        Build a block for up-sampling.
173
174
        This block changes the tensor shape (width, height, depth),
175
        but it does not changes the number of channels.
176
177
        :param filters: number of channels for output
178
        :param output_padding: padding for output
179
        :param kernel_size: arg for deconv3d
180
        :param padding: arg for deconv3d
181
        :param strides: arg for deconv3d
182
        :param output_shape: shape of the output tensor
183
        :return: a block consists of one or multiple layers
184
        """
185
186
        if self._use_additive_upsampling:
187
            return AdditiveUpsampling(
188
                filters=filters,
189
                output_padding=output_padding,
190
                kernel_size=kernel_size,
191
                strides=strides,
192
                padding=padding,
193
                output_shape=output_shape,
194
            )
195
196
        return layer.Deconv3dBlock(
197
            filters=filters,
198
            output_padding=output_padding,
199
            kernel_size=kernel_size,
200
            strides=strides,
201
            padding=padding,
202
        )
203
204
    def build_output_block(
205
        self,
206
        image_size: Tuple[int],
207
        extract_levels: Tuple[int],
208
        out_channels: int,
209
        out_kernel_initializer: str,
210
        out_activation: str,
211
    ) -> Union[tf.keras.Model, tfkl.Layer]:
212
        """
213
        Build a block for output.
214
215
        The input to this block is a list of tensors.
216
217
        :param image_size: such as (dim1, dim2, dim3)
218
        :param extract_levels: number of extraction levels.
219
        :param out_channels: number of channels for the extractions
220
        :param out_kernel_initializer: initializer to use for kernels.
221
        :param out_activation: activation to use at end layer.
222
        :return: a block consists of one or multiple layers
223
        """
224
        return Extraction(
225
            image_size=image_size,
226
            extract_levels=extract_levels,
227
            out_channels=out_channels,
228
            out_kernel_initializer=out_kernel_initializer,
229
            out_activation=out_activation,
230
        )
231
232
    def get_config(self) -> dict:
233
        """Return the config dictionary for recreating this class."""
234
        config = super().get_config()
235
        config.update(use_additive_upsampling=self._use_additive_upsampling)
236
        return config
237