Passed
Pull Request — main (#656)
by Yunguan
02:46
created

LocalNet.__init__()   C

Complexity

Conditions 6

Size

Total Lines 126
Code Lines 83

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 83
dl 0
loc 126
rs 6.623
c 0
b 0
f 0
cc 6
nop 10

How to fix   Long Method    Many Parameters   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
# coding=utf-8
0 ignored issues
show
introduced by
Missing module docstring
Loading history...
2
3
from typing import List
4
5
import tensorflow as tf
6
import tensorflow.keras.layers as tfkl
7
8
from deepreg.model import layer
9
from deepreg.model.backbone.interface import Backbone
10
from deepreg.registry import REGISTRY
11
12
13
@REGISTRY.register_backbone(name="local")
14
class LocalNet(Backbone):
15
    """
16
    Build LocalNet for image registration.
17
18
    Reference:
19
20
    - Hu, Yipeng, et al.
21
      "Weakly-supervised convolutional neural networks
22
      for multimodal image registration."
23
      Medical image analysis 49 (2018): 1-13.
24
      https://doi.org/10.1016/j.media.2018.07.002
25
26
    - Hu, Yipeng, et al.
27
      "Label-driven weakly-supervised learning
28
      for multimodal deformable image registration,"
29
      https://arxiv.org/abs/1711.01666
30
    """
31
32
    def __init__(
33
        self,
34
        image_size: tuple,
35
        out_channels: int,
36
        num_channel_initial: int,
37
        extract_levels: List[int],
38
        out_kernel_initializer: str,
39
        out_activation: str,
40
        control_points: (tuple, None) = None,
41
        name: str = "LocalNet",
42
        **kwargs,
43
    ):
44
        """
45
        Image is encoded gradually, i from level 0 to E,
46
        then it is decoded gradually, j from level E to D.
47
        Some of the decoded levels are used for generating extractions.
48
49
        So, extract_levels are between [0, E] with E = max(extract_levels),
50
        and D = min(extract_levels).
51
52
        :param image_size: such as (dim1, dim2, dim3)
53
        :param out_channels: number of channels for the extractions
54
        :param num_channel_initial: number of initial channels.
55
        :param extract_levels: number of extraction levels.
56
        :param out_kernel_initializer: initializer to use for kernels.
57
        :param out_activation: activation to use at end layer.
58
        :param control_points: specify the distance between control points (in voxels).
59
        :param name: name of the backbone.
60
        :param kwargs: additional arguments.
61
        """
62
        super().__init__(
63
            image_size=image_size,
64
            out_channels=out_channels,
65
            num_channel_initial=num_channel_initial,
66
            out_kernel_initializer=out_kernel_initializer,
67
            out_activation=out_activation,
68
            name=name,
69
            **kwargs,
70
        )
71
72
        # save parameters
73
        self._extract_levels = extract_levels
74
        self._extract_max_level = max(self._extract_levels)  # E
75
        self._extract_min_level = min(self._extract_levels)  # D
76
77
        # init layer variables
78
        num_channels = [
79
            num_channel_initial * (2 ** level)
80
            for level in range(self._extract_max_level + 1)
81
        ]  # level 0 to E
82
        kernel_sizes = [
83
            7 if level == 0 else 3 for level in range(self._extract_max_level + 1)
84
        ]
85
        self._downsample_convs = []
86
        self._downsample_pools = []
87
        tensor_shape = image_size
88
        self._tensor_shapes = [tensor_shape]
89
        for i in range(self._extract_max_level):
90
            downsample_conv = tf.keras.Sequential(
91
                [
92
                    layer.Conv3dBlock(
93
                        filters=num_channels[i],
94
                        kernel_size=kernel_sizes[i],
95
                        padding="same",
96
                    ),
97
                    layer.ResidualConv3dBlock(
98
                        filters=num_channels[i],
99
                        kernel_size=kernel_sizes[i],
100
                        padding="same",
101
                    ),
102
                ]
103
            )
104
            downsample_pool = tfkl.MaxPool3D(pool_size=2, strides=2, padding="same")
105
            tensor_shape = tuple((x + 1) // 2 for x in tensor_shape)
106
            self._downsample_convs.append(downsample_conv)
107
            self._downsample_pools.append(downsample_pool)
108
            self._tensor_shapes.append(tensor_shape)
109
110
        self._conv3d_block = layer.Conv3dBlock(
111
            filters=num_channels[-1], kernel_size=3, padding="same"
112
        )  # level E
113
114
        self._upsample_blocks = []
115
        for level in range(
116
            self._extract_max_level - 1, self._extract_min_level - 1, -1
117
        ):  # level D to E-1
118
            padding = layer.deconv_output_padding(
119
                input_shape=self._tensor_shapes[level + 1],
120
                output_shape=self._tensor_shapes[level],
121
                kernel_size=kernel_sizes[level],
122
                stride=2,
123
                padding="same",
124
            )
125
            upsample_block = layer.LocalNetUpSampleResnetBlock(
126
                num_channels[level],
127
                output_padding=padding,
128
                output_shape=self._tensor_shapes[level],
129
            )
130
            self._upsample_blocks.append(upsample_block)
131
132
        self._extract_layers = [
133
            tf.keras.Sequential(
134
                [
135
                    tfkl.Conv3D(
136
                        filters=out_channels,
137
                        kernel_size=3,
138
                        strides=1,
139
                        padding="same",
140
                        kernel_initializer=out_kernel_initializer,
141
                        activation=out_activation,
142
                    ),
143
                    layer.Resize3d(shape=image_size),
144
                ]
145
            )
146
            for _ in self._extract_levels
147
        ]
148
149
        self.resize = (
150
            layer.ResizeCPTransform(control_points)
151
            if control_points is not None
152
            else False
153
        )
154
        self.interpolate = (
155
            layer.BSplines3DTransform(control_points, image_size)
156
            if control_points is not None
157
            else False
158
        )
159
160
    def call(self, inputs: tf.Tensor, training=None, mask=None) -> tf.Tensor:
0 ignored issues
show
introduced by
"mask, training" missing in parameter type documentation
Loading history...
161
        """
162
        Build LocalNet graph based on built layers.
163
164
        :param inputs: image batch, shape = (batch, f_dim1, f_dim2, f_dim3, ch)
165
        :param training: None or bool.
166
        :param mask: None or tf.Tensor.
167
        :return: shape = (batch, f_dim1, f_dim2, f_dim3, out_channels)
168
        """
169
170
        # down sample from level 0 to E
171
        encoded = []
172
        # outputs used for decoding, encoded[i] corresponds -> level i
173
        # stored only 0 to E-1
174
175
        h_in = inputs
176
        for level in range(self._extract_max_level):  # level 0 to E - 1
177
            skip = self._downsample_convs[level](inputs=h_in, training=training)
178
            h_in = self._downsample_pools[level](inputs=skip)
179
            encoded.append(skip)
180
        h_bottom = self._conv3d_block(
181
            inputs=h_in, training=training
182
        )  # level E of encoding/decoding
183
184
        # up sample from level E to D
185
        decoded = [h_bottom]  # level E
186
        for idx, level in enumerate(
187
            range(self._extract_max_level - 1, self._extract_min_level - 1, -1)
188
        ):  # level E-1 to D
189
            h_bottom = self._upsample_blocks[idx](
190
                inputs=[h_bottom, encoded[level]], training=training
191
            )
192
            decoded.append(h_bottom)
193
194
        # output
195
        output = tf.reduce_mean(
196
            tf.stack(
197
                [
198
                    self._extract_layers[idx](
199
                        inputs=decoded[self._extract_max_level - level]
200
                    )
201
                    for idx, level in enumerate(self._extract_levels)
202
                ],
203
                axis=5,
204
            ),
205
            axis=5,
206
        )
207
208
        if self.resize:
209
            output = self.resize(output)
210
            output = self.interpolate(output)
211
212
        return output
213