Passed
Pull Request — main (#785)
by
unknown
04:05
created

depth_wise_3D.DepthwiseConv3D.__init__()   B

Complexity

Conditions 1

Size

Total Lines 42
Code Lines 42

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 42
dl 0
loc 42
rs 8.872
c 0
b 0
f 0
cc 1
nop 18

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
# https://github.com/alexandrosstergiou/keras-DepthwiseConv3D
2
# -*- coding: utf-8 -*-
3
4
'''
5
This is a modification of the SeparableConv3D code in Keras,
6
to perform just the Depthwise Convolution (1st step) of the
7
Depthwise Separable Convolution layer.
8
'''
9
from __future__ import absolute_import
10
11
from keras import backend as K
12
from keras import initializers
13
from keras import regularizers
14
from keras import constraints
15
from keras import layers
16
from keras.engine import InputSpec
17
from keras.utils import conv_utils
18
from keras.legacy.interfaces import conv3d_args_preprocessor, generate_legacy_interface
19
from keras.layers import Conv3D
20
from keras.backend.tensorflow_backend import _preprocess_padding, _preprocess_conv3d_input
21
22
import tensorflow as tf
23
24
25
def depthwise_conv3d_args_preprocessor(args, kwargs):
26
    converted = []
27
28
    if 'init' in kwargs:
29
        init = kwargs.pop('init')
30
        kwargs['depthwise_initializer'] = init
31
        converted.append(('init', 'depthwise_initializer'))
32
33
    args, kwargs, _converted = conv3d_args_preprocessor(args, kwargs)
34
    return args, kwargs, converted + _converted
35
36
    legacy_depthwise_conv3d_support = generate_legacy_interface(
37
    allowed_positional_args=['filters', 'kernel_size'],
38
    conversions=[('nb_filter', 'filters'),
39
                 ('subsample', 'strides'),
40
                 ('border_mode', 'padding'),
41
                 ('dim_ordering', 'data_format'),
42
                 ('b_regularizer', 'bias_regularizer'),
43
                 ('b_constraint', 'bias_constraint'),
44
                 ('bias', 'use_bias')],
45
    value_conversions={'dim_ordering': {'tf': 'channels_last',
46
                                        'th': 'channels_first',
47
                                        'default': None}},
48
    preprocessor=depthwise_conv3d_args_preprocessor)
49
50
51
class DepthwiseConv3D(Conv3D):
52
    """Depthwise 3D convolution.
53
    Depth-wise part of separable convolutions consist in performing
54
    just the first step/operation
55
    (which acts on each input channel separately).
56
    It does not perform the pointwise convolution (second step).
57
    The `depth_multiplier` argument controls how many
58
    output channels are generated per input channel in the depthwise step.
59
    # Arguments
60
        kernel_size: An integer or tuple/list of 3 integers, specifying the
61
            depth, width and height of the 3D convolution window.
62
            Can be a single integer to specify the same value for
63
            all spatial dimensions.
64
        strides: An integer or tuple/list of 3 integers,
65
            specifying the strides of the convolution along the depth, width and height.
66
            Can be a single integer to specify the same value for
67
            all spatial dimensions.
68
        padding: one of `"valid"` or `"same"` (case-insensitive).
69
        depth_multiplier: The number of depthwise convolution output channels
70
            for each input channel.
71
            The total number of depthwise convolution output
72
            channels will be equal to `filterss_in * depth_multiplier`.
73
        groups: The depth size of the convolution (as a variant of the original Depthwise conv)
74
        data_format: A string,
75
            one of `channels_last` (default) or `channels_first`.
76
            The ordering of the dimensions in the inputs.
77
            `channels_last` corresponds to inputs with shape
78
            `(batch, height, width, channels)` while `channels_first`
79
            corresponds to inputs with shape
80
            `(batch, channels, height, width)`.
81
            It defaults to the `image_data_format` value found in your
82
            Keras config file at `~/.keras/keras.json`.
83
            If you never set it, then it will be "channels_last".
84
        activation: Activation function to use
85
            (see [activations](../activations.md)).
86
            If you don't specify anything, no activation is applied
87
            (ie. "linear" activation: `a(x) = x`).
88
        use_bias: Boolean, whether the layer uses a bias vector.
89
        depthwise_initializer: Initializer for the depthwise kernel matrix
90
            (see [initializers](../initializers.md)).
91
        bias_initializer: Initializer for the bias vector
92
            (see [initializers](../initializers.md)).
93
        depthwise_regularizer: Regularizer function applied to
94
            the depthwise kernel matrix
95
            (see [regularizer](../regularizers.md)).
96
        bias_regularizer: Regularizer function applied to the bias vector
97
            (see [regularizer](../regularizers.md)).
98
        dialation_rate: List of ints.
99
                        Defines the dilation factor for each dimension in the
100
                        input. Defaults to (1,1,1)
101
        activity_regularizer: Regularizer function applied to
102
            the output of the layer (its "activation").
103
            (see [regularizer](../regularizers.md)).
104
        depthwise_constraint: Constraint function applied to
105
            the depthwise kernel matrix
106
            (see [constraints](../constraints.md)).
107
        bias_constraint: Constraint function applied to the bias vector
108
            (see [constraints](../constraints.md)).
109
    # Input shape
110
        5D tensor with shape:
111
        `(batch, depth, channels, rows, cols)` if data_format='channels_first'
112
        or 5D tensor with shape:
113
        `(batch, depth, rows, cols, channels)` if data_format='channels_last'.
114
    # Output shape
115
        5D tensor with shape:
116
        `(batch, filters * depth, new_depth, new_rows, new_cols)` if data_format='channels_first'
117
        or 4D tensor with shape:
118
        `(batch, new_depth, new_rows, new_cols, filters * depth)` if data_format='channels_last'.
119
        `rows` and `cols` values might have changed due to padding.
120
    """
121
122
    #@legacy_depthwise_conv3d_support
123
    def __init__(self,
124
                 kernel_size,
125
                 strides=(1, 1, 1),
126
                 padding='valid',
127
                 depth_multiplier=1,
128
                 groups=None,
129
                 data_format=None,
130
                 activation=None,
131
                 use_bias=True,
132
                 depthwise_initializer='glorot_uniform',
133
                 bias_initializer='zeros',
134
                 dilation_rate = (1, 1, 1),
135
                 depthwise_regularizer=None,
136
                 bias_regularizer=None,
137
                 activity_regularizer=None,
138
                 depthwise_constraint=None,
139
                 bias_constraint=None,
140
                 **kwargs):
141
        super(DepthwiseConv3D, self).__init__(
142
            filters=None,
143
            kernel_size=kernel_size,
144
            strides=strides,
145
            padding=padding,
146
            data_format=data_format,
147
            activation=activation,
148
            use_bias=use_bias,
149
            bias_regularizer=bias_regularizer,
150
            dilation_rate=dilation_rate,
151
            activity_regularizer=activity_regularizer,
152
            bias_constraint=bias_constraint,
153
            **kwargs)
154
        self.depth_multiplier = depth_multiplier
155
        self.groups = groups
156
        self.depthwise_initializer = initializers.get(depthwise_initializer)
157
        self.depthwise_regularizer = regularizers.get(depthwise_regularizer)
158
        self.depthwise_constraint = constraints.get(depthwise_constraint)
159
        self.bias_initializer = initializers.get(bias_initializer)
160
        self.dilation_rate = dilation_rate
161
        self._padding = _preprocess_padding(self.padding)
162
        self._strides = (1,) + self.strides + (1,)
163
        self._data_format = "NDHWC"
164
        self.input_dim = None
165
166
    def build(self, input_shape):
167
        if len(input_shape) < 5:
168
            raise ValueError('Inputs to `DepthwiseConv3D` should have rank 5. '
169
                             'Received input shape:', str(input_shape))
170
        if self.data_format == 'channels_first':
171
            channel_axis = 1
172
        else:
173
            channel_axis = -1
174
        if input_shape[channel_axis] is None:
175
            raise ValueError('The channel dimension of the inputs to '
176
                             '`DepthwiseConv3D` '
177
                             'should be defined. Found `None`.')
178
        self.input_dim = int(input_shape[channel_axis])
179
180
        if self.groups is None:
181
            self.groups = self.input_dim
182
183
        if self.groups > self.input_dim:
184
            raise ValueError('The number of groups cannot exceed the number of channels')
185
186
        if self.input_dim % self.groups != 0:
187
            raise ValueError('Warning! The channels dimension is not divisible by the group size chosen')
188
189
        depthwise_kernel_shape = (self.kernel_size[0],
190
                                  self.kernel_size[1],
191
                                  self.kernel_size[2],
192
                                  self.input_dim,
193
                                  self.depth_multiplier)
194
195
        self.depthwise_kernel = self.add_weight(
196
            shape=depthwise_kernel_shape,
197
            initializer=self.depthwise_initializer,
198
            name='depthwise_kernel',
199
            regularizer=self.depthwise_regularizer,
200
            constraint=self.depthwise_constraint)
201
202
        if self.use_bias:
203
            self.bias = self.add_weight(shape=(self.groups * self.depth_multiplier,),
204
                                        initializer=self.bias_initializer,
205
                                        name='bias',
206
                                        regularizer=self.bias_regularizer,
207
                                        constraint=self.bias_constraint)
208
        else:
209
            self.bias = None
210
        # Set input spec.
211
        self.input_spec = InputSpec(ndim=5, axes={channel_axis: self.input_dim})
212
        self.built = True
213
214
    def call(self, inputs, training=None):
215
        inputs = _preprocess_conv3d_input(inputs, self.data_format)
216
217
        if self.data_format == 'channels_last':
218
            dilation = (1,) + self.dilation_rate + (1,)
219
        else:
220
            dilation = self.dilation_rate + (1,) + (1,)
221
222
        if self._data_format == 'NCDHW':
223
            outputs = tf.concat(
224
                [tf.nn.conv3d(inputs[0][:, i:i+self.input_dim//self.groups, :, :, :], self.depthwise_kernel[:, :, :, i:i+self.input_dim//self.groups, :],
225
                    strides=self._strides,
226
                    padding=self._padding,
227
                    dilations=dilation,
228
                    data_format=self._data_format) for i in range(0, self.input_dim, self.input_dim//self.groups)], axis=1)
229
230
        else:
231
            outputs = tf.concat(
232
                [tf.nn.conv3d(inputs[0][:, :, :, :, i:i+self.input_dim//self.groups], self.depthwise_kernel[:, :, :, i:i+self.input_dim//self.groups, :],
233
                    strides=self._strides,
234
                    padding=self._padding,
235
                    dilations=dilation,
236
                    data_format=self._data_format) for i in range(0, self.input_dim, self.input_dim//self.groups)], axis=-1)
237
238
        if self.bias is not None:
239
            outputs = K.bias_add(
240
                outputs,
241
                self.bias,
242
                data_format=self.data_format)
243
244
        if self.activation is not None:
245
            return self.activation(outputs)
246
247
        return outputs
248
249
    def compute_output_shape(self, input_shape):
250
        if self.data_format == 'channels_first':
251
            depth = input_shape[2]
252
            rows = input_shape[3]
253
            cols = input_shape[4]
254
            out_filters = self.groups * self.depth_multiplier
255
        elif self.data_format == 'channels_last':
256
            depth = input_shape[1]
257
            rows = input_shape[2]
258
            cols = input_shape[3]
259
            out_filters = self.groups * self.depth_multiplier
260
261
        depth = conv_utils.conv_output_length(depth, self.kernel_size[0],
0 ignored issues
show
introduced by
The variable depth does not seem to be defined for all execution paths.
Loading history...
262
                                             self.padding,
263
                                             self.strides[0])
264
265
        rows = conv_utils.conv_output_length(rows, self.kernel_size[1],
0 ignored issues
show
introduced by
The variable rows does not seem to be defined for all execution paths.
Loading history...
266
                                             self.padding,
267
                                             self.strides[1])
268
269
        cols = conv_utils.conv_output_length(cols, self.kernel_size[2],
0 ignored issues
show
introduced by
The variable cols does not seem to be defined for all execution paths.
Loading history...
270
                                             self.padding,
271
                                             self.strides[2])
272
273
        if self.data_format == 'channels_first':
274
            return (input_shape[0], out_filters, depth, rows, cols)
0 ignored issues
show
introduced by
The variable out_filters does not seem to be defined for all execution paths.
Loading history...
275
276
        elif self.data_format == 'channels_last':
277
            return (input_shape[0], depth, rows, cols, out_filters)
278
279
    def get_config(self):
280
        config = super(DepthwiseConv3D, self).get_config()
281
        config.pop('filters')
282
        config.pop('kernel_initializer')
283
        config.pop('kernel_regularizer')
284
        config.pop('kernel_constraint')
285
        config['depth_multiplier'] = self.depth_multiplier
286
        config['depthwise_initializer'] = initializers.serialize(self.depthwise_initializer)
287
        config['depthwise_regularizer'] = regularizers.serialize(self.depthwise_regularizer)
288
        config['depthwise_constraint'] = constraints.serialize(self.depthwise_constraint)
289
        return config
290
291
DepthwiseConvolution3D = DepthwiseConv3D