Passed
Pull Request — main (#656)
by Yunguan
02:35
created

deepreg.model.backbone.local_net.LocalNet.call()   B

Complexity

Conditions 5

Size

Total Lines 53
Code Lines 31

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 31
dl 0
loc 53
rs 8.6693
c 0
b 0
f 0
cc 5
nop 4

How to fix   Long Method   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

1
# coding=utf-8
0 ignored issues
show
introduced by
Missing module docstring
Loading history...
2
3
from typing import List
4
5
import tensorflow as tf
6
import tensorflow.keras.layers as tfkl
7
8
import deepreg.model.layer_util
9
from deepreg.model import layer
10
from deepreg.model.backbone.interface import Backbone
11
from deepreg.registry import REGISTRY
12
13
14
@REGISTRY.register_backbone(name="local")
15
class LocalNet(Backbone):
16
    """
17
    Build LocalNet for image registration.
18
19
    Reference:
20
21
    - Hu, Yipeng, et al.
22
      "Weakly-supervised convolutional neural networks
23
      for multimodal image registration."
24
      Medical image analysis 49 (2018): 1-13.
25
      https://doi.org/10.1016/j.media.2018.07.002
26
27
    - Hu, Yipeng, et al.
28
      "Label-driven weakly-supervised learning
29
      for multimodal deformable image registration,"
30
      https://arxiv.org/abs/1711.01666
31
    """
32
33
    def __init__(
0 ignored issues
show
introduced by
"use_additive_upsampling" missing in parameter documentation
Loading history...
34
        self,
35
        image_size: tuple,
36
        out_channels: int,
37
        num_channel_initial: int,
38
        extract_levels: List[int],
39
        out_kernel_initializer: str,
40
        out_activation: str,
41
        use_additive_upsampling: bool = True,
42
        control_points: (tuple, None) = None,
43
        name: str = "LocalNet",
44
        **kwargs,
45
    ):
46
        """
47
        Image is encoded gradually, i from level 0 to E,
48
        then it is decoded gradually, j from level E to D.
49
        Some of the decoded levels are used for generating extractions.
50
51
        So, extract_levels are between [0, E] with E = max(extract_levels),
52
        and D = min(extract_levels).
53
54
        :param image_size: such as (dim1, dim2, dim3)
55
        :param out_channels: number of channels for the extractions
56
        :param num_channel_initial: number of initial channels.
57
        :param extract_levels: number of extraction levels.
58
        :param out_kernel_initializer: initializer to use for kernels.
59
        :param out_activation: activation to use at end layer.
60
        :param control_points: specify the distance between control points (in voxels).
61
        :param name: name of the backbone.
62
        :param kwargs: additional arguments.
63
        """
64
        super().__init__(
65
            image_size=image_size,
66
            out_channels=out_channels,
67
            num_channel_initial=num_channel_initial,
68
            out_kernel_initializer=out_kernel_initializer,
69
            out_activation=out_activation,
70
            name=name,
71
            **kwargs,
72
        )
73
74
        # save parameters
75
        self._extract_levels = extract_levels
76
        self._use_additive_upsampling = use_additive_upsampling
77
        self._extract_max_level = max(self._extract_levels)  # E
78
        self._extract_min_level = min(self._extract_levels)  # D
79
80
        # init layer variables
81
        num_channels = [
82
            num_channel_initial * (2 ** level)
83
            for level in range(self._extract_max_level + 1)
84
        ]  # level 0 to E
85
        kernel_sizes = [
86
            7 if level == 0 else 3 for level in range(self._extract_max_level + 1)
87
        ]
88
        self._downsample_convs = []
89
        self._downsample_pools = []
90
        tensor_shape = image_size
91
        self._tensor_shapes = [tensor_shape]
92
        for i in range(self._extract_max_level):
93
            downsample_conv = tf.keras.Sequential(
94
                [
95
                    layer.Conv3dBlock(
96
                        filters=num_channels[i],
97
                        kernel_size=kernel_sizes[i],
98
                        padding="same",
99
                    ),
100
                    layer.ResidualConv3dBlock(
101
                        filters=num_channels[i],
102
                        kernel_size=kernel_sizes[i],
103
                        padding="same",
104
                    ),
105
                ]
106
            )
107
            downsample_pool = tfkl.MaxPool3D(pool_size=2, strides=2, padding="same")
108
            tensor_shape = tuple((x + 1) // 2 for x in tensor_shape)
109
            self._downsample_convs.append(downsample_conv)
110
            self._downsample_pools.append(downsample_pool)
111
            self._tensor_shapes.append(tensor_shape)
112
113
        self._conv3d_block = layer.Conv3dBlock(
114
            filters=num_channels[-1], kernel_size=3, padding="same"
115
        )  # level E
116
117
        self._upsample_deconvs = []
118
        self._resizes = []
119
        self._upsample_convs = []
120
        for level in range(
121
            self._extract_max_level - 1, self._extract_min_level - 1, -1
122
        ):  # level D to E-1
123
            padding = deepreg.model.layer_util.deconv_output_padding(
124
                input_shape=self._tensor_shapes[level + 1],
125
                output_shape=self._tensor_shapes[level],
126
                kernel_size=kernel_sizes[level],
127
                stride=2,
128
                padding="same",
129
            )
130
            upsample_deconv = layer.Deconv3dBlock(
131
                filters=num_channels[level],
132
                output_padding=padding,
133
                kernel_size=3,
134
                strides=2,
135
                padding="same",
136
            )
137
            upsample_conv = tf.keras.Sequential(
138
                [
139
                    layer.Conv3dBlock(
140
                        filters=num_channels[level], kernel_size=3, padding="same"
141
                    ),
142
                    layer.ResidualConv3dBlock(
143
                        filters=num_channels[level], kernel_size=3, padding="same"
144
                    ),
145
                ]
146
            )
147
            self._upsample_deconvs.append(upsample_deconv)
148
            self._upsample_convs.append(upsample_conv)
149
            if self._use_additive_upsampling:
150
                resize = layer.Resize3d(shape=self._tensor_shapes[level])
151
                self._resizes.append(resize)
152
        self._extract_layers = [
153
            tf.keras.Sequential(
154
                [
155
                    tfkl.Conv3D(
156
                        filters=out_channels,
157
                        kernel_size=3,
158
                        strides=1,
159
                        padding="same",
160
                        kernel_initializer=out_kernel_initializer,
161
                        activation=out_activation,
162
                    ),
163
                    layer.Resize3d(shape=image_size),
164
                ]
165
            )
166
            for _ in self._extract_levels
167
        ]
168
169
        self.resize = (
170
            layer.ResizeCPTransform(control_points)
171
            if control_points is not None
172
            else False
173
        )
174
        self.interpolate = (
175
            layer.BSplines3DTransform(control_points, image_size)
176
            if control_points is not None
177
            else False
178
        )
179
180
    def call(self, inputs: tf.Tensor, training=None, mask=None) -> tf.Tensor:
0 ignored issues
show
introduced by
"mask, training" missing in parameter type documentation
Loading history...
181
        """
182
        Build LocalNet graph based on built layers.
183
184
        :param inputs: image batch, shape = (batch, f_dim1, f_dim2, f_dim3, ch)
185
        :param training: None or bool.
186
        :param mask: None or tf.Tensor.
187
        :return: shape = (batch, f_dim1, f_dim2, f_dim3, out_channels)
188
        """
189
190
        # down sample from level 0 to E
191
        # outputs used for decoding, encoded[i] corresponds -> level i
192
        # stored only 0 to E-1
193
        encoded = []
194
        h_in = inputs
195
        for level in range(self._extract_max_level):  # level 0 to E - 1
196
            skip = self._downsample_convs[level](inputs=h_in, training=training)
197
            h_in = self._downsample_pools[level](inputs=skip, training=training)
198
            encoded.append(skip)
199
        h_bottom = self._conv3d_block(
200
            inputs=h_in, training=training
201
        )  # level E of encoding/decoding
202
203
        # up sample from level E to D
204
        decoded = [h_bottom]  # level E
205
        for idx, level in enumerate(
206
            range(self._extract_max_level - 1, self._extract_min_level - 1, -1)
207
        ):  # level E-1 to D
208
            h = self._upsample_deconvs[idx](inputs=h_bottom, training=training)
209
            if self._use_additive_upsampling:
210
                up_sampled = self._resizes[idx](inputs=h_bottom)
211
                up_sampled = tf.split(up_sampled, num_or_size_splits=2, axis=4)
0 ignored issues
show
Unused Code introduced by
Argument 'axis' passed by position and keyword in function call
Loading history...
212
                up_sampled = tf.add_n(up_sampled)
213
                h = h + up_sampled
214
            h = h + encoded[level]
215
            h_bottom = self._upsample_convs[idx](inputs=h, training=training)
216
            decoded.append(h_bottom)
217
218
        # output
219
        output = tf.add_n(
220
            [
221
                self._extract_layers[idx](
222
                    inputs=decoded[self._extract_max_level - level]
223
                )
224
                for idx, level in enumerate(self._extract_levels)
225
            ]
226
        ) / len(self._extract_levels)
227
228
        if self.resize:
229
            output = self.resize(output)
230
            output = self.interpolate(output)
231
232
        return output
233