Passed
Push — master ( 53016e...38fde4 )
by Oleksandr
02:52
created

tabpy.tabpy_tools.client.Client.remove()   A

Complexity

Conditions 1

Size

Total Lines 8
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 2
dl 0
loc 8
rs 10
c 0
b 0
f 0
cc 1
nop 2
1
import copy
2
from re import compile
3
import time
4
import requests
5
6
from .rest import RequestsNetworkWrapper, ServiceClient
7
8
from .rest_client import RESTServiceClient, Endpoint
9
10
from .custom_query_object import CustomQueryObject
11
import os
12
import logging
13
14
logger = logging.getLogger(__name__)
15
16
_name_checker = compile(r"^[\w -]+$")
17
18
19
def _check_endpoint_type(name):
20
    if not isinstance(name, str):
21
        raise TypeError("Endpoint name must be a string")
22
23
    if name == "":
24
        raise ValueError("Endpoint name cannot be empty")
25
26
27
def _check_hostname(name):
28
    _check_endpoint_type(name)
29
    hostname_checker = compile(r"^^http(s)?://[\w.-]+(/)?(:\d+)?(/)?$")
30
31
    if not hostname_checker.match(name):
32
        raise ValueError(
33
            f"endpoint name {name} should be in http(s)://<hostname>"
34
            "[:<port>] and hostname may consist only of: "
35
            "a-z, A-Z, 0-9, underscore and hyphens."
36
        )
37
38
39
def _check_endpoint_name(name):
40
    """Checks that the endpoint name is valid by comparing it with an RE and
41
    checking that it is not reserved."""
42
    _check_endpoint_type(name)
43
44
    if not _name_checker.match(name):
45
        raise ValueError(
46
            f"endpoint name {name} can only contain: a-z, A-Z, 0-9,"
47
            " underscore, hyphens and spaces."
48
        )
49
50
51
class Client:
52
    def __init__(self, endpoint, query_timeout=1000):
53
        """
54
        Connects to a running server.
55
56
        The class constructor takes a server address which is then used to
57
        connect for all subsequent member APIs.
58
59
        Parameters
60
        ----------
61
        endpoint : str, optional
62
            The server URL.
63
64
        query_timeout : float, optional
65
            The timeout for query operations.
66
        """
67
        _check_hostname(endpoint)
68
69
        self._endpoint = endpoint
70
71
        session = requests.session()
72
        session.verify = False
73
        requests.packages.urllib3.disable_warnings()
74
75
        # Setup the communications layer.
76
        network_wrapper = RequestsNetworkWrapper(session)
77
        service_client = ServiceClient(self._endpoint, network_wrapper)
78
79
        self._service = RESTServiceClient(service_client)
80
        if type(query_timeout) in (int, float) and query_timeout > 0:
81
            self._service.query_timeout = query_timeout
82
        else:
83
            self._service.query_timeout = 0.0
84
85
    def __repr__(self):
86
        return (
87
            "<"
88
            + self.__class__.__name__
89
            + " object at "
90
            + hex(id(self))
91
            + " connected to "
92
            + repr(self._endpoint)
93
            + ">"
94
        )
95
96
    def get_status(self):
97
        """
98
        Gets the status of the deployed endpoints.
99
100
        Returns
101
        -------
102
        dict
103
            Keys are endpoints and values are dicts describing the state of
104
            the endpoint.
105
106
        Examples
107
        --------
108
        .. sourcecode:: python
109
            {
110
                u'foo': {
111
                    u'status': u'LoadFailed',
112
                    u'last_error': u'error mesasge',
113
                    u'version': 1,
114
                    u'type': u'model',
115
                },
116
            }
117
        """
118
        return self._service.get_status()
119
120
    #
121
    # Query
122
    #
123
124
    @property
125
    def query_timeout(self):
126
        """The timeout for queries in milliseconds."""
127
        return self._service.query_timeout
128
129
    @query_timeout.setter
