Passed
Push — master ( f6c774...627e72 )
by
unknown
15:38
created

TabPyState.get_access_control_allow_headers()   A

Complexity

Conditions 2

Size

Total Lines 12
Code Lines 8

Duplication

Lines 0
Ratio 0 %

Code Coverage

Tests 5
CRAP Score 2.0932

Importance

Changes 0
Metric Value
eloc 8
dl 0
loc 12
ccs 5
cts 7
cp 0.7143
rs 10
c 0
b 0
f 0
cc 2
nop 1
crap 2.0932
1 1
try:
2 1
    from ConfigParser import ConfigParser
3 1
except ImportError:
4 1
    from configparser import ConfigParser
5 1
import json
6 1
import logging
7 1
from tabpy.tabpy_server.management.util import write_state_config
8 1
from threading import Lock
9 1
from time import time
10
11
12 1
logger = logging.getLogger(__name__)
13
14
# State File Config Section Names
15 1
_DEPLOYMENT_SECTION_NAME = "Query Objects Service Versions"
16 1
_QUERY_OBJECT_DOCSTRING = "Query Objects Docstrings"
17 1
_SERVICE_INFO_SECTION_NAME = "Service Info"
18 1
_META_SECTION_NAME = "Meta"
19
20
# Directory Names
21 1
_QUERY_OBJECT_DIR = "query_objects"
22
23 1
"""
24
Lock to change the TabPy State.
25
"""
26 1
_PS_STATE_LOCK = Lock()
27
28
29 1
def state_lock(func):
30
    """
31
    Mutex for changing PS state
32
    """
33
34 1
    def wrapper(self, *args, **kwargs):
35 1
        try:
36 1
            _PS_STATE_LOCK.acquire()
37 1
            return func(self, *args, **kwargs)
38
        finally:
39
            # ALWAYS RELEASE LOCK
40 1
            _PS_STATE_LOCK.release()
41
42 1
    return wrapper
43
44
45 1
def _get_root_path(state_path):
46 1
    if state_path[-1] != "/":
47 1
        state_path += "/"
48
49 1
    return state_path
50
51
52 1
def get_query_object_path(state_file_path, name, version):
53
    """
54
    Returns the query object path
55
56
    If the version is None, a path without the version will be returned.
57
    """
58 1
    root_path = _get_root_path(state_file_path)
59 1
    sub_path = [_QUERY_OBJECT_DIR, name]
60 1
    if version is not None:
61 1
        sub_path.append(str(version))
62 1
    full_path = root_path + "/".join(sub_path)
63 1
    return full_path
64
65
66 1
class TabPyState:
67
    """
68
    The TabPy state object that stores attributes
69
    about this TabPy and perform GET/SET on these
70
    attributes.
71
72
    Attributes:
73
        - name
74
        - description
75
        - endpoints (name, description, docstring, version, target)
76
        - revision number
77
78
    When the state object is initialized, the state is saved as a ConfigParser.
79
    There is a config to any attribute.
80
81
    """
82
83 1
    def __init__(self, settings, config=None):
84 1
        self.settings = settings
85 1
        self.set_config(config, _update=False)
86
87 1
    @state_lock
88 1
    def set_config(self, config, logger=logging.getLogger(__name__), _update=True):
89
        """
90
        Set the local ConfigParser manually.
91
        This new ConfigParser will be used as current state.
92
        """
93 1
        if not isinstance(config, ConfigParser):
94
            raise ValueError("Invalid config")
95 1
        self.config = config
96 1
        if _update:
97
            self._write_state(logger)
98
99 1
    def get_endpoints(self, name=None):
100
        """
101
        Return a dictionary of endpoints
102
103
        Parameters
104
        ----------
105
        name : str
106
            The name of the endpoint.
107
            If "name" is specified, only the information about that endpoint
108
            will be returned.
109
110
        Returns
111
        -------
112
        endpoints : dict
113
            The dictionary containing information about each endpoint.
114
            The keys are the endpoint names.
115
            The values for each include:
116
                - description
117
                - doc string
118
                - type
119
                - target
120
121
        """
122 1
        endpoints = {}
123 1
        try:
124 1
            endpoint_names = self._get_config_value(_DEPLOYMENT_SECTION_NAME, name)
125
        except Exception as e:
126
            logger.error(f"error in get_endpoints: {str(e)}")
127
            return {}
128
129 1
        if name:
