Test Failed
Pull Request — master (#878)
by
unknown
04:00
created

savu.plugins.stats.stats_utils   A

Complexity

Total Complexity 24

Size/Duplication

Total Lines 110
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 94
dl 0
loc 110
rs 10
c 0
b 0
f 0
wmc 24

5 Methods

Rating   Name   Duplication   Size   Complexity  
B StatsUtils.make_stats_graphs() 0 37 6
A StatsUtils._remove_arrays() 0 9 4
A StatsUtils.make_stats_table() 0 6 1
A StatsUtils.generate_figures() 0 13 3
C StatsUtils._get_dicts_for_graphs() 0 28 10
1
import matplotlib.pyplot as plt
2
import pandas as pd
3
import h5py as h5
4
import numpy as np
5
6
class StatsUtils(object):
7
8
    _pattern_dict = {"projection": ["SINOGRAM", "PROJECTION", "TANGENTOGRAM", "4D_SCAN", "SINOMOVIE"],
9
                     "reconstruction": ["VOLUME_YZ", "VOLUME_XZ", "VOLUME_XY", "VOLUME_3D"]}
10
    _stats_list = ["max", "min", "mean", "mean_std_dev", "median_std_dev", "RMSD"]
11
12
    def generate_figures(self, filepath, savepath):
13
        f = h5.File(filepath, 'r')
14
        stats_dict, index_list = self._get_dicts_for_graphs(f)
15
        f.close()
16
17
        self.make_stats_table(stats_dict, index_list, f"{savepath}/stats_table.html")
18
19
        if len(stats_dict["projection"]["max"]):
20
            self.make_stats_graphs(stats_dict["projection"], index_list["projection"], "Projection Stats",
21
                                   f"{savepath}/projection_stats.png")
22
        if len(stats_dict["reconstruction"]["max"]):
23
            self.make_stats_graphs(stats_dict["reconstruction"], index_list["reconstruction"], "Reconstruction Stats",
24
                                   f"{savepath}/reconstruction_stats.png")
25
26
    @staticmethod
27
    def make_stats_table(stats_dict, index_list, savepath):
28
        p_stats = pd.DataFrame(stats_dict["projection"], index_list["projection"])
29
        r_stats = pd.DataFrame(stats_dict["reconstruction"], index_list["reconstruction"])
30
        all_stats = pd.concat([p_stats, r_stats], keys=["Projection", "Reconstruction"])
31
        all_stats.to_html(savepath)  # create table of stats for all plugins
32
33
    def make_stats_graphs(self, stats_dict, index_list, title, savepath):
34
        stats_df = pd.DataFrame(stats_dict, index_list)
35
        stats_dict, array_plugins = self._remove_arrays(stats_dict, index_list)
36
37
        stats_df_new = pd.DataFrame(stats_dict, index_list)
38
39
        colours = ["red", "blue", "green", "black", "purple", "brown"]  #max, min, mean, mean std dev, median std dev, RMSD
40
41
        new_index = []
42
        legend = ""
43
        for ind in stats_df_new.index:
44
            new_index.append(ind[0])  # change x ticks to only be plugin numbers rather than names (for space)
45
            legend += f"{ind}\n"  # This will form a key showing the plugin names corresponding to plugin numbers
46
        stats_df_new.index = new_index
47
        fig, ax = plt.subplots(3, 2, figsize=(11, 9), dpi=320, facecolor="lavender")
48
        i = 0
49
        for row in ax:
50
            for axis in row:
51
                stat = self._stats_list[i]
52
                axis.plot(stats_df_new[stat], "x-", color=colours[i])
53
                for plugin in array_plugins:  # adding 'error' bars for plugins with multiple values due to parameter changes
54
                    my_max = max(stats_df[stat][plugin])
55
                    my_min = min(stats_df[stat][plugin])
56
                    middle = (my_max + my_min) / 2
57
                    my_range = my_max - my_min
58
                    axis.errorbar(int(plugin[0]) - int(stats_df_new.index[0]), middle, yerr=[my_range / 2], capsize=5)
59
                if i == 1:
60
                    maxx = len(stats_df_new[stat]) * 1.08 - 1
61
                    maxy = max(stats_df_new[stat])
62
                    axis.text(maxx, maxy, legend, ha="left", va="top",
63
                              bbox=dict(boxstyle="round", facecolor="red", alpha=0.4))
64
                stat = stat.replace("_", " ")
65
                axis.set_title(stat)
66
                axis.grid(True)
67
                i += 1
68
        fig.suptitle(title, fontsize="x-large")
69
        plt.savefig(savepath, bbox_inches="tight")
70
71
    @staticmethod
72
    def _get_dicts_for_graphs(file):
73
        stats_dict = {}
74
        stats_dict["projection"] = {"max": [], "min": [], "mean": [], "mean_std_dev": [], "median_std_dev": [],
75
                                    "RMSD": []}
76
        stats_dict["reconstruction"] = {"max": [], "min": [], "mean": [], "mean_std_dev": [], "median_std_dev": [],
77
                                        "RMSD": []}
78
79
        index_list = {"projection": [], "reconstruction": []}
80
81
        group = file["stats"]
82
        for space in ("projection", "reconstruction"):
83
            for index, stat in enumerate(["max", "min", "mean", "mean_std_dev", "median_std_dev", "RMSD"]):
84
                for key in list(group.keys()):
85
                    if group[key].attrs.get("pattern") in StatsUtils._pattern_dict[space]:
86
                        if f"{key}: {group[key].attrs.get('plugin_name')}" not in index_list[space]:
87
                            index_list[space].append(f"{key}: {group[key].attrs.get('plugin_name')}")
88
                        if group[key].ndim == 1:
89
                            if len(group[key]) > index:
90
                                stats_dict[space][stat].append(group[key][index])
91
                            else:
92
                                stats_dict[space][stat].append(None)
93
                        elif group[key].ndim == 2:
94
                            if len(group[key][0]) > index:
95
                                stats_dict[space][stat].append(group[key][:, index])
96
                            else:
97
                                stats_dict[space][stat].append(None)
98
        return stats_dict, index_list
99
100
101
    @staticmethod
102
    def _remove_arrays(stats_dict, index_list):
103
        array_plugins = set(())
104
        for stat in list(stats_dict.keys()):
105
            for index, value in enumerate(stats_dict[stat]):
106
                if isinstance(value, np.ndarray):
107
                    stats_dict[stat][index] = stats_dict[stat][index][0]
108
                    array_plugins.add(index_list[index])
109
        return stats_dict, array_plugins
110
111
112