1
|
|
|
import asyncio |
2
|
|
|
import aiohttp |
3
|
|
|
|
4
|
|
|
from gremlin_python.driver import transport |
5
|
|
|
|
6
|
|
|
|
7
|
|
|
class AiohttpTransport(transport.AbstractBaseTransport): |
8
|
|
|
|
9
|
|
|
def __init__(self, loop): |
10
|
|
|
self._loop = loop |
11
|
|
|
self._connected = False |
12
|
|
|
|
13
|
|
|
async def connect(self, url, *, ssl_context=None): |
14
|
|
|
await self.close() |
15
|
|
|
connector = aiohttp.TCPConnector( |
16
|
|
|
ssl_context=ssl_context, loop=self._loop) |
17
|
|
|
self._client_session = aiohttp.ClientSession( |
18
|
|
|
loop=self._loop, connector=connector) |
19
|
|
|
self._ws = await self._client_session.ws_connect(url) |
20
|
|
|
self._connected = True |
21
|
|
|
|
22
|
|
|
async def write(self, message): |
23
|
|
|
coro = self._ws.send_bytes(message) |
24
|
|
|
if asyncio.iscoroutine(coro): |
25
|
|
|
await coro |
26
|
|
|
|
27
|
|
|
async def read(self): |
28
|
|
|
data = await self._ws.receive() |
29
|
|
|
if data.type == aiohttp.WSMsgType.close: |
30
|
|
|
await self._transport.close() |
31
|
|
|
raise RuntimeError("Connection closed by server") |
32
|
|
|
elif data.type == aiohttp.WSMsgType.error: |
33
|
|
|
# This won't raise properly, fix |
34
|
|
|
raise data.data |
35
|
|
|
elif data.type == aiohttp.WSMsgType.closed: |
36
|
|
|
# Hmm |
37
|
|
|
raise RuntimeError("Connection closed by server") |
38
|
|
|
elif data.type == aiohttp.WSMsgType.text: |
39
|
|
|
# Should return bytes |
40
|
|
|
data = data.data.strip().encode('utf-8') |
41
|
|
|
else: |
42
|
|
|
data = data.data |
43
|
|
|
return data |
44
|
|
|
|
45
|
|
|
async def close(self): |
46
|
|
|
if self._connected: |
47
|
|
|
if not self._ws.closed: |
48
|
|
|
await self._ws.close() |
49
|
|
|
if not self._client_session.closed: |
50
|
|
|
await self._client_session.close() |
51
|
|
|
|
52
|
|
|
@property |
53
|
|
|
def closed(self): |
54
|
|
|
return self._ws.closed or self._client_session.closed |
55
|
|
|
|