Passed
Pull Request — main (#656)
by Yunguan
03:01
created

UNet.build_encode_layers()   B

Complexity

Conditions 3

Size

Total Lines 64
Code Lines 42

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 42
dl 0
loc 64
rs 8.872
c 0
b 0
f 0
cc 3
nop 7

How to fix   Long Method   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

1
# coding=utf-8
0 ignored issues
show
introduced by
Missing module docstring
Loading history...
2
3
from typing import List, Tuple, Union
4
5
import tensorflow as tf
6
import tensorflow.keras.layers as tfkl
7
from tensorflow.python.keras.utils import conv_utils
8
9
from deepreg.model import layer, layer_util
10
from deepreg.model.backbone.interface import Backbone
11
from deepreg.model.layer import Extraction
12
from deepreg.registry import REGISTRY
13
14
15
@REGISTRY.register_backbone(name="unet")
16
class UNet(Backbone):
17
    """
18
    Class that implements an adapted 3D UNet.
19
20
    Reference:
21
22
    - O. Ronneberger, P. Fischer, and T. Brox,
23
      “U-net: Convolutional networks for biomedical image segmentation,”,
24
      Lecture Notes in Computer Science, 2015, vol. 9351, pp. 234–241.
25
      https://arxiv.org/abs/1505.04597
26
    """
27
28
    def __init__(
29
        self,
30
        image_size: tuple,
31
        num_channel_initial: int,
32
        depth: int,
33
        out_kernel_initializer: str,
34
        out_activation: str,
35
        out_channels: int,
36
        extract_levels: Tuple[int] = (0,),
37
        pooling: bool = True,
38
        concat_skip: bool = False,
39
        encode_kernel_sizes: Union[int, List[int]] = 3,
40
        decode_kernel_sizes: Union[int, List[int]] = 3,
41
        strides: int = 2,
42
        padding: str = "same",
43
        name: str = "Unet",
44
        **kwargs,
45
    ):
46
        """
47
        Initialise UNet.
48
49
        :param image_size: (dim1, dim2, dim3), dims of input image.
50
        :param num_channel_initial: number of initial channels
51
        :param depth: input is at level 0, bottom is at level depth.
52
        :param out_kernel_initializer: kernel initializer for the last layer
53
        :param out_activation: activation at the last layer
54
        :param out_channels: number of channels for the output
55
        :param extract_levels: list, which levels from net to extract.
56
        :param pooling: for down-sampling, use non-parameterized
57
                        pooling if true, otherwise use conv3d
58
        :param concat_skip: when up-sampling, concatenate skipped
59
                            tensor if true, otherwise use addition
60
        :param encode_kernel_sizes: kernel size for down-sampling
61
        :param decode_kernel_sizes: kernel size for up-sampling
62
        :param strides: strides for down-sampling
63
        :param padding: padding mode for all conv layers
64
        :param name: name of the backbone.
65
        :param kwargs: additional arguments.
66
        """
67
        super().__init__(
68
            image_size=image_size,
69
            out_channels=out_channels,
70
            num_channel_initial=num_channel_initial,
71
            out_kernel_initializer=out_kernel_initializer,
72
            out_activation=out_activation,
73
            name=name,
74
            **kwargs,
75
        )
76
77
        # save parameters
78
        assert max(extract_levels) <= depth
79
        self._extract_levels = extract_levels
80
        self._depth = depth
81
82
        # save extra parameters
83
        self._concat_skip = concat_skip
84
        self._pooling = pooling
85
        self._encode_kernel_sizes = encode_kernel_sizes
86
        self._decode_kernel_sizes = decode_kernel_sizes
87
        self._strides = strides
88
        self._padding = padding
89
90
        # init layers
91
        # all lists start with d = 0
92
        self._encode_convs = None
93
        self._encode_pools = None
94
        self._bottom_block = None
95
        self._decode_deconvs = None
96
        self._decode_convs = None
97
        self._output_block = None
98
99
        # build layers
100
        self.build_layers(
101
            image_size=image_size,
102
            num_channel_initial=num_channel_initial,
103
            depth=depth,
104
            extract_levels=extract_levels,
105
            encode_kernel_sizes=encode_kernel_sizes,
106
            decode_kernel_sizes=decode_kernel_sizes,
107
            strides=strides,
108
            padding=padding,
109
            out_kernel_initializer=out_kernel_initializer,
110
            out_activation=out_activation,
111
            out_channels=out_channels,
112
        )
113
114
    def build_conv_block(
115
        self, filters: int, kernel_size: int, padding: str
116
    ) -> Union[tf.keras.Model, tfkl.Layer]:
117
        """
118
        Build a conv block for down-sampling or up-sampling.
119
120
        This block do not change the tensor shape (width, height, depth),
121
        it only changes the number of channels.
122
123
        :param filters: number of channels for output
124
        :param kernel_size: arg for conv3d
125
        :param padding: arg for conv3d
126
        :return: a block consists of one or multiple layers
127
        """
128
        return tf.keras.Sequential(
129
            [
130
                layer.Conv3dBlock(
131
                    filters=filters,
132
                    kernel_size=kernel_size,
133
                    padding=padding,
134
                ),
135
                layer.ResidualConv3dBlock(
136
                    filters=filters,
137
                    kernel_size=kernel_size,
138
                    padding=padding,
139
                ),
140
            ]
141
        )
142
143
    def build_down_sampling_block(
144
        self, filters: int, kernel_size: int, padding: str, strides: int
145
    ) -> Union[tf.keras.Model, tfkl.Layer]:
146
        """
147
        Build a block for down-sampling.
148
149
        This block changes the tensor shape (width, height, depth),
150
        but it does not changes the number of channels.
151
152
        :param filters: number of channels for output, arg for conv3d
153
        :param kernel_size: arg for pool3d or conv3d
154
        :param padding: arg for pool3d or conv3d
155
        :param strides: arg for pool3d or conv3d
156
        :return: a block consists of one or multiple layers
157
        """
158
        if self._pooling:
0 ignored issues
show
unused-code introduced by
Unnecessary "else" after "return"
Loading history...
159
            return tfkl.MaxPool3D(
160
                pool_size=kernel_size, strides=strides, padding=padding
161
            )
162
        else:
163
            return layer.Conv3dBlock(
164
                filters=filters,
165
                kernel_size=kernel_size,
166
                strides=strides,
167
                padding=padding,
168
            )
169
170
    def build_bottom_block(
171
        self, filters: int, kernel_size: int, padding: str
172
    ) -> Union[tf.keras.Model, tfkl.Layer]:
173
        """
174
        Build a block for bottom layer.
175
176
        This block do not change the tensor shape (width, height, depth),
177
        it only changes the number of channels.
178
179
        :param filters: number of channels for output
180
        :param kernel_size: arg for conv3d
181
        :param padding: arg for conv3d
182
        :return: a block consists of one or multiple layers
183
        """
184
        return tf.keras.Sequential(
185
            [
186
                layer.Conv3dBlock(
187
                    filters=filters,
188
                    kernel_size=kernel_size,
189
                    padding=padding,
190
                ),
191
                layer.ResidualConv3dBlock(
192
                    filters=filters,
193
                    kernel_size=kernel_size,
194
                    padding=padding,
195
                ),
196
            ]
197
        )
198
199
    def build_up_sampling_block(
200
        self,
201
        filters: int,
202
        output_padding: int,
203
        kernel_size: int,
204
        padding: str,
205
        strides: int,
206
        output_shape: tuple,
0 ignored issues
show
Unused Code introduced by
The argument output_shape seems to be unused.
Loading history...
207
    ) -> Union[tf.keras.Model, tfkl.Layer]:
208
        """
209
        Build a block for up-sampling.
210
211
        This block changes the tensor shape (width, height, depth),
212
        but it does not changes the number of channels.
213
214
        :param filters: number of channels for output
215
        :param output_padding: padding for output
216
        :param kernel_size: arg for deconv3d
217
        :param padding: arg for deconv3d
218
        :param strides: arg for deconv3d
219
        :param output_shape: shape of the output tensor
220
        :return: a block consists of one or multiple layers
221
        """
222
        return layer.Deconv3dBlock(
223
            filters=filters,
224
            output_padding=output_padding,
225
            kernel_size=kernel_size,
226
            strides=strides,
227
            padding=padding,
228
        )
229
230
    def build_skip_block(self) -> Union[tf.keras.Model, tfkl.Layer]:
231
        """
232
        Build a block for combining skipped tensor and up-sampled one.
233
234
        This block do not change the tensor shape (width, height, depth),
235
        it only changes the number of channels.
236
237
        The input to this block is a list of tensors.
238
239
        :return: a block consists of one or multiple layers
240
        """
241
        if self._concat_skip:
0 ignored issues
show
unused-code introduced by
Unnecessary "else" after "return"
Loading history...
242
            return tfkl.Concatenate()
243
        else:
244
            return tfkl.Add()
245
246
    def build_output_block(
247
        self,
248
        image_size: Tuple[int],
249
        extract_levels: Tuple[int],
250
        out_channels: int,
251
        out_kernel_initializer: str,
252
        out_activation: str,
253
    ) -> Union[tf.keras.Model, tfkl.Layer]:
254
        """
255
        Build a block for output.
256
257
        The input to this block is a list of tensors.
258
259
        :param image_size: such as (dim1, dim2, dim3)
260
        :param extract_levels: number of extraction levels.
261
        :param out_channels: number of channels for the extractions
262
        :param out_kernel_initializer: initializer to use for kernels.
263
        :param out_activation: activation to use at end layer.
264
        :return: a block consists of one or multiple layers
265
        """
266
        return Extraction(
267
            image_size=image_size,
268
            extract_levels=extract_levels,
269
            out_channels=out_channels,
270
            out_kernel_initializer=out_kernel_initializer,
271
            out_activation=out_activation,
272
        )
273
274
    def build_layers(
275
        self,
276
        image_size: tuple,
277
        num_channel_initial: int,
278
        depth: int,
279
        extract_levels: Tuple[int],
280
        encode_kernel_sizes: Union[int, List[int]],
281
        decode_kernel_sizes: Union[int, List[int]],
282
        strides: int,
283
        padding: str,
284
        out_kernel_initializer: str,
285
        out_activation: str,
286
        out_channels: int,
287
    ):
288
        """
289
        Build layers that will be used in call.
290
291
        :param image_size: (dim1, dim2, dim3).
292
        :param num_channel_initial: number of initial channels.
293
        :param depth: network starts with d = 0, and the bottom has d = depth.
294
        :param extract_levels: from which depths the output will be built.
295
        :param encode_kernel_sizes: kernel size for down-sampling
296
        :param decode_kernel_sizes: kernel size for up-sampling
297
        :param strides: strides for down-sampling
298
        :param padding: padding mode for all conv layers
299
        :param out_kernel_initializer: initializer to use for kernels.
300
        :param out_activation: activation to use at end layer.
301
        :param out_channels: number of channels for the extractions
302
        """
303
        tensor_shapes = self.build_encode_layers(
304
            image_size=image_size,
305
            num_channel_initial=num_channel_initial,
306
            depth=depth,
307
            encode_kernel_sizes=encode_kernel_sizes,
308
            strides=strides,
309
            padding=padding,
310
        )
311
        self.build_decode_layers(
312
            tensor_shapes=tensor_shapes,
313
            image_size=image_size,
314
            num_channel_initial=num_channel_initial,
315
            depth=depth,
316
            extract_levels=extract_levels,
317
            decode_kernel_sizes=decode_kernel_sizes,
318
            strides=strides,
319
            padding=padding,
320
            out_kernel_initializer=out_kernel_initializer,
321
            out_activation=out_activation,
322
            out_channels=out_channels,
323
        )
324
325
    def build_encode_layers(
326
        self,
327
        image_size: tuple,
328
        num_channel_initial: int,
329
        depth: int,
330
        encode_kernel_sizes: Union[int, List[int]],
331
        strides: int,
332
        padding: str,
333
    ) -> List[Tuple]:
334
        """
335
        Build layers for encoding.
336
337
        :param image_size: (dim1, dim2, dim3).
338
        :param num_channel_initial: number of initial channels.
339
        :param depth: network starts with d = 0, and the bottom has d = depth.
340
        :param encode_kernel_sizes: kernel size for down-sampling
341
        :param strides: strides for down-sampling
342
        :param padding: padding mode for all conv layers
343
        :return: list of tensor shapes starting from d = 0
344
        """
345
        # init params
346
        num_channels = [num_channel_initial * (2 ** d) for d in range(depth + 1)]
347
        if isinstance(encode_kernel_sizes, int):
348
            encode_kernel_sizes = [encode_kernel_sizes] * (depth + 1)
349
        assert len(encode_kernel_sizes) == depth + 1
350
351
        # encoding / down-sampling
352
        self._encode_convs = []
353
        self._encode_pools = []
354
        tensor_shape = image_size
355
        tensor_shapes = [tensor_shape]
356
        for d in range(depth):
357
            encode_conv = self.build_conv_block(
358
                filters=num_channels[d],
359
                kernel_size=encode_kernel_sizes[d],
360
                padding=padding,
361
            )
362
            encode_pool = self.build_down_sampling_block(
363
                filters=num_channels[d],
364
                kernel_size=strides,
365
                strides=strides,
366
                padding=padding,
367
            )
368
            tensor_shape = tuple(
369
                conv_utils.conv_output_length(
370
                    input_length=x,
371
                    filter_size=strides,
372
                    padding=padding,
373
                    stride=strides,
374
                    dilation=1,
375
                )
376
                for x in tensor_shape
377
            )
378
            self._encode_convs.append(encode_conv)
379
            self._encode_pools.append(encode_pool)
380
            tensor_shapes.append(tensor_shape)
381
382
        # bottom layer
383
        self._bottom_block = self.build_bottom_block(
384
            filters=num_channels[depth],
385
            kernel_size=encode_kernel_sizes[depth],
386
            padding=padding,
387
        )
388
        return tensor_shapes
389
390
    def build_decode_layers(
391
        self,
392
        tensor_shapes: List[Tuple],
393
        image_size: tuple,
394
        num_channel_initial: int,
395
        depth: int,
396
        extract_levels: Tuple[int],
397
        decode_kernel_sizes: Union[int, List[int]],
398
        strides: int,
399
        padding: str,
400
        out_kernel_initializer: str,
401
        out_activation: str,
402
        out_channels: int,
403
    ):
404
        """
405
        Build layers for decoding.
406
407
        :param tensor_shapes: shapes calculated in encoder
408
        :param image_size: (dim1, dim2, dim3).
409
        :param num_channel_initial: number of initial channels.
410
        :param depth: network starts with d = 0, and the bottom has d = depth.
411
        :param extract_levels: from which depths the output will be built.
412
        :param decode_kernel_sizes: kernel size for up-sampling
413
        :param strides: strides for down-sampling
414
        :param padding: padding mode for all conv layers
415
        :param out_kernel_initializer: initializer to use for kernels.
416
        :param out_activation: activation to use at end layer.
417
        :param out_channels: number of channels for the extractions
418
        """
419
        # init params
420
        min_extract_level = min(extract_levels)
421
        num_channels = [num_channel_initial * (2 ** d) for d in range(depth + 1)]
422
        if isinstance(decode_kernel_sizes, int):
423
            decode_kernel_sizes = [decode_kernel_sizes] * depth
424
        assert len(decode_kernel_sizes) == depth
425
426
        # decoding / up-sampling
427
        self._decode_deconvs = []
428
        self._decode_convs = []
429
        for d in range(depth - 1, min_extract_level - 1, -1):
430
            kernel_size = decode_kernel_sizes[d]
431
            output_padding = layer_util.deconv_output_padding(
432
                input_shape=tensor_shapes[d + 1],
433
                output_shape=tensor_shapes[d],
434
                kernel_size=kernel_size,
435
                stride=strides,
436
                padding=padding,
437
            )
438
            decode_deconv = self.build_up_sampling_block(
439
                filters=num_channels[d],
440
                output_padding=output_padding,
441
                kernel_size=kernel_size,
442
                strides=strides,
443
                padding=padding,
444
                output_shape=tensor_shapes[d],
445
            )
446
            decode_conv = self.build_conv_block(
447
                filters=num_channels[d], kernel_size=kernel_size, padding=padding
448
            )
449
            self._decode_deconvs = [decode_deconv] + self._decode_deconvs
450
            self._decode_convs = [decode_conv] + self._decode_convs
451
        if min_extract_level > 0:
452
            # add Nones to make lists have length depth - 1
453
            self._decode_deconvs = [None] * min_extract_level + self._decode_deconvs
454
            self._decode_convs = [None] * min_extract_level + self._decode_convs
455
456
        # extraction
457
        self._output_block = self.build_output_block(
458
            image_size=image_size,
459
            extract_levels=extract_levels,
460
            out_channels=out_channels,
461
            out_kernel_initializer=out_kernel_initializer,
462
            out_activation=out_activation,
463
        )
464
465
    def call(self, inputs: tf.Tensor, training=None, mask=None) -> tf.Tensor:
0 ignored issues
show
introduced by
"mask, training" missing in parameter type documentation
Loading history...
466
        """
467
        Build compute graph based on built layers.
468
469
        :param inputs: image batch, shape = (batch, f_dim1, f_dim2, f_dim3, ch)
470
        :param training: None or bool.
471
        :param mask: None or tf.Tensor.
472
        :return: shape = (batch, f_dim1, f_dim2, f_dim3, out_channels)
473
        """
474
475
        # encoding / down-sampling
476
        skips = []
477
        encoded = inputs
478
        for d in range(self._depth):
479
            skip = self._encode_convs[d](inputs=encoded, training=training)
480
            encoded = self._encode_pools[d](inputs=skip, training=training)
481
            skips.append(skip)
482
483
        # bottom
484
        decoded = self._bottom_block(inputs=encoded, training=training)
485
486
        # decoding / up-sampling
487
        outs = [decoded]
488
        for d in range(self._depth - 1, min(self._extract_levels) - 1, -1):
489
            decoded = self._decode_deconvs[d](inputs=decoded, training=training)
490
            decoded = self.build_skip_block()([decoded, skips[d]])
491
            decoded = self._decode_convs[d](inputs=decoded, training=training)
492
            outs = [decoded] + outs
493
494
        # output
495
        output = self._output_block(outs)
496
497
        return output
498
499
    def get_config(self) -> dict:
500
        """Return the config dictionary for recreating this class."""
501
        config = super().get_config()
502
        config.update(
503
            depth=self._depth,
504
            extract_levels=self._extract_levels,
505
            pooling=self._pooling,
506
            concat_skip=self._concat_skip,
507
            encode_kernel_sizes=self._encode_kernel_sizes,
508
            decode_kernel_sizes=self._decode_kernel_sizes,
509
            strides=self._strides,
510
            padding=self._padding,
511
        )
512
        return config
513