Passed
Branch async-context-container (98d6fa)
by Steve
03:16
created

AwaitableSingleton.get()   A

Complexity

Conditions 2

Size

Total Lines 4
Code Lines 4

Duplication

Lines 0
Ratio 0 %

Code Coverage

Tests 4
CRAP Score 2

Importance

Changes 0
Metric Value
cc 2
eloc 4
nop 1
dl 0
loc 4
rs 10
c 0
b 0
f 0
ccs 4
cts 4
cp 1
crap 2
1 1
import logging
2 1
from contextlib import AsyncExitStack
3 1
from copy import copy
4 1
from typing import (
5
    Optional,
6
    Type,
7
    TypeVar,
8
    Awaitable,
9
    Generic,
10
    Collection,
11
    Union,
12
    ContextManager,
13
    AsyncContextManager,
14
    Iterator,
15
    Generator,
16
    AsyncGenerator,
17
    List,
18
)
19
20 1
from lagom.container import Container
21 1
from lagom.definitions import Alias, ConstructionWithContainer, SingletonWrapper
22 1
from lagom.exceptions import InvalidDependencyDefinition
23 1
from lagom.experimental.definitions import AsyncConstructionWithContainer
24 1
from lagom.interfaces import ReadableContainer, SpecialDepDefinition
25
26 1
T = TypeVar("T")
27
28
29 1
class AwaitableSingleton(Generic[T]):
30 1
    instance: Optional[T]
31 1
    constructor: ConstructionWithContainer[Awaitable[T]]
32 1
    container: Container
33
34 1
    def __init__(self, constructor: ConstructionWithContainer, container: Container):
35 1
        self.instance = None
36 1
        self.constructor = constructor  # type: ignore
37 1
        self.container = container
38
39 1
    async def get(self) -> T:
40 1
        if not self.instance:
41 1
            self.instance = await self.constructor.get_instance(self.container)
42 1
        return self.instance
43
44 1
    def reset(self):
45 1
        self.instance = None
46
47
48 1
class AsyncContextContainer(Container):
49 1
    async_exit_stack: Optional[AsyncExitStack] = None
50 1
    _managed_singletons: List[Union[SingletonWrapper, AwaitableSingleton]]
51
52 1
    def __init__(
53
        self,
54
        container: Container,
55
        context_types: Collection[Type],
56
        context_singletons: Collection[Type] = tuple(),
57
        log_undefined_deps: Union[bool, logging.Logger] = False,
58
    ):
59 1
        super().__init__(container, log_undefined_deps)
60 1
        self._managed_singletons = []
61 1
        for dep_type in set(context_types):
62 1
            managed_dep = self._context_type_def(dep_type)
63 1
            key = Awaitable[dep_type] if isinstance(managed_dep, AsyncConstructionWithContainer) else dep_type  # type: ignore
64 1
            self[key] = managed_dep  # type: ignore
65 1
        for dep_type in set(context_singletons):
66 1
            managed_singleton = self._singleton_type_def(dep_type)
67 1
            self._managed_singletons.append(managed_singleton)
68 1
            key = AwaitableSingleton[dep_type] if isinstance(managed_singleton, AwaitableSingleton) else dep_type  # type: ignore
69 1
            self[key] = managed_singleton  # type: ignore
70
71 1
    async def __aenter__(self):
72 1
        if not self.async_exit_stack:
73 1
            self.async_exit_stack = AsyncExitStack()
74 1
        return self
75
76 1
    async def __aexit__(self, exc_type, exc_val, exc_tb):
77 1
        if self.async_exit_stack:
78 1
            await self.async_exit_stack.aclose()
79 1
            self.async_exit_stack = None
80 1
        for managed_singleton in self._managed_singletons:
81 1
            managed_singleton.reset()
82
83 1
    def _context_type_def(self, dep_type: Type):
84 1
        type_def = self.get_definition(ContextManager[dep_type]) or self.get_definition(Iterator[dep_type]) or self.get_definition(Generator[dep_type, None, None]) or self.get_definition(AsyncGenerator[dep_type, None]) or self.get_definition(AsyncContextManager[dep_type])  # type: ignore
85 1
        if type_def is None:
86 1
            raise InvalidDependencyDefinition(
87
                f"A ContextManager[{dep_type}] should be defined. "
88
                f"This could be an Iterator[{dep_type}] or Generator[{dep_type}, None, None] "
89
                f"with the @contextmanager decorator"
90
            )
91 1
        if isinstance(type_def, Alias):
92
            # Without this we create a definition that points to
93
            # itself.
94 1
            type_def = copy(type_def)
95 1
            type_def.skip_definitions = True
96 1
        if self.get_definition(AsyncGenerator[dep_type, None]) or self.get_definition(AsyncContextManager[dep_type]):  # type: ignore
97 1
            return AsyncConstructionWithContainer(lambda c: self._async_context_resolver(c, type_def))  # type: ignore
98 1
        return ConstructionWithContainer(lambda c: self._context_resolver(c, type_def))  # type: ignore
99
100 1
    def _context_resolver(self, c: ReadableContainer, type_def: SpecialDepDefinition):
101
        """
102
        Takes an existing definition which must be a context manager. Returns
103
        the value of the context manager from __enter__ and then places the
104
        __exit__ in this container's exit stack
105
        """
106 1
        assert self.async_exit_stack, "Types can only be resolved within an async with"
107 1
        context_manager = type_def.get_instance(c)
108 1
        return self.async_exit_stack.enter_context(context_manager)
109
110 1
    def _async_context_resolver(
111
        self, c: ReadableContainer, type_def: SpecialDepDefinition
112
    ):
113
        """
114
        Takes an existing definition which must be a context manager. Returns
115
        the value of the context manager from __aenter__ and then places the
116
        __aexit__ in this container's exit stack
117
        """
118 1
        assert self.async_exit_stack, "Types can only be resolved within an async with"
119 1
        context_manager = type_def.get_instance(c)
120 1
        return self.async_exit_stack.enter_async_context(context_manager)
121
122 1
    def _singleton_type_def(self, dep_type: Type):
123
        """
124
        The same as context_type_def but acts as a singleton within this container
125
        """
126 1
        type_def = self._context_type_def(dep_type)
127 1
        if isinstance(type_def, AsyncConstructionWithContainer):
128 1
            return AwaitableSingleton(type_def, self)
129
        return SingletonWrapper(type_def)
130