Passed
Push — develop ( f5fdfa...ccf839 )
by Plexxi
05:58 queued 02:59
created

AuthHook.before()   C

Complexity

Conditions 7

Size

Total Lines 24

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 7
dl 0
loc 24
rs 5.5
c 0
b 0
f 0
1
# Licensed to the StackStorm, Inc ('StackStorm') under one or more
2
# contributor license agreements.  See the NOTICE file distributed with
3
# this work for additional information regarding copyright ownership.
4
# The ASF licenses this file to You under the Apache License, Version 2.0
5
# (the "License"); you may not use this file except in compliance with
6
# the License.  You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
import httplib
17
import re
18
import traceback
19
import uuid
20
21
import webob
22
from oslo_config import cfg
23
from pecan.hooks import PecanHook
24
from six.moves.urllib import parse as urlparse
25
from webob import exc
26
27
from st2common import log as logging
28
from st2common.persistence.auth import User
29
from st2common.exceptions import db as db_exceptions
30
from st2common.exceptions import auth as auth_exceptions
31
from st2common.exceptions import rbac as rbac_exceptions
32
from st2common.exceptions.db import StackStormDBObjectNotFoundError
33
from st2common.exceptions.apivalidation import ValueValidationException
34
from st2common.util import auth as auth_utils
35
from st2common.util.jsonify import json_encode
36
from st2common.util.debugging import is_enabled as is_debugging_enabled
37
from st2common.constants.api import REQUEST_ID_HEADER
38
from st2common.constants.auth import HEADER_ATTRIBUTE_NAME
39
from st2common.constants.auth import QUERY_PARAM_ATTRIBUTE_NAME
40
from st2common.constants.auth import HEADER_API_KEY_ATTRIBUTE_NAME
41
from st2common.constants.auth import QUERY_PARAM_API_KEY_ATTRIBUTE_NAME
42
43
44
LOG = logging.getLogger(__name__)
45
46
# A list of method names for which we don't want to log the result / response
47
RESPONSE_LOGGING_METHOD_NAME_BLACKLIST = [
48
    'get_all'
49
]
50
51
# A list of controller classes for which we don't want to log the result / response
52
RESPONSE_LOGGING_CONTROLLER_NAME_BLACKLIST = [
53
    'ActionExecutionChildrenController',  # action executions can be big
54
    'ActionExecutionAttributeController',  # result can be big
55
    'ActionExecutionsController'  # action executions can be big,
56
    'FilesController',  # files controller returns files content
57
    'FileController'  # file controller returns binary file data
58
]
59
60
# Regex for the st2 auth tokens endpoint (i.e. /tokens or /v1/tokens).
61
AUTH_TOKENS_URL_REGEX = '^(?:/tokens|/v\d+/tokens)$'
0 ignored issues
show
Bug introduced by
A suspicious escape sequence \d was found. Did you maybe forget to add an r prefix?

Escape sequences in Python are generally interpreted according to rules similar to standard C. Only if strings are prefixed with r or R are they interpreted as regular expressions.

The escape sequence that was used indicates that you might have intended to write a regular expression.

Learn more about the available escape sequences. in the Python documentation.

Loading history...
62
63
64
class CorsHook(PecanHook):
65
66
    def after(self, state):
67
        headers = state.response.headers
68
69
        origin = state.request.headers.get('Origin')
70
        origins = set(cfg.CONF.api.allow_origin)
71
72
        # Build a list of the default allowed origins
73
        public_api_url = cfg.CONF.auth.api_url
74
75
        # Default gulp development server WebUI URL
76
        origins.add('http://127.0.0.1:3000')
77
78
        # By default WebUI simple http server listens on 8080
79
        origins.add('http://localhost:8080')
80
        origins.add('http://127.0.0.1:8080')
81
82
        if public_api_url:
83
            # Public API URL
84
            origins.add(public_api_url)
85
86
        if origin:
87
            if '*' in origins:
88
                origin_allowed = '*'
89
            else:
90
                # See http://www.w3.org/TR/cors/#access-control-allow-origin-response-header
