aws_alb_oauth_proxy.__main__   A
last analyzed

Complexity

Total Complexity 3

Size/Duplication

Total Lines 100
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 61
dl 0
loc 100
rs 10
c 0
b 0
f 0
wmc 3

1 Function

Rating   Name   Duplication   Size   Complexity  
A work() 0 25 3
1
import asyncio
2
import argparse
3
import logging
4
import sys
5
from concurrent.futures.process import ProcessPoolExecutor
6
7
from aiohttp import web
8
from prometheus_client import start_http_server
9
10
from helpers import _aws_region
11
from server import Proxy
12
13
# Command line arguments
14
15
parser = argparse.ArgumentParser(
16
    description="Decode AWS ALB OIDC JWT to Proxy Auth for Grafana",
17
    formatter_class=argparse.ArgumentDefaultsHelpFormatter,
18
)
19
parser.add_argument("upstream", help="Upstream server URL: scheme://host:port")
20
parser.add_argument("-p", "--port", type=int, default=8080, help="Port to listen on")
21
parser.add_argument("--ignore-auth", action="store_true", help="Whether to ignore the JWT token")
22
parser.add_argument(
23
    "--loglevel", default="info", choices=["debug", "info", "warning", "error", "critical"], help="Logging verbosity"
24
)
25
# parser.add_argument("--logtz", default="local", choices=["utc", "local"], help="Time zone to use for logging")
26
parser.add_argument("--mon-port", type=int, default=8081, help="Port for exposing metrics")
27
28
args = parser.parse_args()
29
30
upstream = args.upstream
31
port = args.port
32
ignore_auth = args.ignore_auth
33
34
loglevel = args.loglevel
35
36
monitor_port = args.mon_port
37
38
# Logging
39
40
logger = logging.getLogger("main")
41
numeric_level = getattr(logging, loglevel.upper(), None)
42
if not isinstance(numeric_level, int):
43
    raise ValueError("Invalid log level: %s" % loglevel)
44
logging.basicConfig(
45
    level=numeric_level, format="%(asctime)s %(processName)-22s %(levelname)-8s %(name)-15s %(message)s"
46
)
47
48
49
# Actual work
50
51
region = _aws_region()
52
if not region and not ignore_auth:
53
    logger.error("Could not detect AWS region. Are we running on AWS?")
54
    sys.exit(1)
55
56
logger.info(f"Upstream:     {upstream}")
57
logger.info(f"Client port:  {port}")
58
logger.info(f"Metrics port: {monitor_port}")
59
60
61
def work():
62
    proxy = Proxy(aws_region=region, upstream=upstream, ignore_auth=ignore_auth)
63
    runner = proxy.runner()
64
65
    async def start():
66
        await runner.setup()
67
        site = web.TCPSite(runner, port=port, reuse_address=True, reuse_port=True)
68
        logger.debug("Started site...")
69
        await site.start()
70
71
    async def cleanup():
72
        await runner.cleanup()
73
74
    try:
75
        loop = asyncio.get_event_loop()
76
    except RuntimeError:
77
        loop = asyncio.new_event_loop()
78
79
    try:
80
        loop.run_until_complete(start())
81
        loop.run_forever()
82
    except Exception as exc:
83
        logger.warning(f"Got exception: {exc}. Shutting down...")
84
        loop.run_until_complete(cleanup())
85
        loop.stop()
86
87
88
# with ProcessPoolExecutor(max_workers=4) as executor:
89
#     workers = {executor.submit(work) for _ in range(4)}
90
#     for future in workers:
91
#         try:
92
#             future.result
93
#         except Exception as exc:
94
#             logger.warning(f"Worker {future} got an exception: {exc}")
95
#         else:
96
#             logger.info(f"Worker {future} is shut down.")
97
98
start_http_server(monitor_port)
99
work()
100