Passed
Push — master ( 305552...00a4d3 )
by Oleksandr
02:44
created

tabpy.tabpy_tools.client   A

Complexity

Total Complexity 34

Size/Duplication

Total Lines 398
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 34
eloc 141
dl 0
loc 398
rs 9.68
c 0
b 0
f 0

3 Functions

Rating   Name   Duplication   Size   Complexity  
A _check_endpoint_name() 0 8 2
A _check_hostname() 0 8 2
A _check_endpoint_type() 0 6 3

12 Methods

Rating   Name   Duplication   Size   Complexity  
A Client.__init__() 0 34 3
A Client.__repr__() 0 5 1
A Client._gen_endpoint() 0 69 3
A Client._get_endpoint_upload_destination() 0 3 1
A Client.query() 0 24 1
A Client._upload_endpoint() 0 14 1
A Client.get_endpoints() 0 27 1
A Client.deploy() 0 57 4
A Client.get_status() 0 23 1
A Client.set_credentials() 0 14 1
B Client._wait_for_endpoint_deployment() 0 41 8
A Client.query_timeout() 0 4 1
1
from re import compile
2
import time
3
import sys
4
import requests
5
6
from .rest import (
7
    RequestsNetworkWrapper,
8
    ServiceClient
9
)
10
11
from .rest_client import (
12
    RESTServiceClient,
13
    Endpoint,
14
    AliasEndpoint
15
)
16
17
from .custom_query_object import CustomQueryObject
18
19
import os
20
import logging
21
22
logger = logging.getLogger(__name__)
23
24
_name_checker = compile(r'^[\w -]+$')
25
26
27
def _check_endpoint_type(name):
28
    if not isinstance(name, str):
29
        raise TypeError("Endpoint name must be a string")
30
31
    if name == '':
32
        raise ValueError("Endpoint name cannot be empty")
33
34
35
def _check_hostname(name):
36
    _check_endpoint_type(name)
37
    hostname_checker = compile(
38
        r'^http(s)?://[a-zA-Z0-9-_\.]+(/)?(:[0-9]+)?(/)?$')
39
40
    if not hostname_checker.match(name):
41
        raise ValueError(
42
            f'endpoint name {name} should be in http(s)://<hostname>'
43
            '[:<port>] and hostname may consist only of: '
44
            'a-z, A-Z, 0-9, underscore and hyphens.')
45
46
47
def _check_endpoint_name(name):
48
    """Checks that the endpoint name is valid by comparing it with an RE and
49
    checking that it is not reserved."""
50
    _check_endpoint_type(name)
51
52
    if not _name_checker.match(name):
53
        raise ValueError(
54
            f'endpoint name {name} can only contain: a-z, A-Z, 0-9,'
55
            ' underscore, hyphens and spaces.')
56
57
58
class Client(object):
59
    def __init__(self,
60
                 endpoint,
61
                 query_timeout=1000):
62
        """
63
        Connects to a running server.
64
65
        The class constructor takes a server address which is then used to
66
        connect for all subsequent member APIs.
67
68
        Parameters
69
        ----------
70
        endpoint : str, optional
71
            The server URL.
72
73
        query_timeout : float, optional
74
            The timeout for query operations.
75
        """
76
        _check_hostname(endpoint)
77
78
        self._endpoint = endpoint
79
80
        session = requests.session()
81
        session.verify = False
82
        requests.packages.urllib3.disable_warnings()
83
84
        # Setup the communications layer.
85
        network_wrapper = RequestsNetworkWrapper(session)
86
        service_client = ServiceClient(self._endpoint, network_wrapper)
87
88
        self._service = RESTServiceClient(service_client)
89
        if query_timeout is not None and query_timeout > 0:
90
            self.query_timeout = query_timeout
91
        else:
92
            self.query_timeout = 0.0
93
94
    def __repr__(self):
95
        return (
96
                "<" + self.__class__.__name__ +
97
                ' object at ' + hex(id(self)) +
98
                ' connected to ' + repr(self._endpoint) + ">")
99
100
    def get_status(self):
