Passed
Push — master ( 38d9d4...f1e085 )
by Max
01:04
created

structured_data.adt._conditional_method()   A

Complexity

Conditions 1

Size

Total Lines 3
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 3
nop 1
dl 0
loc 3
rs 10
c 0
b 0
f 0
1
"""Base classes for defining abstract data types.
2
3
This module provides three public members, which are used together.
4
5
Given a structure, possibly a choice of different structures, that you'd like
6
to associate with a type:
7
8
- First, create a class, that subclasses the Sum class.
9
- Then, for each possible structure, add an attribute annotation to the class
10
  with the desired name of the constructor, and a type of ``Ctor``, with the
11
  types within the constructor as arguments.
12
13
To look inside an ADT instance, use the functions from the
14
:mod:`structured_data.match` module.
15
16
Putting it together:
17
18
>>> from structured_data import match
19
>>> class Example(Sum):
20
...     FirstConstructor: Ctor[int, str]
21
...     SecondConstructor: Ctor[bytes]
22
...     ThirdConstructor: Ctor
23
...     def __iter__(self):
24
...         matchable = match.Matchable(self)
25
...         if matchable(Example.FirstConstructor(match.pat.count, match.pat.string)):
26
...             count, string = matchable[match.pat.count, match.pat.string]
27
...             for _ in range(count):
28
...                 yield string
29
...         elif matchable(Example.SecondConstructor(match.pat.bytes)):
30
...             bytes_ = matchable[match.pat.bytes]
31
...             for byte in bytes_:
32
...                 yield chr(byte)
33
...         elif matchable(Example.ThirdConstructor()):
34
...             yield "Third"
35
...             yield "Constructor"
36
>>> list(Example.FirstConstructor(5, "abc"))
37
['abc', 'abc', 'abc', 'abc', 'abc']
38
>>> list(Example.SecondConstructor(b"abc"))
39
['a', 'b', 'c']
40
>>> list(Example.ThirdConstructor())
41
['Third', 'Constructor']
42
"""
43
44
import functools
45
import inspect
46
import typing
47
48
from . import _adt_constructor
49
from . import _annotations
50
from . import _attribute_constructor
51
from . import _prewritten_methods
52
53
_T = typing.TypeVar("_T")
54
55
56
if typing.TYPE_CHECKING:  # pragma: nocover
57
58
    class Ctor:
59
        """Dummy class for type-checking purposes."""
60
61
    class ConcreteCtor(typing.Generic[_T]):
62
        """Wrapper class for type-checking purposes.
63
64
        The type parameter should be a Tuple type of fixed size.
65
        Classes containing this annotation (meaning they haven't been
66
        processed by the ``adt`` decorator) should not be instantiated.
67
        """
68
69
70
else:
71
    from ._ctor import Ctor
72
73
74
# This is fine.
75
def _name(cls: typing.Type[_T], function) -> str:
76
    """Return the name of a function accessed through a descriptor."""
77
    return function.__get__(None, cls).__name__
78
79
80
# This is mostly fine, though the list of classes is somewhat ad-hoc, to say
81
# the least.
82
def _cant_set_new_functions(cls: typing.Type[_T], *functions) -> typing.Optional[str]:
83
    for function in functions:
84
        name = _name(cls, function)
85
        existing = getattr(cls, name, None)
86
        if existing not in (
87
            getattr(object, name, None),
88
            getattr(Product, name, None),
89
            None,
90
            function,
91
        ):
92
            return name
93
    return None
94
95
96
def _set_new_functions(cls: typing.Type[_T], *functions) -> typing.Optional[str]:
97
    """Attempt to set the attributes corresponding to the functions on cls.
98
99
    If any attributes are already defined, fail *before* setting any, and
100
    return the already-defined name.
101
    """
102
    cant_set = _cant_set_new_functions(cls, *functions)
103
    if cant_set:
104
        return cant_set
105
    for function in functions:
106
        setattr(cls, _name(cls, function), function)
107
    return None
108
109
110
def _sum_new(_cls: typing.Type[_T], subclasses):
111
    def base(cls, args):
112
        return super(_cls, cls).__new__(cls, args)
113
114
    new = _cls.__dict__.get("__new__", staticmethod(base))
115
116
    def __new__(cls, args):
117
        if cls not in subclasses:
118
            raise TypeError
119
        return new.__get__(None, cls)(cls, args)
120
121
    _cls.__new__ = staticmethod(__new__)  # type: ignore
122
123
124
def _product_new(
125
    _cls: typing.Type[_T],
126
    annotations: typing.Dict[str, typing.Any],
127
    defaults: typing.Dict[str, typing.Any],
128
):
129
    def __new__(*args, **kwargs):
130
        cls, *args = args
131
        return super(_cls, cls).__new__(cls, *args, **kwargs)
