Passed
Push — master ( a26568...eff5fd )
by Max
01:08
created

structured_data.adt._ordering_options_are_valid()   A

Complexity

Conditions 3

Size

Total Lines 3
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 3
nop 3
dl 0
loc 3
rs 10
c 0
b 0
f 0
1
"""Base classes for defining abstract data types.
2
3
This module provides three public members, which are used together.
4
5
Given a structure, possibly a choice of different structures, that you'd like
6
to associate with a type:
7
8
- First, create a class, that subclasses the Sum class.
9
- Then, for each possible structure, add an attribute annotation to the class
10
  with the desired name of the constructor, and a type of ``Ctor``, with the
11
  types within the constructor as arguments.
12
13
To look inside an ADT instance, use the functions from the
14
:mod:`structured_data.match` module.
15
16
Putting it together:
17
18
>>> from structured_data import match
19
>>> class Example(Sum):
20
...     FirstConstructor: Ctor[int, str]
21
...     SecondConstructor: Ctor[bytes]
22
...     ThirdConstructor: Ctor
23
...     def __iter__(self):
24
...         matchable = match.Matchable(self)
25
...         if matchable(Example.FirstConstructor(match.pat.count, match.pat.string)):
26
...             count, string = matchable[match.pat.count, match.pat.string]
27
...             for _ in range(count):
28
...                 yield string
29
...         elif matchable(Example.SecondConstructor(match.pat.bytes)):
30
...             bytes_ = matchable[match.pat.bytes]
31
...             for byte in bytes_:
32
...                 yield chr(byte)
33
...         elif matchable(Example.ThirdConstructor()):
34
...             yield "Third"
35
...             yield "Constructor"
36
>>> list(Example.FirstConstructor(5, "abc"))
37
['abc', 'abc', 'abc', 'abc', 'abc']
38
>>> list(Example.SecondConstructor(b"abc"))
39
['a', 'b', 'c']
40
>>> list(Example.ThirdConstructor())
41
['Third', 'Constructor']
42
"""
43
44
import inspect
45
import sys
46
import types
47
import typing
48
49
from ._adt_constructor import ADTConstructor
50
from ._adt_constructor import make_constructor
51
from ._ctor import get_args
52
from ._prewritten_methods import SUBCLASS_ORDER
53
from ._prewritten_methods import PrewrittenProductMethods
54
from ._prewritten_methods import PrewrittenSumMethods
55
56
_T = typing.TypeVar("_T")
57
58
59
if typing.TYPE_CHECKING:  # pragma: nocover
60
61
    class Ctor:
62
        """Dummy class for type-checking purposes."""
63
64
    class ConcreteCtor(typing.Generic[_T]):
65
        """Wrapper class for type-checking purposes.
66
67
        The type parameter should be a Tuple type of fixed size.
68
        Classes containing this annotation (meaning they haven't been
69
        processed by the ``adt`` decorator) should not be instantiated.
70
        """
71
72
73
else:
74
    from ._ctor import Ctor
75
76
77
def _conditional_raise(do_raise, exc_class, *args):
78
    if do_raise:
79
        raise exc_class(*args)
80
81
82
def _name(cls: typing.Type[_T], function) -> str:
83
    """Return the name of a function accessed through a descriptor."""
84
    return function.__get__(None, cls).__name__
85
86
87
def _cant_set_new_functions(cls: typing.Type[_T], *functions) -> typing.Optional[str]:
88
    for function in functions:
89
        name = _name(cls, function)
90
        existing = getattr(cls, name, None)
91
        if existing not in (
92
            getattr(object, name, None),
93
            getattr(Product, name, None),
94
            None,
95
            function,
96
        ):
97
            return name
98
    return None
99
100
101
def _set_new_functions(cls: typing.Type[_T], *functions) -> typing.Optional[str]:
102
    """Attempt to set the attributes corresponding to the functions on cls.
103
104
    If any attributes are already defined, fail *before* setting any, and
105
    return the already-defined name.
106
    """
107
    cant_set = _cant_set_new_functions(cls, *functions)
108
    if cant_set:
109
        return cant_set
110
    for function in functions:
111
        setattr(cls, _name(cls, function), function)
112
    return None
113
114
115
_K = typing.TypeVar("_K")
116
_V = typing.TypeVar("_V")
117
118
119
def _nillable_write(dct: typing.Dict[_K, _V], key: _K, value: typing.Optional[_V]):
120
    if value is None:
121
        dct.pop(key, typing.cast(_V, None))
122
    else:
123
        dct[key] = value
124
125
126
def _add_methods(cls: typing.Type[_T], do_set, *methods):
127
    methods_were_set = False
