Passed
Push — master ( 1e4141...3bde1c )
by Max
57s
created

structured_data.adt.Product.__new__()   B

Complexity

Conditions 6

Size

Total Lines 14
Code Lines 14

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 6
eloc 14
nop 2
dl 0
loc 14
rs 8.6666
c 0
b 0
f 0
1
"""Class decorator for defining abstract data types.
2
3
This module provides two public members, which are used together.
4
5
Given a structure, possibly a choice of different structures, that you'd like
6
to associate with a type:
7
8
- First, create a class, that subclasses the Sum class.
9
- Then, for each possible structure, add an attribute annotation to the class
10
  with the desired name of the constructor, and a type of ``Ctor``, with the
11
  types within the constructor as arguments.
12
13
To look inside an ADT instance, use the functions from the
14
:mod:`structured_data.match` module.
15
16
Putting it together:
17
18
>>> from structured_data import match
19
>>> class Example(Sum):
20
...     FirstConstructor: Ctor[int, str]
21
...     SecondConstructor: Ctor[bytes]
22
...     ThirdConstructor: Ctor
23
...     def __iter__(self):
24
...         matchable = match.Matchable(self)
25
...         if matchable(Example.FirstConstructor(match.pat.count, match.pat.string)):
26
...             count, string = matchable[match.pat.count, match.pat.string]
27
...             for _ in range(count):
28
...                 yield string
29
...         elif matchable(Example.SecondConstructor(match.pat.bytes)):
30
...             bytes_ = matchable[match.pat.bytes]
31
...             for byte in bytes_:
32
...                 yield chr(byte)
33
...         elif matchable(Example.ThirdConstructor()):
34
...             yield "Third"
35
...             yield "Constructor"
36
>>> list(Example.FirstConstructor(5, "abc"))
37
['abc', 'abc', 'abc', 'abc', 'abc']
38
>>> list(Example.SecondConstructor(b"abc"))
39
['a', 'b', 'c']
40
>>> list(Example.ThirdConstructor())
41
['Third', 'Constructor']
42
"""
43
44
import inspect
45
import sys
46
import typing
47
48
from ._adt_constructor import ADTConstructor
49
from ._adt_constructor import make_constructor
50
from ._ctor import get_args
51
from ._prewritten_methods import SUBCLASS_ORDER
52
from ._prewritten_methods import PrewrittenProductMethods
53
from ._prewritten_methods import PrewrittenSumMethods
54
55
_T = typing.TypeVar("_T")
56
57
58
if typing.TYPE_CHECKING:  # pragma: nocover
59
60
    class Ctor:
61
        """Dummy class for type-checking purposes."""
62
63
    class ConcreteCtor(typing.Generic[_T]):
64
        """Wrapper class for type-checking purposes.
65
66
        The type parameter should be a Tuple type of fixed size.
67
        Classes containing this annotation (meaning they haven't been
68
        processed by the ``adt`` decorator) should not be instantiated.
69
        """
70
71
72
else:
73
    from ._ctor import Ctor
74
75
76
def _name(cls: typing.Type[_T], function) -> str:
77
    """Return the name of a function accessed through a descriptor."""
78
    return function.__get__(None, cls).__name__
79
80
81
def _set_new_functions(cls: typing.Type[_T], *functions) -> typing.Optional[str]:
82
    """Attempt to set the attributes corresponding to the functions on cls.
83
84
    If any attributes are already defined, fail *before* setting any, and
85
    return the already-defined name.
86
    """
87
    for function in functions:
88
        name = _name(cls, function)
89
        if getattr(object, name, None) is not getattr(cls, name, None):
90
            return name
91
    for function in functions:
92
        setattr(cls, _name(cls, function), function)
93
    return None
94
95
96
_K = typing.TypeVar("_K")
97
_V = typing.TypeVar("_V")
98
99
100
def _nillable_write(dct: typing.Dict[_K, _V], key: _K, value: typing.Optional[_V]):
101
    if value is None:
102
        dct.pop(key, typing.cast(_V, None))
103
    else:
104
        dct[key] = value
105
106
107
def _add_methods(cls: typing.Type[_T], do_set, *methods):
108
    methods_were_set = False
109
    if do_set:
110
        methods_were_set = not _set_new_functions(cls, *methods)
111
    return methods_were_set
112
113
114
def _set_hash(cls: typing.Type[_T], set_hash, src):
115
    if set_hash:
116
        cls.__hash__ = src.__hash__  # type: ignore
117
118
119
def _add_order(cls: typing.Type[_T], set_order, equality_methods_were_set, src):
120
    if set_order:
121
        if not equality_methods_were_set:
122
            raise ValueError(
123
                "Can't add ordering methods if equality methods are provided."
124
            )
125
        collision = _set_new_functions(
126
            cls, src.__lt__, src.__le__, src.__gt__, src.__ge__
127
        )
128
        if collision:
129
            raise TypeError(
130
                "Cannot overwrite attribute {collision} in class "
131
                "{name}. Consider using functools.total_ordering".format(
132
                    collision=collision, name=cls.__name__
133
                )
134
            )
135
136
137
def _sum_new(_cls: typing.Type[_T], subclasses):
138
    def base(cls, args):
139
        return super(_cls, cls).__new__(cls, args)
140
    new = _cls.__dict__.get("__new__", staticmethod(base))
141
142
    def __new__(cls, args):
143
        if cls not in subclasses:
144
            raise TypeError
145
        return new.__get__(None, cls)(cls, args)
146
    _cls.__new__ = staticmethod(__new__)  # type: ignore
147
148
149
_SENTINEL = object()
150
151
152
def _product_new(
153
    _cls: typing.Type[_T],
154
    annotations: typing.Dict[str, typing.Any],
155
    defaults: typing.Dict[str, typing.Any],
156
):
157
    def __new__(*args, **kwargs):
158
        cls, *args = args
159
        return super(_cls, cls).__new__(cls, *args, **kwargs)
160
161
    __new__.__signature__ = inspect.signature(__new__).replace(
162
        parameters=[inspect.Parameter("cls", inspect.Parameter.POSITIONAL_ONLY)]
163
        + [
164
            inspect.Parameter(
165
                field,
166
                inspect.Parameter.POSITIONAL_OR_KEYWORD,
167
                annotation=annotation,
168
                default=defaults.get(field, inspect.Parameter.empty),
169
            )
170
            for (field, annotation) in annotations.items()
171
        ]
172
    )
173
    _cls.__new__ = __new__
174
175
176
def _all_annotations(
177
    cls: typing.Type[_T]
178
) -> typing.Iterator[typing.Tuple[typing.Type[_T], str, typing.Any]]:
179
    for superclass in reversed(cls.__mro__):
180
        for key, value in vars(superclass).get("__annotations__", {}).items():
181
            yield (superclass, key, value)
182
183
184
def _sum_args_from_annotations(cls: typing.Type[_T]) -> typing.Dict[str, typing.Tuple]:
185
    args: typing.Dict[str, typing.Tuple] = {}
186
    for superclass, key, value in _all_annotations(cls):
187
        _nillable_write(
188
            args, key, get_args(value, vars(sys.modules[superclass.__module__]))
189
        )
190
    return args
191
192
193
def _product_args_from_annotations(
194
    cls: typing.Type[_T]
195
) -> typing.Dict[str, typing.Any]:
196
    args: typing.Dict[str, typing.Any] = {}
197
    for superclass, key, value in _all_annotations(cls):
198
        if value == "None":
199
            value = None
200
        _nillable_write(args, key, value)
201
    return args
202
203
204
def _tuple_getter(index: int):
205
    # TODO: __name__ and __qualname__
206
    @property
207
    def getter(self):
208
        return tuple.__getitem__(self, index)
209
210
    return getter
211
212
213
def _process_class(_cls: typing.Type[_T], _repr, eq, order) -> typing.Type[_T]:
214
    if order and not eq:
215
        raise ValueError("eq must be true if order is true")
216
217
    subclass_order: typing.List[typing.Type[_T]] = []
218
219
    for name, args in _sum_args_from_annotations(_cls).items():
220
        make_constructor(_cls, name, args, subclass_order)
221
222
    SUBCLASS_ORDER[_cls] = tuple(subclass_order)
223
224
    _cls.__init_subclass__ = PrewrittenSumMethods.__init_subclass__  # type: ignore
