1
|
|
|
import unittest |
2
|
|
|
import threading |
3
|
|
|
import _thread |
4
|
|
|
import pyarrow |
5
|
|
|
import os |
6
|
|
|
import pyarrow.csv as csv |
7
|
|
|
|
8
|
|
|
from tabpy.tabpy_server.app.arrow_server import FlightServer |
9
|
|
|
import tabpy.tabpy_server.app.arrow_server as pa |
10
|
|
|
|
11
|
|
|
class TestArrowServer(unittest.TestCase): |
12
|
|
|
@classmethod |
13
|
|
|
def setUpClass(cls): |
14
|
|
|
host = "localhost" |
15
|
|
|
port = 13620 |
16
|
|
|
scheme = "grpc+tcp" |
17
|
|
|
location = "{}://{}:{}".format(scheme, host, port) |
18
|
|
|
cls.arrow_server = FlightServer(host, location) |
19
|
|
|
def start_server(): |
20
|
|
|
pa.start(cls.arrow_server) |
21
|
|
|
_thread.start_new_thread(start_server, ()) |
22
|
|
|
cls.arrow_client = pyarrow.flight.FlightClient(location) |
23
|
|
|
|
24
|
|
|
@classmethod |
25
|
|
|
def tearDownClass(cls): |
26
|
|
|
cls.arrow_server.shutdown() |
27
|
|
|
|
28
|
|
|
def setUp(self): |
29
|
|
|
self.resources_path = os.path.join(os.path.dirname(__file__), "resources") |
30
|
|
|
self.arrow_server.flights = {} |
31
|
|
|
|
32
|
|
|
def get_descriptor(self, data_path): |
33
|
|
|
return pyarrow.flight.FlightDescriptor.for_path(data_path) |
34
|
|
|
|
35
|
|
|
def write_data(self, data_path): |
36
|
|
|
table = csv.read_csv(data_path) |
37
|
|
|
descriptor = self.get_descriptor(data_path) |
38
|
|
|
writer, _ = self.arrow_client.do_put(descriptor, table.schema) |
39
|
|
|
writer.write_table(table) |
40
|
|
|
writer.close() |
41
|
|
|
return table |
42
|
|
|
|
43
|
|
|
def test_server_do_put(self): |
44
|
|
|
self.write_data(os.path.join(self.resources_path, "data.csv")) |
45
|
|
|
flight_info = list(self.arrow_server.list_flights(None, None)) |
46
|
|
|
self.assertEqual(len(flight_info), 1) |
47
|
|
|
|
48
|
|
|
def test_server_do_get(self): |
49
|
|
|
table = self.write_data(os.path.join(self.resources_path, "data.csv")) |
50
|
|
|
descriptor = self.get_descriptor(os.path.join(self.resources_path, "data.csv")) |
51
|
|
|
self.assertEqual(len(self.arrow_server.flights), 1) |
52
|
|
|
info = self.arrow_client.get_flight_info(descriptor) |
53
|
|
|
reader = self.arrow_client.do_get(info.endpoints[0].ticket) |
54
|
|
|
self.assertTrue(reader.read_all().equals(table)) |
55
|
|
|
self.assertEqual(len(self.arrow_server.flights), 0) |
56
|
|
|
|
57
|
|
|
def test_list_flights_on_new_server(self): |
58
|
|
|
flight_info = list(self.arrow_server.list_flights(None, None)) |
59
|
|
|
self.assertEqual(len(flight_info), 0) |
60
|
|
|
|