Passed
Pull Request — main (#656)
by Yunguan
02:48
created

deepreg.model.backbone.u_net   A

Complexity

Total Complexity 8

Size/Duplication

Total Lines 187
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 8
eloc 109
dl 0
loc 187
rs 10
c 0
b 0
f 0

2 Methods

Rating   Name   Duplication   Size   Complexity  
B UNet.__init__() 0 116 4
A UNet.call() 0 42 4
1
# coding=utf-8
0 ignored issues
show
introduced by
Missing module docstring
Loading history...
2
3
4
import tensorflow as tf
5
import tensorflow.keras.layers as tfkl
6
7
import deepreg.model.layer_util
8
from deepreg.model import layer
9
from deepreg.model.backbone.interface import Backbone
10
from deepreg.registry import REGISTRY
11
12
13
@REGISTRY.register_backbone(name="unet")
14
class UNet(Backbone):
15
    """
16
    Class that implements an adapted 3D UNet.
17
18
    Reference:
19
20
    - O. Ronneberger, P. Fischer, and T. Brox,
21
      “U-net: Convolutional networks for biomedical image segmentation,”,
22
      Lecture Notes in Computer Science, 2015, vol. 9351, pp. 234–241.
23
      https://arxiv.org/abs/1505.04597
24
    """
25
26
    def __init__(
27
        self,
28
        image_size: tuple,
29
        out_channels: int,
30
        num_channel_initial: int,
31
        depth: int,
32
        out_kernel_initializer: str,
33
        out_activation: str,
34
        pooling: bool = True,
35
        concat_skip: bool = False,
36
        name: str = "Unet",
37
        **kwargs,
38
    ):
39
        """
40
        Initialise UNet.
41
42
        :param image_size: (dim1, dim2, dim3), dims of input image.
43
        :param out_channels: number of channels for the output
44
        :param num_channel_initial: number of initial channels
45
        :param depth: input is at level 0, bottom is at level depth
46
        :param out_kernel_initializer: kernel initializer for the last layer
47
        :param out_activation: activation at the last layer
48
        :param pooling: for downsampling, use non-parameterized
49
                        pooling if true, otherwise use conv3d
50
        :param concat_skip: when upsampling, concatenate skipped
51
                            tensor if true, otherwise use addition
52
        :param name: name of the backbone.
53
        :param kwargs: additional arguments.
54
        """
55
        super().__init__(
56
            image_size=image_size,
57
            out_channels=out_channels,
58
            num_channel_initial=num_channel_initial,
59
            out_kernel_initializer=out_kernel_initializer,
60
            out_activation=out_activation,
61
            name=name,
62
            **kwargs,
63
        )
64
65
        # init layer variables
66
        num_channels = [num_channel_initial * (2 ** d) for d in range(depth + 1)]
67
68
        self._num_channel_initial = num_channel_initial
69
        self._depth = depth
70
        self._concat_skip = concat_skip
71
        self._downsample_convs = []
72
        self._downsample_pools = []
73
        tensor_shape = image_size
74
        self._tensor_shapes = [tensor_shape]
75
        for d in range(depth):
76
            downsample_conv = tf.keras.Sequential(
77
                [
78
                    layer.Conv3dBlock(
79
                        filters=num_channels[d], kernel_size=3, padding="same"
80
                    ),
81
                    layer.ResidualConv3dBlock(
82
                        filters=num_channels[d], kernel_size=3, padding="same"
83
                    ),
84
                ]
85
            )
86
            if pooling:
87
                downsample_pool = tfkl.MaxPool3D(pool_size=2, strides=2, padding="same")
88
            else:
89
                downsample_pool = layer.Conv3dBlock(
90
                    filters=num_channels[d], kernel_size=3, strides=2, padding="same"
91
                )
92
            tensor_shape = tuple((x + 1) // 2 for x in tensor_shape)
93
            self._downsample_convs.append(downsample_conv)
94
            self._downsample_pools.append(downsample_pool)
95
            self._tensor_shapes.append(tensor_shape)
96
        self._bottom_conv3d = layer.Conv3dBlock(
97
            filters=num_channels[depth], kernel_size=3, padding="same"
98
        )
99
        self._bottom_res3d = layer.ResidualConv3dBlock(
100
            filters=num_channels[depth], kernel_size=3, padding="same"
101
        )
102
        self._upsample_deconvs = []
103
        self._upsample_convs = []
104
        for d in range(depth):
105
            padding = deepreg.model.layer_util.deconv_output_padding(
106
                input_shape=self._tensor_shapes[d + 1],
107
                output_shape=self._tensor_shapes[d],
108
                kernel_size=3,
109
                stride=2,
110
                padding="same",
111
            )
112
            upsample_deconv = layer.Deconv3dBlock(
113
                filters=num_channels[d],
114
                output_padding=padding,
115
                kernel_size=3,
116
                strides=2,
117
                padding="same",
118
            )
119
            upsample_conv = tf.keras.Sequential(
120
                [
121
                    layer.Conv3dBlock(
122
                        filters=num_channels[d], kernel_size=3, padding="same"
123
                    ),
124
                    layer.ResidualConv3dBlock(
125
                        filters=num_channels[d], kernel_size=3, padding="same"
126
                    ),
127
                ]
128
            )
129
            self._upsample_deconvs.append(upsample_deconv)
130
            self._upsample_convs.append(upsample_conv)
131
        self._output_conv3d = tf.keras.Sequential(
132
            [
133
                tfkl.Conv3D(
134
                    filters=out_channels,
135
                    kernel_size=3,
136
                    strides=1,
137
                    padding="same",
138
                    kernel_initializer=out_kernel_initializer,
139
                    activation=out_activation,
140
                ),
141
                layer.Resize3d(shape=image_size),
142
            ]
143
        )
144
145
    def call(self, inputs: tf.Tensor, training=None, mask=None) -> tf.Tensor:
0 ignored issues
show
introduced by
"mask, training" missing in parameter type documentation
Loading history...
146
        """
147
        Builds graph based on built layers.
148
149
        :param inputs: shape = [batch, f_dim1, f_dim2, f_dim3, in_channels]
150
        :param training:
151
        :param mask:
152
        :return: shape = [batch, f_dim1, f_dim2, f_dim3, out_channels]
153
        """
154
155
        down_sampled = inputs
156
157
        # down sample
158
        skips = []
159
        for d_var in range(self._depth):  # level 0 to D-1
160
            skip = self._downsample_convs[d_var](inputs=down_sampled, training=training)
161
            down_sampled = self._downsample_pools[d_var](inputs=skip, training=training)
162
            skips.append(skip)
163
164
        # bottom, level D
165
        up_sampled = self._bottom_res3d(
166
            inputs=self._bottom_conv3d(inputs=down_sampled, training=training),
167
            training=training,
168
        )
169
170
        # up sample, level D-1 to 0
171
        for d_var in range(self._depth - 1, -1, -1):
172
            up_sampled = self._upsample_deconvs[d_var](
173
                inputs=up_sampled, training=training
174
            )
175
            if self._concat_skip:
176
                up_sampled = tf.concat([up_sampled, skips[d_var]], axis=4)
177
            else:
178
                up_sampled = up_sampled + skips[d_var]
179
            up_sampled = self._upsample_convs[d_var](
180
                inputs=up_sampled, training=training
181
            )
182
183
        # output
184
        output = self._output_conv3d(inputs=up_sampled, training=training)
185
186
        return output
187