Passed
Push — master ( 7f6082...f616a2 )
by Max
51s
created

structured_data.adt._conditional_raise()   A

Complexity

Conditions 2

Size

Total Lines 3
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 3
nop 3
dl 0
loc 3
rs 10
c 0
b 0
f 0
1
"""Base classes for defining abstract data types.
2
3
This module provides three public members, which are used together.
4
5
Given a structure, possibly a choice of different structures, that you'd like
6
to associate with a type:
7
8
- First, create a class, that subclasses the Sum class.
9
- Then, for each possible structure, add an attribute annotation to the class
10
  with the desired name of the constructor, and a type of ``Ctor``, with the
11
  types within the constructor as arguments.
12
13
To look inside an ADT instance, use the functions from the
14
:mod:`structured_data.match` module.
15
16
Putting it together:
17
18
>>> from structured_data import match
19
>>> class Example(Sum):
20
...     FirstConstructor: Ctor[int, str]
21
...     SecondConstructor: Ctor[bytes]
22
...     ThirdConstructor: Ctor
23
...     def __iter__(self):
24
...         matchable = match.Matchable(self)
25
...         if matchable(Example.FirstConstructor(match.pat.count, match.pat.string)):
26
...             count, string = matchable[match.pat.count, match.pat.string]
27
...             for _ in range(count):
28
...                 yield string
29
...         elif matchable(Example.SecondConstructor(match.pat.bytes)):
30
...             bytes_ = matchable[match.pat.bytes]
31
...             for byte in bytes_:
32
...                 yield chr(byte)
33
...         elif matchable(Example.ThirdConstructor()):
34
...             yield "Third"
35
...             yield "Constructor"
36
>>> list(Example.FirstConstructor(5, "abc"))
37
['abc', 'abc', 'abc', 'abc', 'abc']
38
>>> list(Example.SecondConstructor(b"abc"))
39
['a', 'b', 'c']
40
>>> list(Example.ThirdConstructor())
41
['Third', 'Constructor']
42
"""
43
44
import inspect
45
import sys
46
import typing
47
48
from ._adt_constructor import ADTConstructor
49
from ._adt_constructor import make_constructor
50
from ._ctor import get_args
51
from ._prewritten_methods import SUBCLASS_ORDER
52
from ._prewritten_methods import PrewrittenProductMethods
53
from ._prewritten_methods import PrewrittenSumMethods
54
55
_T = typing.TypeVar("_T")
56
57
58
if typing.TYPE_CHECKING:  # pragma: nocover
59
60
    class Ctor:
61
        """Dummy class for type-checking purposes."""
62
63
    class ConcreteCtor(typing.Generic[_T]):
64
        """Wrapper class for type-checking purposes.
65
66
        The type parameter should be a Tuple type of fixed size.
67
        Classes containing this annotation (meaning they haven't been
68
        processed by the ``adt`` decorator) should not be instantiated.
69
        """
70
71
72
else:
73
    from ._ctor import Ctor
74
75
76
def _conditional_raise(do_raise, exc_class, *args):
77
    if do_raise:
78
        raise exc_class(*args)
79
80
81
def _name(cls: typing.Type[_T], function) -> str:
82
    """Return the name of a function accessed through a descriptor."""
83
    return function.__get__(None, cls).__name__
84
85
86
def _cant_set_new_functions(cls: typing.Type[_T], *functions) -> typing.Optional[str]:
87
    for function in functions:
88
        name = _name(cls, function)
89
        existing = getattr(cls, name, None)
90
        if existing not in (
91
            getattr(object, name, None),
92
            getattr(Product, name, None),
93
            None,
94
            function,
95
        ):
96
            return name
97
    return None
98
99
100
def _set_new_functions(cls: typing.Type[_T], *functions) -> typing.Optional[str]:
101
    """Attempt to set the attributes corresponding to the functions on cls.
102
103
    If any attributes are already defined, fail *before* setting any, and
104
    return the already-defined name.
105
    """
106
    cant_set = _cant_set_new_functions(cls, *functions)
107
    if cant_set:
108
        return cant_set
109
    for function in functions:
110
        setattr(cls, _name(cls, function), function)
111
    return None
112
113
114
_K = typing.TypeVar("_K")
115
_V = typing.TypeVar("_V")
116
117
118
def _nillable_write(dct: typing.Dict[_K, _V], key: _K, value: typing.Optional[_V]):
119
    if value is None:
120
        dct.pop(key, typing.cast(_V, None))
121
    else:
122
        dct[key] = value
