Passed
Push — master ( 3ca2e8...ce73ec )
by Max
53s
created

structured_data.adt._ProductMethod.__delete__()   A

Complexity

Conditions 1

Size

Total Lines 3
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 2
nop 2
dl 0
loc 3
rs 10
c 0
b 0
f 0
1
"""Base classes for defining abstract data types.
2
3
This module provides three public members, which are used together.
4
5
Given a structure, possibly a choice of different structures, that you'd like
6
to associate with a type:
7
8
- First, create a class, that subclasses the Sum class.
9
- Then, for each possible structure, add an attribute annotation to the class
10
  with the desired name of the constructor, and a type of ``Ctor``, with the
11
  types within the constructor as arguments.
12
13
To look inside an ADT instance, use the functions from the
14
:mod:`structured_data.match` module.
15
16
Putting it together:
17
18
>>> from structured_data import match
19
>>> class Example(Sum):
20
...     FirstConstructor: Ctor[int, str]
21
...     SecondConstructor: Ctor[bytes]
22
...     ThirdConstructor: Ctor
23
...     def __iter__(self):
24
...         matchable = match.Matchable(self)
25
...         if matchable(Example.FirstConstructor(match.pat.count, match.pat.string)):
26
...             count, string = matchable[match.pat.count, match.pat.string]
27
...             for _ in range(count):
28
...                 yield string
29
...         elif matchable(Example.SecondConstructor(match.pat.bytes)):
30
...             bytes_ = matchable[match.pat.bytes]
31
...             for byte in bytes_:
32
...                 yield chr(byte)
33
...         elif matchable(Example.ThirdConstructor()):
34
...             yield "Third"
35
...             yield "Constructor"
36
>>> list(Example.FirstConstructor(5, "abc"))
37
['abc', 'abc', 'abc', 'abc', 'abc']
38
>>> list(Example.SecondConstructor(b"abc"))
39
['a', 'b', 'c']
40
>>> list(Example.ThirdConstructor())
41
['Third', 'Constructor']
42
"""
43
44
import inspect
45
import sys
46
import typing
47
48
from . import _adt_constructor
49
from . import _ctor
50
from . import _prewritten_methods
51
52
_T = typing.TypeVar("_T")
53
54
55
if typing.TYPE_CHECKING:  # pragma: nocover
56
57
    class Ctor:
58
        """Dummy class for type-checking purposes."""
59
60
    class ConcreteCtor(typing.Generic[_T]):
61
        """Wrapper class for type-checking purposes.
62
63
        The type parameter should be a Tuple type of fixed size.
64
        Classes containing this annotation (meaning they haven't been
65
        processed by the ``adt`` decorator) should not be instantiated.
66
        """
67
68
69
else:
70
    from ._ctor import Ctor
71
72
73
# This is fine.
74
def _name(cls: typing.Type[_T], function) -> str:
75
    """Return the name of a function accessed through a descriptor."""
76
    return function.__get__(None, cls).__name__
77
78
79
# This is mostly fine, though the list of classes is somewhat ad-hoc, to say
80
# the least.
81
def _cant_set_new_functions(cls: typing.Type[_T], *functions) -> typing.Optional[str]:
82
    for function in functions:
83
        name = _name(cls, function)
84
        existing = getattr(cls, name, None)
85
        if existing not in (
86
            getattr(object, name, None),
87
            getattr(Product, name, None),
88
            None,
89
            function,
90
        ):
91
            return name
92
    return None
93
94
95
def _set_new_functions(cls: typing.Type[_T], *functions) -> typing.Optional[str]:
96
    """Attempt to set the attributes corresponding to the functions on cls.
97
98
    If any attributes are already defined, fail *before* setting any, and
99
    return the already-defined name.
100
    """
101
    cant_set = _cant_set_new_functions(cls, *functions)
102
    if cant_set:
103
        return cant_set
104
    for function in functions:
105
        setattr(cls, _name(cls, function), function)
106
    return None
107
108
109
_K = typing.TypeVar("_K")
110
_V = typing.TypeVar("_V")
111
112
113
def _nillable_write(dct: typing.Dict[_K, _V], key: _K, value: typing.Optional[_V]):
114
    if value is None:
