1
|
|
|
from itertools import chain |
2
|
|
|
from typing import Dict, Iterable, Optional, Tuple |
3
|
|
|
|
4
|
|
|
from networkx import Graph |
5
|
|
|
|
6
|
|
|
from .query import prepare_query, search_edges |
7
|
|
|
from .utils import get_first_item, get_second_item |
8
|
|
|
|
9
|
|
|
__all__ = ["search_direct_relationships", "join_relationship"] |
10
|
|
|
|
11
|
|
|
|
12
|
|
|
def search_direct_relationships( |
13
|
|
|
graph: Graph, source: Optional[Dict] = None, edge: Optional[Dict] = None, target: Optional[Dict] = None |
14
|
|
|
) -> Iterable[Tuple]: |
15
|
|
|
"""Search direct relation ship. |
16
|
|
|
|
17
|
|
|
Arguments: |
18
|
|
|
graph (Graph): graph instance |
19
|
|
|
source (Optional[Dict]): optional source node query constraint |
20
|
|
|
edge (Optional[Dict]): optional edge query constraint |
21
|
|
|
target (Optional[Dict]): optional target node query constraint |
22
|
|
|
|
23
|
|
|
Returns: |
24
|
|
|
(Iterable[Tuple]): itrable tuple of edge |
25
|
|
|
|
26
|
|
|
""" |
27
|
|
|
_iterable = search_edges(graph=graph, query=edge) if edge else graph.edges() |
28
|
|
|
|
29
|
|
|
if source: |
30
|
|
|
_predicate_source = prepare_query(source) |
31
|
|
|
_iterable = filter(lambda edge: _predicate_source(graph.nodes[edge[0]]), _iterable) |
|
|
|
|
32
|
|
|
|
33
|
|
|
if target: |
34
|
|
|
_predicate_target = prepare_query(target) |
35
|
|
|
_iterable = filter(lambda edge: _predicate_target(graph.nodes[edge[1]]), _iterable) |
|
|
|
|
36
|
|
|
|
37
|
|
|
return _iterable |
38
|
|
|
|
39
|
|
|
|
40
|
|
|
def join_relationship( |
41
|
|
|
graph: Graph, source: Iterable[Tuple], target: Iterable[Tuple], join_on_source_origin: Optional[bool] = True |
42
|
|
|
) -> Iterable[Tuple]: # pragma: no cover |
43
|
|
|
"""Join two relation ship. |
44
|
|
|
|
45
|
|
|
With source = (a, b), target = (c, d) |
46
|
|
|
If join_on_source_origin is set, return (e, f) as e in source(e, _) and e in target(e, _) |
47
|
|
|
else return edge (e, _) or (_ e) as e in source(_, e) and e in target(e, _) |
48
|
|
|
""" |
49
|
|
|
|
50
|
|
|
_source = set(source) |
51
|
|
|
_target = set(target) |
52
|
|
|
|
53
|
|
|
_source_filter = get_first_item if join_on_source_origin else get_second_item |
54
|
|
|
|
55
|
|
|
_nodes = set(filter(_source_filter, _source)).intersection(set(filter(get_first_item, _target))) |
56
|
|
|
|
57
|
|
|
return chain( |
58
|
|
|
filter(lambda edge: _source_filter(edge) in _nodes, _source), |
59
|
|
|
filter(lambda edge: get_first_item(edge) in _nodes, _target), |
60
|
|
|
) |
61
|
|
|
|