101
        '''
102
        Gets the status of the deployed endpoints.
103
104
        Returns
105
        -------
106
        dict
107
            Keys are endpoints and values are dicts describing the state of
108
            the endpoint.
109
110
        Examples
111
        --------
112
        .. sourcecode:: python
113
            {
114
                u'foo': {
115
                    u'status': u'LoadFailed',
116
                    u'last_error': u'error mesasge',
117
                    u'version': 1,
118
                    u'type': u'model',
119
                },
120
            }
121
        '''
122
        return self._service.get_status()
123
124
    #
125
    # Query
126
    #
127
128
    @property
129
    def query_timeout(self):
130
        """The timeout for queries in seconds."""
131
        return self._service.query_timeout
132
133
    @query_timeout.setter
134
    def query_timeout(self, value):
135
        self._service.query_timeout = value
136
137
    def query(self, name, *args, **kwargs):
138
        """Query an endpoint.
139
140
        Parameters
141
        ----------
142
        name : str
143
            The name of the endpoint.
144
145
        *args : list of anything
146
            Ordered parameters to the endpoint.
147
148
        **kwargs : dict of anything
149
            Named parameters to the endpoint.
150
151
        Returns
152
        -------
153
        dict
154
            Keys are:
155
                model: the name of the endpoint
156
                version: the version used.
157
                response: the response to the query.
158
                uuid : a unique id for the request.
159
        """
160
        return self._service.query(name, *args, **kwargs)
161
162
    #
163
    # Endpoints
164
    #
165
166
    def get_endpoints(self, type=None):
167
        """Returns all deployed endpoints.
168
169
        Examples
170
        --------
171
        .. sourcecode:: python
172
            {"clustering":
173
              {"description": "",
174
               "docstring": "-- no docstring found in query function --",
175
               "creation_time": 1469511182,
176
               "version": 1,
177
               "dependencies": [],
178
               "last_modified_time": 1469511182,
179
               "type": "model",
180
               "target": null},
181
            "add": {
182
              "description": "",
183
              "docstring": "-- no docstring found in query function --",
184
              "creation_time": 1469505967,
185
              "version": 1,
186
              "dependencies": [],
187
              "last_modified_time": 1469505967,
188
              "type": "model",
189
              "target": null}
190
            }
191
        """
192
        return self._service.get_endpoints(type)
193
194
    def _get_endpoint_upload_destination(self):
195
        """Returns the endpoint upload destination."""
196
        return self._service.get_endpoint_upload_destination()['path']
197
198
    def deploy(self,
199
               name, obj, description='', schema=None,
200
               override=False):
201
        """Deploys a Python function as an endpoint in the server.
202
203
        Parameters
204
        ----------
205
        name : str
206
            A unique identifier for the endpoint.
207
208
        obj :  function
209
            Refers to a user-defined function with any signature. However both
210
            input and output of the function need to be JSON serializable.
211
212
        description : str, optional
213
            The description for the endpoint. This string will be returned by
214
            the ``endpoints`` API.
215
216
        schema : dict, optional
217
            The schema of the function, containing information about input and
218
            output parameters, and respective examples. Providing a schema for
219
            a deployed function lets other users of the service discover how to
220
            use it. Refer to schema.generate_schema for more information on
221
            how to generate the schema.
222
223
        override : bool
224
            Whether to override (update) an existing endpoint. If False and
225
            there is already an endpoint with that name, it will raise a
226
            RuntimeError. If True and there is already an endpoint with that
227
            name, it will deploy a new version on top of it.
228
229
        See Also
230
        --------
231
        remove, get_endpoints
232
        """
233
        endpoint = self.get_endpoints().get(name)
234
        if endpoint:
235
            if not override:
236
                raise RuntimeError(
237
                    f'An endpoint with that name ({name}) already'
238
                    ' exists. Use "override = True" to force update '
239
                    'an existing endpoint.')
240
241
            version = endpoint.version + 1
242
        else:
243
            version = 1
244
245
        obj = self._gen_endpoint(name, obj, description, version, schema)
246
247
        self._upload_endpoint(obj)
248
249
        if version == 1:
250
            self._service.add_endpoint(Endpoint(**obj))
251
        else:
252
            self._service.set_endpoint(Endpoint(**obj))
