Test Failed
Push — master ( 94b1f7...ff0954 )
by Oleksandr
12:55 queued 02:06
created

tabpy.tabpy_server.handlers.query_plane_handler   A

Complexity

Total Complexity 31

Size/Duplication

Total Lines 234
Duplicated Lines 0 %

Test Coverage

Coverage 16.8%

Importance

Changes 0
Metric Value
wmc 31
eloc 146
dl 0
loc 234
ccs 21
cts 125
cp 0.168
rs 9.92
c 0
b 0
f 0

10 Functions

Rating   Name   Duplication   Size   Complexity  
B handlers.query_plane_handler.QueryPlaneHandler._process_query() 0 60 8
A handlers.query_plane_handler.QueryPlaneHandler.post() 0 11 2
A handlers.query_plane_handler.QueryPlaneHandler.options() 0 10 2
A handlers.query_plane_handler.QueryPlaneHandler.initialize() 0 2 1
A handlers.query_plane_handler.QueryPlaneHandler.get() 0 9 2
A handlers.query_plane_handler.QueryPlaneHandler._query() 0 38 2
A handlers.query_plane_handler.QueryPlaneHandler._get_actual_model() 0 26 5
A handlers.query_plane_handler._get_uuid() 0 3 1
A handlers.query_plane_handler.QueryPlaneHandler._handle_result() 0 29 4
A handlers.query_plane_handler.QueryPlaneHandler._sanitize_request_data() 0 14 4
1 1
from tabpy.tabpy_server.handlers import BaseHandler
2 1
import logging
3 1
import time
4 1
from tabpy.tabpy_server.common.messages import (
5
    Query,
6
    QuerySuccessful,
7
    QueryError,
8
    UnknownURI,
9
)
10 1
from hashlib import md5
11 1
import uuid
12 1
import json
13 1
from tabpy.tabpy_server.common.util import format_exception
14 1
import urllib
15 1
from tornado import gen
16
17
18 1
def _get_uuid():
19
    """Generate a unique identifier string"""
20
    return str(uuid.uuid4())
21
22
23 1
class QueryPlaneHandler(BaseHandler):
24 1
    def initialize(self, app):
25
        super(QueryPlaneHandler, self).initialize(app)
26
27 1
    def _query(self, po_name, data, uid, qry):
28
        """
29
        Parameters
30
        ----------
31
        po_name : str
32
            The name of the query object to query
33
34
        data : dict
35
            The deserialized request body
36
37
        uid: str
38
            A unique identifier for the request
39
40
        qry: str
41
            The incoming query object. This object maintains
42
            raw incoming request, which is different from the sanitied data
43
44
        Returns
45
        -------
46
        out : (result type, dict, int)
47
            A triple containing a result type, the result message
48
            as a dictionary, and the time in seconds that it took to complete
49
            the request.
50
        """
51
        self.logger.log(logging.DEBUG, f"Collecting query info for {po_name}...")
52
        start_time = time.time()
53
        response = self.python_service.ps.query(po_name, data, uid)
54
        gls_time = time.time() - start_time
55
        self.logger.log(logging.DEBUG, f"Query info: {response}")
56
57
        if isinstance(response, QuerySuccessful):
58
            response_json = response.to_json()
59
            md5_tag = md5(response_json.encode("utf-8")).hexdigest()
60
            self.set_header("Etag", f'"{md5_tag}"')
61
            return (QuerySuccessful, response.for_json(), gls_time)
62
        else:
63
            self.logger.log(logging.ERROR, f"Failed query, response: {response}")
64
            return (type(response), response.for_json(), gls_time)
65
66
    # handle HTTP Options requests to support CORS
67
    # don't check API key (client does not send or receive data for OPTIONS,
68
    # it just allows the client to subsequently make a POST request)
69 1
    def options(self, pred_name):
70
        if self.should_fail_with_not_authorized():
71
            self.fail_with_not_authorized()
72
            return
73
74
        self.logger.log(logging.DEBUG, f"Processing OPTIONS for /query/{pred_name}")
75
76
        # add CORS headers if TabPy has a cors_origin specified
77
        self._add_CORS_header()
78
        self.write({})
79
80 1
    def _handle_result(self, po_name, data, qry, uid):
81
        (response_type, response, gls_time) = self._query(po_name, data, uid, qry)
82
83
        if response_type == QuerySuccessful:
84
            result_dict = {
85
                "response": response["response"],
86
                "version": response["version"],
87
                "model": po_name,
88
                "uuid": uid,
89
            }
90
            self.write(result_dict)
91
            self.finish()
92
            return (gls_time, response["response"])
93
        else:
94
            if response_type == UnknownURI:
