Code Duplication    Length = 30-30 lines in 2 locations

benchmarks/color_benchmark.py 2 locations

@@ 45-74 (lines=30) @@
42
43
44
# --- Alternative Implementation (No NumPy) ---
45
@functools.lru_cache
46
def node_color_mapping_pure(graph: nx.Graph, cmap: Union[str, mpl.colors.Colormap] = "tab20") -> Mapping:
47
    if not graph.nodes:
48
        return {}
49
50
    node_type_keys = graph.graph.get('node_types', {}).keys()
51
52
    if len(node_type_keys) > 1 and 'node' in node_type_keys:
53
        final_keys = [k for k in node_type_keys if k != 'node']
54
    else:
55
        final_keys = list(node_type_keys)
56
57
    type_lookup = {t: i for i, t in enumerate(final_keys)}
58
59
    # Generate indices directly
60
    color_indices = [type_lookup.get(graph.nodes[node].get('type'), 0) for node in graph.nodes]
61
62
    if len(color_indices) > 1:
63
        low, high = min(color_indices), max(color_indices)
64
    else:
65
        low = high = 0
66
67
    norm = mpl.colors.Normalize(vmin=low, vmax=high, clip=True)
68
    mapper = mpl.cm.ScalarMappable(norm=norm, cmap=cmap)
69
70
    # mapper.to_rgba accepts a list of values
71
    colors = mapper.to_rgba(color_indices).tolist()
72
73
    color_mapping = dict(zip(graph.nodes, colors))
74
    return color_mapping
75
76
77
def run_benchmark():
@@ 12-41 (lines=30) @@
9
10
11
# --- Original Implementation ---
12
@functools.lru_cache
13
def node_color_mapping_numpy(graph: nx.Graph, cmap: Union[str, mpl.colors.Colormap] = "tab20") -> Mapping:
14
    if not graph.nodes:
15
        return {}
16
17
    node_type_keys = graph.graph.get('node_types', {}).keys()
18
19
    if len(node_type_keys) > 1 and 'node' in node_type_keys:
20
        final_keys = [k for k in node_type_keys if k != 'node']
21
    else:
22
        final_keys = list(node_type_keys)
23
24
    type_lookup = {t: i for i, t in enumerate(final_keys)}
25
26
    color_values_ndarray = np.fromiter(
27
        (type_lookup.get(graph.nodes[node].get('type'), 0) for node in graph.nodes),
28
        dtype=int,
29
        count=len(graph),
30
    )
31
    if len(color_values_ndarray) > 1:
32
        low, high = color_values_ndarray.min(), color_values_ndarray.max()
33
    else:
34
        low = high = 0
35
36
    norm = mpl.colors.Normalize(vmin=low, vmax=high, clip=True)
37
    mapper = mpl.cm.ScalarMappable(norm=norm, cmap=cmap)
38
    colors = mapper.to_rgba(color_values_ndarray).tolist()
39
40
    color_mapping = dict(zip(graph.nodes, colors))
41
    return color_mapping
42
43
44
# --- Alternative Implementation (No NumPy) ---