123
124
125
def _add_methods(cls: typing.Type[_T], do_set, *methods):
126
    methods_were_set = False
127
    if do_set:
128
        methods_were_set = not _set_new_functions(cls, *methods)
129
    return methods_were_set
130
131
132
def _sum_new(_cls: typing.Type[_T], subclasses):
133
    def base(cls, args):
134
        return super(_cls, cls).__new__(cls, args)
135
136
    new = _cls.__dict__.get("__new__", staticmethod(base))
137
138
    def __new__(cls, args):
139
        _conditional_raise(cls not in subclasses, TypeError)
140
        return new.__get__(None, cls)(cls, args)
141
142
    _cls.__new__ = staticmethod(__new__)  # type: ignore
143
144
145
def _product_new(
146
    _cls: typing.Type[_T],
147
    annotations: typing.Dict[str, typing.Any],
148
    defaults: typing.Dict[str, typing.Any],
149
):
150
    def __new__(*args, **kwargs):
151
        cls, *args = args
152
        return super(_cls, cls).__new__(cls, *args, **kwargs)
153
154
    __new__.__signature__ = inspect.signature(__new__).replace(
155
        parameters=[inspect.Parameter("cls", inspect.Parameter.POSITIONAL_ONLY)]
156
        + [
157
            inspect.Parameter(
158
                field,
159
                inspect.Parameter.POSITIONAL_OR_KEYWORD,
160
                annotation=annotation,
161
                default=defaults.get(field, inspect.Parameter.empty),
162
            )
163
            for (field, annotation) in annotations.items()
164
        ]
165
    )
166
    _cls.__new__ = __new__
167
168
169
def _all_annotations(
170
    cls: typing.Type[_T]
171
) -> typing.Iterator[typing.Tuple[typing.Type[_T], str, typing.Any]]:
172
    for superclass in reversed(cls.__mro__):
173
        for key, value in vars(superclass).get("__annotations__", {}).items():
174
            yield (superclass, key, value)
175
176
177
def _sum_args_from_annotations(cls: typing.Type[_T]) -> typing.Dict[str, typing.Tuple]:
178
    args: typing.Dict[str, typing.Tuple] = {}
179
    for superclass, key, value in _all_annotations(cls):
180
        _nillable_write(
181
            args, key, get_args(value, vars(sys.modules[superclass.__module__]))
182
        )
183
    return args
184
185
186
def _product_args_from_annotations(
187
    cls: typing.Type[_T]
188
) -> typing.Dict[str, typing.Any]:
189
    args: typing.Dict[str, typing.Any] = {}
190
    for _, key, value in _all_annotations(cls):
191
        if value == "None":
192
            value = None
193
        _nillable_write(args, key, value)
194
    return args
195
196
197
class Sum:
198
    """Base class of classes with disjoint constructors.
199
200
    Examines PEP 526 __annotations__ to determine subclasses.
201
202
    If repr is true, a __repr__() method is added to the class.
203
    If order is true, rich comparison dunder methods are added.
204
205
    The Sum class examines the class to find Ctor annotations.
206
    A Ctor annotation is the adt.Ctor class itself, or the result of indexing
207
    the class, either with a single type hint, or a tuple of type hints.
208
    All other annotations are ignored.
209
210
    The subclass is not subclassable, but has subclasses at each of the
211
    names that had Ctor annotations. Each subclass takes a fixed number of
212
    arguments, corresponding to the type hints given to its annotation, if any.
213
    """
214
215
    __slots__ = ()
216
217
    def __new__(*args, **kwargs):  # pylint: disable=no-method-argument
218
        cls, *args = args
219
        _conditional_raise(not issubclass(cls, ADTConstructor), TypeError)
220
        return super(Sum, cls).__new__(cls, *args, **kwargs)
221
222
    # Both of these are for consistency with modules defined in the stdlib.
223
    # BOOM!
224
    def __init_subclass__(
225
        cls,
226
        *,
227
        repr=True,  # pylint: disable=redefined-builtin
228
        eq=True,  # pylint: disable=invalid-name
229
        order=False,
230
        **kwargs
231
    ):
232
        super().__init_subclass__(**kwargs)
233
        if issubclass(cls, ADTConstructor):
234
            return
235
        _conditional_raise(
236
            order and not eq, ValueError, "eq must be true if order is true"
237
        )
238
239
        subclass_order: typing.List[typing.Type[_T]] = []
240
241
        for name, args in _sum_args_from_annotations(cls).items():
242
            make_constructor(cls, name, args, subclass_order)
