Passed
Push — master ( ee2acb...c9efce )
by Max
01:05
created

structured_data.adt._unpack_args()   A

Complexity

Conditions 5

Size

Total Lines 9
Code Lines 9

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 5
eloc 9
nop 5
dl 0
loc 9
rs 9.3333
c 0
b 0
f 0
1
"""Base classes for defining abstract data types.
2
3
This module provides three public members, which are used together.
4
5
Given a structure, possibly a choice of different structures, that you'd like
6
to associate with a type:
7
8
- First, create a class, that subclasses the Sum class.
9
- Then, for each possible structure, add an attribute annotation to the class
10
  with the desired name of the constructor, and a type of ``Ctor``, with the
11
  types within the constructor as arguments.
12
13
To look inside an ADT instance, use the functions from the
14
:mod:`structured_data.match` module.
15
16
Putting it together:
17
18
>>> from structured_data import match
19
>>> class Example(Sum):
20
...     FirstConstructor: Ctor[int, str]
21
...     SecondConstructor: Ctor[bytes]
22
...     ThirdConstructor: Ctor
23
...     def __iter__(self):
24
...         matchable = match.Matchable(self)
25
...         if matchable(Example.FirstConstructor(match.pat.count, match.pat.string)):
26
...             count, string = matchable[match.pat.count, match.pat.string]
27
...             for _ in range(count):
28
...                 yield string
29
...         elif matchable(Example.SecondConstructor(match.pat.bytes)):
30
...             bytes_ = matchable[match.pat.bytes]
31
...             for byte in bytes_:
32
...                 yield chr(byte)
33
...         elif matchable(Example.ThirdConstructor()):
34
...             yield "Third"
35
...             yield "Constructor"
36
>>> list(Example.FirstConstructor(5, "abc"))
37
['abc', 'abc', 'abc', 'abc', 'abc']
38
>>> list(Example.SecondConstructor(b"abc"))
39
['a', 'b', 'c']
40
>>> list(Example.ThirdConstructor())
41
['Third', 'Constructor']
42
"""
43
44
import inspect
45
import sys
46
import typing
47
48
from . import _adt_constructor
49
from . import _ctor
50
from . import _prewritten_methods
51
52
_T = typing.TypeVar("_T")
53
54
55
if typing.TYPE_CHECKING:  # pragma: nocover
56
57
    class Ctor:
58
        """Dummy class for type-checking purposes."""
59
60
    class ConcreteCtor(typing.Generic[_T]):
61
        """Wrapper class for type-checking purposes.
62
63
        The type parameter should be a Tuple type of fixed size.
64
        Classes containing this annotation (meaning they haven't been
65
        processed by the ``adt`` decorator) should not be instantiated.
66
        """
67
68
69
else:
70
    from ._ctor import Ctor
71
72
73
# This is fine.
74
def _name(cls: typing.Type[_T], function) -> str:
75
    """Return the name of a function accessed through a descriptor."""
76
    return function.__get__(None, cls).__name__
77
78
79
# This is mostly fine, though the list of classes is somewhat ad-hoc, to say
80
# the least.
81
def _cant_set_new_functions(cls: typing.Type[_T], *functions) -> typing.Optional[str]:
82
    for function in functions:
83
        name = _name(cls, function)
84
        existing = getattr(cls, name, None)
85
        if existing not in (
86
            getattr(object, name, None),
87
            getattr(Product, name, None),
88
            None,
89
            function,
90
        ):
91
            return name
92
    return None
93
94
95
def _set_new_functions(cls: typing.Type[_T], *functions) -> typing.Optional[str]:
96
    """Attempt to set the attributes corresponding to the functions on cls.
97
98
    If any attributes are already defined, fail *before* setting any, and
99
    return the already-defined name.
100
    """
101
    cant_set = _cant_set_new_functions(cls, *functions)
102
    if cant_set:
103
        return cant_set
104
    for function in functions:
105
        setattr(cls, _name(cls, function), function)
106
    return None
107
108
109
_K = typing.TypeVar("_K")
110
_V = typing.TypeVar("_V")
111
112
113
def _nillable_write(dct: typing.Dict[_K, _V], key: _K, value: typing.Optional[_V]):
114
    if value is None:
115
        dct.pop(key, typing.cast(_V, None))
116
    else:
117
        dct[key] = value
118
119
120
def _sum_new(_cls: typing.Type[_T], subclasses):
121
    def base(cls, args):
122
        return super(_cls, cls).__new__(cls, args)
123
124
    new = _cls.__dict__.get("__new__", staticmethod(base))
125
126
    def __new__(cls, args):
127
        if cls not in subclasses:
128
            raise TypeError
129
        return new.__get__(None, cls)(cls, args)
130
131
    _cls.__new__ = staticmethod(__new__)  # type: ignore
132
133
134
def _product_new(
135
    _cls: typing.Type[_T],
136
    annotations: typing.Dict[str, typing.Any],
137
    defaults: typing.Dict[str, typing.Any],
138
):
139
    def __new__(*args, **kwargs):
140
        cls, *args = args
141
        return super(_cls, cls).__new__(cls, *args, **kwargs)
142
143
    __new__.__signature__ = inspect.signature(__new__).replace(
144
        parameters=[inspect.Parameter("cls", inspect.Parameter.POSITIONAL_ONLY)]
145
        + [
146
            inspect.Parameter(
147
                field,
148
                inspect.Parameter.POSITIONAL_OR_KEYWORD,
149
                annotation=annotation,
150
                default=defaults.get(field, inspect.Parameter.empty),
151
            )
152
            for (field, annotation) in annotations.items()
153
        ]
154
    )
155
    _cls.__new__ = __new__
156
157
158
def _all_annotations(
159
    cls: typing.Type[_T]
160
) -> typing.Iterator[typing.Tuple[typing.Type[_T], str, typing.Any]]:
161
    for superclass in reversed(cls.__mro__):
162
        for key, value in vars(superclass).get("__annotations__", {}).items():
163
            yield (superclass, key, value)
164
165
166
def _sum_args_from_annotations(cls: typing.Type[_T]) -> typing.Dict[str, typing.Tuple]:
167
    args: typing.Dict[str, typing.Tuple] = {}
168
    for superclass, key, value in _all_annotations(cls):
169
        _nillable_write(
170
            args, key, _ctor.get_args(value, vars(sys.modules[superclass.__module__]))
171
        )
172
    return args
173
174
175
def _product_args_from_annotations(
176
    cls: typing.Type[_T]
177
) -> typing.Dict[str, typing.Any]:
178
    args: typing.Dict[str, typing.Any] = {}
179
    for _, key, value in _all_annotations(cls):
180
        if value == "None" or _ctor.annotation_is_classvar(
181
            value, vars(sys.modules[cls.__module__])
182
        ):
183
            value = None
184
        _nillable_write(args, key, value)
185
    return args
186
187
188
def _ordering_options_are_valid(*, eq, order):
189
    if order and not eq:
190
        raise ValueError("eq must be true if order is true")
191
192
193
def _set_ordering(*, can_set, setter, cls, source):
194
    if not can_set:
195
        raise ValueError("Can't add ordering methods if equality methods are provided.")
196
    collision = setter(cls, source.__lt__, source.__le__, source.__gt__, source.__ge__)
197
    if collision:
198
        raise TypeError(
199
            "Cannot overwrite attribute {collision} in class "
200
            "{name}. Consider using functools.total_ordering".format(
201
                collision=collision, name=cls.__name__
202
            )
203
        )
204
205
206
def _extract_defaults(*, cls, annotations):
207
    defaults = {}
208
    field_names = iter(reversed(tuple(annotations)))
209
    for field in field_names:
210
        default = getattr(cls, field, inspect.Parameter.empty)
211
        if default is inspect.Parameter.empty:
212
            break
213
        defaults[field] = default
214
    for field in field_names:
215
        if getattr(cls, field, inspect.Parameter.empty) is not inspect.Parameter.empty:
216
            raise TypeError
217
    return defaults
218
219
220
def _unpack_args(*, args, kwargs, fields, values):
221
    fields_iter = iter(fields)
222
    values.update({field: arg for (arg, field) in zip(args, fields_iter)})
223
    for field in fields_iter:
224
        if field in values and field not in kwargs:
225
            continue
226
        values[field] = kwargs.pop(field)
227
    if kwargs:
228
        raise TypeError(kwargs)
229
230
231
class Sum:
232
    """Base class of classes with disjoint constructors.
233
234
    Examines PEP 526 __annotations__ to determine subclasses.
235
236
    If repr is true, a __repr__() method is added to the class.
237
    If order is true, rich comparison dunder methods are added.
238
239
    The Sum class examines the class to find Ctor annotations.
240
    A Ctor annotation is the adt.Ctor class itself, or the result of indexing
241
    the class, either with a single type hint, or a tuple of type hints.
242
    All other annotations are ignored.
243
244
    The subclass is not subclassable, but has subclasses at each of the
245
    names that had Ctor annotations. Each subclass takes a fixed number of
246
    arguments, corresponding to the type hints given to its annotation, if any.
247
    """
248
249
    __slots__ = ()
