Passed
Push — master ( b076f2...d88a76 )
by Max
01:00
created

structured_data.adt._sum_new()   A

Complexity

Conditions 2

Size

Total Lines 12
Code Lines 9

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 9
nop 2
dl 0
loc 12
rs 9.95
c 0
b 0
f 0
1
"""Base classes for defining abstract data types.
2
3
This module provides three public members, which are used together.
4
5
Given a structure, possibly a choice of different structures, that you'd like
6
to associate with a type:
7
8
- First, create a class, that subclasses the Sum class.
9
- Then, for each possible structure, add an attribute annotation to the class
10
  with the desired name of the constructor, and a type of ``Ctor``, with the
11
  types within the constructor as arguments.
12
13
To look inside an ADT instance, use the functions from the
14
:mod:`structured_data.match` module.
15
16
Putting it together:
17
18
>>> from structured_data import match
19
>>> class Example(Sum):
20
...     FirstConstructor: Ctor[int, str]
21
...     SecondConstructor: Ctor[bytes]
22
...     ThirdConstructor: Ctor
23
...     def __iter__(self):
24
...         matchable = match.Matchable(self)
25
...         if matchable(Example.FirstConstructor(match.pat.count, match.pat.string)):
26
...             count, string = matchable[match.pat.count, match.pat.string]
27
...             for _ in range(count):
28
...                 yield string
29
...         elif matchable(Example.SecondConstructor(match.pat.bytes)):
30
...             bytes_ = matchable[match.pat.bytes]
31
...             for byte in bytes_:
32
...                 yield chr(byte)
33
...         elif matchable(Example.ThirdConstructor()):
34
...             yield "Third"
35
...             yield "Constructor"
36
>>> list(Example.FirstConstructor(5, "abc"))
37
['abc', 'abc', 'abc', 'abc', 'abc']
38
>>> list(Example.SecondConstructor(b"abc"))
39
['a', 'b', 'c']
40
>>> list(Example.ThirdConstructor())
41
['Third', 'Constructor']
42
"""
43
44
import inspect
45
import sys
46
import typing
47
48
from ._adt_constructor import ADTConstructor
49
from ._adt_constructor import make_constructor
50
from ._ctor import get_args
51
from ._prewritten_methods import SUBCLASS_ORDER
52
from ._prewritten_methods import PrewrittenProductMethods
53
from ._prewritten_methods import PrewrittenSumMethods
54
55
_T = typing.TypeVar("_T")
56
57
58
if typing.TYPE_CHECKING:  # pragma: nocover
59
60
    class Ctor:
61
        """Dummy class for type-checking purposes."""
62
63
    class ConcreteCtor(typing.Generic[_T]):
64
        """Wrapper class for type-checking purposes.
65
66
        The type parameter should be a Tuple type of fixed size.
67
        Classes containing this annotation (meaning they haven't been
68
        processed by the ``adt`` decorator) should not be instantiated.
69
        """
70
71
72
else:
73
    from ._ctor import Ctor
74
75
76
def _name(cls: typing.Type[_T], function) -> str:
77
    """Return the name of a function accessed through a descriptor."""
78
    return function.__get__(None, cls).__name__
79
80
81
def _set_new_functions(cls: typing.Type[_T], *functions) -> typing.Optional[str]:
82
    """Attempt to set the attributes corresponding to the functions on cls.
83
84
    If any attributes are already defined, fail *before* setting any, and
85
    return the already-defined name.
86
    """
87
    for function in functions:
88
        name = _name(cls, function)
89
        if getattr(object, name, None) is not getattr(cls, name, None):
90
            return name
91
    for function in functions:
92
        setattr(cls, _name(cls, function), function)
93
    return None
94
95
96
_K = typing.TypeVar("_K")
97
_V = typing.TypeVar("_V")
98
99
100
def _nillable_write(dct: typing.Dict[_K, _V], key: _K, value: typing.Optional[_V]):
101
    if value is None:
102
        dct.pop(key, typing.cast(_V, None))
103
    else:
104
        dct[key] = value
105
106
107
def _add_methods(cls: typing.Type[_T], do_set, *methods):
108
    methods_were_set = False
109
    if do_set:
