Passed
Push — master ( 227024...1d0698 )
by Oleksandr
02:44
created

tabpy.tabpy_tools.client   B

Complexity

Total Complexity 45

Size/Duplication

Total Lines 530
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 45
eloc 183
dl 0
loc 530
rs 8.8
c 0
b 0
f 0

16 Methods

Rating   Name   Duplication   Size   Complexity  
A Client.remove() 0 24 2
A Client._gen_endpoint() 0 69 3
A Client._get_endpoint_upload_destination() 0 3 1
B Client.alias() 0 50 5
A Client.query() 0 24 1
A Client._upload_endpoint() 0 14 1
A Client.__init__() 0 34 3
A Client.get_endpoint_dependencies() 0 39 3
A Client.__repr__() 0 5 1
A Client.get_endpoints() 0 27 1
A Client.get_status() 0 23 1
A Client.deploy() 0 57 4
A Client.get_info() 0 14 1
A Client.set_credentials() 0 14 1
B Client._wait_for_endpoint_deployment() 0 41 8
A Client.query_timeout() 0 4 1

3 Functions

Rating   Name   Duplication   Size   Complexity  
A _check_endpoint_name() 0 8 2
A _check_hostname() 0 8 2
A _check_endpoint_type() 0 6 3

How to fix   Complexity   

Complexity

Complex classes like tabpy.tabpy_tools.client often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
from re import compile
2
import time
3
import sys
4
import requests
5
6
from .rest import (
7
    RequestsNetworkWrapper,
8
    ServiceClient
9
)
10
11
from .rest_client import (
12
    RESTServiceClient,
13
    Endpoint,
14
    AliasEndpoint
15
)
16
17
from .custom_query_object import CustomQueryObject
18
19
import os
20
import logging
21
22
logger = logging.getLogger(__name__)
23
24
_name_checker = compile(r'^[\w -]+$')
25
26
27
def _check_endpoint_type(name):
28
    if not isinstance(name, str):
29
        raise TypeError("Endpoint name must be a string")
30
31
    if name == '':
32
        raise ValueError("Endpoint name cannot be empty")
33
34
35
def _check_hostname(name):
36
    _check_endpoint_type(name)
37
    hostname_checker = compile(
38
        r'^http(s)?://[a-zA-Z0-9-_\.]+(/)?(:[0-9]+)?(/)?$')
39
40
    if not hostname_checker.match(name):
41
        raise ValueError(
42
            f'endpoint name {name} should be in http(s)://<hostname>'
43
            '[:<port>] and hostname may consist only of: '
44
            'a-z, A-Z, 0-9, underscore and hyphens.')
45
46
47
def _check_endpoint_name(name):
48
    """Checks that the endpoint name is valid by comparing it with an RE and
49
    checking that it is not reserved."""
50
    _check_endpoint_type(name)
51
52
    if not _name_checker.match(name):
53
        raise ValueError(
54
            f'endpoint name {name} can only contain: a-z, A-Z, 0-9,'
55
            ' underscore, hyphens and spaces.')
56
57
58
class Client(object):
59
    def __init__(self,
60
                 endpoint,
61
                 query_timeout=1000):
62
        """
63
        Connects to a running server.
64
65
        The class constructor takes a server address which is then used to
66
        connect for all subsequent member APIs.
67
68
        Parameters
69
        ----------
70
        endpoint : str, optional
71
            The server URL.
72
73
        query_timeout : float, optional
74
            The timeout for query operations.
75
        """
76
        _check_hostname(endpoint)
77
78
        self._endpoint = endpoint
79
80
        session = requests.session()
81
        session.verify = False
82
        requests.packages.urllib3.disable_warnings()
83
84
        # Setup the communications layer.
85
        network_wrapper = RequestsNetworkWrapper(session)
86
        service_client = ServiceClient(self._endpoint, network_wrapper)
87
88
        self._service = RESTServiceClient(service_client)
89
        if query_timeout is not None and query_timeout > 0:
90
            self.query_timeout = query_timeout
91
        else:
92
            self.query_timeout = 0.0
93
94
    def __repr__(self):
95
        return (
96
                "<" + self.__class__.__name__ +
97
                ' object at ' + hex(id(self)) +
98
                ' connected to ' + repr(self._endpoint) + ">")
