Test Setup Failed
Pull Request — master (#4154)
by W
03:25
created

ChangeRevisionMongoDBAccess.get()   A

Complexity

Conditions 1

Size

Total Lines 4

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
c 0
b 0
f 0
dl 0
loc 4
rs 10
1
# Licensed to the StackStorm, Inc ('StackStorm') under one or more
2
# contributor license agreements.  See the NOTICE file distributed with
3
# this work for additional information regarding copyright ownership.
4
# The ASF licenses this file to You under the Apache License, Version 2.0
5
# (the "License"); you may not use this file except in compliance with
6
# the License.  You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
from __future__ import absolute_import
17
18
import copy
19
import importlib
20
import traceback
21
import ssl as ssl_lib
22
23
import six
24
import mongoengine
25
from mongoengine.queryset import visitor
26
from pymongo import uri_parser
27
from pymongo.errors import OperationFailure
28
29
from st2common import log as logging
30
from st2common.util import isotime
31
from st2common.models.db import stormbase
32
from st2common.models.utils.profiling import log_query_and_profile_data_for_queryset
33
from st2common.exceptions import db as db_exc
34
35
36
LOG = logging.getLogger(__name__)
37
38
MODEL_MODULE_NAMES = [
39
    'st2common.models.db.auth',
40
    'st2common.models.db.action',
41
    'st2common.models.db.actionalias',
42
    'st2common.models.db.keyvalue',
43
    'st2common.models.db.execution',
44
    'st2common.models.db.executionstate',
45
    'st2common.models.db.liveaction',
46
    'st2common.models.db.notification',
47
    'st2common.models.db.pack',
48
    'st2common.models.db.policy',
49
    'st2common.models.db.rbac',
50
    'st2common.models.db.rule',
51
    'st2common.models.db.rule_enforcement',
52
    'st2common.models.db.runner',
53
    'st2common.models.db.sensor',
54
    'st2common.models.db.trace',
55
    'st2common.models.db.trigger',
56
    'st2common.models.db.webhook'
57
]
58
59
# A list of model names for which we don't perform extra index cleanup
60
INDEX_CLEANUP_MODEL_NAMES_BLACKLIST = [
61
    'PermissionGrantDB'
62
]
63
64
65
def get_model_classes():
66
    """
67
    Retrieve a list of all the defined model classes.
68
69
    :rtype: ``list``
70
    """
71
    result = []
72
    for module_name in MODEL_MODULE_NAMES:
73
        module = importlib.import_module(module_name)
74
        model_classes = getattr(module, 'MODELS', [])
75
        result.extend(model_classes)
76
77
    return result
78
79
80
def _db_connect(db_name, db_host, db_port, username=None, password=None,
81
             ssl=False, ssl_keyfile=None, ssl_certfile=None,
82
             ssl_cert_reqs=None, ssl_ca_certs=None, ssl_match_hostname=True):
83
84
    if '://' in db_host:
85
        # Hostname is provided as a URI string. Make sure we don't log the password in case one is
86
        # included as part of the URI string.
87
        uri_dict = uri_parser.parse_uri(db_host)
88
        username_string = uri_dict.get('username', username) or username
89
90
        if uri_dict.get('username', None) and username:
91
            # Username argument has precedence over connection string username
92
            username_string = username
93
94
        hostnames = get_host_names_for_uri_dict(uri_dict=uri_dict)
95
96
        if len(uri_dict['nodelist']) > 1:
97
            host_string = '%s (replica set)' % (hostnames)
98
        else:
99
            host_string = hostnames
100
    else:
101
        host_string = '%s:%s' % (db_host, db_port)
102
        username_string = username
103
104
    LOG.info('Connecting to database "%s" @ "%s" as user "%s".' % (db_name, host_string,
105
                                                                   str(username_string)))
106
107
    ssl_kwargs = _get_ssl_kwargs(ssl=ssl, ssl_keyfile=ssl_keyfile, ssl_certfile=ssl_certfile,
108
                                 ssl_cert_reqs=ssl_cert_reqs, ssl_ca_certs=ssl_ca_certs,
109
                                 ssl_match_hostname=ssl_match_hostname)
110
111
    connection = mongoengine.connection.connect(db_name, host=db_host,
112
                                                port=db_port, tz_aware=True,
113
                                                username=username, password=password,
114
                                                **ssl_kwargs)
115
116
    LOG.info('Successfully connected to database "%s" @ "%s" as user "%s".' % (
117
        db_name, host_string, str(username_string)))
118
119
    return connection
120
121
122
def db_setup(db_name, db_host, db_port, username=None, password=None, ensure_indexes=True,
123
             ssl=False, ssl_keyfile=None, ssl_certfile=None,
124
             ssl_cert_reqs=None, ssl_ca_certs=None, ssl_match_hostname=True):
125
126
    connection = _db_connect(db_name, db_host, db_port, username=username,
127
                             password=password, ssl=ssl, ssl_keyfile=ssl_keyfile,
128
                             ssl_certfile=ssl_certfile,
129
                             ssl_cert_reqs=ssl_cert_reqs, ssl_ca_certs=ssl_ca_certs,
130
                             ssl_match_hostname=ssl_match_hostname)
131
132
    # Create all the indexes upfront to prevent race-conditions caused by
133
    # lazy index creation
134
    if ensure_indexes:
135
        db_ensure_indexes()
136
137
    return connection
138
139
140
def db_ensure_indexes():
141
    """
142
    This function ensures that indexes for all the models have been created and the
143
    extra indexes cleaned up.
144
145
    Note #1: When calling this method database connection already needs to be
146
    established.
147
148
    Note #2: This method blocks until all the index have been created (indexes
149
    are created in real-time and not in background).
150
    """
151
    LOG.debug('Ensuring database indexes...')
152
    model_classes = get_model_classes()
153
154
    for model_class in model_classes:
155
        class_name = model_class.__name__
156
157
        # Note: We need to ensure / create new indexes before removing extra ones
158
        try:
159
            model_class.ensure_indexes()
160
        except OperationFailure as e:
161
            # Special case for "uid" index. MongoDB 3.4 has dropped "_types" index option so we
162
            # need to re-create the index to make it work and avoid "index with different options
163
            # already exists" error.
164
            # Note: This condition would only be encountered when upgrading existing StackStorm
165
            # installation from MongoDB 3.2 to 3.4.
166
            msg = str(e)
167
            if 'already exists with different options' in msg and 'uid_1' in msg:
168
                drop_obsolete_types_indexes(model_class=model_class)
169
            else:
170
                raise e
171
        except Exception as e:
172
            tb_msg = traceback.format_exc()
173
            msg = 'Failed to ensure indexes for model "%s": %s' % (class_name, str(e))
174
            msg += '\n\n' + tb_msg
175
            exc_cls = type(e)
176
            raise exc_cls(msg)
177
178
        if model_class.__name__ in INDEX_CLEANUP_MODEL_NAMES_BLACKLIST:
179
            LOG.debug('Skipping index cleanup for blacklisted model "%s"...' % (class_name))
180
            continue
181
182
        removed_count = cleanup_extra_indexes(model_class=model_class)
183
        if removed_count:
184
            LOG.debug('Removed "%s" extra indexes for model "%s"' % (removed_count, class_name))
185
186
    LOG.debug('Indexes are ensured for models: %s' %
187
              ', '.join(sorted((model_class.__name__ for model_class in model_classes))))
188
189
190
def cleanup_extra_indexes(model_class):
191
    """
192
    Finds any extra indexes and removes those from mongodb.
193
    """
194
    extra_indexes = model_class.compare_indexes().get('extra', None)
195
    if not extra_indexes:
196
        return 0
197
198
    # mongoengine does not have the necessary method so we need to drop to
199
    # pymongo interfaces via some private methods.
200
    removed_count = 0
201
    c = model_class._get_collection()
202
    for extra_index in extra_indexes:
203
        try:
204
            c.drop_index(extra_index)
205
            LOG.debug('Dropped index %s for model %s.', extra_index, model_class.__name__)
206
            removed_count += 1
207
        except OperationFailure:
208
            LOG.warning('Attempt to cleanup index %s failed.', extra_index, exc_info=True)
209
210
    return removed_count
211
212
213
def drop_obsolete_types_indexes(model_class):
214
    """
215
    Special class for droping offending "types" indexes for which support has
216
    been removed in mongoengine and MongoDB 3.4.
217
    For more info, see: http://docs.mongoengine.org/upgrade.html#inheritance
218
    """
219
    class_name = model_class.__name__
220
221
    LOG.debug('Dropping obsolete types index for model "%s"' % (class_name))
222
    collection = model_class._get_collection()
223
    collection.update({}, {'$unset': {'_types': 1}}, multi=True)
224
225
    info = collection.index_information()
226
    indexes_to_drop = [key for key, value in six.iteritems(info)
227
                       if '_types' in dict(value['key']) or 'types' in value]
228
229
    LOG.debug('Will drop obsolete types indexes for model "%s": %s' % (class_name,
230
                                                                       str(indexes_to_drop)))
231
232
    for index in indexes_to_drop:
233
        collection.drop_index(index)
234
235
    LOG.debug('Recreating indexes for model "%s"' % (class_name))
236
    model_class.ensure_indexes()
237
238
239
def db_teardown():
240
    mongoengine.connection.disconnect()
241
242
243
def db_cleanup(db_name, db_host, db_port, username=None, password=None,
244
               ssl=False, ssl_keyfile=None, ssl_certfile=None,
245
               ssl_cert_reqs=None, ssl_ca_certs=None, ssl_match_hostname=True):
246
247
    connection = _db_connect(db_name, db_host, db_port, username=username,
248
                             password=password, ssl=ssl, ssl_keyfile=ssl_keyfile,
249
                             ssl_certfile=ssl_certfile,
250
                             ssl_cert_reqs=ssl_cert_reqs, ssl_ca_certs=ssl_ca_certs,
251
                             ssl_match_hostname=ssl_match_hostname)
252
253
    LOG.info('Dropping database "%s" @ "%s:%s" as user "%s".',
254
             db_name, db_host, db_port, str(username))
255
256
    connection.drop_database(db_name)
257
    return connection
258
259
260
def _get_ssl_kwargs(ssl=False, ssl_keyfile=None, ssl_certfile=None, ssl_cert_reqs=None,
261
                    ssl_ca_certs=None, ssl_match_hostname=True):
262
    ssl_kwargs = {
263
        'ssl': ssl,
264
    }
265
    if ssl_keyfile:
266
        ssl_kwargs['ssl'] = True
267
        ssl_kwargs['ssl_keyfile'] = ssl_keyfile
268
    if ssl_certfile:
269
        ssl_kwargs['ssl'] = True
270
        ssl_kwargs['ssl_certfile'] = ssl_certfile
271
    if ssl_cert_reqs:
272
        if ssl_cert_reqs is 'none':
273
            ssl_cert_reqs = ssl_lib.CERT_NONE
274
        elif ssl_cert_reqs is 'optional':
275
            ssl_cert_reqs = ssl_lib.CERT_OPTIONAL
276
        elif ssl_cert_reqs is 'required':
277
            ssl_cert_reqs = ssl_lib.CERT_REQUIRED
278
        ssl_kwargs['ssl_cert_reqs'] = ssl_cert_reqs
279
    if ssl_ca_certs:
280
        ssl_kwargs['ssl'] = True
281
        ssl_kwargs['ssl_ca_certs'] = ssl_ca_certs
282
    if ssl_kwargs.get('ssl', False):
283
        # pass in ssl_match_hostname only if ssl is True. The right default value
284
        # for ssl_match_hostname in almost all cases is True.
285
        ssl_kwargs['ssl_match_hostname'] = ssl_match_hostname
286
    return ssl_kwargs
287
288
289
class MongoDBAccess(object):
290
    """Database object access class that provides general functions for a model type."""
291
292
    def __init__(self, model):
293
        self.model = model
294
295
    def get_by_name(self, value):
296
        return self.get(name=value, raise_exception=True)
297
298
    def get_by_id(self, value):
299
        return self.get(id=value, raise_exception=True)
300
301
    def get_by_uid(self, value):
302
        return self.get(uid=value, raise_exception=True)
303
304
    def get_by_ref(self, value):
305
        return self.get(ref=value, raise_exception=True)
306
307
    def get_by_pack(self, value):
308
        return self.get(pack=value, raise_exception=True)
309
310
    def get(self, *args, **kwargs):
311
        exclude_fields = kwargs.pop('exclude_fields', None)
312
        raise_exception = kwargs.pop('raise_exception', False)
313
        only_fields = kwargs.pop('only_fields', None)
314
315
        args = self._process_arg_filters(args)
316
317
        instances = self.model.objects(*args, **kwargs)
318
319
        if exclude_fields:
320
            instances = instances.exclude(*exclude_fields)
321
322
        if only_fields:
323
            try:
324
                instances = instances.only(*only_fields)
325
            except mongoengine.errors.LookUpError as e:
326
                msg = ('Invalid or unsupported include attribute specified: %s' % str(e))
327
                raise ValueError(msg)
328
329
        instance = instances[0] if instances else None
330
        log_query_and_profile_data_for_queryset(queryset=instances)
331
332
        if not instance and raise_exception:
333
            msg = 'Unable to find the %s instance. %s' % (self.model.__name__, kwargs)
334
            raise db_exc.StackStormDBObjectNotFoundError(msg)
335
336
        return instance
337
338
    def get_all(self, *args, **kwargs):
339
        return self.query(*args, **kwargs)
340
341
    def count(self, *args, **kwargs):
342
        result = self.model.objects(*args, **kwargs).count()
343
        log_query_and_profile_data_for_queryset(queryset=result)
344
        return result
345
346
    # TODO: PEP-3102 introduced keyword-only arguments, so once we support Python 3+, we can change
347
    #       this definition to have explicit keyword-only arguments:
348
    #
349
    #           def query(self, *args, offset=0, limit=None, order_by=None, exclude_fields=None,
350
    #                     **filters):
351
    def query(self, *args, **filters):
352
        # Python 2: Pop keyword parameters that aren't actually filters off of the kwargs
353
        offset = filters.pop('offset', 0)
354
        limit = filters.pop('limit', None)
355
        order_by = filters.pop('order_by', None)
356
        exclude_fields = filters.pop('exclude_fields', None)
357
        only_fields = filters.pop('only_fields', None)
358
        no_dereference = filters.pop('no_dereference', None)
359
360
        order_by = order_by or []
361
        exclude_fields = exclude_fields or []
362
        eop = offset + int(limit) if limit else None
363
364
        args = self._process_arg_filters(args)
365
        # Process the filters
366
        # Note: Both of those functions manipulate "filters" variable so the order in which they
367
        # are called matters
368
        filters, order_by = self._process_datetime_range_filters(filters=filters, order_by=order_by)
369
        filters = self._process_null_filters(filters=filters)
370
371
        result = self.model.objects(*args, **filters)
372
373
        if exclude_fields:
374
            result = result.exclude(*exclude_fields)
375
376
        if only_fields:
377
            try:
378
                result = result.only(*only_fields)
379
            except mongoengine.errors.LookUpError as e:
380
                msg = ('Invalid or unsupported include attribute specified: %s' % str(e))
381
                raise ValueError(msg)
382
383
        if no_dereference:
384
            result = result.no_dereference()
385
386
        result = result.order_by(*order_by)
387
        result = result[offset:eop]
388
        log_query_and_profile_data_for_queryset(queryset=result)
389
390
        return result
391
392
    def distinct(self, *args, **kwargs):
393
        field = kwargs.pop('field')
394
        result = self.model.objects(**kwargs).distinct(field)
395
        log_query_and_profile_data_for_queryset(queryset=result)
396
        return result
397
398
    def aggregate(self, *args, **kwargs):
399
        return self.model.objects(**kwargs)._collection.aggregate(*args, **kwargs)
400
401
    def insert(self, instance):
402
        instance = self.model.objects.insert(instance)
403
        return self._undo_dict_field_escape(instance)
404
405
    def add_or_update(self, instance):
406
        instance.save()
407
        return self._undo_dict_field_escape(instance)
408
409
    def update(self, instance, **kwargs):
410
        return instance.update(**kwargs)
411
412
    def delete(self, instance):
413
        return instance.delete()
414
415
    def delete_by_query(self, *args, **query):
416
        """
417
        Delete objects by query and return number of deleted objects.
418
        """
419
        qs = self.model.objects.filter(*args, **query)
420
        count = qs.delete()
421
        log_query_and_profile_data_for_queryset(queryset=qs)
422
423
        return count
424
425
    def _undo_dict_field_escape(self, instance):
426
        for attr, field in six.iteritems(instance._fields):
427
            if isinstance(field, stormbase.EscapedDictField):
428
                value = getattr(instance, attr)
429
                setattr(instance, attr, field.to_python(value))
430
        return instance
431
432
    def _process_arg_filters(self, args):
433
        """
434
        Fix filter arguments in nested Q objects
435
        """
436
        _args = tuple()
437
438
        for arg in args:
439
            # Unforunately mongoengine doesn't expose any visitors other than Q, so we have to
440
            # extract QCombination from the module itself
441
            if isinstance(arg, visitor.Q):
442
                # Note: Both of those functions manipulate "filters" variable so the order in which
443
                # they are called matters
444
                filters, _ = self._process_datetime_range_filters(filters=arg.query)
445
                filters = self._process_null_filters(filters=filters)
446
447
                # Create a new Q object with the same filters as the old one
448
                _args += (visitor.Q(**filters),)
449
            elif isinstance(arg, visitor.QCombination):
450
                # Recurse if we need to
451
                children = self._process_arg_filters(arg.children)
452
453
                # Create a new QCombination object with the same operation and fixed filters
454
                _args += (visitor.QCombination(arg.operation, children),)
455
            else:
456
                raise TypeError("Unknown argument type '%s' of argument '%s'"
457
                    % (type(arg), repr(arg)))
458
459
        return _args
460
461
    def _process_null_filters(self, filters):
462
        result = copy.deepcopy(filters)
463
464
        null_filters = {k: v for k, v in six.iteritems(filters)
465
                        if v is None or
466
                        (type(v) in [str, six.text_type] and str(v.lower()) == 'null')}
467
468
        for key in null_filters.keys():
469
            result['%s__exists' % (key)] = False
470
            del result[key]
471
472
        return result
473
474
    def _process_datetime_range_filters(self, filters, order_by=None):
475
        ranges = {k: v for k, v in six.iteritems(filters)
476
                  if type(v) in [str, six.text_type] and '..' in v}
477
478
        order_by_list = copy.deepcopy(order_by) if order_by else []
479
        for k, v in six.iteritems(ranges):
480
            values = v.split('..')
481
            dt1 = isotime.parse(values[0])
482
            dt2 = isotime.parse(values[1])
483
484
            k__gte = '%s__gte' % k
485
            k__lte = '%s__lte' % k
486
            if dt1 < dt2:
487
                query = {k__gte: dt1, k__lte: dt2}
488
                sort_key, reverse_sort_key = k, '-' + k
489
            else:
490
                query = {k__gte: dt2, k__lte: dt1}
491
                sort_key, reverse_sort_key = '-' + k, k
492
            del filters[k]
493
            filters.update(query)
494
495
            if reverse_sort_key in order_by_list:
496
                idx = order_by_list.index(reverse_sort_key)
497
                order_by_list.pop(idx)
498
                order_by_list.insert(idx, sort_key)
499
            elif sort_key not in order_by_list:
500
                order_by_list = [sort_key] + order_by_list
501
502
        return filters, order_by_list
503
504
505
class ChangeRevisionMongoDBAccess(MongoDBAccess):
506
507
    def insert(self, instance):
508
        instance = self.model.objects.insert(instance)
509
510
        return self._undo_dict_field_escape(instance)
511
512
    def add_or_update(self, instance):
513
        return self.save(instance)
514
515
    def update(self, instance, **kwargs):
516
        for k, v in six.iteritems(kwargs):
517
            setattr(instance, k, v)
518
519
        return self.save(instance)
520
521
    def save(self, instance):
522
        if not hasattr(instance, 'id') or not instance.id:
523
            return self.insert(instance)
524
        else:
525
            try:
526
                save_condition = {'id': instance.id, 'rev': instance.rev}
527
                instance.rev = instance.rev + 1
528
                instance.save(save_condition=save_condition)
529
            except mongoengine.SaveConditionError:
530
                raise db_exc.StackStormDBObjectWriteConflictError(instance)
531
532
            return self._undo_dict_field_escape(instance)
533
534
535
def get_host_names_for_uri_dict(uri_dict):
536
    hosts = []
537
538
    for host, port in uri_dict['nodelist']:
539
        hosts.append('%s:%s' % (host, port))
540
541
    hosts = ','.join(hosts)
542
    return hosts
543