EvaluationPlaneHandler.get_arrow_data()   A
last analyzed

Complexity

Conditions 3

Size

Total Lines 11
Code Lines 11

Duplication

Lines 0
Ratio 0 %

Code Coverage

Tests 1
CRAP Score 9.561

Importance

Changes 0
Metric Value
eloc 11
dl 0
loc 11
ccs 1
cts 10
cp 0.1
rs 9.85
c 0
b 0
f 0
cc 3
nop 2
crap 9.561
1 1
import pandas
2 1
import pyarrow
3 1
import uuid
4
5 1
from tabpy.tabpy_server.handlers import BaseHandler
6 1
import json
7 1
import simplejson
8 1
import logging
9 1
from tabpy.tabpy_server.common.util import format_exception
10 1
import requests
11 1
from tornado import gen
12 1
from datetime import timedelta
13 1
from tabpy.tabpy_server.handlers.util import AuthErrorStates
14
15 1
class RestrictedTabPy:
16 1
    def __init__(self, protocol, port, logger, timeout, headers):
17 1
        self.protocol = protocol
18 1
        self.port = port
19 1
        self.logger = logger
20 1
        self.timeout = timeout
21 1
        self.headers = headers
22
23 1
    def query(self, name, *args, **kwargs):
24
        url = f"{self.protocol}://localhost:{self.port}/query/{name}"
25
        self.logger.log(logging.DEBUG, f"Querying {url}...")
26
        internal_data = {"data": args or kwargs}
27
        data = json.dumps(internal_data)
28
        headers = self.headers
29
        response = requests.post(
30
            url=url, data=data, headers=headers, timeout=self.timeout, verify=False
31
        )
32
        return response.json()
33
34
35 1
class EvaluationPlaneDisabledHandler(BaseHandler):
36
    """
37
    EvaluationPlaneDisabledHandler responds with error message when ad-hoc scripts have been disabled.
38
    """
39
40 1
    def initialize(self, executor, app):
41 1
        super(EvaluationPlaneDisabledHandler, self).initialize(app)
42 1
        self.executor = executor
43
44 1
    @gen.coroutine
45 1
    def post(self):
46 1
        if self.should_fail_with_auth_error() != AuthErrorStates.NONE:
47 1
            self.fail_with_auth_error()
48 1
            return
49
50 1
        if not self.request_body_size_within_limit():
51
            return
52
            
53 1
        self.error_out(404, "Ad-hoc scripts have been disabled on this analytics extension, please contact your "
54
                            "administrator.")
55
56
57 1
class EvaluationPlaneHandler(BaseHandler):
58
    """
59
    EvaluationPlaneHandler is responsible for running arbitrary python scripts.
60
    """
61
62 1
    def initialize(self, executor, app):
63 1
        super(EvaluationPlaneHandler, self).initialize(app)
64 1
        self.arrow_server = app.arrow_server
65 1
        self.executor = executor
66 1
        self._error_message_timeout = (
67
            f"User defined script timed out. "
68
            f"Timeout is set to {self.eval_timeout} s."
69
        )
70
71 1
    @gen.coroutine
72 1
    def _post_impl(self):
73 1
        body = json.loads(self.request.body.decode("utf-8"))
74 1
        self.logger.log(logging.DEBUG, f"Processing POST request...")
75 1
        if "script" not in body:
76 1
            self.error_out(400, "Script is empty.")
77 1
            return
78
79
        # Transforming user script into a proper function.
80 1
        user_code = body["script"]
81 1
        arguments = None
82 1
        arguments_str = ""
83 1
        if self.arrow_server is not None and "dataPath" in body:
84
            # arrow flight scenario
85
            arrow_data = self.get_arrow_data(body["dataPath"])
86
            if arrow_data is not None:
87
                arguments = {"_arg1": arrow_data}
88 1
        elif "data" in body:
89
            # legacy scenario
90 1
            arguments = body["data"]
91
92 1
        if arguments is not None:
93 1
            if not isinstance(arguments, dict):
94
                self.error_out(
95
                    400, "Script parameters need to be provided as a dictionary."
96
                )
97
                return
98 1
            args_in = sorted(arguments.keys())
99 1
            n = len(arguments)
100 1
            if sorted('_arg'+str(i+1) for i in range(n)) == args_in:
101 1
                arguments_str = ", " + ", ".join(args_in)
102
            else:
103 1
                self.error_out(
104
                    400,
105
                    "Variables names should follow "
106
                    "the format _arg1, _arg2, _argN",
107
                )
108 1
                return
109 1
        function_to_evaluate = f"def _user_script(tabpy{arguments_str}):\n"
