Passed
Push — master ( c9efce...20fd99 )
by Max
01:12
created

structured_data.adt._values_non_empty()   A

Complexity

Conditions 3

Size

Total Lines 6
Code Lines 6

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 6
nop 2
dl 0
loc 6
rs 10
c 0
b 0
f 0
1
"""Base classes for defining abstract data types.
2
3
This module provides three public members, which are used together.
4
5
Given a structure, possibly a choice of different structures, that you'd like
6
to associate with a type:
7
8
- First, create a class, that subclasses the Sum class.
9
- Then, for each possible structure, add an attribute annotation to the class
10
  with the desired name of the constructor, and a type of ``Ctor``, with the
11
  types within the constructor as arguments.
12
13
To look inside an ADT instance, use the functions from the
14
:mod:`structured_data.match` module.
15
16
Putting it together:
17
18
>>> from structured_data import match
19
>>> class Example(Sum):
20
...     FirstConstructor: Ctor[int, str]
21
...     SecondConstructor: Ctor[bytes]
22
...     ThirdConstructor: Ctor
23
...     def __iter__(self):
24
...         matchable = match.Matchable(self)
25
...         if matchable(Example.FirstConstructor(match.pat.count, match.pat.string)):
26
...             count, string = matchable[match.pat.count, match.pat.string]
27
...             for _ in range(count):
28
...                 yield string
29
...         elif matchable(Example.SecondConstructor(match.pat.bytes)):
30
...             bytes_ = matchable[match.pat.bytes]
31
...             for byte in bytes_:
32
...                 yield chr(byte)
33
...         elif matchable(Example.ThirdConstructor()):
34
...             yield "Third"
35
...             yield "Constructor"
36
>>> list(Example.FirstConstructor(5, "abc"))
37
['abc', 'abc', 'abc', 'abc', 'abc']
38
>>> list(Example.SecondConstructor(b"abc"))
39
['a', 'b', 'c']
40
>>> list(Example.ThirdConstructor())
41
['Third', 'Constructor']
42
"""
43
44
import inspect
45
import sys
46
import typing
47
48
from . import _adt_constructor
49
from . import _ctor
50
from . import _prewritten_methods
51
52
_T = typing.TypeVar("_T")
53
54
55
if typing.TYPE_CHECKING:  # pragma: nocover
56
57
    class Ctor:
58
        """Dummy class for type-checking purposes."""
59
60
    class ConcreteCtor(typing.Generic[_T]):
61
        """Wrapper class for type-checking purposes.
62
63
        The type parameter should be a Tuple type of fixed size.
64
        Classes containing this annotation (meaning they haven't been
65
        processed by the ``adt`` decorator) should not be instantiated.
66
        """
67
68
69
else:
70
    from ._ctor import Ctor
71
72
73
# This is fine.
74
def _name(cls: typing.Type[_T], function) -> str:
75
    """Return the name of a function accessed through a descriptor."""
76
    return function.__get__(None, cls).__name__
77
78
79
# This is mostly fine, though the list of classes is somewhat ad-hoc, to say
80
# the least.
81
def _cant_set_new_functions(cls: typing.Type[_T], *functions) -> typing.Optional[str]:
82
    for function in functions:
83
        name = _name(cls, function)
84
        existing = getattr(cls, name, None)
85
        if existing not in (
86
            getattr(object, name, None),
87
            getattr(Product, name, None),
88
            None,
89
            function,
90
        ):
91
            return name
92
    return None
93
94
95
def _set_new_functions(cls: typing.Type[_T], *functions) -> typing.Optional[str]:
96
    """Attempt to set the attributes corresponding to the functions on cls.
97
98
    If any attributes are already defined, fail *before* setting any, and
99
    return the already-defined name.
100
    """
101
    cant_set = _cant_set_new_functions(cls, *functions)
102
    if cant_set:
103
        return cant_set
104
    for function in functions:
105
        setattr(cls, _name(cls, function), function)
106
    return None
107
108
109
_K = typing.TypeVar("_K")
110
_V = typing.TypeVar("_V")
111
112
113
def _nillable_write(dct: typing.Dict[_K, _V], key: _K, value: typing.Optional[_V]):
114
    if value is None:
115
        dct.pop(key, typing.cast(_V, None))
116
    else:
117
        dct[key] = value
118
119
120
def _sum_new(_cls: typing.Type[_T], subclasses):
121
    def base(cls, args):
122
        return super(_cls, cls).__new__(cls, args)
123
124
    new = _cls.__dict__.get("__new__", staticmethod(base))
125
126
    def __new__(cls, args):
127
        if cls not in subclasses:
128
            raise TypeError
129
        return new.__get__(None, cls)(cls, args)
