Passed
Pull Request — master (#646)
by
unknown
15:35
created

tabpy.tabpy_tools.client.Client.update()   A

Complexity

Conditions 5

Size

Total Lines 49
Code Lines 15

Duplication

Lines 0
Ratio 0 %

Code Coverage

Tests 1
CRAP Score 24.6646

Importance

Changes 0
Metric Value
eloc 15
dl 0
loc 49
ccs 1
cts 13
cp 0.0769
rs 9.1832
c 0
b 0
f 0
cc 5
nop 5
crap 24.6646
1 1
import copy
2 1
from re import compile
3 1
import time
4 1
import requests
5
6 1
from .rest import RequestsNetworkWrapper, ServiceClient
7
8 1
from .rest_client import RESTServiceClient, Endpoint
9
10 1
from .custom_query_object import CustomQueryObject
11 1
import os
12 1
import logging
13
14 1
logger = logging.getLogger(__name__)
15
16 1
_name_checker = compile(r"^[\w -]+$")
17
18
19 1
def _check_endpoint_type(name):
20 1
    if not isinstance(name, str):
21 1
        raise TypeError("Endpoint name must be a string")
22
23 1
    if name == "":
24 1
        raise ValueError("Endpoint name cannot be empty")
25
26
27 1
def _check_hostname(name):
28 1
    _check_endpoint_type(name)
29 1
    hostname_checker = compile(r"^^http(s)?://[\w.-]+(/)?(:\d+)?(/)?$")
30
31 1
    if not hostname_checker.match(name):
32 1
        raise ValueError(
33
            f"endpoint name {name} should be in http(s)://<hostname>"
34
            "[:<port>] and hostname may consist only of: "
35
            "a-z, A-Z, 0-9, underscore and hyphens."
36
        )
37
38
39 1
def _check_endpoint_name(name):
40
    """Checks that the endpoint name is valid by comparing it with an RE and
41
    checking that it is not reserved."""
42 1
    _check_endpoint_type(name)
43
44 1
    if not _name_checker.match(name):
45 1
        raise ValueError(
46
            f"endpoint name {name} can only contain: a-z, A-Z, 0-9,"
47
            " underscore, hyphens and spaces."
48
        )
49
50
51 1
class Client:
52 1
    def __init__(self, endpoint, query_timeout=1000):
53
        """
54
        Connects to a running server.
55
56
        The class constructor takes a server address which is then used to
57
        connect for all subsequent member APIs.
58
59
        Parameters
60
        ----------
61
        endpoint : str, optional
62
            The server URL.
63
64
        query_timeout : float, optional
65
            The timeout for query operations.
66
        """
67 1
        _check_hostname(endpoint)
68
69 1
        self._endpoint = endpoint
70
71 1
        session = requests.session()
72 1
        session.verify = False
73 1
        requests.packages.urllib3.disable_warnings()
74
75
        # Setup the communications layer.
76 1
        network_wrapper = RequestsNetworkWrapper(session)
77 1
        service_client = ServiceClient(self._endpoint, network_wrapper)
78
79 1
        self._service = RESTServiceClient(service_client)
80 1
        if not type(query_timeout) in (int, float) or query_timeout <= 0:
81 1
            query_timeout = 0.0
82 1
        self._service.query_timeout = query_timeout
83
84 1
    def __repr__(self):
85
        return (
86
            "<"
87
            + self.__class__.__name__
88
            + " object at "
89
            + hex(id(self))
90
            + " connected to "
91
            + repr(self._endpoint)
92
            + ">"
93
        )
94
95 1
    def get_status(self):
96
        """
97
        Gets the status of the deployed endpoints.
98
99
        Returns
100
        -------
101
        dict
102
            Keys are endpoints and values are dicts describing the state of
103
            the endpoint.
104
105
        Examples
106
        --------
107
        .. sourcecode:: python
108
            {
109
                u'foo': {
110
                    u'status': u'LoadFailed',
111
                    u'last_error': u'error mesasge',
112
                    u'version': 1,
113
                    u'type': u'model',
114
                },
115
            }
116
        """
117 1
        return self._service.get_status()
118
119
    #
120
    # Query
121
    #
122
123 1
    @property
124 1
    def query_timeout(self):
125
        """The timeout for queries in milliseconds."""
126 1
        return self._service.query_timeout
127
128 1
    @query_timeout.setter
129 1
    def query_timeout(self, value):
130 1
        if type(value) in (int, float) and value > 0:
131 1
            self._service.query_timeout = value
132
133 1
    def query(self, name, *args, **kwargs):
134
        """Query an endpoint.
135
136
        Parameters
137
        ----------
138
        name : str
139
            The name of the endpoint.
140
141
        *args : list of anything
142
            Ordered parameters to the endpoint.
143
144
        **kwargs : dict of anything
145
            Named parameters to the endpoint.
146
147
        Returns
148
        -------
149
        dict
150
            Keys are:
151
                model: the name of the endpoint
152
                version: the version used.
153
                response: the response to the query.
154
                uuid : a unique id for the request.
155
        """
156 1
        return self._service.query(name, *args, **kwargs)
157
158
    #
159
    # Endpoints
160
    #
161
162 1
    def get_endpoints(self, type=None):
163
        """Returns all deployed endpoints.
164
165
        Examples
166
        --------
167
        .. sourcecode:: python
168
            {"clustering":
169
              {"description": "",
170
               "docstring": "-- no docstring found in query function --",
171
               "creation_time": 1469511182,
172
               "version": 1,
173
               "dependencies": [],
174
               "last_modified_time": 1469511182,
175
               "type": "model",
176
               "target": null,
177
               "is_public": True}
178
            "add": {
179
              "description": "",
180
              "docstring": "-- no docstring found in query function --",
181
              "creation_time": 1469505967,
182
              "version": 1,
183
              "dependencies": [],
184
              "last_modified_time": 1469505967,
185
              "type": "model",
186
              "target": null,
187
              "is_public": False}
188
            }
189
        """
190 1
        return self._service.get_endpoints(type)
191
192 1
    def _get_endpoint_upload_destination(self):
193
        """Returns the endpoint upload destination."""
194 1
        return self._service.get_endpoint_upload_destination()["path"]
195
196 1
    def deploy(self, name, obj, description="", schema=None, override=False, is_public=False):
197
        """Deploys a Python function as an endpoint in the server.
198
199
        Parameters
200
        ----------
201
        name : str
202
            A unique identifier for the endpoint.
203
204
        obj :  function
205
            Refers to a user-defined function with any signature. However both
206
            input and output of the function need to be JSON serializable.
207
208
        description : str, optional
209
            The description for the endpoint. This string will be returned by
210
            the ``endpoints`` API.
211
212
        schema : dict, optional
213
            The schema of the function, containing information about input and
214
            output parameters, and respective examples. Providing a schema for
215
            a deployed function lets other users of the service discover how to
216
            use it. Refer to schema.generate_schema for more information on
217
            how to generate the schema.
218
219
        override : bool
220
            Whether to override (update) an existing endpoint. If False and
221
            there is already an endpoint with that name, it will raise a
222
            RuntimeError. If True and there is already an endpoint with that
223
            name, it will deploy a new version on top of it.
224
225
        is_public : bool, optional
226
            Whether a function should be public for viewing from within tableau. If
227
            False, function will not appear in the custom functions explorer within
228
            Tableau. If True, function will be visible ta anyone on a site with this
229
            analytics extension configured
230
231
        See Also
232
        --------
233
        remove, get_endpoints
234
        """
235 1
        endpoint = self.get_endpoints().get(name)
236 1
        version = 1
237 1
        if endpoint:
238 1
            if not override:
239
                raise RuntimeError(
240
                    f"An endpoint with that name ({name}) already"
241
                    ' exists. Use "override = True" to force update '
242
                    "an existing endpoint."
243
                )
244
245 1
            version = endpoint.version + 1
246
247 1
        obj = self._gen_endpoint(name, obj, description, version, schema, is_public)
248
249 1
        self._upload_endpoint(obj)
250
251 1
        if version == 1:
252 1
            self._service.add_endpoint(Endpoint(**obj))
253
        else:
254 1
            self._service.set_endpoint(Endpoint(**obj))
255
256 1
        self._wait_for_endpoint_deployment(obj["name"], obj["version"])
257
258 1
    def remove(self, name):
259
        '''Removes an endpoint dict.
260
261
        Parameters
262
        ----------
263
        name : str
264
            Endpoint name to remove'''
265
        self._service.remove_endpoint(name)
266
267 1
    def update(self, name, description=None, schema=None, is_public=None):
268
        '''Updates description, schema, or is public for an existing endpoint
269
270
        Parameters
271
        ----------
272
        name : str
273
            Name of the endpoint that to be updated. If endpoint does not exist
274
            runtime error will be thrown
275
276
        description : str, optional
277
            The description for the endpoint. This string will be returned by
278
            the ``endpoints`` API.
279
280
        schema : dict, optional
281
            The schema of the function, containing information about input and
282
            output parameters, and respective examples. Providing a schema for
283
            a deployed function lets other users of the service discover how to
284
            use it. Refer to schema.generate_schema for more information on
285
            how to generate the schema.
286
287
        is_public : bool, optional
288
            Whether a function should be public for viewing from within tableau. If
289
            False, function will not appear in the custom functions explorer within
290
            Tableau. If True, function will be visible to anyone on a site with this
291
            analytics extension configured
292
        '''
293
294
        endpoint = self.get_endpoints().get(name)
295
296
        if not endpoint:
297
            raise RuntimeError(
298
                f"No endpoint with that name ({name}) exists"
299
                " Please select an existing endpoint to update"
300
            )
301
302
        if description is not None:
303
            endpoint.description = description
304
        if schema is not None:
305
            endpoint.schema = schema
306
        if is_public is not None:
307
            endpoint.is_public = is_public
308
309
        dest_path = self._get_endpoint_upload_destination()
310
311
        endpoint.src_path = os.path.join(
312
            dest_path, "endpoints", endpoint.name, str(endpoint.version)
313
        )
314
315
        self._service.set_endpoint(endpoint)
316
317 1
    def _gen_endpoint(self, name, obj, description, version=1, schema=None, is_public=False):
318
        """Generates an endpoint dict.
319
320
        Parameters
321
        ----------
322
        name : str
323
            Endpoint name to add or update
324
325
        obj :  func
326
            Object that backs the endpoint. See add() for a complete
327
            description.
328
329
        description : str
330
            Description of the endpoint
331
332
        version : int
333
            The version. Defaults to 1.
334
335
        is_public : bool
336
            True if function should be visible in the custom functions explorer
337
            within Tableau
338
339
        Returns
340
        -------
341
        dict
342
            Keys:
343
                name : str
344
                    The name provided.
345
346
                version : int
347
                    The version provided.
348
349
                description : str
350
                    The provided description.
351
352
                type : str
353
                    The type of the endpoint.
354
355
                endpoint_obj : object
356
                    The wrapper around the obj provided that can be used to
357
                    generate the code and dependencies for the endpoint.
358
359
        Raises
360
        ------
361
        TypeError
362
            When obj is not one of the expected types.
363
        """
364
        # check for invalid PO names
365 1
        _check_endpoint_name(name)
366
367 1
        if description is None:
368
            description = obj.__doc__.strip() or "" if isinstance(obj.__doc__, str) else ""
369
370 1
        endpoint_object = CustomQueryObject(query=obj, description=description,)
371
372 1
        return {
373
            "name": name,
374
            "version": version,
375
            "description": description,
376
            "type": "model",
377
            "endpoint_obj": endpoint_object,
378
            "dependencies": endpoint_object.get_dependencies(),
379
            "methods": endpoint_object.get_methods(),
380
            "required_files": [],
381
            "required_packages": [],
382
            "schema": copy.copy(schema),
383
            "is_public": is_public,
384
        }
385
386 1
    def _upload_endpoint(self, obj):
387
        """Sends the endpoint across the wire."""
388 1
        endpoint_obj = obj["endpoint_obj"]
389
390 1
        dest_path = self._get_endpoint_upload_destination()
391
392
        # Upload the endpoint
393 1
        obj["src_path"] = os.path.join(
394
            dest_path, "endpoints", obj["name"], str(obj["version"])
395
        )
396
397 1
        endpoint_obj.save(obj["src_path"])
398
399 1
    def _wait_for_endpoint_deployment(
400
        self, endpoint_name, version=1, interval=1.0,
401
    ):
402
        """
403
        Waits for the endpoint to be deployed by calling get_status() and
404
        checking the versions deployed of the endpoint against the expected
405
        version. If all the versions are equal to or greater than the version
406
        expected, then it will return. Uses time.sleep().
407
        """
408 1
        logger.info(
409
            f"Waiting for endpoint {endpoint_name} to deploy to " f"version {version}"
410
        )
411 1
        start = time.time()
412 1
        while True:
413 1
            ep_status = self.get_status()
414 1
            try:
415 1
                ep = ep_status[endpoint_name]
416
            except KeyError:
417
                logger.info(
418
                    f"Endpoint {endpoint_name} doesn't " "exist in endpoints yet"
419
                )
420
            else:
421 1
                logger.info(f"ep={ep}")
422
423 1
                if ep["status"] == "LoadFailed":
424
                    raise RuntimeError(f'LoadFailed: {ep["last_error"]}')
425
426 1
                elif ep["status"] == "LoadSuccessful":
427 1
                    if ep["version"] >= version:
428 1
                        logger.info("LoadSuccessful")
429 1
                        break
430
                    else:
431
                        logger.info("LoadSuccessful but wrong version")
432
433
            if time.time() - start > 10:
434
                raise RuntimeError("Waited more then 10s for deployment")
435
436
            logger.info(f"Sleeping {interval}...")
437
            time.sleep(interval)
438
439 1
    def set_credentials(self, username, password):
440
        """
441
        Set credentials for all the TabPy client-server communication
442
        where client is tabpy-tools and server is tabpy-server.
443
444
        Parameters
445
        ----------
446
        username : str
447
            User name (login). Username is case insensitive.
448
449
        password : str
450
            Password in plain text.
451
        """
452
        self._service.set_credentials(username, password)
453