Test Failed
Pull Request — master (#878)
by
unknown
05:51 queued 52s
created

savu.plugins.stats.stats_utils   A

Complexity

Total Complexity 38

Size/Duplication

Total Lines 165
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 136
dl 0
loc 165
rs 9.36
c 0
b 0
f 0
wmc 38

7 Methods

Rating   Name   Duplication   Size   Complexity  
A StatsUtils.make_stats_table() 0 6 1
A StatsUtils.make_loop_graphs() 0 24 4
B StatsUtils.generate_figures() 0 24 8
B StatsUtils.make_stats_graphs() 0 39 8
A StatsUtils._remove_arrays() 0 9 4
C StatsUtils._get_dicts_for_graphs() 0 28 10
A StatsUtils._get_dicts_for_loops() 0 12 3
1
import matplotlib.pyplot as plt
2
from matplotlib.ticker import MaxNLocator
3
import pandas as pd
4
import h5py as h5
5
import numpy as np
6
7
class StatsUtils(object):
8
9
    _pattern_dict = {"projection": ["SINOGRAM", "PROJECTION", "TANGENTOGRAM", "4D_SCAN", "SINOMOVIE"],
10
                     "reconstruction": ["VOLUME_YZ", "VOLUME_XZ", "VOLUME_XY", "VOLUME_3D"]}
11
    _stats_list = ["max", "min", "mean", "mean_std_dev", "median_std_dev", "NRMSD"]
12
13
    def generate_figures(self, filepath, savepath):
14
        f = h5.File(filepath, 'r')
15
        stats_dict, index_list = self._get_dicts_for_graphs(f)
16
        loop_stats, loop_plugins = self._get_dicts_for_loops(f)
17
        f.close()
18
19
        #self.make_loop_graphs(loop_stats, loop_plugins, savepath)
20
21
        table_index_list = index_list
22
        for i in range(len(loop_plugins)):
23
            for space in list(table_index_list.keys()):
24
                for j, plugin in enumerate(table_index_list[space]):
25
                    for loop_plugin in loop_plugins[i]:
26
                        if loop_plugin == plugin[3::]:
27
                            table_index_list[space][j] = f"{table_index_list[space][j]} (loop{i})"
28
29
        self.make_stats_table(stats_dict, table_index_list, f"{savepath}/stats_table.html")
30
31
        if len(stats_dict["projection"]["max"]):
32
            self.make_stats_graphs(stats_dict["projection"], index_list["projection"], "Projection Stats",
33
                                   f"{savepath}/projection_stats.png")
34
        if len(stats_dict["reconstruction"]["max"]):
35
            self.make_stats_graphs(stats_dict["reconstruction"], index_list["reconstruction"], "Reconstruction Stats",
36
                                   f"{savepath}/reconstruction_stats.png")
37
38
39
40
41
    @staticmethod
42
    def make_stats_table(stats_dict, index_list, savepath):
43
        p_stats = pd.DataFrame(stats_dict["projection"], index_list["projection"])
44
        r_stats = pd.DataFrame(stats_dict["reconstruction"], index_list["reconstruction"])
45
        all_stats = pd.concat([p_stats, r_stats], keys=["Projection", "Reconstruction"])
46
        all_stats.to_html(savepath)  # create table of stats for all plugins
47
48
    def make_loop_graphs(self, loop_stats, loop_plugins, savepath):
49
        for i in range(len(loop_stats)):
50
            y = loop_stats[i]["NRMSD"]
51
52
            #x = list(range(1, len(loop_stats[i]["RMSD"]) + 1))
53
            x = [None]*len(y)
54
            for j in range(len(loop_stats[i]["NRMSD"])):
55
                x[j] = f"{j}-{j+1}"
56
            ax = plt.figure(figsize=(11, 9), dpi=320).gca()
57
            ax.xaxis.set_major_locator(MaxNLocator(integer=True))
58
            #ax.locator_params(axis='x', nbins=j + 1)
59
            ax.grid(True)
60
            plt.plot(x, y)
61
            maxx = j
0 ignored issues
show
introduced by
The variable j does not seem to be defined for all execution paths.
Loading history...
62
            maxy = max(y)
63
            plt.title("NRMSD over loop 0")
64
            text = f"Loop 0 iterates {maxx + 2} times over:\n"
65
            for plugin in loop_plugins[i]:
66
                text += f"{plugin}\n"
67
            plt.xlabel("Iteration")
68
            plt.ylabel("NRMSD")
69
70
            plt.text(maxx, maxy, text, ha="right", va="top", bbox=dict(boxstyle="round", facecolor="red", alpha=0.4))
71
            plt.savefig(f"{savepath}/loop_stats{i}.png", bbox_inches="tight")
72
73
74
    def make_stats_graphs(self, stats_dict, index_list, title, savepath):
75
        stats_df = pd.DataFrame(stats_dict, index_list)
76
        stats_dict, array_plugins = self._remove_arrays(stats_dict, index_list)
77
78
        stats_df_new = pd.DataFrame(stats_dict, index_list)
79
80
        colours = ["red", "blue", "green", "black", "purple", "brown"]  #max, min, mean, mean std dev, median std dev, NRMSD
81
82
        new_index = []
83
        legend = ""
84
        for ind in stats_df_new.index:
85
            new_index.append(ind[0])  # change x ticks to only be plugin numbers rather than names (for space)
86
            legend += f"{ind}\n"  # This will form a key showing the plugin names corresponding to plugin numbers
87
        stats_df_new.index = new_index
88
        fig, ax = plt.subplots(3, 2, figsize=(11, 9), dpi=320, facecolor="lavender")
89
        i = 0
90
        for row in ax:
91
            for axis in row:
92
                stat = self._stats_list[i]
93
                axis.plot(stats_df_new[stat], "x-", color=colours[i])
94
                for plugin in array_plugins:  # adding 'error' bars for plugins with multiple values
95
                    if stats_df[stat][plugin] is not None:
96
                        if not np.isnan(stats_df[stat][plugin]).any():
97
                            my_max = max(stats_df[stat][plugin])
98
                            my_min = min(stats_df[stat][plugin])
99
                            middle = (my_max + my_min) / 2
100
                            my_range = my_max - my_min
101
                            axis.errorbar(list(stats_df_new.index).index(plugin[0]), middle, yerr=[my_range / 2], capsize=5)
102
                if i == 1:
103
                    maxx = len(stats_df_new[stat]) * 1.08 - 1
104
                    maxy = max(stats_df_new[stat])
105
                    axis.text(maxx, maxy, legend, ha="left", va="top",
106
                              bbox=dict(boxstyle="round", facecolor="red", alpha=0.4))
107
                stat = stat.replace("_", " ")
108
                axis.set_title(stat)
109
                axis.grid(True)
110
                i += 1
111
        fig.suptitle(title, fontsize="x-large")
112
        plt.savefig(savepath, bbox_inches="tight")
113
114
    @staticmethod
115
    def _get_dicts_for_graphs(file):
116
        stats_dict = {}
117
        stats_dict["projection"] = {"max": [], "min": [], "mean": [], "mean_std_dev": [], "median_std_dev": [],
118
                                    "NRMSD": []}
119
        stats_dict["reconstruction"] = {"max": [], "min": [], "mean": [], "mean_std_dev": [], "median_std_dev": [],
120
                                        "NRMSD": []}
121
122
        index_list = {"projection": [], "reconstruction": []}
123
124
        group = file["stats"]
125
        for space in ("projection", "reconstruction"):
126
            for index, stat in enumerate(["max", "min", "mean", "mean_std_dev", "median_std_dev", "NRMSD"]):
127
                for key in list(group.keys()):
128
                    if group[key].attrs.get("pattern") in StatsUtils._pattern_dict[space]:
129
                        if f"{key}: {group[key].attrs.get('plugin_name')}" not in index_list[space]:
130
                            index_list[space].append(f"{key}: {group[key].attrs.get('plugin_name')}")
131
                        if group[key].ndim == 1:
132
                            if len(group[key]) > index:
133
                                stats_dict[space][stat].append(group[key][index])
134
                            else:
135
                                stats_dict[space][stat].append(None)
136
                        elif group[key].ndim == 2:
137
                            if len(group[key][0]) > index:
138
                                stats_dict[space][stat].append(group[key][:, index])
139
                            else:
140
                                stats_dict[space][stat].append(None)
141
        return stats_dict, index_list
142
143
    @staticmethod
144
    def _get_dicts_for_loops(file):
145
        if "iterative" in list(file.keys()):
146
            group = file["iterative"]
147
            loop_stats = []
148
            loop_plugins = []
149
            for key in list(group.keys()):
150
                loop_stats.append({"NRMSD": list(group[key])})
151
                loop_plugins.append(group[key].attrs.get("loop_plugins"))
152
            return loop_stats, loop_plugins
153
        else:
154
            return [], []
155
156
    @staticmethod
157
    def _remove_arrays(stats_dict, index_list):
158
        array_plugins = set(())
159
        for stat in list(stats_dict.keys()):
160
            for index, value in enumerate(stats_dict[stat]):
161
                if isinstance(value, np.ndarray):
162
                    stats_dict[stat][index] = stats_dict[stat][index][0]
163
                    array_plugins.add(index_list[index])
164
        return stats_dict, array_plugins
165
166
167