Passed
Push — master ( ccf839...2e3794 )
by Plexxi
02:27
created

db_ensure_indexes()   B

Complexity

Conditions 3

Size

Total Lines 28

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
cc 3
c 1
b 0
f 0
dl 0
loc 28
rs 8.8571
1
# Licensed to the StackStorm, Inc ('StackStorm') under one or more
2
# contributor license agreements.  See the NOTICE file distributed with
3
# this work for additional information regarding copyright ownership.
4
# The ASF licenses this file to You under the Apache License, Version 2.0
5
# (the "License"); you may not use this file except in compliance with
6
# the License.  You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
import copy
17
import importlib
18
import ssl as ssl_lib
19
20
import six
21
import mongoengine
22
from pymongo.errors import OperationFailure
23
24
from st2common import log as logging
25
from st2common.util import isotime
26
from st2common.models.db import stormbase
27
from st2common.models.utils.profiling import log_query_and_profile_data_for_queryset
28
from st2common.exceptions.db import StackStormDBObjectNotFoundError
29
30
31
LOG = logging.getLogger(__name__)
32
33
MODEL_MODULE_NAMES = [
34
    'st2common.models.db.auth',
35
    'st2common.models.db.action',
36
    'st2common.models.db.actionalias',
37
    'st2common.models.db.keyvalue',
38
    'st2common.models.db.execution',
39
    'st2common.models.db.executionstate',
40
    'st2common.models.db.liveaction',
41
    'st2common.models.db.notification',
42
    'st2common.models.db.pack',
43
    'st2common.models.db.policy',
44
    'st2common.models.db.rbac',
45
    'st2common.models.db.rule',
46
    'st2common.models.db.rule_enforcement',
47
    'st2common.models.db.runner',
48
    'st2common.models.db.sensor',
49
    'st2common.models.db.trace',
50
    'st2common.models.db.trigger',
51
    'st2common.models.db.webhook'
52
]
53
54
# A list of model names for which we don't perform extra index cleanup
55
INDEX_CLEANUP_MODEL_NAMES_BLACKLIST = [
56
    'PermissionGrantDB'
57
]
58
59
60
def get_model_classes():
61
    """
62
    Retrieve a list of all the defined model classes.
63
64
    :rtype: ``list``
65
    """
66
    result = []
67
    for module_name in MODEL_MODULE_NAMES:
68
        module = importlib.import_module(module_name)
69
        model_classes = getattr(module, 'MODELS', [])
70
        result.extend(model_classes)
71
72
    return result
73
74
75
def db_setup(db_name, db_host, db_port, username=None, password=None, ensure_indexes=True,
76
             ssl=False, ssl_keyfile=None, ssl_certfile=None,
77
             ssl_cert_reqs=None, ssl_ca_certs=None, ssl_match_hostname=True):
78
    LOG.info('Connecting to database "%s" @ "%s:%s" as user "%s".',
79
             db_name, db_host, db_port, str(username))
80
81
    ssl_kwargs = _get_ssl_kwargs(ssl=ssl, ssl_keyfile=ssl_keyfile, ssl_certfile=ssl_certfile,
82
                                 ssl_cert_reqs=ssl_cert_reqs, ssl_ca_certs=ssl_ca_certs,
83
                                 ssl_match_hostname=ssl_match_hostname)
84
85
    connection = mongoengine.connection.connect(db_name, host=db_host,
86
                                                port=db_port, tz_aware=True,
87
                                                username=username, password=password,
88
                                                **ssl_kwargs)
89
90
    # Create all the indexes upfront to prevent race-conditions caused by
91
    # lazy index creation
92
    if ensure_indexes:
93
        db_ensure_indexes()
94
95
    return connection
96
97
98
def db_ensure_indexes():
99
    """
100
    This function ensures that indexes for all the models have been created and the
101
    extra indexes cleaned up.
102
103
    Note #1: When calling this method database connection already needs to be
104
    established.
105
106
    Note #2: This method blocks until all the index have been created (indexes
107
    are created in real-time and not in background).
108
    """
109
    LOG.debug('Ensuring database indexes...')
110
    model_classes = get_model_classes()
111
112
    for model_class in model_classes:
113
        class_name = model_class.__name__
114
115
        # Note: We need to ensure / create new indexes before removing extra ones
