Passed
Push — master ( 91e384...e66e62 )
by Jeffrey
01:13
created

goblin.session.Session.__dirty_element()   A

Complexity

Conditions 4

Size

Total Lines 8
Code Lines 7

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 4
eloc 7
nop 3
dl 0
loc 8
rs 10
c 0
b 0
f 0
1
"""Main OGM API classes and constructors"""
2
3
import asyncio
4
import collections
5
import logging
6
import weakref
7
import uuid
8
from typing import Callable, Awaitable, Any, Optional
9
10
import aiogremlin # type: ignore
11
from aiogremlin.driver.protocol import Message # type: ignore
12
from aiogremlin.driver.resultset import ResultSet # type: ignore
13
from gremlin_python.process.graph_traversal import __, GraphTraversal # type: ignore
14
from gremlin_python.driver.remote_connection import RemoteTraversal # type: ignore
15
from gremlin_python.process.traversal import Binding, Cardinality, Traverser # type: ignore
16
from gremlin_python.structure.graph import Edge, Vertex # type: ignore
17
18
from goblin import exception, mapper
19
from goblin.element import GenericEdge, GenericVertex, VertexProperty, ImmutableMode, LockingMode
20
from goblin.manager import VertexPropertyManager
21
22
logger = logging.getLogger(__name__)
23
24
25
def bindprop(element_class, ogm_name, val, *, binding=None):
26
    """
27
    Helper function for binding ogm properties/values to corresponding db
28
    properties/values for traversals.
29
30
    :param goblin.element.Element element_class: User defined element class
31
    :param str ogm_name: Name of property as defined in the ogm
32
    :param val: The property value
33
    :param str binding: The binding for val (optional)
34
35
    :returns: tuple object ('db_property_name', ('binding(if passed)', val))
36
    """
37
    db_name = getattr(element_class, ogm_name, ogm_name)
38
    _, data_type = element_class.__mapping__.ogm_properties[ogm_name]
39
    val = data_type.to_db(val)
40
    if binding:
41
        val = (binding, val)
42
    return db_name, val
43
44
45
class Session:
46
    """
47
    Provides the main API for interacting with the database. Does not
48
    necessarily correpsond to a database session. Don't instantiate directly,
49
    instead use :py:meth:`Goblin.session<goblin.app.Goblin.session>`.
50
51
    :param goblin.app.Goblin app:
52
    :param aiogremlin.driver.connection.Connection conn:
53
    """
54
55
    def __init__(self, app, remote_connection, get_hashable_id):
56
        self._app = app
57
        self._remote_connection = remote_connection
58
        self._loop = self._app._loop
59
        self._use_session = False
60
        self._pending = collections.deque()
61
        self._current = dict()
62
        self._get_hashable_id = get_hashable_id
63
        self._graph = aiogremlin.Graph()
64
65
    @property
66
    def graph(self):
67
        return self._graph
68
69
    @property
70
    def app(self):
71
        return self._app
72
73
    @property
74
    def remote_connection(self):
75
        return self._remote_connection
76
77
    @property
78
    def current(self):
79
        return self._current
80
81
    async def __aenter__(self):
82
        return self
83
84
    async def __aexit__(self, exc_type, exc, tb):
85
        self.close()
86
87
    def close(self):
88
        """
89
        """
90
        self._remote_connection = None
91
        self._app = None
92
93
    # Traversal API
94
    @property
95
    def g(self):
96
        """
97
        Get a simple traversal source.
98
99
        :returns:
100
            `gremlin_python.process.GraphTraversalSource`
101
            object
102
        """
103
        return self.traversal()
104
105
    @property
106
    def _g(self):
107
        """
108
        Traversal source for internal use. Uses undelying conn. Doesn't
109
        trigger complex deserailization.
110
        """
111
        return self.graph.traversal().withRemote(self.remote_connection)
112
113
    def traversal(self, element_class=None):
