Completed
Pull Request — master (#2744)
by W
10:48 queued 04:22
created

AuthHook.on_error()   D

Complexity

Conditions 8

Size

Total Lines 23

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 8
dl 0
loc 23
rs 4.7619
c 0
b 0
f 0
1
# Licensed to the StackStorm, Inc ('StackStorm') under one or more
2
# contributor license agreements.  See the NOTICE file distributed with
3
# this work for additional information regarding copyright ownership.
4
# The ASF licenses this file to You under the Apache License, Version 2.0
5
# (the "License"); you may not use this file except in compliance with
6
# the License.  You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
import httplib
17
import re
18
import traceback
19
import uuid
20
21
import webob
22
from oslo_config import cfg
23
from pecan.hooks import PecanHook
24
from six.moves.urllib import parse as urlparse
25
from webob import exc
26
27
from st2common import log as logging
28
from st2common.persistence.auth import User
29
from st2common.exceptions import db as db_exceptions
30
from st2common.exceptions import auth as auth_exceptions
31
from st2common.exceptions import rbac as rbac_exceptions
32
from st2common.exceptions.db import StackStormDBObjectNotFoundError
33
from st2common.exceptions.apivalidation import ValueValidationException
34
from st2common.util.jsonify import json_encode
35
from st2common.util.auth import validate_token, validate_api_key
36
from st2common.util.debugging import is_enabled as is_debugging_enabled
37
from st2common.constants.api import REQUEST_ID_HEADER
38
from st2common.constants.auth import HEADER_ATTRIBUTE_NAME
39
from st2common.constants.auth import QUERY_PARAM_ATTRIBUTE_NAME
40
from st2common.constants.auth import HEADER_API_KEY_ATTRIBUTE_NAME
41
from st2common.constants.auth import QUERY_PARAM_API_KEY_ATTRIBUTE_NAME
42
43
44
LOG = logging.getLogger(__name__)
45
46
# A list of method names for which we don't want to log the result / response
47
RESPONSE_LOGGING_METHOD_NAME_BLACKLIST = [
48
    'get_all'
49
]
50
51
# A list of controller classes for which we don't want to log the result / response
52
RESPONSE_LOGGING_CONTROLLER_NAME_BLACKLIST = [
53
    'ActionExecutionChildrenController',  # action executions can be big
54
    'ActionExecutionAttributeController',  # result can be big
55
    'ActionExecutionsController'  # action executions can be big,
56
    'FilesController',  # files controller returns files content
57
    'FileController'  # file controller returns binary file data
58
]
59
60
61
class CorsHook(PecanHook):
62
63
    def after(self, state):
64
        headers = state.response.headers
65
66
        origin = state.request.headers.get('Origin')
67
        origins = set(cfg.CONF.api.allow_origin)
68
69
        # Build a list of the default allowed origins
70
        public_api_url = cfg.CONF.auth.api_url
71
72
        # Default gulp development server WebUI URL
73
        origins.add('http://127.0.0.1:3000')
74
75
        # By default WebUI simple http server listens on 8080
76
        origins.add('http://localhost:8080')
77
        origins.add('http://127.0.0.1:8080')
78
79
        if public_api_url:
80
            # Public API URL
81
            origins.add(public_api_url)
82
83
        if origin:
84
            if '*' in origins:
85
                origin_allowed = '*'
86
            else:
87
                # See http://www.w3.org/TR/cors/#access-control-allow-origin-response-header
88
                origin_allowed = origin if origin in origins else 'null'
89
        else:
90
            origin_allowed = list(origins)[0]
91
92
        methods_allowed = ['GET', 'POST', 'PUT', 'DELETE', 'OPTIONS']
93
        request_headers_allowed = ['Content-Type', 'Authorization', 'X-Auth-Token',
94
                                   HEADER_API_KEY_ATTRIBUTE_NAME, REQUEST_ID_HEADER]
95
        response_headers_allowed = ['Content-Type', 'X-Limit', 'X-Total-Count',
96
                                    REQUEST_ID_HEADER]
97
98
        headers['Access-Control-Allow-Origin'] = origin_allowed
99
        headers['Access-Control-Allow-Methods'] = ','.join(methods_allowed)
100
        headers['Access-Control-Allow-Headers'] = ','.join(request_headers_allowed)
101
        headers['Access-Control-Expose-Headers'] = ','.join(response_headers_allowed)
102
        if not headers.get('Content-Length') \
103
                and not headers.get('Content-type', '').startswith('text/event-stream'):
104
            headers['Content-Length'] = str(len(state.response.body))
105
106
    def on_error(self, state, e):
107
        if state.request.method == 'OPTIONS':
108
            return webob.Response()
109
110
111
class AuthHook(PecanHook):
112
113
    def before(self, state):
114
        # OPTIONS requests doesn't need to be authenticated
115
        if state.request.method == 'OPTIONS':
116
            return
117
118
        # Token request is authenticated separately.
119
        if (state.request.method == 'POST' and (
120
                state.request.path == '/tokens' or
121
                re.search('/v\d+/tokens$', state.request.path))):
0 ignored issues
show
Bug introduced by
A suspicious escape sequence \d was found. Did you maybe forget to add an r prefix?

Escape sequences in Python are generally interpreted according to rules similar to standard C. Only if strings are prefixed with r or R are they interpreted as regular expressions.

The escape sequence that was used indicates that you might have intended to write a regular expression.

Learn more about the available escape sequences. in the Python documentation.

Loading history...
122
            return
123
124
        user_db = self._validate_creds_and_get_user(request=state.request)
125
126
        # Store related user object in the context. The token is not passed
127
        # along any longer as that should only be used in the auth domain.
128
        state.request.context['auth'] = {
129
            'user': user_db
130
        }
131
132
        if QUERY_PARAM_ATTRIBUTE_NAME in state.arguments.keywords:
133
            del state.arguments.keywords[QUERY_PARAM_ATTRIBUTE_NAME]
134
135
        if QUERY_PARAM_API_KEY_ATTRIBUTE_NAME in state.arguments.keywords:
136
            del state.arguments.keywords[QUERY_PARAM_API_KEY_ATTRIBUTE_NAME]
137
138
    def on_error(self, state, e):
139
        if isinstance(e, (auth_exceptions.NoAuthSourceProvidedError,
140
                          auth_exceptions.MultipleAuthSourcesError)):
141
            LOG.error(str(e))
142
            return self._abort_unauthorized(str(e))
143
        if isinstance(e, auth_exceptions.TokenNotProvidedError):
144
            LOG.exception('Token is not provided.')
145
            return self._abort_unauthorized(str(e))
146
        if isinstance(e, auth_exceptions.TokenNotFoundError):
147
            LOG.exception('Token is not found.')
148
            return self._abort_unauthorized(str(e))
149
        if isinstance(e, auth_exceptions.TokenExpiredError):
150
            LOG.exception('Token has expired.')
151
            return self._abort_unauthorized(str(e))
152
        if isinstance(e, auth_exceptions.ApiKeyNotProvidedError):
153
            LOG.exception('API key is not provided.')
154
            return self._abort_unauthorized(str(e))
155
        if isinstance(e, auth_exceptions.ApiKeyNotFoundError):
156
            LOG.exception('API key is not found.')
157
            return self._abort_unauthorized(str(e))
158
        if isinstance(e, auth_exceptions.ApiKeyDisabledError):
159
            LOG.exception('API key is disabled.')
160
            return self._abort_unauthorized(str(e))
161
162
    @staticmethod
163
    def _abort_unauthorized(msg):
164
        faultstring = 'Unauthorized - %s' % msg if msg else 'Unauthorized'
165
        body = json_encode({
166
            'faultstring': faultstring
167
        })
168
        headers = {}
169
        headers['Content-Type'] = 'application/json'
170
        status = httplib.UNAUTHORIZED
171
172
        return webob.Response(body=body, status=status, headers=headers)
173
174
    @staticmethod
175
    def _abort_other_errors():
176
        body = json_encode({
177
            'faultstring': 'Internal Server Error'
178
        })
179
        headers = {}
180
        headers['Content-Type'] = 'application/json'
181
        status = httplib.INTERNAL_SERVER_ERROR
182
183
        return webob.Response(body=body, status=status, headers=headers)
184
185
    @staticmethod
186
    def _validate_creds_and_get_user(request):
187
        """
188
        Validate one of token or api_key provided either in headers or query parameters.
189
        Will returnt the User
190
191
        :rtype: :class:`UserDB`
192
        """
193
194
        headers = request.headers
195
        query_string = request.query_string
196
        query_params = dict(urlparse.parse_qsl(query_string))
197
198
        token_in_headers = headers.get(HEADER_ATTRIBUTE_NAME, None)
199
        token_in_query_params = query_params.get(QUERY_PARAM_ATTRIBUTE_NAME, None)
200
201
        api_key_in_headers = headers.get(HEADER_API_KEY_ATTRIBUTE_NAME, None)
202
        api_key_in_query_params = query_params.get(QUERY_PARAM_API_KEY_ATTRIBUTE_NAME, None)
203
204
        if (token_in_headers or token_in_query_params) and \
205
           (api_key_in_headers or api_key_in_query_params):
206
            raise auth_exceptions.MultipleAuthSourcesError(
207
                'Only one of Token or API key expected.')
208
209
        user = None
210
        if token_in_headers or token_in_query_params:
211
            token_db = validate_token(token_in_headers=token_in_headers,
212
                                      token_in_query_params=token_in_query_params)
213
            user = token_db.user
214
        elif api_key_in_headers or api_key_in_query_params:
215
            api_key_db = validate_api_key(api_key_in_headers=api_key_in_headers,
216
                                          api_key_query_params=api_key_in_query_params)
217
            user = api_key_db.user
218
        else:
219
            raise auth_exceptions.NoAuthSourceProvidedError('One of Token or API key required.')
220
221
        if not user:
222
            LOG.warn('User not found for supplied token or api-key.')
223
            return None
224
225
        try:
226
            return User.get(user)
227
        except StackStormDBObjectNotFoundError:
228
            # User doesn't exist - we should probably also invalidate token/apikey if
229
            # this happens.
230
            LOG.warn('User %s not found.', user)
231
            return None
232
233
234
class JSONErrorResponseHook(PecanHook):
235
    """
236
    Handle all the errors and respond with JSON.
237
    """
238
239
    def on_error(self, state, e):
240
        if hasattr(e, 'body') and isinstance(e.body, dict):
241
            body = e.body
242
        else:
243
            body = {}
244
245
        if isinstance(e, exc.HTTPException):
246
            status_code = state.response.status
247
            message = str(e)
248
        elif isinstance(e, db_exceptions.StackStormDBObjectNotFoundError):
249
            status_code = httplib.NOT_FOUND
250
            message = str(e)
251
        elif isinstance(e, db_exceptions.StackStormDBObjectConflictError):
252
            status_code = httplib.CONFLICT
253
            message = str(e)
254
            body['conflict-id'] = e.conflict_id
255
        elif isinstance(e, rbac_exceptions.AccessDeniedError):
256
            status_code = httplib.FORBIDDEN
257
            message = str(e)
258
        elif isinstance(e, (ValueValidationException, ValueError)):
259
            status_code = httplib.BAD_REQUEST
260
            message = getattr(e, 'message', str(e))
261
        else:
262
            status_code = httplib.INTERNAL_SERVER_ERROR
263
            message = 'Internal Server Error'
264
265
        # Log the error
266
        is_internal_server_error = status_code == httplib.INTERNAL_SERVER_ERROR
267
        error_msg = getattr(e, 'comment', str(e))
268
        extra = {
269
            'exception_class': e.__class__.__name__,
270
            'exception_message': str(e),
271
            'exception_data': e.__dict__
272
        }
273
274
        if is_internal_server_error:
275
            LOG.exception('API call failed: %s', error_msg, extra=extra)
276
            LOG.exception(traceback.format_exc())
277
        else:
278
            LOG.debug('API call failed: %s', error_msg, extra=extra)
279
280
            if is_debugging_enabled():
281
                LOG.debug(traceback.format_exc())
282
283
        body['faultstring'] = message
284
285
        response_body = json_encode(body)
286
        headers = state.response.headers or {}
287
288
        headers['Content-Type'] = 'application/json'
289
        headers['Content-Length'] = str(len(response_body))
290
291
        return webob.Response(response_body, status=status_code, headers=headers)
292
293
294
class LoggingHook(PecanHook):
295
    """
296
    Logs all incoming requests and outgoing responses
297
    """
298
299
    def before(self, state):
300
        # Note: We use getattr since in some places (tests) request is mocked
301
        method = getattr(state.request, 'method', None)
302
        path = getattr(state.request, 'path', None)
303
        remote_addr = getattr(state.request, 'remote_addr', None)
304
305
        # Log the incoming request
306
        values = {'method': method, 'path': path, 'remote_addr': remote_addr}
307
        values['filters'] = state.arguments.keywords
308
309
        request_id = state.request.headers.get(REQUEST_ID_HEADER, None)
310
        values['request_id'] = request_id
311
312
        LOG.info('%(request_id)s -  %(method)s %(path)s with filters=%(filters)s' %
0 ignored issues
show
Coding Style Best Practice introduced by
Specify string format arguments as logging function parameters
Loading history...
313
                 values, extra=values)
314
315
    def after(self, state):
316
        # Note: We use getattr since in some places (tests) request is mocked
317
        method = getattr(state.request, 'method', None)
318
        path = getattr(state.request, 'path', None)
319
        remote_addr = getattr(state.request, 'remote_addr', None)
320
        request_id = state.request.headers.get(REQUEST_ID_HEADER, None)
321
322
        # Log the outgoing response
323
        values = {'method': method, 'path': path, 'remote_addr': remote_addr}
324
        values['status_code'] = state.response.status
325
        values['request_id'] = request_id
326
327
        if hasattr(state.controller, 'im_self'):
328
            function_name = state.controller.im_func.__name__
329
            controller_name = state.controller.im_class.__name__
330
331
            log_result = True
332
            log_result &= function_name not in RESPONSE_LOGGING_METHOD_NAME_BLACKLIST
333
            log_result &= controller_name not in RESPONSE_LOGGING_CONTROLLER_NAME_BLACKLIST
334
        else:
335
            log_result = False
336
337
        if log_result:
338
            values['result'] = state.response.body
339
            log_msg = '%(request_id)s - %(method)s %(path)s result=%(result)s' % values
340
        else:
341
            # Note: We don't want to include a result for some
342
            # methods which have a large result
343
            log_msg = '%(request_id)s - %(method)s %(path)s' % values
344
345
        LOG.info(log_msg, extra=values)
346
347
348
class RequestIDHook(PecanHook):
349
    """
350
    If request id header isn't present, this hooks adds one.
351
    """
352
353
    def before(self, state):
354
        headers = getattr(state.request, 'headers', None)
355
356
        if headers:
357
            req_id_header = getattr(headers, REQUEST_ID_HEADER, None)
358
359
            if not req_id_header:
360
                req_id = str(uuid.uuid4())
361
                state.request.headers[REQUEST_ID_HEADER] = req_id
362
363
    def after(self, state):
364
        req_headers = getattr(state.request, 'headers', None)
365
        resp_headers = getattr(state.response, 'headers', None)
366
367
        if req_headers and resp_headers:
368
            req_id_header = req_headers.get(REQUEST_ID_HEADER, None)
369
            if req_id_header:
370
                resp_headers[REQUEST_ID_HEADER] = req_id_header
371