Passed
Push — develop ( ce0eb9...be6acb )
by Plexxi
05:19 queued 02:35
created

db_ensure_indexes()   C

Complexity

Conditions 7

Size

Total Lines 47

Duplication

Lines 0
Ratio 0 %

Importance

Changes 4
Bugs 0 Features 0
Metric Value
cc 7
c 4
b 0
f 0
dl 0
loc 47
rs 5.5
1
# Licensed to the StackStorm, Inc ('StackStorm') under one or more
2
# contributor license agreements.  See the NOTICE file distributed with
3
# this work for additional information regarding copyright ownership.
4
# The ASF licenses this file to You under the Apache License, Version 2.0
5
# (the "License"); you may not use this file except in compliance with
6
# the License.  You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
import copy
17
import importlib
18
import traceback
19
import ssl as ssl_lib
20
21
import six
22
import mongoengine
23
from pymongo.errors import OperationFailure
24
25
from st2common import log as logging
26
from st2common.util import isotime
27
from st2common.models.db import stormbase
28
from st2common.models.utils.profiling import log_query_and_profile_data_for_queryset
29
from st2common.exceptions.db import StackStormDBObjectNotFoundError
30
31
32
LOG = logging.getLogger(__name__)
33
34
MODEL_MODULE_NAMES = [
35
    'st2common.models.db.auth',
36
    'st2common.models.db.action',
37
    'st2common.models.db.actionalias',
38
    'st2common.models.db.keyvalue',
39
    'st2common.models.db.execution',
40
    'st2common.models.db.executionstate',
41
    'st2common.models.db.liveaction',
42
    'st2common.models.db.notification',
43
    'st2common.models.db.pack',
44
    'st2common.models.db.policy',
45
    'st2common.models.db.rbac',
46
    'st2common.models.db.rule',
47
    'st2common.models.db.rule_enforcement',
48
    'st2common.models.db.runner',
49
    'st2common.models.db.sensor',
50
    'st2common.models.db.trace',
51
    'st2common.models.db.trigger',
52
    'st2common.models.db.webhook'
53
]
54
55
# A list of model names for which we don't perform extra index cleanup
56
INDEX_CLEANUP_MODEL_NAMES_BLACKLIST = [
57
    'PermissionGrantDB'
58
]
59
60
61
def get_model_classes():
62
    """
63
    Retrieve a list of all the defined model classes.
64
65
    :rtype: ``list``
66
    """
67
    result = []
68
    for module_name in MODEL_MODULE_NAMES:
69
        module = importlib.import_module(module_name)
70
        model_classes = getattr(module, 'MODELS', [])
71
        result.extend(model_classes)
72
73
    return result
74
75
76
def db_setup(db_name, db_host, db_port, username=None, password=None, ensure_indexes=True,
77
             ssl=False, ssl_keyfile=None, ssl_certfile=None,
78
             ssl_cert_reqs=None, ssl_ca_certs=None, ssl_match_hostname=True):
79
    LOG.info('Connecting to database "%s" @ "%s:%s" as user "%s".',
80
             db_name, db_host, db_port, str(username))
81
82
    ssl_kwargs = _get_ssl_kwargs(ssl=ssl, ssl_keyfile=ssl_keyfile, ssl_certfile=ssl_certfile,
83
                                 ssl_cert_reqs=ssl_cert_reqs, ssl_ca_certs=ssl_ca_certs,
84
                                 ssl_match_hostname=ssl_match_hostname)
85
86
    connection = mongoengine.connection.connect(db_name, host=db_host,
87
                                                port=db_port, tz_aware=True,
88
                                                username=username, password=password,
89
                                                **ssl_kwargs)
90
91
    # Create all the indexes upfront to prevent race-conditions caused by
92
    # lazy index creation
93
    if ensure_indexes:
94
        db_ensure_indexes()
95
96
    return connection
97
98
99
def db_ensure_indexes():
100
    """
101
    This function ensures that indexes for all the models have been created and the
102
    extra indexes cleaned up.
103
104
    Note #1: When calling this method database connection already needs to be
105
    established.
106
107
    Note #2: This method blocks until all the index have been created (indexes
108
    are created in real-time and not in background).
109
    """
110
    LOG.debug('Ensuring database indexes...')
111
    model_classes = get_model_classes()
112
113
    for model_class in model_classes:
114
        class_name = model_class.__name__
115
116
        # Note: We need to ensure / create new indexes before removing extra ones
117
        LOG.debug('Ensuring indexes for model "%s"...' % (model_class.__name__))
118
119
        try:
120
            model_class.ensure_indexes()
121
        except OperationFailure as e:
122
            # Special case for "uid" index. MongoDB 3.4 has dropped "_types" index option so we
123
            # need to re-create the index to make it work and avoid "index with different options
124
            # already exists" error.
125
            # Note: This condition would only be encountered when upgrading existing StackStorm
126
            # installation from MongoDB 3.2 to 3.4.
127
            msg = str(e)
