|
1
|
1 |
|
import base64 |
|
2
|
1 |
|
import secrets |
|
3
|
|
|
|
|
4
|
1 |
|
from pyarrow.flight import ServerMiddlewareFactory, ServerMiddleware |
|
5
|
1 |
|
from pyarrow.flight import FlightUnauthenticatedError |
|
6
|
|
|
|
|
7
|
1 |
|
from tabpy.tabpy_server.handlers.util import hash_password |
|
8
|
|
|
|
|
9
|
1 |
|
class BasicAuthServerMiddleware(ServerMiddleware): |
|
10
|
1 |
|
def __init__(self, token): |
|
11
|
|
|
self.token = token |
|
12
|
|
|
|
|
13
|
1 |
|
def sending_headers(self): |
|
14
|
|
|
return {"authorization": f"Bearer {self.token}"} |
|
15
|
|
|
|
|
16
|
1 |
|
class BasicAuthServerMiddlewareFactory(ServerMiddlewareFactory): |
|
17
|
1 |
|
def __init__(self, creds): |
|
18
|
|
|
self.creds = creds |
|
19
|
|
|
self.tokens = {} |
|
20
|
|
|
|
|
21
|
1 |
|
def is_valid_user(self, username, password): |
|
22
|
|
|
if username not in self.creds: |
|
23
|
|
|
return False |
|
24
|
|
|
hashed_pwd = hash_password(username, password) |
|
25
|
|
|
return self.creds[username].lower() == hashed_pwd.lower() |
|
26
|
|
|
|
|
27
|
1 |
|
def start_call(self, info, headers): |
|
28
|
|
|
auth_header = None |
|
29
|
|
|
for header in headers: |
|
30
|
|
|
if header.lower() == "authorization": |
|
31
|
|
|
auth_header = headers[header][0] |
|
32
|
|
|
break |
|
33
|
|
|
|
|
34
|
|
|
if not auth_header: |
|
35
|
|
|
raise FlightUnauthenticatedError("No credentials supplied") |
|
36
|
|
|
|
|
37
|
|
|
auth_type, _, value = auth_header.partition(" ") |
|
38
|
|
|
|
|
39
|
|
|
if auth_type == "Basic": |
|
40
|
|
|
decoded = base64.b64decode(value).decode("utf-8") |
|
41
|
|
|
username, _, password = decoded.partition(":") |
|
42
|
|
|
if not self.is_valid_user(username, password): |
|
43
|
|
|
raise FlightUnauthenticatedError("Invalid credentials") |
|
44
|
|
|
token = secrets.token_urlsafe(32) |
|
45
|
|
|
self.tokens[token] = username |
|
46
|
|
|
return BasicAuthServerMiddleware(token) |
|
47
|
|
|
|
|
48
|
|
|
raise FlightUnauthenticatedError("No credentials supplied") |
|
49
|
|
|
|