Passed
Push — develop ( e1e109...2dd18e )
by Plexxi
06:53 queued 03:26
created

AuthHook.on_error()   D

Complexity

Conditions 8

Size

Total Lines 23

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 8
dl 0
loc 23
rs 4.7619
c 0
b 0
f 0
1
# Licensed to the StackStorm, Inc ('StackStorm') under one or more
2
# contributor license agreements.  See the NOTICE file distributed with
3
# this work for additional information regarding copyright ownership.
4
# The ASF licenses this file to You under the Apache License, Version 2.0
5
# (the "License"); you may not use this file except in compliance with
6
# the License.  You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
import httplib
17
import re
18
import traceback
19
import uuid
20
21
import webob
22
from oslo_config import cfg
23
from pecan.hooks import PecanHook
24
from six.moves.urllib import parse as urlparse
25
from webob import exc
26
27
from st2common import log as logging
28
from st2common.persistence.auth import User
29
from st2common.exceptions import db as db_exceptions
30
from st2common.exceptions import auth as auth_exceptions
31
from st2common.exceptions import rbac as rbac_exceptions
32
from st2common.exceptions.db import StackStormDBObjectNotFoundError
33
from st2common.exceptions.apivalidation import ValueValidationException
34
from st2common.util import auth as auth_utils
35
from st2common.util.jsonify import json_encode
36
from st2common.util.debugging import is_enabled as is_debugging_enabled
37
from st2common.constants.api import REQUEST_ID_HEADER
38
from st2common.constants.auth import HEADER_ATTRIBUTE_NAME
39
from st2common.constants.auth import QUERY_PARAM_ATTRIBUTE_NAME
40
from st2common.constants.auth import HEADER_API_KEY_ATTRIBUTE_NAME
41
from st2common.constants.auth import QUERY_PARAM_API_KEY_ATTRIBUTE_NAME
42
43
44
LOG = logging.getLogger(__name__)
45
46
# A list of method names for which we don't want to log the result / response
47
RESPONSE_LOGGING_METHOD_NAME_BLACKLIST = [
48
    'get_all'
49
]
50
51
# A list of controller classes for which we don't want to log the result / response
52
RESPONSE_LOGGING_CONTROLLER_NAME_BLACKLIST = [
53
    'ActionExecutionChildrenController',  # action executions can be big
54
    'ActionExecutionAttributeController',  # result can be big
55
    'ActionExecutionsController'  # action executions can be big,
56
    'FilesController',  # files controller returns files content
57
    'FileController'  # file controller returns binary file data
58
]
59
60
# Regex for the st2 auth tokens endpoint (i.e. /tokens or /v1/tokens).
61
AUTH_TOKENS_URL_REGEX = '^(?:/tokens|/v\d+/tokens)$'
0 ignored issues
show
Bug introduced by
A suspicious escape sequence \d was found. Did you maybe forget to add an r prefix?

Escape sequences in Python are generally interpreted according to rules similar to standard C. Only if strings are prefixed with r or R are they interpreted as regular expressions.

The escape sequence that was used indicates that you might have intended to write a regular expression.

Learn more about the available escape sequences. in the Python documentation.

Loading history...
62
63
64
class CorsHook(PecanHook):
65
66
    def after(self, state):
67
        headers = state.response.headers
68
69
        origin = state.request.headers.get('Origin')
70
        origins = set(cfg.CONF.api.allow_origin)
71
72
        # Build a list of the default allowed origins
73
        public_api_url = cfg.CONF.auth.api_url
74
75
        # Default gulp development server WebUI URL
76
        origins.add('http://127.0.0.1:3000')
77
78
        # By default WebUI simple http server listens on 8080
79
        origins.add('http://localhost:8080')
80
        origins.add('http://127.0.0.1:8080')
81
82
        if public_api_url:
83
            # Public API URL
84
            origins.add(public_api_url)
85
86
        if origin:
87
            if '*' in origins:
88
                origin_allowed = '*'
89
            else:
90
                # See http://www.w3.org/TR/cors/#access-control-allow-origin-response-header
