Passed
Push — master ( 7e35c2...37974c )
by Simon
03:12
created

StreamlitBackend.plotly()   A

Complexity

Conditions 2

Size

Total Lines 23
Code Lines 15

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 15
nop 3
dl 0
loc 23
rs 9.65
c 0
b 0
f 0
1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
import numpy as np
6
import matplotlib.pyplot as plt
7
import plotly.express as px
8
9
try:
10
    from progress_io import ProgressIO
11
except:
12
    from .progress_io import ProgressIO
13
14
15
color_scale = px.colors.sequential.Jet
16
17
18
class StreamlitBackend:
19
    def __init__(self, search_ids):
20
        self.search_ids = search_ids
21
        self.search_id_dict = {}
22
23
        _io_ = ProgressIO("./")
24
25
        for search_id in search_ids:
26
            self.search_id_dict[search_id] = {}
27
28
            self.search_id_dict[search_id]["prog_d"] = _io_.load_progress(search_id)
29
            self.search_id_dict[search_id]["filt_f"] = _io_.load_filter(search_id)
30
31
    def get_progress_data(self, search_id):
32
        progress_data = self.search_id_dict[search_id]["prog_d"]
33
        if progress_data is None:
34
            return
35
36
        return progress_data[~progress_data.isin([np.nan, np.inf, -np.inf]).any(1)]
37
38
    def pyplot(self, progress_data, search_id):
39
        nth_iter = progress_data["nth_iter"]
40
        score_best = progress_data["score_best"]
41
        nth_process = list(progress_data["nth_process"])
42
43
        if np.all(nth_process == nth_process[0]):
44
            fig, ax = plt.subplots()
45
            plt.plot(nth_iter, score_best)
46
        else:
47
            fig, ax = plt.subplots()
48
            ax.set_xlabel("nth iteration")
49
            ax.set_ylabel("score")
50
51
            for i in np.unique(nth_process):
52
                nth_iter_p = nth_iter[nth_process == i]
53
                score_best_p = score_best[nth_process == i]
54
                plt.plot(nth_iter_p, score_best_p, label=str(i) + ". process")
55
            plt.legend()
56
57
        return fig
58
59
    def filter_data(self, df, filter_df):
60
        prog_data_columns = list(df.columns)
61
62
        if len(df) > 1:
63
            for column in prog_data_columns:
64
                if column not in list(filter_df["parameter"]):
65
                    continue
66
67
                filter_ = filter_df[filter_df["parameter"] == column]
68
                lower, upper = (
69
                    filter_["lower bound"].values[0],
70
                    filter_["upper bound"].values[0],
71
                )
72
73
                col_data = df[column]
74
75
                if lower == "lower":
76
                    lower = np.min(col_data)
77
                else:
78
                    lower = float(lower)
79
80
                if upper == "upper":
81
                    upper = np.max(col_data)
82
                else:
83
                    upper = float(upper)
84
85
                df = df[(df[column] >= lower) & (df[column] <= upper)]
86
87
        return df
88
89
    def plotly(self, progress_data, search_id):
90
        filter_df = self.search_id_dict[search_id]["filt_f"]
91
92
        progress_data.drop(
93
            ["nth_iter", "score_best", "nth_process"], axis=1, inplace=True
94
        )
95
96
        if filter_df is not None:
97
            progress_data = self.filter_data(progress_data, filter_df)
98
99
        # remove score
100
        prog_data_columns = list(progress_data.columns)
101
        prog_data_columns.remove("score")
102
103
        fig = px.parallel_coordinates(
104
            progress_data,
105
            dimensions=prog_data_columns,
106
            color="score",
107
            color_continuous_scale=color_scale,
108
        )
109
        fig.update_layout(autosize=False, width=1200, height=540)
110
111
        return fig
112
113
    def create_plots(self, search_id):
114
        progress_data = self.get_progress_data(search_id)
115
        if progress_data is None:
116
            return None, None
117
118
        pyplot_fig = self.pyplot(progress_data, search_id)
119
        plotly_fig = self.plotly(progress_data, search_id)
120
121
        return pyplot_fig, plotly_fig
122