Test Failed
Pull Request — master (#593)
by
unknown
13:45
created

push_data()   A

Complexity

Conditions 1

Size

Total Lines 10
Code Lines 10

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 10
dl 0
loc 10
rs 9.9
c 0
b 0
f 0
cc 1
nop 3
1
# Licensed to the Apache Software Foundation (ASF) under one
2
# or more contributor license agreements.  See the NOTICE file
3
# distributed with this work for additional information
4
# regarding copyright ownership.  The ASF licenses this file
5
# to you under the Apache License, Version 2.0 (the
6
# "License"); you may not use this file except in compliance
7
# with the License.  You may obtain a copy of the License at
8
#
9
#   http://www.apache.org/licenses/LICENSE-2.0
10
#
11
# Unless required by applicable law or agreed to in writing,
12
# software distributed under the License is distributed on an
13
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
# KIND, either express or implied.  See the License for the
15
# specific language governing permissions and limitations
16
# under the License.
17
18
"""An example Flight CLI client."""
19
20
import argparse
21
import sys
22
23
import pyarrow
24
import pyarrow.flight
25
import pyarrow.csv as csv
26
27
28
def list_flights(args, client, connection_args={}):
29
    print('Flights\n=======')
30
    for flight in client.list_flights():
31
        descriptor = flight.descriptor
32
        if descriptor.descriptor_type == pyarrow.flight.DescriptorType.PATH:
33
            print("Path:", descriptor.path)
34
        elif descriptor.descriptor_type == pyarrow.flight.DescriptorType.CMD:
35
            print("Command:", descriptor.command)
36
        else:
37
            print("Unknown descriptor type")
38
39
        print("Total records:", end=" ")
40
        if flight.total_records >= 0:
41
            print(flight.total_records)
42
        else:
43
            print("Unknown")
44
45
        print("Total bytes:", end=" ")
46
        if flight.total_bytes >= 0:
47
            print(flight.total_bytes)
48
        else:
49
            print("Unknown")
50
51
        print("Number of endpoints:", len(flight.endpoints))
52
        print("Schema:")
53
        print(flight.schema)
54
        print('---')
55
56
    print('\nActions\n=======')
57
    for action in client.list_actions():
58
        print("Type:", action.type)
59
        print("Description:", action.description)
60
        print('---')
61
62
63
def do_action(args, client, connection_args={}):
64
    try:
65
        buf = pyarrow.allocate_buffer(0)
66
        action = pyarrow.flight.Action(args.action_type, buf)
67
        print('Running action', args.action_type)
68
        for result in client.do_action(action):
69
            print("Got result", result.body.to_pybytes())
70
    except pyarrow.lib.ArrowIOError as e:
71
        print("Error calling action:", e)
72
73
74
def push_data(args, client, connection_args={}):
75
    print('File Name:', args.file)
76
    my_table = csv.read_csv(args.file)
77
    print('Table rows=', str(len(my_table)))
78
    df = my_table.to_pandas()
79
    print(df.head())
80
    writer, _ = client.do_put(
81
        pyarrow.flight.FlightDescriptor.for_path(args.file), my_table.schema)
82
    writer.write_table(my_table)
83
    writer.close()
84
85
86
def upload_data(client, data, filename, metadata=None):
87
    my_table = pyarrow.table(data)
88
    if metadata is not None:
89
        my_table.schema.with_metadata(metadata)
90
    print('Table rows=', str(len(my_table)))
91
    print("Uploading", data.head())
92
    writer, _ = client.do_put(
93
        pyarrow.flight.FlightDescriptor.for_path(filename), my_table.schema)
94
    writer.write_table(my_table)
95
    writer.close()
96
97
98
def get_flight_by_path(path, client, connection_args={}):
99
    descriptor = pyarrow.flight.FlightDescriptor.for_path(path)
100
101
    info = client.get_flight_info(descriptor)
102
    for endpoint in info.endpoints:
103
        print('Ticket:', endpoint.ticket)
104
        for location in endpoint.locations:
105
            print(location)
106
            get_client = pyarrow.flight.FlightClient(location,
107
                                                     **connection_args)
108
            reader = get_client.do_get(endpoint.ticket)
109
            df = reader.read_pandas()
110
            print(df)
111
            return df
112
    print("no data found for get")
113
    return ''
114
115
def _add_common_arguments(parser):
116
    parser.add_argument('--tls', action='store_true',
117
                        help='Enable transport-level security')
118
    parser.add_argument('--tls-roots', default=None,
119
                        help='Path to trusted TLS certificate(s)')
120
    parser.add_argument("--mtls", nargs=2, default=None,
121
                        metavar=('CERTFILE', 'KEYFILE'),
122
                        help="Enable transport-level security")
123
    parser.add_argument('host', type=str,
124
                        help="Address or hostname to connect to")
125
126
127
def main():
128
    parser = argparse.ArgumentParser()
129
    subcommands = parser.add_subparsers()
130
131
    cmd_list = subcommands.add_parser('list')
132
    cmd_list.set_defaults(action='list')
133
    _add_common_arguments(cmd_list)
134
    cmd_list.add_argument('-l', '--list', action='store_true',
135
                          help="Print more details.")
136
137
    cmd_do = subcommands.add_parser('do')
138
    cmd_do.set_defaults(action='do')
139
    _add_common_arguments(cmd_do)
140
    cmd_do.add_argument('action_type', type=str,
141
                        help="The action type to run.")
142
143
    cmd_put = subcommands.add_parser('put')
144
    cmd_put.set_defaults(action='put')
145
    _add_common_arguments(cmd_put)
146
    cmd_put.add_argument('file', type=str,
147
                         help="CSV file to upload.")
148
149
    cmd_get = subcommands.add_parser('get')
150
    cmd_get.set_defaults(action='get')
151
    _add_common_arguments(cmd_get)
152
    cmd_get_descriptor = cmd_get.add_mutually_exclusive_group(required=True)
153
    cmd_get_descriptor.add_argument('-p', '--path', type=str, action='append',
154
                                    help="The path for the descriptor.")
155
    cmd_get_descriptor.add_argument('-c', '--command', type=str,
156
                                    help="The command for the descriptor.")
157
158
    args = parser.parse_args()
159
    if not hasattr(args, 'action'):
160
        parser.print_help()
161
        sys.exit(1)
162
163
    commands = {
164
        'list': list_flights,
165
        'do': do_action,
166
        'get': get_flight_by_path,
167
        'put': push_data,
168
    }
169
    host, port = args.host.split(':')
170
    port = int(port)
171
    scheme = "grpc+tcp"
172
    connection_args = {}
173
    if args.tls:
174
        scheme = "grpc+tls"
175
        if args.tls_roots:
176
            with open(args.tls_roots, "rb") as root_certs:
177
                connection_args["tls_root_certs"] = root_certs.read()
178
    if args.mtls:
179
        with open(args.mtls[0], "rb") as cert_file:
180
            tls_cert_chain = cert_file.read()
181
        with open(args.mtls[1], "rb") as key_file:
182
            tls_private_key = key_file.read()
183
        connection_args["cert_chain"] = tls_cert_chain
184
        connection_args["private_key"] = tls_private_key
185
    client = pyarrow.flight.FlightClient(f"{scheme}://{host}:{port}",
186
                                         **connection_args)
187
    while True:
188
        try:
189
            action = pyarrow.flight.Action("healthcheck", b"")
190
            options = pyarrow.flight.FlightCallOptions(timeout=1)
191
            list(client.do_action(action, options=options))
192
            break
193
        except pyarrow.ArrowIOError as e:
194
            if "Deadline" in str(e):
195
                print("Server is not ready, waiting...")
196
    commands[args.action](args, client, connection_args)
197
198
199
200
if __name__ == '__main__':
201
    main()
202