Passed
Pull Request — master (#213)
by Steve
02:52
created

EmptyDefinitionSet.get_definition()   A

Complexity

Conditions 1

Size

Total Lines 7
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Code Coverage

Tests 2
CRAP Score 1

Importance

Changes 0
Metric Value
cc 1
eloc 2
nop 2
dl 0
loc 7
ccs 2
cts 2
cp 1
crap 1
rs 10
c 0
b 0
f 0
1 1
import functools
2 1
import io
3 1
import logging
4 1
import typing
5
6 1
from types import FunctionType, MethodType
7 1
from typing import (
8
    Dict,
9
    Type,
10
    Any,
11
    TypeVar,
12
    Callable,
13
    Set,
14
    List,
15
    Optional,
16
    cast,
17
    Union,
18
)
19
20 1
from .definitions import (
21
    normalise,
22
    Singleton,
23
    Alias,
24
    ConstructionWithoutContainer,
25
    UnresolvableTypeDefinition,
26
)
27 1
from .exceptions import (
28
    UnresolvableType,
29
    DuplicateDefinition,
30
    InvalidDependencyDefinition,
31
    RecursiveDefinitionError,
32
    DependencyNotDefined,
33
    TypeOnlyAvailableAsAwaitable,
34
)
35 1
from .interfaces import (
36
    SpecialDepDefinition,
37
    WriteableContainer,
38
    TypeResolver,
39
    DefinitionsSource,
40
    ExtendableContainer,
41
    ContainerDebugInfo,
42
    CallTimeContainerUpdate,
43
)
44 1
from .markers import injectable
45 1
from .updaters import update_container_singletons
46 1
from .util.logging import NullLogger
47 1
from .util.reflection import (
48
    FunctionSpec,
49
    CachingReflector,
50
    remove_optional_type,
51
    remove_awaitable_type,
52
)
53 1
from .wrapping import apply_argument_updater
54
55 1
UNRESOLVABLE_TYPES = [
56
    str,
57
    int,
58
    float,
59
    bool,
60
    bytes,
61
    bytearray,
62
    io.BytesIO,
63
    io.BufferedIOBase,
64
    io.BufferedRandom,
65
    io.BufferedReader,
66
    io.BufferedRWPair,
67
    io.BufferedWriter,
68
    io.FileIO,
69
    io.IOBase,
70
    io.RawIOBase,
71
    io.TextIOBase,
72
    typing.IO,
73
    typing.TextIO,
74
    typing.BinaryIO,
75
]
76
77 1
X = TypeVar("X")
78
79 1
Unset: Any = object()
80
81
82 1
class Container(
83
    WriteableContainer, ExtendableContainer, DefinitionsSource, ContainerDebugInfo
84
):
85
    """Dependency injection container
86
87
    Lagom is a dependency injection container designed to give you "just enough"
88
    help with building your dependencies. The intention is that almost
89
    all of your code doesn't know about or rely on lagom. Lagom will
90
    only be involved at the top level to pull everything together.
91
92
    >>> from tests.examples import SomeClass
93
    >>> c = Container()
94
    >>> c[SomeClass]
95
    <tests.examples.SomeClass object at ...>
96
97
    Objects are constructed as they are needed
98
99
    >>> from tests.examples import SomeClass
100
    >>> c = Container()
101
    >>> first = c[SomeClass]
102
    >>> second = c[SomeClass]
103
    >>> first != second
104
    True
105
106
    And construction logic can be defined
107
    >>> from tests.examples import SomeClass, SomeExtendedClass
108
    >>> c = Container()
109
    >>> c[SomeClass] = SomeExtendedClass
110
    >>> c[SomeClass]
111
    <tests.examples.SomeExtendedClass object at ...>
112
    """
113
114 1
    _registered_types: Dict[Type, SpecialDepDefinition]
115 1
    _parent_definitions: DefinitionsSource
116 1
    _reflector: CachingReflector
117 1
    _undefined_logger: logging.Logger
118
119 1
    def __init_subclass__(cls, **kwargs):
120 1
        super().__init_subclass__(**kwargs)
121 1
        if not hasattr(cls, "_lagom_class"):
122
            raise TypeError(
123
                f"Container can not be subclassed - on some platforms this is a compiled class "
124
                f"and it would be impossible: {type(cls)} with {dir(cls)}"
125
            )
126
127 1
    def __init__(
128
        self,
129
        container: Optional["Container"] = None,
130
        log_undefined_deps: Union[bool, logging.Logger] = False,
131
    ):
132
        """
133
        :param container: Optional container if provided the existing definitions will be copied
134
        :param log_undefined_deps indicates if a log message should be emmited when an undefined dep is loaded
135
        """
136
137
        # ContainerDebugInfo is always registered
138
        # This means consumers can consume an overview of the container
139
        # without hacking anything custom together.
140 1
        self._registered_types = {
141
            ContainerDebugInfo: ConstructionWithoutContainer(lambda: self)
142
        }
143
144 1
        if container:
145 1
            self._parent_definitions = container
146 1
            self._reflector = container._reflector
147
        else:
148 1
            self._parent_definitions = EmptyDefinitionSet()
149 1
            self._reflector = CachingReflector()
150
151 1
        if not log_undefined_deps:
152 1
            self._undefined_logger = NullLogger()
153 1
        elif log_undefined_deps is True:
154 1
            self._undefined_logger = logging.getLogger(__name__)
155
        else:
156 1
            self._undefined_logger = cast(logging.Logger, log_undefined_deps)
157
158 1
    def define(self, dep: Type[X], resolver: TypeResolver[X]) -> SpecialDepDefinition:
159
        """Register how to construct an object of type X
160
161
        >>> from tests.examples import SomeClass
162
        >>> c = Container()
163
        >>> c.define(SomeClass, lambda: SomeClass())
164
        <lagom.definitions.ConstructionWithoutContainer ...>
165
166
        :param dep: The type to be constructed
167
        :param resolver: A definition of how to construct it
168
        :return:
169
        """
170 1
        if dep in UNRESOLVABLE_TYPES:
171
            raise InvalidDependencyDefinition()
172 1
        if dep in self._registered_types:
173 1
            raise DuplicateDefinition()
174 1
        if dep is resolver:
175
            # This is a special case for things like container[Foo] = Foo
176 1
            return self.define(dep, Alias(dep, skip_definitions=True))
177 1
        definition = normalise(resolver)
178 1
        self._registered_types[dep] = definition
179 1
        self._registered_types[Optional[dep]] = definition  # type: ignore
180
181
        # For awaitables we add a convenience exception to be thrown if code hints on the type
182
        # without the awaitable.
183 1
        awaitable_type = remove_awaitable_type(dep)
184 1
        if awaitable_type:
185
            # Unless there's already a sync version defined.
186 1
            if awaitable_type not in self.defined_types:
187 1
                self._registered_types[awaitable_type] = UnresolvableTypeDefinition(
188
                    TypeOnlyAvailableAsAwaitable(awaitable_type)
189
                )
190 1
        return definition
191
192 1
    @property
193 1
    def defined_types(self) -> Set[Type]:
194
        """The types the container has explicit build instructions for
195
196
        :return:
197
        """
198 1
        return self._parent_definitions.defined_types.union(
199
            self._registered_types.keys()
200
        )
201
202 1
    @property
203 1
    def reflection_cache_overview(self) -> Dict[str, str]:
204 1
        return self._reflector.overview_of_cache
205
206 1
    def temporary_singletons(
207
        self, singletons: List[Type] = None
208
    ) -> "_TemporaryInjectionContext":
209
        """
210
        Returns a context that loads a new container with singletons that only exist
211
        for the context.
212
213
        >>> from tests.examples import SomeClass
214
        >>> base_container = Container()
215
        >>> def my_func():
216
        ...     with base_container.temporary_singletons([SomeClass]) as c:
217
        ...         assert c[SomeClass] is c[SomeClass]
218
        >>> my_func()
219
220
        :param singletons: items which should be considered singletons within the context
221
        :return:
222
        """
223 1
        updater = (
224
            functools.partial(update_container_singletons, singletons=singletons)
225
            if singletons
226
            else None
227
        )
228 1
        return _TemporaryInjectionContext(self, updater)
229
230 1
    def resolve(
231
        self, dep_type: Type[X], suppress_error=False, skip_definitions=False
232
    ) -> X:
233
        """Constructs an object of type X
234
235
         If the object can't be constructed an exception will be raised unless
236
         supress errors is true
237
238
        >>> from tests.examples import SomeClass
239
        >>> c = Container()
240
        >>> c.resolve(SomeClass)
241
        <tests.examples.SomeClass object at ...>
242
243
        >>> from tests.examples import SomeClass
244
        >>> c = Container()
245
        >>> c.resolve(int)
246
        Traceback (most recent call last):
247
        ...
248
        lagom.exceptions.UnresolvableType: ...
249
250
        Optional wrappers are stripped out to be what is being asked for
251
        >>> from tests.examples import SomeClass
252
        >>> c = Container()
253
        >>> c.resolve(Optional[SomeClass])
254
        <tests.examples.SomeClass object at ...>
255
256
        :param dep_type: The type of object to construct
257
        :param suppress_error: if true returns None on failure
258
        :param skip_definitions:
259
        :return:
260
        """
261 1
        if not skip_definitions:
262 1
            definition = self.get_definition(dep_type)
263 1
            if definition:
264 1
                return definition.get_instance(self)
265
266 1
        optional_dep_type = remove_optional_type(dep_type)
267 1
        if optional_dep_type:
268 1
            return self.resolve(optional_dep_type, suppress_error=True)
269
270 1
        return self._reflection_build_with_err_handling(dep_type, suppress_error)
271
272 1
    def partial(
273
        self,
274
        func: Callable[..., X],
275
        shared: List[Type] = None,
276
        container_updater: Optional[CallTimeContainerUpdate] = None,
277
    ) -> Callable[..., X]:
278
        """Takes a callable and returns a callable bound to the container
279
        When invoking the new callable if any arguments have a default set
280
        to the special marker object "injectable" then they will be constructed by
281
        the container. For automatic injection without the marker use "magic_partial"
282
        >>> from tests.examples import SomeClass
283
        >>> c = Container()
284
        >>> def my_func(something: SomeClass = injectable):
285
        ...     return f"Successfully called with {something}"
286
        >>> bound_func = c.magic_partial(my_func)
287
        >>> bound_func()
288
        'Successfully called with <tests.examples.SomeClass object at ...>'
289
290
        :param func: the function to bind to the container
291
        :param shared: items which should be considered singletons on a per call level
292
        :param container_updater: An optional callable to update the container before resolution
293
        :return:
294
        """
295 1
        spec = self._get_spec_without_self(func)
296 1
        keys_to_bind = (
297
            key for (key, arg) in spec.defaults.items() if arg is injectable
298
        )
299 1
        keys_and_types = [(key, spec.annotations[key]) for key in keys_to_bind]
300
301 1
        _injection_context = self.temporary_singletons(shared)
302 1
        update_container = container_updater if container_updater else _update_nothing
303
304 1
        def _update_args(supplied_args, supplied_kwargs):
305 1
            keys_to_skip = set(supplied_kwargs.keys())
306 1
            keys_to_skip.update(spec.args[0 : len(supplied_args)])
307 1
            with _injection_context as invocation_container:
308 1
                update_container(invocation_container, supplied_args, supplied_kwargs)
309 1
                kwargs = {
310
                    key: invocation_container.resolve(dep_type)
311
                    for (key, dep_type) in keys_and_types
312
                    if key not in keys_to_skip
313
                }
314 1
            kwargs.update(supplied_kwargs)
315 1
            return supplied_args, kwargs
316
317 1
        return apply_argument_updater(func, _update_args, spec)
318
319 1
    def magic_partial(
320
        self,
321
        func: Callable[..., X],
322
        shared: List[Type] = None,
323
        keys_to_skip: List[str] = None,
324
        skip_pos_up_to: int = 0,
325
        container_updater: Optional[CallTimeContainerUpdate] = None,
326
    ) -> Callable[..., X]:
327
        """Takes a callable and returns a callable bound to the container
328
        When invoking the new callable if any arguments can be constructed by the container
329
        then they can be ommited.
330
        >>> from tests.examples import SomeClass
331
        >>> c = Container()
332
        >>> def my_func(something: SomeClass):
333
        ...   return f"Successfully called with {something}"
334
        >>> bound_func = c.magic_partial(my_func)
335
        >>> bound_func()
336
        'Successfully called with <tests.examples.SomeClass object at ...>'
337
338
        :param func: the function to bind to the container
339
        :param shared: items which should be considered singletons on a per call level
340
        :param keys_to_skip: named arguments which the container shouldnt build
341
        :param skip_pos_up_to: positional arguments which the container shouldnt build
342
        :param container_updater: An optional callable to update the container before resolution
343
        :return:
344
        """
345 1
        spec = self._get_spec_without_self(func)
346
347 1
        update_container = container_updater if container_updater else _update_nothing
348 1
        _injection_context = self.temporary_singletons(shared)
349
350 1
        def _update_args(supplied_args, supplied_kwargs):
351 1
            final_keys_to_skip = (keys_to_skip or []) + list(supplied_kwargs.keys())
352 1
            final_skip_pos_up_to = max(skip_pos_up_to, len(supplied_args))
353 1
            with _injection_context as invocation_container:
354 1
                update_container(invocation_container, supplied_args, supplied_kwargs)
355 1
                kwargs = invocation_container._infer_dependencies(
356
                    spec,
357
                    suppress_error=True,
358
                    keys_to_skip=final_keys_to_skip,
359
                    skip_pos_up_to=final_skip_pos_up_to,
360
                )
361 1
            kwargs.update(supplied_kwargs)
362 1
            return supplied_args, kwargs
363
364 1
        return apply_argument_updater(func, _update_args, spec, catch_errors=True)
365
366 1
    def clone(self) -> "Container":
367
        """returns a copy of the container
368
        :return:
369
        """
370 1
        return Container(self, log_undefined_deps=self._undefined_logger)
371
372 1
    def get_definition(self, dep_type: Type[X]) -> Optional[SpecialDepDefinition[X]]:
373
        """
374
        Will return the definition in this container. If none has been defined any
375
        definition in the parent container will be used.
376
377
        :param dep_type:
378
        :return:
379
        """
380 1
        definition = self._registered_types.get(dep_type, Unset)
381 1
        if definition is Unset:
382 1
            return self._parent_definitions.get_definition(dep_type)
383 1
        return definition
384
385 1
    def __getitem__(self, dep: Type[X]) -> X:
386 1
        return self.resolve(dep)
387
388 1
    def __setitem__(self, dep: Type[X], resolver: TypeResolver[X]):
389 1
        self.define(dep, resolver)
390
391 1
    def _reflection_build_with_err_handling(
392
        self, dep_type: Type[X], suppress_error: bool
393
    ) -> X:
394 1
        try:
395 1
            if dep_type in UNRESOLVABLE_TYPES:
396 1
                raise UnresolvableType(dep_type)
397 1
            return self._reflection_build(dep_type)
398 1
        except UnresolvableType as inner_error:
399 1
            if not suppress_error:
400 1
                raise UnresolvableType(dep_type) from inner_error
401 1
            return None  # type: ignore
402 1
        except RecursionError as recursion_error:
403
            raise RecursiveDefinitionError(dep_type) from recursion_error
404
405 1
    def _reflection_build(self, dep_type: Type[X]) -> X:
406 1
        self._undefined_logger.warning(
407
            f"Undefined dependency. Using reflection for {dep_type}",
408
            extra={"undefined_dependency": dep_type},
409
        )
410 1
        spec = self._reflector.get_function_spec(dep_type.__init__)
411 1
        sub_deps = self._infer_dependencies(spec, types_to_skip={dep_type})
412 1
        try:
413 1
            return dep_type(**sub_deps)  # type: ignore
414 1
        except TypeError as type_error:
415 1
            raise UnresolvableType(dep_type) from type_error
416
417 1
    def _infer_dependencies(
418
        self,
419
        spec: FunctionSpec,
420
        suppress_error=False,
421
        keys_to_skip: List[str] = None,
422
        skip_pos_up_to=0,
423
        types_to_skip: Set[Type] = None,
424
    ):
425 1
        dep_keys_to_skip: List[str] = []
426 1
        dep_keys_to_skip.extend(spec.args[0:skip_pos_up_to])
427 1
        dep_keys_to_skip.extend(keys_to_skip or [])
428 1
        types_to_skip = types_to_skip or set()
429 1
        sub_deps = {
430
            key: self.resolve(sub_dep_type, suppress_error=suppress_error)
431
            for (key, sub_dep_type) in spec.annotations.items()
432
            if sub_dep_type != Any
433
            and (key not in dep_keys_to_skip)
434
            and (sub_dep_type not in types_to_skip)
435
        }
436 1
        return {key: dep for (key, dep) in sub_deps.items() if dep is not None}
437
438 1
    def _get_spec_without_self(self, func: Callable[..., X]) -> FunctionSpec:
439 1
        if isinstance(func, (FunctionType, MethodType)):
440 1
            return self._reflector.get_function_spec(func)
441 1
        t = cast(Type[X], func)
442 1
        return self._reflector.get_function_spec(t.__init__).without_argument("self")
443
444
445 1
class ExplicitContainer(Container):
446 1
    _lagom_class: typing.ClassVar[bool] = True
447
448 1
    def resolve(
449
        self, dep_type: Type[X], suppress_error=False, skip_definitions=False
450
    ) -> X:
451 1
        definition = self.get_definition(dep_type)
452 1
        if not definition:
453 1
            if suppress_error:
454 1
                return None  # type: ignore
455 1
            raise DependencyNotDefined(dep_type)
456 1
        return definition.get_instance(self)
457
458 1
    def define(self, dep, resolver):
459 1
        definition = super().define(dep, resolver)
460 1
        if isinstance(definition, Alias):
461 1
            raise InvalidDependencyDefinition(
462
                "Aliases are not valid in an explicit container"
463
            )
464 1
        if isinstance(definition, Singleton) and isinstance(
465
            definition.singleton_type, Alias
466
        ):
467 1
            raise InvalidDependencyDefinition(
468
                "Aliases are not valid inside singletons in an explicit container"
469
            )
470 1
        return definition
471
472 1
    def clone(self):
473
        """returns a copy of the container
474
        :return:
475
        """
476 1
        return ExplicitContainer(self, log_undefined_deps=self._undefined_logger)
477
478
479 1
class EmptyDefinitionSet(DefinitionsSource):
480
    """
481
    Represents the starting state for a collection of dependency definitions
482
    i.e. None and everything has to be built with reflection
483
    """
484
485 1
    def get_definition(self, dep_type: Type[X]) -> Optional[SpecialDepDefinition[X]]:
486
        """
487
        No types are defined in the empty set
488
        :param dep_type:
489
        :return:
490
        """
491 1
        return None
492
493 1
    @property
494 1
    def defined_types(self) -> Set[Type]:
495 1
        return set()
496
497
498 1
class _TemporaryInjectionContext:
499 1
    _base_container: Container
500
501 1
    def __init__(
502
        self,
503
        container: Container,
504
        update_function: Optional[Callable[[Container], Container]] = None,
505
    ):
506 1
        self._base_container = container
507 1
        if update_function:
508 1
            self._build_temporary_container = lambda: update_function(
509
                self._base_container
510
            )
511
        else:
512 1
            self._build_temporary_container = lambda: self._base_container.clone()
513
514 1
    def __enter__(self) -> Container:
515 1
        return self._build_temporary_container()
516
517 1
    def __exit__(self, exc_type, exc_val, exc_tb):
518 1
        pass
519
520
521 1
def _update_nothing(_c: WriteableContainer, _a: typing.Collection, _k: Dict):
522
    return None
523