Passed
Push — develop ( f534b1...a82689 )
by Plexxi
06:09 queued 03:13
created

extend_with_default()   B

Complexity

Conditions 5

Size

Total Lines 15

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 5
c 0
b 0
f 0
dl 0
loc 15
rs 8.5454

1 Method

Rating   Name   Duplication   Size   Complexity  
A set_defaults() 0 9 4
1
# Licensed to the StackStorm, Inc ('StackStorm') under one or more
2
# contributor license agreements.  See the NOTICE file distributed with
3
# this work for additional information regarding copyright ownership.
4
# The ASF licenses this file to You under the Apache License, Version 2.0
5
# (the "License"); you may not use this file except in compliance with
6
# the License.  You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
import copy
17
import functools
18
import re
19
import six
20
import sys
21
import traceback
22
23
import jsonschema
24
from oslo_config import cfg
25
import routes
26
from six.moves.urllib import parse as urlparse  # pylint: disable=import-error
27
from swagger_spec_validator.validator20 import validate_spec
28
import webob
29
from webob import exc, Request
30
31
from st2common.exceptions import rbac as rbac_exc
32
from st2common.exceptions import auth as auth_exc
33
from st2common import log as logging
34
from st2common.persistence.auth import User
35
from st2common.rbac import resolvers
36
from st2common.util.jsonify import json_encode
37
from st2common.util.http import parse_content_type_header
38
39
40
LOG = logging.getLogger(__name__)
41
42
43
def op_resolver(op_id):
44
    module_name, func_name = op_id.split(':', 1)
45
    __import__(module_name)
46
    module = sys.modules[module_name]
47
    return functools.reduce(getattr, func_name.split('.'), module)
48
49
50
def abort(status_code=exc.HTTPInternalServerError.code, message='Unhandled exception'):
51
    raise exc.status_map[status_code](message)
52
53
54
def abort_unauthorized(msg=None):
55
    raise exc.HTTPUnauthorized('Unauthorized - %s' % msg if msg else 'Unauthorized')
56
57
58
def extend_with_default(validator_class):
59
    validate_properties = validator_class.VALIDATORS["properties"]
60
61
    def set_defaults(validator, properties, instance, schema):
62
        for property, subschema in six.iteritems(properties):
0 ignored issues
show
Bug Best Practice introduced by
This seems to re-define the built-in property.

It is generally discouraged to redefine built-ins as this makes code very hard to read.

Loading history...
63
            if "default" in subschema:
64
                instance.setdefault(property, subschema["default"])
65
66
        for error in validate_properties(
67
            validator, properties, instance, schema,
68
        ):
69
            yield error
70
71
    return jsonschema.validators.extend(
72
        validator_class, {"properties": set_defaults},
73
    )
74
75
76
def extend_with_additional_check(validator_class):
77
    def set_additional_check(validator, properties, instance, schema):
78
        ref = schema.get("x-additional-check")
79
        func = op_resolver(ref)
80
        for error in func(validator, properties, instance, schema):
81
            yield error
82
83
    return jsonschema.validators.extend(
84
        validator_class, {"x-additional-check": set_additional_check},
85
    )
86
87
88
def extend_with_nullable(validator_class):
89
    validate_type = validator_class.VALIDATORS["type"]
90
91
    def set_type_draft4(validator, types, instance, schema):
92
        is_nullable = schema.get("x-nullable", False)
93
94
        if is_nullable and instance is None:
95
            return
96
97
        for error in validate_type(validator, types, instance, schema):
98
            yield error
99
100
    return jsonschema.validators.extend(
101
        validator_class, {"type": set_type_draft4},
102
    )
103
104
105
CustomValidator = jsonschema.Draft4Validator
106
CustomValidator = extend_with_nullable(CustomValidator)
107
CustomValidator = extend_with_additional_check(CustomValidator)
108
CustomValidator = extend_with_default(CustomValidator)
109
110
111
class NotFoundException(Exception):
112
    pass
113
114
115
class Response(webob.Response):
116
    def __init__(self, body=None, status=None, headerlist=None, app_iter=None, content_type=None,
117
                 *args, **kwargs):
