Completed
Pull Request — master (#2677)
by Manas
06:27
created

cleanup_extra_indexes()   A

Complexity

Conditions 4

Size

Total Lines 15

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
cc 4
c 1
b 0
f 0
dl 0
loc 15
rs 9.2
1
# Licensed to the StackStorm, Inc ('StackStorm') under one or more
2
# contributor license agreements.  See the NOTICE file distributed with
3
# this work for additional information regarding copyright ownership.
4
# The ASF licenses this file to You under the Apache License, Version 2.0
5
# (the "License"); you may not use this file except in compliance with
6
# the License.  You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
import copy
17
import importlib
18
import ssl as ssl_lib
19
20
import six
21
import mongoengine
22
from pymongo.errors import OperationFailure
23
24
from st2common import log as logging
25
from st2common.util import isotime
26
from st2common.models.db import stormbase
27
from st2common.models.utils.profiling import log_query_and_profile_data_for_queryset
28
from st2common.exceptions.db import StackStormDBObjectNotFoundError
29
30
31
LOG = logging.getLogger(__name__)
32
33
MODEL_MODULE_NAMES = [
34
    'st2common.models.db.auth',
35
    'st2common.models.db.action',
36
    'st2common.models.db.actionalias',
37
    'st2common.models.db.keyvalue',
38
    'st2common.models.db.execution',
39
    'st2common.models.db.executionstate',
40
    'st2common.models.db.liveaction',
41
    'st2common.models.db.pack',
42
    'st2common.models.db.policy',
43
    'st2common.models.db.rule',
44
    'st2common.models.db.runner',
45
    'st2common.models.db.sensor',
46
    'st2common.models.db.trigger',
47
]
48
49
50
def get_model_classes():
51
    """
52
    Retrieve a list of all the defined model classes.
53
54
    :rtype: ``list``
55
    """
56
    result = []
57
    for module_name in MODEL_MODULE_NAMES:
58
        module = importlib.import_module(module_name)
59
        model_classes = getattr(module, 'MODELS', [])
60
        result.extend(model_classes)
61
62
    return result
63
64
65
def db_setup(db_name, db_host, db_port, username=None, password=None,
66
             ensure_indexes=True, ssl=False, ssl_keyfile=None, ssl_certfile=None,
67
             ssl_cert_reqs=None, ssl_ca_certs=None, ssl_match_hostname=True):
68
    LOG.info('Connecting to database "%s" @ "%s:%s" as user "%s".',
69
             db_name, db_host, db_port, str(username))
70
71
    ssl_kwargs = _get_ssl_kwargs(ssl=ssl, ssl_keyfile=ssl_keyfile, ssl_certfile=ssl_certfile,
72
                                 ssl_cert_reqs=ssl_cert_reqs, ssl_ca_certs=ssl_ca_certs,
73
                                 ssl_match_hostname=ssl_match_hostname)
74
75
    connection = mongoengine.connection.connect(db_name, host=db_host,
76
                                                port=db_port, tz_aware=True,
77
                                                username=username, password=password,
78
                                                **ssl_kwargs)
79
80
    # Create all the indexes upfront to prevent race-conditions caused by
81
    # lazy index creation
82
    if ensure_indexes:
83
        db_ensure_indexes()
84
85
    return connection
86
87
88
def db_ensure_indexes():
89
    """
90
    This function ensures that indexes for all the models have been created and the
91
    extra indexes cleaned up.
92
93
    Note #1: When calling this method database connection already needs to be
94
    established.
95
96
    Note #2: This method blocks until all the index have been created (indexes
97
    are created in real-time and not in background).
98
    """
99
    LOG.debug('Ensuring database indexes...')
100
    model_classes = get_model_classes()
101
102
    for model_class in model_classes:
103
        # First clean-up extra indexes
104
        cleanup_extra_indexes(model_class=model_class)
105
        LOG.debug('Ensuring indexes for model "%s"...' % (model_class.__name__))
106
        model_class.ensure_indexes()
107
108
109
def cleanup_extra_indexes(model_class):
110
    """
111
    Finds any extra indexes and removes those from mongodb.
112
    """
113
    extra_indexes = model_class.compare_indexes().get('extra', None)
114
    if not extra_indexes:
115
        return
116
    # mongoengine does not have the necessary method so we need to drop to
117
    # pymongo interfaces via some private methods.
118
    c = model_class._get_collection()
119
    for extra_index in extra_indexes:
120
        try:
121
            c.drop_index(extra_index)
122
        except OperationFailure:
123
            LOG.warning('Attempt to cleanup index % failed.', extra_index, exc_info=True)
124
125
126
def db_teardown():
127
    mongoengine.connection.disconnect()
128
129
130
def _get_ssl_kwargs(ssl=False, ssl_keyfile=None, ssl_certfile=None, ssl_cert_reqs=None,
131
                    ssl_ca_certs=None, ssl_match_hostname=True):
132
    ssl_kwargs = {
133
        'ssl': ssl,
134
    }
135
    if ssl_keyfile:
136
        ssl_kwargs['ssl'] = True
137
        ssl_kwargs['ssl_keyfile'] = ssl_keyfile
138
    if ssl_certfile:
139
        ssl_kwargs['ssl'] = True
140
        ssl_kwargs['ssl_certfile'] = ssl_certfile
141
    if ssl_cert_reqs:
142
        if ssl_cert_reqs is 'none':
143
            ssl_cert_reqs = ssl_lib.CERT_NONE
144
        elif ssl_cert_reqs is 'optional':
145
            ssl_cert_reqs = ssl_lib.CERT_OPTIONAL
146
        elif ssl_cert_reqs is 'required':
147
            ssl_cert_reqs = ssl_lib.CERT_REQUIRED
148
        ssl_kwargs['ssl_cert_reqs'] = ssl_cert_reqs
149
    if ssl_ca_certs:
150
        ssl_kwargs['ssl'] = True
151
        ssl_kwargs['ssl_ca_certs'] = ssl_ca_certs
152
    if ssl_kwargs.get('ssl', False):
153
        # pass in ssl_match_hostname only if ssl is True. The right default value
154
        # for ssl_match_hostname in almost all cases is True.
155
        ssl_kwargs['ssl_match_hostname'] = ssl_match_hostname
156
    return ssl_kwargs
157
158
159
class MongoDBAccess(object):
160
    """Database object access class that provides general functions for a model type."""
161
162
    def __init__(self, model):
163
        self.model = model
164
165
    def get_by_name(self, value):
166
        return self.get(name=value, raise_exception=True)
167
168
    def get_by_id(self, value):
169
        return self.get(id=value, raise_exception=True)
170
171
    def get_by_ref(self, value):
172
        return self.get(ref=value, raise_exception=True)
173
174
    def get_by_pack(self, value):
175
        return self.get(pack=value, raise_exception=True)
176
177
    def get(self, exclude_fields=None, *args, **kwargs):
178
        raise_exception = kwargs.pop('raise_exception', False)
179
180
        instances = self.model.objects(**kwargs)
181
182
        if exclude_fields:
183
            instances = instances.exclude(*exclude_fields)
184
185
        instance = instances[0] if instances else None
186
        log_query_and_profile_data_for_queryset(queryset=instances)
187
188
        if not instance and raise_exception:
189
            msg = 'Unable to find the %s instance. %s' % (self.model.__name__, kwargs)
190
            raise StackStormDBObjectNotFoundError(msg)
191
192
        return instance
193
194
    def get_all(self, *args, **kwargs):
195
        return self.query(*args, **kwargs)
196
197
    def count(self, *args, **kwargs):
198
        result = self.model.objects(**kwargs).count()
199
        log_query_and_profile_data_for_queryset(queryset=result)
200
        return result
201
202
    def query(self, offset=0, limit=None, order_by=None, exclude_fields=None,
203
              **filters):
204
        order_by = order_by or []
205
        exclude_fields = exclude_fields or []
206
        eop = offset + int(limit) if limit else None
207
208
        # Process the filters
209
        # Note: Both of those functions manipulate "filters" variable so the order in which they
210
        # are called matters
211
        filters, order_by = self._process_datetime_range_filters(filters=filters, order_by=order_by)
212
        filters = self._process_null_filters(filters=filters)
213
214
        result = self.model.objects(**filters)
215
216
        if exclude_fields:
217
            result = result.exclude(*exclude_fields)
218
219
        result = result.order_by(*order_by)
220
        result = result[offset:eop]
221
        log_query_and_profile_data_for_queryset(queryset=result)
222
223
        return result
224
225
    def distinct(self, *args, **kwargs):
226
        field = kwargs.pop('field')
227
        result = self.model.objects(**kwargs).distinct(field)
228
        log_query_and_profile_data_for_queryset(queryset=result)
229
        return result
230
231
    def aggregate(self, *args, **kwargs):
232
        return self.model.objects(**kwargs)._collection.aggregate(*args, **kwargs)
233
234
    def insert(self, instance):
235
        instance = self.model.objects.insert(instance)
236
        return self._undo_dict_field_escape(instance)
237
238
    def add_or_update(self, instance):
239
        instance.save()
240
        return self._undo_dict_field_escape(instance)
241
242
    def update(self, instance, **kwargs):
243
        return instance.update(**kwargs)
244
245
    def delete(self, instance):
246
        return instance.delete()
247
248
    def delete_by_query(self, **query):
249
        qs = self.model.objects.filter(**query)
250
        qs.delete()
251
        log_query_and_profile_data_for_queryset(queryset=qs)
252
        # mongoengine does not return anything useful so cannot return anything meaningful.
253
        return None
254
255
    def _undo_dict_field_escape(self, instance):
256
        for attr, field in instance._fields.iteritems():
257
            if isinstance(field, stormbase.EscapedDictField):
258
                value = getattr(instance, attr)
259
                setattr(instance, attr, field.to_python(value))
260
        return instance
261
262
    def _process_null_filters(self, filters):
263
        result = copy.deepcopy(filters)
264
265
        null_filters = {k: v for k, v in six.iteritems(filters)
266
                        if v is None or (type(v) in [str, unicode] and str(v.lower()) == 'null')}
267
268
        for key in null_filters.keys():
269
            result['%s__exists' % (key)] = False
270
            del result[key]
271
272
        return result
273
274
    def _process_datetime_range_filters(self, filters, order_by=None):
275
        ranges = {k: v for k, v in filters.iteritems()
276
                  if type(v) in [str, unicode] and '..' in v}
277
278
        order_by_list = copy.deepcopy(order_by) if order_by else []
279
        for k, v in ranges.iteritems():
280
            values = v.split('..')
281
            dt1 = isotime.parse(values[0])
282
            dt2 = isotime.parse(values[1])
283
284
            k__gte = '%s__gte' % k
285
            k__lte = '%s__lte' % k
286
            if dt1 < dt2:
287
                query = {k__gte: dt1, k__lte: dt2}
288
                sort_key, reverse_sort_key = k, '-' + k
289
            else:
290
                query = {k__gte: dt2, k__lte: dt1}
291
                sort_key, reverse_sort_key = '-' + k, k
292
            del filters[k]
293
            filters.update(query)
294
295
            if reverse_sort_key in order_by_list:
296
                idx = order_by_list.index(reverse_sort_key)
297
                order_by_list.pop(idx)
298
                order_by_list.insert(idx, sort_key)
299
            elif sort_key not in order_by_list:
300
                order_by_list = [sort_key] + order_by_list
301
302
        return filters, order_by_list
303