91
                origin_allowed = origin if origin in origins else 'null'
92
        else:
93
            origin_allowed = list(origins)[0]
94
95
        methods_allowed = ['GET', 'POST', 'PUT', 'DELETE', 'OPTIONS']
96
        request_headers_allowed = ['Content-Type', 'Authorization', 'X-Auth-Token',
97
                                   HEADER_API_KEY_ATTRIBUTE_NAME, REQUEST_ID_HEADER]
98
        response_headers_allowed = ['Content-Type', 'X-Limit', 'X-Total-Count',
99
                                    REQUEST_ID_HEADER]
100
101
        headers['Access-Control-Allow-Origin'] = origin_allowed
102
        headers['Access-Control-Allow-Methods'] = ','.join(methods_allowed)
103
        headers['Access-Control-Allow-Headers'] = ','.join(request_headers_allowed)
104
        headers['Access-Control-Expose-Headers'] = ','.join(response_headers_allowed)
105
        if not headers.get('Content-Length') \
106
                and not headers.get('Content-type', '').startswith('text/event-stream'):
107
            headers['Content-Length'] = str(len(state.response.body))
108
109
    def on_error(self, state, e):
110
        if state.request.method == 'OPTIONS':
111
            return webob.Response()
112
113
114
class AuthHook(PecanHook):
115
116
    def before(self, state):
117
        # OPTIONS requests doesn't need to be authenticated
118
        if state.request.method == 'OPTIONS':
119
            return
120
121
        # Token request is authenticated separately.
122
        if (state.request.method == 'POST' and
123
                re.search(AUTH_TOKENS_URL_REGEX, state.request.path)):
124
            return
125
126
        if cfg.CONF.auth.enable:
127
            user_db = self._validate_creds_and_get_user(request=state.request)
128
129
            # Store related user object in the context. The token is not passed
130
            # along any longer as that should only be used in the auth domain.
131
            state.request.context['auth'] = {
132
                'user': user_db
133
            }
134
135
        if QUERY_PARAM_ATTRIBUTE_NAME in state.arguments.keywords:
136
            del state.arguments.keywords[QUERY_PARAM_ATTRIBUTE_NAME]
137
138
        if QUERY_PARAM_API_KEY_ATTRIBUTE_NAME in state.arguments.keywords:
139
            del state.arguments.keywords[QUERY_PARAM_API_KEY_ATTRIBUTE_NAME]
140
141
    def on_error(self, state, e):
142
        if isinstance(e, (auth_exceptions.NoAuthSourceProvidedError,
143
                          auth_exceptions.MultipleAuthSourcesError)):
144
            LOG.error(str(e))
145
            return self._abort_unauthorized(str(e))
146
        if isinstance(e, auth_exceptions.TokenNotProvidedError):
147
            LOG.exception('Token is not provided.')
148
            return self._abort_unauthorized(str(e))
149
        if isinstance(e, auth_exceptions.TokenNotFoundError):
150
            LOG.exception('Token is not found.')
151
            return self._abort_unauthorized(str(e))
152
        if isinstance(e, auth_exceptions.TokenExpiredError):
153
            LOG.exception('Token has expired.')
154
            return self._abort_unauthorized(str(e))
155
        if isinstance(e, auth_exceptions.ApiKeyNotProvidedError):
156
            LOG.exception('API key is not provided.')
157
            return self._abort_unauthorized(str(e))
158
        if isinstance(e, auth_exceptions.ApiKeyNotFoundError):
159
            LOG.exception('API key is not found.')
160
            return self._abort_unauthorized(str(e))
161
        if isinstance(e, auth_exceptions.ApiKeyDisabledError):
162
            LOG.exception('API key is disabled.')
163
            return self._abort_unauthorized(str(e))
164
165
    @staticmethod
166
    def _abort_unauthorized(msg):
167
        faultstring = 'Unauthorized - %s' % msg if msg else 'Unauthorized'
168
        body = json_encode({
169
            'faultstring': faultstring
170
        })
171
        headers = {}