130
131
    _cls.__new__ = staticmethod(__new__)  # type: ignore
132
133
134
def _product_new(
135
    _cls: typing.Type[_T],
136
    annotations: typing.Dict[str, typing.Any],
137
    defaults: typing.Dict[str, typing.Any],
138
):
139
    def __new__(*args, **kwargs):
140
        cls, *args = args
141
        return super(_cls, cls).__new__(cls, *args, **kwargs)
142
143
    __new__.__signature__ = inspect.signature(__new__).replace(
144
        parameters=[inspect.Parameter("cls", inspect.Parameter.POSITIONAL_ONLY)]
145
        + [
146
            inspect.Parameter(
147
                field,
148
                inspect.Parameter.POSITIONAL_OR_KEYWORD,
149
                annotation=annotation,
150
                default=defaults.get(field, inspect.Parameter.empty),
151
            )
152
            for (field, annotation) in annotations.items()
153
        ]
154
    )
155
    _cls.__new__ = __new__
156
157
158
def _all_annotations(
159
    cls: typing.Type[_T]
160
) -> typing.Iterator[typing.Tuple[typing.Type[_T], str, typing.Any]]:
161
    for superclass in reversed(cls.__mro__):
162
        for key, value in vars(superclass).get("__annotations__", {}).items():
163
            yield (superclass, key, value)
164
165
166
def _sum_args_from_annotations(cls: typing.Type[_T]) -> typing.Dict[str, typing.Tuple]:
167
    args: typing.Dict[str, typing.Tuple] = {}
168
    for superclass, key, value in _all_annotations(cls):
169
        _nillable_write(
170
            args, key, _ctor.get_args(value, vars(sys.modules[superclass.__module__]))
171
        )
172
    return args
173
174
175
def _product_args_from_annotations(
176
    cls: typing.Type[_T]
177
) -> typing.Dict[str, typing.Any]:
178
    args: typing.Dict[str, typing.Any] = {}
179
    for _, key, value in _all_annotations(cls):
180
        if value == "None" or _ctor.annotation_is_classvar(
181
            value, vars(sys.modules[cls.__module__])
182
        ):
183
            value = None
184
        _nillable_write(args, key, value)
185
    return args
186
187
188
def _ordering_options_are_valid(*, eq, order):
189
    if order and not eq:
190
        raise ValueError("eq must be true if order is true")
191
192
193
def _set_ordering(*, can_set, setter, cls, source):
194
    if not can_set:
195
        raise ValueError("Can't add ordering methods if equality methods are provided.")
196
    collision = setter(cls, source.__lt__, source.__le__, source.__gt__, source.__ge__)
197
    if collision:
198
        raise TypeError(
199
            "Cannot overwrite attribute {collision} in class "
200
            "{name}. Consider using functools.total_ordering".format(
201
                collision=collision, name=cls.__name__
202
            )
203
        )
204
205
206
def _values_non_empty(cls, field_names):
207
    for field in field_names:
208
        default = getattr(cls, field, inspect.Parameter.empty)
209
        if default is inspect.Parameter.empty:
210
            return
211
        yield (field, default)
212
213
214
def _values_until_non_empty(cls, field_names):
215
    for field in field_names:
216
        default = getattr(cls, field, inspect.Parameter.empty)
217
        if default is not inspect.Parameter.empty:
218
            yield
219
220
221
def _extract_defaults(*, cls, annotations):
222
    field_names = iter(reversed(tuple(annotations)))
223
    defaults = {
224
        field: default for (field, default) in _values_non_empty(cls, field_names)
225
    }
226
    for _ in _values_until_non_empty(cls, field_names):
227
        raise TypeError
228
    return defaults
229
230
231
def _unpack_args(*, args, kwargs, fields, values):
232
    fields_iter = iter(fields)
233
    values.update({field: arg for (arg, field) in zip(args, fields_iter)})
234
    for field in fields_iter:
235
        if field in values and field not in kwargs:
236
            continue
237
        values[field] = kwargs.pop(field)
238
    if kwargs:
239
        raise TypeError(kwargs)
240
241
242
class Sum:
243
    """Base class of classes with disjoint constructors.
244
245
    Examines PEP 526 __annotations__ to determine subclasses.
246
247
    If repr is true, a __repr__() method is added to the class.
248
    If order is true, rich comparison dunder methods are added.
249
250
    The Sum class examines the class to find Ctor annotations.
251
    A Ctor annotation is the adt.Ctor class itself, or the result of indexing
252
    the class, either with a single type hint, or a tuple of type hints.
253
    All other annotations are ignored.
254
255
    The subclass is not subclassable, but has subclasses at each of the
256
    names that had Ctor annotations. Each subclass takes a fixed number of
257
    arguments, corresponding to the type hints given to its annotation, if any.
258
    """
