goblin.session.Session.remove_vertex()   A
last analyzed

Complexity

Conditions 2

Size

Total Lines 17
Code Lines 11

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 11
nop 2
dl 0
loc 17
rs 9.85
c 0
b 0
f 0
1
"""Main OGM API classes and constructors"""
2
3
import asyncio
4
import collections
5
import logging
6
import weakref
7
import uuid
8
from typing import Callable, Awaitable, Any, Optional
9
10
import aiogremlin # type: ignore
11
from aiogremlin.driver.protocol import Message # type: ignore
12
from aiogremlin.driver.resultset import ResultSet # type: ignore
13
from gremlin_python.process.graph_traversal import __, GraphTraversal # type: ignore
14
from gremlin_python.driver.remote_connection import RemoteTraversal # type: ignore
15
from gremlin_python.process.traversal import Binding, Cardinality, Traverser # type: ignore
16
from gremlin_python.structure.graph import Edge, Vertex # type: ignore
17
18
from goblin import exception, mapper
19
from goblin.element import GenericEdge, GenericVertex, VertexProperty, ImmutableMode, LockingMode
20
from goblin.manager import VertexPropertyManager
21
import traceback
22
23
logger = logging.getLogger(__name__)
24
25
26
def bindprop(element_class, ogm_name, val, *, binding=None):
27
    """
28
    Helper function for binding ogm properties/values to corresponding db
29
    properties/values for traversals.
30
31
    :param goblin.element.Element element_class: User defined element class
32
    :param str ogm_name: Name of property as defined in the ogm
33
    :param val: The property value
34
    :param str binding: The binding for val (optional)
35
36
    :returns: tuple object ('db_property_name', ('binding(if passed)', val))
37
    """
38
    db_name = getattr(element_class, ogm_name, ogm_name)
39
    _, data_type = element_class.__mapping__.ogm_properties[ogm_name]
40
    val = data_type.to_db(val)
41
    if binding:
42
        val = (binding, val)
43
    return db_name, val
44
45
46
class Session:
47
    """
48
    Provides the main API for interacting with the database. Does not
49
    necessarily correpsond to a database session. Don't instantiate directly,
50
    instead use :py:meth:`Goblin.session<goblin.app.Goblin.session>`.
51
52
    :param goblin.app.Goblin app:
53
    :param aiogremlin.driver.connection.Connection conn:
54
    """
55
56
    def __init__(self, app, remote_connection, get_hashable_id):
57
        self._app = app
58
        self._remote_connection = remote_connection
59
        self._loop = self._app._loop
60
        self._use_session = False
61
        self._pending = collections.deque()
62
        self._current = dict()
63
        self._get_hashable_id = get_hashable_id
64
        self._graph = aiogremlin.Graph()
65
66
    @property
67
    def graph(self):
68
        return self._graph
69
70
    @property
71
    def app(self):
72
        return self._app
73
74
    @property
75
    def remote_connection(self):
76
        return self._remote_connection
77
78
    @property
79
    def current(self):
80
        return self._current
81
82
    async def __aenter__(self):
83
        return self
84
85
    async def __aexit__(self, exc_type, exc, tb):
86
        self.close()
87
88
    def close(self):
89
        """
90
        """
91
        self._remote_connection = None
92
        self._app = None
93
94
    # Traversal API
95
    @property
96
    def g(self):
97
        """
98
        Get a simple traversal source.
99
100
        :returns:
101
            `gremlin_python.process.GraphTraversalSource`
102
            object
103
        """
104
        return self.traversal()
105
106
    @property
107
    def _g(self):
108
        """
109
        Traversal source for internal use. Uses undelying conn. Doesn't
110
        trigger complex deserailization.
111
        """
112
        return self.graph.traversal().withRemote(self.remote_connection)
113
114
    def traversal(self, element_class=None):
115
        """
116
        Generate a traversal using a user defined element class as a
117
        starting point.
118
119
        :param goblin.element.Element element_class: An optional element
120
            class that will dictate the element type (vertex/edge) as well as
121
            the label for the traversal source
122
123
        :returns: `aiogremlin.process.graph_traversal.AsyncGraphTraversal`
124
        """
