Test Failed
Pull Request — master (#592)
by
unknown
06:58
created

tabpy.tabpy_server.app.arrow_server   A

Complexity

Total Complexity 22

Size/Duplication

Total Lines 155
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 22
eloc 108
dl 0
loc 155
rs 10
c 0
b 0
f 0

10 Methods

Rating   Name   Duplication   Size   Complexity  
A FlightServer.descriptor_to_key() 0 4 1
A FlightServer._make_flight_info() 0 19 2
A FlightServer.list_flights() 0 9 3
A FlightServer.__init__() 0 9 1
A FlightServer.do_put() 0 5 1
A FlightServer.get_flight_info() 0 6 2
A FlightServer.list_actions() 0 4 1
A FlightServer.do_action() 0 13 4
A FlightServer.do_get() 0 5 2
A FlightServer._shutdown() 0 5 1

1 Function

Rating   Name   Duplication   Size   Complexity  
A start() 0 30 4
1
# Licensed to the Apache Software Foundation (ASF) under one
2
# or more contributor license agreements.  See the NOTICE file
3
# distributed with this work for additional information
4
# regarding copyright ownership.  The ASF licenses this file
5
# to you under the Apache License, Version 2.0 (the
6
# "License"); you may not use this file except in compliance
7
# with the License.  You may obtain a copy of the License at
8
#
9
#   http://www.apache.org/licenses/LICENSE-2.0
10
#
11
# Unless required by applicable law or agreed to in writing,
12
# software distributed under the License is distributed on an
13
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
# KIND, either express or implied.  See the License for the
15
# specific language governing permissions and limitations
16
# under the License.
17
18
"""An example Flight Python server."""
19
20
import argparse
21
import ast
22
import threading
23
import time
24
25
import pyarrow
26
import pyarrow.flight
27
28
29
class FlightServer(pyarrow.flight.FlightServerBase):
30
    def __init__(self, host="localhost", location=None,
31
                 tls_certificates=None, verify_client=False,
32
                 root_certificates=None, auth_handler=None):
33
        super(FlightServer, self).__init__(
34
            location, auth_handler, tls_certificates, verify_client,
35
            root_certificates)
36
        self.flights = {}
37
        self.host = host
38
        self.tls_certificates = tls_certificates
39
40
    @classmethod
41
    def descriptor_to_key(self, descriptor):
42
        return (descriptor.descriptor_type.value, descriptor.command,
43
                tuple(descriptor.path or tuple()))
44
45
    def _make_flight_info(self, key, descriptor, table):
46
        if self.tls_certificates:
47
            location = pyarrow.flight.Location.for_grpc_tls(
48
                self.host, self.port)
49
        else:
50
            location = pyarrow.flight.Location.for_grpc_tcp(
51
                self.host, self.port)
52
        endpoints = [pyarrow.flight.FlightEndpoint(repr(key), [location]), ]
53
54
        mock_sink = pyarrow.MockOutputStream()
55
        stream_writer = pyarrow.RecordBatchStreamWriter(
56
            mock_sink, table.schema)
57
        stream_writer.write_table(table)
58
        stream_writer.close()
59
        data_size = mock_sink.size()
60
61
        return pyarrow.flight.FlightInfo(table.schema,
62
                                         descriptor, endpoints,
63
                                         table.num_rows, data_size)
64
65
    def list_flights(self, context, criteria):
66
        for key, table in self.flights.items():
67
            if key[1] is not None:
68
                descriptor = \
69
                    pyarrow.flight.FlightDescriptor.for_command(key[1])
70
            else:
71
                descriptor = pyarrow.flight.FlightDescriptor.for_path(*key[2])
72
73
            yield self._make_flight_info(key, descriptor, table)
74
75
    def get_flight_info(self, context, descriptor):
76
        key = FlightServer.descriptor_to_key(descriptor)
77
        if key in self.flights:
78
            table = self.flights[key]
79
            return self._make_flight_info(key, descriptor, table)
80
        raise KeyError('Flight not found.')
81
82
    def do_put(self, context, descriptor, reader, writer):
83
        key = FlightServer.descriptor_to_key(descriptor)
84
        print(key)
85
        self.flights[key] = reader.read_all()
86
        print(self.flights[key])
87
88
    def do_get(self, context, ticket):
89
        key = ast.literal_eval(ticket.ticket.decode())
90
        if key not in self.flights:
91
            return None
92
        return pyarrow.flight.RecordBatchStream(self.flights[key])
93
94
    def list_actions(self, context):
95
        return [
96
            ("clear", "Clear the stored flights."),
97
            ("shutdown", "Shut down this server."),
98
        ]
99
100
    def do_action(self, context, action):
101
        if action.type == "clear":
102
            raise NotImplementedError(
103
                "{} is not implemented.".format(action.type))
104
        elif action.type == "healthcheck":
105
            pass
106
        elif action.type == "shutdown":
107
            yield pyarrow.flight.Result(pyarrow.py_buffer(b'Shutdown!'))
108
            # Shut down on background thread to avoid blocking current
109
            # request
110
            threading.Thread(target=self._shutdown).start()
111
        else:
112
            raise KeyError("Unknown action {!r}".format(action.type))
113
114
    def _shutdown(self):
115
        """Shut down after a delay."""
116
        print("Server is shutting down...")
117
        time.sleep(2)
118
        self.shutdown()
119
120
121
def start():
122
    parser = argparse.ArgumentParser()
123
    parser.add_argument("--host", type=str, default="localhost",
124
                        help="Address or hostname to listen on")
125
    parser.add_argument("--port", type=int, default=5005,
126
                        help="Port number to listen on")
127
    parser.add_argument("--tls", nargs=2, default=None,
128
                        metavar=('CERTFILE', 'KEYFILE'),
129
                        help="Enable transport-level security")
130
    parser.add_argument("--verify_client", type=bool, default=False,
131
                        help="enable mutual TLS and verify the client if True")
132
133
    args = parser.parse_args()
134
    tls_certificates = []
135
    scheme = "grpc+tcp"
136
    if args.tls:
137
        scheme = "grpc+tls"
138
        with open(args.tls[0], "rb") as cert_file:
139
            tls_cert_chain = cert_file.read()
140
        with open(args.tls[1], "rb") as key_file:
141
            tls_private_key = key_file.read()
142
        tls_certificates.append((tls_cert_chain, tls_private_key))
143
144
    location = "{}://{}:{}".format(scheme, args.host, args.port)
145
146
    server = FlightServer(args.host, location,
147
                          tls_certificates=tls_certificates,
148
                          verify_client=args.verify_client)
149
    print("Serving on", location)
150
    server.serve()
151
152
153
if __name__ == '__main__':
154
    start()
155