Passed
Pull Request — main (#785)
by
unknown
01:28
created

efficient_net.EfficientNet.call()   A

Complexity

Conditions 4

Size

Total Lines 62
Code Lines 35

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 35
dl 0
loc 62
rs 9.0399
c 0
b 0
f 0
cc 4
nop 4

How to fix   Long Method   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

1
"""This script provides an example of using efficient for training."""
2
import os
3
import math
4
import tensorflow as tf
5
import numpy as np
6
from copy import deepcopy
7
from tensorflow import keras
8
from tensorflow.keras import backend
9
from tensorflow.keras import layers
10
from tensorflow.keras.models import Model
11
from tensorflow.keras.applications import imagenet_utils
12
from tensorflow.keras.applications.imagenet_utils import decode_predictions
13
from tensorflow.keras.preprocessing import image
14
15
from deepreg.model.backbone import Backbone
16
from deepreg.registry import REGISTRY
17
from deepreg.train import train
18
19
efficientnet_params = {
20
    # model_name: (width_mult, depth_mult, image_size, dropout_rate, dropconnect_rate)
21
    "efficientnet-b0": (1.0, 1.0, 224, 0.2, 0.2),
22
    "efficientnet-b1": (1.0, 1.1, 240, 0.2, 0.2),
23
    "efficientnet-b2": (1.1, 1.2, 260, 0.3, 0.2),
24
    "efficientnet-b3": (1.2, 1.4, 300, 0.3, 0.2),
25
    "efficientnet-b4": (1.4, 1.8, 380, 0.4, 0.2),
26
    "efficientnet-b5": (1.6, 2.2, 456, 0.4, 0.2),
27
    "efficientnet-b6": (1.8, 2.6, 528, 0.5, 0.2),
28
    "efficientnet-b7": (2.0, 3.1, 600, 0.5, 0.2),
29
}
30
31
32
# Each Blocks Parameters
33
DEFAULT_BLOCKS_ARGS = [
34
    {'kernel_size': 3, 'repeats': 1, 'filters_in': 32, 'filters_out': 16,
35
     'expand_ratio': 1, 'id_skip': True, 'strides': 1, 'se_ratio': 0.25},
36
    {'kernel_size': 3, 'repeats': 2, 'filters_in': 16, 'filters_out': 24,
37
     'expand_ratio': 6, 'id_skip': True, 'strides': 2, 'se_ratio': 0.25},
38
    {'kernel_size': 5, 'repeats': 2, 'filters_in': 24, 'filters_out': 40,
39
     'expand_ratio': 6, 'id_skip': True, 'strides': 2, 'se_ratio': 0.25},
40
    {'kernel_size': 3, 'repeats': 3, 'filters_in': 40, 'filters_out': 80,
41
     'expand_ratio': 6, 'id_skip': True, 'strides': 2, 'se_ratio': 0.25},
42
    {'kernel_size': 5, 'repeats': 3, 'filters_in': 80, 'filters_out': 112,
43
     'expand_ratio': 6, 'id_skip': True, 'strides': 1, 'se_ratio': 0.25},
44
    {'kernel_size': 5, 'repeats': 4, 'filters_in': 112, 'filters_out': 192,
45
     'expand_ratio': 6, 'id_skip': True, 'strides': 2, 'se_ratio': 0.25},
46
    {'kernel_size': 3, 'repeats': 1, 'filters_in': 192, 'filters_out': 320,
47
     'expand_ratio': 6, 'id_skip': True, 'strides': 1, 'se_ratio': 0.25}
48
]
49
50
# Two Kernel Initializer
51
CONV_KERNEL_INITIALIZER = {
52
    'class_name': 'VarianceScaling',
53
    'config': {
54
        'scale': 2.0,
55
        'mode': 'fan_out',
56
        'distribution': 'normal'
57
    }
58
}
59
60
DENSE_KERNEL_INITIALIZER = {
61
    'class_name': 'VarianceScaling',
62
    'config': {
63
        'scale': 1. / 3.,
64
        'mode': 'fan_out',
65
        'distribution': 'uniform'
66
    }
67
}
68
69
70
@REGISTRY.register_backbone(name="efficient_net")
71
class EfficientNet(Backbone):
72
    """
73
    A dummy custom model for demonstration purpose only
74
    """
75
76
    def __init__(
77
        self,
78
        image_size: tuple,
79
        out_channels: int,
80
        num_channel_initial: int,
81
        out_kernel_initializer: str,
82
        out_activation: str,
83
        name: str = "EfficientNet",
84
        width_coefficient: float = 1.0,
85
        depth_coefficient: float = 1.0,
86
        default_size: int = 224,
87
        dropout_rate: float = 0.2,
88
        drop_connect_rate: float = 0.2,
89
        depth_divisor: int = 8,
90
        **kwargs,
91
    ):
92
        """
93
        Init.
94
95
        :param image_size: (dim1, dim2, dim3), dims of input image.
96
        :param out_channels: number of channels for the output
97
        :param num_channel_initial: number of initial channels
98
        :param depth: input is at level 0, bottom is at level depth
99
        :param out_kernel_initializer: kernel initializer for the last layer
100
        :param out_activation: activation at the last layer
101
        :param name: name of the backbone
102
        :param kwargs: additional arguments.
103
        """
104
        super().__init__(
105
            image_size=image_size,
106
            out_channels=out_channels,
107
            num_channel_initial=num_channel_initial,
108
            out_kernel_initializer=out_kernel_initializer,
109
            out_activation=out_activation,
110
            name=name,
111
            **kwargs,
112
        )
113
114
        self.width_coefficient =  width_coefficient
115
        self.depth_coefficient = depth_coefficient
116
        self.default_size = default_size
117
        self.dropout_rate = dropout_rate
118
        self.drop_connect_rate = drop_connect_rate
119
        self.depth_divisor = depth_divisor
120
        self.activation_fn = tf.nn.swish
121
122
123
    def correct_pad(self, inputs, kernel_size):
124
        img_dim = 1
125
        input_size = backend.int_shape(inputs)[img_dim:(img_dim + 3)]
126
127
        if isinstance(kernel_size, int):
128
            kernel_size = (kernel_size, kernel_size, kernel_size)
129
130
        if input_size[0] is None:
131
            adjust = (1, 1, 1)
132
        else:
133
            adjust = (1 - input_size[0] % 2, 1 - input_size[1] % 2, 1 - input_size[2] % 2)
134
135
        correct = (kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[2] // 2)
136
137
        return ((correct[0] - adjust[0], correct[0]),
138
                (correct[1] - adjust[1], correct[1]),
139
                (correct[2] - adjust[2], correct[2]))
140
141
    def block(self, inputs, activation_fn=tf.nn.swish, drop_rate=0., name='',
142
            filters_in=32, filters_out=16, kernel_size=3, strides=1,
143
            expand_ratio=1, se_ratio=0., id_skip=True):
144
145
        bn_axis = 4
146
147
        filters = filters_in * expand_ratio
148
149
        # Inverted residuals
150
        if expand_ratio != 1:
151
            x = layers.Conv3D(filters, 1,
152
                            padding='same',
153
                            use_bias=False,
154
                            kernel_initializer=CONV_KERNEL_INITIALIZER,
155
                            name=name + 'expand_conv')(inputs)
156
            x = layers.BatchNormalization(axis=bn_axis, name=name + 'expand_bn')(x)
157
            x = layers.Activation(activation_fn, name=name + 'expand_activation')(x)
158
        else:
159
            x = inputs
160
161
        # padding
162
        # if strides == 2:
163
        #     x = layers.ZeroPadding3D(padding=self.correct_pad(x, kernel_size),
164
        #                             name=name + 'dwconv_pad')(x)
165
        #     conv_pad = 'valid'
166
        # else:
167
        #     conv_pad = 'same'
168
169
        # TODO(Sicong): Find DepthwiseConv3D
170
        # x = layers.DepthwiseConv2D(kernel_size,
171
        #                         strides=strides,
172
        #                         padding=conv_pad,
173
        #                         use_bias=False,
174
        #                         depthwise_initializer=CONV_KERNEL_INITIALIZER,
175
        #                         name=name + 'dwconv')(x)
176
        x = layers.BatchNormalization(axis=bn_axis, name=name + 'bn')(x)
177
        x = layers.Activation(activation_fn, name=name + 'activation')(x)
178
179
        if 0 < se_ratio <= 1:
180
            filters_se = max(1, int(filters_in * se_ratio))
181
            se = layers.GlobalAveragePooling3D(name=name + 'se_squeeze')(x)
182
            se = layers.Reshape((1, 1, 1, filters), name=name + 'se_reshape')(se)
183
            se = layers.Conv3D(filters_se, 1,
184
                            padding='same',
185
                            activation=activation_fn,
186
                            kernel_initializer=CONV_KERNEL_INITIALIZER,
187
                            name=name + 'se_reduce')(se)
188
            se = layers.Conv3D(filters, 1,
189
                            padding='same',
190
                            activation='sigmoid',
191
                            kernel_initializer=CONV_KERNEL_INITIALIZER,
192
                            name=name + 'se_expand')(se)
193
            x = layers.multiply([x, se], name=name + 'se_excite')
194
195
        # part3
196
        x = layers.Conv3D(filters_out, 1,
197
                        padding='same',
198
                        use_bias=False,
199
                        kernel_initializer=CONV_KERNEL_INITIALIZER,
200
                        name=name + 'project_conv')(x)
201
        x = layers.BatchNormalization(axis=bn_axis, name=name + 'project_bn')(x)
202
203
        if (id_skip is True and strides == 1 and filters_in == filters_out):
204
            if drop_rate > 0:
205
                x = layers.Dropout(drop_rate,
206
                                noise_shape=None,
207
                                name=name + 'drop')(x)
208
            x = layers.add([x, inputs], name=name + 'add')
209
210
        return x
211
212
213
    def call(self, inputs: tf.Tensor, training=None, mask=None) -> tf.Tensor:
214
        """
215
        Builds graph based on built layers.
216
217
        :param inputs: shape = (batch, f_dim1, f_dim2, f_dim3, in_channels)
218
        :param training:
219
        :param mask:
220
        :return: shape = (batch, f_dim1, f_dim2, f_dim3, out_channels)
221
        """
222
        img_input = layers.Input(tensor=inputs, shape=self.image_size)
223
        bn_axis = 4  
224
        # Build stem
225
        x = img_input
226
        # x = layers.ZeroPadding3D(padding=self.correct_pad(x, 3),
227
        #                         name='stem_conv_pad')(x)
228
229
        x = layers.Conv3D(self.round_filters(32), 3,
230
                        strides=2,
231
                        padding='same',
232
                        use_bias=False,
233
                        kernel_initializer=CONV_KERNEL_INITIALIZER,
234
                        name='stem_conv')(x)
235
        x = layers.BatchNormalization(axis=bn_axis, name='stem_bn')(x)
236
        x = layers.Activation(self.activation_fn, name='stem_activation')(x)
237
        blocks_args = deepcopy(DEFAULT_BLOCKS_ARGS)
238
239
        b = 0
240
        # 计算总的block的数量
241
        blocks = float(sum(args['repeats'] for args in blocks_args))
242
        for (i, args) in enumerate(blocks_args):
243
            assert args['repeats'] > 0
244
            args['filters_in'] = self.round_filters(args['filters_in'])
245
            args['filters_out'] = self.round_filters(args['filters_out'])
246
247
            for j in range(self.round_repeats(args.pop('repeats'))):
248
                if j > 0:
249
                    args['strides'] = 1
250
                    args['filters_in'] = args['filters_out']
251
                x = self.block(x, self.activation_fn, self.drop_connect_rate * b / blocks,
252
                        name='block{}{}_'.format(i + 1, chr(j + 97)), **args)
253
                b += 1
254
        
255
        x = layers.Conv3D(self.round_filters(1280), 1,
256
                        padding='same',
257
                        use_bias=False,
258
                        kernel_initializer=CONV_KERNEL_INITIALIZER,
259
                        name='top_conv')(x)
260
        x = layers.BatchNormalization(axis=bn_axis, name='top_bn')(x)
261
        x = layers.Activation(self.activation_fn, name='top_activation')(x)
262
263
        # Use GlobalAveragePooling2D replace Fully-connected layer
264
        # x = layers.GlobalAveragePooling3D(name='avg_pool')(x)
265
        print("input.shape", inputs.shape, x.shape)
266
        # if dropout_rate > 0:
267
        #     x = layers.Dropout(dropout_rate, name='top_dropout')(x)
268
269
        # x = layers.Dense(classes,
270
        #                     activation='softmax',
271
        #                     kernel_initializer=DENSE_KERNEL_INITIALIZER,
272
        #                     name='probs')(x)
273
274
        return x
275
276
    # 保证filter的大小可以被8整除
277
    def round_filters(self, filters):
278
        """Round number of filters based on depth multiplier."""
279
        filters *= self.width_coefficient
280
        divisor = self.depth_divisor
281
        new_filters = max(divisor, int(filters + divisor / 2) // divisor * divisor)
282
        # Make sure that round down does not go down by more than 10%.
283
        if new_filters < 0.9 * filters:
284
            new_filters += divisor
285
        return int(new_filters)
286
287
    # 重复次数,取顶
288
    def round_repeats(self, repeats):
289
        return int(math.ceil(self.depth_coefficient * repeats))
290
291
292
config_path = "examples/config_efficient_net.yaml"
293
train(
294
    gpu="",
295
    config_path=config_path,
296
    gpu_allow_growth=True,
297
    ckpt_path="",
298
)
299