99
100
    def get_info(self):
101
        """Returns a dict containing information about the service.
102
103
        Returns
104
        -------
105
        dict
106
            Keys are:
107
            * name: The name of the service
108
            * creation_time: The creation time in seconds since 1970-01-01
109
            * description: Description of the service
110
            * server_version: The version of the service used
111
            * state_path: Where the state file is stored.
112
        """
113
        return self._service.get_info()
114
115
    def get_status(self):
116
        '''
117
        Gets the status of the deployed endpoints.
118
119
        Returns
120
        -------
121
        dict
122
            Keys are endpoints and values are dicts describing the state of
123
            the endpoint.
124
125
        Examples
126
        --------
127
        .. sourcecode:: python
128
            {
129
                u'foo': {
130
                    u'status': u'LoadFailed',
131
                    u'last_error': u'error mesasge',
132
                    u'version': 1,
133
                    u'type': u'model',
134
                },
135
            }
136
        '''
137
        return self._service.get_status()
138
139
    #
140
    # Query
141
    #
142
143
    @property
144
    def query_timeout(self):
145
        """The timeout for queries in seconds."""
146
        return self._service.query_timeout
147
148
    @query_timeout.setter
149
    def query_timeout(self, value):
150
        self._service.query_timeout = value
151
152
    def query(self, name, *args, **kwargs):
153
        """Query an endpoint.
154
155
        Parameters
156
        ----------
157
        name : str
158
            The name of the endpoint.
159
160
        *args : list of anything
161
            Ordered parameters to the endpoint.
162
163
        **kwargs : dict of anything
164
            Named parameters to the endpoint.
165
166
        Returns
167
        -------
168
        dict
169
            Keys are:
170
                model: the name of the endpoint
171
                version: the version used.
172
                response: the response to the query.
173
                uuid : a unique id for the request.
174
        """
175
        return self._service.query(name, *args, **kwargs)
176
177
    #
178
    # Endpoints
179
    #
180
181
    def get_endpoints(self, type=None):
182
        """Returns all deployed endpoints.
183
184
        Examples
185
        --------
186
        .. sourcecode:: python
187
            {"clustering":
188
              {"description": "",
189
               "docstring": "-- no docstring found in query function --",
190
               "creation_time": 1469511182,
191
               "version": 1,
192
               "dependencies": [],
193
               "last_modified_time": 1469511182,
194
               "type": "model",
195
               "target": null},
196
            "add": {
197
              "description": "",
198
              "docstring": "-- no docstring found in query function --",
199
              "creation_time": 1469505967,
200
              "version": 1,
201
              "dependencies": [],
202
              "last_modified_time": 1469505967,
203
              "type": "model",
204
              "target": null}
205
            }
206
        """
207
        return self._service.get_endpoints(type)
208
209
    def _get_endpoint_upload_destination(self):
210
        """Returns the endpoint upload destination."""
211
        return self._service.get_endpoint_upload_destination()['path']
212
213
    def alias(self, alias, existing_endpoint_name, description=None):
214
        '''
215
        Create a new endpoint to redirect to an existing endpoint, or update an
216
        existing alias to point to a different existing endpoint.
217
218
        Parameters
219
        ----------
220
        alias : str
221
            The new endpoint name or an existing alias endpoint name.
222
223
        existing_endpoint_name : str
224
            A name of an existing endpoint to redirect the alias to.
225
226
        description : str, optional
227
            A description for the alias.
228
        '''
229
        # check for invalid PO names
230
        _check_endpoint_name(alias)
231
232
        if not description:
233
            description = f'Alias for {existing_endpoint_name}'
234
235
        if existing_endpoint_name not in self.get_endpoints():
236
            raise ValueError(
237
                f'Endpoint "{existing_endpoint_name}" does not exist.')
238
239
        # Can only overwrite existing alias
240
        existing_endpoint = self.get_endpoints().get(alias)
241
        endpoint = AliasEndpoint(
242
            name=alias,
243
            type='alias',
244
            description=description,
245
            target=existing_endpoint_name,
246
            cache_state='disabled',
247
            version=1,
248
        )
249
250
        if existing_endpoint:
251
            if existing_endpoint.type != 'alias':
252
                raise RuntimeError(
253
                    f'Name "{alias}" is already in use by another '
254
                    'endpoint.')
255
256
            endpoint.version = existing_endpoint.version + 1
257
258
            self._service.set_endpoint(endpoint)
259
        else:
260
            self._service.add_endpoint(endpoint)
261
262
        self._wait_for_endpoint_deployment(alias, endpoint.version)
263
264
    def deploy(self,
265
               name, obj, description='', schema=None,
266
               override=False):
267
        """Deploys a Python function as an endpoint in the server.
268
269
        Parameters
270
        ----------
271
        name : str
272
            A unique identifier for the endpoint.
273
274
        obj :  function
275
            Refers to a user-defined function with any signature. However both
276
            input and output of the function need to be JSON serializable.
277
278
        description : str, optional
279
            The description for the endpoint. This string will be returned by
280
            the ``endpoints`` API.
281
282
        schema : dict, optional
283
            The schema of the function, containing information about input and
284
            output parameters, and respective examples. Providing a schema for
285
            a deployed function lets other users of the service discover how to
286
            use it. Refer to schema.generate_schema for more information on
287
            how to generate the schema.
288
289
        override : bool
290
            Whether to override (update) an existing endpoint. If False and
291
            there is already an endpoint with that name, it will raise a
292
            RuntimeError. If True and there is already an endpoint with that
293
            name, it will deploy a new version on top of it.
294
295
        See Also
296
        --------
297
        remove, get_endpoints
298
        """
299
        endpoint = self.get_endpoints().get(name)
300
        if endpoint:
301
            if not override:
302
                raise RuntimeError(
303
                    f'An endpoint with that name ({name}) already'
304
                    ' exists. Use "override = True" to force update '
305
                    'an existing endpoint.')
306
307
            version = endpoint.version + 1
308
        else:
309
            version = 1
310
311
        obj = self._gen_endpoint(name, obj, description, version, schema)
312
313
        self._upload_endpoint(obj)
314
315
        if version == 1:
316
            self._service.add_endpoint(Endpoint(**obj))
317
        else:
318
            self._service.set_endpoint(Endpoint(**obj))
319
320
        self._wait_for_endpoint_deployment(obj['name'], obj['version'])
321
322
    def _gen_endpoint(self, name, obj, description, version=1, schema=[]):
323
        '''Generates an endpoint dict.
324
325
        Parameters
326
        ----------
327
        name : str
328
            Endpoint name to add or update
329
330
        obj :  func
331
            Object that backs the endpoint. See add() for a complete
332
            description.
333
334
        description : str
335
            Description of the endpoint
336
337
        version : int
338
            The version. Defaults to 1.
339
340
        Returns
341
        -------
342
        dict
343
            Keys:
344
                name : str
345
                    The name provided.
346
347
                version : int
348
                    The version provided.
349
350
                description : str
351
                    The provided description.
352
353
                type : str
354
                    The type of the endpoint.
355
356
                endpoint_obj : object
357
                    The wrapper around the obj provided that can be used to
358
                    generate the code and dependencies for the endpoint.
359
360
        Raises
361
        ------
362
        TypeError
363
            When obj is not one of the expected types.
364
        '''
365
        # check for invalid PO names
366
        _check_endpoint_name(name)
367
368
        if description is None:
369
            if isinstance(obj.__doc__, str):
370
                # extract doc string
371
                description = obj.__doc__.strip() or ''
372
            else:
373
                description = ''
374
375
        endpoint_object = CustomQueryObject(
376
            query=obj,
377
            description=description,
378
        )
379
380
        return {
381
            'name': name,
382
            'version': version,
383
            'description': description,
384
            'type': 'model',
385
            'endpoint_obj': endpoint_object,
386
            'dependencies': endpoint_object.get_dependencies(),
387
            'methods': endpoint_object.get_methods(),
388
            'required_files': [],
389
            'required_packages': [],
390
            'schema': schema
391
        }
392
393
    def _upload_endpoint(self, obj):
394
        """Sends the endpoint across the wire."""
395
        endpoint_obj = obj['endpoint_obj']
396
397
        dest_path = self._get_endpoint_upload_destination()
