|
1
|
|
|
import functools |
|
2
|
|
|
from collections.abc import Mapping, Sequence |
|
3
|
|
|
from typing import Union |
|
4
|
|
|
|
|
5
|
|
|
import matplotlib as mpl |
|
6
|
|
|
import networkx as nx |
|
7
|
|
|
import numpy as np |
|
8
|
|
|
|
|
9
|
|
|
|
|
10
|
|
|
@functools.lru_cache |
|
11
|
|
|
def node_color_mapping(graph: nx.Graph, cmap: Union[str, mpl.colors.Colormap] = "tab20") -> Mapping: |
|
12
|
|
|
"""Map node types to RGBA colors based on a colormap. |
|
13
|
|
|
Args: |
|
14
|
|
|
graph: nx.Graph - The input graph for which node colors need to be mapped. |
|
15
|
|
|
cmap: Union[str, mpl.colors.Colormap], optional - The colormap used to map values to RGBA colors. |
|
16
|
|
|
Default is "tab20". |
|
17
|
|
|
Returns: |
|
18
|
|
|
Mapping - A dictionary mapping nodes to their corresponding RGBA colors based on the colormap. |
|
19
|
|
|
|
|
20
|
|
|
.. note:: |
|
21
|
|
|
The graph should have a 'node_types' attribute containing the types of nodes. |
|
22
|
|
|
The colormap can be specified as a string or a matplotlib colormap object. |
|
23
|
|
|
""" |
|
24
|
|
|
if not graph.nodes: |
|
25
|
|
|
return {} |
|
26
|
|
|
|
|
27
|
|
|
node_type_keys = graph.graph.get('node_types', {}).keys() |
|
28
|
|
|
|
|
29
|
|
|
if len(node_type_keys) > 1 and 'node' in node_type_keys: |
|
30
|
|
|
# Create a new list of keys, preserving order, but excluding 'node' |
|
31
|
|
|
final_keys = [k for k in node_type_keys if k != 'node'] |
|
32
|
|
|
else: |
|
33
|
|
|
final_keys = list(node_type_keys) |
|
34
|
|
|
|
|
35
|
|
|
type_lookup = {t: i for i, t in enumerate(final_keys)} |
|
36
|
|
|
|
|
37
|
|
|
color_values_ndarray = np.fromiter( |
|
38
|
|
|
(type_lookup.get(graph.nodes[node].get('type'), 0) for node in graph.nodes), |
|
|
|
|
|
|
39
|
|
|
dtype=int, |
|
40
|
|
|
count=len(graph), |
|
41
|
|
|
) |
|
42
|
|
|
if len(color_values_ndarray) > 1: |
|
43
|
|
|
low, high = color_values_ndarray.min(), color_values_ndarray.max() |
|
44
|
|
|
else: |
|
45
|
|
|
low = high = 0 |
|
46
|
|
|
|
|
47
|
|
|
norm = mpl.colors.Normalize(vmin=low, vmax=high, clip=True) |
|
48
|
|
|
mapper = mpl.cm.ScalarMappable(norm=norm, cmap=cmap) |
|
49
|
|
|
colors = mapper.to_rgba(color_values_ndarray).tolist() |
|
50
|
|
|
|
|
51
|
|
|
color_mapping = dict(zip(graph.nodes, colors)) |
|
52
|
|
|
return color_mapping |
|
53
|
|
|
|
|
54
|
|
|
|
|
55
|
|
|
def color_hex(color: Union[str, Sequence[Union[float, int]]]) -> Union[str, Sequence[Union[float, int]]]: |
|
56
|
|
|
"""Get HEX color code |
|
57
|
|
|
|
|
58
|
|
|
Args: |
|
59
|
|
|
color: input color |
|
60
|
|
|
Returns: |
|
61
|
|
|
Color HEX code |
|
62
|
|
|
|
|
63
|
|
|
.. note:: |
|
64
|
|
|
If the input is a tuple or list, it should contain either three floats (0-1) or three ints (0-255). |
|
65
|
|
|
The function will convert these to a HEX color code. |
|
66
|
|
|
""" |
|
67
|
|
|
if isinstance(color, (tuple, list)): |
|
68
|
|
|
rgb = color[:3] |
|
69
|
|
|
|
|
70
|
|
|
if all(isinstance(c, float) and 0 <= c <= 1 for c in rgb): |
|
71
|
|
|
rgb = tuple(int(c * 255) for c in rgb) |
|
72
|
|
|
elif all(isinstance(c, int) and 0 <= c <= 255 for c in rgb): |
|
73
|
|
|
rgb = tuple(rgb) |
|
74
|
|
|
else: |
|
75
|
|
|
msg = "Input values should either be a float between 0 and 1 or an int between 0 and 255" |
|
76
|
|
|
raise ValueError(msg) |
|
77
|
|
|
|
|
78
|
|
|
r, g, b = rgb |
|
79
|
|
|
return f'#{r:02x}{g:02x}{b:02x}' |
|
80
|
|
|
|
|
81
|
|
|
else: |
|
82
|
|
|
return color |
|
83
|
|
|
|
|
84
|
|
|
|
|
85
|
|
|
def convert_colors_to_hex(graph: nx.Graph, color: str = 'color') -> None: |
|
86
|
|
|
"""Convert all color labels in the graph to hexadecimal format. |
|
87
|
|
|
|
|
88
|
|
|
Args: |
|
89
|
|
|
graph (nx.Graph): The input graph with node attributes. |
|
90
|
|
|
color (str): The attribute name for the color. Default is 'color'. |
|
91
|
|
|
|
|
92
|
|
|
Returns: |
|
93
|
|
|
None: The function modifies the graph in place. |
|
94
|
|
|
|
|
95
|
|
|
.. note:: |
|
96
|
|
|
This function assumes that the color attribute is present in the node data. |
|
97
|
|
|
""" |
|
98
|
|
|
|
|
99
|
|
|
color_values = {node: color_hex(data[color]) for node, data in graph.nodes(data=True) if color in data} |
|
100
|
|
|
nx.set_node_attributes(graph, values=color_values, name=color) |
|
101
|
|
|
|