Passed
Pull Request — main (#785)
by
unknown
01:24
created

efficient_net.EfficientNet.block()   B

Complexity

Conditions 7

Size

Total Lines 70
Code Lines 44

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 44
dl 0
loc 70
rs 7.424
c 0
b 0
f 0
cc 7
nop 12

How to fix   Long Method    Many Parameters   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
"""This script provides an example of using efficient for training."""
2
import os
3
import math
4
import tensorflow as tf
5
import numpy as np
6
from copy import deepcopy
7
from tensorflow import keras
8
from tensorflow.keras import backend
9
from tensorflow.keras import layers
10
from tensorflow.keras.models import Model
11
from tensorflow.keras.applications import imagenet_utils
12
from tensorflow.keras.applications.imagenet_utils import decode_predictions
13
from tensorflow.keras.preprocessing import image
14
15
from deepreg.model.backbone import Backbone
16
from deepreg.registry import REGISTRY
17
from deepreg.train import train
18
19
efficientnet_params = {
20
    # model_name: (width_mult, depth_mult, image_size, dropout_rate, dropconnect_rate)
21
    "efficientnet-b0": (1.0, 1.0, 224, 0.2, 0.2),
22
    "efficientnet-b1": (1.0, 1.1, 240, 0.2, 0.2),
23
    "efficientnet-b2": (1.1, 1.2, 260, 0.3, 0.2),
24
    "efficientnet-b3": (1.2, 1.4, 300, 0.3, 0.2),
25
    "efficientnet-b4": (1.4, 1.8, 380, 0.4, 0.2),
26
    "efficientnet-b5": (1.6, 2.2, 456, 0.4, 0.2),
27
    "efficientnet-b6": (1.8, 2.6, 528, 0.5, 0.2),
28
    "efficientnet-b7": (2.0, 3.1, 600, 0.5, 0.2),
29
}
30
31
32
# Each Blocks Parameters
33
DEFAULT_BLOCKS_ARGS = [
34
    {'kernel_size': 3, 'repeats': 1, 'filters_in': 32, 'filters_out': 16,
35
     'expand_ratio': 1, 'id_skip': True, 'strides': 1, 'se_ratio': 0.25},
36
    {'kernel_size': 3, 'repeats': 2, 'filters_in': 16, 'filters_out': 24,
37
     'expand_ratio': 6, 'id_skip': True, 'strides': 2, 'se_ratio': 0.25},
38
    {'kernel_size': 5, 'repeats': 2, 'filters_in': 24, 'filters_out': 40,
39
     'expand_ratio': 6, 'id_skip': True, 'strides': 2, 'se_ratio': 0.25},
40
    {'kernel_size': 3, 'repeats': 3, 'filters_in': 40, 'filters_out': 80,
41
     'expand_ratio': 6, 'id_skip': True, 'strides': 2, 'se_ratio': 0.25},
42
    {'kernel_size': 5, 'repeats': 3, 'filters_in': 80, 'filters_out': 112,
43
     'expand_ratio': 6, 'id_skip': True, 'strides': 1, 'se_ratio': 0.25},
44
    {'kernel_size': 5, 'repeats': 4, 'filters_in': 112, 'filters_out': 192,
45
     'expand_ratio': 6, 'id_skip': True, 'strides': 2, 'se_ratio': 0.25},
46
    {'kernel_size': 3, 'repeats': 1, 'filters_in': 192, 'filters_out': 320,
47
     'expand_ratio': 6, 'id_skip': True, 'strides': 1, 'se_ratio': 0.25}
48
]
49
50
# Two Kernel Initializer
51
CONV_KERNEL_INITIALIZER = {
52
    'class_name': 'VarianceScaling',
53
    'config': {
54
        'scale': 2.0,
55
        'mode': 'fan_out',
56
        'distribution': 'normal'
57
    }
58
}
59
60
DENSE_KERNEL_INITIALIZER = {
61
    'class_name': 'VarianceScaling',
62
    'config': {
63
        'scale': 1. / 3.,
64
        'mode': 'fan_out',
65
        'distribution': 'uniform'
66
    }
67
}
68
69
70
@REGISTRY.register_backbone(name="efficient_net")
71
class EfficientNet(Backbone):
72
    """
73
    A dummy custom model for demonstration purpose only
74
    """
75
76
    def __init__(
77
        self,
78
        image_size: tuple,
79
        out_channels: int,
80
        num_channel_initial: int,
81
        out_kernel_initializer: str,
82
        out_activation: str,
83
        name: str = "EfficientNet",
84
        width_coefficient: float = 1.0,
85
        depth_coefficient: float = 1.0,
86
        default_size: int = 224,
87
        dropout_rate: float = 0.2,
88
        drop_connect_rate: float = 0.2,
89
        depth_divisor: int = 8,
90
        **kwargs,
91
    ):
92
        """
93
        Init.
94
95
        :param image_size: (dim1, dim2, dim3), dims of input image.
96
        :param out_channels: number of channels for the output
97
        :param num_channel_initial: number of initial channels
98
        :param depth: input is at level 0, bottom is at level depth
99
        :param out_kernel_initializer: kernel initializer for the last layer
100
        :param out_activation: activation at the last layer
101
        :param name: name of the backbone
102
        :param kwargs: additional arguments.
103
        """
104
        super().__init__(
105
            image_size=image_size,
106
            out_channels=out_channels,
107
            num_channel_initial=num_channel_initial,
108
            out_kernel_initializer=out_kernel_initializer,
109
            out_activation=out_activation,
110
            name=name,
111
            **kwargs,
112
        )
113
114
        self.width_coefficient =  width_coefficient
115
        self.depth_coefficient = depth_coefficient
116
        self.default_size = default_size
117
        self.dropout_rate = dropout_rate
118
        self.drop_connect_rate = drop_connect_rate
119
        self.depth_divisor = depth_divisor
120
        self.activation_fn = tf.nn.swish
121
122
123
    def correct_pad(self, inputs, kernel_size):
124
        img_dim = 1
125
        input_size = backend.int_shape(inputs)[img_dim:(img_dim + 3)]
126
127
        if isinstance(kernel_size, int):
128
            kernel_size = (kernel_size, kernel_size, kernel_size)
129
130
        if input_size[0] is None:
131
            adjust = (1, 1, 1)
132
        else:
133
            adjust = (1 - input_size[0] % 2, 1 - input_size[1] % 2, 1 - input_size[2] % 2)
134
135
        correct = (kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[2] // 2)
136
137
        return ((correct[0] - adjust[0], correct[0]),
138
                (correct[1] - adjust[1], correct[1]),
139
                (correct[2] - adjust[2], correct[2]))
140
141
    def block(self, inputs, activation_fn=tf.nn.swish, drop_rate=0., name='',
142
            filters_in=32, filters_out=16, kernel_size=3, strides=1,
143
            expand_ratio=1, se_ratio=0., id_skip=True):
144
145
        bn_axis = 4
146
147
        filters = filters_in * expand_ratio
148
149
        # Inverted residuals
150
        if expand_ratio != 1:
151
            x = layers.Conv3D(filters, 1,
152
                            padding='same',
153
                            use_bias=False,
154
                            kernel_initializer=CONV_KERNEL_INITIALIZER,
155
                            name=name + 'expand_conv')(inputs)
156
            x = layers.BatchNormalization(axis=bn_axis, name=name + 'expand_bn')(x)
157
            x = layers.Activation(activation_fn, name=name + 'expand_activation')(x)
158
        else:
159
            x = inputs
160
161
        # padding
162
        # if strides == 2:
163
        #     x = layers.ZeroPadding3D(padding=self.correct_pad(x, kernel_size),
164
        #                             name=name + 'dwconv_pad')(x)
165
        #     conv_pad = 'valid'
166
        # else:
167
        #     conv_pad = 'same'
168
169
        # TODO(Sicong): Find DepthwiseConv3D
170
        # x = layers.DepthwiseConv2D(kernel_size,
171
        #                         strides=strides,
172
        #                         padding=conv_pad,
173
        #                         use_bias=False,
174
        #                         depthwise_initializer=CONV_KERNEL_INITIALIZER,
175
        #                         name=name + 'dwconv')(x)
176
        x = layers.BatchNormalization(axis=bn_axis, name=name + 'bn')(x)
177
        x = layers.Activation(activation_fn, name=name + 'activation')(x)
178
179
        if 0 < se_ratio <= 1:
180
            filters_se = max(1, int(filters_in * se_ratio))
181
            se = layers.GlobalAveragePooling3D(name=name + 'se_squeeze')(x)
182
            se = layers.Reshape((1, 1, 1, filters), name=name + 'se_reshape')(se)
183
            se = layers.Conv3D(filters_se, 1,
184
                            padding='same',
185
                            activation=activation_fn,
186
                            kernel_initializer=CONV_KERNEL_INITIALIZER,
187
                            name=name + 'se_reduce')(se)
188
            se = layers.Conv3D(filters, 1,
189
                            padding='same',
190
                            activation='sigmoid',
191
                            kernel_initializer=CONV_KERNEL_INITIALIZER,
192
                            name=name + 'se_expand')(se)
193
            x = layers.multiply([x, se], name=name + 'se_excite')
194
195
        # part3
196
        x = layers.Conv3D(filters_out, 1,
197
                        padding='same',
198
                        use_bias=False,
199
                        kernel_initializer=CONV_KERNEL_INITIALIZER,
200
                        name=name + 'project_conv')(x)
201
        x = layers.BatchNormalization(axis=bn_axis, name=name + 'project_bn')(x)
202
203
        if (id_skip is True and strides == 1 and filters_in == filters_out):
204
            if drop_rate > 0:
205
                x = layers.Dropout(drop_rate,
206
                                noise_shape=None,
207
                                name=name + 'drop')(x)
208
            x = layers.add([x, inputs], name=name + 'add')
209
210
        return x
211
212
213
    def call(self, inputs: tf.Tensor, training=None, mask=None) -> tf.Tensor:
214
        """
215
        Builds graph based on built layers.
216
217
        :param inputs: shape = (batch, f_dim1, f_dim2, f_dim3, in_channels)
218
        :param training:
219
        :param mask:
220
        :return: shape = (batch, f_dim1, f_dim2, f_dim3, out_channels)
221
        """
222
        img_input = layers.Input(tensor=inputs, shape=self.image_size)
223
        bn_axis = 4  
224
        # Build stem
225
        x = img_input
226
        # x = layers.ZeroPadding3D(padding=self.correct_pad(x, 3),
227
        #                         name='stem_conv_pad')(x)
228
229
        x = layers.Conv3D(self.round_filters(32), 3,
230
                        strides=2,
231
                        padding='same',
232
                        use_bias=False,
233
                        kernel_initializer=CONV_KERNEL_INITIALIZER,
234
                        name='stem_conv')(x)
235
        x = layers.BatchNormalization(axis=bn_axis, name='stem_bn')(x)
236
        x = layers.Activation(self.activation_fn, name='stem_activation')(x)
237
        blocks_args = deepcopy(DEFAULT_BLOCKS_ARGS)
238
239
        b = 0
240
        # 计算总的block的数量
241
        blocks = float(sum(args['repeats'] for args in blocks_args))
242
        for (i, args) in enumerate(blocks_args):
243
            assert args['repeats'] > 0
244
            args['filters_in'] = self.round_filters(args['filters_in'])
245
            args['filters_out'] = self.round_filters(args['filters_out'])
246
247
            for j in range(self.round_repeats(args.pop('repeats'))):
248
                if j > 0:
249
                    args['strides'] = 1
250
                    args['filters_in'] = args['filters_out']
251
                x = self.block(x, self.activation_fn, self.drop_connect_rate * b / blocks,
252
                        name='block{}{}_'.format(i + 1, chr(j + 97)), **args)
253
                b += 1
254
        
255
        x = layers.Conv3D(self.round_filters(1280), 1,
256
                        padding='same',
257
                        use_bias=False,
258
                        kernel_initializer=CONV_KERNEL_INITIALIZER,
259
                        name='top_conv')(x)
260
        x = layers.BatchNormalization(axis=bn_axis, name='top_bn')(x)
261
        x = layers.Activation(self.activation_fn, name='top_activation')(x)
262
263
        # Use GlobalAveragePooling2D replace Fully-connected layer
264
        # x = layers.GlobalAveragePooling3D(name='avg_pool')(x)
265
        print("input.shape", inputs.shape, x.shape)
266
        # if dropout_rate > 0:
267
        #     x = layers.Dropout(dropout_rate, name='top_dropout')(x)
268
269
        # x = layers.Dense(classes,
270
        #                     activation='softmax',
271
        #                     kernel_initializer=DENSE_KERNEL_INITIALIZER,
272
        #                     name='probs')(x)
273
274
        return x
275
276
    # 保证filter的大小可以被8整除
277
    def round_filters(self, filters):
278
        """Round number of filters based on depth multiplier."""
279
        filters *= self.width_coefficient
280
        divisor = self.depth_divisor
281
        new_filters = max(divisor, int(filters + divisor / 2) // divisor * divisor)
282
        # Make sure that round down does not go down by more than 10%.
283
        if new_filters < 0.9 * filters:
284
            new_filters += divisor
285
        return int(new_filters)
286
287
    # 重复次数,取顶
288
    def round_repeats(self, repeats):
289
        return int(math.ceil(self.depth_coefficient * repeats))
290
291
292
config_path = "examples/config_efficient_net.yaml"
293
train(
294
    gpu="",
295
    config_path=config_path,
296
    gpu_allow_growth=True,
297
    ckpt_path="",
298
)
299