Passed
Pull Request — main (#656)
by Yunguan
11:15 queued 50s
created

AbstractUNet.__init__()   A

Complexity

Conditions 1

Size

Total Lines 55
Code Lines 28

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 28
dl 0
loc 55
rs 9.208
c 0
b 0
f 0
cc 1
nop 10

How to fix   Long Method    Many Parameters   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
# coding=utf-8
0 ignored issues
show
introduced by
Missing module docstring
Loading history...
2
3
4
from abc import abstractmethod
5
from typing import List, Tuple, Union
6
7
import tensorflow as tf
8
import tensorflow.keras.layers as tfkl
9
from tensorflow.python.keras.utils import conv_utils
10
11
from deepreg.model import layer, layer_util
12
from deepreg.model.backbone.interface import Backbone
13
from deepreg.registry import REGISTRY
14
15
16
class AbstractUNet(Backbone):
17
    """An interface for u-net style backbones."""
18
19
    def __init__(
0 ignored issues
show
introduced by
"use_additive_upsampling" differing in parameter documentation
Loading history...
20
        self,
21
        image_size: tuple,
22
        num_channel_initial: int,
23
        depth: int,
24
        extract_levels: List[int],
25
        out_kernel_initializer: str,
26
        out_activation: str,
27
        out_channels: int,
28
        name: str = "AbstractUNet",
29
        **kwargs,
30
    ):
31
        """
32
        Init.
33
34
        Image is encoded gradually, i from level 0 to D,
35
        then it is decoded gradually, j from level D to 0.
36
        Some of the decoded levels are used for generating extractions.
37
38
        So, extract_levels are between [0, D].
39
40
        :param image_size: such as (dim1, dim2, dim3)
41
        :param num_channel_initial: number of initial channels.
42
        :param depth: d = 0 to depth, both side included
43
        :param extract_levels: from which depths the output will be built.
44
        :param out_kernel_initializer: initializer to use for kernels.
45
        :param out_activation: activation to use at end layer.
46
        :param out_channels: number of channels for the extractions
47
        :param use_additive_upsampling: whether use additive up-sampling.
48
        :param name: name of the backbone.
49
        :param kwargs: additional arguments.
50
        """
51
        super().__init__(
52
            image_size=image_size,
53
            out_channels=out_channels,
54
            num_channel_initial=num_channel_initial,
55
            out_kernel_initializer=out_kernel_initializer,
56
            out_activation=out_activation,
57
            name=name,
58
            **kwargs,
59
        )
60
61
        # save parameters
62
        assert max(extract_levels) <= depth
63
        self._extract_levels = extract_levels
64
        self._depth = depth
65
66
        # init layers
67
        # all lists start with d = 0
68
        self._downsample_convs = None
69
        self._downsample_pools = None
70
        self._bottom_block = None
71
        self._upsample_deconvs = None
72
        self._upsample_convs = None
73
        self._output_block = None
74
75
    def build_layers(
76
        self,
77
        image_size: tuple,
78
        num_channel_initial: int,
79
        depth: int,
80
        extract_levels: List[int],
81
        downsample_kernel_sizes: Union[int, List[int]],
82
        upsample_kernel_sizes: Union[int, List[int]],
83
        strides: int,
84
        padding: str,
85
        out_kernel_initializer: str,
86
        out_activation: str,
87
        out_channels: int,
88
    ):
89
        """
90
        Build layers that will be used in call.
91
92
        :param image_size: (dim1, dim2, dim3).
93
        :param num_channel_initial: number of initial channels.
94
        :param depth: network starts with d = 0, and the bottom has d = depth.
95
        :param extract_levels: from which depths the output will be built.
96
        :param downsample_kernel_sizes: kernel size for down-sampling
97
        :param upsample_kernel_sizes: kernel size for up-sampling
98
        :param strides: strides for down-sampling
99
        :param padding: padding mode for all conv layers
100
        :param out_kernel_initializer: initializer to use for kernels.
101
        :param out_activation: activation to use at end layer.
102
        :param out_channels: number of channels for the extractions
103
        """
104
        # init params
105
        min_extract_level = min(extract_levels)
106
        num_channels = [num_channel_initial * (2 ** d) for d in range(depth + 1)]
107
        if isinstance(downsample_kernel_sizes, int):
108
            downsample_kernel_sizes = [downsample_kernel_sizes] * (depth + 1)
109
        assert len(downsample_kernel_sizes) == depth + 1
110
        if isinstance(upsample_kernel_sizes, int):
111
            upsample_kernel_sizes = [upsample_kernel_sizes] * depth
112
        assert len(upsample_kernel_sizes) == depth
113
114
        # down-sampling
115
        self._downsample_convs = []
116
        self._downsample_pools = []
117
        tensor_shape = image_size
118
        tensor_shapes = [tensor_shape]
119
        for d in range(depth):
120
            downsample_conv = self.build_conv_block(
121
                filters=num_channels[d],
122
                kernel_size=downsample_kernel_sizes[d],
123
                padding=padding,
124
            )
125
            downsample_pool = self.build_down_sampling_block(
126
                kernel_size=strides, strides=strides, padding=padding
127
            )
128
            tensor_shape = tuple(
129
                conv_utils.conv_output_length(
130
                    input_length=x,
131
                    filter_size=strides,
132
                    padding=padding,
133
                    stride=strides,
134
                    dilation=1,
135
                )
136
                for x in tensor_shape
137
            )
138
            self._downsample_convs.append(downsample_conv)
139
            self._downsample_pools.append(downsample_pool)
140
            tensor_shapes.append(tensor_shape)
141
142
        # bottom layer
143
        self._bottom_block = self.build_bottom_block(
144
            filters=num_channels[depth],
145
            kernel_size=downsample_kernel_sizes[depth],
146
            padding=padding,
147
        )
148
149
        # up-sampling
150
        self._upsample_deconvs = []
151
        self._upsample_convs = []
152
        for d in range(depth - 1, min_extract_level - 1, -1):
153
            kernel_size = upsample_kernel_sizes[d]
154
            output_padding = layer_util.deconv_output_padding(
155
                input_shape=tensor_shapes[d + 1],
156
                output_shape=tensor_shapes[d],
157
                kernel_size=kernel_size,
158
                stride=strides,
159
                padding=padding,
160
            )
161
            upsample_deconv = self.build_up_sampling_block(
162
                filters=num_channels[d],
163
                output_padding=output_padding,
164
                kernel_size=kernel_size,
165
                strides=strides,
166
                padding=padding,
167
                output_shape=tensor_shapes[d],
168
            )
169
            upsample_conv = self.build_conv_block(
170
                filters=num_channels[d], kernel_size=kernel_size, padding=padding
171
            )
172
            self._upsample_deconvs = [upsample_deconv] + self._upsample_deconvs
173
            self._upsample_convs = [upsample_conv] + self._upsample_convs
174
        if min_extract_level > 0:
175
            # add Nones to make lists have length depth - 1
176
            self._upsample_deconvs = [None] * min_extract_level + self._upsample_deconvs
177
            self._upsample_convs = [None] * min_extract_level + self._upsample_convs
178
179
        # extraction
180
        self._output_block = self.build_output_block(
181
            image_size=image_size,
182
            extract_levels=extract_levels,
183
            out_channels=out_channels,
184
            out_kernel_initializer=out_kernel_initializer,
185
            out_activation=out_activation,
186
        )
187
188
    @abstractmethod
189
    def build_conv_block(
190
        self, filters: int, kernel_size: int, padding: str
191
    ) -> Union[tf.keras.Model, tfkl.Layer]:
192
        """
193
        Build a conv block for down-sampling or up-sampling.
194
195
        This block do not change the tensor shape (width, height, depth),
196
        it only changes the number of channels.
197
198
        :param filters: number of channels for output
199
        :param kernel_size: arg for conv3d
200
        :param padding: arg for conv3d
201
        :return: a block consists of one or multiple layers
202
        """
203
204
    @abstractmethod
205
    def build_down_sampling_block(
206
        self, kernel_size: int, padding: str, strides: int
207
    ) -> Union[tf.keras.Model, tfkl.Layer]:
208
        """
209
        Build a block for down-sampling.
210
211
        This block changes the tensor shape (width, height, depth),
212
        but it does not changes the number of channels.
213
214
        :param kernel_size: arg for pool3d
215
        :param padding: arg for pool3d
216
        :param strides: arg for pool3d
217
        :return: a block consists of one or multiple layers
218
        """
219
220
    @abstractmethod
221
    def build_bottom_block(
222
        self, filters: int, kernel_size: int, padding: str
223
    ) -> Union[tf.keras.Model, tfkl.Layer]:
224
        """
225
        Build a block for bottom layer.
226
227
        This block do not change the tensor shape (width, height, depth),
228
        it only changes the number of channels.
229
230
        :param filters: number of channels for output
231
        :param kernel_size: arg for conv3d
232
        :param padding: arg for conv3d
233
        :return: a block consists of one or multiple layers
234
        """
235
236
    @abstractmethod
237
    def build_up_sampling_block(
238
        self,
239
        filters: int,
240
        output_padding: int,
241
        kernel_size: int,
242
        padding: str,
243
        strides: int,
244
        output_shape: tuple,
245
    ) -> Union[tf.keras.Model, tfkl.Layer]:
246
        """
247
        Build a block for up-sampling.
248
249
        This block changes the tensor shape (width, height, depth),
250
        but it does not changes the number of channels.
251
252
        :param filters: number of channels for output
253
        :param output_padding: padding for output
254
        :param kernel_size: arg for deconv3d
255
        :param padding: arg for deconv3d
256
        :param strides: arg for deconv3d
257
        :param output_shape: shape of the output tensor
258
        :return: a block consists of one or multiple layers
259
        """
260
261
    @abstractmethod
262
    def build_skip_block(self) -> Union[tf.keras.Model, tfkl.Layer]:
263
        """
264
        Build a block for combining skipped tensor and up-sampled one.
265
266
        This block do not change the tensor shape (width, height, depth),
267
        it only changes the number of channels.
268
269
        The input to this block is a list of tensors.
270
271
        :return: a block consists of one or multiple layers
272
        """
273
274
    @abstractmethod
275
    def build_output_block(
276
        self,
277
        image_size: Tuple[int],
278
        extract_levels: List[int],
279
        out_channels: int,
280
        out_kernel_initializer: str,
281
        out_activation: str,
282
    ) -> Union[tf.keras.Model, tfkl.Layer]:
283
        """
284
        Build a block for output.
285
286
        The input to this block is a list of tensors.
287
288
        :param image_size: such as (dim1, dim2, dim3)
289
        :param extract_levels: number of extraction levels.
290
        :param out_channels: number of channels for the extractions
291
        :param out_kernel_initializer: initializer to use for kernels.
292
        :param out_activation: activation to use at end layer.
293
        :return: a block consists of one or multiple layers
294
        """
295
296
    def call(self, inputs: tf.Tensor, training=None, mask=None) -> tf.Tensor:
0 ignored issues
show
introduced by
"mask, training" missing in parameter type documentation
Loading history...
297
        """
298
        Build LocalNet graph based on built layers.
299
300
        :param inputs: image batch, shape = (batch, f_dim1, f_dim2, f_dim3, ch)
301
        :param training: None or bool.
302
        :param mask: None or tf.Tensor.
303
        :return: shape = (batch, f_dim1, f_dim2, f_dim3, out_channels)
304
        """
305
306
        # down-sampling
307
        skips = []
308
        down_sampled = inputs
309
        for d in range(self._depth):
310
            skip = self._downsample_convs[d](inputs=down_sampled, training=training)
311
            down_sampled = self._downsample_pools[d](inputs=skip, training=training)
312
            skips.append(skip)
313
314
        # bottom
315
        up_sampled = self._bottom_block(inputs=down_sampled, training=training)
316
317
        # up-sampling
318
        outs = [up_sampled]
319
        for d in range(self._depth - 1, min(self._extract_levels) - 1, -1):
320
            up_sampled = self._upsample_deconvs[d](inputs=up_sampled, training=training)
321
            up_sampled = self.build_skip_block()([up_sampled, skips[d]])
322
            up_sampled = self._upsample_convs[d](inputs=up_sampled, training=training)
323
            outs.append(up_sampled)
324
325
        # output
326
        output = self._output_block(outs)
327
328
        return output
329
330
331
@REGISTRY.register_backbone(name="unet")
332
class UNet(Backbone):
333
    """
334
    Class that implements an adapted 3D UNet.
335
336
    Reference:
337
338
    - O. Ronneberger, P. Fischer, and T. Brox,
339
      “U-net: Convolutional networks for biomedical image segmentation,”,
340
      Lecture Notes in Computer Science, 2015, vol. 9351, pp. 234–241.
341
      https://arxiv.org/abs/1505.04597
342
    """
343
344
    def __init__(
345
        self,
346
        image_size: tuple,
347
        out_channels: int,
348
        num_channel_initial: int,
349
        depth: int,
350
        out_kernel_initializer: str,
351
        out_activation: str,
352
        pooling: bool = True,
353
        concat_skip: bool = False,
354
        name: str = "Unet",
355
        **kwargs,
356
    ):
357
        """
358
        Initialise UNet.
359
360
        :param image_size: (dim1, dim2, dim3), dims of input image.
361
        :param out_channels: number of channels for the output
362
        :param num_channel_initial: number of initial channels
363
        :param depth: input is at level 0, bottom is at level depth
364
        :param out_kernel_initializer: kernel initializer for the last layer
365
        :param out_activation: activation at the last layer
366
        :param pooling: for downsampling, use non-parameterized
367
                        pooling if true, otherwise use conv3d
368
        :param concat_skip: when upsampling, concatenate skipped
369
                            tensor if true, otherwise use addition
370
        :param name: name of the backbone.
371
        :param kwargs: additional arguments.
372
        """
373
        super().__init__(
374
            image_size=image_size,
375
            out_channels=out_channels,
376
            num_channel_initial=num_channel_initial,
377
            out_kernel_initializer=out_kernel_initializer,
378
            out_activation=out_activation,
379
            name=name,
380
            **kwargs,
381
        )
382
383
        # init layer variables
384
        num_channels = [num_channel_initial * (2 ** d) for d in range(depth + 1)]
385
386
        self._num_channel_initial = num_channel_initial
387
        self._depth = depth
388
        self._concat_skip = concat_skip
389
        self._downsample_convs = []
390
        self._downsample_pools = []
391
        tensor_shape = image_size
392
        self._tensor_shapes = [tensor_shape]
393
        for d in range(depth):
394
            downsample_conv = tf.keras.Sequential(
395
                [
396
                    layer.Conv3dBlock(
397
                        filters=num_channels[d], kernel_size=3, padding="same"
398
                    ),
399
                    layer.ResidualConv3dBlock(
400
                        filters=num_channels[d], kernel_size=3, padding="same"
401
                    ),
402
                ]
403
            )
404
            if pooling:
405
                downsample_pool = tfkl.MaxPool3D(pool_size=2, strides=2, padding="same")
406
            else:
407
                downsample_pool = layer.Conv3dBlock(
408
                    filters=num_channels[d], kernel_size=3, strides=2, padding="same"
409
                )
410
            tensor_shape = tuple((x + 1) // 2 for x in tensor_shape)
411
            self._downsample_convs.append(downsample_conv)
412
            self._downsample_pools.append(downsample_pool)
413
            self._tensor_shapes.append(tensor_shape)
414
        self._bottom_conv3d = layer.Conv3dBlock(
415
            filters=num_channels[depth], kernel_size=3, padding="same"
416
        )
417
        self._bottom_res3d = layer.ResidualConv3dBlock(
418
            filters=num_channels[depth], kernel_size=3, padding="same"
419
        )
420
        self._upsample_deconvs = []
421
        self._upsample_convs = []
422
        for d in range(depth):
423
            padding = layer_util.deconv_output_padding(
424
                input_shape=self._tensor_shapes[d + 1],
425
                output_shape=self._tensor_shapes[d],
426
                kernel_size=3,
427
                stride=2,
428
                padding="same",
429
            )
430
            upsample_deconv = layer.Deconv3dBlock(
431
                filters=num_channels[d],
432
                output_padding=padding,
433
                kernel_size=3,
434
                strides=2,
435
                padding="same",
436
            )
437
            upsample_conv = tf.keras.Sequential(
438
                [
439
                    layer.Conv3dBlock(
440
                        filters=num_channels[d], kernel_size=3, padding="same"
441
                    ),
442
                    layer.ResidualConv3dBlock(
443
                        filters=num_channels[d], kernel_size=3, padding="same"
444
                    ),
445
                ]
446
            )
447
            self._upsample_deconvs.append(upsample_deconv)
448
            self._upsample_convs.append(upsample_conv)
449
        self._output_conv3d = tf.keras.Sequential(
450
            [
451
                tfkl.Conv3D(
452
                    filters=out_channels,
453
                    kernel_size=3,
454
                    strides=1,
455
                    padding="same",
456
                    kernel_initializer=out_kernel_initializer,
457
                    activation=out_activation,
458
                ),
459
                layer.Resize3d(shape=image_size),
460
            ]
461
        )
462
463
    def call(self, inputs: tf.Tensor, training=None, mask=None) -> tf.Tensor:
0 ignored issues
show
introduced by
"mask, training" missing in parameter type documentation
Loading history...
464
        """
465
        Builds graph based on built layers.
466
467
        :param inputs: shape = [batch, f_dim1, f_dim2, f_dim3, in_channels]
468
        :param training:
469
        :param mask:
470
        :return: shape = [batch, f_dim1, f_dim2, f_dim3, out_channels]
471
        """
472
473
        down_sampled = inputs
474
475
        # down sample
476
        skips = []
477
        for d_var in range(self._depth):  # level 0 to D-1
478
            skip = self._downsample_convs[d_var](inputs=down_sampled, training=training)
479
            down_sampled = self._downsample_pools[d_var](inputs=skip, training=training)
480
            skips.append(skip)
481
482
        # bottom, level D
483
        up_sampled = self._bottom_res3d(
484
            inputs=self._bottom_conv3d(inputs=down_sampled, training=training),
485
            training=training,
486
        )
487
488
        # up sample, level D-1 to 0
489
        for d_var in range(self._depth - 1, -1, -1):
490
            up_sampled = self._upsample_deconvs[d_var](
491
                inputs=up_sampled, training=training
492
            )
493
            if self._concat_skip:
494
                up_sampled = tf.concat([up_sampled, skips[d_var]], axis=4)
495
            else:
496
                up_sampled = up_sampled + skips[d_var]
497
            up_sampled = self._upsample_convs[d_var](
498
                inputs=up_sampled, training=training
499
            )
500
501
        # output
502
        output = self._output_conv3d(inputs=up_sampled, training=training)
503
504
        return output
505