Passed
Push — master ( aea67e...1f152d )
by Max
01:03
created

structured_data.adt.Product.__init_subclass__()   B

Complexity

Conditions 7

Size

Total Lines 54
Code Lines 43

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 7
eloc 43
nop 6
dl 0
loc 54
rs 7.448
c 0
b 0
f 0

How to fix   Long Method   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

1
"""Base classes for defining abstract data types.
2
3
This module provides three public members, which are used together.
4
5
Given a structure, possibly a choice of different structures, that you'd like
6
to associate with a type:
7
8
- First, create a class, that subclasses the Sum class.
9
- Then, for each possible structure, add an attribute annotation to the class
10
  with the desired name of the constructor, and a type of ``Ctor``, with the
11
  types within the constructor as arguments.
12
13
To look inside an ADT instance, use the functions from the
14
:mod:`structured_data.match` module.
15
16
Putting it together:
17
18
>>> from structured_data import match
19
>>> class Example(Sum):
20
...     FirstConstructor: Ctor[int, str]
21
...     SecondConstructor: Ctor[bytes]
22
...     ThirdConstructor: Ctor
23
...     def __iter__(self):
24
...         matchable = match.Matchable(self)
25
...         if matchable(Example.FirstConstructor(match.pat.count, match.pat.string)):
26
...             count, string = matchable[match.pat.count, match.pat.string]
27
...             for _ in range(count):
28
...                 yield string
29
...         elif matchable(Example.SecondConstructor(match.pat.bytes)):
30
...             bytes_ = matchable[match.pat.bytes]
31
...             for byte in bytes_:
32
...                 yield chr(byte)
33
...         elif matchable(Example.ThirdConstructor()):
34
...             yield "Third"
35
...             yield "Constructor"
36
>>> list(Example.FirstConstructor(5, "abc"))
37
['abc', 'abc', 'abc', 'abc', 'abc']
38
>>> list(Example.SecondConstructor(b"abc"))
39
['a', 'b', 'c']
40
>>> list(Example.ThirdConstructor())
41
['Third', 'Constructor']
42
"""
43
44
import inspect
45
import sys
46
import typing
47
48
from . import _adt_constructor
49
from . import _ctor
50
from . import _prewritten_methods
51
52
_T = typing.TypeVar("_T")
53
54
55
if typing.TYPE_CHECKING:  # pragma: nocover
56
57
    class Ctor:
58
        """Dummy class for type-checking purposes."""
59
60
    class ConcreteCtor(typing.Generic[_T]):
61
        """Wrapper class for type-checking purposes.
62
63
        The type parameter should be a Tuple type of fixed size.
64
        Classes containing this annotation (meaning they haven't been
65
        processed by the ``adt`` decorator) should not be instantiated.
66
        """
67
68
69
else:
70
    from ._ctor import Ctor
71
72
73
# The locations that use this function should be rewritten to use one with a
74
# clearer name.
75
# Right now, this describes what the function does, not why it should be
76
# called.
77
def _conditional_raise(do_raise, exc_class, *args):
78
    if do_raise:
79
        raise exc_class(*args)
80
81
82
# This is fine.
83
def _name(cls: typing.Type[_T], function) -> str:
84
    """Return the name of a function accessed through a descriptor."""
85
    return function.__get__(None, cls).__name__
86
87
88
# This is mostly fine, though the list of classes is somewhat ad-hoc, to say
89
# the least.
90
def _cant_set_new_functions(cls: typing.Type[_T], *functions) -> typing.Optional[str]:
91
    for function in functions:
92
        name = _name(cls, function)
93
        existing = getattr(cls, name, None)
94
        if existing not in (
95
            getattr(object, name, None),
96
            getattr(Product, name, None),
97
            None,
98
            function,
99
        ):
100
            return name
101
    return None
102
103
104
# Maybe it would make more sense to pull the check logic out of this function,
105
# and require it explicitly where this is currently used.
106
def _set_new_functions(cls: typing.Type[_T], *functions) -> typing.Optional[str]:
107
    """Attempt to set the attributes corresponding to the functions on cls.
108
109
    If any attributes are already defined, fail *before* setting any, and
110
    return the already-defined name.
111
    """
112
    cant_set = _cant_set_new_functions(cls, *functions)
113
    if cant_set:
114
        return cant_set
115
    for function in functions:
116
        setattr(cls, _name(cls, function), function)
117
    return None
118
119
120
_K = typing.TypeVar("_K")
121
_V = typing.TypeVar("_V")
122
123
124
def _nillable_write(dct: typing.Dict[_K, _V], key: _K, value: typing.Optional[_V]):
125
    if value is None:
126
        dct.pop(key, typing.cast(_V, None))
127
    else:
128
        dct[key] = value
