Passed
Pull Request — main (#656)
by Yunguan
02:46
created

LocalNet.__init__()   C

Complexity

Conditions 6

Size

Total Lines 126
Code Lines 83

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 83
dl 0
loc 126
rs 6.623
c 0
b 0
f 0
cc 6
nop 10

How to fix   Long Method    Many Parameters   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
# coding=utf-8
0 ignored issues
show
introduced by
Missing module docstring
Loading history...
2
3
from typing import List
4
5
import tensorflow as tf
6
import tensorflow.keras.layers as tfkl
7
8
from deepreg.model import layer
9
from deepreg.model.backbone.interface import Backbone
10
from deepreg.registry import REGISTRY
11
12
13
@REGISTRY.register_backbone(name="local")
14
class LocalNet(Backbone):
15
    """
16
    Build LocalNet for image registration.
17
18
    Reference:
19
20
    - Hu, Yipeng, et al.
21
      "Weakly-supervised convolutional neural networks
22
      for multimodal image registration."
23
      Medical image analysis 49 (2018): 1-13.
24
      https://doi.org/10.1016/j.media.2018.07.002
25
26
    - Hu, Yipeng, et al.
27
      "Label-driven weakly-supervised learning
28
      for multimodal deformable image registration,"
29
      https://arxiv.org/abs/1711.01666
30
    """
31
32
    def __init__(
33
        self,
34
        image_size: tuple,
35
        out_channels: int,
36
        num_channel_initial: int,
37
        extract_levels: List[int],
38
        out_kernel_initializer: str,
39
        out_activation: str,
40
        control_points: (tuple, None) = None,
41
        name: str = "LocalNet",
42
        **kwargs,
43
    ):
44
        """
45
        Image is encoded gradually, i from level 0 to E,
46
        then it is decoded gradually, j from level E to D.
47
        Some of the decoded levels are used for generating extractions.
48
49
        So, extract_levels are between [0, E] with E = max(extract_levels),
50
        and D = min(extract_levels).
51
52
        :param image_size: such as (dim1, dim2, dim3)
53
        :param out_channels: number of channels for the extractions
54
        :param num_channel_initial: number of initial channels.
55
        :param extract_levels: number of extraction levels.
56
        :param out_kernel_initializer: initializer to use for kernels.
57
        :param out_activation: activation to use at end layer.
58
        :param control_points: specify the distance between control points (in voxels).
59
        :param name: name of the backbone.
60
        :param kwargs: additional arguments.
61
        """
62
        super().__init__(
63
            image_size=image_size,
64
            out_channels=out_channels,
65
            num_channel_initial=num_channel_initial,
66
            out_kernel_initializer=out_kernel_initializer,
67
            out_activation=out_activation,
68
            name=name,
69
            **kwargs,
70
        )
71
72
        # save parameters
73
        self._extract_levels = extract_levels
74
        self._extract_max_level = max(self._extract_levels)  # E
75
        self._extract_min_level = min(self._extract_levels)  # D
76
77
        # init layer variables
78
        num_channels = [
79
            num_channel_initial * (2 ** level)
80
            for level in range(self._extract_max_level + 1)
81
        ]  # level 0 to E
82
        kernel_sizes = [
83
            7 if level == 0 else 3 for level in range(self._extract_max_level + 1)
84
        ]
85
        self._downsample_convs = []
86
        self._downsample_pools = []
87
        tensor_shape = image_size
88
        self._tensor_shapes = [tensor_shape]
89
        for i in range(self._extract_max_level):
90
            downsample_conv = tf.keras.Sequential(
91
                [
92
                    layer.Conv3dBlock(
93
                        filters=num_channels[i],
94
                        kernel_size=kernel_sizes[i],
95
                        padding="same",
96
                    ),
97
                    layer.ResidualConv3dBlock(
98
                        filters=num_channels[i],
99
                        kernel_size=kernel_sizes[i],
100
                        padding="same",
101
                    ),
102
                ]
103
            )
104
            downsample_pool = tfkl.MaxPool3D(pool_size=2, strides=2, padding="same")
105
            tensor_shape = tuple((x + 1) // 2 for x in tensor_shape)
106
            self._downsample_convs.append(downsample_conv)
107
            self._downsample_pools.append(downsample_pool)
108
            self._tensor_shapes.append(tensor_shape)
109
110
        self._conv3d_block = layer.Conv3dBlock(
111
            filters=num_channels[-1], kernel_size=3, padding="same"
112
        )  # level E
113
114
        self._upsample_blocks = []
115
        for level in range(
116
            self._extract_max_level - 1, self._extract_min_level - 1, -1
117
        ):  # level D to E-1
118
            padding = layer.deconv_output_padding(
119
                input_shape=self._tensor_shapes[level + 1],
120
                output_shape=self._tensor_shapes[level],
121
                kernel_size=kernel_sizes[level],
122
                stride=2,
123
                padding="same",
124
            )
125
            upsample_block = layer.LocalNetUpSampleResnetBlock(
126
                num_channels[level],
127
                output_padding=padding,
128
                output_shape=self._tensor_shapes[level],
129
            )
130
            self._upsample_blocks.append(upsample_block)
131
132
        self._extract_layers = [
133
            tf.keras.Sequential(
134
                [
135
                    tfkl.Conv3D(
136
                        filters=out_channels,
137
                        kernel_size=3,
138
                        strides=1,
139
                        padding="same",
140
                        kernel_initializer=out_kernel_initializer,
141
                        activation=out_activation,
142
                    ),
143
                    layer.Resize3d(shape=image_size),
144
                ]
145
            )
146
            for _ in self._extract_levels
147
        ]
148
149
        self.resize = (
150
            layer.ResizeCPTransform(control_points)
151
            if control_points is not None
152
            else False
153
        )
154
        self.interpolate = (
155
            layer.BSplines3DTransform(control_points, image_size)
156
            if control_points is not None
157
            else False
158
        )
159
160
    def call(self, inputs: tf.Tensor, training=None, mask=None) -> tf.Tensor:
0 ignored issues
show
introduced by
"mask, training" missing in parameter type documentation
Loading history...
161
        """
162
        Build LocalNet graph based on built layers.
163
164
        :param inputs: image batch, shape = (batch, f_dim1, f_dim2, f_dim3, ch)
165
        :param training: None or bool.
166
        :param mask: None or tf.Tensor.
167
        :return: shape = (batch, f_dim1, f_dim2, f_dim3, out_channels)
168
        """
169
170
        # down sample from level 0 to E
171
        # outputs used for decoding, encoded[i] corresponds -> level i
172
        # stored only 0 to E-1
173
        encoded = []
174
        h_in = inputs
175
        for level in range(self._extract_max_level):  # level 0 to E - 1
176
            skip = self._downsample_convs[level](inputs=h_in, training=training)
177
            h_in = self._downsample_pools[level](inputs=skip, training=training)
178
            encoded.append(skip)
179
        h_bottom = self._conv3d_block(
180
            inputs=h_in, training=training
181
        )  # level E of encoding/decoding
182
183
        # up sample from level E to D
184
        decoded = [h_bottom]  # level E
185
        for idx, level in enumerate(
186
            range(self._extract_max_level - 1, self._extract_min_level - 1, -1)
187
        ):  # level E-1 to D
188
            h_bottom = self._upsample_blocks[idx](
189
                inputs=[h_bottom, encoded[level]], training=training
190
            )
191
            decoded.append(h_bottom)
192
193
        # output
194
        output = tf.add_n(
195
            [
196
                self._extract_layers[idx](
197
                    inputs=decoded[self._extract_max_level - level]
198
                )
199
                for idx, level in enumerate(self._extract_levels)
200
            ]
201
        ) / len(self._extract_levels)
202
203
        if self.resize:
204
            output = self.resize(output)
205
            output = self.interpolate(output)
206
207
        return output
208