graphinate.color.color_hex()   B
last analyzed

Complexity

Conditions 6

Size

Total Lines 28
Code Lines 12

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 12
dl 0
loc 28
rs 8.6666
c 0
b 0
f 0
cc 6
nop 1
1
import functools
2
from collections.abc import Mapping, Sequence
3
from typing import Union
4
5
import matplotlib as mpl
6
import networkx as nx
7
import numpy as np
8
9
10
@functools.lru_cache
11
def node_color_mapping(graph: nx.Graph, cmap: Union[str, mpl.colors.Colormap] = "tab20") -> Mapping:
12
    """Map node types to RGBA colors based on a colormap.
13
    Args:
14
        graph: nx.Graph - The input graph for which node colors need to be mapped.
15
        cmap: Union[str, mpl.colors.Colormap], optional - The colormap used to map values to RGBA colors.
16
              Default is "tab20".
17
    Returns:
18
        Mapping - A dictionary mapping nodes to their corresponding RGBA colors based on the colormap.
19
20
    .. note::
21
        The graph should have a 'node_types' attribute containing the types of nodes.
22
        The colormap can be specified as a string or a matplotlib colormap object.
23
    """
24
    if not graph.nodes:
25
        return {}
26
27
    node_type_keys = graph.graph.get('node_types', {}).keys()
28
29
    if len(node_type_keys) > 1 and 'node' in node_type_keys:
30
        # Create a new list of keys, preserving order, but excluding 'node'
31
        final_keys = [k for k in node_type_keys if k != 'node']
32
    else:
33
        final_keys = list(node_type_keys)
34
35
    type_lookup = {t: i for i, t in enumerate(final_keys)}
36
37
    color_values_ndarray = np.fromiter(
38
        (type_lookup.get(graph.nodes[node].get('type'), 0) for node in graph.nodes),
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable node does not seem to be defined.
Loading history...
39
        dtype=int,
40
        count=len(graph),
41
    )
42
    if len(color_values_ndarray) > 1:
43
        low, high = color_values_ndarray.min(), color_values_ndarray.max()
44
    else:
45
        low = high = 0
46
47
    norm = mpl.colors.Normalize(vmin=low, vmax=high, clip=True)
48
    mapper = mpl.cm.ScalarMappable(norm=norm, cmap=cmap)
49
    colors = mapper.to_rgba(color_values_ndarray).tolist()
50
51
    color_mapping = dict(zip(graph.nodes, colors))
52
    return color_mapping
53
54
55
def color_hex(color: Union[str, Sequence[Union[float, int]]]) -> Union[str, Sequence[Union[float, int]]]:
56
    """Get HEX color code
57
58
    Args:
59
        color: input color
60
    Returns:
61
         Color HEX code
62
63
    .. note::
64
        If the input is a tuple or list, it should contain either three floats (0-1) or three ints (0-255).
65
        The function will convert these to a HEX color code.
66
    """
67
    if isinstance(color, (tuple, list)):
68
        rgb = color[:3]
69
70
        if all(isinstance(c, float) and 0 <= c <= 1 for c in rgb):
71
            rgb = tuple(int(c * 255) for c in rgb)
72
        elif all(isinstance(c, int) and 0 <= c <= 255 for c in rgb):
73
            rgb = tuple(rgb)
74
        else:
75
            msg = "Input values should either be a float between 0 and 1 or an int between 0 and 255"
76
            raise ValueError(msg)
77
78
        r, g, b = rgb
79
        return f'#{r:02x}{g:02x}{b:02x}'
80
81
    else:
82
        return color
83
84
85
def convert_colors_to_hex(graph: nx.Graph, color: str = 'color') -> None:
86
    """Convert all color labels in the graph to hexadecimal format.
87
88
    Args:
89
        graph (nx.Graph): The input graph with node attributes.
90
        color (str): The attribute name for the color. Default is 'color'.
91
92
    Returns:
93
        None: The function modifies the graph in place.
94
95
    .. note::
96
       This function assumes that the color attribute is present in the node data.
97
    """
98
99
    color_values = {node: color_hex(data[color]) for node, data in graph.nodes(data=True) if color in data}
100
    nx.set_node_attributes(graph, values=color_values, name=color)
101