Passed
Push — develop ( 8c7641...71cfc9 )
by Plexxi
07:22 queued 03:39
created

MongoDBAccess.query()   A

Complexity

Conditions 3

Size

Total Lines 22

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 3
dl 0
loc 22
rs 9.2
1
# Licensed to the StackStorm, Inc ('StackStorm') under one or more
2
# contributor license agreements.  See the NOTICE file distributed with
3
# this work for additional information regarding copyright ownership.
4
# The ASF licenses this file to You under the Apache License, Version 2.0
5
# (the "License"); you may not use this file except in compliance with
6
# the License.  You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
import copy
17
import importlib
18
import ssl as ssl_lib
19
20
import six
21
import mongoengine
22
23
from st2common import log as logging
24
from st2common.util import isotime
25
from st2common.models.db import stormbase
26
from st2common.models.utils.profiling import log_query_and_profile_data_for_queryset
27
28
29
LOG = logging.getLogger(__name__)
30
31
MODEL_MODULE_NAMES = [
32
    'st2common.models.db.auth',
33
    'st2common.models.db.action',
34
    'st2common.models.db.actionalias',
35
    'st2common.models.db.keyvalue',
36
    'st2common.models.db.execution',
37
    'st2common.models.db.executionstate',
38
    'st2common.models.db.liveaction',
39
    'st2common.models.db.pack',
40
    'st2common.models.db.policy',
41
    'st2common.models.db.rule',
42
    'st2common.models.db.runner',
43
    'st2common.models.db.sensor',
44
    'st2common.models.db.trigger',
45
]
46
47
48
def get_model_classes():
49
    """
50
    Retrieve a list of all the defined model classes.
51
52
    :rtype: ``list``
53
    """
54
    result = []
55
    for module_name in MODEL_MODULE_NAMES:
56
        module = importlib.import_module(module_name)
57
        model_classes = getattr(module, 'MODELS', [])
58
        result.extend(model_classes)
59
60
    return result
61
62
63
def db_setup(db_name, db_host, db_port, username=None, password=None,
64
             ensure_indexes=True, ssl=False, ssl_keyfile=None, ssl_certfile=None,
65
             ssl_cert_reqs=None, ssl_ca_certs=None, ssl_match_hostname=True):
66
    LOG.info('Connecting to database "%s" @ "%s:%s" as user "%s".',
67
             db_name, db_host, db_port, str(username))
68
69
    ssl_kwargs = _get_ssl_kwargs(ssl=ssl, ssl_keyfile=ssl_keyfile, ssl_certfile=ssl_certfile,
70
                                 ssl_cert_reqs=ssl_cert_reqs, ssl_ca_certs=ssl_ca_certs,
71
                                 ssl_match_hostname=ssl_match_hostname)
72
73
    connection = mongoengine.connection.connect(db_name, host=db_host,
74
                                                port=db_port, tz_aware=True,
75
                                                username=username, password=password,
76
                                                **ssl_kwargs)
77
78
    # Create all the indexes upfront to prevent race-conditions caused by
79
    # lazy index creation
80
    if ensure_indexes:
81
        db_ensure_indexes()
82
83
    return connection
84
85
86
def db_ensure_indexes():
87
    """
88
    This function ensures that indexes for all the models have been created.
89
90
    Note #1: When calling this method database connection already needs to be
91
    established.
92
93
    Note #2: This method blocks until all the index have been created (indexes
94
    are created in real-time and not in background).
95
    """
96
    LOG.debug('Ensuring database indexes...')
97
    model_classes = get_model_classes()
98
99
    for cls in model_classes:
100
        LOG.debug('Ensuring indexes for model "%s"...' % (cls.__name__))
101
        cls.ensure_indexes()
102
103
104
def db_teardown():
105
    mongoengine.connection.disconnect()
106
107
108
def _get_ssl_kwargs(ssl=False, ssl_keyfile=None, ssl_certfile=None, ssl_cert_reqs=None,
109
                    ssl_ca_certs=None, ssl_match_hostname=True):
110
    ssl_kwargs = {
111
        'ssl': ssl,
112
    }
113
    if ssl_keyfile:
114
        ssl_kwargs['ssl'] = True
115
        ssl_kwargs['ssl_keyfile'] = ssl_keyfile
116
    if ssl_certfile:
117
        ssl_kwargs['ssl'] = True
118
        ssl_kwargs['ssl_certfile'] = ssl_certfile
119
    if ssl_cert_reqs:
120
        if ssl_cert_reqs is 'none':
121
            ssl_cert_reqs = ssl_lib.CERT_NONE
122
        elif ssl_cert_reqs is 'optional':
123
            ssl_cert_reqs = ssl_lib.CERT_OPTIONAL
124
        elif ssl_cert_reqs is 'required':
125
            ssl_cert_reqs = ssl_lib.CERT_REQUIRED
126
        ssl_kwargs['ssl_cert_reqs'] = ssl_cert_reqs
127
    if ssl_ca_certs:
128
        ssl_kwargs['ssl'] = True
129
        ssl_kwargs['ssl_ca_certs'] = ssl_ca_certs
130
    if ssl_kwargs.get('ssl', False):
131
        # pass in ssl_match_hostname only if ssl is True. The right default value
132
        # for ssl_match_hostname in almost all cases is True.
