Test Failed
Pull Request — master (#226)
by Steve
02:44
created

AsyncContextContainer.magic_partial()   A

Complexity

Conditions 2

Size

Total Lines 17
Code Lines 13

Duplication

Lines 0
Ratio 0 %

Code Coverage

Tests 9
CRAP Score 2

Importance

Changes 0
Metric Value
cc 2
eloc 13
nop 6
dl 0
loc 17
ccs 9
cts 9
cp 1
crap 2
rs 9.75
c 0
b 0
f 0
1 1
import inspect
2 1
import logging
3 1
from asyncio import Lock
4 1
from contextlib import AsyncExitStack
5 1
from copy import copy
6 1
from functools import wraps
7
from typing import (
8
    Optional,
9
    Type,
10
    TypeVar,
11
    Awaitable,
12
    Generic,
13
    Collection,
14
    Union,
15
    ContextManager,
16
    AsyncContextManager,
17
    Iterator,
18
    Generator,
19
    AsyncGenerator,
20
    Callable,
21
    List,
22
)
23 1
24 1
from lagom.container import Container
25 1
from lagom.definitions import Alias, ConstructionWithContainer, SingletonWrapper
26 1
from lagom.exceptions import InvalidDependencyDefinition, MissingFeature
27 1
from lagom.experimental.definitions import AsyncConstructionWithContainer
28
from lagom.interfaces import (
29
    ReadableContainer,
30
    SpecialDepDefinition,
31
    CallTimeContainerUpdate,
32
    ContainerBoundFunction,
33 1
)
34 1
35
T = TypeVar("T")
36
X = TypeVar("X")
37 1
38 1
39 1
class AwaitableSingleton(Generic[T]):
40 1
    instance: Optional[T]
41 1
    constructor: ConstructionWithContainer[Awaitable[T]]
42
    container: Container
43 1
    _lock: Lock
44 1
45 1
    def __init__(self, constructor: ConstructionWithContainer, container: Container):
46 1
        self.instance = None
47 1
        self.constructor = constructor  # type: ignore
48
        self.container = container
49 1
        self._lock = Lock()
50 1
51 1
    async def get(self) -> T:
52 1
        if not self.instance:
53 1
            async with self._lock:
54 1
                if not self.instance:
55
                    self.instance = await self.constructor.get_instance(self.container)
56
        return self.instance
57 1
58 1
59 1
class _AsyncContextBoundFunction(ContainerBoundFunction[X]):
60 1
    """
61 1
    Represents an instance of a function bound to an async context container
62
    """
63 1
64
    __slots__ = ("async_context_container", "partially_bound_function", "__dict__")
65
66
    async_context_container: "AsyncContextContainer"
67
    partially_bound_function: ContainerBoundFunction
68
69
    def __init__(
70 1
        self,
71 1
        async_context_container: "AsyncContextContainer",
72 1
        partially_bound_function: ContainerBoundFunction,
73
    ):
74 1
        self.async_context_container = async_context_container
75
        self.partially_bound_function = partially_bound_function
76
77
    def __call__(self, *args, **kwargs) -> X:
78 1
        return self.__async_call__(*args, **kwargs)
79
80
    async def __async_call__(self, *args, **kwargs):
81
        async with self.async_context_container as c:
82
            return await self.partially_bound_function.rebind(c)(*args, **kwargs)
83
84
    def rebind(self, container: ReadableContainer) -> "ContainerBoundFunction[X]":
85 1
        return wraps(self.partially_bound_function)(
86 1
            _AsyncContextBoundFunction(
87 1
                self.async_context_container,
88
                self.partially_bound_function.rebind(container),
89 1
            )
90
        )
91 1
92 1
93 1
class AsyncContextContainer(Container):
94
    async_exit_stack: Optional[AsyncExitStack] = None
95 1
    _context_types: Collection[Type]
96 1
    _context_singletons: Collection[Type]
97 1
    _root_context: bool = True
98 1
99 1
    def __init__(
100 1
        self,
101 1
        container: Container,
102 1
        context_types: Collection[Type],
103
        context_singletons: Collection[Type] = tuple(),
104
        log_undefined_deps: Union[bool, logging.Logger] = False,
105 1
    ):
106 1
        super().__init__(container, log_undefined_deps)
107 1
        self._context_types = set(context_types)
108
        self._context_singletons = set(context_singletons)
109 1
110 1
    def clone(self) -> "AsyncContextContainer":
111 1
        """returns a copy of the container
112 1
        :return:
113
        """
114 1
        return AsyncContextContainer(
115
            self,
116
            context_types=self._context_types,
117
            context_singletons=self._context_singletons,
118
            log_undefined_deps=self._undefined_logger,
119
        )
120 1
121
    async def __aenter__(self):
122
        if not self.async_exit_stack and self._root_context:
123
            self.async_exit_stack = AsyncExitStack()
124
125 1
        if self.async_exit_stack and self._root_context:
126 1
            # All actual context definitions happen on a clone so that there's isolation between invocations
127
            in_context = self.clone()
128 1
            in_context.async_exit_stack = AsyncExitStack()
129
            in_context._root_context = False
130
131 1
            for dep_type in self._context_types:
132
                managed_dep = self._context_type_def(dep_type)
133 1
                key = Awaitable[dep_type] if isinstance(managed_dep, AsyncConstructionWithContainer) else dep_type  # type: ignore
134
                in_context[key] = managed_dep  # type: ignore
135 1
            for dep_type in self._context_singletons:
136
                managed_singleton = self._singleton_type_def(dep_type)
137
                key = AwaitableSingleton[dep_type] if isinstance(managed_singleton, AwaitableSingleton) else dep_type  # type: ignore
138
                in_context[key] = managed_singleton  # type: ignore
139
140
            # The parent context manager keeps track of the inner clone
141
            await self.async_exit_stack.enter_async_context(in_context)
142
            return in_context
143 1
        return self
144
145
    async def __aexit__(self, exc_type, exc_val, exc_tb):
146
        if self.async_exit_stack:
147
            await self.async_exit_stack.aclose()
148 1
            self.async_exit_stack = None
149 1
150
    def partial(
151 1
        self,
152
        func: Callable[..., X],
153
        shared: Optional[List[Type]] = None,
154 1
        container_updater: Optional[CallTimeContainerUpdate] = None,
155
    ) -> ContainerBoundFunction[X]:
156 1
        if not inspect.iscoroutinefunction(func):
157
            raise MissingFeature(
158 1
                "AsyncContextManager currently can only deal with async functions"
159 1
            )
160 1
        base_partial = super(AsyncContextContainer, self).partial(
161 1
            func, shared, container_updater
162
        )
163
164
        return wraps(base_partial)(_AsyncContextBoundFunction(self, base_partial))
165
166 1
    def magic_partial(
167
        self,
168
        func: Callable[..., X],
169 1
        shared: Optional[List[Type]] = None,
170 1
        keys_to_skip: Optional[List[str]] = None,
171 1
        skip_pos_up_to: int = 0,
172 1
        container_updater: Optional[CallTimeContainerUpdate] = None,
173 1
    ) -> ContainerBoundFunction[X]:
174
        if not inspect.iscoroutinefunction(func):
175 1
            raise MissingFeature(
176
                "AsyncContextManager currently can only deal with async functions"
177
            )
178
        base_partial = super(AsyncContextContainer, self).magic_partial(
179
            func, shared, keys_to_skip, skip_pos_up_to, container_updater
180
        )
181 1
182 1
        return wraps(base_partial)(_AsyncContextBoundFunction(self, base_partial))
183 1
184
    def _context_type_def(self, dep_type: Type):
185 1
        type_def = self.get_definition(ContextManager[dep_type]) or self.get_definition(Iterator[dep_type]) or self.get_definition(Generator[dep_type, None, None]) or self.get_definition(AsyncGenerator[dep_type, None]) or self.get_definition(AsyncContextManager[dep_type])  # type: ignore
186
        if type_def is None:
187
            raise InvalidDependencyDefinition(
188
                f"A ContextManager[{dep_type}] should be defined. "
189
                f"This could be an Iterator[{dep_type}] or Generator[{dep_type}, None, None] "
190
                f"with the @contextmanager decorator"
191
            )
192
        if isinstance(type_def, Alias):
193 1
            # Without this we create a definition that points to
194 1
            # itself.
195 1
            type_def = copy(type_def)
196
            type_def.skip_definitions = True
197 1
        if self.get_definition(AsyncGenerator[dep_type, None]) or self.get_definition(AsyncContextManager[dep_type]):  # type: ignore
198
            return AsyncConstructionWithContainer(lambda c: self._async_context_resolver(c, type_def))  # type: ignore
199
        return ConstructionWithContainer(lambda c: self._context_resolver(c, type_def))  # type: ignore
200
201 1
    def _context_resolver(self, c: ReadableContainer, type_def: SpecialDepDefinition):
202 1
        """
203 1
        Takes an existing definition which must be a context manager. Returns
204 1
        the value of the context manager from __enter__ and then places the
205
        __exit__ in this container's exit stack
206
        """
207
        assert self.async_exit_stack, "Types can only be resolved within an async with"
208
        context_manager = type_def.get_instance(c)
209
        return self.async_exit_stack.enter_context(context_manager)
210
211
    def _async_context_resolver(
212
        self, c: ReadableContainer, type_def: SpecialDepDefinition
213
    ):
214
        """
215
        Takes an existing definition which must be a context manager. Returns
216
        the value of the context manager from __aenter__ and then places the
217
        __aexit__ in this container's exit stack
218
        """
219
        assert self.async_exit_stack, "Types can only be resolved within an async with"
220
        context_manager = type_def.get_instance(c)
221
        return self.async_exit_stack.enter_async_context(context_manager)
222
223
    def _singleton_type_def(self, dep_type: Type):
224
        """
225
        The same as context_type_def but acts as a singleton within this container
226
        """
227
        type_def = self._context_type_def(dep_type)
228
        if isinstance(type_def, AsyncConstructionWithContainer):
229
            return AwaitableSingleton(type_def, self)
230
        return SingletonWrapper(type_def)
231