Passed
Pull Request — main (#785)
by
unknown
01:38
created

efficient_net.EfficientNet.__init__()   B

Complexity

Conditions 2

Size

Total Lines 75
Code Lines 42

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 42
dl 0
loc 75
rs 8.872
c 0
b 0
f 0
cc 2
nop 18

How to fix   Long Method    Many Parameters   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
"""This script provides an example of using efficient for training."""
2
3
import os
4
import math
5
import numpy as np
6
import tensorflow as tf
7
from tensorflow.keras import layers
8
from copy import deepcopy
9
from typing import List, Optional, Tuple, Union
10
11
from deepreg.model import layer
12
from deepreg.model.backbone import Backbone
13
from deepreg.model.backbone.local_net import LocalNet
14
from deepreg.model.backbone.u_net import UNet
15
from deepreg.model.layer import Extraction
16
from deepreg.registry import REGISTRY
17
from deepreg.train import train
18
19
20
EFFICIENTNET_PARAMS = {
21
    # model_name: (width_mult, depth_mult, image_size, dropout_rate, dropconnect_rate)
22
    "efficientnet-b0": (1.0, 1.0, 224, 0.2, 0.2),
23
    "efficientnet-b1": (1.0, 1.1, 240, 0.2, 0.2),
24
    "efficientnet-b2": (1.1, 1.2, 260, 0.3, 0.2),
25
    "efficientnet-b3": (1.2, 1.4, 300, 0.3, 0.2),
26
    "efficientnet-b4": (1.4, 1.8, 380, 0.4, 0.2),
27
    "efficientnet-b5": (1.6, 2.2, 456, 0.4, 0.2),
28
    "efficientnet-b6": (1.8, 2.6, 528, 0.5, 0.2),
29
    "efficientnet-b7": (2.0, 3.1, 600, 0.5, 0.2),
30
}
31
32
# Each Blocks Parameters
33
DEFAULT_BLOCKS_ARGS = [
34
    {'kernel_size': 3, 'repeats': 1, 'filters_in': 32, 'filters_out': 16,
35
     'expand_ratio': 1, 'id_skip': True, 'strides': 1, 'se_ratio': 0.25},
36
    {'kernel_size': 3, 'repeats': 2, 'filters_in': 16, 'filters_out': 24,
37
     'expand_ratio': 6, 'id_skip': True, 'strides': 2, 'se_ratio': 0.25},
38
    {'kernel_size': 5, 'repeats': 2, 'filters_in': 24, 'filters_out': 40,
39
     'expand_ratio': 6, 'id_skip': True, 'strides': 2, 'se_ratio': 0.25},
40
    {'kernel_size': 3, 'repeats': 3, 'filters_in': 40, 'filters_out': 80,
41
     'expand_ratio': 6, 'id_skip': True, 'strides': 2, 'se_ratio': 0.25},
42
    {'kernel_size': 5, 'repeats': 3, 'filters_in': 80, 'filters_out': 112,
43
     'expand_ratio': 6, 'id_skip': True, 'strides': 1, 'se_ratio': 0.25},
44
    {'kernel_size': 5, 'repeats': 4, 'filters_in': 112, 'filters_out': 192,
45
     'expand_ratio': 6, 'id_skip': True, 'strides': 2, 'se_ratio': 0.25},
46
    {'kernel_size': 3, 'repeats': 1, 'filters_in': 192, 'filters_out': 320,
47
     'expand_ratio': 6, 'id_skip': True, 'strides': 1, 'se_ratio': 0.25}
48
]
49
50
# Two Kernel Initializer
51
CONV_KERNEL_INITIALIZER = {
52
    'class_name': 'VarianceScaling',
53
    'config': {
54
        'scale': 2.0,
55
        'mode': 'fan_out',
56
        'distribution': 'normal'
57
    }
58
}
59
60
DENSE_KERNEL_INITIALIZER = {
61
    'class_name': 'VarianceScaling',
62
    'config': {
63
        'scale': 1. / 3.,
64
        'mode': 'fan_out',
65
        'distribution': 'uniform'
66
    }
67
}
68
69
70
@REGISTRY.register_backbone(name="efficient_net")
71
class EfficientNet(LocalNet):
72
    """
73
    Class that implements an Efficient-Net for image registration.
74
75
    Reference:
76
    - Author: Mingxing Tan, Quoc V. Le,
77
      EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks
78
      https://arxiv.org/pdf/1905.11946.pdf
79
    """
80
    def __init__(
81
        self,
82
        image_size: tuple,
83
        num_channel_initial: int,
84
        extract_levels: Tuple[int, ...],
85
        out_kernel_initializer: str,
86
        out_activation: str,
87
        out_channels: int,
88
        depth: Optional[int] = None,
89
        pooling: bool = True,
90
        concat_skip: bool = False,
91
        width_coefficient: float = 1.0,
92
        depth_coefficient: float = 1.0,
93
        default_size: int = 224,
94
        dropout_rate: float = 0.2,
95
        drop_connect_rate: float = 0.2,
96
        depth_divisor: int = 8,
97
        name: str = "EfficientNet",
98
        **kwargs,
99
    ):
100
        """
101
        Init.
102
103
        Image is encoded gradually, i from level 0 to D,
104
        then it is decoded gradually, j from level D to 0.
105
        Some of the decoded levels are used for generating extractions.
106
107
        So, extract_levels are between [0, D].
108
109
        :param image_size: such as (dim1, dim2, dim3)
110
        :param num_channel_initial: number of initial channels.
111
        :param extract_levels: from which depths the output will be built.
112
        :param out_kernel_initializer: initializer to use for kernels.
113
        :param out_activation: activation to use at end layer.
114
        :param out_channels: number of channels for the extractions
115
        :param depth: depth of the encoder.
116
            If depth is not given, depth = max(extract_levels) will be used.
117
        :param pooling: for down-sampling, use non-parameterized
118
                        pooling if true, otherwise use conv3d
119
        :param concat_skip: when up-sampling, concatenate skipped
120
                            tensor if true, otherwise use addition
121
        :param width_coefficient: float, scaling coefficient for network width.
122
        :param depth_coefficient: float, scaling coefficient for network depth.
123
        :param default_size: int, default input image size.
124
        :param dropout_rate: float, dropout rate before final classifier layer.
125
        :param drop_connect_rate: float, dropout rate at skip connections.
126
        :param depth_divisor: int divisor for depth.
127
        :param name: name of the backbone.
128
        :param kwargs: additional arguments.
129
        """
130
        if depth is None:
131
            depth = max(extract_levels)
132
        kwargs["encode_kernel_sizes"] = [7] + [3] * depth
133
        super().__init__(
134
            image_size=image_size,
135
            num_channel_initial=num_channel_initial,
136
            depth=depth,
137
            extract_levels=extract_levels,
138
            out_kernel_initializer=out_kernel_initializer,
139
            out_activation=out_activation,
140
            out_channels=out_channels,
141
            use_additive_upsampling = False,
142
            pooling=pooling,
143
            concat_skip=concat_skip,
144
            name=name,
145
            **kwargs,
146
        )
147
148
        self.width_coefficient =  width_coefficient
149
        self.depth_coefficient = depth_coefficient
150
        self.default_size = default_size
151
        self.dropout_rate = dropout_rate
152
        self.drop_connect_rate = drop_connect_rate
153
        self.depth_divisor = depth_divisor
154
        self.activation_fn = tf.nn.swish
155
156 View Code Duplication
    def call(self, inputs: tf.Tensor, training=None, mask=None) -> tf.Tensor:
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
157
        """
158
        Build compute graph based on built layers.
159
160
        :param inputs: image batch, shape = (batch, f_dim1, f_dim2, f_dim3, ch)
161
        :param training: None or bool.
162
        :param mask: None or tf.Tensor.
163
        :return: shape = (batch, f_dim1, f_dim2, f_dim3, out_channels)
164
        """
165
166
        # encoding / down-sampling
167
        skips = []
168
        encoded = inputs
169
        for d in range(self._depth):
170
            skip = self._encode_convs[d](inputs=encoded, training=training)
171
            encoded = self._encode_pools[d](inputs=skip, training=training)
172
            skips.append(skip)
173
174
        # bottom
175
        decoded = self.build_efficient_net(inputs=encoded, training=training)  # type: ignore
176
177
        # decoding / up-sampling
178
        outs = [decoded]
179
        for d in range(self._depth - 1, min(self._extract_levels) - 1, -1):
180
            decoded = self._decode_deconvs[d](inputs=decoded, training=training)
181
            decoded = self.build_skip_block()([decoded, skips[d]])
182
            decoded = self._decode_convs[d](inputs=decoded, training=training)
183
            outs = [decoded] + outs
184
185
        # output
186
        output = self._output_block(outs)  # type: ignore
187
188
        return output
189
190
191
    def build_efficient_net(self, inputs: tf.Tensor, training=None) -> tf.Tensor:
192
        """
193
        Builds graph based on built layers.
194
195
        :param inputs: shape = (batch, f_dim1, f_dim2, f_dim3, in_channels)
196
        :param training:
197
        :param mask:
198
        :return: shape = (batch, f_dim1, f_dim2, f_dim3, out_channels)
199
        """
200
        img_input = layers.Input(tensor=inputs, shape=self.image_size)
201
        bn_axis = 4  
202
        x = img_input
203
        # x = layers.ZeroPadding3D(padding=self.correct_pad(x, 3),
204
        #                         name='stem_conv_pad')(x)
205
206
        x = layers.Conv3D(self.round_filters(32), 3,
207
                        strides=1,
208
                        padding='same',
209
                        use_bias=False,
210
                        kernel_initializer=CONV_KERNEL_INITIALIZER,
211
                        name='stem_conv')(x)
212
        x = layers.BatchNormalization(axis=bn_axis, name='stem_bn')(x)
213
        x = layers.Activation(self.activation_fn, name='stem_activation')(x)
214
        blocks_args = deepcopy(DEFAULT_BLOCKS_ARGS)
215
216
        b = 0
217
        # Calculate the number of blocks
218
        blocks = float(sum(args['repeats'] for args in blocks_args))
219
        for (i, args) in enumerate(blocks_args):
220
            assert args['repeats'] > 0
221
            args['filters_in'] = self.round_filters(args['filters_in'])
222
            args['filters_out'] = self.round_filters(args['filters_out'])
223
224
            for j in range(self.round_repeats(args.pop('repeats'))):
225
                if j > 0:
226
                    args['strides'] = 1
227
                    args['filters_in'] = args['filters_out']
228
                x = self.block(x, self.activation_fn, self.drop_connect_rate * b / blocks,
229
                        name='block{}{}_'.format(i + 1, chr(j + 97)), **args)
230
                b += 1
231
        
232
        x = layers.Conv3D(self.round_filters(128), 1,
233
                        padding='same',
234
                        use_bias=False,
235
                        kernel_initializer=CONV_KERNEL_INITIALIZER,
236
                        name='top_conv')(x)
237
        x = layers.BatchNormalization(axis=bn_axis, name='top_bn')(x)
238
        x = layers.Activation(self.activation_fn, name='top_activation')(x)
239
240
        print("input.shape", inputs.shape, x.shape)
241
        return x
242
243
    def round_filters(self, filters):
244
        """Round number of filters based on depth multiplier."""
245
        filters *= self.width_coefficient
246
        divisor = self.depth_divisor
247
        new_filters = max(divisor, int(filters + divisor / 2) // divisor * divisor)
248
        # Make sure that round down does not go down by more than 10%.
249
        if new_filters < 0.9 * filters:
250
            new_filters += divisor
251
        return int(new_filters)
252
253
    def round_repeats(self, repeats):
254
        return int(math.ceil(self.depth_coefficient * repeats))
255
256
    def correct_pad(self, inputs, kernel_size):
257
        img_dim = 1
258
        input_size = backend.int_shape(inputs)[img_dim:(img_dim + 3)]
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable backend does not seem to be defined.
Loading history...
259
260
        if isinstance(kernel_size, int):
261
            kernel_size = (kernel_size, kernel_size, kernel_size)
262
263
        if input_size[0] is None:
264
            adjust = (1, 1, 1)
265
        else:
266
            adjust = (1 - input_size[0] % 2, 1 - input_size[1] % 2, 1 - input_size[2] % 2)
267
268
        correct = (kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[2] // 2)
269
270
        return ((correct[0] - adjust[0], correct[0]),
271
                (correct[1] - adjust[1], correct[1]),
272
                (correct[2] - adjust[2], correct[2]))
273
274
    def block(self, inputs, activation_fn=tf.nn.swish, drop_rate=0., name='',
275
            filters_in=32, filters_out=16, kernel_size=3, strides=1,
276
            expand_ratio=1, se_ratio=0., id_skip=True):
277
278
        bn_axis = 4
279
280
        filters = filters_in * expand_ratio
281
282
        # Inverted residuals
283
        if expand_ratio != 1:
284
            x = layers.Conv3D(filters, 1,
285
                            padding='same',
286
                            use_bias=False,
287
                            kernel_initializer=CONV_KERNEL_INITIALIZER,
288
                            name=name + 'expand_conv')(inputs)
289
            x = layers.BatchNormalization(axis=bn_axis, name=name + 'expand_bn')(x)
290
            x = layers.Activation(activation_fn, name=name + 'expand_activation')(x)
291
        else:
292
            x = inputs
293
294
        # padding
295
        # if strides == 2:
296
        #     x = layers.ZeroPadding3D(padding=self.correct_pad(x, kernel_size),
297
        #                             name=name + 'dwconv_pad')(x)
298
        #     conv_pad = 'valid'
299
        # else:
300
        #     conv_pad = 'same'
301
302
        # TODO(Sicong): Find DepthwiseConv3D
303
        # x = layers.DepthwiseConv2D(kernel_size,
304
        #                         strides=strides,
305
        #                         padding=conv_pad,
306
        #                         use_bias=False,
307
        #                         depthwise_initializer=CONV_KERNEL_INITIALIZER,
308
        #                         name=name + 'dwconv')(x)
309
        x = layers.BatchNormalization(axis=bn_axis, name=name + 'bn')(x)
310
        x = layers.Activation(activation_fn, name=name + 'activation')(x)
311
312
        if 0 < se_ratio <= 1:
313
            filters_se = max(1, int(filters_in * se_ratio))
314
            se = layers.GlobalAveragePooling3D(name=name + 'se_squeeze')(x)
315
            se = layers.Reshape((1, 1, 1, filters), name=name + 'se_reshape')(se)
316
            se = layers.Conv3D(filters_se, 1,
317
                            padding='same',
318
                            activation=activation_fn,
319
                            kernel_initializer=CONV_KERNEL_INITIALIZER,
320
                            name=name + 'se_reduce')(se)
321
            se = layers.Conv3D(filters, 1,
322
                            padding='same',
323
                            activation='sigmoid',
324
                            kernel_initializer=CONV_KERNEL_INITIALIZER,
325
                            name=name + 'se_expand')(se)
326
            x = layers.multiply([x, se], name=name + 'se_excite')
327
328
        x = layers.Conv3D(filters_out, 1,
329
                        padding='same',
330
                        use_bias=False,
331
                        kernel_initializer=CONV_KERNEL_INITIALIZER,
332
                        name=name + 'project_conv')(x)
333
        x = layers.BatchNormalization(axis=bn_axis, name=name + 'project_bn')(x)
334
335
        if (id_skip is True and strides == 1 and filters_in == filters_out):
336
            if drop_rate > 0:
337
                x = layers.Dropout(drop_rate,
338
                                noise_shape=None,
339
                                name=name + 'drop')(x)
340
            x = layers.add([x, inputs], name=name + 'add')
341
342
        return x
343
344
    def get_config(self) -> dict:
345
        """Return the config dictionary for recreating this class."""
346
        config = super().get_config()
347
        return config
348
349
350
if __name__ == "__main__":
351
    config_path = "examples/config_efficient_net.yaml"
352
    train(
353
        gpu="",
354
        config_path=config_path,
355
        gpu_allow_growth=True,
356
        ckpt_path="",
357
    )
358