253
254
        self._wait_for_endpoint_deployment(obj['name'], obj['version'])
255
256
    def _gen_endpoint(self, name, obj, description, version=1, schema=[]):
257
        '''Generates an endpoint dict.
258
259
        Parameters
260
        ----------
261
        name : str
262
            Endpoint name to add or update
263
264
        obj :  func
265
            Object that backs the endpoint. See add() for a complete
266
            description.
267
268
        description : str
269
            Description of the endpoint
270
271
        version : int
272
            The version. Defaults to 1.
273
274
        Returns
275
        -------
276
        dict
277
            Keys:
278
                name : str
279
                    The name provided.
280
281
                version : int
282
                    The version provided.
283
284
                description : str
285
                    The provided description.
286
287
                type : str
288
                    The type of the endpoint.
289
290
                endpoint_obj : object
291
                    The wrapper around the obj provided that can be used to
292
                    generate the code and dependencies for the endpoint.
293
294
        Raises
295
        ------
296
        TypeError
297
            When obj is not one of the expected types.
298
        '''
299
        # check for invalid PO names
300
        _check_endpoint_name(name)
301
302
        if description is None:
303
            if isinstance(obj.__doc__, str):
304
                # extract doc string
305
                description = obj.__doc__.strip() or ''
306
            else:
307
                description = ''
308
309
        endpoint_object = CustomQueryObject(
310
            query=obj,
311
            description=description,
312
        )
313
314
        return {
315
            'name': name,
316
            'version': version,
317
            'description': description,
318
            'type': 'model',
319
            'endpoint_obj': endpoint_object,
320
            'dependencies': endpoint_object.get_dependencies(),
321
            'methods': endpoint_object.get_methods(),
322
            'required_files': [],
323
            'required_packages': [],
324
            'schema': schema
325
        }
326
327
    def _upload_endpoint(self, obj):
328
        """Sends the endpoint across the wire."""
329
        endpoint_obj = obj['endpoint_obj']
330
331
        dest_path = self._get_endpoint_upload_destination()
332
333
        # Upload the endpoint
334
        obj['src_path'] = os.path.join(
335
            dest_path,
336
            'endpoints',
337
            obj['name'],
338
            str(obj['version']))
339
340
        endpoint_obj.save(obj['src_path'])
341
342
    def _wait_for_endpoint_deployment(self,
343
                                      endpoint_name,
344
                                      version=1,
345
                                      interval=1.0,
346
                                      ):
347
        """
348
        Waits for the endpoint to be deployed by calling get_status() and
349
        checking the versions deployed of the endpoint against the expected
350
        version. If all the versions are equal to or greater than the version
351
        expected, then it will return. Uses time.sleep().
352
        """
353
        logger.info(
354
            f'Waiting for endpoint {endpoint_name} to deploy to '
355
            f'version {version}')
356
        start = time.time()
357
        while True:
358
            ep_status = self.get_status()
359
            try:
360
                ep = ep_status[endpoint_name]
361
            except KeyError:
362
                logger.info(f'Endpoint {endpoint_name} doesn\'t '
363
                            'exist in endpoints yet')
364
            else:
365
                logger.info(f'ep={ep}')
366
367
                if ep['status'] == 'LoadFailed':
368
                    raise RuntimeError(
369
                        f'LoadFailed: {ep["last_error"]}')
370
371
                elif ep['status'] == 'LoadSuccessful':
372
                    if ep['version'] >= version:
373
                        logger.info("LoadSuccessful")
374
                        break
375
                    else:
376
                        logger.info("LoadSuccessful but wrong version")
377
378
            if time.time() - start > 10:
379
                raise RuntimeError("Waited more then 10s for deployment")
380
381
            logger.info(f'Sleeping {interval}...')
382
            time.sleep(interval)
383
384
    def set_credentials(self, username, password):
385
        '''
386
        Set credentials for all the TabPy client-server communication
387
        where client is tabpy-tools and server is tabpy-server.
388
389
        Parameters
390
        ----------
391
        username : str
392
            User name (login). Username is case insensitive.
393
394
        password : str
395
            Password in plain text.
396
        '''
397
        self._service.set_credentials(username, password)
398