172
        headers['Content-Type'] = 'application/json'
173
        status = httplib.UNAUTHORIZED
174
175
        return webob.Response(body=body, status=status, headers=headers)
176
177
    @staticmethod
178
    def _abort_other_errors():
179
        body = json_encode({
180
            'faultstring': 'Internal Server Error'
181
        })
182
        headers = {}
183
        headers['Content-Type'] = 'application/json'
184
        status = httplib.INTERNAL_SERVER_ERROR
185
186
        return webob.Response(body=body, status=status, headers=headers)
187
188
    @staticmethod
189
    def _validate_creds_and_get_user(request):
190
        """
191
        Validate one of token or api_key provided either in headers or query parameters.
192
        Will returnt the User
193
194
        :rtype: :class:`UserDB`
195
        """
196
197
        headers = request.headers
198
        query_string = request.query_string
199
        query_params = dict(urlparse.parse_qsl(query_string))
200
201
        token_in_headers = headers.get(HEADER_ATTRIBUTE_NAME, None)
202
        token_in_query_params = query_params.get(QUERY_PARAM_ATTRIBUTE_NAME, None)
203
204
        api_key_in_headers = headers.get(HEADER_API_KEY_ATTRIBUTE_NAME, None)
205
        api_key_in_query_params = query_params.get(QUERY_PARAM_API_KEY_ATTRIBUTE_NAME, None)
206
207
        if ((token_in_headers or token_in_query_params) and
208
                (api_key_in_headers or api_key_in_query_params)):
209
            raise auth_exceptions.MultipleAuthSourcesError(
210
                'Only one of Token or API key expected.')
211
212
        user = None
213
214
        if token_in_headers or token_in_query_params:
215
            token_db = auth_utils.validate_token_and_source(
216
                token_in_headers=token_in_headers,
217
                token_in_query_params=token_in_query_params)
218
            user = token_db.user
219
        elif api_key_in_headers or api_key_in_query_params:
220
            api_key_db = auth_utils.validate_api_key_and_source(
221
                api_key_in_headers=api_key_in_headers,
222
                api_key_query_params=api_key_in_query_params)
223
            user = api_key_db.user
224
        else:
225
            raise auth_exceptions.NoAuthSourceProvidedError('One of Token or API key required.')
226
227
        if not user:
228
            LOG.warn('User not found for supplied token or api-key.')
229
            return None
230
231
        try:
232
            return User.get(user)
233
        except StackStormDBObjectNotFoundError:
234
            # User doesn't exist - we should probably also invalidate token/apikey if
235
            # this happens.
236
            LOG.warn('User %s not found.', user)
237
            return None
238
239
240
class JSONErrorResponseHook(PecanHook):
241
    """
242
    Handle all the errors and respond with JSON.
243
    """
244
245
    def on_error(self, state, e):
246
        if hasattr(e, 'body') and isinstance(e.body, dict):
247
            body = e.body
248
        else:
249
            body = {}
250
251
        if isinstance(e, exc.HTTPException):
252
            status_code = state.response.status
253
            message = str(e)
254
        elif isinstance(e, db_exceptions.StackStormDBObjectNotFoundError):
255
            status_code = httplib.NOT_FOUND
256
            message = str(e)
257
        elif isinstance(e, db_exceptions.StackStormDBObjectConflictError):
258
            status_code = httplib.CONFLICT
259
            message = str(e)
260
            body['conflict-id'] = e.conflict_id
261
        elif isinstance(e, rbac_exceptions.AccessDeniedError):
262
            status_code = httplib.FORBIDDEN
263
            message = str(e)
264
        elif isinstance(e, (ValueValidationException, ValueError)):
265
            status_code = httplib.BAD_REQUEST
266
            message = getattr(e, 'message', str(e))
267
        else:
268
            status_code = httplib.INTERNAL_SERVER_ERROR
269
            message = 'Internal Server Error'
270
271
        # Log the error
272
        is_internal_server_error = status_code == httplib.INTERNAL_SERVER_ERROR
273
        error_msg = getattr(e, 'comment', str(e))