114
        """
115
        Generate a traversal using a user defined element class as a
116
        starting point.
117
118
        :param goblin.element.Element element_class: An optional element
119
            class that will dictate the element type (vertex/edge) as well as
120
            the label for the traversal source
121
122
        :returns: `aiogremlin.process.graph_traversal.AsyncGraphTraversal`
123
        """
124
        traversal = self.graph.traversal().withRemote(self)
125
        if element_class:
126
            label = element_class.__mapping__.label
127
            if element_class.__type__ == 'vertex':
128
                traversal = traversal.V()
129
            if element_class.__type__ == 'edge':
130
                traversal = traversal.E()
131
            traversal = traversal.hasLabel(label)
132
        return traversal
133
134
    async def submit(self, bytecode):
135
        """
136
        Submit a query to the Gremiln Server.
137
138
        :param str gremlin: Gremlin script to submit to server.
139
        :param dict bindings: A mapping of bindings for Gremlin script.
140
141
        :returns:
142
            `gremlin_python.driver.remove_connection.RemoteTraversal`
143
            object
144
        """
145
        await self.flush()
146
        remote_traversal = await self.remote_connection.submit(bytecode)
147
        traversers = remote_traversal.traversers
148
        side_effects = remote_traversal.side_effects
149
        result_set = ResultSet(traversers.request_id, traversers._timeout,
150
                               self._loop)
151
        self._loop.create_task(self._receive(traversers, result_set))
152
        return RemoteTraversal(result_set, side_effects)
153
154
    async def _receive(self, traversers, result_set):
155
        try:
156
            async for result in traversers:
157
                result = await self._deserialize_result(result)
158
                msg = Message(200, result, '')
159
                result_set.queue_result(msg)
160
        except Exception as e:
161
            msg = Message(500, None, e.args[0])
162
            result_set.queue_result(msg)
163
        finally:
164
            result_set.queue_result(None)
165
166
    async def _deserialize_result(self, result):
167
        if isinstance(result, Traverser):
168
            bulk = result.bulk
169
            obj = result.object
170
            if isinstance(obj, (Vertex, Edge)):
171
                hashable_id = self._get_hashable_id(obj.id)
172
                current = self.current.get(hashable_id, None)
173
                if isinstance(obj, Vertex):
174
                    # why doesn't this come in on the vertex?
175
                    label = await self._g.V(obj.id).label().next()
176
                    if not current:
177
                        current = self.app.vertices.get(label, GenericVertex)()
178
                    props = await self._get_vertex_properties(obj.id, label)
179
                if isinstance(obj, Edge):
180
                    props = await self._g.E(obj.id).valueMap(True).next()
181
                    if not current:
182
                        current = self.app.edges.get(
183
                            props.get('label'), GenericEdge)()
184
                        current.source = GenericVertex()
185
                        current.target = GenericVertex()
186
                element = current.__mapping__.mapper_func(obj, props, current)
0 ignored issues
show
introduced by
The variable props does not seem to be defined in case isinstance(obj, Vertex) on line 173 is False. Are you sure this can never be the case?
Loading history...
187
                self.current[hashable_id] = element
188
                return Traverser(element, bulk)
189
            else:
190
                return result
191
        # Recursive serialization is broken in goblin
192
        elif isinstance(result, dict):
193
            for key in result:
194
                result[key] = self._deserialize_result(result[key])
195
            return result
196
        elif isinstance(result, list):
197
            return [self._deserialize_result(item) for item in result]
198
        else:
199
            return result
200
201
    async def _get_vertex_properties(self, vid, label):
202
        projection = self._g.V(vid).properties() \
203
                            .project('id', 'key', 'value', 'meta') \
204
                            .by(__.id()).by(__.key()).by(__.value()) \
205
                            .by(__.valueMap())
206
        props = await projection.toList()
207
        new_props = {'label': label, 'id': vid}
208
        for prop in props:
209
            key = prop['key']
210
            val = prop['value']
211
            # print('val_type', type(val))
212
            meta = prop['meta']
213
            new_props.setdefault(key, [])
214
            if meta:
215
                meta['key'] = key
216
                meta['value'] = val