116
        LOG.debug('Ensuring indexes for model "%s"...' % (model_class.__name__))
117
        model_class.ensure_indexes()
118
119
        if model_class.__name__ in INDEX_CLEANUP_MODEL_NAMES_BLACKLIST:
120
            LOG.debug('Skipping index cleanup for blacklisted model "%s"...' % (class_name))
121
            continue
122
123
        LOG.debug('Removing extra indexes for model "%s"...' % (class_name))
124
        removed_count = cleanup_extra_indexes(model_class=model_class)
125
        LOG.debug('Removed "%s" extra indexes for model "%s"' % (removed_count, class_name))
126
127
128
def cleanup_extra_indexes(model_class):
129
    """
130
    Finds any extra indexes and removes those from mongodb.
131
    """
132
    extra_indexes = model_class.compare_indexes().get('extra', None)
133
    if not extra_indexes:
134
        return 0
135
136
    # mongoengine does not have the necessary method so we need to drop to
137
    # pymongo interfaces via some private methods.
138
    removed_count = 0
139
    c = model_class._get_collection()
140
    for extra_index in extra_indexes:
141
        try:
142
            c.drop_index(extra_index)
143
            LOG.debug('Dropped index %s for model %s.', extra_index, model_class.__name__)
144
            removed_count += 1
145
        except OperationFailure:
146
            LOG.warning('Attempt to cleanup index %s failed.', extra_index, exc_info=True)
147
148
    return removed_count
149
150
151
def db_teardown():
152
    mongoengine.connection.disconnect()
153
154
155
def _get_ssl_kwargs(ssl=False, ssl_keyfile=None, ssl_certfile=None, ssl_cert_reqs=None,
156
                    ssl_ca_certs=None, ssl_match_hostname=True):
157
    ssl_kwargs = {
158
        'ssl': ssl,
159
    }
160
    if ssl_keyfile:
161
        ssl_kwargs['ssl'] = True
162
        ssl_kwargs['ssl_keyfile'] = ssl_keyfile
163
    if ssl_certfile:
164
        ssl_kwargs['ssl'] = True
165
        ssl_kwargs['ssl_certfile'] = ssl_certfile
166
    if ssl_cert_reqs:
167
        if ssl_cert_reqs is 'none':
168
            ssl_cert_reqs = ssl_lib.CERT_NONE
169
        elif ssl_cert_reqs is 'optional':
170
            ssl_cert_reqs = ssl_lib.CERT_OPTIONAL
171
        elif ssl_cert_reqs is 'required':
172
            ssl_cert_reqs = ssl_lib.CERT_REQUIRED
173
        ssl_kwargs['ssl_cert_reqs'] = ssl_cert_reqs
174
    if ssl_ca_certs:
175
        ssl_kwargs['ssl'] = True
176
        ssl_kwargs['ssl_ca_certs'] = ssl_ca_certs
177
    if ssl_kwargs.get('ssl', False):
178
        # pass in ssl_match_hostname only if ssl is True. The right default value
179
        # for ssl_match_hostname in almost all cases is True.
180
        ssl_kwargs['ssl_match_hostname'] = ssl_match_hostname
181
    return ssl_kwargs
182
183
184
class MongoDBAccess(object):
185
    """Database object access class that provides general functions for a model type."""
186
187
    def __init__(self, model):
188
        self.model = model
189
190
    def get_by_name(self, value):
191
        return self.get(name=value, raise_exception=True)
192
193
    def get_by_id(self, value):
194
        return self.get(id=value, raise_exception=True)
195
196
    def get_by_uid(self, value):
197
        return self.get(uid=value, raise_exception=True)
198
199
    def get_by_ref(self, value):
200
        return self.get(ref=value, raise_exception=True)
201
202
    def get_by_pack(self, value):
203
        return self.get(pack=value, raise_exception=True)
204
205
    def get(self, exclude_fields=None, *args, **kwargs):
206
        raise_exception = kwargs.pop('raise_exception', False)
207
208
        instances = self.model.objects(**kwargs)
209
210
        if exclude_fields:
211
            instances = instances.exclude(*exclude_fields)
212
213
        instance = instances[0] if instances else None
214
        log_query_and_profile_data_for_queryset(queryset=instances)
215
216
        if not instance and raise_exception:
