Passed
Push — develop ( 9f128b...266903 )
by Plexxi
06:37 queued 03:15
created

MongoDBAccess.get_by_id()   A

Complexity

Conditions 1

Size

Total Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
dl 0
loc 2
rs 10
c 0
b 0
f 0
1
# Licensed to the StackStorm, Inc ('StackStorm') under one or more
2
# contributor license agreements.  See the NOTICE file distributed with
3
# this work for additional information regarding copyright ownership.
4
# The ASF licenses this file to You under the Apache License, Version 2.0
5
# (the "License"); you may not use this file except in compliance with
6
# the License.  You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
import copy
17
import importlib
18
import ssl as ssl_lib
19
20
import six
21
import mongoengine
22
from pymongo.errors import OperationFailure
23
24
from st2common import log as logging
25
from st2common.util import isotime
26
from st2common.models.db import stormbase
27
from st2common.models.utils.profiling import log_query_and_profile_data_for_queryset
28
from st2common.exceptions.db import StackStormDBObjectNotFoundError
29
30
31
LOG = logging.getLogger(__name__)
32
33
MODEL_MODULE_NAMES = [
34
    'st2common.models.db.auth',
35
    'st2common.models.db.action',
36
    'st2common.models.db.actionalias',
37
    'st2common.models.db.keyvalue',
38
    'st2common.models.db.execution',
39
    'st2common.models.db.executionstate',
40
    'st2common.models.db.liveaction',
41
    'st2common.models.db.pack',
42
    'st2common.models.db.policy',
43
    'st2common.models.db.rule',
44
    'st2common.models.db.runner',
45
    'st2common.models.db.sensor',
46
    'st2common.models.db.trigger',
47
]
48
49
50
def get_model_classes():
51
    """
52
    Retrieve a list of all the defined model classes.
53
54
    :rtype: ``list``
55
    """
56
    result = []
57
    for module_name in MODEL_MODULE_NAMES:
58
        module = importlib.import_module(module_name)
59
        model_classes = getattr(module, 'MODELS', [])
60
        result.extend(model_classes)
61
62
    return result
63
64
65
def db_setup(db_name, db_host, db_port, username=None, password=None,
66
             ensure_indexes=True, ssl=False, ssl_keyfile=None, ssl_certfile=None,
67
             ssl_cert_reqs=None, ssl_ca_certs=None, ssl_match_hostname=True):
68
    LOG.info('Connecting to database "%s" @ "%s:%s" as user "%s".',
69
             db_name, db_host, db_port, str(username))
70
71
    ssl_kwargs = _get_ssl_kwargs(ssl=ssl, ssl_keyfile=ssl_keyfile, ssl_certfile=ssl_certfile,
72
                                 ssl_cert_reqs=ssl_cert_reqs, ssl_ca_certs=ssl_ca_certs,
73
                                 ssl_match_hostname=ssl_match_hostname)
74
75
    connection = mongoengine.connection.connect(db_name, host=db_host,
76
                                                port=db_port, tz_aware=True,
77
                                                username=username, password=password,
78
                                                **ssl_kwargs)
79
80
    # Create all the indexes upfront to prevent race-conditions caused by
81
    # lazy index creation
82
    if ensure_indexes:
83
        db_ensure_indexes()
84
85
    return connection
86
87
88
def db_ensure_indexes():
89
    """
90
    This function ensures that indexes for all the models have been created and the
91
    extra indexes cleaned up.
92
93
    Note #1: When calling this method database connection already needs to be
94
    established.
95
96
    Note #2: This method blocks until all the index have been created (indexes
97
    are created in real-time and not in background).
98
    """
99
    LOG.debug('Ensuring database indexes...')
100
    model_classes = get_model_classes()
101
102
    for model_class in model_classes:
103
        # First clean-up extra indexes
104
        cleanup_extra_indexes(model_class=model_class)
105
        LOG.debug('Ensuring indexes for model "%s"...' % (model_class.__name__))
