TabPyState._check_target()   A
last analyzed

Complexity

Conditions 3

Size

Total Lines 3
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Code Coverage

Tests 1
CRAP Score 5.667

Importance

Changes 0
Metric Value
eloc 3
dl 0
loc 3
rs 10
c 0
b 0
f 0
ccs 1
cts 3
cp 0.3333
cc 3
nop 2
crap 5.667
1 1
try:
2 1
    from ConfigParser import ConfigParser
3 1
except ImportError:
4 1
    from configparser import ConfigParser
5 1
import json
6 1
import logging
7 1
from tabpy.tabpy_server.management.util import write_state_config
8 1
from threading import Lock
9 1
from time import time
10
11
12 1
logger = logging.getLogger(__name__)
13
14
# State File Config Section Names
15 1
_DEPLOYMENT_SECTION_NAME = "Query Objects Service Versions"
16 1
_QUERY_OBJECT_DOCSTRING = "Query Objects Docstrings"
17 1
_SERVICE_INFO_SECTION_NAME = "Service Info"
18 1
_META_SECTION_NAME = "Meta"
19
20
# Directory Names
21 1
_QUERY_OBJECT_DIR = "query_objects"
22
23 1
"""
24
Lock to change the TabPy State.
25
"""
26 1
_PS_STATE_LOCK = Lock()
27
28
29 1
def state_lock(func):
30
    """
31
    Mutex for changing PS state
32
    """
33
34 1
    def wrapper(self, *args, **kwargs):
35 1
        try:
36 1
            _PS_STATE_LOCK.acquire()
37 1
            return func(self, *args, **kwargs)
38
        finally:
39
            # ALWAYS RELEASE LOCK
40 1
            _PS_STATE_LOCK.release()
41
42 1
    return wrapper
43
44
45 1
def _get_root_path(state_path):
46 1
    if state_path[-1] != "/":
47 1
        state_path += "/"
48
49 1
    return state_path
50
51
52 1
def get_query_object_path(state_file_path, name, version):
53
    """
54
    Returns the query object path
55
56
    If the version is None, a path without the version will be returned.
57
    """
58 1
    root_path = _get_root_path(state_file_path)
59 1
    sub_path = [_QUERY_OBJECT_DIR, name]
60 1
    if version is not None:
61 1
        sub_path.append(str(version))
62 1
    full_path = root_path + "/".join(sub_path)
63 1
    return full_path
64
65
66 1
class TabPyState:
67
    """
68
    The TabPy state object that stores attributes
69
    about this TabPy and perform GET/SET on these
70
    attributes.
71
72
    Attributes:
73
        - name
74
        - description
75
        - endpoints (name, description, docstring, version, target)
76
        - revision number
77
78
    When the state object is initialized, the state is saved as a ConfigParser.
79
    There is a config to any attribute.
80
81
    """
82
83 1
    def __init__(self, settings, config=None):
84 1
        self.settings = settings
85 1
        self.set_config(config, _update=False)
86
87 1
    @state_lock
88 1
    def set_config(self, config, logger=logging.getLogger(__name__), _update=True):
89
        """
90
        Set the local ConfigParser manually.
91
        This new ConfigParser will be used as current state.
92
        """
93 1
        if not isinstance(config, ConfigParser):
94
            raise ValueError("Invalid config")
95 1
        self.config = config
96 1
        if _update:
97
            self._write_state(logger)
98
99 1
    def get_endpoints(self, name=None):
100
        """
101
        Return a dictionary of endpoints
102
103
        Parameters
104
        ----------
105
        name : str
106
            The name of the endpoint.
107
            If "name" is specified, only the information about that endpoint
108
            will be returned.
109
110
        Returns
111
        -------
112
        endpoints : dict
113
            The dictionary containing information about each endpoint.
114
            The keys are the endpoint names.
115
            The values for each include:
116
                - description
117
                - doc string
118
                - type
119
                - target
120
121
        """
122 1
        endpoints = {}
123 1
        try:
124 1
            endpoint_names = self._get_config_value(_DEPLOYMENT_SECTION_NAME, name)
125
        except Exception as e:
126
            logger.error(f"error in get_endpoints: {str(e)}")
127
            return {}
128
129 1
        if name:
130
            endpoint_info = json.loads(endpoint_names)
131
            docstring = self._get_config_value(_QUERY_OBJECT_DOCSTRING, name)
