Passed
Pull Request — main (#656)
by Yunguan
03:15
created

LocalNet.__init__()   C

Complexity

Conditions 5

Size

Total Lines 133
Code Lines 90

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 90
dl 0
loc 133
rs 6.8606
c 0
b 0
f 0
cc 5
nop 10

How to fix   Long Method    Many Parameters   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
# coding=utf-8
0 ignored issues
show
introduced by
Missing module docstring
Loading history...
2
3
from typing import List
4
5
import tensorflow as tf
6
import tensorflow.keras.layers as tfkl
7
8
import deepreg.model.layer_util
9
from deepreg.model import layer
10
from deepreg.model.backbone.interface import Backbone
11
from deepreg.registry import REGISTRY
12
13
14
@REGISTRY.register_backbone(name="local")
15
class LocalNet(Backbone):
16
    """
17
    Build LocalNet for image registration.
18
19
    Reference:
20
21
    - Hu, Yipeng, et al.
22
      "Weakly-supervised convolutional neural networks
23
      for multimodal image registration."
24
      Medical image analysis 49 (2018): 1-13.
25
      https://doi.org/10.1016/j.media.2018.07.002
26
27
    - Hu, Yipeng, et al.
28
      "Label-driven weakly-supervised learning
29
      for multimodal deformable image registration,"
30
      https://arxiv.org/abs/1711.01666
31
    """
32
33
    def __init__(
34
        self,
35
        image_size: tuple,
36
        out_channels: int,
37
        num_channel_initial: int,
38
        extract_levels: List[int],
39
        out_kernel_initializer: str,
40
        out_activation: str,
41
        use_additive_upsampling: bool = True,
42
        name: str = "LocalNet",
43
        **kwargs,
44
    ):
45
        """
46
        Image is encoded gradually, i from level 0 to E,
47
        then it is decoded gradually, j from level E to D.
48
        Some of the decoded levels are used for generating extractions.
49
50
        So, extract_levels are between [0, E] with E = max(extract_levels),
51
        and D = min(extract_levels).
52
53
        :param image_size: such as (dim1, dim2, dim3)
54
        :param out_channels: number of channels for the extractions
55
        :param num_channel_initial: number of initial channels.
56
        :param extract_levels: number of extraction levels.
57
        :param out_kernel_initializer: initializer to use for kernels.
58
        :param out_activation: activation to use at end layer.
59
        :param use_additive_upsampling: whether use additive up-sampling.
60
        :param name: name of the backbone.
61
        :param kwargs: additional arguments.
62
        """
63
        super().__init__(
64
            image_size=image_size,
65
            out_channels=out_channels,
66
            num_channel_initial=num_channel_initial,
67
            out_kernel_initializer=out_kernel_initializer,
68
            out_activation=out_activation,
69
            name=name,
70
            **kwargs,
71
        )
72
73
        # save parameters
74
        self._extract_levels = extract_levels
75
        self._use_additive_upsampling = use_additive_upsampling
76
        self._extract_max_level = max(self._extract_levels)  # E
77
        self._extract_min_level = min(self._extract_levels)  # D
78
79
        # init layer variables
80
        num_channels = [
81
            num_channel_initial * (2 ** level)
82
            for level in range(self._extract_max_level + 1)
83
        ]  # level 0 to E
84
        kernel_sizes = [
85
            7 if level == 0 else 3 for level in range(self._extract_max_level + 1)
86
        ]
87
        self._downsample_convs = []
88
        self._downsample_pools = []
89
        tensor_shape = image_size
90
        self._tensor_shapes = [tensor_shape]
91
        for i in range(self._extract_max_level):
92
            downsample_conv = tf.keras.Sequential(
93
                [
94
                    layer.Conv3dBlock(
95
                        filters=num_channels[i],
96
                        kernel_size=kernel_sizes[i],
97
                        padding="same",
98
                    ),
99
                    layer.ResidualConv3dBlock(
100
                        filters=num_channels[i],
101
                        kernel_size=kernel_sizes[i],
102
                        padding="same",
103
                    ),
104
                ]
105
            )
106
            downsample_pool = tfkl.MaxPool3D(pool_size=2, strides=2, padding="same")
107
            tensor_shape = tuple((x + 1) // 2 for x in tensor_shape)
108
            self._downsample_convs.append(downsample_conv)
109
            self._downsample_pools.append(downsample_pool)
110
            self._tensor_shapes.append(tensor_shape)
111
112
        self._conv3d_block = layer.Conv3dBlock(
113
            filters=num_channels[-1], kernel_size=3, padding="same"
114
        )  # level E
115
116
        self._upsample_deconvs = []
117
        self._resizes = []
118
        self._upsample_convs = []
119
        for level in range(
120
            self._extract_max_level - 1, self._extract_min_level - 1, -1
121
        ):  # level D to E-1
122
            padding = deepreg.model.layer_util.deconv_output_padding(
123
                input_shape=self._tensor_shapes[level + 1],
124
                output_shape=self._tensor_shapes[level],
125
                kernel_size=kernel_sizes[level],
126
                stride=2,
127
                padding="same",
128
            )
129
            upsample_deconv = layer.Deconv3dBlock(
130
                filters=num_channels[level],
131
                output_padding=padding,
132
                kernel_size=3,
133
                strides=2,
134
                padding="same",
135
            )
136
            upsample_conv = tf.keras.Sequential(
137
                [
138
                    layer.Conv3dBlock(
139
                        filters=num_channels[level], kernel_size=3, padding="same"
140
                    ),
141
                    layer.ResidualConv3dBlock(
142
                        filters=num_channels[level], kernel_size=3, padding="same"
143
                    ),
144
                ]
145
            )
146
            self._upsample_deconvs.append(upsample_deconv)
147
            self._upsample_convs.append(upsample_conv)
148
            if self._use_additive_upsampling:
149
                resize = layer.Resize3d(shape=self._tensor_shapes[level])
150
                self._resizes.append(resize)
151
        self._extract_layers = [
152
            tf.keras.Sequential(
153
                [
154
                    tfkl.Conv3D(
155
                        filters=out_channels,
156
                        kernel_size=3,
157
                        strides=1,
158
                        padding="same",
159
                        kernel_initializer=out_kernel_initializer,
160
                        activation=out_activation,
161
                    ),
162
                    layer.Resize3d(shape=image_size),
163
                ]
164
            )
165
            for _ in self._extract_levels
166
        ]
167
168
    def call(self, inputs: tf.Tensor, training=None, mask=None) -> tf.Tensor:
0 ignored issues
show
introduced by
"mask, training" missing in parameter type documentation
Loading history...
169
        """
170
        Build LocalNet graph based on built layers.
171
172
        :param inputs: image batch, shape = (batch, f_dim1, f_dim2, f_dim3, ch)
173
        :param training: None or bool.
174
        :param mask: None or tf.Tensor.
175
        :return: shape = (batch, f_dim1, f_dim2, f_dim3, out_channels)
176
        """
177
178
        # down sample from level 0 to E
179
        # outputs used for decoding, encoded[i] corresponds -> level i
180
        # stored only 0 to E-1
181
        encoded = []
182
        h_in = inputs
183
        for level in range(self._extract_max_level):  # level 0 to E - 1
184
            skip = self._downsample_convs[level](inputs=h_in, training=training)
185
            h_in = self._downsample_pools[level](inputs=skip, training=training)
186
            encoded.append(skip)
187
        h_bottom = self._conv3d_block(
188
            inputs=h_in, training=training
189
        )  # level E of encoding/decoding
190
191
        # up sample from level E to D
192
        decoded = [h_bottom]  # level E
193
        for idx, level in enumerate(
194
            range(self._extract_max_level - 1, self._extract_min_level - 1, -1)
195
        ):  # level E-1 to D
196
            h = self._upsample_deconvs[idx](inputs=h_bottom, training=training)
197
            if self._use_additive_upsampling:
198
                up_sampled = self._resizes[idx](inputs=h_bottom)
199
                up_sampled = tf.split(up_sampled, num_or_size_splits=2, axis=4)
0 ignored issues
show
Unused Code introduced by
Argument 'axis' passed by position and keyword in function call
Loading history...
200
                up_sampled = tf.add_n(up_sampled)
201
                h = h + up_sampled
202
            h = h + encoded[level]
203
            h_bottom = self._upsample_convs[idx](inputs=h, training=training)
204
            decoded.append(h_bottom)
205
206
        # output
207
        output = tf.add_n(
208
            [
209
                self._extract_layers[idx](
210
                    inputs=decoded[self._extract_max_level - level]
211
                )
212
                for idx, level in enumerate(self._extract_levels)
213
            ]
214
        ) / len(self._extract_levels)
215
216
        return output
217