Passed
Pull Request — main (#656)
by Yunguan
02:35
created

LocalNet.__init__()   C

Complexity

Conditions 7

Size

Total Lines 145
Code Lines 99

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 99
dl 0
loc 145
rs 5.6327
c 0
b 0
f 0
cc 7
nop 11

How to fix   Long Method    Many Parameters   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
# coding=utf-8
0 ignored issues
show
introduced by
Missing module docstring
Loading history...
2
3
from typing import List
4
5
import tensorflow as tf
6
import tensorflow.keras.layers as tfkl
7
8
import deepreg.model.layer_util
9
from deepreg.model import layer
10
from deepreg.model.backbone.interface import Backbone
11
from deepreg.registry import REGISTRY
12
13
14
@REGISTRY.register_backbone(name="local")
15
class LocalNet(Backbone):
16
    """
17
    Build LocalNet for image registration.
18
19
    Reference:
20
21
    - Hu, Yipeng, et al.
22
      "Weakly-supervised convolutional neural networks
23
      for multimodal image registration."
24
      Medical image analysis 49 (2018): 1-13.
25
      https://doi.org/10.1016/j.media.2018.07.002
26
27
    - Hu, Yipeng, et al.
28
      "Label-driven weakly-supervised learning
29
      for multimodal deformable image registration,"
30
      https://arxiv.org/abs/1711.01666
31
    """
32
33
    def __init__(
0 ignored issues
show
introduced by
"use_additive_upsampling" missing in parameter documentation
Loading history...
34
        self,
35
        image_size: tuple,
36
        out_channels: int,
37
        num_channel_initial: int,
38
        extract_levels: List[int],
39
        out_kernel_initializer: str,
40
        out_activation: str,
41
        use_additive_upsampling: bool = True,
42
        control_points: (tuple, None) = None,
43
        name: str = "LocalNet",
44
        **kwargs,
45
    ):
46
        """
47
        Image is encoded gradually, i from level 0 to E,
48
        then it is decoded gradually, j from level E to D.
49
        Some of the decoded levels are used for generating extractions.
50
51
        So, extract_levels are between [0, E] with E = max(extract_levels),
52
        and D = min(extract_levels).
53
54
        :param image_size: such as (dim1, dim2, dim3)
55
        :param out_channels: number of channels for the extractions
56
        :param num_channel_initial: number of initial channels.
57
        :param extract_levels: number of extraction levels.
58
        :param out_kernel_initializer: initializer to use for kernels.
59
        :param out_activation: activation to use at end layer.
60
        :param control_points: specify the distance between control points (in voxels).
61
        :param name: name of the backbone.
62
        :param kwargs: additional arguments.
63
        """
64
        super().__init__(
65
            image_size=image_size,
66
            out_channels=out_channels,
67
            num_channel_initial=num_channel_initial,
68
            out_kernel_initializer=out_kernel_initializer,
69
            out_activation=out_activation,
70
            name=name,
71
            **kwargs,
72
        )
73
74
        # save parameters
75
        self._extract_levels = extract_levels
76
        self._use_additive_upsampling = use_additive_upsampling
77
        self._extract_max_level = max(self._extract_levels)  # E
78
        self._extract_min_level = min(self._extract_levels)  # D
79
80
        # init layer variables
81
        num_channels = [
82
            num_channel_initial * (2 ** level)
83
            for level in range(self._extract_max_level + 1)
84
        ]  # level 0 to E
85
        kernel_sizes = [
86
            7 if level == 0 else 3 for level in range(self._extract_max_level + 1)
87
        ]
88
        self._downsample_convs = []
89
        self._downsample_pools = []
90
        tensor_shape = image_size
91
        self._tensor_shapes = [tensor_shape]
92
        for i in range(self._extract_max_level):
93
            downsample_conv = tf.keras.Sequential(
94
                [
95
                    layer.Conv3dBlock(
96
                        filters=num_channels[i],
97
                        kernel_size=kernel_sizes[i],
98
                        padding="same",
99
                    ),
100
                    layer.ResidualConv3dBlock(
101
                        filters=num_channels[i],
102
                        kernel_size=kernel_sizes[i],
103
                        padding="same",
104
                    ),
105
                ]
106
            )
107
            downsample_pool = tfkl.MaxPool3D(pool_size=2, strides=2, padding="same")
108
            tensor_shape = tuple((x + 1) // 2 for x in tensor_shape)
109
            self._downsample_convs.append(downsample_conv)
110
            self._downsample_pools.append(downsample_pool)
111
            self._tensor_shapes.append(tensor_shape)
112
113
        self._conv3d_block = layer.Conv3dBlock(
114
            filters=num_channels[-1], kernel_size=3, padding="same"
115
        )  # level E
116
117
        self._upsample_deconvs = []
118
        self._resizes = []
119
        self._upsample_convs = []
120
        for level in range(
121
            self._extract_max_level - 1, self._extract_min_level - 1, -1
122
        ):  # level D to E-1
123
            padding = deepreg.model.layer_util.deconv_output_padding(
124
                input_shape=self._tensor_shapes[level + 1],
125
                output_shape=self._tensor_shapes[level],
126
                kernel_size=kernel_sizes[level],
127
                stride=2,
128
                padding="same",
129
            )
130
            upsample_deconv = layer.Deconv3dBlock(
131
                filters=num_channels[level],
132
                output_padding=padding,
133
                kernel_size=3,
134
                strides=2,
135
                padding="same",
136
            )
137
            upsample_conv = tf.keras.Sequential(
138
                [
139
                    layer.Conv3dBlock(
140
                        filters=num_channels[level], kernel_size=3, padding="same"
141
                    ),
142
                    layer.ResidualConv3dBlock(
143
                        filters=num_channels[level], kernel_size=3, padding="same"
144
                    ),
145
                ]
146
            )
147
            self._upsample_deconvs.append(upsample_deconv)
148
            self._upsample_convs.append(upsample_conv)
149
            if self._use_additive_upsampling:
150
                resize = layer.Resize3d(shape=self._tensor_shapes[level])
151
                self._resizes.append(resize)
152
        self._extract_layers = [
153
            tf.keras.Sequential(
154
                [
155
                    tfkl.Conv3D(
156
                        filters=out_channels,
157
                        kernel_size=3,
158
                        strides=1,
159
                        padding="same",
160
                        kernel_initializer=out_kernel_initializer,
161
                        activation=out_activation,
162
                    ),
163
                    layer.Resize3d(shape=image_size),
164
                ]
165
            )
166
            for _ in self._extract_levels
167
        ]
168
169
        self.resize = (
170
            layer.ResizeCPTransform(control_points)
171
            if control_points is not None
172
            else False
173
        )
174
        self.interpolate = (
175
            layer.BSplines3DTransform(control_points, image_size)
176
            if control_points is not None
177
            else False
178
        )
179
180
    def call(self, inputs: tf.Tensor, training=None, mask=None) -> tf.Tensor:
0 ignored issues
show
introduced by
"mask, training" missing in parameter type documentation
Loading history...
181
        """
182
        Build LocalNet graph based on built layers.
183
184
        :param inputs: image batch, shape = (batch, f_dim1, f_dim2, f_dim3, ch)
185
        :param training: None or bool.
186
        :param mask: None or tf.Tensor.
187
        :return: shape = (batch, f_dim1, f_dim2, f_dim3, out_channels)
188
        """
189
190
        # down sample from level 0 to E
191
        # outputs used for decoding, encoded[i] corresponds -> level i
192
        # stored only 0 to E-1
193
        encoded = []
194
        h_in = inputs
195
        for level in range(self._extract_max_level):  # level 0 to E - 1
196
            skip = self._downsample_convs[level](inputs=h_in, training=training)
197
            h_in = self._downsample_pools[level](inputs=skip, training=training)
198
            encoded.append(skip)
199
        h_bottom = self._conv3d_block(
200
            inputs=h_in, training=training
201
        )  # level E of encoding/decoding
202
203
        # up sample from level E to D
204
        decoded = [h_bottom]  # level E
205
        for idx, level in enumerate(
206
            range(self._extract_max_level - 1, self._extract_min_level - 1, -1)
207
        ):  # level E-1 to D
208
            h = self._upsample_deconvs[idx](inputs=h_bottom, training=training)
209
            if self._use_additive_upsampling:
210
                up_sampled = self._resizes[idx](inputs=h_bottom)
211
                up_sampled = tf.split(up_sampled, num_or_size_splits=2, axis=4)
0 ignored issues
show
Unused Code introduced by
Argument 'axis' passed by position and keyword in function call
Loading history...
212
                up_sampled = tf.add_n(up_sampled)
213
                h = h + up_sampled
214
            h = h + encoded[level]
215
            h_bottom = self._upsample_convs[idx](inputs=h, training=training)
216
            decoded.append(h_bottom)
217
218
        # output
219
        output = tf.add_n(
220
            [
221
                self._extract_layers[idx](
222
                    inputs=decoded[self._extract_max_level - level]
223
                )
224
                for idx, level in enumerate(self._extract_levels)
225
            ]
226
        ) / len(self._extract_levels)
227
228
        if self.resize:
229
            output = self.resize(output)
230
            output = self.interpolate(output)
231
232
        return output
233