128
            if 'already exists with different options' in msg and 'uid_1' in msg:
129
                drop_obsolete_types_indexes(model_class=model_class)
130
            else:
131
                raise e
132
        except Exception as e:
133
            tb_msg = traceback.format_exc()
134
            msg = 'Failed to ensure indexes for model "%s": %s' % (class_name, str(e))
135
            msg += '\n\n' + tb_msg
136
            exc_cls = type(e)
137
            raise exc_cls(msg)
138
139
        if model_class.__name__ in INDEX_CLEANUP_MODEL_NAMES_BLACKLIST:
140
            LOG.debug('Skipping index cleanup for blacklisted model "%s"...' % (class_name))
141
            continue
142
143
        LOG.debug('Removing extra indexes for model "%s"...' % (class_name))
144
        removed_count = cleanup_extra_indexes(model_class=model_class)
145
        LOG.debug('Removed "%s" extra indexes for model "%s"' % (removed_count, class_name))
146
147
148
def cleanup_extra_indexes(model_class):
149
    """
150
    Finds any extra indexes and removes those from mongodb.
151
    """
152
    extra_indexes = model_class.compare_indexes().get('extra', None)
153
    if not extra_indexes:
154
        return 0
155
156
    # mongoengine does not have the necessary method so we need to drop to
157
    # pymongo interfaces via some private methods.
158
    removed_count = 0
159
    c = model_class._get_collection()
160
    for extra_index in extra_indexes:
161
        try:
162
            c.drop_index(extra_index)
163
            LOG.debug('Dropped index %s for model %s.', extra_index, model_class.__name__)
164
            removed_count += 1
165
        except OperationFailure:
166
            LOG.warning('Attempt to cleanup index %s failed.', extra_index, exc_info=True)
167
168
    return removed_count
169
170
171
def drop_obsolete_types_indexes(model_class):
172
    """
173
    Special class for droping offending "types" indexes for which support has
174
    been removed in mongoengine and MongoDB 3.4.
175
    For more info, see: http://docs.mongoengine.org/upgrade.html#inheritance
176
    """
177
    class_name = model_class.__name__
178
179
    LOG.debug('Dropping obsolete types index for model "%s"' % (class_name))
180
    collection = model_class._get_collection()
181
    collection.update({}, {'$unset': {'_types': 1}}, multi=True)
182
183
    info = collection.index_information()
184
    indexes_to_drop = [key for key, value in info.iteritems()
185
                       if '_types' in dict(value['key']) or 'types' in value]
186
187
    LOG.debug('Will drop obsolete types indexes for model "%s": %s' % (class_name,
188
                                                                       str(indexes_to_drop)))
189
190
    for index in indexes_to_drop:
191
        collection.drop_index(index)
192
193
    LOG.debug('Recreating indexes for model "%s"' % (class_name))
194
    model_class.ensure_indexes()
195
196
197
def db_teardown():
198
    mongoengine.connection.disconnect()
199
200
201
def _get_ssl_kwargs(ssl=False, ssl_keyfile=None, ssl_certfile=None, ssl_cert_reqs=None,
202
                    ssl_ca_certs=None, ssl_match_hostname=True):
203
    ssl_kwargs = {
204
        'ssl': ssl,
205
    }
206
    if ssl_keyfile:
207
        ssl_kwargs['ssl'] = True
208
        ssl_kwargs['ssl_keyfile'] = ssl_keyfile
209
    if ssl_certfile:
210
        ssl_kwargs['ssl'] = True
211
        ssl_kwargs['ssl_certfile'] = ssl_certfile
212
    if ssl_cert_reqs:
213
        if ssl_cert_reqs is 'none':
214
            ssl_cert_reqs = ssl_lib.CERT_NONE
215
        elif ssl_cert_reqs is 'optional':
216
            ssl_cert_reqs = ssl_lib.CERT_OPTIONAL
217
        elif ssl_cert_reqs is 'required':
218
            ssl_cert_reqs = ssl_lib.CERT_REQUIRED
219
        ssl_kwargs['ssl_cert_reqs'] = ssl_cert_reqs
220
    if ssl_ca_certs:
221
        ssl_kwargs['ssl'] = True
222
        ssl_kwargs['ssl_ca_certs'] = ssl_ca_certs
223
    if ssl_kwargs.get('ssl', False):
224
        # pass in ssl_match_hostname only if ssl is True. The right default value
225
        # for ssl_match_hostname in almost all cases is True.
226
        ssl_kwargs['ssl_match_hostname'] = ssl_match_hostname
227
    return ssl_kwargs
228
229
230
class MongoDBAccess(object):
231
    """Database object access class that provides general functions for a model type."""
232
233
    def __init__(self, model):
234
        self.model = model
235
236
    def get_by_name(self, value):
237
        return self.get(name=value, raise_exception=True)
238
239
    def get_by_id(self, value):
240
        return self.get(id=value, raise_exception=True)
241
242
    def get_by_uid(self, value):
243
        return self.get(uid=value, raise_exception=True)
