Passed
Push — develop ( f534b1...a82689 )
by Plexxi
06:09 queued 03:13
created

db_ensure_indexes()   F

Complexity

Conditions 9

Size

Total Lines 48

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 9
c 0
b 0
f 0
dl 0
loc 48
rs 3.5294
1
# Licensed to the StackStorm, Inc ('StackStorm') under one or more
2
# contributor license agreements.  See the NOTICE file distributed with
3
# this work for additional information regarding copyright ownership.
4
# The ASF licenses this file to You under the Apache License, Version 2.0
5
# (the "License"); you may not use this file except in compliance with
6
# the License.  You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
import copy
17
import importlib
18
import traceback
19
import ssl as ssl_lib
20
21
import six
22
import mongoengine
23
from pymongo.errors import OperationFailure
24
25
from st2common import log as logging
26
from st2common.util import isotime
27
from st2common.models.db import stormbase
28
from st2common.models.utils.profiling import log_query_and_profile_data_for_queryset
29
from st2common.exceptions.db import StackStormDBObjectNotFoundError
30
31
32
LOG = logging.getLogger(__name__)
33
34
MODEL_MODULE_NAMES = [
35
    'st2common.models.db.auth',
36
    'st2common.models.db.action',
37
    'st2common.models.db.actionalias',
38
    'st2common.models.db.keyvalue',
39
    'st2common.models.db.execution',
40
    'st2common.models.db.executionstate',
41
    'st2common.models.db.liveaction',
42
    'st2common.models.db.notification',
43
    'st2common.models.db.pack',
44
    'st2common.models.db.policy',
45
    'st2common.models.db.rbac',
46
    'st2common.models.db.rule',
47
    'st2common.models.db.rule_enforcement',
48
    'st2common.models.db.runner',
49
    'st2common.models.db.sensor',
50
    'st2common.models.db.trace',
51
    'st2common.models.db.trigger',
52
    'st2common.models.db.webhook'
53
]
54
55
# A list of model names for which we don't perform extra index cleanup
56
INDEX_CLEANUP_MODEL_NAMES_BLACKLIST = [
57
    'PermissionGrantDB'
58
]
59
60
61
def get_model_classes():
62
    """
63
    Retrieve a list of all the defined model classes.
64
65
    :rtype: ``list``
66
    """
67
    result = []
68
    for module_name in MODEL_MODULE_NAMES:
69
        module = importlib.import_module(module_name)
70
        model_classes = getattr(module, 'MODELS', [])
71
        result.extend(model_classes)
72
73
    return result
74
75
76
def db_setup(db_name, db_host, db_port, username=None, password=None, ensure_indexes=True,
77
             ssl=False, ssl_keyfile=None, ssl_certfile=None,
78
             ssl_cert_reqs=None, ssl_ca_certs=None, ssl_match_hostname=True):
79
    LOG.info('Connecting to database "%s" @ "%s:%s" as user "%s".',
80
             db_name, db_host, db_port, str(username))
81
82
    ssl_kwargs = _get_ssl_kwargs(ssl=ssl, ssl_keyfile=ssl_keyfile, ssl_certfile=ssl_certfile,
83
                                 ssl_cert_reqs=ssl_cert_reqs, ssl_ca_certs=ssl_ca_certs,
84
                                 ssl_match_hostname=ssl_match_hostname)
85
86
    connection = mongoengine.connection.connect(db_name, host=db_host,
87
                                                port=db_port, tz_aware=True,
88
                                                username=username, password=password,
89
                                                **ssl_kwargs)
90
91
    # Create all the indexes upfront to prevent race-conditions caused by
92
    # lazy index creation
93
    if ensure_indexes:
94
        db_ensure_indexes()
95
96
    return connection
97
98
99
def db_ensure_indexes():
100
    """
101
    This function ensures that indexes for all the models have been created and the
102
    extra indexes cleaned up.
103
104
    Note #1: When calling this method database connection already needs to be
105
    established.
106
107
    Note #2: This method blocks until all the index have been created (indexes
108
    are created in real-time and not in background).
109
    """
110
    LOG.debug('Ensuring database indexes...')
111
    model_classes = get_model_classes()
112
113
    for model_class in model_classes:
114
        class_name = model_class.__name__
115
116
        # Note: We need to ensure / create new indexes before removing extra ones
117
        try:
118
            model_class.ensure_indexes()
119
        except OperationFailure as e:
120
            # Special case for "uid" index. MongoDB 3.4 has dropped "_types" index option so we
121
            # need to re-create the index to make it work and avoid "index with different options
122
            # already exists" error.
123
            # Note: This condition would only be encountered when upgrading existing StackStorm
124
            # installation from MongoDB 3.2 to 3.4.
125
            msg = str(e)
126
            if 'already exists with different options' in msg and 'uid_1' in msg:
127
                drop_obsolete_types_indexes(model_class=model_class)
128
            else:
129
                raise e
130
        except Exception as e:
