Completed
Pull Request — master (#2622)
by Manas
06:42
created

_get_ssl_kwargs()   F

Complexity

Conditions 9

Size

Total Lines 27

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 9
dl 0
loc 27
rs 3
1
# Licensed to the StackStorm, Inc ('StackStorm') under one or more
2
# contributor license agreements.  See the NOTICE file distributed with
3
# this work for additional information regarding copyright ownership.
4
# The ASF licenses this file to You under the Apache License, Version 2.0
5
# (the "License"); you may not use this file except in compliance with
6
# the License.  You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
import copy
17
import importlib
18
import ssl as ssl_lib
19
20
import six
21
import mongoengine
22
23
from st2common import log as logging
24
from st2common.util import isotime
25
from st2common.models.db import stormbase
26
from st2common.models.utils.profiling import log_query_and_profile_data_for_queryset
27
28
29
LOG = logging.getLogger(__name__)
30
31
MODEL_MODULE_NAMES = [
32
    'st2common.models.db.auth',
33
    'st2common.models.db.action',
34
    'st2common.models.db.actionalias',
35
    'st2common.models.db.keyvalue',
36
    'st2common.models.db.execution',
37
    'st2common.models.db.executionstate',
38
    'st2common.models.db.liveaction',
39
    'st2common.models.db.policy',
40
    'st2common.models.db.rule',
41
    'st2common.models.db.runner',
42
    'st2common.models.db.sensor',
43
    'st2common.models.db.trigger',
44
]
45
46
47
def get_model_classes():
48
    """
49
    Retrieve a list of all the defined model classes.
50
51
    :rtype: ``list``
52
    """
53
    result = []
54
    for module_name in MODEL_MODULE_NAMES:
55
        module = importlib.import_module(module_name)
56
        model_classes = getattr(module, 'MODELS', [])
57
        result.extend(model_classes)
58
59
    return result
60
61
62
def db_setup(db_name, db_host, db_port, username=None, password=None,
63
             ensure_indexes=True, ssl=False, ssl_keyfile=None, ssl_certfile=None,
64
             ssl_cert_reqs=None, ssl_ca_certs=None, ssl_match_hostname=True):
65
    LOG.info('Connecting to database "%s" @ "%s:%s" as user "%s".',
66
             db_name, db_host, db_port, str(username))
67
    ssl_kwargs = _get_ssl_kwargs()
68
    connection = mongoengine.connection.connect(db_name, host=db_host,
69
                                                port=db_port, tz_aware=True,
70
                                                username=username, password=password,
71
                                                **ssl_kwargs)
72
73
    # Create all the indexes upfront to prevent race-conditions caused by
74
    # lazy index creation
75
    if ensure_indexes:
76
        db_ensure_indexes()
77
78
    return connection
79
80
81
def db_ensure_indexes():
82
    """
83
    This function ensures that indexes for all the models have been created.
84
85
    Note #1: When calling this method database connection already needs to be
86
    established.
87
88
    Note #2: This method blocks until all the index have been created (indexes
89
    are created in real-time and not in background).
90
    """
91
    LOG.debug('Ensuring database indexes...')
92
    model_classes = get_model_classes()
93
94
    for cls in model_classes:
95
        LOG.debug('Ensuring indexes for model "%s"...' % (cls.__name__))
96
        cls.ensure_indexes()
97
98
99
def db_teardown():
100
    mongoengine.connection.disconnect()
101
102
103
def _get_ssl_kwargs(ssl=False, ssl_keyfile=None, ssl_certfile=None, ssl_cert_reqs=None,
104
                    ssl_ca_certs=None, ssl_match_hostname=True):
105
    ssl_kwargs = {
106
        'ssl': ssl,
107
    }
108
    if ssl_keyfile:
109
        ssl_kwargs['ssl'] = True
110
        ssl_kwargs['ssl_keyfile'] = ssl_keyfile
111
    if ssl_certfile:
112
        ssl_kwargs['ssl'] = True
113
        ssl_kwargs['ssl_certfile'] = ssl_certfile
114
    if ssl_cert_reqs:
115
        if ssl_cert_reqs is 'none':
116
            ssl_cert_reqs = ssl_lib.CERT_NONE
117
        elif ssl_cert_reqs is 'optional':
118
            ssl_cert_reqs = ssl_lib.CERT_OPTIONAL
119
        elif ssl_cert_reqs is 'required':
120
            ssl_cert_reqs = ssl_lib.CERT_REQUIRED
121
        ssl_kwargs['ssl_cert_reqs'] = ssl_cert_reqs
122
    if ssl_ca_certs:
123
        ssl_kwargs['ssl'] = True
124
        ssl_kwargs['ssl_ca_certs'] = ssl_ca_certs
125
    if ssl_kwargs.get('ssl', False):
126
        # pass in ssl_match_hostname only if ssl is True. The right default value
127
        # for ssl_match_hostname in almost all cases is True.
128
        ssl_kwargs['ssl_match_hostname'] = ssl_match_hostname
129
    return ssl_kwargs
130
131
132
class MongoDBAccess(object):
133
    """Database object access class that provides general functions for a model type."""
134
135
    def __init__(self, model):
136
        self.model = model
137
138
    def get_by_name(self, value):
139
        return self.get(name=value, raise_exception=True)
140
141
    def get_by_id(self, value):
142
        return self.get(id=value, raise_exception=True)
143
144
    def get_by_ref(self, value):
145
        return self.get(ref=value, raise_exception=True)
146
147
    def get(self, exclude_fields=None, *args, **kwargs):
148
        raise_exception = kwargs.pop('raise_exception', False)
149
150
        instances = self.model.objects(**kwargs)
151
152
        if exclude_fields:
153
            instances = instances.exclude(*exclude_fields)
154
155
        instance = instances[0] if instances else None
156
        log_query_and_profile_data_for_queryset(queryset=instances)
157
158
        if not instance and raise_exception:
159
            raise ValueError('Unable to find the %s instance. %s' % (self.model.__name__, kwargs))
160
        return instance
161
162
    def get_all(self, *args, **kwargs):
163
        return self.query(*args, **kwargs)
164
165
    def count(self, *args, **kwargs):
166
        result = self.model.objects(**kwargs).count()
167
        log_query_and_profile_data_for_queryset(queryset=result)
168
        return result
169
170
    def query(self, offset=0, limit=None, order_by=None, exclude_fields=None,
171
              **filters):
172
        order_by = order_by or []
173
        exclude_fields = exclude_fields or []
174
        eop = offset + int(limit) if limit else None
175
176
        # Process the filters
177
        # Note: Both of those functions manipulate "filters" variable so the order in which they
178
        # are called matters
179
        filters, order_by = self._process_datetime_range_filters(filters=filters, order_by=order_by)
180
        filters = self._process_null_filters(filters=filters)
181
182
        result = self.model.objects(**filters)
183
184
        if exclude_fields:
185
            result = result.exclude(*exclude_fields)
186
187
        result = result.order_by(*order_by)
188
        result = result[offset:eop]
189
        log_query_and_profile_data_for_queryset(queryset=result)
190
191
        return result
192
193
    def distinct(self, *args, **kwargs):
194
        field = kwargs.pop('field')
195
        result = self.model.objects(**kwargs).distinct(field)
196
        log_query_and_profile_data_for_queryset(queryset=result)
197
        return result
198
199
    def aggregate(self, *args, **kwargs):
200
        return self.model.objects(**kwargs)._collection.aggregate(*args, **kwargs)
201
202
    def insert(self, instance):
203
        instance = self.model.objects.insert(instance)
204
        return self._undo_dict_field_escape(instance)
205
206
    def add_or_update(self, instance):
207
        instance.save()
208
        return self._undo_dict_field_escape(instance)
209
210
    def update(self, instance, **kwargs):
211
        return instance.update(**kwargs)
212
213
    def delete(self, instance):
214
        return instance.delete()
215
216
    def delete_by_query(self, **query):
217
        qs = self.model.objects.filter(**query)
218
        qs.delete()
219
        log_query_and_profile_data_for_queryset(queryset=qs)
220
        # mongoengine does not return anything useful so cannot return anything meaningful.
221
        return None
222
223
    def _undo_dict_field_escape(self, instance):
224
        for attr, field in instance._fields.iteritems():
225
            if isinstance(field, stormbase.EscapedDictField):
226
                value = getattr(instance, attr)
227
                setattr(instance, attr, field.to_python(value))
228
        return instance
229
230
    def _process_null_filters(self, filters):
231
        result = copy.deepcopy(filters)
232
233
        null_filters = {k: v for k, v in six.iteritems(filters)
234
                        if v is None or (type(v) in [str, unicode] and str(v.lower()) == 'null')}
235
236
        for key in null_filters.keys():
237
            result['%s__exists' % (key)] = False
238
            del result[key]
239
240
        return result
241
242
    def _process_datetime_range_filters(self, filters, order_by=None):
243
        ranges = {k: v for k, v in filters.iteritems()
244
                  if type(v) in [str, unicode] and '..' in v}
245
246
        order_by_list = copy.deepcopy(order_by) if order_by else []
247
        for k, v in ranges.iteritems():
248
            values = v.split('..')
249
            dt1 = isotime.parse(values[0])
250
            dt2 = isotime.parse(values[1])
251
252
            k__gte = '%s__gte' % k
253
            k__lte = '%s__lte' % k
254
            if dt1 < dt2:
255
                query = {k__gte: dt1, k__lte: dt2}
256
                sort_key, reverse_sort_key = k, '-' + k
257
            else:
258
                query = {k__gte: dt2, k__lte: dt1}
259
                sort_key, reverse_sort_key = '-' + k, k
260
            del filters[k]
261
            filters.update(query)
262
263
            if reverse_sort_key in order_by_list:
264
                idx = order_by_list.index(reverse_sort_key)
265
                order_by_list.pop(idx)
266
                order_by_list.insert(idx, sort_key)
267
            elif sort_key not in order_by_list:
268
                order_by_list = [sort_key] + order_by_list
269
270
        return filters, order_by_list
271