128
    if do_set:
129
        methods_were_set = not _set_new_functions(cls, *methods)
130
    return methods_were_set
131
132
133
def _sum_new(_cls: typing.Type[_T], subclasses):
134
    def base(cls, args):
135
        return super(_cls, cls).__new__(cls, args)
136
137
    new = _cls.__dict__.get("__new__", staticmethod(base))
138
139
    def __new__(cls, args):
140
        _conditional_raise(cls not in subclasses, TypeError)
141
        return new.__get__(None, cls)(cls, args)
142
143
    _cls.__new__ = staticmethod(__new__)  # type: ignore
144
145
146
def _product_new(
147
    _cls: typing.Type[_T],
148
    annotations: typing.Dict[str, typing.Any],
149
    defaults: typing.Dict[str, typing.Any],
150
):
151
    def __new__(*args, **kwargs):
152
        cls, *args = args
153
        return super(_cls, cls).__new__(cls, *args, **kwargs)
154
155
    __new__.__signature__ = inspect.signature(__new__).replace(
156
        parameters=[inspect.Parameter("cls", inspect.Parameter.POSITIONAL_ONLY)]
157
        + [
158
            inspect.Parameter(
159
                field,
160
                inspect.Parameter.POSITIONAL_OR_KEYWORD,
161
                annotation=annotation,
162
                default=defaults.get(field, inspect.Parameter.empty),
163
            )
164
            for (field, annotation) in annotations.items()
165
        ]
166
    )
167
    _cls.__new__ = __new__
168
169
170
def _all_annotations(
171
    cls: typing.Type[_T]
172
) -> typing.Iterator[typing.Tuple[typing.Type[_T], str, typing.Any]]:
173
    for superclass in reversed(cls.__mro__):
174
        for key, value in vars(superclass).get("__annotations__", {}).items():
175
            yield (superclass, key, value)
176
177
178
def _sum_args_from_annotations(cls: typing.Type[_T]) -> typing.Dict[str, typing.Tuple]:
179
    args: typing.Dict[str, typing.Tuple] = {}
180
    for superclass, key, value in _all_annotations(cls):
181
        _nillable_write(
182
            args, key, get_args(value, vars(sys.modules[superclass.__module__]))
183
        )
184
    return args
185
186
187
def _product_args_from_annotations(
188
    cls: typing.Type[_T]
189
) -> typing.Dict[str, typing.Any]:
190
    args: typing.Dict[str, typing.Any] = {}
191
    for _, key, value in _all_annotations(cls):
192
        if value == "None":
193
            value = None
194
        _nillable_write(args, key, value)
195
    return args
196
197
198
def _conditional_update(obj, attrs_and_values):
199
    for key, value in attrs_and_values.items():
200
        if value is not None:
201
            setattr(obj, key, value)
202
203
204
def _ordering_options_are_valid(*, eq, order):
205
    if order and not eq:
206
        raise ValueError("eq must be true if order is true")
207
208
209
class Sum:
210
    """Base class of classes with disjoint constructors.
211
212
    Examines PEP 526 __annotations__ to determine subclasses.
213
214
    If repr is true, a __repr__() method is added to the class.
215
    If order is true, rich comparison dunder methods are added.
216
217
    The Sum class examines the class to find Ctor annotations.
218
    A Ctor annotation is the adt.Ctor class itself, or the result of indexing
219
    the class, either with a single type hint, or a tuple of type hints.
220
    All other annotations are ignored.
221
222
    The subclass is not subclassable, but has subclasses at each of the
223
    names that had Ctor annotations. Each subclass takes a fixed number of
224
    arguments, corresponding to the type hints given to its annotation, if any.
225
    """
226
227
    __slots__ = ()
228
229
    def __new__(*args, **kwargs):  # pylint: disable=no-method-argument
230
        cls, *args = args
231
        _conditional_raise(not issubclass(cls, ADTConstructor), TypeError)
232
        return super(Sum, cls).__new__(cls, *args, **kwargs)
233
234
    # Both of these are for consistency with modules defined in the stdlib.
235
    # BOOM!
236
    def __init_subclass__(
237
        cls,
238
        *,
239
        repr=True,  # pylint: disable=redefined-builtin
240
        eq=True,  # pylint: disable=invalid-name
241
        order=False,
242
        **kwargs
243
    ):
244
        super().__init_subclass__(**kwargs)
245
        if issubclass(cls, ADTConstructor):
246
            return
247
        _ordering_options_are_valid(eq=eq, order=order)
248
249
        subclass_order: typing.List[typing.Type[_T]] = []
250
251
        for name, args in _sum_args_from_annotations(cls).items():
