graphinate.color   A
last analyzed

Complexity

Total Complexity 11

Size/Duplication

Total Lines 86
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 11
eloc 35
dl 0
loc 86
rs 10
c 0
b 0
f 0

3 Functions

Rating   Name   Duplication   Size   Complexity  
B color_hex() 0 27 6
A node_color_mapping() 0 30 4
A convert_colors_to_hex() 0 16 1
1
import functools
2
from collections.abc import Mapping, Sequence
3
from typing import Union
4
5
import matplotlib as mpl
6
import networkx as nx
7
8
9
@functools.lru_cache
10
def node_color_mapping(graph: nx.Graph, cmap: Union[str, mpl.colors.Colormap] = "tab20") -> Mapping:
11
    """Map node types to RGBA colors based on a colormap.
12
    Args:
13
        graph: nx.Graph - The input graph for which node colors need to be mapped.
14
        cmap: Union[str, mpl.colors.Colormap], optional - The colormap used to map values to RGBA colors.
15
              Default is "tab20".
16
    Returns:
17
        Mapping - A dictionary mapping nodes to their corresponding RGBA colors based on the colormap.
18
19
    .. note::
20
        The graph should have a 'node_types' attribute containing the types of nodes.
21
        The colormap can be specified as a string or a matplotlib colormap object.
22
    """
23
24
    node_types = graph.graph.get('node_types', {})
25
26
    if len(node_types) > 1 and 'node' in node_types:
27
        node_types.pop('node')
28
29
    type_lookup = {t: i for i, t in enumerate(graph.graph['node_types'].keys())}
30
    color_lookup = {node: type_lookup.get(data.get('type'), 0) for node, data in graph.nodes.data()}
31
    if len(color_lookup) > 1:
32
        low, *_, high = sorted(color_lookup.values())
33
    else:
34
        low = high = 0
35
    norm = mpl.colors.Normalize(vmin=low, vmax=high, clip=True)
36
    mapper = mpl.cm.ScalarMappable(norm=norm, cmap=cmap)
37
    node_colors = {n: mapper.to_rgba(i) for n, i in color_lookup.items()}
38
    return node_colors
39
40
41
def color_hex(color: Union[str, Sequence[Union[float, int]]]) -> Union[str, Sequence[Union[float, int]]]:
42
    """Get HEX color code
43
44
    Args:
45
        color: input color
46
    Returns:
47
         Color HEX code
48
49
    .. note::
50
        If the input is a tuple or list, it should contain either three floats (0-1) or three ints (0-255).
51
        The function will convert these to a HEX color code.
52
    """
53
    if isinstance(color, (tuple, list)):  # noqa: UP038
54
        rgb = color[:3]
55
56
        if all(isinstance(c, float) and 0 <= c <= 1 for c in rgb):
57
            rgb = tuple(int(c * 255) for c in rgb)
58
        elif all(isinstance(c, int) and 0 <= c <= 255 for c in rgb):
59
            rgb = tuple(rgb)
60
        else:
61
            msg = "Input values should either be a float between 0 and 1 or an int between 0 and 255"
62
            raise ValueError(msg)
63
64
        return '#{:02x}{:02x}{:02x}'.format(*rgb)
65
66
    else:
67
        return color
68
69
70
def convert_colors_to_hex(graph: nx.Graph, color: str = 'color') -> None:
71
    """Convert all color labels in the graph to hexadecimal format.
72
73
    Args:
74
        graph (nx.Graph): The input graph with node attributes.
75
        color (str): The attribute name for the color. Default is 'color'.
76
77
    Returns:
78
        None: The function modifies the graph in place.
79
80
    .. note::
81
       This function assumes that the color attribute is present in the node data.
82
    """
83
84
    color_values = {node: color_hex(data[color]) for node, data in graph.nodes(data=True) if color in data}
85
    nx.set_node_attributes(graph, values=color_values, name=color)
86