129
130
131
def _add_methods(cls: typing.Type[_T], do_set, *methods):
132
    methods_were_set = False
133
    if do_set:
134
        methods_were_set = not _set_new_functions(cls, *methods)
135
    return methods_were_set
136
137
138
def _sum_new(_cls: typing.Type[_T], subclasses):
139
    def base(cls, args):
140
        return super(_cls, cls).__new__(cls, args)
141
142
    new = _cls.__dict__.get("__new__", staticmethod(base))
143
144
    def __new__(cls, args):
145
        if cls not in subclasses:
146
            raise TypeError
147
        return new.__get__(None, cls)(cls, args)
148
149
    _cls.__new__ = staticmethod(__new__)  # type: ignore
150
151
152
def _product_new(
153
    _cls: typing.Type[_T],
154
    annotations: typing.Dict[str, typing.Any],
155
    defaults: typing.Dict[str, typing.Any],
156
):
157
    def __new__(*args, **kwargs):
158
        cls, *args = args
159
        return super(_cls, cls).__new__(cls, *args, **kwargs)
160
161
    __new__.__signature__ = inspect.signature(__new__).replace(
162
        parameters=[inspect.Parameter("cls", inspect.Parameter.POSITIONAL_ONLY)]
163
        + [
164
            inspect.Parameter(
165
                field,
166
                inspect.Parameter.POSITIONAL_OR_KEYWORD,
167
                annotation=annotation,
168
                default=defaults.get(field, inspect.Parameter.empty),
169
            )
170
            for (field, annotation) in annotations.items()
171
        ]
172
    )
173
    _cls.__new__ = __new__
174
175
176
def _all_annotations(
177
    cls: typing.Type[_T]
178
) -> typing.Iterator[typing.Tuple[typing.Type[_T], str, typing.Any]]:
179
    for superclass in reversed(cls.__mro__):
180
        for key, value in vars(superclass).get("__annotations__", {}).items():
181
            yield (superclass, key, value)
182
183
184
def _sum_args_from_annotations(cls: typing.Type[_T]) -> typing.Dict[str, typing.Tuple]:
185
    args: typing.Dict[str, typing.Tuple] = {}
186
    for superclass, key, value in _all_annotations(cls):
187
        _nillable_write(
188
            args, key, _ctor.get_args(value, vars(sys.modules[superclass.__module__]))
189
        )
190
    return args
191
192
193
def _product_args_from_annotations(
194
    cls: typing.Type[_T]
195
) -> typing.Dict[str, typing.Any]:
196
    args: typing.Dict[str, typing.Any] = {}
197
    for _, key, value in _all_annotations(cls):
198
        if value == "None" or _ctor.annotation_is_classvar(
199
            value, vars(sys.modules[cls.__module__])
200
        ):
201
            value = None
202
        _nillable_write(args, key, value)
203
    return args
204
205
206
def _ordering_options_are_valid(*, eq, order):
207
    if order and not eq:
208
        raise ValueError("eq must be true if order is true")
209
210
211
def _set_ordering(*, order, can_set, setter, cls, source):
212
    if not order:
213
        return
214
    if not can_set:
215
        raise ValueError("Can't add ordering methods if equality methods are provided.")
216
    collision = setter(cls, source.__lt__, source.__le__, source.__gt__, source.__ge__)
217
    if collision:
218
        raise TypeError(
219
            "Cannot overwrite attribute {collision} in class "
220
            "{name}. Consider using functools.total_ordering".format(
221
                collision=collision, name=cls.__name__
222
            )
223
        )
224
225
226
class Sum:
227
    """Base class of classes with disjoint constructors.
228
229
    Examines PEP 526 __annotations__ to determine subclasses.
230
231
    If repr is true, a __repr__() method is added to the class.
232
    If order is true, rich comparison dunder methods are added.
233
234
    The Sum class examines the class to find Ctor annotations.
235
    A Ctor annotation is the adt.Ctor class itself, or the result of indexing
236
    the class, either with a single type hint, or a tuple of type hints.
237
    All other annotations are ignored.
238
239
    The subclass is not subclassable, but has subclasses at each of the
240
    names that had Ctor annotations. Each subclass takes a fixed number of
241
    arguments, corresponding to the type hints given to its annotation, if any.
242
    """
243
244
    __slots__ = ()
245
246
    def __new__(*args, **kwargs):  # pylint: disable=no-method-argument
247
        cls, *args = args
248
        if not issubclass(cls, _adt_constructor.ADTConstructor):
249
            raise TypeError
250
        return super(Sum, cls).__new__(cls, *args, **kwargs)
251
252
    # Both of these are for consistency with modules defined in the stdlib.
253
    # BOOM!
