Passed
Branch master (459d08)
by Max
46s
created

structured_data.enum._add_eq()   A

Complexity

Conditions 2

Size

Total Lines 6
Code Lines 6

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 6
nop 2
dl 0
loc 6
rs 10
c 0
b 0
f 0
1
"""Class decorator for defining abstract data types."""
2
3
import ast
4
import sys
5
import typing
6
import weakref
7
8
import astor
9
10
from ._enum_constructor import make_constructor
11
from ._prewritten_methods import SUBCLASS_ORDER
12
from ._prewritten_methods import PrewrittenMethods
13
14
_CTOR_CACHE = {}
15
16
17
ARGS = weakref.WeakKeyDictionary()
18
19
20
class Ctor:
21
    """Marker class for enum constructors.
22
23
    To use, index with a sequence of types, and annotate a variable in an
24
    enum-decorated class with it.
25
    """
26
27
    def __new__(cls, args):
28
        if args == ():
29
            return cls
30
        self = object.__new__(cls)
31
        ARGS[self] = args
32
        return _CTOR_CACHE.setdefault(args, self)
33
34
    def __init_subclass__(cls, **kwargs):
35
        raise TypeError
36
37
    def __class_getitem__(cls, args):
38
        if not isinstance(args, tuple):
39
            args = (args,)
40
        return cls(args)
41
42
43
ARGS[Ctor] = ()
44
45
46
def _interpret_args_from_non_string(constructor):
47
    try:
48
        return ARGS.get(constructor)
49
    except TypeError:
50
        return None
51
52
53
def _parse_constructor(constructor):
54
    try:
55
        return ast.parse(constructor, mode='eval')
56
    except Exception:
57
        raise ValueError('parsing annotation failed')
58
59
60
def _get_args_from_index(index):
61
    if isinstance(index, ast.Tuple):
62
        return tuple(astor.to_source(elt) for elt in index.elts)
63
    return (astor.to_source(index),)
64
65
66
def _checked_eval(source, global_ns):
67
    try:
68
        return eval(source, global_ns)
69
    except Exception:
70
        return None
71
72
73
def _extract_tuple_ast(constructor, global_ns):
74
    ctor_ast = _parse_constructor(constructor)
75
    if (
76
            isinstance(ctor_ast.body, ast.Subscript)
77
            and isinstance(ctor_ast.body.slice, ast.Index)):
78
        index = ctor_ast.body.slice.value
79
        ctor_ast.body = ctor_ast.body.value
80
        value = _checked_eval(compile(ctor_ast, '<annotation>', 'eval'), global_ns)
81
        if value is Ctor:
82
            return _get_args_from_index(index)
83
        if value is None:
84
            return None
85
    return _interpret_args_from_non_string(_checked_eval(constructor, global_ns))
86
87
88
def _args(constructor, global_ns):
89
    if isinstance(constructor, str):
90
        try:
91
            return _extract_tuple_ast(constructor, global_ns)
92
        except ValueError:
93
            return None
94
    return _interpret_args_from_non_string(constructor)
95
96
97
def _name(cls, function) -> str:
98
    """Return the name of a function accessed through a descriptor."""
99
    return function.__get__(None, cls).__name__
100
101
102
def _set_new_functions(cls, *functions) -> typing.Optional[str]:
103
    """Attempt to set the attributes corresponding to the functions on cls.
104
105
    If any attributes are already defined, fail *before* setting any, and
106
    return the already-defined name.
107
    """
108
    for function in functions:
109
        if _name(cls, function) in cls.__dict__:
110
            return _name(cls, function)
111
    for function in functions:
112
        setattr(cls, _name(cls, function), function)
113
    return None
114
115
116
def _enum_super(_cls):
117
    def base(cls, args):
118
        return super(_cls, cls).__new__(cls, args)
119
    return base
120
121
122
def _make_nested_new(_cls, subclasses, base__new__):
123
    @staticmethod
124
    def __new__(cls, args):
125
        if cls not in subclasses:
126
            raise TypeError
127
        return base__new__(cls, args)
128
    return __new__
129
130
131
def _nillable_write(dct, key, value):
132
    if value is None:
133
        dct.pop(key, None)
134
    else:
135
        dct[key] = value
136
137
138
def _add_repr(cls, set_repr):
139
    if set_repr:
140
        _set_new_functions(cls, PrewrittenMethods.__repr__)
141
142
143
def _add_eq(cls, set_eq):
144
    equality_methods_were_set = False
145
    if set_eq:
146
        equality_methods_were_set = not _set_new_functions(
147
            cls, PrewrittenMethods.__eq__, PrewrittenMethods.__ne__)
148
    return equality_methods_were_set
149
150
151
def _process_class(_cls, _repr, eq, order):
152
    if order and not eq:
153
        raise ValueError('eq must be true if order is true')
154
155
    args = {}
156
    subclasses = set()
157
    subclass_order = []
158
    for cls in reversed(_cls.__mro__):
159
        for key, value in getattr(cls, '__annotations__', {}).items():
160
            _nillable_write(
161
                args, key, _args(value, vars(sys.modules[cls.__module__])))
162
163
    for name, args_ in args.items():
164
        make_constructor(_cls, name, args_, subclasses, subclass_order)
165
166
    SUBCLASS_ORDER[_cls] = tuple(subclass_order)
167
168
    _cls.__init_subclass__ = PrewrittenMethods.__init_subclass__
169
170
    if _set_new_functions(_cls, _make_nested_new(_cls, subclasses, _enum_super(_cls))):
171
        _cls.__new__ = _make_nested_new(_cls, subclasses, _cls.__new__)
172
173
    _set_new_functions(
174
        _cls, PrewrittenMethods.__setattr__, PrewrittenMethods.__delattr__)
175
    _set_new_functions(_cls, PrewrittenMethods.__bool__)
176
177
    _add_repr(_cls, _repr)
178
179
    equality_methods_were_set = _add_eq(_cls, eq)
180
181
    if equality_methods_were_set:
182
        _cls.__hash__ = PrewrittenMethods.__hash__
183
184
    if order:
185
        if not equality_methods_were_set:
186
            raise ValueError(
187
                "Can't add ordering methods if equality methods are provided.")
188
        collision = _set_new_functions(
189
            _cls,
190
            PrewrittenMethods.__lt__,
191
            PrewrittenMethods.__le__,
192
            PrewrittenMethods.__gt__,
193
            PrewrittenMethods.__ge__
194
            )
195
        if collision:
196
            raise TypeError(
197
                'Cannot overwrite attribute {collision} in class '
198
                '{name}. Consider using functools.total_ordering'.format(
199
                    collision=collision, name=_cls.__name__))
200
201
    return _cls
202
203
204
def enum(_cls=None, *, repr=True, eq=True, order=False):
205
    """Decorate a class to be an algebraic data type."""
206
207
    def wrap(cls):
208
        """Return the processed class."""
209
        return _process_class(cls, repr, eq, order)
210
211
    if _cls is None:
212
        return wrap
213
214
    return wrap(_cls)
215
216
217
__all__ = ['Ctor', 'enum']
218