Passed
Push — master ( fad680...3bdfcd )
by
unknown
13:18
created

TestArrowServer.test_server_do_get()   A

Complexity

Conditions 1

Size

Total Lines 8
Code Lines 8

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 8
dl 0
loc 8
rs 10
c 0
b 0
f 0
cc 1
nop 1
1
import unittest
2
import threading
3
import _thread
4
import pyarrow
5
import os
6
import pyarrow.csv as csv
7
8
from tabpy.tabpy_server.app.arrow_server import FlightServer
9
import tabpy.tabpy_server.app.arrow_server as pa
10
11
class TestArrowServer(unittest.TestCase):
12
    @classmethod
13
    def setUpClass(cls):
14
        host = "localhost"
15
        port = 13620
16
        scheme = "grpc+tcp"
17
        location = "{}://{}:{}".format(scheme, host, port)
18
        cls.arrow_server = FlightServer(host, location)
19
        def start_server():
20
            pa.start(cls.arrow_server)
21
        _thread.start_new_thread(start_server, ())
22
        cls.arrow_client = pyarrow.flight.FlightClient(location)
23
    
24
    @classmethod
25
    def tearDownClass(cls):
26
        cls.arrow_server.shutdown()
27
    
28
    def setUp(self):
29
        self.resources_path = os.path.join(os.path.dirname(__file__), "resources")
30
        self.arrow_server.flights = {}
31
32
    def get_descriptor(self, data_path):
33
        return pyarrow.flight.FlightDescriptor.for_path(data_path)
34
35
    def write_data(self, data_path):
36
        table = csv.read_csv(data_path)
37
        descriptor = self.get_descriptor(data_path)
38
        writer, _ = self.arrow_client.do_put(descriptor, table.schema)
39
        writer.write_table(table)
40
        writer.close()
41
        return table
42
43
    def test_server_do_put(self):
44
        self.write_data(os.path.join(self.resources_path, "data.csv"))
45
        flight_info = list(self.arrow_server.list_flights(None, None))
46
        self.assertEqual(len(flight_info), 1)
47
48
    def test_server_do_get(self):
49
        table = self.write_data(os.path.join(self.resources_path, "data.csv"))
50
        descriptor = self.get_descriptor(os.path.join(self.resources_path, "data.csv"))
51
        self.assertEqual(len(self.arrow_server.flights), 1)
52
        info = self.arrow_client.get_flight_info(descriptor)
53
        reader = self.arrow_client.do_get(info.endpoints[0].ticket)
54
        self.assertTrue(reader.read_all().equals(table))
55
        self.assertEqual(len(self.arrow_server.flights), 0)
56
57
    def test_list_flights_on_new_server(self):
58
        flight_info = list(self.arrow_server.list_flights(None, None))
59
        self.assertEqual(len(flight_info), 0)
60