131
            tb_msg = traceback.format_exc()
132
            msg = 'Failed to ensure indexes for model "%s": %s' % (class_name, str(e))
133
            msg += '\n\n' + tb_msg
134
            exc_cls = type(e)
135
            raise exc_cls(msg)
136
137
        if model_class.__name__ in INDEX_CLEANUP_MODEL_NAMES_BLACKLIST:
138
            LOG.debug('Skipping index cleanup for blacklisted model "%s"...' % (class_name))
139
            continue
140
141
        removed_count = cleanup_extra_indexes(model_class=model_class)
142
        if removed_count:
143
            LOG.debug('Removed "%s" extra indexes for model "%s"' % (removed_count, class_name))
144
145
    LOG.debug('Indexes are ensured for models: %s' %
146
              ', '.join(sorted((model_class.__name__ for model_class in model_classes))))
147
148
149
def cleanup_extra_indexes(model_class):
150
    """
151
    Finds any extra indexes and removes those from mongodb.
152
    """
153
    extra_indexes = model_class.compare_indexes().get('extra', None)
154
    if not extra_indexes:
155
        return 0
156
157
    # mongoengine does not have the necessary method so we need to drop to
158
    # pymongo interfaces via some private methods.
159
    removed_count = 0
160
    c = model_class._get_collection()
161
    for extra_index in extra_indexes:
162
        try:
163
            c.drop_index(extra_index)
164
            LOG.debug('Dropped index %s for model %s.', extra_index, model_class.__name__)
165
            removed_count += 1
166
        except OperationFailure:
167
            LOG.warning('Attempt to cleanup index %s failed.', extra_index, exc_info=True)
168
169
    return removed_count
170
171
172
def drop_obsolete_types_indexes(model_class):
173
    """
174
    Special class for droping offending "types" indexes for which support has
175
    been removed in mongoengine and MongoDB 3.4.
176
    For more info, see: http://docs.mongoengine.org/upgrade.html#inheritance
177
    """
178
    class_name = model_class.__name__
179
180
    LOG.debug('Dropping obsolete types index for model "%s"' % (class_name))
181
    collection = model_class._get_collection()
182
    collection.update({}, {'$unset': {'_types': 1}}, multi=True)
183
184
    info = collection.index_information()
185
    indexes_to_drop = [key for key, value in info.iteritems()
186
                       if '_types' in dict(value['key']) or 'types' in value]
187
188
    LOG.debug('Will drop obsolete types indexes for model "%s": %s' % (class_name,
189
                                                                       str(indexes_to_drop)))
190
191
    for index in indexes_to_drop:
192
        collection.drop_index(index)
193
194
    LOG.debug('Recreating indexes for model "%s"' % (class_name))
195
    model_class.ensure_indexes()
196
197
198
def db_teardown():
199
    mongoengine.connection.disconnect()
200
201
202
def _get_ssl_kwargs(ssl=False, ssl_keyfile=None, ssl_certfile=None, ssl_cert_reqs=None,
203
                    ssl_ca_certs=None, ssl_match_hostname=True):
204
    ssl_kwargs = {
205
        'ssl': ssl,
206
    }
207
    if ssl_keyfile:
208
        ssl_kwargs['ssl'] = True
209
        ssl_kwargs['ssl_keyfile'] = ssl_keyfile
210
    if ssl_certfile:
211
        ssl_kwargs['ssl'] = True
212
        ssl_kwargs['ssl_certfile'] = ssl_certfile
213
    if ssl_cert_reqs:
214
        if ssl_cert_reqs is 'none':
215
            ssl_cert_reqs = ssl_lib.CERT_NONE
216
        elif ssl_cert_reqs is 'optional':
217
            ssl_cert_reqs = ssl_lib.CERT_OPTIONAL
218
        elif ssl_cert_reqs is 'required':
219
            ssl_cert_reqs = ssl_lib.CERT_REQUIRED
220
        ssl_kwargs['ssl_cert_reqs'] = ssl_cert_reqs
221
    if ssl_ca_certs:
222
        ssl_kwargs['ssl'] = True
223
        ssl_kwargs['ssl_ca_certs'] = ssl_ca_certs
224
    if ssl_kwargs.get('ssl', False):
225
        # pass in ssl_match_hostname only if ssl is True. The right default value
226
        # for ssl_match_hostname in almost all cases is True.
227
        ssl_kwargs['ssl_match_hostname'] = ssl_match_hostname
228
    return ssl_kwargs
229
230
231
class MongoDBAccess(object):
232
    """Database object access class that provides general functions for a model type."""
233
234
    def __init__(self, model):
235
        self.model = model
236
237
    def get_by_name(self, value):
238
        return self.get(name=value, raise_exception=True)
239
240
    def get_by_id(self, value):
241
        return self.get(id=value, raise_exception=True)
242
243
    def get_by_uid(self, value):
244
        return self.get(uid=value, raise_exception=True)
245
246
    def get_by_ref(self, value):