118
        # Do some sanity checking, and turn json_body into an actual body
119
        if app_iter is None and body is None and ('json_body' in kwargs or 'json' in kwargs):
120
            if 'json_body' in kwargs:
121
                json_body = kwargs.pop('json_body')
122
            else:
123
                json_body = kwargs.pop('json')
124
            body = json_encode(json_body).encode('UTF-8')
125
126
            if content_type is None:
127
                content_type = 'application/json'
128
129
        super(Response, self).__init__(body, status, headerlist, app_iter, content_type,
130
                                       *args, **kwargs)
131
132
    def _json_body__get(self):
133
        return super(Response, self)._json_body__get()
134
135
    def _json_body__set(self, value):
136
        self.body = json_encode(value).encode('UTF-8')
137
138
    def _json_body__del(self):
139
        return super(Response, self)._json_body__del()
140
141
    json = json_body = property(_json_body__get, _json_body__set, _json_body__del)
142
143
144
class Router(object):
145
    def __init__(self, arguments=None, debug=False, auth=True):
146
        self.debug = debug
147
        self.auth = auth
148
149
        self.arguments = arguments or {}
150
151
        self.spec = {}
152
        self.spec_resolver = None
153
        self.routes = routes.Mapper()
154
155
    def add_spec(self, spec, transforms):
156
        info = spec.get('info', {})
157
        LOG.debug('Adding API: %s %s', info.get('title', 'untitled'), info.get('version', '0.0.0'))
158
159
        self.spec = spec
160
        self.spec_resolver = validate_spec(copy.deepcopy(self.spec))
161
162
        for filter in transforms:
0 ignored issues
show
Bug Best Practice introduced by
This seems to re-define the built-in filter.

It is generally discouraged to redefine built-ins as this makes code very hard to read.

Loading history...
163
            for (path, methods) in six.iteritems(spec['paths']):
164
                if not re.search(filter, path):
165
                    continue
166
167
                for (method, endpoint) in six.iteritems(methods):
168
                    conditions = {
169
                        'method': [method.upper()]
170
                    }
171
172
                    connect_kw = {}
173
                    if 'x-requirements' in endpoint:
174
                        connect_kw['requirements'] = endpoint['x-requirements']
175
176
                    m = self.routes.submapper(_api_path=path, _api_method=method,
177
                                              conditions=conditions)
178
                    for transform in transforms[filter]:
179
                        m.connect(None, re.sub(filter, transform, path), **connect_kw)
180
181
        for route in sorted(self.routes.matchlist, key=lambda r: r.routepath):
182
            LOG.debug('Route registered: %+6s %s', route.conditions['method'][0], route.routepath)
183
184
    def match(self, req):
185
        path = req.path
186
187
        if len(path) > 1 and path.endswith('/'):
188
            path = path[:-1]
189
190
        match = self.routes.match(path, req.environ)
191
192
        if match is None:
193
            raise NotFoundException('No route matches "%s" path' % req.path)
194
195
        # To account for situation when match may return multiple values
196
        try:
197
            path_vars = match[0]
198
        except KeyError:
199
            path_vars = match
200
201
        path = path_vars.pop('_api_path')
202
        method = path_vars.pop('_api_method')
203
        endpoint = self.spec['paths'][path][method]
204
205
        return endpoint, path_vars
206
207
    def __call__(self, req):
208
        """
209
        The method is invoked on every request and shows the lifecycle of the request received from
210
        the middleware.
211
212
        Although some middleware may use parts of the API spec, it is safe to assume that if you're
213
        looking for the particular spec property handler, it's most  likely a part of this method.
214
215
        At the time of writing, the only property being utilized by middleware was `x-log-result`.
216
        """
217
        endpoint, path_vars = self.match(req)
218
219
        context = copy.copy(getattr(self, 'mock_context', {}))
220
221
        # Handle security
222
        if 'security' in endpoint:
223
            security = endpoint.get('security')
224
        else:
225
            security = self.spec.get('security', [])
226
227
        if self.auth and security:
228
            try:
229
                auth_resp = None
