Passed
Pull Request — main (#656)
by Yunguan
02:46
created

deepreg.model.backbone.local_net   A

Complexity

Total Complexity 10

Size/Duplication

Total Lines 213
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 10
eloc 120
dl 0
loc 213
rs 10
c 0
b 0
f 0

2 Methods

Rating   Name   Duplication   Size   Complexity  
C LocalNet.__init__() 0 126 6
A LocalNet.call() 0 53 4
1
# coding=utf-8
0 ignored issues
show
introduced by
Missing module docstring
Loading history...
2
3
from typing import List
4
5
import tensorflow as tf
6
import tensorflow.keras.layers as tfkl
7
8
from deepreg.model import layer
9
from deepreg.model.backbone.interface import Backbone
10
from deepreg.registry import REGISTRY
11
12
13
@REGISTRY.register_backbone(name="local")
14
class LocalNet(Backbone):
15
    """
16
    Build LocalNet for image registration.
17
18
    Reference:
19
20
    - Hu, Yipeng, et al.
21
      "Weakly-supervised convolutional neural networks
22
      for multimodal image registration."
23
      Medical image analysis 49 (2018): 1-13.
24
      https://doi.org/10.1016/j.media.2018.07.002
25
26
    - Hu, Yipeng, et al.
27
      "Label-driven weakly-supervised learning
28
      for multimodal deformable image registration,"
29
      https://arxiv.org/abs/1711.01666
30
    """
31
32
    def __init__(
33
        self,
34
        image_size: tuple,
35
        out_channels: int,
36
        num_channel_initial: int,
37
        extract_levels: List[int],
38
        out_kernel_initializer: str,
39
        out_activation: str,
40
        control_points: (tuple, None) = None,
41
        name: str = "LocalNet",
42
        **kwargs,
43
    ):
44
        """
45
        Image is encoded gradually, i from level 0 to E,
46
        then it is decoded gradually, j from level E to D.
47
        Some of the decoded levels are used for generating extractions.
48
49
        So, extract_levels are between [0, E] with E = max(extract_levels),
50
        and D = min(extract_levels).
51
52
        :param image_size: such as (dim1, dim2, dim3)
53
        :param out_channels: number of channels for the extractions
54
        :param num_channel_initial: number of initial channels.
55
        :param extract_levels: number of extraction levels.
56
        :param out_kernel_initializer: initializer to use for kernels.
57
        :param out_activation: activation to use at end layer.
58
        :param control_points: specify the distance between control points (in voxels).
59
        :param name: name of the backbone.
60
        :param kwargs: additional arguments.
61
        """
62
        super().__init__(
63
            image_size=image_size,
64
            out_channels=out_channels,
65
            num_channel_initial=num_channel_initial,
66
            out_kernel_initializer=out_kernel_initializer,
67
            out_activation=out_activation,
68
            name=name,
69
            **kwargs,
70
        )
71
72
        # save parameters
73
        self._extract_levels = extract_levels
74
        self._extract_max_level = max(self._extract_levels)  # E
75
        self._extract_min_level = min(self._extract_levels)  # D
76
77
        # init layer variables
78
        num_channels = [
79
            num_channel_initial * (2 ** level)
80
            for level in range(self._extract_max_level + 1)
81
        ]  # level 0 to E
82
        kernel_sizes = [
83
            7 if level == 0 else 3 for level in range(self._extract_max_level + 1)
84
        ]
85
        self._downsample_convs = []
86
        self._downsample_pools = []
87
        tensor_shape = image_size
88
        self._tensor_shapes = [tensor_shape]
89
        for i in range(self._extract_max_level):
90
            downsample_conv = tf.keras.Sequential(
91
                [
92
                    layer.Conv3dBlock(
93
                        filters=num_channels[i],
94
                        kernel_size=kernel_sizes[i],
95
                        padding="same",
96
                    ),
97
                    layer.ResidualConv3dBlock(
98
                        filters=num_channels[i],
99
                        kernel_size=kernel_sizes[i],
100
                        padding="same",
101
                    ),
102
                ]
103
            )
104
            downsample_pool = tfkl.MaxPool3D(pool_size=2, strides=2, padding="same")
105
            tensor_shape = tuple((x + 1) // 2 for x in tensor_shape)
106
            self._downsample_convs.append(downsample_conv)
107
            self._downsample_pools.append(downsample_pool)
108
            self._tensor_shapes.append(tensor_shape)
109
110
        self._conv3d_block = layer.Conv3dBlock(
111
            filters=num_channels[-1], kernel_size=3, padding="same"
112
        )  # level E
113
114
        self._upsample_blocks = []
115
        for level in range(
116
            self._extract_max_level - 1, self._extract_min_level - 1, -1
117
        ):  # level D to E-1
118
            padding = layer.deconv_output_padding(
119
                input_shape=self._tensor_shapes[level + 1],
120
                output_shape=self._tensor_shapes[level],
121
                kernel_size=kernel_sizes[level],
122
                stride=2,
123
                padding="same",
124
            )
125
            upsample_block = layer.LocalNetUpSampleResnetBlock(
126
                num_channels[level],
127
                output_padding=padding,
128
                output_shape=self._tensor_shapes[level],
129
            )
130
            self._upsample_blocks.append(upsample_block)
131
132
        self._extract_layers = [
133
            tf.keras.Sequential(
134
                [
135
                    tfkl.Conv3D(
136
                        filters=out_channels,
137
                        kernel_size=3,
138
                        strides=1,
139
                        padding="same",
140
                        kernel_initializer=out_kernel_initializer,
141
                        activation=out_activation,
142
                    ),
143
                    layer.Resize3d(shape=image_size),
144
                ]
145
            )
146
            for _ in self._extract_levels
147
        ]
148
149
        self.resize = (
150
            layer.ResizeCPTransform(control_points)
151
            if control_points is not None
152
            else False
153
        )
154
        self.interpolate = (
155
            layer.BSplines3DTransform(control_points, image_size)
156
            if control_points is not None
157
            else False
158
        )
159
160
    def call(self, inputs: tf.Tensor, training=None, mask=None) -> tf.Tensor:
0 ignored issues
show
introduced by
"mask, training" missing in parameter type documentation
Loading history...
161
        """
162
        Build LocalNet graph based on built layers.
163
164
        :param inputs: image batch, shape = (batch, f_dim1, f_dim2, f_dim3, ch)
165
        :param training: None or bool.
166
        :param mask: None or tf.Tensor.
167
        :return: shape = (batch, f_dim1, f_dim2, f_dim3, out_channels)
168
        """
169
170
        # down sample from level 0 to E
171
        encoded = []
172
        # outputs used for decoding, encoded[i] corresponds -> level i
173
        # stored only 0 to E-1
174
175
        h_in = inputs
176
        for level in range(self._extract_max_level):  # level 0 to E - 1
177
            skip = self._downsample_convs[level](inputs=h_in, training=training)
178
            h_in = self._downsample_pools[level](inputs=skip)
179
            encoded.append(skip)
180
        h_bottom = self._conv3d_block(
181
            inputs=h_in, training=training
182
        )  # level E of encoding/decoding
183
184
        # up sample from level E to D
185
        decoded = [h_bottom]  # level E
186
        for idx, level in enumerate(
187
            range(self._extract_max_level - 1, self._extract_min_level - 1, -1)
188
        ):  # level E-1 to D
189
            h_bottom = self._upsample_blocks[idx](
190
                inputs=[h_bottom, encoded[level]], training=training
191
            )
192
            decoded.append(h_bottom)
193
194
        # output
195
        output = tf.reduce_mean(
196
            tf.stack(
197
                [
198
                    self._extract_layers[idx](
199
                        inputs=decoded[self._extract_max_level - level]
200
                    )
201
                    for idx, level in enumerate(self._extract_levels)
202
                ],
203
                axis=5,
204
            ),
205
            axis=5,
206
        )
207
208
        if self.resize:
209
            output = self.resize(output)
210
            output = self.interpolate(output)
211
212
        return output
213