Passed
Pull Request — main (#158)
by
unknown
03:04 queued 01:39
created

pincer.utils.api_object   A

Complexity

Total Complexity 18

Size/Duplication

Total Lines 122
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 18
eloc 59
dl 0
loc 122
rs 10
c 0
b 0
f 0
1
# Copyright Pincer 2021-Present
2
# Full MIT License can be found in `LICENSE` at the project root.
3
4
from __future__ import annotations
5
6
import copy
7
import logging
8
from dataclasses import dataclass, fields, _is_dataclass_instance
9
from enum import Enum, EnumMeta
10
from inspect import getfullargspec
11
from sys import modules
12
from typing import (
13
    Dict, Tuple, Union, Generic, TypeVar, Any, TYPE_CHECKING,
14
    List, get_type_hints, get_origin, get_args
15
)
16
17
from .conversion import convert
18
from .types import MissingType, Singleton, MISSING
19
from ..exceptions import InvalidAnnotation
20
21
if TYPE_CHECKING:
22
    from ..client import Client
23
    from ..core.http import HTTPClient
24
25
T = TypeVar("T")
26
27
_log = logging.getLogger(__package__)
28
29
30
def _asdict_ignore_none(obj: Generic[T]) -> Union[Tuple, Dict, T]:
31
    """
32
    Returns a dict from a dataclass that ignores
33
    all values that are None
34
    Modification of _asdict_inner from dataclasses
35
36
    :param obj:
37
        Dataclass obj
38
    """
39
40
    if _is_dataclass_instance(obj):
41
        result = []
42
        for f in fields(obj):
43
            value = _asdict_ignore_none(getattr(obj, f.name))
44
45
            if isinstance(value, Enum):
46
                result.append((f.name, value.value))
47
            # This if statement was added to the function
48
            elif not isinstance(value, MissingType):
49
                result.append((f.name, value))
50
51
        return dict(result)
52
53
    elif isinstance(obj, tuple) and hasattr(obj, '_fields'):
54
        return type(obj)(*[_asdict_ignore_none(v) for v in obj])
55
56
    elif isinstance(obj, (list, tuple)):
57
        return type(obj)(_asdict_ignore_none(v) for v in obj)
58
59
    elif isinstance(obj, dict):
60
        return type(obj)(
61
            (
62
                _asdict_ignore_none(k),
63
                _asdict_ignore_none(v)
64
            ) for k, v in obj.items()
65
        )
66
    else:
67
        return copy.deepcopy(obj)
68
69
70
class HTTPMeta(type):
71
    __meta_items: List[str] = ["_client", "_http"]
72
    __ori_annotations: Dict[str, type] = {}
73
74
    def __new__(mcs, name, base, mapping):
75
        for key in HTTPMeta.__meta_items:
76
            if mapping.get("__annotations__") and \
77
                    (value := mapping["__annotations__"].get(key)):
0 ignored issues
show
introduced by
invalid syntax (<unknown>, line 77)
Loading history...
78
                HTTPMeta.__ori_annotations.update({key: value})
79
                del mapping["__annotations__"][key]
80
81
        http_object = super().__new__(mcs, name, base, mapping)
82
83
        if getattr(http_object, "__annotations__", None):
84
            for k, v in HTTPMeta.__ori_annotations.items():
85
                http_object.__annotations__[k] = v
86
                setattr(http_object, k, None)
87
88
        return http_object
89
90
91
class TypeCache(metaclass=Singleton):
92
    cache = {}
93
94
    def __init__(self):
95
        lcp = modules.copy()
96
        for module in lcp:
97
            if not module.startswith("pincer"):
98
                continue
99
100
            TypeCache.cache.update(lcp[module].__dict__)
101
102
103
@dataclass
104
class APIObject(metaclass=HTTPMeta):
105
    """
106
    Represents an object which has been fetched from the Discord API.
107
    """
108
    _client: Client
109
    _http: HTTPClient
110
111
    def __get_types(self, attr: str, arg_type: type) -> Tuple[Any]:
112
        origin = get_origin(arg_type)
113
114
        if origin is Union:
115
            args = get_args(arg_type)
116
117
            if 2 <= len(args) < 4:
118
                return args
119
120
            raise InvalidAnnotation(
121
                f"Attribute `{attr}` in `{type(self).__name__}` has too many "
122
                f"or not enough arguments! (got {len(args)} expected 2-3)"
123
            )
124
125
        return arg_type,
126
127
    def __attr_convert(self, attr: str, attr_type: T) -> T:
128
        factory = attr_type
129
130
        if getattr(attr_type, "__factory__", None):
131
            factory = attr_type.__factory__
132
133
        return convert(
134
            getattr(self, attr),
135
            factory,
136
            attr_type,
137
            self._client
138
        )
139
140
    def __post_init__(self):
141
        TypeCache()
142
143
        attributes = get_type_hints(self, globalns=TypeCache.cache).items()
144
145
        for attr, attr_type in attributes:
146
            if attr.startswith('_'):
147
                continue
148
149
            types = self.__get_types(attr, attr_type)
150
151
            types = tuple(filter(
152
                lambda tpe: tpe is not None and tpe is not MISSING,
153
                types
154
            ))
155
156
            if not types:
157
                raise InvalidAnnotation(
158
                    f"Attribute `{attr}` in `{type(self).__name__}` only "
159
                    "consisted of missing/optional type!"
160
                )
161
162
            specific_tp = types[0]
163
164
            if tp := get_origin(specific_tp):
165
                specific_tp = tp
166
167
            if isinstance(specific_tp, EnumMeta) and not getattr(self, attr):
168
                attr_value = MISSING
169
            else:
170
                attr_value = self.__attr_convert(attr, specific_tp)
171
172
            setattr(self, attr, attr_value)
173
174
    @classmethod
175
    def __factory__(cls: Generic[T], *args, **kwargs) -> T:
176
        return cls.from_dict(*args, **kwargs)
177
178
    @classmethod
179
    def from_dict(
180
            cls: Generic[T],
181
            data: Dict[str, Union[str, bool, int, Any]]
182
    ) -> T:
183
        """
184
        Parse an API object from a dictionary.
185
        """
186
        if isinstance(data, cls):
187
            return data
188
189
        # Disable inspection for IDE because this is valid code for the
190
        # inherited classes:
191
        # noinspection PyArgumentList
192
        return cls(**dict(map(
193
            lambda key: (
194
                key,
195
                data[key].value if isinstance(data[key], Enum) else data[key]
196
            ),
197
            filter(
198
                lambda object_argument: data.get(object_argument) is not None,
199
                getfullargspec(cls.__init__).args
200
            )
201
        )))
202
203
    def to_dict(self) -> Dict:
204
        """
205
        Transform the current object to a dictionary representation.
206
        """
207
        return _asdict_ignore_none(self)
208