Passed
Push — master ( a2fe81...4dbf06 )
by Max
52s
created

structured_data.adt._coalesce()   A

Complexity

Conditions 2

Size

Total Lines 4
Code Lines 4

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 4
nop 2
dl 0
loc 4
rs 10
c 0
b 0
f 0
1
"""Base classes for defining abstract data types.
2
3
This module provides three public members, which are used together.
4
5
Given a structure, possibly a choice of different structures, that you'd like
6
to associate with a type:
7
8
- First, create a class, that subclasses the Sum class.
9
- Then, for each possible structure, add an attribute annotation to the class
10
  with the desired name of the constructor, and a type of ``Ctor``, with the
11
  types within the constructor as arguments.
12
13
To look inside an ADT instance, use the functions from the
14
:mod:`structured_data.match` module.
15
16
Putting it together:
17
18
>>> from structured_data import match
19
>>> class Example(Sum):
20
...     FirstConstructor: Ctor[int, str]
21
...     SecondConstructor: Ctor[bytes]
22
...     ThirdConstructor: Ctor
23
...     def __iter__(self):
24
...         matchable = match.Matchable(self)
25
...         if matchable(Example.FirstConstructor(match.pat.count, match.pat.string)):
26
...             count, string = matchable[match.pat.count, match.pat.string]
27
...             for _ in range(count):
28
...                 yield string
29
...         elif matchable(Example.SecondConstructor(match.pat.bytes)):
30
...             bytes_ = matchable[match.pat.bytes]
31
...             for byte in bytes_:
32
...                 yield chr(byte)
33
...         elif matchable(Example.ThirdConstructor()):
34
...             yield "Third"
35
...             yield "Constructor"
36
>>> list(Example.FirstConstructor(5, "abc"))
37
['abc', 'abc', 'abc', 'abc', 'abc']
38
>>> list(Example.SecondConstructor(b"abc"))
39
['a', 'b', 'c']
40
>>> list(Example.ThirdConstructor())
41
['Third', 'Constructor']
42
"""
43
44
import inspect
45
import sys
46
import typing
47
48
from ._adt_constructor import ADTConstructor
49
from ._adt_constructor import make_constructor
50
from ._ctor import get_args
51
from ._prewritten_methods import SUBCLASS_ORDER
52
from ._prewritten_methods import PrewrittenProductMethods
53
from ._prewritten_methods import PrewrittenSumMethods
54
55
_T = typing.TypeVar("_T")
56
57
58
if typing.TYPE_CHECKING:  # pragma: nocover
59
60
    class Ctor:
61
        """Dummy class for type-checking purposes."""
62
63
    class ConcreteCtor(typing.Generic[_T]):
64
        """Wrapper class for type-checking purposes.
65
66
        The type parameter should be a Tuple type of fixed size.
67
        Classes containing this annotation (meaning they haven't been
68
        processed by the ``adt`` decorator) should not be instantiated.
69
        """
70
71
72
else:
73
    from ._ctor import Ctor
74
75
76
def _name(cls: typing.Type[_T], function) -> str:
77
    """Return the name of a function accessed through a descriptor."""
78
    return function.__get__(None, cls).__name__
79
80
81
def _cant_set_new_functions(cls: typing.Type[_T], *functions) -> typing.Optional[str]:
82
    for function in functions:
83
        name = _name(cls, function)
84
        existing = getattr(cls, name, None)
85
        if existing not in (
86
            getattr(object, name, None),
87
            getattr(Product, name, None),
88
            None,
89
            function,
90
        ):
91
            return name
92
93
94
def _set_new_functions(cls: typing.Type[_T], *functions) -> typing.Optional[str]:
95
    """Attempt to set the attributes corresponding to the functions on cls.
96
97
    If any attributes are already defined, fail *before* setting any, and
98
    return the already-defined name.
99
    """
100
    cant_set = _cant_set_new_functions(cls, *functions)
101
    if cant_set:
102
        return cant_set