252
            make_constructor(cls, name, args, subclass_order)
253
254
        SUBCLASS_ORDER[cls] = tuple(subclass_order)
255
256
        cls.__init_subclass__ = PrewrittenSumMethods.__init_subclass__  # type: ignore
257
258
        _sum_new(cls, frozenset(subclass_order))
259
260
        _set_new_functions(
261
            cls, PrewrittenSumMethods.__setattr__, PrewrittenSumMethods.__delattr__
262
        )
263
        _set_new_functions(cls, PrewrittenSumMethods.__bool__)
264
265
        _add_methods(cls, repr, PrewrittenSumMethods.__repr__)
266
267
        equality_methods_were_set = _add_methods(
268
            cls, eq, PrewrittenSumMethods.__eq__, PrewrittenSumMethods.__ne__
269
        )
270
271
        if equality_methods_were_set:
272
            cls.__hash__ = PrewrittenSumMethods.__hash__
273
274
        if order:
275
276
            _conditional_raise(
277
                not equality_methods_were_set,
278
                ValueError,
279
                "Can't add ordering methods if equality methods are provided.",
280
            )
281
            collision = _set_new_functions(
282
                cls,
283
                PrewrittenSumMethods.__lt__,
284
                PrewrittenSumMethods.__le__,
285
                PrewrittenSumMethods.__gt__,
286
                PrewrittenSumMethods.__ge__,
287
            )
288
            _conditional_raise(
289
                collision,
290
                TypeError,
291
                "Cannot overwrite attribute {collision} in class "
292
                "{name}. Consider using functools.total_ordering".format(
293
                    collision=collision, name=cls.__name__
294
                ),
295
            )
296
297
298
class Product(ADTConstructor, tuple):
299
    """Base class of classes with typed fields.
300
301
    Examines PEP 526 __annotations__ to determine fields.
302
303
    If repr is true, a __repr__() method is added to the class.
304
    If order is true, rich comparison dunder methods are added.
305
306
    The Product class examines the class to find annotations.
307
    Annotations with a value of "None" are discarded.
308
    Fields may have default values, and can be set to inspect.empty to
309
    indicate "no default".
310
311
    The subclass is subclassable. The implementation was designed with a focus
312
    on flexibility over ideals of purity, and therefore provides various
313
    optional facilities that conflict with, for example, Liskov
314
    substitutability. For the purposes of matching, each class is considered
315
    distinct.
316
    """
317
318
    __slots__ = ()
319
320
    def __new__(*args, **kwargs):  # pylint: disable=no-method-argument
321
        cls, *args = args
322
        _conditional_raise(cls is Product, TypeError)
323
        # Similar to https://github.com/PyCQA/pylint/issues/1802
324
        values = cls.__defaults.copy()  # pylint: disable=protected-access
325
        fields_iter = iter(cls.__fields)  # pylint: disable=protected-access
326
        for arg, field in zip(args, fields_iter):
327
            values[field] = arg
328
        for field in fields_iter:
329
            if field in values and field not in kwargs:
330
                continue
331
            values[field] = kwargs.pop(field)
332
        _conditional_raise(kwargs, TypeError, kwargs)
333
        return super(Product, cls).__new__(
334
            cls,
335
            [
336
                values[field]
337
                for field in cls.__fields  # pylint: disable=protected-access
338
            ],
339
        )
340
341
    __repr = True
342
    __eq = True
343
    __order = False
344
    __eq_succeeded = None
345
346
    # Both of these are for consistency with modules defined in the stdlib.
347
    # BOOM!
348
    def __init_subclass__(
349
        cls,
350
        *,
351
        repr=None,  # pylint: disable=redefined-builtin
352
        eq=None,  # pylint: disable=invalid-name
353
        order=None,
354
        **kwargs
355
    ):
356
        super().__init_subclass__(**kwargs)
357
358
        overrides = types.SimpleNamespace()
359
        # This is really gross, but it seems to work, and it reduces the
360
        # cyclomatic complexity, so it must be good!
361
        overrides.__repr = repr  # pylint: disable=protected-access
362
        overrides.__eq = eq  # pylint: disable=protected-access
363
        overrides.__order = order  # pylint: disable=protected-access
364
365
        _conditional_update(cls, vars(overrides))
366
367
        _ordering_options_are_valid(eq=cls.__eq, order=cls.__order)
368
369
        cls.__annotations = _product_args_from_annotations(cls)
370
        cls.__fields = {field: index for (index, field) in enumerate(cls.__annotations)}
371
372
        cls.__defaults = {}
373
        field_names = iter(reversed(tuple(cls.__annotations)))
374
        for field in field_names:
375
            default = getattr(cls, field, inspect.Parameter.empty)