250
251
    def __new__(*args, **kwargs):  # pylint: disable=no-method-argument
252
        cls, *args = args
253
        if not issubclass(cls, _adt_constructor.ADTConstructor):
254
            raise TypeError
255
        return super(Sum, cls).__new__(cls, *args, **kwargs)
256
257
    # Both of these are for consistency with modules defined in the stdlib.
258
    # BOOM!
259
    def __init_subclass__(
260
        cls,
261
        *,
262
        repr=True,  # pylint: disable=redefined-builtin
263
        eq=True,  # pylint: disable=invalid-name
264
        order=False,
265
        **kwargs
266
    ):
267
        super().__init_subclass__(**kwargs)
268
        if issubclass(cls, _adt_constructor.ADTConstructor):
269
            return
270
        _ordering_options_are_valid(eq=eq, order=order)
271
272
        subclass_order: typing.List[typing.Type[_T]] = []
273
274
        for name, args in _sum_args_from_annotations(cls).items():
275
            _adt_constructor.make_constructor(cls, name, args, subclass_order)
276
277
        _prewritten_methods.SUBCLASS_ORDER[cls] = tuple(subclass_order)
278
279
        cls.__init_subclass__ = (
280
            _prewritten_methods.PrewrittenSumMethods.__init_subclass__
281
        )  # type: ignore
282
283
        _sum_new(cls, frozenset(subclass_order))
284
285
        _set_new_functions(
286
            cls,
287
            _prewritten_methods.PrewrittenSumMethods.__setattr__,
288
            _prewritten_methods.PrewrittenSumMethods.__delattr__,
289
        )
290
        _set_new_functions(cls, _prewritten_methods.PrewrittenSumMethods.__bool__)
291
292
        if repr:
293
            _set_new_functions(cls, _prewritten_methods.PrewrittenSumMethods.__repr__)
294
295
        equality_methods_were_set = False
296
        if eq:
297
            equality_methods_were_set = not _set_new_functions(
298
                cls,
299
                _prewritten_methods.PrewrittenSumMethods.__eq__,
300
                _prewritten_methods.PrewrittenSumMethods.__ne__,
301
            )
302
303
        if equality_methods_were_set:
304
            cls.__hash__ = _prewritten_methods.PrewrittenSumMethods.__hash__
305
306
        if order:
307
            _set_ordering(
308
                can_set=equality_methods_were_set,
309
                setter=_set_new_functions,
310
                cls=cls,
311
                source=_prewritten_methods.PrewrittenSumMethods,
312
            )
313
314
315
class Product(_adt_constructor.ADTConstructor, tuple):
316
    """Base class of classes with typed fields.
317
318
    Examines PEP 526 __annotations__ to determine fields.
319
320
    If repr is true, a __repr__() method is added to the class.
321
    If order is true, rich comparison dunder methods are added.
322
323
    The Product class examines the class to find annotations.
324
    Annotations with a value of "None" are discarded.
325
    Fields may have default values, and can be set to inspect.empty to
326
    indicate "no default".
327
328
    The subclass is subclassable. The implementation was designed with a focus
329
    on flexibility over ideals of purity, and therefore provides various
330
    optional facilities that conflict with, for example, Liskov
331
    substitutability. For the purposes of matching, each class is considered
332
    distinct.
333
    """
334
335
    __slots__ = ()
336
337
    def __new__(*args, **kwargs):  # pylint: disable=no-method-argument
338
        cls, *args = args
339
        if cls is Product:
340
            raise TypeError
341
        # Probably a result of not having positional-only args.
342
        values = cls.__defaults.copy()  # pylint: disable=protected-access
343
        _unpack_args(
344
            args=args,
345
            kwargs=kwargs,
346
            fields=cls.__fields,  # pylint: disable=protected-access
347
            values=values,
348
        )
349
        return super(Product, cls).__new__(
350
            cls,
351
            [
352
                values[field]
353
                for field in cls.__fields  # pylint: disable=protected-access
354
            ],
355
        )
356
357
    __repr = True
358
    __eq = True
359
    __order = False
360
    __eq_succeeded = None
361
362
    # Both of these are for consistency with modules defined in the stdlib.
363
    # BOOM!
364
    def __init_subclass__(
365
        cls,
366
        *,
367
        repr=None,  # pylint: disable=redefined-builtin
368
        eq=None,  # pylint: disable=invalid-name
369
        order=None,
370
        **kwargs
371
    ):
372
        super().__init_subclass__(**kwargs)
373
374
        if repr is not None:
375
            cls.__repr = repr
376
        if eq is not None:
377
            cls.__eq = eq
378
        if order is not None:
379
            cls.__order = order
