Passed
Push — master ( 53016e...38fde4 )
by Oleksandr
02:52
created

TabPyState._get_config_value()   B

Complexity

Conditions 5

Size

Total Lines 27
Code Lines 20

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 20
dl 0
loc 27
rs 8.9332
c 0
b 0
f 0
cc 5
nop 5
1
try:
2
    from ConfigParser import ConfigParser
3
except ImportError:
4
    from configparser import ConfigParser
5
import json
6
import logging
7
from tabpy.tabpy_server.management.util import write_state_config
8
from threading import Lock
9
from time import time
10
11
12
logger = logging.getLogger(__name__)
13
14
# State File Config Section Names
15
_DEPLOYMENT_SECTION_NAME = "Query Objects Service Versions"
16
_QUERY_OBJECT_DOCSTRING = "Query Objects Docstrings"
17
_SERVICE_INFO_SECTION_NAME = "Service Info"
18
_META_SECTION_NAME = "Meta"
19
20
# Directory Names
21
_QUERY_OBJECT_DIR = "query_objects"
22
23
"""
24
Lock to change the TabPy State.
25
"""
26
_PS_STATE_LOCK = Lock()
27
28
29
def state_lock(func):
30
    """
31
    Mutex for changing PS state
32
    """
33
34
    def wrapper(self, *args, **kwargs):
35
        try:
36
            _PS_STATE_LOCK.acquire()
37
            return func(self, *args, **kwargs)
38
        finally:
39
            # ALWAYS RELEASE LOCK
40
            _PS_STATE_LOCK.release()
41
42
    return wrapper
43
44
45
def _get_root_path(state_path):
46
    if state_path[-1] != "/":
47
        return state_path + "/"
48
    else:
49
        return state_path
50
51
52
def get_query_object_path(state_file_path, name, version):
53
    """
54
    Returns the query object path
55
56
    If the version is None, a path without the version will be returned.
57
    """
58
    root_path = _get_root_path(state_file_path)
59
    if version is not None:
60
        full_path = root_path + "/".join([_QUERY_OBJECT_DIR, name, str(version)])
61
    else:
62
        full_path = root_path + "/".join([_QUERY_OBJECT_DIR, name])
63
    return full_path
64
65
66
class TabPyState:
67
    """
68
    The TabPy state object that stores attributes
69
    about this TabPy and perform GET/SET on these
70
    attributes.
71
72
    Attributes:
73
        - name
74
        - description
75
        - endpoints (name, description, docstring, version, target)
76
        - revision number
77
78
    When the state object is initialized, the state is saved as a ConfigParser.
79
    There is a config to any attribute.
80
81
    """
82
83
    def __init__(self, settings, config=None):
84
        self.settings = settings
85
        self.set_config(config, _update=False)
86
87
    @state_lock
88
    def set_config(self, config, logger=logging.getLogger(__name__), _update=True):
89
        """
90
        Set the local ConfigParser manually.
91
        This new ConfigParser will be used as current state.
92
        """
93
        if not isinstance(config, ConfigParser):
94
            raise ValueError("Invalid config")
95
        self.config = config
96
        if _update:
97
            self._write_state(logger)
98
99
    def get_endpoints(self, name=None):
100
        """
101
        Return a dictionary of endpoints
102
103
        Parameters
104
        ----------
105
        name : str
106
            The name of the endpoint.
107
            If "name" is specified, only the information about that endpoint
108
            will be returned.
109
110
        Returns
111
        -------
112
        endpoints : dict
113
            The dictionary containing information about each endpoint.
114
            The keys are the endpoint names.
115
            The values for each include:
116
                - description
117
                - doc string
118
                - type
119
                - target
120
121
        """
122
        endpoints = {}
123
        try:
124
            endpoint_names = self._get_config_value(_DEPLOYMENT_SECTION_NAME, name)
125
        except Exception as e:
126
            logger.error(f"error in get_endpoints: {str(e)}")
127
            return {}
128
129
        if name:
130
            endpoint_info = json.loads(endpoint_names)