217
                meta['id'] = prop['id']
218
                val = meta
219
220
            new_props[key].append(val)
221
        return new_props
222
223
    # Creation API
224
    def add(self, *elements):
225
        """
226
        Add elements to session pending queue.
227
228
        :param goblin.element.Element elements: Elements to be added
229
        """
230
        for elem in elements:
231
            self._pending.append(elem)
232
233
    async def flush(
234
                    self,
235
                    conflicts_query: Optional[GraphTraversal] = None
236
                  ) -> None:
237
        """
238
        Issue creation/update queries to database for all elements in the
239
        session pending queue.
240
        """
241
        transaction_id = str(uuid.uuid4())
242
        processed = []
243
        try:
244
            while self._pending:
245
                elem = self._pending.popleft()
246
                actual_id = self.__dirty_element(elem, id=transaction_id)
247
                if actual_id:
248
                    processed.append(await self.save(elem))
249
                else:
250
                    await self.save(elem)
251
252
            if not processed: return
253
            if not conflicts_query:
254
                await self.__commit_transaction(transaction_id)
255
            else:
256
                await (self.
257
                       g.
258
                       E().
259
                       has('dirty', transaction_id).
260
                       aggregate('x').
261
                       fold().
262
                       V().
263
                       has('dirty', transaction_id).
264
                       aggregate('x').
265
                       choose(
266
                           conflicts_query,
267
268
                           __.
269
                           select('x').
270
                           unfold().
271
                           properties('dirty').
272
                           drop()).
273
                       iterate())  # type: ignore
274
                await self.__rollback_transaction(transaction_id)
275
        except Exception as e:
276
            await self.__rollback_transaction(transaction_id)
277
            raise e
278
        for elem in processed:
279
            elem.dirty = None
280
281
    async def remove_vertex(self, vertex):
282
        """
283
        Remove a vertex from the db.
284
285
        :param goblin.element.Vertex vertex: Vertex to be removed
286
        """
287
        traversal = self._g.V(Binding('vid', vertex.id)).drop()
288
        result = await self._simple_traversal(traversal, vertex)
289
        hashable_id = self._get_hashable_id(vertex.id)
290
        if hashable_id in self.current:
291
            vertex = self.current.pop(hashable_id)
292
        else:
293
            msg = 'Vertex {} does not belong to this session obj {}'.format(
294
                vertex, self)
295
            logger.warning(msg)
296
        del vertex
297
        return result
298
299
    async def remove_edge(self, edge):
300
        """
301
        Remove an edge from the db.
302
303
        :param goblin.element.Edge edge: Element to be removed
304
        """
305
        eid = edge.id
306
        if isinstance(eid, dict):
307
            eid = Binding('eid', edge.id)
308
        traversal = self._g.E(eid).drop()
309
        result = await self._simple_traversal(traversal, edge)
310
        hashable_id = self._get_hashable_id(edge.id)
311
        if hashable_id in self.current:
312
            edge = self.current.pop(hashable_id)
313
        else:
314
            msg = 'Edge {} does not belong to this session obj {}'.format(
315
                edge, self)
316
            logger.warning(msg)
317
        del edge
318
        return result
319
320
    async def save(self, elem):
321
        """
322
        Save an element to the db.
323
324
        :param goblin.element.Element element: Vertex or Edge to be saved
325
326
        :returns: :py:class:`Element<goblin.element.Element>` object
327
        """
328
        if elem.__type__ == 'vertex':
329
            result = await self.save_vertex(elem)
330
        elif elem.__type__ == 'edge':
331
            result = await self.save_edge(elem)
332
        else:
333
            raise exception.ElementError("Unknown element type: {}".format(
334
                elem.__type__))
335
        return result
336
337
    async def save_vertex(self, vertex):
338
        """
339
        Save a vertex to the db.
340
341
        :param goblin.element.Vertex element: Vertex to be saved
342
343
        :returns: :py:class:`Vertex<goblin.element.Vertex>` object
344
        """
