Passed
Pull Request — main (#785)
by
unknown
04:05
created

efficient_net.EfficientNet.get_config()   A

Complexity

Conditions 1

Size

Total Lines 4
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 3
dl 0
loc 4
rs 10
c 0
b 0
f 0
cc 1
nop 1
1
"""This script provides an example of using efficient for training."""
2
3
import os
4
import math
5
import numpy as np
6
import tensorflow as tf
7
# tf.config.experimental_run_functions_eagerly(True)
8
9
from tensorflow.keras import layers
10
from copy import deepcopy
11
from typing import List, Optional, Tuple, Union
12
13
from deepreg.model import layer
14
from deepreg.model.backbone import Backbone
15
from deepreg.model.backbone.local_net import LocalNet
16
from deepreg.model.backbone.u_net import UNet
17
from deepreg.model.layer import Extraction
18
from deepreg.registry import REGISTRY
19
from deepreg.train import train
20
21
22
EFFICIENTNET_PARAMS = {
23
    # model_name: (width_mult, depth_mult, image_size, dropout_rate, dropconnect_rate)
24
    "efficientnet-b0": (1.0, 1.0, 224, 0.2, 0.2),
25
    "efficientnet-b1": (1.0, 1.1, 240, 0.2, 0.2),
26
    "efficientnet-b2": (1.1, 1.2, 260, 0.3, 0.2),
27
    "efficientnet-b3": (1.2, 1.4, 300, 0.3, 0.2),
28
    "efficientnet-b4": (1.4, 1.8, 380, 0.4, 0.2),
29
    "efficientnet-b5": (1.6, 2.2, 456, 0.4, 0.2),
30
    "efficientnet-b6": (1.8, 2.6, 528, 0.5, 0.2),
31
    "efficientnet-b7": (2.0, 3.1, 600, 0.5, 0.2),
32
}
33
34
# Each Blocks Parameters
35
DEFAULT_BLOCKS_ARGS = [
36
    {'kernel_size': 3, 'repeats': 1, 'filters_in': 32, 'filters_out': 16,
37
     'expand_ratio': 1, 'id_skip': True, 'strides': 1, 'se_ratio': 0.25},
38
    {'kernel_size': 3, 'repeats': 2, 'filters_in': 16, 'filters_out': 24,
39
     'expand_ratio': 6, 'id_skip': True, 'strides': 2, 'se_ratio': 0.25},
40
    {'kernel_size': 5, 'repeats': 2, 'filters_in': 24, 'filters_out': 40,
41
     'expand_ratio': 6, 'id_skip': True, 'strides': 2, 'se_ratio': 0.25},
42
    {'kernel_size': 3, 'repeats': 3, 'filters_in': 40, 'filters_out': 80,
43
     'expand_ratio': 6, 'id_skip': True, 'strides': 2, 'se_ratio': 0.25},
44
    {'kernel_size': 5, 'repeats': 3, 'filters_in': 80, 'filters_out': 112,
45
     'expand_ratio': 6, 'id_skip': True, 'strides': 1, 'se_ratio': 0.25},
46
    {'kernel_size': 5, 'repeats': 4, 'filters_in': 112, 'filters_out': 192,
47
     'expand_ratio': 6, 'id_skip': True, 'strides': 2, 'se_ratio': 0.25},
48
    {'kernel_size': 3, 'repeats': 1, 'filters_in': 192, 'filters_out': 320,
49
     'expand_ratio': 6, 'id_skip': True, 'strides': 1, 'se_ratio': 0.25}
50
]
51
52
# Two Kernel Initializer
53
CONV_KERNEL_INITIALIZER = {
54
    'class_name': 'VarianceScaling',
55
    'config': {
56
        'scale': 2.0,
57
        'mode': 'fan_out',
58
        'distribution': 'normal'
59
    }
60
}
61
62
DENSE_KERNEL_INITIALIZER = {
63
    'class_name': 'VarianceScaling',
64
    'config': {
65
        'scale': 1. / 3.,
66
        'mode': 'fan_out',
67
        'distribution': 'uniform'
68
    }
69
}
70
71
72
@REGISTRY.register_backbone(name="efficient_net")
73
class EfficientNet(LocalNet):
74
    """
75
    Class that implements an Efficient-Net for image registration.
76
77
    Reference:
78
    - Author: Mingxing Tan, Quoc V. Le,
79
      EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks
80
      https://arxiv.org/pdf/1905.11946.pdf
81
    """
82
    def __init__(
83
        self,
84
        image_size: tuple,
85
        num_channel_initial: int,
86
        extract_levels: Tuple[int, ...],
87
        out_kernel_initializer: str,
88
        out_activation: str,
89
        out_channels: int,
90
        depth: Optional[int] = None,
91
        pooling: bool = True,
92
        concat_skip: bool = False,
93
        width_coefficient: float = 1.0,
94
        depth_coefficient: float = 1.0,
95
        default_size: int = 224,
96
        dropout_rate: float = 0.2,
97
        drop_connect_rate: float = 0.2,
98
        depth_divisor: int = 8,
99
        name: str = "EfficientNet",
100
        **kwargs,
101
    ):
102
        """
103
        Init.
104
105
        Image is encoded gradually, i from level 0 to D,
106
        then it is decoded gradually, j from level D to 0.
107
        Some of the decoded levels are used for generating extractions.
108
109
        So, extract_levels are between [0, D].
110
111
        :param image_size: such as (dim1, dim2, dim3)
112
        :param num_channel_initial: number of initial channels.
113
        :param extract_levels: from which depths the output will be built.
114
        :param out_kernel_initializer: initializer to use for kernels.
115
        :param out_activation: activation to use at end layer.
116
        :param out_channels: number of channels for the extractions
117
        :param depth: depth of the encoder.
118
            If depth is not given, depth = max(extract_levels) will be used.
119
        :param pooling: for down-sampling, use non-parameterized
120
                        pooling if true, otherwise use conv3d
121
        :param concat_skip: when up-sampling, concatenate skipped
122
                            tensor if true, otherwise use addition
123
        :param width_coefficient: float, scaling coefficient for network width.
124
        :param depth_coefficient: float, scaling coefficient for network depth.
125
        :param default_size: int, default input image size.
126
        :param dropout_rate: float, dropout rate before final classifier layer.
127
        :param drop_connect_rate: float, dropout rate at skip connections.
128
        :param depth_divisor: int divisor for depth.
129
        :param name: name of the backbone.
130
        :param kwargs: additional arguments.
131
        """
132
        if depth is None:
133
            depth = max(extract_levels)
134
        kwargs["encode_kernel_sizes"] = [7] + [3] * depth
135
        super().__init__(
136
            image_size=image_size,
137
            num_channel_initial=num_channel_initial,
138
            depth=depth,
139
            extract_levels=extract_levels,
140
            out_kernel_initializer=out_kernel_initializer,
141
            out_activation=out_activation,
142
            out_channels=out_channels,
143
            use_additive_upsampling = False,
144
            pooling=pooling,
145
            concat_skip=concat_skip,
146
            name=name,
147
            **kwargs,
148
        )
149
150
        self.width_coefficient =  width_coefficient
151
        self.depth_coefficient = depth_coefficient
152
        self.default_size = default_size
153
        self.dropout_rate = dropout_rate
154
        self.drop_connect_rate = drop_connect_rate
155
        self.depth_divisor = depth_divisor
156
        self.activation_fn = tf.nn.swish
157
158
    def call(self, inputs: tf.Tensor, training=None, mask=None) -> tf.Tensor:
159
        """
160
        Build compute graph based on built layers.
161
162
        :param inputs: image batch, shape = (batch, f_dim1, f_dim2, f_dim3, ch)
163
        :param training: None or bool.
164
        :param mask: None or tf.Tensor.
165
        :return: shape = (batch, f_dim1, f_dim2, f_dim3, out_channels)
166
        """
167
168
        # encoding / down-sampling
169
        skips = []
170
        encoded = inputs
171
        print("before:", encoded.shape)
172
173
        for d in range(self._depth):
174
            skip = self._encode_convs[d](inputs=encoded, training=training)
175
            encoded = self._encode_pools[d](inputs=skip, training=training)
176
            skips.append(skip)
177
        print("skip: ", len(skip), self._depth)
0 ignored issues
show
introduced by
The variable skip does not seem to be defined in case the for loop on line 173 is not entered. Are you sure this can never be the case?
Loading history...
178
        print("encode: ", encoded.shape)
179
        # bottom
180
        import efficientnet_3D.keras as efn 
181
        e_model = efn.EfficientNetB0(input_shape=(24, 24, 26, 64), weights='imagenet')
182
        decoded = e_model(encoded)
183
        # decoded = self.build_efficient_net(inputs=encoded)  # type: ignore
184
185
        print("decode: ", decoded.shape)
186
        # decoding / up-sampling
187
        outs = [decoded]
188
        for d in range(self._depth - 1, min(self._extract_levels) - 1, -1):
189
            decoded = self._decode_deconvs[d](inputs=decoded, training=training)
190
            decoded = self.build_skip_block()([decoded, skips[d]])
191
            decoded = self._decode_convs[d](inputs=decoded, training=training)
192
            outs = [decoded] + outs
193
194
        # output
195
        output = self._output_block(outs)  # type: ignore
196
197
        return output
198
199
200
    def build_efficient_net(self, inputs: tf.Tensor, training=None) -> tf.Tensor:
201
        """
202
        Builds graph based on built layers.
203
204
        :param inputs: shape = (batch, f_dim1, f_dim2, f_dim3, in_channels)
205
        :param training:
206
        :param mask:
207
        :return: shape = (batch, f_dim1, f_dim2, f_dim3, out_channels)
208
        """
209
        img_input = layers.Input(tensor=inputs, shape=self.image_size)
210
        bn_axis = 4  
211
        x = img_input
212
        # x = layers.ZeroPadding3D(padding=self.correct_pad(x, 3),
213
        #                         name='stem_conv_pad')(x)
214
215
        x = layers.Conv3D(self.round_filters(32), 3,
216
                        strides=1,
217
                        padding='same',
218
                        use_bias=False,
219
                        # kernel_initializer=CONV_KERNEL_INITIALIZER,
220
                        name='stem_conv')(x)
221
        x = layers.BatchNormalization(axis=bn_axis, name='stem_bn')(x)
222
        x = layers.Activation(self.activation_fn, name='stem_activation')(x)
223
        blocks_args = deepcopy(DEFAULT_BLOCKS_ARGS)
224
225
        b = 0
226
        # Calculate the number of blocks
227
        # blocks = float(sum(args['repeats'] for args in blocks_args))
228
        # for (i, args) in enumerate(blocks_args):
229
        #     assert args['repeats'] > 0
230
        #     args['filters_in'] = self.round_filters(args['filters_in'])
231
        #     args['filters_out'] = self.round_filters(args['filters_out'])
232
233
        #     for j in range(self.round_repeats(args.pop('repeats'))):
234
        #         if j > 0:
235
        #             args['strides'] = 1
236
        #             args['filters_in'] = args['filters_out']
237
        #         x = self.block(x, self.activation_fn, self.drop_connect_rate * b / blocks,
238
        #                 name='block{}{}_'.format(i + 1, chr(j + 97)), **args)
239
        #         b += 1
240
        
241
        x = layers.Conv3D(self.round_filters(128), 1,
242
                        padding='same',
243
                        use_bias=False,
244
                        # kernel_initializer=CONV_KERNEL_INITIALIZER,
245
                        name='top_conv')(x)
246
        x = layers.BatchNormalization(axis=bn_axis, name='top_bn')(x)
247
        x = layers.Activation(self.activation_fn, name='top_activation')(x)
248
249
        print("input.shape", inputs.shape, x.shape)
250
        return x
251
252
    def round_filters(self, filters):
253
        """Round number of filters based on depth multiplier."""
254
        filters *= self.width_coefficient
255
        divisor = self.depth_divisor
256
        new_filters = max(divisor, int(filters + divisor / 2) // divisor * divisor)
257
        # Make sure that round down does not go down by more than 10%.
258
        if new_filters < 0.9 * filters:
259
            new_filters += divisor
260
        return int(new_filters)
261
262
    def round_repeats(self, repeats):
263
        return int(math.ceil(self.depth_coefficient * repeats))
264
265
    def correct_pad(self, inputs, kernel_size):
266
        img_dim = 1
267
        input_size = backend.int_shape(inputs)[img_dim:(img_dim + 3)]
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable backend does not seem to be defined.
Loading history...
268
269
        if isinstance(kernel_size, int):
270
            kernel_size = (kernel_size, kernel_size, kernel_size)
271
272
        if input_size[0] is None:
273
            adjust = (1, 1, 1)
274
        else:
275
            adjust = (1 - input_size[0] % 2, 1 - input_size[1] % 2, 1 - input_size[2] % 2)
276
277
        correct = (kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[2] // 2)
278
279
        return ((correct[0] - adjust[0], correct[0]),
280
                (correct[1] - adjust[1], correct[1]),
281
                (correct[2] - adjust[2], correct[2]))
282
283
    def block(self, inputs, activation_fn=tf.nn.swish, drop_rate=0., name='',
284
            filters_in=32, filters_out=16, kernel_size=3, strides=1,
285
            expand_ratio=1, se_ratio=0., id_skip=True):
286
287
        bn_axis = 4
288
289
        filters = filters_in * expand_ratio
290
291
        # Inverted residuals
292
        if expand_ratio != 1:
293
            x = layers.Conv3D(filters, 1,
294
                            padding='same',
295
                            use_bias=False,
296
                            # kernel_initializer=CONV_KERNEL_INITIALIZER,
297
                            name=name + 'expand_conv')(inputs)
298
            x = layers.BatchNormalization(axis=bn_axis, name=name + 'expand_bn')(x)
299
            x = layers.Activation(activation_fn, name=name + 'expand_activation')(x)
300
        else:
301
            x = inputs
302
303
        # padding
304
        # if strides == 2:
305
        #     x = layers.ZeroPadding3D(padding=self.correct_pad(x, kernel_size),
306
        #                             name=name + 'dwconv_pad')(x)
307
        #     conv_pad = 'valid'
308
        # else:
309
        #     conv_pad = 'same'
310
311
        # TODO(Sicong): Find DepthwiseConv3D
312
        # x = layers.DepthwiseConv2D(kernel_size,
313
        #                         strides=strides,
314
        #                         padding=conv_pad,
315
        #                         use_bias=False,
316
        #                         depthwise_initializer=CONV_KERNEL_INITIALIZER,
317
        #                         name=name + 'dwconv')(x)
318
        x = layers.BatchNormalization(axis=bn_axis, name=name + 'bn')(x)
319
        x = layers.Activation(activation_fn, name=name + 'activation')(x)
320
321
        if 0 < se_ratio <= 1:
322
            filters_se = max(1, int(filters_in * se_ratio))
323
            se = layers.GlobalAveragePooling3D(name=name + 'se_squeeze')(x)
324
            se = layers.Reshape((1, 1, 1, filters), name=name + 'se_reshape')(se)
325
            se = layers.Conv3D(filters_se, 1,
326
                            padding='same',
327
                            activation=activation_fn,
328
                            # kernel_initializer=CONV_KERNEL_INITIALIZER,
329
                            name=name + 'se_reduce')(se)
330
            se = layers.Conv3D(filters, 1,
331
                            padding='same',
332
                            activation='sigmoid',
333
                            # kernel_initializer=CONV_KERNEL_INITIALIZER,
334
                            name=name + 'se_expand')(se)
335
            x = layers.multiply([x, se], name=name + 'se_excite')
336
337
        x = layers.Conv3D(filters_out, 1,
338
                        padding='same',
339
                        use_bias=False,
340
                        # kernel_initializer=CONV_KERNEL_INITIALIZER,
341
                        name=name + 'project_conv')(x)
342
        x = layers.BatchNormalization(axis=bn_axis, name=name + 'project_bn')(x)
343
344
        if (id_skip is True and strides == 1 and filters_in == filters_out):
345
            if drop_rate > 0:
346
                x = layers.Dropout(drop_rate,
347
                                noise_shape=None,
348
                                name=name + 'drop')(x)
349
            x = layers.add([x, inputs], name=name + 'add')
350
351
        return x
352
353
    def get_config(self) -> dict:
354
        """Return the config dictionary for recreating this class."""
355
        config = super().get_config()
356
        return config
357
358
359
if __name__ == "__main__":
360
    config_path = "examples/config_efficient_net.yaml"
361
    train(
362
        gpu="0",
363
        config_path=config_path,
364
        gpu_allow_growth=True,
365
        ckpt_path="",
366
    )
367