| 1 |  |  | # https://github.com/alexandrosstergiou/keras-DepthwiseConv3D | 
            
                                                                                                            
                            
            
                                    
            
            
                | 2 |  |  | # -*- coding: utf-8 -*- | 
            
                                                                                                            
                            
            
                                    
            
            
                | 3 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 4 |  |  | ''' | 
            
                                                                                                            
                            
            
                                    
            
            
                | 5 |  |  | This is a modification of the SeparableConv3D code in Keras, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 6 |  |  | to perform just the Depthwise Convolution (1st step) of the | 
            
                                                                                                            
                            
            
                                    
            
            
                | 7 |  |  | Depthwise Separable Convolution layer. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 8 |  |  | ''' | 
            
                                                                                                            
                            
            
                                    
            
            
                | 9 |  |  | from __future__ import absolute_import | 
            
                                                                                                            
                            
            
                                    
            
            
                | 10 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 11 |  |  | from keras import backend as K | 
            
                                                                                                            
                            
            
                                    
            
            
                | 12 |  |  | from keras import initializers | 
            
                                                                                                            
                            
            
                                    
            
            
                | 13 |  |  | from keras import regularizers | 
            
                                                                                                            
                            
            
                                    
            
            
                | 14 |  |  | from keras import constraints | 
            
                                                                                                            
                            
            
                                    
            
            
                | 15 |  |  | from keras import layers | 
            
                                                                                                            
                            
            
                                    
            
            
                | 16 |  |  | from keras.engine import InputSpec | 
            
                                                                                                            
                            
            
                                    
            
            
                | 17 |  |  | from keras.utils import conv_utils | 
            
                                                                                                            
                            
            
                                    
            
            
                | 18 |  |  | from keras.legacy.interfaces import conv3d_args_preprocessor, generate_legacy_interface | 
            
                                                                                                            
                            
            
                                    
            
            
                | 19 |  |  | from keras.layers import Conv3D | 
            
                                                                                                            
                            
            
                                    
            
            
                | 20 |  |  | from keras.backend.tensorflow_backend import _preprocess_padding, _preprocess_conv3d_input | 
            
                                                                                                            
                            
            
                                    
            
            
                | 21 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 22 |  |  | import tensorflow as tf | 
            
                                                                                                            
                            
            
                                    
            
            
                | 23 |  |  |  | 
            
                                                                                                            
                                                                
            
                                    
            
            
                | 24 |  |  |  | 
            
                                                                        
                            
            
                                    
            
            
                | 25 |  |  | def depthwise_conv3d_args_preprocessor(args, kwargs): | 
            
                                                                        
                            
            
                                    
            
            
                | 26 |  |  |     converted = [] | 
            
                                                                        
                            
            
                                    
            
            
                | 27 |  |  |  | 
            
                                                                        
                            
            
                                    
            
            
                | 28 |  |  |     if 'init' in kwargs: | 
            
                                                                        
                            
            
                                    
            
            
                | 29 |  |  |         init = kwargs.pop('init') | 
            
                                                                        
                            
            
                                    
            
            
                | 30 |  |  |         kwargs['depthwise_initializer'] = init | 
            
                                                                        
                            
            
                                    
            
            
                | 31 |  |  |         converted.append(('init', 'depthwise_initializer')) | 
            
                                                                        
                            
            
                                    
            
            
                | 32 |  |  |  | 
            
                                                                        
                            
            
                                    
            
            
                | 33 |  |  |     args, kwargs, _converted = conv3d_args_preprocessor(args, kwargs) | 
            
                                                                        
                            
            
                                    
            
            
                | 34 |  |  |     return args, kwargs, converted + _converted | 
            
                                                                        
                            
            
                                    
            
            
                | 35 |  |  |  | 
            
                                                                        
                            
            
                                    
            
            
                | 36 |  |  |     legacy_depthwise_conv3d_support = generate_legacy_interface( | 
            
                                                                        
                            
            
                                    
            
            
                | 37 |  |  |     allowed_positional_args=['filters', 'kernel_size'], | 
            
                                                                        
                            
            
                                    
            
            
                | 38 |  |  |     conversions=[('nb_filter', 'filters'), | 
            
                                                                        
                            
            
                                    
            
            
                | 39 |  |  |                  ('subsample', 'strides'), | 
            
                                                                        
                            
            
                                    
            
            
                | 40 |  |  |                  ('border_mode', 'padding'), | 
            
                                                                        
                            
            
                                    
            
            
                | 41 |  |  |                  ('dim_ordering', 'data_format'), | 
            
                                                                        
                            
            
                                    
            
            
                | 42 |  |  |                  ('b_regularizer', 'bias_regularizer'), | 
            
                                                                        
                            
            
                                    
            
            
                | 43 |  |  |                  ('b_constraint', 'bias_constraint'), | 
            
                                                                        
                            
            
                                    
            
            
                | 44 |  |  |                  ('bias', 'use_bias')], | 
            
                                                                        
                            
            
                                    
            
            
                | 45 |  |  |     value_conversions={'dim_ordering': {'tf': 'channels_last', | 
            
                                                                        
                            
            
                                    
            
            
                | 46 |  |  |                                         'th': 'channels_first', | 
            
                                                                        
                            
            
                                    
            
            
                | 47 |  |  |                                         'default': None}}, | 
            
                                                                        
                            
            
                                    
            
            
                | 48 |  |  |     preprocessor=depthwise_conv3d_args_preprocessor) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 49 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 50 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 51 |  |  | class DepthwiseConv3D(Conv3D): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 52 |  |  |     """Depthwise 3D convolution. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 53 |  |  |     Depth-wise part of separable convolutions consist in performing | 
            
                                                                                                            
                            
            
                                    
            
            
                | 54 |  |  |     just the first step/operation | 
            
                                                                                                            
                            
            
                                    
            
            
                | 55 |  |  |     (which acts on each input channel separately). | 
            
                                                                                                            
                            
            
                                    
            
            
                | 56 |  |  |     It does not perform the pointwise convolution (second step). | 
            
                                                                                                            
                            
            
                                    
            
            
                | 57 |  |  |     The `depth_multiplier` argument controls how many | 
            
                                                                                                            
                            
            
                                    
            
            
                | 58 |  |  |     output channels are generated per input channel in the depthwise step. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 59 |  |  |     # Arguments | 
            
                                                                                                            
                            
            
                                    
            
            
                | 60 |  |  |         kernel_size: An integer or tuple/list of 3 integers, specifying the | 
            
                                                                                                            
                            
            
                                    
            
            
                | 61 |  |  |             depth, width and height of the 3D convolution window. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 62 |  |  |             Can be a single integer to specify the same value for | 
            
                                                                                                            
                            
            
                                    
            
            
                | 63 |  |  |             all spatial dimensions. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 64 |  |  |         strides: An integer or tuple/list of 3 integers, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 65 |  |  |             specifying the strides of the convolution along the depth, width and height. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 66 |  |  |             Can be a single integer to specify the same value for | 
            
                                                                                                            
                            
            
                                    
            
            
                | 67 |  |  |             all spatial dimensions. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 68 |  |  |         padding: one of `"valid"` or `"same"` (case-insensitive). | 
            
                                                                                                            
                            
            
                                    
            
            
                | 69 |  |  |         depth_multiplier: The number of depthwise convolution output channels | 
            
                                                                                                            
                            
            
                                    
            
            
                | 70 |  |  |             for each input channel. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 71 |  |  |             The total number of depthwise convolution output | 
            
                                                                                                            
                            
            
                                    
            
            
                | 72 |  |  |             channels will be equal to `filterss_in * depth_multiplier`. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 73 |  |  |         groups: The depth size of the convolution (as a variant of the original Depthwise conv) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 74 |  |  |         data_format: A string, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 75 |  |  |             one of `channels_last` (default) or `channels_first`. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 76 |  |  |             The ordering of the dimensions in the inputs. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 77 |  |  |             `channels_last` corresponds to inputs with shape | 
            
                                                                                                            
                            
            
                                    
            
            
                | 78 |  |  |             `(batch, height, width, channels)` while `channels_first` | 
            
                                                                                                            
                            
            
                                    
            
            
                | 79 |  |  |             corresponds to inputs with shape | 
            
                                                                                                            
                            
            
                                    
            
            
                | 80 |  |  |             `(batch, channels, height, width)`. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 81 |  |  |             It defaults to the `image_data_format` value found in your | 
            
                                                                                                            
                            
            
                                    
            
            
                | 82 |  |  |             Keras config file at `~/.keras/keras.json`. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 83 |  |  |             If you never set it, then it will be "channels_last". | 
            
                                                                                                            
                            
            
                                    
            
            
                | 84 |  |  |         activation: Activation function to use | 
            
                                                                                                            
                            
            
                                    
            
            
                | 85 |  |  |             (see [activations](../activations.md)). | 
            
                                                                                                            
                            
            
                                    
            
            
                | 86 |  |  |             If you don't specify anything, no activation is applied | 
            
                                                                                                            
                            
            
                                    
            
            
                | 87 |  |  |             (ie. "linear" activation: `a(x) = x`). | 
            
                                                                                                            
                            
            
                                    
            
            
                | 88 |  |  |         use_bias: Boolean, whether the layer uses a bias vector. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 89 |  |  |         depthwise_initializer: Initializer for the depthwise kernel matrix | 
            
                                                                                                            
                            
            
                                    
            
            
                | 90 |  |  |             (see [initializers](../initializers.md)). | 
            
                                                                                                            
                            
            
                                    
            
            
                | 91 |  |  |         bias_initializer: Initializer for the bias vector | 
            
                                                                                                            
                            
            
                                    
            
            
                | 92 |  |  |             (see [initializers](../initializers.md)). | 
            
                                                                                                            
                            
            
                                    
            
            
                | 93 |  |  |         depthwise_regularizer: Regularizer function applied to | 
            
                                                                                                            
                            
            
                                    
            
            
                | 94 |  |  |             the depthwise kernel matrix | 
            
                                                                                                            
                            
            
                                    
            
            
                | 95 |  |  |             (see [regularizer](../regularizers.md)). | 
            
                                                                                                            
                            
            
                                    
            
            
                | 96 |  |  |         bias_regularizer: Regularizer function applied to the bias vector | 
            
                                                                                                            
                            
            
                                    
            
            
                | 97 |  |  |             (see [regularizer](../regularizers.md)). | 
            
                                                                                                            
                            
            
                                    
            
            
                | 98 |  |  |         dialation_rate: List of ints. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 99 |  |  |                         Defines the dilation factor for each dimension in the | 
            
                                                                                                            
                            
            
                                    
            
            
                | 100 |  |  |                         input. Defaults to (1,1,1) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 101 |  |  |         activity_regularizer: Regularizer function applied to | 
            
                                                                                                            
                            
            
                                    
            
            
                | 102 |  |  |             the output of the layer (its "activation"). | 
            
                                                                                                            
                            
            
                                    
            
            
                | 103 |  |  |             (see [regularizer](../regularizers.md)). | 
            
                                                                                                            
                            
            
                                    
            
            
                | 104 |  |  |         depthwise_constraint: Constraint function applied to | 
            
                                                                                                            
                            
            
                                    
            
            
                | 105 |  |  |             the depthwise kernel matrix | 
            
                                                                                                            
                            
            
                                    
            
            
                | 106 |  |  |             (see [constraints](../constraints.md)). | 
            
                                                                                                            
                            
            
                                    
            
            
                | 107 |  |  |         bias_constraint: Constraint function applied to the bias vector | 
            
                                                                                                            
                            
            
                                    
            
            
                | 108 |  |  |             (see [constraints](../constraints.md)). | 
            
                                                                                                            
                            
            
                                    
            
            
                | 109 |  |  |     # Input shape | 
            
                                                                                                            
                            
            
                                    
            
            
                | 110 |  |  |         5D tensor with shape: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 111 |  |  |         `(batch, depth, channels, rows, cols)` if data_format='channels_first' | 
            
                                                                                                            
                            
            
                                    
            
            
                | 112 |  |  |         or 5D tensor with shape: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 113 |  |  |         `(batch, depth, rows, cols, channels)` if data_format='channels_last'. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 114 |  |  |     # Output shape | 
            
                                                                                                            
                            
            
                                    
            
            
                | 115 |  |  |         5D tensor with shape: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 116 |  |  |         `(batch, filters * depth, new_depth, new_rows, new_cols)` if data_format='channels_first' | 
            
                                                                                                            
                            
            
                                    
            
            
                | 117 |  |  |         or 4D tensor with shape: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 118 |  |  |         `(batch, new_depth, new_rows, new_cols, filters * depth)` if data_format='channels_last'. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 119 |  |  |         `rows` and `cols` values might have changed due to padding. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 120 |  |  |     """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 121 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 122 |  |  |     #@legacy_depthwise_conv3d_support | 
            
                                                                                                            
                            
            
                                    
            
            
                | 123 |  |  |     def __init__(self, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 124 |  |  |                  kernel_size, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 125 |  |  |                  strides=(1, 1, 1), | 
            
                                                                                                            
                            
            
                                    
            
            
                | 126 |  |  |                  padding='valid', | 
            
                                                                                                            
                            
            
                                    
            
            
                | 127 |  |  |                  depth_multiplier=1, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 128 |  |  |                  groups=None, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 129 |  |  |                  data_format=None, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 130 |  |  |                  activation=None, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 131 |  |  |                  use_bias=True, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 132 |  |  |                  depthwise_initializer='glorot_uniform', | 
            
                                                                                                            
                            
            
                                    
            
            
                | 133 |  |  |                  bias_initializer='zeros', | 
            
                                                                                                            
                            
            
                                    
            
            
                | 134 |  |  |                  dilation_rate = (1, 1, 1), | 
            
                                                                                                            
                            
            
                                    
            
            
                | 135 |  |  |                  depthwise_regularizer=None, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 136 |  |  |                  bias_regularizer=None, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 137 |  |  |                  activity_regularizer=None, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 138 |  |  |                  depthwise_constraint=None, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 139 |  |  |                  bias_constraint=None, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 140 |  |  |                  **kwargs): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 141 |  |  |         super(DepthwiseConv3D, self).__init__( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 142 |  |  |             filters=None, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 143 |  |  |             kernel_size=kernel_size, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 144 |  |  |             strides=strides, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 145 |  |  |             padding=padding, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 146 |  |  |             data_format=data_format, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 147 |  |  |             activation=activation, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 148 |  |  |             use_bias=use_bias, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 149 |  |  |             bias_regularizer=bias_regularizer, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 150 |  |  |             dilation_rate=dilation_rate, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 151 |  |  |             activity_regularizer=activity_regularizer, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 152 |  |  |             bias_constraint=bias_constraint, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 153 |  |  |             **kwargs) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 154 |  |  |         self.depth_multiplier = depth_multiplier | 
            
                                                                                                            
                            
            
                                    
            
            
                | 155 |  |  |         self.groups = groups | 
            
                                                                                                            
                            
            
                                    
            
            
                | 156 |  |  |         self.depthwise_initializer = initializers.get(depthwise_initializer) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 157 |  |  |         self.depthwise_regularizer = regularizers.get(depthwise_regularizer) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 158 |  |  |         self.depthwise_constraint = constraints.get(depthwise_constraint) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 159 |  |  |         self.bias_initializer = initializers.get(bias_initializer) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 160 |  |  |         self.dilation_rate = dilation_rate | 
            
                                                                                                            
                            
            
                                    
            
            
                | 161 |  |  |         self._padding = _preprocess_padding(self.padding) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 162 |  |  |         self._strides = (1,) + self.strides + (1,) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 163 |  |  |         self._data_format = "NDHWC" | 
            
                                                                                                            
                            
            
                                    
            
            
                | 164 |  |  |         self.input_dim = None | 
            
                                                                                                            
                            
            
                                    
            
            
                | 165 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 166 |  |  |     def build(self, input_shape): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 167 |  |  |         if len(input_shape) < 5: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 168 |  |  |             raise ValueError('Inputs to `DepthwiseConv3D` should have rank 5. ' | 
            
                                                                                                            
                            
            
                                    
            
            
                | 169 |  |  |                              'Received input shape:', str(input_shape)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 170 |  |  |         if self.data_format == 'channels_first': | 
            
                                                                                                            
                            
            
                                    
            
            
                | 171 |  |  |             channel_axis = 1 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 172 |  |  |         else: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 173 |  |  |             channel_axis = -1 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 174 |  |  |         if input_shape[channel_axis] is None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 175 |  |  |             raise ValueError('The channel dimension of the inputs to ' | 
            
                                                                                                            
                            
            
                                    
            
            
                | 176 |  |  |                              '`DepthwiseConv3D` ' | 
            
                                                                                                            
                            
            
                                    
            
            
                | 177 |  |  |                              'should be defined. Found `None`.') | 
            
                                                                                                            
                            
            
                                    
            
            
                | 178 |  |  |         self.input_dim = int(input_shape[channel_axis]) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 179 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 180 |  |  |         if self.groups is None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 181 |  |  |             self.groups = self.input_dim | 
            
                                                                                                            
                            
            
                                    
            
            
                | 182 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 183 |  |  |         if self.groups > self.input_dim: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 184 |  |  |             raise ValueError('The number of groups cannot exceed the number of channels') | 
            
                                                                                                            
                            
            
                                    
            
            
                | 185 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 186 |  |  |         if self.input_dim % self.groups != 0: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 187 |  |  |             raise ValueError('Warning! The channels dimension is not divisible by the group size chosen') | 
            
                                                                                                            
                            
            
                                    
            
            
                | 188 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 189 |  |  |         depthwise_kernel_shape = (self.kernel_size[0], | 
            
                                                                                                            
                            
            
                                    
            
            
                | 190 |  |  |                                   self.kernel_size[1], | 
            
                                                                                                            
                            
            
                                    
            
            
                | 191 |  |  |                                   self.kernel_size[2], | 
            
                                                                                                            
                            
            
                                    
            
            
                | 192 |  |  |                                   self.input_dim, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 193 |  |  |                                   self.depth_multiplier) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 194 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 195 |  |  |         self.depthwise_kernel = self.add_weight( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 196 |  |  |             shape=depthwise_kernel_shape, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 197 |  |  |             initializer=self.depthwise_initializer, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 198 |  |  |             name='depthwise_kernel', | 
            
                                                                                                            
                            
            
                                    
            
            
                | 199 |  |  |             regularizer=self.depthwise_regularizer, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 200 |  |  |             constraint=self.depthwise_constraint) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 201 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 202 |  |  |         if self.use_bias: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 203 |  |  |             self.bias = self.add_weight(shape=(self.groups * self.depth_multiplier,), | 
            
                                                                                                            
                            
            
                                    
            
            
                | 204 |  |  |                                         initializer=self.bias_initializer, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 205 |  |  |                                         name='bias', | 
            
                                                                                                            
                            
            
                                    
            
            
                | 206 |  |  |                                         regularizer=self.bias_regularizer, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 207 |  |  |                                         constraint=self.bias_constraint) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 208 |  |  |         else: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 209 |  |  |             self.bias = None | 
            
                                                                                                            
                            
            
                                    
            
            
                | 210 |  |  |         # Set input spec. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 211 |  |  |         self.input_spec = InputSpec(ndim=5, axes={channel_axis: self.input_dim}) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 212 |  |  |         self.built = True | 
            
                                                                                                            
                            
            
                                    
            
            
                | 213 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 214 |  |  |     def call(self, inputs, training=None): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 215 |  |  |         inputs = _preprocess_conv3d_input(inputs, self.data_format) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 216 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 217 |  |  |         if self.data_format == 'channels_last': | 
            
                                                                                                            
                            
            
                                    
            
            
                | 218 |  |  |             dilation = (1,) + self.dilation_rate + (1,) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 219 |  |  |         else: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 220 |  |  |             dilation = self.dilation_rate + (1,) + (1,) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 221 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 222 |  |  |         if self._data_format == 'NCDHW': | 
            
                                                                                                            
                            
            
                                    
            
            
                | 223 |  |  |             outputs = tf.concat( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 224 |  |  |                 [tf.nn.conv3d(inputs[0][:, i:i+self.input_dim//self.groups, :, :, :], self.depthwise_kernel[:, :, :, i:i+self.input_dim//self.groups, :], | 
            
                                                                                                            
                            
            
                                    
            
            
                | 225 |  |  |                     strides=self._strides, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 226 |  |  |                     padding=self._padding, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 227 |  |  |                     dilations=dilation, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 228 |  |  |                     data_format=self._data_format) for i in range(0, self.input_dim, self.input_dim//self.groups)], axis=1) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 229 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 230 |  |  |         else: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 231 |  |  |             outputs = tf.concat( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 232 |  |  |                 [tf.nn.conv3d(inputs[0][:, :, :, :, i:i+self.input_dim//self.groups], self.depthwise_kernel[:, :, :, i:i+self.input_dim//self.groups, :], | 
            
                                                                                                            
                            
            
                                    
            
            
                | 233 |  |  |                     strides=self._strides, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 234 |  |  |                     padding=self._padding, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 235 |  |  |                     dilations=dilation, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 236 |  |  |                     data_format=self._data_format) for i in range(0, self.input_dim, self.input_dim//self.groups)], axis=-1) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 237 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 238 |  |  |         if self.bias is not None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 239 |  |  |             outputs = K.bias_add( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 240 |  |  |                 outputs, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 241 |  |  |                 self.bias, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 242 |  |  |                 data_format=self.data_format) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 243 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 244 |  |  |         if self.activation is not None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 245 |  |  |             return self.activation(outputs) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 246 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 247 |  |  |         return outputs | 
            
                                                                                                            
                            
            
                                    
            
            
                | 248 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 249 |  |  |     def compute_output_shape(self, input_shape): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 250 |  |  |         if self.data_format == 'channels_first': | 
            
                                                                                                            
                            
            
                                    
            
            
                | 251 |  |  |             depth = input_shape[2] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 252 |  |  |             rows = input_shape[3] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 253 |  |  |             cols = input_shape[4] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 254 |  |  |             out_filters = self.groups * self.depth_multiplier | 
            
                                                                                                            
                            
            
                                    
            
            
                | 255 |  |  |         elif self.data_format == 'channels_last': | 
            
                                                                                                            
                            
            
                                    
            
            
                | 256 |  |  |             depth = input_shape[1] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 257 |  |  |             rows = input_shape[2] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 258 |  |  |             cols = input_shape[3] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 259 |  |  |             out_filters = self.groups * self.depth_multiplier | 
            
                                                                                                            
                            
            
                                    
            
            
                | 260 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 261 |  |  |         depth = conv_utils.conv_output_length(depth, self.kernel_size[0], | 
                            
                    |  |  |  | 
                                                                                        
                                                                                     | 
            
                                                                                                            
                            
            
                                    
            
            
                | 262 |  |  |                                              self.padding, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 263 |  |  |                                              self.strides[0]) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 264 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 265 |  |  |         rows = conv_utils.conv_output_length(rows, self.kernel_size[1], | 
                            
                    |  |  |  | 
                                                                                        
                                                                                     | 
            
                                                                                                            
                            
            
                                    
            
            
                | 266 |  |  |                                              self.padding, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 267 |  |  |                                              self.strides[1]) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 268 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 269 |  |  |         cols = conv_utils.conv_output_length(cols, self.kernel_size[2], | 
                            
                    |  |  |  | 
                                                                                        
                                                                                     | 
            
                                                                                                            
                            
            
                                    
            
            
                | 270 |  |  |                                              self.padding, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 271 |  |  |                                              self.strides[2]) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 272 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 273 |  |  |         if self.data_format == 'channels_first': | 
            
                                                                                                            
                            
            
                                    
            
            
                | 274 |  |  |             return (input_shape[0], out_filters, depth, rows, cols) | 
                            
                    |  |  |  | 
                                                                                        
                                                                                     | 
            
                                                                                                            
                            
            
                                    
            
            
                | 275 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 276 |  |  |         elif self.data_format == 'channels_last': | 
            
                                                                                                            
                            
            
                                    
            
            
                | 277 |  |  |             return (input_shape[0], depth, rows, cols, out_filters) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 278 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 279 |  |  |     def get_config(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 280 |  |  |         config = super(DepthwiseConv3D, self).get_config() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 281 |  |  |         config.pop('filters') | 
            
                                                                                                            
                            
            
                                    
            
            
                | 282 |  |  |         config.pop('kernel_initializer') | 
            
                                                                                                            
                            
            
                                    
            
            
                | 283 |  |  |         config.pop('kernel_regularizer') | 
            
                                                                                                            
                            
            
                                    
            
            
                | 284 |  |  |         config.pop('kernel_constraint') | 
            
                                                                                                            
                            
            
                                    
            
            
                | 285 |  |  |         config['depth_multiplier'] = self.depth_multiplier | 
            
                                                                                                            
                            
            
                                    
            
            
                | 286 |  |  |         config['depthwise_initializer'] = initializers.serialize(self.depthwise_initializer) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 287 |  |  |         config['depthwise_regularizer'] = regularizers.serialize(self.depthwise_regularizer) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 288 |  |  |         config['depthwise_constraint'] = constraints.serialize(self.depthwise_constraint) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 289 |  |  |         return config | 
            
                                                                                                            
                                                                
            
                                    
            
            
                | 290 |  |  |  | 
            
                                                        
            
                                    
            
            
                | 291 |  |  | DepthwiseConvolution3D = DepthwiseConv3D |