Passed
Push — master ( f616a2...91f7f1 )
by Max
59s
created

structured_data.adt.Product.__init_subclass__()   B

Complexity

Conditions 5

Size

Total Lines 70
Code Lines 52

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 5
eloc 52
nop 6
dl 0
loc 70
rs 8.1042
c 0
b 0
f 0

How to fix   Long Method   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

1
"""Base classes for defining abstract data types.
2
3
This module provides three public members, which are used together.
4
5
Given a structure, possibly a choice of different structures, that you'd like
6
to associate with a type:
7
8
- First, create a class, that subclasses the Sum class.
9
- Then, for each possible structure, add an attribute annotation to the class
10
  with the desired name of the constructor, and a type of ``Ctor``, with the
11
  types within the constructor as arguments.
12
13
To look inside an ADT instance, use the functions from the
14
:mod:`structured_data.match` module.
15
16
Putting it together:
17
18
>>> from structured_data import match
19
>>> class Example(Sum):
20
...     FirstConstructor: Ctor[int, str]
21
...     SecondConstructor: Ctor[bytes]
22
...     ThirdConstructor: Ctor
23
...     def __iter__(self):
24
...         matchable = match.Matchable(self)
25
...         if matchable(Example.FirstConstructor(match.pat.count, match.pat.string)):
26
...             count, string = matchable[match.pat.count, match.pat.string]
27
...             for _ in range(count):
28
...                 yield string
29
...         elif matchable(Example.SecondConstructor(match.pat.bytes)):
30
...             bytes_ = matchable[match.pat.bytes]
31
...             for byte in bytes_:
32
...                 yield chr(byte)
33
...         elif matchable(Example.ThirdConstructor()):
34
...             yield "Third"
35
...             yield "Constructor"
36
>>> list(Example.FirstConstructor(5, "abc"))
37
['abc', 'abc', 'abc', 'abc', 'abc']
38
>>> list(Example.SecondConstructor(b"abc"))
39
['a', 'b', 'c']
40
>>> list(Example.ThirdConstructor())
41
['Third', 'Constructor']
42
"""
43
44
import inspect
45
import sys
46
import types
47
import typing
48
49
from ._adt_constructor import ADTConstructor
50
from ._adt_constructor import make_constructor
51
from ._ctor import get_args
52
from ._prewritten_methods import SUBCLASS_ORDER
53
from ._prewritten_methods import PrewrittenProductMethods
54
from ._prewritten_methods import PrewrittenSumMethods
55
56
_T = typing.TypeVar("_T")
57
58
59
if typing.TYPE_CHECKING:  # pragma: nocover
60
61
    class Ctor:
62
        """Dummy class for type-checking purposes."""
63
64
    class ConcreteCtor(typing.Generic[_T]):
65
        """Wrapper class for type-checking purposes.
66
67
        The type parameter should be a Tuple type of fixed size.
68
        Classes containing this annotation (meaning they haven't been
69
        processed by the ``adt`` decorator) should not be instantiated.
70
        """
71
72
73
else:
74
    from ._ctor import Ctor
75
76
77
def _conditional_raise(do_raise, exc_class, *args):
78
    if do_raise:
79
        raise exc_class(*args)
80
81
82
def _name(cls: typing.Type[_T], function) -> str:
83
    """Return the name of a function accessed through a descriptor."""
84
    return function.__get__(None, cls).__name__
85
86
87
def _cant_set_new_functions(cls: typing.Type[_T], *functions) -> typing.Optional[str]:
88
    for function in functions:
89
        name = _name(cls, function)
90
        existing = getattr(cls, name, None)
91
        if existing not in (
92
            getattr(object, name, None),
93
            getattr(Product, name, None),
94
            None,
95
            function,
96
        ):
97
            return name
98
    return None
99
100
101
def _set_new_functions(cls: typing.Type[_T], *functions) -> typing.Optional[str]:
102
    """Attempt to set the attributes corresponding to the functions on cls.
103
104
    If any attributes are already defined, fail *before* setting any, and
105
    return the already-defined name.
106
    """
107
    cant_set = _cant_set_new_functions(cls, *functions)
108
    if cant_set:
109
        return cant_set
110
    for function in functions:
111
        setattr(cls, _name(cls, function), function)
112
    return None
113
114
115
_K = typing.TypeVar("_K")
116
_V = typing.TypeVar("_V")
117
118
119
def _nillable_write(dct: typing.Dict[_K, _V], key: _K, value: typing.Optional[_V]):
120
    if value is None:
121
        dct.pop(key, typing.cast(_V, None))
122
    else:
123
        dct[key] = value