133
        ssl_kwargs['ssl_match_hostname'] = ssl_match_hostname
134
    return ssl_kwargs
135
136
137
class MongoDBAccess(object):
138
    """Database object access class that provides general functions for a model type."""
139
140
    def __init__(self, model):
141
        self.model = model
142
143
    def get_by_name(self, value):
144
        return self.get(name=value, raise_exception=True)
145
146
    def get_by_id(self, value):
147
        return self.get(id=value, raise_exception=True)
148
149
    def get_by_ref(self, value):
150
        return self.get(ref=value, raise_exception=True)
151
152
    def get_by_pack(self, value):
153
        return self.get(pack=value, raise_exception=True)
154
155
    def get(self, exclude_fields=None, *args, **kwargs):
156
        raise_exception = kwargs.pop('raise_exception', False)
157
158
        instances = self.model.objects(**kwargs)
159
160
        if exclude_fields:
161
            instances = instances.exclude(*exclude_fields)
162
163
        instance = instances[0] if instances else None
164
        log_query_and_profile_data_for_queryset(queryset=instances)
165
166
        if not instance and raise_exception:
167
            raise ValueError('Unable to find the %s instance. %s' % (self.model.__name__, kwargs))
168
        return instance
169
170
    def get_all(self, *args, **kwargs):
171
        return self.query(*args, **kwargs)
172
173
    def count(self, *args, **kwargs):
174
        result = self.model.objects(**kwargs).count()
175
        log_query_and_profile_data_for_queryset(queryset=result)
176
        return result
177
178
    def query(self, offset=0, limit=None, order_by=None, exclude_fields=None,
179
              **filters):
180
        order_by = order_by or []
181
        exclude_fields = exclude_fields or []
182
        eop = offset + int(limit) if limit else None
183
184
        # Process the filters
185
        # Note: Both of those functions manipulate "filters" variable so the order in which they
186
        # are called matters
187
        filters, order_by = self._process_datetime_range_filters(filters=filters, order_by=order_by)
188
        filters = self._process_null_filters(filters=filters)
189
190
        result = self.model.objects(**filters)
191
192
        if exclude_fields:
193
            result = result.exclude(*exclude_fields)
194
195
        result = result.order_by(*order_by)
196
        result = result[offset:eop]
197
        log_query_and_profile_data_for_queryset(queryset=result)
198
199
        return result
200
201
    def distinct(self, *args, **kwargs):
202
        field = kwargs.pop('field')
203
        result = self.model.objects(**kwargs).distinct(field)
204
        log_query_and_profile_data_for_queryset(queryset=result)
205
        return result
206
207
    def aggregate(self, *args, **kwargs):
208
        return self.model.objects(**kwargs)._collection.aggregate(*args, **kwargs)
209
210
    def insert(self, instance):
211
        instance = self.model.objects.insert(instance)
212
        return self._undo_dict_field_escape(instance)
213
214
    def add_or_update(self, instance):
215
        instance.save()
216
        return self._undo_dict_field_escape(instance)
217
218
    def update(self, instance, **kwargs):
219
        return instance.update(**kwargs)
220
221
    def delete(self, instance):
222
        return instance.delete()
223
224
    def delete_by_query(self, **query):
225
        qs = self.model.objects.filter(**query)
226
        qs.delete()
227
        log_query_and_profile_data_for_queryset(queryset=qs)
228
        # mongoengine does not return anything useful so cannot return anything meaningful.
229
        return None
230
231
    def _undo_dict_field_escape(self, instance):
232
        for attr, field in instance._fields.iteritems():
233
            if isinstance(field, stormbase.EscapedDictField):
234
                value = getattr(instance, attr)
235
                setattr(instance, attr, field.to_python(value))
236
        return instance
237
238
    def _process_null_filters(self, filters):
239
        result = copy.deepcopy(filters)
240
241
        null_filters = {k: v for k, v in six.iteritems(filters)
242
                        if v is None or (type(v) in [str, unicode] and str(v.lower()) == 'null')}
243
244
        for key in null_filters.keys():
245
            result['%s__exists' % (key)] = False
246
            del result[key]
247
248
        return result
249
250
    def _process_datetime_range_filters(self, filters, order_by=None):
251
        ranges = {k: v for k, v in filters.iteritems()
252
                  if type(v) in [str, unicode] and '..' in v}
253
254
        order_by_list = copy.deepcopy(order_by) if order_by else []
255
        for k, v in ranges.iteritems():
256
            values = v.split('..')
257
            dt1 = isotime.parse(values[0])
258
            dt2 = isotime.parse(values[1])
259
260
            k__gte = '%s__gte' % k
261
            k__lte = '%s__lte' % k
262
            if dt1 < dt2:
263
                query = {k__gte: dt1, k__lte: dt2}
264
                sort_key, reverse_sort_key = k, '-' + k
265
            else:
266
                query = {k__gte: dt2, k__lte: dt1}
267
                sort_key, reverse_sort_key = '-' + k, k
268
            del filters[k]
269
            filters.update(query)
270
271
            if reverse_sort_key in order_by_list:
272
                idx = order_by_list.index(reverse_sort_key)
273
                order_by_list.pop(idx)
274
                order_by_list.insert(idx, sort_key)
275
            elif sort_key not in order_by_list:
276
                order_by_list = [sort_key] + order_by_list
277
278
        return filters, order_by_list
279