259
260
    __slots__ = ()
261
262
    def __new__(*args, **kwargs):  # pylint: disable=no-method-argument
263
        cls, *args = args
264
        if not issubclass(cls, _adt_constructor.ADTConstructor):
265
            raise TypeError
266
        return super(Sum, cls).__new__(cls, *args, **kwargs)
267
268
    # Both of these are for consistency with modules defined in the stdlib.
269
    # BOOM!
270
    def __init_subclass__(
271
        cls,
272
        *,
273
        repr=True,  # pylint: disable=redefined-builtin
274
        eq=True,  # pylint: disable=invalid-name
275
        order=False,
276
        **kwargs
277
    ):
278
        super().__init_subclass__(**kwargs)
279
        if issubclass(cls, _adt_constructor.ADTConstructor):
280
            return
281
        _ordering_options_are_valid(eq=eq, order=order)
282
283
        subclass_order: typing.List[typing.Type[_T]] = []
284
285
        for name, args in _sum_args_from_annotations(cls).items():
286
            _adt_constructor.make_constructor(cls, name, args, subclass_order)
287
288
        _prewritten_methods.SUBCLASS_ORDER[cls] = tuple(subclass_order)
289
290
        cls.__init_subclass__ = (
291
            _prewritten_methods.PrewrittenSumMethods.__init_subclass__
292
        )  # type: ignore
293
294
        _sum_new(cls, frozenset(subclass_order))
295
296
        _set_new_functions(
297
            cls,
298
            _prewritten_methods.PrewrittenSumMethods.__setattr__,
299
            _prewritten_methods.PrewrittenSumMethods.__delattr__,
300
        )
301
        _set_new_functions(cls, _prewritten_methods.PrewrittenSumMethods.__bool__)
302
303
        if repr:
304
            _set_new_functions(cls, _prewritten_methods.PrewrittenSumMethods.__repr__)
305
306
        equality_methods_were_set = False
307
        if eq:
308
            equality_methods_were_set = not _set_new_functions(
309
                cls,
310
                _prewritten_methods.PrewrittenSumMethods.__eq__,
311
                _prewritten_methods.PrewrittenSumMethods.__ne__,
312
            )
313
314
        if equality_methods_were_set:
315
            cls.__hash__ = _prewritten_methods.PrewrittenSumMethods.__hash__
316
317
        if order:
318
            _set_ordering(
319
                can_set=equality_methods_were_set,
320
                setter=_set_new_functions,
321
                cls=cls,
322
                source=_prewritten_methods.PrewrittenSumMethods,
323
            )
324
325
326
class Product(_adt_constructor.ADTConstructor, tuple):
327
    """Base class of classes with typed fields.
328
329
    Examines PEP 526 __annotations__ to determine fields.
330
331
    If repr is true, a __repr__() method is added to the class.
332
    If order is true, rich comparison dunder methods are added.
333
334
    The Product class examines the class to find annotations.
335
    Annotations with a value of "None" are discarded.
336
    Fields may have default values, and can be set to inspect.empty to
337
    indicate "no default".
338
339
    The subclass is subclassable. The implementation was designed with a focus
340
    on flexibility over ideals of purity, and therefore provides various
341
    optional facilities that conflict with, for example, Liskov
342
    substitutability. For the purposes of matching, each class is considered
343
    distinct.
344
    """
345
346
    __slots__ = ()
347
348
    def __new__(*args, **kwargs):  # pylint: disable=no-method-argument
349
        cls, *args = args
350
        if cls is Product:
351
            raise TypeError
352
        # Probably a result of not having positional-only args.
353
        values = cls.__defaults.copy()  # pylint: disable=protected-access
354
        _unpack_args(
355
            args=args,
356
            kwargs=kwargs,
357
            fields=cls.__fields,  # pylint: disable=protected-access
358
            values=values,
359
        )
360
        return super(Product, cls).__new__(
361
            cls,
362
            [
363
                values[field]
364
                for field in cls.__fields  # pylint: disable=protected-access
365
            ],
366
        )
367
368
    __repr = True
369
    __eq = True
370
    __order = False
371
    __eq_succeeded = None
372
373
    # Both of these are for consistency with modules defined in the stdlib.
374
    # BOOM!
375
    def __init_subclass__(
376
        cls,
377
        *,
378
        repr=None,  # pylint: disable=redefined-builtin
379
        eq=None,  # pylint: disable=invalid-name
380
        order=None,
381
        **kwargs
382
    ):
383
        super().__init_subclass__(**kwargs)
384
385
        if repr is not None:
386
            cls.__repr = repr
387
        if eq is not None:
388
            cls.__eq = eq
389
        if order is not None:
390
            cls.__order = order
391
392
        _ordering_options_are_valid(eq=cls.__eq, order=cls.__order)
393
394
        cls.__annotations = _product_args_from_annotations(cls)
395
        cls.__fields = {field: index for (index, field) in enumerate(cls.__annotations)}
396
397
        cls.__defaults = _extract_defaults(cls=cls, annotations=cls.__annotations)
398
399
        _product_new(cls, cls.__annotations, cls.__defaults)
400
401
        cls.__eq_succeeded = False
402
        if cls.__eq:
403
            cls.__eq_succeeded = not _cant_set_new_functions(
404
                cls,
405
                _prewritten_methods.PrewrittenProductMethods.__eq__,
406
                _prewritten_methods.PrewrittenProductMethods.__ne__,
407
            )
408
409
        if cls.__order:
410
            _set_ordering(
411
                can_set=cls.__eq_succeeded,
412
                setter=_cant_set_new_functions,
413
                cls=cls,
414
                source=_prewritten_methods.PrewrittenProductMethods,
415
            )
416
417
    def __dir__(self):
418
        return super().__dir__() + list(self.__fields)
419
420
    def __getattribute__(self, name):
421
        try:
422
            return super().__getattribute__(name)
423
        except AttributeError:
424
            index = self.__fields.get(name)
425
            if index is None:
426
                raise
427
            return tuple.__getitem__(self, index)
428
429
    __setattr__ = _prewritten_methods.PrewrittenProductMethods.__setattr__
430
    __delattr__ = _prewritten_methods.PrewrittenProductMethods.__delattr__
431
    __bool__ = _prewritten_methods.PrewrittenProductMethods.__bool__
432
433
    @property
434
    def __repr__(self):
435
        if self.__repr:
436
            return _prewritten_methods.PrewrittenProductMethods.__repr__.__get__(
437
                self, type(self)
438
            )
439
        return super().__repr__
440
441
    @property
442
    def __hash__(self):
443
        if self.__eq_succeeded:
444
            return _prewritten_methods.PrewrittenProductMethods.__hash__.__get__(
445
                self, type(self)
446
            )
447
        return super().__hash__
448
449
    @property
450
    def __eq__(self):  # pylint: disable=unexpected-special-method-signature
451
        if self.__eq_succeeded:
452
            # I think this is a Pylint bug, but I'm not sure how to reduce it.
453
            # pylint: disable=no-value-for-parameter
454
            return _prewritten_methods.PrewrittenProductMethods.__eq__.__get__(
455
                self, type(self)
456
            )
457
        return super().__eq__
458
459
    @property
460
    def __ne__(self):  # pylint: disable=unexpected-special-method-signature
461
        if self.__eq_succeeded:
462
            # I think this is a Pylint bug, but I'm not sure how to reduce it.
463
            # pylint: disable=no-value-for-parameter
464
            return _prewritten_methods.PrewrittenProductMethods.__ne__.__get__(
465
                self, type(self)
466
            )
467
        return super().__ne__
468
469
    @property
470
    def __lt__(self):  # pylint: disable=unexpected-special-method-signature
471
        if self.__order:
472
            # I think this is a Pylint bug, but I'm not sure how to reduce it.
473
            # pylint: disable=no-value-for-parameter
474
            return _prewritten_methods.PrewrittenProductMethods.__lt__.__get__(
475
                self, type(self)
476
            )
477
        return super().__lt__
478
479
    @property
480
    def __le__(self):  # pylint: disable=unexpected-special-method-signature
481
        if self.__order:
482
            # I think this is a Pylint bug, but I'm not sure how to reduce it.
483
            # pylint: disable=no-value-for-parameter
484
            return _prewritten_methods.PrewrittenProductMethods.__le__.__get__(
485
                self, type(self)
486
            )
487
        return super().__le__
488
489
    @property
490
    def __gt__(self):  # pylint: disable=unexpected-special-method-signature
491
        if self.__order:
492
            # I think this is a Pylint bug, but I'm not sure how to reduce it.
493
            # pylint: disable=no-value-for-parameter
494
            return _prewritten_methods.PrewrittenProductMethods.__gt__.__get__(
495
                self, type(self)
496
            )
497
        return super().__gt__
498
499
    @property
500
    def __ge__(self):  # pylint: disable=unexpected-special-method-signature
501
        if self.__order:
502
            # I think this is a Pylint bug, but I'm not sure how to reduce it.
503
            # pylint: disable=no-value-for-parameter
504
            return _prewritten_methods.PrewrittenProductMethods.__ge__.__get__(
505
                self, type(self)
506
            )
507
        return super().__ge__
508
509
510
__all__ = ["Ctor", "Product", "Sum"]
511