| Conditions | 2 |
| Total Lines | 48 |
| Code Lines | 35 |
| Lines | 0 |
| Ratio | 0 % |
| Changes | 0 | ||
| 1 | from tqdm import tqdm |
||
| 10 | def create_convergence_plots(optimizer_keys, plot_name): |
||
| 11 | plt.figure(figsize=(10, 8)) |
||
| 12 | |||
| 13 | mean_plts = [] |
||
| 14 | std_plts = [] |
||
| 15 | |||
| 16 | for idx, optimizer_key in enumerate(optimizer_keys): |
||
| 17 | convergence_data = pd.read_csv( |
||
| 18 | "./_data/" + optimizer_key + "_convergence_data.csv" |
||
| 19 | ) |
||
| 20 | |||
| 21 | x_range = range(len(convergence_data)) |
||
| 22 | scores_mean = convergence_data["scores_mean"] |
||
| 23 | scores_std = convergence_data["scores_std"] |
||
| 24 | |||
| 25 | (mean_plt,) = plt.plot( |
||
| 26 | x_range, |
||
| 27 | scores_mean, |
||
| 28 | linestyle="--", |
||
| 29 | marker=",", |
||
| 30 | alpha=0.9, |
||
| 31 | label=optimizer_key, |
||
| 32 | linewidth=1, |
||
| 33 | ) |
||
| 34 | std_plt = plt.fill_between( |
||
| 35 | x_range, |
||
| 36 | scores_mean - scores_std, |
||
| 37 | scores_mean + scores_std, |
||
| 38 | label=optimizer_key, |
||
| 39 | alpha=0.3, |
||
| 40 | ) |
||
| 41 | |||
| 42 | mean_plts.append(mean_plt) |
||
| 43 | std_plts.append(std_plt) |
||
| 44 | |||
| 45 | plt.tight_layout() |
||
| 46 | leg1 = plt.legend( |
||
| 47 | mean_plts, optimizer_keys, loc="lower center", title="average score" |
||
| 48 | ) |
||
| 49 | plt.legend( |
||
| 50 | std_plts, optimizer_keys, loc="lower right", title="standard deviation" |
||
| 51 | ) |
||
| 52 | plt.gca().add_artist(leg1) |
||
| 53 | |||
| 54 | plt.savefig( |
||
| 55 | "./_plots/" + plot_name + "_convergence.png", dpi=300, |
||
| 56 | ) |
||
| 57 | plt.close() |
||
| 58 | |||
| 61 |