243
244
        SUBCLASS_ORDER[cls] = tuple(subclass_order)
245
246
        cls.__init_subclass__ = PrewrittenSumMethods.__init_subclass__  # type: ignore
247
248
        _sum_new(cls, frozenset(subclass_order))
249
250
        _set_new_functions(
251
            cls, PrewrittenSumMethods.__setattr__, PrewrittenSumMethods.__delattr__
252
        )
253
        _set_new_functions(cls, PrewrittenSumMethods.__bool__)
254
255
        _add_methods(cls, repr, PrewrittenSumMethods.__repr__)
256
257
        equality_methods_were_set = _add_methods(
258
            cls, eq, PrewrittenSumMethods.__eq__, PrewrittenSumMethods.__ne__
259
        )
260
261
        if equality_methods_were_set:
262
            cls.__hash__ = PrewrittenSumMethods.__hash__
263
264
        if order:
265
266
            _conditional_raise(
267
                not equality_methods_were_set,
268
                ValueError,
269
                "Can't add ordering methods if equality methods are provided.",
270
            )
271
            collision = _set_new_functions(
272
                cls,
273
                PrewrittenSumMethods.__lt__,
274
                PrewrittenSumMethods.__le__,
275
                PrewrittenSumMethods.__gt__,
276
                PrewrittenSumMethods.__ge__,
277
            )
278
            _conditional_raise(
279
                collision,
280
                TypeError,
281
                "Cannot overwrite attribute {collision} in class "
282
                "{name}. Consider using functools.total_ordering".format(
283
                    collision=collision, name=cls.__name__
284
                ),
285
            )
286
287
288
class Product(ADTConstructor, tuple):
289
    """Base class of classes with typed fields.
290
291
    Examines PEP 526 __annotations__ to determine fields.
292
293
    If repr is true, a __repr__() method is added to the class.
294
    If order is true, rich comparison dunder methods are added.
295
296
    The Product class examines the class to find annotations.
297
    Annotations with a value of "None" are discarded.
298
    Fields may have default values, and can be set to inspect.empty to
299
    indicate "no default".
300
301
    The subclass is subclassable. The implementation was designed with a focus
302
    on flexibility over ideals of purity, and therefore provides various
303
    optional facilities that conflict with, for example, Liskov
304
    substitutability. For the purposes of matching, each class is considered
305
    distinct.
306
    """
307
308
    __slots__ = ()
309
310
    def __new__(*args, **kwargs):  # pylint: disable=no-method-argument
311
        cls, *args = args
312
        _conditional_raise(cls is Product, TypeError)
313
        # Similar to https://github.com/PyCQA/pylint/issues/1802
314
        values = cls.__defaults.copy()  # pylint: disable=protected-access
315
        fields_iter = iter(cls.__fields)  # pylint: disable=protected-access
316
        for arg, field in zip(args, fields_iter):
317
            values[field] = arg
318
        for field in fields_iter:
319
            if field in values and field not in kwargs:
320
                continue
321
            values[field] = kwargs.pop(field)
322
        _conditional_raise(kwargs, TypeError, kwargs)
323
        return super(Product, cls).__new__(
324
            cls,
325
            [
326
                values[field]
327
                for field in cls.__fields  # pylint: disable=protected-access
328
            ],
329
        )
330
331
    __repr = True
332
    __eq = True
333
    __order = False
334
    __eq_succeeded = None
335
336
    # Both of these are for consistency with modules defined in the stdlib.
337
    # BOOM!
338
    def __init_subclass__(
339
        cls,
340
        *,
341
        repr=None,  # pylint: disable=redefined-builtin
342
        eq=None,  # pylint: disable=invalid-name
343
        order=None,
344
        **kwargs
345
    ):
346
        super().__init_subclass__(**kwargs)
347
348
        if repr is not None:
349
            cls.__repr = repr
350
        if eq is not None:
351
            cls.__eq = eq
352
        if order is not None:
353
            cls.__order = order
354
355
        _conditional_raise(
356
            cls.__order and not cls.__eq, ValueError, "eq must be true if order is true"
357
        )
358
359
        cls.__annotations = _product_args_from_annotations(cls)
360
        cls.__fields = {field: index for (index, field) in enumerate(cls.__annotations)}
361
362
        cls.__defaults = {}
363
        field_names = iter(reversed(tuple(cls.__annotations)))
364
        for field in field_names:
365
            default = getattr(cls, field, inspect.Parameter.empty)
366
            if default is inspect.Parameter.empty:
367
                break
368
            cls.__defaults[field] = default
