Passed
Pull Request — master (#595)
by
unknown
13:17
created

tabpy.tabpy_server.app.arrow_server.start()   A

Complexity

Conditions 1

Size

Total Lines 3
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Code Coverage

Tests 1
CRAP Score 1.2963

Importance

Changes 0
Metric Value
eloc 3
dl 0
loc 3
ccs 1
cts 3
cp 0.3333
rs 10
c 0
b 0
f 0
cc 1
nop 1
crap 1.2963
1
# Licensed to the Apache Software Foundation (ASF) under one
2
# or more contributor license agreements.  See the NOTICE file
3
# distributed with this work for additional information
4
# regarding copyright ownership.  The ASF licenses this file
5
# to you under the Apache License, Version 2.0 (the
6
# "License"); you may not use this file except in compliance
7
# with the License.  You may obtain a copy of the License at
8
#
9
#   http://www.apache.org/licenses/LICENSE-2.0
10
#
11
# Unless required by applicable law or agreed to in writing,
12
# software distributed under the License is distributed on an
13
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
# KIND, either express or implied.  See the License for the
15
# specific language governing permissions and limitations
16
# under the License.
17
18
19 1
import ast
20 1
import logging
21 1
import threading
22 1
import time
23 1
import uuid
24
25 1
import pyarrow
26 1
import pyarrow.flight
27
28 1
from tabpy.tabpy_server.app.app_parameters import SettingsParameters, ConfigParameters
29 1
from tabpy.tabpy_server.app.util import parse_pwd_file
30 1
from tabpy.tabpy_server.handlers import NoOpAuthHandler, BasicAuthServerMiddlewareFactory
31
32
33 1
logger = logging.getLogger('__main__.' + __name__)
34
35 1
class FlightServer(pyarrow.flight.FlightServerBase):
36 1
    def __init__(self, host="localhost", location=None,
37
                 tls_certificates=None, verify_client=False,
38
                 root_certificates=None, auth_handler=None, middleware=None):
39
        super(FlightServer, self).__init__(
40
            location, auth_handler, tls_certificates, verify_client,
41
            root_certificates, middleware)
42
        self.flights = {}
43
        self.host = host
44
        self.tls_certificates = tls_certificates
45
        self.location = location
46
47 1
    @classmethod
48
    def descriptor_to_key(self, descriptor):
49
        return (descriptor.descriptor_type.value, descriptor.command,
50
                tuple(descriptor.path or tuple()))
51
52 1
    def _make_flight_info(self, key, descriptor, table):
53
        if self.tls_certificates:
54
            location = pyarrow.flight.Location.for_grpc_tls(
55
                self.host, self.port)
56
        else:
57
            location = pyarrow.flight.Location.for_grpc_tcp(
58
                self.host, self.port)
59
        endpoints = [pyarrow.flight.FlightEndpoint(repr(key), [location]), ]
60
61
        mock_sink = pyarrow.MockOutputStream()
62
        stream_writer = pyarrow.RecordBatchStreamWriter(
63
            mock_sink, table.schema)
64
        stream_writer.write_table(table)
65
        stream_writer.close()
66
        data_size = mock_sink.size()
67
68
        return pyarrow.flight.FlightInfo(table.schema,
69
                                         descriptor, endpoints,
70
                                         table.num_rows, data_size)
71
72 1
    def list_flights(self, context, criteria):
73
        for key, table in self.flights.items():
74
            if key[1] is not None:
75
                descriptor = \
76
                    pyarrow.flight.FlightDescriptor.for_command(key[1])
77
            else:
78
                descriptor = pyarrow.flight.FlightDescriptor.for_path(*key[2])
79
80
            yield self._make_flight_info(key, descriptor, table)
81
82 1
    def get_flight_info(self, context, descriptor):
83
        key = FlightServer.descriptor_to_key(descriptor)
84
        logger.info(f"get_flight_info: key={key}")
85
        if key in self.flights:
86
            table = self.flights[key]
87
            return self._make_flight_info(key, descriptor, table)
88
        raise KeyError('Flight not found.')
89
90 1
    def do_put(self, context, descriptor, reader, writer):
91
        key = FlightServer.descriptor_to_key(descriptor)
92
        logger.info(f"do_put: key={key}")
93
        self.flights[key] = reader.read_all()
94
95 1
    def do_get(self, context, ticket):
96
        logger.info(f"do_get: ticket={ticket}")
97
        key = ast.literal_eval(ticket.ticket.decode())
98
        if key not in self.flights:
99
            logger.warn(f"do_get: key={key} not found")
100
            return None
101
        logger.info(f"do_get: returning key={key}")
102
        flight = self.flights.pop(key)
103
        return pyarrow.flight.RecordBatchStream(flight)
104
105 1
    def list_actions(self, context):
106
        return iter([
107
            ("getUniquePath", "Get a unique FileDescriptor path to put data to."),
108
            ("clear", "Clear the stored flights."),
109
            ("shutdown", "Shut down this server."),
110
        ])
111
112 1
    def do_action(self, context, action):
113
        logger.info(f"do_action: action={action.type}")
114
        if action.type == "getUniquePath":
115
            uniqueId = str(uuid.uuid4())
116
            logger.info(f"getUniquePath id={uniqueId}")
117
            yield uniqueId.encode('utf-8')
118
        elif action.type == "clear":
119
            self._clear()
120
            # raise NotImplementedError(
121
            #     "{} is not implemented.".format(action.type))
122
        elif action.type == "healthcheck":
123
            pass
124
        elif action.type == "shutdown":
125
            self._clear()
126
            yield pyarrow.flight.Result(pyarrow.py_buffer(b'Shutdown!'))
127
            # Shut down on background thread to avoid blocking current
128
            # request
129
            threading.Thread(target=self._shutdown).start()
130
        else:
131
            raise KeyError("Unknown action {!r}".format(action.type))
132
133 1
    def _clear(self):
134
        """Clear the stored flights."""
135
        self.flights = {}
136
137 1
    def _shutdown(self):
138
        """Shut down after a delay."""
139
        logger.info("Server is shutting down...")
140
        time.sleep(2)
141
        self.shutdown()
142
143 1
def start(server):
144
    logger.info(f"Serving on {server.location}")
145
    server.serve()
146
147
148
if __name__ == '__main__':
149
    start()