130
            endpoint_info = json.loads(endpoint_names)
131
            docstring = self._get_config_value(_QUERY_OBJECT_DOCSTRING, name)
132
            endpoint_info["docstring"] = str(
133
                bytes(docstring, "utf-8").decode("unicode_escape")
134
            )
135
            endpoints = {name: endpoint_info}
136
        else:
137 1
            for endpoint_name in endpoint_names:
138 1
                endpoint_info = json.loads(
139
                    self._get_config_value(_DEPLOYMENT_SECTION_NAME, endpoint_name)
140
                )
141 1
                docstring = self._get_config_value(
142
                    _QUERY_OBJECT_DOCSTRING, endpoint_name, True, ""
143
                )
144 1
                endpoint_info["docstring"] = str(
145
                    bytes(docstring, "utf-8").decode("unicode_escape")
146
                )
147 1
                endpoints[endpoint_name] = endpoint_info
148 1
        logger.debug(f"Collected endpoints: {endpoints}")
149 1
        return endpoints
150
151 1
    def _check_endpoint_exists(self, name):
152
        endpoints = self.get_endpoints()
153
        if not name or not isinstance(name, str) or len(name) == 0:
154
            raise ValueError("name of the endpoint must be a valid string.")
155
156
        return name in endpoints
157
158 1
    def _check_and_set_endpoint_str_value(self, param, paramName, defaultValue):
159
        if not param and defaultValue is not None:
160
            return defaultValue
161
162
        if not param or not isinstance(param, str):
163
            raise ValueError(f"{paramName} must be a string.")
164
165
        return param
166
167 1
    def _check_and_set_endpoint_description(self, description, defaultValue):
168
        return self._check_and_set_endpoint_str_value(description, "description", defaultValue)
169
170 1
    def _check_and_set_endpoint_docstring(self, docstring, defaultValue):
171
        return self._check_and_set_endpoint_str_value(docstring, "docstring", defaultValue)
172
173 1
    def _check_and_set_endpoint_type(self, endpoint_type, defaultValue):
174
        return self._check_and_set_endpoint_str_value(
175
            endpoint_type, "endpoint type", defaultValue)
176
177 1
    def _check_target(self, target):
178
        if target and not isinstance(target, str):
179
            raise ValueError("target must be a string.")
180
181 1
    def _check_and_set_dependencies(self, dependencies, defaultValue):
182
        if not dependencies:
183
            return defaultValue
184
185
        if dependencies or not isinstance(dependencies, list):
186
            raise ValueError("dependencies must be a list.")
187
188
        return dependencies
189
190 1
    def _check_and_set_is_public(self, isPublic, defaultValue):
191
        if not isPublic:
192
            return defaultValue
193
194
        return isPublic
195
196 1
    @state_lock
197 1
    def add_endpoint(
198
        self,
199
        name,
200
        description=None,
201
        docstring=None,
202
        endpoint_type=None,
203
        methods=None,
204
        target=None,
205
        dependencies=None,
206
        schema=None,
207
        isPublic=None,
208
    ):
209
        """
210
        Add a new endpoint to the TabPy.
211
212
        Parameters
213
        ----------
214
        name : str
215
            Name of the endpoint
216
        description : str, optional
217
            Description of this endpoint
218
        doc_string : str, optional
219
            The doc string for this endpoint, if needed.
220
        endpoint_type : str
221
            The endpoint type (model, alias)
222
        target : str, optional
223
            The target endpoint name for the alias to be added.
224
225
        Note:
226
        The version of this endpoint will be set to 1 since it is a new
227
        endpoint.
228
229
        """
230
        try:
231
            if (self._check_endpoint_exists(name)):
232
                raise ValueError(f"endpoint {name} already exists.")
233
234
            endpoints = self.get_endpoints()
235
236
            description = self._check_and_set_endpoint_description(description, "")
237
            docstring = self._check_and_set_endpoint_docstring(
238
                docstring, "-- no docstring found in query function --")
239
            endpoint_type = self._check_and_set_endpoint_type(endpoint_type, None)
240
            dependencies = self._check_and_set_dependencies(dependencies, [])
241
            isPublic = self._check_and_set_is_public(isPublic, False)
242
243
            self._check_target(target)
244
            if target and target not in endpoints:
245
                raise ValueError("target endpoint is not valid.")