398
399
        # Upload the endpoint
400
        obj['src_path'] = os.path.join(
401
            dest_path,
402
            'endpoints',
403
            obj['name'],
404
            str(obj['version']))
405
406
        endpoint_obj.save(obj['src_path'])
407
408
    def _wait_for_endpoint_deployment(self,
409
                                      endpoint_name,
410
                                      version=1,
411
                                      interval=1.0,
412
                                      ):
413
        """
414
        Waits for the endpoint to be deployed by calling get_status() and
415
        checking the versions deployed of the endpoint against the expected
416
        version. If all the versions are equal to or greater than the version
417
        expected, then it will return. Uses time.sleep().
418
        """
419
        logger.info(
420
            f'Waiting for endpoint {endpoint_name} to deploy to '
421
            f'version {version}')
422
        start = time.time()
423
        while True:
424
            ep_status = self.get_status()
425
            try:
426
                ep = ep_status[endpoint_name]
427
            except KeyError:
428
                logger.info(f'Endpoint {endpoint_name} doesn\'t '
429
                            'exist in endpoints yet')
430
            else:
431
                logger.info(f'ep={ep}')
432
433
                if ep['status'] == 'LoadFailed':
434
                    raise RuntimeError(
435
                        f'LoadFailed: {ep["last_error"]}')
436
437
                elif ep['status'] == 'LoadSuccessful':
438
                    if ep['version'] >= version:
439
                        logger.info("LoadSuccessful")
440
                        break
441
                    else:
442
                        logger.info("LoadSuccessful but wrong version")
443
444
            if time.time() - start > 10:
445
                raise RuntimeError("Waited more then 10s for deployment")
446
447
            logger.info(f'Sleeping {interval}...')
448
            time.sleep(interval)
449
450
    def remove(self, name):
451
        '''
452
        Remove the endpoint that has the specified name.
453
454
        Parameters
455
        ----------
456
        name : str
457
            The name of the endpoint to be removed.
458
459
        Notes
460
        -----
461
        This could fail if the endpoint does not exist, or if the endpoint is
462
        in use by an alias. To check all endpoints
463
        that are depending on this endpoint, use `get_endpoint_dependencies`.
464
465
        See Also
466
        --------
467
        deploy, get_endpoint_dependencies
468
        '''
469
        self._service.remove_endpoint(name)
470
471
        # Wait for the endpoint to be removed
472
        while name in self.get_endpoints():
473
            time.sleep(1.0)
474
475
    def get_endpoint_dependencies(self, endpoint_name=None):
476
        '''
477
        Get all endpoints that depend on the given endpoint. The only
478
        dependency that is recorded is aliases on the endpoint they refer to.
479
        This will not return internal dependencies, as when you have an
480
        endpoint that calls another endpoint from within its code body.
481
482
        Parameters
483
        ----------
484
        endpoint_name : str, optional
485
            The name of the endpoint to find dependent endpoints. If not given,
486
            find all dependent endpoints for all endpoints.
487
488
        Returns
489
        -------
490
        dependent endpoints : dict
491
            If endpoint_name is given, returns a list of endpoint names that
492
            depend on the given endpoint.
493
494
            If endpoint_name is not given, returns a dictionary where key is
495
            the endpoint name and value is a set of endpoints that depend on
496
            the endpoint specified by the key.
497
        '''
498
        endpoints = self.get_endpoints()
499
500
        def get_dependencies(endpoint):
501
            result = set()
502
            for d in endpoints[endpoint].dependencies:
503
                result.update([d])
504
                result.update(get_dependencies(d))
505
            return result
506
507
        if endpoint_name:
508
            return get_dependencies(endpoint_name)
509
510
        else:
511
            return {
512
                endpoint: get_dependencies(endpoint)
513
                for endpoint in endpoints
514
            }
515
516
    def set_credentials(self, username, password):
517
        '''
518
        Set credentials for all the TabPy client-server communication
519
        where client is tabpy-tools and server is tabpy-server.
520
521
        Parameters
522
        ----------
523
        username : str
524
            User name (login). Username is case insensitive.
525
526
        password : str
527
            Password in plain text.
528
        '''
529
        self._service.set_credentials(username, password)
530