Passed
Push — master ( 3e847f...cc3d3c )
by Steve
03:10
created

AsyncContextContainer.partial()   A

Complexity

Conditions 3

Size

Total Lines 20
Code Lines 14

Duplication

Lines 0
Ratio 0 %

Code Coverage

Tests 7
CRAP Score 3.0175

Importance

Changes 0
Metric Value
cc 3
eloc 14
nop 4
dl 0
loc 20
rs 9.7
c 0
b 0
f 0
ccs 7
cts 8
cp 0.875
crap 3.0175
1 1
import inspect
2 1
import logging
3 1
from asyncio import Lock
4 1
from contextlib import AsyncExitStack
5 1
from copy import copy
6 1
from typing import (
7
    Optional,
8
    Type,
9
    TypeVar,
10
    Awaitable,
11
    Generic,
12
    Collection,
13
    Union,
14
    ContextManager,
15
    AsyncContextManager,
16
    Iterator,
17
    Generator,
18
    AsyncGenerator,
19
    Callable,
20
    List,
21
)
22
23 1
from lagom.container import Container
24 1
from lagom.definitions import Alias, ConstructionWithContainer, SingletonWrapper
25 1
from lagom.exceptions import InvalidDependencyDefinition, MissingFeature
26 1
from lagom.experimental.definitions import AsyncConstructionWithContainer
27 1
from lagom.interfaces import (
28
    ReadableContainer,
29
    SpecialDepDefinition,
30
    CallTimeContainerUpdate,
31
)
32
33 1
T = TypeVar("T")
34 1
X = TypeVar("X")
35
36
37 1
class AwaitableSingleton(Generic[T]):
38 1
    instance: Optional[T]
39 1
    constructor: ConstructionWithContainer[Awaitable[T]]
40 1
    container: Container
41 1
    _lock: Lock
42
43 1
    def __init__(self, constructor: ConstructionWithContainer, container: Container):
44 1
        self.instance = None
45 1
        self.constructor = constructor  # type: ignore
46 1
        self.container = container
47 1
        self._lock = Lock()
48
49 1
    async def get(self) -> T:
50 1
        if not self.instance:
51 1
            async with self._lock:
52 1
                if not self.instance:
53 1
                    self.instance = await self.constructor.get_instance(self.container)
54 1
        return self.instance
55
56
57 1
class AsyncContextContainer(Container):
58 1
    async_exit_stack: Optional[AsyncExitStack] = None
59 1
    _context_types: Collection[Type]
60 1
    _context_singletons: Collection[Type]
61 1
    _root_context: bool = True
62
63 1
    def __init__(
64
        self,
65
        container: Container,
66
        context_types: Collection[Type],
67
        context_singletons: Collection[Type] = tuple(),
68
        log_undefined_deps: Union[bool, logging.Logger] = False,
69
    ):
70 1
        super().__init__(container, log_undefined_deps)
71 1
        self._context_types = set(context_types)
72 1
        self._context_singletons = set(context_singletons)
73
74 1
    def clone(self) -> "AsyncContextContainer":
75
        """returns a copy of the container
76
        :return:
77
        """
78 1
        return AsyncContextContainer(
79
            self,
80
            context_types=self._context_types,
81
            context_singletons=self._context_singletons,
82
            log_undefined_deps=self._undefined_logger,
83
        )
84
85 1
    async def __aenter__(self):
86 1
        if not self.async_exit_stack and self._root_context:
87 1
            self.async_exit_stack = AsyncExitStack()
88
89 1
        if self.async_exit_stack and self._root_context:
90
            # All actual context definitions happen on a clone so that there's isolation between invocations
91 1
            in_context = self.clone()
92 1
            in_context.async_exit_stack = AsyncExitStack()
93 1
            in_context._root_context = False
94
95 1
            for dep_type in self._context_types:
96 1
                managed_dep = self._context_type_def(dep_type)
97 1
                key = Awaitable[dep_type] if isinstance(managed_dep, AsyncConstructionWithContainer) else dep_type  # type: ignore
98 1
                in_context[key] = managed_dep  # type: ignore
99 1
            for dep_type in self._context_singletons:
100 1
                managed_singleton = self._singleton_type_def(dep_type)
101 1
                key = AwaitableSingleton[dep_type] if isinstance(managed_singleton, AwaitableSingleton) else dep_type  # type: ignore
102 1
                in_context[key] = managed_singleton  # type: ignore
103
104
            # The parent context manager keeps track of the inner clone
105 1
            await self.async_exit_stack.enter_async_context(in_context)
106 1
            return in_context
107 1
        return self