274
        extra = {
275
            'exception_class': e.__class__.__name__,
276
            'exception_message': str(e),
277
            'exception_data': e.__dict__
278
        }
279
280
        if is_internal_server_error:
281
            LOG.exception('API call failed: %s', error_msg, extra=extra)
282
            LOG.exception(traceback.format_exc())
283
        else:
284
            LOG.debug('API call failed: %s', error_msg, extra=extra)
285
286
            if is_debugging_enabled():
287
                LOG.debug(traceback.format_exc())
288
289
        body['faultstring'] = message
290
291
        response_body = json_encode(body)
292
        headers = state.response.headers or {}
293
294
        headers['Content-Type'] = 'application/json'
295
        headers['Content-Length'] = str(len(response_body))
296
297
        return webob.Response(response_body, status=status_code, headers=headers)
298
299
300
class LoggingHook(PecanHook):
301
    """
302
    Logs all incoming requests and outgoing responses
303
    """
304
305
    def before(self, state):
306
        # Note: We use getattr since in some places (tests) request is mocked
307
        method = getattr(state.request, 'method', None)
308
        path = getattr(state.request, 'path', None)
309
        remote_addr = getattr(state.request, 'remote_addr', None)
310
311
        # Log the incoming request
312
        values = {'method': method, 'path': path, 'remote_addr': remote_addr}
313
        values['filters'] = state.arguments.keywords
314
315
        request_id = state.request.headers.get(REQUEST_ID_HEADER, None)
316
        values['request_id'] = request_id
317
318
        LOG.info('%(request_id)s -  %(method)s %(path)s with filters=%(filters)s' %
0 ignored issues
show
Coding Style Best Practice introduced by
Specify string format arguments as logging function parameters
Loading history...
319
                 values, extra=values)
320
321
    def after(self, state):
322
        # Note: We use getattr since in some places (tests) request is mocked
323
        method = getattr(state.request, 'method', None)
324
        path = getattr(state.request, 'path', None)
325
        remote_addr = getattr(state.request, 'remote_addr', None)
326
        request_id = state.request.headers.get(REQUEST_ID_HEADER, None)
327
328
        # Log the outgoing response
329
        values = {'method': method, 'path': path, 'remote_addr': remote_addr}
330
        values['status_code'] = state.response.status
331
        values['request_id'] = request_id
332
333
        if hasattr(state.controller, 'im_self'):
334
            function_name = state.controller.im_func.__name__
335
            controller_name = state.controller.im_class.__name__
336
337
            log_result = True
338
            log_result &= function_name not in RESPONSE_LOGGING_METHOD_NAME_BLACKLIST
339
            log_result &= controller_name not in RESPONSE_LOGGING_CONTROLLER_NAME_BLACKLIST
340
        else:
341
            log_result = False
342
343
        if log_result:
344
            values['result'] = state.response.body
345
            log_msg = '%(request_id)s - %(method)s %(path)s result=%(result)s' % values
346
        else:
347
            # Note: We don't want to include a result for some
348
            # methods which have a large result
349
            log_msg = '%(request_id)s - %(method)s %(path)s' % values
350
351
        LOG.info(log_msg, extra=values)
352
353
354
class RequestIDHook(PecanHook):
355
    """
356
    If request id header isn't present, this hooks adds one.
357
    """
358
359
    def before(self, state):
360
        headers = getattr(state.request, 'headers', None)
361
362
        if headers:
363
            req_id_header = getattr(headers, REQUEST_ID_HEADER, None)
364
365
            if not req_id_header:
366
                req_id = str(uuid.uuid4())
367
                state.request.headers[REQUEST_ID_HEADER] = req_id
368
369
    def after(self, state):
370
        req_headers = getattr(state.request, 'headers', None)
371
        resp_headers = getattr(state.response, 'headers', None)
372
373
        if req_headers and resp_headers:
374
            req_id_header = req_headers.get(REQUEST_ID_HEADER, None)
375
            if req_id_header:
376
                resp_headers[REQUEST_ID_HEADER] = req_id_header
377