Passed
Pull Request — main (#656)
by Yunguan
02:46
created

deepreg.model.backbone.u_net.UNet.__init__()   C

Complexity

Conditions 6

Size

Total Lines 114
Code Lines 77

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 77
dl 0
loc 114
rs 6.8193
c 0
b 0
f 0
cc 6
nop 12

How to fix   Long Method    Many Parameters   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
# coding=utf-8
0 ignored issues
show
introduced by
Missing module docstring
Loading history...
2
3
4
import tensorflow as tf
5
import tensorflow.keras.layers as tfkl
6
7
from deepreg.model import layer
8
from deepreg.model.backbone.interface import Backbone
9
from deepreg.registry import REGISTRY
10
11
12
@REGISTRY.register_backbone(name="unet")
13
class UNet(Backbone):
14
    """
15
    Class that implements an adapted 3D UNet.
16
17
    Reference:
18
19
    - O. Ronneberger, P. Fischer, and T. Brox,
20
      “U-net: Convolutional networks for biomedical image segmentation,”,
21
      Lecture Notes in Computer Science, 2015, vol. 9351, pp. 234–241.
22
      https://arxiv.org/abs/1505.04597
23
    """
24
25
    def __init__(
26
        self,
27
        image_size: tuple,
28
        out_channels: int,
29
        num_channel_initial: int,
30
        depth: int,
31
        out_kernel_initializer: str,
32
        out_activation: str,
33
        pooling: bool = True,
34
        concat_skip: bool = False,
35
        control_points: (tuple, None) = None,
36
        name: str = "Unet",
37
        **kwargs,
38
    ):
39
        """
40
        Initialise UNet.
41
42
        :param image_size: (dim1, dim2, dim3), dims of input image.
43
        :param out_channels: number of channels for the output
44
        :param num_channel_initial: number of initial channels
45
        :param depth: input is at level 0, bottom is at level depth
46
        :param out_kernel_initializer: kernel initializer for the last layer
47
        :param out_activation: activation at the last layer
48
        :param pooling: for downsampling, use non-parameterized
49
                        pooling if true, otherwise use conv3d
50
        :param concat_skip: when upsampling, concatenate skipped
51
                            tensor if true, otherwise use addition
52
        :param control_points: specify the distance between control points (in voxels).
53
        :param name: name of the backbone.
54
        :param kwargs: additional arguments.
55
        """
56
        super().__init__(
57
            image_size=image_size,
58
            out_channels=out_channels,
59
            num_channel_initial=num_channel_initial,
60
            out_kernel_initializer=out_kernel_initializer,
61
            out_activation=out_activation,
62
            name=name,
63
            **kwargs,
64
        )
65
66
        # init layer variables
67
        num_channels = [num_channel_initial * (2 ** d) for d in range(depth + 1)]
68
69
        self._num_channel_initial = num_channel_initial
70
        self._depth = depth
71
        self._downsample_convs = []
72
        self._downsample_pools = []
73
        self._upsample_blocks = []
74
        tensor_shape = image_size
75
        self._tensor_shapes = [tensor_shape]
76
        for d in range(depth):
77
            downsample_conv = tf.keras.Sequential(
78
                [
79
                    layer.Conv3dBlock(
80
                        filters=num_channels[d], kernel_size=3, padding="same"
81
                    ),
82
                    layer.ResidualConv3dBlock(
83
                        filters=num_channels[d], kernel_size=3, padding="same"
84
                    ),
85
                ]
86
            )
87
            if pooling:
88
                downsample_pool = tfkl.MaxPool3D(pool_size=2, strides=2, padding="same")
89
            else:
90
                downsample_pool = layer.Conv3dBlock(
91
                    filters=num_channels[d], kernel_size=3, strides=2, padding="same"
92
                )
93
            tensor_shape = tuple((x + 1) // 2 for x in tensor_shape)
94
            self._downsample_convs.append(downsample_conv)
95
            self._downsample_pools.append(downsample_pool)
96
            self._tensor_shapes.append(tensor_shape)
97
        self._bottom_conv3d = layer.Conv3dBlock(
98
            filters=num_channels[depth], kernel_size=3, padding="same"
99
        )
100
        self._bottom_res3d = layer.ResidualConv3dBlock(
101
            filters=num_channels[depth], kernel_size=3, padding="same"
102
        )
103
        for d in range(depth):
104
            padding = layer.deconv_output_padding(
105
                input_shape=self._tensor_shapes[d + 1],
106
                output_shape=self._tensor_shapes[d],
107
                kernel_size=3,
108
                stride=2,
109
                padding="same",
110
            )
111
            upsample_block = layer.UpSampleResnetBlock(
112
                filters=num_channels[d], output_padding=padding, concat=concat_skip
113
            )
114
            self._upsample_blocks.append(upsample_block)
115
            self._tensor_shapes.append(tensor_shape)
116
        self._output_conv3d = tf.keras.Sequential(
117
            [
118
                tfkl.Conv3D(
119
                    filters=out_channels,
120
                    kernel_size=3,
121
                    strides=1,
122
                    padding="same",
123
                    kernel_initializer=out_kernel_initializer,
124
                    activation=out_activation,
125
                ),
126
                layer.Resize3d(shape=image_size),
127
            ]
128
        )
129
130
        self.resize = (
131
            layer.ResizeCPTransform(control_points)
132
            if control_points is not None
133
            else False
134
        )
135
        self.interpolate = (
136
            layer.BSplines3DTransform(control_points, image_size)
137
            if control_points is not None
138
            else False
139
        )
140
141
    def call(self, inputs: tf.Tensor, training=None, mask=None) -> tf.Tensor:
0 ignored issues
show
introduced by
"mask, training" missing in parameter type documentation
Loading history...
142
        """
143
        Builds graph based on built layers.
144
145
        :param inputs: shape = [batch, f_dim1, f_dim2, f_dim3, in_channels]
146
        :param training:
147
        :param mask:
148
        :return: shape = [batch, f_dim1, f_dim2, f_dim3, out_channels]
149
        """
150
151
        down_sampled = inputs
152
153
        # down sample
154
        skips = []
155
        for d_var in range(self._depth):  # level 0 to D-1
156
            skip = self._downsample_convs[d_var](inputs=down_sampled, training=training)
157
            down_sampled = self._downsample_pools[d_var](inputs=skip)
158
            skips.append(skip)
159
160
        # bottom, level D
161
        up_sampled = self._bottom_res3d(
162
            inputs=self._bottom_conv3d(inputs=down_sampled, training=training),
163
            training=training,
164
        )
165
166
        # up sample, level D-1 to 0
167
        for d_var in range(self._depth - 1, -1, -1):
168
            up_sampled = self._upsample_blocks[d_var](
169
                inputs=[up_sampled, skips[d_var]], training=training
170
            )
171
172
        # output
173
        output = self._output_conv3d(inputs=up_sampled)
174
175
        if self.resize:
176
            output = self.resize(output)
177
            output = self.interpolate(output)
178
179
        return output
180