106
        model_class.ensure_indexes()
107
108
109
def cleanup_extra_indexes(model_class):
110
    """
111
    Finds any extra indexes and removes those from mongodb.
112
    """
113
    extra_indexes = model_class.compare_indexes().get('extra', None)
114
    if not extra_indexes:
115
        return
116
    # mongoengine does not have the necessary method so we need to drop to
117
    # pymongo interfaces via some private methods.
118
    c = model_class._get_collection()
119
    for extra_index in extra_indexes:
120
        try:
121
            c.drop_index(extra_index)
122
            LOG.debug('Dropped index %s for model %s.', extra_index, model_class.__name__)
123
        except OperationFailure:
124
            LOG.warning('Attempt to cleanup index % failed.', extra_index, exc_info=True)
125
126
127
def db_teardown():
128
    mongoengine.connection.disconnect()
129
130
131
def _get_ssl_kwargs(ssl=False, ssl_keyfile=None, ssl_certfile=None, ssl_cert_reqs=None,
132
                    ssl_ca_certs=None, ssl_match_hostname=True):
133
    ssl_kwargs = {
134
        'ssl': ssl,
135
    }
136
    if ssl_keyfile:
137
        ssl_kwargs['ssl'] = True
138
        ssl_kwargs['ssl_keyfile'] = ssl_keyfile
139
    if ssl_certfile:
140
        ssl_kwargs['ssl'] = True
141
        ssl_kwargs['ssl_certfile'] = ssl_certfile
142
    if ssl_cert_reqs:
143
        if ssl_cert_reqs is 'none':
144
            ssl_cert_reqs = ssl_lib.CERT_NONE
145
        elif ssl_cert_reqs is 'optional':
146
            ssl_cert_reqs = ssl_lib.CERT_OPTIONAL
147
        elif ssl_cert_reqs is 'required':
148
            ssl_cert_reqs = ssl_lib.CERT_REQUIRED
149
        ssl_kwargs['ssl_cert_reqs'] = ssl_cert_reqs
150
    if ssl_ca_certs:
151
        ssl_kwargs['ssl'] = True
152
        ssl_kwargs['ssl_ca_certs'] = ssl_ca_certs
153
    if ssl_kwargs.get('ssl', False):
154
        # pass in ssl_match_hostname only if ssl is True. The right default value
155
        # for ssl_match_hostname in almost all cases is True.
156
        ssl_kwargs['ssl_match_hostname'] = ssl_match_hostname
157
    return ssl_kwargs
158
159
160
class MongoDBAccess(object):
161
    """Database object access class that provides general functions for a model type."""
162
163
    def __init__(self, model):
164
        self.model = model
165
166
    def get_by_name(self, value):
167
        return self.get(name=value, raise_exception=True)
168
169
    def get_by_id(self, value):
170
        return self.get(id=value, raise_exception=True)
171
172
    def get_by_uid(self, value):
173
        return self.get(uid=value, raise_exception=True)
174
175
    def get_by_ref(self, value):
176
        return self.get(ref=value, raise_exception=True)
177
178
    def get_by_pack(self, value):
179
        return self.get(pack=value, raise_exception=True)
180
181
    def get(self, exclude_fields=None, *args, **kwargs):
182
        raise_exception = kwargs.pop('raise_exception', False)
183
184
        instances = self.model.objects(**kwargs)
185
186
        if exclude_fields:
187
            instances = instances.exclude(*exclude_fields)
188
189
        instance = instances[0] if instances else None
190
        log_query_and_profile_data_for_queryset(queryset=instances)
191
192
        if not instance and raise_exception:
193
            msg = 'Unable to find the %s instance. %s' % (self.model.__name__, kwargs)
194
            raise StackStormDBObjectNotFoundError(msg)
195
196
        return instance
197
198
    def get_all(self, *args, **kwargs):
199
        return self.query(*args, **kwargs)
200
201
    def count(self, *args, **kwargs):
202
        result = self.model.objects(**kwargs).count()