124
125
126
def _add_methods(cls: typing.Type[_T], do_set, *methods):
127
    methods_were_set = False
128
    if do_set:
129
        methods_were_set = not _set_new_functions(cls, *methods)
130
    return methods_were_set
131
132
133
def _sum_new(_cls: typing.Type[_T], subclasses):
134
    def base(cls, args):
135
        return super(_cls, cls).__new__(cls, args)
136
137
    new = _cls.__dict__.get("__new__", staticmethod(base))
138
139
    def __new__(cls, args):
140
        _conditional_raise(cls not in subclasses, TypeError)
141
        return new.__get__(None, cls)(cls, args)
142
143
    _cls.__new__ = staticmethod(__new__)  # type: ignore
144
145
146
def _product_new(
147
    _cls: typing.Type[_T],
148
    annotations: typing.Dict[str, typing.Any],
149
    defaults: typing.Dict[str, typing.Any],
150
):
151
    def __new__(*args, **kwargs):
152
        cls, *args = args
153
        return super(_cls, cls).__new__(cls, *args, **kwargs)
154
155
    __new__.__signature__ = inspect.signature(__new__).replace(
156
        parameters=[inspect.Parameter("cls", inspect.Parameter.POSITIONAL_ONLY)]
157
        + [
158
            inspect.Parameter(
159
                field,
160
                inspect.Parameter.POSITIONAL_OR_KEYWORD,
161
                annotation=annotation,
162
                default=defaults.get(field, inspect.Parameter.empty),
163
            )
164
            for (field, annotation) in annotations.items()
165
        ]
166
    )
167
    _cls.__new__ = __new__
168
169
170
def _all_annotations(
171
    cls: typing.Type[_T]
172
) -> typing.Iterator[typing.Tuple[typing.Type[_T], str, typing.Any]]:
173
    for superclass in reversed(cls.__mro__):
174
        for key, value in vars(superclass).get("__annotations__", {}).items():
175
            yield (superclass, key, value)
176
177
178
def _sum_args_from_annotations(cls: typing.Type[_T]) -> typing.Dict[str, typing.Tuple]:
179
    args: typing.Dict[str, typing.Tuple] = {}
180
    for superclass, key, value in _all_annotations(cls):
181
        _nillable_write(
182
            args, key, get_args(value, vars(sys.modules[superclass.__module__]))
183
        )
184
    return args
185
186
187
def _product_args_from_annotations(
188
    cls: typing.Type[_T]
189
) -> typing.Dict[str, typing.Any]:
190
    args: typing.Dict[str, typing.Any] = {}
191
    for _, key, value in _all_annotations(cls):
192
        if value == "None":
193
            value = None
194
        _nillable_write(args, key, value)
195
    return args
196
197
198
def _conditional_update(obj, **kwargs):
199
    for key, value in kwargs.items():
200
        if value is not None:
201
            setattr(obj, key, value)
202
203
204
class Sum:
205
    """Base class of classes with disjoint constructors.
206
207
    Examines PEP 526 __annotations__ to determine subclasses.
208
209
    If repr is true, a __repr__() method is added to the class.
210
    If order is true, rich comparison dunder methods are added.
211
212
    The Sum class examines the class to find Ctor annotations.
213
    A Ctor annotation is the adt.Ctor class itself, or the result of indexing
214
    the class, either with a single type hint, or a tuple of type hints.
215
    All other annotations are ignored.
216
217
    The subclass is not subclassable, but has subclasses at each of the
218
    names that had Ctor annotations. Each subclass takes a fixed number of
219
    arguments, corresponding to the type hints given to its annotation, if any.
220
    """
221
222
    __slots__ = ()
223
224
    def __new__(*args, **kwargs):  # pylint: disable=no-method-argument
225
        cls, *args = args
226
        _conditional_raise(not issubclass(cls, ADTConstructor), TypeError)
227
        return super(Sum, cls).__new__(cls, *args, **kwargs)
228
229
    # Both of these are for consistency with modules defined in the stdlib.
230
    # BOOM!
231
    def __init_subclass__(
232
        cls,
233
        *,
234
        repr=True,  # pylint: disable=redefined-builtin
235
        eq=True,  # pylint: disable=invalid-name
236
        order=False,
237
        **kwargs
238
    ):
239
        super().__init_subclass__(**kwargs)
240
        if issubclass(cls, ADTConstructor):
241
            return
242
        _conditional_raise(
243
            order and not eq, ValueError, "eq must be true if order is true"
244
        )
245
246
        subclass_order: typing.List[typing.Type[_T]] = []
247
248
        for name, args in _sum_args_from_annotations(cls).items():
249
            make_constructor(cls, name, args, subclass_order)