115
        dct.pop(key, typing.cast(_V, None))
116
    else:
117
        dct[key] = value
118
119
120
def _sum_new(_cls: typing.Type[_T], subclasses):
121
    def base(cls, args):
122
        return super(_cls, cls).__new__(cls, args)
123
124
    new = _cls.__dict__.get("__new__", staticmethod(base))
125
126
    def __new__(cls, args):
127
        if cls not in subclasses:
128
            raise TypeError
129
        return new.__get__(None, cls)(cls, args)
130
131
    _cls.__new__ = staticmethod(__new__)  # type: ignore
132
133
134
def _product_new(
135
    _cls: typing.Type[_T],
136
    annotations: typing.Dict[str, typing.Any],
137
    defaults: typing.Dict[str, typing.Any],
138
):
139
    def __new__(*args, **kwargs):
140
        cls, *args = args
141
        return super(_cls, cls).__new__(cls, *args, **kwargs)
142
143
    __new__.__signature__ = inspect.signature(__new__).replace(
144
        parameters=[inspect.Parameter("cls", inspect.Parameter.POSITIONAL_ONLY)]
145
        + [
146
            inspect.Parameter(
147
                field,
148
                inspect.Parameter.POSITIONAL_OR_KEYWORD,
149
                annotation=annotation,
150
                default=defaults.get(field, inspect.Parameter.empty),
151
            )
152
            for (field, annotation) in annotations.items()
153
        ]
154
    )
155
    _cls.__new__ = __new__
156
157
158
def _all_annotations(
159
    cls: typing.Type[_T]
160
) -> typing.Iterator[typing.Tuple[typing.Type[_T], str, typing.Any]]:
161
    for superclass in reversed(cls.__mro__):
162
        for key, value in vars(superclass).get("__annotations__", {}).items():
163
            yield (superclass, key, value)
164
165
166
def _sum_args_from_annotations(cls: typing.Type[_T]) -> typing.Dict[str, typing.Tuple]:
167
    args: typing.Dict[str, typing.Tuple] = {}
168
    for superclass, key, value in _all_annotations(cls):
169
        _nillable_write(
170
            args, key, _ctor.get_args(value, vars(sys.modules[superclass.__module__]))
171
        )
172
    return args
173
174
175
def _product_args_from_annotations(
176
    cls: typing.Type[_T]
177
) -> typing.Dict[str, typing.Any]:
178
    args: typing.Dict[str, typing.Any] = {}
179
    for superclass, key, value in _all_annotations(cls):
180
        if value == "None" or _ctor.annotation_is_classvar(
181
            value, vars(sys.modules[superclass.__module__])
182
        ):
183
            value = None
184
        _nillable_write(args, key, value)
185
    return args
186
187
188
def _ordering_options_are_valid(*, eq, order):
189
    if order and not eq:
190
        raise ValueError("eq must be true if order is true")
191
192
193
def _set_ordering(*, can_set, setter, cls, source):
194
    if not can_set:
195
        raise ValueError("Can't add ordering methods if equality methods are provided.")
196
    collision = setter(cls, source.__lt__, source.__le__, source.__gt__, source.__ge__)
197
    if collision:
198
        raise TypeError(
199
            "Cannot overwrite attribute {collision} in class "
200
            "{name}. Consider using functools.total_ordering".format(
201
                collision=collision, name=cls.__name__
202
            )
203
        )
204
205
206
def _values_non_empty(cls, field_names):
207
    for field in field_names:
208
        default = getattr(cls, field, inspect.Parameter.empty)
209
        if default is inspect.Parameter.empty:
210
            return
211
        yield (field, default)
212
213
214
def _values_until_non_empty(cls, field_names):
215
    for field in field_names:
216
        default = getattr(cls, field, inspect.Parameter.empty)
217
        if default is not inspect.Parameter.empty:
218
            yield
219
220
221
def _extract_defaults(*, cls, annotations):
222
    field_names = iter(reversed(tuple(annotations)))
223
    defaults = dict(_values_non_empty(cls, field_names))
224
    for _ in _values_until_non_empty(cls, field_names):
225
        raise TypeError
226
    return defaults
