Completed
Pull Request — master (#2622)
by Manas
06:03
created

_get_ssl_kwargs()   D

Complexity

Conditions 8

Size

Total Lines 24

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 8
dl 0
loc 24
rs 4.3478
1
# Licensed to the StackStorm, Inc ('StackStorm') under one or more
2
# contributor license agreements.  See the NOTICE file distributed with
3
# this work for additional information regarding copyright ownership.
4
# The ASF licenses this file to You under the Apache License, Version 2.0
5
# (the "License"); you may not use this file except in compliance with
6
# the License.  You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
import copy
17
import importlib
18
import ssl as ssl_lib
19
20
import six
21
import mongoengine
22
23
from st2common import log as logging
24
from st2common.util import isotime
25
from st2common.models.db import stormbase
26
from st2common.models.utils.profiling import log_query_and_profile_data_for_queryset
27
28
29
LOG = logging.getLogger(__name__)
30
31
MODEL_MODULE_NAMES = [
32
    'st2common.models.db.auth',
33
    'st2common.models.db.action',
34
    'st2common.models.db.actionalias',
35
    'st2common.models.db.keyvalue',
36
    'st2common.models.db.execution',
37
    'st2common.models.db.executionstate',
38
    'st2common.models.db.liveaction',
39
    'st2common.models.db.policy',
40
    'st2common.models.db.rule',
41
    'st2common.models.db.runner',
42
    'st2common.models.db.sensor',
43
    'st2common.models.db.trigger',
44
]
45
46
47
def get_model_classes():
48
    """
49
    Retrieve a list of all the defined model classes.
50
51
    :rtype: ``list``
52
    """
53
    result = []
54
    for module_name in MODEL_MODULE_NAMES:
55
        module = importlib.import_module(module_name)
56
        model_classes = getattr(module, 'MODELS', [])
57
        result.extend(model_classes)
58
59
    return result
60
61
62
def db_setup(db_name, db_host, db_port, username=None, password=None,
63
             ensure_indexes=True, ssl=False, ssl_keyfile=None, ssl_certfile=None,
64
             ssl_cert_reqs=None, ssl_ca_certs=None, ssl_match_hostname=True):
65
    LOG.info('Connecting to database "%s" @ "%s:%s" as user "%s".',
66
             db_name, db_host, db_port, str(username))
67
    ssl_kwargs = _get_ssl_kwargs()
68
    connection = mongoengine.connection.connect(db_name, host=db_host,
69
                                                port=db_port, tz_aware=True,
70
                                                username=username, password=password,
71
                                                **ssl_kwargs)
72
73
    # Create all the indexes upfront to prevent race-conditions caused by
74
    # lazy index creation
75
    if ensure_indexes:
76
        db_ensure_indexes()
77
78
    return connection
79
80
81
def db_ensure_indexes():
82
    """
83
    This function ensures that indexes for all the models have been created.
84
85
    Note #1: When calling this method database connection already needs to be
86
    established.
87
88
    Note #2: This method blocks until all the index have been created (indexes
89
    are created in real-time and not in background).
90
    """
91
    LOG.debug('Ensuring database indexes...')
92
    model_classes = get_model_classes()
93
94
    for cls in model_classes:
95
        LOG.debug('Ensuring indexes for model "%s"...' % (cls.__name__))
96
        cls.ensure_indexes()
97
98
99
def db_teardown():
100
    mongoengine.connection.disconnect()
101
102
103
def _get_ssl_kwargs(ssl=False, ssl_keyfile=None, ssl_certfile=None, ssl_cert_reqs=None,
104
                    ssl_ca_certs=None, ssl_match_hostname=True):
105
    ssl_kwargs = {
106
        'ssl': ssl,
107
        'ssl_match_hostname': ssl_match_hostname
108
    }
109
    if ssl_keyfile:
110
        ssl_kwargs['ssl'] = True
111
        ssl_kwargs['ssl_keyfile'] = ssl_keyfile
112
    if ssl_certfile:
113
        ssl_kwargs['ssl'] = True
114
        ssl_kwargs['ssl_certfile'] = ssl_certfile
115
    if ssl_cert_reqs:
116
        if ssl_cert_reqs is 'none':
117
            ssl_cert_reqs = ssl_lib.CERT_NONE
118
        elif ssl_cert_reqs is 'optional':
119
            ssl_cert_reqs = ssl_lib.CERT_OPTIONAL
120
        elif ssl_cert_reqs is 'required':
121
            ssl_cert_reqs = ssl_lib.CERT_REQUIRED
122
        ssl_kwargs['ssl_cert_reqs'] = ssl_cert_reqs
123
    if ssl_ca_certs:
124
        ssl_kwargs['ssl'] = True
125
        ssl_kwargs['ssl_ca_certs'] = ssl_ca_certs
126
    return ssl_kwargs
127
128
129
class MongoDBAccess(object):
130
    """Database object access class that provides general functions for a model type."""
131
132
    def __init__(self, model):
133
        self.model = model
134
135
    def get_by_name(self, value):
136
        return self.get(name=value, raise_exception=True)
137
138
    def get_by_id(self, value):
139
        return self.get(id=value, raise_exception=True)
140
141
    def get_by_ref(self, value):
142
        return self.get(ref=value, raise_exception=True)
143
144
    def get(self, exclude_fields=None, *args, **kwargs):
145
        raise_exception = kwargs.pop('raise_exception', False)
146
147
        instances = self.model.objects(**kwargs)
148
149
        if exclude_fields:
150
            instances = instances.exclude(*exclude_fields)
151
152
        instance = instances[0] if instances else None
153
        log_query_and_profile_data_for_queryset(queryset=instances)
154
155
        if not instance and raise_exception:
156
            raise ValueError('Unable to find the %s instance. %s' % (self.model.__name__, kwargs))
157
        return instance
158
159
    def get_all(self, *args, **kwargs):
160
        return self.query(*args, **kwargs)
161
162
    def count(self, *args, **kwargs):
163
        result = self.model.objects(**kwargs).count()
164
        log_query_and_profile_data_for_queryset(queryset=result)
165
        return result
166
167
    def query(self, offset=0, limit=None, order_by=None, exclude_fields=None,
168
              **filters):
169
        order_by = order_by or []
170
        exclude_fields = exclude_fields or []
171
        eop = offset + int(limit) if limit else None
172
173
        # Process the filters
174
        # Note: Both of those functions manipulate "filters" variable so the order in which they
175
        # are called matters
176
        filters, order_by = self._process_datetime_range_filters(filters=filters, order_by=order_by)
177
        filters = self._process_null_filters(filters=filters)
178
179
        result = self.model.objects(**filters)
180
181
        if exclude_fields:
182
            result = result.exclude(*exclude_fields)
183
184
        result = result.order_by(*order_by)
185
        result = result[offset:eop]
186
        log_query_and_profile_data_for_queryset(queryset=result)
187
188
        return result
189
190
    def distinct(self, *args, **kwargs):
191
        field = kwargs.pop('field')
192
        result = self.model.objects(**kwargs).distinct(field)
193
        log_query_and_profile_data_for_queryset(queryset=result)
194
        return result
195
196
    def aggregate(self, *args, **kwargs):
197
        return self.model.objects(**kwargs)._collection.aggregate(*args, **kwargs)
198
199
    def insert(self, instance):
200
        instance = self.model.objects.insert(instance)
201
        return self._undo_dict_field_escape(instance)
202
203
    def add_or_update(self, instance):
204
        instance.save()
205
        return self._undo_dict_field_escape(instance)
206
207
    def update(self, instance, **kwargs):
208
        return instance.update(**kwargs)
209
210
    def delete(self, instance):
211
        return instance.delete()
212
213
    def delete_by_query(self, **query):
214
        qs = self.model.objects.filter(**query)
215
        qs.delete()
216
        log_query_and_profile_data_for_queryset(queryset=qs)
217
        # mongoengine does not return anything useful so cannot return anything meaningful.
218
        return None
219
220
    def _undo_dict_field_escape(self, instance):
221
        for attr, field in instance._fields.iteritems():
222
            if isinstance(field, stormbase.EscapedDictField):
223
                value = getattr(instance, attr)
224
                setattr(instance, attr, field.to_python(value))
225
        return instance
226
227
    def _process_null_filters(self, filters):
228
        result = copy.deepcopy(filters)
229
230
        null_filters = {k: v for k, v in six.iteritems(filters)
231
                        if v is None or (type(v) in [str, unicode] and str(v.lower()) == 'null')}
232
233
        for key in null_filters.keys():
234
            result['%s__exists' % (key)] = False
235
            del result[key]
236
237
        return result
238
239
    def _process_datetime_range_filters(self, filters, order_by=None):
240
        ranges = {k: v for k, v in filters.iteritems()
241
                  if type(v) in [str, unicode] and '..' in v}
242
243
        order_by_list = copy.deepcopy(order_by) if order_by else []
244
        for k, v in ranges.iteritems():
245
            values = v.split('..')
246
            dt1 = isotime.parse(values[0])
247
            dt2 = isotime.parse(values[1])
248
249
            k__gte = '%s__gte' % k
250
            k__lte = '%s__lte' % k
251
            if dt1 < dt2:
252
                query = {k__gte: dt1, k__lte: dt2}
253
                sort_key, reverse_sort_key = k, '-' + k
254
            else:
255
                query = {k__gte: dt2, k__lte: dt1}
256
                sort_key, reverse_sort_key = '-' + k, k
257
            del filters[k]
258
            filters.update(query)
259
260
            if reverse_sort_key in order_by_list:
261
                idx = order_by_list.index(reverse_sort_key)
262
                order_by_list.pop(idx)
263
                order_by_list.insert(idx, sort_key)
264
            elif sort_key not in order_by_list:
265
                order_by_list = [sort_key] + order_by_list
266
267
        return filters, order_by_list
268