Test Failed
Pull Request — master (#566)
by
unknown
03:46
created

tabpy.tabpy_server.handlers.query_plane_handler.QueryPlaneHandler.post()   A

Complexity

Conditions 2

Size

Total Lines 11
Code Lines 9

Duplication

Lines 0
Ratio 0 %

Code Coverage

Tests 1
CRAP Score 4.6796

Importance

Changes 0
Metric Value
eloc 9
dl 0
loc 11
ccs 1
cts 8
cp 0.125
rs 9.95
c 0
b 0
f 0
cc 2
nop 2
crap 4.6796
1 1
from tabpy.tabpy_server.handlers import BaseHandler
2 1
import logging
3 1
import time
4 1
from tabpy.tabpy_server.common.messages import (
5
    Query,
6
    QuerySuccessful,
7
    QueryError,
8
    UnknownURI,
9
)
10 1
from hashlib import md5
11 1
import uuid
12 1
import json
13 1
from tabpy.tabpy_server.common.util import format_exception
14 1
import urllib
15 1
from tornado import gen
16 1
from tabpy.tabpy_server.handlers.util import AuthErrorStates
17
18
19 1
def _get_uuid():
20
    """Generate a unique identifier string"""
21
    return str(uuid.uuid4())
22
23
24 1
class QueryPlaneHandler(BaseHandler):
25 1
    def initialize(self, app):
26
        super(QueryPlaneHandler, self).initialize(app)
27
28 1
    def _query(self, po_name, data, uid, qry):
29
        """
30
        Parameters
31
        ----------
32
        po_name : str
33
            The name of the query object to query
34
35
        data : dict
36
            The deserialized request body
37
38
        uid: str
39
            A unique identifier for the request
40
41
        qry: str
42
            The incoming query object. This object maintains
43
            raw incoming request, which is different from the sanitied data
44
45
        Returns
46
        -------
47
        out : (result type, dict, int)
48
            A triple containing a result type, the result message
49
            as a dictionary, and the time in seconds that it took to complete
50
            the request.
51
        """
52
        self.logger.log(logging.DEBUG, f"Collecting query info for {po_name}...")
53
        start_time = time.time()
54
        response = self.python_service.ps.query(po_name, data, uid)
55
        gls_time = time.time() - start_time
56
        self.logger.log(logging.DEBUG, f"Query info: {response}")
57
58
        if isinstance(response, QuerySuccessful):
59
            response_json = response.to_json()
60
            md5_tag = md5(response_json.encode("utf-8")).hexdigest()
61
            self.set_header("Etag", f'"{md5_tag}"')
62
            self.set_header(“Strict-Transport-Security”, “preload; max-age=2592000")
63
            return (QuerySuccessful, response.for_json(), gls_time)
64
        else:
65
            self.logger.log(logging.ERROR, f"Failed query, response: {response}")
66
            return (type(response), response.for_json(), gls_time)
67
68
    # handle HTTP Options requests to support CORS
69
    # don't check API key (client does not send or receive data for OPTIONS,
70 1
    # it just allows the client to subsequently make a POST request)
71
    def options(self, pred_name):
72
        if self.should_fail_with_auth_error() != AuthErrorStates.NONE:
73
            self.fail_with_auth_error()
74
            return
75
76
        self.logger.log(logging.DEBUG, f"Processing OPTIONS for /query/{pred_name}")
77
78
        # add CORS headers if TabPy has a cors_origin specified
79
        self._add_CORS_header()
80
        self.write({})
81 1
82
    def _handle_result(self, po_name, data, qry, uid):
83
        (response_type, response, gls_time) = self._query(po_name, data, uid, qry)
84
85
        if response_type == QuerySuccessful:
86
            result_dict = {
87
                "response": response["response"],
88
                "version": response["version"],
89
                "model": po_name,
90
                "uuid": uid,
91
            }
92
            self.write(result_dict)
93
            self.finish()
94
            return (gls_time, response["response"])
95
        else:
96
            if response_type == UnknownURI:
97
                self.error_out(
98
                    404,
99
                    "UnknownURI",
100
                    info=(
101
                        "No query object has been registered"
102
                        f' with the name "{po_name}"'
103
                    ),
104
                )
105
            elif response_type == QueryError:
106
                self.error_out(400, "QueryError", info=response)
107
            else:
108
                self.error_out(500, f"Error querying function '{po_name}'", info=response)
109
110
            return (None, None)
111 1
112
    def _sanitize_request_data(self, data):
113
        if not isinstance(data, dict):
114
            msg = "Input data must be a dictionary"
115
            self.logger.log(logging.CRITICAL, msg)
116
            raise RuntimeError(msg)
117
118
        if "method" in data:
119
            return {"data": data.get("data"), "method": data.get("method")}
120
        elif "data" in data:
121
            return data.get("data")
122
        else:
123
            msg = 'Input data must be a dictionary with a key called "data"'
124
            self.logger.log(logging.CRITICAL, msg)
125
            raise RuntimeError(msg)
126 1
127
    def _process_query(self, endpoint_name, start):
128
        self.logger.log(logging.DEBUG, f"Processing query {endpoint_name}...")
129
        try:
130
            self._add_CORS_header()
131
132
            if not self.request.body:
133
                self.request.body = {}
134
135
            # extract request data explicitly for caching purpose
136
            request_json = self.request.body.decode("utf-8")
137
138
            # Sanitize input data
139
            data = self._sanitize_request_data(json.loads(request_json))
140
        except Exception as e:
141
            self.logger.log(logging.ERROR, str(e))
142
            err_msg = format_exception(e, "Invalid Input Data")
143
            self.error_out(400, err_msg)
144
            return
145
146
        try:
147
            (po_name, _) = self._get_actual_model(endpoint_name)
148
149
            # po_name is None if self.python_service.ps.query_objects.get(
150
            # endpoint_name) is None
151
            if not po_name:
152
                self.error_out(
153
                    404, "UnknownURI", info=f"Endpoint '{endpoint_name}' does not exist"
154
                )
155
                return
156
157
            po_obj = self.python_service.ps.query_objects.get(po_name)
158
159
            if not po_obj:
160
                self.error_out(
161
                    404, "UnknownURI", info=f'Endpoint "{po_name}" does not exist'
162
                )
163
                return
164
165
            if po_name != endpoint_name:
166
                self.logger.log(
167
                    logging.INFO, f"Querying actual model: po_name={po_name}"
168
                )
169
170
            uid = _get_uuid()
171
172
            # record query w/ request ID in query log
173
            qry = Query(po_name, request_json)
174
            gls_time = 0
175
            # send a query to PythonService and return
176
            (gls_time, _) = self._handle_result(po_name, data, qry, uid)
177
178
            # if error occurred, GLS time is None.
179
            if not gls_time:
180
                return
181
182
        except Exception as e:
183
            self.logger.log(logging.ERROR, str(e))
184
            err_msg = format_exception(e, "process query")
185
            self.error_out(500, "Error processing query", info=err_msg)
186
            return
187 1
188
    def _get_actual_model(self, endpoint_name):
189
        # Find the actual query to run from given endpoint
190
        all_endpoint_names = []
191
192
        while True:
193
            endpoint_info = self.python_service.ps.query_objects.get(endpoint_name)
194
            if not endpoint_info:
195
                return [None, None]
196
197
            all_endpoint_names.append(endpoint_name)
198
199
            endpoint_type = endpoint_info.get("type", "model")
200
201
            if endpoint_type == "alias":
202
                endpoint_name = endpoint_info["endpoint_obj"]
203
            elif endpoint_type == "model":
204
                break
205
            else:
206
                self.error_out(
207
                    500,
208
                    "Unknown endpoint type",
209
                    info=f'Endpoint type "{endpoint_type}" does not exist',
210
                )
211
                return
212
213
        return (endpoint_name, all_endpoint_names)
214 1
215
    @gen.coroutine
216
    def get(self, endpoint_name):
217
        if self.should_fail_with_auth_error() != AuthErrorStates.NONE:
218
            self.fail_with_auth_error()
219
            return
220
221
        start = time.time()
222
        endpoint_name = urllib.parse.unquote(endpoint_name)
223
        self._process_query(endpoint_name, start)
224 1
225
    @gen.coroutine
226
    def post(self, endpoint_name):
227
        self.logger.log(logging.DEBUG, f"Processing POST for /query/{endpoint_name}...")
228
229
        if self.should_fail_with_auth_error() != AuthErrorStates.NONE:
230
            self.fail_with_auth_error()
231
            return
232
233
        start = time.time()
234
        endpoint_name = urllib.parse.unquote(endpoint_name)
235
        self._process_query(endpoint_name, start)
236