Passed
Pull Request — main (#785)
by
unknown
01:38
created

efficient_net.EfficientNet.build_efficient_net()   A

Complexity

Conditions 4

Size

Total Lines 51
Code Lines 35

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 35
dl 0
loc 51
rs 9.0399
c 0
b 0
f 0
cc 4
nop 3

How to fix   Long Method   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

1
"""This script provides an example of using efficient for training."""
2
3
import os
4
import math
5
import numpy as np
6
import tensorflow as tf
7
from tensorflow.keras import layers
8
from copy import deepcopy
9
from typing import List, Optional, Tuple, Union
10
11
from deepreg.model import layer
12
from deepreg.model.backbone import Backbone
13
from deepreg.model.backbone.local_net import LocalNet
14
from deepreg.model.backbone.u_net import UNet
15
from deepreg.model.layer import Extraction
16
from deepreg.registry import REGISTRY
17
from deepreg.train import train
18
19
20
EFFICIENTNET_PARAMS = {
21
    # model_name: (width_mult, depth_mult, image_size, dropout_rate, dropconnect_rate)
22
    "efficientnet-b0": (1.0, 1.0, 224, 0.2, 0.2),
23
    "efficientnet-b1": (1.0, 1.1, 240, 0.2, 0.2),
24
    "efficientnet-b2": (1.1, 1.2, 260, 0.3, 0.2),
25
    "efficientnet-b3": (1.2, 1.4, 300, 0.3, 0.2),
26
    "efficientnet-b4": (1.4, 1.8, 380, 0.4, 0.2),
27
    "efficientnet-b5": (1.6, 2.2, 456, 0.4, 0.2),
28
    "efficientnet-b6": (1.8, 2.6, 528, 0.5, 0.2),
29
    "efficientnet-b7": (2.0, 3.1, 600, 0.5, 0.2),
30
}
31
32
# Each Blocks Parameters
33
DEFAULT_BLOCKS_ARGS = [
34
    {'kernel_size': 3, 'repeats': 1, 'filters_in': 32, 'filters_out': 16,
35
     'expand_ratio': 1, 'id_skip': True, 'strides': 1, 'se_ratio': 0.25},
36
    {'kernel_size': 3, 'repeats': 2, 'filters_in': 16, 'filters_out': 24,
37
     'expand_ratio': 6, 'id_skip': True, 'strides': 2, 'se_ratio': 0.25},
38
    {'kernel_size': 5, 'repeats': 2, 'filters_in': 24, 'filters_out': 40,
39
     'expand_ratio': 6, 'id_skip': True, 'strides': 2, 'se_ratio': 0.25},
40
    {'kernel_size': 3, 'repeats': 3, 'filters_in': 40, 'filters_out': 80,
41
     'expand_ratio': 6, 'id_skip': True, 'strides': 2, 'se_ratio': 0.25},
42
    {'kernel_size': 5, 'repeats': 3, 'filters_in': 80, 'filters_out': 112,
43
     'expand_ratio': 6, 'id_skip': True, 'strides': 1, 'se_ratio': 0.25},
44
    {'kernel_size': 5, 'repeats': 4, 'filters_in': 112, 'filters_out': 192,
45
     'expand_ratio': 6, 'id_skip': True, 'strides': 2, 'se_ratio': 0.25},
46
    {'kernel_size': 3, 'repeats': 1, 'filters_in': 192, 'filters_out': 320,
47
     'expand_ratio': 6, 'id_skip': True, 'strides': 1, 'se_ratio': 0.25}
48
]
49
50
# Two Kernel Initializer
51
CONV_KERNEL_INITIALIZER = {
52
    'class_name': 'VarianceScaling',
53
    'config': {
54
        'scale': 2.0,
55
        'mode': 'fan_out',
56
        'distribution': 'normal'
57
    }
58
}
59
60
DENSE_KERNEL_INITIALIZER = {
61
    'class_name': 'VarianceScaling',
62
    'config': {
63
        'scale': 1. / 3.,
64
        'mode': 'fan_out',
65
        'distribution': 'uniform'
66
    }
67
}
68
69
70
@REGISTRY.register_backbone(name="efficient_net")
71
class EfficientNet(LocalNet):
72
    """
73
    Class that implements an Efficient-Net for image registration.
74
75
    Reference:
76
    - Author: Mingxing Tan, Quoc V. Le,
77
      EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks
78
      https://arxiv.org/pdf/1905.11946.pdf
79
    """
80
    def __init__(
81
        self,
82
        image_size: tuple,
83
        num_channel_initial: int,
84
        extract_levels: Tuple[int, ...],
85
        out_kernel_initializer: str,
86
        out_activation: str,
87
        out_channels: int,
88
        depth: Optional[int] = None,
89
        pooling: bool = True,
90
        concat_skip: bool = False,
91
        width_coefficient: float = 1.0,
92
        depth_coefficient: float = 1.0,
93
        default_size: int = 224,
94
        dropout_rate: float = 0.2,
95
        drop_connect_rate: float = 0.2,
96
        depth_divisor: int = 8,
97
        name: str = "EfficientNet",
98
        **kwargs,
99
    ):
100
        """
101
        Init.
102
103
        Image is encoded gradually, i from level 0 to D,
104
        then it is decoded gradually, j from level D to 0.
105
        Some of the decoded levels are used for generating extractions.
106
107
        So, extract_levels are between [0, D].
108
109
        :param image_size: such as (dim1, dim2, dim3)
110
        :param num_channel_initial: number of initial channels.
111
        :param extract_levels: from which depths the output will be built.
112
        :param out_kernel_initializer: initializer to use for kernels.
113
        :param out_activation: activation to use at end layer.
114
        :param out_channels: number of channels for the extractions
115
        :param depth: depth of the encoder.
116
            If depth is not given, depth = max(extract_levels) will be used.
117
        :param pooling: for down-sampling, use non-parameterized
118
                        pooling if true, otherwise use conv3d
119
        :param concat_skip: when up-sampling, concatenate skipped
120
                            tensor if true, otherwise use addition
121
        :param width_coefficient: float, scaling coefficient for network width.
122
        :param depth_coefficient: float, scaling coefficient for network depth.
123
        :param default_size: int, default input image size.
124
        :param dropout_rate: float, dropout rate before final classifier layer.
125
        :param drop_connect_rate: float, dropout rate at skip connections.
126
        :param depth_divisor: int divisor for depth.
127
        :param name: name of the backbone.
128
        :param kwargs: additional arguments.
129
        """
130
        if depth is None:
131
            depth = max(extract_levels)
132
        kwargs["encode_kernel_sizes"] = [7] + [3] * depth
133
        super().__init__(
134
            image_size=image_size,
135
            num_channel_initial=num_channel_initial,
136
            depth=depth,
137
            extract_levels=extract_levels,
138
            out_kernel_initializer=out_kernel_initializer,
139
            out_activation=out_activation,
140
            out_channels=out_channels,
141
            use_additive_upsampling = False,
142
            pooling=pooling,
143
            concat_skip=concat_skip,
144
            name=name,
145
            **kwargs,
146
        )
147
148
        self.width_coefficient =  width_coefficient
149
        self.depth_coefficient = depth_coefficient
150
        self.default_size = default_size
151
        self.dropout_rate = dropout_rate
152
        self.drop_connect_rate = drop_connect_rate
153
        self.depth_divisor = depth_divisor
154
        self.activation_fn = tf.nn.swish
155
156 View Code Duplication
    def call(self, inputs: tf.Tensor, training=None, mask=None) -> tf.Tensor:
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
157
        """
158
        Build compute graph based on built layers.
159
160
        :param inputs: image batch, shape = (batch, f_dim1, f_dim2, f_dim3, ch)
161
        :param training: None or bool.
162
        :param mask: None or tf.Tensor.
163
        :return: shape = (batch, f_dim1, f_dim2, f_dim3, out_channels)
164
        """
165
166
        # encoding / down-sampling
167
        skips = []
168
        encoded = inputs
169
        for d in range(self._depth):
170
            skip = self._encode_convs[d](inputs=encoded, training=training)
171
            encoded = self._encode_pools[d](inputs=skip, training=training)
172
            skips.append(skip)
173
174
        # bottom
175
        decoded = self.build_efficient_net(inputs=encoded, training=training)  # type: ignore
176
177
        # decoding / up-sampling
178
        outs = [decoded]
179
        for d in range(self._depth - 1, min(self._extract_levels) - 1, -1):
180
            decoded = self._decode_deconvs[d](inputs=decoded, training=training)
181
            decoded = self.build_skip_block()([decoded, skips[d]])
182
            decoded = self._decode_convs[d](inputs=decoded, training=training)
183
            outs = [decoded] + outs
184
185
        # output
186
        output = self._output_block(outs)  # type: ignore
187
188
        return output
189
190
191
    def build_efficient_net(self, inputs: tf.Tensor, training=None) -> tf.Tensor:
192
        """
193
        Builds graph based on built layers.
194
195
        :param inputs: shape = (batch, f_dim1, f_dim2, f_dim3, in_channels)
196
        :param training:
197
        :param mask:
198
        :return: shape = (batch, f_dim1, f_dim2, f_dim3, out_channels)
199
        """
200
        img_input = layers.Input(tensor=inputs, shape=self.image_size)
201
        bn_axis = 4  
202
        x = img_input
203
        # x = layers.ZeroPadding3D(padding=self.correct_pad(x, 3),
204
        #                         name='stem_conv_pad')(x)
205
206
        x = layers.Conv3D(self.round_filters(32), 3,
207
                        strides=1,
208
                        padding='same',
209
                        use_bias=False,
210
                        kernel_initializer=CONV_KERNEL_INITIALIZER,
211
                        name='stem_conv')(x)
212
        x = layers.BatchNormalization(axis=bn_axis, name='stem_bn')(x)
213
        x = layers.Activation(self.activation_fn, name='stem_activation')(x)
214
        blocks_args = deepcopy(DEFAULT_BLOCKS_ARGS)
215
216
        b = 0
217
        # Calculate the number of blocks
218
        blocks = float(sum(args['repeats'] for args in blocks_args))
219
        for (i, args) in enumerate(blocks_args):
220
            assert args['repeats'] > 0
221
            args['filters_in'] = self.round_filters(args['filters_in'])
222
            args['filters_out'] = self.round_filters(args['filters_out'])
223
224
            for j in range(self.round_repeats(args.pop('repeats'))):
225
                if j > 0:
226
                    args['strides'] = 1
227
                    args['filters_in'] = args['filters_out']
228
                x = self.block(x, self.activation_fn, self.drop_connect_rate * b / blocks,
229
                        name='block{}{}_'.format(i + 1, chr(j + 97)), **args)
230
                b += 1
231
        
232
        x = layers.Conv3D(self.round_filters(128), 1,
233
                        padding='same',
234
                        use_bias=False,
235
                        kernel_initializer=CONV_KERNEL_INITIALIZER,
236
                        name='top_conv')(x)
237
        x = layers.BatchNormalization(axis=bn_axis, name='top_bn')(x)
238
        x = layers.Activation(self.activation_fn, name='top_activation')(x)
239
240
        print("input.shape", inputs.shape, x.shape)
241
        return x
242
243
    def round_filters(self, filters):
244
        """Round number of filters based on depth multiplier."""
245
        filters *= self.width_coefficient
246
        divisor = self.depth_divisor
247
        new_filters = max(divisor, int(filters + divisor / 2) // divisor * divisor)
248
        # Make sure that round down does not go down by more than 10%.
249
        if new_filters < 0.9 * filters:
250
            new_filters += divisor
251
        return int(new_filters)
252
253
    def round_repeats(self, repeats):
254
        return int(math.ceil(self.depth_coefficient * repeats))
255
256
    def correct_pad(self, inputs, kernel_size):
257
        img_dim = 1
258
        input_size = backend.int_shape(inputs)[img_dim:(img_dim + 3)]
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable backend does not seem to be defined.
Loading history...
259
260
        if isinstance(kernel_size, int):
261
            kernel_size = (kernel_size, kernel_size, kernel_size)
262
263
        if input_size[0] is None:
264
            adjust = (1, 1, 1)
265
        else:
266
            adjust = (1 - input_size[0] % 2, 1 - input_size[1] % 2, 1 - input_size[2] % 2)
267
268
        correct = (kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[2] // 2)
269
270
        return ((correct[0] - adjust[0], correct[0]),
271
                (correct[1] - adjust[1], correct[1]),
272
                (correct[2] - adjust[2], correct[2]))
273
274
    def block(self, inputs, activation_fn=tf.nn.swish, drop_rate=0., name='',
275
            filters_in=32, filters_out=16, kernel_size=3, strides=1,
276
            expand_ratio=1, se_ratio=0., id_skip=True):
277
278
        bn_axis = 4
279
280
        filters = filters_in * expand_ratio
281
282
        # Inverted residuals
283
        if expand_ratio != 1:
284
            x = layers.Conv3D(filters, 1,
285
                            padding='same',
286
                            use_bias=False,
287
                            kernel_initializer=CONV_KERNEL_INITIALIZER,
288
                            name=name + 'expand_conv')(inputs)
289
            x = layers.BatchNormalization(axis=bn_axis, name=name + 'expand_bn')(x)
290
            x = layers.Activation(activation_fn, name=name + 'expand_activation')(x)
291
        else:
292
            x = inputs
293
294
        # padding
295
        # if strides == 2:
296
        #     x = layers.ZeroPadding3D(padding=self.correct_pad(x, kernel_size),
297
        #                             name=name + 'dwconv_pad')(x)
298
        #     conv_pad = 'valid'
299
        # else:
300
        #     conv_pad = 'same'
301
302
        # TODO(Sicong): Find DepthwiseConv3D
303
        # x = layers.DepthwiseConv2D(kernel_size,
304
        #                         strides=strides,
305
        #                         padding=conv_pad,
306
        #                         use_bias=False,
307
        #                         depthwise_initializer=CONV_KERNEL_INITIALIZER,
308
        #                         name=name + 'dwconv')(x)
309
        x = layers.BatchNormalization(axis=bn_axis, name=name + 'bn')(x)
310
        x = layers.Activation(activation_fn, name=name + 'activation')(x)
311
312
        if 0 < se_ratio <= 1:
313
            filters_se = max(1, int(filters_in * se_ratio))
314
            se = layers.GlobalAveragePooling3D(name=name + 'se_squeeze')(x)
315
            se = layers.Reshape((1, 1, 1, filters), name=name + 'se_reshape')(se)
316
            se = layers.Conv3D(filters_se, 1,
317
                            padding='same',
318
                            activation=activation_fn,
319
                            kernel_initializer=CONV_KERNEL_INITIALIZER,
320
                            name=name + 'se_reduce')(se)
321
            se = layers.Conv3D(filters, 1,
322
                            padding='same',
323
                            activation='sigmoid',
324
                            kernel_initializer=CONV_KERNEL_INITIALIZER,
325
                            name=name + 'se_expand')(se)
326
            x = layers.multiply([x, se], name=name + 'se_excite')
327
328
        x = layers.Conv3D(filters_out, 1,
329
                        padding='same',
330
                        use_bias=False,
331
                        kernel_initializer=CONV_KERNEL_INITIALIZER,
332
                        name=name + 'project_conv')(x)
333
        x = layers.BatchNormalization(axis=bn_axis, name=name + 'project_bn')(x)
334
335
        if (id_skip is True and strides == 1 and filters_in == filters_out):
336
            if drop_rate > 0:
337
                x = layers.Dropout(drop_rate,
338
                                noise_shape=None,
339
                                name=name + 'drop')(x)
340
            x = layers.add([x, inputs], name=name + 'add')
341
342
        return x
343
344
    def get_config(self) -> dict:
345
        """Return the config dictionary for recreating this class."""
346
        config = super().get_config()
347
        return config
348
349
350
if __name__ == "__main__":
351
    config_path = "examples/config_efficient_net.yaml"
352
    train(
353
        gpu="",
354
        config_path=config_path,
355
        gpu_allow_growth=True,
356
        ckpt_path="",
357
    )
358