132
            endpoint_info["docstring"] = str(
133
                bytes(docstring, "utf-8").decode("unicode_escape")
134
            )
135
            endpoints = {name: endpoint_info}
136
        else:
137 1
            for endpoint_name in endpoint_names:
138 1
                endpoint_info = json.loads(
139
                    self._get_config_value(_DEPLOYMENT_SECTION_NAME, endpoint_name)
140
                )
141 1
                docstring = self._get_config_value(
142
                    _QUERY_OBJECT_DOCSTRING, endpoint_name, True, ""
143
                )
144 1
                endpoint_info["docstring"] = str(
145
                    bytes(docstring, "utf-8").decode("unicode_escape")
146
                )
147 1
                endpoints[endpoint_name] = endpoint_info
148 1
        logger.debug(f"Collected endpoints: {endpoints}")
149 1
        return endpoints
150
151 1
    def _check_endpoint_exists(self, name):
152
        endpoints = self.get_endpoints()
153
        if not name or not isinstance(name, str) or len(name) == 0:
154
            raise ValueError("name of the endpoint must be a valid string.")
155
156
        return name in endpoints
157
158 1
    def _check_and_set_endpoint_str_value(self, param, paramName, defaultValue):
159
        if not param and defaultValue is not None:
160
            return defaultValue
161
162
        if not param or not isinstance(param, str):
163
            raise ValueError(f"{paramName} must be a string.")
164
165
        return param
166
167 1
    def _check_and_set_endpoint_description(self, description, defaultValue):
168
        return self._check_and_set_endpoint_str_value(description, "description", defaultValue)
169
170 1
    def _check_and_set_endpoint_docstring(self, docstring, defaultValue):
171
        return self._check_and_set_endpoint_str_value(docstring, "docstring", defaultValue)
172
173 1
    def _check_and_set_endpoint_type(self, endpoint_type, defaultValue):
174
        return self._check_and_set_endpoint_str_value(
175
            endpoint_type, "endpoint type", defaultValue)
176
177 1
    def _check_target(self, target):
178
        if target and not isinstance(target, str):
179
            raise ValueError("target must be a string.")
180
181 1
    def _check_and_set_dependencies(self, dependencies, defaultValue):
182
        if not dependencies:
183
            return defaultValue
184
185
        if dependencies or not isinstance(dependencies, list):
186
            raise ValueError("dependencies must be a list.")
187
188
        return dependencies
189
190 1
    def _check_and_set_is_public(self, is_public, defaultValue):
191
        if is_public is None:
192
            return defaultValue
193
194
        return is_public
195
196 1
    @state_lock
197 1
    def add_endpoint(
198
        self,
199
        name,
200
        description=None,
201
        docstring=None,
202
        endpoint_type=None,
203
        methods=None,
204
        target=None,
205
        dependencies=None,
206
        schema=None,
207
        is_public=None,
208
    ):
209
        """
210
        Add a new endpoint to the TabPy.
211
212
        Parameters
213
        ----------
214
        name : str
215
            Name of the endpoint
216
        description : str, optional
217
            Description of this endpoint
218
        docstring : str, optional
219
            The doc string for this endpoint, if needed.
220
        endpoint_type : str
221
            The endpoint type (model, alias)
222
        target : str, optional
223
            The target endpoint name for the alias to be added.
224
225
        Note:
226
        The version of this endpoint will be set to 1 since it is a new
227
        endpoint.
228
229
        """
230
        try:
231
            if (self._check_endpoint_exists(name)):
232
                raise ValueError(f"endpoint {name} already exists.")
233
234
            endpoints = self.get_endpoints()
235
236
            description = self._check_and_set_endpoint_description(description, "")
237
            docstring = self._check_and_set_endpoint_docstring(
238
                docstring, "-- no docstring found in query function --")
239
            endpoint_type = self._check_and_set_endpoint_type(endpoint_type, None)
240
            dependencies = self._check_and_set_dependencies(dependencies, [])
241
            is_public = self._check_and_set_is_public(is_public, False)
242
243
            self._check_target(target)
244
            if target and target not in endpoints:
245
                raise ValueError("target endpoint is not valid.")
246
247
            endpoint_info = {
248
                "description": description,
249
                "docstring": docstring,
250
                "type": endpoint_type,
251
                "version": 1,
252
                "dependencies": dependencies,
253
                "target": target,
254
                "creation_time": int(time()),
255
                "last_modified_time": int(time()),
256
                "schema": schema,
257
                "is_public": is_public,
258
            }