130
    def query_timeout(self, value):
131
        if type(value) in (int, float) and value > 0:
132
            self._service.query_timeout = value
133
134
    def query(self, name, *args, **kwargs):
135
        """Query an endpoint.
136
137
        Parameters
138
        ----------
139
        name : str
140
            The name of the endpoint.
141
142
        *args : list of anything
143
            Ordered parameters to the endpoint.
144
145
        **kwargs : dict of anything
146
            Named parameters to the endpoint.
147
148
        Returns
149
        -------
150
        dict
151
            Keys are:
152
                model: the name of the endpoint
153
                version: the version used.
154
                response: the response to the query.
155
                uuid : a unique id for the request.
156
        """
157
        return self._service.query(name, *args, **kwargs)
158
159
    #
160
    # Endpoints
161
    #
162
163
    def get_endpoints(self, type=None):
164
        """Returns all deployed endpoints.
165
166
        Examples
167
        --------
168
        .. sourcecode:: python
169
            {"clustering":
170
              {"description": "",
171
               "docstring": "-- no docstring found in query function --",
172
               "creation_time": 1469511182,
173
               "version": 1,
174
               "dependencies": [],
175
               "last_modified_time": 1469511182,
176
               "type": "model",
177
               "target": null},
178
            "add": {
179
              "description": "",
180
              "docstring": "-- no docstring found in query function --",
181
              "creation_time": 1469505967,
182
              "version": 1,
183
              "dependencies": [],
184
              "last_modified_time": 1469505967,
185
              "type": "model",
186
              "target": null}
187
            }
188
        """
189
        return self._service.get_endpoints(type)
190
191
    def _get_endpoint_upload_destination(self):
192
        """Returns the endpoint upload destination."""
193
        return self._service.get_endpoint_upload_destination()["path"]
194
195
    def deploy(self, name, obj, description="", schema=None, override=False):
196
        """Deploys a Python function as an endpoint in the server.
197
198
        Parameters
199
        ----------
200
        name : str
201
            A unique identifier for the endpoint.
202
203
        obj :  function
204
            Refers to a user-defined function with any signature. However both
205
            input and output of the function need to be JSON serializable.
206
207
        description : str, optional
208
            The description for the endpoint. This string will be returned by
209
            the ``endpoints`` API.
210
211
        schema : dict, optional
212
            The schema of the function, containing information about input and
213
            output parameters, and respective examples. Providing a schema for
214
            a deployed function lets other users of the service discover how to
215
            use it. Refer to schema.generate_schema for more information on
216
            how to generate the schema.
217
218
        override : bool
219
            Whether to override (update) an existing endpoint. If False and
220
            there is already an endpoint with that name, it will raise a
221
            RuntimeError. If True and there is already an endpoint with that
222
            name, it will deploy a new version on top of it.
223
224
        See Also
225
        --------
226
        remove, get_endpoints
227
        """
228
        endpoint = self.get_endpoints().get(name)
229
        if endpoint:
230
            if not override:
231
                raise RuntimeError(
232
                    f"An endpoint with that name ({name}) already"
233
                    ' exists. Use "override = True" to force update '
234
                    "an existing endpoint."
235
                )
236
237
            version = endpoint.version + 1
238
        else:
239
            version = 1
240
241
        obj = self._gen_endpoint(name, obj, description, version, schema)
242
243
        self._upload_endpoint(obj)
244
245
        if version == 1:
246
            self._service.add_endpoint(Endpoint(**obj))
247
        else:
248
            self._service.set_endpoint(Endpoint(**obj))
249
250
        self._wait_for_endpoint_deployment(obj["name"], obj["version"])
251
252
    def remove(self, name):
253
        '''Removes an endpoint dict.
254
255
        Parameters
256
        ----------
257
        name : str
258
            Endpoint name to remove'''
259
        self._service.remove_endpoint(name)
260
261
    def _gen_endpoint(self, name, obj, description, version=1, schema=None):
