Passed
Pull Request — main (#656)
by Yunguan
02:48
created

deepreg.model.backbone.local_net   A

Complexity

Total Complexity 9

Size/Duplication

Total Lines 217
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 9
eloc 129
dl 0
loc 217
rs 10
c 0
b 0
f 0

2 Methods

Rating   Name   Duplication   Size   Complexity  
C LocalNet.__init__() 0 133 5
A LocalNet.call() 0 49 4
1
# coding=utf-8
0 ignored issues
show
introduced by
Missing module docstring
Loading history...
2
3
from typing import List
4
5
import tensorflow as tf
6
import tensorflow.keras.layers as tfkl
7
8
import deepreg.model.layer_util
9
from deepreg.model import layer
10
from deepreg.model.backbone.interface import Backbone
11
from deepreg.registry import REGISTRY
12
13
14
@REGISTRY.register_backbone(name="local")
15
class LocalNet(Backbone):
16
    """
17
    Build LocalNet for image registration.
18
19
    Reference:
20
21
    - Hu, Yipeng, et al.
22
      "Weakly-supervised convolutional neural networks
23
      for multimodal image registration."
24
      Medical image analysis 49 (2018): 1-13.
25
      https://doi.org/10.1016/j.media.2018.07.002
26
27
    - Hu, Yipeng, et al.
28
      "Label-driven weakly-supervised learning
29
      for multimodal deformable image registration,"
30
      https://arxiv.org/abs/1711.01666
31
    """
32
33
    def __init__(
34
        self,
35
        image_size: tuple,
36
        out_channels: int,
37
        num_channel_initial: int,
38
        extract_levels: List[int],
39
        out_kernel_initializer: str,
40
        out_activation: str,
41
        use_additive_upsampling: bool = True,
42
        name: str = "LocalNet",
43
        **kwargs,
44
    ):
45
        """
46
        Image is encoded gradually, i from level 0 to E,
47
        then it is decoded gradually, j from level E to D.
48
        Some of the decoded levels are used for generating extractions.
49
50
        So, extract_levels are between [0, E] with E = max(extract_levels),
51
        and D = min(extract_levels).
52
53
        :param image_size: such as (dim1, dim2, dim3)
54
        :param out_channels: number of channels for the extractions
55
        :param num_channel_initial: number of initial channels.
56
        :param extract_levels: number of extraction levels.
57
        :param out_kernel_initializer: initializer to use for kernels.
58
        :param out_activation: activation to use at end layer.
59
        :param use_additive_upsampling: whether use additive up-sampling.
60
        :param name: name of the backbone.
61
        :param kwargs: additional arguments.
62
        """
63
        super().__init__(
64
            image_size=image_size,
65
            out_channels=out_channels,
66
            num_channel_initial=num_channel_initial,
67
            out_kernel_initializer=out_kernel_initializer,
68
            out_activation=out_activation,
69
            name=name,
70
            **kwargs,
71
        )
72
73
        # save parameters
74
        self._extract_levels = extract_levels
75
        self._use_additive_upsampling = use_additive_upsampling
76
        self._extract_max_level = max(self._extract_levels)  # E
77
        self._extract_min_level = min(self._extract_levels)  # D
78
79
        # init layer variables
80
        num_channels = [
81
            num_channel_initial * (2 ** level)
82
            for level in range(self._extract_max_level + 1)
83
        ]  # level 0 to E
84
        kernel_sizes = [
85
            7 if level == 0 else 3 for level in range(self._extract_max_level + 1)
86
        ]
87
        self._downsample_convs = []
88
        self._downsample_pools = []
89
        tensor_shape = image_size
90
        self._tensor_shapes = [tensor_shape]
91
        for i in range(self._extract_max_level):
92
            downsample_conv = tf.keras.Sequential(
93
                [
94
                    layer.Conv3dBlock(
95
                        filters=num_channels[i],
96
                        kernel_size=kernel_sizes[i],
97
                        padding="same",
98
                    ),
99
                    layer.ResidualConv3dBlock(
100
                        filters=num_channels[i],
101
                        kernel_size=kernel_sizes[i],
102
                        padding="same",
103
                    ),
104
                ]
105
            )
106
            downsample_pool = tfkl.MaxPool3D(pool_size=2, strides=2, padding="same")
107
            tensor_shape = tuple((x + 1) // 2 for x in tensor_shape)
108
            self._downsample_convs.append(downsample_conv)
109
            self._downsample_pools.append(downsample_pool)
110
            self._tensor_shapes.append(tensor_shape)
111
112
        self._conv3d_block = layer.Conv3dBlock(
113
            filters=num_channels[-1], kernel_size=3, padding="same"
114
        )  # level E
115
116
        self._upsample_deconvs = []
117
        self._resizes = []
118
        self._upsample_convs = []
119
        for level in range(
120
            self._extract_max_level - 1, self._extract_min_level - 1, -1
121
        ):  # level D to E-1
122
            padding = deepreg.model.layer_util.deconv_output_padding(
123
                input_shape=self._tensor_shapes[level + 1],
124
                output_shape=self._tensor_shapes[level],
125
                kernel_size=kernel_sizes[level],
126
                stride=2,
127
                padding="same",
128
            )
129
            upsample_deconv = layer.Deconv3dBlock(
130
                filters=num_channels[level],
131
                output_padding=padding,
132
                kernel_size=3,
133
                strides=2,
134
                padding="same",
135
            )
136
            upsample_conv = tf.keras.Sequential(
137
                [
138
                    layer.Conv3dBlock(
139
                        filters=num_channels[level], kernel_size=3, padding="same"
140
                    ),
141
                    layer.ResidualConv3dBlock(
142
                        filters=num_channels[level], kernel_size=3, padding="same"
143
                    ),
144
                ]
145
            )
146
            self._upsample_deconvs.append(upsample_deconv)
147
            self._upsample_convs.append(upsample_conv)
148
            if self._use_additive_upsampling:
149
                resize = layer.Resize3d(shape=self._tensor_shapes[level])
150
                self._resizes.append(resize)
151
        self._extract_layers = [
152
            tf.keras.Sequential(
153
                [
154
                    tfkl.Conv3D(
155
                        filters=out_channels,
156
                        kernel_size=3,
157
                        strides=1,
158
                        padding="same",
159
                        kernel_initializer=out_kernel_initializer,
160
                        activation=out_activation,
161
                    ),
162
                    layer.Resize3d(shape=image_size),
163
                ]
164
            )
165
            for _ in self._extract_levels
166
        ]
167
168
    def call(self, inputs: tf.Tensor, training=None, mask=None) -> tf.Tensor:
0 ignored issues
show
introduced by
"mask, training" missing in parameter type documentation
Loading history...
169
        """
170
        Build LocalNet graph based on built layers.
171
172
        :param inputs: image batch, shape = (batch, f_dim1, f_dim2, f_dim3, ch)
173
        :param training: None or bool.
174
        :param mask: None or tf.Tensor.
175
        :return: shape = (batch, f_dim1, f_dim2, f_dim3, out_channels)
176
        """
177
178
        # down sample from level 0 to E
179
        # outputs used for decoding, encoded[i] corresponds -> level i
180
        # stored only 0 to E-1
181
        encoded = []
182
        h_in = inputs
183
        for level in range(self._extract_max_level):  # level 0 to E - 1
184
            skip = self._downsample_convs[level](inputs=h_in, training=training)
185
            h_in = self._downsample_pools[level](inputs=skip, training=training)
186
            encoded.append(skip)
187
        h_bottom = self._conv3d_block(
188
            inputs=h_in, training=training
189
        )  # level E of encoding/decoding
190
191
        # up sample from level E to D
192
        decoded = [h_bottom]  # level E
193
        for idx, level in enumerate(
194
            range(self._extract_max_level - 1, self._extract_min_level - 1, -1)
195
        ):  # level E-1 to D
196
            h = self._upsample_deconvs[idx](inputs=h_bottom, training=training)
197
            if self._use_additive_upsampling:
198
                up_sampled = self._resizes[idx](inputs=h_bottom)
199
                up_sampled = tf.split(up_sampled, num_or_size_splits=2, axis=4)
0 ignored issues
show
Unused Code introduced by
Argument 'axis' passed by position and keyword in function call
Loading history...
200
                up_sampled = tf.add_n(up_sampled)
201
                h = h + up_sampled
202
            h = h + encoded[level]
203
            h_bottom = self._upsample_convs[idx](inputs=h, training=training)
204
            decoded.append(h_bottom)
205
206
        # output
207
        output = tf.add_n(
208
            [
209
                self._extract_layers[idx](
210
                    inputs=decoded[self._extract_max_level - level]
211
                )
212
                for idx, level in enumerate(self._extract_levels)
213
            ]
214
        ) / len(self._extract_levels)
215
216
        return output
217