259
260
            endpoints[name] = endpoint_info
261
            self._add_update_endpoints_config(endpoints)
262
        except Exception as e:
263
            logger.error(f"Error in add_endpoint: {e}")
264
            raise
265
266 1
    def _add_update_endpoints_config(self, endpoints):
267
        # save the endpoint info to config
268
        dstring = ""
269
        for endpoint_name in endpoints:
270
            try:
271
                info = endpoints[endpoint_name]
272
                dstring = str(
273
                    bytes(info["docstring"], "utf-8").decode("unicode_escape")
274
                )
275
                self._set_config_value(
276
                    _QUERY_OBJECT_DOCSTRING,
277
                    endpoint_name,
278
                    dstring,
279
                    _update_revision=False,
280
                )
281
                del info["docstring"]
282
                self._set_config_value(
283
                    _DEPLOYMENT_SECTION_NAME, endpoint_name, json.dumps(info)
284
                )
285
            except Exception as e:
286
                logger.error(f"Unable to write endpoints config: {e}")
287
                raise
288
289 1
    @state_lock
290 1
    def update_endpoint(
291
        self,
292
        name,
293
        description=None,
294
        docstring=None,
295
        endpoint_type=None,
296
        version=None,
297
        methods=None,
298
        target=None,
299
        dependencies=None,
300
        schema=None,
301
        is_public=None,
302
    ):
303
        """
304
        Update an existing endpoint on the TabPy.
305
306
        Parameters
307
        ----------
308
        name : str
309
            Name of the endpoint
310
        description : str, optional
311
            Description of this endpoint
312
        docstring : str, optional
313
            The doc string for this endpoint, if needed.
314
        endpoint_type : str, optional
315
            The endpoint type (model, alias)
316
        version : str, optional
317
            The version of this endpoint
318
        dependencies=[]
319
            List of dependent endpoints for this existing endpoint
320
        target : str, optional
321
            The target endpoint name for the alias.
322
323
        Note:
324
        For those parameters that are not specified, those values will not
325
        get changed.
326
327
        """
328
        try:
329
            if (not self._check_endpoint_exists(name)):
330
                raise ValueError(f"endpoint {name} does not exist.")
331
332
            endpoints = self.get_endpoints()
333
            endpoint_info = endpoints[name]
334
335
            description = self._check_and_set_endpoint_description(
336
                description, endpoint_info["description"])
337
            docstring = self._check_and_set_endpoint_docstring(
338
                docstring, endpoint_info["docstring"])
339
            endpoint_type = self._check_and_set_endpoint_type(
340
                endpoint_type, endpoint_info["type"])
341
            dependencies = self._check_and_set_dependencies(
342
                dependencies, endpoint_info.get("dependencies", []))
343
            # Adding is_public means that some existing functions do not have is_public set.
344
            # We need to check for this when updating and set to False by default
345
            is_public = self._check_and_set_is_public(
346
                is_public, getattr(endpoint_info, "is_public", False))
347
348
            self._check_target(target)
349
            if target and target not in endpoints:
350
                raise ValueError("target endpoint is not valid.")
351
            elif not target:
352
                target = endpoint_info["target"]
353
354
            if version and not isinstance(version, int):
355
                raise ValueError("version must be an int.")
356
            elif not version:
357
                version = endpoint_info["version"]
358
359
            endpoint_info = {
360
                "description": description,
361
                "docstring": docstring,
362
                "type": endpoint_type,
363
                "version": version,
364
                "dependencies": dependencies,
365
                "target": target,
366
                "creation_time": endpoint_info["creation_time"],
367
                "last_modified_time": int(time()),
368
                "schema": schema,
369
                "is_public": is_public,
370
            }
371
372
            endpoints[name] = endpoint_info
373
            self._add_update_endpoints_config(endpoints)
374
        except Exception as e:
375
            logger.error(f"Error in update_endpoint: {e}")
376
            raise
377
378 1
    @state_lock
379 1
    def delete_endpoint(self, name):
380
        """
381
        Delete an existing endpoint on the TabPy
382
383
        Parameters
384
        ----------
385
        name : str
386
            The name of the endpoint to be deleted.
387
388
        Returns
389
        -------
390
        deleted endpoint object
391
392
        Note:
393
        Cannot delete this endpoint if other endpoints are currently
394
        depending on this endpoint.
395
396
        """
