|
1
|
|
|
import unittest |
|
2
|
|
|
import threading |
|
3
|
|
|
import _thread |
|
4
|
|
|
import pyarrow |
|
5
|
|
|
import os |
|
6
|
|
|
import pyarrow.csv as csv |
|
7
|
|
|
|
|
8
|
|
|
from tabpy.tabpy_server.app.arrow_server import FlightServer |
|
9
|
|
|
import tabpy.tabpy_server.app.arrow_server as pa |
|
10
|
|
|
|
|
11
|
|
|
class TestArrowServer(unittest.TestCase): |
|
12
|
|
|
@classmethod |
|
13
|
|
|
def setUpClass(cls): |
|
14
|
|
|
host = "localhost" |
|
15
|
|
|
port = 13620 |
|
16
|
|
|
scheme = "grpc+tcp" |
|
17
|
|
|
location = "{}://{}:{}".format(scheme, host, port) |
|
18
|
|
|
cls.arrow_server = FlightServer(host, location) |
|
19
|
|
|
def start_server(): |
|
20
|
|
|
pa.start(cls.arrow_server) |
|
21
|
|
|
_thread.start_new_thread(start_server, ()) |
|
22
|
|
|
cls.arrow_client = pyarrow.flight.FlightClient(location) |
|
23
|
|
|
|
|
24
|
|
|
@classmethod |
|
25
|
|
|
def tearDownClass(cls): |
|
26
|
|
|
cls.arrow_server.shutdown() |
|
27
|
|
|
|
|
28
|
|
|
def setUp(self): |
|
29
|
|
|
self.resources_path = os.path.join(os.path.dirname(__file__), "resources") |
|
30
|
|
|
self.arrow_server.flights = {} |
|
31
|
|
|
|
|
32
|
|
|
def get_descriptor(self, data_path): |
|
33
|
|
|
return pyarrow.flight.FlightDescriptor.for_path(data_path) |
|
34
|
|
|
|
|
35
|
|
|
def write_data(self, data_path): |
|
36
|
|
|
table = csv.read_csv(data_path) |
|
37
|
|
|
descriptor = self.get_descriptor(data_path) |
|
38
|
|
|
writer, _ = self.arrow_client.do_put(descriptor, table.schema) |
|
39
|
|
|
writer.write_table(table) |
|
40
|
|
|
writer.close() |
|
41
|
|
|
return table |
|
42
|
|
|
|
|
43
|
|
|
def test_server_do_put(self): |
|
44
|
|
|
self.write_data(os.path.join(self.resources_path, "data.csv")) |
|
45
|
|
|
flight_info = list(self.arrow_server.list_flights(None, None)) |
|
46
|
|
|
self.assertEqual(len(flight_info), 1) |
|
47
|
|
|
|
|
48
|
|
|
def test_server_do_get(self): |
|
49
|
|
|
table = self.write_data(os.path.join(self.resources_path, "data.csv")) |
|
50
|
|
|
descriptor = self.get_descriptor(os.path.join(self.resources_path, "data.csv")) |
|
51
|
|
|
self.assertEqual(len(self.arrow_server.flights), 1) |
|
52
|
|
|
info = self.arrow_client.get_flight_info(descriptor) |
|
53
|
|
|
reader = self.arrow_client.do_get(info.endpoints[0].ticket) |
|
54
|
|
|
self.assertTrue(reader.read_all().equals(table)) |
|
55
|
|
|
self.assertEqual(len(self.arrow_server.flights), 0) |
|
56
|
|
|
|
|
57
|
|
|
def test_list_flights_on_new_server(self): |
|
58
|
|
|
flight_info = list(self.arrow_server.list_flights(None, None)) |
|
59
|
|
|
self.assertEqual(len(flight_info), 0) |
|
60
|
|
|
|