125
        traversal = self.graph.traversal().withRemote(self)
126
        if element_class:
127
            label = element_class.__mapping__.label
128
            if element_class.__type__ == 'vertex':
129
                traversal = traversal.V()
130
            if element_class.__type__ == 'edge':
131
                traversal = traversal.E()
132
            traversal = traversal.hasLabel(label)
133
        return traversal
134
135
    async def submit(self, bytecode):
136
        """
137
        Submit a query to the Gremiln Server.
138
139
        :param str gremlin: Gremlin script to submit to server.
140
        :param dict bindings: A mapping of bindings for Gremlin script.
141
142
        :returns:
143
            `gremlin_python.driver.remove_connection.RemoteTraversal`
144
            object
145
        """
146
        await self.flush()
147
        remote_traversal = await self.remote_connection.submit(bytecode)
148
        traversers = remote_traversal.traversers
149
        side_effects = remote_traversal.side_effects
150
        result_set = ResultSet(traversers.request_id, traversers._timeout,
151
                               self._loop)
152
        self._loop.create_task(self._receive(traversers, result_set))
153
        return RemoteTraversal(result_set, side_effects)
154
155
    async def _receive(self, traversers, result_set):
156
        try:
157
            async for result in traversers:
158
                result = await self._deserialize_result(result)
159
                msg = Message(200, result, '')
160
                result_set.queue_result(msg)
161
        except Exception as e:
162
            print("Exception caught, will show up as a msg, traceback here:")
163
            print(e)
164
            for line in traceback.format_stack():
165
                print(line.strip())
166
            msg = Message(500, None, e.args[0])
167
            result_set.queue_result(msg)
168
        finally:
169
            result_set.queue_result(None)
170
171
    async def _deserialize_result(self, result):
172
        if isinstance(result, Traverser):
173
            bulk = result.bulk
174
            obj = result.object
175
            if isinstance(obj, (Vertex, Edge)):
176
                hashable_id = self._get_hashable_id(obj.id)
177
                current = self.current.get(hashable_id, None)
178
                if isinstance(obj, Vertex):
179
                    # why doesn't this come in on the vertex?
180
                    label = await self._g.V(hashable_id).label().next()
181
                    if not current:
182
                        current = self.app.vertices.get(label, GenericVertex)()
183
                    props = await self._get_vertex_properties(hashable_id, label)
184
                if isinstance(obj, Edge):
185
                    props = await self._g.E(hashable_id).valueMap(True).next()
186
                    if not current:
187
                        current = self.app.edges.get(
188
                            props.get('label'), GenericEdge)()
189
                        current.source = GenericVertex()
190
                        current.target = GenericVertex()
191
                element = current.__mapping__.mapper_func(obj, props, current)
0 ignored issues
show
introduced by
The variable props does not seem to be defined in case isinstance(obj, Vertex) on line 178 is False. Are you sure this can never be the case?
Loading history...
192
                self.current[hashable_id] = element
193
                return Traverser(element, bulk)
194
            else:
195
                return result
196
        # Recursive serialization is broken in goblin
197
        elif isinstance(result, dict):
198
            for key in result:
199
                result[key] = self._deserialize_result(result[key])
200
            return result
201
        elif isinstance(result, list):
202
            return [self._deserialize_result(item) for item in result]
203
        else:
204
            return result
205
206
    async def _get_vertex_properties(self, vid, label):
207
        projection = self._g.V(vid).properties() \
208
                            .project('id', 'key', 'value', 'meta') \
209
                            .by(__.id()).by(__.key()).by(__.value()) \
210
                            .by(__.valueMap())
211
        props = await projection.toList()
212
        new_props = {'label': label, 'id': vid}
213
        for prop in props:
214
            key = prop['key']
215
            val = prop['value']
216
            # print('val_type', type(val))
217
            meta = prop['meta']
218
            new_props.setdefault(key, [])
219
            if meta:
220
                meta['key'] = key
221
                meta['value'] = val
222
                meta['id'] = prop['id']
223
                val = meta
224
225
            new_props[key].append(val)
226
        return new_props
227
228
    # Creation API
229
    def add(self, *elements):
230
        """
231
        Add elements to session pending queue.
232
233
        :param goblin.element.Element elements: Elements to be added
234
        """
235
        for elem in elements:
236
            self._pending.append(elem)
237
238
    async def flush(
239
                    self,
240
                    conflicts_query: Optional[GraphTraversal] = None
241
                  ) -> None:
242
        """
243
        Issue creation/update queries to database for all elements in the
244
        session pending queue.
245
        """
246
        transaction_id = str(uuid.uuid4())
247
        processed = []
248
        try:
249
            while self._pending:
250
                elem = self._pending.popleft()
251
                actual_id = self.__dirty_element(elem, id=transaction_id)
252
                if actual_id:
253
                    processed.append(await self.save(elem))
254
                else:
255
                    await self.save(elem)
256
257
            if not processed: return
258
            if not conflicts_query:
259
                await self.__commit_transaction(transaction_id)
260
            else:
261
                await (self.
262
                       g.
263
                       E().
264
                       has('dirty', transaction_id).
265
                       aggregate('x').
266
                       fold().
267
                       V().
268
                       has('dirty', transaction_id).
269
                       aggregate('x').
270
                       choose(
271
                           conflicts_query,
272
273
                           __.
274
                           select('x').
275
                           unfold().
276
                           properties('dirty').
277
                           drop()).
278
                       iterate())  # type: ignore
279
                await self.__rollback_transaction(transaction_id)
280
        except Exception as e:
281
            await self.__rollback_transaction(transaction_id)
282
            raise e
283
        for elem in processed:
284
            elem.dirty = None
285
286
    async def remove_vertex(self, vertex):
287
        """
288
        Remove a vertex from the db.
289
290
        :param goblin.element.Vertex vertex: Vertex to be removed
291
        """
292
        traversal = self._g.V(Binding('vid', vertex.id)).drop()
293
        result = await self._simple_traversal(traversal, vertex)
294
        hashable_id = self._get_hashable_id(vertex.id)
295
        if hashable_id in self.current:
296
            vertex = self.current.pop(hashable_id)
297
        else:
298
            msg = 'Vertex {} does not belong to this session obj {}'.format(
299
                vertex, self)
300
            logger.warning(msg)
301
        del vertex
302
        return result
303
304
    async def remove_edge(self, edge):
305
        """
306
        Remove an edge from the db.
307
308
        :param goblin.element.Edge edge: Element to be removed
309
        """
310
        eid = edge.id
311
        if isinstance(eid, dict):
312
            eid = Binding('eid', edge.id)
313
        traversal = self._g.E(eid).drop()
314
        result = await self._simple_traversal(traversal, edge)
315
        hashable_id = self._get_hashable_id(edge.id)
316
        if hashable_id in self.current:
317
            edge = self.current.pop(hashable_id)
318
        else:
319
            msg = 'Edge {} does not belong to this session obj {}'.format(
320
                edge, self)
321
            logger.warning(msg)
322
        del edge
323
        return result
324
325
    async def save(self, elem):
326
        """
327
        Save an element to the db.
328
329
        :param goblin.element.Element element: Vertex or Edge to be saved
330
331
        :returns: :py:class:`Element<goblin.element.Element>` object
332
        """
333
        if elem.__type__ == 'vertex':
334
            result = await self.save_vertex(elem)
335
        elif elem.__type__ == 'edge':
336
            result = await self.save_edge(elem)
337
        else:
338
            raise exception.ElementError("Unknown element type: {}".format(
339
                elem.__type__))
340
        return result
341
342
    async def save_vertex(self, vertex):
343
        """
344
        Save a vertex to the db.
345
346
        :param goblin.element.Vertex element: Vertex to be saved
347
348
        :returns: :py:class:`Vertex<goblin.element.Vertex>` object
349
        """
