Passed
Push — master ( 84e84b...408aae )
by Max
01:17
created

structured_data.adt._can_set_ordering()   A

Complexity

Conditions 2

Size

Total Lines 3
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 3
nop 2
dl 0
loc 3
rs 10
c 0
b 0
f 0
1
"""Base classes for defining abstract data types.
2
3
This module provides three public members, which are used together.
4
5
Given a structure, possibly a choice of different structures, that you'd like
6
to associate with a type:
7
8
- First, create a class, that subclasses the Sum class.
9
- Then, for each possible structure, add an attribute annotation to the class
10
  with the desired name of the constructor, and a type of ``Ctor``, with the
11
  types within the constructor as arguments.
12
13
To look inside an ADT instance, use the functions from the
14
:mod:`structured_data.match` module.
15
16
Putting it together:
17
18
>>> from structured_data import match
19
>>> class Example(Sum):
20
...     FirstConstructor: Ctor[int, str]
21
...     SecondConstructor: Ctor[bytes]
22
...     ThirdConstructor: Ctor
23
...     def __iter__(self):
24
...         matchable = match.Matchable(self)
25
...         if matchable(Example.FirstConstructor(match.pat.count, match.pat.string)):
26
...             count, string = matchable[match.pat.count, match.pat.string]
27
...             for _ in range(count):
28
...                 yield string
29
...         elif matchable(Example.SecondConstructor(match.pat.bytes)):
30
...             bytes_ = matchable[match.pat.bytes]
31
...             for byte in bytes_:
32
...                 yield chr(byte)
33
...         elif matchable(Example.ThirdConstructor()):
34
...             yield "Third"
35
...             yield "Constructor"
36
>>> list(Example.FirstConstructor(5, "abc"))
37
['abc', 'abc', 'abc', 'abc', 'abc']
38
>>> list(Example.SecondConstructor(b"abc"))
39
['a', 'b', 'c']
40
>>> list(Example.ThirdConstructor())
41
['Third', 'Constructor']
42
"""
43
44
import inspect
45
import typing
46
47
from . import _adt_constructor
48
from . import _annotations
49
from . import _conditional_method
50
from . import _prewritten_methods
51
52
_T = typing.TypeVar("_T")
53
54
55
if typing.TYPE_CHECKING:  # pragma: nocover
56
57
    class Ctor:
58
        """Dummy class for type-checking purposes."""
59
60
    class ConcreteCtor(typing.Generic[_T]):
61
        """Wrapper class for type-checking purposes.
62
63
        The type parameter should be a Tuple type of fixed size.
64
        Classes containing this annotation (meaning they haven't been
65
        processed by the ``adt`` decorator) should not be instantiated.
66
        """
67
68
69
else:
70
    from ._ctor import Ctor
71
72
73
# This is fine.
74
def _name(cls: type, function) -> str:
75
    """Return the name of a function accessed through a descriptor."""
76
    return function.__get__(None, cls).__name__
77
78
79
# This is mostly fine, though the list of classes is somewhat ad-hoc, to say
80
# the least.
81
def _cant_set_new_functions(cls: type, *functions) -> typing.Optional[str]:
82
    for function in functions:
83
        name = _name(cls, function)
84
        existing = getattr(cls, name, None)
85
        if existing not in (
86
            getattr(object, name, None),
87
            getattr(Product, name, None),
88
            None,
89
            function,
90
        ):
91
            return name
92
    return None
93
94
95
MISSING = object()
96
97
98
def cant_modify(self, name):
99
    """Prevent attempts to modify an attr of the given name."""
100
    class_repr = repr(self.__class__.__name__)
101
    name_repr = repr(name)
102
    if inspect.getattr_static(self, name, MISSING) is MISSING:
103
        format_msg = "{class_repr} object has no attribute {name_repr}"
104
    else:
105
        format_msg = "{class_repr} object attribute {name_repr} is read-only"
106
    raise AttributeError(format_msg.format(class_repr=class_repr, name_repr=name_repr))
107
108
109
def _set_new_functions(cls: type, *functions) -> typing.Optional[str]:
110
    """Attempt to set the attributes corresponding to the functions on cls.
111
112
    If any attributes are already defined, fail *before* setting any, and
113
    return the already-defined name.
114
    """