108
109 1
    async def __aexit__(self, exc_type, exc_val, exc_tb):
110 1
        if self.async_exit_stack:
111 1
            await self.async_exit_stack.aclose()
112 1
            self.async_exit_stack = None
113
114 1
    def partial(
115
        self,
116
        func: Callable[..., X],
117
        shared: List[Type] = None,
118
        container_updater: Optional[CallTimeContainerUpdate] = None,
119
    ) -> Callable[..., X]:
120 1
        if not inspect.iscoroutinefunction(func):
121
            raise MissingFeature(
122
                "AsyncContextManager currently can only deal with async functions"
123
            )
124
125 1
        async def _with_context(*args, **kwargs):
126 1
            async with self as c:
127
                # TODO: Try and move this partial outside the function as this is expensive
128 1
                base_partial = super(AsyncContextContainer, c).partial(
129
                    func, shared, container_updater
130
                )
131 1
                return await base_partial(*args, **kwargs)  # type: ignore
132
133 1
        return _with_context
134
135 1
    def magic_partial(
136
        self,
137
        func: Callable[..., X],
138
        shared: List[Type] = None,
139
        keys_to_skip: List[str] = None,
140
        skip_pos_up_to: int = 0,
141
        container_updater: Optional[CallTimeContainerUpdate] = None,
142
    ) -> Callable[..., X]:
143 1
        if not inspect.iscoroutinefunction(func):
144
            raise MissingFeature(
145
                "AsyncContextManager currently can only deal with async functions"
146
            )
147
148 1
        async def _with_context(*args, **kwargs):
149 1
            async with self as c:
150
                # TODO: Try and move this partial outside the function as this is expensive
151 1
                base_partial = super(AsyncContextContainer, c).magic_partial(
152
                    func, shared, keys_to_skip, skip_pos_up_to, container_updater
153
                )
154 1
                return await base_partial(*args, **kwargs)  # type: ignore
155
156 1
        return _with_context
157
158 1
    def _context_type_def(self, dep_type: Type):
159 1
        type_def = self.get_definition(ContextManager[dep_type]) or self.get_definition(Iterator[dep_type]) or self.get_definition(Generator[dep_type, None, None]) or self.get_definition(AsyncGenerator[dep_type, None]) or self.get_definition(AsyncContextManager[dep_type])  # type: ignore
160 1
        if type_def is None:
161 1
            raise InvalidDependencyDefinition(
162
                f"A ContextManager[{dep_type}] should be defined. "
163
                f"This could be an Iterator[{dep_type}] or Generator[{dep_type}, None, None] "
164
                f"with the @contextmanager decorator"
165
            )
166 1
        if isinstance(type_def, Alias):
167
            # Without this we create a definition that points to
168
            # itself.
169 1
            type_def = copy(type_def)
170 1
            type_def.skip_definitions = True
171 1
        if self.get_definition(AsyncGenerator[dep_type, None]) or self.get_definition(AsyncContextManager[dep_type]):  # type: ignore
172 1
            return AsyncConstructionWithContainer(lambda c: self._async_context_resolver(c, type_def))  # type: ignore
173 1
        return ConstructionWithContainer(lambda c: self._context_resolver(c, type_def))  # type: ignore
174
175 1
    def _context_resolver(self, c: ReadableContainer, type_def: SpecialDepDefinition):
176
        """
177
        Takes an existing definition which must be a context manager. Returns
178
        the value of the context manager from __enter__ and then places the
179
        __exit__ in this container's exit stack
180
        """
181 1
        assert self.async_exit_stack, "Types can only be resolved within an async with"
182 1
        context_manager = type_def.get_instance(c)
183 1
        return self.async_exit_stack.enter_context(context_manager)
184
185 1
    def _async_context_resolver(
186
        self, c: ReadableContainer, type_def: SpecialDepDefinition
187
    ):
188
        """
189
        Takes an existing definition which must be a context manager. Returns
190
        the value of the context manager from __aenter__ and then places the
191
        __aexit__ in this container's exit stack
192
        """
193 1
        assert self.async_exit_stack, "Types can only be resolved within an async with"
194 1
        context_manager = type_def.get_instance(c)
195 1
        return self.async_exit_stack.enter_async_context(context_manager)
196
197 1
    def _singleton_type_def(self, dep_type: Type):
198
        """
199
        The same as context_type_def but acts as a singleton within this container
200
        """
201 1
        type_def = self._context_type_def(dep_type)
202 1
        if isinstance(type_def, AsyncConstructionWithContainer):
203 1
            return AwaitableSingleton(type_def, self)
204
        return SingletonWrapper(type_def)
205