1
|
1 |
|
import functools |
2
|
1 |
|
import inspect |
3
|
1 |
|
from typing import Dict, Type, Union, Any, TypeVar, Callable |
4
|
|
|
|
5
|
1 |
|
from .interfaces import SpecialDepDefinition |
6
|
1 |
|
from .exceptions import UnresolvableType, DuplicateDefinition |
7
|
1 |
|
from .definitions import normalise |
8
|
|
|
|
9
|
1 |
|
UNRESOLVABLE_TYPES = [str, int, float, bool] |
10
|
|
|
|
11
|
1 |
|
X = TypeVar("X") |
12
|
|
|
|
13
|
1 |
|
DepDefinition = Any |
14
|
|
|
|
15
|
|
|
|
16
|
1 |
|
class Container: |
17
|
1 |
|
_registered_types: Dict[Type, SpecialDepDefinition] |
18
|
|
|
|
19
|
1 |
|
def __init__(self): |
20
|
1 |
|
self._registered_types = {} |
21
|
|
|
|
22
|
1 |
|
def define(self, dep: Union[Type[X], Type], resolver: DepDefinition) -> None: |
23
|
1 |
|
if dep in self._registered_types: |
24
|
1 |
|
raise DuplicateDefinition() |
25
|
1 |
|
self._registered_types[dep] = normalise(resolver, self) |
26
|
|
|
|
27
|
1 |
|
def resolve(self, dep_type: Type[X], suppress_error=False) -> X: |
28
|
1 |
|
try: |
29
|
1 |
|
if dep_type in UNRESOLVABLE_TYPES: |
30
|
1 |
|
raise UnresolvableType(f"Cannot construct type {dep_type}") |
31
|
1 |
|
registered_type = self._registered_types.get(dep_type, dep_type) |
32
|
1 |
|
return self._build(registered_type) |
33
|
1 |
|
except UnresolvableType as inner_error: |
34
|
1 |
|
if not suppress_error: |
35
|
1 |
|
raise UnresolvableType( |
36
|
|
|
f"Cannot construct type {dep_type.__name__}" |
37
|
|
|
) from inner_error |
38
|
1 |
|
return None # type: ignore |
39
|
|
|
|
40
|
1 |
|
def partial(self, func: Callable[..., X], keys_to_skip=None) -> Callable[..., X]: |
41
|
1 |
|
spec = inspect.getfullargspec(func) |
42
|
1 |
|
bindable_deps = self._infer_dependencies( |
43
|
|
|
spec, suppress_error=True, keys_to_skip=keys_to_skip or [] |
44
|
|
|
) |
45
|
1 |
|
return functools.partial(func, **bindable_deps) |
46
|
|
|
|
47
|
1 |
|
def __getitem__(self, dep: Type[X]) -> X: |
48
|
1 |
|
return self.resolve(dep) |
49
|
|
|
|
50
|
1 |
|
def __setitem__(self, dep: Type, resolver: DepDefinition): |
51
|
1 |
|
self.define(dep, resolver) |
52
|
|
|
|
53
|
1 |
|
def _build(self, dep_type: Any) -> Any: |
54
|
1 |
|
if isinstance(dep_type, SpecialDepDefinition): |
55
|
1 |
|
return dep_type.get_instance(self._build) |
56
|
1 |
|
return self._reflection_build(dep_type) |
57
|
|
|
|
58
|
1 |
|
def _reflection_build(self, dep_type: Type[X]) -> X: |
59
|
1 |
|
spec = inspect.getfullargspec(dep_type.__init__) |
60
|
1 |
|
sub_deps = self._infer_dependencies(spec) |
61
|
1 |
|
return dep_type(**sub_deps) # type: ignore |
62
|
|
|
|
63
|
1 |
|
def _infer_dependencies( |
64
|
|
|
self, spec: inspect.FullArgSpec, suppress_error=False, keys_to_skip=[] |
65
|
|
|
): |
66
|
1 |
|
sub_deps = { |
67
|
|
|
key: self.resolve(sub_dep_type, suppress_error=suppress_error) |
68
|
|
|
for (key, sub_dep_type) in spec.annotations.items() |
69
|
|
|
if key != "return" and sub_dep_type != Any and key not in keys_to_skip |
70
|
|
|
} |
71
|
1 |
|
filtered_deps = {key: dep for (key, dep) in sub_deps.items() if dep is not None} |
72
|
|
|
return filtered_deps |
73
|
|
|
|