230
                security_definitions = self.spec.get('securityDefinitions', {})
231
                for statement in security:
232
                    declaration, options = statement.copy().popitem()
233
                    definition = security_definitions[declaration]
234
235
                    if definition['type'] == 'apiKey':
236
                        if definition['in'] == 'header':
237
                            token = req.headers.get(definition['name'])
238
                        elif definition['in'] == 'query':
239
                            token = req.GET.get(definition['name'])
240
                        else:
241
                            token = None
242
243
                        if token:
244
                            if auth_resp:
245
                                raise auth_exc.MultipleAuthSourcesError(
246
                                    'Only one of Token or API key expected.')
247
248
                            auth_func = op_resolver(definition['x-operationId'])
249
                            auth_resp = auth_func(token)
250
251
                            context['user'] = User.get_by_name(auth_resp.user)
252
253
                if 'user' not in context:
254
                    raise auth_exc.NoAuthSourceProvidedError('One of Token or API key required.')
255
            except (auth_exc.NoAuthSourceProvidedError,
256
                    auth_exc.MultipleAuthSourcesError) as e:
257
                LOG.error(str(e))
258
                return abort_unauthorized(str(e))
259
            except auth_exc.TokenNotProvidedError as e:
260
                LOG.exception('Token is not provided.')
261
                return abort_unauthorized(str(e))
262
            except auth_exc.TokenNotFoundError as e:
263
                LOG.exception('Token is not found.')
264
                return abort_unauthorized(str(e))
265
            except auth_exc.TokenExpiredError as e:
266
                LOG.exception('Token has expired.')
267
                return abort_unauthorized(str(e))
268
            except auth_exc.ApiKeyNotProvidedError as e:
269
                LOG.exception('API key is not provided.')
270
                return abort_unauthorized(str(e))
271
            except auth_exc.ApiKeyNotFoundError as e:
272
                LOG.exception('API key is not found.')
273
                return abort_unauthorized(str(e))
274
            except auth_exc.ApiKeyDisabledError as e:
275
                LOG.exception('API key is disabled.')
276
                return abort_unauthorized(str(e))
277
278
        if cfg.CONF.rbac.enable:
279
            user_db = context['user']
280
281
            permission_type = endpoint.get('x-permissions', None)
282
            if permission_type:
283
                resolver = resolvers.get_resolver_for_permission_type(permission_type)
284
                has_permission = resolver.user_has_permission(user_db, permission_type)
285
286
                if not has_permission:
287
                    raise rbac_exc.ResourceTypeAccessDeniedError(user_db,
288
                                                                 permission_type)
289
290
        # Collect parameters
291
        kw = {}
292
        for param in endpoint.get('parameters', []) + endpoint.get('x-parameters', []):
293
            name = param['name']
294
            argument_name = param.get('x-as', None) or name
295
            source = param['in']
296
            default = param.get('default', None)
297
298
            # Collecting params from different sources
299
            if source == 'query':
300
                kw[argument_name] = req.GET.get(name, default)
301
            elif source == 'path':
302
                kw[argument_name] = path_vars[name]
303
            elif source == 'header':
304
                kw[argument_name] = req.headers.get(name, default)
305
            elif source == 'formData':
306
                kw[argument_name] = req.POST.get(name, default)
307
            elif source == 'environ':
308
                kw[argument_name] = req.environ.get(name.upper(), default)
309
            elif source == 'context':
310
                kw[argument_name] = context.get(name, default)
311
            elif source == 'request':
312
                kw[argument_name] = getattr(req, name)
313
            elif source == 'body':
314
                if req.body:
315
                    content_type = req.headers.get('Content-Type', 'application/json')
316
                    content_type = parse_content_type_header(content_type=content_type)[0]
317
                    schema = param['schema']
318
319
                    try:
320
                        if content_type == 'application/json':
321
                            data = req.json
322
                        elif content_type == 'text/plain':
323
                            data = req.body
324
                        elif content_type in ['application/x-www-form-urlencoded',
325
                                              'multipart/form-data']:
326
                            data = urlparse.parse_qs(req.body)
