DiamondLightSource /
Savu
| 1 | """ |
||
| 2 | .. module:: statistics |
||
| 3 | :platform: Unix |
||
| 4 | :synopsis: Contains and processes statistics information for each plugin. |
||
| 5 | |||
| 6 | .. moduleauthor::Jacob Williamson <[email protected]> |
||
| 7 | |||
| 8 | """ |
||
| 9 | import logging |
||
| 10 | |||
| 11 | from savu.plugins.savers.utils.hdf5_utils import Hdf5Utils |
||
| 12 | from savu.data.stats.stats_utils import StatsUtils |
||
| 13 | from savu.core.iterate_plugin_group_utils import check_if_in_iterative_loop |
||
| 14 | import savu.core.utils as cu |
||
| 15 | |||
| 16 | import time |
||
| 17 | import h5py as h5 |
||
| 18 | import numpy as np |
||
| 19 | import os |
||
| 20 | from mpi4py import MPI |
||
| 21 | from collections import OrderedDict |
||
| 22 | |||
| 23 | class Statistics(object): |
||
| 24 | _pattern_list = ["SINOGRAM", "PROJECTION", "TANGENTOGRAM", "VOLUME_YZ", "VOLUME_XZ", "VOLUME_XY", "VOLUME_3D", "4D_SCAN", "SINOMOVIE"] |
||
| 25 | _no_stats_plugins = ["BasicOperations", "Mipmap", "UnetApply"] |
||
| 26 | _possible_stats = ("max", "min", "mean", "mean_std_dev", "median_std_dev", "NRMSD", "zeros", "zeros%", "range_used") # list of possible stats |
||
| 27 | _volume_to_slice = {"max": "max", "min": "min", "mean": "mean", "mean_std_dev": "std_dev", |
||
| 28 | "median_std_dev": "std_dev", "NRMSD": ("RSS", "data_points", "max", "min"), |
||
| 29 | "zeros": ("zeros", "data_points"), "zeros%": ("zeros", "data_points"), |
||
| 30 | "range_used": ("min", "max")} # volume stat: required slice stat(s) |
||
| 31 | #_savers = ["Hdf5Saver", "ImageSaver", "MrcSaver", "TiffSaver", "XrfSaver"] |
||
| 32 | _has_setup = False |
||
| 33 | |||
| 34 | |||
| 35 | def __init__(self): |
||
| 36 | |||
| 37 | self.calc_stats = True |
||
| 38 | self.stats_before_processing = {'max': [], 'min': [], 'mean': [], 'std_dev': []} |
||
| 39 | self.residuals = {'max': [], 'min': [], 'mean': [], 'std_dev': []} |
||
| 40 | self._repeat_count = 0 |
||
| 41 | self.plugin = None |
||
| 42 | self.p_num = None |
||
| 43 | self.stats_key = ["max", "min", "mean", "mean_std_dev", "median_std_dev", "RMSD"] |
||
| 44 | self.slice_stats_key = None |
||
| 45 | self.stats = None |
||
| 46 | self.GPU = False |
||
| 47 | self._iterative_group = None |
||
| 48 | |||
| 49 | def setup(self, plugin_self, pattern=None): |
||
| 50 | if not Statistics._has_setup: |
||
| 51 | self._setup_class(plugin_self.exp) |
||
| 52 | self.plugin_name = plugin_self.name |
||
| 53 | self.p_num = Statistics.count |
||
| 54 | self.plugin = plugin_self |
||
| 55 | self.set_stats_key(self.stats_key) |
||
| 56 | self.stats = {stat: [] for stat in self.slice_stats_key} |
||
| 57 | if plugin_self.name in Statistics._no_stats_plugins: |
||
| 58 | self.calc_stats = False |
||
| 59 | if self.calc_stats: |
||
| 60 | self._pad_dims = [] |
||
| 61 | self._already_called = False |
||
| 62 | if pattern is not None: |
||
| 63 | self.pattern = pattern |
||
| 64 | else: |
||
| 65 | self._set_pattern_info() |
||
| 66 | if self.calc_stats: |
||
| 67 | Statistics._any_stats = True |
||
| 68 | self._setup_4d() |
||
| 69 | self._setup_iterative() |
||
| 70 | |||
| 71 | def _setup_iterative(self): |
||
| 72 | self._iterative_group = check_if_in_iterative_loop(Statistics.exp) |
||
| 73 | if self._iterative_group: |
||
| 74 | if self._iterative_group.start_index == Statistics.count: |
||
| 75 | Statistics._loop_counter += 1 |
||
| 76 | Statistics.loop_stats.append({"NRMSD": np.array([])}) |
||
| 77 | self.l_num = Statistics._loop_counter - 1 |
||
| 78 | |||
| 79 | def _setup_4d(self): |
||
| 80 | try: |
||
| 81 | in_dataset, out_dataset = self.plugin.get_datasets() |
||
| 82 | if in_dataset[0].data_info["nDims"] == 4 and len(out_dataset) != 0: |
||
| 83 | self._4d = True |
||
| 84 | shape = out_dataset[0].data_info["shape"] |
||
| 85 | self._volume_total_points = 1 |
||
| 86 | for i in shape[:-1]: |
||
| 87 | self._volume_total_points *= i |
||
| 88 | else: |
||
| 89 | self._4d = False |
||
| 90 | except KeyError: |
||
| 91 | self._4d = False |
||
| 92 | |||
| 93 | @classmethod |
||
| 94 | def _setup_class(cls, exp): |
||
| 95 | """Sets up the statistics class for the whole plugin chain (only called once)""" |
||
| 96 | if exp.meta_data.get("stats") == "on": |
||
| 97 | cls._stats_flag = True |
||
| 98 | elif exp.meta_data.get("stats") == "off": |
||
| 99 | cls._stats_flag = False |
||
| 100 | cls._any_stats = False |
||
| 101 | cls.exp = exp |
||
| 102 | cls.count = 2 |
||
| 103 | cls.global_stats = {} |
||
| 104 | cls.global_times = {} |
||
| 105 | cls.loop_stats = [] |
||
| 106 | cls.n_plugins = len(exp.meta_data.plugin_list.plugin_list) |
||
| 107 | for i in range(1, cls.n_plugins + 1): |
||
| 108 | cls.global_stats[i] = {} |
||
| 109 | cls.global_times[i] = 0 |
||
| 110 | cls.global_residuals = {} |
||
| 111 | cls.plugin_numbers = {} |
||
| 112 | cls.plugin_names = {} |
||
| 113 | cls._loop_counter = 0 |
||
| 114 | cls.path = exp.meta_data['out_path'] |
||
| 115 | if cls.path[-1] == '/': |
||
| 116 | cls.path = cls.path[0:-1] |
||
| 117 | cls.path = f"{cls.path}/stats" |
||
| 118 | if MPI.COMM_WORLD.rank == 0: |
||
| 119 | if not os.path.exists(cls.path): |
||
| 120 | os.mkdir(cls.path) |
||
| 121 | cls._has_setup = True |
||
| 122 | |||
| 123 | def get_stats(self, p_num=None, stat=None, instance=-1): |
||
| 124 | """Returns stats associated with a certain plugin, given the plugin number (its place in the process list). |
||
| 125 | |||
| 126 | :param p_num: Plugin number of the plugin whose associated stats are being fetched. |
||
| 127 | If p_num <= 0, it is relative to the plugin number of the current plugin being run. |
||
| 128 | E.g current plugin number = 5, p_num = -2 --> will return stats of the third plugin. |
||
| 129 | By default will gather stats for the current plugin. |
||
| 130 | :param stat: Specify the stat parameter you want to fetch, i.e 'max', 'mean', 'median_std_dev'. |
||
| 131 | If left blank will return the whole dictionary of stats: |
||
| 132 | {'max': , 'min': , 'mean': , 'mean_std_dev': , 'median_std_dev': , 'NRMSD': } |
||
| 133 | :param instance: In cases where there are multiple set of stats associated with a plugin |
||
| 134 | due to iterative loops or multi-parameters, specify which set you want to retrieve, i.e 3 to retrieve the |
||
| 135 | stats associated with the third run of a plugin. Pass 'all' to get a list of all sets. |
||
| 136 | By default will retrieve the most recent set. |
||
| 137 | """ |
||
| 138 | if p_num is None: |
||
| 139 | p_num = self.p_num |
||
| 140 | if p_num <= 0: |
||
| 141 | try: |
||
| 142 | p_num = self.p_num + p_num |
||
| 143 | except TypeError: |
||
| 144 | p_num = Statistics.count + p_num |
||
| 145 | if instance == "all": |
||
| 146 | stats_list = [self.get_stats(p_num, stat=stat, instance=1)] |
||
| 147 | n = 2 |
||
| 148 | while n <= len(Statistics.global_stats[p_num]): |
||
| 149 | stats_list.append(self.get_stats(p_num, stat=stat, instance=n)) |
||
| 150 | n += 1 |
||
| 151 | return stats_list |
||
| 152 | if instance > 0: |
||
| 153 | instance -= 1 |
||
| 154 | stats_dict = Statistics.global_stats[p_num][instance] |
||
| 155 | if stat is not None: |
||
| 156 | return stats_dict[stat] |
||
| 157 | else: |
||
| 158 | return stats_dict |
||
| 159 | |||
| 160 | def get_stats_from_name(self, plugin_name, n=None, stat=None, instance=-1): |
||
| 161 | """Returns stats associated with a certain plugin. |
||
| 162 | |||
| 163 | :param plugin_name: name of the plugin whose associated stats are being fetched. |
||
| 164 | :param n: In a case where there are multiple instances of **plugin_name** in the process list, |
||
| 165 | specify the nth instance. Not specifying will select the first (or only) instance. |
||
| 166 | :param stat: Specify the stat parameter you want to fetch, i.e 'max', 'mean', 'median_std_dev'. |
||
| 167 | If left blank will return the whole dictionary of stats: |
||
| 168 | {'max': , 'min': , 'mean': , 'mean_std_dev': , 'median_std_dev': , 'NRMSD': } |
||
| 169 | :param instance: In cases where there are multiple set of stats associated with a plugin |
||
| 170 | due to iterative loops or multi-parameters, specify which set you want to retrieve, i.e 3 to retrieve the |
||
| 171 | stats associated with the third run of a plugin. Pass 'all' to get a list of all sets. |
||
| 172 | By default will retrieve the most recent set. |
||
| 173 | """ |
||
| 174 | name = plugin_name |
||
| 175 | if n not in (None, 0, 1): |
||
| 176 | name = name + str(n) |
||
| 177 | p_num = Statistics.plugin_numbers[name] |
||
| 178 | return self.get_stats(p_num, stat, instance) |
||
| 179 | |||
| 180 | def get_stats_from_dataset(self, dataset, stat=None, instance=-1): |
||
| 181 | """Returns stats associated with a dataset. |
||
| 182 | |||
| 183 | :param dataset: The dataset whose associated stats are being fetched. |
||
| 184 | :param stat: Specify the stat parameter you want to fetch, i.e 'max', 'mean', 'median_std_dev'. |
||
| 185 | If left blank will return the whole dictionary of stats: |
||
| 186 | {'max': , 'min': , 'mean': , 'mean_std_dev': , 'median_std_dev': , 'NRMSD': } |
||
| 187 | :param instance: In cases where there are multiple set of stats associated with a dataset |
||
| 188 | due to iterative loops or multi-parameters, specify which set you want to retrieve, i.e 3 to retrieve the |
||
| 189 | stats associated with the third run of a plugin. Pass 'all' to get a list of all sets. |
||
| 190 | By default will retrieve the most recent set. |
||
| 191 | """ |
||
| 192 | stats_list = [dataset.meta_data.get("stats")] |
||
| 193 | n = 2 |
||
| 194 | while ("stats" + str(n)) in list(dataset.meta_data.get_dictionary().keys()): |
||
| 195 | stats_list.append(dataset.meta_data.get("stats" + str(n))) |
||
| 196 | n += 1 |
||
| 197 | if stat: |
||
| 198 | for i in range(len(stats_list)): |
||
| 199 | stats_list[i] = stats_list[i][stat] |
||
| 200 | if instance in (None, 0, 1): |
||
| 201 | stats = stats_list[0] |
||
| 202 | elif instance == "all": |
||
| 203 | stats = stats_list |
||
| 204 | else: |
||
| 205 | if instance >= 2: |
||
| 206 | instance -= 1 |
||
| 207 | stats = stats_list[instance] |
||
| 208 | return stats |
||
| 209 | |||
| 210 | def set_stats_key(self, stats_key): |
||
| 211 | """Changes which stats are to be calculated for the current plugin. |
||
| 212 | |||
| 213 | :param stats_key: List of stats to be calculated. |
||
| 214 | """ |
||
| 215 | valid = Statistics._possible_stats |
||
| 216 | stats_key = sorted(set(valid).intersection(stats_key), key=lambda stat: valid.index(stat)) |
||
| 217 | self.stats_key = stats_key |
||
| 218 | self.slice_stats_key = list(set(self._flatten(list(Statistics._volume_to_slice[stat] for stat in stats_key)))) |
||
| 219 | if "data_points" not in self.slice_stats_key: |
||
| 220 | self.slice_stats_key.append("data_points") # Data points is essential |
||
| 221 | |||
| 222 | def set_slice_stats(self, my_slice, base_slice=None, pad=True): |
||
| 223 | """Sets slice stats for the current slice. |
||
| 224 | |||
| 225 | :param my_slice: The slice whose stats are being set. |
||
| 226 | :param base_slice: Provide a base slice to calculate residuals from, to calculate RMSD. |
||
| 227 | :param pad: Specify whether slice is padded or not (usually can leave as True even if slice is not padded). |
||
| 228 | """ |
||
| 229 | my_slice = self._de_list(my_slice) |
||
| 230 | if 0 not in my_slice.shape: |
||
| 231 | try: |
||
| 232 | slice_stats = self.calc_slice_stats(my_slice, base_slice=base_slice, pad=pad) |
||
| 233 | except: |
||
| 234 | pass |
||
| 235 | if slice_stats is not None: |
||
| 236 | for key, value in slice_stats.items(): |
||
| 237 | self.stats[key].append(value) |
||
| 238 | if self._4d: |
||
| 239 | if sum(self.stats["data_points"]) >= self._volume_total_points: |
||
| 240 | self.set_volume_stats() |
||
| 241 | else: |
||
| 242 | self.calc_stats = False |
||
| 243 | else: |
||
| 244 | self.calc_stats = False |
||
| 245 | |||
| 246 | def calc_slice_stats(self, my_slice, base_slice=None, pad=True): |
||
| 247 | """Calculates and returns slice stats for the current slice. |
||
| 248 | |||
| 249 | :param my_slice: The slice whose stats are being calculated. |
||
| 250 | :param base_slice: Provide a base slice to calculate residuals from, to calculate RMSD. |
||
| 251 | :param pad: Specify whether slice is padded or not (usually can leave as True even if slice is not padded). |
||
| 252 | """ |
||
| 253 | if my_slice is not None: |
||
| 254 | my_slice = self._de_list(my_slice) |
||
| 255 | if pad: |
||
| 256 | my_slice = self._unpad_slice(my_slice) |
||
| 257 | slice_stats = {} |
||
| 258 | if "max" in self.slice_stats_key: |
||
| 259 | slice_stats["max"] = np.amax(my_slice).astype('float64') |
||
| 260 | if "min" in self.slice_stats_key: |
||
| 261 | slice_stats["min"] = np.amin(my_slice).astype('float64') |
||
| 262 | if "mean" in self.slice_stats_key: |
||
| 263 | slice_stats["mean"] = np.mean(my_slice) |
||
| 264 | if "std_dev" in self.slice_stats_key: |
||
| 265 | slice_stats["std_dev"] = np.std(my_slice) |
||
| 266 | if "zeros" in self.slice_stats_key: |
||
| 267 | slice_stats["zeros"] = self.calc_zeros(my_slice) |
||
| 268 | if "data_points" in self.slice_stats_key: |
||
| 269 | slice_stats["data_points"] = my_slice.size |
||
| 270 | if "RSS" in self.slice_stats_key and base_slice is not None: |
||
| 271 | base_slice = self._de_list(base_slice) |
||
| 272 | base_slice = self._unpad_slice(base_slice) |
||
| 273 | slice_stats["RSS"] = self.calc_rss(my_slice, base_slice) |
||
| 274 | if "dtype" not in self.stats: |
||
| 275 | self.stats["dtype"] = my_slice.dtype |
||
| 276 | return slice_stats |
||
| 277 | return None |
||
| 278 | |||
| 279 | @staticmethod |
||
| 280 | def calc_zeros(my_slice): |
||
| 281 | return my_slice.size - np.count_nonzero(my_slice) |
||
| 282 | |||
| 283 | @staticmethod |
||
| 284 | def calc_rss(array1, array2): # residual sum of squares |
||
| 285 | if array1.shape == array2.shape: |
||
| 286 | residuals = np.subtract(array1, array2) |
||
| 287 | rss = np.sum(residuals.flatten() ** 2) |
||
| 288 | else: |
||
| 289 | logging.debug("Cannot calculate RSS, arrays different sizes.") |
||
| 290 | rss = None |
||
| 291 | return rss |
||
| 292 | |||
| 293 | @staticmethod |
||
| 294 | def rmsd_from_rss(rss, n): |
||
| 295 | return np.sqrt(rss/n) |
||
| 296 | |||
| 297 | def calc_rmsd(self, array1, array2): |
||
| 298 | if array1.shape == array2.shape: |
||
| 299 | rss = self.calc_rss(array1, array2) |
||
| 300 | rmsd = self.rmsd_from_rss(rss, array1.size) |
||
| 301 | else: |
||
| 302 | logging.error("Cannot calculate RMSD, arrays different sizes.") |
||
| 303 | rmsd = None |
||
| 304 | return rmsd |
||
| 305 | |||
| 306 | def calc_stats_residuals(self, stats_before, stats_after): # unused |
||
| 307 | residuals = {'max': None, 'min': None, 'mean': None, 'std_dev': None} |
||
| 308 | for key in list(residuals.keys()): |
||
| 309 | residuals[key] = stats_after[key] - stats_before[key] |
||
| 310 | return residuals |
||
| 311 | |||
| 312 | def set_stats_residuals(self, residuals): # unused |
||
| 313 | self.residuals['max'].append(residuals['max']) |
||
| 314 | self.residuals['min'].append(residuals['min']) |
||
| 315 | self.residuals['mean'].append(residuals['mean']) |
||
| 316 | self.residuals['std_dev'].append(residuals['std_dev']) |
||
| 317 | |||
| 318 | def calc_volume_stats(self, slice_stats): |
||
| 319 | """Calculates and returns volume-wide stats from slice-wide stats. |
||
| 320 | |||
| 321 | :param slice_stats: The slice-wide stats that the volume-wide stats are calculated from. |
||
| 322 | """ |
||
| 323 | slice_stats = slice_stats |
||
| 324 | volume_stats = {} |
||
| 325 | if "max" in self.stats_key: |
||
| 326 | volume_stats["max"] = max(slice_stats["max"]) |
||
| 327 | if "min" in self.stats_key: |
||
| 328 | volume_stats["min"] = min(slice_stats["min"]) |
||
| 329 | if "mean" in self.stats_key: |
||
| 330 | volume_stats["mean"] = np.mean(slice_stats["mean"]) |
||
| 331 | if "mean_std_dev" in self.stats_key: |
||
| 332 | volume_stats["mean_std_dev"] = np.mean(slice_stats["std_dev"]) |
||
| 333 | if "median_std_dev" in self.stats_key: |
||
| 334 | volume_stats["median_std_dev"] = np.median(slice_stats["std_dev"]) |
||
| 335 | if "NRMSD" in self.stats_key and None not in slice_stats["RSS"]: |
||
| 336 | total_rss = sum(slice_stats["RSS"]) |
||
| 337 | n = sum(slice_stats["data_points"]) |
||
| 338 | RMSD = self.rmsd_from_rss(total_rss, n) |
||
| 339 | the_range = volume_stats["max"] - volume_stats["min"] |
||
| 340 | NRMSD = RMSD / the_range # normalised RMSD (dividing by the range) |
||
| 341 | volume_stats["NRMSD"] = NRMSD |
||
| 342 | if "zeros" in self.stats_key: |
||
| 343 | volume_stats["zeros"] = sum(slice_stats["zeros"]) |
||
| 344 | if "zeros%" in self.stats_key: |
||
| 345 | volume_stats["zeros%"] = (volume_stats["zeros"] / sum(slice_stats["data_points"])) * 100 |
||
| 346 | if "range_used" in self.stats_key: |
||
| 347 | my_range = volume_stats["max"] - volume_stats["min"] |
||
| 348 | if "int" in str(self.stats["dtype"]): |
||
| 349 | possible_max = np.iinfo(self.stats["dtype"]).max |
||
| 350 | possible_min = np.iinfo(self.stats["dtype"]).min |
||
| 351 | self.stats["possible_max"] = possible_max |
||
| 352 | self.stats["possible_min"] = possible_min |
||
| 353 | elif "float" in str(self.stats["dtype"]): |
||
| 354 | possible_max = np.finfo(self.stats["dtype"]).max |
||
| 355 | possible_min = np.finfo(self.stats["dtype"]).min |
||
| 356 | self.stats["possible_max"] = possible_max |
||
| 357 | self.stats["possible_min"] = possible_min |
||
| 358 | possible_range = possible_max - possible_min |
||
|
0 ignored issues
–
show
introduced
by
Loading history...
|
|||
| 359 | volume_stats["range_used"] = (my_range / possible_range) * 100 |
||
| 360 | return volume_stats |
||
| 361 | |||
| 362 | def _set_loop_stats(self): |
||
| 363 | # NEED TO CHANGE THIS - MUST USE SLICES (unused) |
||
| 364 | data_obj1 = list(self._iterative_group._ip_data_dict["iterating"].keys())[0] |
||
| 365 | data_obj2 = self._iterative_group._ip_data_dict["iterating"][data_obj1] |
||
| 366 | RMSD = self.calc_rmsd(data_obj1.data, data_obj2.data) |
||
| 367 | the_range = self.get_stats(self.p_num, stat="max", instance=self._iterative_group._ip_iteration) -\ |
||
| 368 | self.get_stats(self.p_num, stat="min", instance=self._iterative_group._ip_iteration) |
||
| 369 | NRMSD = RMSD/the_range |
||
| 370 | Statistics.loop_stats[self.l_num]["NRMSD"] = np.append(Statistics.loop_stats[self.l_num]["NRMSD"], NRMSD) |
||
| 371 | |||
| 372 | def set_volume_stats(self): |
||
| 373 | """Calculates volume-wide statistics from slice stats, and updates class-wide arrays with these values. |
||
| 374 | Links volume stats with the output dataset and writes slice stats to file. |
||
| 375 | """ |
||
| 376 | stats = self.stats |
||
| 377 | comm = self.plugin.get_communicator() |
||
| 378 | combined_stats = self._combine_mpi_stats(stats, comm=comm) |
||
| 379 | if not self.p_num: |
||
| 380 | self.p_num = Statistics.count |
||
| 381 | p_num = self.p_num |
||
| 382 | name = self.plugin_name |
||
| 383 | i = 2 |
||
| 384 | if not self._iterative_group: |
||
| 385 | while name in list(Statistics.plugin_numbers.keys()): |
||
| 386 | name = self.plugin_name + str(i) |
||
| 387 | i += 1 |
||
| 388 | elif self._iterative_group._ip_iteration == 0: |
||
| 389 | while name in list(Statistics.plugin_numbers.keys()): |
||
| 390 | name = self.plugin_name + str(i) |
||
| 391 | i += 1 |
||
| 392 | |||
| 393 | if p_num not in list(Statistics.plugin_names.keys()): |
||
| 394 | Statistics.plugin_names[p_num] = name |
||
| 395 | Statistics.plugin_numbers[name] = p_num |
||
| 396 | if len(combined_stats['max']) != 0: |
||
| 397 | stats_dict = self.calc_volume_stats(combined_stats) |
||
| 398 | Statistics.global_residuals[p_num] = {} |
||
| 399 | #before_processing = self.calc_volume_stats(self.stats_before_processing) |
||
| 400 | #for key in list(before_processing.keys()): |
||
| 401 | # Statistics.global_residuals[p_num][key] = Statistics.global_stats[p_num][key] - before_processing[key] |
||
| 402 | |||
| 403 | if len(Statistics.global_stats[p_num]) == 0: |
||
| 404 | Statistics.global_stats[p_num] = [stats_dict] |
||
| 405 | else: |
||
| 406 | Statistics.global_stats[p_num].append(stats_dict) |
||
| 407 | |||
| 408 | self._link_stats_to_datasets(stats_dict, self._iterative_group) |
||
| 409 | self._write_stats_to_file(p_num, comm=comm) |
||
| 410 | self._already_called = True |
||
| 411 | self._repeat_count += 1 |
||
| 412 | if self._iterative_group or self._4d: |
||
| 413 | self.stats = {stat: [] for stat in self.slice_stats_key} |
||
| 414 | |||
| 415 | def start_time(self): |
||
| 416 | """Called at the start of a plugin.""" |
||
| 417 | self.t0 = time.time() |
||
| 418 | |||
| 419 | def stop_time(self): |
||
| 420 | """Called at the ebd of a plugin.""" |
||
| 421 | self.t1 = time.time() |
||
| 422 | elapsed = round(self.t1 - self.t0, 1) |
||
| 423 | if self._stats_flag and self.calc_stats: |
||
| 424 | self.set_time(elapsed) |
||
| 425 | |||
| 426 | def set_time(self, seconds): |
||
| 427 | """Sets time taken for plugin to complete.""" |
||
| 428 | Statistics.global_times[self.p_num] += seconds # Gives total time for a plugin in a loop |
||
| 429 | #print(f"{self.p_num}, {seconds}") |
||
| 430 | comm = self.plugin.get_communicator() |
||
| 431 | try: |
||
| 432 | rank = comm.rank |
||
| 433 | except (MPI.Exception, AttributeError): # Sometimes get_communicator() returns an invalid communicator. |
||
| 434 | comm = MPI.COMM_WORLD # So using COMM_WORLD in this case. |
||
| 435 | self._write_times_to_file(comm) |
||
| 436 | |||
| 437 | def _combine_mpi_stats(self, slice_stats, comm=MPI.COMM_WORLD): |
||
| 438 | """Combines slice stats from different processes, so volume stats can be calculated. |
||
| 439 | |||
| 440 | :param slice_stats: slice stats (each process will have a different set). |
||
| 441 | :param comm: MPI communicator being used. |
||
| 442 | """ |
||
| 443 | combined_stats_list = comm.allgather(slice_stats) |
||
| 444 | combined_stats = {stat: [] for stat in self.slice_stats_key} |
||
| 445 | for single_stats in combined_stats_list: |
||
| 446 | for key in self.slice_stats_key: |
||
| 447 | combined_stats[key] += single_stats[key] |
||
| 448 | return combined_stats |
||
| 449 | |||
| 450 | def _array_to_dict(self, stats_array, key_list=None): |
||
| 451 | """Converts an array of stats to a dictionary of stats. |
||
| 452 | |||
| 453 | :param stats_array: Array of stats to be converted. |
||
| 454 | :param key_list: List of keys indicating the names of the stats in the stats_array. |
||
| 455 | """ |
||
| 456 | if key_list is None: |
||
| 457 | key_list = self.stats_key |
||
| 458 | stats_dict = {} |
||
| 459 | for i, value in enumerate(stats_array): |
||
| 460 | stats_dict[key_list[i]] = value |
||
| 461 | return stats_dict |
||
| 462 | |||
| 463 | def _dict_to_array(self, stats_dict): |
||
| 464 | """Converts stats dict into a numpy array (keys will be lost). |
||
| 465 | |||
| 466 | :param stats_dict: dictionary of stats. |
||
| 467 | """ |
||
| 468 | return np.array(list(stats_dict.values())) |
||
| 469 | |||
| 470 | def _broadcast_gpu_stats(self, gpu_processes, process): |
||
| 471 | """During GPU plugins, most processes are unused, and don't have access to stats. |
||
| 472 | This method shares stats between processes so all have access to stats. |
||
| 473 | |||
| 474 | :param gpu_processes: List that determines whether a process is a GPU process. |
||
| 475 | :param process: Process number. |
||
| 476 | """ |
||
| 477 | p_num = self.p_num |
||
| 478 | Statistics.global_stats[p_num] = MPI.COMM_WORLD.bcast(Statistics.global_stats[p_num], root=0) |
||
| 479 | if not gpu_processes[process]: |
||
| 480 | if len(Statistics.global_stats[p_num]) != 0: |
||
| 481 | for stats_dict in Statistics.global_stats[p_num]: |
||
| 482 | self._link_stats_to_datasets(stats_dict, self._iterative_group) |
||
| 483 | |||
| 484 | def _set_pattern_info(self): |
||
| 485 | """Gathers information about the pattern of the data in the current plugin.""" |
||
| 486 | out_datasets = self.plugin.get_out_datasets() |
||
| 487 | if len(out_datasets) == 0: |
||
| 488 | self.calc_stats = False |
||
| 489 | try: |
||
| 490 | self.pattern = self.plugin.parameters['pattern'] |
||
| 491 | if self.pattern == None: |
||
| 492 | raise KeyError |
||
| 493 | except KeyError: |
||
| 494 | if not out_datasets: |
||
| 495 | self.pattern = None |
||
| 496 | else: |
||
| 497 | patterns = out_datasets[0].get_data_patterns() |
||
| 498 | for pattern in patterns: |
||
| 499 | if 1 in patterns.get(pattern)["slice_dims"]: |
||
| 500 | self.pattern = pattern |
||
| 501 | break |
||
| 502 | self.pattern = None |
||
| 503 | if self.pattern not in Statistics._pattern_list: |
||
| 504 | self.calc_stats = False |
||
| 505 | |||
| 506 | def _link_stats_to_datasets(self, stats_dict, iterative=False): |
||
| 507 | """Links the volume wide statistics to the output dataset(s). |
||
| 508 | |||
| 509 | :param stats_dict: Dictionary of stats being linked. |
||
| 510 | :param iterative: boolean indicating if the plugin is iterative or not. |
||
| 511 | """ |
||
| 512 | out_dataset = self.plugin.get_out_datasets()[0] |
||
| 513 | my_dataset = out_dataset |
||
| 514 | if iterative: |
||
| 515 | if "itr_clone" in out_dataset.group_name: |
||
| 516 | my_dataset = list(iterative._ip_data_dict["iterating"].keys())[0] |
||
| 517 | n_datasets = self.plugin.nOutput_datasets() |
||
| 518 | |||
| 519 | i = 2 |
||
| 520 | group_name = "stats" |
||
| 521 | while group_name in list(my_dataset.meta_data.get_dictionary().keys()): |
||
| 522 | group_name = f"stats{i}" # If more than one set of stats for a plugin (such as iterative plugin) |
||
| 523 | i += 1 # the groups will be named stats, stats2, stats3 etc. |
||
| 524 | for key, value in stats_dict.items(): |
||
| 525 | my_dataset.meta_data.set([group_name, key], value) |
||
| 526 | |||
| 527 | def _write_stats_to_file(self, p_num=None, plugin_name=None, comm=MPI.COMM_WORLD): |
||
| 528 | """Writes stats to a h5 file. This file is used to create figures and tables from the stats. |
||
| 529 | |||
| 530 | :param p_num: The plugin number of the plugin the stats belong to (usually left as None except |
||
| 531 | for special cases). |
||
| 532 | :param plugin_name: Same as above (but for the name of the plugin). |
||
| 533 | :param comm: The MPI communicator the plugin is using. |
||
| 534 | """ |
||
| 535 | if p_num is None: |
||
| 536 | p_num = self.p_num |
||
| 537 | if plugin_name is None: |
||
| 538 | plugin_name = self.plugin_names[p_num] |
||
| 539 | path = Statistics.path |
||
| 540 | filename = f"{path}/stats.h5" |
||
| 541 | stats_dict = self.get_stats(p_num, instance="all") |
||
| 542 | stats_array = self._dict_to_array(stats_dict[0]) |
||
| 543 | stats_key = list(stats_dict[0].keys()) |
||
| 544 | for i, my_dict in enumerate(stats_dict): |
||
| 545 | if i != 0: |
||
| 546 | stats_array = np.vstack([stats_array, self._dict_to_array(my_dict)]) |
||
| 547 | self.hdf5 = Hdf5Utils(self.exp) |
||
| 548 | self.exp._barrier(communicator=comm) |
||
| 549 | if comm.rank == 0: |
||
| 550 | with h5.File(filename, "a") as h5file: |
||
| 551 | group = h5file.require_group("stats") |
||
| 552 | if stats_array.shape != (0,): |
||
| 553 | if str(p_num) in list(group.keys()): |
||
| 554 | del group[str(p_num)] |
||
| 555 | dataset = group.create_dataset(str(p_num), shape=stats_array.shape, dtype=stats_array.dtype) |
||
| 556 | dataset[::] = stats_array[::] |
||
| 557 | dataset.attrs.create("plugin_name", plugin_name) |
||
| 558 | dataset.attrs.create("pattern", self.pattern) |
||
| 559 | dataset.attrs.create("stats_key", stats_key) |
||
| 560 | if self._iterative_group: |
||
| 561 | l_stats = Statistics.loop_stats[self.l_num] |
||
| 562 | group1 = h5file.require_group("iterative") |
||
| 563 | if self._iterative_group._ip_iteration == self._iterative_group._ip_fixed_iterations - 1\ |
||
| 564 | and self.p_num == self._iterative_group.end_index: |
||
| 565 | dataset1 = group1.create_dataset(str(self.l_num), shape=l_stats["NRMSD"].shape, dtype=l_stats["NRMSD"].dtype) |
||
| 566 | dataset1[::] = l_stats["NRMSD"][::] |
||
| 567 | loop_plugins = [] |
||
| 568 | for i in range(self._iterative_group.start_index, self._iterative_group.end_index + 1): |
||
| 569 | if i in list(self.plugin_names.keys()): |
||
| 570 | loop_plugins.append(self.plugin_names[i]) |
||
| 571 | dataset1.attrs.create("loop_plugins", loop_plugins) |
||
| 572 | dataset.attrs.create("n_loop_plugins", len(loop_plugins)) |
||
|
0 ignored issues
–
show
|
|||
| 573 | self.exp._barrier(communicator=comm) |
||
| 574 | |||
| 575 | def _write_times_to_file(self, comm): |
||
| 576 | """Writes times into the file containing all the stats.""" |
||
| 577 | p_num = self.p_num |
||
| 578 | plugin_name = self.plugin_name |
||
| 579 | path = Statistics.path |
||
| 580 | filename = f"{path}/stats.h5" |
||
| 581 | time = Statistics.global_times[p_num] |
||
| 582 | self.hdf5 = Hdf5Utils(self.exp) |
||
| 583 | if comm.rank == 0: |
||
| 584 | with h5.File(filename, "a") as h5file: |
||
| 585 | group = h5file.require_group("stats") |
||
| 586 | dataset = group[str(p_num)] |
||
| 587 | dataset.attrs.create("time", time) |
||
| 588 | |||
| 589 | def write_slice_stats_to_file(self, slice_stats=None, p_num=None, comm=MPI.COMM_WORLD): |
||
| 590 | """Writes slice statistics to a h5 file. Placed in the stats folder in the output directory. Currently unused.""" |
||
| 591 | if not slice_stats: |
||
| 592 | slice_stats = self.stats |
||
| 593 | if not p_num: |
||
| 594 | p_num = self.count |
||
| 595 | plugin_name = self.plugin_name |
||
| 596 | else: |
||
| 597 | plugin_name = self.plugin_names[p_num] |
||
| 598 | combined_stats = self._combine_mpi_stats(slice_stats) |
||
| 599 | slice_stats_arrays = {} |
||
| 600 | datasets = {} |
||
| 601 | path = Statistics.path |
||
| 602 | filename = f"{path}/stats_p{p_num}_{plugin_name}.h5" |
||
| 603 | self.hdf5 = Hdf5Utils(self.plugin.exp) |
||
| 604 | with h5.File(filename, "a", driver="mpio", comm=comm) as h5file: |
||
| 605 | i = 2 |
||
| 606 | group_name = "/stats" |
||
| 607 | while group_name in h5file: |
||
| 608 | group_name = f"/stats{i}" |
||
| 609 | i += 1 |
||
| 610 | group = h5file.create_group(group_name, track_order=None) |
||
| 611 | for key in list(combined_stats.keys()): |
||
| 612 | slice_stats_arrays[key] = np.array(combined_stats[key]) |
||
| 613 | datasets[key] = self.hdf5.create_dataset_nofill(group, key, (len(slice_stats_arrays[key]),), slice_stats_arrays[key].dtype) |
||
| 614 | datasets[key][::] = slice_stats_arrays[key] |
||
| 615 | |||
| 616 | def _unpad_slice(self, my_slice): |
||
| 617 | """If data is padded in the slice dimension, removes this pad.""" |
||
| 618 | out_datasets = self.plugin.get_out_datasets() |
||
| 619 | if len(out_datasets) == 1: |
||
| 620 | out_dataset = out_datasets[0] |
||
| 621 | else: |
||
| 622 | for dataset in out_datasets: |
||
| 623 | if self.pattern in list(dataset.data_info.get(["data_patterns"]).keys()): |
||
| 624 | out_dataset = dataset |
||
| 625 | break |
||
| 626 | slice_dims = out_dataset.get_slice_dimensions() |
||
|
0 ignored issues
–
show
|
|||
| 627 | if self.plugin.pcount == 0: |
||
| 628 | self._slice_list, self._pad = self._get_unpadded_slice_list(my_slice, slice_dims) |
||
| 629 | if self._pad: |
||
| 630 | #for slice_dim in slice_dims: |
||
| 631 | slice_dim = slice_dims[0] |
||
| 632 | temp_slice = np.swapaxes(my_slice, 0, slice_dim) |
||
| 633 | temp_slice = temp_slice[self._slice_list[slice_dim]] |
||
| 634 | my_slice = np.swapaxes(temp_slice, 0, slice_dim) |
||
| 635 | return my_slice |
||
| 636 | |||
| 637 | def _get_unpadded_slice_list(self, my_slice, slice_dims): |
||
| 638 | """Creates slice object(s) to un-pad slices in the slice dimension(s).""" |
||
| 639 | slice_list = list(self.plugin.slice_list[0]) |
||
| 640 | pad = False |
||
| 641 | if len(slice_list) == len(my_slice.shape): |
||
| 642 | i = slice_dims[0] |
||
| 643 | slice_width = self.plugin.slice_list[0][i].stop - self.plugin.slice_list[0][i].start |
||
| 644 | if slice_width < my_slice.shape[i]: |
||
| 645 | pad = True |
||
| 646 | pad_width = (my_slice.shape[i] - slice_width) // 2 # Assuming symmetrical padding |
||
| 647 | slice_list[i] = slice(pad_width, pad_width + 1, 1) |
||
| 648 | return tuple(slice_list), pad |
||
| 649 | else: |
||
| 650 | return self.plugin.slice_list[0], pad |
||
| 651 | |||
| 652 | def _flatten(self, l): |
||
| 653 | """Function to flatten nested lists.""" |
||
| 654 | out = [] |
||
| 655 | for item in l: |
||
| 656 | if isinstance(item, (list, tuple)): |
||
| 657 | out.extend(self._flatten(item)) |
||
| 658 | else: |
||
| 659 | out.append(item) |
||
| 660 | return out |
||
| 661 | |||
| 662 | def _de_list(self, my_slice): |
||
| 663 | """If the slice is in a list, remove it from that list (takes 0th element).""" |
||
| 664 | if type(my_slice) == list: |
||
| 665 | if len(my_slice) != 0: |
||
| 666 | my_slice = my_slice[0] |
||
| 667 | my_slice = self._de_list(my_slice) |
||
| 668 | return my_slice |
||
| 669 | |||
| 670 | @classmethod |
||
| 671 | def _count(cls): |
||
| 672 | cls.count += 1 |
||
| 673 | |||
| 674 | @classmethod |
||
| 675 | def _post_chain(cls): |
||
| 676 | """Called after all plugins have run.""" |
||
| 677 | if cls._any_stats & cls._stats_flag: |
||
| 678 | stats_utils = StatsUtils() |
||
| 679 | stats_utils.generate_figures(f"{cls.path}/stats.h5", cls.path) |
||
| 680 |