254
    def __init_subclass__(
255
        cls,
256
        *,
257
        repr=True,  # pylint: disable=redefined-builtin
258
        eq=True,  # pylint: disable=invalid-name
259
        order=False,
260
        **kwargs
261
    ):
262
        super().__init_subclass__(**kwargs)
263
        if issubclass(cls, _adt_constructor.ADTConstructor):
264
            return
265
        _ordering_options_are_valid(eq=eq, order=order)
266
267
        subclass_order: typing.List[typing.Type[_T]] = []
268
269
        for name, args in _sum_args_from_annotations(cls).items():
270
            _adt_constructor.make_constructor(cls, name, args, subclass_order)
271
272
        _prewritten_methods.SUBCLASS_ORDER[cls] = tuple(subclass_order)
273
274
        cls.__init_subclass__ = (
275
            _prewritten_methods.PrewrittenSumMethods.__init_subclass__
276
        )  # type: ignore
277
278
        _sum_new(cls, frozenset(subclass_order))
279
280
        _set_new_functions(
281
            cls,
282
            _prewritten_methods.PrewrittenSumMethods.__setattr__,
283
            _prewritten_methods.PrewrittenSumMethods.__delattr__,
284
        )
285
        _set_new_functions(cls, _prewritten_methods.PrewrittenSumMethods.__bool__)
286
287
        _add_methods(cls, repr, _prewritten_methods.PrewrittenSumMethods.__repr__)
288
289
        equality_methods_were_set = _add_methods(
290
            cls,
291
            eq,
292
            _prewritten_methods.PrewrittenSumMethods.__eq__,
293
            _prewritten_methods.PrewrittenSumMethods.__ne__,
294
        )
295
296
        if equality_methods_were_set:
297
            cls.__hash__ = _prewritten_methods.PrewrittenSumMethods.__hash__
298
299
        _set_ordering(
300
            order=order,
301
            can_set=equality_methods_were_set,
302
            setter=_set_new_functions,
303
            cls=cls,
304
            source=_prewritten_methods.PrewrittenSumMethods,
305
        )
306
307
308
class Product(_adt_constructor.ADTConstructor, tuple):
309
    """Base class of classes with typed fields.
310
311
    Examines PEP 526 __annotations__ to determine fields.
312
313
    If repr is true, a __repr__() method is added to the class.
314
    If order is true, rich comparison dunder methods are added.
315
316
    The Product class examines the class to find annotations.
317
    Annotations with a value of "None" are discarded.
318
    Fields may have default values, and can be set to inspect.empty to
319
    indicate "no default".
320
321
    The subclass is subclassable. The implementation was designed with a focus
322
    on flexibility over ideals of purity, and therefore provides various
323
    optional facilities that conflict with, for example, Liskov
324
    substitutability. For the purposes of matching, each class is considered
325
    distinct.
326
    """
327
328
    __slots__ = ()
329
330
    def __new__(*args, **kwargs):  # pylint: disable=no-method-argument
331
        cls, *args = args
332
        _conditional_raise(cls is Product, TypeError)
333
        # Similar to https://github.com/PyCQA/pylint/issues/1802
334
        values = cls.__defaults.copy()  # pylint: disable=protected-access
335
        fields_iter = iter(cls.__fields)  # pylint: disable=protected-access
336
        for arg, field in zip(args, fields_iter):
337
            values[field] = arg
338
        for field in fields_iter:
339
            if field in values and field not in kwargs:
340
                continue
341
            values[field] = kwargs.pop(field)
342
        _conditional_raise(kwargs, TypeError, kwargs)
343
        return super(Product, cls).__new__(
344
            cls,
345
            [
346
                values[field]
347
                for field in cls.__fields  # pylint: disable=protected-access
348
            ],
349
        )
350
351
    __repr = True
352
    __eq = True
353
    __order = False
354
    __eq_succeeded = None
355
356
    # Both of these are for consistency with modules defined in the stdlib.
357
    # BOOM!
358
    def __init_subclass__(
359
        cls,
360
        *,
361
        repr=None,  # pylint: disable=redefined-builtin
362
        eq=None,  # pylint: disable=invalid-name
363
        order=None,
364
        **kwargs
365
    ):
366
        super().__init_subclass__(**kwargs)
367
368
        if repr is not None:
369
            cls.__repr = repr
370
        if eq is not None:
371
            cls.__eq = eq
372
        if order is not None:
373
            cls.__order = order
374
375
        _ordering_options_are_valid(eq=cls.__eq, order=cls.__order)
376
377
        cls.__annotations = _product_args_from_annotations(cls)
378
        cls.__fields = {field: index for (index, field) in enumerate(cls.__annotations)}
379
380
        cls.__defaults = {}
381
        field_names = iter(reversed(tuple(cls.__annotations)))
