Passed
Pull Request — main (#656)
by Yunguan
03:15
created

deepreg.model.backbone.u_net.UNet.__init__()   B

Complexity

Conditions 4

Size

Total Lines 116
Code Lines 80

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 80
dl 0
loc 116
rs 7.6545
c 0
b 0
f 0
cc 4
nop 11

How to fix   Long Method    Many Parameters   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
# coding=utf-8
0 ignored issues
show
introduced by
Missing module docstring
Loading history...
2
3
4
import tensorflow as tf
5
import tensorflow.keras.layers as tfkl
6
7
import deepreg.model.layer_util
8
from deepreg.model import layer
9
from deepreg.model.backbone.interface import Backbone
10
from deepreg.registry import REGISTRY
11
12
13
@REGISTRY.register_backbone(name="unet")
14
class UNet(Backbone):
15
    """
16
    Class that implements an adapted 3D UNet.
17
18
    Reference:
19
20
    - O. Ronneberger, P. Fischer, and T. Brox,
21
      “U-net: Convolutional networks for biomedical image segmentation,”,
22
      Lecture Notes in Computer Science, 2015, vol. 9351, pp. 234–241.
23
      https://arxiv.org/abs/1505.04597
24
    """
25
26
    def __init__(
27
        self,
28
        image_size: tuple,
29
        out_channels: int,
30
        num_channel_initial: int,
31
        depth: int,
32
        out_kernel_initializer: str,
33
        out_activation: str,
34
        pooling: bool = True,
35
        concat_skip: bool = False,
36
        name: str = "Unet",
37
        **kwargs,
38
    ):
39
        """
40
        Initialise UNet.
41
42
        :param image_size: (dim1, dim2, dim3), dims of input image.
43
        :param out_channels: number of channels for the output
44
        :param num_channel_initial: number of initial channels
45
        :param depth: input is at level 0, bottom is at level depth
46
        :param out_kernel_initializer: kernel initializer for the last layer
47
        :param out_activation: activation at the last layer
48
        :param pooling: for downsampling, use non-parameterized
49
                        pooling if true, otherwise use conv3d
50
        :param concat_skip: when upsampling, concatenate skipped
51
                            tensor if true, otherwise use addition
52
        :param name: name of the backbone.
53
        :param kwargs: additional arguments.
54
        """
55
        super().__init__(
56
            image_size=image_size,
57
            out_channels=out_channels,
58
            num_channel_initial=num_channel_initial,
59
            out_kernel_initializer=out_kernel_initializer,
60
            out_activation=out_activation,
61
            name=name,
62
            **kwargs,
63
        )
64
65
        # init layer variables
66
        num_channels = [num_channel_initial * (2 ** d) for d in range(depth + 1)]
67
68
        self._num_channel_initial = num_channel_initial
69
        self._depth = depth
70
        self._concat_skip = concat_skip
71
        self._downsample_convs = []
72
        self._downsample_pools = []
73
        tensor_shape = image_size
74
        self._tensor_shapes = [tensor_shape]
75
        for d in range(depth):
76
            downsample_conv = tf.keras.Sequential(
77
                [
78
                    layer.Conv3dBlock(
79
                        filters=num_channels[d], kernel_size=3, padding="same"
80
                    ),
81
                    layer.ResidualConv3dBlock(
82
                        filters=num_channels[d], kernel_size=3, padding="same"
83
                    ),
84
                ]
85
            )
86
            if pooling:
87
                downsample_pool = tfkl.MaxPool3D(pool_size=2, strides=2, padding="same")
88
            else:
89
                downsample_pool = layer.Conv3dBlock(
90
                    filters=num_channels[d], kernel_size=3, strides=2, padding="same"
91
                )
92
            tensor_shape = tuple((x + 1) // 2 for x in tensor_shape)
93
            self._downsample_convs.append(downsample_conv)
94
            self._downsample_pools.append(downsample_pool)
95
            self._tensor_shapes.append(tensor_shape)
96
        self._bottom_conv3d = layer.Conv3dBlock(
97
            filters=num_channels[depth], kernel_size=3, padding="same"
98
        )
99
        self._bottom_res3d = layer.ResidualConv3dBlock(
100
            filters=num_channels[depth], kernel_size=3, padding="same"
101
        )
102
        self._upsample_deconvs = []
103
        self._upsample_convs = []
104
        for d in range(depth):
105
            padding = deepreg.model.layer_util.deconv_output_padding(
106
                input_shape=self._tensor_shapes[d + 1],
107
                output_shape=self._tensor_shapes[d],
108
                kernel_size=3,
109
                stride=2,
110
                padding="same",
111
            )
112
            upsample_deconv = layer.Deconv3dBlock(
113
                filters=num_channels[d],
114
                output_padding=padding,
115
                kernel_size=3,
116
                strides=2,
117
                padding="same",
118
            )
119
            upsample_conv = tf.keras.Sequential(
120
                [
121
                    layer.Conv3dBlock(
122
                        filters=num_channels[d], kernel_size=3, padding="same"
123
                    ),
124
                    layer.ResidualConv3dBlock(
125
                        filters=num_channels[d], kernel_size=3, padding="same"
126
                    ),
127
                ]
128
            )
129
            self._upsample_deconvs.append(upsample_deconv)
130
            self._upsample_convs.append(upsample_conv)
131
        self._output_conv3d = tf.keras.Sequential(
132
            [
133
                tfkl.Conv3D(
134
                    filters=out_channels,
135
                    kernel_size=3,
136
                    strides=1,
137
                    padding="same",
138
                    kernel_initializer=out_kernel_initializer,
139
                    activation=out_activation,
140
                ),
141
                layer.Resize3d(shape=image_size),
142
            ]
143
        )
144
145
    def call(self, inputs: tf.Tensor, training=None, mask=None) -> tf.Tensor:
0 ignored issues
show
introduced by
"mask, training" missing in parameter type documentation
Loading history...
146
        """
147
        Builds graph based on built layers.
148
149
        :param inputs: shape = [batch, f_dim1, f_dim2, f_dim3, in_channels]
150
        :param training:
151
        :param mask:
152
        :return: shape = [batch, f_dim1, f_dim2, f_dim3, out_channels]
153
        """
154
155
        down_sampled = inputs
156
157
        # down sample
158
        skips = []
159
        for d_var in range(self._depth):  # level 0 to D-1
160
            skip = self._downsample_convs[d_var](inputs=down_sampled, training=training)
161
            down_sampled = self._downsample_pools[d_var](inputs=skip, training=training)
162
            skips.append(skip)
163
164
        # bottom, level D
165
        up_sampled = self._bottom_res3d(
166
            inputs=self._bottom_conv3d(inputs=down_sampled, training=training),
167
            training=training,
168
        )
169
170
        # up sample, level D-1 to 0
171
        for d_var in range(self._depth - 1, -1, -1):
172
            up_sampled = self._upsample_deconvs[d_var](
173
                inputs=up_sampled, training=training
174
            )
175
            if self._concat_skip:
176
                up_sampled = tf.concat([up_sampled, skips[d_var]], axis=4)
177
            else:
178
                up_sampled = up_sampled + skips[d_var]
179
            up_sampled = self._upsample_convs[d_var](
180
                inputs=up_sampled, training=training
181
            )
182
183
        # output
184
        output = self._output_conv3d(inputs=up_sampled, training=training)
185
186
        return output
187