1
|
|
|
import asyncio |
2
|
|
|
import logging |
3
|
|
|
from typing import Mapping |
4
|
|
|
|
5
|
|
|
import jwt |
6
|
|
|
from aiohttp import ClientSession, web, DummyCookieJar |
7
|
|
|
from aiohttp.web_exceptions import HTTPUnauthorized, HTTPProxyAuthenticationRequired, HTTPBadRequest |
8
|
|
|
from jwt import DecodeError, ExpiredSignatureError |
9
|
|
|
from yarl import URL |
10
|
|
|
|
11
|
|
|
from helpers import clean_response_headers |
12
|
|
|
from monitoring import REQUEST_HISTOGRAM, UPSTREAM_STATUS_COUNTER |
13
|
|
|
|
14
|
|
|
logger = logging.getLogger(__name__) |
15
|
|
|
|
16
|
|
|
|
17
|
|
|
class Proxy: |
18
|
|
|
"""This is basically a reverse proxy that translates some headers. We don't care about cookies or sessions. |
19
|
|
|
|
20
|
|
|
This takes the OIDC data from the load balancer, validates it, and adds new headers as expected by Grafana. |
21
|
|
|
Some form of key caching may be useful and will be implemented later. |
22
|
|
|
""" |
23
|
|
|
|
24
|
|
|
def __init__( |
25
|
|
|
self, |
26
|
|
|
upstream: str, |
27
|
|
|
aws_region: str, |
28
|
|
|
header_name: str = "X-WEBAUTH-USER", |
29
|
|
|
header_property: str = "email", |
30
|
|
|
ignore_auth: bool = False, |
31
|
|
|
): |
32
|
|
|
"""Creates a server for a given AWS region. |
33
|
|
|
|
34
|
|
|
:param upstream: The URL of the upstream server |
35
|
|
|
:param aws_region: There AWS region where this is running, used to fetch the key. |
36
|
|
|
:param header_name: HTTP header name to send, as configured in ``grafana.ini``. |
37
|
|
|
:param header_property: The header property to use from the payload. Should match what Grafana expects. |
38
|
|
|
:param ignore_auth: Whether to run without authentication. Should only be used in testing. |
39
|
|
|
""" |
40
|
|
|
self._ignore_auth = ignore_auth |
41
|
|
|
self._upstream = URL(upstream) |
42
|
|
|
self._key_url = URL(f"https://public-keys.auth.elb.{aws_region}.amazonaws.com") |
43
|
|
|
self._header_name = header_name |
44
|
|
|
self._header_property = header_property |
45
|
|
|
|
46
|
|
|
async def _setup_session(self, app): |
47
|
|
|
"""Handle context sessions nicely. |
48
|
|
|
|
49
|
|
|
`See docs <https://docs.aiohttp.org/en/latest/client_advanced.html#persistent-session>`_""" |
50
|
|
|
self._key_session = ClientSession(raise_for_status=True) |
51
|
|
|
self._upstream_session = ClientSession( |
52
|
|
|
raise_for_status=False, cookie_jar=DummyCookieJar(), auto_decompress=False |
53
|
|
|
) |
54
|
|
|
yield |
55
|
|
|
await asyncio.gather(self._key_session.close(), self._upstream_session.close()) |
56
|
|
|
|
57
|
|
|
def runner(self): |
58
|
|
|
app = web.Application(middlewares=[self.auth_middleware], logger=logger) |
59
|
|
|
app.router.add_route("*", "/{tail:.*}", self.handle_request) |
60
|
|
|
app.cleanup_ctx.append(self._setup_session) |
61
|
|
|
return web.AppRunner(app) |
62
|
|
|
|
63
|
|
|
async def _decode_payload(self, oidc_data: str) -> Mapping[str, str]: |
64
|
|
|
""" Returns the payload of the OIDC data sent by the ALB |
65
|
|
|
|
66
|
|
|
`Relevant AWS Documentation |
67
|
|
|
<https://docs.aws.amazon.com/elasticloadbalancing/latest/application/listener-authenticate-users.html#user-claims-encoding>`_ |
68
|
|
|
|
69
|
|
|
:param oidc_data: OIDC data from alb |
70
|
|
|
:return: payload |
71
|
|
|
:raise: jwt.exceptions.ExpiredSignatureError: If the token is not longer valid |
72
|
|
|
""" |
73
|
|
|
header = jwt.get_unverified_header(oidc_data) |
74
|
|
|
kid = header["kid"] |
75
|
|
|
alg = header["alg"] |
76
|
|
|
|
77
|
|
|
async with self._key_session.get(self._key_url.join(URL(kid))) as response: |
78
|
|
|
pub_key = await response.text() |
79
|
|
|
|
80
|
|
|
payload = jwt.decode(oidc_data, pub_key, algorithms=[alg]) |
81
|
|
|
try: |
82
|
|
|
return payload[self._header_property] |
83
|
|
|
except KeyError: |
84
|
|
|
logger.warning(f"Could not find '{self._header_property}' key in OIDC Data.") |
85
|
|
|
raise HTTPBadRequest |
86
|
|
|
|
87
|
|
|
async def _add_auth_info(self, request: web.Request): |
88
|
|
|
"""Adds the authentication information, if any, to the request. |
89
|
|
|
|
90
|
|
|
Catches exceptions from decoding the payload and converts them to HTTP exceptions to be propagated. |
91
|
|
|
If authentication is disabled via :attr:`~_ignore_auth` doesn't do anything. |
92
|
|
|
|
93
|
|
|
Headers are kept in a `CIMultiDictProxy`_ so case of the header is not important. |
94
|
|
|
|
95
|
|
|
.. _CIMultiDictProxy: https://multidict.readthedocs.io/en/stable/multidict.html#multidict.CIMultiDictProxy |
96
|
|
|
""" |
97
|
|
|
if self._ignore_auth: |
98
|
|
|
return None |
99
|
|
|
|
100
|
|
|
try: |
101
|
|
|
oidc_data = request.headers["X-Amzn-Oidc-Data"] |
102
|
|
|
except KeyError: |
103
|
|
|
logger.warning("No 'X-Amzn-Oidc-Data' header present. Dropping request.") |
104
|
|
|
raise HTTPProxyAuthenticationRequired() |
105
|
|
|
try: |
106
|
|
|
request["auth_payload"] = (self._header_name, await self._decode_payload(oidc_data)) |
107
|
|
|
except ExpiredSignatureError: |
108
|
|
|
logger.warning("Got expired token. Dropping request.") |
109
|
|
|
raise HTTPUnauthorized() |
110
|
|
|
except DecodeError as e: |
111
|
|
|
logger.warning("Couldn't decode token. Dropping request.") |
112
|
|
|
logger.debug("Couldn't decode token: %s" % e) |
113
|
|
|
raise HTTPBadRequest() |
114
|
|
|
|
115
|
|
|
@REQUEST_HISTOGRAM.time() |
116
|
|
|
async def handle_request(self, request: web.Request) -> web.StreamResponse: |
117
|
|
|
upstream_url = self._upstream.join(request.url.relative()) |
118
|
|
|
upstream_request = self._upstream_session.request( |
119
|
|
|
url=upstream_url, |
120
|
|
|
method=request.method, |
121
|
|
|
headers=clean_response_headers(request), |
122
|
|
|
params=request.query, |
123
|
|
|
data=request.content, |
124
|
|
|
allow_redirects=False, |
125
|
|
|
) |
126
|
|
|
async with upstream_request as upstream_response: |
127
|
|
|
UPSTREAM_STATUS_COUNTER.labels(method=upstream_response.method, status=upstream_response.status).inc() |
128
|
|
|
response = web.StreamResponse(status=upstream_response.status, headers=upstream_response.headers) |
129
|
|
|
await response.prepare(request) |
130
|
|
|
async for data in upstream_response.content.iter_any(): |
131
|
|
|
await response.write(data) |
132
|
|
|
await response.write_eof() |
133
|
|
|
return response |
134
|
|
|
|
135
|
|
|
@web.middleware |
136
|
|
|
async def auth_middleware(self, request, handler): |
137
|
|
|
await self._add_auth_info(request) |
138
|
|
|
return await handler(request) |
139
|
|
|
|