110 1
        for u in user_code.splitlines():
111 1
            function_to_evaluate += " " + u + "\n"
112
113 1
        self.logger.log(
114
            logging.INFO, f"function to evaluate={function_to_evaluate}"
115
        )
116
117 1
        try:
118 1
            result = yield self._call_subprocess(function_to_evaluate, arguments)
119 1
        except (
120
            gen.TimeoutError,
121
            requests.exceptions.ConnectTimeout,
122
            requests.exceptions.ReadTimeout,
123
        ):
124
            self.logger.log(logging.ERROR, self._error_message_timeout)
125
            self.error_out(408, self._error_message_timeout)
126
            return
127
128 1
        if result is not None:
129 1
            if self.arrow_server is not None and "dataPath" in body:
130
                # arrow flight scenario
131
                output_data_id = str(uuid.uuid4())
132
                self.upload_arrow_data(result, output_data_id, {
133
                    'removeOnDelete': 'True',
134
                    'linkedIDs': body["dataPath"]
135
                })
136
                result = { 'outputDataPath': output_data_id }
137
                self.logger.log(logging.WARN, f'outputDataPath={output_data_id}')
138
            else:
139 1
                if isinstance(result, pandas.DataFrame):
140
                    result = result.to_dict(orient='list')
141 1
            self.write(simplejson.dumps(result, ignore_nan=True))
142
        else:
143 1
            self.write("null")
144 1
        self.finish()
145
146 1
    def get_arrow_data(self, filename):
147
        descriptor = pyarrow.flight.FlightDescriptor.for_path(filename)
148
        info = self.arrow_server.get_flight_info(None, descriptor)
149
        for endpoint in info.endpoints:
150
            for location in endpoint.locations:
151
                key = (descriptor.descriptor_type.value, descriptor.command,
152
                       tuple(descriptor.path or tuple()))
153
                df = self.arrow_server.flights.pop(key).to_pandas()
154
                return df
155
        self.logger.log(logging.INFO, f'no data found for {filename}')
156
        return ''
157
158 1
    def upload_arrow_data(self, data, filename, metadata):
159
        my_table = pyarrow.table(data)
160
        if metadata is not None:
161
            my_table.schema.with_metadata(metadata)
162
        descriptor = pyarrow.flight.FlightDescriptor.for_path(filename)
163
        key = (descriptor.descriptor_type.value, descriptor.command,
164
                tuple(descriptor.path or tuple()))
165
        self.arrow_server.flights[key] = my_table
166
167 1
    @gen.coroutine
168 1
    def post(self):
169 1
        if self.should_fail_with_auth_error() != AuthErrorStates.NONE:
170 1
            self.fail_with_auth_error()
171 1
            return
172
        
173 1
        if not self.request_body_size_within_limit():
174 1
            return
175
176 1
        self._add_CORS_header()
177 1
        try:
178 1
            yield self._post_impl()
179 1
        except Exception as e:
180 1
            import traceback
181 1
            self.logger.log(logging.ERROR, traceback.format_exc())
182 1
            err_msg = f"{e.__class__.__name__} : {str(e)}"
183 1
            if err_msg != "KeyError : 'response'":
184 1
                err_msg = format_exception(e, "POST /evaluate")
185 1
                self.error_out(500, "Error processing script", info=err_msg)
186
            else:
187
                self.error_out(
188
                    404,
189
                    "Error processing script",
190
                    info="The endpoint you're "
191
                    "trying to query did not respond. Please make sure the "
192
                    "endpoint exists and the correct set of arguments are "
193
                    "provided.",
194
                )
195
196 1
    @gen.coroutine
197 1
    def _call_subprocess(self, function_to_evaluate, arguments):
198 1
        restricted_tabpy = RestrictedTabPy(
199
            self.protocol, self.port, self.logger, self.eval_timeout, self.request.headers
200
        )
201
        # Exec does not run the function, so it does not block.
202 1
        exec(function_to_evaluate, globals())
203
204
        # 'noqa' comments below tell flake8 to ignore undefined _user_script
205
        # name - the name is actually defined with user script being wrapped
206
        # in _user_script function (constructed as a striong) and then executed
207
        # with exec() call above.
208 1
        future = self.executor.submit(_user_script,  # noqa: F821
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable _user_script does not seem to be defined.
Loading history...
209
                                      restricted_tabpy,
210
                                      **arguments if arguments is not None else None)
211
212 1
        ret = yield gen.with_timeout(timedelta(seconds=self.eval_timeout), future)
213
        raise gen.Return(ret)
214