Passed
Pull Request — main (#673)
by Yunguan
04:31 queued 01:45
created

UNet.build_encode_layers()   B

Complexity

Conditions 3

Size

Total Lines 63
Code Lines 41

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 41
dl 0
loc 63
rs 8.896
c 0
b 0
f 0
cc 3
nop 7

How to fix   Long Method   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

1
# coding=utf-8
0 ignored issues
show
introduced by
Missing module docstring
Loading history...
2
3
from typing import List, Optional, Tuple, Union
4
5
import tensorflow as tf
6
import tensorflow.keras.layers as tfkl
7
from tensorflow.python.keras.utils import conv_utils
8
9
from deepreg.model import layer, layer_util
10
from deepreg.model.backbone.interface import Backbone
11
from deepreg.model.layer import Extraction
12
from deepreg.registry import REGISTRY
13
14
15
@REGISTRY.register_backbone(name="unet")
16
class UNet(Backbone):
17
    """
18
    Class that implements an adapted 3D UNet.
19
20
    Reference:
21
22
    - O. Ronneberger, P. Fischer, and T. Brox,
23
      “U-net: Convolutional networks for biomedical image segmentation,”,
24
      Lecture Notes in Computer Science, 2015, vol. 9351, pp. 234–241.
25
      https://arxiv.org/abs/1505.04597
26
    """
27
28
    def __init__(
29
        self,
30
        image_size: tuple,
31
        num_channel_initial: Optional[int],
32
        depth: int,
33
        out_kernel_initializer: str,
34
        out_activation: str,
35
        out_channels: int,
36
        extract_levels: Tuple = (0,),
37
        pooling: bool = True,
38
        concat_skip: bool = False,
39
        encode_kernel_sizes: Union[int, List[int]] = 3,
40
        decode_kernel_sizes: Union[int, List[int]] = 3,
41
        encode_num_channels: Optional[Tuple] = None,
42
        decode_num_channels: Optional[Tuple] = None,
43
        strides: int = 2,
44
        padding: str = "same",
45
        name: str = "Unet",
46
        **kwargs,
47
    ):
48
        """
49
        Initialise UNet.
50
51
        :param image_size: (dim1, dim2, dim3), dims of input image.
52
        :param num_channel_initial: number of initial channels
53
        :param depth: input is at level 0, bottom is at level depth
54
        :param out_kernel_initializer: kernel initializer for the last layer
55
        :param out_activation: activation at the last layer
56
        :param out_channels: number of channels for the output
57
        :param extract_levels: list, which levels from net to extract.
58
        :param pooling: for down-sampling, use non-parameterized
59
                        pooling if true, otherwise use conv3d
60
        :param concat_skip: when up-sampling, concatenate skipped
61
                            tensor if true, otherwise use addition
62
        :param encode_kernel_sizes: kernel size for down-sampling
63
        :param decode_kernel_sizes: kernel size for up-sampling
64
        :param encode_num_channels: filters/channels for down-sampling,
65
            by default it is doubled at each layer during down-sampling
66
        :param decode_num_channels: filters/channels for up-sampling,
67
            by default it is the same as encode_num_channels
68
        :param strides: strides for down-sampling
69
        :param padding: padding mode for all conv layers
70
        :param name: name of the backbone.
71
        :param kwargs: additional arguments.
72
        """
73
        super().__init__(
74
            image_size=image_size,
75
            out_channels=out_channels,
76
            num_channel_initial=num_channel_initial,
77
            out_kernel_initializer=out_kernel_initializer,
78
            out_activation=out_activation,
79
            name=name,
80
            **kwargs,
81
        )
82
83
        # save parameters
84
        assert max(extract_levels) <= depth
85
        self._extract_levels = extract_levels
86
        self._depth = depth
87
88
        # save extra parameters
89
        self._concat_skip = concat_skip
90
        self._pooling = pooling
91
        self._encode_kernel_sizes = encode_kernel_sizes
92
        self._decode_kernel_sizes = decode_kernel_sizes
93
        self._encode_num_channels = encode_num_channels
94
        self._decode_num_channels = decode_num_channels
95
        self._strides = strides
96
        self._padding = padding
97
98
        # init layers
99
        # all lists start with d = 0
100
        self._encode_convs = None
101
        self._encode_pools = None
102
        self._bottom_block = None
103
        self._decode_deconvs = None
104
        self._decode_convs = None
105
        self._output_block = None
106
107
        # build layers
108
        self.build_layers(
109
            image_size=image_size,
110
            num_channel_initial=num_channel_initial,
111
            depth=depth,
112
            extract_levels=extract_levels,
113
            encode_kernel_sizes=encode_kernel_sizes,
114
            decode_kernel_sizes=decode_kernel_sizes,
115
            encode_num_channels=encode_num_channels,
116
            decode_num_channels=decode_num_channels,
117
            strides=strides,
118
            padding=padding,
119
            out_kernel_initializer=out_kernel_initializer,
120
            out_activation=out_activation,
121
            out_channels=out_channels,
122
        )
123
124
    def build_conv_block(
125
        self, filters: int, kernel_size: int, padding: str
126
    ) -> Union[tf.keras.Model, tfkl.Layer]:
127
        """
128
        Build a conv block for down-sampling or up-sampling.
129
130
        This block do not change the tensor shape (width, height, depth),
131
        it only changes the number of channels.
132
133
        :param filters: number of channels for output
134
        :param kernel_size: arg for conv3d
135
        :param padding: arg for conv3d
136
        :return: a block consists of one or multiple layers
137
        """
138
        return tf.keras.Sequential(
139
            [
140
                layer.Conv3dBlock(
141
                    filters=filters,
142
                    kernel_size=kernel_size,
143
                    padding=padding,
144
                ),
145
                layer.ResidualConv3dBlock(
146
                    filters=filters,
147
                    kernel_size=kernel_size,
148
                    padding=padding,
149
                ),
150
            ]
151
        )
152
153
    def build_down_sampling_block(
154
        self, filters: int, kernel_size: int, padding: str, strides: int
155
    ) -> Union[tf.keras.Model, tfkl.Layer]:
156
        """
157
        Build a block for down-sampling.
158
159
        This block changes the tensor shape (width, height, depth),
160
        but it does not changes the number of channels.
161
162
        :param filters: number of channels for output, arg for conv3d
163
        :param kernel_size: arg for pool3d or conv3d
164
        :param padding: arg for pool3d or conv3d
165
        :param strides: arg for pool3d or conv3d
166
        :return: a block consists of one or multiple layers
167
        """
168
        if self._pooling:
0 ignored issues
show
unused-code introduced by
Unnecessary "else" after "return"
Loading history...
169
            return tfkl.MaxPool3D(
170
                pool_size=kernel_size, strides=strides, padding=padding
171
            )
172
        else:
173
            return layer.Conv3dBlock(
174
                filters=filters,
175
                kernel_size=kernel_size,
176
                strides=strides,
177
                padding=padding,
178
            )
179
180
    def build_bottom_block(
181
        self, filters: int, kernel_size: int, padding: str
182
    ) -> Union[tf.keras.Model, tfkl.Layer]:
183
        """
184
        Build a block for bottom layer.
185
186
        This block do not change the tensor shape (width, height, depth),
187
        it only changes the number of channels.
188
189
        :param filters: number of channels for output
190
        :param kernel_size: arg for conv3d
191
        :param padding: arg for conv3d
192
        :return: a block consists of one or multiple layers
193
        """
194
        return tf.keras.Sequential(
195
            [
196
                layer.Conv3dBlock(
197
                    filters=filters,
198
                    kernel_size=kernel_size,
199
                    padding=padding,
200
                ),
201
                layer.ResidualConv3dBlock(
202
                    filters=filters,
203
                    kernel_size=kernel_size,
204
                    padding=padding,
205
                ),
206
            ]
207
        )
208
209
    def build_up_sampling_block(
210
        self,
211
        filters: int,
212
        output_padding: int,
213
        kernel_size: int,
214
        padding: str,
215
        strides: int,
216
        output_shape: tuple,
0 ignored issues
show
Unused Code introduced by
The argument output_shape seems to be unused.
Loading history...
217
    ) -> Union[tf.keras.Model, tfkl.Layer]:
218
        """
219
        Build a block for up-sampling.
220
221
        This block changes the tensor shape (width, height, depth),
222
        but it does not changes the number of channels.
223
224
        :param filters: number of channels for output
225
        :param output_padding: padding for output
226
        :param kernel_size: arg for deconv3d
227
        :param padding: arg for deconv3d
228
        :param strides: arg for deconv3d
229
        :param output_shape: shape of the output tensor
230
        :return: a block consists of one or multiple layers
231
        """
232
        return layer.Deconv3dBlock(
233
            filters=filters,
234
            output_padding=output_padding,
235
            kernel_size=kernel_size,
236
            strides=strides,
237
            padding=padding,
238
        )
239
240
    def build_skip_block(self) -> Union[tf.keras.Model, tfkl.Layer]:
241
        """
242
        Build a block for combining skipped tensor and up-sampled one.
243
244
        This block do not change the tensor shape (width, height, depth),
245
        it only changes the number of channels.
246
247
        The input to this block is a list of tensors.
248
249
        :return: a block consists of one or multiple layers
250
        """
251
        if self._concat_skip:
0 ignored issues
show
unused-code introduced by
Unnecessary "else" after "return"
Loading history...
252
            return tfkl.Concatenate()
253
        else:
254
            return tfkl.Add()
255
256
    def build_output_block(
257
        self,
258
        image_size: Tuple[int],
259
        extract_levels: Tuple[int],
260
        out_channels: int,
261
        out_kernel_initializer: str,
262
        out_activation: str,
263
    ) -> Union[tf.keras.Model, tfkl.Layer]:
264
        """
265
        Build a block for output.
266
267
        The input to this block is a list of tensors.
268
269
        :param image_size: such as (dim1, dim2, dim3)
270
        :param extract_levels: number of extraction levels.
271
        :param out_channels: number of channels for the extractions
272
        :param out_kernel_initializer: initializer to use for kernels.
273
        :param out_activation: activation to use at end layer.
274
        :return: a block consists of one or multiple layers
275
        """
276
        return Extraction(
277
            image_size=image_size,
278
            extract_levels=extract_levels,
279
            out_channels=out_channels,
280
            out_kernel_initializer=out_kernel_initializer,
281
            out_activation=out_activation,
282
        )
283
284
    def build_layers(
0 ignored issues
show
introduced by
"ValueError" not documented as being raised
Loading history...
285
        self,
286
        image_size: tuple,
287
        num_channel_initial: int,
288
        depth: int,
289
        extract_levels: Tuple[int],
290
        encode_kernel_sizes: Union[int, List[int]],
291
        decode_kernel_sizes: Union[int, List[int]],
292
        encode_num_channels: Optional[Tuple],
293
        decode_num_channels: Optional[Tuple],
294
        strides: int,
295
        padding: str,
296
        out_kernel_initializer: str,
297
        out_activation: str,
298
        out_channels: int,
299
    ):
300
        """
301
        Build layers that will be used in call.
302
303
        :param image_size: (dim1, dim2, dim3).
304
        :param num_channel_initial: number of initial channels.
305
        :param depth: network starts with d = 0, and the bottom has d = depth.
306
        :param extract_levels: from which depths the output will be built.
307
        :param encode_kernel_sizes: kernel size for down-sampling
308
        :param decode_kernel_sizes: kernel size for up-sampling
309
        :param encode_num_channels: filters/channels for down-sampling,
310
            by default it is doubled at each layer during down-sampling
311
        :param decode_num_channels: filters/channels for up-sampling,
312
            by default it is the same as encode_num_channels
313
        :param strides: strides for down-sampling
314
        :param padding: padding mode for all conv layers
315
        :param out_kernel_initializer: initializer to use for kernels.
316
        :param out_activation: activation to use at end layer.
317
        :param out_channels: number of channels for the extractions
318
        """
319
        if encode_num_channels is None:
320
            assert num_channel_initial > 1
321
            encode_num_channels = tuple(
322
                num_channel_initial * (2 ** d) for d in range(depth + 1)
323
            )
324
        assert len(encode_num_channels) == depth + 1
325
        if decode_num_channels is None:
326
            decode_num_channels = encode_num_channels
327
        assert len(decode_num_channels) == depth + 1
328
        if not self._concat_skip:
329
            # in case of adding skip tensors, the channels should match
330
            if decode_num_channels != encode_num_channels:
331
                raise ValueError(
332
                    "For UNet, if the skipped tensor is added "
333
                    "instead of being concatenated, "
334
                    "the encode_num_channels and decode_num_channels "
335
                    "should be the same. "
336
                    f"But got encode_num_channels = {encode_num_channels},"
337
                    f"decode_num_channels = {decode_num_channels}."
338
                )
339
        tensor_shapes = self.build_encode_layers(
340
            image_size=image_size,
341
            num_channels=encode_num_channels,
342
            depth=depth,
343
            encode_kernel_sizes=encode_kernel_sizes,
344
            strides=strides,
345
            padding=padding,
346
        )
347
        self.build_decode_layers(
348
            tensor_shapes=tensor_shapes,
349
            image_size=image_size,
350
            num_channels=decode_num_channels,
351
            depth=depth,
352
            extract_levels=extract_levels,
353
            decode_kernel_sizes=decode_kernel_sizes,
354
            strides=strides,
355
            padding=padding,
356
            out_kernel_initializer=out_kernel_initializer,
357
            out_activation=out_activation,
358
            out_channels=out_channels,
359
        )
360
361
    def build_encode_layers(
362
        self,
363
        image_size: Tuple,
364
        num_channels: Tuple,
365
        depth: int,
366
        encode_kernel_sizes: Union[int, List[int]],
367
        strides: int,
368
        padding: str,
369
    ) -> List[Tuple]:
370
        """
371
        Build layers for encoding.
372
373
        :param image_size: (dim1, dim2, dim3).
374
        :param num_channels: number of channels for each layer,
375
            starting from the top layer.
376
        :param depth: network starts with d = 0, and the bottom has d = depth.
377
        :param encode_kernel_sizes: kernel size for down-sampling
378
        :param strides: strides for down-sampling
379
        :param padding: padding mode for all conv layers
380
        :return: list of tensor shapes starting from d = 0
381
        """
382
        if isinstance(encode_kernel_sizes, int):
383
            encode_kernel_sizes = [encode_kernel_sizes] * (depth + 1)
384
        assert len(encode_kernel_sizes) == depth + 1
385
386
        # encoding / down-sampling
387
        self._encode_convs = []
388
        self._encode_pools = []
389
        tensor_shape = image_size
390
        tensor_shapes = [tensor_shape]
391
        for d in range(depth):
392
            encode_conv = self.build_conv_block(
393
                filters=num_channels[d],
394
                kernel_size=encode_kernel_sizes[d],
395
                padding=padding,
396
            )
397
            encode_pool = self.build_down_sampling_block(
398
                filters=num_channels[d],
399
                kernel_size=strides,
400
                strides=strides,
401
                padding=padding,
402
            )
403
            tensor_shape = tuple(
404
                conv_utils.conv_output_length(
405
                    input_length=x,
406
                    filter_size=strides,
407
                    padding=padding,
408
                    stride=strides,
409
                    dilation=1,
410
                )
411
                for x in tensor_shape
412
            )
413
            self._encode_convs.append(encode_conv)
414
            self._encode_pools.append(encode_pool)
415
            tensor_shapes.append(tensor_shape)
416
417
        # bottom layer
418
        self._bottom_block = self.build_bottom_block(
419
            filters=num_channels[depth],
420
            kernel_size=encode_kernel_sizes[depth],
421
            padding=padding,
422
        )
423
        return tensor_shapes
424
425
    def build_decode_layers(
426
        self,
427
        tensor_shapes: List[Tuple],
428
        image_size: Tuple,
429
        num_channels: Tuple,
430
        depth: int,
431
        extract_levels: Tuple[int],
432
        decode_kernel_sizes: Union[int, List[int]],
433
        strides: int,
434
        padding: str,
435
        out_kernel_initializer: str,
436
        out_activation: str,
437
        out_channels: int,
438
    ):
439
        """
440
        Build layers for decoding.
441
442
        :param tensor_shapes: shapes calculated in encoder
443
        :param image_size: (dim1, dim2, dim3).
444
        :param num_channels: number of channels for each layer,
445
            starting from the top layer.
446
        :param depth: network starts with d = 0, and the bottom has d = depth.
447
        :param extract_levels: from which depths the output will be built.
448
        :param decode_kernel_sizes: kernel size for up-sampling
449
        :param strides: strides for down-sampling
450
        :param padding: padding mode for all conv layers
451
        :param out_kernel_initializer: initializer to use for kernels.
452
        :param out_activation: activation to use at end layer.
453
        :param out_channels: number of channels for the extractions
454
        """
455
        # init params
456
        min_extract_level = min(extract_levels)
457
        if isinstance(decode_kernel_sizes, int):
458
            decode_kernel_sizes = [decode_kernel_sizes] * depth
459
        assert len(decode_kernel_sizes) == depth
460
461
        # decoding / up-sampling
462
        self._decode_deconvs = []
463
        self._decode_convs = []
464
        for d in range(depth - 1, min_extract_level - 1, -1):
465
            kernel_size = decode_kernel_sizes[d]
466
            output_padding = layer_util.deconv_output_padding(
467
                input_shape=tensor_shapes[d + 1],
468
                output_shape=tensor_shapes[d],
469
                kernel_size=kernel_size,
470
                stride=strides,
471
                padding=padding,
472
            )
473
            decode_deconv = self.build_up_sampling_block(
474
                filters=num_channels[d],
475
                output_padding=output_padding,
476
                kernel_size=kernel_size,
477
                strides=strides,
478
                padding=padding,
479
                output_shape=tensor_shapes[d],
480
            )
481
            decode_conv = self.build_conv_block(
482
                filters=num_channels[d], kernel_size=kernel_size, padding=padding
483
            )
484
            self._decode_deconvs = [decode_deconv] + self._decode_deconvs
485
            self._decode_convs = [decode_conv] + self._decode_convs
486
        if min_extract_level > 0:
487
            # add Nones to make lists have length depth - 1
488
            self._decode_deconvs = [None] * min_extract_level + self._decode_deconvs
489
            self._decode_convs = [None] * min_extract_level + self._decode_convs
490
491
        # extraction
492
        self._output_block = self.build_output_block(
493
            image_size=image_size,
494
            extract_levels=extract_levels,
495
            out_channels=out_channels,
496
            out_kernel_initializer=out_kernel_initializer,
497
            out_activation=out_activation,
498
        )
499
500
    def call(self, inputs: tf.Tensor, training=None, mask=None) -> tf.Tensor:
0 ignored issues
show
introduced by
"mask, training" missing in parameter type documentation
Loading history...
501
        """
502
        Build LocalNet graph based on built layers.
503
504
        :param inputs: image batch, shape = (batch, f_dim1, f_dim2, f_dim3, ch)
505
        :param training: None or bool.
506
        :param mask: None or tf.Tensor.
507
        :return: shape = (batch, f_dim1, f_dim2, f_dim3, out_channels)
508
        """
509
510
        # encoding / down-sampling
511
        skips = []
512
        encoded = inputs
513
        for d in range(self._depth):
514
            skip = self._encode_convs[d](inputs=encoded, training=training)
515
            encoded = self._encode_pools[d](inputs=skip, training=training)
516
            skips.append(skip)
517
518
        # bottom
519
        decoded = self._bottom_block(inputs=encoded, training=training)
520
521
        # decoding / up-sampling
522
        outs = [decoded]
523
        for d in range(self._depth - 1, min(self._extract_levels) - 1, -1):
524
            decoded = self._decode_deconvs[d](inputs=decoded, training=training)
525
            decoded = self.build_skip_block()([decoded, skips[d]])
526
            decoded = self._decode_convs[d](inputs=decoded, training=training)
527
            outs.append(decoded)
528
529
        # output
530
        output = self._output_block(outs)
531
532
        return output
533
534
    def get_config(self) -> dict:
535
        """Return the config dictionary for recreating this class."""
536
        config = super().get_config()
537
        config.update(
538
            depth=self._depth,
539
            extract_levels=self._extract_levels,
540
            pooling=self._pooling,
541
            concat_skip=self._concat_skip,
542
            encode_kernel_sizes=self._encode_kernel_sizes,
543
            decode_kernel_sizes=self._decode_kernel_sizes,
544
            encode_num_channels=self._encode_num_channels,
545
            decode_num_channels=self._decode_num_channels,
546
            strides=self._strides,
547
            padding=self._padding,
548
        )
549
        return config
550