103
    for function in functions:
104
        setattr(cls, _name(cls, function), function)
105
    return None
106
107
108
_K = typing.TypeVar("_K")
109
_V = typing.TypeVar("_V")
110
111
112
def _nillable_write(dct: typing.Dict[_K, _V], key: _K, value: typing.Optional[_V]):
113
    if value is None:
114
        dct.pop(key, typing.cast(_V, None))
115
    else:
116
        dct[key] = value
117
118
119
def _add_methods(cls: typing.Type[_T], do_set, *methods):
120
    methods_were_set = False
121
    if do_set:
122
        methods_were_set = not _set_new_functions(cls, *methods)
123
    return methods_were_set
124
125
126
def _sum_new(_cls: typing.Type[_T], subclasses):
127
    def base(cls, args):
128
        return super(_cls, cls).__new__(cls, args)
129
130
    new = _cls.__dict__.get("__new__", staticmethod(base))
131
132
    def __new__(cls, args):
133
        if cls not in subclasses:
134
            raise TypeError
135
        return new.__get__(None, cls)(cls, args)
136
137
    _cls.__new__ = staticmethod(__new__)  # type: ignore
138
139
140
def _product_new(
141
    _cls: typing.Type[_T],
142
    annotations: typing.Dict[str, typing.Any],
143
    defaults: typing.Dict[str, typing.Any],
144
):
145
    def __new__(*args, **kwargs):
146
        cls, *args = args
147
        return super(_cls, cls).__new__(cls, *args, **kwargs)
148
149
    __new__.__signature__ = inspect.signature(__new__).replace(
150
        parameters=[inspect.Parameter("cls", inspect.Parameter.POSITIONAL_ONLY)]
151
        + [
152
            inspect.Parameter(
153
                field,
154
                inspect.Parameter.POSITIONAL_OR_KEYWORD,
155
                annotation=annotation,
156
                default=defaults.get(field, inspect.Parameter.empty),
157
            )
158
            for (field, annotation) in annotations.items()
159
        ]
160
    )
161
    _cls.__new__ = __new__
162
163
164
def _all_annotations(
165
    cls: typing.Type[_T]
166
) -> typing.Iterator[typing.Tuple[typing.Type[_T], str, typing.Any]]:
167
    for superclass in reversed(cls.__mro__):
168
        for key, value in vars(superclass).get("__annotations__", {}).items():
169
            yield (superclass, key, value)
170
171
172
def _sum_args_from_annotations(cls: typing.Type[_T]) -> typing.Dict[str, typing.Tuple]:
173
    args: typing.Dict[str, typing.Tuple] = {}
174
    for superclass, key, value in _all_annotations(cls):
175
        _nillable_write(
176
            args, key, get_args(value, vars(sys.modules[superclass.__module__]))
177
        )
178
    return args
179
180
181
def _product_args_from_annotations(
182
    cls: typing.Type[_T]
183
) -> typing.Dict[str, typing.Any]:
184
    args: typing.Dict[str, typing.Any] = {}
185
    for _, key, value in _all_annotations(cls):
186
        if value == "None":
187
            value = None
188
        _nillable_write(args, key, value)
189
    return args
190
191
192
class Sum:
193
    """Base class of classes with disjoint constructors.
194
195
    Examines PEP 526 __annotations__ to determine subclasses.
196
197
    If repr is true, a __repr__() method is added to the class.
198
    If order is true, rich comparison dunder methods are added.
199
200
    The Sum class examines the class to find Ctor annotations.
201
    A Ctor annotation is the adt.Ctor class itself, or the result of indexing
202
    the class, either with a single type hint, or a tuple of type hints.
203
    All other annotations are ignored.
204
205
    The subclass is not subclassable, but has subclasses at each of the
206
    names that had Ctor annotations. Each subclass takes a fixed number of
207
    arguments, corresponding to the type hints given to its annotation, if any.
208
    """
209
210
    __slots__ = ()