345
        result = await self._save_element(
346
            vertex, self._check_vertex, self._add_vertex, self._update_vertex)
347
        hashable_id = self._get_hashable_id(result.id)
348
        self.current[hashable_id] = result
349
        return result
350
351
    async def save_edge(self, edge):
352
        """
353
        Save an edge to the db.
354
355
        :param goblin.element.Edge element: Edge to be saved
356
357
        :returns: :py:class:`Edge<goblin.element.Edge>` object
358
        """
359
        if not (hasattr(edge, 'source') and hasattr(edge, 'target')):
360
            raise exception.ElementError(
361
                "Edges require both source/target vertices")
362
        result = await self._save_element(edge, self._check_edge,
363
                                          self._add_edge, self._update_edge)
364
        hashable_id = self._get_hashable_id(result.id)
365
        self.current[hashable_id] = result
366
        return result
367
368
    async def get_vertex(self, vertex):
369
        """
370
        Get a vertex from the db. Vertex must have id.
371
372
        :param goblin.element.Vertex element: Vertex to be retrieved
373
374
        :returns: :py:class:`Vertex<goblin.element.Vertex>` | None
375
        """
376
        return await self.g.V(Binding('vid', vertex.id)).next()
377
378
    async def get_edge(self, edge):
379
        """
380
        Get a edge from the db. Edge must have id.
381
382
        :param goblin.element.Edge element: Edge to be retrieved
383
384
        :returns: :py:class:`Edge<goblin.element.Edge>` | None
385
        """
386
        eid = edge.id
387
        if isinstance(eid, dict):
388
            eid = Binding('eid', edge.id)
389
        return await self.g.E(eid).next()
390
391
    def __dirty_element(self, elem, id = str(uuid.uuid4())):
392
        if elem.__locking__ and elem.__locking__ == LockingMode.OPTIMISTIC_LOCKING:
393
            if not elem.dirty:
394
                elem.dirty = id
395
                return id
396
            else:
397
                return elem.dirty
398
        return None
399
400
    async def __commit_transaction(self, id):
401
        if id: await self._g.E().has('dirty',id).aggregate('x').fold().V().has('dirty',id).aggregate('x').select('x').unfold().properties('dirty').drop().iterate()
402
403
    async def __rollback_transaction(self, id):
404
        print("id of: %s" % id)
405
        if id: await self._g.E().has('dirty',id).aggregate('x').fold().V().has('dirty',id).aggregate('x').select('x').unfold().drop().iterate()
406
407
    async def _update_vertex(self, vertex):
408
        """
409
        Update a vertex, generally to change/remove property values.
410
411
        :param goblin.element.Vertex vertex: Vertex to be updated
412
413
        :returns: :py:class:`Vertex<goblin.element.Vertex>` object
414
        """
415
        props = mapper.map_props_to_db(vertex, vertex.__mapping__)
416
        traversal = self._g.V(Binding('vid', vertex.id))
417
        return await self._update_vertex_properties(vertex, traversal, props)
418
419
    async def _update_edge(self, edge):
420
        """
421
        Update an edge, generally to change/remove property values.
422
423
        :param goblin.element.Edge edge: Edge to be updated
424
425
        :returns: :py:class:`Edge<goblin.element.Edge>` object
426
        """
427
        props = mapper.map_props_to_db(edge, edge.__mapping__)
428
        eid = edge.id
429
        if isinstance(eid, dict):
430
            eid = Binding('eid', edge.id)
431
        traversal = self._g.E(eid)
432
        return await self._update_edge_properties(edge, traversal, props)
433
434
    # *metodos especiales privados for creation API
435
436
    async def _simple_traversal(self, traversal, element):
437
        elem = await traversal.next()
438
        if elem:
439
            if element.__type__ == 'vertex':
440
                # Look into this
441
                label = await self._g.V(elem.id).label().next()
442
                props = await self._get_vertex_properties(elem.id, label)
443
            elif element.__type__ == 'edge':
444
                props = await self._g.E(elem.id).valueMap(True).next()
