Passed
Pull Request — master (#595)
by
unknown
13:12
created

FlightServer.descriptor_to_key()   A

Complexity

Conditions 1

Size

Total Lines 4
Code Lines 4

Duplication

Lines 0
Ratio 0 %

Code Coverage

Tests 1
CRAP Score 1.125

Importance

Changes 0
Metric Value
eloc 4
dl 0
loc 4
ccs 1
cts 2
cp 0.5
rs 10
c 0
b 0
f 0
cc 1
nop 2
crap 1.125
1
# Licensed to the Apache Software Foundation (ASF) under one
2
# or more contributor license agreements.  See the NOTICE file
3
# distributed with this work for additional information
4
# regarding copyright ownership.  The ASF licenses this file
5
# to you under the Apache License, Version 2.0 (the
6
# "License"); you may not use this file except in compliance
7
# with the License.  You may obtain a copy of the License at
8
#
9
#   http://www.apache.org/licenses/LICENSE-2.0
10
#
11
# Unless required by applicable law or agreed to in writing,
12
# software distributed under the License is distributed on an
13
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
# KIND, either express or implied.  See the License for the
15
# specific language governing permissions and limitations
16
# under the License.
17
18 1
"""An example Flight Python server."""
19
20 1
import argparse
21 1
import ast
22 1
import logging
23 1
import threading
24 1
import time
25 1
import uuid
26
27 1
import pyarrow
28 1
import pyarrow.flight
29
30 1
logger = logging.getLogger('__main__.' + __name__)
31
32 1
class FlightServer(pyarrow.flight.FlightServerBase):
33 1
    def __init__(self, host="localhost", location=None,
34
                 tls_certificates=None, verify_client=False,
35
                 root_certificates=None, auth_handler=None):
36
        super(FlightServer, self).__init__(
37
            location, auth_handler, tls_certificates, verify_client,
38
            root_certificates)
39
        self.flights = {}
40
        self.host = host
41
        self.tls_certificates = tls_certificates
42
43 1
    @classmethod
44
    def descriptor_to_key(self, descriptor):
45
        return (descriptor.descriptor_type.value, descriptor.command,
46
                tuple(descriptor.path or tuple()))
47
48 1
    def _make_flight_info(self, key, descriptor, table):
49
        if self.tls_certificates:
50
            location = pyarrow.flight.Location.for_grpc_tls(
51
                self.host, self.port)
52
        else:
53
            location = pyarrow.flight.Location.for_grpc_tcp(
54
                self.host, self.port)
55
        endpoints = [pyarrow.flight.FlightEndpoint(repr(key), [location]), ]
56
57
        mock_sink = pyarrow.MockOutputStream()
58
        stream_writer = pyarrow.RecordBatchStreamWriter(
59
            mock_sink, table.schema)
60
        stream_writer.write_table(table)
61
        stream_writer.close()
62
        data_size = mock_sink.size()
63
64
        return pyarrow.flight.FlightInfo(table.schema,
65
                                         descriptor, endpoints,
66
                                         table.num_rows, data_size)
67
68 1
    def list_flights(self, context, criteria):
69
        for key, table in self.flights.items():
70
            if key[1] is not None:
71
                descriptor = \
72
                    pyarrow.flight.FlightDescriptor.for_command(key[1])
73
            else:
74
                descriptor = pyarrow.flight.FlightDescriptor.for_path(*key[2])
75
76
            yield self._make_flight_info(key, descriptor, table)
77
78 1
    def get_flight_info(self, context, descriptor):
79
        key = FlightServer.descriptor_to_key(descriptor)
80
        logger.info(f"get_flight_info: key={key}")
81
        if key in self.flights:
82
            table = self.flights[key]
83
            return self._make_flight_info(key, descriptor, table)
84
        raise KeyError('Flight not found.')
85
86 1
    def do_put(self, context, descriptor, reader, writer):
87
        key = FlightServer.descriptor_to_key(descriptor)
88
        logger.info(f"do_put: key={key}")
89
        self.flights[key] = reader.read_all()
90
91 1
    def do_get(self, context, ticket):
92
        logger.info(f"do_get: ticket={ticket}")
93
        key = ast.literal_eval(ticket.ticket.decode())
94
        if key not in self.flights:
95
            logger.warn(f"do_get: key={key} not found")
96
            return None
97
        logger.info(f"do_get: returning key={key}")
98
        flight = self.flights.pop(key)
99
        return pyarrow.flight.RecordBatchStream(flight)
100
101 1
    def list_actions(self, context):
102
        return iter([
103
            ("getUniquePath", "Get a unique FileDescriptor path to put data to."),
104
            ("clear", "Clear the stored flights."),
105
            ("shutdown", "Shut down this server."),
106
        ])
107
108 1
    def do_action(self, context, action):
109
        logger.info(f"do_action: action={action.type}")
110
        if action.type == "getUniquePath":
111
            uniqueId = str(uuid.uuid4())
112
            logger.info(f"getUniquePath id={uniqueId}")
113
            yield uniqueId.encode('utf-8')
114
        elif action.type == "clear":
115
            self._clear()
116
            # raise NotImplementedError(
117
            #     "{} is not implemented.".format(action.type))
118
        elif action.type == "healthcheck":
119
            pass
120
        elif action.type == "shutdown":
121
            self._clear()
122
            yield pyarrow.flight.Result(pyarrow.py_buffer(b'Shutdown!'))
123
            # Shut down on background thread to avoid blocking current
124
            # request
125
            threading.Thread(target=self._shutdown).start()
126
        else:
127
            raise KeyError("Unknown action {!r}".format(action.type))
128
129 1
    def _clear(self):
130
        """Clear the stored flights."""
131
        self.flights = {}
132
133 1
    def _shutdown(self):
134
        """Shut down after a delay."""
135
        logger.info("Server is shutting down...")
136
        time.sleep(2)
137
        self.shutdown()
138
139
140 1
def start():
141
    parser = argparse.ArgumentParser()
142
    parser.add_argument("--host", type=str, default="localhost",
143
                        help="Address or hostname to listen on")
144
    parser.add_argument("--port", type=int, default=13622,
145
                        help="Port number to listen on")
146
    parser.add_argument("--tls", nargs=2, default=None,
147
                        metavar=('CERTFILE', 'KEYFILE'),
148
                        help="Enable transport-level security")
149
    parser.add_argument("--verify_client", type=bool, default=False,
150
                        help="enable mutual TLS and verify the client if True")
151
    parser.add_argument("--config", type=str, default="", help="should be ignored") # TODO: implement config
152
153
    args = parser.parse_args()
154
    tls_certificates = []
155
    scheme = "grpc+tcp"
156
    if args.tls:
157
        scheme = "grpc+tls"
158
        with open(args.tls[0], "rb") as cert_file:
159
            tls_cert_chain = cert_file.read()
160
        with open(args.tls[1], "rb") as key_file:
161
            tls_private_key = key_file.read()
162
        tls_certificates.append((tls_cert_chain, tls_private_key))
163
164
    location = "{}://{}:{}".format(scheme, args.host, args.port)
165
166
    server = FlightServer(args.host, location,
167
                          tls_certificates=tls_certificates,
168
                          verify_client=args.verify_client)
169
    logger.info(f"Serving on {location}")
170
    server.serve()
171
172
173
if __name__ == '__main__':
174
    start()