Test Failed
Pull Request — master (#593)
by
unknown
13:45
created

FlightServer._clear()   A

Complexity

Conditions 1

Size

Total Lines 3
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 2
dl 0
loc 3
rs 10
c 0
b 0
f 0
cc 1
nop 1
1
# Licensed to the Apache Software Foundation (ASF) under one
2
# or more contributor license agreements.  See the NOTICE file
3
# distributed with this work for additional information
4
# regarding copyright ownership.  The ASF licenses this file
5
# to you under the Apache License, Version 2.0 (the
6
# "License"); you may not use this file except in compliance
7
# with the License.  You may obtain a copy of the License at
8
#
9
#   http://www.apache.org/licenses/LICENSE-2.0
10
#
11
# Unless required by applicable law or agreed to in writing,
12
# software distributed under the License is distributed on an
13
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
# KIND, either express or implied.  See the License for the
15
# specific language governing permissions and limitations
16
# under the License.
17
18
"""An example Flight Python server."""
19
20
import argparse
21
import ast
22
import logging
23
import threading
24
import time
25
import uuid
26
27
import pyarrow
28
import pyarrow.flight
29
30
logger = logging.getLogger(__name__)
31
32
class FlightServer(pyarrow.flight.FlightServerBase):
33
    def __init__(self, host="localhost", location=None,
34
                 tls_certificates=None, verify_client=False,
35
                 root_certificates=None, auth_handler=None):
36
        super(FlightServer, self).__init__(
37
            location, auth_handler, tls_certificates, verify_client,
38
            root_certificates)
39
        self.flights = {}
40
        self.host = host
41
        self.tls_certificates = tls_certificates
42
43
    @classmethod
44
    def descriptor_to_key(self, descriptor):
45
        return (descriptor.descriptor_type.value, descriptor.command,
46
                tuple(descriptor.path or tuple()))
47
48
    def _make_flight_info(self, key, descriptor, table):
49
        if self.tls_certificates:
50
            location = pyarrow.flight.Location.for_grpc_tls(
51
                self.host, self.port)
52
        else:
53
            location = pyarrow.flight.Location.for_grpc_tcp(
54
                self.host, self.port)
55
        endpoints = [pyarrow.flight.FlightEndpoint(repr(key), [location]), ]
56
57
        mock_sink = pyarrow.MockOutputStream()
58
        stream_writer = pyarrow.RecordBatchStreamWriter(
59
            mock_sink, table.schema)
60
        stream_writer.write_table(table)
61
        stream_writer.close()
62
        data_size = mock_sink.size()
63
64
        return pyarrow.flight.FlightInfo(table.schema,
65
                                         descriptor, endpoints,
66
                                         table.num_rows, data_size)
67
68
    def list_flights(self, context, criteria):
69
        for key, table in self.flights.items():
70
            if key[1] is not None:
71
                descriptor = \
72
                    pyarrow.flight.FlightDescriptor.for_command(key[1])
73
            else:
74
                descriptor = pyarrow.flight.FlightDescriptor.for_path(*key[2])
75
76
            yield self._make_flight_info(key, descriptor, table)
77
78
    def get_flight_info(self, context, descriptor):
79
        key = FlightServer.descriptor_to_key(descriptor)
80
        logger.info(f"get_flight_info: key={key}")
81
        if key in self.flights:
82
            table = self.flights[key]
83
            return self._make_flight_info(key, descriptor, table)
84
        raise KeyError('Flight not found.')
85
86
    def do_put(self, context, descriptor, reader, writer):
87
        key = FlightServer.descriptor_to_key(descriptor)
88
        logger.info(f"do_put: key={key}")
89
        self.flights[key] = reader.read_all()
90
91
    def do_get(self, context, ticket):
92
        logger.info(f"do_get: ticket={ticket}")
93
        key = ast.literal_eval(ticket.ticket.decode())
94
        if key not in self.flights:
95
            logger.warn(f"do_get: key={key} not found")
96
            return None
97
        logger.info(f"do_get: returning key={key}")
98
        return pyarrow.flight.RecordBatchStream(self.flights[key])
99
100
    def list_actions(self, context):
101
        return iter([
102
            ("getUniquePath", "Get a unique FileDescriptor path to put data to."),
103
            ("clear", "Clear the stored flights."),
104
            ("shutdown", "Shut down this server."),
105
        ])
106
107
    def do_action(self, context, action):
108
        logger.info(f"do_action: action={action.type}")
109
        if action.type == "getUniquePath":
110
            uniqueId = str(uuid.uuid4())
111
            logger.info(f"getUniquePath id={uniqueId}")
112
            yield uniqueId.encode('utf-8')
113
        elif action.type == "clear":
114
            self._clear()
115
            # raise NotImplementedError(
116
            #     "{} is not implemented.".format(action.type))
117
        elif action.type == "healthcheck":
118
            pass
119
        elif action.type == "shutdown":
120
            self._clear()
121
            yield pyarrow.flight.Result(pyarrow.py_buffer(b'Shutdown!'))
122
            # Shut down on background thread to avoid blocking current
123
            # request
124
            threading.Thread(target=self._shutdown).start()
125
        else:
126
            raise KeyError("Unknown action {!r}".format(action.type))
127
        
128
    def _clear(self):
129
        """Clear the stored flights."""
130
        self.flights = {}
131
132
    def _shutdown(self):
133
        """Shut down after a delay."""
134
        print("Server is shutting down...")
135
        time.sleep(2)
136
        self.shutdown()
137
138
139
def start():
140
    parser = argparse.ArgumentParser()
141
    parser.add_argument("--host", type=str, default="localhost",
142
                        help="Address or hostname to listen on")
143
    parser.add_argument("--port", type=int, default=5005,
144
                        help="Port number to listen on")
145
    parser.add_argument("--tls", nargs=2, default=None,
146
                        metavar=('CERTFILE', 'KEYFILE'),
147
                        help="Enable transport-level security")
148
    parser.add_argument("--verify_client", type=bool, default=False,
149
                        help="enable mutual TLS and verify the client if True")
150
    parser.add_argument("--config", type=str, default="", help="should be ignored") # TODO: implement config
151
152
    args = parser.parse_args()
153
    tls_certificates = []
154
    scheme = "grpc+tcp"
155
    if args.tls:
156
        scheme = "grpc+tls"
157
        with open(args.tls[0], "rb") as cert_file:
158
            tls_cert_chain = cert_file.read()
159
        with open(args.tls[1], "rb") as key_file:
160
            tls_private_key = key_file.read()
161
        tls_certificates.append((tls_cert_chain, tls_private_key))
162
163
    location = "{}://{}:{}".format(scheme, args.host, args.port)
164
165
    server = FlightServer(args.host, location,
166
                          tls_certificates=tls_certificates,
167
                          verify_client=args.verify_client)
168
    print("Serving on", location)
169
    server.serve()
170
171
172
if __name__ == '__main__':
173
    start()
174