91
                origin_allowed = origin if origin in origins else 'null'
92
        else:
93
            origin_allowed = list(origins)[0]
94
95
        methods_allowed = ['GET', 'POST', 'PUT', 'DELETE', 'OPTIONS']
96
        request_headers_allowed = ['Content-Type', 'Authorization', 'X-Auth-Token',
97
                                   HEADER_API_KEY_ATTRIBUTE_NAME, REQUEST_ID_HEADER]
98
        response_headers_allowed = ['Content-Type', 'X-Limit', 'X-Total-Count',
99
                                    REQUEST_ID_HEADER]
100
101
        headers['Access-Control-Allow-Origin'] = origin_allowed
102
        headers['Access-Control-Allow-Methods'] = ','.join(methods_allowed)
103
        headers['Access-Control-Allow-Headers'] = ','.join(request_headers_allowed)
104
        headers['Access-Control-Expose-Headers'] = ','.join(response_headers_allowed)
105
        if not headers.get('Content-Length') \
106
                and not headers.get('Content-type', '').startswith('text/event-stream'):
107
            headers['Content-Length'] = str(len(state.response.body))
108
109
    def on_error(self, state, e):
110
        if state.request.method == 'OPTIONS':
111
            return webob.Response()
112
113
114
class AuthHook(PecanHook):
115
116
    def before(self, state):
117
        # OPTIONS requests doesn't need to be authenticated
118
        if state.request.method == 'OPTIONS':
119
            return
120
121
        # Token request is authenticated separately.
122
        if (state.request.method == 'POST' and
123
                re.search(AUTH_TOKENS_URL_REGEX, state.request.path)):
124
            return
125
126
        user_db = self._validate_creds_and_get_user(request=state.request)
127
128
        # Store related user object in the context. The token is not passed
129
        # along any longer as that should only be used in the auth domain.
130
        state.request.context['auth'] = {
131
            'user': user_db
132
        }
133
134
        if QUERY_PARAM_ATTRIBUTE_NAME in state.arguments.keywords:
135
            del state.arguments.keywords[QUERY_PARAM_ATTRIBUTE_NAME]
136
137
        if QUERY_PARAM_API_KEY_ATTRIBUTE_NAME in state.arguments.keywords:
138
            del state.arguments.keywords[QUERY_PARAM_API_KEY_ATTRIBUTE_NAME]
139
140
    def on_error(self, state, e):
141
        if isinstance(e, (auth_exceptions.NoAuthSourceProvidedError,
142
                          auth_exceptions.MultipleAuthSourcesError)):
143
            LOG.error(str(e))
144
            return self._abort_unauthorized(str(e))
145
        if isinstance(e, auth_exceptions.TokenNotProvidedError):
146
            LOG.exception('Token is not provided.')
147
            return self._abort_unauthorized(str(e))
148
        if isinstance(e, auth_exceptions.TokenNotFoundError):
149
            LOG.exception('Token is not found.')
150
            return self._abort_unauthorized(str(e))
151
        if isinstance(e, auth_exceptions.TokenExpiredError):
152
            LOG.exception('Token has expired.')
153
            return self._abort_unauthorized(str(e))
154
        if isinstance(e, auth_exceptions.ApiKeyNotProvidedError):
155
            LOG.exception('API key is not provided.')
156
            return self._abort_unauthorized(str(e))
157
        if isinstance(e, auth_exceptions.ApiKeyNotFoundError):
158
            LOG.exception('API key is not found.')
159
            return self._abort_unauthorized(str(e))
160
        if isinstance(e, auth_exceptions.ApiKeyDisabledError):
161
            LOG.exception('API key is disabled.')
162
            return self._abort_unauthorized(str(e))
163
164
    @staticmethod
165
    def _abort_unauthorized(msg):
166
        faultstring = 'Unauthorized - %s' % msg if msg else 'Unauthorized'
167
        body = json_encode({
168
            'faultstring': faultstring
169
        })
170
        headers = {}
171
        headers['Content-Type'] = 'application/json'
