Passed
Push — master ( fb5575...a0aa6e )
by Max
01:00
created

structured_data.adt.Product.__init_subclass__()   C

Complexity

Conditions 8

Size

Total Lines 66
Code Lines 52

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 8
eloc 52
nop 6
dl 0
loc 66
rs 6.7042
c 0
b 0
f 0

How to fix   Long Method   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

1
"""Base classes for defining abstract data types.
2
3
This module provides three public members, which are used together.
4
5
Given a structure, possibly a choice of different structures, that you'd like
6
to associate with a type:
7
8
- First, create a class, that subclasses the Sum class.
9
- Then, for each possible structure, add an attribute annotation to the class
10
  with the desired name of the constructor, and a type of ``Ctor``, with the
11
  types within the constructor as arguments.
12
13
To look inside an ADT instance, use the functions from the
14
:mod:`structured_data.match` module.
15
16
Putting it together:
17
18
>>> from structured_data import match
19
>>> class Example(Sum):
20
...     FirstConstructor: Ctor[int, str]
21
...     SecondConstructor: Ctor[bytes]
22
...     ThirdConstructor: Ctor
23
...     def __iter__(self):
24
...         matchable = match.Matchable(self)
25
...         if matchable(Example.FirstConstructor(match.pat.count, match.pat.string)):
26
...             count, string = matchable[match.pat.count, match.pat.string]
27
...             for _ in range(count):
28
...                 yield string
29
...         elif matchable(Example.SecondConstructor(match.pat.bytes)):
30
...             bytes_ = matchable[match.pat.bytes]
31
...             for byte in bytes_:
32
...                 yield chr(byte)
33
...         elif matchable(Example.ThirdConstructor()):
34
...             yield "Third"
35
...             yield "Constructor"
36
>>> list(Example.FirstConstructor(5, "abc"))
37
['abc', 'abc', 'abc', 'abc', 'abc']
38
>>> list(Example.SecondConstructor(b"abc"))
39
['a', 'b', 'c']
40
>>> list(Example.ThirdConstructor())
41
['Third', 'Constructor']
42
"""
43
44
import inspect
45
import sys
46
import typing
47
48
from ._adt_constructor import ADTConstructor
49
from ._adt_constructor import make_constructor
50
from ._ctor import annotation_is_classvar
51
from ._ctor import get_args
52
from ._prewritten_methods import SUBCLASS_ORDER
53
from ._prewritten_methods import PrewrittenProductMethods
54
from ._prewritten_methods import PrewrittenSumMethods
55
56
_T = typing.TypeVar("_T")
57
58
59
if typing.TYPE_CHECKING:  # pragma: nocover
60
61
    class Ctor:
62
        """Dummy class for type-checking purposes."""
63
64
    class ConcreteCtor(typing.Generic[_T]):
65
        """Wrapper class for type-checking purposes.
66
67
        The type parameter should be a Tuple type of fixed size.
68
        Classes containing this annotation (meaning they haven't been
69
        processed by the ``adt`` decorator) should not be instantiated.
70
        """
71
72
73
else:
74
    from ._ctor import Ctor
75
76
77
# The locations that use this function should be rewritten to use one with a
78
# clearer name.
79
# Right now, this describes what the function does, not why it should be
80
# called.
81
def _conditional_raise(do_raise, exc_class, *args):
82
    if do_raise:
83
        raise exc_class(*args)
84
85
86
# This is fine.
87
def _name(cls: typing.Type[_T], function) -> str:
88
    """Return the name of a function accessed through a descriptor."""
89
    return function.__get__(None, cls).__name__
90
91
92
# This is mostly fine, though the list of classes is somewhat ad-hoc, to say
93
# the least.
94
def _cant_set_new_functions(cls: typing.Type[_T], *functions) -> typing.Optional[str]:
95
    for function in functions:
96
        name = _name(cls, function)
97
        existing = getattr(cls, name, None)
98
        if existing not in (
99
            getattr(object, name, None),
100
            getattr(Product, name, None),
101
            None,
102
            function,
103
        ):
104
            return name
105
    return None
106
107
108
# Maybe it would make more sense to pull the check logic out of this function,
109
# and require it explicitly where this is currently used.
110
def _set_new_functions(cls: typing.Type[_T], *functions) -> typing.Optional[str]:
111
    """Attempt to set the attributes corresponding to the functions on cls.
112
113
    If any attributes are already defined, fail *before* setting any, and
114
    return the already-defined name.
115
    """
116
    cant_set = _cant_set_new_functions(cls, *functions)
117
    if cant_set:
118
        return cant_set
119
    for function in functions:
120
        setattr(cls, _name(cls, function), function)
121
    return None