131
            docstring = self._get_config_value(_QUERY_OBJECT_DOCSTRING, name)
132
            endpoint_info["docstring"] = str(
133
                bytes(docstring, "utf-8").decode("unicode_escape")
134
            )
135
            endpoints = {name: endpoint_info}
136
        else:
137
            for endpoint_name in endpoint_names:
138
                endpoint_info = json.loads(
139
                    self._get_config_value(_DEPLOYMENT_SECTION_NAME, endpoint_name)
140
                )
141
                docstring = self._get_config_value(
142
                    _QUERY_OBJECT_DOCSTRING, endpoint_name, True, ""
143
                )
144
                endpoint_info["docstring"] = str(
145
                    bytes(docstring, "utf-8").decode("unicode_escape")
146
                )
147
                endpoints[endpoint_name] = endpoint_info
148
        logger.debug(f"Collected endpoints: {endpoints}")
149
        return endpoints
150
151
    @state_lock
152
    def add_endpoint(
153
        self,
154
        name,
155
        description=None,
156
        docstring=None,
157
        endpoint_type=None,
158
        methods=None,
159
        target=None,
160
        dependencies=None,
161
        schema=None,
162
    ):
163
        """
164
        Add a new endpoint to the TabPy.
165
166
        Parameters
167
        ----------
168
        name : str
169
            Name of the endpoint
170
        description : str, optional
171
            Description of this endpoint
172
        doc_string : str, optional
173
            The doc string for this endpoint, if needed.
174
        endpoint_type : str
175
            The endpoint type (model, alias)
176
        target : str, optional
177
            The target endpoint name for the alias to be added.
178
179
        Note:
180
        The version of this endpoint will be set to 1 since it is a new
181
        endpoint.
182
183
        """
184
        try:
185
            endpoints = self.get_endpoints()
186
            if name is None or not isinstance(name, str) or len(name) == 0:
187
                raise ValueError("name of the endpoint must be a valid string.")
188
            elif name in endpoints:
189
                raise ValueError(f"endpoint {name} already exists.")
190
            if description and not isinstance(description, str):
191
                raise ValueError("description must be a string.")
192
            elif not description:
193
                description = ""
194
            if docstring and not isinstance(docstring, str):
195
                raise ValueError("docstring must be a string.")
196
            elif not docstring:
197
                docstring = "-- no docstring found in query function --"
198
            if not endpoint_type or not isinstance(endpoint_type, str):
199
                raise ValueError("endpoint type must be a string.")
200
            if dependencies and not isinstance(dependencies, list):
201
                raise ValueError("dependencies must be a list.")
202
            elif not dependencies:
203
                dependencies = []
204
            if target and not isinstance(target, str):
205
                raise ValueError("target must be a string.")
206
            elif target and target not in endpoints:
207
                raise ValueError("target endpoint is not valid.")
208
209
            endpoint_info = {
210
                "description": description,
211
                "docstring": docstring,
212
                "type": endpoint_type,
213
                "version": 1,
214
                "dependencies": dependencies,
215
                "target": target,
216
                "creation_time": int(time()),
217
                "last_modified_time": int(time()),
218
                "schema": schema,
219
            }
220
221
            endpoints[name] = endpoint_info
222
            self._add_update_endpoints_config(endpoints)
223
        except Exception as e:
224
            logger.error(f"Error in add_endpoint: {e}")
225
            raise
226
227
    def _add_update_endpoints_config(self, endpoints):
228
        # save the endpoint info to config
229
        dstring = ""
230
        for endpoint_name in endpoints:
231
            try:
232
                info = endpoints[endpoint_name]
233
                dstring = str(
234
                    bytes(info["docstring"], "utf-8").decode("unicode_escape")
235
                )
236
                self._set_config_value(
237
                    _QUERY_OBJECT_DOCSTRING,
238
                    endpoint_name,
239
                    dstring,
240
                    _update_revision=False,
241
                )
242
                del info["docstring"]
243
                self._set_config_value(
244
                    _DEPLOYMENT_SECTION_NAME, endpoint_name, json.dumps(info)
245
                )
