Passed
Pull Request — main (#656)
by Yunguan
02:51
created

AdditiveUpsampling.__init__()   A

Complexity

Conditions 1

Size

Total Lines 30
Code Lines 17

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 17
dl 0
loc 30
rs 9.55
c 0
b 0
f 0
cc 1
nop 8

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
# coding=utf-8
0 ignored issues
show
introduced by
Missing module docstring
Loading history...
2
3
from typing import Optional, Tuple, Union
4
5
import tensorflow as tf
6
import tensorflow.keras.layers as tfkl
7
8
from deepreg.model import layer
9
from deepreg.model.backbone.u_net import UNet
10
from deepreg.model.layer import Extraction
11
from deepreg.registry import REGISTRY
12
13
14
class AdditiveUpsampling(tfkl.Layer):
0 ignored issues
show
introduced by
Missing class docstring
Loading history...
15
    def __init__(
16
        self,
17
        filters: int,
18
        output_padding: int,
19
        kernel_size: int,
20
        padding: str,
21
        strides: int,
22
        output_shape: tuple,
23
        name: str = "AdditiveUpsampling",
24
    ):
25
        """
26
        Addictive up-sampling layer.
27
28
        :param filters: number of channels for output
29
        :param output_padding: padding for output
30
        :param kernel_size: arg for deconv3d
31
        :param padding: arg for deconv3d
32
        :param strides: arg for deconv3d
33
        :param output_shape: shape of the output tensor
34
        :param name: name of the layer.
35
        """
36
        super().__init__(name=name)
37
        self.deconv3d = layer.Deconv3dBlock(
38
            filters=filters,
39
            output_padding=output_padding,
40
            kernel_size=kernel_size,
41
            strides=strides,
42
            padding=padding,
43
        )
44
        self.resize = layer.Resize3d(shape=output_shape)
45
46
    def call(self, inputs, **kwargs):
47
        deconved = self.deconv3d(inputs)
48
        resized = self.resize(inputs)
49
        resized = tf.add_n(tf.split(resized, num_or_size_splits=2, axis=4))
0 ignored issues
show
Unused Code introduced by
Argument 'axis' passed by position and keyword in function call
Loading history...
50
        return deconved + resized
51
52
53
@REGISTRY.register_backbone(name="local")
54
class LocalNet(UNet):
55
    """
56
    Build LocalNet for image registration.
57
58
    Reference:
59
60
    - Hu, Yipeng, et al.
61
      "Weakly-supervised convolutional neural networks
62
      for multimodal image registration."
63
      Medical image analysis 49 (2018): 1-13.
64
      https://doi.org/10.1016/j.media.2018.07.002
65
66
    - Hu, Yipeng, et al.
67
      "Label-driven weakly-supervised learning
68
      for multimodal deformable image registration,"
69
      https://arxiv.org/abs/1711.01666
70
    """
71
72
    def __init__(
73
        self,
74
        image_size: tuple,
75
        num_channel_initial: int,
76
        extract_levels: Tuple[int],
77
        out_kernel_initializer: str,
78
        out_activation: str,
79
        out_channels: int,
80
        depth: Optional[int] = None,
81
        use_additive_upsampling: bool = True,
82
        pooling: bool = True,
83
        concat_skip: bool = False,
84
        name: str = "LocalNet",
85
        **kwargs,
86
    ):
87
        """
88
        Init.
89
90
        Image is encoded gradually, i from level 0 to D,
91
        then it is decoded gradually, j from level D to 0.
92
        Some of the decoded levels are used for generating extractions.
93
94
        So, extract_levels are between [0, D].
95
96
        :param image_size: such as (dim1, dim2, dim3)
97
        :param num_channel_initial: number of initial channels.
98
        :param extract_levels: from which depths the output will be built.
99
        :param out_kernel_initializer: initializer to use for kernels.
100
        :param out_activation: activation to use at end layer.
101
        :param out_channels: number of channels for the extractions
102
        :param depth: depth of the encoder.
103
        :param use_additive_upsampling: whether use additive up-sampling.
104
        :param pooling: for down-sampling, use non-parameterized
105
                        pooling if true, otherwise use conv3d
106
        :param concat_skip: when up-sampling, concatenate skipped
107
                            tensor if true, otherwise use addition
108
        :param name: name of the backbone.
109
        :param kwargs: additional arguments.
110
        """
111
        self._use_additive_upsampling = use_additive_upsampling
112
        if depth is None:
113
            depth = max(extract_levels)
114
        super().__init__(
115
            image_size=image_size,
116
            num_channel_initial=num_channel_initial,
117
            depth=depth,
118
            extract_levels=extract_levels,
119
            out_kernel_initializer=out_kernel_initializer,
120
            out_activation=out_activation,
121
            out_channels=out_channels,
122
            pooling=pooling,
123
            concat_skip=concat_skip,
124
            encode_kernel_sizes=[7] + [3] * depth,
125
            name=name,
126
            **kwargs,
127
        )
128
129
    def build_bottom_block(
130
        self, filters: int, kernel_size: int, padding: str
131
    ) -> Union[tf.keras.Model, tfkl.Layer]:
132
        """
133
        Build a block for bottom layer.
134
135
        This block do not change the tensor shape (width, height, depth),
136
        it only changes the number of channels.
137
138
        :param filters: number of channels for output
139
        :param kernel_size: arg for conv3d
140
        :param padding: arg for conv3d
141
        :return: a block consists of one or multiple layers
142
        """
143
        return layer.Conv3dBlock(
144
            filters=filters, kernel_size=kernel_size, padding=padding
145
        )
146
147
    def build_up_sampling_block(
148
        self,
149
        filters: int,
150
        output_padding: int,
151
        kernel_size: int,
152
        padding: str,
153
        strides: int,
154
        output_shape: tuple,
155
    ) -> Union[tf.keras.Model, tfkl.Layer]:
156
        """
157
        Build a block for up-sampling.
158
159
        This block changes the tensor shape (width, height, depth),
160
        but it does not changes the number of channels.
161
162
        :param filters: number of channels for output
163
        :param output_padding: padding for output
164
        :param kernel_size: arg for deconv3d
165
        :param padding: arg for deconv3d
166
        :param strides: arg for deconv3d
167
        :param output_shape: shape of the output tensor
168
        :return: a block consists of one or multiple layers
169
        """
170
171
        if self._use_additive_upsampling:
172
            return AdditiveUpsampling(
173
                filters=filters,
174
                output_padding=output_padding,
175
                kernel_size=kernel_size,
176
                strides=strides,
177
                padding=padding,
178
                output_shape=output_shape,
179
            )
180
181
        return layer.Deconv3dBlock(
182
            filters=filters,
183
            output_padding=output_padding,
184
            kernel_size=kernel_size,
185
            strides=strides,
186
            padding=padding,
187
        )
188
189
    def build_output_block(
190
        self,
191
        image_size: Tuple[int],
192
        extract_levels: Tuple[int],
193
        out_channels: int,
194
        out_kernel_initializer: str,
195
        out_activation: str,
196
    ) -> Union[tf.keras.Model, tfkl.Layer]:
197
        """
198
        Build a block for output.
199
200
        The input to this block is a list of tensors.
201
202
        :param image_size: such as (dim1, dim2, dim3)
203
        :param extract_levels: number of extraction levels.
204
        :param out_channels: number of channels for the extractions
205
        :param out_kernel_initializer: initializer to use for kernels.
206
        :param out_activation: activation to use at end layer.
207
        :return: a block consists of one or multiple layers
208
        """
209
        return Extraction(
210
            image_size=image_size,
211
            extract_levels=extract_levels,
212
            out_channels=out_channels,
213
            out_kernel_initializer=out_kernel_initializer,
214
            out_activation=out_activation,
215
        )
216