122
123
124
_K = typing.TypeVar("_K")
125
_V = typing.TypeVar("_V")
126
127
128
def _nillable_write(dct: typing.Dict[_K, _V], key: _K, value: typing.Optional[_V]):
129
    if value is None:
130
        dct.pop(key, typing.cast(_V, None))
131
    else:
132
        dct[key] = value
133
134
135
def _add_methods(cls: typing.Type[_T], do_set, *methods):
136
    methods_were_set = False
137
    if do_set:
138
        methods_were_set = not _set_new_functions(cls, *methods)
139
    return methods_were_set
140
141
142
def _sum_new(_cls: typing.Type[_T], subclasses):
143
    def base(cls, args):
144
        return super(_cls, cls).__new__(cls, args)
145
146
    new = _cls.__dict__.get("__new__", staticmethod(base))
147
148
    def __new__(cls, args):
149
        if cls not in subclasses:
150
            raise TypeError
151
        return new.__get__(None, cls)(cls, args)
152
153
    _cls.__new__ = staticmethod(__new__)  # type: ignore
154
155
156
def _product_new(
157
    _cls: typing.Type[_T],
158
    annotations: typing.Dict[str, typing.Any],
159
    defaults: typing.Dict[str, typing.Any],
160
):
161
    def __new__(*args, **kwargs):
162
        cls, *args = args
163
        return super(_cls, cls).__new__(cls, *args, **kwargs)
164
165
    __new__.__signature__ = inspect.signature(__new__).replace(
166
        parameters=[inspect.Parameter("cls", inspect.Parameter.POSITIONAL_ONLY)]
167
        + [
168
            inspect.Parameter(
169
                field,
170
                inspect.Parameter.POSITIONAL_OR_KEYWORD,
171
                annotation=annotation,
172
                default=defaults.get(field, inspect.Parameter.empty),
173
            )
174
            for (field, annotation) in annotations.items()
175
        ]
176
    )
177
    _cls.__new__ = __new__
178
179
180
def _all_annotations(
181
    cls: typing.Type[_T]
182
) -> typing.Iterator[typing.Tuple[typing.Type[_T], str, typing.Any]]:
183
    for superclass in reversed(cls.__mro__):
184
        for key, value in vars(superclass).get("__annotations__", {}).items():
185
            yield (superclass, key, value)
186
187
188
def _sum_args_from_annotations(cls: typing.Type[_T]) -> typing.Dict[str, typing.Tuple]:
189
    args: typing.Dict[str, typing.Tuple] = {}
190
    for superclass, key, value in _all_annotations(cls):
191
        _nillable_write(
192
            args, key, get_args(value, vars(sys.modules[superclass.__module__]))
193
        )
194
    return args
195
196
197
def _product_args_from_annotations(
198
    cls: typing.Type[_T]
199
) -> typing.Dict[str, typing.Any]:
200
    args: typing.Dict[str, typing.Any] = {}
201
    for _, key, value in _all_annotations(cls):
202
        if value == "None" or annotation_is_classvar(
203
            value, vars(sys.modules[cls.__module__])
204
        ):
205
            value = None
206
        _nillable_write(args, key, value)
207
    return args
208
209
210
def _ordering_options_are_valid(*, eq, order):
211
    if order and not eq:
212
        raise ValueError("eq must be true if order is true")
213
214
215
class Sum:
216
    """Base class of classes with disjoint constructors.
217
218
    Examines PEP 526 __annotations__ to determine subclasses.
219
220
    If repr is true, a __repr__() method is added to the class.
221
    If order is true, rich comparison dunder methods are added.
222
223
    The Sum class examines the class to find Ctor annotations.
224
    A Ctor annotation is the adt.Ctor class itself, or the result of indexing
225
    the class, either with a single type hint, or a tuple of type hints.
226
    All other annotations are ignored.
227
228
    The subclass is not subclassable, but has subclasses at each of the
229
    names that had Ctor annotations. Each subclass takes a fixed number of
230
    arguments, corresponding to the type hints given to its annotation, if any.
231
    """
232
233
    __slots__ = ()
234
235
    def __new__(*args, **kwargs):  # pylint: disable=no-method-argument
236
        cls, *args = args
237
        if not issubclass(cls, ADTConstructor):
238
            raise TypeError
239
        return super(Sum, cls).__new__(cls, *args, **kwargs)
240
241
    # Both of these are for consistency with modules defined in the stdlib.
242
    # BOOM!
243
    def __init_subclass__(
244
        cls,
245
        *,
246
        repr=True,  # pylint: disable=redefined-builtin
247
        eq=True,  # pylint: disable=invalid-name
248
        order=False,
249
        **kwargs
250
    ):
251
        super().__init_subclass__(**kwargs)
