Test Setup Failed
Pull Request — master (#4154)
by W
04:10
created

ChangeRevisionMongoDBAccess.get()   A

Complexity

Conditions 1

Size

Total Lines 4

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
c 0
b 0
f 0
dl 0
loc 4
rs 10
1
# Licensed to the StackStorm, Inc ('StackStorm') under one or more
2
# contributor license agreements.  See the NOTICE file distributed with
3
# this work for additional information regarding copyright ownership.
4
# The ASF licenses this file to You under the Apache License, Version 2.0
5
# (the "License"); you may not use this file except in compliance with
6
# the License.  You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
from __future__ import absolute_import
17
18
import copy
19
import importlib
20
import traceback
21
import ssl as ssl_lib
22
23
import six
24
import mongoengine
25
from mongoengine.queryset import visitor
26
from pymongo import uri_parser
27
from pymongo.errors import OperationFailure
28
from pymongo.errors import ConnectionFailure
29
30
from st2common import log as logging
31
from st2common.util import isotime
32
from st2common.models.db import stormbase
33
from st2common.models.utils.profiling import log_query_and_profile_data_for_queryset
34
from st2common.exceptions import db as db_exc
35
36
37
LOG = logging.getLogger(__name__)
38
39
MODEL_MODULE_NAMES = [
40
    'st2common.models.db.auth',
41
    'st2common.models.db.action',
42
    'st2common.models.db.actionalias',
43
    'st2common.models.db.keyvalue',
44
    'st2common.models.db.execution',
45
    'st2common.models.db.executionstate',
46
    'st2common.models.db.liveaction',
47
    'st2common.models.db.notification',
48
    'st2common.models.db.pack',
49
    'st2common.models.db.policy',
50
    'st2common.models.db.rbac',
51
    'st2common.models.db.rule',
52
    'st2common.models.db.rule_enforcement',
53
    'st2common.models.db.runner',
54
    'st2common.models.db.sensor',
55
    'st2common.models.db.trace',
56
    'st2common.models.db.trigger',
57
    'st2common.models.db.webhook'
58
]
59
60
# A list of model names for which we don't perform extra index cleanup
61
INDEX_CLEANUP_MODEL_NAMES_BLACKLIST = [
62
    'PermissionGrantDB'
63
]
64
65
66
def get_model_classes():
67
    """
68
    Retrieve a list of all the defined model classes.
69
70
    :rtype: ``list``
71
    """
72
    result = []
73
    for module_name in MODEL_MODULE_NAMES:
74
        module = importlib.import_module(module_name)
75
        model_classes = getattr(module, 'MODELS', [])
76
        result.extend(model_classes)
77
78
    return result
79
80
81
def _db_connect(db_name, db_host, db_port, username=None, password=None,
82
             ssl=False, ssl_keyfile=None, ssl_certfile=None,
83
             ssl_cert_reqs=None, ssl_ca_certs=None, ssl_match_hostname=True):
84
85
    if '://' in db_host:
86
        # Hostname is provided as a URI string. Make sure we don't log the password in case one is
87
        # included as part of the URI string.
88
        uri_dict = uri_parser.parse_uri(db_host)
89
        username_string = uri_dict.get('username', username) or username
90
91
        if uri_dict.get('username', None) and username:
92
            # Username argument has precedence over connection string username
93
            username_string = username
94
95
        hostnames = get_host_names_for_uri_dict(uri_dict=uri_dict)
96
97
        if len(uri_dict['nodelist']) > 1:
98
            host_string = '%s (replica set)' % (hostnames)
99
        else:
100
            host_string = hostnames
101
    else:
102
        host_string = '%s:%s' % (db_host, db_port)
103
        username_string = username
104
105
    LOG.info('Connecting to database "%s" @ "%s" as user "%s".' % (db_name, host_string,
106
                                                                   str(username_string)))
107
108
    ssl_kwargs = _get_ssl_kwargs(ssl=ssl, ssl_keyfile=ssl_keyfile, ssl_certfile=ssl_certfile,
109
                                 ssl_cert_reqs=ssl_cert_reqs, ssl_ca_certs=ssl_ca_certs,
110
                                 ssl_match_hostname=ssl_match_hostname)
111
112
    connection = mongoengine.connection.connect(db_name, host=db_host,
113
                                                port=db_port, tz_aware=True,
114
                                                username=username, password=password,
115
                                                **ssl_kwargs)
116
117
    # NOTE: Since pymongo 3.0, connect() method is lazy and not blocking (always returns success)
118
    # so we need to issue a command / query to check if connection has been
119
    # successfuly established.
120
    # See http://api.mongodb.com/python/current/api/pymongo/mongo_client.html for details
121
    try:
122
        # The ismaster command is cheap and does not require auth
123
        connection.admin.command('ismaster')
124
    except ConnectionFailure as e:
125
        LOG.error('Failed to connect to database "%s" @ "%s" as user "%s": %s' %
126
                  (db_name, host_string, str(username_string), str(e)))
127
        raise e
128
129
    LOG.info('Successfully connected to database "%s" @ "%s" as user "%s".' % (
130
        db_name, host_string, str(username_string)))
131
132
    return connection
133
134
135
def db_setup(db_name, db_host, db_port, username=None, password=None, ensure_indexes=True,
136
             ssl=False, ssl_keyfile=None, ssl_certfile=None,
137
             ssl_cert_reqs=None, ssl_ca_certs=None, ssl_match_hostname=True):
138
139
    connection = _db_connect(db_name, db_host, db_port, username=username,
140
                             password=password, ssl=ssl, ssl_keyfile=ssl_keyfile,
141
                             ssl_certfile=ssl_certfile,
142
                             ssl_cert_reqs=ssl_cert_reqs, ssl_ca_certs=ssl_ca_certs,
143
                             ssl_match_hostname=ssl_match_hostname)
144
145
    # Create all the indexes upfront to prevent race-conditions caused by
146
    # lazy index creation
147
    if ensure_indexes:
148
        db_ensure_indexes()
149
150
    return connection
151
152
153
def db_ensure_indexes():
154
    """
155
    This function ensures that indexes for all the models have been created and the
156
    extra indexes cleaned up.
157
158
    Note #1: When calling this method database connection already needs to be
159
    established.
160
161
    Note #2: This method blocks until all the index have been created (indexes
162
    are created in real-time and not in background).
163
    """
164
    LOG.debug('Ensuring database indexes...')
165
    model_classes = get_model_classes()
166
167
    for model_class in model_classes:
168
        class_name = model_class.__name__
169
170
        # Note: We need to ensure / create new indexes before removing extra ones
171
        try:
172
            model_class.ensure_indexes()
173
        except OperationFailure as e:
174
            # Special case for "uid" index. MongoDB 3.4 has dropped "_types" index option so we
175
            # need to re-create the index to make it work and avoid "index with different options
176
            # already exists" error.
177
            # Note: This condition would only be encountered when upgrading existing StackStorm
178
            # installation from MongoDB 3.2 to 3.4.
179
            msg = str(e)
180
            if 'already exists with different options' in msg and 'uid_1' in msg:
181
                drop_obsolete_types_indexes(model_class=model_class)
182
            else:
183
                raise e
184
        except Exception as e:
185
            tb_msg = traceback.format_exc()
186
            msg = 'Failed to ensure indexes for model "%s": %s' % (class_name, str(e))
187
            msg += '\n\n' + tb_msg
188
            exc_cls = type(e)
189
            raise exc_cls(msg)
190
191
        if model_class.__name__ in INDEX_CLEANUP_MODEL_NAMES_BLACKLIST:
192
            LOG.debug('Skipping index cleanup for blacklisted model "%s"...' % (class_name))
193
            continue
194
195
        removed_count = cleanup_extra_indexes(model_class=model_class)
196
        if removed_count:
197
            LOG.debug('Removed "%s" extra indexes for model "%s"' % (removed_count, class_name))
198
199
    LOG.debug('Indexes are ensured for models: %s' %
200
              ', '.join(sorted((model_class.__name__ for model_class in model_classes))))
201
202
203
def cleanup_extra_indexes(model_class):
204
    """
205
    Finds any extra indexes and removes those from mongodb.
206
    """
207
    extra_indexes = model_class.compare_indexes().get('extra', None)
208
    if not extra_indexes:
209
        return 0
210
211
    # mongoengine does not have the necessary method so we need to drop to
212
    # pymongo interfaces via some private methods.
213
    removed_count = 0
214
    c = model_class._get_collection()
215
    for extra_index in extra_indexes:
216
        try:
217
            c.drop_index(extra_index)
218
            LOG.debug('Dropped index %s for model %s.', extra_index, model_class.__name__)
219
            removed_count += 1
220
        except OperationFailure:
221
            LOG.warning('Attempt to cleanup index %s failed.', extra_index, exc_info=True)
222
223
    return removed_count
224
225
226
def drop_obsolete_types_indexes(model_class):
227
    """
228
    Special class for droping offending "types" indexes for which support has
229
    been removed in mongoengine and MongoDB 3.4.
230
    For more info, see: http://docs.mongoengine.org/upgrade.html#inheritance
231
    """
232
    class_name = model_class.__name__
233
234
    LOG.debug('Dropping obsolete types index for model "%s"' % (class_name))
235
    collection = model_class._get_collection()
236
    collection.update({}, {'$unset': {'_types': 1}}, multi=True)
237
238
    info = collection.index_information()
239
    indexes_to_drop = [key for key, value in six.iteritems(info)
240
                       if '_types' in dict(value['key']) or 'types' in value]
241
242
    LOG.debug('Will drop obsolete types indexes for model "%s": %s' % (class_name,
243
                                                                       str(indexes_to_drop)))
244
245
    for index in indexes_to_drop:
246
        collection.drop_index(index)
247
248
    LOG.debug('Recreating indexes for model "%s"' % (class_name))
249
    model_class.ensure_indexes()
250
251
252
def db_teardown():
253
    mongoengine.connection.disconnect()
254
255
256
def db_cleanup(db_name, db_host, db_port, username=None, password=None,
257
               ssl=False, ssl_keyfile=None, ssl_certfile=None,
258
               ssl_cert_reqs=None, ssl_ca_certs=None, ssl_match_hostname=True):
259
260
    connection = _db_connect(db_name, db_host, db_port, username=username,
261
                             password=password, ssl=ssl, ssl_keyfile=ssl_keyfile,
262
                             ssl_certfile=ssl_certfile,
263
                             ssl_cert_reqs=ssl_cert_reqs, ssl_ca_certs=ssl_ca_certs,
264
                             ssl_match_hostname=ssl_match_hostname)
265
266
    LOG.info('Dropping database "%s" @ "%s:%s" as user "%s".',
267
             db_name, db_host, db_port, str(username))
268
269
    connection.drop_database(db_name)
270
    return connection
271
272
273
def _get_ssl_kwargs(ssl=False, ssl_keyfile=None, ssl_certfile=None, ssl_cert_reqs=None,
274
                    ssl_ca_certs=None, ssl_match_hostname=True):
275
    ssl_kwargs = {
276
        'ssl': ssl,
277
    }
278
    if ssl_keyfile:
279
        ssl_kwargs['ssl'] = True
280
        ssl_kwargs['ssl_keyfile'] = ssl_keyfile
281
    if ssl_certfile:
282
        ssl_kwargs['ssl'] = True
283
        ssl_kwargs['ssl_certfile'] = ssl_certfile
284
    if ssl_cert_reqs:
285
        if ssl_cert_reqs is 'none':
286
            ssl_cert_reqs = ssl_lib.CERT_NONE
287
        elif ssl_cert_reqs is 'optional':
288
            ssl_cert_reqs = ssl_lib.CERT_OPTIONAL
289
        elif ssl_cert_reqs is 'required':
290
            ssl_cert_reqs = ssl_lib.CERT_REQUIRED
291
        ssl_kwargs['ssl_cert_reqs'] = ssl_cert_reqs
292
    if ssl_ca_certs:
293
        ssl_kwargs['ssl'] = True
294
        ssl_kwargs['ssl_ca_certs'] = ssl_ca_certs
295
    if ssl_kwargs.get('ssl', False):
296
        # pass in ssl_match_hostname only if ssl is True. The right default value
297
        # for ssl_match_hostname in almost all cases is True.
298
        ssl_kwargs['ssl_match_hostname'] = ssl_match_hostname
299
    return ssl_kwargs
300
301
302
class MongoDBAccess(object):
303
    """Database object access class that provides general functions for a model type."""
304
305
    def __init__(self, model):
306
        self.model = model
307
308
    def get_by_name(self, value):
309
        return self.get(name=value, raise_exception=True)
310
311
    def get_by_id(self, value):
312
        return self.get(id=value, raise_exception=True)
313
314
    def get_by_uid(self, value):
315
        return self.get(uid=value, raise_exception=True)
316
317
    def get_by_ref(self, value):
318
        return self.get(ref=value, raise_exception=True)
319
320
    def get_by_pack(self, value):
321
        return self.get(pack=value, raise_exception=True)
322
323
    def get(self, *args, **kwargs):
324
        exclude_fields = kwargs.pop('exclude_fields', None)
325
        raise_exception = kwargs.pop('raise_exception', False)
326
        only_fields = kwargs.pop('only_fields', None)
327
328
        args = self._process_arg_filters(args)
329
330
        instances = self.model.objects(*args, **kwargs)
331
332
        if exclude_fields:
333
            instances = instances.exclude(*exclude_fields)
334
335
        if only_fields:
336
            try:
337
                instances = instances.only(*only_fields)
338
            except mongoengine.errors.LookUpError as e:
339
                msg = ('Invalid or unsupported include attribute specified: %s' % str(e))
340
                raise ValueError(msg)
341
342
        instance = instances[0] if instances else None
343
        log_query_and_profile_data_for_queryset(queryset=instances)
344
345
        if not instance and raise_exception:
346
            msg = 'Unable to find the %s instance. %s' % (self.model.__name__, kwargs)
347
            raise db_exc.StackStormDBObjectNotFoundError(msg)
348
349
        return instance
350
351
    def get_all(self, *args, **kwargs):
352
        return self.query(*args, **kwargs)
353
354
    def count(self, *args, **kwargs):
355
        result = self.model.objects(*args, **kwargs).count()
356
        log_query_and_profile_data_for_queryset(queryset=result)
357
        return result
358
359
    # TODO: PEP-3102 introduced keyword-only arguments, so once we support Python 3+, we can change
360
    #       this definition to have explicit keyword-only arguments:
361
    #
362
    #           def query(self, *args, offset=0, limit=None, order_by=None, exclude_fields=None,
363
    #                     **filters):
364
    def query(self, *args, **filters):
365
        # Python 2: Pop keyword parameters that aren't actually filters off of the kwargs
366
        offset = filters.pop('offset', 0)
367
        limit = filters.pop('limit', None)
368
        order_by = filters.pop('order_by', None)
369
        exclude_fields = filters.pop('exclude_fields', None)
370
        only_fields = filters.pop('only_fields', None)
371
        no_dereference = filters.pop('no_dereference', None)
372
373
        order_by = order_by or []
374
        exclude_fields = exclude_fields or []
375
        eop = offset + int(limit) if limit else None
376
377
        args = self._process_arg_filters(args)
378
        # Process the filters
379
        # Note: Both of those functions manipulate "filters" variable so the order in which they
380
        # are called matters
381
        filters, order_by = self._process_datetime_range_filters(filters=filters, order_by=order_by)
382
        filters = self._process_null_filters(filters=filters)
383
384
        result = self.model.objects(*args, **filters)
385
386
        if exclude_fields:
387
            result = result.exclude(*exclude_fields)
388
389
        if only_fields:
390
            try:
391
                result = result.only(*only_fields)
392
            except mongoengine.errors.LookUpError as e:
393
                msg = ('Invalid or unsupported include attribute specified: %s' % str(e))
394
                raise ValueError(msg)
395
396
        if no_dereference:
397
            result = result.no_dereference()
398
399
        result = result.order_by(*order_by)
400
        result = result[offset:eop]
401
        log_query_and_profile_data_for_queryset(queryset=result)
402
403
        return result
404
405
    def distinct(self, *args, **kwargs):
406
        field = kwargs.pop('field')
407
        result = self.model.objects(**kwargs).distinct(field)
408
        log_query_and_profile_data_for_queryset(queryset=result)
409
        return result
410
411
    def aggregate(self, *args, **kwargs):
412
        return self.model.objects(**kwargs)._collection.aggregate(*args, **kwargs)
413
414
    def insert(self, instance):
415
        instance = self.model.objects.insert(instance)
416
        return self._undo_dict_field_escape(instance)
417
418
    def add_or_update(self, instance):
419
        instance.save()
420
        return self._undo_dict_field_escape(instance)
421
422
    def update(self, instance, **kwargs):
423
        return instance.update(**kwargs)
424
425
    def delete(self, instance):
426
        return instance.delete()
427
428
    def delete_by_query(self, *args, **query):
429
        """
430
        Delete objects by query and return number of deleted objects.
431
        """
432
        qs = self.model.objects.filter(*args, **query)
433
        count = qs.delete()
434
        log_query_and_profile_data_for_queryset(queryset=qs)
435
436
        return count
437
438
    def _undo_dict_field_escape(self, instance):
439
        for attr, field in six.iteritems(instance._fields):
440
            if isinstance(field, stormbase.EscapedDictField):
441
                value = getattr(instance, attr)
442
                setattr(instance, attr, field.to_python(value))
443
        return instance
444
445
    def _process_arg_filters(self, args):
446
        """
447
        Fix filter arguments in nested Q objects
448
        """
449
        _args = tuple()
450
451
        for arg in args:
452
            # Unforunately mongoengine doesn't expose any visitors other than Q, so we have to
453
            # extract QCombination from the module itself
454
            if isinstance(arg, visitor.Q):
455
                # Note: Both of those functions manipulate "filters" variable so the order in which
456
                # they are called matters
457
                filters, _ = self._process_datetime_range_filters(filters=arg.query)
458
                filters = self._process_null_filters(filters=filters)
459
460
                # Create a new Q object with the same filters as the old one
461
                _args += (visitor.Q(**filters),)
462
            elif isinstance(arg, visitor.QCombination):
463
                # Recurse if we need to
464
                children = self._process_arg_filters(arg.children)
465
466
                # Create a new QCombination object with the same operation and fixed filters
467
                _args += (visitor.QCombination(arg.operation, children),)
468
            else:
469
                raise TypeError("Unknown argument type '%s' of argument '%s'"
470
                    % (type(arg), repr(arg)))
471
472
        return _args
473
474
    def _process_null_filters(self, filters):
475
        result = copy.deepcopy(filters)
476
477
        null_filters = {k: v for k, v in six.iteritems(filters)
478
                        if v is None or
479
                        (type(v) in [str, six.text_type] and str(v.lower()) == 'null')}
480
481
        for key in null_filters.keys():
482
            result['%s__exists' % (key)] = False
483
            del result[key]
484
485
        return result
486
487
    def _process_datetime_range_filters(self, filters, order_by=None):
488
        ranges = {k: v for k, v in six.iteritems(filters)
489
                  if type(v) in [str, six.text_type] and '..' in v}
490
491
        order_by_list = copy.deepcopy(order_by) if order_by else []
492
        for k, v in six.iteritems(ranges):
493
            values = v.split('..')
494
            dt1 = isotime.parse(values[0])
495
            dt2 = isotime.parse(values[1])
496
497
            k__gte = '%s__gte' % k
498
            k__lte = '%s__lte' % k
499
            if dt1 < dt2:
500
                query = {k__gte: dt1, k__lte: dt2}
501
                sort_key, reverse_sort_key = k, '-' + k
502
            else:
503
                query = {k__gte: dt2, k__lte: dt1}
504
                sort_key, reverse_sort_key = '-' + k, k
505
            del filters[k]
506
            filters.update(query)
507
508
            if reverse_sort_key in order_by_list:
509
                idx = order_by_list.index(reverse_sort_key)
510
                order_by_list.pop(idx)
511
                order_by_list.insert(idx, sort_key)
512
            elif sort_key not in order_by_list:
513
                order_by_list = [sort_key] + order_by_list
514
515
        return filters, order_by_list
516
517
518
class ChangeRevisionMongoDBAccess(MongoDBAccess):
519
520
    def insert(self, instance):
521
        instance = self.model.objects.insert(instance)
522
523
        return self._undo_dict_field_escape(instance)
524
525
    def add_or_update(self, instance):
526
        return self.save(instance)
527
528
    def update(self, instance, **kwargs):
529
        for k, v in six.iteritems(kwargs):
530
            setattr(instance, k, v)
531
532
        return self.save(instance)
533
534
    def save(self, instance):
535
        if not hasattr(instance, 'id') or not instance.id:
536
            return self.insert(instance)
537
        else:
538
            try:
539
                save_condition = {'id': instance.id, 'rev': instance.rev}
540
                instance.rev = instance.rev + 1
541
                instance.save(save_condition=save_condition)
542
            except mongoengine.SaveConditionError:
543
                raise db_exc.StackStormDBObjectWriteConflictError(instance)
544
545
            return self._undo_dict_field_escape(instance)
546
547
548
def get_host_names_for_uri_dict(uri_dict):
549
    hosts = []
550
551
    for host, port in uri_dict['nodelist']:
552
        hosts.append('%s:%s' % (host, port))
553
554
    hosts = ','.join(hosts)
555
    return hosts
556