Passed
Pull Request — main (#656)
by Yunguan
02:46
created

deepreg.model.backbone.u_net.UNet.call()   A

Complexity

Conditions 4

Size

Total Lines 39
Code Lines 18

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 18
dl 0
loc 39
rs 9.5
c 0
b 0
f 0
cc 4
nop 4
1
# coding=utf-8
0 ignored issues
show
introduced by
Missing module docstring
Loading history...
2
3
4
import tensorflow as tf
5
import tensorflow.keras.layers as tfkl
6
7
from deepreg.model import layer
8
from deepreg.model.backbone.interface import Backbone
9
from deepreg.registry import REGISTRY
10
11
12
@REGISTRY.register_backbone(name="unet")
13
class UNet(Backbone):
14
    """
15
    Class that implements an adapted 3D UNet.
16
17
    Reference:
18
19
    - O. Ronneberger, P. Fischer, and T. Brox,
20
      “U-net: Convolutional networks for biomedical image segmentation,”,
21
      Lecture Notes in Computer Science, 2015, vol. 9351, pp. 234–241.
22
      https://arxiv.org/abs/1505.04597
23
    """
24
25
    def __init__(
26
        self,
27
        image_size: tuple,
28
        out_channels: int,
29
        num_channel_initial: int,
30
        depth: int,
31
        out_kernel_initializer: str,
32
        out_activation: str,
33
        pooling: bool = True,
34
        concat_skip: bool = False,
35
        control_points: (tuple, None) = None,
36
        name: str = "Unet",
37
        **kwargs,
38
    ):
39
        """
40
        Initialise UNet.
41
42
        :param image_size: (dim1, dim2, dim3), dims of input image.
43
        :param out_channels: number of channels for the output
44
        :param num_channel_initial: number of initial channels
45
        :param depth: input is at level 0, bottom is at level depth
46
        :param out_kernel_initializer: kernel initializer for the last layer
47
        :param out_activation: activation at the last layer
48
        :param pooling: for downsampling, use non-parameterized
49
                        pooling if true, otherwise use conv3d
50
        :param concat_skip: when upsampling, concatenate skipped
51
                            tensor if true, otherwise use addition
52
        :param control_points: specify the distance between control points (in voxels).
53
        :param name: name of the backbone.
54
        :param kwargs: additional arguments.
55
        """
56
        super().__init__(
57
            image_size=image_size,
58
            out_channels=out_channels,
59
            num_channel_initial=num_channel_initial,
60
            out_kernel_initializer=out_kernel_initializer,
61
            out_activation=out_activation,
62
            name=name,
63
            **kwargs,
64
        )
65
66
        # init layer variables
67
        num_channels = [num_channel_initial * (2 ** d) for d in range(depth + 1)]
68
69
        self._num_channel_initial = num_channel_initial
70
        self._depth = depth
71
        self._downsample_convs = []
72
        self._downsample_pools = []
73
        self._upsample_blocks = []
74
        tensor_shape = image_size
75
        self._tensor_shapes = [tensor_shape]
76
        for d in range(depth):
77
            downsample_conv = tf.keras.Sequential(
78
                [
79
                    layer.Conv3dBlock(
80
                        filters=num_channels[d], kernel_size=3, padding="same"
81
                    ),
82
                    layer.ResidualConv3dBlock(
83
                        filters=num_channels[d], kernel_size=3, padding="same"
84
                    ),
85
                ]
86
            )
87
            if pooling:
88
                downsample_pool = tfkl.MaxPool3D(pool_size=2, strides=2, padding="same")
89
            else:
90
                downsample_pool = layer.Conv3dBlock(
91
                    filters=num_channels[d], kernel_size=3, strides=2, padding="same"
92
                )
93
            tensor_shape = tuple((x + 1) // 2 for x in tensor_shape)
94
            self._downsample_convs.append(downsample_conv)
95
            self._downsample_pools.append(downsample_pool)
96
            self._tensor_shapes.append(tensor_shape)
97
        self._bottom_conv3d = layer.Conv3dBlock(
98
            filters=num_channels[depth], kernel_size=3, padding="same"
99
        )
100
        self._bottom_res3d = layer.ResidualConv3dBlock(
101
            filters=num_channels[depth], kernel_size=3, padding="same"
102
        )
103
        for d in range(depth):
104
            padding = layer.deconv_output_padding(
105
                input_shape=self._tensor_shapes[d + 1],
106
                output_shape=self._tensor_shapes[d],
107
                kernel_size=3,
108
                stride=2,
109
                padding="same",
110
            )
111
            upsample_block = layer.UpSampleResnetBlock(
112
                filters=num_channels[d], output_padding=padding, concat=concat_skip
113
            )
114
            self._upsample_blocks.append(upsample_block)
115
            self._tensor_shapes.append(tensor_shape)
116
        self._output_conv3d = tf.keras.Sequential(
117
            [
118
                tfkl.Conv3D(
119
                    filters=out_channels,
120
                    kernel_size=3,
121
                    strides=1,
122
                    padding="same",
123
                    kernel_initializer=out_kernel_initializer,
124
                    activation=out_activation,
125
                ),
126
                layer.Resize3d(shape=image_size),
127
            ]
128
        )
129
130
        self.resize = (
131
            layer.ResizeCPTransform(control_points)
132
            if control_points is not None
133
            else False
134
        )
135
        self.interpolate = (
136
            layer.BSplines3DTransform(control_points, image_size)
137
            if control_points is not None
138
            else False
139
        )
140
141
    def call(self, inputs: tf.Tensor, training=None, mask=None) -> tf.Tensor:
0 ignored issues
show
introduced by
"mask, training" missing in parameter type documentation
Loading history...
142
        """
143
        Builds graph based on built layers.
144
145
        :param inputs: shape = [batch, f_dim1, f_dim2, f_dim3, in_channels]
146
        :param training:
147
        :param mask:
148
        :return: shape = [batch, f_dim1, f_dim2, f_dim3, out_channels]
149
        """
150
151
        down_sampled = inputs
152
153
        # down sample
154
        skips = []
155
        for d_var in range(self._depth):  # level 0 to D-1
156
            skip = self._downsample_convs[d_var](inputs=down_sampled, training=training)
157
            down_sampled = self._downsample_pools[d_var](inputs=skip)
158
            skips.append(skip)
159
160
        # bottom, level D
161
        up_sampled = self._bottom_res3d(
162
            inputs=self._bottom_conv3d(inputs=down_sampled, training=training),
163
            training=training,
164
        )
165
166
        # up sample, level D-1 to 0
167
        for d_var in range(self._depth - 1, -1, -1):
168
            up_sampled = self._upsample_blocks[d_var](
169
                inputs=[up_sampled, skips[d_var]], training=training
170
            )
171
172
        # output
173
        output = self._output_conv3d(inputs=up_sampled)
174
175
        if self.resize:
176
            output = self.resize(output)
177
            output = self.interpolate(output)
178
179
        return output
180