Passed
Push — main ( 60479c...7a4347 )
by Eran
02:10 queued 28s
created

graphinate.converters.decode()   A

Complexity

Conditions 1

Size

Total Lines 6
Code Lines 6

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 6
dl 0
loc 6
rs 10
c 0
b 0
f 0
cc 1
nop 2
1
import ast
2
import base64
3
import decimal
4
import math
5
from types import MappingProxyType
6
from typing import Any, NewType, Union
7
8
import strawberry
9
10
from .constants import DEFAULT_EDGE_DELIMITER, DEFAULT_NODE_DELIMITER
11
12
__all__ = [
13
    'InfNumber',
14
    'decode_edge_id',
15
    'decode_id',
16
    'edge_label_converter',
17
    'encode_edge_id',
18
    'encode_id',
19
    'infnum_to_value',
20
    'label_converter',
21
    'node_label_converter',
22
    'value_to_infnum',
23
]
24
25
InfNumber = NewType("InfNumber", Union[float, int, decimal.Decimal])
26
27
INFINITY_MAPPING = MappingProxyType({
28
    'Infinity': math.inf,
29
    '+Infinity': math.inf,
30
    '-Infinity': -math.inf
31
})
32
33
MATH_INF_MAPPING = MappingProxyType({
34
    math.inf: 'Infinity',
35
    -math.inf: '-Infinity'
36
})
37
38
39
def value_to_infnum(value: any) -> InfNumber:
40
    return INFINITY_MAPPING.get(value, value)
41
42
43
def infnum_to_value(value: InfNumber):
44
    return MATH_INF_MAPPING.get(value, value)
45
46
47
def label_converter(value, delimiter: str):
48
    if value:
49
        return delimiter.join(str(v) for v in value) if isinstance(value, tuple) else str(value)
50
51
    return value
52
53
54
def node_label_converter(value):
55
    return label_converter(value, delimiter=DEFAULT_NODE_DELIMITER)
56
57
58
def edge_label_converter(value):
59
    return label_converter(tuple(node_label_converter(n) for n in value), delimiter=DEFAULT_EDGE_DELIMITER)
60
61
62
def encode(value: Any, encoding: str = 'utf-8') -> str:
63
    obj_s: str = repr(value)
64
    obj_b: bytes = obj_s.encode(encoding)
65
    enc_b: bytes = base64.urlsafe_b64encode(obj_b)
66
    enc_s: str = enc_b.decode(encoding)
67
    return enc_s
68
69
70
def decode(value: str, encoding: str = 'utf-8') -> Any:
71
    enc_b: bytes = value.encode(encoding)
72
    obj_b: bytes = base64.urlsafe_b64decode(enc_b)
73
    obj_s: str = obj_b.decode(encoding)
74
    obj: Any = ast.literal_eval(obj_s)
75
    return obj
76
77
78
def encode_id(graph_node_id: tuple,
79
              encoding: str = 'utf-8') -> str:
80
    return encode(graph_node_id, encoding)
81
82
83
def decode_id(graphql_node_id: strawberry.ID,
84
              encoding: str = 'utf-8') -> tuple[str, ...]:
85
    return decode(graphql_node_id, encoding)
86
87
88
def encode_edge_id(edge: tuple, encoding: str = 'utf-8'):
89
    encoded_edge = tuple(encode_id(n, encoding) for n in edge)
90
    return encode_id(encoded_edge, encoding)
91
92
93
def decode_edge_id(graphql_edge_id: strawberry.ID, encoding: str = 'utf-8'):
94
    encoded_edge: tuple = decode_id(graphql_edge_id, encoding)
95
    return tuple(decode_id(enc_node) for enc_node in encoded_edge)
96