Passed
Push — master ( ba1484...e45894 )
by Oleksandr
11:55
created

tabpy.tabpy_tools.client   A

Complexity

Total Complexity 37

Size/Duplication

Total Lines 390
Duplicated Lines 0 %

Test Coverage

Coverage 88.24%

Importance

Changes 0
Metric Value
wmc 37
eloc 129
dl 0
loc 390
ccs 90
cts 102
cp 0.8824
rs 9.44
c 0
b 0
f 0

3 Functions

Rating   Name   Duplication   Size   Complexity  
A _check_endpoint_name() 0 8 2
A _check_hostname() 0 7 2
A _check_endpoint_type() 0 6 3

13 Methods

Rating   Name   Duplication   Size   Complexity  
A Client.__repr__() 0 9 1
A Client.remove() 0 8 1
A Client._gen_endpoint() 0 62 3
A Client._get_endpoint_upload_destination() 0 3 1
A Client.query() 0 24 1
A Client._upload_endpoint() 0 12 1
A Client.__init__() 0 31 3
A Client.get_endpoints() 0 27 1
A Client.get_status() 0 23 1
A Client.set_credentials() 0 14 1
B Client._wait_for_endpoint_deployment() 0 39 8
A Client.query_timeout() 0 4 3
A Client.deploy() 0 55 4
1 1
import copy
2 1
from re import compile
3 1
import time
4 1
import requests
5
6 1
from .rest import RequestsNetworkWrapper, ServiceClient
7
8 1
from .rest_client import RESTServiceClient, Endpoint
9
10 1
from .custom_query_object import CustomQueryObject
11 1
import os
12 1
import logging
13
14 1
logger = logging.getLogger(__name__)
15
16 1
_name_checker = compile(r"^[\w -]+$")
17
18
19 1
def _check_endpoint_type(name):
20 1
    if not isinstance(name, str):
21 1
        raise TypeError("Endpoint name must be a string")
22
23 1
    if name == "":
24 1
        raise ValueError("Endpoint name cannot be empty")
25
26
27 1
def _check_hostname(name):
28 1
    _check_endpoint_type(name)
29 1
    hostname_checker = compile(r"^^http(s)?://[\w.-]+(/)?(:\d+)?(/)?$")
30
31 1
    if not hostname_checker.match(name):
32 1
        raise ValueError(
33
            f"endpoint name {name} should be in http(s)://<hostname>"
34
            "[:<port>] and hostname may consist only of: "
35
            "a-z, A-Z, 0-9, underscore and hyphens."
36
        )
37
38
39 1
def _check_endpoint_name(name):
40
    """Checks that the endpoint name is valid by comparing it with an RE and
41
    checking that it is not reserved."""
42 1
    _check_endpoint_type(name)
43
44 1
    if not _name_checker.match(name):
45 1
        raise ValueError(
46
            f"endpoint name {name} can only contain: a-z, A-Z, 0-9,"
47
            " underscore, hyphens and spaces."
48
        )
49
50
51 1
class Client:
52 1
    def __init__(self, endpoint, query_timeout=1000):
53
        """
54
        Connects to a running server.
55
56
        The class constructor takes a server address which is then used to
57
        connect for all subsequent member APIs.
58
59
        Parameters
60
        ----------
61
        endpoint : str, optional
62
            The server URL.
63
64
        query_timeout : float, optional
65
            The timeout for query operations.
66
        """
67 1
        _check_hostname(endpoint)
68
69 1
        self._endpoint = endpoint
70
71 1
        session = requests.session()
72 1
        session.verify = False
73 1
        requests.packages.urllib3.disable_warnings()
74
75
        # Setup the communications layer.
76 1
        network_wrapper = RequestsNetworkWrapper(session)
77 1
        service_client = ServiceClient(self._endpoint, network_wrapper)
78
79 1
        self._service = RESTServiceClient(service_client)
80 1
        if not type(query_timeout) in (int, float) or query_timeout <= 0:
81 1
            query_timeout = 0.0
82 1
        self._service.query_timeout = query_timeout
83
84 1
    def __repr__(self):
85
        return (
86
            "<"
87
            + self.__class__.__name__
88
            + " object at "
89
            + hex(id(self))
90
            + " connected to "
91
            + repr(self._endpoint)
92
            + ">"
93
        )
94
95 1
    def get_status(self):