132
133
    __new__.__signature__ = inspect.signature(__new__).replace(
134
        parameters=[inspect.Parameter("cls", inspect.Parameter.POSITIONAL_ONLY)]
135
        + [
136
            inspect.Parameter(
137
                field,
138
                inspect.Parameter.POSITIONAL_OR_KEYWORD,
139
                annotation=annotation,
140
                default=defaults.get(field, inspect.Parameter.empty),
141
            )
142
            for (field, annotation) in annotations.items()
143
        ]
144
    )
145
    _cls.__new__ = __new__
146
147
148
def _ordering_options_are_valid(*, eq, order):
149
    if order and not eq:
150
        raise ValueError("eq must be true if order is true")
151
152
153
def _set_ordering(*, can_set, setter, cls, source):
154
    if not can_set:
155
        raise ValueError("Can't add ordering methods if equality methods are provided.")
156
    collision = setter(cls, source.__lt__, source.__le__, source.__gt__, source.__ge__)
157
    if collision:
158
        raise TypeError(
159
            "Cannot overwrite attribute {collision} in class "
160
            "{name}. Consider using functools.total_ordering".format(
161
                collision=collision, name=cls.__name__
162
            )
163
        )
164
165
166
def _values_non_empty(cls, field_names):
167
    for field in field_names:
168
        default = getattr(cls, field, inspect.Parameter.empty)
169
        if default is inspect.Parameter.empty:
170
            return
171
        yield (field, default)
172
173
174
def _values_until_non_empty(cls, field_names):
175
    for field in field_names:
176
        default = getattr(cls, field, inspect.Parameter.empty)
177
        if default is not inspect.Parameter.empty:
178
            yield
179
180
181
def _extract_defaults(*, cls, annotations):
182
    field_names = iter(reversed(tuple(annotations)))
183
    defaults = dict(_values_non_empty(cls, field_names))
184
    for _ in _values_until_non_empty(cls, field_names):
185
        raise TypeError
186
    return defaults
187
188
189
def _unpack_args(*, args, kwargs, fields, values):
190
    fields_iter = iter(fields)
191
    values.update({field: arg for (arg, field) in zip(args, fields_iter)})
192
    for field in fields_iter:
193
        if field in values and field not in kwargs:
194
            continue
195
        values[field] = kwargs.pop(field)
196
    if kwargs:
197
        raise TypeError(kwargs)
198
199
200
class Sum:
201
    """Base class of classes with disjoint constructors.
202
203
    Examines PEP 526 __annotations__ to determine subclasses.
204
205
    If repr is true, a __repr__() method is added to the class.
206
    If order is true, rich comparison dunder methods are added.
207
208
    The Sum class examines the class to find Ctor annotations.
209
    A Ctor annotation is the adt.Ctor class itself, or the result of indexing
210
    the class, either with a single type hint, or a tuple of type hints.
211
    All other annotations are ignored.
212
213
    The subclass is not subclassable, but has subclasses at each of the
214
    names that had Ctor annotations. Each subclass takes a fixed number of
215
    arguments, corresponding to the type hints given to its annotation, if any.
216
    """
217
218
    __slots__ = ()
219
220
    def __new__(*args, **kwargs):  # pylint: disable=no-method-argument
221
        cls, *args = args
222
        if not issubclass(cls, _adt_constructor.ADTConstructor):
223
            raise TypeError
224
        return super(Sum, cls).__new__(cls, *args, **kwargs)
225
226
    # Both of these are for consistency with modules defined in the stdlib.
227
    # BOOM!
228
    def __init_subclass__(
229
        cls,
230
        *,
231
        repr=True,  # pylint: disable=redefined-builtin
232
        eq=True,  # pylint: disable=invalid-name
233
        order=False,
234
        **kwargs
235
    ):
236
        super().__init_subclass__(**kwargs)
237
        if issubclass(cls, _adt_constructor.ADTConstructor):
238
            return
239
        _ordering_options_are_valid(eq=eq, order=order)
240
241
        _prewritten_methods.SUBCLASS_ORDER[cls] = _adt_constructor.make_constructors(
242
            cls
243
        )
244
245
        source = _prewritten_methods.PrewrittenSumMethods
246
247
        cls.__init_subclass__ = source.__init_subclass__  # type: ignore
248
249
        _sum_new(cls, frozenset(_prewritten_methods.SUBCLASS_ORDER[cls]))
250
251
        _set_new_functions(cls, source.__setattr__, source.__delattr__)
252
        _set_new_functions(cls, source.__bool__)
253
254
        if repr:
255
            _set_new_functions(cls, source.__repr__)
256
257
        equality_methods_were_set = False
258
        if eq:
259
            equality_methods_were_set = not _set_new_functions(
260
                cls, source.__eq__, source.__ne__
261
            )
262
263
        if equality_methods_were_set:
264
            cls.__hash__ = source.__hash__
265
266
        if order:
267
            _set_ordering(
268
                can_set=equality_methods_were_set,
269
                setter=_set_new_functions,
270
                cls=cls,
271
                source=source,
272
            )
273
274
275
def _conditional_method(source):
276
    return _attribute_constructor.AttributeConstructor(
277
        functools.partial(_ConditionalMethod, source)
278
    )
