aiogremlin.driver.protocol   A
last analyzed

Complexity

Total Complexity 12

Size/Duplication

Total Lines 76
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 12
eloc 59
dl 0
loc 76
rs 10
c 0
b 0
f 0

4 Methods

Rating   Name   Duplication   Size   Complexity  
B GremlinServerWSProtocol.data_received() 0 32 7
A GremlinServerWSProtocol.write() 0 6 2
A GremlinServerWSProtocol.__init__() 0 6 2
A GremlinServerWSProtocol.connection_made() 0 2 1
1
import asyncio
2
import base64
3
import collections
4
import logging
5
6
try:
7
    import ujson as json
8
except ImportError:
9
    import json
10
11
from gremlin_python.driver import protocol, request, serializer
12
13
14
__author__ = 'David M. Brown ([email protected])'
15
16
17
logger = logging.getLogger(__name__)
18
19
20
Message = collections.namedtuple(
21
    "Message",
22
    ["status_code", "data", "message"])
23
24
25
class GremlinServerWSProtocol(protocol.AbstractBaseProtocol):
26
    """Implemenation of the Gremlin Server Websocket protocol"""
27
    def __init__(self, message_serializer, username='', password=''):
28
        if isinstance(message_serializer, type):
29
            message_serializer = message_serializer()
30
        self._message_serializer = message_serializer
31
        self._username = username
32
        self._password = password
33
34
    def connection_made(self, transport):
35
        self._transport = transport
36
37
    async def write(self, request_id, request_message):
38
        message = self._message_serializer.serialize_message(
39
            request_id, request_message)
40
        func = self._transport.write(message)
41
        if asyncio.iscoroutine(func):
42
            await func
43
44
    async def data_received(self, data, results_dict):
45
        data = data.decode('utf-8')
46
        message = self._message_serializer.deserialize_message(json.loads(data))
47
        request_id = message['requestId']
48
        status_code = message['status']['code']
49
        data = message['result']['data']
50
        msg = message['status']['message']
51
        if request_id in results_dict:
52
            result_set = results_dict[request_id]
53
            aggregate_to = message['result']['meta'].get('aggregateTo', 'list')
54
            result_set.aggregate_to = aggregate_to
55
56
            if status_code == 407:
57
                auth = b''.join([b'\x00', self._username.encode('utf-8'),
58
                                 b'\x00', self._password.encode('utf-8')])
59
                request_message = request.RequestMessage(
60
                    'traversal', 'authentication',
61
                    {'sasl': base64.b64encode(auth).decode()})
62
                await self.write(request_id, request_message)
63
            elif status_code == 204:
64
                result_set.queue_result(None)
65
            else:
66
                if data:
67
                    for result in data:
68
                        result = self._message_serializer.deserialize_message(result)
69
                        message = Message(status_code, result, msg)
70
                        result_set.queue_result(message)
71
                else:
72
                    message = Message(status_code, data, msg)
73
                    result_set.queue_result(message)
74
                if status_code != 206:
75
                    result_set.queue_result(None)
76