ReflectionUtils.injection()   A
last analyzed

Complexity

Conditions 2

Size

Total Lines 22
Code Lines 9

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 9
dl 0
loc 22
rs 9.95
c 0
b 0
f 0
cc 2
nop 3
1
# SPDX-FileCopyrightText: Copyright 2020-2023, Contributors to pocketutils
2
# SPDX-PackageHomePage: https://github.com/dmyersturnbull/pocketutils
3
# SPDX-License-Identifier: Apache-2.0
4
"""
5
6
"""
7
8
import inspect
9
import sys
10
import typing
11
from collections.abc import Callable, Generator, Mapping
12
from dataclasses import dataclass
13
from inspect import Parameter
14
from types import MappingProxyType
15
from typing import Any, Self, TypeVar
16
17
__all__ = ["ReflectionUtils", "ReflectionTools"]
18
19
T_co = TypeVar("T_co", covariant=True)
20
21
22
@dataclass(slots=True, frozen=True)
23
class ReflectionUtils:
24
    def get_generic_arg(self: Self, clazz: type[T_co], bound: type[T_co] | None = None) -> type[T_co]:
25
        """
26
        Finds the generic argument (specific TypeVar) of a `typing.Generic` class.
27
        **Assumes that `clazz` only has one type parameter. Always returns the first.**
28
29
        Args:
30
            clazz: The Generic class
31
            bound: If non-None, requires the returned type to be a subclass of `bound` (or equal to it)
32
33
        Returns:
34
            The class
35
36
        Raises:
37
            AssertionError: For most errors
38
        """
39
        bases = clazz.__orig_bases__
40
        try:
41
            param = typing.get_args(bases[0])[0]
42
        except KeyError:
43
            msg = f"Failed to get generic type on {clazz}"
44
            raise AssertionError(msg)
45
        if not issubclass(param, bound):
46
            msg = f"{param} is not a {bound}"
47
            raise AssertionError(msg)
48
        return param
49
50
    def subclass_dict(self: Self, clazz: type[T_co], concrete: bool = False) -> Mapping[str, type[T_co]]:
51
        return {c.__name__: c for c in self.subclasses(clazz, concrete=concrete)}
52
53
    def subclasses(self: Self, clazz: type[T_co], concrete: bool = False) -> Generator[type[T_co], None, None]:
54
        for subclass in clazz.__subclasses__():
55
            yield from self.subclasses(subclass, concrete=concrete)
56
            if not concrete or not inspect.isabstract(subclass) and not subclass.__name__.startswith("_"):
57
                yield subclass
58
59
    def default_arg_values(self: Self, func: Callable[..., Any]) -> Mapping[str, Any | None]:
60
        return {k: v.default for k, v in self.optional_args(func).items()}
61
62
    def required_args(self: Self, func: Callable[..., Any]) -> Mapping[str, MappingProxyType]:
63
        """
64
        Finds parameters that lack default values.
65
66
        Args:
67
            func: A function or method
68
69
        Returns:
70
            A dict mapping parameter names to instances of `MappingProxyType`,
71
            just as `inspect.signature(func).parameters` does.
72
        """
73
        return self._args(func, True)
74
75
    def optional_args(self: Self, func: Callable[..., Any]) -> Mapping[str, Parameter]:
76
        """
77
        Finds parameters that have default values.
78
79
        Args:
80
            func: A function or method
81
82
        Returns:
83
            A dict mapping parameter names to instances of `MappingProxyType`,
84
            just as `inspect.signature(func).parameters` does.
85
        """
86
        return self._args(func, False)
87
88
    def injection(self: Self, fully_qualified: str, clazz: type[T_co]) -> type[T_co]:
89
        """
90
        Gets a **class** by its fully-resolved class name.
91
92
        Args:
93
            fully_qualified: Dotted syntax
94
            clazz: Class
95
96
        Returns:
97
            The Type
98
99
        Raises:
100
            InjectionError: If the class was not found
101
        """
102
        s = fully_qualified
103
        mod = s[: s.rfind(".")]
104
        clz = s[s.rfind(".") :]
105
        try:
106
            return getattr(sys.modules[mod], clz)
107
        except AttributeError:
108
            msg = f"Did not find {clazz} by fully-qualified class name {fully_qualified}"
109
            raise LookupError(msg) from None
110
111
    def _args(self: Self, func: Callable[..., Any], req: bool) -> dict[str, Parameter]:
112
        signature = inspect.signature(func)
113
        return {
114
            k: v
115
            for k, v in signature.parameters.items()
116
            if req and v.default is Parameter.empty or not req and v.default is not Parameter.empty
117
        }
118
119
120
ReflectionTools = ReflectionUtils()
121