Passed
Pull Request — main (#785)
by
unknown
01:27
created

EfficientNet.block()   B

Complexity

Conditions 7

Size

Total Lines 44
Code Lines 37

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 37
dl 0
loc 44
rs 7.592
c 0
b 0
f 0
cc 7
nop 12

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
# coding=utf-8
2
3
from typing import List, Optional, Tuple, Union
4
5
import numpy as np
6
import tensorflow as tf
7
import tensorflow.keras.layers as tfkl
8
from tensorflow.python.keras.utils import conv_utils
9
10
from deepreg.model import layer, layer_util
11
from deepreg.model.backbone.interface import Backbone
12
from deepreg.model.layer import Extraction
13
from deepreg.registry import REGISTRY
14
15
EFFICIENTNET_PARAMS = {
16
    # model_name: (width_mult, depth_mult, image_size, dropout_rate, dropconnect_rate)
17
    "efficientnet-b0": (1.0, 1.0, 224, 0.2, 0.2),
18
    "efficientnet-b1": (1.0, 1.1, 240, 0.2, 0.2),
19
    "efficientnet-b2": (1.1, 1.2, 260, 0.3, 0.2),
20
    "efficientnet-b3": (1.2, 1.4, 300, 0.3, 0.2),
21
    "efficientnet-b4": (1.4, 1.8, 380, 0.4, 0.2),
22
    "efficientnet-b5": (1.6, 2.2, 456, 0.4, 0.2),
23
    "efficientnet-b6": (1.8, 2.6, 528, 0.5, 0.2),
24
    "efficientnet-b7": (2.0, 3.1, 600, 0.5, 0.2),
25
}
26
27
DEFAULT_BLOCKS_ARGS = [
28
    {'kernel_size': 3, 'repeats': 1, 'filters_in': 32, 'filters_out': 16,
29
     'expand_ratio': 1, 'id_skip': True, 'strides': 1, 'se_ratio': 0.25},
30
    {'kernel_size': 3, 'repeats': 2, 'filters_in': 16, 'filters_out': 24,
31
     'expand_ratio': 6, 'id_skip': True, 'strides': 2, 'se_ratio': 0.25},
32
    {'kernel_size': 5, 'repeats': 2, 'filters_in': 24, 'filters_out': 40,
33
     'expand_ratio': 6, 'id_skip': True, 'strides': 2, 'se_ratio': 0.25},
34
    {'kernel_size': 3, 'repeats': 3, 'filters_in': 40, 'filters_out': 80,
35
     'expand_ratio': 6, 'id_skip': True, 'strides': 2, 'se_ratio': 0.25},
36
    {'kernel_size': 5, 'repeats': 3, 'filters_in': 80, 'filters_out': 112,
37
     'expand_ratio': 6, 'id_skip': True, 'strides': 1, 'se_ratio': 0.25},
38
    {'kernel_size': 5, 'repeats': 4, 'filters_in': 112, 'filters_out': 192,
39
     'expand_ratio': 6, 'id_skip': True, 'strides': 2, 'se_ratio': 0.25},
40
    {'kernel_size': 3, 'repeats': 1, 'filters_in': 192, 'filters_out': 320,
41
     'expand_ratio': 6, 'id_skip': True, 'strides': 1, 'se_ratio': 0.25}
42
]
43
44 View Code Duplication
class AffineHead(tfkl.Layer):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
45
    def __init__(
46
        self,
47
        image_size: tuple,
48
        name: str = "AffineHead",
49
    ):
50
        """
51
        Init.
52
53
        :param image_size: such as (dim1, dim2, dim3)
54
        :param name: name of the layer
55
        """
56
        super().__init__(name=name)
57
        self.reference_grid = layer_util.get_reference_grid(image_size)
58
        self.transform_initial = tf.constant_initializer(
59
            value=list(np.eye(4, 3).reshape((-1)))
60
        )
61
        self._flatten = tfkl.Flatten()
62
        self._dense = tfkl.Dense(units=12, bias_initializer=self.transform_initial)
63
64
    def call(
65
        self, inputs: Union[tf.Tensor, List], **kwargs
66
    ) -> Tuple[tf.Tensor, tf.Tensor]:
67
        """
68
69
        :param inputs: a tensor or a list of tensor with length 1
70
        :param kwargs: additional args
71
        :return: ddf and theta
72
73
            - ddf has shape (batch, dim1, dim2, dim3, 3)
74
            - theta has shape (batch, 4, 3)
75
        """
76
        if isinstance(inputs, list):
77
            inputs = inputs[0]
78
        theta = self._dense(self._flatten(inputs))
79
        theta = tf.reshape(theta, shape=(-1, 4, 3))
80
        # warp the reference grid with affine parameters to output a ddf
81
        grid_warped = layer_util.warp_grid(self.reference_grid, theta)
82
        ddf = grid_warped - self.reference_grid
83
        return ddf, theta
84
85
    def get_config(self):
86
        """Return the config dictionary for recreating this class."""
87
        config = super().get_config()
88
        config.update(image_size=self.reference_grid.shape[:3])
89
        return config
90
91
92
@REGISTRY.register_backbone(name="efficient_net")
93
class EfficientNet(Backbone):
94
    """
95
    Class that implements an Efficient-Net for image registration.
96
97
    Reference:
98
    - Author: Mingxing Tan, Quoc V. Le,
99
      EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks
100
      https://arxiv.org/pdf/1905.11946.pdf
101
    """
102
103
    def __init__(
104
        self,
105
        image_size: tuple,
106
        num_channel_initial: int,
107
        depth: int,
108
        out_kernel_initializer: str,
109
        out_activation: str,
110
        out_channels: int,
111
        extract_levels: Tuple = (0,),
112
        pooling: bool = True,
113
        concat_skip: bool = False,
114
        encode_kernel_sizes: Union[int, List[int]] = 3,
115
        decode_kernel_sizes: Union[int, List[int]] = 3,
116
        encode_num_channels: Optional[Tuple] = None,
117
        decode_num_channels: Optional[Tuple] = None,
118
        strides: int = 2,
119
        padding: str = "same",
120
        width_coefficient: float = 1.0,
121
        depth_coefficient: float = 1.0,
122
        default_size: int = 224,
123
        dropout_rate: float = 0.2,
124
        drop_connect_rate: float = 0.2,
125
        depth_divisor: int = 8,
126
        name: str = "EfficientNet",
127
        **kwargs,
128
    ):
129
        """
130
        Initialise UNet.
131
132
        :param image_size: (dim1, dim2, dim3), dims of input image.
133
        :param num_channel_initial: number of initial channels
134
        :param depth: input is at level 0, bottom is at level depth.
135
        :param out_kernel_initializer: kernel initializer for the last layer
136
        :param out_activation: activation at the last layer
137
        :param out_channels: number of channels for the output
138
        :param extract_levels: list, which levels from net to extract.
139
        :param pooling: for down-sampling, use non-parameterized
140
                        pooling if true, otherwise use conv3d
141
        :param concat_skip: when up-sampling, concatenate skipped
142
                            tensor if true, otherwise use addition
143
        :param encode_kernel_sizes: kernel size for down-sampling
144
        :param decode_kernel_sizes: kernel size for up-sampling
145
        :param encode_num_channels: filters/channels for down-sampling,
146
            by default it is doubled at each layer during down-sampling
147
        :param decode_num_channels: filters/channels for up-sampling,
148
            by default it is the same as encode_num_channels
149
        :param strides: strides for down-sampling
150
        :param padding: padding mode for all conv layers
151
        :param width_coefficient: float, scaling coefficient for network width.
152
        :param depth_coefficient: float, scaling coefficient for network depth.
153
        :param default_size: int, default input image size.
154
        :param dropout_rate: float, dropout rate before final classifier layer.
155
        :param drop_connect_rate: float, dropout rate at skip connections.
156
        :param depth_divisor: int divisor for depth.
157
        :param name: name of the backbone.
158
        :param kwargs: additional arguments.
159
        """
160
        super().__init__(
161
            image_size=image_size,
162
            out_channels=out_channels,
163
            num_channel_initial=num_channel_initial,
164
            out_kernel_initializer=out_kernel_initializer,
165
            out_activation=out_activation,
166
            name=name,
167
            **kwargs,
168
        )
169
170
        # save parameters
171
        assert max(extract_levels) <= depth
172
        self._extract_levels = extract_levels
173
        self._depth = depth
174
175
        # save extra parameters
176
        self._concat_skip = concat_skip
177
        self._pooling = pooling
178
        self._encode_kernel_sizes = encode_kernel_sizes
179
        self._decode_kernel_sizes = decode_kernel_sizes
180
        self._encode_num_channels = encode_num_channels
181
        self._decode_num_channels = decode_num_channels
182
        self._strides = strides
183
        self._padding = padding
184
185
        # efficient parameters
186
        self._width_coefficient =  width_coefficient
187
        self._depth_coefficient = depth_coefficient
188
        self._default_size = default_size
189
        self._dropout_rate = dropout_rate
190
        self._drop_connect_rate = drop_connect_rate
191
        self._depth_divisor = depth_divisor
192
        self._activation_fn = tf.nn.swish
193
194
        # init layers
195
        # all lists start with d = 0
196
        self._encode_convs: List[tfkl.Layer] = []
197
        self._encode_pools: List[tfkl.Layer] = []
198
        self._bottom_block = None
199
        self._decode_convs: List[tfkl.Layer] = []
200
        self._output_block = None
201
202
        # build layers
203
        self.build_layers(
204
            image_size=image_size,
205
            num_channel_initial=num_channel_initial,
206
            depth=depth,
207
            extract_levels=extract_levels,
208
            encode_kernel_sizes=encode_kernel_sizes,
209
            decode_kernel_sizes=decode_kernel_sizes,
210
            encode_num_channels=encode_num_channels,
211
            decode_num_channels=decode_num_channels,
212
            strides=strides,
213
            padding=padding,
214
            out_kernel_initializer=out_kernel_initializer,
215
            out_activation=out_activation,
216
            out_channels=out_channels,
217
        )
218
219
    def build_encode_conv_block(
220
        self, filters: int, kernel_size: int, padding: str
221
    ) -> Union[tf.keras.Model, tfkl.Layer]:
222
        """
223
        Build a conv block for down-sampling
224
225
        This block do not change the tensor shape (width, height, depth),
226
        it only changes the number of channels.
227
228
        :param filters: number of channels for output
229
        :param kernel_size: arg for conv3d
230
        :param padding: arg for conv3d
231
        :return: a block consists of one or multiple layers
232
        """
233
        return tf.keras.Sequential(
234
            [
235
                layer.Conv3dBlock(
236
                    filters=filters,
237
                    kernel_size=kernel_size,
238
                    padding=padding,
239
                ),
240
                layer.ResidualConv3dBlock(
241
                    filters=filters,
242
                    kernel_size=kernel_size,
243
                    padding=padding,
244
                ),
245
            ]
246
        )
247
248
    def build_down_sampling_block(
249
        self, filters: int, kernel_size: int, padding: str, strides: int
250
    ) -> Union[tf.keras.Model, tfkl.Layer]:
251
        """
252
        Build a block for down-sampling.
253
254
        This block changes the tensor shape (width, height, depth),
255
        but it does not changes the number of channels.
256
257
        :param filters: number of channels for output, arg for conv3d
258
        :param kernel_size: arg for pool3d or conv3d
259
        :param padding: arg for pool3d or conv3d
260
        :param strides: arg for pool3d or conv3d
261
        :return: a block consists of one or multiple layers
262
        """
263
        if self._pooling:
264
            return tfkl.MaxPool3D(
265
                pool_size=kernel_size, strides=strides, padding=padding
266
            )
267
        else:
268
            return layer.Conv3dBlock(
269
                filters=filters,
270
                kernel_size=kernel_size,
271
                strides=strides,
272
                padding=padding,
273
            )
274
275
    def build_bottom_block(
276
        self, filters: int, kernel_size: int, padding: str
277
    ) -> Union[tf.keras.Model, tfkl.Layer]:
278
        """
279
        Build a block for bottom layer.
280
281
        This block do not change the tensor shape (width, height, depth),
282
        it only changes the number of channels.
283
284
        :param filters: number of channels for output
285
        :param kernel_size: arg for conv3d
286
        :param padding: arg for conv3d
287
        :return: a block consists of one or multiple layers
288
        """
289
        return tf.keras.Sequential(
290
            [
291
                layer.Conv3dBlock(
292
                    filters=filters,
293
                    kernel_size=kernel_size,
294
                    padding=padding,
295
                ),
296
                layer.ResidualConv3dBlock(
297
                    filters=filters,
298
                    kernel_size=kernel_size,
299
                    padding=padding,
300
                ),
301
            ]
302
        )
303
304
    def build_output_block(
305
        self,
306
        image_size: Tuple[int, ...],
307
        extract_levels: Tuple[int, ...],
308
        out_channels: int,
309
        out_kernel_initializer: str,
310
        out_activation: str,
311
    ) -> Union[tf.keras.Model, tfkl.Layer]:
312
        """
313
        Build a block for output.
314
315
        The input to this block is a list of length 1.
316
        The output has two tensors.
317
318
        :param image_size: such as (dim1, dim2, dim3)
319
        :param extract_levels: not used
320
        :param out_channels: not used
321
        :param out_kernel_initializer: not used
322
        :param out_activation: not used
323
        :return: a block consists of one or multiple layers
324
        """
325
        return AffineHead(image_size=image_size)
326
327
    def build_layers(
328
        self,
329
        image_size: tuple,
330
        num_channel_initial: int,
331
        depth: int,
332
        extract_levels: Tuple[int, ...],
333
        encode_kernel_sizes: Union[int, List[int]],
334
        decode_kernel_sizes: Union[int, List[int]],
335
        encode_num_channels: Optional[Tuple],
336
        decode_num_channels: Optional[Tuple],
337
        strides: int,
338
        padding: str,
339
        out_kernel_initializer: str,
340
        out_activation: str,
341
        out_channels: int,
342
    ):
343
        """
344
        Build layers that will be used in call.
345
346
        :param image_size: (dim1, dim2, dim3).
347
        :param num_channel_initial: number of initial channels.
348
        :param depth: network starts with d = 0, and the bottom has d = depth.
349
        :param extract_levels: from which depths the output will be built.
350
        :param encode_kernel_sizes: kernel size for down-sampling
351
        :param decode_kernel_sizes: kernel size for up-sampling
352
        :param encode_num_channels: filters/channels for down-sampling,
353
            by default it is doubled at each layer during down-sampling
354
        :param decode_num_channels: filters/channels for up-sampling,
355
            by default it is the same as encode_num_channels
356
        :param strides: strides for down-sampling
357
        :param padding: padding mode for all conv layers
358
        :param out_kernel_initializer: initializer to use for kernels.
359
        :param out_activation: activation to use at end layer.
360
        :param out_channels: number of channels for the extractions
361
        """
362
        if encode_num_channels is None:
363
            assert num_channel_initial >= 1
364
            encode_num_channels = tuple(
365
                num_channel_initial * (2 ** d) for d in range(depth + 1)
366
            )
367
        assert len(encode_num_channels) == depth + 1
368
        tensor_shapes = self.build_encode_layers(
369
            image_size=image_size,
370
            num_channels=encode_num_channels,
371
            depth=depth,
372
            encode_kernel_sizes=encode_kernel_sizes,
373
            strides=strides,
374
            padding=padding,
375
        )
376
        self._output_block = self.build_output_block(
377
            image_size=image_size,
378
            extract_levels=extract_levels,
379
            out_channels=out_channels,
380
            out_kernel_initializer=out_kernel_initializer,
381
            out_activation=out_activation,
382
        )
383
384 View Code Duplication
    def build_encode_layers(
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
385
        self,
386
        image_size: Tuple,
387
        num_channels: Tuple,
388
        depth: int,
389
        encode_kernel_sizes: Union[int, List[int]],
390
        strides: int,
391
        padding: str,
392
    ) -> List[Tuple]:
393
        """
394
        Build layers for encoding.
395
396
        :param image_size: (dim1, dim2, dim3).
397
        :param num_channels: number of channels for each layer,
398
            starting from the top layer.
399
        :param depth: network starts with d = 0, and the bottom has d = depth.
400
        :param encode_kernel_sizes: kernel size for down-sampling
401
        :param strides: strides for down-sampling
402
        :param padding: padding mode for all conv layers
403
        :return: list of tensor shapes starting from d = 0
404
        """
405
        if isinstance(encode_kernel_sizes, int):
406
            encode_kernel_sizes = [encode_kernel_sizes] * (depth + 1)
407
        assert len(encode_kernel_sizes) == depth + 1
408
409
        # encoding / down-sampling
410
        self._encode_convs = []
411
        self._encode_pools = []
412
        tensor_shape = image_size
413
        tensor_shapes = [tensor_shape]
414
        for d in range(depth):
415
            encode_conv = self.build_encode_conv_block(
416
                filters=num_channels[d],
417
                kernel_size=encode_kernel_sizes[d],
418
                padding=padding,
419
            )
420
            encode_pool = self.build_down_sampling_block(
421
                filters=num_channels[d],
422
                kernel_size=strides,
423
                strides=strides,
424
                padding=padding,
425
            )
426
            tensor_shape = tuple(
427
                conv_utils.conv_output_length(
428
                    input_length=x,
429
                    filter_size=strides,
430
                    padding=padding,
431
                    stride=strides,
432
                    dilation=1,
433
                )
434
                for x in tensor_shape
435
            )
436
            self._encode_convs.append(encode_conv)
437
            self._encode_pools.append(encode_pool)
438
            tensor_shapes.append(tensor_shape)
439
440
        # bottom layer
441
        self._bottom_block = self.build_bottom_block(
442
            filters=num_channels[depth],
443
            kernel_size=encode_kernel_sizes[depth],
444
            padding=padding,
445
        )
446
        return tensor_shapes
447
448
    def call(self, inputs: tf.Tensor, training=None, mask=None) -> tf.Tensor:
449
        """
450
        Build compute graph based on built layers.
451
452
        :param inputs: image batch, shape = (batch, f_dim1, f_dim2, f_dim3, ch)
453
        :param training: None or bool.
454
        :param mask: None or tf.Tensor.
455
        :return: shape = (batch, f_dim1, f_dim2, f_dim3, out_channels)
456
        """
457
458
        # encoding / down-sampling
459
        # skips = []
460
        # encoded = inputs
461
        # for d in range(self._depth):
462
        #     skip = self._encode_convs[d](inputs=encoded, training=training)
463
        #     encoded = self._encode_pools[d](inputs=skip, training=training)
464
        #     skips.append(skip)
465
466
        # bottom
467
        # decoded = self._bottom_block(inputs=encoded, training=training)  # type: ignore
468
469
        # decoding / up-sampling. TODO(SicongLu): Add efficient_net based decoder. 
470
471
        # output
472
        decoded = self.build_efficient_net(inputs=encoded)  # type: ignore
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable encoded does not seem to be defined.
Loading history...
473
        outs = [decoded]
474
        output = self._output_block(outs)  # type: ignore
475
476
        return output
477
478
    def build_efficient_net(self, inputs: tf.Tensor, training=None) -> tf.Tensor:
479
        """
480
        Builds graph based on built layers.
481
482
        :param inputs: shape = (batch, f_dim1, f_dim2, f_dim3, in_channels)
483
        :param training:
484
        :param mask:
485
        :return: shape = (batch, f_dim1, f_dim2, f_dim3, out_channels)
486
        """
487
        x = inputs
488
        x = layers.Conv3D(32, 3,
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable layers does not seem to be defined.
Loading history...
489
                        strides=1,
490
                        padding='same',
491
                        use_bias=False,
492
                        # kernel_initializer=CONV_KERNEL_INITIALIZER,
493
                        name='stem_conv')(x)
494
        x = layers.BatchNormalization(axis=4, name='stem_bn')(x)
495
        x = layers.Activation(self.activation_fn, name='stem_activation')(x)
496
        blocks_args = deepcopy(DEFAULT_BLOCKS_ARGS)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable deepcopy does not seem to be defined.
Loading history...
497
498
        b = 0
499
        # Calculate the number of blocks
500
        blocks = float(sum(args['repeats'] for args in blocks_args))
501
        for (i, args) in enumerate(blocks_args):
502
            assert args['repeats'] > 0
503
            args['filters_in'] = self.round_filters(args['filters_in'])
504
            args['filters_out'] = self.round_filters(args['filters_out'])
505
506
            for j in range(self.round_repeats(args.pop('repeats'))):
507
                if j > 0:
508
                    args['strides'] = 1
509
                    args['filters_in'] = args['filters_out']
510
                x = self.block(x, self.activation_fn, self.drop_connect_rate * b / blocks,
511
                        name='block{}{}_'.format(i + 1, chr(j + 97)), **args)
512
                b += 1
513
        
514
        x = layers.Conv3D(128, 1,
515
                        padding='same',
516
                        use_bias=False,
517
                        name='top_conv')(x)
518
        x = layers.BatchNormalization(axis=4, name='top_bn')(x)
519
        x = layers.Activation(self.activation_fn, name='top_activation')(x)
520
521
        return x
522
523
    def round_filters(self, filters):
524
        """Round number of filters based on depth multiplier."""
525
        filters *= self.width_coefficient
526
        divisor = self.depth_divisor
527
        new_filters = max(divisor, int(filters + divisor / 2) // divisor * divisor)
528
        # Make sure that round down does not go down by more than 10%.
529
        if new_filters < 0.9 * filters:
530
            new_filters += divisor
531
        return int(new_filters)
532
533
    def round_repeats(self, repeats):
534
        return int(math.ceil(self.depth_coefficient * repeats))
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable math does not seem to be defined.
Loading history...
535
536
    def block(self, inputs, activation_fn=tf.nn.swish, drop_rate=0., name='',
537
            filters_in=32, filters_out=16, kernel_size=3, strides=1,
538
            expand_ratio=1, se_ratio=0., id_skip=True):
539
        filters = filters_in * expand_ratio
540
541
        # Inverted residuals
542
        if expand_ratio != 1:
543
            x = layers.Conv3D(filters, 1,
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable layers does not seem to be defined.
Loading history...
544
                            padding='same',
545
                            use_bias=False,
546
                            name=name + 'expand_conv')(inputs)
547
            x = layers.BatchNormalization(axis=4, name=name + 'expand_bn')(x)
548
            x = layers.Activation(activation_fn, name=name + 'expand_activation')(x)
549
        else:
550
            x = inputs
551
552
        if 0 < se_ratio <= 1:
553
            filters_se = max(1, int(filters_in * se_ratio))
554
            se = layers.GlobalAveragePooling3D(name=name + 'se_squeeze')(x)
555
            se = layers.Reshape((1, 1, 1, filters), name=name + 'se_reshape')(se)
556
            se = layers.Conv3D(filters_se, 1,
557
                            padding='same',
558
                            activation=activation_fn,
559
                            name=name + 'se_reduce')(se)
560
            se = layers.Conv3D(filters, 1,
561
                            padding='same',
562
                            activation='sigmoid',
563
                            name=name + 'se_expand')(se)
564
            x = layers.multiply([x, se], name=name + 'se_excite')
565
566
        x = layers.Conv3D(filters_out, 1,
567
                        padding='same',
568
                        use_bias=False,
569
                        name=name + 'project_conv')(x)
570
        x = layers.BatchNormalization(axis=4, name=name + 'project_bn')(x)
571
572
        if (id_skip is True and strides == 1 and filters_in == filters_out):
573
            if drop_rate > 0:
574
                x = layers.Dropout(drop_rate,
575
                                noise_shape=None,
576
                                name=name + 'drop')(x)
577
            x = layers.add([x, inputs], name=name + 'add')
578
579
        return x
580
581
582
583
    def get_config(self) -> dict:
584
        """Return the config dictionary for recreating this class."""
585
        config = super().get_config()
586
        config.update(
587
            depth=self._depth,
588
            extract_levels=self._extract_levels,
589
            pooling=self._pooling,
590
            concat_skip=self._concat_skip,
591
            encode_kernel_sizes=self._encode_kernel_sizes,
592
            decode_kernel_sizes=self._decode_kernel_sizes,
593
            encode_num_channels=self._encode_num_channels,
594
            decode_num_channels=self._decode_num_channels,
595
            strides=self._strides,
596
            padding=self._padding,
597
        )
598
        return config
599