Passed
Pull Request — main (#656)
by Yunguan
02:35
created

deepreg.model.backbone.u_net.UNet.__init__()   C

Complexity

Conditions 6

Size

Total Lines 130
Code Lines 89

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 89
dl 0
loc 130
rs 6.4266
c 0
b 0
f 0
cc 6
nop 12

How to fix   Long Method    Many Parameters   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
# coding=utf-8
0 ignored issues
show
introduced by
Missing module docstring
Loading history...
2
3
4
import tensorflow as tf
5
import tensorflow.keras.layers as tfkl
6
7
import deepreg.model.layer_util
8
from deepreg.model import layer
9
from deepreg.model.backbone.interface import Backbone
10
from deepreg.registry import REGISTRY
11
12
13
@REGISTRY.register_backbone(name="unet")
14
class UNet(Backbone):
15
    """
16
    Class that implements an adapted 3D UNet.
17
18
    Reference:
19
20
    - O. Ronneberger, P. Fischer, and T. Brox,
21
      “U-net: Convolutional networks for biomedical image segmentation,”,
22
      Lecture Notes in Computer Science, 2015, vol. 9351, pp. 234–241.
23
      https://arxiv.org/abs/1505.04597
24
    """
25
26
    def __init__(
27
        self,
28
        image_size: tuple,
29
        out_channels: int,
30
        num_channel_initial: int,
31
        depth: int,
32
        out_kernel_initializer: str,
33
        out_activation: str,
34
        pooling: bool = True,
35
        concat_skip: bool = False,
36
        control_points: (tuple, None) = None,
37
        name: str = "Unet",
38
        **kwargs,
39
    ):
40
        """
41
        Initialise UNet.
42
43
        :param image_size: (dim1, dim2, dim3), dims of input image.
44
        :param out_channels: number of channels for the output
45
        :param num_channel_initial: number of initial channels
46
        :param depth: input is at level 0, bottom is at level depth
47
        :param out_kernel_initializer: kernel initializer for the last layer
48
        :param out_activation: activation at the last layer
49
        :param pooling: for downsampling, use non-parameterized
50
                        pooling if true, otherwise use conv3d
51
        :param concat_skip: when upsampling, concatenate skipped
52
                            tensor if true, otherwise use addition
53
        :param control_points: specify the distance between control points (in voxels).
54
        :param name: name of the backbone.
55
        :param kwargs: additional arguments.
56
        """
57
        super().__init__(
58
            image_size=image_size,
59
            out_channels=out_channels,
60
            num_channel_initial=num_channel_initial,
61
            out_kernel_initializer=out_kernel_initializer,
62
            out_activation=out_activation,
63
            name=name,
64
            **kwargs,
65
        )
66
67
        # init layer variables
68
        num_channels = [num_channel_initial * (2 ** d) for d in range(depth + 1)]
69
70
        self._num_channel_initial = num_channel_initial
71
        self._depth = depth
72
        self._concat_skip = concat_skip
73
        self._downsample_convs = []
74
        self._downsample_pools = []
75
        tensor_shape = image_size
76
        self._tensor_shapes = [tensor_shape]
77
        for d in range(depth):
78
            downsample_conv = tf.keras.Sequential(
79
                [
80
                    layer.Conv3dBlock(
81
                        filters=num_channels[d], kernel_size=3, padding="same"
82
                    ),
83
                    layer.ResidualConv3dBlock(
84
                        filters=num_channels[d], kernel_size=3, padding="same"
85
                    ),
86
                ]
87
            )
88
            if pooling:
89
                downsample_pool = tfkl.MaxPool3D(pool_size=2, strides=2, padding="same")
90
            else:
91
                downsample_pool = layer.Conv3dBlock(
92
                    filters=num_channels[d], kernel_size=3, strides=2, padding="same"
93
                )
94
            tensor_shape = tuple((x + 1) // 2 for x in tensor_shape)
95
            self._downsample_convs.append(downsample_conv)
96
            self._downsample_pools.append(downsample_pool)
97
            self._tensor_shapes.append(tensor_shape)
98
        self._bottom_conv3d = layer.Conv3dBlock(
99
            filters=num_channels[depth], kernel_size=3, padding="same"
100
        )
101
        self._bottom_res3d = layer.ResidualConv3dBlock(
102
            filters=num_channels[depth], kernel_size=3, padding="same"
103
        )
104
        self._upsample_deconvs = []
105
        self._upsample_convs = []
106
        for d in range(depth):
107
            padding = deepreg.model.layer_util.deconv_output_padding(
108
                input_shape=self._tensor_shapes[d + 1],
109
                output_shape=self._tensor_shapes[d],
110
                kernel_size=3,
111
                stride=2,
112
                padding="same",
113
            )
114
            upsample_deconv = layer.Deconv3dBlock(
115
                filters=num_channels[d],
116
                output_padding=padding,
117
                kernel_size=3,
118
                strides=2,
119
                padding="same",
120
            )
121
            upsample_conv = tf.keras.Sequential(
122
                [
123
                    layer.Conv3dBlock(
124
                        filters=num_channels[d], kernel_size=3, padding="same"
125
                    ),
126
                    layer.ResidualConv3dBlock(
127
                        filters=num_channels[d], kernel_size=3, padding="same"
128
                    ),
129
                ]
130
            )
131
            self._upsample_deconvs.append(upsample_deconv)
132
            self._upsample_convs.append(upsample_conv)
133
        self._output_conv3d = tf.keras.Sequential(
134
            [
135
                tfkl.Conv3D(
136
                    filters=out_channels,
137
                    kernel_size=3,
138
                    strides=1,
139
                    padding="same",
140
                    kernel_initializer=out_kernel_initializer,
141
                    activation=out_activation,
142
                ),
143
                layer.Resize3d(shape=image_size),
144
            ]
145
        )
146
147
        self.resize = (
148
            layer.ResizeCPTransform(control_points)
149
            if control_points is not None
150
            else False
151
        )
152
        self.interpolate = (
153
            layer.BSplines3DTransform(control_points, image_size)
154
            if control_points is not None
155
            else False
156
        )
157
158
    def call(self, inputs: tf.Tensor, training=None, mask=None) -> tf.Tensor:
0 ignored issues
show
introduced by
"mask, training" missing in parameter type documentation
Loading history...
159
        """
160
        Builds graph based on built layers.
161
162
        :param inputs: shape = [batch, f_dim1, f_dim2, f_dim3, in_channels]
163
        :param training:
164
        :param mask:
165
        :return: shape = [batch, f_dim1, f_dim2, f_dim3, out_channels]
166
        """
167
168
        down_sampled = inputs
169
170
        # down sample
171
        skips = []
172
        for d_var in range(self._depth):  # level 0 to D-1
173
            skip = self._downsample_convs[d_var](inputs=down_sampled, training=training)
174
            down_sampled = self._downsample_pools[d_var](inputs=skip, training=training)
175
            skips.append(skip)
176
177
        # bottom, level D
178
        up_sampled = self._bottom_res3d(
179
            inputs=self._bottom_conv3d(inputs=down_sampled, training=training),
180
            training=training,
181
        )
182
183
        # up sample, level D-1 to 0
184
        for d_var in range(self._depth - 1, -1, -1):
185
            up_sampled = self._upsample_deconvs[d_var](
186
                inputs=up_sampled, training=training
187
            )
188
            if self._concat_skip:
189
                up_sampled = tf.concat([up_sampled, skips[d_var]], axis=4)
190
            else:
191
                up_sampled = up_sampled + skips[d_var]
192
            up_sampled = self._upsample_convs[d_var](
193
                inputs=up_sampled, training=training
194
            )
195
196
        # output
197
        output = self._output_conv3d(inputs=up_sampled, training=training)
198
199
        if self.resize:
200
            output = self.resize(output)
201
            output = self.interpolate(output)
202
203
        return output
204