Passed
Pull Request — main (#785)
by
unknown
04:05
created

DepthwiseConv3D.compute_output_shape()   B

Complexity

Conditions 5

Size

Total Lines 29
Code Lines 24

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 24
dl 0
loc 29
rs 8.8373
c 0
b 0
f 0
cc 5
nop 2
1
# https://github.com/alexandrosstergiou/keras-DepthwiseConv3D
2
# -*- coding: utf-8 -*-
3
4
'''
5
This is a modification of the SeparableConv3D code in Keras,
6
to perform just the Depthwise Convolution (1st step) of the
7
Depthwise Separable Convolution layer.
8
'''
9
from __future__ import absolute_import
10
11
from keras import backend as K
12
from keras import initializers
13
from keras import regularizers
14
from keras import constraints
15
from keras import layers
16
from keras.engine import InputSpec
17
from keras.utils import conv_utils
18
from keras.legacy.interfaces import conv3d_args_preprocessor, generate_legacy_interface
19
from keras.layers import Conv3D
20
from keras.backend.tensorflow_backend import _preprocess_padding, _preprocess_conv3d_input
21
22
import tensorflow as tf
23
24
25
def depthwise_conv3d_args_preprocessor(args, kwargs):
26
    converted = []
27
28
    if 'init' in kwargs:
29
        init = kwargs.pop('init')
30
        kwargs['depthwise_initializer'] = init
31
        converted.append(('init', 'depthwise_initializer'))
32
33
    args, kwargs, _converted = conv3d_args_preprocessor(args, kwargs)
34
    return args, kwargs, converted + _converted
35
36
    legacy_depthwise_conv3d_support = generate_legacy_interface(
37
    allowed_positional_args=['filters', 'kernel_size'],
38
    conversions=[('nb_filter', 'filters'),
39
                 ('subsample', 'strides'),
40
                 ('border_mode', 'padding'),
41
                 ('dim_ordering', 'data_format'),
42
                 ('b_regularizer', 'bias_regularizer'),
43
                 ('b_constraint', 'bias_constraint'),
44
                 ('bias', 'use_bias')],
45
    value_conversions={'dim_ordering': {'tf': 'channels_last',
46
                                        'th': 'channels_first',
47
                                        'default': None}},
48
    preprocessor=depthwise_conv3d_args_preprocessor)
49
50
51
class DepthwiseConv3D(Conv3D):
52
    """Depthwise 3D convolution.
53
    Depth-wise part of separable convolutions consist in performing
54
    just the first step/operation
55
    (which acts on each input channel separately).
56
    It does not perform the pointwise convolution (second step).
57
    The `depth_multiplier` argument controls how many
58
    output channels are generated per input channel in the depthwise step.
59
    # Arguments
60
        kernel_size: An integer or tuple/list of 3 integers, specifying the
61
            depth, width and height of the 3D convolution window.
62
            Can be a single integer to specify the same value for
63
            all spatial dimensions.
64
        strides: An integer or tuple/list of 3 integers,
65
            specifying the strides of the convolution along the depth, width and height.
66
            Can be a single integer to specify the same value for
67
            all spatial dimensions.
68
        padding: one of `"valid"` or `"same"` (case-insensitive).
69
        depth_multiplier: The number of depthwise convolution output channels
70
            for each input channel.
71
            The total number of depthwise convolution output
72
            channels will be equal to `filterss_in * depth_multiplier`.
73
        groups: The depth size of the convolution (as a variant of the original Depthwise conv)
74
        data_format: A string,
75
            one of `channels_last` (default) or `channels_first`.
76
            The ordering of the dimensions in the inputs.
77
            `channels_last` corresponds to inputs with shape
78
            `(batch, height, width, channels)` while `channels_first`
79
            corresponds to inputs with shape
80
            `(batch, channels, height, width)`.
81
            It defaults to the `image_data_format` value found in your
82
            Keras config file at `~/.keras/keras.json`.
83
            If you never set it, then it will be "channels_last".
84
        activation: Activation function to use
85
            (see [activations](../activations.md)).
86
            If you don't specify anything, no activation is applied
87
            (ie. "linear" activation: `a(x) = x`).
88
        use_bias: Boolean, whether the layer uses a bias vector.
89
        depthwise_initializer: Initializer for the depthwise kernel matrix
90
            (see [initializers](../initializers.md)).
91
        bias_initializer: Initializer for the bias vector
92
            (see [initializers](../initializers.md)).
93
        depthwise_regularizer: Regularizer function applied to
94
            the depthwise kernel matrix
95
            (see [regularizer](../regularizers.md)).
96
        bias_regularizer: Regularizer function applied to the bias vector
97
            (see [regularizer](../regularizers.md)).
98
        dialation_rate: List of ints.
99
                        Defines the dilation factor for each dimension in the
100
                        input. Defaults to (1,1,1)
101
        activity_regularizer: Regularizer function applied to
102
            the output of the layer (its "activation").
103
            (see [regularizer](../regularizers.md)).
104
        depthwise_constraint: Constraint function applied to
105
            the depthwise kernel matrix
106
            (see [constraints](../constraints.md)).
107
        bias_constraint: Constraint function applied to the bias vector
108
            (see [constraints](../constraints.md)).
109
    # Input shape
110
        5D tensor with shape:
111
        `(batch, depth, channels, rows, cols)` if data_format='channels_first'
112
        or 5D tensor with shape:
113
        `(batch, depth, rows, cols, channels)` if data_format='channels_last'.
114
    # Output shape
115
        5D tensor with shape:
116
        `(batch, filters * depth, new_depth, new_rows, new_cols)` if data_format='channels_first'
117
        or 4D tensor with shape:
118
        `(batch, new_depth, new_rows, new_cols, filters * depth)` if data_format='channels_last'.
119
        `rows` and `cols` values might have changed due to padding.
120
    """
121
122
    #@legacy_depthwise_conv3d_support
123
    def __init__(self,
124
                 kernel_size,
125
                 strides=(1, 1, 1),
126
                 padding='valid',
127
                 depth_multiplier=1,
128
                 groups=None,
129
                 data_format=None,
130
                 activation=None,
131
                 use_bias=True,
132
                 depthwise_initializer='glorot_uniform',
133
                 bias_initializer='zeros',
134
                 dilation_rate = (1, 1, 1),
135
                 depthwise_regularizer=None,
136
                 bias_regularizer=None,
137
                 activity_regularizer=None,
138
                 depthwise_constraint=None,
139
                 bias_constraint=None,
140
                 **kwargs):
141
        super(DepthwiseConv3D, self).__init__(
142
            filters=None,
143
            kernel_size=kernel_size,
144
            strides=strides,
145
            padding=padding,
146
            data_format=data_format,
147
            activation=activation,
148
            use_bias=use_bias,
149
            bias_regularizer=bias_regularizer,
150
            dilation_rate=dilation_rate,
151
            activity_regularizer=activity_regularizer,
152
            bias_constraint=bias_constraint,
153
            **kwargs)
154
        self.depth_multiplier = depth_multiplier
155
        self.groups = groups
156
        self.depthwise_initializer = initializers.get(depthwise_initializer)
157
        self.depthwise_regularizer = regularizers.get(depthwise_regularizer)
158
        self.depthwise_constraint = constraints.get(depthwise_constraint)
159
        self.bias_initializer = initializers.get(bias_initializer)
160
        self.dilation_rate = dilation_rate
161
        self._padding = _preprocess_padding(self.padding)
162
        self._strides = (1,) + self.strides + (1,)
163
        self._data_format = "NDHWC"
164
        self.input_dim = None
165
166
    def build(self, input_shape):
167
        if len(input_shape) < 5:
168
            raise ValueError('Inputs to `DepthwiseConv3D` should have rank 5. '
169
                             'Received input shape:', str(input_shape))
170
        if self.data_format == 'channels_first':
171
            channel_axis = 1
172
        else:
173
            channel_axis = -1
174
        if input_shape[channel_axis] is None:
175
            raise ValueError('The channel dimension of the inputs to '
176
                             '`DepthwiseConv3D` '
177
                             'should be defined. Found `None`.')
178
        self.input_dim = int(input_shape[channel_axis])
179
180
        if self.groups is None:
181
            self.groups = self.input_dim
182
183
        if self.groups > self.input_dim:
184
            raise ValueError('The number of groups cannot exceed the number of channels')
185
186
        if self.input_dim % self.groups != 0:
187
            raise ValueError('Warning! The channels dimension is not divisible by the group size chosen')
188
189
        depthwise_kernel_shape = (self.kernel_size[0],
190
                                  self.kernel_size[1],
191
                                  self.kernel_size[2],
192
                                  self.input_dim,
193
                                  self.depth_multiplier)
194
195
        self.depthwise_kernel = self.add_weight(
196
            shape=depthwise_kernel_shape,
197
            initializer=self.depthwise_initializer,
198
            name='depthwise_kernel',
199
            regularizer=self.depthwise_regularizer,
200
            constraint=self.depthwise_constraint)
201
202
        if self.use_bias:
203
            self.bias = self.add_weight(shape=(self.groups * self.depth_multiplier,),
204
                                        initializer=self.bias_initializer,
205
                                        name='bias',
206
                                        regularizer=self.bias_regularizer,
207
                                        constraint=self.bias_constraint)
208
        else:
209
            self.bias = None
210
        # Set input spec.
211
        self.input_spec = InputSpec(ndim=5, axes={channel_axis: self.input_dim})
212
        self.built = True
213
214
    def call(self, inputs, training=None):
215
        inputs = _preprocess_conv3d_input(inputs, self.data_format)
216
217
        if self.data_format == 'channels_last':
218
            dilation = (1,) + self.dilation_rate + (1,)
219
        else:
220
            dilation = self.dilation_rate + (1,) + (1,)
221
222
        if self._data_format == 'NCDHW':
223
            outputs = tf.concat(
224
                [tf.nn.conv3d(inputs[0][:, i:i+self.input_dim//self.groups, :, :, :], self.depthwise_kernel[:, :, :, i:i+self.input_dim//self.groups, :],
225
                    strides=self._strides,
226
                    padding=self._padding,
227
                    dilations=dilation,
228
                    data_format=self._data_format) for i in range(0, self.input_dim, self.input_dim//self.groups)], axis=1)
229
230
        else:
231
            outputs = tf.concat(
232
                [tf.nn.conv3d(inputs[0][:, :, :, :, i:i+self.input_dim//self.groups], self.depthwise_kernel[:, :, :, i:i+self.input_dim//self.groups, :],
233
                    strides=self._strides,
234
                    padding=self._padding,
235
                    dilations=dilation,
236
                    data_format=self._data_format) for i in range(0, self.input_dim, self.input_dim//self.groups)], axis=-1)
237
238
        if self.bias is not None:
239
            outputs = K.bias_add(
240
                outputs,
241
                self.bias,
242
                data_format=self.data_format)
243
244
        if self.activation is not None:
245
            return self.activation(outputs)
246
247
        return outputs
248
249
    def compute_output_shape(self, input_shape):
250
        if self.data_format == 'channels_first':
251
            depth = input_shape[2]
252
            rows = input_shape[3]
253
            cols = input_shape[4]
254
            out_filters = self.groups * self.depth_multiplier
255
        elif self.data_format == 'channels_last':
256
            depth = input_shape[1]
257
            rows = input_shape[2]
258
            cols = input_shape[3]
259
            out_filters = self.groups * self.depth_multiplier
260
261
        depth = conv_utils.conv_output_length(depth, self.kernel_size[0],
0 ignored issues
show
introduced by
The variable depth does not seem to be defined for all execution paths.
Loading history...
262
                                             self.padding,
263
                                             self.strides[0])
264
265
        rows = conv_utils.conv_output_length(rows, self.kernel_size[1],
0 ignored issues
show
introduced by
The variable rows does not seem to be defined for all execution paths.
Loading history...
266
                                             self.padding,
267
                                             self.strides[1])
268
269
        cols = conv_utils.conv_output_length(cols, self.kernel_size[2],
0 ignored issues
show
introduced by
The variable cols does not seem to be defined for all execution paths.
Loading history...
270
                                             self.padding,
271
                                             self.strides[2])
272
273
        if self.data_format == 'channels_first':
274
            return (input_shape[0], out_filters, depth, rows, cols)
0 ignored issues
show
introduced by
The variable out_filters does not seem to be defined for all execution paths.
Loading history...
275
276
        elif self.data_format == 'channels_last':
277
            return (input_shape[0], depth, rows, cols, out_filters)
278
279
    def get_config(self):
280
        config = super(DepthwiseConv3D, self).get_config()
281
        config.pop('filters')
282
        config.pop('kernel_initializer')
283
        config.pop('kernel_regularizer')
284
        config.pop('kernel_constraint')
285
        config['depth_multiplier'] = self.depth_multiplier
286
        config['depthwise_initializer'] = initializers.serialize(self.depthwise_initializer)
287
        config['depthwise_regularizer'] = regularizers.serialize(self.depthwise_regularizer)
288
        config['depthwise_constraint'] = constraints.serialize(self.depthwise_constraint)
289
        return config
290
291
DepthwiseConvolution3D = DepthwiseConv3D