246
247
            endpoint_info = {
248
                "description": description,
249
                "docstring": docstring,
250
                "type": endpoint_type,
251
                "version": 1,
252
                "dependencies": dependencies,
253
                "target": target,
254
                "creation_time": int(time()),
255
                "last_modified_time": int(time()),
256
                "schema": schema,
257
                "isPublic": isPublic,
258
            }
259
260
            endpoints[name] = endpoint_info
261
            self._add_update_endpoints_config(endpoints)
262
        except Exception as e:
263
            logger.error(f"Error in add_endpoint: {e}")
264
            raise
265
266 1
    def _add_update_endpoints_config(self, endpoints):
267
        # save the endpoint info to config
268
        dstring = ""
269
        for endpoint_name in endpoints:
270
            try:
271
                info = endpoints[endpoint_name]
272
                dstring = str(
273
                    bytes(info["docstring"], "utf-8").decode("unicode_escape")
274
                )
275
                self._set_config_value(
276
                    _QUERY_OBJECT_DOCSTRING,
277
                    endpoint_name,
278
                    dstring,
279
                    _update_revision=False,
280
                )
281
                del info["docstring"]
282
                self._set_config_value(
283
                    _DEPLOYMENT_SECTION_NAME, endpoint_name, json.dumps(info)
284
                )
285
            except Exception as e:
286
                logger.error(f"Unable to write endpoints config: {e}")
287
                raise
288
289 1
    @state_lock
290 1
    def update_endpoint(
291
        self,
292
        name,
293
        description=None,
294
        docstring=None,
295
        endpoint_type=None,
296
        version=None,
297
        methods=None,
298
        target=None,
299
        dependencies=None,
300
        schema=None,
301
        isPublic=None,
302
    ):
303
        """
304
        Update an existing endpoint on the TabPy.
305
306
        Parameters
307
        ----------
308
        name : str
309
            Name of the endpoint
310
        description : str, optional
311
            Description of this endpoint
312
        doc_string : str, optional
313
            The doc string for this endpoint, if needed.
314
        endpoint_type : str, optional
315
            The endpoint type (model, alias)
316
        version : str, optional
317
            The version of this endpoint
318
        dependencies=[]
319
            List of dependent endpoints for this existing endpoint
320
        target : str, optional
321
            The target endpoint name for the alias.
322
323
        Note:
324
        For those parameters that are not specified, those values will not
325
        get changed.
326
327
        """
328
        try:
329
            if (not self._check_endpoint_exists(name)):
330
                raise ValueError(f"endpoint {name} does not exist.")
331
332
            endpoints = self.get_endpoints()
333
            endpoint_info = endpoints[name]
334
335
            description = self._check_and_set_endpoint_description(
336
                description, endpoint_info["description"])
337
            docstring = self._check_and_set_endpoint_docstring(
338
                docstring, endpoint_info["docstring"])
339
            endpoint_type = self._check_and_set_endpoint_type(
340
                endpoint_type, endpoint_info["type"])
341
            dependencies = self._check_and_set_dependencies(
342
                dependencies, endpoint_info.get("dependencies", []))
343
            isPublic = self._check_and_set_is_public(isPublic, endpoint_info["isPublic"])
344
345
            self._check_target(target)
346
            if target and target not in endpoints:
347
                raise ValueError("target endpoint is not valid.")
348
            elif not target:
349
                target = endpoint_info["target"]
350
351
            if version and not isinstance(version, int):
352
                raise ValueError("version must be an int.")
353
            elif not version:
354
                version = endpoint_info["version"]
355
356
            endpoint_info = {
357
                "description": description,
358
                "docstring": docstring,
359
                "type": endpoint_type,
360
                "version": version,
361
                "dependencies": dependencies,
362
                "target": target,
363
                "creation_time": endpoint_info["creation_time"],
364
                "last_modified_time": int(time()),
365
                "schema": schema,
366
                "isPublic": isPublic,
367
            }
368
369
            endpoints[name] = endpoint_info
370
            self._add_update_endpoints_config(endpoints)
371
        except Exception as e:
372
            logger.error(f"Error in update_endpoint: {e}")
373
            raise
374
375 1
    @state_lock
376 1
    def delete_endpoint(self, name):
377
        """
378
        Delete an existing endpoint on the TabPy
379
380
        Parameters
381
        ----------
382
        name : str
383
            The name of the endpoint to be deleted.
384
385
        Returns
386
        -------
387
        deleted endpoint object
388
389
        Note:
390
        Cannot delete this endpoint if other endpoints are currently
391
        depending on this endpoint.
392
393
        """