115
    cant_set = _cant_set_new_functions(cls, *functions)
116
    if cant_set:
117
        return cant_set
118
    for function in functions:
119
        setattr(cls, _name(cls, function), function)
120
    return None
121
122
123
def _sum_new(_cls: typing.Type[_T], subclasses):
124
    def base(cls: typing.Type[_T], args):
125
        return super(_cls, cls).__new__(cls, args)  # type: ignore
126
127
    new = vars(_cls).get("__new__", staticmethod(base))
128
129
    def __new__(cls: typing.Type[_T], args):
130
        if cls not in subclasses:
131
            raise TypeError
132
        return new.__get__(None, cls)(cls, args)
133
134
    _cls.__new__ = staticmethod(__new__)  # type: ignore
135
136
137
def _product_new(
138
    _cls: typing.Type[_T],
139
    annotations: typing.Dict[str, typing.Any],
140
    defaults: typing.Dict[str, typing.Any],
141
):
142
    if "__new__" in vars(_cls):
143
        original_new = _cls.__new__
144
145
        def __new__(*args, **kwargs):
146
            cls, *args = args
147
            if cls is _cls:
148
                return original_new(cls, *args, **kwargs)
149
            return super(_cls, cls).__new__(cls, *args, **kwargs)
150
151
        signature = inspect.signature(original_new)
152
    else:
153
154
        def __new__(*args, **kwargs):
155
            cls, *args = args
156
            return super(_cls, cls).__new__(cls, *args, **kwargs)
157
158
        signature = inspect.signature(__new__).replace(
159
            parameters=[inspect.Parameter("cls", inspect.Parameter.POSITIONAL_ONLY)]
160
            + [
161
                inspect.Parameter(
162
                    field,
163
                    inspect.Parameter.POSITIONAL_OR_KEYWORD,
164
                    annotation=annotation,
165
                    default=defaults.get(field, inspect.Parameter.empty),
166
                )
167
                for (field, annotation) in annotations.items()
168
            ]
169
        )
170
    __new__.__signature__ = signature  # type: ignore
171
    _cls.__new__ = __new__  # type: ignore
172
173
174
def _ordering_options_are_valid(
175
    *, eq: bool, order: bool  # pylint: disable=invalid-name
176
):
177
    if order and not eq:
178
        raise ValueError("eq must be true if order is true")
179
180
181
def _can_set_ordering(*, can_set: bool):
182
    if not can_set:
183
        raise ValueError("Can't add ordering methods if equality methods are provided.")
184
185
186
def _set_ordering(*, setter, cls: type, source: type):
187
    collision = setter(
188
        cls, source.__lt__, source.__le__, source.__gt__, source.__ge__  # type: ignore
189
    )
190
    if collision:
191
        raise TypeError(
192
            "Cannot overwrite attribute {collision} in class "
193
            "{name}. Consider using functools.total_ordering".format(
194
                collision=collision, name=cls.__name__
195
            )
196
        )
197
198
199
def _values_non_empty(
200
    cls: type, field_names: typing.Iterator[str]
201
) -> typing.Iterator[typing.Tuple[str, typing.Any]]:
202
    for field in field_names:
203
        default = getattr(cls, field, inspect.Parameter.empty)
204
        if default is inspect.Parameter.empty:
205
            return
206
        yield (field, default)
207
208
209
def _values_until_non_empty(
210
    cls: type, field_names: typing.Iterator[str]
211
) -> typing.Iterator:
212
    for field in field_names:
213
        default = getattr(cls, field, inspect.Parameter.empty)
214
        if default is not inspect.Parameter.empty:
215
            yield
216
217
218
def _extract_defaults(*, cls: type, annotations: typing.Iterable[str]):
219
    field_names = iter(reversed(tuple(annotations)))
220
    defaults = dict(_values_non_empty(cls, field_names))
221
    for _ in _values_until_non_empty(cls, field_names):
222
        raise TypeError
223
    return defaults
224
225
226
def _unpack_args(
227
    *,
228
    args: typing.Tuple[typing.Any, ...],
229
    kwargs: typing.Dict[str, typing.Any],
230
    fields: typing.Iterable[str],
231
    values: typing.Dict[str, typing.Any],
232
):
233
    fields_iter = iter(fields)