110
        methods_were_set = not _set_new_functions(cls, *methods)
111
    return methods_were_set
112
113
114
def _sum_new(_cls: typing.Type[_T], subclasses):
115
    def base(cls, args):
116
        return super(_cls, cls).__new__(cls, args)
117
118
    new = _cls.__dict__.get("__new__", staticmethod(base))
119
120
    def __new__(cls, args):
121
        if cls not in subclasses:
122
            raise TypeError
123
        return new.__get__(None, cls)(cls, args)
124
125
    _cls.__new__ = staticmethod(__new__)  # type: ignore
126
127
128
def _product_new(
129
    _cls: typing.Type[_T],
130
    annotations: typing.Dict[str, typing.Any],
131
    defaults: typing.Dict[str, typing.Any],
132
):
133
    def __new__(*args, **kwargs):
134
        cls, *args = args
135
        return super(_cls, cls).__new__(cls, *args, **kwargs)
136
137
    __new__.__signature__ = inspect.signature(__new__).replace(
138
        parameters=[inspect.Parameter("cls", inspect.Parameter.POSITIONAL_ONLY)]
139
        + [
140
            inspect.Parameter(
141
                field,
142
                inspect.Parameter.POSITIONAL_OR_KEYWORD,
143
                annotation=annotation,
144
                default=defaults.get(field, inspect.Parameter.empty),
145
            )
146
            for (field, annotation) in annotations.items()
147
        ]
148
    )
149
    _cls.__new__ = __new__
150
151
152
def _all_annotations(
153
    cls: typing.Type[_T]
154
) -> typing.Iterator[typing.Tuple[typing.Type[_T], str, typing.Any]]:
155
    for superclass in reversed(cls.__mro__):
156
        for key, value in vars(superclass).get("__annotations__", {}).items():
157
            yield (superclass, key, value)
158
159
160
def _sum_args_from_annotations(cls: typing.Type[_T]) -> typing.Dict[str, typing.Tuple]:
161
    args: typing.Dict[str, typing.Tuple] = {}
162
    for superclass, key, value in _all_annotations(cls):
163
        _nillable_write(
164
            args, key, get_args(value, vars(sys.modules[superclass.__module__]))
165
        )
166
    return args
167
168
169
def _product_args_from_annotations(
170
    cls: typing.Type[_T]
171
) -> typing.Dict[str, typing.Any]:
172
    args: typing.Dict[str, typing.Any] = {}
173
    for superclass, key, value in _all_annotations(cls):
174
        if value == "None":
175
            value = None
176
        _nillable_write(args, key, value)
177
    return args
178
179
180
def _add_prewritten_methods(_cls: typing.Type[_T], _repr, eq, order, src):
181
    _set_new_functions(_cls, src.__setattr__, src.__delattr__)
182
    _set_new_functions(_cls, src.__bool__)
183
184
    _add_methods(_cls, _repr, src.__repr__)
185
186
    equality_methods_were_set = _add_methods(_cls, eq, src.__eq__, src.__ne__)
187
188
    if equality_methods_were_set:
189
        _cls.__hash__ = src.__hash__
190
191
    if not order:
192
        return
193
194
    if not equality_methods_were_set:
195
        raise ValueError("Can't add ordering methods if equality methods are provided.")
196
    collision = _set_new_functions(_cls, src.__lt__, src.__le__, src.__gt__, src.__ge__)
197
    if collision:
198
        raise TypeError(
199
            "Cannot overwrite attribute {collision} in class "
200
            "{name}. Consider using functools.total_ordering".format(
201
                collision=collision, name=_cls.__name__
202
            )
203
        )
204
205
206
class Sum:
207
    """Base class of classes with disjoint constructors.
208
209
    Examines PEP 526 __annotations__ to determine subclasses.
210
211
    If repr is true, a __repr__() method is added to the class.
212
    If order is true, rich comparison dunder methods are added.
213
214
    The Sum class examines the class to find Ctor annotations.
215
    A Ctor annotation is the adt.Ctor class itself, or the result of indexing
216
    the class, either with a single type hint, or a tuple of type hints.
217
    All other annotations are ignored.
218
219
    The subclass is not subclassable, but has subclasses at each of the
220
    names that had Ctor annotations. Each subclass takes a fixed number of
221
    arguments, corresponding to the type hints given to its annotation, if any.
222
    """
