Completed
Pull Request — master (#2622)
by Manas
11:51 queued 05:50
created

_get_ssl_kwargs()   F

Complexity

Conditions 9

Size

Total Lines 27

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 9
dl 0
loc 27
rs 3
1
# Licensed to the StackStorm, Inc ('StackStorm') under one or more
2
# contributor license agreements.  See the NOTICE file distributed with
3
# this work for additional information regarding copyright ownership.
4
# The ASF licenses this file to You under the Apache License, Version 2.0
5
# (the "License"); you may not use this file except in compliance with
6
# the License.  You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
import copy
17
import importlib
18
import ssl as ssl_lib
19
20
import six
21
import mongoengine
22
23
from st2common import log as logging
24
from st2common.util import isotime
25
from st2common.models.db import stormbase
26
from st2common.models.utils.profiling import log_query_and_profile_data_for_queryset
27
28
29
LOG = logging.getLogger(__name__)
30
31
MODEL_MODULE_NAMES = [
32
    'st2common.models.db.auth',
33
    'st2common.models.db.action',
34
    'st2common.models.db.actionalias',
35
    'st2common.models.db.keyvalue',
36
    'st2common.models.db.execution',
37
    'st2common.models.db.executionstate',
38
    'st2common.models.db.liveaction',
39
    'st2common.models.db.policy',
40
    'st2common.models.db.rule',
41
    'st2common.models.db.runner',
42
    'st2common.models.db.sensor',
43
    'st2common.models.db.trigger',
44
]
45
46
47
def get_model_classes():
48
    """
49
    Retrieve a list of all the defined model classes.
50
51
    :rtype: ``list``
52
    """
53
    result = []
54
    for module_name in MODEL_MODULE_NAMES:
55
        module = importlib.import_module(module_name)
56
        model_classes = getattr(module, 'MODELS', [])
57
        result.extend(model_classes)
58
59
    return result
60
61
62
def db_setup(db_name, db_host, db_port, username=None, password=None,
63
             ensure_indexes=True, ssl=False, ssl_keyfile=None, ssl_certfile=None,
64
             ssl_cert_reqs=None, ssl_ca_certs=None, ssl_match_hostname=True):
65
    LOG.info('Connecting to database "%s" @ "%s:%s" as user "%s".',
66
             db_name, db_host, db_port, str(username))
67
68
    ssl_kwargs = _get_ssl_kwargs(ssl=ssl, ssl_keyfile=ssl_keyfile, ssl_certfile=ssl_certfile,
69
                                 ssl_cert_reqs=ssl_cert_reqs, ssl_ca_certs=ssl_ca_certs,
70
                                 ssl_match_hostname=ssl_match_hostname)
71
72
    connection = mongoengine.connection.connect(db_name, host=db_host,
73
                                                port=db_port, tz_aware=True,
74
                                                username=username, password=password,
75
                                                **ssl_kwargs)
76
77
    # Create all the indexes upfront to prevent race-conditions caused by
78
    # lazy index creation
79
    if ensure_indexes:
80
        db_ensure_indexes()
81
82
    return connection
83
84
85
def db_ensure_indexes():
86
    """
87
    This function ensures that indexes for all the models have been created.
88
89
    Note #1: When calling this method database connection already needs to be
90
    established.
91
92
    Note #2: This method blocks until all the index have been created (indexes
93
    are created in real-time and not in background).
94
    """
95
    LOG.debug('Ensuring database indexes...')
96
    model_classes = get_model_classes()
97
98
    for cls in model_classes:
99
        LOG.debug('Ensuring indexes for model "%s"...' % (cls.__name__))
100
        cls.ensure_indexes()
101
102
103
def db_teardown():
104
    mongoengine.connection.disconnect()
105
106
107
def _get_ssl_kwargs(ssl=False, ssl_keyfile=None, ssl_certfile=None, ssl_cert_reqs=None,
108
                    ssl_ca_certs=None, ssl_match_hostname=True):
109
    ssl_kwargs = {
110
        'ssl': ssl,
111
    }
112
    if ssl_keyfile:
113
        ssl_kwargs['ssl'] = True
114
        ssl_kwargs['ssl_keyfile'] = ssl_keyfile
115
    if ssl_certfile:
116
        ssl_kwargs['ssl'] = True
117
        ssl_kwargs['ssl_certfile'] = ssl_certfile
118
    if ssl_cert_reqs:
119
        if ssl_cert_reqs is 'none':
120
            ssl_cert_reqs = ssl_lib.CERT_NONE
121
        elif ssl_cert_reqs is 'optional':
122
            ssl_cert_reqs = ssl_lib.CERT_OPTIONAL
123
        elif ssl_cert_reqs is 'required':
124
            ssl_cert_reqs = ssl_lib.CERT_REQUIRED
125
        ssl_kwargs['ssl_cert_reqs'] = ssl_cert_reqs
126
    if ssl_ca_certs:
127
        ssl_kwargs['ssl'] = True
128
        ssl_kwargs['ssl_ca_certs'] = ssl_ca_certs
129
    if ssl_kwargs.get('ssl', False):
130
        # pass in ssl_match_hostname only if ssl is True. The right default value
131
        # for ssl_match_hostname in almost all cases is True.
132
        ssl_kwargs['ssl_match_hostname'] = ssl_match_hostname
133
    return ssl_kwargs
134
135
136
class MongoDBAccess(object):
137
    """Database object access class that provides general functions for a model type."""
138
139
    def __init__(self, model):
140
        self.model = model
141
142
    def get_by_name(self, value):
143
        return self.get(name=value, raise_exception=True)
144
145
    def get_by_id(self, value):
146
        return self.get(id=value, raise_exception=True)
147
148
    def get_by_ref(self, value):
149
        return self.get(ref=value, raise_exception=True)
150
151
    def get(self, exclude_fields=None, *args, **kwargs):
152
        raise_exception = kwargs.pop('raise_exception', False)
153
154
        instances = self.model.objects(**kwargs)
155
156
        if exclude_fields:
157
            instances = instances.exclude(*exclude_fields)
158
159
        instance = instances[0] if instances else None
160
        log_query_and_profile_data_for_queryset(queryset=instances)
161
162
        if not instance and raise_exception:
163
            raise ValueError('Unable to find the %s instance. %s' % (self.model.__name__, kwargs))
164
        return instance
165
166
    def get_all(self, *args, **kwargs):
167
        return self.query(*args, **kwargs)
168
169
    def count(self, *args, **kwargs):
170
        result = self.model.objects(**kwargs).count()
171
        log_query_and_profile_data_for_queryset(queryset=result)
172
        return result
173
174
    def query(self, offset=0, limit=None, order_by=None, exclude_fields=None,
175
              **filters):
176
        order_by = order_by or []
177
        exclude_fields = exclude_fields or []
178
        eop = offset + int(limit) if limit else None
179
180
        # Process the filters
181
        # Note: Both of those functions manipulate "filters" variable so the order in which they
182
        # are called matters
183
        filters, order_by = self._process_datetime_range_filters(filters=filters, order_by=order_by)
184
        filters = self._process_null_filters(filters=filters)
185
186
        result = self.model.objects(**filters)
187
188
        if exclude_fields:
189
            result = result.exclude(*exclude_fields)
190
191
        result = result.order_by(*order_by)
192
        result = result[offset:eop]
193
        log_query_and_profile_data_for_queryset(queryset=result)
194
195
        return result
196
197
    def distinct(self, *args, **kwargs):
198
        field = kwargs.pop('field')
199
        result = self.model.objects(**kwargs).distinct(field)
200
        log_query_and_profile_data_for_queryset(queryset=result)
201
        return result
202
203
    def aggregate(self, *args, **kwargs):
204
        return self.model.objects(**kwargs)._collection.aggregate(*args, **kwargs)
205
206
    def insert(self, instance):
207
        instance = self.model.objects.insert(instance)
208
        return self._undo_dict_field_escape(instance)
209
210
    def add_or_update(self, instance):
211
        instance.save()
212
        return self._undo_dict_field_escape(instance)
213
214
    def update(self, instance, **kwargs):
215
        return instance.update(**kwargs)
216
217
    def delete(self, instance):
218
        return instance.delete()
219
220
    def delete_by_query(self, **query):
221
        qs = self.model.objects.filter(**query)
222
        qs.delete()
223
        log_query_and_profile_data_for_queryset(queryset=qs)
224
        # mongoengine does not return anything useful so cannot return anything meaningful.
225
        return None
226
227
    def _undo_dict_field_escape(self, instance):
228
        for attr, field in instance._fields.iteritems():
229
            if isinstance(field, stormbase.EscapedDictField):
230
                value = getattr(instance, attr)
231
                setattr(instance, attr, field.to_python(value))
232
        return instance
233
234
    def _process_null_filters(self, filters):
235
        result = copy.deepcopy(filters)
236
237
        null_filters = {k: v for k, v in six.iteritems(filters)
238
                        if v is None or (type(v) in [str, unicode] and str(v.lower()) == 'null')}
239
240
        for key in null_filters.keys():
241
            result['%s__exists' % (key)] = False
242
            del result[key]
243
244
        return result
245
246
    def _process_datetime_range_filters(self, filters, order_by=None):
247
        ranges = {k: v for k, v in filters.iteritems()
248
                  if type(v) in [str, unicode] and '..' in v}
249
250
        order_by_list = copy.deepcopy(order_by) if order_by else []
251
        for k, v in ranges.iteritems():
252
            values = v.split('..')
253
            dt1 = isotime.parse(values[0])
254
            dt2 = isotime.parse(values[1])
255
256
            k__gte = '%s__gte' % k
257
            k__lte = '%s__lte' % k
258
            if dt1 < dt2:
259
                query = {k__gte: dt1, k__lte: dt2}
260
                sort_key, reverse_sort_key = k, '-' + k
261
            else:
262
                query = {k__gte: dt2, k__lte: dt1}
263
                sort_key, reverse_sort_key = '-' + k, k
264
            del filters[k]
265
            filters.update(query)
266
267
            if reverse_sort_key in order_by_list:
268
                idx = order_by_list.index(reverse_sort_key)
269
                order_by_list.pop(idx)
270
                order_by_list.insert(idx, sort_key)
271
            elif sort_key not in order_by_list:
272
                order_by_list = [sort_key] + order_by_list
273
274
        return filters, order_by_list
275