Passed
Pull Request — main (#656)
by Yunguan
02:56
created

LocalNet.__init__()   A

Complexity

Conditions 2

Size

Total Lines 56
Code Lines 30

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 30
dl 0
loc 56
rs 9.16
c 0
b 0
f 0
cc 2
nop 13

How to fix   Long Method    Many Parameters   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
# coding=utf-8
0 ignored issues
show
introduced by
Missing module docstring
Loading history...
2
3
from typing import List, Optional, Tuple, Union
4
5
import tensorflow as tf
6
import tensorflow.keras.layers as tfkl
7
8
from deepreg.model import layer
9
from deepreg.model.backbone.u_net import UNet
10
from deepreg.model.layer import Extraction
11
from deepreg.registry import REGISTRY
12
13
14
class AdditiveUpsampling(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
15
    def __init__(
16
        self,
17
        filters: int,
18
        output_padding: Union[int, Tuple, List],
19
        kernel_size: Union[int, Tuple, List],
20
        padding: str,
21
        strides: Union[int, Tuple, List],
22
        output_shape: Tuple,
23
        name: str = "AdditiveUpsampling",
24
    ):
25
        """
26
        Addictive up-sampling layer.
27
28
        :param filters: number of channels for output
29
        :param output_padding: padding for output
30
        :param kernel_size: arg for deconv3d
31
        :param padding: arg for deconv3d
32
        :param strides: arg for deconv3d
33
        :param output_shape: shape of the output tensor
34
        :param name: name of the layer.
35
        """
36
        super().__init__(name=name)
37
        self.deconv3d = layer.Deconv3dBlock(
38
            filters=filters,
39
            output_padding=output_padding,
40
            kernel_size=kernel_size,
41
            strides=strides,
42
            padding=padding,
43
        )
44
        self.resize = layer.Resize3d(shape=output_shape)
45
46
    def call(self, inputs, **kwargs):
47
        deconved = self.deconv3d(inputs)
48
        resized = self.resize(inputs)
49
        resized = tf.add_n(tf.split(resized, num_or_size_splits=2, axis=4))
0 ignored issues
show
Unused Code introduced by
Argument 'axis' passed by position and keyword in function call
Loading history...
50
        return deconved + resized
51
52
    def get_config(self) -> dict:
53
        """Return the config dictionary for recreating this class."""
54
        config = super().get_config()
55
        deconv_config = self.deconv3d.get_config()
56
        config.update(
57
            filters=deconv_config["filters"],
58
            output_padding=deconv_config["output_padding"],
59
            kernel_size=deconv_config["kernel_size"],
60
            strides=deconv_config["strides"],
61
            padding=deconv_config["padding"],
62
        )
63
        config.update(output_shape=self.resize._shape)
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _shape was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
64
        return config
65
66
67
@REGISTRY.register_backbone(name="local")
68
class LocalNet(UNet):
69
    """
70
    Build LocalNet for image registration.
71
72
    Reference:
73
74
    - Hu, Yipeng, et al.
75
      "Weakly-supervised convolutional neural networks
76
      for multimodal image registration."
77
      Medical image analysis 49 (2018): 1-13.
78
      https://doi.org/10.1016/j.media.2018.07.002
79
80
    - Hu, Yipeng, et al.
81
      "Label-driven weakly-supervised learning
82
      for multimodal deformable image registration,"
83
      https://arxiv.org/abs/1711.01666
84
    """
85
86
    def __init__(
87
        self,
88
        image_size: tuple,
89
        num_channel_initial: int,
90
        extract_levels: Tuple[int],
91
        out_kernel_initializer: str,
92
        out_activation: str,
93
        out_channels: int,
94
        depth: Optional[int] = None,
95
        use_additive_upsampling: bool = True,
96
        pooling: bool = True,
97
        concat_skip: bool = False,
98
        name: str = "LocalNet",
99
        **kwargs,
100
    ):
101
        """
102
        Init.
103
104
        Image is encoded gradually, i from level 0 to D,
105
        then it is decoded gradually, j from level D to 0.
106
        Some of the decoded levels are used for generating extractions.
107
108
        So, extract_levels are between [0, D].
109
110
        :param image_size: such as (dim1, dim2, dim3)
111
        :param num_channel_initial: number of initial channels.
112
        :param extract_levels: from which depths the output will be built.
113
        :param out_kernel_initializer: initializer to use for kernels.
114
        :param out_activation: activation to use at end layer.
115
        :param out_channels: number of channels for the extractions
116
        :param depth: depth of the encoder.
117
        :param use_additive_upsampling: whether use additive up-sampling layer
118
            for decoding.
119
        :param pooling: for down-sampling, use non-parameterized
120
                        pooling if true, otherwise use conv3d
121
        :param concat_skip: when up-sampling, concatenate skipped
122
                            tensor if true, otherwise use addition
123
        :param name: name of the backbone.
124
        :param kwargs: additional arguments.
125
        """
126
        self._use_additive_upsampling = use_additive_upsampling
127
        if depth is None:
128
            depth = max(extract_levels)
129
        kwargs["encode_kernel_sizes"] = [7] + [3] * depth
130
        super().__init__(
131
            image_size=image_size,
132
            num_channel_initial=num_channel_initial,
133
            depth=depth,
134
            extract_levels=extract_levels,
135
            out_kernel_initializer=out_kernel_initializer,
136
            out_activation=out_activation,
137
            out_channels=out_channels,
138
            pooling=pooling,
139
            concat_skip=concat_skip,
140
            name=name,
141
            **kwargs,
142
        )
143
144
    def build_bottom_block(
145
        self, filters: int, kernel_size: int, padding: str
146
    ) -> Union[tf.keras.Model, tfkl.Layer]:
147
        """
148
        Build a block for bottom layer.
149
150
        This block do not change the tensor shape (width, height, depth),
151
        it only changes the number of channels.
152
153
        :param filters: number of channels for output
154
        :param kernel_size: arg for conv3d
155
        :param padding: arg for conv3d
156
        :return: a block consists of one or multiple layers
157
        """
158
        return layer.Conv3dBlock(
159
            filters=filters, kernel_size=kernel_size, padding=padding
160
        )
161
162
    def build_up_sampling_block(
163
        self,
164
        filters: int,
165
        output_padding: int,
166
        kernel_size: int,
167
        padding: str,
168
        strides: int,
169
        output_shape: tuple,
170
    ) -> Union[tf.keras.Model, tfkl.Layer]:
171
        """
172
        Build a block for up-sampling.
173
174
        This block changes the tensor shape (width, height, depth),
175
        but it does not changes the number of channels.
176
177
        :param filters: number of channels for output
178
        :param output_padding: padding for output
179
        :param kernel_size: arg for deconv3d
180
        :param padding: arg for deconv3d
181
        :param strides: arg for deconv3d
182
        :param output_shape: shape of the output tensor
183
        :return: a block consists of one or multiple layers
184
        """
185
186
        if self._use_additive_upsampling:
187
            return AdditiveUpsampling(
188
                filters=filters,
189
                output_padding=output_padding,
190
                kernel_size=kernel_size,
191
                strides=strides,
192
                padding=padding,
193
                output_shape=output_shape,
194
            )
195
196
        return layer.Deconv3dBlock(
197
            filters=filters,
198
            output_padding=output_padding,
199
            kernel_size=kernel_size,
200
            strides=strides,
201
            padding=padding,
202
        )
203
204
    def build_output_block(
205
        self,
206
        image_size: Tuple[int],
207
        extract_levels: Tuple[int],
208
        out_channels: int,
209
        out_kernel_initializer: str,
210
        out_activation: str,
211
    ) -> Union[tf.keras.Model, tfkl.Layer]:
212
        """
213
        Build a block for output.
214
215
        The input to this block is a list of tensors.
216
217
        :param image_size: such as (dim1, dim2, dim3)
218
        :param extract_levels: number of extraction levels.
219
        :param out_channels: number of channels for the extractions
220
        :param out_kernel_initializer: initializer to use for kernels.
221
        :param out_activation: activation to use at end layer.
222
        :return: a block consists of one or multiple layers
223
        """
224
        return Extraction(
225
            image_size=image_size,
226
            extract_levels=extract_levels,
227
            out_channels=out_channels,
228
            out_kernel_initializer=out_kernel_initializer,
229
            out_activation=out_activation,
230
        )
231
232
    def get_config(self) -> dict:
233
        """Return the config dictionary for recreating this class."""
234
        config = super().get_config()
235
        config.update(use_additive_upsampling=self._use_additive_upsampling)
236
        return config
237