211
212
    def __init_subclass__(cls, *, repr=True, eq=True, order=False, **kwargs):
213
        super().__init_subclass__(**kwargs)
214
        if issubclass(cls, ADTConstructor):
215
            return
216
        if order and not eq:
217
            raise ValueError("eq must be true if order is true")
218
219
        subclass_order: typing.List[typing.Type[_T]] = []
220
221
        for name, args in _sum_args_from_annotations(cls).items():
222
            make_constructor(cls, name, args, subclass_order)
223
224
        SUBCLASS_ORDER[cls] = tuple(subclass_order)
225
226
        cls.__init_subclass__ = PrewrittenSumMethods.__init_subclass__  # type: ignore
227
228
        _sum_new(cls, frozenset(subclass_order))
229
230
        _set_new_functions(
231
            cls, PrewrittenSumMethods.__setattr__, PrewrittenSumMethods.__delattr__
232
        )
233
        _set_new_functions(cls, PrewrittenSumMethods.__bool__)
234
235
        _add_methods(cls, repr, PrewrittenSumMethods.__repr__)
236
237
        equality_methods_were_set = _add_methods(
238
            cls, eq, PrewrittenSumMethods.__eq__, PrewrittenSumMethods.__ne__
239
        )
240
241
        if equality_methods_were_set:
242
            cls.__hash__ = PrewrittenSumMethods.__hash__
243
244
        if order:
245
246
            if not equality_methods_were_set:
247
                raise ValueError(
248
                    "Can't add ordering methods if equality methods are provided."
249
                )
250
            collision = _set_new_functions(
251
                cls,
252
                PrewrittenSumMethods.__lt__,
253
                PrewrittenSumMethods.__le__,
254
                PrewrittenSumMethods.__gt__,
255
                PrewrittenSumMethods.__ge__,
256
            )
257
            if collision:
258
                raise TypeError(
259
                    "Cannot overwrite attribute {collision} in class "
260
                    "{name}. Consider using functools.total_ordering".format(
261
                        collision=collision, name=cls.__name__
262
                    )
263
                )
264
265
266
def _coalesce(new, old):
267
    if new is None:
268
        return old
269
    return new
270
271
272
class Product(ADTConstructor, tuple):
273
    """Base class of classes with typed fields.
274
275
    Examines PEP 526 __annotations__ to determine fields.
276
277
    If repr is true, a __repr__() method is added to the class.
278
    If order is true, rich comparison dunder methods are added.
279
280
    The Product class examines the class to find annotations.
281
    Annotations with a value of "None" are discarded.
282
    Fields may have default values, and can be set to inspect.empty to
283
    indicate "no default".
284
285
    The subclass is subclassable. The implementation was designed with a focus
286
    on flexibility over ideals of purity, and therefore provides various
287
    optional facilities that conflict with, for example, Liskov
288
    substitutability. For the purposes of matching, each class is considered
289
    distinct.
290
    """
291
292
    __slots__ = ()
293
294
    def __new__(*args, **kwargs):
295
        cls, *args = args
296
        values = cls.__defaults.copy()
297
        fields_iter = iter(cls.__annotations)
298
        for arg, field in zip(args, fields_iter):
299
            values[field] = arg
300
        for field in fields_iter:
301
            if field in values and field not in kwargs:
302
                continue
303
            values[field] = kwargs.pop(field)
304
        if kwargs:
305
            raise TypeError(kwargs)
306
        return super(Product, cls).__new__(
307
            cls, [values[field] for field in cls.__annotations]
308
        )
309
310
    __repr = True
311
    __eq = True
312
    __order = False
313
    __eq_succeeded = None
314
315
    def __init_subclass__(cls, *, repr=None, eq=None, order=None, **kwargs):
316
        super().__init_subclass__(**kwargs)
317
318
        cls.__repr = _coalesce(repr, cls.__repr)
319
        cls.__eq = _coalesce(eq, cls.__eq)
320
        cls.__order = _coalesce(order, cls.__order)