380
381
        _ordering_options_are_valid(eq=cls.__eq, order=cls.__order)
382
383
        cls.__annotations = _product_args_from_annotations(cls)
384
        cls.__fields = {field: index for (index, field) in enumerate(cls.__annotations)}
385
386
        cls.__defaults = _extract_defaults(cls=cls, annotations=cls.__annotations)
387
388
        _product_new(cls, cls.__annotations, cls.__defaults)
389
390
        cls.__eq_succeeded = False
391
        if cls.__eq:
392
            cls.__eq_succeeded = not _cant_set_new_functions(
393
                cls,
394
                _prewritten_methods.PrewrittenProductMethods.__eq__,
395
                _prewritten_methods.PrewrittenProductMethods.__ne__,
396
            )
397
398
        if cls.__order:
399
            _set_ordering(
400
                can_set=cls.__eq_succeeded,
401
                setter=_cant_set_new_functions,
402
                cls=cls,
403
                source=_prewritten_methods.PrewrittenProductMethods,
404
            )
405
406
    def __dir__(self):
407
        return super().__dir__() + list(self.__fields)
408
409
    def __getattribute__(self, name):
410
        try:
411
            return super().__getattribute__(name)
412
        except AttributeError:
413
            index = self.__fields.get(name)
414
            if index is None:
415
                raise
416
            return tuple.__getitem__(self, index)
417
418
    __setattr__ = _prewritten_methods.PrewrittenProductMethods.__setattr__
419
    __delattr__ = _prewritten_methods.PrewrittenProductMethods.__delattr__
420
    __bool__ = _prewritten_methods.PrewrittenProductMethods.__bool__
421
422
    @property
423
    def __repr__(self):
424
        if self.__repr:
425
            return _prewritten_methods.PrewrittenProductMethods.__repr__.__get__(
426
                self, type(self)
427
            )
428
        return super().__repr__
429
430
    @property
431
    def __hash__(self):
432
        if self.__eq_succeeded:
433
            return _prewritten_methods.PrewrittenProductMethods.__hash__.__get__(
434
                self, type(self)
435
            )
436
        return super().__hash__
437
438
    @property
439
    def __eq__(self):  # pylint: disable=unexpected-special-method-signature
440
        if self.__eq_succeeded:
441
            # I think this is a Pylint bug, but I'm not sure how to reduce it.
442
            # pylint: disable=no-value-for-parameter
443
            return _prewritten_methods.PrewrittenProductMethods.__eq__.__get__(
444
                self, type(self)
445
            )
446
        return super().__eq__
447
448
    @property
449
    def __ne__(self):  # pylint: disable=unexpected-special-method-signature
450
        if self.__eq_succeeded:
451
            # I think this is a Pylint bug, but I'm not sure how to reduce it.
452
            # pylint: disable=no-value-for-parameter
453
            return _prewritten_methods.PrewrittenProductMethods.__ne__.__get__(
454
                self, type(self)
455
            )
456
        return super().__ne__
457
458
    @property
459
    def __lt__(self):  # pylint: disable=unexpected-special-method-signature
460
        if self.__order:
461
            # I think this is a Pylint bug, but I'm not sure how to reduce it.
462
            # pylint: disable=no-value-for-parameter
463
            return _prewritten_methods.PrewrittenProductMethods.__lt__.__get__(
464
                self, type(self)
465
            )
466
        return super().__lt__
467
468
    @property
469
    def __le__(self):  # pylint: disable=unexpected-special-method-signature
470
        if self.__order:
471
            # I think this is a Pylint bug, but I'm not sure how to reduce it.
472
            # pylint: disable=no-value-for-parameter
473
            return _prewritten_methods.PrewrittenProductMethods.__le__.__get__(
474
                self, type(self)
475
            )
476
        return super().__le__
477
478
    @property
479
    def __gt__(self):  # pylint: disable=unexpected-special-method-signature
480
        if self.__order:
481
            # I think this is a Pylint bug, but I'm not sure how to reduce it.
482
            # pylint: disable=no-value-for-parameter
483
            return _prewritten_methods.PrewrittenProductMethods.__gt__.__get__(
484
                self, type(self)
485
            )
486
        return super().__gt__
487
488
    @property
489
    def __ge__(self):  # pylint: disable=unexpected-special-method-signature
490
        if self.__order:
491
            # I think this is a Pylint bug, but I'm not sure how to reduce it.
492
            # pylint: disable=no-value-for-parameter
493
            return _prewritten_methods.PrewrittenProductMethods.__ge__.__get__(
494
                self, type(self)
495
            )
496
        return super().__ge__
497
498
499
__all__ = ["Ctor", "Product", "Sum"]
500