Passed
Pull Request — master (#652)
by
unknown
16:42
created

tabpy.tabpy_tools.client._check_hostname()   A

Complexity

Conditions 2

Size

Total Lines 7
Code Lines 6

Duplication

Lines 0
Ratio 0 %

Code Coverage

Tests 5
CRAP Score 2

Importance

Changes 0
Metric Value
eloc 6
dl 0
loc 7
ccs 5
cts 5
cp 1
rs 10
c 0
b 0
f 0
cc 2
nop 1
crap 2
1 1
import copy
2 1
import inspect
3 1
from re import compile
4 1
import time
5 1
import requests
6
7 1
from .rest import RequestsNetworkWrapper, ServiceClient
8
9 1
from .rest_client import RESTServiceClient, Endpoint
10
11 1
from .custom_query_object import CustomQueryObject
12 1
import os
13 1
import logging
14
15 1
logger = logging.getLogger(__name__)
16
17 1
_name_checker = compile(r"^[\w -]+$")
18
19
20 1
def _check_endpoint_type(name):
21 1
    if not isinstance(name, str):
22 1
        raise TypeError("Endpoint name must be a string")
23
24 1
    if name == "":
25 1
        raise ValueError("Endpoint name cannot be empty")
26
27
28 1
def _check_hostname(name):
29 1
    _check_endpoint_type(name)
30 1
    hostname_checker = compile(r"^^http(s)?://[\w.-]+(/)?(:\d+)?(/)?$")
31
32 1
    if not hostname_checker.match(name):
33 1
        raise ValueError(
34
            f"endpoint name {name} should be in http(s)://<hostname>"
35
            "[:<port>] and hostname may consist only of: "
36
            "a-z, A-Z, 0-9, underscore and hyphens."
37
        )
38
39
40 1
def _check_endpoint_name(name):
41
    """Checks that the endpoint name is valid by comparing it with an RE and
42
    checking that it is not reserved."""
43 1
    _check_endpoint_type(name)
44
45 1
    if not _name_checker.match(name):
46 1
        raise ValueError(
47
            f"endpoint name {name} can only contain: a-z, A-Z, 0-9,"
48
            " underscore, hyphens and spaces."
49
        )
50
51
52 1
class Client:
53 1
    def __init__(
54
        self, endpoint, query_timeout=1000, remote_server=False, localhost_endpoint=None
55
    ):
56
        """
57
        Connects to a running server.
58
59
        The class constructor takes a server address which is then used to
60
        connect for all subsequent member APIs.
61
62
        Parameters
63
        ----------
64
        endpoint : str, optional
65
            The server URL.
66
67
        query_timeout : float, optional
68
            The timeout for query operations.
69
70
        remote_server : bool, optional
71
            Whether client is a remote TabPy server.
72
73
        localhost_endpoint : str, optional
74
            The localhost endpoint with potentially different protocol and
75
            port compared to the main endpoint parameter.
76
        """
77 1
        _check_hostname(endpoint)
78
79 1
        self._endpoint = endpoint
80 1
        self._remote_server = remote_server
81 1
        self._localhost_endpoint = localhost_endpoint
82
83 1
        session = requests.session()
84 1
        session.verify = False
85 1
        requests.packages.urllib3.disable_warnings()
86
87
        # Setup the communications layer.
88 1
        network_wrapper = RequestsNetworkWrapper(session)
89 1
        service_client = ServiceClient(self._endpoint, network_wrapper)
90
91 1
        self._service = RESTServiceClient(service_client)
92 1
        if not type(query_timeout) in (int, float) or query_timeout <= 0:
93 1
            query_timeout = 0.0
94 1
        self._service.query_timeout = query_timeout
95
96 1
    def __repr__(self):
97
        return (
98
            "<"
99
            + self.__class__.__name__
100
            + " object at "
101
            + hex(id(self))
102
            + " connected to "
103
            + repr(self._endpoint)
104
            + ">"
105
        )
106
107 1
    def get_status(self):
108
        """
109
        Gets the status of the deployed endpoints.
110
111
        Returns
112
        -------
113
        dict
114
            Keys are endpoints and values are dicts describing the state of
115
            the endpoint.
116
117
        Examples
118
        --------
119
        .. sourcecode:: python
120
            {
121
                u'foo': {
122
                    u'status': u'LoadFailed',
123
                    u'last_error': u'error mesasge',
124
                    u'version': 1,
125
                    u'type': u'model',
126
                },
127
            }
128
        """
129 1
        return self._service.get_status()
130
131
    #
132
    # Query
133
    #
134
135 1
    @property
136 1
    def query_timeout(self):
137
        """The timeout for queries in milliseconds."""
138 1
        return self._service.query_timeout
139
140 1
    @query_timeout.setter
141 1
    def query_timeout(self, value):
142 1
        if type(value) in (int, float) and value > 0:
143 1
            self._service.query_timeout = value
144
145 1
    def query(self, name, *args, **kwargs):
146
        """Query an endpoint.
147
148
        Parameters
149
        ----------
150
        name : str
151
            The name of the endpoint.
152
153
        *args : list of anything
154
            Ordered parameters to the endpoint.
155
156
        **kwargs : dict of anything
157
            Named parameters to the endpoint.
158
159
        Returns
160
        -------
161
        dict
162
            Keys are:
163
                model: the name of the endpoint
164
                version: the version used.
165
                response: the response to the query.
166
                uuid : a unique id for the request.
167
        """
168 1
        return self._service.query(name, *args, **kwargs)
169
170
    #
171
    # Endpoints
172
    #
173
174 1
    def get_endpoints(self, type=None):
175
        """Returns all deployed endpoints.
176
177
        Examples
178
        --------
179
        .. sourcecode:: python
180
            {"clustering":
181
              {"description": "",
182
               "docstring": "-- no docstring found in query function --",
183
               "creation_time": 1469511182,
184
               "version": 1,
185
               "dependencies": [],
186
               "last_modified_time": 1469511182,
187
               "type": "model",
188
               "target": null,
189
               "is_public": True}
190
            "add": {
191
              "description": "",
192
              "docstring": "-- no docstring found in query function --",
193
              "creation_time": 1469505967,
194
              "version": 1,
195
              "dependencies": [],
196
              "last_modified_time": 1469505967,
197
              "type": "model",
198
              "target": null,
199
              "is_public": False}
200
            }
201
        """
202 1
        return self._service.get_endpoints(type)
203
204 1
    def _get_endpoint_upload_destination(self):
205
        """Returns the endpoint upload destination."""
206 1
        return self._service.get_endpoint_upload_destination()["path"]
207
208 1
    def deploy(self, name, obj, description="", schema=None, override=False, is_public=False):
209
        """Deploys a Python function as an endpoint in the server.
210
211
        Parameters
212
        ----------
213
        name : str
214
            A unique identifier for the endpoint.
215
216
        obj :  function
217
            Refers to a user-defined function with any signature. However both
218
            input and output of the function need to be JSON serializable.
219
220
        description : str, optional
221
            The description for the endpoint. This string will be returned by
222
            the ``endpoints`` API.
223
224
        schema : dict, optional
225
            The schema of the function, containing information about input and
226
            output parameters, and respective examples. Providing a schema for
227
            a deployed function lets other users of the service discover how to
228
            use it. Refer to schema.generate_schema for more information on
229
            how to generate the schema.
230
231
        override : bool
232
            Whether to override (update) an existing endpoint. If False and
233
            there is already an endpoint with that name, it will raise a
234
            RuntimeError. If True and there is already an endpoint with that
235
            name, it will deploy a new version on top of it.
236
237
        is_public : bool, optional
238
            Whether a function should be public for viewing from within tableau. If
239
            False, function will not appear in the custom functions explorer within
240
            Tableau. If True, function will be visible ta anyone on a site with this
241
            analytics extension configured
242
243
        See Also
244
        --------
245
        remove, get_endpoints
246
        """
247 1
        if self._remote_server:
248 1
            return self._remote_deploy(
249
                name, obj,
250
                description=description, schema=schema, override=override, is_public=is_public
251
            )
252
253 1
        endpoint = self.get_endpoints().get(name)
254 1
        version = 1
255 1
        if endpoint:
256 1
            if not override:
257
                raise RuntimeError(
258
                    f"An endpoint with that name ({name}) already"
259
                    ' exists. Use "override = True" to force update '
260
                    "an existing endpoint."
261
                )
262
263 1
            version = endpoint.version + 1
264
265 1
        obj = self._gen_endpoint(name, obj, description, version, schema, is_public)
266
267 1
        self._upload_endpoint(obj)
268
269 1
        if version == 1:
270 1
            self._service.add_endpoint(Endpoint(**obj))
271
        else:
272 1
            self._service.set_endpoint(Endpoint(**obj), should_update_version=True)
273
274 1
        self._wait_for_endpoint_deployment(obj["name"], obj["version"])
275
276 1
    def remove(self, name):
277
        '''Removes an endpoint dict.
278
279
        Parameters
280
        ----------
281
        name : str
282
            Endpoint name to remove'''
283
        self._service.remove_endpoint(name)
284
285 1
    def update_endpoint_info(self, name, description=None, schema=None, is_public=None):
286
        '''Updates description, schema, or is public for an existing endpoint
287
288
        Parameters
289
        ----------
290
        name : str
291
            Name of the endpoint that to be updated. If endpoint does not exist
292
            runtime error will be thrown
293
294
        description : str, optional
295
            The description for the endpoint. This string will be returned by
296
            the ``endpoints`` API.
297
298
        schema : dict, optional
299
            The schema of the function, containing information about input and
300
            output parameters, and respective examples. Providing a schema for
301
            a deployed function lets other users of the service discover how to
302
            use it. Refer to schema.generate_schema for more information on
303
            how to generate the schema.
304
305
        is_public : bool, optional
306
            Whether a function should be public for viewing from within tableau. If
307
            False, function will not appear in the custom functions explorer within
308
            Tableau. If True, function will be visible to anyone on a site with this
309
            analytics extension configured
310
        '''
311
312
        endpoint = self.get_endpoints().get(name)
313
314
        if not endpoint:
315
            raise RuntimeError(
316
                f"No endpoint with that name ({name}) exists"
317
                " Please select an existing endpoint to update"
318
            )
319
320
        if description is not None:
321
            if type(description) is not str:
322
                raise RuntimeError(
323
                    f"Type of description must be string"
324
                )
325
            endpoint.description = description
326
        if schema is not None:
327
            if type(schema) is not dict:
328
                raise RuntimeError(
329
                    f"Type of schema must be dictionary"
330
                )
331
            endpoint.schema = schema
332
        if is_public is not None:
333
            if type(is_public) is not bool:
334
                raise RuntimeError(
335
                    f"Type of is_public must be bool"
336
                )
337
            endpoint.is_public = is_public
338
339
        dest_path = self._get_endpoint_upload_destination()
340
341
        endpoint.src_path = os.path.join(
342
            dest_path, "endpoints", endpoint.name, str(endpoint.version)
343
        )
344
        self._service.set_endpoint(endpoint, should_update_version=False)
345
346 1
    def _gen_endpoint(self, name, obj, description, version=1, schema=None, is_public=False):
347
        """Generates an endpoint dict.
348
349
        Parameters
350
        ----------
351
        name : str
352
            Endpoint name to add or update
353
354
        obj :  func
355
            Object that backs the endpoint. See add() for a complete
356
            description.
357
358
        description : str
359
            Description of the endpoint
360
361
        version : int
362
            The version. Defaults to 1.
363
364
        is_public : bool
365
            True if function should be visible in the custom functions explorer
366
            within Tableau
367
368
        Returns
369
        -------
370
        dict
371
            Keys:
372
                name : str
373
                    The name provided.
374
375
                version : int
376
                    The version provided.
377
378
                description : str
379
                    The provided description.
380
381
                type : str
382
                    The type of the endpoint.
383
384
                endpoint_obj : object
385
                    The wrapper around the obj provided that can be used to
386
                    generate the code and dependencies for the endpoint.
387
388
        Raises
389
        ------
390
        TypeError
391
            When obj is not one of the expected types.
392
        """
393
        # check for invalid PO names
394 1
        _check_endpoint_name(name)
395
396 1
        if description is None:
397
            description = obj.__doc__.strip() or "" if isinstance(obj.__doc__, str) else ""
398
399 1
        endpoint_object = CustomQueryObject(query=obj, description=description,)
400
401 1
        return {
402
            "name": name,
403
            "version": version,
404
            "description": description,
405
            "type": "model",
406
            "endpoint_obj": endpoint_object,
407
            "dependencies": endpoint_object.get_dependencies(),
408
            "methods": endpoint_object.get_methods(),
409
            "required_files": [],
410
            "required_packages": [],
411
            "docstring": endpoint_object.get_docstring(),
412
            "schema": copy.copy(schema),
413
            "is_public": is_public,
414
        }
415
416 1
    def _upload_endpoint(self, obj):
417
        """Sends the endpoint across the wire."""
418 1
        endpoint_obj = obj["endpoint_obj"]
419
420 1
        dest_path = self._get_endpoint_upload_destination()
421
422
        # Upload the endpoint
423 1
        obj["src_path"] = os.path.join(
424
            dest_path, "endpoints", obj["name"], str(obj["version"])
425
        )
426
427 1
        endpoint_obj.save(obj["src_path"])
428
429 1
    def _wait_for_endpoint_deployment(
430
        self, endpoint_name, version=1, interval=1.0,
431
    ):
432
        """
433
        Waits for the endpoint to be deployed by calling get_status() and
434
        checking the versions deployed of the endpoint against the expected
435
        version. If all the versions are equal to or greater than the version
436
        expected, then it will return. Uses time.sleep().
437
        """
438 1
        logger.info(
439
            f"Waiting for endpoint {endpoint_name} to deploy to " f"version {version}"
440
        )
441 1
        time.sleep(interval)
442 1
        start = time.time()
443 1
        while True:
444 1
            ep_status = self.get_status()
445 1
            try:
446 1
                ep = ep_status[endpoint_name]
447
            except KeyError:
448
                logger.info(
449
                    f"Endpoint {endpoint_name} doesn't " "exist in endpoints yet"
450
                )
451
            else:
452 1
                logger.info(f"ep={ep}")
453
454 1
                if ep["status"] == "LoadFailed":
455
                    raise RuntimeError(f'LoadFailed: {ep["last_error"]}')
456
457 1
                elif ep["status"] == "LoadSuccessful":
458 1
                    if ep["version"] >= version:
459 1
                        logger.info("LoadSuccessful")
460 1
                        break
461
                    else:
462
                        logger.info("LoadSuccessful but wrong version")
463
464
            if time.time() - start > 10:
465
                raise RuntimeError("Waited more then 10s for deployment")
466
467
            logger.info(f"Sleeping {interval}...")
468
            time.sleep(interval)
469
470 1
    def _remote_deploy(
471
        self, name, obj, description="", schema=None, override=False, is_public=False
472
    ):
473
        """
474
        Remotely deploy a Python function using the /evaluate endpoint. Takes the same inputs
475
        as deploy.
476
        """
477 1
        remote_script = self._gen_remote_script()
478 1
        remote_script += f"{inspect.getsource(obj)}\n"
479
480 1
        remote_script += (
481
            f"client.deploy("
482
            f"'{name}', {obj.__name__}, '{description}', "
483
            f"override={override}, is_public={is_public}, schema={schema}"
484
            f")"
485
        )
486
487 1
        return self._evaluate_remote_script(remote_script)
488
489 1
    def _gen_remote_script(self):
490
        """
491
        Generates a remote script for TabPy client connection with credential handling.
492
493
        Returns:
494
            str: A Python script to establish a TabPy client connection
495
        """
496 1
        remote_script = [
497
            "from tabpy.tabpy_tools.client import Client",
498
            f"client = Client('{self._localhost_endpoint or self._endpoint}')"
499
        ]
500
501 1
        remote_script.append(
502
            f"client.set_credentials('{auth.username}', '{auth.password}')"
503
        ) if (auth := self._service.service_client.network_wrapper.auth) else None
504
505 1
        return "\n".join(remote_script) + "\n"
506
507 1
    def _evaluate_remote_script(self, remote_script):
508
        """
509
        Uses TabPy /evaluate endpoint to execute a remote TabPy client script.
510
511
        Parameters
512
        ----------
513
        remote_script : str
514
            The script to execute remotely.
515
        """
516
        print(f"Remote script:\n{remote_script}\n")
517
        url = f"{self._endpoint}evaluate"
518
        headers = {"Content-Type": "application/json"}
519
        payload = {"data": {}, "script": remote_script}
520
521
        response = requests.post(
522
            url,
523
            headers=headers,
524
            auth=self._service.service_client.network_wrapper.auth,
525
            json=payload
526
        )
527
528
        msg = response.text.replace('null', 'Success')
529
        if "Ad-hoc scripts have been disabled" in msg:
530
            msg += "\n[Deployment to remote tabpy client not allowed.]"
531
532
        status_message = (f"{response.status_code} - {msg}\n")
533
        print(status_message)
534
        return status_message
535
536 1
    def set_credentials(self, username, password):
537
        """
538
        Set credentials for all the TabPy client-server communication
539
        where client is tabpy-tools and server is tabpy-server.
540
541
        Parameters
542
        ----------
543
        username : str
544
            User name (login). Username is case insensitive.
545
546
        password : str
547
            Password in plain text.
548
        """
549
        self._service.set_credentials(username, password)
550