Test Failed
Pull Request — master (#878)
by Daniil
04:28
created

savu.plugins.stats.stats_utils   A

Complexity

Total Complexity 37

Size/Duplication

Total Lines 164
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 135
dl 0
loc 164
rs 9.44
c 0
b 0
f 0
wmc 37

7 Methods

Rating   Name   Duplication   Size   Complexity  
A StatsUtils.make_stats_table() 0 6 1
A StatsUtils.make_loop_graphs() 0 24 4
B StatsUtils.generate_figures() 0 24 8
B StatsUtils.make_stats_graphs() 0 38 7
A StatsUtils._remove_arrays() 0 9 4
C StatsUtils._get_dicts_for_graphs() 0 28 10
A StatsUtils._get_dicts_for_loops() 0 12 3
1
import matplotlib.pyplot as plt
2
from matplotlib.ticker import MaxNLocator
3
import pandas as pd
4
import h5py as h5
5
import numpy as np
6
7
class StatsUtils(object):
8
9
    _pattern_dict = {"projection": ["SINOGRAM", "PROJECTION", "TANGENTOGRAM", "4D_SCAN", "SINOMOVIE"],
10
                     "reconstruction": ["VOLUME_YZ", "VOLUME_XZ", "VOLUME_XY", "VOLUME_3D"]}
11
    _stats_list = ["max", "min", "mean", "mean_std_dev", "median_std_dev", "NRMSD"]
12
13
    def generate_figures(self, filepath, savepath):
14
        f = h5.File(filepath, 'r')
15
        stats_dict, index_list = self._get_dicts_for_graphs(f)
16
        loop_stats, loop_plugins = self._get_dicts_for_loops(f)
17
        f.close()
18
19
        #self.make_loop_graphs(loop_stats, loop_plugins, savepath)
20
21
        table_index_list = index_list
22
        for i in range(len(loop_plugins)):
23
            for space in list(table_index_list.keys()):
24
                for j, plugin in enumerate(table_index_list[space]):
25
                    for loop_plugin in loop_plugins[i]:
26
                        if loop_plugin == plugin[3::]:
27
                            table_index_list[space][j] = f"{table_index_list[space][j]} (loop{i})"
28
29
        self.make_stats_table(stats_dict, table_index_list, f"{savepath}/stats_table.html")
30
31
        if len(stats_dict["projection"]["max"]):
32
            self.make_stats_graphs(stats_dict["projection"], index_list["projection"], "Projection Stats",
33
                                   f"{savepath}/projection_stats.png")
34
        if len(stats_dict["reconstruction"]["max"]):
35
            self.make_stats_graphs(stats_dict["reconstruction"], index_list["reconstruction"], "Reconstruction Stats",
36
                                   f"{savepath}/reconstruction_stats.png")
37
38
39
40
41
    @staticmethod
42
    def make_stats_table(stats_dict, index_list, savepath):
43
        p_stats = pd.DataFrame(stats_dict["projection"], index_list["projection"])
44
        r_stats = pd.DataFrame(stats_dict["reconstruction"], index_list["reconstruction"])
45
        all_stats = pd.concat([p_stats, r_stats], keys=["Projection", "Reconstruction"])
46
        all_stats.to_html(savepath)  # create table of stats for all plugins
47
48
    def make_loop_graphs(self, loop_stats, loop_plugins, savepath):
49
        for i in range(len(loop_stats)):
50
            y = loop_stats[i]["NRMSD"]
51
52
            #x = list(range(1, len(loop_stats[i]["RMSD"]) + 1))
53
            x = [None]*len(y)
54
            for j in range(len(loop_stats[i]["NRMSD"])):
55
                x[j] = f"{j}-{j+1}"
56
            ax = plt.figure(figsize=(11, 9), dpi=320).gca()
57
            ax.xaxis.set_major_locator(MaxNLocator(integer=True))
58
            #ax.locator_params(axis='x', nbins=j + 1)
59
            ax.grid(True)
60
            plt.plot(x, y)
61
            maxx = j
0 ignored issues
show
introduced by
The variable j does not seem to be defined for all execution paths.
Loading history...
62
            maxy = max(y)
63
            plt.title("NRMSD over loop 0")
64
            text = f"Loop 0 iterates {maxx + 2} times over:\n"
65
            for plugin in loop_plugins[i]:
66
                text += f"{plugin}\n"
67
            plt.xlabel("Iteration")
68
            plt.ylabel("NRMSD")
69
70
            plt.text(maxx, maxy, text, ha="right", va="top", bbox=dict(boxstyle="round", facecolor="red", alpha=0.4))
71
            plt.savefig(f"{savepath}/loop_stats{i}.png", bbox_inches="tight")
72
73
74
    def make_stats_graphs(self, stats_dict, index_list, title, savepath):
75
        stats_df = pd.DataFrame(stats_dict, index_list)
76
        stats_dict, array_plugins = self._remove_arrays(stats_dict, index_list)
77
78
        stats_df_new = pd.DataFrame(stats_dict, index_list)
79
80
        colours = ["red", "blue", "green", "black", "purple", "brown"]  #max, min, mean, mean std dev, median std dev, NRMSD
81
82
        new_index = []
83
        legend = ""
84
        for ind in stats_df_new.index:
85
            new_index.append(ind[0])  # change x ticks to only be plugin numbers rather than names (for space)
86
            legend += f"{ind}\n"  # This will form a key showing the plugin names corresponding to plugin numbers
87
        stats_df_new.index = new_index
88
        fig, ax = plt.subplots(3, 2, figsize=(11, 9), dpi=320, facecolor="lavender")
89
        i = 0
90
        for row in ax:
91
            for axis in row:
92
                stat = self._stats_list[i]
93
                axis.plot(stats_df_new[stat], "x-", color=colours[i])
94
                for plugin in array_plugins:  # adding 'error' bars for plugins with multiple values due to parameter changes
95
                    if stats_df[stat][plugin] is not None:
96
                        my_max = max(stats_df[stat][plugin])
97
                        my_min = min(stats_df[stat][plugin])
98
                        middle = (my_max + my_min) / 2
99
                        my_range = my_max - my_min
100
                        axis.errorbar(int(plugin[0]) - int(stats_df_new.index[0]), middle, yerr=[my_range / 2], capsize=5)
101
                if i == 1:
102
                    maxx = len(stats_df_new[stat]) * 1.08 - 1
103
                    maxy = max(stats_df_new[stat])
104
                    axis.text(maxx, maxy, legend, ha="left", va="top",
105
                              bbox=dict(boxstyle="round", facecolor="red", alpha=0.4))
106
                stat = stat.replace("_", " ")
107
                axis.set_title(stat)
108
                axis.grid(True)
109
                i += 1
110
        fig.suptitle(title, fontsize="x-large")
111
        plt.savefig(savepath, bbox_inches="tight")
112
113
    @staticmethod
114
    def _get_dicts_for_graphs(file):
115
        stats_dict = {}
116
        stats_dict["projection"] = {"max": [], "min": [], "mean": [], "mean_std_dev": [], "median_std_dev": [],
117
                                    "NRMSD": []}
118
        stats_dict["reconstruction"] = {"max": [], "min": [], "mean": [], "mean_std_dev": [], "median_std_dev": [],
119
                                        "NRMSD": []}
120
121
        index_list = {"projection": [], "reconstruction": []}
122
123
        group = file["stats"]
124
        for space in ("projection", "reconstruction"):
125
            for index, stat in enumerate(["max", "min", "mean", "mean_std_dev", "median_std_dev", "NRMSD"]):
126
                for key in list(group.keys()):
127
                    if group[key].attrs.get("pattern") in StatsUtils._pattern_dict[space]:
128
                        if f"{key}: {group[key].attrs.get('plugin_name')}" not in index_list[space]:
129
                            index_list[space].append(f"{key}: {group[key].attrs.get('plugin_name')}")
130
                        if group[key].ndim == 1:
131
                            if len(group[key]) > index:
132
                                stats_dict[space][stat].append(group[key][index])
133
                            else:
134
                                stats_dict[space][stat].append(None)
135
                        elif group[key].ndim == 2:
136
                            if len(group[key][0]) > index:
137
                                stats_dict[space][stat].append(group[key][:, index])
138
                            else:
139
                                stats_dict[space][stat].append(None)
140
        return stats_dict, index_list
141
142
    @staticmethod
143
    def _get_dicts_for_loops(file):
144
        if "iterative" in list(file.keys()):
145
            group = file["iterative"]
146
            loop_stats = []
147
            loop_plugins = []
148
            for key in list(group.keys()):
149
                loop_stats.append({"NRMSD": list(group[key])})
150
                loop_plugins.append(group[key].attrs.get("loop_plugins"))
151
            return loop_stats, loop_plugins
152
        else:
153
            return [], []
154
155
    @staticmethod
156
    def _remove_arrays(stats_dict, index_list):
157
        array_plugins = set(())
158
        for stat in list(stats_dict.keys()):
159
            for index, value in enumerate(stats_dict[stat]):
160
                if isinstance(value, np.ndarray):
161
                    stats_dict[stat][index] = stats_dict[stat][index][0]
162
                    array_plugins.add(index_list[index])
163
        return stats_dict, array_plugins
164
165
166