397
        if not name or name == "":
398
            raise ValueError("Name of the endpoint must be a valid string.")
399
        endpoints = self.get_endpoints()
400
        if name not in endpoints:
401
            raise ValueError(f"Endpoint {name} does not exist.")
402
403
        endpoint_to_delete = endpoints[name]
404
405
        # get dependencies and target
406
        deps = set()
407
        for endpoint_name in endpoints:
408
            if endpoint_name != name:
409
                deps_list = endpoints[endpoint_name].get("dependencies", [])
410
                if name in deps_list:
411
                    deps.add(endpoint_name)
412
413
        # check if other endpoints are depending on this endpoint
414
        if len(deps) > 0:
415
            raise ValueError(
416
                f"Cannot remove endpoint {name}, it is currently "
417
                f"used by {list(deps)} endpoints."
418
            )
419
420
        del endpoints[name]
421
422
        # delete the endpoint from state
423
        try:
424
            self._remove_config_option(
425
                _QUERY_OBJECT_DOCSTRING, name, _update_revision=False
426
            )
427
            self._remove_config_option(_DEPLOYMENT_SECTION_NAME, name)
428
429
            return endpoint_to_delete
430
        except Exception as e:
431
            logger.error(f"Unable to delete endpoint {e}")
432
            raise ValueError(f"Unable to delete endpoint: {e}")
433
434 1
    @property
435 1
    def name(self):
436
        """
437
        Returns the name of the TabPy service.
438
        """
439 1
        name = None
440 1
        try:
441 1
            name = self._get_config_value(_SERVICE_INFO_SECTION_NAME, "Name")
442
        except Exception as e:
443
            logger.error(f"Unable to get name: {e}")
444 1
        return name
445
446 1
    @property
447 1
    def creation_time(self):
448
        """
449
        Returns the creation time of the TabPy service.
450
        """
451 1
        creation_time = 0
452 1
        try:
453 1
            creation_time = self._get_config_value(
454
                _SERVICE_INFO_SECTION_NAME, "Creation Time"
455
            )
456
        except Exception as e:
457
            logger.error(f"Unable to get name: {e}")
458 1
        return creation_time
459
460 1
    @state_lock
461 1
    def set_name(self, name):
462
        """
463
        Set the name of this TabPy service.
464
465
        Parameters
466
        ----------
467
        name : str
468
            Name of TabPy service.
469
        """
470
        if not isinstance(name, str):
471
            raise ValueError("name must be a string.")
472
        try:
473
            self._set_config_value(_SERVICE_INFO_SECTION_NAME, "Name", name)
474
        except Exception as e:
475
            logger.error(f"Unable to set name: {e}")
476
477 1
    def get_description(self):
478
        """
479
        Returns the description of the TabPy service.
480
        """
481 1
        description = None
482 1
        try:
483 1
            description = self._get_config_value(
484
                _SERVICE_INFO_SECTION_NAME, "Description"
485
            )
486
        except Exception as e:
487
            logger.error(f"Unable to get description: {e}")
488 1
        return description
489
490 1
    @state_lock
491 1
    def set_description(self, description):
492
        """
493
        Set the description of this TabPy service.
494
495
        Parameters
496
        ----------
497
        description : str
498
            Description of TabPy service.
499
        """
500
        if not isinstance(description, str):
501
            raise ValueError("Description must be a string.")
502
        try:
503
            self._set_config_value(
504
                _SERVICE_INFO_SECTION_NAME, "Description", description
505
            )
506
        except Exception as e:
507
            logger.error(f"Unable to set description: {e}")
508
509 1
    def get_revision_number(self):
510
        """
511
        Returns the revision number of this TabPy service.
512
        """
513
        rev = -1
514
        try:
515
            rev = int(self._get_config_value(_META_SECTION_NAME, "Revision Number"))
516
        except Exception as e:
517
            logger.error(f"Unable to get revision number: {e}")
518
        return rev
519
520 1
    def get_access_control_allow_origin(self):
521
        """
522
        Returns Access-Control-Allow-Origin of this TabPy service.
523
        """
524 1
        _cors_origin = ""
525 1
        try:
526 1
            logger.debug("Collecting Access-Control-Allow-Origin from state file ...")
527 1
            _cors_origin = self._get_config_value(
528
                "Service Info", "Access-Control-Allow-Origin"
529
            )
