Completed
Pull Request — master (#2677)
by Manas
07:10 queued 20s
created

cleanup_extra_indexes()   A

Complexity

Conditions 3

Size

Total Lines 12

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
cc 3
c 1
b 0
f 0
dl 0
loc 12
rs 9.4285
1
# Licensed to the StackStorm, Inc ('StackStorm') under one or more
2
# contributor license agreements.  See the NOTICE file distributed with
3
# this work for additional information regarding copyright ownership.
4
# The ASF licenses this file to You under the Apache License, Version 2.0
5
# (the "License"); you may not use this file except in compliance with
6
# the License.  You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
import copy
17
import importlib
18
import ssl as ssl_lib
19
20
import six
21
import mongoengine
22
23
from st2common import log as logging
24
from st2common.util import isotime
25
from st2common.models.db import stormbase
26
from st2common.models.utils.profiling import log_query_and_profile_data_for_queryset
27
from st2common.exceptions.db import StackStormDBObjectNotFoundError
28
29
30
LOG = logging.getLogger(__name__)
31
32
MODEL_MODULE_NAMES = [
33
    'st2common.models.db.auth',
34
    'st2common.models.db.action',
35
    'st2common.models.db.actionalias',
36
    'st2common.models.db.keyvalue',
37
    'st2common.models.db.execution',
38
    'st2common.models.db.executionstate',
39
    'st2common.models.db.liveaction',
40
    'st2common.models.db.pack',
41
    'st2common.models.db.policy',
42
    'st2common.models.db.rule',
43
    'st2common.models.db.runner',
44
    'st2common.models.db.sensor',
45
    'st2common.models.db.trigger',
46
]
47
48
49
def get_model_classes():
50
    """
51
    Retrieve a list of all the defined model classes.
52
53
    :rtype: ``list``
54
    """
55
    result = []
56
    for module_name in MODEL_MODULE_NAMES:
57
        module = importlib.import_module(module_name)
58
        model_classes = getattr(module, 'MODELS', [])
59
        result.extend(model_classes)
60
61
    return result
62
63
64
def db_setup(db_name, db_host, db_port, username=None, password=None,
65
             ensure_indexes=True, ssl=False, ssl_keyfile=None, ssl_certfile=None,
66
             ssl_cert_reqs=None, ssl_ca_certs=None, ssl_match_hostname=True):
67
    LOG.info('Connecting to database "%s" @ "%s:%s" as user "%s".',
68
             db_name, db_host, db_port, str(username))
69
70
    ssl_kwargs = _get_ssl_kwargs(ssl=ssl, ssl_keyfile=ssl_keyfile, ssl_certfile=ssl_certfile,
71
                                 ssl_cert_reqs=ssl_cert_reqs, ssl_ca_certs=ssl_ca_certs,
72
                                 ssl_match_hostname=ssl_match_hostname)
73
74
    connection = mongoengine.connection.connect(db_name, host=db_host,
75
                                                port=db_port, tz_aware=True,
76
                                                username=username, password=password,
77
                                                **ssl_kwargs)
78
79
    # Create all the indexes upfront to prevent race-conditions caused by
80
    # lazy index creation
81
    if ensure_indexes:
82
        db_ensure_indexes()
83
84
    return connection
85
86
87
def db_ensure_indexes():
88
    """
89
    This function ensures that indexes for all the models have been created and the
90
    extra indexes cleaned up.
91
92
    Note #1: When calling this method database connection already needs to be
93
    established.
94
95
    Note #2: This method blocks until all the index have been created (indexes
96
    are created in real-time and not in background).
97
    """
98
    LOG.debug('Ensuring database indexes...')
99
    model_classes = get_model_classes()
100
101
    for model_class in model_classes:
102
        # First clean-up extra indexes
103
        cleanup_extra_indexes(model_class=model_class)
104
        LOG.debug('Ensuring indexes for model "%s"...' % (model_class.__name__))
105
        model_class.ensure_indexes()
106
107
108
def cleanup_extra_indexes(model_class):
109
    """
110
    Finds any extra indexes and removes those from mongodb.
111
    """
112
    extra_indexes = model_class.compare_indexes().get('extra', None)
113
    if not extra_indexes:
114
        return
115
    # mongoengine does not have the necessary method so we need to drop to
116
    # pymongo interfaces via some private methods.
117
    c = model_class._get_collection()
118
    for extra_index in extra_indexes:
119
        c.drop_index(extra_index)
120
121
122
def db_teardown():
123
    mongoengine.connection.disconnect()
124
125
126
def _get_ssl_kwargs(ssl=False, ssl_keyfile=None, ssl_certfile=None, ssl_cert_reqs=None,
127
                    ssl_ca_certs=None, ssl_match_hostname=True):
128
    ssl_kwargs = {
129
        'ssl': ssl,
130
    }
131
    if ssl_keyfile:
132
        ssl_kwargs['ssl'] = True
133
        ssl_kwargs['ssl_keyfile'] = ssl_keyfile
134
    if ssl_certfile:
135
        ssl_kwargs['ssl'] = True
136
        ssl_kwargs['ssl_certfile'] = ssl_certfile
137
    if ssl_cert_reqs:
138
        if ssl_cert_reqs is 'none':
139
            ssl_cert_reqs = ssl_lib.CERT_NONE
140
        elif ssl_cert_reqs is 'optional':
141
            ssl_cert_reqs = ssl_lib.CERT_OPTIONAL
142
        elif ssl_cert_reqs is 'required':
143
            ssl_cert_reqs = ssl_lib.CERT_REQUIRED
144
        ssl_kwargs['ssl_cert_reqs'] = ssl_cert_reqs
145
    if ssl_ca_certs:
146
        ssl_kwargs['ssl'] = True
147
        ssl_kwargs['ssl_ca_certs'] = ssl_ca_certs
148
    if ssl_kwargs.get('ssl', False):
149
        # pass in ssl_match_hostname only if ssl is True. The right default value
150
        # for ssl_match_hostname in almost all cases is True.
151
        ssl_kwargs['ssl_match_hostname'] = ssl_match_hostname
152
    return ssl_kwargs
153
154
155
class MongoDBAccess(object):
156
    """Database object access class that provides general functions for a model type."""
157
158
    def __init__(self, model):
159
        self.model = model
160
161
    def get_by_name(self, value):
162
        return self.get(name=value, raise_exception=True)
163
164
    def get_by_id(self, value):
165
        return self.get(id=value, raise_exception=True)
166
167
    def get_by_ref(self, value):
168
        return self.get(ref=value, raise_exception=True)
169
170
    def get_by_pack(self, value):
171
        return self.get(pack=value, raise_exception=True)
172
173
    def get(self, exclude_fields=None, *args, **kwargs):
174
        raise_exception = kwargs.pop('raise_exception', False)
175
176
        instances = self.model.objects(**kwargs)
177
178
        if exclude_fields:
179
            instances = instances.exclude(*exclude_fields)
180
181
        instance = instances[0] if instances else None
182
        log_query_and_profile_data_for_queryset(queryset=instances)
183
184
        if not instance and raise_exception:
185
            msg = 'Unable to find the %s instance. %s' % (self.model.__name__, kwargs)
186
            raise StackStormDBObjectNotFoundError(msg)
187
188
        return instance
189
190
    def get_all(self, *args, **kwargs):
191
        return self.query(*args, **kwargs)
192
193
    def count(self, *args, **kwargs):
194
        result = self.model.objects(**kwargs).count()
195
        log_query_and_profile_data_for_queryset(queryset=result)
196
        return result
197
198
    def query(self, offset=0, limit=None, order_by=None, exclude_fields=None,
199
              **filters):
200
        order_by = order_by or []
201
        exclude_fields = exclude_fields or []
202
        eop = offset + int(limit) if limit else None
203
204
        # Process the filters
205
        # Note: Both of those functions manipulate "filters" variable so the order in which they
206
        # are called matters
207
        filters, order_by = self._process_datetime_range_filters(filters=filters, order_by=order_by)
208
        filters = self._process_null_filters(filters=filters)
209
210
        result = self.model.objects(**filters)
211
212
        if exclude_fields:
213
            result = result.exclude(*exclude_fields)
214
215
        result = result.order_by(*order_by)
216
        result = result[offset:eop]
217
        log_query_and_profile_data_for_queryset(queryset=result)
218
219
        return result
220
221
    def distinct(self, *args, **kwargs):
222
        field = kwargs.pop('field')
223
        result = self.model.objects(**kwargs).distinct(field)
224
        log_query_and_profile_data_for_queryset(queryset=result)
225
        return result
226
227
    def aggregate(self, *args, **kwargs):
228
        return self.model.objects(**kwargs)._collection.aggregate(*args, **kwargs)
229
230
    def insert(self, instance):
231
        instance = self.model.objects.insert(instance)
232
        return self._undo_dict_field_escape(instance)
233
234
    def add_or_update(self, instance):
235
        instance.save()
236
        return self._undo_dict_field_escape(instance)
237
238
    def update(self, instance, **kwargs):
239
        return instance.update(**kwargs)
240
241
    def delete(self, instance):
242
        return instance.delete()
243
244
    def delete_by_query(self, **query):
245
        qs = self.model.objects.filter(**query)
246
        qs.delete()
247
        log_query_and_profile_data_for_queryset(queryset=qs)
248
        # mongoengine does not return anything useful so cannot return anything meaningful.
249
        return None
250
251
    def _undo_dict_field_escape(self, instance):
252
        for attr, field in instance._fields.iteritems():
253
            if isinstance(field, stormbase.EscapedDictField):
254
                value = getattr(instance, attr)
255
                setattr(instance, attr, field.to_python(value))
256
        return instance
257
258
    def _process_null_filters(self, filters):
259
        result = copy.deepcopy(filters)
260
261
        null_filters = {k: v for k, v in six.iteritems(filters)
262
                        if v is None or (type(v) in [str, unicode] and str(v.lower()) == 'null')}
263
264
        for key in null_filters.keys():
265
            result['%s__exists' % (key)] = False
266
            del result[key]
267
268
        return result
269
270
    def _process_datetime_range_filters(self, filters, order_by=None):
271
        ranges = {k: v for k, v in filters.iteritems()
272
                  if type(v) in [str, unicode] and '..' in v}
273
274
        order_by_list = copy.deepcopy(order_by) if order_by else []
275
        for k, v in ranges.iteritems():
276
            values = v.split('..')
277
            dt1 = isotime.parse(values[0])
278
            dt2 = isotime.parse(values[1])
279
280
            k__gte = '%s__gte' % k
281
            k__lte = '%s__lte' % k
282
            if dt1 < dt2:
283
                query = {k__gte: dt1, k__lte: dt2}
284
                sort_key, reverse_sort_key = k, '-' + k
285
            else:
286
                query = {k__gte: dt2, k__lte: dt1}
287
                sort_key, reverse_sort_key = '-' + k, k
288
            del filters[k]
289
            filters.update(query)
290
291
            if reverse_sort_key in order_by_list:
292
                idx = order_by_list.index(reverse_sort_key)
293
                order_by_list.pop(idx)
294
                order_by_list.insert(idx, sort_key)
295
            elif sort_key not in order_by_list:
296
                order_by_list = [sort_key] + order_by_list
297
298
        return filters, order_by_list
299