246
            except Exception as e:
247
                logger.error(f"Unable to write endpoints config: {e}")
248
                raise
249
250
    @state_lock
251
    def update_endpoint(
252
        self,
253
        name,
254
        description=None,
255
        docstring=None,
256
        endpoint_type=None,
257
        version=None,
258
        methods=None,
259
        target=None,
260
        dependencies=None,
261
        schema=None,
262
    ):
263
        """
264
        Update an existing endpoint on the TabPy.
265
266
        Parameters
267
        ----------
268
        name : str
269
            Name of the endpoint
270
        description : str, optional
271
            Description of this endpoint
272
        doc_string : str, optional
273
            The doc string for this endpoint, if needed.
274
        endpoint_type : str, optional
275
            The endpoint type (model, alias)
276
        version : str, optional
277
            The version of this endpoint
278
        dependencies=[]
279
            List of dependent endpoints for this existing endpoint
280
        target : str, optional
281
            The target endpoint name for the alias.
282
283
        Note:
284
        For those parameters that are not specified, those values will not
285
        get changed.
286
287
        """
288
        try:
289
            endpoints = self.get_endpoints()
290
            if not name or not isinstance(name, str):
291
                raise ValueError("name of the endpoint must be string.")
292
            elif name not in endpoints:
293
                raise ValueError(f"endpoint {name} does not exist.")
294
295
            endpoint_info = endpoints[name]
296
297
            if description and not isinstance(description, str):
298
                raise ValueError("description must be a string.")
299
            elif not description:
300
                description = endpoint_info["description"]
301
            if docstring and not isinstance(docstring, str):
302
                raise ValueError("docstring must be a string.")
303
            elif not docstring:
304
                docstring = endpoint_info["docstring"]
305
            if endpoint_type and not isinstance(endpoint_type, str):
306
                raise ValueError("endpoint type must be a string.")
307
            elif not endpoint_type:
308
                endpoint_type = endpoint_info["type"]
309
            if version and not isinstance(version, int):
310
                raise ValueError("version must be an int.")
311
            elif not version:
312
                version = endpoint_info["version"]
313
            if dependencies and not isinstance(dependencies, list):
314
                raise ValueError("dependencies must be a list.")
315
            elif not dependencies:
316
                if "dependencies" in endpoint_info:
317
                    dependencies = endpoint_info["dependencies"]
318
                else:
319
                    dependencies = []
320
            if target and not isinstance(target, str):
321
                raise ValueError("target must be a string.")
322
            elif target and target not in endpoints:
323
                raise ValueError("target endpoint is not valid.")
324
            elif not target:
325
                target = endpoint_info["target"]
326
            endpoint_info = {
327
                "description": description,
328
                "docstring": docstring,
329
                "type": endpoint_type,
330
                "version": version,
331
                "dependencies": dependencies,
332
                "target": target,
333
                "creation_time": endpoint_info["creation_time"],
334
                "last_modified_time": int(time()),
335
                "schema": schema,
336
            }
337
338
            endpoints[name] = endpoint_info
339
            self._add_update_endpoints_config(endpoints)
340
        except Exception as e:
341
            logger.error(f"Error in update_endpoint: {e}")
342
            raise
343
344
    @state_lock
345
    def delete_endpoint(self, name):
346
        """
347
        Delete an existing endpoint on the TabPy
348
349
        Parameters
350
        ----------
351
        name : str
352
            The name of the endpoint to be deleted.
353
354
        Returns
355
        -------
356
        deleted endpoint object
357
358
        Note:
359
        Cannot delete this endpoint if other endpoints are currently
360
        depending on this endpoint.
361
362
        """
363
        if not name or name == "":
364
            raise ValueError("Name of the endpoint must be a valid string.")
365
        endpoints = self.get_endpoints()
366
        if name not in endpoints:
367
            raise ValueError(f"Endpoint {name} does not exist.")
368
369
        endpoint_to_delete = endpoints[name]
370
371
        # get dependencies and target
372
        deps = set()