530
        except Exception as e:
531
            logger.error(e)
532 1
        return _cors_origin
533
534 1
    def get_access_control_allow_headers(self):
535
        """
536
        Returns Access-Control-Allow-Headers of this TabPy service.
537
        """
538 1
        _cors_headers = ""
539 1
        try:
540 1
            _cors_headers = self._get_config_value(
541
                "Service Info", "Access-Control-Allow-Headers"
542
            )
543
        except Exception:
544
            pass
545 1
        return _cors_headers
546
547 1
    def get_access_control_allow_methods(self):
548
        """
549
        Returns Access-Control-Allow-Methods of this TabPy service.
550
        """
551 1
        _cors_methods = ""
552 1
        try:
553 1
            _cors_methods = self._get_config_value(
554
                "Service Info", "Access-Control-Allow-Methods"
555
            )
556
        except Exception:
557
            pass
558 1
        return _cors_methods
559
560 1
    def _set_revision_number(self, revision_number):
561
        """
562
        Set the revision number of this TabPy service.
563
        """
564
        if not isinstance(revision_number, int):
565
            raise ValueError("revision number must be an int.")
566
        try:
567
            self._set_config_value(
568
                _META_SECTION_NAME, "Revision Number", revision_number
569
            )
570
        except Exception as e:
571
            logger.error(f"Unable to set revision number: {e}")
572
573 1
    def _remove_config_option(
574
        self,
575
        section_name,
576
        option_name,
577
        logger=logging.getLogger(__name__),
578
        _update_revision=True,
579
    ):
580
        if not self.config:
581
            raise ValueError("State configuration not yet loaded.")
582
        self.config.remove_option(section_name, option_name)
583
        # update revision number
584
        if _update_revision:
585
            self._increase_revision_number()
586
        self._write_state(logger=logger)
587
588 1
    def _has_config_value(self, section_name, option_name):
589
        if not self.config:
590
            raise ValueError("State configuration not yet loaded.")
591
        return self.config.has_option(section_name, option_name)
592
593 1
    def _increase_revision_number(self):
594
        if not self.config:
595
            raise ValueError("State configuration not yet loaded.")
596
        cur_rev = int(self.config.get(_META_SECTION_NAME, "Revision Number"))
597
        self.config.set(_META_SECTION_NAME, "Revision Number", str(cur_rev + 1))
598
599 1
    def _set_config_value(
600
        self,
601
        section_name,
602
        option_name,
603
        option_value,
604
        logger=logging.getLogger(__name__),
605
        _update_revision=True,
606
    ):
607
        if not self.config:
608
            raise ValueError("State configuration not yet loaded.")
609
610
        if not self.config.has_section(section_name):
611
            logger.log(logging.DEBUG, f"Adding config section {section_name}")
612
            self.config.add_section(section_name)
613
614
        self.config.set(section_name, option_name, option_value)
615
        # update revision number
616
        if _update_revision:
617
            self._increase_revision_number()
618
        self._write_state(logger=logger)
619
620 1
    def _get_config_items(self, section_name):
621
        if not self.config:
622
            raise ValueError("State configuration not yet loaded.")
623
        return self.config.items(section_name)
624
625 1
    def _get_config_value(
626
        self, section_name, option_name, optional=False, default_value=None
627
    ):
628 1
        logger.log(
629
            logging.DEBUG,
630
            f"Loading option '{option_name}' from section [{section_name}]...")
631
632 1
        if not self.config:
633
            msg = "State configuration not yet loaded."
634
            logging.log(msg)
635
            raise ValueError(msg)
636
637 1
        res = None
638 1
        if not option_name:
639 1
            res = self.config.options(section_name)
640 1
        elif self.config.has_option(section_name, option_name):
641 1
            res = self.config.get(section_name, option_name)
642
        elif optional:
643
            res = default_value
644
        else:
645
            raise ValueError(
646
                f"Cannot find option name {option_name} "
647
                f"under section {section_name}"
648
            )
649
650 1
        logger.log(logging.DEBUG, f"Returning value '{res}'")
651 1
        return res
652
653 1
    def _write_state(self, logger=logging.getLogger(__name__)):
654
        """
655
        Write state (ConfigParser) to Consul
656
        """
657
        logger.log(logging.INFO, "Writing state to config")
658
        write_state_config(self.config, self.settings, logger=logger)
659