| 1 |  |  | import matplotlib.pyplot as plt | 
            
                                                                                                            
                            
            
                                    
            
            
                | 2 |  |  | from matplotlib.ticker import MaxNLocator | 
            
                                                                                                            
                            
            
                                    
            
            
                | 3 |  |  | import pandas as pd | 
            
                                                                                                            
                            
            
                                    
            
            
                | 4 |  |  | import h5py as h5 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 5 |  |  | import numpy as np | 
            
                                                                                                            
                            
            
                                    
            
            
                | 6 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 7 |  |  | class StatsUtils(object): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 8 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 9 |  |  |     _pattern_dict = {"projection": ["SINOGRAM", "PROJECTION", "TANGENTOGRAM", "4D_SCAN", "SINOMOVIE"], | 
            
                                                                                                            
                            
            
                                    
            
            
                | 10 |  |  |                      "reconstruction": ["VOLUME_YZ", "VOLUME_XZ", "VOLUME_XY", "VOLUME_3D"]} | 
            
                                                                                                            
                            
            
                                    
            
            
                | 11 |  |  |     _stats_list = ["max", "min", "mean", "mean_std_dev", "median_std_dev", "NRMSD"] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 12 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 13 |  |  |     plt.set_loglevel('WARNING') | 
            
                                                                                                            
                            
            
                                    
            
            
                | 14 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 15 |  |  |     def generate_figures(self, filepath, savepath): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 16 |  |  |         f = h5.File(filepath, 'r') | 
            
                                                                                                            
                            
            
                                    
            
            
                | 17 |  |  |         stats_dict, index_list = self._get_dicts_for_graphs(f) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 18 |  |  |         loop_stats, loop_plugins = self._get_dicts_for_loops(f) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 19 |  |  |         f.close() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 20 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 21 |  |  |         #self.make_loop_graphs(loop_stats, loop_plugins, savepath) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 22 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 23 |  |  |         table_index_list = index_list | 
            
                                                                                                            
                            
            
                                    
            
            
                | 24 |  |  |         for i in range(len(loop_plugins)): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 25 |  |  |             for space in list(table_index_list.keys()): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 26 |  |  |                 for j, plugin in enumerate(table_index_list[space]): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 27 |  |  |                     for loop_plugin in loop_plugins[i]: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 28 |  |  |                         if loop_plugin == plugin[3::]: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 29 |  |  |                             table_index_list[space][j] = f"{table_index_list[space][j]} (loop{i})" | 
            
                                                                                                            
                            
            
                                    
            
            
                | 30 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 31 |  |  |         self.make_stats_table(stats_dict, table_index_list, f"{savepath}/stats_table.html") | 
            
                                                                                                            
                            
            
                                    
            
            
                | 32 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 33 |  |  |         if len(stats_dict["projection"]["max"]): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 34 |  |  |             self.make_stats_graphs(stats_dict["projection"], index_list["projection"], "Projection Stats", | 
            
                                                                                                            
                            
            
                                    
            
            
                | 35 |  |  |                                    f"{savepath}/projection_stats.png") | 
            
                                                                                                            
                            
            
                                    
            
            
                | 36 |  |  |         if len(stats_dict["reconstruction"]["max"]): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 37 |  |  |             self.make_stats_graphs(stats_dict["reconstruction"], index_list["reconstruction"], "Reconstruction Stats", | 
            
                                                                                                            
                            
            
                                    
            
            
                | 38 |  |  |                                    f"{savepath}/reconstruction_stats.png") | 
            
                                                                                                            
                            
            
                                    
            
            
                | 39 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 40 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 41 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 42 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 43 |  |  |     @staticmethod | 
            
                                                                                                            
                            
            
                                    
            
            
                | 44 |  |  |     def make_stats_table(stats_dict, index_list, savepath): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 45 |  |  |         p_stats = pd.DataFrame(stats_dict["projection"], index_list["projection"]) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 46 |  |  |         r_stats = pd.DataFrame(stats_dict["reconstruction"], index_list["reconstruction"]) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 47 |  |  |         all_stats = pd.concat([p_stats, r_stats], keys=["Projection", "Reconstruction"]) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 48 |  |  |         all_stats.to_html(savepath)  # create table of stats for all plugins | 
            
                                                                                                            
                            
            
                                    
            
            
                | 49 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 50 |  |  |     def make_loop_graphs(self, loop_stats, loop_plugins, savepath): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 51 |  |  |         for i in range(len(loop_stats)): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 52 |  |  |             y = loop_stats[i]["NRMSD"] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 53 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 54 |  |  |             #x = list(range(1, len(loop_stats[i]["RMSD"]) + 1)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 55 |  |  |             x = [None]*len(y) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 56 |  |  |             for j in range(len(loop_stats[i]["NRMSD"])): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 57 |  |  |                 x[j] = f"{j}-{j+1}" | 
            
                                                                                                            
                            
            
                                    
            
            
                | 58 |  |  |             ax = plt.figure(figsize=(11, 9), dpi=320).gca() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 59 |  |  |             ax.xaxis.set_major_locator(MaxNLocator(integer=True)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 60 |  |  |             #ax.locator_params(axis='x', nbins=j + 1) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 61 |  |  |             ax.grid(True) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 62 |  |  |             plt.plot(x, y) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 63 |  |  |             maxx = j | 
                            
                    |  |  |  | 
                                                                                        
                                                                                     | 
            
                                                                                                            
                            
            
                                    
            
            
                | 64 |  |  |             maxy = max(y) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 65 |  |  |             plt.title("NRMSD over loop 0") | 
            
                                                                                                            
                            
            
                                    
            
            
                | 66 |  |  |             text = f"Loop 0 iterates {maxx + 2} times over:\n" | 
            
                                                                                                            
                            
            
                                    
            
            
                | 67 |  |  |             for plugin in loop_plugins[i]: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 68 |  |  |                 text += f"{plugin}\n" | 
            
                                                                                                            
                            
            
                                    
            
            
                | 69 |  |  |             plt.xlabel("Iteration") | 
            
                                                                                                            
                            
            
                                    
            
            
                | 70 |  |  |             plt.ylabel("NRMSD") | 
            
                                                                                                            
                            
            
                                    
            
            
                | 71 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 72 |  |  |             plt.text(maxx, maxy, text, ha="right", va="top", bbox=dict(boxstyle="round", facecolor="red", alpha=0.4)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 73 |  |  |             plt.savefig(f"{savepath}/loop_stats{i}.png", bbox_inches="tight") | 
            
                                                                                                            
                            
            
                                    
            
            
                | 74 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 75 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 76 |  |  |     def make_stats_graphs(self, stats_dict, index_list, title, savepath): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 77 |  |  |         stats_df = pd.DataFrame(stats_dict, index_list) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 78 |  |  |         stats_dict, array_plugins = self._remove_arrays(stats_dict, index_list) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 79 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 80 |  |  |         stats_df_new = pd.DataFrame(stats_dict, index_list) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 81 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 82 |  |  |         colours = ["red", "blue", "green", "black", "purple", "brown"]  #max, min, mean, mean std dev, median std dev, NRMSD | 
            
                                                                                                            
                            
            
                                    
            
            
                | 83 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 84 |  |  |         new_index = [] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 85 |  |  |         legend = "" | 
            
                                                                                                            
                            
            
                                    
            
            
                | 86 |  |  |         for ind in stats_df_new.index: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 87 |  |  |             new_index.append(ind[0])  # change x ticks to only be plugin numbers rather than names (for space) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 88 |  |  |             legend += f"{ind}\n"  # This will form a key showing the plugin names corresponding to plugin numbers | 
            
                                                                                                            
                            
            
                                    
            
            
                | 89 |  |  |         stats_df_new.index = new_index | 
            
                                                                                                            
                            
            
                                    
            
            
                | 90 |  |  |         fig, ax = plt.subplots(3, 2, figsize=(11, 9), dpi=320, facecolor="lavender") | 
            
                                                                                                            
                            
            
                                    
            
            
                | 91 |  |  |         i = 0 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 92 |  |  |         for row in ax: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 93 |  |  |             for axis in row: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 94 |  |  |                 stat = self._stats_list[i] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 95 |  |  |                 axis.plot(stats_df_new[stat], "x-", color=colours[i]) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 96 |  |  |                 for plugin in array_plugins:  # adding 'error' bars for plugins with multiple values | 
            
                                                                                                            
                            
            
                                    
            
            
                | 97 |  |  |                     if stats_df[stat][plugin] is not None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 98 |  |  |                         if not np.isnan(stats_df[stat][plugin]).any(): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 99 |  |  |                             my_max = max(stats_df[stat][plugin]) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 100 |  |  |                             my_min = min(stats_df[stat][plugin]) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 101 |  |  |                             middle = (my_max + my_min) / 2 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 102 |  |  |                             my_range = my_max - my_min | 
            
                                                                                                            
                            
            
                                    
            
            
                | 103 |  |  |                             axis.errorbar(list(stats_df_new.index).index(plugin[0]), middle, yerr=[my_range / 2], capsize=5) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 104 |  |  |                 if i == 1: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 105 |  |  |                     maxx = len(stats_df_new[stat]) * 1.08 - 1 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 106 |  |  |                     maxy = max(stats_df_new[stat]) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 107 |  |  |                     axis.text(maxx, maxy, legend, ha="left", va="top", | 
            
                                                                                                            
                            
            
                                    
            
            
                | 108 |  |  |                               bbox=dict(boxstyle="round", facecolor="red", alpha=0.4)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 109 |  |  |                 stat = stat.replace("_", " ") | 
            
                                                                                                            
                            
            
                                    
            
            
                | 110 |  |  |                 axis.set_title(stat) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 111 |  |  |                 axis.grid(True) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 112 |  |  |                 i += 1 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 113 |  |  |         fig.suptitle(title, fontsize="x-large") | 
            
                                                                                                            
                            
            
                                    
            
            
                | 114 |  |  |         plt.savefig(savepath, bbox_inches="tight") | 
            
                                                                                                            
                                                                
            
                                    
            
            
                | 115 |  |  |  | 
            
                                                                        
                            
            
                                    
            
            
                | 116 |  |  |     @staticmethod | 
            
                                                                        
                            
            
                                    
            
            
                | 117 |  |  |     def _get_dicts_for_graphs(file): | 
            
                                                                        
                            
            
                                    
            
            
                | 118 |  |  |         stats_dict = {} | 
            
                                                                        
                            
            
                                    
            
            
                | 119 |  |  |         stats_dict["projection"] = {"max": [], "min": [], "mean": [], "mean_std_dev": [], "median_std_dev": [], | 
            
                                                                        
                            
            
                                    
            
            
                | 120 |  |  |                                     "NRMSD": []} | 
            
                                                                        
                            
            
                                    
            
            
                | 121 |  |  |         stats_dict["reconstruction"] = {"max": [], "min": [], "mean": [], "mean_std_dev": [], "median_std_dev": [], | 
            
                                                                        
                            
            
                                    
            
            
                | 122 |  |  |                                         "NRMSD": []} | 
            
                                                                        
                            
            
                                    
            
            
                | 123 |  |  |  | 
            
                                                                        
                            
            
                                    
            
            
                | 124 |  |  |         index_list = {"projection": [], "reconstruction": []} | 
            
                                                                        
                            
            
                                    
            
            
                | 125 |  |  |  | 
            
                                                                        
                            
            
                                    
            
            
                | 126 |  |  |         group = file["stats"] | 
            
                                                                        
                            
            
                                    
            
            
                | 127 |  |  |         for space in ("projection", "reconstruction"): | 
            
                                                                        
                            
            
                                    
            
            
                | 128 |  |  |             for index, stat in enumerate(["max", "min", "mean", "mean_std_dev", "median_std_dev", "NRMSD"]): | 
            
                                                                        
                            
            
                                    
            
            
                | 129 |  |  |                 for key in list(group.keys()): | 
            
                                                                        
                            
            
                                    
            
            
                | 130 |  |  |                     if group[key].attrs.get("pattern") in StatsUtils._pattern_dict[space]: | 
            
                                                                        
                            
            
                                    
            
            
                | 131 |  |  |                         if f"{key}: {group[key].attrs.get('plugin_name')}" not in index_list[space]: | 
            
                                                                        
                            
            
                                    
            
            
                | 132 |  |  |                             index_list[space].append(f"{key}: {group[key].attrs.get('plugin_name')}") | 
            
                                                                        
                            
            
                                    
            
            
                | 133 |  |  |                         if group[key].ndim == 1: | 
            
                                                                        
                            
            
                                    
            
            
                | 134 |  |  |                             if len(group[key]) > index: | 
            
                                                                        
                            
            
                                    
            
            
                | 135 |  |  |                                 stats_dict[space][stat].append(group[key][index]) | 
            
                                                                        
                            
            
                                    
            
            
                | 136 |  |  |                             else: | 
            
                                                                        
                            
            
                                    
            
            
                | 137 |  |  |                                 stats_dict[space][stat].append(None) | 
            
                                                                        
                            
            
                                    
            
            
                | 138 |  |  |                         elif group[key].ndim == 2: | 
            
                                                                        
                            
            
                                    
            
            
                | 139 |  |  |                             if len(group[key][0]) > index: | 
            
                                                                        
                            
            
                                    
            
            
                | 140 |  |  |                                 stats_dict[space][stat].append(group[key][:, index]) | 
            
                                                                        
                            
            
                                    
            
            
                | 141 |  |  |                             else: | 
            
                                                                        
                            
            
                                    
            
            
                | 142 |  |  |                                 stats_dict[space][stat].append(None) | 
            
                                                                        
                            
            
                                    
            
            
                | 143 |  |  |         return stats_dict, index_list | 
            
                                                                                                            
                            
            
                                    
            
            
                | 144 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 145 |  |  |     @staticmethod | 
            
                                                                                                            
                            
            
                                    
            
            
                | 146 |  |  |     def _get_dicts_for_loops(file): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 147 |  |  |         if "iterative" in list(file.keys()): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 148 |  |  |             group = file["iterative"] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 149 |  |  |             loop_stats = [] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 150 |  |  |             loop_plugins = [] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 151 |  |  |             for key in list(group.keys()): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 152 |  |  |                 loop_stats.append({"NRMSD": list(group[key])}) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 153 |  |  |                 loop_plugins.append(group[key].attrs.get("loop_plugins")) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 154 |  |  |             return loop_stats, loop_plugins | 
            
                                                                                                            
                            
            
                                    
            
            
                | 155 |  |  |         else: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 156 |  |  |             return [], [] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 157 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 158 |  |  |     @staticmethod | 
            
                                                                                                            
                            
            
                                    
            
            
                | 159 |  |  |     def _remove_arrays(stats_dict, index_list): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 160 |  |  |         array_plugins = set(()) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 161 |  |  |         for stat in list(stats_dict.keys()): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 162 |  |  |             for index, value in enumerate(stats_dict[stat]): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 163 |  |  |                 if isinstance(value, np.ndarray): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 164 |  |  |                     stats_dict[stat][index] = stats_dict[stat][index][0] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 165 |  |  |                     array_plugins.add(index_list[index]) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 166 |  |  |         return stats_dict, array_plugins | 
            
                                                                                                            
                            
            
                                    
            
            
                | 167 |  |  |  | 
            
                                                                                                            
                                                                
            
                                    
            
            
                | 168 |  |  |  | 
            
                                                        
            
                                    
            
            
                | 169 |  |  |  |