373
        for endpoint_name in endpoints:
374
            if endpoint_name != name:
375
                deps_list = endpoints[endpoint_name].get("dependencies", [])
376
                if name in deps_list:
377
                    deps.add(endpoint_name)
378
379
        # check if other endpoints are depending on this endpoint
380
        if len(deps) > 0:
381
            raise ValueError(
382
                f"Cannot remove endpoint {name}, it is currently "
383
                f"used by {list(deps)} endpoints."
384
            )
385
386
        del endpoints[name]
387
388
        # delete the endpoint from state
389
        try:
390
            self._remove_config_option(
391
                _QUERY_OBJECT_DOCSTRING, name, _update_revision=False
392
            )
393
            self._remove_config_option(_DEPLOYMENT_SECTION_NAME, name)
394
395
            return endpoint_to_delete
396
        except Exception as e:
397
            logger.error(f"Unable to delete endpoint {e}")
398
            raise ValueError(f"Unable to delete endpoint: {e}")
399
400
    @property
401
    def name(self):
402
        """
403
        Returns the name of the TabPy service.
404
        """
405
        name = None
406
        try:
407
            name = self._get_config_value(_SERVICE_INFO_SECTION_NAME, "Name")
408
        except Exception as e:
409
            logger.error(f"Unable to get name: {e}")
410
        return name
411
412
    @property
413
    def creation_time(self):
414
        """
415
        Returns the creation time of the TabPy service.
416
        """
417
        creation_time = 0
418
        try:
419
            creation_time = self._get_config_value(
420
                _SERVICE_INFO_SECTION_NAME, "Creation Time"
421
            )
422
        except Exception as e:
423
            logger.error(f"Unable to get name: {e}")
424
        return creation_time
425
426
    @state_lock
427
    def set_name(self, name):
428
        """
429
        Set the name of this TabPy service.
430
431
        Parameters
432
        ----------
433
        name : str
434
            Name of TabPy service.
435
        """
436
        if not isinstance(name, str):
437
            raise ValueError("name must be a string.")
438
        try:
439
            self._set_config_value(_SERVICE_INFO_SECTION_NAME, "Name", name)
440
        except Exception as e:
441
            logger.error(f"Unable to set name: {e}")
442
443
    def get_description(self):
444
        """
445
        Returns the description of the TabPy service.
446
        """
447
        description = None
448
        try:
449
            description = self._get_config_value(
450
                _SERVICE_INFO_SECTION_NAME, "Description"
451
            )
452
        except Exception as e:
453
            logger.error(f"Unable to get description: {e}")
454
        return description
455
456
    @state_lock
457
    def set_description(self, description):
458
        """
459
        Set the description of this TabPy service.
460
461
        Parameters
462
        ----------
463
        description : str
464
            Description of TabPy service.
465
        """
466
        if not isinstance(description, str):
467
            raise ValueError("Description must be a string.")
468
        try:
469
            self._set_config_value(
470
                _SERVICE_INFO_SECTION_NAME, "Description", description
471
            )
472
        except Exception as e:
473
            logger.error(f"Unable to set description: {e}")
474
475
    def get_revision_number(self):
476
        """
477
        Returns the revision number of this TabPy service.
478
        """
479
        rev = -1
480
        try:
481
            rev = int(self._get_config_value(_META_SECTION_NAME, "Revision Number"))
482
        except Exception as e:
483
            logger.error(f"Unable to get revision number: {e}")
484
        return rev
485
486
    def get_access_control_allow_origin(self):
487
        """
488
        Returns Access-Control-Allow-Origin of this TabPy service.
489
        """
490
        _cors_origin = ""
491
        try:
492
            logger.debug("Collecting Access-Control-Allow-Origin from state file ...")
493
            _cors_origin = self._get_config_value(
494
                "Service Info", "Access-Control-Allow-Origin"
495
            )
496
        except Exception as e:
497
            logger.error(e)
498
        return _cors_origin
499
500
    def get_access_control_allow_headers(self):