244
245
    def get_by_ref(self, value):
246
        return self.get(ref=value, raise_exception=True)
247
248
    def get_by_pack(self, value):
249
        return self.get(pack=value, raise_exception=True)
250
251
    def get(self, exclude_fields=None, *args, **kwargs):
252
        raise_exception = kwargs.pop('raise_exception', False)
253
254
        instances = self.model.objects(**kwargs)
255
256
        if exclude_fields:
257
            instances = instances.exclude(*exclude_fields)
258
259
        instance = instances[0] if instances else None
260
        log_query_and_profile_data_for_queryset(queryset=instances)
261
262
        if not instance and raise_exception:
263
            msg = 'Unable to find the %s instance. %s' % (self.model.__name__, kwargs)
264
            raise StackStormDBObjectNotFoundError(msg)
265
266
        return instance
267
268
    def get_all(self, *args, **kwargs):
269
        return self.query(*args, **kwargs)
270
271
    def count(self, *args, **kwargs):
272
        result = self.model.objects(**kwargs).count()
273
        log_query_and_profile_data_for_queryset(queryset=result)
274
        return result
275
276
    def query(self, offset=0, limit=None, order_by=None, exclude_fields=None,
277
              **filters):
278
        order_by = order_by or []
279
        exclude_fields = exclude_fields or []
280
        eop = offset + int(limit) if limit else None
281
282
        # Process the filters
283
        # Note: Both of those functions manipulate "filters" variable so the order in which they
284
        # are called matters
285
        filters, order_by = self._process_datetime_range_filters(filters=filters, order_by=order_by)
286
        filters = self._process_null_filters(filters=filters)
287
288
        result = self.model.objects(**filters)
289
290
        if exclude_fields:
291
            result = result.exclude(*exclude_fields)
292
293
        result = result.order_by(*order_by)
294
        result = result[offset:eop]
295
        log_query_and_profile_data_for_queryset(queryset=result)
296
297
        return result
298
299
    def distinct(self, *args, **kwargs):
300
        field = kwargs.pop('field')
301
        result = self.model.objects(**kwargs).distinct(field)
302
        log_query_and_profile_data_for_queryset(queryset=result)
303
        return result
304
305
    def aggregate(self, *args, **kwargs):
306
        return self.model.objects(**kwargs)._collection.aggregate(*args, **kwargs)
307
308
    def insert(self, instance):
309
        instance = self.model.objects.insert(instance)
310
        return self._undo_dict_field_escape(instance)
311
312
    def add_or_update(self, instance):
313
        instance.save()
314
        return self._undo_dict_field_escape(instance)
315
316
    def update(self, instance, **kwargs):
317
        return instance.update(**kwargs)
318
319
    def delete(self, instance):
320
        return instance.delete()
321
322
    def delete_by_query(self, **query):
323
        """
324
        Delete objects by query and return number of deleted objects.
325
        """
326
        qs = self.model.objects.filter(**query)
327
        count = qs.delete()
328
        log_query_and_profile_data_for_queryset(queryset=qs)
329
330
        return count
331
332
    def _undo_dict_field_escape(self, instance):
333
        for attr, field in instance._fields.iteritems():
334
            if isinstance(field, stormbase.EscapedDictField):
335
                value = getattr(instance, attr)
336
                setattr(instance, attr, field.to_python(value))
337
        return instance
338
339
    def _process_null_filters(self, filters):
340
        result = copy.deepcopy(filters)
341
342
        null_filters = {k: v for k, v in six.iteritems(filters)
343
                        if v is None or (type(v) in [str, unicode] and str(v.lower()) == 'null')}
344
345
        for key in null_filters.keys():
346
            result['%s__exists' % (key)] = False
347
            del result[key]
348
349
        return result
350
351
    def _process_datetime_range_filters(self, filters, order_by=None):
352
        ranges = {k: v for k, v in filters.iteritems()
353
                  if type(v) in [str, unicode] and '..' in v}
354
355
        order_by_list = copy.deepcopy(order_by) if order_by else []
356
        for k, v in ranges.iteritems():
357
            values = v.split('..')
358
            dt1 = isotime.parse(values[0])
359
            dt2 = isotime.parse(values[1])
360
361
            k__gte = '%s__gte' % k
362
            k__lte = '%s__lte' % k
363
            if dt1 < dt2:
364
                query = {k__gte: dt1, k__lte: dt2}
365
                sort_key, reverse_sort_key = k, '-' + k
366
            else:
367
                query = {k__gte: dt2, k__lte: dt1}
368
                sort_key, reverse_sort_key = '-' + k, k
369
            del filters[k]
370
            filters.update(query)
371
372
            if reverse_sort_key in order_by_list:
373
                idx = order_by_list.index(reverse_sort_key)
374
                order_by_list.pop(idx)
375
                order_by_list.insert(idx, sort_key)
376
            elif sort_key not in order_by_list:
377
                order_by_list = [sort_key] + order_by_list
378
379
        return filters, order_by_list
380