Passed
Pull Request — master (#226)
by Steve
02:59
created

AwaitableSingleton.get()   A

Complexity

Conditions 4

Size

Total Lines 6
Code Lines 6

Duplication

Lines 0
Ratio 0 %

Code Coverage

Tests 6
CRAP Score 4

Importance

Changes 0
Metric Value
cc 4
eloc 6
nop 1
dl 0
loc 6
ccs 6
cts 6
cp 1
crap 4
rs 10
c 0
b 0
f 0
1 1
import inspect
2 1
import logging
3 1
from asyncio import Lock
4 1
from contextlib import AsyncExitStack
5 1
from copy import copy
6 1
from typing import (
7
    Optional,
8
    Type,
9
    TypeVar,
10
    Awaitable,
11
    Generic,
12
    Collection,
13
    Union,
14
    ContextManager,
15
    AsyncContextManager,
16
    Iterator,
17
    Generator,
18
    AsyncGenerator,
19
    Callable,
20
    List,
21
)
22
23 1
from lagom.container import Container
24 1
from lagom.definitions import Alias, ConstructionWithContainer, SingletonWrapper
25 1
from lagom.exceptions import InvalidDependencyDefinition, MissingFeature
26 1
from lagom.experimental.definitions import AsyncConstructionWithContainer
27 1
from lagom.interfaces import (
28
    ReadableContainer,
29
    SpecialDepDefinition,
30
    CallTimeContainerUpdate,
31
    ContainerBoundFunction,
32
)
33
34 1
T = TypeVar("T")
35 1
X = TypeVar("X")
36
37
38 1
class AwaitableSingleton(Generic[T]):
39 1
    instance: Optional[T]
40 1
    constructor: ConstructionWithContainer[Awaitable[T]]
41 1
    container: Container
42 1
    _lock: Lock
43
44 1
    def __init__(self, constructor: ConstructionWithContainer, container: Container):
45 1
        self.instance = None
46 1
        self.constructor = constructor  # type: ignore
47 1
        self.container = container
48 1
        self._lock = Lock()
49
50 1
    async def get(self) -> T:
51 1
        if not self.instance:
52 1
            async with self._lock:
53 1
                if not self.instance:
54 1
                    self.instance = await self.constructor.get_instance(self.container)
55 1
        return self.instance
56
57
58 1
class _AsyncContextBoundFunction(ContainerBoundFunction[X]):
59
    """
60
    Represents an instance of a function bound to an async context container
61
    """
62
63 1
    async_context_container: "AsyncContextContainer"
64 1
    partially_bound_function: ContainerBoundFunction
65
66 1
    def __init__(
67
        self,
68
        async_context_container: "AsyncContextContainer",
69
        partially_bound_function: ContainerBoundFunction,
70
    ):
71 1
        self.async_context_container = async_context_container
72 1
        self.partially_bound_function = partially_bound_function
73
74 1
    def __call__(self, *args, **kwargs) -> X:
75 1
        return self.__async_call__(*args, **kwargs)
76
77 1
    async def __async_call__(self, *args, **kwargs):
78 1
        async with self.async_context_container as c:
79 1
            return await self.partially_bound_function.rebind(c)(*args, **kwargs)
80
81 1
    def rebind(self, container: ReadableContainer) -> "ContainerBoundFunction[X]":
82
        return _AsyncContextBoundFunction(
83
            self.async_context_container,
84
            self.partially_bound_function.rebind(container),
85
        )
86
87
88 1
class AsyncContextContainer(Container):
89 1
    async_exit_stack: Optional[AsyncExitStack] = None
90 1
    _context_types: Collection[Type]
91 1
    _context_singletons: Collection[Type]
92 1
    _root_context: bool = True
93
94 1
    def __init__(
95
        self,
96
        container: Container,
97
        context_types: Collection[Type],
98
        context_singletons: Collection[Type] = tuple(),
99
        log_undefined_deps: Union[bool, logging.Logger] = False,
100
    ):
101 1
        super().__init__(container, log_undefined_deps)
102 1
        self._context_types = set(context_types)
103 1
        self._context_singletons = set(context_singletons)
104
105 1
    def clone(self) -> "AsyncContextContainer":
106
        """returns a copy of the container
107
        :return:
108
        """
109 1
        return AsyncContextContainer(
110
            self,
111
            context_types=self._context_types,
112
            context_singletons=self._context_singletons,
113
            log_undefined_deps=self._undefined_logger,
114
        )
115
116 1
    async def __aenter__(self):
117 1
        if not self.async_exit_stack and self._root_context:
118 1
            self.async_exit_stack = AsyncExitStack()
119
120 1
        if self.async_exit_stack and self._root_context:
121
            # All actual context definitions happen on a clone so that there's isolation between invocations
122 1
            in_context = self.clone()
123 1
            in_context.async_exit_stack = AsyncExitStack()
124 1
            in_context._root_context = False
125
126 1
            for dep_type in self._context_types:
127 1
                managed_dep = self._context_type_def(dep_type)
128 1
                key = Awaitable[dep_type] if isinstance(managed_dep, AsyncConstructionWithContainer) else dep_type  # type: ignore
129 1
                in_context[key] = managed_dep  # type: ignore
130 1
            for dep_type in self._context_singletons:
131 1
                managed_singleton = self._singleton_type_def(dep_type)
132 1
                key = AwaitableSingleton[dep_type] if isinstance(managed_singleton, AwaitableSingleton) else dep_type  # type: ignore
133 1
                in_context[key] = managed_singleton  # type: ignore
134
135
            # The parent context manager keeps track of the inner clone
136 1
            await self.async_exit_stack.enter_async_context(in_context)
137 1
            return in_context
138 1
        return self
139
140 1
    async def __aexit__(self, exc_type, exc_val, exc_tb):
141 1
        if self.async_exit_stack:
142 1
            await self.async_exit_stack.aclose()
143 1
            self.async_exit_stack = None
144
145 1
    def partial(
146
        self,
147
        func: Callable[..., X],
148
        shared: Optional[List[Type]] = None,
149
        container_updater: Optional[CallTimeContainerUpdate] = None,
150
    ) -> ContainerBoundFunction[X]:
151 1
        if not inspect.iscoroutinefunction(func):
152
            raise MissingFeature(
153
                "AsyncContextManager currently can only deal with async functions"
154
            )
155 1
        base_partial = super(AsyncContextContainer, self).partial(
156
            func, shared, container_updater
157
        )
158
159 1
        return _AsyncContextBoundFunction(self, base_partial)
160
161 1
    def magic_partial(
162
        self,
163
        func: Callable[..., X],
164
        shared: Optional[List[Type]] = None,
165
        keys_to_skip: Optional[List[str]] = None,
166
        skip_pos_up_to: int = 0,
167
        container_updater: Optional[CallTimeContainerUpdate] = None,
168
    ) -> ContainerBoundFunction[X]:
169 1
        if not inspect.iscoroutinefunction(func):
170
            raise MissingFeature(
171
                "AsyncContextManager currently can only deal with async functions"
172
            )
173 1
        base_partial = super(AsyncContextContainer, self).magic_partial(
174
            func, shared, keys_to_skip, skip_pos_up_to, container_updater
175
        )
176
177 1
        return _AsyncContextBoundFunction(self, base_partial)
178
179 1
    def _context_type_def(self, dep_type: Type):
180 1
        type_def = self.get_definition(ContextManager[dep_type]) or self.get_definition(Iterator[dep_type]) or self.get_definition(Generator[dep_type, None, None]) or self.get_definition(AsyncGenerator[dep_type, None]) or self.get_definition(AsyncContextManager[dep_type])  # type: ignore
181 1
        if type_def is None:
182 1
            raise InvalidDependencyDefinition(
183
                f"A ContextManager[{dep_type}] should be defined. "
184
                f"This could be an Iterator[{dep_type}] or Generator[{dep_type}, None, None] "
185
                f"with the @contextmanager decorator"
186
            )
187 1
        if isinstance(type_def, Alias):
188
            # Without this we create a definition that points to
189
            # itself.
190 1
            type_def = copy(type_def)
191 1
            type_def.skip_definitions = True
192 1
        if self.get_definition(AsyncGenerator[dep_type, None]) or self.get_definition(AsyncContextManager[dep_type]):  # type: ignore
193 1
            return AsyncConstructionWithContainer(lambda c: self._async_context_resolver(c, type_def))  # type: ignore
194 1
        return ConstructionWithContainer(lambda c: self._context_resolver(c, type_def))  # type: ignore
195
196 1
    def _context_resolver(self, c: ReadableContainer, type_def: SpecialDepDefinition):
197
        """
198
        Takes an existing definition which must be a context manager. Returns
199
        the value of the context manager from __enter__ and then places the
200
        __exit__ in this container's exit stack
201
        """
202 1
        assert self.async_exit_stack, "Types can only be resolved within an async with"
203 1
        context_manager = type_def.get_instance(c)
204 1
        return self.async_exit_stack.enter_context(context_manager)
205
206 1
    def _async_context_resolver(
207
        self, c: ReadableContainer, type_def: SpecialDepDefinition
208
    ):
209
        """
210
        Takes an existing definition which must be a context manager. Returns
211
        the value of the context manager from __aenter__ and then places the
212
        __aexit__ in this container's exit stack
213
        """
214 1
        assert self.async_exit_stack, "Types can only be resolved within an async with"
215 1
        context_manager = type_def.get_instance(c)
216 1
        return self.async_exit_stack.enter_async_context(context_manager)
217
218 1
    def _singleton_type_def(self, dep_type: Type):
219
        """
220
        The same as context_type_def but acts as a singleton within this container
221
        """
222 1
        type_def = self._context_type_def(dep_type)
223 1
        if isinstance(type_def, AsyncConstructionWithContainer):
224 1
            return AwaitableSingleton(type_def, self)
225
        return SingletonWrapper(type_def)
226