247
        return self.get(ref=value, raise_exception=True)
248
249
    def get_by_pack(self, value):
250
        return self.get(pack=value, raise_exception=True)
251
252
    def get(self, exclude_fields=None, *args, **kwargs):
253
        raise_exception = kwargs.pop('raise_exception', False)
254
255
        instances = self.model.objects(**kwargs)
256
257
        if exclude_fields:
258
            instances = instances.exclude(*exclude_fields)
259
260
        instance = instances[0] if instances else None
261
        log_query_and_profile_data_for_queryset(queryset=instances)
262
263
        if not instance and raise_exception:
264
            msg = 'Unable to find the %s instance. %s' % (self.model.__name__, kwargs)
265
            raise StackStormDBObjectNotFoundError(msg)
266
267
        return instance
268
269
    def get_all(self, *args, **kwargs):
270
        return self.query(*args, **kwargs)
271
272
    def count(self, *args, **kwargs):
273
        result = self.model.objects(**kwargs).count()
274
        log_query_and_profile_data_for_queryset(queryset=result)
275
        return result
276
277
    def query(self, offset=0, limit=None, order_by=None, exclude_fields=None,
278
              **filters):
279
        order_by = order_by or []
280
        exclude_fields = exclude_fields or []
281
        eop = offset + int(limit) if limit else None
282
283
        # Process the filters
284
        # Note: Both of those functions manipulate "filters" variable so the order in which they
285
        # are called matters
286
        filters, order_by = self._process_datetime_range_filters(filters=filters, order_by=order_by)
287
        filters = self._process_null_filters(filters=filters)
288
289
        result = self.model.objects(**filters)
290
291
        if exclude_fields:
292
            result = result.exclude(*exclude_fields)
293
294
        result = result.order_by(*order_by)
295
        result = result[offset:eop]
296
        log_query_and_profile_data_for_queryset(queryset=result)
297
298
        return result
299
300
    def distinct(self, *args, **kwargs):
301
        field = kwargs.pop('field')
302
        result = self.model.objects(**kwargs).distinct(field)
303
        log_query_and_profile_data_for_queryset(queryset=result)
304
        return result
305
306
    def aggregate(self, *args, **kwargs):
307
        return self.model.objects(**kwargs)._collection.aggregate(*args, **kwargs)
308
309
    def insert(self, instance):
310
        instance = self.model.objects.insert(instance)
311
        return self._undo_dict_field_escape(instance)
312
313
    def add_or_update(self, instance):
314
        instance.save()
315
        return self._undo_dict_field_escape(instance)
316
317
    def update(self, instance, **kwargs):
318
        return instance.update(**kwargs)
319
320
    def delete(self, instance):
321
        return instance.delete()
322
323
    def delete_by_query(self, **query):
324
        """
325
        Delete objects by query and return number of deleted objects.
326
        """
327
        qs = self.model.objects.filter(**query)
328
        count = qs.delete()
329
        log_query_and_profile_data_for_queryset(queryset=qs)
330
331
        return count
332
333
    def _undo_dict_field_escape(self, instance):
334
        for attr, field in instance._fields.iteritems():
335
            if isinstance(field, stormbase.EscapedDictField):
336
                value = getattr(instance, attr)
337
                setattr(instance, attr, field.to_python(value))
338
        return instance
339
340
    def _process_null_filters(self, filters):
341
        result = copy.deepcopy(filters)
342
343
        null_filters = {k: v for k, v in six.iteritems(filters)
344
                        if v is None or (type(v) in [str, unicode] and str(v.lower()) == 'null')}
345
346
        for key in null_filters.keys():
347
            result['%s__exists' % (key)] = False
348
            del result[key]
349
350
        return result
351
352
    def _process_datetime_range_filters(self, filters, order_by=None):
353
        ranges = {k: v for k, v in filters.iteritems()
354
                  if type(v) in [str, unicode] and '..' in v}
355
356
        order_by_list = copy.deepcopy(order_by) if order_by else []
357
        for k, v in ranges.iteritems():
358
            values = v.split('..')
359
            dt1 = isotime.parse(values[0])
360
            dt2 = isotime.parse(values[1])
361
362
            k__gte = '%s__gte' % k
363
            k__lte = '%s__lte' % k
364
            if dt1 < dt2:
365
                query = {k__gte: dt1, k__lte: dt2}
366
                sort_key, reverse_sort_key = k, '-' + k
367
            else:
368
                query = {k__gte: dt2, k__lte: dt1}
369
                sort_key, reverse_sort_key = '-' + k, k
370
            del filters[k]
371
            filters.update(query)
372
373
            if reverse_sort_key in order_by_list:
374
                idx = order_by_list.index(reverse_sort_key)
375
                order_by_list.pop(idx)
376
                order_by_list.insert(idx, sort_key)
377
            elif sort_key not in order_by_list:
378
                order_by_list = [sort_key] + order_by_list
379
380
        return filters, order_by_list
381