172
        status = httplib.UNAUTHORIZED
173
174
        return webob.Response(body=body, status=status, headers=headers)
175
176
    @staticmethod
177
    def _abort_other_errors():
178
        body = json_encode({
179
            'faultstring': 'Internal Server Error'
180
        })
181
        headers = {}
182
        headers['Content-Type'] = 'application/json'
183
        status = httplib.INTERNAL_SERVER_ERROR
184
185
        return webob.Response(body=body, status=status, headers=headers)
186
187
    @staticmethod
188
    def _validate_creds_and_get_user(request):
189
        """
190
        Validate one of token or api_key provided either in headers or query parameters.
191
        Will returnt the User
192
193
        :rtype: :class:`UserDB`
194
        """
195
196
        headers = request.headers
197
        query_string = request.query_string
198
        query_params = dict(urlparse.parse_qsl(query_string))
199
200
        token_in_headers = headers.get(HEADER_ATTRIBUTE_NAME, None)
201
        token_in_query_params = query_params.get(QUERY_PARAM_ATTRIBUTE_NAME, None)
202
203
        api_key_in_headers = headers.get(HEADER_API_KEY_ATTRIBUTE_NAME, None)
204
        api_key_in_query_params = query_params.get(QUERY_PARAM_API_KEY_ATTRIBUTE_NAME, None)
205
206
        if ((token_in_headers or token_in_query_params) and
207
                (api_key_in_headers or api_key_in_query_params)):
208
            raise auth_exceptions.MultipleAuthSourcesError(
209
                'Only one of Token or API key expected.')
210
211
        user = None
212
213
        if token_in_headers or token_in_query_params:
214
            token_db = auth_utils.validate_token_and_source(
215
                token_in_headers=token_in_headers,
216
                token_in_query_params=token_in_query_params)
217
            user = token_db.user
218
        elif api_key_in_headers or api_key_in_query_params:
219
            api_key_db = auth_utils.validate_api_key_and_source(
220
                api_key_in_headers=api_key_in_headers,
221
                api_key_query_params=api_key_in_query_params)
222
            user = api_key_db.user
223
        else:
224
            raise auth_exceptions.NoAuthSourceProvidedError('One of Token or API key required.')
225
226
        if not user:
227
            LOG.warn('User not found for supplied token or api-key.')
228
            return None
229
230
        try:
231
            return User.get(user)
232
        except StackStormDBObjectNotFoundError:
233
            # User doesn't exist - we should probably also invalidate token/apikey if
234
            # this happens.
235
            LOG.warn('User %s not found.', user)
236
            return None
237
238
239
class JSONErrorResponseHook(PecanHook):
240
    """
241
    Handle all the errors and respond with JSON.
242
    """
243
244
    def on_error(self, state, e):
245
        if hasattr(e, 'body') and isinstance(e.body, dict):
246
            body = e.body
247
        else:
248
            body = {}
249
250
        if isinstance(e, exc.HTTPException):
251
            status_code = state.response.status
252
            message = str(e)
253
        elif isinstance(e, db_exceptions.StackStormDBObjectNotFoundError):
254
            status_code = httplib.NOT_FOUND
255
            message = str(e)
256
        elif isinstance(e, db_exceptions.StackStormDBObjectConflictError):
257
            status_code = httplib.CONFLICT
258
            message = str(e)
259
            body['conflict-id'] = e.conflict_id
260
        elif isinstance(e, rbac_exceptions.AccessDeniedError):
261
            status_code = httplib.FORBIDDEN
262
            message = str(e)
263
        elif isinstance(e, (ValueValidationException, ValueError)):
264
            status_code = httplib.BAD_REQUEST
265
            message = getattr(e, 'message', str(e))
266
        else:
267
            status_code = httplib.INTERNAL_SERVER_ERROR
268
            message = 'Internal Server Error'
269
270
        # Log the error
271
        is_internal_server_error = status_code == httplib.INTERNAL_SERVER_ERROR
272
        error_msg = getattr(e, 'comment', str(e))
