Passed
Push — master ( bd4f18...4a0f6d )
by Max
41s
created

structured_data.enum._args()   A

Complexity

Conditions 3

Size

Total Lines 7
Code Lines 7

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 7
nop 2
dl 0
loc 7
rs 10
c 0
b 0
f 0
1
"""Class decorator for defining abstract data types."""
2
3
import ast
4
import sys
5
import typing
6
import weakref
7
8
import astor
9
10
from ._enum_constructor import make_constructor
11
from ._prewritten_methods import SUBCLASS_ORDER
12
from ._prewritten_methods import PrewrittenMethods
13
14
_CTOR_CACHE = {}
15
16
17
ARGS = weakref.WeakKeyDictionary()
18
19
20
class Ctor:
21
    """Marker class for enum constructors.
22
23
    To use, index with a sequence of types, and annotate a variable in an
24
    enum-decorated class with it.
25
    """
26
27
    def __new__(cls, args):
28
        if args == ():
29
            return cls
30
        self = object.__new__(cls)
31
        ARGS[self] = args
32
        return _CTOR_CACHE.setdefault(args, self)
33
34
    def __init_subclass__(cls, **kwargs):
35
        raise TypeError
36
37
    def __class_getitem__(cls, args):
38
        if not isinstance(args, tuple):
39
            args = (args,)
40
        return cls(args)
41
42
43
ARGS[Ctor] = ()
44
45
46
def _interpret_args_from_non_string(constructor):
47
    try:
48
        return ARGS.get(constructor)
49
    except TypeError:
50
        return None
51
52
53
def _parse_constructor(constructor):
54
    try:
55
        return ast.parse(constructor, mode='eval')
56
    except Exception:
57
        raise ValueError('parsing annotation failed')
58
59
60
def _get_args_from_index(index):
61
    if isinstance(index, ast.Tuple):
62
        return tuple(astor.to_source(elt) for elt in index.elts)
63
    return (astor.to_source(index),)
64
65
66
def _checked_eval(source, global_ns):
67
    try:
68
        return eval(source, global_ns)
69
    except Exception:
70
        return None
71
72
73
def _extract_tuple_ast(constructor, global_ns):
74
    ctor_ast = _parse_constructor(constructor)
75
    if (
76
            isinstance(ctor_ast.body, ast.Subscript)
77
            and isinstance(ctor_ast.body.slice, ast.Index)):
78
        index = ctor_ast.body.slice.value
79
        ctor_ast.body = ctor_ast.body.value
80
        value = _checked_eval(compile(ctor_ast, '<annotation>', 'eval'), global_ns)
81
        if value is Ctor:
82
            return _get_args_from_index(index)
83
        if value is None:
84
            return None
85
    return _interpret_args_from_non_string(_checked_eval(constructor, global_ns))
86
87
88
def _args(constructor, global_ns):
89
    if isinstance(constructor, str):
90
        try:
91
            return _extract_tuple_ast(constructor, global_ns)
92
        except ValueError:
93
            return None
94
    return _interpret_args_from_non_string(constructor)
95
96
97
def _name(cls, function) -> str:
98
    """Return the name of a function accessed through a descriptor."""
99
    return function.__get__(None, cls).__name__
100
101
102
def _set_new_functions(cls, *functions) -> typing.Optional[str]:
103
    """Attempt to set the attributes corresponding to the functions on cls.
104
105
    If any attributes are already defined, fail *before* setting any, and
106
    return the already-defined name.
107
    """
108
    for function in functions:
109
        if _name(cls, function) in cls.__dict__:
110
            return _name(cls, function)
111
    for function in functions:
112
        setattr(cls, _name(cls, function), function)
113
    return None
114
115
116
def _enum_super(_cls):
117
    def base(cls, args):
118
        return super(_cls, cls).__new__(cls, args)
119
    return base
120
121
122
def _make_nested_new(_cls, subclasses, base__new__):
123
    @staticmethod
124
    def __new__(cls, args):
125
        if cls not in subclasses:
126
            raise TypeError
127
        return base__new__(cls, args)
128
    return __new__
129
130
131
def _nillable_write(dct, key, value):
132
    if value is None:
133
        dct.pop(key, None)
134
    else:
135
        dct[key] = value
136
137
138
def _process_class(_cls, _repr, eq, order):
139
    if order and not eq:
140
        raise ValueError('eq must be true if order is true')
141
142
    args = {}
143
    subclasses = set()
144
    subclass_order = []
145
    for cls in reversed(_cls.__mro__):
146
        for key, value in getattr(cls, '__annotations__', {}).items():
147
            _nillable_write(
148
                args, key, _args(value, vars(sys.modules[cls.__module__])))
149
150
    for name, args_ in args.items():
151
        make_constructor(_cls, name, args_, subclasses, subclass_order)
152
153
    SUBCLASS_ORDER[_cls] = tuple(subclass_order)
154
155
    _cls.__init_subclass__ = PrewrittenMethods.__init_subclass__
156
157
    if _set_new_functions(_cls, _make_nested_new(_cls, subclasses, _enum_super(_cls))):
158
        _cls.__new__ = _make_nested_new(_cls, subclasses, _cls.__new__)
159
160
    _set_new_functions(
161
        _cls, PrewrittenMethods.__setattr__, PrewrittenMethods.__delattr__)
162
    _set_new_functions(_cls, PrewrittenMethods.__bool__)
163
164
    if _repr:
165
        _set_new_functions(_cls, PrewrittenMethods.__repr__)
166
167
    equality_methods_were_set = False
168
169
    if eq:
170
        equality_methods_were_set = not _set_new_functions(
171
            _cls, PrewrittenMethods.__eq__, PrewrittenMethods.__ne__)
172
173
    if equality_methods_were_set:
174
        _cls.__hash__ = PrewrittenMethods.__hash__
175
176
    if order:
177
        if not equality_methods_were_set:
178
            raise ValueError(
179
                "Can't add ordering methods if equality methods are provided.")
180
        collision = _set_new_functions(
181
            _cls,
182
            PrewrittenMethods.__lt__,
183
            PrewrittenMethods.__le__,
184
            PrewrittenMethods.__gt__,
185
            PrewrittenMethods.__ge__
186
            )
187
        if collision:
188
            raise TypeError(
189
                'Cannot overwrite attribute {collision} in class '
190
                '{name}. Consider using functools.total_ordering'.format(
191
                    collision=collision, name=_cls.__name__))
192
193
    return _cls
194
195
196
def enum(_cls=None, *, repr=True, eq=True, order=False):
197
    """Decorate a class to be an algebraic data type."""
198
199
    def wrap(cls):
200
        """Return the processed class."""
201
        return _process_class(cls, repr, eq, order)
202
203
    if _cls is None:
204
        return wrap
205
206
    return wrap(_cls)
207
208
209
__all__ = ['Ctor', 'enum']
210