Passed
Push — master ( 1f152d...934b7f )
by Max
01:05
created

structured_data.adt.Product.__init_subclass__()   C

Complexity

Conditions 8

Size

Total Lines 54
Code Lines 43

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 8
eloc 43
nop 6
dl 0
loc 54
rs 6.9813
c 0
b 0
f 0

How to fix   Long Method   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

1
"""Base classes for defining abstract data types.
2
3
This module provides three public members, which are used together.
4
5
Given a structure, possibly a choice of different structures, that you'd like
6
to associate with a type:
7
8
- First, create a class, that subclasses the Sum class.
9
- Then, for each possible structure, add an attribute annotation to the class
10
  with the desired name of the constructor, and a type of ``Ctor``, with the
11
  types within the constructor as arguments.
12
13
To look inside an ADT instance, use the functions from the
14
:mod:`structured_data.match` module.
15
16
Putting it together:
17
18
>>> from structured_data import match
19
>>> class Example(Sum):
20
...     FirstConstructor: Ctor[int, str]
21
...     SecondConstructor: Ctor[bytes]
22
...     ThirdConstructor: Ctor
23
...     def __iter__(self):
24
...         matchable = match.Matchable(self)
25
...         if matchable(Example.FirstConstructor(match.pat.count, match.pat.string)):
26
...             count, string = matchable[match.pat.count, match.pat.string]
27
...             for _ in range(count):
28
...                 yield string
29
...         elif matchable(Example.SecondConstructor(match.pat.bytes)):
30
...             bytes_ = matchable[match.pat.bytes]
31
...             for byte in bytes_:
32
...                 yield chr(byte)
33
...         elif matchable(Example.ThirdConstructor()):
34
...             yield "Third"
35
...             yield "Constructor"
36
>>> list(Example.FirstConstructor(5, "abc"))
37
['abc', 'abc', 'abc', 'abc', 'abc']
38
>>> list(Example.SecondConstructor(b"abc"))
39
['a', 'b', 'c']
40
>>> list(Example.ThirdConstructor())
41
['Third', 'Constructor']
42
"""
43
44
import inspect
45
import sys
46
import typing
47
48
from . import _adt_constructor
49
from . import _ctor
50
from . import _prewritten_methods
51
52
_T = typing.TypeVar("_T")
53
54
55
if typing.TYPE_CHECKING:  # pragma: nocover
56
57
    class Ctor:
58
        """Dummy class for type-checking purposes."""
59
60
    class ConcreteCtor(typing.Generic[_T]):
61
        """Wrapper class for type-checking purposes.
62
63
        The type parameter should be a Tuple type of fixed size.
64
        Classes containing this annotation (meaning they haven't been
65
        processed by the ``adt`` decorator) should not be instantiated.
66
        """
67
68
69
else:
70
    from ._ctor import Ctor
71
72
73
# The locations that use this function should be rewritten to use one with a
74
# clearer name.
75
# Right now, this describes what the function does, not why it should be
76
# called.
77
def _conditional_raise(do_raise, exc_class, *args):
78
    if do_raise:
79
        raise exc_class(*args)
80
81
82
# This is fine.
83
def _name(cls: typing.Type[_T], function) -> str:
84
    """Return the name of a function accessed through a descriptor."""
85
    return function.__get__(None, cls).__name__
86
87
88
# This is mostly fine, though the list of classes is somewhat ad-hoc, to say
89
# the least.
90
def _cant_set_new_functions(cls: typing.Type[_T], *functions) -> typing.Optional[str]:
91
    for function in functions:
92
        name = _name(cls, function)
93
        existing = getattr(cls, name, None)
94
        if existing not in (
95
            getattr(object, name, None),
96
            getattr(Product, name, None),
97
            None,
98
            function,
99
        ):
100
            return name
101
    return None
102
103
104
# Maybe it would make more sense to pull the check logic out of this function,
105
# and require it explicitly where this is currently used.
106
def _set_new_functions(cls: typing.Type[_T], *functions) -> typing.Optional[str]:
107
    """Attempt to set the attributes corresponding to the functions on cls.
108
109
    If any attributes are already defined, fail *before* setting any, and
110
    return the already-defined name.
111
    """
112
    cant_set = _cant_set_new_functions(cls, *functions)
113
    if cant_set:
114
        return cant_set
115
    for function in functions:
116
        setattr(cls, _name(cls, function), function)
117
    return None
118
119
120
_K = typing.TypeVar("_K")
121
_V = typing.TypeVar("_V")
122
123
124
def _nillable_write(dct: typing.Dict[_K, _V], key: _K, value: typing.Optional[_V]):
125
    if value is None:
126
        dct.pop(key, typing.cast(_V, None))
