1 | 1 | import logging |
|
2 | 1 | from contextlib import ExitStack |
|
3 | 1 | from copy import copy |
|
4 | 1 | from typing import ( |
|
5 | Collection, |
||
6 | Union, |
||
7 | Type, |
||
8 | TypeVar, |
||
9 | Optional, |
||
10 | cast, |
||
11 | ContextManager, |
||
12 | Iterator, |
||
13 | Generator, |
||
14 | Callable, |
||
15 | List, |
||
16 | ) |
||
17 | |||
18 | 1 | from lagom import Container |
|
19 | 1 | from lagom.compilaton import mypyc_attr |
|
20 | 1 | from lagom.definitions import ConstructionWithContainer, SingletonWrapper, Alias |
|
21 | 1 | from lagom.exceptions import InvalidDependencyDefinition |
|
22 | 1 | from lagom.interfaces import ( |
|
23 | ReadableContainer, |
||
24 | SpecialDepDefinition, |
||
25 | CallTimeContainerUpdate, |
||
26 | ) |
||
27 | |||
28 | 1 | X = TypeVar("X") |
|
29 | |||
30 | |||
31 | 1 | @mypyc_attr(allow_interpreted_subclasses=True) |
|
32 | 1 | class ContextContainer(Container): |
|
33 | """ |
||
34 | Wraps a regular container but is a ContextManager for use within a `with`. |
||
35 | |||
36 | >>> from tests.examples import SomeClass, SomeClassManager |
||
37 | >>> from lagom import Container |
||
38 | >>> from typing import ContextManager |
||
39 | >>> |
||
40 | >>> # The regular container |
||
41 | >>> c = Container() |
||
42 | >>> |
||
43 | >>> # register a context manager for SomeClass |
||
44 | >>> c[ContextManager[SomeClass]] = SomeClassManager |
||
45 | >>> |
||
46 | >>> context_c = ContextContainer(c, context_types=[SomeClass]) |
||
47 | >>> with context_c as c: |
||
48 | ... c[SomeClass] |
||
49 | <tests.examples.SomeClass object at ...> |
||
50 | """ |
||
51 | |||
52 | 1 | exit_stack: Optional[ExitStack] = None |
|
53 | 1 | _context_types: Collection[Type] |
|
54 | 1 | _context_singletons: Collection[Type] |
|
55 | |||
56 | 1 | def __init__( |
|
57 | self, |
||
58 | container: Container, |
||
59 | context_types: Collection[Type], |
||
60 | context_singletons: Collection[Type] = tuple(), |
||
61 | log_undefined_deps: Union[bool, logging.Logger] = False, |
||
62 | ): |
||
63 | 1 | self._context_types = context_types |
|
64 | 1 | self._context_singletons = context_singletons |
|
65 | 1 | super().__init__(container, log_undefined_deps) |
|
66 | |||
67 | 1 | def clone(self) -> "ContextContainer": |
|
68 | """returns a copy of the container |
||
69 | :return: |
||
70 | """ |
||
71 | 1 | return ContextContainer( |
|
72 | self, |
||
73 | context_types=self._context_types, |
||
74 | context_singletons=self._context_singletons, |
||
75 | log_undefined_deps=self._undefined_logger, |
||
76 | ) |
||
77 | |||
78 | 1 | def __enter__(self): |
|
79 | 1 | if not self.exit_stack: |
|
80 | # All actual context definitions happen on a clone so that there's isolation between invocations |
||
81 | 1 | in_context = self.clone() |
|
82 | 1 | for dep_type in set(self._context_types): |
|
83 | 1 | in_context[dep_type] = self._context_type_def(dep_type) |
|
84 | 1 | for dep_type in set(self._context_singletons): |
|
85 | 1 | in_context[dep_type] = self._singleton_type_def(dep_type) |
|
86 | 1 | in_context.exit_stack = ExitStack() |
|
87 | |||
88 | # The parent context manager keeps track of the inner clone |
||
89 | 1 | self.exit_stack = ExitStack() |
|
90 | 1 | self.exit_stack.enter_context(in_context) |
|
91 | 1 | return in_context |
|
92 | 1 | return self |
|
93 | |||
94 | 1 | def __exit__(self, exc_type, exc_val, exc_tb): |
|
95 | 1 | if self.exit_stack: |
|
96 | 1 | self.exit_stack.close() |
|
97 | 1 | self.exit_stack = None |
|
98 | |||
99 | 1 | def partial( |
|
100 | self, |
||
101 | func: Callable[..., X], |
||
102 | shared: Optional[List[Type]] = None, |
||
103 | container_updater: Optional[CallTimeContainerUpdate] = None, |
||
104 | ) -> Callable[..., X]: |
||
105 | 1 | def _with_context(*args, **kwargs): |
|
106 | 1 | with self as c: |
|
107 | # TODO: Try and move this partial outside the function as this is expensive |
||
108 | 1 | base_partial = super(ContextContainer, c).partial( |
|
109 | func, shared, container_updater |
||
110 | ) |
||
111 | 1 | return base_partial(*args, **kwargs) |
|
112 | |||
113 | 1 | return _with_context |
|
114 | |||
115 | 1 | View Code Duplication | def magic_partial( |
0 ignored issues
–
show
Duplication
introduced
by
![]() |
|||
116 | self, |
||
117 | func: Callable[..., X], |
||
118 | shared: Optional[List[Type]] = None, |
||
119 | keys_to_skip: Optional[List[str]] = None, |
||
120 | skip_pos_up_to: int = 0, |
||
121 | container_updater: Optional[CallTimeContainerUpdate] = None, |
||
122 | ) -> Callable[..., X]: |
||
123 | 1 | def _with_context(*args, **kwargs): |
|
124 | 1 | with self as c: |
|
125 | # TODO: Try and move this partial outside the function as this is expensive |
||
126 | 1 | base_partial = super(ContextContainer, c).magic_partial( |
|
127 | func, shared, keys_to_skip, skip_pos_up_to, container_updater |
||
128 | ) |
||
129 | 1 | return base_partial(*args, **kwargs) |
|
130 | |||
131 | 1 | return _with_context |
|
132 | |||
133 | 1 | def _context_type_def(self, dep_type: Type): |
|
134 | 1 | type_def = self.get_definition(ContextManager[dep_type]) or self.get_definition(Iterator[dep_type]) or self.get_definition(Generator[dep_type, None, None]) # type: ignore |
|
135 | 1 | if type_def is None: |
|
136 | 1 | raise InvalidDependencyDefinition( |
|
137 | f"A ContextManager[{dep_type}] should be defined. " |
||
138 | f"This could be an Iterator[{dep_type}] or Generator[{dep_type}, None, None] " |
||
139 | f"with the @contextmanager decorator" |
||
140 | ) |
||
141 | 1 | if isinstance(type_def, Alias): |
|
142 | # Without this we create a definition that points to |
||
143 | # itself. |
||
144 | 1 | type_def = copy(type_def) |
|
145 | 1 | type_def.skip_definitions = True |
|
146 | 1 | return ConstructionWithContainer(lambda c: self._context_resolver(c, type_def)) # type: ignore |
|
147 | |||
148 | 1 | def _singleton_type_def(self, dep_type: Type): |
|
149 | """ |
||
150 | The same as context_type_def but acts as a singleton within this container |
||
151 | """ |
||
152 | 1 | return SingletonWrapper(self._context_type_def(dep_type)) |
|
153 | |||
154 | 1 | def _context_resolver(self, c: ReadableContainer, type_def: SpecialDepDefinition): |
|
155 | """ |
||
156 | Takes an existing definition which must be a context manager. Returns |
||
157 | the value of the context manager from __enter__ and then places the |
||
158 | __exit__ in this container's exit stack |
||
159 | """ |
||
160 | 1 | assert self.exit_stack, "Types can only be resolved within a with" |
|
161 | 1 | context_manager = cast(ContextManager, type_def.get_instance(c)) |
|
162 | return self.exit_stack.enter_context(context_manager) |
||
163 |