501
        """
502
        Returns Access-Control-Allow-Headers of this TabPy service.
503
        """
504
        _cors_headers = ""
505
        try:
506
            _cors_headers = self._get_config_value(
507
                "Service Info", "Access-Control-Allow-Headers"
508
            )
509
        except Exception:
510
            pass
511
        return _cors_headers
512
513
    def get_access_control_allow_methods(self):
514
        """
515
        Returns Access-Control-Allow-Methods of this TabPy service.
516
        """
517
        _cors_methods = ""
518
        try:
519
            _cors_methods = self._get_config_value(
520
                "Service Info", "Access-Control-Allow-Methods"
521
            )
522
        except Exception:
523
            pass
524
        return _cors_methods
525
526
    def _set_revision_number(self, revision_number):
527
        """
528
        Set the revision number of this TabPy service.
529
        """
530
        if not isinstance(revision_number, int):
531
            raise ValueError("revision number must be an int.")
532
        try:
533
            self._set_config_value(
534
                _META_SECTION_NAME, "Revision Number", revision_number
535
            )
536
        except Exception as e:
537
            logger.error(f"Unable to set revision number: {e}")
538
539
    def _remove_config_option(
540
        self,
541
        section_name,
542
        option_name,
543
        logger=logging.getLogger(__name__),
544
        _update_revision=True,
545
    ):
546
        if not self.config:
547
            raise ValueError("State configuration not yet loaded.")
548
        self.config.remove_option(section_name, option_name)
549
        # update revision number
550
        if _update_revision:
551
            self._increase_revision_number()
552
        self._write_state(logger=logger)
553
554
    def _has_config_value(self, section_name, option_name):
555
        if not self.config:
556
            raise ValueError("State configuration not yet loaded.")
557
        return self.config.has_option(section_name, option_name)
558
559
    def _increase_revision_number(self):
560
        if not self.config:
561
            raise ValueError("State configuration not yet loaded.")
562
        cur_rev = int(self.config.get(_META_SECTION_NAME, "Revision Number"))
563
        self.config.set(_META_SECTION_NAME, "Revision Number", str(cur_rev + 1))
564
565
    def _set_config_value(
566
        self,
567
        section_name,
568
        option_name,
569
        option_value,
570
        logger=logging.getLogger(__name__),
571
        _update_revision=True,
572
    ):
573
        if not self.config:
574
            raise ValueError("State configuration not yet loaded.")
575
576
        if not self.config.has_section(section_name):
577
            logger.log(logging.DEBUG, f"Adding config section {section_name}")
578
            self.config.add_section(section_name)
579
580
        self.config.set(section_name, option_name, option_value)
581
        # update revision number
582
        if _update_revision:
583
            self._increase_revision_number()
584
        self._write_state(logger=logger)
585
586
    def _get_config_items(self, section_name):
587
        if not self.config:
588
            raise ValueError("State configuration not yet loaded.")
589
        return self.config.items(section_name)
590
591
    def _get_config_value(
592
        self, section_name, option_name, optional=False, default_value=None
593
    ):
594
        logger.log(
595
            logging.DEBUG,
596
            f"Loading option '{option_name}' from section [{section_name}]...")
597
598
        if not self.config:
599
            msg = "State configuration not yet loaded."
600
            logging.log(msg)
601
            raise ValueError(msg)
602
603
        res = None
604
        if not option_name:
605
            res = self.config.options(section_name)
606
        elif self.config.has_option(section_name, option_name):
607
            res = self.config.get(section_name, option_name)
608
        elif optional:
609
            res = default_value
610
        else:
611
            raise ValueError(
612
                f"Cannot find option name {option_name} "
613
                f"under section {section_name}"
614
            )
615
616
        logger.log(logging.DEBUG, f"Returning value '{res}'")
617
        return res
618
619
    def _write_state(self, logger=logging.getLogger(__name__)):
620
        """
621
        Write state (ConfigParser) to Consul
622
        """
623
        logger.log(logging.INFO, "Writing state to config")
624
        write_state_config(self.config, self.settings, logger=logger)
625