95
                self.error_out(
96
                    404,
97
                    "UnknownURI",
98
                    info=(
99
                        "No query object has been registered"
100
                        f' with the name "{po_name}"'
101
                    ),
102
                )
103
            elif response_type == QueryError:
104
                self.error_out(400, "QueryError", info=response)
105
            else:
106
                self.error_out(500, "Error querying GLS", info=response)
107
108
            return (None, None)
109
110 1
    def _sanitize_request_data(self, data):
111
        if not isinstance(data, dict):
112
            msg = "Input data must be a dictionary"
113
            self.logger.log(logging.CRITICAL, msg)
114
            raise RuntimeError(msg)
115
116
        if "method" in data:
117
            return {"data": data.get("data"), "method": data.get("method")}
118
        elif "data" in data:
119
            return data.get("data")
120
        else:
121
            msg = 'Input data must be a dictionary with a key called "data"'
122
            self.logger.log(logging.CRITICAL, msg)
123
            raise RuntimeError(msg)
124
125 1
    def _process_query(self, endpoint_name, start):
126
        self.logger.log(logging.DEBUG, f"Processing query {endpoint_name}...")
127
        try:
128
            self._add_CORS_header()
129
130
            if not self.request.body:
131
                self.request.body = {}
132
133
            # extract request data explicitly for caching purpose
134
            request_json = self.request.body.decode("utf-8")
135
136
            # Sanitize input data
137
            data = self._sanitize_request_data(json.loads(request_json))
138
        except Exception as e:
139
            self.logger.log(logging.ERROR, str(e))
140
            err_msg = format_exception(e, "Invalid Input Data")
141
            self.error_out(400, err_msg)
142
            return
143
144
        try:
145
            (po_name, _) = self._get_actual_model(endpoint_name)
146
147
            # po_name is None if self.python_service.ps.query_objects.get(
148
            # endpoint_name) is None
149
            if not po_name:
150
                self.error_out(
151
                    404, "UnknownURI", info=f'Endpoint "{endpoint_name}" does not exist'
152
                )
153
                return
154
155
            po_obj = self.python_service.ps.query_objects.get(po_name)
156
157
            if not po_obj:
158
                self.error_out(
159
                    404, "UnknownURI", info=f'Endpoint "{po_name}" does not exist'
160
                )
161
                return
162
163
            if po_name != endpoint_name:
164
                self.logger.log(
165
                    logging.INFO, f"Querying actual model: po_name={po_name}"
166
                )
167
168
            uid = _get_uuid()
169
170
            # record query w/ request ID in query log
171
            qry = Query(po_name, request_json)
172
            gls_time = 0
173
            # send a query to PythonService and return
174
            (gls_time, _) = self._handle_result(po_name, data, qry, uid)
175
176
            # if error occurred, GLS time is None.
177
            if not gls_time:
178
                return
179
180
        except Exception as e:
181
            self.logger.log(logging.ERROR, str(e))
182
            err_msg = format_exception(e, "process query")
183
            self.error_out(500, "Error processing query", info=err_msg)
184
            return
185
186 1
    def _get_actual_model(self, endpoint_name):
187
        # Find the actual query to run from given endpoint
188
        all_endpoint_names = []
189
190
        while True:
191
            endpoint_info = self.python_service.ps.query_objects.get(endpoint_name)
192
            if not endpoint_info:
193
                return [None, None]
194
195
            all_endpoint_names.append(endpoint_name)
196
197
            endpoint_type = endpoint_info.get("type", "model")
198
199
            if endpoint_type == "alias":
200
                endpoint_name = endpoint_info["endpoint_obj"]
201
            elif endpoint_type == "model":
202
                break
203
            else:
204
                self.error_out(
205
                    500,
206
                    "Unknown endpoint type",
207
                    info=f'Endpoint type "{endpoint_type}" does not exist',
208
                )
209
                return
210
211
        return (endpoint_name, all_endpoint_names)
212
213 1
    @gen.coroutine
214
    def get(self, endpoint_name):
215
        if self.should_fail_with_not_authorized():
216
            self.fail_with_not_authorized()
217
            return
218
219
        start = time.time()
220
        endpoint_name = urllib.parse.unquote(endpoint_name)
221
        self._process_query(endpoint_name, start)
222
223 1
    @gen.coroutine
224
    def post(self, endpoint_name):
225
        self.logger.log(logging.DEBUG, f"Processing POST for /query/{endpoint_name}...")
226
227
        if self.should_fail_with_not_authorized():
228
            self.fail_with_not_authorized()
229
            return
230
231
        start = time.time()
232
        endpoint_name = urllib.parse.unquote(endpoint_name)
233
        self._process_query(endpoint_name, start)
234