321
322
        if cls.__order and not cls.__eq:
323
            raise ValueError("eq must be true if order is true")
324
325
        cls.__annotations = _product_args_from_annotations(cls)
326
        cls.__fields = {field: index for (index, field) in enumerate(cls.__annotations)}
327
328
        cls.__defaults = {}
329
        field_names = iter(reversed(tuple(cls.__annotations)))
330
        for field in field_names:
331
            default = getattr(cls, field, inspect.Parameter.empty)
332
            if default is inspect.Parameter.empty:
333
                break
334
            cls.__defaults[field] = default
335
        if any(
336
            getattr(cls, field, inspect.Parameter.empty) is not inspect.Parameter.empty
337
            for field in field_names
338
        ):
339
            raise TypeError
340
341
        _product_new(cls, cls.__annotations, cls.__defaults)
342
343
        cls.__eq_succeeded = False
344
        if cls.__eq:
345
            cls.__eq_succeeded = not _cant_set_new_functions(
346
                cls, PrewrittenProductMethods.__eq__, PrewrittenProductMethods.__ne__
347
            )
348
349
        if order:
350
351
            if not cls.__eq_succeeded:
352
                raise ValueError(
353
                    "Can't add ordering methods if equality methods are provided."
354
                )
355
            collision = _cant_set_new_functions(
356
                cls,
357
                PrewrittenProductMethods.__lt__,
358
                PrewrittenProductMethods.__le__,
359
                PrewrittenProductMethods.__gt__,
360
                PrewrittenProductMethods.__ge__,
361
            )
362
            if collision:
363
                raise TypeError(
364
                    "Cannot overwrite attribute {collision} in class "
365
                    "{name}. Consider using functools.total_ordering".format(
366
                        collision=collision, name=cls.__name__
367
                    )
368
                )
369
370
    def __dir__(self):
371
        return super().__dir__() + list(self.__fields)
372
373
    def __getattribute__(self, name):
374
        try:
375
            return super().__getattribute__(name)
376
        except AttributeError:
377
            index = self.__fields.get(name)
378
            if index is None:
379
                raise
380
            return tuple.__getitem__(self, index)
381
382
    __setattr__ = PrewrittenProductMethods.__setattr__
383
    __delattr__ = PrewrittenProductMethods.__delattr__
384
    __bool__ = PrewrittenProductMethods.__bool__
385
386
    @property
387
    def __repr__(self):
388
        if self.__repr:
389
            return PrewrittenProductMethods.__repr__.__get__(self, type(self))
390
        return super().__repr__
391
392
    @property
393
    def __hash__(self):
394
        if self.__eq_succeeded:
395
            return PrewrittenProductMethods.__hash__.__get__(self, type(self))
396
        return super().__hash__
397
398
    @property
399
    def __eq__(self):
400
        if self.__eq_succeeded:
401
            return PrewrittenProductMethods.__eq__.__get__(self, type(self))
402
        return super().__eq__
403
404
    @property
405
    def __ne__(self):
406
        if self.__eq_succeeded:
407
            return PrewrittenProductMethods.__ne__.__get__(self, type(self))
408
        return super().__ne__
409
410
    @property
411
    def __lt__(self):
412
        if self.__order:
413
            return PrewrittenProductMethods.__lt__.__get__(self, type(self))
414
        return super().__lt__
415
416
    @property
417
    def __le__(self):
418
        if self.__order:
419
            return PrewrittenProductMethods.__le__.__get__(self, type(self))
420
        return super().__le__
421
422
    @property
423
    def __gt__(self):
424
        if self.__order:
425
            return PrewrittenProductMethods.__gt__.__get__(self, type(self))
426
        return super().__gt__
427
428
    @property
429
    def __ge__(self):
430
        if self.__order:
431
            return PrewrittenProductMethods.__ge__.__get__(self, type(self))
432
        return super().__ge__
433
434
435
__all__ = ["Ctor", "Product", "Sum"]
436