127
    else:
128
        dct[key] = value
129
130
131
def _sum_new(_cls: typing.Type[_T], subclasses):
132
    def base(cls, args):
133
        return super(_cls, cls).__new__(cls, args)
134
135
    new = _cls.__dict__.get("__new__", staticmethod(base))
136
137
    def __new__(cls, args):
138
        if cls not in subclasses:
139
            raise TypeError
140
        return new.__get__(None, cls)(cls, args)
141
142
    _cls.__new__ = staticmethod(__new__)  # type: ignore
143
144
145
def _product_new(
146
    _cls: typing.Type[_T],
147
    annotations: typing.Dict[str, typing.Any],
148
    defaults: typing.Dict[str, typing.Any],
149
):
150
    def __new__(*args, **kwargs):
151
        cls, *args = args
152
        return super(_cls, cls).__new__(cls, *args, **kwargs)
153
154
    __new__.__signature__ = inspect.signature(__new__).replace(
155
        parameters=[inspect.Parameter("cls", inspect.Parameter.POSITIONAL_ONLY)]
156
        + [
157
            inspect.Parameter(
158
                field,
159
                inspect.Parameter.POSITIONAL_OR_KEYWORD,
160
                annotation=annotation,
161
                default=defaults.get(field, inspect.Parameter.empty),
162
            )
163
            for (field, annotation) in annotations.items()
164
        ]
165
    )
166
    _cls.__new__ = __new__
167
168
169
def _all_annotations(
170
    cls: typing.Type[_T]
171
) -> typing.Iterator[typing.Tuple[typing.Type[_T], str, typing.Any]]:
172
    for superclass in reversed(cls.__mro__):
173
        for key, value in vars(superclass).get("__annotations__", {}).items():
174
            yield (superclass, key, value)
175
176
177
def _sum_args_from_annotations(cls: typing.Type[_T]) -> typing.Dict[str, typing.Tuple]:
178
    args: typing.Dict[str, typing.Tuple] = {}
179
    for superclass, key, value in _all_annotations(cls):
180
        _nillable_write(
181
            args, key, _ctor.get_args(value, vars(sys.modules[superclass.__module__]))
182
        )
183
    return args
184
185
186
def _product_args_from_annotations(
187
    cls: typing.Type[_T]
188
) -> typing.Dict[str, typing.Any]:
189
    args: typing.Dict[str, typing.Any] = {}
190
    for _, key, value in _all_annotations(cls):
191
        if value == "None" or _ctor.annotation_is_classvar(
192
            value, vars(sys.modules[cls.__module__])
193
        ):
194
            value = None
195
        _nillable_write(args, key, value)
196
    return args
197
198
199
def _ordering_options_are_valid(*, eq, order):
200
    if order and not eq:
201
        raise ValueError("eq must be true if order is true")
202
203
204
def _set_ordering(*, can_set, setter, cls, source):
205
    if not can_set:
206
        raise ValueError("Can't add ordering methods if equality methods are provided.")
207
    collision = setter(cls, source.__lt__, source.__le__, source.__gt__, source.__ge__)
208
    if collision:
209
        raise TypeError(
210
            "Cannot overwrite attribute {collision} in class "
211
            "{name}. Consider using functools.total_ordering".format(
212
                collision=collision, name=cls.__name__
213
            )
214
        )
215
216
217
class Sum:
218
    """Base class of classes with disjoint constructors.
219
220
    Examines PEP 526 __annotations__ to determine subclasses.
221
222
    If repr is true, a __repr__() method is added to the class.
223
    If order is true, rich comparison dunder methods are added.
224
225
    The Sum class examines the class to find Ctor annotations.
226
    A Ctor annotation is the adt.Ctor class itself, or the result of indexing
227
    the class, either with a single type hint, or a tuple of type hints.
228
    All other annotations are ignored.
229
230
    The subclass is not subclassable, but has subclasses at each of the
231
    names that had Ctor annotations. Each subclass takes a fixed number of
232
    arguments, corresponding to the type hints given to its annotation, if any.
233
    """
234
235
    __slots__ = ()
236
237
    def __new__(*args, **kwargs):  # pylint: disable=no-method-argument
238
        cls, *args = args
239
        if not issubclass(cls, _adt_constructor.ADTConstructor):
240
            raise TypeError
241
        return super(Sum, cls).__new__(cls, *args, **kwargs)
242
243
    # Both of these are for consistency with modules defined in the stdlib.
244
    # BOOM!
245
    def __init_subclass__(
246
        cls,
247
        *,
248
        repr=True,  # pylint: disable=redefined-builtin
249
        eq=True,  # pylint: disable=invalid-name
250
        order=False,
251
        **kwargs
252
    ):