369
        _conditional_raise(
370
            any(
371
                getattr(cls, field, inspect.Parameter.empty)
372
                is not inspect.Parameter.empty
373
                for field in field_names
374
            ),
375
            TypeError,
376
        )
377
378
        _product_new(cls, cls.__annotations, cls.__defaults)
379
380
        cls.__eq_succeeded = False
381
        if cls.__eq:
382
            cls.__eq_succeeded = not _cant_set_new_functions(
383
                cls, PrewrittenProductMethods.__eq__, PrewrittenProductMethods.__ne__
384
            )
385
386
        if order:
387
388
            _conditional_raise(
389
                not cls.__eq_succeeded,
390
                ValueError,
391
                "Can't add ordering methods if equality methods are provided.",
392
            )
393
            collision = _cant_set_new_functions(
394
                cls,
395
                PrewrittenProductMethods.__lt__,
396
                PrewrittenProductMethods.__le__,
397
                PrewrittenProductMethods.__gt__,
398
                PrewrittenProductMethods.__ge__,
399
            )
400
            _conditional_raise(
401
                collision,
402
                TypeError,
403
                "Cannot overwrite attribute {collision} in class "
404
                "{name}. Consider using functools.total_ordering".format(
405
                    collision=collision, name=cls.__name__
406
                ),
407
            )
408
409
    def __dir__(self):
410
        return super().__dir__() + list(self.__fields)
411
412
    def __getattribute__(self, name):
413
        try:
414
            return super().__getattribute__(name)
415
        except AttributeError:
416
            index = self.__fields.get(name)
417
            if index is None:
418
                raise
419
            return tuple.__getitem__(self, index)
420
421
    __setattr__ = PrewrittenProductMethods.__setattr__
422
    __delattr__ = PrewrittenProductMethods.__delattr__
423
    __bool__ = PrewrittenProductMethods.__bool__
424
425
    @property
426
    def __repr__(self):
427
        if self.__repr:
428
            return PrewrittenProductMethods.__repr__.__get__(self, type(self))
429
        return super().__repr__
430
431
    @property
432
    def __hash__(self):
433
        if self.__eq_succeeded:
434
            return PrewrittenProductMethods.__hash__.__get__(self, type(self))
435
        return super().__hash__
436
437
    @property
438
    def __eq__(self):  # pylint: disable=unexpected-special-method-signature
439
        if self.__eq_succeeded:
440
            # I think this is a Pylint bug, but I'm not sure how to reduce it.
441
            # pylint: disable=no-value-for-parameter
442
            return PrewrittenProductMethods.__eq__.__get__(self, type(self))
443
        return super().__eq__
444
445
    @property
446
    def __ne__(self):  # pylint: disable=unexpected-special-method-signature
447
        if self.__eq_succeeded:
448
            # I think this is a Pylint bug, but I'm not sure how to reduce it.
449
            # pylint: disable=no-value-for-parameter
450
            return PrewrittenProductMethods.__ne__.__get__(self, type(self))
451
        return super().__ne__
452
453
    @property
454
    def __lt__(self):  # pylint: disable=unexpected-special-method-signature
455
        if self.__order:
456
            # I think this is a Pylint bug, but I'm not sure how to reduce it.
457
            # pylint: disable=no-value-for-parameter
458
            return PrewrittenProductMethods.__lt__.__get__(self, type(self))
459
        return super().__lt__
460
461
    @property
462
    def __le__(self):  # pylint: disable=unexpected-special-method-signature
463
        if self.__order:
464
            # I think this is a Pylint bug, but I'm not sure how to reduce it.
465
            # pylint: disable=no-value-for-parameter
466
            return PrewrittenProductMethods.__le__.__get__(self, type(self))
467
        return super().__le__
468
469
    @property
470
    def __gt__(self):  # pylint: disable=unexpected-special-method-signature
471
        if self.__order:
472
            # I think this is a Pylint bug, but I'm not sure how to reduce it.
473
            # pylint: disable=no-value-for-parameter
474
            return PrewrittenProductMethods.__gt__.__get__(self, type(self))
475
        return super().__gt__
476
477
    @property
478
    def __ge__(self):  # pylint: disable=unexpected-special-method-signature
479
        if self.__order:
480
            # I think this is a Pylint bug, but I'm not sure how to reduce it.
481
            # pylint: disable=no-value-for-parameter
482
            return PrewrittenProductMethods.__ge__.__get__(self, type(self))
483
        return super().__ge__
484
485
486
__all__ = ["Ctor", "Product", "Sum"]
487