227
228
229
def _unpack_args(*, args, kwargs, fields, values):
230
    fields_iter = iter(fields)
231
    values.update({field: arg for (arg, field) in zip(args, fields_iter)})
232
    for field in fields_iter:
233
        if field in values and field not in kwargs:
234
            continue
235
        values[field] = kwargs.pop(field)
236
    if kwargs:
237
        raise TypeError(kwargs)
238
239
240
class Sum:
241
    """Base class of classes with disjoint constructors.
242
243
    Examines PEP 526 __annotations__ to determine subclasses.
244
245
    If repr is true, a __repr__() method is added to the class.
246
    If order is true, rich comparison dunder methods are added.
247
248
    The Sum class examines the class to find Ctor annotations.
249
    A Ctor annotation is the adt.Ctor class itself, or the result of indexing
250
    the class, either with a single type hint, or a tuple of type hints.
251
    All other annotations are ignored.
252
253
    The subclass is not subclassable, but has subclasses at each of the
254
    names that had Ctor annotations. Each subclass takes a fixed number of
255
    arguments, corresponding to the type hints given to its annotation, if any.
256
    """
257
258
    __slots__ = ()
259
260
    def __new__(*args, **kwargs):  # pylint: disable=no-method-argument
261
        cls, *args = args
262
        if not issubclass(cls, _adt_constructor.ADTConstructor):
263
            raise TypeError
264
        return super(Sum, cls).__new__(cls, *args, **kwargs)
265
266
    # Both of these are for consistency with modules defined in the stdlib.
267
    # BOOM!
268
    def __init_subclass__(
269
        cls,
270
        *,
271
        repr=True,  # pylint: disable=redefined-builtin
272
        eq=True,  # pylint: disable=invalid-name
273
        order=False,
274
        **kwargs
275
    ):
276
        super().__init_subclass__(**kwargs)
277
        if issubclass(cls, _adt_constructor.ADTConstructor):
278
            return
279
        _ordering_options_are_valid(eq=eq, order=order)
280
281
        subclass_order: typing.List[typing.Type[_T]] = []
282
283
        for name, args in _sum_args_from_annotations(cls).items():
284
            _adt_constructor.make_constructor(cls, name, args, subclass_order)
285
286
        _prewritten_methods.SUBCLASS_ORDER[cls] = tuple(subclass_order)
287
288
        source = _prewritten_methods.PrewrittenSumMethods
289
290
        cls.__init_subclass__ = source.__init_subclass__  # type: ignore
291
292
        _sum_new(cls, frozenset(subclass_order))
293
294
        _set_new_functions(cls, source.__setattr__, source.__delattr__)
295
        _set_new_functions(cls, source.__bool__)
296
297
        if repr:
298
            _set_new_functions(cls, source.__repr__)
299
300
        equality_methods_were_set = False
301
        if eq:
302
            equality_methods_were_set = not _set_new_functions(
303
                cls, source.__eq__, source.__ne__
304
            )
305
306
        if equality_methods_were_set:
307
            cls.__hash__ = source.__hash__
308
309
        if order:
310
            _set_ordering(
311
                can_set=equality_methods_were_set,
312
                setter=_set_new_functions,
313
                cls=cls,
314
                source=source,
315
            )
316
317
318
class _ConditionalMethod:
319
    name = None
320
    field_check = None
321
322
    def __init__(self, source):
323
        self.source = source
324
325
    def __getattr__(self, name):
326
        self.field_check = name
327
        return self
328
329
    def __set_name__(self, owner, name):
330
        self.__objclass__ = owner
331
        self.name = name
332
333
    def __get__(self, instance, owner):
334
        if getattr(owner, self.field_check):
335
            return getattr(self.source, self.name).__get__(instance, owner)
336
        target = owner if instance is None else instance
337
        return getattr(super(self.__objclass__, target), self.name)
338
339
    def __set__(self, instance, value):
340
        # Don't care about this coverage
341
        raise AttributeError  # pragma: nocover
342
343
    def __delete__(self, instance):
344
        # Don't care about this coverage
345
        raise AttributeError  # pragma: nocover