382
        for field in field_names:
383
            default = getattr(cls, field, inspect.Parameter.empty)
384
            if default is inspect.Parameter.empty:
385
                break
386
            cls.__defaults[field] = default
387
        _conditional_raise(
388
            any(
389
                getattr(cls, field, inspect.Parameter.empty)
390
                is not inspect.Parameter.empty
391
                for field in field_names
392
            ),
393
            TypeError,
394
        )
395
396
        _product_new(cls, cls.__annotations, cls.__defaults)
397
398
        cls.__eq_succeeded = False
399
        if cls.__eq:
400
            cls.__eq_succeeded = not _cant_set_new_functions(
401
                cls,
402
                _prewritten_methods.PrewrittenProductMethods.__eq__,
403
                _prewritten_methods.PrewrittenProductMethods.__ne__,
404
            )
405
406
        _set_ordering(
407
            order=cls.__order,
408
            can_set=cls.__eq_succeeded,
409
            setter=_cant_set_new_functions,
410
            cls=cls,
411
            source=_prewritten_methods.PrewrittenProductMethods,
412
        )
413
414
    def __dir__(self):
415
        return super().__dir__() + list(self.__fields)
416
417
    def __getattribute__(self, name):
418
        try:
419
            return super().__getattribute__(name)
420
        except AttributeError:
421
            index = self.__fields.get(name)
422
            if index is None:
423
                raise
424
            return tuple.__getitem__(self, index)
425
426
    __setattr__ = _prewritten_methods.PrewrittenProductMethods.__setattr__
427
    __delattr__ = _prewritten_methods.PrewrittenProductMethods.__delattr__
428
    __bool__ = _prewritten_methods.PrewrittenProductMethods.__bool__
429
430
    @property
431
    def __repr__(self):
432
        if self.__repr:
433
            return _prewritten_methods.PrewrittenProductMethods.__repr__.__get__(
434
                self, type(self)
435
            )
436
        return super().__repr__
437
438
    @property
439
    def __hash__(self):
440
        if self.__eq_succeeded:
441
            return _prewritten_methods.PrewrittenProductMethods.__hash__.__get__(
442
                self, type(self)
443
            )
444
        return super().__hash__
445
446
    @property
447
    def __eq__(self):  # pylint: disable=unexpected-special-method-signature
448
        if self.__eq_succeeded:
449
            # I think this is a Pylint bug, but I'm not sure how to reduce it.
450
            # pylint: disable=no-value-for-parameter
451
            return _prewritten_methods.PrewrittenProductMethods.__eq__.__get__(
452
                self, type(self)
453
            )
454
        return super().__eq__
455
456
    @property
457
    def __ne__(self):  # pylint: disable=unexpected-special-method-signature
458
        if self.__eq_succeeded:
459
            # I think this is a Pylint bug, but I'm not sure how to reduce it.
460
            # pylint: disable=no-value-for-parameter
461
            return _prewritten_methods.PrewrittenProductMethods.__ne__.__get__(
462
                self, type(self)
463
            )
464
        return super().__ne__
465
466
    @property
467
    def __lt__(self):  # pylint: disable=unexpected-special-method-signature
468
        if self.__order:
469
            # I think this is a Pylint bug, but I'm not sure how to reduce it.
470
            # pylint: disable=no-value-for-parameter
471
            return _prewritten_methods.PrewrittenProductMethods.__lt__.__get__(
472
                self, type(self)
473
            )
474
        return super().__lt__
475
476
    @property
477
    def __le__(self):  # pylint: disable=unexpected-special-method-signature
478
        if self.__order:
479
            # I think this is a Pylint bug, but I'm not sure how to reduce it.
480
            # pylint: disable=no-value-for-parameter
481
            return _prewritten_methods.PrewrittenProductMethods.__le__.__get__(
482
                self, type(self)
483
            )
484
        return super().__le__
485
486
    @property
487
    def __gt__(self):  # pylint: disable=unexpected-special-method-signature
488
        if self.__order:
489
            # I think this is a Pylint bug, but I'm not sure how to reduce it.
490
            # pylint: disable=no-value-for-parameter
491
            return _prewritten_methods.PrewrittenProductMethods.__gt__.__get__(
492
                self, type(self)
493
            )
494
        return super().__gt__
495
496
    @property
497
    def __ge__(self):  # pylint: disable=unexpected-special-method-signature
498
        if self.__order:
499
            # I think this is a Pylint bug, but I'm not sure how to reduce it.
500
            # pylint: disable=no-value-for-parameter
501
            return _prewritten_methods.PrewrittenProductMethods.__ge__.__get__(
502
                self, type(self)
503
            )
504
        return super().__ge__
505
506
507
__all__ = ["Ctor", "Product", "Sum"]
508