Passed
Pull Request — master (#653)
by
unknown
16:32 queued 51s
created

tabpy.tabpy_tools.client.Client.remove()   A

Complexity

Conditions 1

Size

Total Lines 8
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Code Coverage

Tests 1
CRAP Score 1.125

Importance

Changes 0
Metric Value
eloc 2
dl 0
loc 8
ccs 1
cts 2
cp 0.5
rs 10
c 0
b 0
f 0
cc 1
nop 2
crap 1.125
1 1
import copy
2 1
import inspect
3 1
from re import compile
4 1
import time
5 1
import requests
6
7 1
from .rest import RequestsNetworkWrapper, ServiceClient
8
9 1
from .rest_client import RESTServiceClient, Endpoint
10
11 1
from .custom_query_object import CustomQueryObject
12 1
import os
13 1
import logging
14
15 1
logger = logging.getLogger(__name__)
16
17 1
_name_checker = compile(r"^[\w -]+$")
18
19
20 1
def _check_endpoint_type(name):
21 1
    if not isinstance(name, str):
22 1
        raise TypeError("Endpoint name must be a string")
23
24 1
    if name == "":
25 1
        raise ValueError("Endpoint name cannot be empty")
26
27
28 1
def _check_hostname(name):
29 1
    _check_endpoint_type(name)
30 1
    hostname_checker = compile(r"^^http(s)?://[\w.-]+(/)?(:\d+)?(/)?$")
31
32 1
    if not hostname_checker.match(name):
33 1
        raise ValueError(
34
            f"endpoint name {name} should be in http(s)://<hostname>"
35
            "[:<port>] and hostname may consist only of: "
36
            "a-z, A-Z, 0-9, underscore and hyphens."
37
        )
38
39
40 1
def _check_endpoint_name(name):
41
    """Checks that the endpoint name is valid by comparing it with an RE and
42
    checking that it is not reserved."""
43 1
    _check_endpoint_type(name)
44
45 1
    if not _name_checker.match(name):
46 1
        raise ValueError(
47
            f"endpoint name {name} can only contain: a-z, A-Z, 0-9,"
48
            " underscore, hyphens and spaces."
49
        )
50
51
52 1
class Client:
53 1
    def __init__(self, endpoint, query_timeout=1000):
54
        """
55
        Connects to a running server.
56
57
        The class constructor takes a server address which is then used to
58
        connect for all subsequent member APIs.
59
60
        Parameters
61
        ----------
62
        endpoint : str, optional
63
            The server URL.
64
65
        query_timeout : float, optional
66
            The timeout for query operations.
67
        """
68 1
        _check_hostname(endpoint)
69
70 1
        self._endpoint = endpoint
71
72 1
        session = requests.session()
73 1
        session.verify = False
74 1
        requests.packages.urllib3.disable_warnings()
75
76
        # Setup the communications layer.
77 1
        network_wrapper = RequestsNetworkWrapper(session)
78 1
        service_client = ServiceClient(self._endpoint, network_wrapper)
79
80 1
        self._service = RESTServiceClient(service_client)
81 1
        if not type(query_timeout) in (int, float) or query_timeout <= 0:
82 1
            query_timeout = 0.0
83 1
        self._service.query_timeout = query_timeout
84
85 1
    def __repr__(self):
86
        return (
87
            "<"
88
            + self.__class__.__name__
89
            + " object at "
90
            + hex(id(self))
91
            + " connected to "
92
            + repr(self._endpoint)
93
            + ">"
94
        )
95
96 1
    def get_status(self):
97
        """
98
        Gets the status of the deployed endpoints.
99
100
        Returns
101
        -------
102
        dict
103
            Keys are endpoints and values are dicts describing the state of
104
            the endpoint.
105
106
        Examples
107
        --------
108
        .. sourcecode:: python
109
            {
110
                u'foo': {
111
                    u'status': u'LoadFailed',
112
                    u'last_error': u'error mesasge',
113
                    u'version': 1,
114
                    u'type': u'model',
115
                },
116
            }
117
        """
118 1
        return self._service.get_status()
119
120
    #
121
    # Query
122
    #
123
124 1
    @property
125 1
    def query_timeout(self):
126
        """The timeout for queries in milliseconds."""
127 1
        return self._service.query_timeout
128
129 1
    @query_timeout.setter
130 1
    def query_timeout(self, value):
131 1
        if type(value) in (int, float) and value > 0:
132 1
            self._service.query_timeout = value
133
134 1
    def query(self, name, *args, **kwargs):
135
        """Query an endpoint.
136
137
        Parameters
138
        ----------
139
        name : str
140
            The name of the endpoint.
141
142
        *args : list of anything
143
            Ordered parameters to the endpoint.
144
145
        **kwargs : dict of anything
146
            Named parameters to the endpoint.
147
148
        Returns
149
        -------
150
        dict
151
            Keys are:
152
                model: the name of the endpoint
153
                version: the version used.
154
                response: the response to the query.
155
                uuid : a unique id for the request.
156
        """
157 1
        return self._service.query(name, *args, **kwargs)
158
159
    #
160
    # Endpoints
161
    #
162
163 1
    def get_endpoints(self, type=None):
164
        """Returns all deployed endpoints.
165
166
        Examples
167
        --------
168
        .. sourcecode:: python
169
            {"clustering":
170
              {"description": "",
171
               "docstring": "-- no docstring found in query function --",
172
               "creation_time": 1469511182,
173
               "version": 1,
174
               "dependencies": [],
175
               "last_modified_time": 1469511182,
176
               "type": "model",
177
               "target": null,
178
               "is_public": True}
179
            "add": {
180
              "description": "",
181
              "docstring": "-- no docstring found in query function --",
182
              "creation_time": 1469505967,
183
              "version": 1,
184
              "dependencies": [],
185
              "last_modified_time": 1469505967,
186
              "type": "model",
187
              "target": null,
188
              "is_public": False}
189
            }
190
        """
191 1
        return self._service.get_endpoints(type)
192
193 1
    def _get_endpoint_upload_destination(self):
194
        """Returns the endpoint upload destination."""
195 1
        return self._service.get_endpoint_upload_destination()["path"]
196
197 1
    def deploy(self, name, obj, description="", schema=None, override=False, is_public=False):
198
        """Deploys a Python function as an endpoint in the server.
199
200
        Parameters
201
        ----------
202
        name : str
203
            A unique identifier for the endpoint.
204
205
        obj :  function
206
            Refers to a user-defined function with any signature. However both
207
            input and output of the function need to be JSON serializable.
208
209
        description : str, optional
210
            The description for the endpoint. This string will be returned by
211
            the ``endpoints`` API.
212
213
        schema : dict, optional
214
            The schema of the function, containing information about input and
215
            output parameters, and respective examples. Providing a schema for
216
            a deployed function lets other users of the service discover how to
217
            use it. Refer to schema.generate_schema for more information on
218
            how to generate the schema.
219
220
        override : bool
221
            Whether to override (update) an existing endpoint. If False and
222
            there is already an endpoint with that name, it will raise a
223
            RuntimeError. If True and there is already an endpoint with that
224
            name, it will deploy a new version on top of it.
225
226
        is_public : bool, optional
227
            Whether a function should be public for viewing from within tableau. If
228
            False, function will not appear in the custom functions explorer within
229
            Tableau. If True, function will be visible ta anyone on a site with this
230
            analytics extension configured
231
232
        See Also
233
        --------
234
        remove, get_endpoints
235
        """
236 1
        endpoint = self.get_endpoints().get(name)
237 1
        version = 1
238 1
        if endpoint:
239 1
            if not override:
240
                raise RuntimeError(
241
                    f"An endpoint with that name ({name}) already"
242
                    ' exists. Use "override = True" to force update '
243
                    "an existing endpoint."
244
                )
245
246 1
            version = endpoint.version + 1
247
248 1
        obj = self._gen_endpoint(name, obj, description, version, schema, is_public)
249
250 1
        self._upload_endpoint(obj)
251
252 1
        if version == 1:
253 1
            self._service.add_endpoint(Endpoint(**obj))
254
        else:
255 1
            self._service.set_endpoint(Endpoint(**obj), should_update_version=True)
256
257 1
        self._wait_for_endpoint_deployment(obj["name"], obj["version"])
258
259 1
    def remove(self, name):
260
        '''Removes an endpoint dict.
261
262
        Parameters
263
        ----------
264
        name : str
265
            Endpoint name to remove'''
266
        self._service.remove_endpoint(name)
267
268 1
    def update_endpoint_info(self, name, description=None, schema=None, is_public=None):
269
        '''Updates description, schema, or is public for an existing endpoint
270
271
        Parameters
272
        ----------
273
        name : str
274
            Name of the endpoint that to be updated. If endpoint does not exist
275
            runtime error will be thrown
276
277
        description : str, optional
278
            The description for the endpoint. This string will be returned by
279
            the ``endpoints`` API.
280
281
        schema : dict, optional
282
            The schema of the function, containing information about input and
283
            output parameters, and respective examples. Providing a schema for
284
            a deployed function lets other users of the service discover how to
285
            use it. Refer to schema.generate_schema for more information on
286
            how to generate the schema.
287
288
        is_public : bool, optional
289
            Whether a function should be public for viewing from within tableau. If
290
            False, function will not appear in the custom functions explorer within
291
            Tableau. If True, function will be visible to anyone on a site with this
292
            analytics extension configured
293
        '''
294
295
        endpoint = self.get_endpoints().get(name)
296
297
        if not endpoint:
298
            raise RuntimeError(
299
                f"No endpoint with that name ({name}) exists"
300
                " Please select an existing endpoint to update"
301
            )
302
303
        if description is not None:
304
            if type(description) is not str:
305
                raise RuntimeError(
306
                    f"Type of description must be string"
307
                )
308
            endpoint.description = description
309
        if schema is not None:
310
            if type(schema) is not dict:
311
                raise RuntimeError(
312
                    f"Type of schema must be dictionary"
313
                )
314
            endpoint.schema = schema
315
        if is_public is not None:
316
            if type(is_public) is not bool:
317
                raise RuntimeError(
318
                    f"Type of is_public must be bool"
319
                )
320
            endpoint.is_public = is_public
321
322
        dest_path = self._get_endpoint_upload_destination()
323
324
        endpoint.src_path = os.path.join(
325
            dest_path, "endpoints", endpoint.name, str(endpoint.version)
326
        )
327
        self._service.set_endpoint(endpoint, should_update_version=False)
328
329 1
    def _gen_endpoint(self, name, obj, description, version=1, schema=None, is_public=False):
330
        """Generates an endpoint dict.
331
332
        Parameters
333
        ----------
334
        name : str
335
            Endpoint name to add or update
336
337
        obj :  func
338
            Object that backs the endpoint. See add() for a complete
339
            description.
340
341
        description : str
342
            Description of the endpoint
343
344
        version : int
345
            The version. Defaults to 1.
346
347
        is_public : bool
348
            True if function should be visible in the custom functions explorer
349
            within Tableau
350
351
        Returns
352
        -------
353
        dict
354
            Keys:
355
                name : str
356
                    The name provided.
357
358
                version : int
359
                    The version provided.
360
361
                description : str
362
                    The provided description.
363
364
                type : str
365
                    The type of the endpoint.
366
367
                endpoint_obj : object
368
                    The wrapper around the obj provided that can be used to
369
                    generate the code and dependencies for the endpoint.
370
371
        Raises
372
        ------
373
        TypeError
374
            When obj is not one of the expected types.
375
        """
376
        # check for invalid PO names
377 1
        _check_endpoint_name(name)
378
379 1
        if description is None:
380
            description = obj.__doc__.strip() or "" if isinstance(obj.__doc__, str) else ""
381
382 1
        endpoint_object = CustomQueryObject(query=obj, description=description,)
383
384 1
        return {
385
            "name": name,
386
            "version": version,
387
            "description": description,
388
            "type": "model",
389
            "endpoint_obj": endpoint_object,
390
            "dependencies": endpoint_object.get_dependencies(),
391
            "methods": endpoint_object.get_methods(),
392
            "required_files": [],
393
            "required_packages": [],
394
            "docstring": endpoint_object.get_doc_string(),
395
            "schema": copy.copy(schema),
396
            "is_public": is_public,
397
        }
398
399 1
    def _upload_endpoint(self, obj):
400
        """Sends the endpoint across the wire."""
401 1
        endpoint_obj = obj["endpoint_obj"]
402
403 1
        dest_path = self._get_endpoint_upload_destination()
404
405
        # Upload the endpoint
406 1
        obj["src_path"] = os.path.join(
407
            dest_path, "endpoints", obj["name"], str(obj["version"])
408
        )
409
410 1
        endpoint_obj.save(obj["src_path"])
411
412 1
    def _wait_for_endpoint_deployment(
413
        self, endpoint_name, version=1, interval=1.0,
414
    ):
415
        """
416
        Waits for the endpoint to be deployed by calling get_status() and
417
        checking the versions deployed of the endpoint against the expected
418
        version. If all the versions are equal to or greater than the version
419
        expected, then it will return. Uses time.sleep().
420
        """
421 1
        logger.info(
422
            f"Waiting for endpoint {endpoint_name} to deploy to " f"version {version}"
423
        )
424 1
        start = time.time()
425 1
        while True:
426 1
            ep_status = self.get_status()
427 1
            try:
428 1
                ep = ep_status[endpoint_name]
429
            except KeyError:
430
                logger.info(
431
                    f"Endpoint {endpoint_name} doesn't " "exist in endpoints yet"
432
                )
433
            else:
434 1
                logger.info(f"ep={ep}")
435
436 1
                if ep["status"] == "LoadFailed":
437
                    raise RuntimeError(f'LoadFailed: {ep["last_error"]}')
438
439 1
                elif ep["status"] == "LoadSuccessful":
440 1
                    if ep["version"] >= version:
441 1
                        logger.info("LoadSuccessful")
442 1
                        break
443
                    else:
444
                        logger.info("LoadSuccessful but wrong version")
445
446
            if time.time() - start > 10:
447
                raise RuntimeError("Waited more then 10s for deployment")
448
449
            logger.info(f"Sleeping {interval}...")
450
            time.sleep(interval)
451
452 1
    def set_credentials(self, username, password):
453
        """
454
        Set credentials for all the TabPy client-server communication
455
        where client is tabpy-tools and server is tabpy-server.
456
457
        Parameters
458
        ----------
459
        username : str
460
            User name (login). Username is case insensitive.
461
462
        password : str
463
            Password in plain text.
464
        """
465
        self._service.set_credentials(username, password)
466