394
        if not name or name == "":
395
            raise ValueError("Name of the endpoint must be a valid string.")
396
        endpoints = self.get_endpoints()
397
        if name not in endpoints:
398
            raise ValueError(f"Endpoint {name} does not exist.")
399
400
        endpoint_to_delete = endpoints[name]
401
402
        # get dependencies and target
403
        deps = set()
404
        for endpoint_name in endpoints:
405
            if endpoint_name != name:
406
                deps_list = endpoints[endpoint_name].get("dependencies", [])
407
                if name in deps_list:
408
                    deps.add(endpoint_name)
409
410
        # check if other endpoints are depending on this endpoint
411
        if len(deps) > 0:
412
            raise ValueError(
413
                f"Cannot remove endpoint {name}, it is currently "
414
                f"used by {list(deps)} endpoints."
415
            )
416
417
        del endpoints[name]
418
419
        # delete the endpoint from state
420
        try:
421
            self._remove_config_option(
422
                _QUERY_OBJECT_DOCSTRING, name, _update_revision=False
423
            )
424
            self._remove_config_option(_DEPLOYMENT_SECTION_NAME, name)
425
426
            return endpoint_to_delete
427
        except Exception as e:
428
            logger.error(f"Unable to delete endpoint {e}")
429
            raise ValueError(f"Unable to delete endpoint: {e}")
430
431 1
    @property
432 1
    def name(self):
433
        """
434
        Returns the name of the TabPy service.
435
        """
436 1
        name = None
437 1
        try:
438 1
            name = self._get_config_value(_SERVICE_INFO_SECTION_NAME, "Name")
439
        except Exception as e:
440
            logger.error(f"Unable to get name: {e}")
441 1
        return name
442
443 1
    @property
444 1
    def creation_time(self):
445
        """
446
        Returns the creation time of the TabPy service.
447
        """
448 1
        creation_time = 0
449 1
        try:
450 1
            creation_time = self._get_config_value(
451
                _SERVICE_INFO_SECTION_NAME, "Creation Time"
452
            )
453
        except Exception as e:
454
            logger.error(f"Unable to get name: {e}")
455 1
        return creation_time
456
457 1
    @state_lock
458 1
    def set_name(self, name):
459
        """
460
        Set the name of this TabPy service.
461
462
        Parameters
463
        ----------
464
        name : str
465
            Name of TabPy service.
466
        """
467
        if not isinstance(name, str):
468
            raise ValueError("name must be a string.")
469
        try:
470
            self._set_config_value(_SERVICE_INFO_SECTION_NAME, "Name", name)
471
        except Exception as e:
472
            logger.error(f"Unable to set name: {e}")
473
474 1
    def get_description(self):
475
        """
476
        Returns the description of the TabPy service.
477
        """
478 1
        description = None
479 1
        try:
480 1
            description = self._get_config_value(
481
                _SERVICE_INFO_SECTION_NAME, "Description"
482
            )
483
        except Exception as e:
484
            logger.error(f"Unable to get description: {e}")
485 1
        return description
486
487 1
    @state_lock
488 1
    def set_description(self, description):
489
        """
490
        Set the description of this TabPy service.
491
492
        Parameters
493
        ----------
494
        description : str
495
            Description of TabPy service.
496
        """
497
        if not isinstance(description, str):
498
            raise ValueError("Description must be a string.")
499
        try:
500
            self._set_config_value(
501
                _SERVICE_INFO_SECTION_NAME, "Description", description
502
            )
503
        except Exception as e:
504
            logger.error(f"Unable to set description: {e}")
505
506 1
    def get_revision_number(self):
507
        """
508
        Returns the revision number of this TabPy service.
509
        """
510
        rev = -1
511
        try:
512
            rev = int(self._get_config_value(_META_SECTION_NAME, "Revision Number"))
513
        except Exception as e:
514
            logger.error(f"Unable to get revision number: {e}")
515
        return rev
516
517 1
    def get_access_control_allow_origin(self):
518
        """
519
        Returns Access-Control-Allow-Origin of this TabPy service.
520
        """
521 1
        _cors_origin = ""
522 1
        try:
523 1
            logger.debug("Collecting Access-Control-Allow-Origin from state file ...")
524 1
            _cors_origin = self._get_config_value(
525
                "Service Info", "Access-Control-Allow-Origin"
526
            )
