Passed
Pull Request — main (#785)
by
unknown
01:27
created

EfficientNet.build_encode_layers()   B

Complexity

Conditions 3

Size

Total Lines 63
Code Lines 41

Duplication

Lines 63
Ratio 100 %

Importance

Changes 0
Metric Value
eloc 41
dl 63
loc 63
rs 8.896
c 0
b 0
f 0
cc 3
nop 7

How to fix   Long Method   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

1
# coding=utf-8
2
3
from typing import List, Optional, Tuple, Union
4
5
import numpy as np
6
import tensorflow as tf
7
import tensorflow.keras.layers as tfkl
8
from tensorflow.python.keras.utils import conv_utils
9
10
from deepreg.model import layer, layer_util
11
from deepreg.model.backbone.interface import Backbone
12
from deepreg.model.layer import Extraction
13
from deepreg.registry import REGISTRY
14
15
EFFICIENTNET_PARAMS = {
16
    # model_name: (width_mult, depth_mult, image_size, dropout_rate, dropconnect_rate)
17
    "efficientnet-b0": (1.0, 1.0, 224, 0.2, 0.2),
18
    "efficientnet-b1": (1.0, 1.1, 240, 0.2, 0.2),
19
    "efficientnet-b2": (1.1, 1.2, 260, 0.3, 0.2),
20
    "efficientnet-b3": (1.2, 1.4, 300, 0.3, 0.2),
21
    "efficientnet-b4": (1.4, 1.8, 380, 0.4, 0.2),
22
    "efficientnet-b5": (1.6, 2.2, 456, 0.4, 0.2),
23
    "efficientnet-b6": (1.8, 2.6, 528, 0.5, 0.2),
24
    "efficientnet-b7": (2.0, 3.1, 600, 0.5, 0.2),
25
}
26
27
DEFAULT_BLOCKS_ARGS = [
28
    {'kernel_size': 3, 'repeats': 1, 'filters_in': 32, 'filters_out': 16,
29
     'expand_ratio': 1, 'id_skip': True, 'strides': 1, 'se_ratio': 0.25},
30
    {'kernel_size': 3, 'repeats': 2, 'filters_in': 16, 'filters_out': 24,
31
     'expand_ratio': 6, 'id_skip': True, 'strides': 2, 'se_ratio': 0.25},
32
    {'kernel_size': 5, 'repeats': 2, 'filters_in': 24, 'filters_out': 40,
33
     'expand_ratio': 6, 'id_skip': True, 'strides': 2, 'se_ratio': 0.25},
34
    {'kernel_size': 3, 'repeats': 3, 'filters_in': 40, 'filters_out': 80,
35
     'expand_ratio': 6, 'id_skip': True, 'strides': 2, 'se_ratio': 0.25},
36
    {'kernel_size': 5, 'repeats': 3, 'filters_in': 80, 'filters_out': 112,
37
     'expand_ratio': 6, 'id_skip': True, 'strides': 1, 'se_ratio': 0.25},
38
    {'kernel_size': 5, 'repeats': 4, 'filters_in': 112, 'filters_out': 192,
39
     'expand_ratio': 6, 'id_skip': True, 'strides': 2, 'se_ratio': 0.25},
40
    {'kernel_size': 3, 'repeats': 1, 'filters_in': 192, 'filters_out': 320,
41
     'expand_ratio': 6, 'id_skip': True, 'strides': 1, 'se_ratio': 0.25}
42
]
43
44 View Code Duplication
class AffineHead(tfkl.Layer):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
45
    def __init__(
46
        self,
47
        image_size: tuple,
48
        name: str = "AffineHead",
49
    ):
50
        """
51
        Init.
52
53
        :param image_size: such as (dim1, dim2, dim3)
54
        :param name: name of the layer
55
        """
56
        super().__init__(name=name)
57
        self.reference_grid = layer_util.get_reference_grid(image_size)
58
        self.transform_initial = tf.constant_initializer(
59
            value=list(np.eye(4, 3).reshape((-1)))
60
        )
61
        self._flatten = tfkl.Flatten()
62
        self._dense = tfkl.Dense(units=12, bias_initializer=self.transform_initial)
63
64
    def call(
65
        self, inputs: Union[tf.Tensor, List], **kwargs
66
    ) -> Tuple[tf.Tensor, tf.Tensor]:
67
        """
68
69
        :param inputs: a tensor or a list of tensor with length 1
70
        :param kwargs: additional args
71
        :return: ddf and theta
72
73
            - ddf has shape (batch, dim1, dim2, dim3, 3)
74
            - theta has shape (batch, 4, 3)
75
        """
76
        if isinstance(inputs, list):
77
            inputs = inputs[0]
78
        theta = self._dense(self._flatten(inputs))
79
        theta = tf.reshape(theta, shape=(-1, 4, 3))
80
        # warp the reference grid with affine parameters to output a ddf
81
        grid_warped = layer_util.warp_grid(self.reference_grid, theta)
82
        ddf = grid_warped - self.reference_grid
83
        return ddf, theta
84
85
    def get_config(self):
86
        """Return the config dictionary for recreating this class."""
87
        config = super().get_config()
88
        config.update(image_size=self.reference_grid.shape[:3])
89
        return config
90
91
92
@REGISTRY.register_backbone(name="efficient_net")
93
class EfficientNet(Backbone):
94
    """
95
    Class that implements an Efficient-Net for image registration.
96
97
    Reference:
98
    - Author: Mingxing Tan, Quoc V. Le,
99
      EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks
100
      https://arxiv.org/pdf/1905.11946.pdf
101
    """
102
103
    def __init__(
104
        self,
105
        image_size: tuple,
106
        num_channel_initial: int,
107
        depth: int,
108
        out_kernel_initializer: str,
109
        out_activation: str,
110
        out_channels: int,
111
        extract_levels: Tuple = (0,),
112
        pooling: bool = True,
113
        concat_skip: bool = False,
114
        encode_kernel_sizes: Union[int, List[int]] = 3,
115
        decode_kernel_sizes: Union[int, List[int]] = 3,
116
        encode_num_channels: Optional[Tuple] = None,
117
        decode_num_channels: Optional[Tuple] = None,
118
        strides: int = 2,
119
        padding: str = "same",
120
        width_coefficient: float = 1.0,
121
        depth_coefficient: float = 1.0,
122
        default_size: int = 224,
123
        dropout_rate: float = 0.2,
124
        drop_connect_rate: float = 0.2,
125
        depth_divisor: int = 8,
126
        name: str = "EfficientNet",
127
        **kwargs,
128
    ):
129
        """
130
        Initialise UNet.
131
132
        :param image_size: (dim1, dim2, dim3), dims of input image.
133
        :param num_channel_initial: number of initial channels
134
        :param depth: input is at level 0, bottom is at level depth.
135
        :param out_kernel_initializer: kernel initializer for the last layer
136
        :param out_activation: activation at the last layer
137
        :param out_channels: number of channels for the output
138
        :param extract_levels: list, which levels from net to extract.
139
        :param pooling: for down-sampling, use non-parameterized
140
                        pooling if true, otherwise use conv3d
141
        :param concat_skip: when up-sampling, concatenate skipped
142
                            tensor if true, otherwise use addition
143
        :param encode_kernel_sizes: kernel size for down-sampling
144
        :param decode_kernel_sizes: kernel size for up-sampling
145
        :param encode_num_channels: filters/channels for down-sampling,
146
            by default it is doubled at each layer during down-sampling
147
        :param decode_num_channels: filters/channels for up-sampling,
148
            by default it is the same as encode_num_channels
149
        :param strides: strides for down-sampling
150
        :param padding: padding mode for all conv layers
151
        :param width_coefficient: float, scaling coefficient for network width.
152
        :param depth_coefficient: float, scaling coefficient for network depth.
153
        :param default_size: int, default input image size.
154
        :param dropout_rate: float, dropout rate before final classifier layer.
155
        :param drop_connect_rate: float, dropout rate at skip connections.
156
        :param depth_divisor: int divisor for depth.
157
        :param name: name of the backbone.
158
        :param kwargs: additional arguments.
159
        """
160
        super().__init__(
161
            image_size=image_size,
162
            out_channels=out_channels,
163
            num_channel_initial=num_channel_initial,
164
            out_kernel_initializer=out_kernel_initializer,
165
            out_activation=out_activation,
166
            name=name,
167
            **kwargs,
168
        )
169
170
        # save parameters
171
        assert max(extract_levels) <= depth
172
        self._extract_levels = extract_levels
173
        self._depth = depth
174
175
        # save extra parameters
176
        self._concat_skip = concat_skip
177
        self._pooling = pooling
178
        self._encode_kernel_sizes = encode_kernel_sizes
179
        self._decode_kernel_sizes = decode_kernel_sizes
180
        self._encode_num_channels = encode_num_channels
181
        self._decode_num_channels = decode_num_channels
182
        self._strides = strides
183
        self._padding = padding
184
185
        # efficient parameters
186
        self._width_coefficient =  width_coefficient
187
        self._depth_coefficient = depth_coefficient
188
        self._default_size = default_size
189
        self._dropout_rate = dropout_rate
190
        self._drop_connect_rate = drop_connect_rate
191
        self._depth_divisor = depth_divisor
192
        self._activation_fn = tf.nn.swish
193
194
        # init layers
195
        # all lists start with d = 0
196
        self._encode_convs: List[tfkl.Layer] = []
197
        self._encode_pools: List[tfkl.Layer] = []
198
        self._bottom_block = None
199
        self._decode_convs: List[tfkl.Layer] = []
200
        self._output_block = None
201
202
        # build layers
203
        self.build_layers(
204
            image_size=image_size,
205
            num_channel_initial=num_channel_initial,
206
            depth=depth,
207
            extract_levels=extract_levels,
208
            encode_kernel_sizes=encode_kernel_sizes,
209
            decode_kernel_sizes=decode_kernel_sizes,
210
            encode_num_channels=encode_num_channels,
211
            decode_num_channels=decode_num_channels,
212
            strides=strides,
213
            padding=padding,
214
            out_kernel_initializer=out_kernel_initializer,
215
            out_activation=out_activation,
216
            out_channels=out_channels,
217
        )
218
219
    def build_encode_conv_block(
220
        self, filters: int, kernel_size: int, padding: str
221
    ) -> Union[tf.keras.Model, tfkl.Layer]:
222
        """
223
        Build a conv block for down-sampling
224
225
        This block do not change the tensor shape (width, height, depth),
226
        it only changes the number of channels.
227
228
        :param filters: number of channels for output
229
        :param kernel_size: arg for conv3d
230
        :param padding: arg for conv3d
231
        :return: a block consists of one or multiple layers
232
        """
233
        return tf.keras.Sequential(
234
            [
235
                layer.Conv3dBlock(
236
                    filters=filters,
237
                    kernel_size=kernel_size,
238
                    padding=padding,
239
                ),
240
                layer.ResidualConv3dBlock(
241
                    filters=filters,
242
                    kernel_size=kernel_size,
243
                    padding=padding,
244
                ),
245
            ]
246
        )
247
248
    def build_down_sampling_block(
249
        self, filters: int, kernel_size: int, padding: str, strides: int
250
    ) -> Union[tf.keras.Model, tfkl.Layer]:
251
        """
252
        Build a block for down-sampling.
253
254
        This block changes the tensor shape (width, height, depth),
255
        but it does not changes the number of channels.
256
257
        :param filters: number of channels for output, arg for conv3d
258
        :param kernel_size: arg for pool3d or conv3d
259
        :param padding: arg for pool3d or conv3d
260
        :param strides: arg for pool3d or conv3d
261
        :return: a block consists of one or multiple layers
262
        """
263
        if self._pooling:
264
            return tfkl.MaxPool3D(
265
                pool_size=kernel_size, strides=strides, padding=padding
266
            )
267
        else:
268
            return layer.Conv3dBlock(
269
                filters=filters,
270
                kernel_size=kernel_size,
271
                strides=strides,
272
                padding=padding,
273
            )
274
275
    def build_bottom_block(
276
        self, filters: int, kernel_size: int, padding: str
277
    ) -> Union[tf.keras.Model, tfkl.Layer]:
278
        """
279
        Build a block for bottom layer.
280
281
        This block do not change the tensor shape (width, height, depth),
282
        it only changes the number of channels.
283
284
        :param filters: number of channels for output
285
        :param kernel_size: arg for conv3d
286
        :param padding: arg for conv3d
287
        :return: a block consists of one or multiple layers
288
        """
289
        return tf.keras.Sequential(
290
            [
291
                layer.Conv3dBlock(
292
                    filters=filters,
293
                    kernel_size=kernel_size,
294
                    padding=padding,
295
                ),
296
                layer.ResidualConv3dBlock(
297
                    filters=filters,
298
                    kernel_size=kernel_size,
299
                    padding=padding,
300
                ),
301
            ]
302
        )
303
304
    def build_output_block(
305
        self,
306
        image_size: Tuple[int, ...],
307
        extract_levels: Tuple[int, ...],
308
        out_channels: int,
309
        out_kernel_initializer: str,
310
        out_activation: str,
311
    ) -> Union[tf.keras.Model, tfkl.Layer]:
312
        """
313
        Build a block for output.
314
315
        The input to this block is a list of length 1.
316
        The output has two tensors.
317
318
        :param image_size: such as (dim1, dim2, dim3)
319
        :param extract_levels: not used
320
        :param out_channels: not used
321
        :param out_kernel_initializer: not used
322
        :param out_activation: not used
323
        :return: a block consists of one or multiple layers
324
        """
325
        return AffineHead(image_size=image_size)
326
327
    def build_layers(
328
        self,
329
        image_size: tuple,
330
        num_channel_initial: int,
331
        depth: int,
332
        extract_levels: Tuple[int, ...],
333
        encode_kernel_sizes: Union[int, List[int]],
334
        decode_kernel_sizes: Union[int, List[int]],
335
        encode_num_channels: Optional[Tuple],
336
        decode_num_channels: Optional[Tuple],
337
        strides: int,
338
        padding: str,
339
        out_kernel_initializer: str,
340
        out_activation: str,
341
        out_channels: int,
342
    ):
343
        """
344
        Build layers that will be used in call.
345
346
        :param image_size: (dim1, dim2, dim3).
347
        :param num_channel_initial: number of initial channels.
348
        :param depth: network starts with d = 0, and the bottom has d = depth.
349
        :param extract_levels: from which depths the output will be built.
350
        :param encode_kernel_sizes: kernel size for down-sampling
351
        :param decode_kernel_sizes: kernel size for up-sampling
352
        :param encode_num_channels: filters/channels for down-sampling,
353
            by default it is doubled at each layer during down-sampling
354
        :param decode_num_channels: filters/channels for up-sampling,
355
            by default it is the same as encode_num_channels
356
        :param strides: strides for down-sampling
357
        :param padding: padding mode for all conv layers
358
        :param out_kernel_initializer: initializer to use for kernels.
359
        :param out_activation: activation to use at end layer.
360
        :param out_channels: number of channels for the extractions
361
        """
362
        if encode_num_channels is None:
363
            assert num_channel_initial >= 1
364
            encode_num_channels = tuple(
365
                num_channel_initial * (2 ** d) for d in range(depth + 1)
366
            )
367
        assert len(encode_num_channels) == depth + 1
368
        tensor_shapes = self.build_encode_layers(
369
            image_size=image_size,
370
            num_channels=encode_num_channels,
371
            depth=depth,
372
            encode_kernel_sizes=encode_kernel_sizes,
373
            strides=strides,
374
            padding=padding,
375
        )
376
        self._output_block = self.build_output_block(
377
            image_size=image_size,
378
            extract_levels=extract_levels,
379
            out_channels=out_channels,
380
            out_kernel_initializer=out_kernel_initializer,
381
            out_activation=out_activation,
382
        )
383
384 View Code Duplication
    def build_encode_layers(
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
385
        self,
386
        image_size: Tuple,
387
        num_channels: Tuple,
388
        depth: int,
389
        encode_kernel_sizes: Union[int, List[int]],
390
        strides: int,
391
        padding: str,
392
    ) -> List[Tuple]:
393
        """
394
        Build layers for encoding.
395
396
        :param image_size: (dim1, dim2, dim3).
397
        :param num_channels: number of channels for each layer,
398
            starting from the top layer.
399
        :param depth: network starts with d = 0, and the bottom has d = depth.
400
        :param encode_kernel_sizes: kernel size for down-sampling
401
        :param strides: strides for down-sampling
402
        :param padding: padding mode for all conv layers
403
        :return: list of tensor shapes starting from d = 0
404
        """
405
        if isinstance(encode_kernel_sizes, int):
406
            encode_kernel_sizes = [encode_kernel_sizes] * (depth + 1)
407
        assert len(encode_kernel_sizes) == depth + 1
408
409
        # encoding / down-sampling
410
        self._encode_convs = []
411
        self._encode_pools = []
412
        tensor_shape = image_size
413
        tensor_shapes = [tensor_shape]
414
        for d in range(depth):
415
            encode_conv = self.build_encode_conv_block(
416
                filters=num_channels[d],
417
                kernel_size=encode_kernel_sizes[d],
418
                padding=padding,
419
            )
420
            encode_pool = self.build_down_sampling_block(
421
                filters=num_channels[d],
422
                kernel_size=strides,
423
                strides=strides,
424
                padding=padding,
425
            )
426
            tensor_shape = tuple(
427
                conv_utils.conv_output_length(
428
                    input_length=x,
429
                    filter_size=strides,
430
                    padding=padding,
431
                    stride=strides,
432
                    dilation=1,
433
                )
434
                for x in tensor_shape
435
            )
436
            self._encode_convs.append(encode_conv)
437
            self._encode_pools.append(encode_pool)
438
            tensor_shapes.append(tensor_shape)
439
440
        # bottom layer
441
        self._bottom_block = self.build_bottom_block(
442
            filters=num_channels[depth],
443
            kernel_size=encode_kernel_sizes[depth],
444
            padding=padding,
445
        )
446
        return tensor_shapes
447
448
    def call(self, inputs: tf.Tensor, training=None, mask=None) -> tf.Tensor:
449
        """
450
        Build compute graph based on built layers.
451
452
        :param inputs: image batch, shape = (batch, f_dim1, f_dim2, f_dim3, ch)
453
        :param training: None or bool.
454
        :param mask: None or tf.Tensor.
455
        :return: shape = (batch, f_dim1, f_dim2, f_dim3, out_channels)
456
        """
457
458
        # encoding / down-sampling
459
        # skips = []
460
        # encoded = inputs
461
        # for d in range(self._depth):
462
        #     skip = self._encode_convs[d](inputs=encoded, training=training)
463
        #     encoded = self._encode_pools[d](inputs=skip, training=training)
464
        #     skips.append(skip)
465
466
        # bottom
467
        # decoded = self._bottom_block(inputs=encoded, training=training)  # type: ignore
468
469
        # decoding / up-sampling. TODO(SicongLu): Add efficient_net based decoder. 
470
471
        # output
472
        decoded = self.build_efficient_net(inputs=encoded)  # type: ignore
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable encoded does not seem to be defined.
Loading history...
473
        outs = [decoded]
474
        output = self._output_block(outs)  # type: ignore
475
476
        return output
477
478
    def build_efficient_net(self, inputs: tf.Tensor, training=None) -> tf.Tensor:
479
        """
480
        Builds graph based on built layers.
481
482
        :param inputs: shape = (batch, f_dim1, f_dim2, f_dim3, in_channels)
483
        :param training:
484
        :param mask:
485
        :return: shape = (batch, f_dim1, f_dim2, f_dim3, out_channels)
486
        """
487
        x = inputs
488
        x = layers.Conv3D(32, 3,
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable layers does not seem to be defined.
Loading history...
489
                        strides=1,
490
                        padding='same',
491
                        use_bias=False,
492
                        # kernel_initializer=CONV_KERNEL_INITIALIZER,
493
                        name='stem_conv')(x)
494
        x = layers.BatchNormalization(axis=4, name='stem_bn')(x)
495
        x = layers.Activation(self.activation_fn, name='stem_activation')(x)
496
        blocks_args = deepcopy(DEFAULT_BLOCKS_ARGS)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable deepcopy does not seem to be defined.
Loading history...
497
498
        b = 0
499
        # Calculate the number of blocks
500
        blocks = float(sum(args['repeats'] for args in blocks_args))
501
        for (i, args) in enumerate(blocks_args):
502
            assert args['repeats'] > 0
503
            args['filters_in'] = self.round_filters(args['filters_in'])
504
            args['filters_out'] = self.round_filters(args['filters_out'])
505
506
            for j in range(self.round_repeats(args.pop('repeats'))):
507
                if j > 0:
508
                    args['strides'] = 1
509
                    args['filters_in'] = args['filters_out']
510
                x = self.block(x, self.activation_fn, self.drop_connect_rate * b / blocks,
511
                        name='block{}{}_'.format(i + 1, chr(j + 97)), **args)
512
                b += 1
513
        
514
        x = layers.Conv3D(128, 1,
515
                        padding='same',
516
                        use_bias=False,
517
                        name='top_conv')(x)
518
        x = layers.BatchNormalization(axis=4, name='top_bn')(x)
519
        x = layers.Activation(self.activation_fn, name='top_activation')(x)
520
521
        return x
522
523
    def round_filters(self, filters):
524
        """Round number of filters based on depth multiplier."""
525
        filters *= self.width_coefficient
526
        divisor = self.depth_divisor
527
        new_filters = max(divisor, int(filters + divisor / 2) // divisor * divisor)
528
        # Make sure that round down does not go down by more than 10%.
529
        if new_filters < 0.9 * filters:
530
            new_filters += divisor
531
        return int(new_filters)
532
533
    def round_repeats(self, repeats):
534
        return int(math.ceil(self.depth_coefficient * repeats))
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable math does not seem to be defined.
Loading history...
535
536
    def block(self, inputs, activation_fn=tf.nn.swish, drop_rate=0., name='',
537
            filters_in=32, filters_out=16, kernel_size=3, strides=1,
538
            expand_ratio=1, se_ratio=0., id_skip=True):
539
        filters = filters_in * expand_ratio
540
541
        # Inverted residuals
542
        if expand_ratio != 1:
543
            x = layers.Conv3D(filters, 1,
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable layers does not seem to be defined.
Loading history...
544
                            padding='same',
545
                            use_bias=False,
546
                            name=name + 'expand_conv')(inputs)
547
            x = layers.BatchNormalization(axis=4, name=name + 'expand_bn')(x)
548
            x = layers.Activation(activation_fn, name=name + 'expand_activation')(x)
549
        else:
550
            x = inputs
551
552
        if 0 < se_ratio <= 1:
553
            filters_se = max(1, int(filters_in * se_ratio))
554
            se = layers.GlobalAveragePooling3D(name=name + 'se_squeeze')(x)
555
            se = layers.Reshape((1, 1, 1, filters), name=name + 'se_reshape')(se)
556
            se = layers.Conv3D(filters_se, 1,
557
                            padding='same',
558
                            activation=activation_fn,
559
                            name=name + 'se_reduce')(se)
560
            se = layers.Conv3D(filters, 1,
561
                            padding='same',
562
                            activation='sigmoid',
563
                            name=name + 'se_expand')(se)
564
            x = layers.multiply([x, se], name=name + 'se_excite')
565
566
        x = layers.Conv3D(filters_out, 1,
567
                        padding='same',
568
                        use_bias=False,
569
                        name=name + 'project_conv')(x)
570
        x = layers.BatchNormalization(axis=4, name=name + 'project_bn')(x)
571
572
        if (id_skip is True and strides == 1 and filters_in == filters_out):
573
            if drop_rate > 0:
574
                x = layers.Dropout(drop_rate,
575
                                noise_shape=None,
576
                                name=name + 'drop')(x)
577
            x = layers.add([x, inputs], name=name + 'add')
578
579
        return x
580
581
582
583
    def get_config(self) -> dict:
584
        """Return the config dictionary for recreating this class."""
585
        config = super().get_config()
586
        config.update(
587
            depth=self._depth,
588
            extract_levels=self._extract_levels,
589
            pooling=self._pooling,
590
            concat_skip=self._concat_skip,
591
            encode_kernel_sizes=self._encode_kernel_sizes,
592
            decode_kernel_sizes=self._decode_kernel_sizes,
593
            encode_num_channels=self._encode_num_channels,
594
            decode_num_channels=self._decode_num_channels,
595
            strides=self._strides,
596
            padding=self._padding,
597
        )
598
        return config
599