Passed
Pull Request — main (#656)
by Yunguan
02:35
created

deepreg.model.backbone.u_net.UNet.call()   B

Complexity

Conditions 5

Size

Total Lines 46
Code Lines 23

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 23
dl 0
loc 46
rs 8.8613
c 0
b 0
f 0
cc 5
nop 4
1
# coding=utf-8
0 ignored issues
show
introduced by
Missing module docstring
Loading history...
2
3
4
import tensorflow as tf
5
import tensorflow.keras.layers as tfkl
6
7
import deepreg.model.layer_util
8
from deepreg.model import layer
9
from deepreg.model.backbone.interface import Backbone
10
from deepreg.registry import REGISTRY
11
12
13
@REGISTRY.register_backbone(name="unet")
14
class UNet(Backbone):
15
    """
16
    Class that implements an adapted 3D UNet.
17
18
    Reference:
19
20
    - O. Ronneberger, P. Fischer, and T. Brox,
21
      “U-net: Convolutional networks for biomedical image segmentation,”,
22
      Lecture Notes in Computer Science, 2015, vol. 9351, pp. 234–241.
23
      https://arxiv.org/abs/1505.04597
24
    """
25
26
    def __init__(
27
        self,
28
        image_size: tuple,
29
        out_channels: int,
30
        num_channel_initial: int,
31
        depth: int,
32
        out_kernel_initializer: str,
33
        out_activation: str,
34
        pooling: bool = True,
35
        concat_skip: bool = False,
36
        control_points: (tuple, None) = None,
37
        name: str = "Unet",
38
        **kwargs,
39
    ):
40
        """
41
        Initialise UNet.
42
43
        :param image_size: (dim1, dim2, dim3), dims of input image.
44
        :param out_channels: number of channels for the output
45
        :param num_channel_initial: number of initial channels
46
        :param depth: input is at level 0, bottom is at level depth
47
        :param out_kernel_initializer: kernel initializer for the last layer
48
        :param out_activation: activation at the last layer
49
        :param pooling: for downsampling, use non-parameterized
50
                        pooling if true, otherwise use conv3d
51
        :param concat_skip: when upsampling, concatenate skipped
52
                            tensor if true, otherwise use addition
53
        :param control_points: specify the distance between control points (in voxels).
54
        :param name: name of the backbone.
55
        :param kwargs: additional arguments.
56
        """
57
        super().__init__(
58
            image_size=image_size,
59
            out_channels=out_channels,
60
            num_channel_initial=num_channel_initial,
61
            out_kernel_initializer=out_kernel_initializer,
62
            out_activation=out_activation,
63
            name=name,
64
            **kwargs,
65
        )
66
67
        # init layer variables
68
        num_channels = [num_channel_initial * (2 ** d) for d in range(depth + 1)]
69
70
        self._num_channel_initial = num_channel_initial
71
        self._depth = depth
72
        self._concat_skip = concat_skip
73
        self._downsample_convs = []
74
        self._downsample_pools = []
75
        tensor_shape = image_size
76
        self._tensor_shapes = [tensor_shape]
77
        for d in range(depth):
78
            downsample_conv = tf.keras.Sequential(
79
                [
80
                    layer.Conv3dBlock(
81
                        filters=num_channels[d], kernel_size=3, padding="same"
82
                    ),
83
                    layer.ResidualConv3dBlock(
84
                        filters=num_channels[d], kernel_size=3, padding="same"
85
                    ),
86
                ]
87
            )
88
            if pooling:
89
                downsample_pool = tfkl.MaxPool3D(pool_size=2, strides=2, padding="same")
90
            else:
91
                downsample_pool = layer.Conv3dBlock(
92
                    filters=num_channels[d], kernel_size=3, strides=2, padding="same"
93
                )
94
            tensor_shape = tuple((x + 1) // 2 for x in tensor_shape)
95
            self._downsample_convs.append(downsample_conv)
96
            self._downsample_pools.append(downsample_pool)
97
            self._tensor_shapes.append(tensor_shape)
98
        self._bottom_conv3d = layer.Conv3dBlock(
99
            filters=num_channels[depth], kernel_size=3, padding="same"
100
        )
101
        self._bottom_res3d = layer.ResidualConv3dBlock(
102
            filters=num_channels[depth], kernel_size=3, padding="same"
103
        )
104
        self._upsample_deconvs = []
105
        self._upsample_convs = []
106
        for d in range(depth):
107
            padding = deepreg.model.layer_util.deconv_output_padding(
108
                input_shape=self._tensor_shapes[d + 1],
109
                output_shape=self._tensor_shapes[d],
110
                kernel_size=3,
111
                stride=2,
112
                padding="same",
113
            )
114
            upsample_deconv = layer.Deconv3dBlock(
115
                filters=num_channels[d],
116
                output_padding=padding,
117
                kernel_size=3,
118
                strides=2,
119
                padding="same",
120
            )
121
            upsample_conv = tf.keras.Sequential(
122
                [
123
                    layer.Conv3dBlock(
124
                        filters=num_channels[d], kernel_size=3, padding="same"
125
                    ),
126
                    layer.ResidualConv3dBlock(
127
                        filters=num_channels[d], kernel_size=3, padding="same"
128
                    ),
129
                ]
130
            )
131
            self._upsample_deconvs.append(upsample_deconv)
132
            self._upsample_convs.append(upsample_conv)
133
        self._output_conv3d = tf.keras.Sequential(
134
            [
135
                tfkl.Conv3D(
136
                    filters=out_channels,
137
                    kernel_size=3,
138
                    strides=1,
139
                    padding="same",
140
                    kernel_initializer=out_kernel_initializer,
141
                    activation=out_activation,
142
                ),
143
                layer.Resize3d(shape=image_size),
144
            ]
145
        )
146
147
        self.resize = (
148
            layer.ResizeCPTransform(control_points)
149
            if control_points is not None
150
            else False
151
        )
152
        self.interpolate = (
153
            layer.BSplines3DTransform(control_points, image_size)
154
            if control_points is not None
155
            else False
156
        )
157
158
    def call(self, inputs: tf.Tensor, training=None, mask=None) -> tf.Tensor:
0 ignored issues
show
introduced by
"mask, training" missing in parameter type documentation
Loading history...
159
        """
160
        Builds graph based on built layers.
161
162
        :param inputs: shape = [batch, f_dim1, f_dim2, f_dim3, in_channels]
163
        :param training:
164
        :param mask:
165
        :return: shape = [batch, f_dim1, f_dim2, f_dim3, out_channels]
166
        """
167
168
        down_sampled = inputs
169
170
        # down sample
171
        skips = []
172
        for d_var in range(self._depth):  # level 0 to D-1
173
            skip = self._downsample_convs[d_var](inputs=down_sampled, training=training)
174
            down_sampled = self._downsample_pools[d_var](inputs=skip, training=training)
175
            skips.append(skip)
176
177
        # bottom, level D
178
        up_sampled = self._bottom_res3d(
179
            inputs=self._bottom_conv3d(inputs=down_sampled, training=training),
180
            training=training,
181
        )
182
183
        # up sample, level D-1 to 0
184
        for d_var in range(self._depth - 1, -1, -1):
185
            up_sampled = self._upsample_deconvs[d_var](
186
                inputs=up_sampled, training=training
187
            )
188
            if self._concat_skip:
189
                up_sampled = tf.concat([up_sampled, skips[d_var]], axis=4)
190
            else:
191
                up_sampled = up_sampled + skips[d_var]
192
            up_sampled = self._upsample_convs[d_var](
193
                inputs=up_sampled, training=training
194
            )
195
196
        # output
197
        output = self._output_conv3d(inputs=up_sampled, training=training)
198
199
        if self.resize:
200
            output = self.resize(output)
201
            output = self.interpolate(output)
202
203
        return output
204