Passed
Push — master ( cba5ed...7e35c2 )
by Simon
04:17
created

parallel_coordinates_plotly()   A

Complexity

Conditions 1

Size

Total Lines 5
Code Lines 4

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 4
nop 4
dl 0
loc 5
rs 10
c 0
b 0
f 0
1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
import os
6
import sys
7
import time
8
import numpy as np
9
import pandas as pd
10
import streamlit as st
11
import plotly.express as px
12
import matplotlib.pyplot as plt
13
14
15
color_scale = px.colors.sequential.Jet
16
17
18
def parallel_coordinates_plotly(*args, plotly_width=1200, plotly_height=540, **kwargs):
19
    fig = px.parallel_coordinates(*args, **kwargs, color_continuous_scale=color_scale)
20
    fig.update_layout(autosize=False, width=plotly_width, height=plotly_height)
21
22
    return fig
23
24
25
def filter_data(filter, df, columns):
26
    if len(df) > 1:
27
        for column in columns:
28
            if column not in list(filter["parameter"]):
29
                continue
30
31
            filter_ = filter[filter["parameter"] == column]
32
            lower, upper = (
33
                filter_["lower bound"].values[0],
34
                filter_["upper bound"].values[0],
35
            )
36
37
            col_data = df[column]
38
39
            if lower == "lower":
40
                lower = np.min(col_data)
41
            else:
42
                lower = float(lower)
43
44
            if upper == "upper":
45
                upper = np.max(col_data)
46
            else:
47
                upper = float(upper)
48
49
            df = df[(df[column] >= lower) & (df[column] <= upper)]
50
51
    return df
52
53
54
def main():
55
    try:
56
        st.set_page_config(page_title="Hyperactive Progress Board", layout="wide")
57
    except:
58
        pass
59
60
    search_ids = sys.argv[1:]
61
62
    search_id_dict = {}
63
    for search_id in search_ids:
64
        search_id_dict[search_id] = {}
65
66
        progress_data_path = "./progress_data_" + search_id + ".csv~"
67
        filter_path = "./filter_" + search_id + ".csv"
68
69
        if os.path.isfile(progress_data_path):
70
            search_id_dict[search_id]["progress_data"] = pd.read_csv(progress_data_path)
71
        if os.path.isfile(filter_path):
72
            search_id_dict[search_id]["filter"] = pd.read_csv(filter_path)
73
74
    for search_id in search_id_dict.keys():
75
        progress_data = search_id_dict[search_id]["progress_data"]
76
        filter = search_id_dict[search_id]["filter"]
77
78
        st.title(search_id)
79
        st.components.v1.html(
80
            """<hr style="height:1px;border:none;color:#333;background-color:#333;" /> """,
81
            height=10,
82
        )
83
84
        col1, col2 = st.beta_columns([1, 2])
85
86
        progress_data_f = progress_data[
87
            ~progress_data.isin([np.nan, np.inf, -np.inf]).any(1)
88
        ]
89
90
        nth_iter = progress_data_f["nth_iter"]
91
        score_best = progress_data_f["score_best"]
92
        nth_process = list(progress_data_f["nth_process"])
93
94
        if np.all(nth_process == nth_process[0]):
95
            fig, ax = plt.subplots()
96
            plt.plot(nth_iter, score_best)
97
            col1.pyplot(fig)
98
        else:
99
            fig, ax = plt.subplots()
100
            ax.set_xlabel("nth iteration")
101
            ax.set_ylabel("score")
102
103
            for i in np.unique(nth_process):
104
                nth_iter_p = nth_iter[nth_process == i]
105
                score_best_p = score_best[nth_process == i]
106
                plt.plot(nth_iter_p, score_best_p, label=str(i) + ". process")
107
            plt.legend()
108
            col1.pyplot(fig)
109
110
        progress_data_f.drop(
111
            ["nth_iter", "score_best", "nth_process"], axis=1, inplace=True
112
        )
113
        prog_data_columns = list(progress_data_f.columns)
114
115
        progress_data_f = filter_data(filter, progress_data_f, prog_data_columns)
116
117
        # remove score
118
        prog_data_columns.remove("score")
119
120
        fig = parallel_coordinates_plotly(
121
            progress_data_f, dimensions=prog_data_columns, color="score"
122
        )
123
        col2.plotly_chart(fig)
124
125
        for _ in range(3):
126
            st.write(" ")
127
128
    time.sleep(1)
129
    st.experimental_rerun()
130
131
132
if __name__ == "__main__":
133
    main()
134