527
        except Exception as e:
528
            logger.error(e)
529 1
        return _cors_origin
530
531 1
    def get_access_control_allow_headers(self):
532
        """
533
        Returns Access-Control-Allow-Headers of this TabPy service.
534
        """
535 1
        _cors_headers = ""
536 1
        try:
537 1
            _cors_headers = self._get_config_value(
538
                "Service Info", "Access-Control-Allow-Headers"
539
            )
540
        except Exception:
541
            pass
542 1
        return _cors_headers
543
544 1
    def get_access_control_allow_methods(self):
545
        """
546
        Returns Access-Control-Allow-Methods of this TabPy service.
547
        """
548 1
        _cors_methods = ""
549 1
        try:
550 1
            _cors_methods = self._get_config_value(
551
                "Service Info", "Access-Control-Allow-Methods"
552
            )
553
        except Exception:
554
            pass
555 1
        return _cors_methods
556
557 1
    def _set_revision_number(self, revision_number):
558
        """
559
        Set the revision number of this TabPy service.
560
        """
561
        if not isinstance(revision_number, int):
562
            raise ValueError("revision number must be an int.")
563
        try:
564
            self._set_config_value(
565
                _META_SECTION_NAME, "Revision Number", revision_number
566
            )
567
        except Exception as e:
568
            logger.error(f"Unable to set revision number: {e}")
569
570 1
    def _remove_config_option(
571
        self,
572
        section_name,
573
        option_name,
574
        logger=logging.getLogger(__name__),
575
        _update_revision=True,
576
    ):
577
        if not self.config:
578
            raise ValueError("State configuration not yet loaded.")
579
        self.config.remove_option(section_name, option_name)
580
        # update revision number
581
        if _update_revision:
582
            self._increase_revision_number()
583
        self._write_state(logger=logger)
584
585 1
    def _has_config_value(self, section_name, option_name):
586
        if not self.config:
587
            raise ValueError("State configuration not yet loaded.")
588
        return self.config.has_option(section_name, option_name)
589
590 1
    def _increase_revision_number(self):
591
        if not self.config:
592
            raise ValueError("State configuration not yet loaded.")
593
        cur_rev = int(self.config.get(_META_SECTION_NAME, "Revision Number"))
594
        self.config.set(_META_SECTION_NAME, "Revision Number", str(cur_rev + 1))
595
596 1
    def _set_config_value(
597
        self,
598
        section_name,
599
        option_name,
600
        option_value,
601
        logger=logging.getLogger(__name__),
602
        _update_revision=True,
603
    ):
604
        if not self.config:
605
            raise ValueError("State configuration not yet loaded.")
606
607
        if not self.config.has_section(section_name):
608
            logger.log(logging.DEBUG, f"Adding config section {section_name}")
609
            self.config.add_section(section_name)
610
611
        self.config.set(section_name, option_name, option_value)
612
        # update revision number
613
        if _update_revision:
614
            self._increase_revision_number()
615
        self._write_state(logger=logger)
616
617 1
    def _get_config_items(self, section_name):
618
        if not self.config:
619
            raise ValueError("State configuration not yet loaded.")
620
        return self.config.items(section_name)
621
622 1
    def _get_config_value(
623
        self, section_name, option_name, optional=False, default_value=None
624
    ):
625 1
        logger.log(
626
            logging.DEBUG,
627
            f"Loading option '{option_name}' from section [{section_name}]...")
628
629 1
        if not self.config:
630
            msg = "State configuration not yet loaded."
631
            logging.log(msg)
632
            raise ValueError(msg)
633
634 1
        res = None
635 1
        if not option_name:
636 1
            res = self.config.options(section_name)
637 1
        elif self.config.has_option(section_name, option_name):
638 1
            res = self.config.get(section_name, option_name)
639
        elif optional:
640
            res = default_value
641
        else:
642
            raise ValueError(
643
                f"Cannot find option name {option_name} "
644
                f"under section {section_name}"
645
            )
646
647 1
        logger.log(logging.DEBUG, f"Returning value '{res}'")
648 1
        return res
649
650 1
    def _write_state(self, logger=logging.getLogger(__name__)):
651
        """
652
        Write state (ConfigParser) to Consul
653
        """
654
        logger.log(logging.INFO, "Writing state to config")
655
        write_state_config(self.config, self.settings, logger=logger)
656