203
        log_query_and_profile_data_for_queryset(queryset=result)
204
        return result
205
206
    def query(self, offset=0, limit=None, order_by=None, exclude_fields=None,
207
              **filters):
208
        order_by = order_by or []
209
        exclude_fields = exclude_fields or []
210
        eop = offset + int(limit) if limit else None
211
212
        # Process the filters
213
        # Note: Both of those functions manipulate "filters" variable so the order in which they
214
        # are called matters
215
        filters, order_by = self._process_datetime_range_filters(filters=filters, order_by=order_by)
216
        filters = self._process_null_filters(filters=filters)
217
218
        result = self.model.objects(**filters)
219
220
        if exclude_fields:
221
            result = result.exclude(*exclude_fields)
222
223
        result = result.order_by(*order_by)
224
        result = result[offset:eop]
225
        log_query_and_profile_data_for_queryset(queryset=result)
226
227
        return result
228
229
    def distinct(self, *args, **kwargs):
230
        field = kwargs.pop('field')
231
        result = self.model.objects(**kwargs).distinct(field)
232
        log_query_and_profile_data_for_queryset(queryset=result)
233
        return result
234
235
    def aggregate(self, *args, **kwargs):
236
        return self.model.objects(**kwargs)._collection.aggregate(*args, **kwargs)
237
238
    def insert(self, instance):
239
        instance = self.model.objects.insert(instance)
240
        return self._undo_dict_field_escape(instance)
241
242
    def add_or_update(self, instance):
243
        instance.save()
244
        return self._undo_dict_field_escape(instance)
245
246
    def update(self, instance, **kwargs):
247
        return instance.update(**kwargs)
248
249
    def delete(self, instance):
250
        return instance.delete()
251
252
    def delete_by_query(self, **query):
253
        qs = self.model.objects.filter(**query)
254
        qs.delete()
255
        log_query_and_profile_data_for_queryset(queryset=qs)
256
        # mongoengine does not return anything useful so cannot return anything meaningful.
257
        return None
258
259
    def _undo_dict_field_escape(self, instance):
260
        for attr, field in instance._fields.iteritems():
261
            if isinstance(field, stormbase.EscapedDictField):
262
                value = getattr(instance, attr)
263
                setattr(instance, attr, field.to_python(value))
264
        return instance
265
266
    def _process_null_filters(self, filters):
267
        result = copy.deepcopy(filters)
268
269
        null_filters = {k: v for k, v in six.iteritems(filters)
270
                        if v is None or (type(v) in [str, unicode] and str(v.lower()) == 'null')}
271
272
        for key in null_filters.keys():
273
            result['%s__exists' % (key)] = False
274
            del result[key]
275
276
        return result
277
278
    def _process_datetime_range_filters(self, filters, order_by=None):
279
        ranges = {k: v for k, v in filters.iteritems()
280
                  if type(v) in [str, unicode] and '..' in v}
281
282
        order_by_list = copy.deepcopy(order_by) if order_by else []
283
        for k, v in ranges.iteritems():
284
            values = v.split('..')
285
            dt1 = isotime.parse(values[0])
286
            dt2 = isotime.parse(values[1])
287
288
            k__gte = '%s__gte' % k
289
            k__lte = '%s__lte' % k
290
            if dt1 < dt2:
291
                query = {k__gte: dt1, k__lte: dt2}
292
                sort_key, reverse_sort_key = k, '-' + k
293
            else:
294
                query = {k__gte: dt2, k__lte: dt1}
295
                sort_key, reverse_sort_key = '-' + k, k
296
            del filters[k]
297
            filters.update(query)
298
299
            if reverse_sort_key in order_by_list:
300
                idx = order_by_list.index(reverse_sort_key)
301
                order_by_list.pop(idx)
302
                order_by_list.insert(idx, sort_key)
303
            elif sort_key not in order_by_list:
304
                order_by_list = [sort_key] + order_by_list
305
306
        return filters, order_by_list
307