279
280
281
class _ConditionalMethod:
282
    name = None
283
    field_check = None
284
285
    def __init__(self, source, field_check):
286
        self.source = source
287
        self.field_check = field_check
288
289
    def __set_name__(self, owner, name):
290
        self.__objclass__ = owner
291
        self.name = name
292
293
    def __get__(self, instance, owner):
294
        if getattr(owner, self.field_check):
295
            return getattr(self.source, self.name).__get__(instance, owner)
296
        target = owner if instance is None else instance
297
        return getattr(super(self.__objclass__, target), self.name)
298
299
    def __set__(self, instance, value):
300
        # Don't care about this coverage
301
        raise AttributeError  # pragma: nocover
302
303
    def __delete__(self, instance):
304
        # Don't care about this coverage
305
        raise AttributeError  # pragma: nocover
306
307
308
class Product(_adt_constructor.ADTConstructor, tuple):
309
    """Base class of classes with typed fields.
310
311
    Examines PEP 526 __annotations__ to determine fields.
312
313
    If repr is true, a __repr__() method is added to the class.
314
    If order is true, rich comparison dunder methods are added.
315
316
    The Product class examines the class to find annotations.
317
    Annotations with a value of "None" are discarded.
318
    Fields may have default values, and can be set to inspect.empty to
319
    indicate "no default".
320
321
    The subclass is subclassable. The implementation was designed with a focus
322
    on flexibility over ideals of purity, and therefore provides various
323
    optional facilities that conflict with, for example, Liskov
324
    substitutability. For the purposes of matching, each class is considered
325
    distinct.
326
    """
327
328
    __slots__ = ()
329
330
    def __new__(*args, **kwargs):  # pylint: disable=no-method-argument
331
        cls, *args = args
332
        if cls is Product:
333
            raise TypeError
334
        # Probably a result of not having positional-only args.
335
        values = cls.__defaults.copy()  # pylint: disable=protected-access
336
        _unpack_args(
337
            args=args,
338
            kwargs=kwargs,
339
            fields=cls.__fields,  # pylint: disable=protected-access
340
            values=values,
341
        )
342
        return super(Product, cls).__new__(
343
            cls,
344
            [
345
                values[field]
346
                for field in cls.__fields  # pylint: disable=protected-access
347
            ],
348
        )
349
350
    __repr = True
351
    __eq = True
352
    __order = False
353
    __eq_succeeded = None
354
355
    # Both of these are for consistency with modules defined in the stdlib.
356
    # BOOM!
357
    def __init_subclass__(
358
        cls,
359
        *,
360
        repr=None,  # pylint: disable=redefined-builtin
361
        eq=None,  # pylint: disable=invalid-name
362
        order=None,
363
        **kwargs
364
    ):
365
        super().__init_subclass__(**kwargs)
366
367
        if repr is not None:
368
            cls.__repr = repr
369
        if eq is not None:
370
            cls.__eq = eq
371
        if order is not None:
372
            cls.__order = order
373
374
        _ordering_options_are_valid(eq=cls.__eq, order=cls.__order)
375
376
        cls.__annotations = _annotations._product_args_from_annotations(cls)
377
        cls.__fields = {field: index for (index, field) in enumerate(cls.__annotations)}
378
379
        cls.__defaults = _extract_defaults(cls=cls, annotations=cls.__annotations)
380
381
        _product_new(cls, cls.__annotations, cls.__defaults)
382
383
        source = _prewritten_methods.PrewrittenProductMethods
384
385
        cls.__eq_succeeded = False
386
        if cls.__eq:
387
            cls.__eq_succeeded = not _cant_set_new_functions(
388
                cls, source.__eq__, source.__ne__
389
            )
390
391
        if cls.__order:
392
            _set_ordering(
393
                can_set=cls.__eq_succeeded,
394
                setter=_cant_set_new_functions,
395
                cls=cls,
396
                source=source,
397
            )
398
399
    def __dir__(self):
400
        return super().__dir__() + list(self.__fields)
401
402
    def __getattribute__(self, name):
403
        try:
404
            return super().__getattribute__(name)
405
        except AttributeError:
406
            index = self.__fields.get(name)
407
            if index is None:
408
                raise
409
            return tuple.__getitem__(self, index)
410
411
    source = _prewritten_methods.PrewrittenProductMethods
412
413
    __setattr__ = source.__setattr__
414
    __delattr__ = source.__delattr__
415
    __bool__ = source.__bool__
416
417
    __repr__ = _conditional_method(source).__repr
418
    __hash__ = _conditional_method(source).__eq_succeeded
419
    __eq__ = _conditional_method(source).__eq_succeeded
420
    __ne__ = _conditional_method(source).__eq_succeeded
421
    __lt__ = _conditional_method(source).__order
422
    __le__ = _conditional_method(source).__order
423
    __gt__ = _conditional_method(source).__order
424
    __ge__ = _conditional_method(source).__order
425
426
    del source
427
428
429
__all__ = ["Ctor", "Product", "Sum"]
430