445
            elem = element.__mapping__.mapper_func(elem, props, element)
0 ignored issues
show
introduced by
The variable props does not seem to be defined for all execution paths.
Loading history...
446
        return elem
447
448
449
    async def __handle_create_func(self, elem, create_func):
450
        transaction_id = elem.dirty
451
        if not transaction_id:
452
            transaction_id = self.__dirty_element(elem)
453
            if transaction_id:
454
                result = None
455
                try:
456
                    result = await create_func(elem)
457
                    await self.__commit_transaction(transaction_id)
458
                    result.dirty = None
459
                except Exception as e:
460
                    await self.__rollback_transaction(transaction_id)
461
                    raise e
462
                return result
463
464
        return await create_func(elem)
465
466
467
    async def _save_element(self, elem, check_func, create_func, update_func):
468
        if hasattr(elem, 'id'):
469
            exists = await check_func(elem)
470
            if not exists:
471
                result = await self.__handle_create_func(elem, create_func)
472
            else:
473
                if elem.__immutable__ and elem.__immutable__ != ImmutableMode.OFF: raise AttributeError("Trying to update an immutable element: %s" % elem)
474
                result = await update_func(elem)
475
        else:
476
            result = await self.__handle_create_func(elem, create_func)
477
        return result
478
479
    async def _add_vertex(self, vertex):
480
        """Convenience function for generating crud traversals."""
481
        props = mapper.map_props_to_db(vertex, vertex.__mapping__)
482
        traversal = self._g.addV(vertex.__mapping__.label)
483
        return await self._add_properties(traversal, props, vertex)
484
485
    async def _add_edge(self, edge):
486
        """Convenience function for generating crud traversals."""
487
        props = mapper.map_props_to_db(edge, edge.__mapping__)
488
        traversal = self._g.V(Binding('sid', edge.source.id))
489
        traversal = traversal.addE(edge.__mapping__._label)
490
        traversal = traversal.to(__.V(Binding('tid', edge.target.id)))
491
        return await self._add_properties(traversal, props, edge)
492
493
    async def _check_vertex(self, vertex):
494
        """Used to check for existence, does not update session vertex"""
495
        msg = await self._g.V(Binding('vid', vertex.id)).next()
496
        return msg
497
498
    async def _check_edge(self, edge):
499
        """Used to check for existence, does not update session edge"""
500
        eid = edge.id
501
        if isinstance(eid, dict):
502
            eid = Binding('eid', edge.id)
503
        return await self._g.E(eid).next()
504
505
    async def _update_vertex_properties(self, vertex, traversal, props):
506
        await self._g.V(vertex.id).properties().drop().iterate()
507
        return await self._add_properties(traversal, props, vertex)
508
509
    async def _update_edge_properties(self, edge, traversal, props):
510
        await self._g.E(edge.id).properties().drop().iterate()
511
        return await self._add_properties(traversal, props, edge)
512
513
    async def _add_properties(self, traversal, props, elem):
514
        binding = 0
515
        for card, db_name, val, metaprops in props:
516
            if not metaprops:
517
                metaprops = {}
518
            if val is not None:
519
                key = db_name
520
                #key = ('k' + str(binding), db_name)
521
                #val = ('v' + str(binding), val)
522
                if card:
523
                    # Maybe use a dict here as a translator
524
                    if card == Cardinality.list_:
525
                        card = Cardinality.list_
526
                    elif card == Cardinality.set_:
527
                        card = Cardinality.set_
528
                    else:
529
                        card = Cardinality.single
530
                    metas = [
531
                        j
532
                        for i in zip(metaprops.keys(), metaprops.values())
533
                        for j in i
534
                    ]
535
                    traversal = traversal.property(card, key, val, *metas)
536
                else:
537
                    metas = [
538
                        j
539
                        for i in zip(metaprops.keys(), metaprops.values())
540
                        for j in i
541
                    ]
542
                    traversal = traversal.property(key, val, *metas)
543
                binding += 1
544
        return await self._simple_traversal(traversal, elem)
545