1
|
|
|
# Licensed to the Apache Software Foundation (ASF) under one |
2
|
|
|
# or more contributor license agreements. See the NOTICE file |
3
|
|
|
# distributed with this work for additional information |
4
|
|
|
# regarding copyright ownership. The ASF licenses this file |
5
|
|
|
# to you under the Apache License, Version 2.0 (the |
6
|
|
|
# "License"); you may not use this file except in compliance |
7
|
|
|
# with the License. You may obtain a copy of the License at |
8
|
|
|
# |
9
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0 |
10
|
|
|
# |
11
|
|
|
# Unless required by applicable law or agreed to in writing, |
12
|
|
|
# software distributed under the License is distributed on an |
13
|
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14
|
|
|
# KIND, either express or implied. See the License for the |
15
|
|
|
# specific language governing permissions and limitations |
16
|
|
|
# under the License. |
17
|
|
|
|
18
|
|
|
|
19
|
1 |
|
import ast |
20
|
1 |
|
import logging |
21
|
1 |
|
import threading |
22
|
1 |
|
import time |
23
|
1 |
|
import uuid |
24
|
|
|
|
25
|
1 |
|
import pyarrow |
26
|
1 |
|
import pyarrow.flight |
27
|
|
|
|
28
|
|
|
|
29
|
1 |
|
logger = logging.getLogger('__main__.' + __name__) |
30
|
|
|
|
31
|
1 |
|
class FlightServer(pyarrow.flight.FlightServerBase): |
32
|
1 |
|
def __init__(self, host="localhost", location=None, |
33
|
|
|
tls_certificates=None, verify_client=False, |
34
|
|
|
root_certificates=None, auth_handler=None, middleware=None): |
35
|
1 |
|
super(FlightServer, self).__init__( |
36
|
|
|
location, auth_handler, tls_certificates, verify_client, |
37
|
|
|
root_certificates, middleware) |
38
|
1 |
|
self.flights = {} |
39
|
1 |
|
self.host = host |
40
|
1 |
|
self.tls_certificates = tls_certificates |
41
|
1 |
|
self.location = location |
42
|
|
|
|
43
|
1 |
|
@classmethod |
44
|
|
|
def descriptor_to_key(self, descriptor): |
45
|
|
|
return (descriptor.descriptor_type.value, descriptor.command, |
46
|
|
|
tuple(descriptor.path or tuple())) |
47
|
|
|
|
48
|
1 |
|
def _make_flight_info(self, key, descriptor, table): |
49
|
1 |
|
if self.tls_certificates: |
50
|
|
|
location = pyarrow.flight.Location.for_grpc_tls( |
51
|
|
|
self.host, self.port) |
52
|
|
|
else: |
53
|
1 |
|
location = pyarrow.flight.Location.for_grpc_tcp( |
54
|
|
|
self.host, self.port) |
55
|
1 |
|
endpoints = [pyarrow.flight.FlightEndpoint(repr(key), [location]), ] |
56
|
|
|
|
57
|
1 |
|
mock_sink = pyarrow.MockOutputStream() |
58
|
1 |
|
stream_writer = pyarrow.RecordBatchStreamWriter( |
59
|
|
|
mock_sink, table.schema) |
60
|
1 |
|
stream_writer.write_table(table) |
61
|
1 |
|
stream_writer.close() |
62
|
1 |
|
data_size = mock_sink.size() |
63
|
|
|
|
64
|
1 |
|
return pyarrow.flight.FlightInfo(table.schema, |
65
|
|
|
descriptor, endpoints, |
66
|
|
|
table.num_rows, data_size) |
67
|
|
|
|
68
|
1 |
|
def list_flights(self, context, criteria): |
69
|
1 |
|
for key, table in self.flights.items(): |
70
|
1 |
|
if key[1] is not None: |
71
|
|
|
descriptor = \ |
72
|
|
|
pyarrow.flight.FlightDescriptor.for_command(key[1]) |
73
|
|
|
else: |
74
|
1 |
|
descriptor = pyarrow.flight.FlightDescriptor.for_path(*key[2]) |
75
|
|
|
|
76
|
1 |
|
yield self._make_flight_info(key, descriptor, table) |
77
|
|
|
|
78
|
1 |
|
def get_flight_info(self, context, descriptor): |
79
|
|
|
key = FlightServer.descriptor_to_key(descriptor) |
80
|
|
|
logger.info(f"get_flight_info: key={key}") |
81
|
|
|
if key in self.flights: |
82
|
|
|
table = self.flights[key] |
83
|
|
|
return self._make_flight_info(key, descriptor, table) |
84
|
|
|
raise KeyError('Flight not found.') |
85
|
|
|
|
86
|
1 |
|
def do_put(self, context, descriptor, reader, writer): |
87
|
|
|
key = FlightServer.descriptor_to_key(descriptor) |
88
|
|
|
logger.info(f"do_put: key={key}") |
89
|
|
|
self.flights[key] = reader.read_all() |
90
|
|
|
|
91
|
1 |
|
def do_get(self, context, ticket): |
92
|
|
|
logger.info(f"do_get: ticket={ticket}") |
93
|
|
|
key = ast.literal_eval(ticket.ticket.decode()) |
94
|
|
|
if key not in self.flights: |
95
|
|
|
logger.warn(f"do_get: key={key} not found") |
96
|
|
|
return None |
97
|
|
|
logger.info(f"do_get: returning key={key}") |
98
|
|
|
flight = self.flights.pop(key) |
99
|
|
|
return pyarrow.flight.RecordBatchStream(flight) |
100
|
|
|
|
101
|
1 |
|
def list_actions(self, context): |
102
|
|
|
return iter([ |
103
|
|
|
("getUniquePath", "Get a unique FlightDescriptor path to put data to."), |
104
|
|
|
("clear", "Clear the stored flights."), |
105
|
|
|
("shutdown", "Shut down this server."), |
106
|
|
|
]) |
107
|
|
|
|
108
|
1 |
|
def do_action(self, context, action): |
109
|
|
|
logger.info(f"do_action: action={action.type}") |
110
|
|
|
if action.type == "getUniquePath": |
111
|
|
|
uniqueId = str(uuid.uuid4()) |
112
|
|
|
logger.info(f"getUniquePath id={uniqueId}") |
113
|
|
|
yield uniqueId.encode('utf-8') |
114
|
|
|
elif action.type == "clear": |
115
|
|
|
self._clear() |
116
|
|
|
elif action.type == "healthcheck": |
117
|
|
|
pass |
118
|
|
|
elif action.type == "shutdown": |
119
|
|
|
self._clear() |
120
|
|
|
yield pyarrow.flight.Result(pyarrow.py_buffer(b'Shutdown!')) |
121
|
|
|
# Shut down on background thread to avoid blocking current |
122
|
|
|
# request |
123
|
|
|
threading.Thread(target=self._shutdown).start() |
124
|
|
|
else: |
125
|
|
|
raise KeyError("Unknown action {!r}".format(action.type)) |
126
|
|
|
|
127
|
1 |
|
def _clear(self): |
128
|
|
|
"""Clear the stored flights.""" |
129
|
|
|
self.flights = {} |
130
|
|
|
|
131
|
1 |
|
def _shutdown(self): |
132
|
|
|
"""Shut down after a delay.""" |
133
|
|
|
logger.info("Server is shutting down...") |
134
|
|
|
time.sleep(2) |
135
|
|
|
self.shutdown() |
136
|
|
|
|
137
|
1 |
|
def start(server): |
138
|
|
|
logger.info(f"Serving on {server.location}") |
139
|
|
|
server.serve() |
140
|
|
|
|
141
|
|
|
|
142
|
|
|
if __name__ == '__main__': |
143
|
|
|
start() |