96
        """
97
        Gets the status of the deployed endpoints.
98
99
        Returns
100
        -------
101
        dict
102
            Keys are endpoints and values are dicts describing the state of
103
            the endpoint.
104
105
        Examples
106
        --------
107
        .. sourcecode:: python
108
            {
109
                u'foo': {
110
                    u'status': u'LoadFailed',
111
                    u'last_error': u'error mesasge',
112
                    u'version': 1,
113
                    u'type': u'model',
114
                },
115
            }
116
        """
117 1
        return self._service.get_status()
118
119
    #
120
    # Query
121
    #
122
123 1
    @property
124
    def query_timeout(self):
125
        """The timeout for queries in milliseconds."""
126 1
        return self._service.query_timeout
127
128 1
    @query_timeout.setter
129
    def query_timeout(self, value):
130 1
        if type(value) in (int, float) and value > 0:
131 1
            self._service.query_timeout = value
132
133 1
    def query(self, name, *args, **kwargs):
134
        """Query an endpoint.
135
136
        Parameters
137
        ----------
138
        name : str
139
            The name of the endpoint.
140
141
        *args : list of anything
142
            Ordered parameters to the endpoint.
143
144
        **kwargs : dict of anything
145
            Named parameters to the endpoint.
146
147
        Returns
148
        -------
149
        dict
150
            Keys are:
151
                model: the name of the endpoint
152
                version: the version used.
153
                response: the response to the query.
154
                uuid : a unique id for the request.
155
        """
156 1
        return self._service.query(name, *args, **kwargs)
157
158
    #
159
    # Endpoints
160
    #
161
162 1
    def get_endpoints(self, type=None):
163
        """Returns all deployed endpoints.
164
165
        Examples
166
        --------
167
        .. sourcecode:: python
168
            {"clustering":
169
              {"description": "",
170
               "docstring": "-- no docstring found in query function --",
171
               "creation_time": 1469511182,
172
               "version": 1,
173
               "dependencies": [],
174
               "last_modified_time": 1469511182,
175
               "type": "model",
176
               "target": null},
177
            "add": {
178
              "description": "",
179
              "docstring": "-- no docstring found in query function --",
180
              "creation_time": 1469505967,
181
              "version": 1,
182
              "dependencies": [],
183
              "last_modified_time": 1469505967,
184
              "type": "model",
185
              "target": null}
186
            }
187
        """
188 1
        return self._service.get_endpoints(type)
189
190 1
    def _get_endpoint_upload_destination(self):
191
        """Returns the endpoint upload destination."""
192 1
        return self._service.get_endpoint_upload_destination()["path"]
193
194 1
    def deploy(self, name, obj, description="", schema=None, override=False):
195
        """Deploys a Python function as an endpoint in the server.
196
197
        Parameters
198
        ----------
199
        name : str
200
            A unique identifier for the endpoint.
201
202
        obj :  function
203
            Refers to a user-defined function with any signature. However both
204
            input and output of the function need to be JSON serializable.
205
206
        description : str, optional
207
            The description for the endpoint. This string will be returned by
208
            the ``endpoints`` API.
209
210
        schema : dict, optional
211
            The schema of the function, containing information about input and
212
            output parameters, and respective examples. Providing a schema for
213
            a deployed function lets other users of the service discover how to
214
            use it. Refer to schema.generate_schema for more information on
215
            how to generate the schema.
216
217
        override : bool
218
            Whether to override (update) an existing endpoint. If False and
219
            there is already an endpoint with that name, it will raise a
220
            RuntimeError. If True and there is already an endpoint with that
221
            name, it will deploy a new version on top of it.
222
223
        See Also
224
        --------
225
        remove, get_endpoints
226
        """
227 1
        endpoint = self.get_endpoints().get(name)
228 1
        version = 1
229 1
        if endpoint:
230 1
            if not override:
231
                raise RuntimeError(
232
                    f"An endpoint with that name ({name}) already"
233
                    ' exists. Use "override = True" to force update '
234
                    "an existing endpoint."
235
                )
236
237 1
            version = endpoint.version + 1
238
239 1
        obj = self._gen_endpoint(name, obj, description, version, schema)
240
241 1
        self._upload_endpoint(obj)
242
243 1
        if version == 1:
244 1
            self._service.add_endpoint(Endpoint(**obj))
245
        else:
246 1
            self._service.set_endpoint(Endpoint(**obj))