346
347
348
class Product(_adt_constructor.ADTConstructor, tuple):
349
    """Base class of classes with typed fields.
350
351
    Examines PEP 526 __annotations__ to determine fields.
352
353
    If repr is true, a __repr__() method is added to the class.
354
    If order is true, rich comparison dunder methods are added.
355
356
    The Product class examines the class to find annotations.
357
    Annotations with a value of "None" are discarded.
358
    Fields may have default values, and can be set to inspect.empty to
359
    indicate "no default".
360
361
    The subclass is subclassable. The implementation was designed with a focus
362
    on flexibility over ideals of purity, and therefore provides various
363
    optional facilities that conflict with, for example, Liskov
364
    substitutability. For the purposes of matching, each class is considered
365
    distinct.
366
    """
367
368
    __slots__ = ()
369
370
    def __new__(*args, **kwargs):  # pylint: disable=no-method-argument
371
        cls, *args = args
372
        if cls is Product:
373
            raise TypeError
374
        # Probably a result of not having positional-only args.
375
        values = cls.__defaults.copy()  # pylint: disable=protected-access
376
        _unpack_args(
377
            args=args,
378
            kwargs=kwargs,
379
            fields=cls.__fields,  # pylint: disable=protected-access
380
            values=values,
381
        )
382
        return super(Product, cls).__new__(
383
            cls,
384
            [
385
                values[field]
386
                for field in cls.__fields  # pylint: disable=protected-access
387
            ],
388
        )
389
390
    __repr = True
391
    __eq = True
392
    __order = False
393
    __eq_succeeded = None
394
395
    # Both of these are for consistency with modules defined in the stdlib.
396
    # BOOM!
397
    def __init_subclass__(
398
        cls,
399
        *,
400
        repr=None,  # pylint: disable=redefined-builtin
401
        eq=None,  # pylint: disable=invalid-name
402
        order=None,
403
        **kwargs
404
    ):
405
        super().__init_subclass__(**kwargs)
406
407
        if repr is not None:
408
            cls.__repr = repr
409
        if eq is not None:
410
            cls.__eq = eq
411
        if order is not None:
412
            cls.__order = order
413
414
        _ordering_options_are_valid(eq=cls.__eq, order=cls.__order)
415
416
        cls.__annotations = _product_args_from_annotations(cls)
417
        cls.__fields = {field: index for (index, field) in enumerate(cls.__annotations)}
418
419
        cls.__defaults = _extract_defaults(cls=cls, annotations=cls.__annotations)
420
421
        _product_new(cls, cls.__annotations, cls.__defaults)
422
423
        source = _prewritten_methods.PrewrittenProductMethods
424
425
        cls.__eq_succeeded = False
426
        if cls.__eq:
427
            cls.__eq_succeeded = not _cant_set_new_functions(
428
                cls, source.__eq__, source.__ne__
429
            )
430
431
        if cls.__order:
432
            _set_ordering(
433
                can_set=cls.__eq_succeeded,
434
                setter=_cant_set_new_functions,
435
                cls=cls,
436
                source=source,
437
            )
438
439
    def __dir__(self):
440
        return super().__dir__() + list(self.__fields)
441
442
    def __getattribute__(self, name):
443
        try:
444
            return super().__getattribute__(name)
445
        except AttributeError:
446
            index = self.__fields.get(name)
447
            if index is None:
448
                raise
449
            return tuple.__getitem__(self, index)
450
451
    source = _prewritten_methods.PrewrittenProductMethods
452
453
    __setattr__ = source.__setattr__
454
    __delattr__ = source.__delattr__
455
    __bool__ = source.__bool__
456
457
    __repr__ = _ConditionalMethod(source).__repr
458
    __hash__ = _ConditionalMethod(source).__eq_succeeded
459
    __eq__ = _ConditionalMethod(source).__eq_succeeded
460
    __ne__ = _ConditionalMethod(source).__eq_succeeded
461
    __lt__ = _ConditionalMethod(source).__order
462
    __le__ = _ConditionalMethod(source).__order
463
    __gt__ = _ConditionalMethod(source).__order
464
    __ge__ = _ConditionalMethod(source).__order
465
466
    del source
467
468
469
__all__ = ["Ctor", "Product", "Sum"]
470