234
    values.update({field: arg for (arg, field) in zip(args, fields_iter)})
235
    for field in fields_iter:
236
        if field in values and field not in kwargs:
237
            continue
238
        values[field] = kwargs.pop(field)
239
    if kwargs:
240
        raise TypeError(kwargs)
241
242
243
class Sum:
244
    """Base class of classes with disjoint constructors.
245
246
    Examines PEP 526 __annotations__ to determine subclasses.
247
248
    If repr is true, a __repr__() method is added to the class.
249
    If order is true, rich comparison dunder methods are added.
250
251
    The Sum class examines the class to find Ctor annotations.
252
    A Ctor annotation is the adt.Ctor class itself, or the result of indexing
253
    the class, either with a single type hint, or a tuple of type hints.
254
    All other annotations are ignored.
255
256
    The subclass is not subclassable, but has subclasses at each of the
257
    names that had Ctor annotations. Each subclass takes a fixed number of
258
    arguments, corresponding to the type hints given to its annotation, if any.
259
    """
260
261
    __slots__ = ()
262
263
    def __new__(*args, **kwargs):  # pylint: disable=no-method-argument
264
        cls, *args = args
265
        if not issubclass(cls, _adt_constructor.ADTConstructor):
266
            raise TypeError
267
        return super(Sum, cls).__new__(cls, *args, **kwargs)
268
269
    # Both of these are for consistency with modules defined in the stdlib.
270
    # BOOM!
271
    def __init_subclass__(
272
        cls,
273
        *,
274
        repr: bool = True,  # pylint: disable=redefined-builtin
275
        eq: bool = True,  # pylint: disable=invalid-name
276
        order: bool = False,
277
        **kwargs,
278
    ):
279
        super().__init_subclass__(**kwargs)  # type: ignore
280
        if issubclass(cls, _adt_constructor.ADTConstructor):
281
            return
282
        _ordering_options_are_valid(eq=eq, order=order)
283
284
        _prewritten_methods.SUBCLASS_ORDER[cls] = _adt_constructor.make_constructors(
285
            cls
286
        )
287
288
        source = _prewritten_methods.PrewrittenSumMethods
289
290
        cls.__init_subclass__ = source.__init_subclass__  # type: ignore
291
292
        _sum_new(cls, frozenset(_prewritten_methods.SUBCLASS_ORDER[cls]))
293
294
        if repr:
295
            _set_new_functions(cls, source.__repr__)
296
297
        equality_methods_were_set = False
298
        if eq:
299
            equality_methods_were_set = not _set_new_functions(
300
                cls, source.__eq__, source.__ne__
301
            )
302
303
        if equality_methods_were_set:
304
            cls.__hash__ = source.__hash__  # type: ignore
305
306
        if order:
307
            _can_set_ordering(can_set=equality_methods_were_set)
308
            _set_ordering(setter=_set_new_functions, cls=cls, source=source)
309
310
    def __bool__(self):
311
        return True
312
313
    def __setattr__(self, name, value):
314
        if not inspect.isdatadescriptor(inspect.getattr_static(self, name, MISSING)):
315
            cant_modify(self, name)
316
        super().__setattr__(name, value)
317
318
    def __delattr__(self, name):
319
        if not inspect.isdatadescriptor(inspect.getattr_static(self, name, MISSING)):
320
            cant_modify(self, name)
321
        super().__delattr__(name)
322
323
324
class Product(_adt_constructor.ADTConstructor, tuple):
325
    """Base class of classes with typed fields.
326
327
    Examines PEP 526 __annotations__ to determine fields.
328
329
    If repr is true, a __repr__() method is added to the class.
330
    If order is true, rich comparison dunder methods are added.
331
332
    The Product class examines the class to find annotations.
333
    Annotations with a value of "None" are discarded.
334
    Fields may have default values, and can be set to inspect.empty to
335
    indicate "no default".
336
337
    The subclass is subclassable. The implementation was designed with a focus
338
    on flexibility over ideals of purity, and therefore provides various
339
    optional facilities that conflict with, for example, Liskov
340
    substitutability. For the purposes of matching, each class is considered
341
    distinct.
342
    """
343
344
    __slots__ = ()
