|
1
|
|
|
import functools |
|
2
|
|
|
from collections.abc import Mapping, Sequence |
|
3
|
|
|
from typing import Union |
|
4
|
|
|
|
|
5
|
|
|
import matplotlib as mpl |
|
6
|
|
|
import networkx as nx |
|
7
|
|
|
|
|
8
|
|
|
|
|
9
|
|
|
@functools.lru_cache |
|
10
|
|
|
def node_color_mapping(graph: nx.Graph, cmap: Union[str, mpl.colors.Colormap] = "tab20") -> Mapping: |
|
11
|
|
|
"""Map node types to RGBA colors based on a colormap. |
|
12
|
|
|
Args: |
|
13
|
|
|
graph: nx.Graph - The input graph for which node colors need to be mapped. |
|
14
|
|
|
cmap: Union[str, mpl.colors.Colormap], optional - The colormap used to map values to RGBA colors. |
|
15
|
|
|
Default is "tab20". |
|
16
|
|
|
Returns: |
|
17
|
|
|
Mapping - A dictionary mapping nodes to their corresponding RGBA colors based on the colormap. |
|
18
|
|
|
|
|
19
|
|
|
.. note:: |
|
20
|
|
|
The graph should have a 'node_types' attribute containing the types of nodes. |
|
21
|
|
|
The colormap can be specified as a string or a matplotlib colormap object. |
|
22
|
|
|
""" |
|
23
|
|
|
|
|
24
|
|
|
node_types = graph.graph.get('node_types', {}) |
|
25
|
|
|
|
|
26
|
|
|
if len(node_types) > 1 and 'node' in node_types: |
|
27
|
|
|
node_types.pop('node') |
|
28
|
|
|
|
|
29
|
|
|
type_lookup = {t: i for i, t in enumerate(graph.graph['node_types'].keys())} |
|
30
|
|
|
color_lookup = {node: type_lookup.get(data.get('type'), 0) for node, data in graph.nodes.data()} |
|
31
|
|
|
if len(color_lookup) > 1: |
|
32
|
|
|
low, *_, high = sorted(color_lookup.values()) |
|
33
|
|
|
else: |
|
34
|
|
|
low = high = 0 |
|
35
|
|
|
norm = mpl.colors.Normalize(vmin=low, vmax=high, clip=True) |
|
36
|
|
|
mapper = mpl.cm.ScalarMappable(norm=norm, cmap=cmap) |
|
37
|
|
|
node_colors = {n: mapper.to_rgba(i) for n, i in color_lookup.items()} |
|
38
|
|
|
return node_colors |
|
39
|
|
|
|
|
40
|
|
|
|
|
41
|
|
|
def color_hex(color: Union[str, Sequence[Union[float, int]]]) -> Union[str, Sequence[Union[float, int]]]: |
|
42
|
|
|
"""Get HEX color code |
|
43
|
|
|
|
|
44
|
|
|
Args: |
|
45
|
|
|
color: input color |
|
46
|
|
|
Returns: |
|
47
|
|
|
Color HEX code |
|
48
|
|
|
|
|
49
|
|
|
.. note:: |
|
50
|
|
|
If the input is a tuple or list, it should contain either three floats (0-1) or three ints (0-255). |
|
51
|
|
|
The function will convert these to a HEX color code. |
|
52
|
|
|
""" |
|
53
|
|
|
if isinstance(color, (tuple, list)): # noqa: UP038 |
|
54
|
|
|
rgb = color[:3] |
|
55
|
|
|
|
|
56
|
|
|
if all(isinstance(c, float) and 0 <= c <= 1 for c in rgb): |
|
57
|
|
|
rgb = tuple(int(c * 255) for c in rgb) |
|
58
|
|
|
elif all(isinstance(c, int) and 0 <= c <= 255 for c in rgb): |
|
59
|
|
|
rgb = tuple(rgb) |
|
60
|
|
|
else: |
|
61
|
|
|
msg = "Input values should either be a float between 0 and 1 or an int between 0 and 255" |
|
62
|
|
|
raise ValueError(msg) |
|
63
|
|
|
|
|
64
|
|
|
return '#{:02x}{:02x}{:02x}'.format(*rgb) |
|
65
|
|
|
|
|
66
|
|
|
else: |
|
67
|
|
|
return color |
|
68
|
|
|
|
|
69
|
|
|
|
|
70
|
|
|
def convert_colors_to_hex(graph: nx.Graph, color: str = 'color') -> None: |
|
71
|
|
|
"""Convert all color labels in the graph to hexadecimal format. |
|
72
|
|
|
|
|
73
|
|
|
Args: |
|
74
|
|
|
graph (nx.Graph): The input graph with node attributes. |
|
75
|
|
|
color (str): The attribute name for the color. Default is 'color'. |
|
76
|
|
|
|
|
77
|
|
|
Returns: |
|
78
|
|
|
None: The function modifies the graph in place. |
|
79
|
|
|
|
|
80
|
|
|
.. note:: |
|
81
|
|
|
This function assumes that the color attribute is present in the node data. |
|
82
|
|
|
""" |
|
83
|
|
|
|
|
84
|
|
|
color_values = {node: color_hex(data[color]) for node, data in graph.nodes(data=True) if color in data} |
|
85
|
|
|
nx.set_node_attributes(graph, values=color_values, name=color) |
|
86
|
|
|
|