Test Failed
Pull Request — master (#4023)
by W
03:56
created

ChangeRevisionMongoDBAccess.insert()   A

Complexity

Conditions 1

Size

Total Lines 5

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
c 0
b 0
f 0
dl 0
loc 5
rs 9.4285
1
# Licensed to the StackStorm, Inc ('StackStorm') under one or more
2
# contributor license agreements.  See the NOTICE file distributed with
3
# this work for additional information regarding copyright ownership.
4
# The ASF licenses this file to You under the Apache License, Version 2.0
5
# (the "License"); you may not use this file except in compliance with
6
# the License.  You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
from __future__ import absolute_import
17
18
import copy
19
import importlib
20
import traceback
21
import ssl as ssl_lib
22
23
import six
24
import mongoengine
25
from mongoengine.queryset import visitor
26
from pymongo import uri_parser
27
from pymongo.errors import OperationFailure
28
29
from st2common import log as logging
30
from st2common.util import isotime
31
from st2common.models.db import stormbase
32
from st2common.models.utils.profiling import log_query_and_profile_data_for_queryset
33
from st2common.exceptions.db import StackStormDBObjectNotFoundError
34
35
36
LOG = logging.getLogger(__name__)
37
38
MODEL_MODULE_NAMES = [
39
    'st2common.models.db.auth',
40
    'st2common.models.db.action',
41
    'st2common.models.db.actionalias',
42
    'st2common.models.db.keyvalue',
43
    'st2common.models.db.execution',
44
    'st2common.models.db.executionstate',
45
    'st2common.models.db.liveaction',
46
    'st2common.models.db.notification',
47
    'st2common.models.db.pack',
48
    'st2common.models.db.policy',
49
    'st2common.models.db.rbac',
50
    'st2common.models.db.rule',
51
    'st2common.models.db.rule_enforcement',
52
    'st2common.models.db.runner',
53
    'st2common.models.db.sensor',
54
    'st2common.models.db.trace',
55
    'st2common.models.db.trigger',
56
    'st2common.models.db.webhook'
57
]
58
59
# A list of model names for which we don't perform extra index cleanup
60
INDEX_CLEANUP_MODEL_NAMES_BLACKLIST = [
61
    'PermissionGrantDB'
62
]
63
64
65
def get_model_classes():
66
    """
67
    Retrieve a list of all the defined model classes.
68
69
    :rtype: ``list``
70
    """
71
    result = []
72
    for module_name in MODEL_MODULE_NAMES:
73
        module = importlib.import_module(module_name)
74
        model_classes = getattr(module, 'MODELS', [])
75
        result.extend(model_classes)
76
77
    return result
78
79
80
def _db_connect(db_name, db_host, db_port, username=None, password=None,
81
             ssl=False, ssl_keyfile=None, ssl_certfile=None,
82
             ssl_cert_reqs=None, ssl_ca_certs=None, ssl_match_hostname=True):
83
84
    if '://' in db_host:
85
        # Hostname is provided as a URI string. Make sure we don't log the password in case one is
86
        # included as part of the URI string.
87
        uri_dict = uri_parser.parse_uri(db_host)
88
        username_string = uri_dict.get('username', username) or username
89
90
        if uri_dict.get('username', None) and username:
91
            # Username argument has precedence over connection string username
92
            username_string = username
93
94
        hostnames = get_host_names_for_uri_dict(uri_dict=uri_dict)
95
96
        if len(uri_dict['nodelist']) > 1:
97
            host_string = '%s (replica set)' % (hostnames)
98
        else:
99
            host_string = hostnames
100
    else:
101
        host_string = '%s:%s' % (db_host, db_port)
102
        username_string = username
103
104
    LOG.info('Connecting to database "%s" @ "%s" as user "%s".' % (db_name, host_string,
105
                                                                   str(username_string)))
106
107
    ssl_kwargs = _get_ssl_kwargs(ssl=ssl, ssl_keyfile=ssl_keyfile, ssl_certfile=ssl_certfile,
108
                                 ssl_cert_reqs=ssl_cert_reqs, ssl_ca_certs=ssl_ca_certs,
109
                                 ssl_match_hostname=ssl_match_hostname)
110
111
    connection = mongoengine.connection.connect(db_name, host=db_host,
112
                                                port=db_port, tz_aware=True,
113
                                                username=username, password=password,
114
                                                **ssl_kwargs)
115
    return connection
116
117
118
def db_setup(db_name, db_host, db_port, username=None, password=None, ensure_indexes=True,
119
             ssl=False, ssl_keyfile=None, ssl_certfile=None,
120
             ssl_cert_reqs=None, ssl_ca_certs=None, ssl_match_hostname=True):
121
122
    connection = _db_connect(db_name, db_host, db_port, username=username,
123
                             password=password, ssl=ssl, ssl_keyfile=ssl_keyfile,
124
                             ssl_certfile=ssl_certfile,
125
                             ssl_cert_reqs=ssl_cert_reqs, ssl_ca_certs=ssl_ca_certs,
126
                             ssl_match_hostname=ssl_match_hostname)
127
128
    # Create all the indexes upfront to prevent race-conditions caused by
129
    # lazy index creation
130
    if ensure_indexes:
131
        db_ensure_indexes()
132
133
    return connection
134
135
136
def db_ensure_indexes():
137
    """
138
    This function ensures that indexes for all the models have been created and the
139
    extra indexes cleaned up.
140
141
    Note #1: When calling this method database connection already needs to be
142
    established.
143
144
    Note #2: This method blocks until all the index have been created (indexes
145
    are created in real-time and not in background).
146
    """
147
    LOG.debug('Ensuring database indexes...')
148
    model_classes = get_model_classes()
149
150
    for model_class in model_classes:
151
        class_name = model_class.__name__
152
153
        # Note: We need to ensure / create new indexes before removing extra ones
154
        try:
155
            model_class.ensure_indexes()
156
        except OperationFailure as e:
157
            # Special case for "uid" index. MongoDB 3.4 has dropped "_types" index option so we
158
            # need to re-create the index to make it work and avoid "index with different options
159
            # already exists" error.
160
            # Note: This condition would only be encountered when upgrading existing StackStorm
161
            # installation from MongoDB 3.2 to 3.4.
162
            msg = str(e)
163
            if 'already exists with different options' in msg and 'uid_1' in msg:
164
                drop_obsolete_types_indexes(model_class=model_class)
165
            else:
166
                raise e
167
        except Exception as e:
168
            tb_msg = traceback.format_exc()
169
            msg = 'Failed to ensure indexes for model "%s": %s' % (class_name, str(e))
170
            msg += '\n\n' + tb_msg
171
            exc_cls = type(e)
172
            raise exc_cls(msg)
173
174
        if model_class.__name__ in INDEX_CLEANUP_MODEL_NAMES_BLACKLIST:
175
            LOG.debug('Skipping index cleanup for blacklisted model "%s"...' % (class_name))
176
            continue
177
178
        removed_count = cleanup_extra_indexes(model_class=model_class)
179
        if removed_count:
180
            LOG.debug('Removed "%s" extra indexes for model "%s"' % (removed_count, class_name))
181
182
    LOG.debug('Indexes are ensured for models: %s' %
183
              ', '.join(sorted((model_class.__name__ for model_class in model_classes))))
184
185
186
def cleanup_extra_indexes(model_class):
187
    """
188
    Finds any extra indexes and removes those from mongodb.
189
    """
190
    extra_indexes = model_class.compare_indexes().get('extra', None)
191
    if not extra_indexes:
192
        return 0
193
194
    # mongoengine does not have the necessary method so we need to drop to
195
    # pymongo interfaces via some private methods.
196
    removed_count = 0
197
    c = model_class._get_collection()
198
    for extra_index in extra_indexes:
199
        try:
200
            c.drop_index(extra_index)
201
            LOG.debug('Dropped index %s for model %s.', extra_index, model_class.__name__)
202
            removed_count += 1
203
        except OperationFailure:
204
            LOG.warning('Attempt to cleanup index %s failed.', extra_index, exc_info=True)
205
206
    return removed_count
207
208
209
def drop_obsolete_types_indexes(model_class):
210
    """
211
    Special class for droping offending "types" indexes for which support has
212
    been removed in mongoengine and MongoDB 3.4.
213
    For more info, see: http://docs.mongoengine.org/upgrade.html#inheritance
214
    """
215
    class_name = model_class.__name__
216
217
    LOG.debug('Dropping obsolete types index for model "%s"' % (class_name))
218
    collection = model_class._get_collection()
219
    collection.update({}, {'$unset': {'_types': 1}}, multi=True)
220
221
    info = collection.index_information()
222
    indexes_to_drop = [key for key, value in six.iteritems(info)
223
                       if '_types' in dict(value['key']) or 'types' in value]
224
225
    LOG.debug('Will drop obsolete types indexes for model "%s": %s' % (class_name,
226
                                                                       str(indexes_to_drop)))
227
228
    for index in indexes_to_drop:
229
        collection.drop_index(index)
230
231
    LOG.debug('Recreating indexes for model "%s"' % (class_name))
232
    model_class.ensure_indexes()
233
234
235
def db_teardown():
236
    mongoengine.connection.disconnect()
237
238
239
def db_cleanup(db_name, db_host, db_port, username=None, password=None,
240
               ssl=False, ssl_keyfile=None, ssl_certfile=None,
241
               ssl_cert_reqs=None, ssl_ca_certs=None, ssl_match_hostname=True):
242
243
    connection = _db_connect(db_name, db_host, db_port, username=username,
244
                             password=password, ssl=ssl, ssl_keyfile=ssl_keyfile,
245
                             ssl_certfile=ssl_certfile,
246
                             ssl_cert_reqs=ssl_cert_reqs, ssl_ca_certs=ssl_ca_certs,
247
                             ssl_match_hostname=ssl_match_hostname)
248
249
    LOG.info('Dropping database "%s" @ "%s:%s" as user "%s".',
250
             db_name, db_host, db_port, str(username))
251
252
    connection.drop_database(db_name)
253
    return connection
254
255
256
def _get_ssl_kwargs(ssl=False, ssl_keyfile=None, ssl_certfile=None, ssl_cert_reqs=None,
257
                    ssl_ca_certs=None, ssl_match_hostname=True):
258
    ssl_kwargs = {
259
        'ssl': ssl,
260
    }
261
    if ssl_keyfile:
262
        ssl_kwargs['ssl'] = True
263
        ssl_kwargs['ssl_keyfile'] = ssl_keyfile
264
    if ssl_certfile:
265
        ssl_kwargs['ssl'] = True
266
        ssl_kwargs['ssl_certfile'] = ssl_certfile
267
    if ssl_cert_reqs:
268
        if ssl_cert_reqs is 'none':
269
            ssl_cert_reqs = ssl_lib.CERT_NONE
270
        elif ssl_cert_reqs is 'optional':
271
            ssl_cert_reqs = ssl_lib.CERT_OPTIONAL
272
        elif ssl_cert_reqs is 'required':
273
            ssl_cert_reqs = ssl_lib.CERT_REQUIRED
274
        ssl_kwargs['ssl_cert_reqs'] = ssl_cert_reqs
275
    if ssl_ca_certs:
276
        ssl_kwargs['ssl'] = True
277
        ssl_kwargs['ssl_ca_certs'] = ssl_ca_certs
278
    if ssl_kwargs.get('ssl', False):
279
        # pass in ssl_match_hostname only if ssl is True. The right default value
280
        # for ssl_match_hostname in almost all cases is True.
281
        ssl_kwargs['ssl_match_hostname'] = ssl_match_hostname
282
    return ssl_kwargs
283
284
285
class MongoDBAccess(object):
286
    """Database object access class that provides general functions for a model type."""
287
288
    def __init__(self, model):
289
        self.model = model
290
291
    def get_by_name(self, value):
292
        return self.get(name=value, raise_exception=True)
293
294
    def get_by_id(self, value):
295
        return self.get(id=value, raise_exception=True)
296
297
    def get_by_uid(self, value):
298
        return self.get(uid=value, raise_exception=True)
299
300
    def get_by_ref(self, value):
301
        return self.get(ref=value, raise_exception=True)
302
303
    def get_by_pack(self, value):
304
        return self.get(pack=value, raise_exception=True)
305
306
    def get(self, *args, **kwargs):
307
        exclude_fields = kwargs.pop('exclude_fields', None)
308
        raise_exception = kwargs.pop('raise_exception', False)
309
        only_fields = kwargs.pop('only_fields', None)
310
311
        args = self._process_arg_filters(args)
312
313
        instances = self.model.objects(*args, **kwargs)
314
315
        if exclude_fields:
316
            instances = instances.exclude(*exclude_fields)
317
318
        if only_fields:
319
            try:
320
                instances = instances.only(*only_fields)
321
            except mongoengine.errors.LookUpError as e:
322
                msg = ('Invalid or unsupported include attribute specified: %s' % str(e))
323
                raise ValueError(msg)
324
325
        instance = instances[0] if instances else None
326
        log_query_and_profile_data_for_queryset(queryset=instances)
327
328
        if not instance and raise_exception:
329
            msg = 'Unable to find the %s instance. %s' % (self.model.__name__, kwargs)
330
            raise StackStormDBObjectNotFoundError(msg)
331
332
        return instance
333
334
    def get_all(self, *args, **kwargs):
335
        return self.query(*args, **kwargs)
336
337
    def count(self, *args, **kwargs):
338
        result = self.model.objects(*args, **kwargs).count()
339
        log_query_and_profile_data_for_queryset(queryset=result)
340
        return result
341
342
    # TODO: PEP-3102 introduced keyword-only arguments, so once we support Python 3+, we can change
343
    #       this definition to have explicit keyword-only arguments:
344
    #
345
    #           def query(self, *args, offset=0, limit=None, order_by=None, exclude_fields=None,
346
    #                     **filters):
347
    def query(self, *args, **filters):
348
        # Python 2: Pop keyword parameters that aren't actually filters off of the kwargs
349
        offset = filters.pop('offset', 0)
350
        limit = filters.pop('limit', None)
351
        order_by = filters.pop('order_by', None)
352
        exclude_fields = filters.pop('exclude_fields', None)
353
        only_fields = filters.pop('only_fields', None)
354
        no_dereference = filters.pop('no_dereference', None)
355
356
        order_by = order_by or []
357
        exclude_fields = exclude_fields or []
358
        eop = offset + int(limit) if limit else None
359
360
        args = self._process_arg_filters(args)
361
        # Process the filters
362
        # Note: Both of those functions manipulate "filters" variable so the order in which they
363
        # are called matters
364
        filters, order_by = self._process_datetime_range_filters(filters=filters, order_by=order_by)
365
        filters = self._process_null_filters(filters=filters)
366
367
        result = self.model.objects(*args, **filters)
368
369
        if exclude_fields:
370
            result = result.exclude(*exclude_fields)
371
372
        if only_fields:
373
            try:
374
                result = result.only(*only_fields)
375
            except mongoengine.errors.LookUpError as e:
376
                msg = ('Invalid or unsupported include attribute specified: %s' % str(e))
377
                raise ValueError(msg)
378
379
        if no_dereference:
380
            result = result.no_dereference()
381
382
        result = result.order_by(*order_by)
383
        result = result[offset:eop]
384
        log_query_and_profile_data_for_queryset(queryset=result)
385
386
        return result
387
388
    def distinct(self, *args, **kwargs):
389
        field = kwargs.pop('field')
390
        result = self.model.objects(**kwargs).distinct(field)
391
        log_query_and_profile_data_for_queryset(queryset=result)
392
        return result
393
394
    def aggregate(self, *args, **kwargs):
395
        return self.model.objects(**kwargs)._collection.aggregate(*args, **kwargs)
396
397
    def insert(self, instance):
398
        instance = self.model.objects.insert(instance)
399
        return self._undo_dict_field_escape(instance)
400
401
    def add_or_update(self, instance):
402
        instance.save()
403
        return self._undo_dict_field_escape(instance)
404
405
    def update(self, instance, **kwargs):
406
        return instance.update(**kwargs)
407
408
    def delete(self, instance):
409
        return instance.delete()
410
411
    def delete_by_query(self, *args, **query):
412
        """
413
        Delete objects by query and return number of deleted objects.
414
        """
415
        qs = self.model.objects.filter(*args, **query)
416
        count = qs.delete()
417
        log_query_and_profile_data_for_queryset(queryset=qs)
418
419
        return count
420
421
    def _undo_dict_field_escape(self, instance):
422
        for attr, field in six.iteritems(instance._fields):
423
            if isinstance(field, stormbase.EscapedDictField):
424
                value = getattr(instance, attr)
425
                setattr(instance, attr, field.to_python(value))
426
        return instance
427
428
    def _process_arg_filters(self, args):
429
        """
430
        Fix filter arguments in nested Q objects
431
        """
432
        _args = tuple()
433
434
        for arg in args:
435
            # Unforunately mongoengine doesn't expose any visitors other than Q, so we have to
436
            # extract QCombination from the module itself
437
            if isinstance(arg, visitor.Q):
438
                # Note: Both of those functions manipulate "filters" variable so the order in which
439
                # they are called matters
440
                filters, _ = self._process_datetime_range_filters(filters=arg.query)
441
                filters = self._process_null_filters(filters=filters)
442
443
                # Create a new Q object with the same filters as the old one
444
                _args += (visitor.Q(**filters),)
445
            elif isinstance(arg, visitor.QCombination):
446
                # Recurse if we need to
447
                children = self._process_arg_filters(arg.children)
448
449
                # Create a new QCombination object with the same operation and fixed filters
450
                _args += (visitor.QCombination(arg.operation, children),)
451
            else:
452
                raise TypeError("Unknown argument type '%s' of argument '%s'"
453
                    % (type(arg), repr(arg)))
454
455
        return _args
456
457
    def _process_null_filters(self, filters):
458
        result = copy.deepcopy(filters)
459
460
        null_filters = {k: v for k, v in six.iteritems(filters)
461
                        if v is None or
462
                        (type(v) in [str, six.text_type] and str(v.lower()) == 'null')}
463
464
        for key in null_filters.keys():
465
            result['%s__exists' % (key)] = False
466
            del result[key]
467
468
        return result
469
470
    def _process_datetime_range_filters(self, filters, order_by=None):
471
        ranges = {k: v for k, v in six.iteritems(filters)
472
                  if type(v) in [str, six.text_type] and '..' in v}
473
474
        order_by_list = copy.deepcopy(order_by) if order_by else []
475
        for k, v in six.iteritems(ranges):
476
            values = v.split('..')
477
            dt1 = isotime.parse(values[0])
478
            dt2 = isotime.parse(values[1])
479
480
            k__gte = '%s__gte' % k
481
            k__lte = '%s__lte' % k
482
            if dt1 < dt2:
483
                query = {k__gte: dt1, k__lte: dt2}
484
                sort_key, reverse_sort_key = k, '-' + k
485
            else:
486
                query = {k__gte: dt2, k__lte: dt1}
487
                sort_key, reverse_sort_key = '-' + k, k
488
            del filters[k]
489
            filters.update(query)
490
491
            if reverse_sort_key in order_by_list:
492
                idx = order_by_list.index(reverse_sort_key)
493
                order_by_list.pop(idx)
494
                order_by_list.insert(idx, sort_key)
495
            elif sort_key not in order_by_list:
496
                order_by_list = [sort_key] + order_by_list
497
498
        return filters, order_by_list
499
500
501
class ChangeRevisionMongoDBAccess(MongoDBAccess):
502
503
    def get(self, *args, **kwargs):
504
        instance = super(ChangeRevisionMongoDBAccess, self).get(*args, **kwargs)
505
        instance._original_rev = instance.rev
506
        return instance
507
508
    def insert(self, instance):
509
        instance = self.model.objects.insert(instance)
510
        instance._original_rev = instance.rev
511
512
        return self._undo_dict_field_escape(instance)
513
514
    def add_or_update(self, instance):
515
        if not hasattr(instance, '_original_rev'):
516
            return self.insert(instance)
517
        else:
518
            save_condition = {
519
                'id': instance.id,
520
                'rev': instance._original_rev
521
            }
522
523
            instance.rev = instance._original_rev + 1
524
            instance.save(save_condition=save_condition)
525
526
            return self._undo_dict_field_escape(instance)
527
528
    def update(self, instance, **kwargs):
529
        for k, v in six.iteritems(kwargs):
530
            setattr(instance, k, v)
531
532
        save_condition = {
533
            'id': instance.id,
534
            'rev': instance._original_rev
535
        }
536
537
        instance.rev = instance._original_rev + 1
538
        instance.save(save_condition=save_condition)
539
540
        return self._undo_dict_field_escape(instance)
541
542
543
def get_host_names_for_uri_dict(uri_dict):
544
    hosts = []
545
546
    for host, port in uri_dict['nodelist']:
547
        hosts.append('%s:%s' % (host, port))
548
549
    hosts = ','.join(hosts)
550
    return hosts
551