Passed
Pull Request — main (#656)
by Yunguan
03:01
created

AdditiveUpsampling.__init__()   A

Complexity

Conditions 1

Size

Total Lines 30
Code Lines 17

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 17
dl 0
loc 30
rs 9.55
c 0
b 0
f 0
cc 1
nop 8

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
# coding=utf-8
0 ignored issues
show
introduced by
Missing module docstring
Loading history...
2
3
from typing import List, Optional, Tuple, Union
4
5
import tensorflow as tf
6
import tensorflow.keras.layers as tfkl
7
8
from deepreg.model import layer
9
from deepreg.model.backbone.u_net import UNet
10
from deepreg.model.layer import Extraction
11
from deepreg.registry import REGISTRY
12
13
14
class AdditiveUpsampling(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
15
    def __init__(
16
        self,
17
        filters: int,
18
        output_padding: Union[int, Tuple, List],
19
        kernel_size: Union[int, Tuple, List],
20
        padding: str,
21
        strides: Union[int, Tuple, List],
22
        output_shape: Tuple,
23
        name: str = "AdditiveUpsampling",
24
    ):
25
        """
26
        Addictive up-sampling layer.
27
28
        :param filters: number of channels for output
29
        :param output_padding: padding for output
30
        :param kernel_size: arg for deconv3d
31
        :param padding: arg for deconv3d
32
        :param strides: arg for deconv3d
33
        :param output_shape: shape of the output tensor
34
        :param name: name of the layer.
35
        """
36
        super().__init__(name=name)
37
        self.deconv3d = layer.Deconv3dBlock(
38
            filters=filters,
39
            output_padding=output_padding,
40
            kernel_size=kernel_size,
41
            strides=strides,
42
            padding=padding,
43
        )
44
        self.resize = layer.Resize3d(shape=output_shape)
45
46
    def call(self, inputs, **kwargs):
47
        deconved = self.deconv3d(inputs)
48
        resized = self.resize(inputs)
49
        resized = tf.add_n(tf.split(resized, num_or_size_splits=2, axis=4))
0 ignored issues
show
Unused Code introduced by
Argument 'axis' passed by position and keyword in function call
Loading history...
50
        return deconved + resized
51
52
    def get_config(self) -> dict:
53
        """Return the config dictionary for recreating this class."""
54
        config = super().get_config()
55
        deconv_config = self.deconv3d.get_config()
56
        config.update(
57
            filters=deconv_config["filters"],
58
            output_padding=deconv_config["output_padding"],
59
            kernel_size=deconv_config["kernel_size"],
60
            strides=deconv_config["strides"],
61
            padding=deconv_config["padding"],
62
        )
63
        config.update(output_shape=self.resize._shape)
0 ignored issues
show
Coding Style Best Practice introduced by
It seems like _shape was declared protected and should not be accessed from this context.

Prefixing a member variable _ is usually regarded as the equivalent of declaring it with protected visibility that exists in other languages. Consequentially, such a member should only be accessed from the same class or a child class:

class MyParent:
    def __init__(self):
        self._x = 1;
        self.y = 2;

class MyChild(MyParent):
    def some_method(self):
        return self._x    # Ok, since accessed from a child class

class AnotherClass:
    def some_method(self, instance_of_my_child):
        return instance_of_my_child._x   # Would be flagged as AnotherClass is not
                                         # a child class of MyParent
Loading history...
64
        return config
65
66
67
@REGISTRY.register_backbone(name="local")
68
class LocalNet(UNet):
69
    """
70
    Build LocalNet for image registration.
71
72
    Reference:
73
74
    - Hu, Yipeng, et al.
75
      "Weakly-supervised convolutional neural networks
76
      for multimodal image registration."
77
      Medical image analysis 49 (2018): 1-13.
78
      https://doi.org/10.1016/j.media.2018.07.002
79
80
    - Hu, Yipeng, et al.
81
      "Label-driven weakly-supervised learning
82
      for multimodal deformable image registration,"
83
      https://arxiv.org/abs/1711.01666
84
    """
85
86
    def __init__(
87
        self,
88
        image_size: tuple,
89
        num_channel_initial: int,
90
        extract_levels: Tuple[int],
91
        out_kernel_initializer: str,
92
        out_activation: str,
93
        out_channels: int,
94
        depth: Optional[int] = None,
95
        use_additive_upsampling: bool = True,
96
        pooling: bool = True,
97
        concat_skip: bool = False,
98
        name: str = "LocalNet",
99
        **kwargs,
100
    ):
101
        """
102
        Init.
103
104
        Image is encoded gradually, i from level 0 to D,
105
        then it is decoded gradually, j from level D to 0.
106
        Some of the decoded levels are used for generating extractions.
107
108
        So, extract_levels are between [0, D].
109
110
        :param image_size: such as (dim1, dim2, dim3)
111
        :param num_channel_initial: number of initial channels.
112
        :param extract_levels: from which depths the output will be built.
113
        :param out_kernel_initializer: initializer to use for kernels.
114
        :param out_activation: activation to use at end layer.
115
        :param out_channels: number of channels for the extractions
116
        :param depth: depth of the encoder.
117
            If depth is not given, depth = max(extract_levels) will be used.
118
        :param use_additive_upsampling: whether use additive up-sampling layer
119
            for decoding.
120
        :param pooling: for down-sampling, use non-parameterized
121
                        pooling if true, otherwise use conv3d
122
        :param concat_skip: when up-sampling, concatenate skipped
123
                            tensor if true, otherwise use addition
124
        :param name: name of the backbone.
125
        :param kwargs: additional arguments.
126
        """
127
        self._use_additive_upsampling = use_additive_upsampling
128
        if depth is None:
129
            depth = max(extract_levels)
130
        kwargs["encode_kernel_sizes"] = [7] + [3] * depth
131
        super().__init__(
132
            image_size=image_size,
133
            num_channel_initial=num_channel_initial,
134
            depth=depth,
135
            extract_levels=extract_levels,
136
            out_kernel_initializer=out_kernel_initializer,
137
            out_activation=out_activation,
138
            out_channels=out_channels,
139
            pooling=pooling,
140
            concat_skip=concat_skip,
141
            name=name,
142
            **kwargs,
143
        )
144
145
    def build_bottom_block(
146
        self, filters: int, kernel_size: int, padding: str
147
    ) -> Union[tf.keras.Model, tfkl.Layer]:
148
        """
149
        Build a block for bottom layer.
150
151
        This block do not change the tensor shape (width, height, depth),
152
        it only changes the number of channels.
153
154
        :param filters: number of channels for output
155
        :param kernel_size: arg for conv3d
156
        :param padding: arg for conv3d
157
        :return: a block consists of one or multiple layers
158
        """
159
        return layer.Conv3dBlock(
160
            filters=filters, kernel_size=kernel_size, padding=padding
161
        )
162
163
    def build_up_sampling_block(
164
        self,
165
        filters: int,
166
        output_padding: int,
167
        kernel_size: int,
168
        padding: str,
169
        strides: int,
170
        output_shape: tuple,
171
    ) -> Union[tf.keras.Model, tfkl.Layer]:
172
        """
173
        Build a block for up-sampling.
174
175
        This block changes the tensor shape (width, height, depth),
176
        but it does not changes the number of channels.
177
178
        :param filters: number of channels for output
179
        :param output_padding: padding for output
180
        :param kernel_size: arg for deconv3d
181
        :param padding: arg for deconv3d
182
        :param strides: arg for deconv3d
183
        :param output_shape: shape of the output tensor
184
        :return: a block consists of one or multiple layers
185
        """
186
187
        if self._use_additive_upsampling:
188
            return AdditiveUpsampling(
189
                filters=filters,
190
                output_padding=output_padding,
191
                kernel_size=kernel_size,
192
                strides=strides,
193
                padding=padding,
194
                output_shape=output_shape,
195
            )
196
197
        return layer.Deconv3dBlock(
198
            filters=filters,
199
            output_padding=output_padding,
200
            kernel_size=kernel_size,
201
            strides=strides,
202
            padding=padding,
203
        )
204
205
    def build_output_block(
206
        self,
207
        image_size: Tuple[int],
208
        extract_levels: Tuple[int],
209
        out_channels: int,
210
        out_kernel_initializer: str,
211
        out_activation: str,
212
    ) -> Union[tf.keras.Model, tfkl.Layer]:
213
        """
214
        Build a block for output.
215
216
        The input to this block is a list of tensors.
217
218
        :param image_size: such as (dim1, dim2, dim3)
219
        :param extract_levels: number of extraction levels.
220
        :param out_channels: number of channels for the extractions
221
        :param out_kernel_initializer: initializer to use for kernels.
222
        :param out_activation: activation to use at end layer.
223
        :return: a block consists of one or multiple layers
224
        """
225
        return Extraction(
226
            image_size=image_size,
227
            extract_levels=extract_levels,
228
            out_channels=out_channels,
229
            out_kernel_initializer=out_kernel_initializer,
230
            out_activation=out_activation,
231
        )
232
233
    def get_config(self) -> dict:
234
        """Return the config dictionary for recreating this class."""
235
        config = super().get_config()
236
        config.update(use_additive_upsampling=self._use_additive_upsampling)
237
        return config
238