Test Failed
Pull Request — master (#878)
by
unknown
04:38
created

StatsUtils.make_loop_graphs()   A

Complexity

Conditions 4

Size

Total Lines 24
Code Lines 20

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 4
eloc 20
nop 4
dl 0
loc 24
rs 9.4
c 0
b 0
f 0
1
import matplotlib.pyplot as plt
2
from matplotlib.ticker import MaxNLocator
3
import pandas as pd
4
import h5py as h5
5
import numpy as np
6
7
class StatsUtils(object):
8
9
    _pattern_dict = {"projection": ["SINOGRAM", "PROJECTION", "TANGENTOGRAM", "4D_SCAN", "SINOMOVIE"],
10
                     "reconstruction": ["VOLUME_YZ", "VOLUME_XZ", "VOLUME_XY", "VOLUME_3D"]}
11
    _stats_list = ["max", "min", "mean", "mean_std_dev", "median_std_dev", "NRMSD"]
12
13
    plt.set_loglevel('WARNING')
14
15
    def generate_figures(self, filepath, savepath):
16
        f = h5.File(filepath, 'r')
17
        stats_dict, index_list = self._get_dicts_for_graphs(f)
18
        loop_stats, loop_plugins = self._get_dicts_for_loops(f)
19
        f.close()
20
21
        #self.make_loop_graphs(loop_stats, loop_plugins, savepath)
22
23
        table_index_list = index_list
24
        for i in range(len(loop_plugins)):
25
            for space in list(table_index_list.keys()):
26
                for j, plugin in enumerate(table_index_list[space]):
27
                    for loop_plugin in loop_plugins[i]:
28
                        if loop_plugin == plugin[3::]:
29
                            table_index_list[space][j] = f"{table_index_list[space][j]} (loop{i})"
30
31
        self.make_stats_table(stats_dict, table_index_list, f"{savepath}/stats_table.html")
32
33
        if len(stats_dict["projection"]["max"]):
34
            self.make_stats_graphs(stats_dict["projection"], index_list["projection"], "Projection Stats",
35
                                   f"{savepath}/projection_stats.png")
36
        if len(stats_dict["reconstruction"]["max"]):
37
            self.make_stats_graphs(stats_dict["reconstruction"], index_list["reconstruction"], "Reconstruction Stats",
38
                                   f"{savepath}/reconstruction_stats.png")
39
40
41
42
43
    @staticmethod
44
    def make_stats_table(stats_dict, index_list, savepath):
45
        p_stats = pd.DataFrame(stats_dict["projection"], index_list["projection"])
46
        r_stats = pd.DataFrame(stats_dict["reconstruction"], index_list["reconstruction"])
47
        all_stats = pd.concat([p_stats, r_stats], keys=["Projection", "Reconstruction"])
48
        all_stats.to_html(savepath)  # create table of stats for all plugins
49
50
    def make_loop_graphs(self, loop_stats, loop_plugins, savepath):
51
        for i in range(len(loop_stats)):
52
            y = loop_stats[i]["NRMSD"]
53
54
            #x = list(range(1, len(loop_stats[i]["RMSD"]) + 1))
55
            x = [None]*len(y)
56
            for j in range(len(loop_stats[i]["NRMSD"])):
57
                x[j] = f"{j}-{j+1}"
58
            ax = plt.figure(figsize=(11, 9), dpi=320).gca()
59
            ax.xaxis.set_major_locator(MaxNLocator(integer=True))
60
            #ax.locator_params(axis='x', nbins=j + 1)
61
            ax.grid(True)
62
            plt.plot(x, y)
63
            maxx = j
0 ignored issues
show
introduced by
The variable j does not seem to be defined for all execution paths.
Loading history...
64
            maxy = max(y)
65
            plt.title("NRMSD over loop 0")
66
            text = f"Loop 0 iterates {maxx + 2} times over:\n"
67
            for plugin in loop_plugins[i]:
68
                text += f"{plugin}\n"
69
            plt.xlabel("Iteration")
70
            plt.ylabel("NRMSD")
71
72
            plt.text(maxx, maxy, text, ha="right", va="top", bbox=dict(boxstyle="round", facecolor="red", alpha=0.4))
73
            plt.savefig(f"{savepath}/loop_stats{i}.png", bbox_inches="tight")
74
75
76
    def make_stats_graphs(self, stats_dict, index_list, title, savepath):
77
        stats_df = pd.DataFrame(stats_dict, index_list)
78
        stats_dict, array_plugins = self._remove_arrays(stats_dict, index_list)
79
80
        stats_df_new = pd.DataFrame(stats_dict, index_list)
81
82
        colours = ["red", "blue", "green", "black", "purple", "brown"]  #max, min, mean, mean std dev, median std dev, NRMSD
83
84
        new_index = []
85
        legend = ""
86
        for ind in stats_df_new.index:
87
            new_index.append(ind[0])  # change x ticks to only be plugin numbers rather than names (for space)
88
            legend += f"{ind}\n"  # This will form a key showing the plugin names corresponding to plugin numbers
89
        stats_df_new.index = new_index
90
        fig, ax = plt.subplots(3, 2, figsize=(11, 9), dpi=320, facecolor="lavender")
91
        i = 0
92
        for row in ax:
93
            for axis in row:
94
                stat = self._stats_list[i]
95
                axis.plot(stats_df_new[stat], "x-", color=colours[i])
96
                for plugin in array_plugins:  # adding 'error' bars for plugins with multiple values
97
                    if stats_df[stat][plugin] is not None:
98
                        if not np.isnan(stats_df[stat][plugin]).any():
99
                            my_max = max(stats_df[stat][plugin])
100
                            my_min = min(stats_df[stat][plugin])
101
                            middle = (my_max + my_min) / 2
102
                            my_range = my_max - my_min
103
                            axis.errorbar(list(stats_df_new.index).index(plugin[0]), middle, yerr=[my_range / 2], capsize=5)
104
                if i == 1:
105
                    maxx = len(stats_df_new[stat]) * 1.08 - 1
106
                    maxy = max(stats_df_new[stat])
107
                    axis.text(maxx, maxy, legend, ha="left", va="top",
108
                              bbox=dict(boxstyle="round", facecolor="red", alpha=0.4))
109
                stat = stat.replace("_", " ")
110
                axis.set_title(stat)
111
                axis.grid(True)
112
                i += 1
113
        fig.suptitle(title, fontsize="x-large")
114
        plt.savefig(savepath, bbox_inches="tight")
115
116
    @staticmethod
117
    def _get_dicts_for_graphs(file):
118
        stats_dict = {}
119
        stats_dict["projection"] = {"max": [], "min": [], "mean": [], "mean_std_dev": [], "median_std_dev": [],
120
                                    "NRMSD": []}
121
        stats_dict["reconstruction"] = {"max": [], "min": [], "mean": [], "mean_std_dev": [], "median_std_dev": [],
122
                                        "NRMSD": []}
123
124
        index_list = {"projection": [], "reconstruction": []}
125
126
        group = file["stats"]
127
        for space in ("projection", "reconstruction"):
128
            for index, stat in enumerate(["max", "min", "mean", "mean_std_dev", "median_std_dev", "NRMSD"]):
129
                for key in list(group.keys()):
130
                    if group[key].attrs.get("pattern") in StatsUtils._pattern_dict[space]:
131
                        if f"{key}: {group[key].attrs.get('plugin_name')}" not in index_list[space]:
132
                            index_list[space].append(f"{key}: {group[key].attrs.get('plugin_name')}")
133
                        if group[key].ndim == 1:
134
                            if len(group[key]) > index:
135
                                stats_dict[space][stat].append(group[key][index])
136
                            else:
137
                                stats_dict[space][stat].append(None)
138
                        elif group[key].ndim == 2:
139
                            if len(group[key][0]) > index:
140
                                stats_dict[space][stat].append(group[key][:, index])
141
                            else:
142
                                stats_dict[space][stat].append(None)
143
        return stats_dict, index_list
144
145
    @staticmethod
146
    def _get_dicts_for_loops(file):
147
        if "iterative" in list(file.keys()):
148
            group = file["iterative"]
149
            loop_stats = []
150
            loop_plugins = []
151
            for key in list(group.keys()):
152
                loop_stats.append({"NRMSD": list(group[key])})
153
                loop_plugins.append(group[key].attrs.get("loop_plugins"))
154
            return loop_stats, loop_plugins
155
        else:
156
            return [], []
157
158
    @staticmethod
159
    def _remove_arrays(stats_dict, index_list):
160
        array_plugins = set(())
161
        for stat in list(stats_dict.keys()):
162
            for index, value in enumerate(stats_dict[stat]):
163
                if isinstance(value, np.ndarray):
164
                    stats_dict[stat][index] = stats_dict[stat][index][0]
165
                    array_plugins.add(index_list[index])
166
        return stats_dict, array_plugins
167
168
169