247
248 1
        self._wait_for_endpoint_deployment(obj["name"], obj["version"])
249
250 1
    def remove(self, name):
251
        '''Removes an endpoint dict.
252
253
        Parameters
254
        ----------
255
        name : str
256
            Endpoint name to remove'''
257
        self._service.remove_endpoint(name)
258
259 1
    def _gen_endpoint(self, name, obj, description, version=1, schema=None):
260
        """Generates an endpoint dict.
261
262
        Parameters
263
        ----------
264
        name : str
265
            Endpoint name to add or update
266
267
        obj :  func
268
            Object that backs the endpoint. See add() for a complete
269
            description.
270
271
        description : str
272
            Description of the endpoint
273
274
        version : int
275
            The version. Defaults to 1.
276
277
        Returns
278
        -------
279
        dict
280
            Keys:
281
                name : str
282
                    The name provided.
283
284
                version : int
285
                    The version provided.
286
287
                description : str
288
                    The provided description.
289
290
                type : str
291
                    The type of the endpoint.
292
293
                endpoint_obj : object
294
                    The wrapper around the obj provided that can be used to
295
                    generate the code and dependencies for the endpoint.
296
297
        Raises
298
        ------
299
        TypeError
300
            When obj is not one of the expected types.
301
        """
302
        # check for invalid PO names
303 1
        _check_endpoint_name(name)
304
305 1
        if description is None:
306
            description = obj.__doc__.strip() or "" if isinstance(obj.__doc__, str) else ""
307
308 1
        endpoint_object = CustomQueryObject(query=obj, description=description,)
309
310 1
        return {
311
            "name": name,
312
            "version": version,
313
            "description": description,
314
            "type": "model",
315
            "endpoint_obj": endpoint_object,
316
            "dependencies": endpoint_object.get_dependencies(),
317
            "methods": endpoint_object.get_methods(),
318
            "required_files": [],
319
            "required_packages": [],
320
            "schema": copy.copy(schema),
321
        }
322
323 1
    def _upload_endpoint(self, obj):
324
        """Sends the endpoint across the wire."""
325 1
        endpoint_obj = obj["endpoint_obj"]
326
327 1
        dest_path = self._get_endpoint_upload_destination()
328
329
        # Upload the endpoint
330 1
        obj["src_path"] = os.path.join(
331
            dest_path, "endpoints", obj["name"], str(obj["version"])
332
        )
333
334 1
        endpoint_obj.save(obj["src_path"])
335
336 1
    def _wait_for_endpoint_deployment(
337
        self, endpoint_name, version=1, interval=1.0,
338
    ):
339
        """
340
        Waits for the endpoint to be deployed by calling get_status() and
341
        checking the versions deployed of the endpoint against the expected
342
        version. If all the versions are equal to or greater than the version
343
        expected, then it will return. Uses time.sleep().
344
        """
345 1
        logger.info(
346
            f"Waiting for endpoint {endpoint_name} to deploy to " f"version {version}"
347
        )
348 1
        start = time.time()
349 1
        while True:
350 1
            ep_status = self.get_status()
351 1
            try:
352 1
                ep = ep_status[endpoint_name]
353
            except KeyError:
354
                logger.info(
355
                    f"Endpoint {endpoint_name} doesn't " "exist in endpoints yet"
356
                )
357
            else:
358 1
                logger.info(f"ep={ep}")
359
360 1
                if ep["status"] == "LoadFailed":
361
                    raise RuntimeError(f'LoadFailed: {ep["last_error"]}')
362
363 1
                elif ep["status"] == "LoadSuccessful":
364 1
                    if ep["version"] >= version:
365 1
                        logger.info("LoadSuccessful")
366 1
                        break
367
                    else:
368
                        logger.info("LoadSuccessful but wrong version")
369
370
            if time.time() - start > 10:
371
                raise RuntimeError("Waited more then 10s for deployment")
372
373
            logger.info(f"Sleeping {interval}...")
374
            time.sleep(interval)
375
376 1
    def set_credentials(self, username, password):
377
        """
378
        Set credentials for all the TabPy client-server communication
379
        where client is tabpy-tools and server is tabpy-server.
380
381
        Parameters
382
        ----------
383
        username : str
384
            User name (login). Username is case insensitive.
385
386
        password : str
387
            Password in plain text.
388
        """
389
        self._service.set_credentials(username, password)
390