Passed
Pull Request — main (#656)
by Yunguan
02:46
created

deepreg.model.backbone.u_net.UNet.__init__()   C

Complexity

Conditions 6

Size

Total Lines 130
Code Lines 89

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 89
dl 0
loc 130
rs 6.4266
c 0
b 0
f 0
cc 6
nop 12

How to fix   Long Method    Many Parameters   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
# coding=utf-8
0 ignored issues
show
introduced by
Missing module docstring
Loading history...
2
3
4
import tensorflow as tf
5
import tensorflow.keras.layers as tfkl
6
7
from deepreg.model import layer
8
from deepreg.model.backbone.interface import Backbone
9
from deepreg.registry import REGISTRY
10
11
12
@REGISTRY.register_backbone(name="unet")
13
class UNet(Backbone):
14
    """
15
    Class that implements an adapted 3D UNet.
16
17
    Reference:
18
19
    - O. Ronneberger, P. Fischer, and T. Brox,
20
      “U-net: Convolutional networks for biomedical image segmentation,”,
21
      Lecture Notes in Computer Science, 2015, vol. 9351, pp. 234–241.
22
      https://arxiv.org/abs/1505.04597
23
    """
24
25
    def __init__(
26
        self,
27
        image_size: tuple,
28
        out_channels: int,
29
        num_channel_initial: int,
30
        depth: int,
31
        out_kernel_initializer: str,
32
        out_activation: str,
33
        pooling: bool = True,
34
        concat_skip: bool = False,
35
        control_points: (tuple, None) = None,
36
        name: str = "Unet",
37
        **kwargs,
38
    ):
39
        """
40
        Initialise UNet.
41
42
        :param image_size: (dim1, dim2, dim3), dims of input image.
43
        :param out_channels: number of channels for the output
44
        :param num_channel_initial: number of initial channels
45
        :param depth: input is at level 0, bottom is at level depth
46
        :param out_kernel_initializer: kernel initializer for the last layer
47
        :param out_activation: activation at the last layer
48
        :param pooling: for downsampling, use non-parameterized
49
                        pooling if true, otherwise use conv3d
50
        :param concat_skip: when upsampling, concatenate skipped
51
                            tensor if true, otherwise use addition
52
        :param control_points: specify the distance between control points (in voxels).
53
        :param name: name of the backbone.
54
        :param kwargs: additional arguments.
55
        """
56
        super().__init__(
57
            image_size=image_size,
58
            out_channels=out_channels,
59
            num_channel_initial=num_channel_initial,
60
            out_kernel_initializer=out_kernel_initializer,
61
            out_activation=out_activation,
62
            name=name,
63
            **kwargs,
64
        )
65
66
        # init layer variables
67
        num_channels = [num_channel_initial * (2 ** d) for d in range(depth + 1)]
68
69
        self._num_channel_initial = num_channel_initial
70
        self._depth = depth
71
        self._concat_skip = concat_skip
72
        self._downsample_convs = []
73
        self._downsample_pools = []
74
        tensor_shape = image_size
75
        self._tensor_shapes = [tensor_shape]
76
        for d in range(depth):
77
            downsample_conv = tf.keras.Sequential(
78
                [
79
                    layer.Conv3dBlock(
80
                        filters=num_channels[d], kernel_size=3, padding="same"
81
                    ),
82
                    layer.ResidualConv3dBlock(
83
                        filters=num_channels[d], kernel_size=3, padding="same"
84
                    ),
85
                ]
86
            )
87
            if pooling:
88
                downsample_pool = tfkl.MaxPool3D(pool_size=2, strides=2, padding="same")
89
            else:
90
                downsample_pool = layer.Conv3dBlock(
91
                    filters=num_channels[d], kernel_size=3, strides=2, padding="same"
92
                )
93
            tensor_shape = tuple((x + 1) // 2 for x in tensor_shape)
94
            self._downsample_convs.append(downsample_conv)
95
            self._downsample_pools.append(downsample_pool)
96
            self._tensor_shapes.append(tensor_shape)
97
        self._bottom_conv3d = layer.Conv3dBlock(
98
            filters=num_channels[depth], kernel_size=3, padding="same"
99
        )
100
        self._bottom_res3d = layer.ResidualConv3dBlock(
101
            filters=num_channels[depth], kernel_size=3, padding="same"
102
        )
103
        self._upsample_deconvs = []
104
        self._upsample_convs = []
105
        for d in range(depth):
106
            padding = layer.deconv_output_padding(
107
                input_shape=self._tensor_shapes[d + 1],
108
                output_shape=self._tensor_shapes[d],
109
                kernel_size=3,
110
                stride=2,
111
                padding="same",
112
            )
113
            upsample_deconv = layer.Deconv3dBlock(
114
                filters=num_channels[d],
115
                output_padding=padding,
116
                kernel_size=3,
117
                strides=2,
118
                padding="same",
119
            )
120
            upsample_conv = tf.keras.Sequential(
121
                [
122
                    layer.Conv3dBlock(
123
                        filters=num_channels[d], kernel_size=3, padding="same"
124
                    ),
125
                    layer.ResidualConv3dBlock(
126
                        filters=num_channels[d], kernel_size=3, padding="same"
127
                    ),
128
                ]
129
            )
130
            self._upsample_deconvs.append(upsample_deconv)
131
            self._upsample_convs.append(upsample_conv)
132
        self._output_conv3d = tf.keras.Sequential(
133
            [
134
                tfkl.Conv3D(
135
                    filters=out_channels,
136
                    kernel_size=3,
137
                    strides=1,
138
                    padding="same",
139
                    kernel_initializer=out_kernel_initializer,
140
                    activation=out_activation,
141
                ),
142
                layer.Resize3d(shape=image_size),
143
            ]
144
        )
145
146
        self.resize = (
147
            layer.ResizeCPTransform(control_points)
148
            if control_points is not None
149
            else False
150
        )
151
        self.interpolate = (
152
            layer.BSplines3DTransform(control_points, image_size)
153
            if control_points is not None
154
            else False
155
        )
156
157
    def call(self, inputs: tf.Tensor, training=None, mask=None) -> tf.Tensor:
0 ignored issues
show
introduced by
"mask, training" missing in parameter type documentation
Loading history...
158
        """
159
        Builds graph based on built layers.
160
161
        :param inputs: shape = [batch, f_dim1, f_dim2, f_dim3, in_channels]
162
        :param training:
163
        :param mask:
164
        :return: shape = [batch, f_dim1, f_dim2, f_dim3, out_channels]
165
        """
166
167
        down_sampled = inputs
168
169
        # down sample
170
        skips = []
171
        for d_var in range(self._depth):  # level 0 to D-1
172
            skip = self._downsample_convs[d_var](inputs=down_sampled, training=training)
173
            down_sampled = self._downsample_pools[d_var](inputs=skip, training=training)
174
            skips.append(skip)
175
176
        # bottom, level D
177
        up_sampled = self._bottom_res3d(
178
            inputs=self._bottom_conv3d(inputs=down_sampled, training=training),
179
            training=training,
180
        )
181
182
        # up sample, level D-1 to 0
183
        for d_var in range(self._depth - 1, -1, -1):
184
            up_sampled = self._upsample_deconvs[d_var](
185
                inputs=up_sampled, training=training
186
            )
187
            if self._concat_skip:
188
                up_sampled = tf.concat([up_sampled, skips[d_var]], axis=4)
189
            else:
190
                up_sampled = up_sampled + skips[d_var]
191
            up_sampled = self._upsample_convs[d_var](
192
                inputs=up_sampled, training=training
193
            )
194
195
        # output
196
        output = self._output_conv3d(inputs=up_sampled, training=training)
197
198
        if self.resize:
199
            output = self.resize(output)
200
            output = self.interpolate(output)
201
202
        return output
203