Passed
Push — master ( fad680...3bdfcd )
by
unknown
13:18
created

EvaluationPlaneHandler.initialize()   A

Complexity

Conditions 1

Size

Total Lines 6
Code Lines 6

Duplication

Lines 0
Ratio 0 %

Code Coverage

Tests 5
CRAP Score 1

Importance

Changes 0
Metric Value
eloc 6
dl 0
loc 6
ccs 5
cts 5
cp 1
rs 10
c 0
b 0
f 0
cc 1
nop 3
crap 1
1 1
import pandas
2 1
import pyarrow
3 1
import uuid
4
5 1
from tabpy.tabpy_server.handlers import BaseHandler
6 1
import json
7 1
import simplejson
8 1
import logging
9 1
from tabpy.tabpy_server.common.util import format_exception
10 1
import requests
11 1
from tornado import gen
12 1
from datetime import timedelta
13 1
from tabpy.tabpy_server.handlers.util import AuthErrorStates
14
15 1
class RestrictedTabPy:
16 1
    def __init__(self, protocol, port, logger, timeout, headers):
17 1
        self.protocol = protocol
18 1
        self.port = port
19 1
        self.logger = logger
20 1
        self.timeout = timeout
21 1
        self.headers = headers
22
23 1
    def query(self, name, *args, **kwargs):
24
        url = f"{self.protocol}://localhost:{self.port}/query/{name}"
25
        self.logger.log(logging.DEBUG, f"Querying {url}...")
26
        internal_data = {"data": args or kwargs}
27
        data = json.dumps(internal_data)
28
        headers = self.headers
29
        response = requests.post(
30
            url=url, data=data, headers=headers, timeout=self.timeout, verify=False
31
        )
32
        return response.json()
33
34
35 1
class EvaluationPlaneDisabledHandler(BaseHandler):
36
    """
37
    EvaluationPlaneDisabledHandler responds with error message when ad-hoc scripts have been disabled.
38
    """
39
40 1
    def initialize(self, executor, app):
41 1
        super(EvaluationPlaneDisabledHandler, self).initialize(app)
42 1
        self.executor = executor
43
44 1
    @gen.coroutine
45
    def post(self):
46 1
        if self.should_fail_with_auth_error() != AuthErrorStates.NONE:
47 1
            self.fail_with_auth_error()
48 1
            return
49 1
        self.error_out(404, "Ad-hoc scripts have been disabled on this analytics extension, please contact your "
50
                            "administrator.")
51
52
53 1
class EvaluationPlaneHandler(BaseHandler):
54
    """
55
    EvaluationPlaneHandler is responsible for running arbitrary python scripts.
56
    """
57
58 1
    def initialize(self, executor, app):
59 1
        super(EvaluationPlaneHandler, self).initialize(app)
60 1
        self.arrow_server = app.arrow_server
61 1
        self.executor = executor
62 1
        self._error_message_timeout = (
63
            f"User defined script timed out. "
64
            f"Timeout is set to {self.eval_timeout} s."
65
        )
66
67 1
    @gen.coroutine
68
    def _post_impl(self):
69 1
        body = json.loads(self.request.body.decode("utf-8"))
70 1
        self.logger.log(logging.DEBUG, f"Processing POST request...")
71 1
        if "script" not in body:
72 1
            self.error_out(400, "Script is empty.")
73 1
            return
74
75
        # Transforming user script into a proper function.
76 1
        user_code = body["script"]
77 1
        arguments = None
78 1
        arguments_str = ""
79 1
        if self.arrow_server is not None and "dataPath" in body:
80
            # arrow flight scenario
81
            arrow_data = self.get_arrow_data(body["dataPath"])
82
            if arrow_data is not None:
83
                arguments = {"_arg1": arrow_data}
84 1
        elif "data" in body:
85
            # legacy scenario
86 1
            arguments = body["data"]
87
88 1
        if arguments is not None:
89 1
            if not isinstance(arguments, dict):
90
                self.error_out(
91
                    400, "Script parameters need to be provided as a dictionary."
92
                )
93
                return
94 1
            args_in = sorted(arguments.keys())
95 1
            n = len(arguments)
96 1
            if sorted('_arg'+str(i+1) for i in range(n)) == args_in:
97 1
                arguments_str = ", " + ", ".join(args_in)
98
            else:
99 1
                self.error_out(
100
                    400,
101
                    "Variables names should follow "
102
                    "the format _arg1, _arg2, _argN",
103
                )
104 1
                return
