Passed
Pull Request — main (#656)
by Yunguan
02:51
created

deepreg.model.backbone.u_net   A

Complexity

Total Complexity 20

Size/Duplication

Total Lines 494
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 20
eloc 262
dl 0
loc 494
rs 10
c 0
b 0
f 0

11 Methods

Rating   Name   Duplication   Size   Complexity  
B UNet.__init__() 0 80 1
A UNet.build_skip_block() 0 15 2
A UNet.build_up_sampling_block() 0 29 1
B UNet.build_encode_layers() 0 64 3
A UNet.build_down_sampling_block() 0 25 2
A UNet.call() 0 33 3
B UNet.build_decode_layers() 0 73 4
A UNet.build_layers() 0 49 1
A UNet.build_bottom_block() 0 25 1
A UNet.build_output_block() 0 26 1
A UNet.build_conv_block() 0 25 1
1
# coding=utf-8
0 ignored issues
show
introduced by
Missing module docstring
Loading history...
2
3
from typing import List, Tuple, Union
4
5
import tensorflow as tf
6
import tensorflow.keras.layers as tfkl
7
from tensorflow.python.keras.utils import conv_utils
8
9
from deepreg.model import layer, layer_util
10
from deepreg.model.backbone.interface import Backbone
11
from deepreg.model.layer import Extraction
12
from deepreg.registry import REGISTRY
13
14
15
@REGISTRY.register_backbone(name="unet")
16
class UNet(Backbone):
17
    """
18
    Class that implements an adapted 3D UNet.
19
20
    Reference:
21
22
    - O. Ronneberger, P. Fischer, and T. Brox,
23
      “U-net: Convolutional networks for biomedical image segmentation,”,
24
      Lecture Notes in Computer Science, 2015, vol. 9351, pp. 234–241.
25
      https://arxiv.org/abs/1505.04597
26
    """
27
28
    def __init__(
29
        self,
30
        image_size: tuple,
31
        num_channel_initial: int,
32
        depth: int,
33
        out_kernel_initializer: str,
34
        out_activation: str,
35
        out_channels: int,
36
        extract_levels: Tuple[int] = (0,),
37
        pooling: bool = True,
38
        concat_skip: bool = False,
39
        encode_kernel_sizes: Union[int, List[int]] = 3,
40
        decode_kernel_sizes: Union[int, List[int]] = 3,
41
        strides: int = 2,
42
        padding: str = "same",
43
        name: str = "Unet",
44
        **kwargs,
45
    ):
46
        """
47
        Initialise UNet.
48
49
        :param image_size: (dim1, dim2, dim3), dims of input image.
50
        :param num_channel_initial: number of initial channels
51
        :param depth: input is at level 0, bottom is at level depth
52
        :param out_kernel_initializer: kernel initializer for the last layer
53
        :param out_activation: activation at the last layer
54
        :param out_channels: number of channels for the output
55
        :param extract_levels: list, which levels from net to extract.
56
        :param pooling: for down-sampling, use non-parameterized
57
                        pooling if true, otherwise use conv3d
58
        :param concat_skip: when up-sampling, concatenate skipped
59
                            tensor if true, otherwise use addition
60
        :param encode_kernel_sizes: kernel size for down-sampling
61
        :param decode_kernel_sizes: kernel size for up-sampling
62
        :param strides: strides for down-sampling
63
        :param padding: padding mode for all conv layers
64
        :param name: name of the backbone.
65
        :param kwargs: additional arguments.
66
        """
67
        super().__init__(
68
            image_size=image_size,
69
            out_channels=out_channels,
70
            num_channel_initial=num_channel_initial,
71
            out_kernel_initializer=out_kernel_initializer,
72
            out_activation=out_activation,
73
            name=name,
74
            **kwargs,
75
        )
76
77
        # save parameters
78
        assert max(extract_levels) <= depth
79
        self._extract_levels = extract_levels
80
        self._depth = depth
81
82
        # save extra parameters
83
        self._concat_skip = concat_skip
84
        self._pooling = pooling
85
86
        # init layers
87
        # all lists start with d = 0
88
        self._encode_convs = None
89
        self._encode_pools = None
90
        self._bottom_block = None
91
        self._decode_deconvs = None
92
        self._decode_convs = None
93
        self._output_block = None
94
95
        # build layers
96
        self.build_layers(
97
            image_size=image_size,
98
            num_channel_initial=num_channel_initial,
99
            depth=depth,
100
            extract_levels=extract_levels,
101
            encode_kernel_sizes=encode_kernel_sizes,
102
            decode_kernel_sizes=decode_kernel_sizes,
103
            strides=strides,
104
            padding=padding,
105
            out_kernel_initializer=out_kernel_initializer,
106
            out_activation=out_activation,
107
            out_channels=out_channels,
108
        )
109
110
    def build_conv_block(
111
        self, filters: int, kernel_size: int, padding: str
112
    ) -> Union[tf.keras.Model, tfkl.Layer]:
113
        """
114
        Build a conv block for down-sampling or up-sampling.
115
116
        This block do not change the tensor shape (width, height, depth),
117
        it only changes the number of channels.
118
119
        :param filters: number of channels for output
120
        :param kernel_size: arg for conv3d
121
        :param padding: arg for conv3d
122
        :return: a block consists of one or multiple layers
123
        """
124
        return tf.keras.Sequential(
125
            [
126
                layer.Conv3dBlock(
127
                    filters=filters,
128
                    kernel_size=kernel_size,
129
                    padding=padding,
130
                ),
131
                layer.ResidualConv3dBlock(
132
                    filters=filters,
133
                    kernel_size=kernel_size,
134
                    padding=padding,
135
                ),
136
            ]
137
        )
138
139
    def build_down_sampling_block(
140
        self, filters: int, kernel_size: int, padding: str, strides: int
141
    ) -> Union[tf.keras.Model, tfkl.Layer]:
142
        """
143
        Build a block for down-sampling.
144
145
        This block changes the tensor shape (width, height, depth),
146
        but it does not changes the number of channels.
147
148
        :param filters: number of channels for output, arg for conv3d
149
        :param kernel_size: arg for pool3d or conv3d
150
        :param padding: arg for pool3d or conv3d
151
        :param strides: arg for pool3d or conv3d
152
        :return: a block consists of one or multiple layers
153
        """
154
        if self._pooling:
0 ignored issues
show
unused-code introduced by
Unnecessary "else" after "return"
Loading history...
155
            return tfkl.MaxPool3D(
156
                pool_size=kernel_size, strides=strides, padding=padding
157
            )
158
        else:
159
            return layer.Conv3dBlock(
160
                filters=filters,
161
                kernel_size=kernel_size,
162
                strides=strides,
163
                padding=padding,
164
            )
165
166
    def build_bottom_block(
167
        self, filters: int, kernel_size: int, padding: str
168
    ) -> Union[tf.keras.Model, tfkl.Layer]:
169
        """
170
        Build a block for bottom layer.
171
172
        This block do not change the tensor shape (width, height, depth),
173
        it only changes the number of channels.
174
175
        :param filters: number of channels for output
176
        :param kernel_size: arg for conv3d
177
        :param padding: arg for conv3d
178
        :return: a block consists of one or multiple layers
179
        """
180
        return tf.keras.Sequential(
181
            [
182
                layer.Conv3dBlock(
183
                    filters=filters,
184
                    kernel_size=kernel_size,
185
                    padding=padding,
186
                ),
187
                layer.ResidualConv3dBlock(
188
                    filters=filters,
189
                    kernel_size=kernel_size,
190
                    padding=padding,
191
                ),
192
            ]
193
        )
194
195
    def build_up_sampling_block(
196
        self,
197
        filters: int,
198
        output_padding: int,
199
        kernel_size: int,
200
        padding: str,
201
        strides: int,
202
        output_shape: tuple,
0 ignored issues
show
Unused Code introduced by
The argument output_shape seems to be unused.
Loading history...
203
    ) -> Union[tf.keras.Model, tfkl.Layer]:
204
        """
205
        Build a block for up-sampling.
206
207
        This block changes the tensor shape (width, height, depth),
208
        but it does not changes the number of channels.
209
210
        :param filters: number of channels for output
211
        :param output_padding: padding for output
212
        :param kernel_size: arg for deconv3d
213
        :param padding: arg for deconv3d
214
        :param strides: arg for deconv3d
215
        :param output_shape: shape of the output tensor
216
        :return: a block consists of one or multiple layers
217
        """
218
        return layer.Deconv3dBlock(
219
            filters=filters,
220
            output_padding=output_padding,
221
            kernel_size=kernel_size,
222
            strides=strides,
223
            padding=padding,
224
        )
225
226
    def build_skip_block(self) -> Union[tf.keras.Model, tfkl.Layer]:
227
        """
228
        Build a block for combining skipped tensor and up-sampled one.
229
230
        This block do not change the tensor shape (width, height, depth),
231
        it only changes the number of channels.
232
233
        The input to this block is a list of tensors.
234
235
        :return: a block consists of one or multiple layers
236
        """
237
        if self._concat_skip:
0 ignored issues
show
unused-code introduced by
Unnecessary "else" after "return"
Loading history...
238
            return tfkl.Concatenate()
239
        else:
240
            return tfkl.Add()
241
242
    def build_output_block(
243
        self,
244
        image_size: Tuple[int],
245
        extract_levels: Tuple[int],
246
        out_channels: int,
247
        out_kernel_initializer: str,
248
        out_activation: str,
249
    ) -> Union[tf.keras.Model, tfkl.Layer]:
250
        """
251
        Build a block for output.
252
253
        The input to this block is a list of tensors.
254
255
        :param image_size: such as (dim1, dim2, dim3)
256
        :param extract_levels: number of extraction levels.
257
        :param out_channels: number of channels for the extractions
258
        :param out_kernel_initializer: initializer to use for kernels.
259
        :param out_activation: activation to use at end layer.
260
        :return: a block consists of one or multiple layers
261
        """
262
        return Extraction(
263
            image_size=image_size,
264
            extract_levels=extract_levels,
265
            out_channels=out_channels,
266
            out_kernel_initializer=out_kernel_initializer,
267
            out_activation=out_activation,
268
        )
269
270
    def build_layers(
271
        self,
272
        image_size: tuple,
273
        num_channel_initial: int,
274
        depth: int,
275
        extract_levels: Tuple[int],
276
        encode_kernel_sizes: Union[int, List[int]],
277
        decode_kernel_sizes: Union[int, List[int]],
278
        strides: int,
279
        padding: str,
280
        out_kernel_initializer: str,
281
        out_activation: str,
282
        out_channels: int,
283
    ):
284
        """
285
        Build layers that will be used in call.
286
287
        :param image_size: (dim1, dim2, dim3).
288
        :param num_channel_initial: number of initial channels.
289
        :param depth: network starts with d = 0, and the bottom has d = depth.
290
        :param extract_levels: from which depths the output will be built.
291
        :param encode_kernel_sizes: kernel size for down-sampling
292
        :param decode_kernel_sizes: kernel size for up-sampling
293
        :param strides: strides for down-sampling
294
        :param padding: padding mode for all conv layers
295
        :param out_kernel_initializer: initializer to use for kernels.
296
        :param out_activation: activation to use at end layer.
297
        :param out_channels: number of channels for the extractions
298
        """
299
        tensor_shapes = self.build_encode_layers(
300
            image_size=image_size,
301
            num_channel_initial=num_channel_initial,
302
            depth=depth,
303
            encode_kernel_sizes=encode_kernel_sizes,
304
            strides=strides,
305
            padding=padding,
306
        )
307
        self.build_decode_layers(
308
            tensor_shapes=tensor_shapes,
309
            image_size=image_size,
310
            num_channel_initial=num_channel_initial,
311
            depth=depth,
312
            extract_levels=extract_levels,
313
            decode_kernel_sizes=decode_kernel_sizes,
314
            strides=strides,
315
            padding=padding,
316
            out_kernel_initializer=out_kernel_initializer,
317
            out_activation=out_activation,
318
            out_channels=out_channels,
319
        )
320
321
    def build_encode_layers(
322
        self,
323
        image_size: tuple,
324
        num_channel_initial: int,
325
        depth: int,
326
        encode_kernel_sizes: Union[int, List[int]],
327
        strides: int,
328
        padding: str,
329
    ) -> List[Tuple]:
330
        """
331
        Build layers for encoding.
332
333
        :param image_size: (dim1, dim2, dim3).
334
        :param num_channel_initial: number of initial channels.
335
        :param depth: network starts with d = 0, and the bottom has d = depth.
336
        :param encode_kernel_sizes: kernel size for down-sampling
337
        :param strides: strides for down-sampling
338
        :param padding: padding mode for all conv layers
339
        :return: list of tensor shapes starting from d = 0
340
        """
341
        # init params
342
        num_channels = [num_channel_initial * (2 ** d) for d in range(depth + 1)]
343
        if isinstance(encode_kernel_sizes, int):
344
            encode_kernel_sizes = [encode_kernel_sizes] * (depth + 1)
345
        assert len(encode_kernel_sizes) == depth + 1
346
347
        # encoding / down-sampling
348
        self._encode_convs = []
349
        self._encode_pools = []
350
        tensor_shape = image_size
351
        tensor_shapes = [tensor_shape]
352
        for d in range(depth):
353
            encode_conv = self.build_conv_block(
354
                filters=num_channels[d],
355
                kernel_size=encode_kernel_sizes[d],
356
                padding=padding,
357
            )
358
            encode_pool = self.build_down_sampling_block(
359
                filters=num_channels[d],
360
                kernel_size=strides,
361
                strides=strides,
362
                padding=padding,
363
            )
364
            tensor_shape = tuple(
365
                conv_utils.conv_output_length(
366
                    input_length=x,
367
                    filter_size=strides,
368
                    padding=padding,
369
                    stride=strides,
370
                    dilation=1,
371
                )
372
                for x in tensor_shape
373
            )
374
            self._encode_convs.append(encode_conv)
375
            self._encode_pools.append(encode_pool)
376
            tensor_shapes.append(tensor_shape)
377
378
        # bottom layer
379
        self._bottom_block = self.build_bottom_block(
380
            filters=num_channels[depth],
381
            kernel_size=encode_kernel_sizes[depth],
382
            padding=padding,
383
        )
384
        return tensor_shapes
385
386
    def build_decode_layers(
387
        self,
388
        tensor_shapes: List[Tuple],
389
        image_size: tuple,
390
        num_channel_initial: int,
391
        depth: int,
392
        extract_levels: Tuple[int],
393
        decode_kernel_sizes: Union[int, List[int]],
394
        strides: int,
395
        padding: str,
396
        out_kernel_initializer: str,
397
        out_activation: str,
398
        out_channels: int,
399
    ):
400
        """
401
        Build layers for decoding.
402
403
        :param tensor_shapes: shapes calculated in encoder
404
        :param image_size: (dim1, dim2, dim3).
405
        :param num_channel_initial: number of initial channels.
406
        :param depth: network starts with d = 0, and the bottom has d = depth.
407
        :param extract_levels: from which depths the output will be built.
408
        :param decode_kernel_sizes: kernel size for up-sampling
409
        :param strides: strides for down-sampling
410
        :param padding: padding mode for all conv layers
411
        :param out_kernel_initializer: initializer to use for kernels.
412
        :param out_activation: activation to use at end layer.
413
        :param out_channels: number of channels for the extractions
414
        """
415
        # init params
416
        min_extract_level = min(extract_levels)
417
        num_channels = [num_channel_initial * (2 ** d) for d in range(depth + 1)]
418
        if isinstance(decode_kernel_sizes, int):
419
            decode_kernel_sizes = [decode_kernel_sizes] * depth
420
        assert len(decode_kernel_sizes) == depth
421
422
        # decoding / up-sampling
423
        self._decode_deconvs = []
424
        self._decode_convs = []
425
        for d in range(depth - 1, min_extract_level - 1, -1):
426
            kernel_size = decode_kernel_sizes[d]
427
            output_padding = layer_util.deconv_output_padding(
428
                input_shape=tensor_shapes[d + 1],
429
                output_shape=tensor_shapes[d],
430
                kernel_size=kernel_size,
431
                stride=strides,
432
                padding=padding,
433
            )
434
            decode_deconv = self.build_up_sampling_block(
435
                filters=num_channels[d],
436
                output_padding=output_padding,
437
                kernel_size=kernel_size,
438
                strides=strides,
439
                padding=padding,
440
                output_shape=tensor_shapes[d],
441
            )
442
            decode_conv = self.build_conv_block(
443
                filters=num_channels[d], kernel_size=kernel_size, padding=padding
444
            )
445
            self._decode_deconvs = [decode_deconv] + self._decode_deconvs
446
            self._decode_convs = [decode_conv] + self._decode_convs
447
        if min_extract_level > 0:
448
            # add Nones to make lists have length depth - 1
449
            self._decode_deconvs = [None] * min_extract_level + self._decode_deconvs
450
            self._decode_convs = [None] * min_extract_level + self._decode_convs
451
452
        # extraction
453
        self._output_block = self.build_output_block(
454
            image_size=image_size,
455
            extract_levels=extract_levels,
456
            out_channels=out_channels,
457
            out_kernel_initializer=out_kernel_initializer,
458
            out_activation=out_activation,
459
        )
460
461
    def call(self, inputs: tf.Tensor, training=None, mask=None) -> tf.Tensor:
0 ignored issues
show
introduced by
"mask, training" missing in parameter type documentation
Loading history...
462
        """
463
        Build LocalNet graph based on built layers.
464
465
        :param inputs: image batch, shape = (batch, f_dim1, f_dim2, f_dim3, ch)
466
        :param training: None or bool.
467
        :param mask: None or tf.Tensor.
468
        :return: shape = (batch, f_dim1, f_dim2, f_dim3, out_channels)
469
        """
470
471
        # encoding / down-sampling
472
        skips = []
473
        encoded = inputs
474
        for d in range(self._depth):
475
            skip = self._encode_convs[d](inputs=encoded, training=training)
476
            encoded = self._encode_pools[d](inputs=skip, training=training)
477
            skips.append(skip)
478
479
        # bottom
480
        decoded = self._bottom_block(inputs=encoded, training=training)
481
482
        # decoding / up-sampling
483
        outs = [decoded]
484
        for d in range(self._depth - 1, min(self._extract_levels) - 1, -1):
485
            decoded = self._decode_deconvs[d](inputs=decoded, training=training)
486
            decoded = self.build_skip_block()([decoded, skips[d]])
487
            decoded = self._decode_convs[d](inputs=decoded, training=training)
488
            outs.append(decoded)
489
490
        # output
491
        output = self._output_block(outs)
492
493
        return output
494