Passed
Pull Request — master (#652)
by
unknown
17:05
created

tabpy.tabpy_tools.client   B

Complexity

Total Complexity 45

Size/Duplication

Total Lines 464
Duplicated Lines 0 %

Test Coverage

Coverage 75.61%

Importance

Changes 0
Metric Value
wmc 45
eloc 154
dl 0
loc 464
ccs 93
cts 123
cp 0.7561
rs 8.8
c 0
b 0
f 0

How to fix   Complexity   

Complexity

Complex classes like tabpy.tabpy_tools.client often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1 1
import copy
2 1
import inspect
3 1
from re import compile
4 1
import time
5 1
import requests
6
7 1
from .rest import RequestsNetworkWrapper, ServiceClient
8
9 1
from .rest_client import RESTServiceClient, Endpoint
10
11 1
from .custom_query_object import CustomQueryObject
12 1
import os
13 1
import logging
14
15 1
logger = logging.getLogger(__name__)
16
17 1
_name_checker = compile(r"^[\w -]+$")
18
19
20 1
def _check_endpoint_type(name):
21 1
    if not isinstance(name, str):
22 1
        raise TypeError("Endpoint name must be a string")
23
24 1
    if name == "":
25 1
        raise ValueError("Endpoint name cannot be empty")
26
27
28 1
def _check_hostname(name):
29 1
    _check_endpoint_type(name)
30 1
    hostname_checker = compile(r"^^http(s)?://[\w.-]+(/)?(:\d+)?(/)?$")
31
32 1
    if not hostname_checker.match(name):
33 1
        raise ValueError(
34
            f"endpoint name {name} should be in http(s)://<hostname>"
35
            "[:<port>] and hostname may consist only of: "
36
            "a-z, A-Z, 0-9, underscore and hyphens."
37
        )
38
39
40 1
def _check_endpoint_name(name):
41
    """Checks that the endpoint name is valid by comparing it with an RE and
42
    checking that it is not reserved."""
43 1
    _check_endpoint_type(name)
44
45 1
    if not _name_checker.match(name):
46 1
        raise ValueError(
47
            f"endpoint name {name} can only contain: a-z, A-Z, 0-9,"
48
            " underscore, hyphens and spaces."
49
        )
50
51
52 1
class Client:
53 1
    def __init__(self, endpoint, query_timeout=1000, remote_server=False, localhost_endpoint=None):
54
        """
55
        Connects to a running server.
56
57
        The class constructor takes a server address which is then used to
58
        connect for all subsequent member APIs.
59
60
        Parameters
61
        ----------
62
        endpoint : str, optional
63
            The server URL.
64
65
        query_timeout : float, optional
66
            The timeout for query operations.
67
        
68
        remote_server : bool, optional
69
            Whether client is a remote TabPy server.
70
        
71
        localhost_endpoint : str, optional
72
            The localhost endpoint with potentially different protocol and 
73
            port compared to the main endpoint parameter. 
74
        """
75 1
        _check_hostname(endpoint)
76
77 1
        self._endpoint = endpoint
78 1
        self._remote_server = remote_server
79 1
        self._localhost_endpoint = localhost_endpoint
80
81 1
        session = requests.session()
82 1
        session.verify = False
83 1
        requests.packages.urllib3.disable_warnings()
84
85
        # Setup the communications layer.
86 1
        network_wrapper = RequestsNetworkWrapper(session)
87 1
        service_client = ServiceClient(self._endpoint, network_wrapper)
88
89 1
        self._service = RESTServiceClient(service_client)
90 1
        if not type(query_timeout) in (int, float) or query_timeout <= 0:
91 1
            query_timeout = 0.0
92 1
        self._service.query_timeout = query_timeout
93
94 1
    def __repr__(self):
95
        return (
96
            "<"
97
            + self.__class__.__name__
98
            + " object at "
99
            + hex(id(self))
100
            + " connected to "
101
            + repr(self._endpoint)
102
            + ">"
103
        )
104
105 1
    def get_status(self):
106
        """
107
        Gets the status of the deployed endpoints.
108
109
        Returns
110
        -------
111
        dict
112
            Keys are endpoints and values are dicts describing the state of
113
            the endpoint.
114
115
        Examples
116
        --------
117
        .. sourcecode:: python
118
            {
119
                u'foo': {
120
                    u'status': u'LoadFailed',
121
                    u'last_error': u'error mesasge',
122
                    u'version': 1,
123
                    u'type': u'model',
124
                },
125
            }
126
        """
127 1
        return self._service.get_status()
128
129
    #
130
    # Query
131
    #
132
133 1
    @property
134 1
    def query_timeout(self):
135
        """The timeout for queries in milliseconds."""
136 1
        return self._service.query_timeout
137
138 1
    @query_timeout.setter
139 1
    def query_timeout(self, value):
140 1
        if type(value) in (int, float) and value > 0:
141 1
            self._service.query_timeout = value
142
143 1
    def query(self, name, *args, **kwargs):
144
        """Query an endpoint.
145
146
        Parameters
147
        ----------
148
        name : str
149
            The name of the endpoint.
150
151
        *args : list of anything
152
            Ordered parameters to the endpoint.
153
154
        **kwargs : dict of anything
155
            Named parameters to the endpoint.
156
157
        Returns
158
        -------
159
        dict
160
            Keys are:
161
                model: the name of the endpoint
162
                version: the version used.
163
                response: the response to the query.
164
                uuid : a unique id for the request.
165
        """
166 1
        return self._service.query(name, *args, **kwargs)
167
168
    #
169
    # Endpoints
170
    #
171
172 1
    def get_endpoints(self, type=None):
173
        """Returns all deployed endpoints.
174
175
        Examples
176
        --------
177
        .. sourcecode:: python
178
            {"clustering":
179
              {"description": "",
180
               "docstring": "-- no docstring found in query function --",
181
               "creation_time": 1469511182,
182
               "version": 1,
183
               "dependencies": [],
184
               "last_modified_time": 1469511182,
185
               "type": "model",
186
               "target": null,
187
               "is_public": True}
188
            "add": {
189
              "description": "",
190
              "docstring": "-- no docstring found in query function --",
191
              "creation_time": 1469505967,
192
              "version": 1,
193
              "dependencies": [],
194
              "last_modified_time": 1469505967,
195
              "type": "model",
196
              "target": null,
197
              "is_public": False}
198
            }
199
        """
200 1
        return self._service.get_endpoints(type)
201
202 1
    def _get_endpoint_upload_destination(self):
203
        """Returns the endpoint upload destination."""
204 1
        return self._service.get_endpoint_upload_destination()["path"]
205
206 1
    def deploy(self, name, obj, description="", schema=None, override=False, is_public=False):
207
        """Deploys a Python function as an endpoint in the server.
208
209
        Parameters
210
        ----------
211
        name : str
212
            A unique identifier for the endpoint.
213
214
        obj :  function
215
            Refers to a user-defined function with any signature. However both
216
            input and output of the function need to be JSON serializable.
217
218
        description : str, optional
219
            The description for the endpoint. This string will be returned by
220
            the ``endpoints`` API.
221
222
        schema : dict, optional
223
            The schema of the function, containing information about input and
224
            output parameters, and respective examples. Providing a schema for
225
            a deployed function lets other users of the service discover how to
226
            use it. Refer to schema.generate_schema for more information on
227
            how to generate the schema.
228
229
        override : bool
230
            Whether to override (update) an existing endpoint. If False and
231
            there is already an endpoint with that name, it will raise a
232
            RuntimeError. If True and there is already an endpoint with that
233
            name, it will deploy a new version on top of it.
234
235
        is_public : bool, optional
236
            Whether a function should be public for viewing from within tableau. If
237
            False, function will not appear in the custom functions explorer within
238
            Tableau. If True, function will be visible ta anyone on a site with this
239
            analytics extension configured
240
241
        See Also
242
        --------
243
        remove, get_endpoints
244
        """
245 1
        if self._remote_server:
246
            self._remote_deploy(
247
                name, obj, 
248
                description=description, schema=schema, override=override, is_public=is_public
249
            )
250
            return
251
252 1
        endpoint = self.get_endpoints().get(name)
253 1
        version = 1
254 1
        if endpoint:
255 1
            if not override:
256
                raise RuntimeError(
257
                    f"An endpoint with that name ({name}) already"
258
                    ' exists. Use "override = True" to force update '
259
                    "an existing endpoint."
260
                )
261
262 1
            version = endpoint.version + 1
263
264 1
        obj = self._gen_endpoint(name, obj, description, version, schema, is_public)
265
266 1
        self._upload_endpoint(obj)
267
268 1
        if version == 1:
269 1
            self._service.add_endpoint(Endpoint(**obj))
270
        else:
271 1
            self._service.set_endpoint(Endpoint(**obj), should_update_version=True)
272
273 1
        self._wait_for_endpoint_deployment(obj["name"], obj["version"])
274
275 1
    def remove(self, name):
276
        '''Removes an endpoint dict.
277
278
        Parameters
279
        ----------
280
        name : str
281
            Endpoint name to remove'''
282
        self._service.remove_endpoint(name)
283
284 1
    def update_endpoint_info(self, name, description=None, schema=None, is_public=None):
285
        '''Updates description, schema, or is public for an existing endpoint
286
287
        Parameters
288
        ----------
289
        name : str
290
            Name of the endpoint that to be updated. If endpoint does not exist
291
            runtime error will be thrown
292
293
        description : str, optional
294
            The description for the endpoint. This string will be returned by
295
            the ``endpoints`` API.
296
297
        schema : dict, optional
298
            The schema of the function, containing information about input and
299
            output parameters, and respective examples. Providing a schema for
300
            a deployed function lets other users of the service discover how to
301
            use it. Refer to schema.generate_schema for more information on
302
            how to generate the schema.
303
304
        is_public : bool, optional
305
            Whether a function should be public for viewing from within tableau. If
306
            False, function will not appear in the custom functions explorer within
307
            Tableau. If True, function will be visible to anyone on a site with this
308
            analytics extension configured
309
        '''
310
311
        endpoint = self.get_endpoints().get(name)
312
313
        if not endpoint:
314
            raise RuntimeError(
315
                f"No endpoint with that name ({name}) exists"
316
                " Please select an existing endpoint to update"
317
            )
318
319
        if description is not None:
320
            if type(description) is not str:
321
                raise RuntimeError(
322
                    f"Type of description must be string"
323
                )
324
            endpoint.description = description
325
        if schema is not None:
326
            if type(schema) is not dict:
327
                raise RuntimeError(
328
                    f"Type of schema must be dictionary"
329
                )
330
            endpoint.schema = schema
331
        if is_public is not None:
332
            if type(is_public) is not bool:
333
                raise RuntimeError(
334
                    f"Type of is_public must be bool"
335
                )
336
            endpoint.is_public = is_public
337
338
        dest_path = self._get_endpoint_upload_destination()
339
340
        endpoint.src_path = os.path.join(
341
            dest_path, "endpoints", endpoint.name, str(endpoint.version)
342
        )
343
        self._service.set_endpoint(endpoint, should_update_version=False)
344
345 1
    def _gen_endpoint(self, name, obj, description, version=1, schema=None, is_public=False):
346
        """Generates an endpoint dict.
347
348
        Parameters
349
        ----------
350
        name : str
351
            Endpoint name to add or update
352
353
        obj :  func
354
            Object that backs the endpoint. See add() for a complete
355
            description.
356
357
        description : str
358
            Description of the endpoint
359
360
        version : int
361
            The version. Defaults to 1.
362
363
        is_public : bool
364
            True if function should be visible in the custom functions explorer
365
            within Tableau
366
367
        Returns
368
        -------
369
        dict
370
            Keys:
371
                name : str
372
                    The name provided.
373
374
                version : int
375
                    The version provided.
376
377
                description : str
378
                    The provided description.
379
380
                type : str
381
                    The type of the endpoint.
382
383
                endpoint_obj : object
384
                    The wrapper around the obj provided that can be used to
385
                    generate the code and dependencies for the endpoint.
386
387
        Raises
388
        ------
389
        TypeError
390
            When obj is not one of the expected types.
391
        """
392
        # check for invalid PO names
393 1
        _check_endpoint_name(name)
394
395 1
        if description is None:
396
            description = obj.__doc__.strip() or "" if isinstance(obj.__doc__, str) else ""
397
398 1
        endpoint_object = CustomQueryObject(query=obj, description=description,)
399 1
        docstring = inspect.getdoc(obj) or "-- no docstring found in query function --"
400
401 1
        return {
402
            "name": name,
403
            "version": version,
404
            "description": description,
405
            "type": "model",
406
            "endpoint_obj": endpoint_object,
407
            "dependencies": endpoint_object.get_dependencies(),
408
            "methods": endpoint_object.get_methods(),
409
            "required_files": [],
410
            "required_packages": [],
411
            "docstring": docstring,
412
            "schema": copy.copy(schema),
413
            "is_public": is_public,
414
        }
415
416 1
    def _upload_endpoint(self, obj):
417
        """Sends the endpoint across the wire."""
418 1
        endpoint_obj = obj["endpoint_obj"]
419
420 1
        dest_path = self._get_endpoint_upload_destination()
421
422
        # Upload the endpoint
423 1
        obj["src_path"] = os.path.join(
424
            dest_path, "endpoints", obj["name"], str(obj["version"])
425
        )
426
427 1
        endpoint_obj.save(obj["src_path"])
428
429 1
    def _wait_for_endpoint_deployment(
430
        self, endpoint_name, version=1, interval=1.0,
431
    ):
432
        """
433
        Waits for the endpoint to be deployed by calling get_status() and
434
        checking the versions deployed of the endpoint against the expected
435
        version. If all the versions are equal to or greater than the version
436
        expected, then it will return. Uses time.sleep().
437
        """
438 1
        logger.info(
439
            f"Waiting for endpoint {endpoint_name} to deploy to " f"version {version}"
440
        )
441 1
        time.sleep(interval)
442 1
        start = time.time()
443 1
        while True:
444 1
            ep_status = self.get_status()
445 1
            try:
446 1
                ep = ep_status[endpoint_name]
447
            except KeyError:
448
                logger.info(
449
                    f"Endpoint {endpoint_name} doesn't " "exist in endpoints yet"
450
                )
451
            else:
452 1
                logger.info(f"ep={ep}")
453
454 1
                if ep["status"] == "LoadFailed":
455
                    raise RuntimeError(f'LoadFailed: {ep["last_error"]}')
456
457 1
                elif ep["status"] == "LoadSuccessful":
458 1
                    if ep["version"] >= version:
459 1
                        logger.info("LoadSuccessful")
460 1
                        break
461
                    else:
462
                        logger.info("LoadSuccessful but wrong version")
463
464
            if time.time() - start > 10:
465
                raise RuntimeError("Waited more then 10s for deployment")
466
467
            logger.info(f"Sleeping {interval}...")
468
            time.sleep(interval)
469
470 1
    def _remote_deploy(self, name, obj, description="", schema=None, override=False, is_public=False):
471
        """
472
        Remotely deploy a Python function using the /evaluate endpoint. Takes the same inputs
473
        as deploy.
474
        """
475
        remote_script = self._gen_remote_script()
476
        remote_script += f"{inspect.getsource(obj)}\n"
477
478
        remote_script += (
479
            f"client.deploy("
480
            f"'{name}', {obj.__name__}, '{description}', "
481
            f"override={override}, is_public={is_public}, schema={schema}"
482
            f")"
483
        )
484
485
        self._evaluate_remote_script(remote_script)
486
487 1
    def _gen_remote_script(self):
488
        """
489
        Generates a remote script for TabPy client connection with credential handling.
490
491
        Returns:
492
            str: A Python script to establish a TabPy client connection
493
        """
494
        remote_script = [
495
            "from tabpy.tabpy_tools.client import Client",
496
            f"client = Client('{self._localhost_endpoint or self._endpoint}')"
497
        ]
498
499
        remote_script.append(
500
            f"client.set_credentials('{auth.username}', '{auth.password}')"
501
        ) if (auth := self._service.service_client.network_wrapper.auth) else None
502
        
503
        return "\n".join(remote_script) + "\n"
504
    
505 1
    def _evaluate_remote_script(self, remote_script):
506
        """
507
        Uses TabPy /evaluate endpoint to execute a remote TabPy client script.
508
509
        Parameters
510
        ----------
511
        remote_script : str
512
            The script to execute remotely.
513
        """
514
        print(f"Remote script:\n{remote_script}")
515
        url = f"{self._endpoint}evaluate"
516
        headers = {"Content-Type": "application/json"}
517
        payload = {"data": {}, "script": remote_script}
518
519
        response = requests.post(
520
            url,
521
            headers=headers,
522
            auth=self._service.service_client.network_wrapper.auth,
523
            json=payload
524
        )
525
526
        log_message = response.text.replace('null', 'success')
527
        if "Ad-hoc scripts have been disabled" in log_message:
528
            log_message += "\n[Connecting to this TabPy server with remote_server=True is not allowed.]"
529
        print(f"\n{response.status_code} - {log_message}\n")
530
531 1
    def set_credentials(self, username, password):
532
        """
533
        Set credentials for all the TabPy client-server communication
534
        where client is tabpy-tools and server is tabpy-server.
535
536
        Parameters
537
        ----------
538
        username : str
539
            User name (login). Username is case insensitive.
540
541
        password : str
542
            Password in plain text.
543
        """
544
        self._service.set_credentials(username, password)
545