350
        result = await self._save_element(
351
            vertex, self._check_vertex, self._add_vertex, self._update_vertex)
352
        hashable_id = self._get_hashable_id(result.id)
353
        self.current[hashable_id] = result
354
        return result
355
356
    async def save_edge(self, edge):
357
        """
358
        Save an edge to the db.
359
360
        :param goblin.element.Edge element: Edge to be saved
361
362
        :returns: :py:class:`Edge<goblin.element.Edge>` object
363
        """
364
        if not (hasattr(edge, 'source') and hasattr(edge, 'target')):
365
            raise exception.ElementError(
366
                "Edges require both source/target vertices")
367
        result = await self._save_element(edge, self._check_edge,
368
                                          self._add_edge, self._update_edge)
369
        hashable_id = self._get_hashable_id(result.id)
370
        self.current[hashable_id] = result
371
        return result
372
373
    async def get_vertex(self, vertex):
374
        """
375
        Get a vertex from the db. Vertex must have id.
376
377
        :param goblin.element.Vertex element: Vertex to be retrieved
378
379
        :returns: :py:class:`Vertex<goblin.element.Vertex>` | None
380
        """
381
        return await self.g.V(Binding('vid', vertex.id)).next()
382
383
    async def get_edge(self, edge):
384
        """
385
        Get a edge from the db. Edge must have id.
386
387
        :param goblin.element.Edge element: Edge to be retrieved
388
389
        :returns: :py:class:`Edge<goblin.element.Edge>` | None
390
        """
391
        eid = edge.id
392
        if isinstance(eid, dict):
393
            eid = Binding('eid', edge.id)
394
        return await self.g.E(eid).next()
395
396
    def __dirty_element(self, elem, id = str(uuid.uuid4())):
397
        if elem.__locking__ and elem.__locking__ == LockingMode.OPTIMISTIC_LOCKING:
398
            if not elem.dirty:
399
                elem.dirty = id
400
                return id
401
            else:
402
                return elem.dirty
403
        return None
404
405
    async def __commit_transaction(self, id):
406
        if id: await self._g.E().has('dirty',id).aggregate('x').fold().V().has('dirty',id).aggregate('x').select('x').unfold().properties('dirty').drop().iterate()
407
408
    async def __rollback_transaction(self, id):
409
        if id: await self._g.E().has('dirty',id).aggregate('x').fold().V().has('dirty',id).aggregate('x').select('x').unfold().drop().iterate()
410
411
    async def _update_vertex(self, vertex):
412
        """
413
        Update a vertex, generally to change/remove property values.
414
415
        :param goblin.element.Vertex vertex: Vertex to be updated
416
417
        :returns: :py:class:`Vertex<goblin.element.Vertex>` object
418
        """
419
        props = mapper.map_props_to_db(vertex, vertex.__mapping__)
420
        traversal = self._g.V(Binding('vid', vertex.id))
421
        return await self._update_vertex_properties(vertex, traversal, props)
422
423
    async def _update_edge(self, edge):
424
        """
425
        Update an edge, generally to change/remove property values.
426
427
        :param goblin.element.Edge edge: Edge to be updated
428
429
        :returns: :py:class:`Edge<goblin.element.Edge>` object
430
        """
431
        props = mapper.map_props_to_db(edge, edge.__mapping__)
432
        eid = edge.id
433
        if isinstance(eid, dict):
434
            eid = Binding('eid', edge.id)
435
        traversal = self._g.E(eid)
436
        return await self._update_edge_properties(edge, traversal, props)
437
438
    # *metodos especiales privados for creation API
439
440
    async def _simple_traversal(self, traversal, element):
441
        elem = await traversal.next()
442
        if elem:
443
            if element.__type__ == 'vertex':
444
                # Look into this
445
                label = await self._g.V(self._get_hashable_id(elem.id)).label().next()
446
                props = await self._get_vertex_properties(self._get_hashable_id(elem.id), label)
447
            elif element.__type__ == 'edge':
448
                props = await self._g.E(self._get_hashable_id(elem.id)).valueMap(True).next()