252
        if issubclass(cls, ADTConstructor):
253
            return
254
        _ordering_options_are_valid(eq=eq, order=order)
255
256
        subclass_order: typing.List[typing.Type[_T]] = []
257
258
        for name, args in _sum_args_from_annotations(cls).items():
259
            make_constructor(cls, name, args, subclass_order)
260
261
        SUBCLASS_ORDER[cls] = tuple(subclass_order)
262
263
        cls.__init_subclass__ = PrewrittenSumMethods.__init_subclass__  # type: ignore
264
265
        _sum_new(cls, frozenset(subclass_order))
266
267
        _set_new_functions(
268
            cls, PrewrittenSumMethods.__setattr__, PrewrittenSumMethods.__delattr__
269
        )
270
        _set_new_functions(cls, PrewrittenSumMethods.__bool__)
271
272
        _add_methods(cls, repr, PrewrittenSumMethods.__repr__)
273
274
        equality_methods_were_set = _add_methods(
275
            cls, eq, PrewrittenSumMethods.__eq__, PrewrittenSumMethods.__ne__
276
        )
277
278
        if equality_methods_were_set:
279
            cls.__hash__ = PrewrittenSumMethods.__hash__
280
281
        if order:
282
283
            _conditional_raise(
284
                not equality_methods_were_set,
285
                ValueError,
286
                "Can't add ordering methods if equality methods are provided.",
287
            )
288
            collision = _set_new_functions(
289
                cls,
290
                PrewrittenSumMethods.__lt__,
291
                PrewrittenSumMethods.__le__,
292
                PrewrittenSumMethods.__gt__,
293
                PrewrittenSumMethods.__ge__,
294
            )
295
            _conditional_raise(
296
                collision,
297
                TypeError,
298
                "Cannot overwrite attribute {collision} in class "
299
                "{name}. Consider using functools.total_ordering".format(
300
                    collision=collision, name=cls.__name__
301
                ),
302
            )
303
304
305
class Product(ADTConstructor, tuple):
306
    """Base class of classes with typed fields.
307
308
    Examines PEP 526 __annotations__ to determine fields.
309
310
    If repr is true, a __repr__() method is added to the class.
311
    If order is true, rich comparison dunder methods are added.
312
313
    The Product class examines the class to find annotations.
314
    Annotations with a value of "None" are discarded.
315
    Fields may have default values, and can be set to inspect.empty to
316
    indicate "no default".
317
318
    The subclass is subclassable. The implementation was designed with a focus
319
    on flexibility over ideals of purity, and therefore provides various
320
    optional facilities that conflict with, for example, Liskov
321
    substitutability. For the purposes of matching, each class is considered
322
    distinct.
323
    """
324
325
    __slots__ = ()
326
327
    def __new__(*args, **kwargs):  # pylint: disable=no-method-argument
328
        cls, *args = args
329
        _conditional_raise(cls is Product, TypeError)
330
        # Similar to https://github.com/PyCQA/pylint/issues/1802
331
        values = cls.__defaults.copy()  # pylint: disable=protected-access
332
        fields_iter = iter(cls.__fields)  # pylint: disable=protected-access
333
        for arg, field in zip(args, fields_iter):
334
            values[field] = arg
335
        for field in fields_iter:
336
            if field in values and field not in kwargs:
337
                continue
338
            values[field] = kwargs.pop(field)
339
        _conditional_raise(kwargs, TypeError, kwargs)
340
        return super(Product, cls).__new__(
341
            cls,
342
            [
343
                values[field]
344
                for field in cls.__fields  # pylint: disable=protected-access
345
            ],
346
        )
347
348
    __repr = True
349
    __eq = True
350
    __order = False
351
    __eq_succeeded = None
352
353
    # Both of these are for consistency with modules defined in the stdlib.
354
    # BOOM!
355
    def __init_subclass__(
356
        cls,
357
        *,
358
        repr=None,  # pylint: disable=redefined-builtin
359
        eq=None,  # pylint: disable=invalid-name
360
        order=None,
361
        **kwargs
362
    ):
363
        super().__init_subclass__(**kwargs)
364
365
        if repr is not None:
366
            cls.__repr = repr
367
        if eq is not None:
368
            cls.__eq = eq
369
        if order is not None:
370
            cls.__order = order
371
372
        _ordering_options_are_valid(eq=cls.__eq, order=cls.__order)
373
374
        cls.__annotations = _product_args_from_annotations(cls)
375
        cls.__fields = {field: index for (index, field) in enumerate(cls.__annotations)}
376
377
        cls.__defaults = {}
378
        field_names = iter(reversed(tuple(cls.__annotations)))
379
        for field in field_names:
380
            default = getattr(cls, field, inspect.Parameter.empty)
381
            if default is inspect.Parameter.empty:
382
                break
383
            cls.__defaults[field] = default
384
        _conditional_raise(
385
            any(
386
                getattr(cls, field, inspect.Parameter.empty)
387
                is not inspect.Parameter.empty
388
                for field in field_names
389
            ),
390
            TypeError,
391
        )
392
393
        _product_new(cls, cls.__annotations, cls.__defaults)
394
395
        cls.__eq_succeeded = False
396
        if cls.__eq:
397
            cls.__eq_succeeded = not _cant_set_new_functions(
398
                cls, PrewrittenProductMethods.__eq__, PrewrittenProductMethods.__ne__
399
            )
400
401
        if order:
402
403
            _conditional_raise(
404
                not cls.__eq_succeeded,
405
                ValueError,
406
                "Can't add ordering methods if equality methods are provided.",
407
            )
408
            collision = _cant_set_new_functions(
409
                cls,
410
                PrewrittenProductMethods.__lt__,
411
                PrewrittenProductMethods.__le__,
412
                PrewrittenProductMethods.__gt__,
413
                PrewrittenProductMethods.__ge__,
414
            )
415
            _conditional_raise(
416
                collision,
417
                TypeError,
418
                "Cannot overwrite attribute {collision} in class "
419
                "{name}. Consider using functools.total_ordering".format(
420
                    collision=collision, name=cls.__name__
421
                ),
422
            )
423
424
    def __dir__(self):
425
        return super().__dir__() + list(self.__fields)
426
427
    def __getattribute__(self, name):
428
        try:
429
            return super().__getattribute__(name)
430
        except AttributeError:
431
            index = self.__fields.get(name)
432
            if index is None:
433
                raise
434
            return tuple.__getitem__(self, index)
435
436
    __setattr__ = PrewrittenProductMethods.__setattr__
437
    __delattr__ = PrewrittenProductMethods.__delattr__
438
    __bool__ = PrewrittenProductMethods.__bool__
439
440
    @property
441
    def __repr__(self):
442
        if self.__repr:
443
            return PrewrittenProductMethods.__repr__.__get__(self, type(self))
444
        return super().__repr__
445
446
    @property
447
    def __hash__(self):
448
        if self.__eq_succeeded:
449
            return PrewrittenProductMethods.__hash__.__get__(self, type(self))
450
        return super().__hash__
451
452
    @property
453
    def __eq__(self):  # pylint: disable=unexpected-special-method-signature
454
        if self.__eq_succeeded:
455
            # I think this is a Pylint bug, but I'm not sure how to reduce it.
456
            # pylint: disable=no-value-for-parameter
457
            return PrewrittenProductMethods.__eq__.__get__(self, type(self))
458
        return super().__eq__
459
460
    @property
461
    def __ne__(self):  # pylint: disable=unexpected-special-method-signature
462
        if self.__eq_succeeded:
463
            # I think this is a Pylint bug, but I'm not sure how to reduce it.
464
            # pylint: disable=no-value-for-parameter
465
            return PrewrittenProductMethods.__ne__.__get__(self, type(self))
466
        return super().__ne__
467
468
    @property
469
    def __lt__(self):  # pylint: disable=unexpected-special-method-signature
470
        if self.__order:
471
            # I think this is a Pylint bug, but I'm not sure how to reduce it.
472
            # pylint: disable=no-value-for-parameter
473
            return PrewrittenProductMethods.__lt__.__get__(self, type(self))
474
        return super().__lt__
475
476
    @property
477
    def __le__(self):  # pylint: disable=unexpected-special-method-signature
478
        if self.__order:
479
            # I think this is a Pylint bug, but I'm not sure how to reduce it.
480
            # pylint: disable=no-value-for-parameter
481
            return PrewrittenProductMethods.__le__.__get__(self, type(self))
482
        return super().__le__
483
484
    @property
485
    def __gt__(self):  # pylint: disable=unexpected-special-method-signature
486
        if self.__order:
487
            # I think this is a Pylint bug, but I'm not sure how to reduce it.
488
            # pylint: disable=no-value-for-parameter
489
            return PrewrittenProductMethods.__gt__.__get__(self, type(self))
490
        return super().__gt__
491
492
    @property
493
    def __ge__(self):  # pylint: disable=unexpected-special-method-signature
494
        if self.__order:
495
            # I think this is a Pylint bug, but I'm not sure how to reduce it.
496
            # pylint: disable=no-value-for-parameter
497
            return PrewrittenProductMethods.__ge__.__get__(self, type(self))
498
        return super().__ge__
499
500
501
__all__ = ["Ctor", "Product", "Sum"]
502