217
            msg = 'Unable to find the %s instance. %s' % (self.model.__name__, kwargs)
218
            raise StackStormDBObjectNotFoundError(msg)
219
220
        return instance
221
222
    def get_all(self, *args, **kwargs):
223
        return self.query(*args, **kwargs)
224
225
    def count(self, *args, **kwargs):
226
        result = self.model.objects(**kwargs).count()
227
        log_query_and_profile_data_for_queryset(queryset=result)
228
        return result
229
230
    def query(self, offset=0, limit=None, order_by=None, exclude_fields=None,
231
              **filters):
232
        order_by = order_by or []
233
        exclude_fields = exclude_fields or []
234
        eop = offset + int(limit) if limit else None
235
236
        # Process the filters
237
        # Note: Both of those functions manipulate "filters" variable so the order in which they
238
        # are called matters
239
        filters, order_by = self._process_datetime_range_filters(filters=filters, order_by=order_by)
240
        filters = self._process_null_filters(filters=filters)
241
242
        result = self.model.objects(**filters)
243
244
        if exclude_fields:
245
            result = result.exclude(*exclude_fields)
246
247
        result = result.order_by(*order_by)
248
        result = result[offset:eop]
249
        log_query_and_profile_data_for_queryset(queryset=result)
250
251
        return result
252
253
    def distinct(self, *args, **kwargs):
254
        field = kwargs.pop('field')
255
        result = self.model.objects(**kwargs).distinct(field)
256
        log_query_and_profile_data_for_queryset(queryset=result)
257
        return result
258
259
    def aggregate(self, *args, **kwargs):
260
        return self.model.objects(**kwargs)._collection.aggregate(*args, **kwargs)
261
262
    def insert(self, instance):
263
        instance = self.model.objects.insert(instance)
264
        return self._undo_dict_field_escape(instance)
265
266
    def add_or_update(self, instance):
267
        instance.save()
268
        return self._undo_dict_field_escape(instance)
269
270
    def update(self, instance, **kwargs):
271
        return instance.update(**kwargs)
272
273
    def delete(self, instance):
274
        return instance.delete()
275
276
    def delete_by_query(self, **query):
277
        """
278
        Delete objects by query and return number of deleted objects.
279
        """
280
        qs = self.model.objects.filter(**query)
281
        count = qs.delete()
282
        log_query_and_profile_data_for_queryset(queryset=qs)
283
284
        return count
285
286
    def _undo_dict_field_escape(self, instance):
287
        for attr, field in instance._fields.iteritems():
288
            if isinstance(field, stormbase.EscapedDictField):
289
                value = getattr(instance, attr)
290
                setattr(instance, attr, field.to_python(value))
291
        return instance
292
293
    def _process_null_filters(self, filters):
294
        result = copy.deepcopy(filters)
295
296
        null_filters = {k: v for k, v in six.iteritems(filters)
297
                        if v is None or (type(v) in [str, unicode] and str(v.lower()) == 'null')}
298
299
        for key in null_filters.keys():
300
            result['%s__exists' % (key)] = False
301
            del result[key]
302
303
        return result
304
305
    def _process_datetime_range_filters(self, filters, order_by=None):
306
        ranges = {k: v for k, v in filters.iteritems()
307
                  if type(v) in [str, unicode] and '..' in v}
308
309
        order_by_list = copy.deepcopy(order_by) if order_by else []
310
        for k, v in ranges.iteritems():
311
            values = v.split('..')
312
            dt1 = isotime.parse(values[0])
313
            dt2 = isotime.parse(values[1])
314
315
            k__gte = '%s__gte' % k
316
            k__lte = '%s__lte' % k
317
            if dt1 < dt2:
318
                query = {k__gte: dt1, k__lte: dt2}
319
                sort_key, reverse_sort_key = k, '-' + k
320
            else:
321
                query = {k__gte: dt2, k__lte: dt1}
322
                sort_key, reverse_sort_key = '-' + k, k
323
            del filters[k]
324
            filters.update(query)
325
326
            if reverse_sort_key in order_by_list:
327
                idx = order_by_list.index(reverse_sort_key)
328
                order_by_list.pop(idx)
329
                order_by_list.insert(idx, sort_key)
330
            elif sort_key not in order_by_list:
331
                order_by_list = [sort_key] + order_by_list
332
333
        return filters, order_by_list
334