StatsUtils._remove_arrays()   A
last analyzed

Complexity

Conditions 4

Size

Total Lines 9
Code Lines 9

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 4
eloc 9
nop 2
dl 0
loc 9
rs 9.95
c 0
b 0
f 0
1
import matplotlib.pyplot as plt
2
from matplotlib.ticker import MaxNLocator
3
import pandas as pd
4
import h5py as h5
5
import numpy as np
6
7
class StatsUtils(object):
8
9
    _pattern_dict = {"projection": ["SINOGRAM", "PROJECTION", "TANGENTOGRAM", "4D_SCAN", "SINOMOVIE"],
10
                     "reconstruction": ["VOLUME_YZ", "VOLUME_XZ", "VOLUME_XY", "VOLUME_3D"]}
11
    _stats_key = ["max", "min", "mean", "mean_std_dev", "median_std_dev", "NRMSD"]
12
13
    plt.set_loglevel('WARNING')
14
15
    def generate_figures(self, filepath, savepath):
16
        f = h5.File(filepath, 'r')
17
        stats_dict, index_list, times_dict = self._get_dicts_for_graphs(f)
18
        loop_stats, loop_plugins = self._get_dicts_for_loops(f)
19
        f.close()
20
21
        #self.make_loop_graphs(loop_stats, loop_plugins, savepath)
22
23
        table_index_list = index_list
24
        for i in range(len(loop_plugins)):
25
            for space in list(table_index_list.keys()):
26
                for j, plugin in enumerate(table_index_list[space]):
27
                    for loop_plugin in loop_plugins[i]:
28
                        if loop_plugin == plugin[3::]:
29
                            table_index_list[space][j] = f"{table_index_list[space][j]} (loop{i})"
30
31
        self.make_stats_table(stats_dict, table_index_list, f"{savepath}/stats_table.html")
32
33
        for p_num in list(times_dict.keys()):
34
            for p_name in index_list["projection"] + index_list["reconstruction"]:
35
                if p_num == p_name[0]:
36
                    times_dict[p_name] = times_dict.pop(p_num)
37
38
        self.make_times_figure(times_dict, f"{savepath}/times_chart.png")
39
40
        if len(stats_dict["projection"]["max"]):
41
            self.make_stats_graphs(stats_dict["projection"], index_list["projection"], "Projection Stats",
42
                                   f"{savepath}/projection_stats.png")
43
        if len(stats_dict["reconstruction"]["max"]):
44
            self.make_stats_graphs(stats_dict["reconstruction"], index_list["reconstruction"], "Reconstruction Stats",
45
                                   f"{savepath}/reconstruction_stats.png")
46
47
48
49
50
    @staticmethod
51
    def make_stats_table(stats_dict, index_list, savepath):
52
        stats_dict_copy = {}
53
        for space, value in stats_dict.items():
54
            stats_dict_copy[space] = value.copy()
55
        for stat in list(stats_dict["projection"].keys()):
56
            if all(value is None for value in stats_dict["projection"][stat]) and all(value is None for value in stats_dict["reconstruction"][stat]):
57
                del stats_dict_copy["projection"][stat]
58
                del stats_dict_copy["reconstruction"][stat]
59
        p_stats = pd.DataFrame(stats_dict_copy["projection"], index_list["projection"])
60
        r_stats = pd.DataFrame(stats_dict_copy["reconstruction"], index_list["reconstruction"])
61
        all_stats = pd.concat([p_stats, r_stats], keys=["Projection", "Reconstruction"])
62
        all_stats.to_html(savepath)  # create table of stats for all plugins
63
64
    def make_loop_graphs(self, loop_stats, loop_plugins, savepath):
65
        for i in range(len(loop_stats)):
66
            y = loop_stats[i]["NRMSD"]
67
68
            #x = list(range(1, len(loop_stats[i]["RMSD"]) + 1))
69
            x = [None]*len(y)
70
            for j in range(len(loop_stats[i]["NRMSD"])):
71
                x[j] = f"{j}-{j+1}"
72
            ax = plt.figure(figsize=(11, 9), dpi=320).gca()
73
            ax.xaxis.set_major_locator(MaxNLocator(integer=True))
74
            #ax.locator_params(axis='x', nbins=j + 1)
75
            ax.grid(True)
76
            plt.plot(x, y)
77
            maxx = j
0 ignored issues
show
introduced by
The variable j does not seem to be defined for all execution paths.
Loading history...
78
            maxy = max(y)
79
            plt.title("NRMSD over loop 0")
80
            text = f"Loop 0 iterates {maxx + 2} times over:\n"
81
            for plugin in loop_plugins[i]:
82
                text += f"{plugin}\n"
83
            plt.xlabel("Iteration")
84
            plt.ylabel("NRMSD")
85
86
            plt.text(maxx, maxy, text, ha="right", va="top", bbox=dict(boxstyle="round", facecolor="red", alpha=0.4))
87
            plt.savefig(f"{savepath}/loop_stats{i}.png", bbox_inches="tight")
88
89
90
    def make_stats_graphs(self, stats_dict, index_list, title, savepath):
91
        stats_df = pd.DataFrame(stats_dict, index_list)
92
        stats_dict, array_plugins = self._remove_arrays(stats_dict, index_list)
93
94
        stats_df_new = pd.DataFrame(stats_dict, index_list)
95
96
        colours = ["red", "blue", "green", "black", "purple", "brown"]  #max, min, mean, mean std dev, median std dev, NRMSD
97
98
        new_index = []
99
        legend = ""
100
        for ind in stats_df_new.index:
101
            new_index.append(ind[0])  # change x ticks to only be plugin numbers rather than names (for space)
102
            legend += f"{ind}\n"  # This will form a key showing the plugin names corresponding to plugin numbers
