Passed
Push — master ( 7dd739...b601db )
by Steve
02:55
created

AsyncContextContainer.__init__()   A

Complexity

Conditions 5

Size

Total Lines 18
Code Lines 17

Duplication

Lines 0
Ratio 0 %

Code Coverage

Tests 12
CRAP Score 5

Importance

Changes 0
Metric Value
cc 5
eloc 17
nop 5
dl 0
loc 18
rs 9.0833
c 0
b 0
f 0
ccs 12
cts 12
cp 1
crap 5
1 1
import logging
2 1
from asyncio import Lock
3 1
from contextlib import AsyncExitStack
4 1
from copy import copy
5 1
from typing import (
6
    Optional,
7
    Type,
8
    TypeVar,
9
    Awaitable,
10
    Generic,
11
    Collection,
12
    Union,
13
    ContextManager,
14
    AsyncContextManager,
15
    Iterator,
16
    Generator,
17
    AsyncGenerator,
18
    List,
19
)
20
21 1
from lagom.container import Container
22 1
from lagom.definitions import Alias, ConstructionWithContainer, SingletonWrapper
23 1
from lagom.exceptions import InvalidDependencyDefinition
24 1
from lagom.experimental.definitions import AsyncConstructionWithContainer
25 1
from lagom.interfaces import ReadableContainer, SpecialDepDefinition
26
27 1
T = TypeVar("T")
28
29
30 1
class AwaitableSingleton(Generic[T]):
31 1
    instance: Optional[T]
32 1
    constructor: ConstructionWithContainer[Awaitable[T]]
33 1
    container: Container
34 1
    _lock: Lock
35
36 1
    def __init__(self, constructor: ConstructionWithContainer, container: Container):
37 1
        self.instance = None
38 1
        self.constructor = constructor  # type: ignore
39 1
        self.container = container
40 1
        self._lock = Lock()
41
42 1
    async def get(self) -> T:
43 1
        if not self.instance:
44 1
            async with self._lock:
45 1
                if not self.instance:
46 1
                    self.instance = await self.constructor.get_instance(self.container)
47 1
        return self.instance
48
49 1
    def reset(self):
50 1
        self.instance = None
51
52
53 1
class AsyncContextContainer(Container):
54 1
    async_exit_stack: Optional[AsyncExitStack] = None
55 1
    _managed_singletons: List[Union[SingletonWrapper, AwaitableSingleton]]
56
57 1
    def __init__(
58
        self,
59
        container: Container,
60
        context_types: Collection[Type],
61
        context_singletons: Collection[Type] = tuple(),
62
        log_undefined_deps: Union[bool, logging.Logger] = False,
63
    ):
64 1
        super().__init__(container, log_undefined_deps)
65 1
        self._managed_singletons = []
66 1
        for dep_type in set(context_types):
67 1
            managed_dep = self._context_type_def(dep_type)
68 1
            key = Awaitable[dep_type] if isinstance(managed_dep, AsyncConstructionWithContainer) else dep_type  # type: ignore
69 1
            self[key] = managed_dep  # type: ignore
70 1
        for dep_type in set(context_singletons):
71 1
            managed_singleton = self._singleton_type_def(dep_type)
72 1
            self._managed_singletons.append(managed_singleton)
73 1
            key = AwaitableSingleton[dep_type] if isinstance(managed_singleton, AwaitableSingleton) else dep_type  # type: ignore
74 1
            self[key] = managed_singleton  # type: ignore
75
76 1
    async def __aenter__(self):
77 1
        if not self.async_exit_stack:
78 1
            self.async_exit_stack = AsyncExitStack()
79 1
        return self
80
81 1
    async def __aexit__(self, exc_type, exc_val, exc_tb):
82 1
        if self.async_exit_stack:
83 1
            await self.async_exit_stack.aclose()
84 1
            self.async_exit_stack = None
85 1
        for managed_singleton in self._managed_singletons:
86 1
            managed_singleton.reset()
87
88 1
    def _context_type_def(self, dep_type: Type):
89 1
        type_def = self.get_definition(ContextManager[dep_type]) or self.get_definition(Iterator[dep_type]) or self.get_definition(Generator[dep_type, None, None]) or self.get_definition(AsyncGenerator[dep_type, None]) or self.get_definition(AsyncContextManager[dep_type])  # type: ignore
90 1
        if type_def is None:
91 1
            raise InvalidDependencyDefinition(
92
                f"A ContextManager[{dep_type}] should be defined. "
93
                f"This could be an Iterator[{dep_type}] or Generator[{dep_type}, None, None] "
94
                f"with the @contextmanager decorator"
95
            )
96 1
        if isinstance(type_def, Alias):
97
            # Without this we create a definition that points to
98
            # itself.
99 1
            type_def = copy(type_def)
100 1
            type_def.skip_definitions = True
101 1
        if self.get_definition(AsyncGenerator[dep_type, None]) or self.get_definition(AsyncContextManager[dep_type]):  # type: ignore
102 1
            return AsyncConstructionWithContainer(lambda c: self._async_context_resolver(c, type_def))  # type: ignore
103 1
        return ConstructionWithContainer(lambda c: self._context_resolver(c, type_def))  # type: ignore
104
105 1
    def _context_resolver(self, c: ReadableContainer, type_def: SpecialDepDefinition):
106
        """
107
        Takes an existing definition which must be a context manager. Returns
108
        the value of the context manager from __enter__ and then places the
109
        __exit__ in this container's exit stack
110
        """
111 1
        assert self.async_exit_stack, "Types can only be resolved within an async with"
112 1
        context_manager = type_def.get_instance(c)
113 1
        return self.async_exit_stack.enter_context(context_manager)
114
115 1
    def _async_context_resolver(
116
        self, c: ReadableContainer, type_def: SpecialDepDefinition
117
    ):
118
        """
119
        Takes an existing definition which must be a context manager. Returns
120
        the value of the context manager from __aenter__ and then places the
121
        __aexit__ in this container's exit stack
122
        """
123 1
        assert self.async_exit_stack, "Types can only be resolved within an async with"
124 1
        context_manager = type_def.get_instance(c)
125 1
        return self.async_exit_stack.enter_async_context(context_manager)
126
127 1
    def _singleton_type_def(self, dep_type: Type):
128
        """
129
        The same as context_type_def but acts as a singleton within this container
130
        """
131 1
        type_def = self._context_type_def(dep_type)
132 1
        if isinstance(type_def, AsyncConstructionWithContainer):
133 1
            return AwaitableSingleton(type_def, self)
134
        return SingletonWrapper(type_def)
135