105 1
        function_to_evaluate = f"def _user_script(tabpy{arguments_str}):\n"
106 1
        for u in user_code.splitlines():
107 1
            function_to_evaluate += " " + u + "\n"
108
109 1
        self.logger.log(
110
            logging.INFO, f"function to evaluate={function_to_evaluate}"
111
        )
112
113 1
        try:
114 1
            result = yield self._call_subprocess(function_to_evaluate, arguments)
115 1
        except (
116
            gen.TimeoutError,
117
            requests.exceptions.ConnectTimeout,
118
            requests.exceptions.ReadTimeout,
119
        ):
120
            self.logger.log(logging.ERROR, self._error_message_timeout)
121
            self.error_out(408, self._error_message_timeout)
122
            return
123
124 1
        if result is not None:
125 1
            if self.arrow_server is not None and "dataPath" in body:
126
                # arrow flight scenario
127
                output_data_id = str(uuid.uuid4())
128
                self.upload_arrow_data(result, output_data_id, {
129
                    'removeOnDelete': 'True',
130
                    'linkedIDs': body["dataPath"]
131
                })
132
                result = { 'outputDataPath': output_data_id }
133
                self.logger.log(logging.WARN, f'outputDataPath={output_data_id}')
134
            else:
135 1
                if isinstance(result, pandas.DataFrame):
136
                    result = result.to_dict(orient='list')
137 1
            self.write(simplejson.dumps(result, ignore_nan=True))
138
        else:
139 1
            self.write("null")
140 1
        self.finish()
141
142 1
    def get_arrow_data(self, filename):
143
        descriptor = pyarrow.flight.FlightDescriptor.for_path(filename)
144
        info = self.arrow_server.get_flight_info(None, descriptor)
145
        for endpoint in info.endpoints:
146
            for location in endpoint.locations:
147
                key = (descriptor.descriptor_type.value, descriptor.command,
148
                       tuple(descriptor.path or tuple()))
149
                df = self.arrow_server.flights.pop(key).to_pandas()
150
                return df
151
        self.logger.log(logging.INFO, f'no data found for {filename}')
152
        return ''
153
154 1
    def upload_arrow_data(self, data, filename, metadata):
155
        my_table = pyarrow.table(data)
156
        if metadata is not None:
157
            my_table.schema.with_metadata(metadata)
158
        descriptor = pyarrow.flight.FlightDescriptor.for_path(filename)
159
        key = (descriptor.descriptor_type.value, descriptor.command,
160
                tuple(descriptor.path or tuple()))
161
        self.arrow_server.flights[key] = my_table
162
163 1
    @gen.coroutine
164
    def post(self):
165 1
        if self.should_fail_with_auth_error() != AuthErrorStates.NONE:
166 1
            self.fail_with_auth_error()
167 1
            return
168
169 1
        self._add_CORS_header()
170 1
        try:
171 1
            yield self._post_impl()
172 1
        except Exception as e:
173 1
            import traceback
174 1
            self.logger.log(logging.ERROR, traceback.format_exc())
175 1
            err_msg = f"{e.__class__.__name__} : {str(e)}"
176 1
            if err_msg != "KeyError : 'response'":
177 1
                err_msg = format_exception(e, "POST /evaluate")
178 1
                self.error_out(500, "Error processing script", info=err_msg)
179
            else:
180
                self.error_out(
181
                    404,
182
                    "Error processing script",
183
                    info="The endpoint you're "
184
                    "trying to query did not respond. Please make sure the "
185
                    "endpoint exists and the correct set of arguments are "
186
                    "provided.",
187
                )
188
189 1
    @gen.coroutine
190
    def _call_subprocess(self, function_to_evaluate, arguments):
191 1
        restricted_tabpy = RestrictedTabPy(
192
            self.protocol, self.port, self.logger, self.eval_timeout, self.request.headers
193
        )
194
        # Exec does not run the function, so it does not block.
195 1
        exec(function_to_evaluate, globals())
196
197
        # 'noqa' comments below tell flake8 to ignore undefined _user_script
198
        # name - the name is actually defined with user script being wrapped
199
        # in _user_script function (constructed as a striong) and then executed
200
        # with exec() call above.
201 1
        future = self.executor.submit(_user_script,  # noqa: F821
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable _user_script does not seem to be defined.
Loading history...
202
                                      restricted_tabpy,
203
                                      **arguments if arguments is not None else None)
204
205 1
        ret = yield gen.with_timeout(timedelta(seconds=self.eval_timeout), future)
206
        raise gen.Return(ret)
207