253
        super().__init_subclass__(**kwargs)
254
        if issubclass(cls, _adt_constructor.ADTConstructor):
255
            return
256
        _ordering_options_are_valid(eq=eq, order=order)
257
258
        subclass_order: typing.List[typing.Type[_T]] = []
259
260
        for name, args in _sum_args_from_annotations(cls).items():
261
            _adt_constructor.make_constructor(cls, name, args, subclass_order)
262
263
        _prewritten_methods.SUBCLASS_ORDER[cls] = tuple(subclass_order)
264
265
        cls.__init_subclass__ = (
266
            _prewritten_methods.PrewrittenSumMethods.__init_subclass__
267
        )  # type: ignore
268
269
        _sum_new(cls, frozenset(subclass_order))
270
271
        _set_new_functions(
272
            cls,
273
            _prewritten_methods.PrewrittenSumMethods.__setattr__,
274
            _prewritten_methods.PrewrittenSumMethods.__delattr__,
275
        )
276
        _set_new_functions(cls, _prewritten_methods.PrewrittenSumMethods.__bool__)
277
278
        if repr:
279
            _set_new_functions(cls, _prewritten_methods.PrewrittenSumMethods.__repr__)
280
281
        equality_methods_were_set = False
282
        if eq:
283
            equality_methods_were_set = not _set_new_functions(
284
                cls,
285
                _prewritten_methods.PrewrittenSumMethods.__eq__,
286
                _prewritten_methods.PrewrittenSumMethods.__ne__,
287
            )
288
289
        if equality_methods_were_set:
290
            cls.__hash__ = _prewritten_methods.PrewrittenSumMethods.__hash__
291
292
        if order:
293
            _set_ordering(
294
                can_set=equality_methods_were_set,
295
                setter=_set_new_functions,
296
                cls=cls,
297
                source=_prewritten_methods.PrewrittenSumMethods,
298
            )
299
300
301
class Product(_adt_constructor.ADTConstructor, tuple):
302
    """Base class of classes with typed fields.
303
304
    Examines PEP 526 __annotations__ to determine fields.
305
306
    If repr is true, a __repr__() method is added to the class.
307
    If order is true, rich comparison dunder methods are added.
308
309
    The Product class examines the class to find annotations.
310
    Annotations with a value of "None" are discarded.
311
    Fields may have default values, and can be set to inspect.empty to
312
    indicate "no default".
313
314
    The subclass is subclassable. The implementation was designed with a focus
315
    on flexibility over ideals of purity, and therefore provides various
316
    optional facilities that conflict with, for example, Liskov
317
    substitutability. For the purposes of matching, each class is considered
318
    distinct.
319
    """
320
321
    __slots__ = ()
322
323
    def __new__(*args, **kwargs):  # pylint: disable=no-method-argument
324
        cls, *args = args
325
        _conditional_raise(cls is Product, TypeError)
326
        # Similar to https://github.com/PyCQA/pylint/issues/1802
327
        values = cls.__defaults.copy()  # pylint: disable=protected-access
328
        fields_iter = iter(cls.__fields)  # pylint: disable=protected-access
329
        for arg, field in zip(args, fields_iter):
330
            values[field] = arg
331
        for field in fields_iter:
332
            if field in values and field not in kwargs:
333
                continue
334
            values[field] = kwargs.pop(field)
335
        _conditional_raise(kwargs, TypeError, kwargs)
336
        return super(Product, cls).__new__(
337
            cls,
338
            [
339
                values[field]
340
                for field in cls.__fields  # pylint: disable=protected-access
341
            ],
342
        )
343
344
    __repr = True
345
    __eq = True
346
    __order = False
347
    __eq_succeeded = None
348
349
    # Both of these are for consistency with modules defined in the stdlib.
350
    # BOOM!
351
    def __init_subclass__(
352
        cls,
353
        *,
354
        repr=None,  # pylint: disable=redefined-builtin
355
        eq=None,  # pylint: disable=invalid-name
356
        order=None,
357
        **kwargs
358
    ):
359
        super().__init_subclass__(**kwargs)
360
361
        if repr is not None:
362
            cls.__repr = repr
363
        if eq is not None:
364
            cls.__eq = eq
365
        if order is not None:
366
            cls.__order = order
367
368
        _ordering_options_are_valid(eq=cls.__eq, order=cls.__order)
369
370
        cls.__annotations = _product_args_from_annotations(cls)
371
        cls.__fields = {field: index for (index, field) in enumerate(cls.__annotations)}
372
373
        cls.__defaults = {}
374
        field_names = iter(reversed(tuple(cls.__annotations)))
375
        for field in field_names:
376
            default = getattr(cls, field, inspect.Parameter.empty)
377
            if default is inspect.Parameter.empty:
378
                break
379
            cls.__defaults[field] = default
380
        _conditional_raise(
381
            any(
382
                getattr(cls, field, inspect.Parameter.empty)
383
                is not inspect.Parameter.empty
384
                for field in field_names
385
            ),
386
            TypeError,
387
        )
388
389
        _product_new(cls, cls.__annotations, cls.__defaults)
390
391
        cls.__eq_succeeded = False
392
        if cls.__eq:
393
            cls.__eq_succeeded = not _cant_set_new_functions(
394
                cls,
395
                _prewritten_methods.PrewrittenProductMethods.__eq__,
396
                _prewritten_methods.PrewrittenProductMethods.__ne__,
397
            )
398
399
        if cls.__order:
400
            _set_ordering(
401
                can_set=cls.__eq_succeeded,
402
                setter=_cant_set_new_functions,
403
                cls=cls,
404
                source=_prewritten_methods.PrewrittenProductMethods,
405
            )
406
407
    def __dir__(self):
408
        return super().__dir__() + list(self.__fields)
409
410
    def __getattribute__(self, name):
411
        try:
412
            return super().__getattribute__(name)
413
        except AttributeError:
414
            index = self.__fields.get(name)
415
            if index is None:
416
                raise
417
            return tuple.__getitem__(self, index)
418
419
    __setattr__ = _prewritten_methods.PrewrittenProductMethods.__setattr__
420
    __delattr__ = _prewritten_methods.PrewrittenProductMethods.__delattr__
421
    __bool__ = _prewritten_methods.PrewrittenProductMethods.__bool__
422
423
    @property
424
    def __repr__(self):
425
        if self.__repr:
426
            return _prewritten_methods.PrewrittenProductMethods.__repr__.__get__(
427
                self, type(self)
428
            )
429
        return super().__repr__
430
431
    @property
432
    def __hash__(self):
433
        if self.__eq_succeeded:
434
            return _prewritten_methods.PrewrittenProductMethods.__hash__.__get__(
435
                self, type(self)
436
            )
437
        return super().__hash__
438
439
    @property
440
    def __eq__(self):  # pylint: disable=unexpected-special-method-signature
441
        if self.__eq_succeeded:
442
            # I think this is a Pylint bug, but I'm not sure how to reduce it.
443
            # pylint: disable=no-value-for-parameter
444
            return _prewritten_methods.PrewrittenProductMethods.__eq__.__get__(
445
                self, type(self)
446
            )
447
        return super().__eq__
448
449
    @property
450
    def __ne__(self):  # pylint: disable=unexpected-special-method-signature
451
        if self.__eq_succeeded:
452
            # I think this is a Pylint bug, but I'm not sure how to reduce it.
453
            # pylint: disable=no-value-for-parameter
454
            return _prewritten_methods.PrewrittenProductMethods.__ne__.__get__(
455
                self, type(self)
456
            )
457
        return super().__ne__
458
459
    @property
460
    def __lt__(self):  # pylint: disable=unexpected-special-method-signature
461
        if self.__order:
462
            # I think this is a Pylint bug, but I'm not sure how to reduce it.
463
            # pylint: disable=no-value-for-parameter
464
            return _prewritten_methods.PrewrittenProductMethods.__lt__.__get__(
465
                self, type(self)
466
            )
467
        return super().__lt__
468
469
    @property
470
    def __le__(self):  # pylint: disable=unexpected-special-method-signature
471
        if self.__order:
472
            # I think this is a Pylint bug, but I'm not sure how to reduce it.
473
            # pylint: disable=no-value-for-parameter
474
            return _prewritten_methods.PrewrittenProductMethods.__le__.__get__(
475
                self, type(self)
476
            )
477
        return super().__le__
478
479
    @property
480
    def __gt__(self):  # pylint: disable=unexpected-special-method-signature
481
        if self.__order:
482
            # I think this is a Pylint bug, but I'm not sure how to reduce it.
483
            # pylint: disable=no-value-for-parameter
484
            return _prewritten_methods.PrewrittenProductMethods.__gt__.__get__(
485
                self, type(self)
486
            )
487
        return super().__gt__
488
489
    @property
490
    def __ge__(self):  # pylint: disable=unexpected-special-method-signature
491
        if self.__order:
492
            # I think this is a Pylint bug, but I'm not sure how to reduce it.
493
            # pylint: disable=no-value-for-parameter
494
            return _prewritten_methods.PrewrittenProductMethods.__ge__.__get__(
495
                self, type(self)
496
            )
497
        return super().__ge__
498
499
500
__all__ = ["Ctor", "Product", "Sum"]
501