223
224
    __slots__ = ()
225
226
    def __init_subclass__(cls, *, repr=True, eq=True, order=False, **kwargs):
227
        super().__init_subclass__(**kwargs)
228
        if issubclass(cls, ADTConstructor):
229
            return
230
        if order and not eq:
231
            raise ValueError("eq must be true if order is true")
232
233
        subclass_order: typing.List[typing.Type[_T]] = []
234
235
        for name, args in _sum_args_from_annotations(cls).items():
236
            make_constructor(cls, name, args, subclass_order)
237
238
        SUBCLASS_ORDER[cls] = tuple(subclass_order)
239
240
        cls.__init_subclass__ = PrewrittenSumMethods.__init_subclass__  # type: ignore
241
242
        _sum_new(cls, frozenset(subclass_order))
243
244
        _add_prewritten_methods(cls, repr, eq, order, PrewrittenSumMethods)
245
246
247
class Product(ADTConstructor, tuple):
248
    """Base class of classes with typed fields.
249
250
    Examines PEP 526 __annotations__ to determine fields.
251
252
    If repr is true, a __repr__() method is added to the class.
253
    If order is true, rich comparison dunder methods are added.
254
255
    The Product class examines the class to find annotations.
256
    Annotations with a value of "None" are discarded.
257
    Fields may have default values, and can be set to inspect.empty to
258
    indicate "no default".
259
260
    The subclass is subclassable. The implementation was designed with a focus
261
    on flexibility over ideals of purity, and therefore provides various
262
    optional facilities that conflict with, for example, Liskov
263
    substitutability. For the purposes of matching, each class is considered
264
    distinct.
265
    """
266
267
    __slots__ = ()
268
269
    def __new__(*args, **kwargs):
270
        cls, *args = args
271
        values = cls.__defaults.copy()
272
        fields_iter = iter(cls.__annotations)
273
        for arg, field in zip(args, fields_iter):
274
            values[field] = arg
275
        for field in fields_iter:
276
            if field in values and field not in kwargs:
277
                continue
278
            values[field] = kwargs.pop(field)
279
        if kwargs:
280
            raise TypeError(kwargs)
281
        return super(Product, cls).__new__(
282
            cls, [values[field] for field in cls.__annotations]
283
        )
284
285
    __repr = True
286
    __eq = True
287
    __order = False
288
289
    def __init_subclass__(cls, *, repr=None, eq=None, order=None, **kwargs):
290
        super().__init_subclass__(**kwargs)
291
        if repr is not None:
292
            cls.__repr = repr
293
        if eq is not None:
294
            cls.__eq = eq
295
        if order is not None:
296
            cls.__order = order
297
        if cls.__order and not cls.__eq:
298
            raise ValueError("eq must be true if order is true")
299
300
        cls.__annotations = _product_args_from_annotations(cls)
301
        cls.__fields = {field: index for (index, field) in enumerate(cls.__annotations)}
302
303
        cls.__defaults = {}
304
        field_names = iter(reversed(tuple(cls.__annotations)))
305
        for field in field_names:
306
            default = getattr(cls, field, inspect.Parameter.empty)
307
            if default is inspect.Parameter.empty:
308
                break
309
            cls.__defaults[field] = default
310
        for field in field_names:
311
            if (
312
                getattr(cls, field, inspect.Parameter.empty)
313
                is not inspect.Parameter.empty
314
            ):
315
                raise TypeError
316
317
        _product_new(cls, cls.__annotations, cls.__defaults)
318
319
        _add_prewritten_methods(cls, cls.__repr, cls.__eq, cls.__order, PrewrittenProductMethods)
320
321
    def __dir__(self):
322
        return super().__dir__() + list(self.__fields)
323
324
    def __getattribute__(self, name):
325
        try:
326
            return super().__getattribute__(name)
327
        except AttributeError:
328
            index = self.__fields.get(name)
329
            if index is None:
330
                raise
331
            return tuple.__getitem__(self, index)
332
333
334
__all__ = ["Ctor", "Product", "Sum"]
335