376
            if default is inspect.Parameter.empty:
377
                break
378
            cls.__defaults[field] = default
379
        _conditional_raise(
380
            any(
381
                getattr(cls, field, inspect.Parameter.empty)
382
                is not inspect.Parameter.empty
383
                for field in field_names
384
            ),
385
            TypeError,
386
        )
387
388
        _product_new(cls, cls.__annotations, cls.__defaults)
389
390
        cls.__eq_succeeded = False
391
        if cls.__eq:
392
            cls.__eq_succeeded = not _cant_set_new_functions(
393
                cls, PrewrittenProductMethods.__eq__, PrewrittenProductMethods.__ne__
394
            )
395
396
        if order:
397
398
            _conditional_raise(
399
                not cls.__eq_succeeded,
400
                ValueError,
401
                "Can't add ordering methods if equality methods are provided.",
402
            )
403
            collision = _cant_set_new_functions(
404
                cls,
405
                PrewrittenProductMethods.__lt__,
406
                PrewrittenProductMethods.__le__,
407
                PrewrittenProductMethods.__gt__,
408
                PrewrittenProductMethods.__ge__,
409
            )
410
            _conditional_raise(
411
                collision,
412
                TypeError,
413
                "Cannot overwrite attribute {collision} in class "
414
                "{name}. Consider using functools.total_ordering".format(
415
                    collision=collision, name=cls.__name__
416
                ),
417
            )
418
419
    def __dir__(self):
420
        return super().__dir__() + list(self.__fields)
421
422
    def __getattribute__(self, name):
423
        try:
424
            return super().__getattribute__(name)
425
        except AttributeError:
426
            index = self.__fields.get(name)
427
            if index is None:
428
                raise
429
            return tuple.__getitem__(self, index)
430
431
    __setattr__ = PrewrittenProductMethods.__setattr__
432
    __delattr__ = PrewrittenProductMethods.__delattr__
433
    __bool__ = PrewrittenProductMethods.__bool__
434
435
    @property
436
    def __repr__(self):
437
        if self.__repr:
438
            return PrewrittenProductMethods.__repr__.__get__(self, type(self))
439
        return super().__repr__
440
441
    @property
442
    def __hash__(self):
443
        if self.__eq_succeeded:
444
            return PrewrittenProductMethods.__hash__.__get__(self, type(self))
445
        return super().__hash__
446
447
    @property
448
    def __eq__(self):  # pylint: disable=unexpected-special-method-signature
449
        if self.__eq_succeeded:
450
            # I think this is a Pylint bug, but I'm not sure how to reduce it.
451
            # pylint: disable=no-value-for-parameter
452
            return PrewrittenProductMethods.__eq__.__get__(self, type(self))
453
        return super().__eq__
454
455
    @property
456
    def __ne__(self):  # pylint: disable=unexpected-special-method-signature
457
        if self.__eq_succeeded:
458
            # I think this is a Pylint bug, but I'm not sure how to reduce it.
459
            # pylint: disable=no-value-for-parameter
460
            return PrewrittenProductMethods.__ne__.__get__(self, type(self))
461
        return super().__ne__
462
463
    @property
464
    def __lt__(self):  # pylint: disable=unexpected-special-method-signature
465
        if self.__order:
466
            # I think this is a Pylint bug, but I'm not sure how to reduce it.
467
            # pylint: disable=no-value-for-parameter
468
            return PrewrittenProductMethods.__lt__.__get__(self, type(self))
469
        return super().__lt__
470
471
    @property
472
    def __le__(self):  # pylint: disable=unexpected-special-method-signature
473
        if self.__order:
474
            # I think this is a Pylint bug, but I'm not sure how to reduce it.
475
            # pylint: disable=no-value-for-parameter
476
            return PrewrittenProductMethods.__le__.__get__(self, type(self))
477
        return super().__le__
478
479
    @property
480
    def __gt__(self):  # pylint: disable=unexpected-special-method-signature
481
        if self.__order:
482
            # I think this is a Pylint bug, but I'm not sure how to reduce it.
483
            # pylint: disable=no-value-for-parameter
484
            return PrewrittenProductMethods.__gt__.__get__(self, type(self))
485
        return super().__gt__
486
487
    @property
488
    def __ge__(self):  # pylint: disable=unexpected-special-method-signature
489
        if self.__order:
490
            # I think this is a Pylint bug, but I'm not sure how to reduce it.
491
            # pylint: disable=no-value-for-parameter
492
            return PrewrittenProductMethods.__ge__.__get__(self, type(self))
493
        return super().__ge__
494
495
496
__all__ = ["Ctor", "Product", "Sum"]
497