Passed
Push — master ( 227024...1d0698 )
by Oleksandr
02:44
created

tabpy.tabpy_server.handlers.query_plane_handler   A

Complexity

Total Complexity 31

Size/Duplication

Total Lines 233
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 31
eloc 152
dl 0
loc 233
rs 9.92
c 0
b 0
f 0

9 Methods

Rating   Name   Duplication   Size   Complexity  
A QueryPlaneHandler.initialize() 0 2 1
A QueryPlaneHandler.get() 0 9 2
B QueryPlaneHandler._get_actual_model() 0 26 5
A QueryPlaneHandler.post() 0 12 2
A QueryPlaneHandler._sanitize_request_data() 0 14 4
A QueryPlaneHandler._handle_result() 0 25 4
B QueryPlaneHandler._process_query() 0 60 8
A QueryPlaneHandler.options() 0 12 2
A QueryPlaneHandler._query() 0 41 2

1 Function

Rating   Name   Duplication   Size   Complexity  
A _get_uuid() 0 3 1
1
from tabpy.tabpy_server.handlers import BaseHandler
2
import logging
3
import time
4
from tabpy.tabpy_server.common.messages import (
5
    Query, QuerySuccessful, QueryError, UnknownURI)
6
from hashlib import md5
7
import uuid
8
import json
9
from tabpy.tabpy_server.common.util import format_exception
10
import urllib
11
import tornado.web
12
from tornado import gen
13
14
15
def _get_uuid():
16
    """Generate a unique identifier string"""
17
    return str(uuid.uuid4())
18
19
20
class QueryPlaneHandler(BaseHandler):
21
    def initialize(self, app):
22
        super(QueryPlaneHandler, self).initialize(app)
23
24
    def _query(self, po_name, data, uid, qry):
25
        """
26
        Parameters
27
        ----------
28
        po_name : str
29
            The name of the query object to query
30
31
        data : dict
32
            The deserialized request body
33
34
        uid: str
35
            A unique identifier for the request
36
37
        qry: str
38
            The incoming query object. This object maintains
39
            raw incoming request, which is different from the sanitied data
40
41
        Returns
42
        -------
43
        out : (result type, dict, int)
44
            A triple containing a result type, the result message
45
            as a dictionary, and the time in seconds that it took to complete
46
            the request.
47
        """
48
        self.logger.log(logging.DEBUG,
49
                        f'Collecting query info for {po_name}...')
50
        start_time = time.time()
51
        response = self.python_service.ps.query(po_name, data, uid)
52
        gls_time = time.time() - start_time
53
        self.logger.log(logging.DEBUG, f'Query info: {response}')
54
55
        if isinstance(response, QuerySuccessful):
56
            response_json = response.to_json()
57
            md5_tag = md5(response_json.encode('utf-8')).hexdigest()
58
            self.set_header("Etag", f'"{md5_tag}"')
59
            return (QuerySuccessful, response.for_json(), gls_time)
60
        else:
61
            self.logger.log(
62
                logging.ERROR,
63
                f'Failed query, response: {response}')
64
            return (type(response), response.for_json(), gls_time)
65
66
    # handle HTTP Options requests to support CORS
67
    # don't check API key (client does not send or receive data for OPTIONS,
68
    # it just allows the client to subsequently make a POST request)
69
    def options(self, pred_name):
70
        if self.should_fail_with_not_authorized():
71
            self.fail_with_not_authorized()
72
            return
73
74
        self.logger.log(
75
            logging.DEBUG,
76
            f'Processing OPTIONS for /query/{pred_name}')
77
78
        # add CORS headers if TabPy has a cors_origin specified
79
        self._add_CORS_header()
80
        self.write({})
81
82
    def _handle_result(self, po_name, data, qry, uid):
83
        (response_type, response, gls_time) = \
84
            self._query(po_name, data, uid, qry)
85
86
        if response_type == QuerySuccessful:
87
            result_dict = {
88
                'response': response['response'],
89
                'version': response['version'],
90
                'model': po_name,
91
                'uuid': uid
92
            }
93
            self.write(result_dict)
94
            self.finish()
95
            return (gls_time, response['response'])
96
        else:
97
            if response_type == UnknownURI:
98
                self.error_out(404, 'UnknownURI',
99
                               info=('No query object has been registered'
100
                                     f' with the name "{po_name}"'))
101
            elif response_type == QueryError:
102
                self.error_out(400, 'QueryError', info=response)
103
            else:
104
                self.error_out(500, 'Error querying GLS', info=response)
105
106
            return (None, None)
107
108
    def _sanitize_request_data(self, data):
109
        if not isinstance(data, dict):
110
            msg = 'Input data must be a dictionary'
111
            self.logger.log(logging.CRITICAL, msg)
112
            raise RuntimeError(msg)
113
114
        if "method" in data:
115
            return {"data": data.get("data"), "method": data.get("method")}
116
        elif "data" in data:
117
            return data.get("data")
118
        else:
119
            msg = 'Input data must be a dictionary with a key called "data"'
120
            self.logger.log(logging.CRITICAL, msg)
121
            raise RuntimeError(msg)
122
123
    def _process_query(self, endpoint_name, start):
124
        self.logger.log(logging.DEBUG,
125
                        f'Processing query {endpoint_name}...')
126
        try:
127
            self._add_CORS_header()
128
129
            if not self.request.body:
130
                self.request.body = {}
131
132
            # extract request data explicitly for caching purpose
133
            request_json = self.request.body.decode('utf-8')
134
135
            # Sanitize input data
136
            data = self._sanitize_request_data(json.loads(request_json))
137
        except Exception as e:
138
            err_msg = format_exception(e, "Invalid Input Data")
139
            self.error_out(400, err_msg)
140
            return
141
142
        try:
143
            (po_name, _) = self._get_actual_model(
144
                endpoint_name)
145
146
            # po_name is None if self.python_service.ps.query_objects.get(
147
            # endpoint_name) is None
148
            if not po_name:
149
                self.error_out(
150
                    404,
151
                    'UnknownURI',
152
                    info=f'Endpoint "{endpoint_name}" does not exist')
153
                return
154
155
            po_obj = self.python_service.ps.query_objects.get(po_name)
156
157
            if not po_obj:
158
                self.error_out(404, 'UnknownURI',
159
                               info=f'Endpoint "{po_name}" does not exist')
160
                return
161
162
            if po_name != endpoint_name:
163
                self.logger.log(
164
                    logging.INFO,
165
                    f'Querying actual model: po_name={po_name}')
166
167
            uid = _get_uuid()
168
169
            # record query w/ request ID in query log
170
            qry = Query(po_name, request_json)
171
            gls_time = 0
172
            # send a query to PythonService and return
173
            (gls_time, _) = self._handle_result(po_name, data, qry, uid)
174
175
            # if error occurred, GLS time is None.
176
            if not gls_time:
177
                return
178
179
        except Exception as e:
180
            err_msg = format_exception(e, 'process query')
181
            self.error_out(500, 'Error processing query', info=err_msg)
182
            return
183
184
    def _get_actual_model(self, endpoint_name):
185
        # Find the actual query to run from given endpoint
186
        all_endpoint_names = []
187
188
        while True:
189
            endpoint_info = self.python_service.ps.query_objects.get(
190
                endpoint_name)
191
            if not endpoint_info:
192
                return [None, None]
193
194
            all_endpoint_names.append(endpoint_name)
195
196
            endpoint_type = endpoint_info.get('type', 'model')
197
198
            if endpoint_type == 'alias':
199
                endpoint_name = endpoint_info['endpoint_obj']
200
            elif endpoint_type == 'model':
201
                break
202
            else:
203
                self.error_out(
204
                    500,
205
                    'Unknown endpoint type',
206
                    info=f'Endpoint type "{endpoint_type}" does not exist')
207
                return
208
209
        return (endpoint_name, all_endpoint_names)
210
211
    @gen.coroutine
212
    def get(self, endpoint_name):
213
        if self.should_fail_with_not_authorized():
214
            self.fail_with_not_authorized()
215
            return
216
217
        start = time.time()
218
        endpoint_name = urllib.parse.unquote(endpoint_name)
219
        self._process_query(endpoint_name, start)
220
221
    @gen.coroutine
222
    def post(self, endpoint_name):
223
        self.logger.log(logging.DEBUG,
224
                        f'Processing POST for /query/{endpoint_name}...')
225
226
        if self.should_fail_with_not_authorized():
227
            self.fail_with_not_authorized()
228
            return
229
230
        start = time.time()
231
        endpoint_name = urllib.parse.unquote(endpoint_name)
232
        self._process_query(endpoint_name, start)
233