345
346
    def __new__(*args, **kwargs):  # pylint: disable=no-method-argument
347
        cls, *args = args
348
        if cls is Product:
349
            raise TypeError
350
        # Probably a result of not having positional-only args.
351
        values = cls.__defaults.copy()  # pylint: disable=protected-access
352
        _unpack_args(
353
            args=args,
354
            kwargs=kwargs,
355
            fields=cls.__fields,  # pylint: disable=protected-access
356
            values=values,
357
        )
358
        return super(Product, cls).__new__(
359
            cls,
360
            [
361
                values[field]
362
                for field in cls.__fields  # pylint: disable=protected-access
363
            ],
364
        )
365
366
    __repr: typing.ClassVar[bool] = True
367
    __eq: typing.ClassVar[bool] = True
368
    __order: typing.ClassVar[bool] = False
369
    __eq_succeeded = None
370
371
    # Both of these are for consistency with modules defined in the stdlib.
372
    # BOOM!
373
    def __init_subclass__(
374
        cls,
375
        *,
376
        repr: typing.Optional[bool] = None,  # pylint: disable=redefined-builtin
377
        eq: typing.Optional[bool] = None,  # pylint: disable=invalid-name
378
        order: typing.Optional[bool] = None,
379
        **kwargs,
380
    ):
381
        super().__init_subclass__(**kwargs)  # type: ignore
382
383
        if repr is not None:
384
            cls.__repr = repr
385
        if eq is not None:
386
            cls.__eq = eq
387
        if order is not None:
388
            cls.__order = order
389
390
        _ordering_options_are_valid(eq=cls.__eq, order=cls.__order)
391
392
        cls.__annotations = _annotations.product_args_from_annotations(cls)
393
        cls.__fields = {field: index for (index, field) in enumerate(cls.__annotations)}
394
395
        cls.__defaults = _extract_defaults(cls=cls, annotations=cls.__annotations)
396
397
        _product_new(cls, cls.__annotations, cls.__defaults)
398
399
        source = _prewritten_methods.PrewrittenProductMethods
400
401
        cls.__eq_succeeded = False
402
        if cls.__eq:
403
            cls.__eq_succeeded = not _cant_set_new_functions(
404
                cls, source.__eq__, source.__ne__
405
            )
406
407
        if cls.__order:
408
            _can_set_ordering(can_set=cls.__eq_succeeded)
409
            _set_ordering(setter=_cant_set_new_functions, cls=cls, source=source)
410
411
    def __dir__(self):
412
        return super().__dir__() + list(self.__fields)
413
414
    def __getattribute__(self, name):
415
        index = object.__getattribute__(self, "_Product__fields").get(name)
416
        if index is None:
417
            return super().__getattribute__(name)
418
        return tuple.__getitem__(self, index)
419
420
    def __setattr__(self, name, value):
421
        if not inspect.isdatadescriptor(inspect.getattr_static(self, name, MISSING)):
422
            cant_modify(self, name)
423
        super().__setattr__(name, value)
424
425
    def __delattr__(self, name):
426
        if not inspect.isdatadescriptor(inspect.getattr_static(self, name, MISSING)):
427
            cant_modify(self, name)
428
        super().__delattr__(name)
429
430
    def __bool__(self):
431
        return True
432
433
    source = _prewritten_methods.PrewrittenProductMethods
434
435
    # pylint: disable=protected-access
436
    __repr__ = _conditional_method.conditional_method(source).__repr  # type: ignore
437
    __hash__ = _conditional_method.conditional_method(  # type: ignore
438
        source
439
    ).__eq_succeeded
440
    __eq__ = _conditional_method.conditional_method(  # type: ignore
441
        source
442
    ).__eq_succeeded
443
    __ne__ = _conditional_method.conditional_method(  # type: ignore
444
        source
445
    ).__eq_succeeded
446
    __lt__ = _conditional_method.conditional_method(source).__order  # type: ignore
447
    __le__ = _conditional_method.conditional_method(source).__order  # type: ignore
448
    __gt__ = _conditional_method.conditional_method(source).__order  # type: ignore
449
    __ge__ = _conditional_method.conditional_method(source).__order  # type: ignore
450
451
    del source
452
453
454
__all__ = ["Ctor", "Product", "Sum"]
455