250
251
        SUBCLASS_ORDER[cls] = tuple(subclass_order)
252
253
        cls.__init_subclass__ = PrewrittenSumMethods.__init_subclass__  # type: ignore
254
255
        _sum_new(cls, frozenset(subclass_order))
256
257
        _set_new_functions(
258
            cls, PrewrittenSumMethods.__setattr__, PrewrittenSumMethods.__delattr__
259
        )
260
        _set_new_functions(cls, PrewrittenSumMethods.__bool__)
261
262
        _add_methods(cls, repr, PrewrittenSumMethods.__repr__)
263
264
        equality_methods_were_set = _add_methods(
265
            cls, eq, PrewrittenSumMethods.__eq__, PrewrittenSumMethods.__ne__
266
        )
267
268
        if equality_methods_were_set:
269
            cls.__hash__ = PrewrittenSumMethods.__hash__
270
271
        if order:
272
273
            _conditional_raise(
274
                not equality_methods_were_set,
275
                ValueError,
276
                "Can't add ordering methods if equality methods are provided.",
277
            )
278
            collision = _set_new_functions(
279
                cls,
280
                PrewrittenSumMethods.__lt__,
281
                PrewrittenSumMethods.__le__,
282
                PrewrittenSumMethods.__gt__,
283
                PrewrittenSumMethods.__ge__,
284
            )
285
            _conditional_raise(
286
                collision,
287
                TypeError,
288
                "Cannot overwrite attribute {collision} in class "
289
                "{name}. Consider using functools.total_ordering".format(
290
                    collision=collision, name=cls.__name__
291
                ),
292
            )
293
294
295
class Product(ADTConstructor, tuple):
296
    """Base class of classes with typed fields.
297
298
    Examines PEP 526 __annotations__ to determine fields.
299
300
    If repr is true, a __repr__() method is added to the class.
301
    If order is true, rich comparison dunder methods are added.
302
303
    The Product class examines the class to find annotations.
304
    Annotations with a value of "None" are discarded.
305
    Fields may have default values, and can be set to inspect.empty to
306
    indicate "no default".
307
308
    The subclass is subclassable. The implementation was designed with a focus
309
    on flexibility over ideals of purity, and therefore provides various
310
    optional facilities that conflict with, for example, Liskov
311
    substitutability. For the purposes of matching, each class is considered
312
    distinct.
313
    """
314
315
    __slots__ = ()
316
317
    def __new__(*args, **kwargs):  # pylint: disable=no-method-argument
318
        cls, *args = args
319
        _conditional_raise(cls is Product, TypeError)
320
        # Similar to https://github.com/PyCQA/pylint/issues/1802
321
        values = cls.__defaults.copy()  # pylint: disable=protected-access
322
        fields_iter = iter(cls.__fields)  # pylint: disable=protected-access
323
        for arg, field in zip(args, fields_iter):
324
            values[field] = arg
325
        for field in fields_iter:
326
            if field in values and field not in kwargs:
327
                continue
328
            values[field] = kwargs.pop(field)
329
        _conditional_raise(kwargs, TypeError, kwargs)
330
        return super(Product, cls).__new__(
331
            cls,
332
            [
333
                values[field]
334
                for field in cls.__fields  # pylint: disable=protected-access
335
            ],
336
        )
337
338
    __repr = True
339
    __eq = True
340
    __order = False
341
    __eq_succeeded = None
342
343
    # Both of these are for consistency with modules defined in the stdlib.
344
    # BOOM!
345
    def __init_subclass__(
346
        cls,
347
        *,
348
        repr=None,  # pylint: disable=redefined-builtin
349
        eq=None,  # pylint: disable=invalid-name
350
        order=None,
351
        **kwargs
352
    ):
353
        super().__init_subclass__(**kwargs)
354
355
        overrides = types.SimpleNamespace()
356
        # This is really gross, but it seems to work, and it reduces the
357
        # cyclomatic complexity, so it must be good!
358
        overrides.__repr = repr  # pylint: disable=protected-access
359
        overrides.__eq = eq  # pylint: disable=protected-access
360
        overrides.__order = order  # pylint: disable=protected-access
361
362
        _conditional_update(cls, **vars(overrides))
363
364
        _conditional_raise(
365
            cls.__order and not cls.__eq, ValueError, "eq must be true if order is true"
366
        )
367
368
        cls.__annotations = _product_args_from_annotations(cls)
369
        cls.__fields = {field: index for (index, field) in enumerate(cls.__annotations)}
370
371
        cls.__defaults = {}
372
        field_names = iter(reversed(tuple(cls.__annotations)))
373
        for field in field_names:
374
            default = getattr(cls, field, inspect.Parameter.empty)