273
        extra = {
274
            'exception_class': e.__class__.__name__,
275
            'exception_message': str(e),
276
            'exception_data': e.__dict__
277
        }
278
279
        if is_internal_server_error:
280
            LOG.exception('API call failed: %s', error_msg, extra=extra)
281
            LOG.exception(traceback.format_exc())
282
        else:
283
            LOG.debug('API call failed: %s', error_msg, extra=extra)
284
285
            if is_debugging_enabled():
286
                LOG.debug(traceback.format_exc())
287
288
        body['faultstring'] = message
289
290
        response_body = json_encode(body)
291
        headers = state.response.headers or {}
292
293
        headers['Content-Type'] = 'application/json'
294
        headers['Content-Length'] = str(len(response_body))
295
296
        return webob.Response(response_body, status=status_code, headers=headers)
297
298
299
class LoggingHook(PecanHook):
300
    """
301
    Logs all incoming requests and outgoing responses
302
    """
303
304
    def before(self, state):
305
        # Note: We use getattr since in some places (tests) request is mocked
306
        method = getattr(state.request, 'method', None)
307
        path = getattr(state.request, 'path', None)
308
        remote_addr = getattr(state.request, 'remote_addr', None)
309
310
        # Log the incoming request
311
        values = {'method': method, 'path': path, 'remote_addr': remote_addr}
312
        values['filters'] = state.arguments.keywords
313
314
        request_id = state.request.headers.get(REQUEST_ID_HEADER, None)
315
        values['request_id'] = request_id
316
317
        LOG.info('%(request_id)s -  %(method)s %(path)s with filters=%(filters)s' %
0 ignored issues
show
Coding Style Best Practice introduced by
Specify string format arguments as logging function parameters
Loading history...
318
                 values, extra=values)
319
320
    def after(self, state):
321
        # Note: We use getattr since in some places (tests) request is mocked
322
        method = getattr(state.request, 'method', None)
323
        path = getattr(state.request, 'path', None)
324
        remote_addr = getattr(state.request, 'remote_addr', None)
325
        request_id = state.request.headers.get(REQUEST_ID_HEADER, None)
326
327
        # Log the outgoing response
328
        values = {'method': method, 'path': path, 'remote_addr': remote_addr}
329
        values['status_code'] = state.response.status
330
        values['request_id'] = request_id
331
332
        if hasattr(state.controller, 'im_self'):
333
            function_name = state.controller.im_func.__name__
334
            controller_name = state.controller.im_class.__name__
335
336
            log_result = True
337
            log_result &= function_name not in RESPONSE_LOGGING_METHOD_NAME_BLACKLIST
338
            log_result &= controller_name not in RESPONSE_LOGGING_CONTROLLER_NAME_BLACKLIST
339
        else:
340
            log_result = False
341
342
        if log_result:
343
            values['result'] = state.response.body
344
            log_msg = '%(request_id)s - %(method)s %(path)s result=%(result)s' % values
345
        else:
346
            # Note: We don't want to include a result for some
347
            # methods which have a large result
348
            log_msg = '%(request_id)s - %(method)s %(path)s' % values
349
350
        LOG.info(log_msg, extra=values)
351
352
353
class RequestIDHook(PecanHook):
354
    """
355
    If request id header isn't present, this hooks adds one.
356
    """
357
358
    def before(self, state):
359
        headers = getattr(state.request, 'headers', None)
360
361
        if headers:
362
            req_id_header = getattr(headers, REQUEST_ID_HEADER, None)
363
364
            if not req_id_header:
365
                req_id = str(uuid.uuid4())
366
                state.request.headers[REQUEST_ID_HEADER] = req_id
367
368
    def after(self, state):
369
        req_headers = getattr(state.request, 'headers', None)
370
        resp_headers = getattr(state.response, 'headers', None)
371
372
        if req_headers and resp_headers:
373
            req_id_header = req_headers.get(REQUEST_ID_HEADER, None)
374
            if req_id_header:
375
                resp_headers[REQUEST_ID_HEADER] = req_id_header
376