Passed
Push — master ( 243c32...83a312 )
by Max
50s
created

structured_data.enum   A

Complexity

Total Complexity 29

Size/Duplication

Total Lines 147
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 95
dl 0
loc 147
rs 10
c 0
b 0
f 0
wmc 29

12 Functions

Rating   Name   Duplication   Size   Complexity  
A _set_new_functions() 0 12 4
A enum() 0 11 2
A _add_methods() 0 5 2
A _set_hash() 0 3 2
A _process_class() 0 30 4
A _name() 0 3 1
A _args_from_annotations() 0 7 3
A _make_nested_new() 0 7 2
A _nillable_write() 0 5 2
A _custom_new() 0 5 2
A _enum_super() 0 4 1
A _add_order() 0 17 4
1
"""Class decorator for defining abstract data types."""
2
3
import sys
4
import typing
5
6
from ._ctor import Ctor
7
from ._ctor import get_args
8
from ._enum_constructor import make_constructor
9
from ._prewritten_methods import SUBCLASS_ORDER
10
from ._prewritten_methods import PrewrittenMethods
11
12
13
def _name(cls, function) -> str:
14
    """Return the name of a function accessed through a descriptor."""
15
    return function.__get__(None, cls).__name__
16
17
18
def _set_new_functions(cls, *functions) -> typing.Optional[str]:
19
    """Attempt to set the attributes corresponding to the functions on cls.
20
21
    If any attributes are already defined, fail *before* setting any, and
22
    return the already-defined name.
23
    """
24
    for function in functions:
25
        if _name(cls, function) in cls.__dict__:
26
            return _name(cls, function)
27
    for function in functions:
28
        setattr(cls, _name(cls, function), function)
29
    return None
30
31
32
def _enum_super(_cls):
33
    def base(cls, args):
34
        return super(_cls, cls).__new__(cls, args)
35
    return base
36
37
38
def _make_nested_new(_cls, subclasses, base__new__):
39
    @staticmethod
40
    def __new__(cls, args):
41
        if cls not in subclasses:
42
            raise TypeError
43
        return base__new__(cls, args)
44
    return __new__
45
46
47
def _nillable_write(dct, key, value):
48
    if value is None:
49
        dct.pop(key, None)
50
    else:
51
        dct[key] = value
52
53
54
def _add_methods(cls, do_set, *methods):
55
    methods_were_set = False
56
    if do_set:
57
        methods_were_set = not _set_new_functions(cls, *methods)
58
    return methods_were_set
59
60
61
def _set_hash(cls, set_hash):
62
    if set_hash:
63
        cls.__hash__ = PrewrittenMethods.__hash__
64
65
66
def _add_order(cls, set_order, equality_methods_were_set):
67
    if set_order:
68
        if not equality_methods_were_set:
69
            raise ValueError(
70
                "Can't add ordering methods if equality methods are provided.")
71
        collision = _set_new_functions(
72
            cls,
73
            PrewrittenMethods.__lt__,
74
            PrewrittenMethods.__le__,
75
            PrewrittenMethods.__gt__,
76
            PrewrittenMethods.__ge__
77
            )
78
        if collision:
79
            raise TypeError(
80
                'Cannot overwrite attribute {collision} in class '
81
                '{name}. Consider using functools.total_ordering'.format(
82
                    collision=collision, name=cls.__name__))
83
84
85
def _custom_new(cls, subclasses):
86
    basic_new = _make_nested_new(cls, subclasses, _enum_super(cls))
87
    if _set_new_functions(cls, basic_new):
88
        augmented_new = _make_nested_new(cls, subclasses, cls.__new__)
89
        cls.__new__ = augmented_new
90
91
92
def _args_from_annotations(cls):
93
    args = {}
94
    for superclass in reversed(cls.__mro__):
95
        for key, value in getattr(superclass, '__annotations__', {}).items():
96
            _nillable_write(
97
                args, key, get_args(value, vars(sys.modules[superclass.__module__])))
98
    return args
99
100
101
def _process_class(_cls, _repr, eq, order):
102
    if order and not eq:
103
        raise ValueError('eq must be true if order is true')
104
105
    subclasses = set()
106
    subclass_order = []
107
108
    for name, args in _args_from_annotations(_cls).items():
109
        make_constructor(_cls, name, args, subclasses, subclass_order)
110
111
    SUBCLASS_ORDER[_cls] = tuple(subclass_order)
112
113
    _cls.__init_subclass__ = PrewrittenMethods.__init_subclass__
114
115
    _custom_new(_cls, subclasses)
116
117
    _set_new_functions(
118
        _cls, PrewrittenMethods.__setattr__, PrewrittenMethods.__delattr__)
119
    _set_new_functions(_cls, PrewrittenMethods.__bool__)
120
121
    _add_methods(_cls, _repr, PrewrittenMethods.__repr__)
122
123
    equality_methods_were_set = _add_methods(
124
        _cls, eq, PrewrittenMethods.__eq__, PrewrittenMethods.__ne__)
125
126
    _set_hash(_cls, equality_methods_were_set)
127
128
    _add_order(_cls, order, equality_methods_were_set)
129
130
    return _cls
131
132
133
def enum(_cls=None, *, repr=True, eq=True, order=False):
134
    """Decorate a class to be an algebraic data type."""
135
136
    def wrap(cls):
137
        """Return the processed class."""
138
        return _process_class(cls, repr, eq, order)
139
140
    if _cls is None:
141
        return wrap
142
143
    return wrap(_cls)
144
145
146
__all__ = ['Ctor', 'enum']
147