aws_alb_oauth_proxy.server.Proxy.handle_request()   A
last analyzed

Complexity

Conditions 3

Size

Total Lines 19
Code Lines 18

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 18
nop 2
dl 0
loc 19
rs 9.5
c 0
b 0
f 0
1
import asyncio
2
import logging
3
from typing import Mapping
4
5
import jwt
6
from aiohttp import ClientSession, web, DummyCookieJar
7
from aiohttp.web_exceptions import HTTPUnauthorized, HTTPProxyAuthenticationRequired, HTTPBadRequest
8
from jwt import DecodeError, ExpiredSignatureError
9
from yarl import URL
10
11
from helpers import clean_response_headers
12
from monitoring import REQUEST_HISTOGRAM, UPSTREAM_STATUS_COUNTER
13
14
logger = logging.getLogger(__name__)
15
16
17
class Proxy:
18
    """This is basically a reverse proxy that translates some headers. We don't care about cookies or sessions.
19
20
    This takes the OIDC data from the load balancer, validates it, and adds new headers as expected by Grafana.
21
    Some form of key caching may be useful and will be implemented later.
22
    """
23
24
    def __init__(
25
        self,
26
        upstream: str,
27
        aws_region: str,
28
        header_name: str = "X-WEBAUTH-USER",
29
        header_property: str = "email",
30
        ignore_auth: bool = False,
31
    ):
32
        """Creates a server for a given AWS region.
33
34
        :param upstream: The URL of the upstream server
35
        :param aws_region: There AWS region where this is running, used to fetch the key.
36
        :param header_name: HTTP header name to send, as configured in ``grafana.ini``.
37
        :param header_property: The header property to use from the payload. Should match what Grafana expects.
38
        :param ignore_auth: Whether to run without authentication. Should only be used in testing.
39
        """
40
        self._ignore_auth = ignore_auth
41
        self._upstream = URL(upstream)
42
        self._key_url = URL(f"https://public-keys.auth.elb.{aws_region}.amazonaws.com")
43
        self._header_name = header_name
44
        self._header_property = header_property
45
46
    async def _setup_session(self, app):
47
        """Handle context sessions nicely.
48
49
        `See docs <https://docs.aiohttp.org/en/latest/client_advanced.html#persistent-session>`_"""
50
        self._key_session = ClientSession(raise_for_status=True)
51
        self._upstream_session = ClientSession(
52
            raise_for_status=False, cookie_jar=DummyCookieJar(), auto_decompress=False
53
        )
54
        yield
55
        await asyncio.gather(self._key_session.close(), self._upstream_session.close())
56
57
    def runner(self):
58
        app = web.Application(middlewares=[self.auth_middleware], logger=logger)
59
        app.router.add_route("*", "/{tail:.*}", self.handle_request)
60
        app.cleanup_ctx.append(self._setup_session)
61
        return web.AppRunner(app)
62
63
    async def _decode_payload(self, oidc_data: str) -> Mapping[str, str]:
64
        """ Returns the payload of the OIDC data sent by the ALB
65
66
        `Relevant AWS Documentation
67
        <https://docs.aws.amazon.com/elasticloadbalancing/latest/application/listener-authenticate-users.html#user-claims-encoding>`_
68
69
        :param oidc_data: OIDC data from alb
70
        :return: payload
71
        :raise: jwt.exceptions.ExpiredSignatureError: If the token is not longer valid
72
        """
73
        header = jwt.get_unverified_header(oidc_data)
74
        kid = header["kid"]
75
        alg = header["alg"]
76
77
        async with self._key_session.get(self._key_url.join(URL(kid))) as response:
78
            pub_key = await response.text()
79
80
        payload = jwt.decode(oidc_data, pub_key, algorithms=[alg])
81
        try:
82
            return payload[self._header_property]
83
        except KeyError:
84
            logger.warning(f"Could not find '{self._header_property}' key in OIDC Data.")
85
            raise HTTPBadRequest
86
87
    async def _add_auth_info(self, request: web.Request):
88
        """Adds the authentication information, if any, to the request.
89
90
        Catches exceptions from decoding the payload and converts them to HTTP exceptions to be propagated.
91
        If authentication is disabled via :attr:`~_ignore_auth` doesn't do anything.
92
93
        Headers are kept in a `CIMultiDictProxy`_ so case of the header is not important.
94
95
        .. _CIMultiDictProxy: https://multidict.readthedocs.io/en/stable/multidict.html#multidict.CIMultiDictProxy
96
        """
97
        if self._ignore_auth:
98
            return None
99
100
        try:
101
            oidc_data = request.headers["X-Amzn-Oidc-Data"]
102
        except KeyError:
103
            logger.warning("No 'X-Amzn-Oidc-Data' header present. Dropping request.")
104
            raise HTTPProxyAuthenticationRequired()
105
        try:
106
            request["auth_payload"] = (self._header_name, await self._decode_payload(oidc_data))
107
        except ExpiredSignatureError:
108
            logger.warning("Got expired token. Dropping request.")
109
            raise HTTPUnauthorized()
110
        except DecodeError as e:
111
            logger.warning("Couldn't decode token. Dropping request.")
112
            logger.debug("Couldn't decode token: %s" % e)
113
            raise HTTPBadRequest()
114
115
    @REQUEST_HISTOGRAM.time()
116
    async def handle_request(self, request: web.Request) -> web.StreamResponse:
117
        upstream_url = self._upstream.join(request.url.relative())
118
        upstream_request = self._upstream_session.request(
119
            url=upstream_url,
120
            method=request.method,
121
            headers=clean_response_headers(request),
122
            params=request.query,
123
            data=request.content,
124
            allow_redirects=False,
125
        )
126
        async with upstream_request as upstream_response:
127
            UPSTREAM_STATUS_COUNTER.labels(method=upstream_response.method, status=upstream_response.status).inc()
128
            response = web.StreamResponse(status=upstream_response.status, headers=upstream_response.headers)
129
            await response.prepare(request)
130
            async for data in upstream_response.content.iter_any():
131
                await response.write(data)
132
            await response.write_eof()
133
            return response
134
135
    @web.middleware
136
    async def auth_middleware(self, request, handler):
137
        await self._add_auth_info(request)
138
        return await handler(request)
139