Passed
Push — master ( fad680...3bdfcd )
by
unknown
13:18
created

tests.integration.test_arrow_server   A

Complexity

Total Complexity 8

Size/Duplication

Total Lines 60
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 8
eloc 51
dl 0
loc 60
rs 10
c 0
b 0
f 0

8 Methods

Rating   Name   Duplication   Size   Complexity  
A TestArrowServer.setUp() 0 3 1
A TestArrowServer.test_server_do_put() 0 4 1
A TestArrowServer.test_server_do_get() 0 8 1
A TestArrowServer.tearDownClass() 0 3 1
A TestArrowServer.get_descriptor() 0 2 1
A TestArrowServer.write_data() 0 7 1
A TestArrowServer.test_list_flights_on_new_server() 0 3 1
A TestArrowServer.setUpClass() 0 11 1
1
import unittest
2
import threading
3
import _thread
4
import pyarrow
5
import os
6
import pyarrow.csv as csv
7
8
from tabpy.tabpy_server.app.arrow_server import FlightServer
9
import tabpy.tabpy_server.app.arrow_server as pa
10
11
class TestArrowServer(unittest.TestCase):
12
    @classmethod
13
    def setUpClass(cls):
14
        host = "localhost"
15
        port = 13620
16
        scheme = "grpc+tcp"
17
        location = "{}://{}:{}".format(scheme, host, port)
18
        cls.arrow_server = FlightServer(host, location)
19
        def start_server():
20
            pa.start(cls.arrow_server)
21
        _thread.start_new_thread(start_server, ())
22
        cls.arrow_client = pyarrow.flight.FlightClient(location)
23
    
24
    @classmethod
25
    def tearDownClass(cls):
26
        cls.arrow_server.shutdown()
27
    
28
    def setUp(self):
29
        self.resources_path = os.path.join(os.path.dirname(__file__), "resources")
30
        self.arrow_server.flights = {}
31
32
    def get_descriptor(self, data_path):
33
        return pyarrow.flight.FlightDescriptor.for_path(data_path)
34
35
    def write_data(self, data_path):
36
        table = csv.read_csv(data_path)
37
        descriptor = self.get_descriptor(data_path)
38
        writer, _ = self.arrow_client.do_put(descriptor, table.schema)
39
        writer.write_table(table)
40
        writer.close()
41
        return table
42
43
    def test_server_do_put(self):
44
        self.write_data(os.path.join(self.resources_path, "data.csv"))
45
        flight_info = list(self.arrow_server.list_flights(None, None))
46
        self.assertEqual(len(flight_info), 1)
47
48
    def test_server_do_get(self):
49
        table = self.write_data(os.path.join(self.resources_path, "data.csv"))
50
        descriptor = self.get_descriptor(os.path.join(self.resources_path, "data.csv"))
51
        self.assertEqual(len(self.arrow_server.flights), 1)
52
        info = self.arrow_client.get_flight_info(descriptor)
53
        reader = self.arrow_client.do_get(info.endpoints[0].ticket)
54
        self.assertTrue(reader.read_all().equals(table))
55
        self.assertEqual(len(self.arrow_server.flights), 0)
56
57
    def test_list_flights_on_new_server(self):
58
        flight_info = list(self.arrow_server.list_flights(None, None))
59
        self.assertEqual(len(flight_info), 0)
60