Passed
Pull Request — master (#226)
by Steve
02:53
created

_ContextBoundFunction.__call__()   A

Complexity

Conditions 2

Size

Total Lines 3
Code Lines 3

Duplication

Lines 3
Ratio 100 %

Code Coverage

Tests 3
CRAP Score 2

Importance

Changes 0
Metric Value
cc 2
eloc 3
nop 3
dl 3
loc 3
ccs 3
cts 3
cp 1
crap 2
rs 10
c 0
b 0
f 0
1 1
import logging
2 1
from contextlib import ExitStack
3 1
from copy import copy
4 1
from functools import wraps
5 1
from typing import (
6
    Collection,
7
    Union,
8
    Type,
9
    TypeVar,
10
    Optional,
11
    cast,
12
    ContextManager,
13
    Iterator,
14
    Generator,
15
    Callable,
16
    List,
17
    Dict,
18
)
19
20 1
from lagom import Container
21 1
from lagom.compilaton import mypyc_attr
22 1
from lagom.definitions import ConstructionWithContainer, SingletonWrapper, Alias
23 1
from lagom.exceptions import InvalidDependencyDefinition
24 1
from lagom.interfaces import (
25
    ReadableContainer,
26
    SpecialDepDefinition,
27
    CallTimeContainerUpdate,
28
    ContainerBoundFunction,
29
)
30
31 1
X = TypeVar("X")
32
33
34 1 View Code Duplication
class _ContextBoundFunction(ContainerBoundFunction[X]):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
35
    """
36
    Represents an instance of a function bound to a context container
37
    """
38
39 1
    __slots__ = ("context_container", "partially_bound_function")
40
41 1
    context_container: "ContextContainer"
42 1
    partially_bound_function: ContainerBoundFunction
43
44 1
    __dict__: Dict = dict()
45
46 1
    def __init__(
47
        self,
48
        context_container: "ContextContainer",
49
        partially_bound_function: ContainerBoundFunction,
50
    ):
51 1
        self.context_container = context_container
52 1
        self.partially_bound_function = partially_bound_function
53
54 1
    def __call__(self, *args, **kwargs) -> X:
55 1
        with self.context_container as c:
56 1
            return self.partially_bound_function.rebind(c)(*args, **kwargs)
57
58 1
    def rebind(self, container: ReadableContainer) -> "ContainerBoundFunction[X]":
59
        return wraps(self.partially_bound_function)(
60
            _ContextBoundFunction(
61
                self.context_container, self.partially_bound_function.rebind(container)
62
            )
63
        )
64
65 1
    def __getattr__(self, item):
66
        if item not in self.__slots__:
67
            raise Exception(f"{item} doesn't exist")
68
        if item == "context_container":
69
            return self.context_container
70
        if item == "partially_bound_function":
71
            return self.partially_bound_function
72
73
74 1
@mypyc_attr(allow_interpreted_subclasses=True)
75 1
class ContextContainer(Container):
76
    """
77
    Wraps a regular container but is a ContextManager for use within a `with`.
78
79
    >>> from tests.examples import SomeClass, SomeClassManager
80
    >>> from lagom import Container
81
    >>> from typing import ContextManager
82
    >>>
83
    >>> # The regular container
84
    >>> c = Container()
85
    >>>
86
    >>> # register a context manager for SomeClass
87
    >>> c[ContextManager[SomeClass]] = SomeClassManager
88
    >>>
89
    >>> context_c = ContextContainer(c, context_types=[SomeClass])
90
    >>> with context_c as c:
91
    ...     c[SomeClass]
92
    <tests.examples.SomeClass object at ...>
93
    """
94
95 1
    exit_stack: Optional[ExitStack] = None
96 1
    _context_types: Collection[Type]
97 1
    _context_singletons: Collection[Type]
98
99 1
    def __init__(
100
        self,
101
        container: Container,
102
        context_types: Collection[Type],
103
        context_singletons: Collection[Type] = tuple(),
104
        log_undefined_deps: Union[bool, logging.Logger] = False,
105
    ):
106 1
        self._context_types = context_types
107 1
        self._context_singletons = context_singletons
108 1
        super().__init__(container, log_undefined_deps)
109
110 1
    def clone(self) -> "ContextContainer":
111
        """returns a copy of the container
112
        :return:
113
        """
114 1
        return ContextContainer(
115
            self,
116
            context_types=self._context_types,
117
            context_singletons=self._context_singletons,
118
            log_undefined_deps=self._undefined_logger,
119
        )
120
121 1
    def __enter__(self):
122 1
        if not self.exit_stack:
123
            # All actual context definitions happen on a clone so that there's isolation between invocations
124 1
            in_context = self.clone()
125 1
            for dep_type in set(self._context_types):
126 1
                in_context[dep_type] = self._context_type_def(dep_type)
127 1
            for dep_type in set(self._context_singletons):
128 1
                in_context[dep_type] = self._singleton_type_def(dep_type)
129 1
            in_context.exit_stack = ExitStack()
130
131
            # The parent context manager keeps track of the inner clone
132 1
            self.exit_stack = ExitStack()
133 1
            self.exit_stack.enter_context(in_context)
134 1
            return in_context
135 1
        return self
136
137 1
    def __exit__(self, exc_type, exc_val, exc_tb):
138 1
        if self.exit_stack:
139 1
            self.exit_stack.close()
140 1
            self.exit_stack = None
141
142 1
    def partial(
143
        self,
144
        func: Callable[..., X],
145
        shared: Optional[List[Type]] = None,
146
        container_updater: Optional[CallTimeContainerUpdate] = None,
147
    ) -> ContainerBoundFunction[X]:
148 1
        base_partial = super(ContextContainer, self).partial(
149
            func, shared, container_updater
150
        )
151
152 1
        return wraps(base_partial)(_ContextBoundFunction(self, base_partial))
153
154 1
    def magic_partial(
155
        self,
156
        func: Callable[..., X],
157
        shared: Optional[List[Type]] = None,
158
        keys_to_skip: Optional[List[str]] = None,
159
        skip_pos_up_to: int = 0,
160
        container_updater: Optional[CallTimeContainerUpdate] = None,
161
    ) -> ContainerBoundFunction[X]:
162 1
        base_partial = super(ContextContainer, self).magic_partial(
163
            func, shared, keys_to_skip, skip_pos_up_to, container_updater
164
        )
165
166 1
        return wraps(base_partial)(_ContextBoundFunction(self, base_partial))
167
168 1
    def _context_type_def(self, dep_type: Type):
169 1
        type_def = self.get_definition(ContextManager[dep_type]) or self.get_definition(Iterator[dep_type]) or self.get_definition(Generator[dep_type, None, None])  # type: ignore
170 1
        if type_def is None:
171 1
            raise InvalidDependencyDefinition(
172
                f"A ContextManager[{dep_type}] should be defined. "
173
                f"This could be an Iterator[{dep_type}] or Generator[{dep_type}, None, None] "
174
                f"with the @contextmanager decorator"
175
            )
176 1
        if isinstance(type_def, Alias):
177
            # Without this we create a definition that points to
178
            # itself.
179 1
            type_def = copy(type_def)
180 1
            type_def.skip_definitions = True
181 1
        return ConstructionWithContainer(lambda c: self._context_resolver(c, type_def))  # type: ignore
182
183 1
    def _singleton_type_def(self, dep_type: Type):
184
        """
185
        The same as context_type_def but acts as a singleton within this container
186
        """
187 1
        return SingletonWrapper(self._context_type_def(dep_type))
188
189 1
    def _context_resolver(self, c: ReadableContainer, type_def: SpecialDepDefinition):
190
        """
191
        Takes an existing definition which must be a context manager. Returns
192
        the value of the context manager from __enter__ and then places the
193
        __exit__ in this container's exit stack
194
        """
195 1
        assert self.exit_stack, "Types can only be resolved within a with"
196 1
        context_manager = cast(ContextManager, type_def.get_instance(c))
197
        return self.exit_stack.enter_context(context_manager)
198