375
            if default is inspect.Parameter.empty:
376
                break
377
            cls.__defaults[field] = default
378
        _conditional_raise(
379
            any(
380
                getattr(cls, field, inspect.Parameter.empty)
381
                is not inspect.Parameter.empty
382
                for field in field_names
383
            ),
384
            TypeError,
385
        )
386
387
        _product_new(cls, cls.__annotations, cls.__defaults)
388
389
        cls.__eq_succeeded = False
390
        if cls.__eq:
391
            cls.__eq_succeeded = not _cant_set_new_functions(
392
                cls, PrewrittenProductMethods.__eq__, PrewrittenProductMethods.__ne__
393
            )
394
395
        if order:
396
397
            _conditional_raise(
398
                not cls.__eq_succeeded,
399
                ValueError,
400
                "Can't add ordering methods if equality methods are provided.",
401
            )
402
            collision = _cant_set_new_functions(
403
                cls,
404
                PrewrittenProductMethods.__lt__,
405
                PrewrittenProductMethods.__le__,
406
                PrewrittenProductMethods.__gt__,
407
                PrewrittenProductMethods.__ge__,
408
            )
409
            _conditional_raise(
410
                collision,
411
                TypeError,
412
                "Cannot overwrite attribute {collision} in class "
413
                "{name}. Consider using functools.total_ordering".format(
414
                    collision=collision, name=cls.__name__
415
                ),
416
            )
417
418
    def __dir__(self):
419
        return super().__dir__() + list(self.__fields)
420
421
    def __getattribute__(self, name):
422
        try:
423
            return super().__getattribute__(name)
424
        except AttributeError:
425
            index = self.__fields.get(name)
426
            if index is None:
427
                raise
428
            return tuple.__getitem__(self, index)
429
430
    __setattr__ = PrewrittenProductMethods.__setattr__
431
    __delattr__ = PrewrittenProductMethods.__delattr__
432
    __bool__ = PrewrittenProductMethods.__bool__
433
434
    @property
435
    def __repr__(self):
436
        if self.__repr:
437
            return PrewrittenProductMethods.__repr__.__get__(self, type(self))
438
        return super().__repr__
439
440
    @property
441
    def __hash__(self):
442
        if self.__eq_succeeded:
443
            return PrewrittenProductMethods.__hash__.__get__(self, type(self))
444
        return super().__hash__
445
446
    @property
447
    def __eq__(self):  # pylint: disable=unexpected-special-method-signature
448
        if self.__eq_succeeded:
449
            # I think this is a Pylint bug, but I'm not sure how to reduce it.
450
            # pylint: disable=no-value-for-parameter
451
            return PrewrittenProductMethods.__eq__.__get__(self, type(self))
452
        return super().__eq__
453
454
    @property
455
    def __ne__(self):  # pylint: disable=unexpected-special-method-signature
456
        if self.__eq_succeeded:
457
            # I think this is a Pylint bug, but I'm not sure how to reduce it.
458
            # pylint: disable=no-value-for-parameter
459
            return PrewrittenProductMethods.__ne__.__get__(self, type(self))
460
        return super().__ne__
461
462
    @property
463
    def __lt__(self):  # pylint: disable=unexpected-special-method-signature
464
        if self.__order:
465
            # I think this is a Pylint bug, but I'm not sure how to reduce it.
466
            # pylint: disable=no-value-for-parameter
467
            return PrewrittenProductMethods.__lt__.__get__(self, type(self))
468
        return super().__lt__
469
470
    @property
471
    def __le__(self):  # pylint: disable=unexpected-special-method-signature
472
        if self.__order:
473
            # I think this is a Pylint bug, but I'm not sure how to reduce it.
474
            # pylint: disable=no-value-for-parameter
475
            return PrewrittenProductMethods.__le__.__get__(self, type(self))
476
        return super().__le__
477
478
    @property
479
    def __gt__(self):  # pylint: disable=unexpected-special-method-signature
480
        if self.__order:
481
            # I think this is a Pylint bug, but I'm not sure how to reduce it.
482
            # pylint: disable=no-value-for-parameter
483
            return PrewrittenProductMethods.__gt__.__get__(self, type(self))
484
        return super().__gt__
485
486
    @property
487
    def __ge__(self):  # pylint: disable=unexpected-special-method-signature
488
        if self.__order:
489
            # I think this is a Pylint bug, but I'm not sure how to reduce it.
490
            # pylint: disable=no-value-for-parameter
491
            return PrewrittenProductMethods.__ge__.__get__(self, type(self))
492
        return super().__ge__
493
494
495
__all__ = ["Ctor", "Product", "Sum"]
496