262
        """Generates an endpoint dict.
263
264
        Parameters
265
        ----------
266
        name : str
267
            Endpoint name to add or update
268
269
        obj :  func
270
            Object that backs the endpoint. See add() for a complete
271
            description.
272
273
        description : str
274
            Description of the endpoint
275
276
        version : int
277
            The version. Defaults to 1.
278
279
        Returns
280
        -------
281
        dict
282
            Keys:
283
                name : str
284
                    The name provided.
285
286
                version : int
287
                    The version provided.
288
289
                description : str
290
                    The provided description.
291
292
                type : str
293
                    The type of the endpoint.
294
295
                endpoint_obj : object
296
                    The wrapper around the obj provided that can be used to
297
                    generate the code and dependencies for the endpoint.
298
299
        Raises
300
        ------
301
        TypeError
302
            When obj is not one of the expected types.
303
        """
304
        # check for invalid PO names
305
        _check_endpoint_name(name)
306
307
        if description is None:
308
            if isinstance(obj.__doc__, str):
309
                # extract doc string
310
                description = obj.__doc__.strip() or ""
311
            else:
312
                description = ""
313
314
        endpoint_object = CustomQueryObject(query=obj, description=description,)
315
316
        return {
317
            "name": name,
318
            "version": version,
319
            "description": description,
320
            "type": "model",
321
            "endpoint_obj": endpoint_object,
322
            "dependencies": endpoint_object.get_dependencies(),
323
            "methods": endpoint_object.get_methods(),
324
            "required_files": [],
325
            "required_packages": [],
326
            "schema": copy.copy(schema),
327
        }
328
329
    def _upload_endpoint(self, obj):
330
        """Sends the endpoint across the wire."""
331
        endpoint_obj = obj["endpoint_obj"]
332
333
        dest_path = self._get_endpoint_upload_destination()
334
335
        # Upload the endpoint
336
        obj["src_path"] = os.path.join(
337
            dest_path, "endpoints", obj["name"], str(obj["version"])
338
        )
339
340
        endpoint_obj.save(obj["src_path"])
341
342
    def _wait_for_endpoint_deployment(
343
        self, endpoint_name, version=1, interval=1.0,
344
    ):
345
        """
346
        Waits for the endpoint to be deployed by calling get_status() and
347
        checking the versions deployed of the endpoint against the expected
348
        version. If all the versions are equal to or greater than the version
349
        expected, then it will return. Uses time.sleep().
350
        """
351
        logger.info(
352
            f"Waiting for endpoint {endpoint_name} to deploy to " f"version {version}"
353
        )
354
        start = time.time()
355
        while True:
356
            ep_status = self.get_status()
357
            try:
358
                ep = ep_status[endpoint_name]
359
            except KeyError:
360
                logger.info(
361
                    f"Endpoint {endpoint_name} doesn't " "exist in endpoints yet"
362
                )
363
            else:
364
                logger.info(f"ep={ep}")
365
366
                if ep["status"] == "LoadFailed":
367
                    raise RuntimeError(f'LoadFailed: {ep["last_error"]}')
368
369
                elif ep["status"] == "LoadSuccessful":
370
                    if ep["version"] >= version:
371
                        logger.info("LoadSuccessful")
372
                        break
373
                    else:
374
                        logger.info("LoadSuccessful but wrong version")
375
376
            if time.time() - start > 10:
377
                raise RuntimeError("Waited more then 10s for deployment")
378
379
            logger.info(f"Sleeping {interval}...")
380
            time.sleep(interval)
381
382
    def set_credentials(self, username, password):
383
        """
384
        Set credentials for all the TabPy client-server communication
385
        where client is tabpy-tools and server is tabpy-server.
386
387
        Parameters
388
        ----------
389
        username : str
390
            User name (login). Username is case insensitive.
391
392
        password : str
393
            Password in plain text.
394
        """
395
        self._service.set_credentials(username, password)
396