1
|
1 |
|
import base64 |
2
|
1 |
|
import secrets |
3
|
|
|
|
4
|
1 |
|
from pyarrow.flight import ServerMiddlewareFactory, ServerMiddleware |
5
|
1 |
|
from pyarrow.flight import FlightUnauthenticatedError |
6
|
|
|
|
7
|
1 |
|
from tabpy.tabpy_server.handlers.util import hash_password |
8
|
|
|
|
9
|
1 |
|
class BasicAuthServerMiddleware(ServerMiddleware): |
10
|
1 |
|
def __init__(self, token): |
11
|
|
|
self.token = token |
12
|
|
|
|
13
|
1 |
|
def sending_headers(self): |
14
|
|
|
return {"authorization": f"Bearer {self.token}"} |
15
|
|
|
|
16
|
1 |
|
class BasicAuthServerMiddlewareFactory(ServerMiddlewareFactory): |
17
|
1 |
|
def __init__(self, creds): |
18
|
|
|
self.creds = creds |
19
|
|
|
self.tokens = {} |
20
|
|
|
|
21
|
1 |
|
def is_valid_user(self, username, password): |
22
|
|
|
if username not in self.creds: |
23
|
|
|
return False |
24
|
|
|
hashed_pwd = hash_password(username, password) |
25
|
|
|
return self.creds[username].lower() == hashed_pwd.lower() |
26
|
|
|
|
27
|
1 |
|
def start_call(self, info, headers): |
28
|
|
|
auth_header = None |
29
|
|
|
for header in headers: |
30
|
|
|
if header.lower() == "authorization": |
31
|
|
|
auth_header = headers[header][0] |
32
|
|
|
break |
33
|
|
|
|
34
|
|
|
if not auth_header: |
35
|
|
|
raise FlightUnauthenticatedError("No credentials supplied") |
36
|
|
|
|
37
|
|
|
auth_type, _, value = auth_header.partition(" ") |
38
|
|
|
|
39
|
|
|
if auth_type == "Basic": |
40
|
|
|
decoded = base64.b64decode(value).decode("utf-8") |
41
|
|
|
username, _, password = decoded.partition(":") |
42
|
|
|
if not self.is_valid_user(username, password): |
43
|
|
|
raise FlightUnauthenticatedError("Invalid credentials") |
44
|
|
|
token = secrets.token_urlsafe(32) |
45
|
|
|
self.tokens[token] = username |
46
|
|
|
return BasicAuthServerMiddleware(token) |
47
|
|
|
|
48
|
|
|
raise FlightUnauthenticatedError("No credentials supplied") |
49
|
|
|
|