Passed
Pull Request — main (#656)
by Yunguan
25:40
created

deepreg.model.backbone.u_net   A

Complexity

Total Complexity 20

Size/Duplication

Total Lines 491
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 20
eloc 262
dl 0
loc 491
rs 10
c 0
b 0
f 0

11 Methods

Rating   Name   Duplication   Size   Complexity  
B UNet.__init__() 0 79 1
A UNet.build_skip_block() 0 15 2
A UNet.build_up_sampling_block() 0 29 1
B UNet.build_encode_layers() 0 63 3
A UNet.build_down_sampling_block() 0 25 2
A UNet.call() 0 33 3
B UNet.build_decode_layers() 0 72 4
A UNet.build_layers() 0 49 1
A UNet.build_bottom_block() 0 25 1
A UNet.build_output_block() 0 26 1
A UNet.build_conv_block() 0 25 1
1
# coding=utf-8
0 ignored issues
show
introduced by
Missing module docstring
Loading history...
2
3
from typing import List, Tuple, Union
4
5
import tensorflow as tf
6
import tensorflow.keras.layers as tfkl
7
from tensorflow.python.keras.utils import conv_utils
8
9
from deepreg.model import layer, layer_util
10
from deepreg.model.backbone.interface import Backbone
11
from deepreg.model.layer import Extraction
12
from deepreg.registry import REGISTRY
13
14
15
@REGISTRY.register_backbone(name="unet")
16
class UNet(Backbone):
17
    """
18
    Class that implements an adapted 3D UNet.
19
20
    Reference:
21
22
    - O. Ronneberger, P. Fischer, and T. Brox,
23
      “U-net: Convolutional networks for biomedical image segmentation,”,
24
      Lecture Notes in Computer Science, 2015, vol. 9351, pp. 234–241.
25
      https://arxiv.org/abs/1505.04597
26
    """
27
28
    def __init__(
0 ignored issues
show
introduced by
"extract_levels" missing in parameter documentation
Loading history...
29
        self,
30
        image_size: tuple,
31
        num_channel_initial: int,
32
        depth: int,
33
        out_kernel_initializer: str,
34
        out_activation: str,
35
        out_channels: int,
36
        extract_levels: Tuple[int] = (0,),
37
        pooling: bool = True,
38
        concat_skip: bool = False,
39
        encode_kernel_sizes: Union[int, List[int]] = 3,
40
        decode_kernel_sizes: Union[int, List[int]] = 3,
41
        strides: int = 2,
42
        padding: str = "same",
43
        name: str = "Unet",
44
        **kwargs,
45
    ):
46
        """
47
        Initialise UNet.
48
49
        :param image_size: (dim1, dim2, dim3), dims of input image.
50
        :param out_channels: number of channels for the output
51
        :param num_channel_initial: number of initial channels
52
        :param depth: input is at level 0, bottom is at level depth
53
        :param out_kernel_initializer: kernel initializer for the last layer
54
        :param out_activation: activation at the last layer
55
        :param pooling: for down-sampling, use non-parameterized
56
                        pooling if true, otherwise use conv3d
57
        :param concat_skip: when up-sampling, concatenate skipped
58
                            tensor if true, otherwise use addition
59
        :param encode_kernel_sizes: kernel size for down-sampling
60
        :param decode_kernel_sizes: kernel size for up-sampling
61
        :param strides: strides for down-sampling
62
        :param padding: padding mode for all conv layers
63
        :param name: name of the backbone.
64
        :param kwargs: additional arguments.
65
        """
66
        super().__init__(
67
            image_size=image_size,
68
            out_channels=out_channels,
69
            num_channel_initial=num_channel_initial,
70
            out_kernel_initializer=out_kernel_initializer,
71
            out_activation=out_activation,
72
            name=name,
73
            **kwargs,
74
        )
75
76
        # save parameters
77
        assert max(extract_levels) <= depth
78
        self._extract_levels = extract_levels
79
        self._depth = depth
80
81
        # save extra parameters
82
        self._concat_skip = concat_skip
83
        self._pooling = pooling
84
85
        # init layers
86
        # all lists start with d = 0
87
        self._encode_convs = None
88
        self._encode_pools = None
89
        self._bottom_block = None
90
        self._decode_deconvs = None
91
        self._decode_convs = None
92
        self._output_block = None
93
94
        # build layers
95
        self.build_layers(
96
            image_size=image_size,
97
            num_channel_initial=num_channel_initial,
98
            depth=depth,
99
            extract_levels=extract_levels,
100
            encode_kernel_sizes=encode_kernel_sizes,
101
            decode_kernel_sizes=decode_kernel_sizes,
102
            strides=strides,
103
            padding=padding,
104
            out_kernel_initializer=out_kernel_initializer,
105
            out_activation=out_activation,
106
            out_channels=out_channels,
107
        )
108
109
    def build_conv_block(
110
        self, filters: int, kernel_size: int, padding: str
111
    ) -> Union[tf.keras.Model, tfkl.Layer]:
112
        """
113
        Build a conv block for down-sampling or up-sampling.
114
115
        This block do not change the tensor shape (width, height, depth),
116
        it only changes the number of channels.
117
118
        :param filters: number of channels for output
119
        :param kernel_size: arg for conv3d
120
        :param padding: arg for conv3d
121
        :return: a block consists of one or multiple layers
122
        """
123
        return tf.keras.Sequential(
124
            [
125
                layer.Conv3dBlock(
126
                    filters=filters,
127
                    kernel_size=kernel_size,
128
                    padding=padding,
129
                ),
130
                layer.ResidualConv3dBlock(
131
                    filters=filters,
132
                    kernel_size=kernel_size,
133
                    padding=padding,
134
                ),
135
            ]
136
        )
137
138
    def build_down_sampling_block(
139
        self, filters: int, kernel_size: int, padding: str, strides: int
140
    ) -> Union[tf.keras.Model, tfkl.Layer]:
141
        """
142
        Build a block for down-sampling.
143
144
        This block changes the tensor shape (width, height, depth),
145
        but it does not changes the number of channels.
146
147
        :param filters: number of channels for output, arg for conv3d
148
        :param kernel_size: arg for pool3d or conv3d
149
        :param padding: arg for pool3d or conv3d
150
        :param strides: arg for pool3d or conv3d
151
        :return: a block consists of one or multiple layers
152
        """
153
        if self._pooling:
0 ignored issues
show
unused-code introduced by
Unnecessary "else" after "return"
Loading history...
154
            return tfkl.MaxPool3D(
155
                pool_size=kernel_size, strides=strides, padding=padding
156
            )
157
        else:
158
            return layer.Conv3dBlock(
159
                filters=filters,
160
                kernel_size=kernel_size,
161
                strides=strides,
162
                padding=padding,
163
            )
164
165
    def build_bottom_block(
166
        self, filters: int, kernel_size: int, padding: str
167
    ) -> Union[tf.keras.Model, tfkl.Layer]:
168
        """
169
        Build a block for bottom layer.
170
171
        This block do not change the tensor shape (width, height, depth),
172
        it only changes the number of channels.
173
174
        :param filters: number of channels for output
175
        :param kernel_size: arg for conv3d
176
        :param padding: arg for conv3d
177
        :return: a block consists of one or multiple layers
178
        """
179
        return tf.keras.Sequential(
180
            [
181
                layer.Conv3dBlock(
182
                    filters=filters,
183
                    kernel_size=kernel_size,
184
                    padding=padding,
185
                ),
186
                layer.ResidualConv3dBlock(
187
                    filters=filters,
188
                    kernel_size=kernel_size,
189
                    padding=padding,
190
                ),
191
            ]
192
        )
193
194
    def build_up_sampling_block(
195
        self,
196
        filters: int,
197
        output_padding: int,
198
        kernel_size: int,
199
        padding: str,
200
        strides: int,
201
        output_shape: tuple,
0 ignored issues
show
Unused Code introduced by
The argument output_shape seems to be unused.
Loading history...
202
    ) -> Union[tf.keras.Model, tfkl.Layer]:
203
        """
204
        Build a block for up-sampling.
205
206
        This block changes the tensor shape (width, height, depth),
207
        but it does not changes the number of channels.
208
209
        :param filters: number of channels for output
210
        :param output_padding: padding for output
211
        :param kernel_size: arg for deconv3d
212
        :param padding: arg for deconv3d
213
        :param strides: arg for deconv3d
214
        :param output_shape: shape of the output tensor
215
        :return: a block consists of one or multiple layers
216
        """
217
        return layer.Deconv3dBlock(
218
            filters=filters,
219
            output_padding=output_padding,
220
            kernel_size=kernel_size,
221
            strides=strides,
222
            padding=padding,
223
        )
224
225
    def build_skip_block(self) -> Union[tf.keras.Model, tfkl.Layer]:
226
        """
227
        Build a block for combining skipped tensor and up-sampled one.
228
229
        This block do not change the tensor shape (width, height, depth),
230
        it only changes the number of channels.
231
232
        The input to this block is a list of tensors.
233
234
        :return: a block consists of one or multiple layers
235
        """
236
        if self._concat_skip:
0 ignored issues
show
unused-code introduced by
Unnecessary "else" after "return"
Loading history...
237
            return tfkl.Concatenate()
238
        else:
239
            return tfkl.Add()
240
241
    def build_output_block(
242
        self,
243
        image_size: Tuple[int],
244
        extract_levels: Tuple[int],
245
        out_channels: int,
246
        out_kernel_initializer: str,
247
        out_activation: str,
248
    ) -> Union[tf.keras.Model, tfkl.Layer]:
249
        """
250
        Build a block for output.
251
252
        The input to this block is a list of tensors.
253
254
        :param image_size: such as (dim1, dim2, dim3)
255
        :param extract_levels: number of extraction levels.
256
        :param out_channels: number of channels for the extractions
257
        :param out_kernel_initializer: initializer to use for kernels.
258
        :param out_activation: activation to use at end layer.
259
        :return: a block consists of one or multiple layers
260
        """
261
        return Extraction(
262
            image_size=image_size,
263
            extract_levels=extract_levels,
264
            out_channels=out_channels,
265
            out_kernel_initializer=out_kernel_initializer,
266
            out_activation=out_activation,
267
        )
268
269
    def build_layers(
270
        self,
271
        image_size: tuple,
272
        num_channel_initial: int,
273
        depth: int,
274
        extract_levels: Tuple[int],
275
        encode_kernel_sizes: Union[int, List[int]],
276
        decode_kernel_sizes: Union[int, List[int]],
277
        strides: int,
278
        padding: str,
279
        out_kernel_initializer: str,
280
        out_activation: str,
281
        out_channels: int,
282
    ):
283
        """
284
        Build layers that will be used in call.
285
286
        :param image_size: (dim1, dim2, dim3).
287
        :param num_channel_initial: number of initial channels.
288
        :param depth: network starts with d = 0, and the bottom has d = depth.
289
        :param extract_levels: from which depths the output will be built.
290
        :param encode_kernel_sizes: kernel size for down-sampling
291
        :param decode_kernel_sizes: kernel size for up-sampling
292
        :param strides: strides for down-sampling
293
        :param padding: padding mode for all conv layers
294
        :param out_kernel_initializer: initializer to use for kernels.
295
        :param out_activation: activation to use at end layer.
296
        :param out_channels: number of channels for the extractions
297
        """
298
        tensor_shapes = self.build_encode_layers(
299
            image_size=image_size,
300
            num_channel_initial=num_channel_initial,
301
            depth=depth,
302
            encode_kernel_sizes=encode_kernel_sizes,
303
            strides=strides,
304
            padding=padding,
305
        )
306
        self.build_decode_layers(
307
            tensor_shapes=tensor_shapes,
308
            image_size=image_size,
309
            num_channel_initial=num_channel_initial,
310
            depth=depth,
311
            extract_levels=extract_levels,
312
            decode_kernel_sizes=decode_kernel_sizes,
313
            strides=strides,
314
            padding=padding,
315
            out_kernel_initializer=out_kernel_initializer,
316
            out_activation=out_activation,
317
            out_channels=out_channels,
318
        )
319
320
    def build_encode_layers(
0 ignored issues
show
introduced by
Missing return documentation
Loading history...
introduced by
Missing return type documentation
Loading history...
321
        self,
322
        image_size: tuple,
323
        num_channel_initial: int,
324
        depth: int,
325
        encode_kernel_sizes: Union[int, List[int]],
326
        strides: int,
327
        padding: str,
328
    ):
329
        """
330
        Build layers for encoding.
331
332
        :param image_size: (dim1, dim2, dim3).
333
        :param num_channel_initial: number of initial channels.
334
        :param depth: network starts with d = 0, and the bottom has d = depth.
335
        :param encode_kernel_sizes: kernel size for down-sampling
336
        :param strides: strides for down-sampling
337
        :param padding: padding mode for all conv layers
338
        """
339
        # init params
340
        num_channels = [num_channel_initial * (2 ** d) for d in range(depth + 1)]
341
        if isinstance(encode_kernel_sizes, int):
342
            encode_kernel_sizes = [encode_kernel_sizes] * (depth + 1)
343
        assert len(encode_kernel_sizes) == depth + 1
344
345
        # encoding / down-sampling
346
        self._encode_convs = []
347
        self._encode_pools = []
348
        tensor_shape = image_size
349
        tensor_shapes = [tensor_shape]
350
        for d in range(depth):
351
            encode_conv = self.build_conv_block(
352
                filters=num_channels[d],
353
                kernel_size=encode_kernel_sizes[d],
354
                padding=padding,
355
            )
356
            encode_pool = self.build_down_sampling_block(
357
                filters=num_channels[d],
358
                kernel_size=strides,
359
                strides=strides,
360
                padding=padding,
361
            )
362
            tensor_shape = tuple(
363
                conv_utils.conv_output_length(
364
                    input_length=x,
365
                    filter_size=strides,
366
                    padding=padding,
367
                    stride=strides,
368
                    dilation=1,
369
                )
370
                for x in tensor_shape
371
            )
372
            self._encode_convs.append(encode_conv)
373
            self._encode_pools.append(encode_pool)
374
            tensor_shapes.append(tensor_shape)
375
376
        # bottom layer
377
        self._bottom_block = self.build_bottom_block(
378
            filters=num_channels[depth],
379
            kernel_size=encode_kernel_sizes[depth],
380
            padding=padding,
381
        )
382
        return tensor_shapes
383
384
    def build_decode_layers(
0 ignored issues
show
introduced by
"tensor_shapes" missing in parameter documentation
Loading history...
385
        self,
386
        tensor_shapes: List[Tuple],
387
        image_size: tuple,
388
        num_channel_initial: int,
389
        depth: int,
390
        extract_levels: Tuple[int],
391
        decode_kernel_sizes: Union[int, List[int]],
392
        strides: int,
393
        padding: str,
394
        out_kernel_initializer: str,
395
        out_activation: str,
396
        out_channels: int,
397
    ):
398
        """
399
        Build layers for decoding.
400
401
        :param image_size: (dim1, dim2, dim3).
402
        :param num_channel_initial: number of initial channels.
403
        :param depth: network starts with d = 0, and the bottom has d = depth.
404
        :param extract_levels: from which depths the output will be built.
405
        :param decode_kernel_sizes: kernel size for up-sampling
406
        :param strides: strides for down-sampling
407
        :param padding: padding mode for all conv layers
408
        :param out_kernel_initializer: initializer to use for kernels.
409
        :param out_activation: activation to use at end layer.
410
        :param out_channels: number of channels for the extractions
411
        """
412
        # init params
413
        min_extract_level = min(extract_levels)
414
        num_channels = [num_channel_initial * (2 ** d) for d in range(depth + 1)]
415
        if isinstance(decode_kernel_sizes, int):
416
            decode_kernel_sizes = [decode_kernel_sizes] * depth
417
        assert len(decode_kernel_sizes) == depth
418
419
        # decoding / up-sampling
420
        self._decode_deconvs = []
421
        self._decode_convs = []
422
        for d in range(depth - 1, min_extract_level - 1, -1):
423
            kernel_size = decode_kernel_sizes[d]
424
            output_padding = layer_util.deconv_output_padding(
425
                input_shape=tensor_shapes[d + 1],
426
                output_shape=tensor_shapes[d],
427
                kernel_size=kernel_size,
428
                stride=strides,
429
                padding=padding,
430
            )
431
            decode_deconv = self.build_up_sampling_block(
432
                filters=num_channels[d],
433
                output_padding=output_padding,
434
                kernel_size=kernel_size,
435
                strides=strides,
436
                padding=padding,
437
                output_shape=tensor_shapes[d],
438
            )
439
            decode_conv = self.build_conv_block(
440
                filters=num_channels[d], kernel_size=kernel_size, padding=padding
441
            )
442
            self._decode_deconvs = [decode_deconv] + self._decode_deconvs
443
            self._decode_convs = [decode_conv] + self._decode_convs
444
        if min_extract_level > 0:
445
            # add Nones to make lists have length depth - 1
446
            self._decode_deconvs = [None] * min_extract_level + self._decode_deconvs
447
            self._decode_convs = [None] * min_extract_level + self._decode_convs
448
449
        # extraction
450
        self._output_block = self.build_output_block(
451
            image_size=image_size,
452
            extract_levels=extract_levels,
453
            out_channels=out_channels,
454
            out_kernel_initializer=out_kernel_initializer,
455
            out_activation=out_activation,
456
        )
457
458
    def call(self, inputs: tf.Tensor, training=None, mask=None) -> tf.Tensor:
0 ignored issues
show
introduced by
"mask, training" missing in parameter type documentation
Loading history...
459
        """
460
        Build LocalNet graph based on built layers.
461
462
        :param inputs: image batch, shape = (batch, f_dim1, f_dim2, f_dim3, ch)
463
        :param training: None or bool.
464
        :param mask: None or tf.Tensor.
465
        :return: shape = (batch, f_dim1, f_dim2, f_dim3, out_channels)
466
        """
467
468
        # encoding / down-sampling
469
        skips = []
470
        encoded = inputs
471
        for d in range(self._depth):
472
            skip = self._encode_convs[d](inputs=encoded, training=training)
473
            encoded = self._encode_pools[d](inputs=skip, training=training)
474
            skips.append(skip)
475
476
        # bottom
477
        decoded = self._bottom_block(inputs=encoded, training=training)
478
479
        # decoding / up-sampling
480
        outs = [decoded]
481
        for d in range(self._depth - 1, min(self._extract_levels) - 1, -1):
482
            decoded = self._decode_deconvs[d](inputs=decoded, training=training)
483
            decoded = self.build_skip_block()([decoded, skips[d]])
484
            decoded = self._decode_convs[d](inputs=decoded, training=training)
485
            outs.append(decoded)
486
487
        # output
488
        output = self._output_block(outs)
489
490
        return output
491