449
450
            elem = element.__mapping__.mapper_func(elem, props, element)
0 ignored issues
show
introduced by
The variable props does not seem to be defined for all execution paths.
Loading history...
451
        return elem
452
453
454
    async def __handle_create_func(self, elem, create_func):
455
        transaction_id = elem.dirty
456
        if not transaction_id:
457
            transaction_id = self.__dirty_element(elem)
458
            if transaction_id:
459
                result = None
460
                try:
461
                    result = await create_func(elem)
462
                    await self.__commit_transaction(transaction_id)
463
                    result.dirty = None
464
                except Exception as e:
465
                    await self.__rollback_transaction(transaction_id)
466
                    raise e
467
                return result
468
469
        return await create_func(elem)
470
471
472
    async def _save_element(self, elem, check_func, create_func, update_func):
473
        if hasattr(elem, 'id'):
474
            exists = await check_func(elem)
475
            if not exists:
476
                result = await self.__handle_create_func(elem, create_func)
477
            else:
478
                if elem.__immutable__ and elem.__immutable__ != ImmutableMode.OFF: raise AttributeError("Trying to update an immutable element: %s" % elem)
479
                result = await update_func(elem)
480
        else:
481
            result = await self.__handle_create_func(elem, create_func)
482
        return result
483
484
    async def _add_vertex(self, vertex):
485
        """Convenience function for generating crud traversals."""
486
        props = mapper.map_props_to_db(vertex, vertex.__mapping__)
487
        traversal = self._g.addV(vertex.__mapping__.label)
488
        return await self._add_properties(traversal, props, vertex)
489
490
    async def _add_edge(self, edge):
491
        """Convenience function for generating crud traversals."""
492
        props = mapper.map_props_to_db(edge, edge.__mapping__)
493
        traversal = self._g.V(Binding('sid', edge.source.id))
494
        traversal = traversal.addE(edge.__mapping__._label)
495
        traversal = traversal.to(__.V(Binding('tid', edge.target.id)))
496
        return await self._add_properties(traversal, props, edge)
497
498
    async def _check_vertex(self, vertex):
499
        """Used to check for existence, does not update session vertex"""
500
        msg = await self._g.V(Binding('vid', vertex.id)).next()
501
        return msg
502
503
    async def _check_edge(self, edge):
504
        """Used to check for existence, does not update session edge"""
505
        eid = edge.id
506
        if isinstance(eid, dict):
507
            eid = Binding('eid', edge.id)
508
        return await self._g.E(eid).next()
509
510
    async def _update_vertex_properties(self, vertex, traversal, props):
511
        await self._g.V(vertex.id).properties().drop().iterate()
512
        return await self._add_properties(traversal, props, vertex)
513
514
    async def _update_edge_properties(self, edge, traversal, props):
515
        await self._g.E(edge.id).properties().drop().iterate()
516
        return await self._add_properties(traversal, props, edge)
517
518
    async def _add_properties(self, traversal, props, elem):
519
        binding = 0
520
        for card, db_name, val, metaprops in props:
521
            if not metaprops:
522
                metaprops = {}
523
            if val is not None:
524
                key = db_name
525
                #key = ('k' + str(binding), db_name)
526
                #val = ('v' + str(binding), val)
527
                if card:
528
                    # Maybe use a dict here as a translator
529
                    if card == Cardinality.list_:
530
                        card = Cardinality.list_
531
                    elif card == Cardinality.set_:
532
                        card = Cardinality.set_
533
                    else:
534
                        card = Cardinality.single
535
                    metas = [
536
                        j
537
                        for i in zip(metaprops.keys(), metaprops.values())
538
                        for j in i
539
                    ]
540
                    traversal = traversal.property(card, key, val, *metas)
541
                else:
542
                    metas = [
543
                        j
544
                        for i in zip(metaprops.keys(), metaprops.values())
545
                        for j in i
546
                    ]
547
                    traversal = traversal.property(key, val, *metas)
548
                binding += 1
549
        return await self._simple_traversal(traversal, elem)
550