Passed
Pull Request — main (#673)
by Yunguan
03:10
created

UNet.build_conv_block()   A

Complexity

Conditions 1

Size

Total Lines 25
Code Lines 12

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 12
dl 0
loc 25
rs 9.8
c 0
b 0
f 0
cc 1
nop 4
1
# coding=utf-8
0 ignored issues
show
introduced by
Missing module docstring
Loading history...
2
3
from typing import List, Optional, Tuple, Union
4
5
import tensorflow as tf
6
import tensorflow.keras.layers as tfkl
7
from tensorflow.python.keras.utils import conv_utils
8
9
from deepreg.model import layer, layer_util
10
from deepreg.model.backbone.interface import Backbone
11
from deepreg.model.layer import Extraction
12
from deepreg.registry import REGISTRY
13
14
15
@REGISTRY.register_backbone(name="unet")
16
class UNet(Backbone):
17
    """
18
    Class that implements an adapted 3D UNet.
19
20
    Reference:
21
22
    - O. Ronneberger, P. Fischer, and T. Brox,
23
      “U-net: Convolutional networks for biomedical image segmentation,”,
24
      Lecture Notes in Computer Science, 2015, vol. 9351, pp. 234–241.
25
      https://arxiv.org/abs/1505.04597
26
    """
27
28
    def __init__(
29
        self,
30
        image_size: tuple,
31
        num_channel_initial: Optional[int],
32
        depth: int,
33
        out_kernel_initializer: str,
34
        out_activation: str,
35
        out_channels: int,
36
        extract_levels: Tuple = (0,),
37
        pooling: bool = True,
38
        concat_skip: bool = False,
39
        encode_kernel_sizes: Union[int, List[int]] = 3,
40
        decode_kernel_sizes: Union[int, List[int]] = 3,
41
        encode_num_channels: Optional[Tuple] = None,
42
        decode_num_channels: Optional[Tuple] = None,
43
        strides: int = 2,
44
        padding: str = "same",
45
        name: str = "Unet",
46
        **kwargs,
47
    ):
48
        """
49
        Initialise UNet.
50
51
        :param image_size: (dim1, dim2, dim3), dims of input image.
52
        :param num_channel_initial: number of initial channels
53
        :param depth: input is at level 0, bottom is at level depth.
54
        :param out_kernel_initializer: kernel initializer for the last layer
55
        :param out_activation: activation at the last layer
56
        :param out_channels: number of channels for the output
57
        :param extract_levels: list, which levels from net to extract.
58
        :param pooling: for down-sampling, use non-parameterized
59
                        pooling if true, otherwise use conv3d
60
        :param concat_skip: when up-sampling, concatenate skipped
61
                            tensor if true, otherwise use addition
62
        :param encode_kernel_sizes: kernel size for down-sampling
63
        :param decode_kernel_sizes: kernel size for up-sampling
64
        :param encode_num_channels: filters/channels for down-sampling,
65
            by default it is doubled at each layer during down-sampling
66
        :param decode_num_channels: filters/channels for up-sampling,
67
            by default it is the same as encode_num_channels
68
        :param strides: strides for down-sampling
69
        :param padding: padding mode for all conv layers
70
        :param name: name of the backbone.
71
        :param kwargs: additional arguments.
72
        """
73
        super().__init__(
74
            image_size=image_size,
75
            out_channels=out_channels,
76
            num_channel_initial=num_channel_initial,
77
            out_kernel_initializer=out_kernel_initializer,
78
            out_activation=out_activation,
79
            name=name,
80
            **kwargs,
81
        )
82
83
        # save parameters
84
        assert max(extract_levels) <= depth
85
        self._extract_levels = extract_levels
86
        self._depth = depth
87
88
        # save extra parameters
89
        self._concat_skip = concat_skip
90
        self._pooling = pooling
91
        self._encode_kernel_sizes = encode_kernel_sizes
92
        self._decode_kernel_sizes = decode_kernel_sizes
93
        self._encode_num_channels = encode_num_channels
94
        self._decode_num_channels = decode_num_channels
95
        self._strides = strides
96
        self._padding = padding
97
98
        # init layers
99
        # all lists start with d = 0
100
        self._encode_convs = None
101
        self._encode_pools = None
102
        self._bottom_block = None
103
        self._decode_deconvs = None
104
        self._decode_convs = None
105
        self._output_block = None
106
107
        # build layers
108
        self.build_layers(
109
            image_size=image_size,
110
            num_channel_initial=num_channel_initial,
111
            depth=depth,
112
            extract_levels=extract_levels,
113
            encode_kernel_sizes=encode_kernel_sizes,
114
            decode_kernel_sizes=decode_kernel_sizes,
115
            encode_num_channels=encode_num_channels,
116
            decode_num_channels=decode_num_channels,
117
            strides=strides,
118
            padding=padding,
119
            out_kernel_initializer=out_kernel_initializer,
120
            out_activation=out_activation,
121
            out_channels=out_channels,
122
        )
123
124
    def build_encode_conv_block(
125
        self, filters: int, kernel_size: int, padding: str
126
    ) -> Union[tf.keras.Model, tfkl.Layer]:
127
        """
128
        Build a conv block for down-sampling
129
130
        This block do not change the tensor shape (width, height, depth),
131
        it only changes the number of channels.
132
133
        :param filters: number of channels for output
134
        :param kernel_size: arg for conv3d
135
        :param padding: arg for conv3d
136
        :return: a block consists of one or multiple layers
137
        """
138
        return tf.keras.Sequential(
139
            [
140
                layer.Conv3dBlock(
141
                    filters=filters,
142
                    kernel_size=kernel_size,
143
                    padding=padding,
144
                ),
145
                layer.ResidualConv3dBlock(
146
                    filters=filters,
147
                    kernel_size=kernel_size,
148
                    padding=padding,
149
                ),
150
            ]
151
        )
152
153
    def build_down_sampling_block(
154
        self, filters: int, kernel_size: int, padding: str, strides: int
155
    ) -> Union[tf.keras.Model, tfkl.Layer]:
156
        """
157
        Build a block for down-sampling.
158
159
        This block changes the tensor shape (width, height, depth),
160
        but it does not changes the number of channels.
161
162
        :param filters: number of channels for output, arg for conv3d
163
        :param kernel_size: arg for pool3d or conv3d
164
        :param padding: arg for pool3d or conv3d
165
        :param strides: arg for pool3d or conv3d
166
        :return: a block consists of one or multiple layers
167
        """
168
        if self._pooling:
0 ignored issues
show
unused-code introduced by
Unnecessary "else" after "return"
Loading history...
169
            return tfkl.MaxPool3D(
170
                pool_size=kernel_size, strides=strides, padding=padding
171
            )
172
        else:
173
            return layer.Conv3dBlock(
174
                filters=filters,
175
                kernel_size=kernel_size,
176
                strides=strides,
177
                padding=padding,
178
            )
179
180
    def build_bottom_block(
181
        self, filters: int, kernel_size: int, padding: str
182
    ) -> Union[tf.keras.Model, tfkl.Layer]:
183
        """
184
        Build a block for bottom layer.
185
186
        This block do not change the tensor shape (width, height, depth),
187
        it only changes the number of channels.
188
189
        :param filters: number of channels for output
190
        :param kernel_size: arg for conv3d
191
        :param padding: arg for conv3d
192
        :return: a block consists of one or multiple layers
193
        """
194
        return tf.keras.Sequential(
195
            [
196
                layer.Conv3dBlock(
197
                    filters=filters,
198
                    kernel_size=kernel_size,
199
                    padding=padding,
200
                ),
201
                layer.ResidualConv3dBlock(
202
                    filters=filters,
203
                    kernel_size=kernel_size,
204
                    padding=padding,
205
                ),
206
            ]
207
        )
208
209
    def build_up_sampling_block(
210
        self,
211
        filters: int,
212
        output_padding: int,
213
        kernel_size: int,
214
        padding: str,
215
        strides: int,
216
        output_shape: tuple,
0 ignored issues
show
Unused Code introduced by
The argument output_shape seems to be unused.
Loading history...
217
    ) -> Union[tf.keras.Model, tfkl.Layer]:
218
        """
219
        Build a block for up-sampling.
220
221
        This block changes the tensor shape (width, height, depth),
222
        but it does not changes the number of channels.
223
224
        :param filters: number of channels for output
225
        :param output_padding: padding for output
226
        :param kernel_size: arg for deconv3d
227
        :param padding: arg for deconv3d
228
        :param strides: arg for deconv3d
229
        :param output_shape: shape of the output tensor
230
        :return: a block consists of one or multiple layers
231
        """
232
        return layer.Deconv3dBlock(
233
            filters=filters,
234
            output_padding=output_padding,
235
            kernel_size=kernel_size,
236
            strides=strides,
237
            padding=padding,
238
        )
239
240
    def build_skip_block(self) -> Union[tf.keras.Model, tfkl.Layer]:
241
        """
242
        Build a block for combining skipped tensor and up-sampled one.
243
244
        This block do not change the tensor shape (width, height, depth),
245
        it only changes the number of channels.
246
247
        The input to this block is a list of tensors.
248
249
        :return: a block consists of one or multiple layers
250
        """
251
        if self._concat_skip:
0 ignored issues
show
unused-code introduced by
Unnecessary "else" after "return"
Loading history...
252
            return tfkl.Concatenate()
253
        else:
254
            return tfkl.Add()
255
256
    def build_decode_conv_block(
257
        self, filters: int, kernel_size: int, padding: str
258
    ) -> Union[tf.keras.Model, tfkl.Layer]:
259
        """
260
        Build a conv block for up-sampling
261
262
        This block do not change the tensor shape (width, height, depth),
263
        it only changes the number of channels.
264
265
        :param filters: number of channels for output
266
        :param kernel_size: arg for conv3d
267
        :param padding: arg for conv3d
268
        :return: a block consists of one or multiple layers
269
        """
270
        return tf.keras.Sequential(
271
            [
272
                layer.Conv3dBlock(
273
                    filters=filters,
274
                    kernel_size=kernel_size,
275
                    padding=padding,
276
                ),
277
                layer.ResidualConv3dBlock(
278
                    filters=filters,
279
                    kernel_size=kernel_size,
280
                    padding=padding,
281
                ),
282
            ]
283
        )
284
285
    def build_output_block(
286
        self,
287
        image_size: Tuple[int],
288
        extract_levels: Tuple[int],
289
        out_channels: int,
290
        out_kernel_initializer: str,
291
        out_activation: str,
292
    ) -> Union[tf.keras.Model, tfkl.Layer]:
293
        """
294
        Build a block for output.
295
296
        The input to this block is a list of tensors.
297
298
        :param image_size: such as (dim1, dim2, dim3)
299
        :param extract_levels: number of extraction levels.
300
        :param out_channels: number of channels for the extractions
301
        :param out_kernel_initializer: initializer to use for kernels.
302
        :param out_activation: activation to use at end layer.
303
        :return: a block consists of one or multiple layers
304
        """
305
        return Extraction(
306
            image_size=image_size,
307
            extract_levels=extract_levels,
308
            out_channels=out_channels,
309
            out_kernel_initializer=out_kernel_initializer,
310
            out_activation=out_activation,
311
        )
312
313
    def build_layers(
0 ignored issues
show
introduced by
"ValueError" not documented as being raised
Loading history...
314
        self,
315
        image_size: tuple,
316
        num_channel_initial: int,
317
        depth: int,
318
        extract_levels: Tuple[int],
319
        encode_kernel_sizes: Union[int, List[int]],
320
        decode_kernel_sizes: Union[int, List[int]],
321
        encode_num_channels: Optional[Tuple],
322
        decode_num_channels: Optional[Tuple],
323
        strides: int,
324
        padding: str,
325
        out_kernel_initializer: str,
326
        out_activation: str,
327
        out_channels: int,
328
    ):
329
        """
330
        Build layers that will be used in call.
331
332
        :param image_size: (dim1, dim2, dim3).
333
        :param num_channel_initial: number of initial channels.
334
        :param depth: network starts with d = 0, and the bottom has d = depth.
335
        :param extract_levels: from which depths the output will be built.
336
        :param encode_kernel_sizes: kernel size for down-sampling
337
        :param decode_kernel_sizes: kernel size for up-sampling
338
        :param encode_num_channels: filters/channels for down-sampling,
339
            by default it is doubled at each layer during down-sampling
340
        :param decode_num_channels: filters/channels for up-sampling,
341
            by default it is the same as encode_num_channels
342
        :param strides: strides for down-sampling
343
        :param padding: padding mode for all conv layers
344
        :param out_kernel_initializer: initializer to use for kernels.
345
        :param out_activation: activation to use at end layer.
346
        :param out_channels: number of channels for the extractions
347
        """
348
        if encode_num_channels is None:
349
            assert num_channel_initial >= 1
350
            encode_num_channels = tuple(
351
                num_channel_initial * (2 ** d) for d in range(depth + 1)
352
            )
353
        assert len(encode_num_channels) == depth + 1
354
        if decode_num_channels is None:
355
            decode_num_channels = encode_num_channels
356
        assert len(decode_num_channels) == depth + 1
357
        if not self._concat_skip:
358
            # in case of adding skip tensors, the channels should match
359
            if decode_num_channels != encode_num_channels:
360
                raise ValueError(
361
                    "For UNet, if the skipped tensor is added "
362
                    "instead of being concatenated, "
363
                    "the encode_num_channels and decode_num_channels "
364
                    "should be the same. "
365
                    f"But got encode_num_channels = {encode_num_channels},"
366
                    f"decode_num_channels = {decode_num_channels}."
367
                )
368
        tensor_shapes = self.build_encode_layers(
369
            image_size=image_size,
370
            num_channels=encode_num_channels,
371
            depth=depth,
372
            encode_kernel_sizes=encode_kernel_sizes,
373
            strides=strides,
374
            padding=padding,
375
        )
376
        self.build_decode_layers(
377
            tensor_shapes=tensor_shapes,
378
            image_size=image_size,
379
            num_channels=decode_num_channels,
380
            depth=depth,
381
            extract_levels=extract_levels,
382
            decode_kernel_sizes=decode_kernel_sizes,
383
            strides=strides,
384
            padding=padding,
385
            out_kernel_initializer=out_kernel_initializer,
386
            out_activation=out_activation,
387
            out_channels=out_channels,
388
        )
389
390
    def build_encode_layers(
391
        self,
392
        image_size: Tuple,
393
        num_channels: Tuple,
394
        depth: int,
395
        encode_kernel_sizes: Union[int, List[int]],
396
        strides: int,
397
        padding: str,
398
    ) -> List[Tuple]:
399
        """
400
        Build layers for encoding.
401
402
        :param image_size: (dim1, dim2, dim3).
403
        :param num_channels: number of channels for each layer,
404
            starting from the top layer.
405
        :param depth: network starts with d = 0, and the bottom has d = depth.
406
        :param encode_kernel_sizes: kernel size for down-sampling
407
        :param strides: strides for down-sampling
408
        :param padding: padding mode for all conv layers
409
        :return: list of tensor shapes starting from d = 0
410
        """
411
        if isinstance(encode_kernel_sizes, int):
412
            encode_kernel_sizes = [encode_kernel_sizes] * (depth + 1)
413
        assert len(encode_kernel_sizes) == depth + 1
414
415
        # encoding / down-sampling
416
        self._encode_convs = []
417
        self._encode_pools = []
418
        tensor_shape = image_size
419
        tensor_shapes = [tensor_shape]
420
        for d in range(depth):
421
            encode_conv = self.build_encode_conv_block(
422
                filters=num_channels[d],
423
                kernel_size=encode_kernel_sizes[d],
424
                padding=padding,
425
            )
426
            encode_pool = self.build_down_sampling_block(
427
                filters=num_channels[d],
428
                kernel_size=strides,
429
                strides=strides,
430
                padding=padding,
431
            )
432
            tensor_shape = tuple(
433
                conv_utils.conv_output_length(
434
                    input_length=x,
435
                    filter_size=strides,
436
                    padding=padding,
437
                    stride=strides,
438
                    dilation=1,
439
                )
440
                for x in tensor_shape
441
            )
442
            self._encode_convs.append(encode_conv)
443
            self._encode_pools.append(encode_pool)
444
            tensor_shapes.append(tensor_shape)
445
446
        # bottom layer
447
        self._bottom_block = self.build_bottom_block(
448
            filters=num_channels[depth],
449
            kernel_size=encode_kernel_sizes[depth],
450
            padding=padding,
451
        )
452
        return tensor_shapes
453
454
    def build_decode_layers(
455
        self,
456
        tensor_shapes: List[Tuple],
457
        image_size: Tuple,
458
        num_channels: Tuple,
459
        depth: int,
460
        extract_levels: Tuple[int],
461
        decode_kernel_sizes: Union[int, List[int]],
462
        strides: int,
463
        padding: str,
464
        out_kernel_initializer: str,
465
        out_activation: str,
466
        out_channels: int,
467
    ):
468
        """
469
        Build layers for decoding.
470
471
        :param tensor_shapes: shapes calculated in encoder
472
        :param image_size: (dim1, dim2, dim3).
473
        :param num_channels: number of channels for each layer,
474
            starting from the top layer.
475
        :param depth: network starts with d = 0, and the bottom has d = depth.
476
        :param extract_levels: from which depths the output will be built.
477
        :param decode_kernel_sizes: kernel size for up-sampling
478
        :param strides: strides for down-sampling
479
        :param padding: padding mode for all conv layers
480
        :param out_kernel_initializer: initializer to use for kernels.
481
        :param out_activation: activation to use at end layer.
482
        :param out_channels: number of channels for the extractions
483
        """
484
        # init params
485
        min_extract_level = min(extract_levels)
486
        if isinstance(decode_kernel_sizes, int):
487
            decode_kernel_sizes = [decode_kernel_sizes] * depth
488
        assert len(decode_kernel_sizes) == depth
489
490
        # decoding / up-sampling
491
        self._decode_deconvs = []
492
        self._decode_convs = []
493
        for d in range(depth - 1, min_extract_level - 1, -1):
494
            kernel_size = decode_kernel_sizes[d]
495
            output_padding = layer_util.deconv_output_padding(
496
                input_shape=tensor_shapes[d + 1],
497
                output_shape=tensor_shapes[d],
498
                kernel_size=kernel_size,
499
                stride=strides,
500
                padding=padding,
501
            )
502
            decode_deconv = self.build_up_sampling_block(
503
                filters=num_channels[d],
504
                output_padding=output_padding,
505
                kernel_size=kernel_size,
506
                strides=strides,
507
                padding=padding,
508
                output_shape=tensor_shapes[d],
509
            )
510
            decode_conv = self.build_decode_conv_block(
511
                filters=num_channels[d], kernel_size=kernel_size, padding=padding
512
            )
513
            self._decode_deconvs = [decode_deconv] + self._decode_deconvs
514
            self._decode_convs = [decode_conv] + self._decode_convs
515
        if min_extract_level > 0:
516
            # add Nones to make lists have length depth - 1
517
            self._decode_deconvs = [None] * min_extract_level + self._decode_deconvs
518
            self._decode_convs = [None] * min_extract_level + self._decode_convs
519
520
        # extraction
521
        self._output_block = self.build_output_block(
522
            image_size=image_size,
523
            extract_levels=extract_levels,
524
            out_channels=out_channels,
525
            out_kernel_initializer=out_kernel_initializer,
526
            out_activation=out_activation,
527
        )
528
529
    def call(self, inputs: tf.Tensor, training=None, mask=None) -> tf.Tensor:
0 ignored issues
show
introduced by
"mask, training" missing in parameter type documentation
Loading history...
530
        """
531
        Build compute graph based on built layers.
532
533
        :param inputs: image batch, shape = (batch, f_dim1, f_dim2, f_dim3, ch)
534
        :param training: None or bool.
535
        :param mask: None or tf.Tensor.
536
        :return: shape = (batch, f_dim1, f_dim2, f_dim3, out_channels)
537
        """
538
539
        # encoding / down-sampling
540
        skips = []
541
        encoded = inputs
542
        for d in range(self._depth):
543
            skip = self._encode_convs[d](inputs=encoded, training=training)
544
            encoded = self._encode_pools[d](inputs=skip, training=training)
545
            skips.append(skip)
546
547
        # bottom
548
        decoded = self._bottom_block(inputs=encoded, training=training)
549
550
        # decoding / up-sampling
551
        outs = [decoded]
552
        for d in range(self._depth - 1, min(self._extract_levels) - 1, -1):
553
            decoded = self._decode_deconvs[d](inputs=decoded, training=training)
554
            decoded = self.build_skip_block()([decoded, skips[d]])
555
            decoded = self._decode_convs[d](inputs=decoded, training=training)
556
            outs = [decoded] + outs
557
558
        # output
559
        output = self._output_block(outs)
560
561
        return output
562
563
    def get_config(self) -> dict:
564
        """Return the config dictionary for recreating this class."""
565
        config = super().get_config()
566
        config.update(
567
            depth=self._depth,
568
            extract_levels=self._extract_levels,
569
            pooling=self._pooling,
570
            concat_skip=self._concat_skip,
571
            encode_kernel_sizes=self._encode_kernel_sizes,
572
            decode_kernel_sizes=self._decode_kernel_sizes,
573
            encode_num_channels=self._encode_num_channels,
574
            decode_num_channels=self._decode_num_channels,
575
            strides=self._strides,
576
            padding=self._padding,
577
        )
578
        return config
579