Passed
Push — master ( fe0a6e...5716a6 )
by
unknown
13:10 queued 15s
created

tabpy.tabpy_server.handlers.query_plane_handler   A

Complexity

Total Complexity 33

Size/Duplication

Total Lines 241
Duplicated Lines 0 %

Test Coverage

Coverage 16.92%

Importance

Changes 0
Metric Value
wmc 33
eloc 151
dl 0
loc 241
ccs 22
cts 130
cp 0.1692
rs 9.76
c 0
b 0
f 0

9 Methods

Rating   Name   Duplication   Size   Complexity  
A QueryPlaneHandler.initialize() 0 2 1
A QueryPlaneHandler._get_actual_model() 0 26 5
A QueryPlaneHandler._sanitize_request_data() 0 14 4
A QueryPlaneHandler._handle_result() 0 29 4
B QueryPlaneHandler._process_query() 0 60 8
A QueryPlaneHandler.options() 0 10 2
A QueryPlaneHandler._query() 0 38 2
A QueryPlaneHandler.get() 0 12 3
A QueryPlaneHandler.post() 0 14 3

1 Function

Rating   Name   Duplication   Size   Complexity  
A _get_uuid() 0 3 1
1 1
from tabpy.tabpy_server.handlers import BaseHandler
2 1
import logging
3 1
import time
4 1
from tabpy.tabpy_server.common.messages import (
5
    Query,
6
    QuerySuccessful,
7
    QueryError,
8
    UnknownURI,
9
)
10 1
from hashlib import md5
11 1
import uuid
12 1
import json
13 1
from tabpy.tabpy_server.common.util import format_exception
14 1
import urllib
15 1
from tornado import gen
16 1
from tabpy.tabpy_server.handlers.util import AuthErrorStates
17
18
19 1
def _get_uuid():
20
    """Generate a unique identifier string"""
21
    return str(uuid.uuid4())
22
23
24 1
class QueryPlaneHandler(BaseHandler):
25 1
    def initialize(self, app):
26
        super(QueryPlaneHandler, self).initialize(app)
27
28 1
    def _query(self, po_name, data, uid, qry):
29
        """
30
        Parameters
31
        ----------
32
        po_name : str
33
            The name of the query object to query
34
35
        data : dict
36
            The deserialized request body
37
38
        uid: str
39
            A unique identifier for the request
40
41
        qry: str
42
            The incoming query object. This object maintains
43
            raw incoming request, which is different from the sanitied data
44
45
        Returns
46
        -------
47
        out : (result type, dict, int)
48
            A triple containing a result type, the result message
49
            as a dictionary, and the time in seconds that it took to complete
50
            the request.
51
        """
52
        self.logger.log(logging.DEBUG, f"Collecting query info for {po_name}...")
53
        start_time = time.time()
54
        response = self.python_service.ps.query(po_name, data, uid)
55
        gls_time = time.time() - start_time
56
        self.logger.log(logging.DEBUG, f"Query info: {response}")
57
58
        if isinstance(response, QuerySuccessful):
59
            response_json = response.to_json()
60
            md5_tag = md5(response_json.encode("utf-8")).hexdigest()
61
            self.set_header("Etag", f'"{md5_tag}"')
62
            return (QuerySuccessful, response.for_json(), gls_time)
63
        else:
64
            self.logger.log(logging.ERROR, f"Failed query, response: {response}")
65
            return (type(response), response.for_json(), gls_time)
66
67
    # handle HTTP Options requests to support CORS
68
    # don't check API key (client does not send or receive data for OPTIONS,
69
    # it just allows the client to subsequently make a POST request)
70 1
    def options(self, pred_name):
71
        if self.should_fail_with_auth_error() != AuthErrorStates.NONE:
72
            self.fail_with_auth_error()
73
            return
74
75
        self.logger.log(logging.DEBUG, f"Processing OPTIONS for /query/{pred_name}")
76
77
        # add CORS headers if TabPy has a cors_origin specified
78
        self._add_CORS_header()
79
        self.write({})
80
81 1
    def _handle_result(self, po_name, data, qry, uid):
82
        (response_type, response, gls_time) = self._query(po_name, data, uid, qry)
83
84
        if response_type == QuerySuccessful:
85
            result_dict = {
86
                "response": response["response"],
87
                "version": response["version"],
88
                "model": po_name,
89
                "uuid": uid,
90
            }
91
            self.write(result_dict)
92
            self.finish()
93
            return (gls_time, response["response"])
94
        else:
95
            if response_type == UnknownURI:
96
                self.error_out(
97
                    404,
98
                    "UnknownURI",
99
                    info=(
100
                        "No query object has been registered"
101
                        f' with the name "{po_name}"'
102
                    ),
103
                )