327
                        else:
328
                            raise ValueError('Unsupported Content-Type: "%s"' % (content_type))
329
                    except Exception as e:
330
                        detail = 'Failed to parse request body: %s' % str(e)
331
                        raise exc.HTTPBadRequest(detail=detail)
332
333
                    try:
334
                        CustomValidator(schema, resolver=self.spec_resolver).validate(data)
335
                    except (jsonschema.ValidationError, ValueError) as e:
336
                        raise exc.HTTPBadRequest(detail=e.message,
337
                                                 comment=traceback.format_exc())
338
339
                    if content_type == 'text/plain':
340
                        kw[argument_name] = data
341
                    else:
342
                        class Body(object):
343
                            def __init__(self, **entries):
344
                                self.__dict__.update(entries)
345
346
                        ref = schema.get('$ref', None)
347
                        if ref:
348
                            with self.spec_resolver.resolving(ref) as resolved:
349
                                schema = resolved
350
351
                        if 'x-api-model' in schema:
352
                            Model = op_resolver(schema['x-api-model'])
353
                        else:
354
                            Model = Body
355
356
                        kw[argument_name] = Model(**data)
357
                else:
358
                    kw[argument_name] = None
359
360
            # Making sure all required params are present
361
            required = param.get('required', False)
362
            if required and kw[argument_name] is None:
363
                detail = 'Required parameter "%s" is missing' % name
364
                raise exc.HTTPBadRequest(detail=detail)
365
366
            # Validating and casting param types
367
            type = param.get('type', None)
0 ignored issues
show
Bug Best Practice introduced by
This seems to re-define the built-in type.

It is generally discouraged to redefine built-ins as this makes code very hard to read.

Loading history...
368
            if kw[argument_name] is not None:
369
                if type == 'boolean':
370
                    positive = ('true', '1', 'yes', 'y')
371
                    negative = ('false', '0', 'no', 'n')
372
373
                    if str(kw[argument_name]).lower() not in positive + negative:
374
                        detail = 'Parameter "%s" is not of type boolean' % argument_name
375
                        raise exc.HTTPBadRequest(detail=detail)
376
377
                    kw[argument_name] = str(kw[argument_name]).lower() in positive
378
                elif type == 'integer':
379
                    regex = r'^-?[0-9]+$'
380
381
                    if not re.search(regex, str(kw[argument_name])):
382
                        detail = 'Parameter "%s" is not of type integer' % argument_name
383
                        raise exc.HTTPBadRequest(detail=detail)
384
385
                    kw[argument_name] = int(kw[argument_name])
386
                elif type == 'number':
387
                    regex = r'^[+-]?(\d+(\.\d*)?|\.\d+)([eE][+-]?\d+)?$'
388
389
                    if not re.search(regex, str(kw[argument_name])):
390
                        detail = 'Parameter "%s" is not of type float' % argument_name
391
                        raise exc.HTTPBadRequest(detail=detail)
392
393
                    kw[argument_name] = float(kw[argument_name])
394
395
        # Call the controller
396
        func = op_resolver(endpoint['operationId'])
397
        resp = func(**kw)
398
399
        # Handle response
400
        if resp is None:
401
            resp = Response()
402
403
        if not hasattr(resp, '__call__'):
404
            resp = Response(json=resp)
405
406
        responses = endpoint.get('responses', {})
407
        response_spec = responses.get(str(resp.status_code), responses.get('default', None))
408
409
        if response_spec and 'schema' in response_spec:
410
            try:
411
                validator = CustomValidator(response_spec['schema'], resolver=self.spec_resolver)
412
                validator.validate(resp.json)
413
            except (jsonschema.ValidationError, ValueError):
414
                LOG.exception('Response validation failed.')
415
                resp.headers.add('Warning', '199 OpenAPI "Response validation failed"')
416
417
        return resp
418
419
    def as_wsgi(self, environ, start_response):
420
        """
421
        Converts WSGI request to webob.Request and initiates the response returned by controller.
422
        """
423
        req = Request(environ)
424
        resp = self(req)
425
        return resp(environ, start_response)
426