| 1 |  |  | import functools | 
            
                                                                                                            
                            
            
                                    
            
            
                | 2 |  |  | from collections.abc import Mapping, Sequence | 
            
                                                                                                            
                            
            
                                    
            
            
                | 3 |  |  | from typing import Union | 
            
                                                                                                            
                            
            
                                    
            
            
                | 4 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 5 |  |  | import matplotlib as mpl | 
            
                                                                                                            
                            
            
                                    
            
            
                | 6 |  |  | import networkx as nx | 
            
                                                                                                            
                            
            
                                    
            
            
                | 7 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 8 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 9 |  |  | @functools.lru_cache | 
            
                                                                                                            
                            
            
                                    
            
            
                | 10 |  |  | def node_color_mapping(graph: nx.Graph, cmap: Union[str, mpl.colors.Colormap] = "tab20") -> Mapping: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 11 |  |  |     """Map node types to RGBA colors based on a colormap. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 12 |  |  |     Args: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 13 |  |  |         graph: nx.Graph - The input graph for which node colors need to be mapped. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 14 |  |  |         cmap: Union[str, mpl.colors.Colormap], optional - The colormap used to map values to RGBA colors. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 15 |  |  |               Default is "tab20". | 
            
                                                                                                            
                            
            
                                    
            
            
                | 16 |  |  |     Returns: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 17 |  |  |         Mapping - A dictionary mapping nodes to their corresponding RGBA colors based on the colormap. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 18 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 19 |  |  |     .. note:: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 20 |  |  |         The graph should have a 'node_types' attribute containing the types of nodes. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 21 |  |  |         The colormap can be specified as a string or a matplotlib colormap object. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 22 |  |  |     """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 23 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 24 |  |  |     node_types = graph.graph.get('node_types', {}) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 25 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 26 |  |  |     if len(node_types) > 1 and 'node' in node_types: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 27 |  |  |         node_types.pop('node') | 
            
                                                                                                            
                            
            
                                    
            
            
                | 28 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 29 |  |  |     type_lookup = {t: i for i, t in enumerate(graph.graph['node_types'].keys())} | 
            
                                                                                                            
                            
            
                                    
            
            
                | 30 |  |  |     color_lookup = {node: type_lookup.get(data.get('type'), 0) for node, data in graph.nodes.data()} | 
            
                                                                                                            
                            
            
                                    
            
            
                | 31 |  |  |     if len(color_lookup) > 1: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 32 |  |  |         low, *_, high = sorted(color_lookup.values()) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 33 |  |  |     else: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 34 |  |  |         low = high = 0 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 35 |  |  |     norm = mpl.colors.Normalize(vmin=low, vmax=high, clip=True) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 36 |  |  |     mapper = mpl.cm.ScalarMappable(norm=norm, cmap=cmap) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 37 |  |  |     node_colors = {n: mapper.to_rgba(i) for n, i in color_lookup.items()} | 
            
                                                                                                            
                            
            
                                    
            
            
                | 38 |  |  |     return node_colors | 
            
                                                                                                            
                            
            
                                    
            
            
                | 39 |  |  |  | 
            
                                                                                                            
                                                                
            
                                    
            
            
                | 40 |  |  |  | 
            
                                                                        
                            
            
                                    
            
            
                | 41 |  |  | def color_hex(color: Union[str, Sequence[Union[float, int]]]) -> Union[str, Sequence[Union[float, int]]]: | 
            
                                                                        
                            
            
                                    
            
            
                | 42 |  |  |     """Get HEX color code | 
            
                                                                        
                            
            
                                    
            
            
                | 43 |  |  |  | 
            
                                                                        
                            
            
                                    
            
            
                | 44 |  |  |     Args: | 
            
                                                                        
                            
            
                                    
            
            
                | 45 |  |  |         color: input color | 
            
                                                                        
                            
            
                                    
            
            
                | 46 |  |  |     Returns: | 
            
                                                                        
                            
            
                                    
            
            
                | 47 |  |  |          Color HEX code | 
            
                                                                        
                            
            
                                    
            
            
                | 48 |  |  |  | 
            
                                                                        
                            
            
                                    
            
            
                | 49 |  |  |     .. note:: | 
            
                                                                        
                            
            
                                    
            
            
                | 50 |  |  |         If the input is a tuple or list, it should contain either three floats (0-1) or three ints (0-255). | 
            
                                                                        
                            
            
                                    
            
            
                | 51 |  |  |         The function will convert these to a HEX color code. | 
            
                                                                        
                            
            
                                    
            
            
                | 52 |  |  |     """ | 
            
                                                                        
                            
            
                                    
            
            
                | 53 |  |  |     if isinstance(color, (tuple, list)):  # noqa: UP038 | 
            
                                                                        
                            
            
                                    
            
            
                | 54 |  |  |         rgb = color[:3] | 
            
                                                                        
                            
            
                                    
            
            
                | 55 |  |  |  | 
            
                                                                        
                            
            
                                    
            
            
                | 56 |  |  |         if all(isinstance(c, float) and 0 <= c <= 1 for c in rgb): | 
            
                                                                        
                            
            
                                    
            
            
                | 57 |  |  |             rgb = tuple(int(c * 255) for c in rgb) | 
            
                                                                        
                            
            
                                    
            
            
                | 58 |  |  |         elif all(isinstance(c, int) and 0 <= c <= 255 for c in rgb): | 
            
                                                                        
                            
            
                                    
            
            
                | 59 |  |  |             rgb = tuple(rgb) | 
            
                                                                        
                            
            
                                    
            
            
                | 60 |  |  |         else: | 
            
                                                                        
                            
            
                                    
            
            
                | 61 |  |  |             msg = "Input values should either be a float between 0 and 1 or an int between 0 and 255" | 
            
                                                                        
                            
            
                                    
            
            
                | 62 |  |  |             raise ValueError(msg) | 
            
                                                                        
                            
            
                                    
            
            
                | 63 |  |  |  | 
            
                                                                        
                            
            
                                    
            
            
                | 64 |  |  |         return '#{:02x}{:02x}{:02x}'.format(*rgb) | 
            
                                                                        
                            
            
                                    
            
            
                | 65 |  |  |  | 
            
                                                                        
                            
            
                                    
            
            
                | 66 |  |  |     else: | 
            
                                                                        
                            
            
                                    
            
            
                | 67 |  |  |         return color | 
            
                                                                                                            
                            
            
                                    
            
            
                | 68 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 69 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 70 |  |  | def convert_colors_to_hex(graph: nx.Graph, color: str = 'color') -> None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 71 |  |  |     """Convert all color labels in the graph to hexadecimal format. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 72 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 73 |  |  |     Args: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 74 |  |  |         graph (nx.Graph): The input graph with node attributes. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 75 |  |  |         color (str): The attribute name for the color. Default is 'color'. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 76 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 77 |  |  |     Returns: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 78 |  |  |         None: The function modifies the graph in place. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 79 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 80 |  |  |     .. note:: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 81 |  |  |        This function assumes that the color attribute is present in the node data. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 82 |  |  |     """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 83 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 84 |  |  |     color_values = {node: color_hex(data[color]) for node, data in graph.nodes(data=True) if color in data} | 
            
                                                                                                            
                                                                
            
                                    
            
            
                | 85 |  |  |     nx.set_node_attributes(graph, values=color_values, name=color) | 
            
                                                        
            
                                    
            
            
                | 86 |  |  |  |