1 | """ |
||
2 | For use with Starlette (https://www.starlette.io/) |
||
3 | """ |
||
4 | |||
5 | 1 | from typing import List, Type, Callable, Optional, Union |
|
6 | |||
7 | 1 | from starlette.routing import Route, WebSocketRoute |
|
8 | 1 | from starlette.endpoints import HTTPEndpoint, WebSocketEndpoint |
|
9 | |||
10 | 1 | from ..interfaces import ExtendableContainer |
|
11 | |||
12 | 1 | OVERRIDE_HTTP_METHODS = { |
|
13 | "get", |
||
14 | "head", |
||
15 | "post", |
||
16 | "put", |
||
17 | "delete", |
||
18 | "connect", |
||
19 | "options", |
||
20 | "trace", |
||
21 | "patch", |
||
22 | } |
||
23 | |||
24 | 1 | OVERRIDE_WEBSOCKET_METHODS = {"on_connect", "on_receive", "on_disconnect"} |
|
25 | |||
26 | |||
27 | 1 | class StarletteIntegration: |
|
28 | """ |
||
29 | Wraps a container and a route method for use in the Starlette framework |
||
30 | """ |
||
31 | |||
32 | 1 | _request_singletons: List[Type] |
|
33 | 1 | _container: ExtendableContainer |
|
34 | |||
35 | 1 | def __init__( |
|
36 | self, |
||
37 | container: ExtendableContainer, |
||
38 | request_singletons: Optional[List[Type]] = None, |
||
39 | ): |
||
40 | """ |
||
41 | :param request_singletons: List of types that will be singletons for a request |
||
42 | :param container: |
||
43 | """ |
||
44 | 1 | self._request_singletons = request_singletons or [] |
|
45 | 1 | self._container = container |
|
46 | |||
47 | 1 | def route( |
|
48 | self, |
||
49 | path: str, |
||
50 | endpoint: Callable, |
||
51 | *, |
||
52 | methods: Optional[List[str]] = None, |
||
53 | name: Optional[str] = None, |
||
54 | include_in_schema: bool = True, |
||
55 | ) -> Route: |
||
56 | """Returns an instance of a starlette Route |
||
57 | The callable endpoint is bound to the container so dependencies can be |
||
58 | injected. All other arguments are passed on to starlette. |
||
59 | :param path: |
||
60 | :param endpoint: |
||
61 | :param methods: |
||
62 | :param name: |
||
63 | :param include_in_schema: |
||
64 | :return: |
||
65 | """ |
||
66 | 1 | wrapped = self.wrapped_endpoint_factory(endpoint, self._container.partial) |
|
67 | |||
68 | 1 | return Route( |
|
69 | path, |
||
70 | wrapped, |
||
71 | methods=methods, |
||
72 | name=name, |
||
73 | include_in_schema=include_in_schema, |
||
74 | ) |
||
75 | |||
76 | 1 | def magic_route( |
|
77 | self, |
||
78 | path: str, |
||
79 | endpoint: Callable, |
||
80 | *, |
||
81 | methods: Optional[List[str]] = None, |
||
82 | name: Optional[str] = None, |
||
83 | include_in_schema: bool = True, |
||
84 | ) -> Route: |
||
85 | """Returns an instance of a starlette Route |
||
86 | The callable endpoint is bound to the container so dependencies can be |
||
87 | auto injected. All other arguments are passed on to starlette. |
||
88 | :param path: |
||
89 | :param endpoint: |
||
90 | :param methods: |
||
91 | :param name: |
||
92 | :param include_in_schema: |
||
93 | :return: |
||
94 | """ |
||
95 | wrapped = self.wrapped_endpoint_factory(endpoint, self._container.magic_partial) |
||
96 | |||
97 | return Route( |
||
98 | path, |
||
99 | wrapped, |
||
100 | methods=methods, |
||
101 | name=name, |
||
102 | include_in_schema=include_in_schema, |
||
103 | ) |
||
104 | |||
105 | 1 | def ws_route( |
|
106 | self, |
||
107 | path: str, |
||
108 | endpoint: Callable, |
||
109 | *, |
||
110 | name: Optional[str] = None, |
||
111 | ) -> WebSocketRoute: |
||
112 | """Returns an instance of a starlette WebSocketRoute |
||
113 | The callable endpoint is bound to the container so dependencies can be |
||
114 | injected. All other arguments are passed on to starlette. |
||
115 | :param path: |
||
116 | :param endpoint: |
||
117 | :param name: |
||
118 | :return: |
||
119 | """ |
||
120 | 1 | wrapped = self.wrapped_endpoint_factory(endpoint, self._container.partial) |
|
121 | |||
122 | 1 | return WebSocketRoute(path, wrapped, name=name) |
|
123 | |||
124 | 1 | def ws_magic_route( |
|
125 | self, |
||
126 | path: str, |
||
127 | endpoint: Callable, |
||
128 | *, |
||
129 | name: Optional[str] = None, |
||
130 | ) -> WebSocketRoute: |
||
131 | """Returns an instance of a starlette WebSocketRoute |
||
132 | The callable endpoint is bound to the container so dependencies can be |
||
133 | auto injected. All other arguments are passed on to starlette. |
||
134 | :param path: |
||
135 | :param endpoint: |
||
136 | :param name: |
||
137 | :return: |
||
138 | """ |
||
139 | wrapped = self.wrapped_endpoint_factory(endpoint, self._container.magic_partial) |
||
140 | |||
141 | return WebSocketRoute(path, wrapped, name=name) |
||
142 | |||
143 | 1 | def wrapped_endpoint_factory( |
|
144 | self, endpoint: Union[Callable, Type[HTTPEndpoint]], partial_provider: Callable |
||
145 | ): |
||
146 | """Builds an instance of a starlette Route with endpoint callables |
||
147 | bound to the container so dependencies can be auto injected. |
||
148 | |||
149 | :param endpoint: |
||
150 | :param partial_provider: |
||
151 | """ |
||
152 | 1 | if not isinstance(endpoint, type): |
|
153 | 1 | return partial_provider(endpoint, shared=self._request_singletons) |
|
154 | |||
155 | 1 | if issubclass(endpoint, HTTPEndpoint): |
|
156 | 1 | return self.create_http_endpoint_proxy( |
|
157 | endpoint, partial_provider, self._request_singletons |
||
158 | ) |
||
159 | |||
160 | 1 | if issubclass(endpoint, WebSocketEndpoint): |
|
161 | 1 | return self.create_websocket_endpoint_proxy( |
|
162 | endpoint, partial_provider, self._request_singletons |
||
163 | ) |
||
164 | |||
165 | 1 | View Code Duplication | @staticmethod |
0 ignored issues
–
show
Duplication
introduced
by
![]() |
|||
166 | 1 | def create_http_endpoint_proxy( |
|
167 | endpoint_cls: Type[HTTPEndpoint], |
||
168 | partial_provider: Callable, |
||
169 | request_singletons: List[Type], |
||
170 | ) -> Type[HTTPEndpoint]: |
||
171 | """Create a subclass of Starlette's HTTPEndpoint which injects dependencies |
||
172 | into HTTP-method-named methods on the user's `endpoint_cls` subclass of HTTPEndpoint |
||
173 | |||
174 | :param endpoint_cls: |
||
175 | :param partial_provider: |
||
176 | :param request_singletons: |
||
177 | """ |
||
178 | |||
179 | 1 | class HTTPEndpointProxy(HTTPEndpoint): |
|
180 | 1 | def __init__(self, scope, receive, send): |
|
181 | 1 | super().__init__(scope, receive, send) |
|
182 | 1 | self.endpoint = endpoint_cls(scope, receive, send) |
|
183 | |||
184 | 1 | def __getattribute__(self, name: str): |
|
185 | 1 | if name in OVERRIDE_HTTP_METHODS: |
|
186 | 1 | endpoint_instance = object.__getattribute__(self, "endpoint") |
|
187 | 1 | endpoint_method = getattr(endpoint_instance, name) |
|
188 | |||
189 | 1 | return partial_provider(endpoint_method, shared=request_singletons) |
|
190 | |||
191 | 1 | return object.__getattribute__(self, name) |
|
192 | |||
193 | 1 | return HTTPEndpointProxy |
|
194 | |||
195 | 1 | View Code Duplication | @staticmethod |
0 ignored issues
–
show
|
|||
196 | 1 | def create_websocket_endpoint_proxy( |
|
197 | endpoint_cls: Type[WebSocketEndpoint], |
||
198 | partial_provider: Callable, |
||
199 | request_singletons: List[Type], |
||
200 | ) -> Type[WebSocketEndpoint]: |
||
201 | """Create a subclass of Starlette's WebSocketEndpoint which injects dependencies |
||
202 | into relevant methods on the user's `endpoint_cls` subclass of WebSocketEndpoint |
||
203 | |||
204 | :param endpoint_cls: |
||
205 | :param partial_provider: |
||
206 | :param request_singletons: |
||
207 | """ |
||
208 | |||
209 | 1 | class WebSocketEndpointProxy(WebSocketEndpoint): |
|
210 | 1 | def __init__(self, scope, receive, send): |
|
211 | 1 | super().__init__(scope, receive, send) |
|
212 | 1 | self.endpoint = endpoint_cls(scope, receive, send) |
|
213 | |||
214 | 1 | def __getattribute__(self, name: str): |
|
215 | 1 | if name in OVERRIDE_WEBSOCKET_METHODS: |
|
216 | 1 | endpoint_instance = object.__getattribute__(self, "endpoint") |
|
217 | 1 | endpoint_method = getattr(endpoint_instance, name) |
|
218 | |||
219 | 1 | return partial_provider(endpoint_method, shared=request_singletons) |
|
220 | |||
221 | 1 | return object.__getattribute__(self, name) |
|
222 | |||
223 | return WebSocketEndpointProxy |
||
224 |