104
            elif response_type == QueryError:
105
                self.error_out(400, "QueryError", info=response)
106
            else:
107
                self.error_out(500, f"Error querying function '{po_name}'", info=response)
108
109
            return (None, None)
110
111 1
    def _sanitize_request_data(self, data):
112
        if not isinstance(data, dict):
113
            msg = "Input data must be a dictionary"
114
            self.logger.log(logging.CRITICAL, msg)
115
            raise RuntimeError(msg)
116
117
        if "method" in data:
118
            return {"data": data.get("data"), "method": data.get("method")}
119
        elif "data" in data:
120
            return data.get("data")
121
        else:
122
            msg = 'Input data must be a dictionary with a key called "data"'
123
            self.logger.log(logging.CRITICAL, msg)
124
            raise RuntimeError(msg)
125
126 1
    def _process_query(self, endpoint_name, start):
127
        self.logger.log(logging.DEBUG, f"Processing query {endpoint_name}...")
128
        try:
129
            self._add_CORS_header()
130
131
            if not self.request.body:
132
                self.request.body = {}
133
134
            # extract request data explicitly for caching purpose
135
            request_json = self.request.body.decode("utf-8")
136
137
            # Sanitize input data
138
            data = self._sanitize_request_data(json.loads(request_json))
139
        except Exception as e:
140
            self.logger.log(logging.ERROR, str(e))
141
            err_msg = format_exception(e, "Invalid Input Data")
142
            self.error_out(400, err_msg)
143
            return
144
145
        try:
146
            (po_name, _) = self._get_actual_model(endpoint_name)
147
148
            # po_name is None if self.python_service.ps.query_objects.get(
149
            # endpoint_name) is None
150
            if not po_name:
151
                self.error_out(
152
                    404, "UnknownURI", info=f"Endpoint '{endpoint_name}' does not exist"
153
                )
154
                return
155
156
            po_obj = self.python_service.ps.query_objects.get(po_name)
157
158
            if not po_obj:
159
                self.error_out(
160
                    404, "UnknownURI", info=f'Endpoint "{po_name}" does not exist'
161
                )
162
                return
163
164
            if po_name != endpoint_name:
165
                self.logger.log(
166
                    logging.INFO, f"Querying actual model: po_name={po_name}"
167
                )
168
169
            uid = _get_uuid()
170
171
            # record query w/ request ID in query log
172
            qry = Query(po_name, request_json)
173
            gls_time = 0
174
            # send a query to PythonService and return
175
            (gls_time, _) = self._handle_result(po_name, data, qry, uid)
176
177
            # if error occurred, GLS time is None.
178
            if not gls_time:
179
                return
180
181
        except Exception as e:
182
            self.logger.log(logging.ERROR, str(e))
183
            err_msg = format_exception(e, "process query")
184
            self.error_out(500, "Error processing query", info=err_msg)
185
            return
186
187 1
    def _get_actual_model(self, endpoint_name):
188
        # Find the actual query to run from given endpoint
189
        all_endpoint_names = []
190
191
        while True:
192
            endpoint_info = self.python_service.ps.query_objects.get(endpoint_name)
193
            if not endpoint_info:
194
                return [None, None]
195
196
            all_endpoint_names.append(endpoint_name)
197
198
            endpoint_type = endpoint_info.get("type", "model")
199
200
            if endpoint_type == "alias":
201
                endpoint_name = endpoint_info["endpoint_obj"]
202
            elif endpoint_type == "model":
203
                break
204
            else:
205
                self.error_out(
206
                    500,
207
                    "Unknown endpoint type",
208
                    info=f'Endpoint type "{endpoint_type}" does not exist',
209
                )
210
                return
211
212
        return (endpoint_name, all_endpoint_names)
213
214 1
    @gen.coroutine
215
    def get(self, endpoint_name):
216
        if self.should_fail_with_auth_error() != AuthErrorStates.NONE:
217
            self.fail_with_auth_error()
218
            return
219
220
        if not self.request_body_size_within_limit():
221
            return
222
223
        start = time.time()
224
        endpoint_name = urllib.parse.unquote(endpoint_name)
225
        self._process_query(endpoint_name, start)
226
227 1
    @gen.coroutine
228
    def post(self, endpoint_name):
229
        self.logger.log(logging.DEBUG, f"Processing POST for /query/{endpoint_name}...")
230
231
        if self.should_fail_with_auth_error() != AuthErrorStates.NONE:
232
            self.fail_with_auth_error()
233
            return
234
235
        if not self.request_body_size_within_limit():
236
            return
237
238
        start = time.time()
239
        endpoint_name = urllib.parse.unquote(endpoint_name)
240
        self._process_query(endpoint_name, start)
241