225
226
    _sum_new(_cls, frozenset(subclass_order))
227
228
    _set_new_functions(
229
        _cls, PrewrittenSumMethods.__setattr__, PrewrittenSumMethods.__delattr__
230
    )
231
    _set_new_functions(_cls, PrewrittenSumMethods.__bool__)
232
233
    _add_methods(_cls, _repr, PrewrittenSumMethods.__repr__)
234
235
    equality_methods_were_set = _add_methods(
236
        _cls, eq, PrewrittenSumMethods.__eq__, PrewrittenSumMethods.__ne__
237
    )
238
239
    _set_hash(_cls, equality_methods_were_set, PrewrittenSumMethods)
240
241
    _add_order(_cls, order, equality_methods_were_set, PrewrittenSumMethods)
242
243
    return _cls
244
245
246
class Sum:
247
    """Base class of classes with disjoint constructors.
248
249
    Examines PEP 526 __annotations__ to determine subclasses.
250
251
    If repr is true, a __repr__() method is added to the class.
252
    If order is true, rich comparison dunder methods are added.
253
254
    The Sum class examines the class to find Ctor annotations.
255
    A Ctor annotation is the adt.Ctor class itself, or the result of indexing
256
    the class, either with a single type hint, or a tuple of type hints.
257
    All other annotations are ignored.
258
259
    The subclass is not subclassable, but has subclasses at each of the
260
    names that had Ctor annotations. Each subclass takes a fixed number of
261
    arguments, corresponding to the type hints given to its annotation, if any.
262
    """
263
264
    __slots__ = ()
265
266
    def __init_subclass__(cls, *, repr=True, eq=True, order=False, **kwargs):
267
        super().__init_subclass__(**kwargs)
268
        if not issubclass(cls, ADTConstructor):
269
            _process_class(cls, repr, eq, order)
270
271
272
class Product(ADTConstructor, tuple):
273
274
    __slots__ = ()
275
276
    def __new__(*args, **kwargs):
277
        cls, *args = args
278
        values = cls.__defaults.copy()
279
        fields_iter = iter(cls.__annotations)
280
        for arg, field in zip(args, fields_iter):
281
            values[field] = arg
282
        for field in fields_iter:
283
            if field in values and field not in kwargs:
284
                continue
285
            values[field] = kwargs.pop(field)
286
        if kwargs:
287
            raise TypeError(kwargs)
288
        return super(Product, cls).__new__(
289
            cls, [values[field] for field in cls.__annotations]
290
        )
291
292
    def __init_subclass__(cls, *, repr=True, eq=True, order=False, **kwargs):
293
        super().__init_subclass__(**kwargs)
294
        if "__annotations__" not in vars(cls):
295
            return
296
        if order and not eq:
297
            raise ValueError("eq must be true if order is true")
298
299
        cls.__annotations = _product_args_from_annotations(cls)
300
301
        cls.__defaults = {}
302
        field_names = iter(reversed(tuple(cls.__annotations)))
303
        for field in field_names:
304
            default = getattr(cls, field, _SENTINEL)
305
            if default is _SENTINEL:
306
                break
307
            cls.__defaults[field] = default
308
        for field in field_names:
309
            if getattr(cls, field, _SENTINEL) is not _SENTINEL:
310
                raise TypeError
311
312
        _product_new(cls, cls.__annotations, cls.__defaults)
313
314
        for index, field in enumerate(cls.__annotations):
315
            setattr(cls, field, _tuple_getter(index))
316
317
        _set_new_functions(
318
            cls,
319
            PrewrittenProductMethods.__setattr__,
320
            PrewrittenProductMethods.__delattr__,
321
        )
322
        _set_new_functions(cls, PrewrittenProductMethods.__bool__)
323
324
        _add_methods(cls, repr, PrewrittenProductMethods.__repr__)
325
326
        equality_methods_were_set = _add_methods(
327
            cls, eq, PrewrittenProductMethods.__eq__, PrewrittenProductMethods.__ne__
328
        )
329
330
        _set_hash(cls, equality_methods_were_set, PrewrittenProductMethods)
331
332
        _add_order(cls, order, equality_methods_were_set, PrewrittenProductMethods)
333
334
335
__all__ = ["Ctor", "Product", "Sum"]
336