103
        stats_df_new.index = new_index
104
        fig, ax = plt.subplots(3, 2, figsize=(11, 9), dpi=320, facecolor="lavender")
105
        i = 0
106
        for row in ax:
107
            for axis in row:
108
                stat = self._stats_key[i]
109
                axis.plot(stats_df_new[stat], "x-", color=colours[i])
110
                for plugin in array_plugins:  # adding 'error' bars for plugins with multiple values
111
                    if stats_df[stat][plugin] is not None:
112
                        if not np.isnan(stats_df[stat][plugin]).any():
113
                            my_max = max(stats_df[stat][plugin])
114
                            my_min = min(stats_df[stat][plugin])
115
                            middle = (my_max + my_min) / 2
116
                            my_range = my_max - my_min
117
                            axis.errorbar(list(stats_df_new.index).index(plugin[0]), middle, yerr=[my_range / 2], capsize=5)
118
                if i == 1:
119
                    maxx = len(stats_df_new[stat]) * 1.08 - 1
120
                    maxy = max(stats_df_new[stat])
121
                    axis.text(maxx, maxy, legend, ha="left", va="top",
122
                              bbox=dict(boxstyle="round", facecolor="red", alpha=0.4))
123
                stat = stat.replace("_", " ")
124
                axis.set_title(stat)
125
                axis.grid(True)
126
                i += 1
127
        fig.suptitle(title, fontsize="x-large")
128
        plt.savefig(savepath, bbox_inches="tight")
129
130
131
    def make_times_figure(self, times_dict, savepath):
132
        colors = plt.get_cmap('Blues')(np.linspace(0.2, 0.7, len(list(times_dict.keys()))))
133
        fig, ax = plt.subplots()
134
        total = sum(list(times_dict.values()))
135
        ax.pie(list(times_dict.values()), labels=list(times_dict.keys()), autopct=lambda pct: self._get_times_pct(pct, total),
136
               counterclock=False, startangle=90, colors=colors, radius=3, center=(4, 4), wedgeprops={"linewidth": 1, "edgecolor": "white"})
137
        fig.suptitle("Plugin Times", x=0.1, y=1.4, horizontalalignment="right", fontsize="x-large")
138
        plt.savefig(savepath, bbox_inches="tight")
139
140
    @staticmethod
141
    def _get_times_pct(pct, total):
142
        absolute = (pct/100)*total
143
        return f"{round(pct, 1)}%\n{round(absolute, 1)} (s)"
144
145
    @staticmethod
146
    def _get_dicts_for_graphs(file):
147
        stats_dict = {}
148
        stats_dict["projection"] = {"max": [], "min": [], "mean": [], "mean_std_dev": [], "median_std_dev": [],
149
                                    "NRMSD": [], "zeros": [], "zeros%": [], "time (s)": []}
150
        stats_dict["reconstruction"] = {"max": [], "min": [], "mean": [], "mean_std_dev": [], "median_std_dev": [],
151
                                        "NRMSD": [], "zeros": [], "zeros%": [], "time (s)": []}
152
153
        index_list = {"projection": [], "reconstruction": []}
154
155
        times_dict = {}
156
157
        group = file["stats"]
158
        for space in ("projection", "reconstruction"):
159
            for index, stat in enumerate(["max", "min", "mean", "mean_std_dev", "median_std_dev", "NRMSD", "zeros", "zeros%"]):
160
                for p_num in list(group.keys()):
161
                    if group[p_num].attrs.get("pattern") in StatsUtils._pattern_dict[space]:
162
                        if f"{p_num}: {group[p_num].attrs.get('plugin_name')}" not in index_list[space]:
163
                            index_list[space].append(f"{p_num}: {group[p_num].attrs.get('plugin_name')}")
164
                        if group[p_num].ndim == 1:
165
                            stats_key = list(group[p_num].attrs.get("stats_key"))
166
                            if stat in stats_key:
167
                                stats_dict[space][stat].append(group[p_num][stats_key.index(stat)])
168
                            else:
169
                                stats_dict[space][stat].append(None)
170
                        elif group[p_num].ndim == 2:
171
                            stats_key = list(group[p_num].attrs.get("stats_key"))
172
                            if stat in stats_key:
173
                                stats_dict[space][stat].append(group[p_num][:, stats_key.index(stat)])
174
                            else:
175
                                stats_dict[space][stat].append(None)
176
            for p_num in list(group.keys()):
177
                if group[p_num].attrs.get("pattern") in StatsUtils._pattern_dict[space]:
178
                    stats_dict[space]["time (s)"].append(group[p_num].attrs.get("time"))
179
180
        for plugin in list(group.keys()):
181
            if group[plugin].attrs.get("time") is not None:
182
                times_dict[plugin] = group[plugin].attrs.get("time")
183
184
        return stats_dict, index_list, times_dict
185
186
    @staticmethod
187
    def _get_dicts_for_loops(file):
188
        if "iterative" in list(file.keys()):
189
            group = file["iterative"]
190
            loop_stats = []
191
            loop_plugins = []
192
            for key in list(group.keys()):
193
                loop_stats.append({"NRMSD": list(group[key])})
194
                loop_plugins.append(group[key].attrs.get("loop_plugins"))
195
            return loop_stats, loop_plugins
196
        else:
197
            return [], []
198
199
    @staticmethod
200
    def _remove_arrays(stats_dict, index_list):
201
        array_plugins = set(())
202
        for stat in list(stats_dict.keys()):
203
            for index, value in enumerate(stats_dict[stat]):
204
                if isinstance(value, np.ndarray):
205
                    stats_dict[stat][index] = stats_dict[stat][index][0]
206
                    array_plugins.add(index_list[index])
207
        return stats_dict, array_plugins
208
209
210