generate_sankey_diagram_data.parse_markdown_list()   A
last analyzed

Complexity

Conditions 3

Size

Total Lines 13
Code Lines 10

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 10
nop 1
dl 0
loc 13
rs 9.9
c 0
b 0
f 0
1
from ast import literal_eval
2
from collections import defaultdict
3
from typing import Any
4
5
import pandas as pd
6
7
# future: use plotly.graph_objects instead of Flourish
8
9
10
def parse_markdown_list(
11
    markdown: str,
12
) -> defaultdict[Any, dict[str, int | defaultdict]]:
13
    lines = markdown.strip().split("\n")
14
    stack: list[dict[str, int] | defaultdict] = []
15
    root: defaultdict[Any, dict[str, int | defaultdict]] = defaultdict(
16
        lambda: {"count": 0, "children": defaultdict()}
17
    )
18
    lines = skip_frontmatter(lines)
19
20
    for line in lines:
21
        parse_line(line, root, stack)
22
    return root
23
24
25
def parse_line(
26
    line: str,
27
    root: defaultdict[Any, dict[str, int | defaultdict]],
28
    stack: list[dict[str, int] | defaultdict],
29
) -> None:
30
    indent_level = (len(line) - len(line.lstrip())) // 4
31
    item = line.strip("- ").strip()
32
    while len(stack) > indent_level:
33
        stack.pop()
34
    if stack:
35
        current = stack[-1]["children"]
36
    else:
37
        current = root
38
    if item not in current:
39
        current[item] = {"count": 0, "children": defaultdict()}
40
    current[item]["count"] += 1
41
    stack.append(current[item])
42
43
44
def skip_frontmatter(lines: list[str]) -> list[str]:
45
    if lines[0].startswith("---"):
46
        end_index = 0
47
        for i, line in enumerate(lines):
48
            if line.startswith("---") and i != 0:
49
                end_index = i + 1
50
                break
51
        lines = lines[end_index:]
52
    return lines
53
54
55
def update_counts(category: str, structure: dict, path: list) -> bool:
56
    if category in structure:
57
        structure[category]["count"] += 1
58
        for parent in path:
59
            parent["count"] += 1
60
        return True
61
    for key, value in structure.items():
62
        if update_counts(category, value["children"], path + [value]):
63
            return True
64
    return False
65
66
67
def update_values_from_csv(structure: dict, csv_data: pd.DataFrame) -> None:
68
    for categories in csv_data.iterrows():
69
        category_list = literal_eval(categories[1][0])
70
        for category in category_list:
71
            update_counts(category, structure, [])
72
73
74
def generate_sankey_data(
75
    structure: dict, step: int = 0, parent: str | None = None
76
) -> list:
77
    data = []
78
    for key, value in structure.items():
79
        if parent is not None:
80
            data.append(
81
                {
82
                    "Source": parent,
83
                    "Dest": key,
84
                    "Value": value["count"],
85
                    "Step from": step,
86
                    "Step to": step + 1,
87
                }
88
            )
89
        if value["children"]:
90
            data.extend(generate_sankey_data(value["children"], step + 1, key))
91
    return data
92
93
94
def prune_sankey_data(data: list, max_depth: int) -> list:
95
    return [row for row in data if row["Step to"] <= max_depth]
96
97
98
def main(markdown_file_path: str, csv_file_path: str, max_depth: int) -> None:
99
    with open(markdown_file_path, "r") as file:
100
        markdown_list = file.read()
101
    parsed_structure = parse_markdown_list(markdown_list)
102
103
    csv_df = pd.read_csv(csv_file_path, header=None)
104
    update_values_from_csv(parsed_structure, csv_df)
105
    sankey_data = generate_sankey_data(parsed_structure)
106
    pruned_sankey_data = prune_sankey_data(sankey_data, max_depth)
107
    df = pd.DataFrame(
108
        pruned_sankey_data, columns=["Source", "Dest", "Value", "Step from", "Step to"]
109
    )
110
    csv_output = df.to_csv(index=False)
111
    with open("/Users/gymate1/Downloads/sankey_data.csv", "w") as file:
112
        file.write(csv_output)
113
114
115
if __name__ == "__main__":
116
    main(
117
        markdown_file_path="/mindmap/SR_cause_categories.md",
118
        csv_file_path="/Users/gymate1/Downloads/kalauz_speed_restrictions.csv",
119
        max_depth=4,
120
    )
121