Passed
Pull Request — master (#652)
by
unknown
21:36 queued 04:55
created

tabpy.tabpy_tools.client   B

Complexity

Total Complexity 45

Size/Duplication

Total Lines 464
Duplicated Lines 0 %

Test Coverage

Coverage 75.61%

Importance

Changes 0
Metric Value
wmc 45
eloc 154
dl 0
loc 464
ccs 93
cts 123
cp 0.7561
rs 8.8
c 0
b 0
f 0

How to fix   Complexity   

Complexity

Complex classes like tabpy.tabpy_tools.client often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1 1
import copy
2 1
import inspect
3 1
from re import compile
4 1
import time
5 1
import requests
6
7 1
from .rest import RequestsNetworkWrapper, ServiceClient
8
9 1
from .rest_client import RESTServiceClient, Endpoint
10
11 1
from .custom_query_object import CustomQueryObject
12 1
import os
13 1
import logging
14
15 1
logger = logging.getLogger(__name__)
16
17 1
_name_checker = compile(r"^[\w -]+$")
18
19
20 1
def _check_endpoint_type(name):
21 1
    if not isinstance(name, str):
22 1
        raise TypeError("Endpoint name must be a string")
23
24 1
    if name == "":
25 1
        raise ValueError("Endpoint name cannot be empty")
26
27
28 1
def _check_hostname(name):
29 1
    _check_endpoint_type(name)
30 1
    hostname_checker = compile(r"^^http(s)?://[\w.-]+(/)?(:\d+)?(/)?$")
31
32 1
    if not hostname_checker.match(name):
33 1
        raise ValueError(
34
            f"endpoint name {name} should be in http(s)://<hostname>"
35
            "[:<port>] and hostname may consist only of: "
36
            "a-z, A-Z, 0-9, underscore and hyphens."
37
        )
38
39
40 1
def _check_endpoint_name(name):
41
    """Checks that the endpoint name is valid by comparing it with an RE and
42
    checking that it is not reserved."""
43 1
    _check_endpoint_type(name)
44
45 1
    if not _name_checker.match(name):
46 1
        raise ValueError(
47
            f"endpoint name {name} can only contain: a-z, A-Z, 0-9,"
48
            " underscore, hyphens and spaces."
49
        )
50
51
52 1
class Client:
53 1
    def __init__(
54
        self, endpoint, query_timeout=1000, remote_server=False, localhost_endpoint=None
55
    ):
56
        """
57
        Connects to a running server.
58
59
        The class constructor takes a server address which is then used to
60
        connect for all subsequent member APIs.
61
62
        Parameters
63
        ----------
64
        endpoint : str, optional
65
            The server URL.
66
67
        query_timeout : float, optional
68
            The timeout for query operations.
69
70
        remote_server : bool, optional
71
            Whether client is a remote TabPy server.
72
73
        localhost_endpoint : str, optional
74
            The localhost endpoint with potentially different protocol and
75
            port compared to the main endpoint parameter.
76
        """
77 1
        _check_hostname(endpoint)
78
79 1
        self._endpoint = endpoint
80 1
        self._remote_server = remote_server
81 1
        self._localhost_endpoint = localhost_endpoint
82
83 1
        session = requests.session()
84 1
        session.verify = False
85 1
        requests.packages.urllib3.disable_warnings()
86
87
        # Setup the communications layer.
88 1
        network_wrapper = RequestsNetworkWrapper(session)
89 1
        service_client = ServiceClient(self._endpoint, network_wrapper)
90
91 1
        self._service = RESTServiceClient(service_client)
92 1
        if not type(query_timeout) in (int, float) or query_timeout <= 0:
93 1
            query_timeout = 0.0
94 1
        self._service.query_timeout = query_timeout
95
96 1
    def __repr__(self):
97
        return (
98
            "<"
99
            + self.__class__.__name__
100
            + " object at "
101
            + hex(id(self))
102
            + " connected to "
103
            + repr(self._endpoint)
104
            + ">"
105
        )
106
107 1
    def get_status(self):
108
        """
109
        Gets the status of the deployed endpoints.
110
111
        Returns
112
        -------
113
        dict
114
            Keys are endpoints and values are dicts describing the state of
115
            the endpoint.
116
117
        Examples
118
        --------
119
        .. sourcecode:: python
120
            {
121
                u'foo': {
122
                    u'status': u'LoadFailed',
123
                    u'last_error': u'error mesasge',
124
                    u'version': 1,
125
                    u'type': u'model',
126
                },
127
            }
128
        """
129 1
        return self._service.get_status()
130
131
    #
132
    # Query
133
    #
134
135 1
    @property
136 1
    def query_timeout(self):
137
        """The timeout for queries in milliseconds."""
138 1
        return self._service.query_timeout
139
140 1
    @query_timeout.setter
141 1
    def query_timeout(self, value):
142 1
        if type(value) in (int, float) and value > 0:
143 1
            self._service.query_timeout = value
144
145 1
    def query(self, name, *args, **kwargs):
146
        """Query an endpoint.
147
148
        Parameters
149
        ----------
150
        name : str
151
            The name of the endpoint.
152
153
        *args : list of anything
154
            Ordered parameters to the endpoint.
155
156
        **kwargs : dict of anything
157
            Named parameters to the endpoint.
158
159
        Returns
160
        -------
161
        dict
162
            Keys are:
163
                model: the name of the endpoint
164
                version: the version used.
165
                response: the response to the query.
166
                uuid : a unique id for the request.
167
        """
168 1
        return self._service.query(name, *args, **kwargs)
169
170
    #
171
    # Endpoints
172
    #
173
174 1
    def get_endpoints(self, type=None):
175
        """Returns all deployed endpoints.
176
177
        Examples
178
        --------
179
        .. sourcecode:: python
180
            {"clustering":
181
              {"description": "",
182
               "docstring": "-- no docstring found in query function --",
183
               "creation_time": 1469511182,
184
               "version": 1,
185
               "dependencies": [],
186
               "last_modified_time": 1469511182,
187
               "type": "model",
188
               "target": null,
189
               "is_public": True}
190
            "add": {
191
              "description": "",
192
              "docstring": "-- no docstring found in query function --",
193
              "creation_time": 1469505967,
194
              "version": 1,
195
              "dependencies": [],
196
              "last_modified_time": 1469505967,
197
              "type": "model",
198
              "target": null,
199
              "is_public": False}
200
            }
201
        """
202 1
        return self._service.get_endpoints(type)
203
204 1
    def _get_endpoint_upload_destination(self):
205
        """Returns the endpoint upload destination."""
206 1
        return self._service.get_endpoint_upload_destination()["path"]
207
208 1
    def deploy(self, name, obj, description="", schema=None, override=False, is_public=False):
209
        """Deploys a Python function as an endpoint in the server.
210
211
        Parameters
212
        ----------
213
        name : str
214
            A unique identifier for the endpoint.
215
216
        obj :  function
217
            Refers to a user-defined function with any signature. However both
218
            input and output of the function need to be JSON serializable.
219
220
        description : str, optional
221
            The description for the endpoint. This string will be returned by
222
            the ``endpoints`` API.
223
224
        schema : dict, optional
225
            The schema of the function, containing information about input and
226
            output parameters, and respective examples. Providing a schema for
227
            a deployed function lets other users of the service discover how to
228
            use it. Refer to schema.generate_schema for more information on
229
            how to generate the schema.
230
231
        override : bool
232
            Whether to override (update) an existing endpoint. If False and
233
            there is already an endpoint with that name, it will raise a
234
            RuntimeError. If True and there is already an endpoint with that
235
            name, it will deploy a new version on top of it.
236
237
        is_public : bool, optional
238
            Whether a function should be public for viewing from within tableau. If
239
            False, function will not appear in the custom functions explorer within
240
            Tableau. If True, function will be visible ta anyone on a site with this
241
            analytics extension configured
242
243
        See Also
244
        --------
245
        remove, get_endpoints
246
        """
247 1
        if self._remote_server:
248 1
            self._remote_deploy(
249
                name, obj,
250
                description=description, schema=schema, override=override, is_public=is_public
251
            )
252 1
            return
253
254 1
        endpoint = self.get_endpoints().get(name)
255 1
        version = 1
256 1
        if endpoint:
257 1
            if not override:
258
                raise RuntimeError(
259
                    f"An endpoint with that name ({name}) already"
260
                    ' exists. Use "override = True" to force update '
261
                    "an existing endpoint."
262
                )
263
264 1
            version = endpoint.version + 1
265
266 1
        obj = self._gen_endpoint(name, obj, description, version, schema, is_public)
267
268 1
        self._upload_endpoint(obj)
269
270 1
        if version == 1:
271 1
            self._service.add_endpoint(Endpoint(**obj))
272
        else:
273 1
            self._service.set_endpoint(Endpoint(**obj), should_update_version=True)
274
275 1
        self._wait_for_endpoint_deployment(obj["name"], obj["version"])
276
277 1
    def remove(self, name):
278
        '''Removes an endpoint dict.
279
280
        Parameters
281
        ----------
282
        name : str
283
            Endpoint name to remove'''
284
        self._service.remove_endpoint(name)
285
286 1
    def update_endpoint_info(self, name, description=None, schema=None, is_public=None):
287
        '''Updates description, schema, or is public for an existing endpoint
288
289
        Parameters
290
        ----------
291
        name : str
292
            Name of the endpoint that to be updated. If endpoint does not exist
293
            runtime error will be thrown
294
295
        description : str, optional
296
            The description for the endpoint. This string will be returned by
297
            the ``endpoints`` API.
298
299
        schema : dict, optional
300
            The schema of the function, containing information about input and
301
            output parameters, and respective examples. Providing a schema for
302
            a deployed function lets other users of the service discover how to
303
            use it. Refer to schema.generate_schema for more information on
304
            how to generate the schema.
305
306
        is_public : bool, optional
307
            Whether a function should be public for viewing from within tableau. If
308
            False, function will not appear in the custom functions explorer within
309
            Tableau. If True, function will be visible to anyone on a site with this
310
            analytics extension configured
311
        '''
312
313
        endpoint = self.get_endpoints().get(name)
314
315
        if not endpoint:
316
            raise RuntimeError(
317
                f"No endpoint with that name ({name}) exists"
318
                " Please select an existing endpoint to update"
319
            )
320
321
        if description is not None:
322
            if type(description) is not str:
323
                raise RuntimeError(
324
                    f"Type of description must be string"
325
                )
326
            endpoint.description = description
327
        if schema is not None:
328
            if type(schema) is not dict:
329
                raise RuntimeError(
330
                    f"Type of schema must be dictionary"
331
                )
332
            endpoint.schema = schema
333
        if is_public is not None:
334
            if type(is_public) is not bool:
335
                raise RuntimeError(
336
                    f"Type of is_public must be bool"
337
                )
338
            endpoint.is_public = is_public
339
340
        dest_path = self._get_endpoint_upload_destination()
341
342
        endpoint.src_path = os.path.join(
343
            dest_path, "endpoints", endpoint.name, str(endpoint.version)
344
        )
345
        self._service.set_endpoint(endpoint, should_update_version=False)
346
347 1
    def _gen_endpoint(self, name, obj, description, version=1, schema=None, is_public=False):
348
        """Generates an endpoint dict.
349
350
        Parameters
351
        ----------
352
        name : str
353
            Endpoint name to add or update
354
355
        obj :  func
356
            Object that backs the endpoint. See add() for a complete
357
            description.
358
359
        description : str
360
            Description of the endpoint
361
362
        version : int
363
            The version. Defaults to 1.
364
365
        is_public : bool
366
            True if function should be visible in the custom functions explorer
367
            within Tableau
368
369
        Returns
370
        -------
371
        dict
372
            Keys:
373
                name : str
374
                    The name provided.
375
376
                version : int
377
                    The version provided.
378
379
                description : str
380
                    The provided description.
381
382
                type : str
383
                    The type of the endpoint.
384
385
                endpoint_obj : object
386
                    The wrapper around the obj provided that can be used to
387
                    generate the code and dependencies for the endpoint.
388
389
        Raises
390
        ------
391
        TypeError
392
            When obj is not one of the expected types.
393
        """
394
        # check for invalid PO names
395 1
        _check_endpoint_name(name)
396
397 1
        if description is None:
398
            description = obj.__doc__.strip() or "" if isinstance(obj.__doc__, str) else ""
399
400 1
        endpoint_object = CustomQueryObject(query=obj, description=description,)
401 1
        docstring = inspect.getdoc(obj) or "-- no docstring found in query function --"
402
403 1
        return {
404
            "name": name,
405
            "version": version,
406
            "description": description,
407
            "type": "model",
408
            "endpoint_obj": endpoint_object,
409
            "dependencies": endpoint_object.get_dependencies(),
410
            "methods": endpoint_object.get_methods(),
411
            "required_files": [],
412
            "required_packages": [],
413
            "docstring": docstring,
414
            "schema": copy.copy(schema),
415
            "is_public": is_public,
416
        }
417
418 1
    def _upload_endpoint(self, obj):
419
        """Sends the endpoint across the wire."""
420 1
        endpoint_obj = obj["endpoint_obj"]
421
422 1
        dest_path = self._get_endpoint_upload_destination()
423
424
        # Upload the endpoint
425 1
        obj["src_path"] = os.path.join(
426
            dest_path, "endpoints", obj["name"], str(obj["version"])
427
        )
428
429 1
        endpoint_obj.save(obj["src_path"])
430
431 1
    def _wait_for_endpoint_deployment(
432
        self, endpoint_name, version=1, interval=1.0,
433
    ):
434
        """
435
        Waits for the endpoint to be deployed by calling get_status() and
436
        checking the versions deployed of the endpoint against the expected
437
        version. If all the versions are equal to or greater than the version
438
        expected, then it will return. Uses time.sleep().
439
        """
440 1
        logger.info(
441
            f"Waiting for endpoint {endpoint_name} to deploy to " f"version {version}"
442
        )
443 1
        time.sleep(interval)
444 1
        start = time.time()
445 1
        while True:
446 1
            ep_status = self.get_status()
447 1
            try:
448 1
                ep = ep_status[endpoint_name]
449
            except KeyError:
450
                logger.info(
451
                    f"Endpoint {endpoint_name} doesn't " "exist in endpoints yet"
452
                )
453
            else:
454 1
                logger.info(f"ep={ep}")
455
456 1
                if ep["status"] == "LoadFailed":
457
                    raise RuntimeError(f'LoadFailed: {ep["last_error"]}')
458
459 1
                elif ep["status"] == "LoadSuccessful":
460 1
                    if ep["version"] >= version:
461 1
                        logger.info("LoadSuccessful")
462 1
                        break
463
                    else:
464
                        logger.info("LoadSuccessful but wrong version")
465
466
            if time.time() - start > 10:
467
                raise RuntimeError("Waited more then 10s for deployment")
468
469
            logger.info(f"Sleeping {interval}...")
470
            time.sleep(interval)
471
472 1
    def _remote_deploy(
473
        self, name, obj, description="", schema=None, override=False, is_public=False
474
    ):
475
        """
476
        Remotely deploy a Python function using the /evaluate endpoint. Takes the same inputs
477
        as deploy.
478
        """
479 1
        remote_script = self._gen_remote_script()
480 1
        remote_script += f"{inspect.getsource(obj)}\n"
481
482 1
        remote_script += (
483
            f"client.deploy("
484
            f"'{name}', {obj.__name__}, '{description}', "
485
            f"override={override}, is_public={is_public}, schema={schema}"
486
            f")"
487
        )
488
489 1
        self._evaluate_remote_script(remote_script)
490
491 1
    def _gen_remote_script(self):
492
        """
493
        Generates a remote script for TabPy client connection with credential handling.
494
495
        Returns:
496
            str: A Python script to establish a TabPy client connection
497
        """
498 1
        remote_script = [
499
            "from tabpy.tabpy_tools.client import Client",
500
            f"client = Client('{self._localhost_endpoint or self._endpoint}')"
501
        ]
502
503 1
        remote_script.append(
504
            f"client.set_credentials('{auth.username}', '{auth.password}')"
505
        ) if (auth := self._service.service_client.network_wrapper.auth) else None
506
507 1
        return "\n".join(remote_script) + "\n"
508
509 1
    def _evaluate_remote_script(self, remote_script):
510
        """
511
        Uses TabPy /evaluate endpoint to execute a remote TabPy client script.
512
513
        Parameters
514
        ----------
515
        remote_script : str
516
            The script to execute remotely.
517
        """
518
        print(f"Remote script:\n{remote_script}")
519
        url = f"{self._endpoint}evaluate"
520
        headers = {"Content-Type": "application/json"}
521
        payload = {"data": {}, "script": remote_script}
522
523
        response = requests.post(
524
            url,
525
            headers=headers,
526
            auth=self._service.service_client.network_wrapper.auth,
527
            json=payload
528
        )
529
530
        msg = response.text.replace('null', 'success')
531
        if "Ad-hoc scripts have been disabled" in msg:
532
            msg += "\n[Remote TabPy client not allowed.]"
533
        print(f"\n{response.status_code} - {msg}\n")
534
535 1
    def set_credentials(self, username, password):
536
        """
537
        Set credentials for all the TabPy client-server communication
538
        where client is tabpy-tools and server is tabpy-server.
539
540
        Parameters
541
        ----------
542
        username : str
543
            User name (login). Username is case insensitive.
544
545
        password : str
546
            Password in plain text.
547
        """
548
        self._service.set_credentials(username, password)
549