Passed
Push — master ( e565cb...3e847f )
by Steve
03:15
created

AsyncContextContainer.__aenter__()   C

Complexity

Conditions 9

Size

Total Lines 23
Code Lines 18

Duplication

Lines 0
Ratio 0 %

Code Coverage

Tests 18
CRAP Score 9

Importance

Changes 0
Metric Value
cc 9
eloc 18
nop 1
dl 0
loc 23
ccs 18
cts 18
cp 1
crap 9
rs 6.6666
c 0
b 0
f 0
1 1
import logging
2 1
from asyncio import Lock
3 1
from contextlib import AsyncExitStack
4 1
from copy import copy
5 1
from typing import (
6
    Optional,
7
    Type,
8
    TypeVar,
9
    Awaitable,
10
    Generic,
11
    Collection,
12
    Union,
13
    ContextManager,
14
    AsyncContextManager,
15
    Iterator,
16
    Generator,
17
    AsyncGenerator,
18
)
19
20 1
from lagom.container import Container
21 1
from lagom.definitions import Alias, ConstructionWithContainer, SingletonWrapper
22 1
from lagom.exceptions import InvalidDependencyDefinition
23 1
from lagom.experimental.definitions import AsyncConstructionWithContainer
24 1
from lagom.interfaces import ReadableContainer, SpecialDepDefinition
25
26 1
T = TypeVar("T")
27
28
29 1
class AwaitableSingleton(Generic[T]):
30 1
    instance: Optional[T]
31 1
    constructor: ConstructionWithContainer[Awaitable[T]]
32 1
    container: Container
33 1
    _lock: Lock
34
35 1
    def __init__(self, constructor: ConstructionWithContainer, container: Container):
36 1
        self.instance = None
37 1
        self.constructor = constructor  # type: ignore
38 1
        self.container = container
39 1
        self._lock = Lock()
40
41 1
    async def get(self) -> T:
42 1
        if not self.instance:
43 1
            async with self._lock:
44 1
                if not self.instance:
45 1
                    self.instance = await self.constructor.get_instance(self.container)
46 1
        return self.instance
47
48
49 1
class AsyncContextContainer(Container):
50 1
    async_exit_stack: Optional[AsyncExitStack] = None
51 1
    _context_types: Collection[Type]
52 1
    _context_singletons: Collection[Type]
53 1
    _root_context: bool = True
54
55 1
    def __init__(
56
        self,
57
        container: Container,
58
        context_types: Collection[Type],
59
        context_singletons: Collection[Type] = tuple(),
60
        log_undefined_deps: Union[bool, logging.Logger] = False,
61
    ):
62 1
        super().__init__(container, log_undefined_deps)
63 1
        self._context_types = set(context_types)
64 1
        self._context_singletons = set(context_singletons)
65
66 1
    def clone(self) -> "AsyncContextContainer":
67
        """returns a copy of the container
68
        :return:
69
        """
70 1
        return AsyncContextContainer(
71
            self,
72
            context_types=self._context_types,
73
            context_singletons=self._context_singletons,
74
            log_undefined_deps=self._undefined_logger,
75
        )
76
77 1
    async def __aenter__(self):
78 1
        if not self.async_exit_stack and self._root_context:
79 1
            self.async_exit_stack = AsyncExitStack()
80
81 1
        if self.async_exit_stack and self._root_context:
82
            # All actual context definitions happen on a clone so that there's isolation between invocations
83 1
            in_context = self.clone()
84 1
            in_context.async_exit_stack = AsyncExitStack()
85 1
            in_context._root_context = False
86
87 1
            for dep_type in self._context_types:
88 1
                managed_dep = self._context_type_def(dep_type)
89 1
                key = Awaitable[dep_type] if isinstance(managed_dep, AsyncConstructionWithContainer) else dep_type  # type: ignore
90 1
                in_context[key] = managed_dep  # type: ignore
91 1
            for dep_type in self._context_singletons:
92 1
                managed_singleton = self._singleton_type_def(dep_type)
93 1
                key = AwaitableSingleton[dep_type] if isinstance(managed_singleton, AwaitableSingleton) else dep_type  # type: ignore
94 1
                in_context[key] = managed_singleton  # type: ignore
95
96
            # The parent context manager keeps track of the inner clone
97 1
            await self.async_exit_stack.enter_async_context(in_context)
98 1
            return in_context
99 1
        return self
100
101 1
    async def __aexit__(self, exc_type, exc_val, exc_tb):
102 1
        if self.async_exit_stack:
103 1
            await self.async_exit_stack.aclose()
104 1
            self.async_exit_stack = None
105
106 1
    def _context_type_def(self, dep_type: Type):
107 1
        type_def = self.get_definition(ContextManager[dep_type]) or self.get_definition(Iterator[dep_type]) or self.get_definition(Generator[dep_type, None, None]) or self.get_definition(AsyncGenerator[dep_type, None]) or self.get_definition(AsyncContextManager[dep_type])  # type: ignore
108 1
        if type_def is None:
109 1
            raise InvalidDependencyDefinition(
110
                f"A ContextManager[{dep_type}] should be defined. "
111
                f"This could be an Iterator[{dep_type}] or Generator[{dep_type}, None, None] "
112
                f"with the @contextmanager decorator"
113
            )
114 1
        if isinstance(type_def, Alias):
115
            # Without this we create a definition that points to
116
            # itself.
117 1
            type_def = copy(type_def)
118 1
            type_def.skip_definitions = True
119 1
        if self.get_definition(AsyncGenerator[dep_type, None]) or self.get_definition(AsyncContextManager[dep_type]):  # type: ignore
120 1
            return AsyncConstructionWithContainer(lambda c: self._async_context_resolver(c, type_def))  # type: ignore
121 1
        return ConstructionWithContainer(lambda c: self._context_resolver(c, type_def))  # type: ignore
122
123 1
    def _context_resolver(self, c: ReadableContainer, type_def: SpecialDepDefinition):
124
        """
125
        Takes an existing definition which must be a context manager. Returns
126
        the value of the context manager from __enter__ and then places the
127
        __exit__ in this container's exit stack
128
        """
129 1
        assert self.async_exit_stack, "Types can only be resolved within an async with"
130 1
        context_manager = type_def.get_instance(c)
131 1
        return self.async_exit_stack.enter_context(context_manager)
132
133 1
    def _async_context_resolver(
134
        self, c: ReadableContainer, type_def: SpecialDepDefinition
135
    ):
136
        """
137
        Takes an existing definition which must be a context manager. Returns
138
        the value of the context manager from __aenter__ and then places the
139
        __aexit__ in this container's exit stack
140
        """
141 1
        assert self.async_exit_stack, "Types can only be resolved within an async with"
142 1
        context_manager = type_def.get_instance(c)
143 1
        return self.async_exit_stack.enter_async_context(context_manager)
144
145 1
    def _singleton_type_def(self, dep_type: Type):
146
        """
147
        The same as context_type_def but acts as a singleton within this container
148
        """
149 1
        type_def = self._context_type_def(dep_type)
150 1
        if isinstance(type_def, AsyncConstructionWithContainer):
151 1
            return AwaitableSingleton(type_def, self)
152
        return SingletonWrapper(type_def)
153