|
1
|
|
|
""" |
|
2
|
|
|
.. module:: statistics |
|
3
|
|
|
:platform: Unix |
|
4
|
|
|
:synopsis: Contains and processes statistics information for each plugin. |
|
5
|
|
|
|
|
6
|
|
|
.. moduleauthor::Jacob Williamson <[email protected]> |
|
7
|
|
|
|
|
8
|
|
|
""" |
|
9
|
|
|
import logging |
|
10
|
|
|
|
|
11
|
|
|
from savu.plugins.savers.utils.hdf5_utils import Hdf5Utils |
|
12
|
|
|
from savu.data.stats.stats_utils import StatsUtils |
|
13
|
|
|
from savu.core.iterate_plugin_group_utils import check_if_in_iterative_loop |
|
14
|
|
|
import savu.core.utils as cu |
|
15
|
|
|
|
|
16
|
|
|
import time |
|
17
|
|
|
import h5py as h5 |
|
18
|
|
|
import numpy as np |
|
19
|
|
|
import os |
|
20
|
|
|
from mpi4py import MPI |
|
21
|
|
|
from collections import OrderedDict |
|
22
|
|
|
|
|
23
|
|
|
class Statistics(object): |
|
24
|
|
|
_pattern_list = ["SINOGRAM", "PROJECTION", "TANGENTOGRAM", "VOLUME_YZ", "VOLUME_XZ", "VOLUME_XY", "VOLUME_3D", "4D_SCAN", "SINOMOVIE"] |
|
25
|
|
|
_no_stats_plugins = ["BasicOperations", "Mipmap", "UnetApply"] |
|
26
|
|
|
_possible_stats = ("max", "min", "mean", "mean_std_dev", "median_std_dev", "NRMSD", "zeros", "zeros%", "range_used") # list of possible stats |
|
27
|
|
|
_volume_to_slice = {"max": "max", "min": "min", "mean": "mean", "mean_std_dev": "std_dev", |
|
28
|
|
|
"median_std_dev": "std_dev", "NRMSD": ("RSS", "data_points", "max", "min"), |
|
29
|
|
|
"zeros": ("zeros", "data_points"), "zeros%": ("zeros", "data_points"), |
|
30
|
|
|
"range_used": ("min", "max")} # volume stat: required slice stat(s) |
|
31
|
|
|
#_savers = ["Hdf5Saver", "ImageSaver", "MrcSaver", "TiffSaver", "XrfSaver"] |
|
32
|
|
|
_has_setup = False |
|
33
|
|
|
|
|
34
|
|
|
|
|
35
|
|
|
def __init__(self): |
|
36
|
|
|
|
|
37
|
|
|
self.calc_stats = True |
|
38
|
|
|
self.stats_before_processing = {'max': [], 'min': [], 'mean': [], 'std_dev': []} |
|
39
|
|
|
self.residuals = {'max': [], 'min': [], 'mean': [], 'std_dev': []} |
|
40
|
|
|
self._repeat_count = 0 |
|
41
|
|
|
self.plugin = None |
|
42
|
|
|
self.p_num = None |
|
43
|
|
|
self.stats_key = ["max", "min", "mean", "mean_std_dev", "median_std_dev", "RMSD"] |
|
44
|
|
|
self.slice_stats_key = None |
|
45
|
|
|
self.stats = None |
|
46
|
|
|
self.GPU = False |
|
47
|
|
|
self._iterative_group = None |
|
48
|
|
|
|
|
49
|
|
|
def setup(self, plugin_self, pattern=None): |
|
50
|
|
|
if not Statistics._has_setup: |
|
51
|
|
|
self._setup_class(plugin_self.exp) |
|
52
|
|
|
self.plugin_name = plugin_self.name |
|
53
|
|
|
self.p_num = Statistics.count |
|
54
|
|
|
self.plugin = plugin_self |
|
55
|
|
|
self.set_stats_key(self.stats_key) |
|
56
|
|
|
self.stats = {stat: [] for stat in self.slice_stats_key} |
|
57
|
|
|
if plugin_self.name in Statistics._no_stats_plugins: |
|
58
|
|
|
self.calc_stats = False |
|
59
|
|
|
if self.calc_stats: |
|
60
|
|
|
self._pad_dims = [] |
|
61
|
|
|
self._already_called = False |
|
62
|
|
|
if pattern is not None: |
|
63
|
|
|
self.pattern = pattern |
|
64
|
|
|
else: |
|
65
|
|
|
self._set_pattern_info() |
|
66
|
|
|
if self.calc_stats: |
|
67
|
|
|
Statistics._any_stats = True |
|
68
|
|
|
self._setup_4d() |
|
69
|
|
|
self._setup_iterative() |
|
70
|
|
|
|
|
71
|
|
|
def _setup_iterative(self): |
|
72
|
|
|
self._iterative_group = check_if_in_iterative_loop(Statistics.exp) |
|
73
|
|
|
if self._iterative_group: |
|
74
|
|
|
if self._iterative_group.start_index == Statistics.count: |
|
75
|
|
|
Statistics._loop_counter += 1 |
|
76
|
|
|
Statistics.loop_stats.append({"NRMSD": np.array([])}) |
|
77
|
|
|
self.l_num = Statistics._loop_counter - 1 |
|
78
|
|
|
|
|
79
|
|
|
def _setup_4d(self): |
|
80
|
|
|
try: |
|
81
|
|
|
in_dataset, out_dataset = self.plugin.get_datasets() |
|
82
|
|
|
if in_dataset[0].data_info["nDims"] == 4: |
|
83
|
|
|
self._4d = True |
|
84
|
|
|
shape = out_dataset[0].data_info["shape"] |
|
85
|
|
|
self._volume_total_points = 1 |
|
86
|
|
|
for i in shape[:-1]: |
|
87
|
|
|
self._volume_total_points *= i |
|
88
|
|
|
else: |
|
89
|
|
|
self._4d = False |
|
90
|
|
|
except KeyError: |
|
91
|
|
|
self._4d = False |
|
92
|
|
|
|
|
93
|
|
|
@classmethod |
|
94
|
|
|
def _setup_class(cls, exp): |
|
95
|
|
|
"""Sets up the statistics class for the whole plugin chain (only called once)""" |
|
96
|
|
|
if exp.meta_data.get("stats") == "on": |
|
97
|
|
|
cls._stats_flag = True |
|
98
|
|
|
elif exp.meta_data.get("stats") == "off": |
|
99
|
|
|
cls._stats_flag = False |
|
100
|
|
|
cls._any_stats = False |
|
101
|
|
|
cls.exp = exp |
|
102
|
|
|
cls.count = 2 |
|
103
|
|
|
cls.global_stats = {} |
|
104
|
|
|
cls.global_times = {} |
|
105
|
|
|
cls.loop_stats = [] |
|
106
|
|
|
cls.n_plugins = len(exp.meta_data.plugin_list.plugin_list) |
|
107
|
|
|
for i in range(1, cls.n_plugins + 1): |
|
108
|
|
|
cls.global_stats[i] = {} |
|
109
|
|
|
cls.global_times[i] = 0 |
|
110
|
|
|
cls.global_residuals = {} |
|
111
|
|
|
cls.plugin_numbers = {} |
|
112
|
|
|
cls.plugin_names = {} |
|
113
|
|
|
cls._loop_counter = 0 |
|
114
|
|
|
cls.path = exp.meta_data['out_path'] |
|
115
|
|
|
if cls.path[-1] == '/': |
|
116
|
|
|
cls.path = cls.path[0:-1] |
|
117
|
|
|
cls.path = f"{cls.path}/stats" |
|
118
|
|
|
if MPI.COMM_WORLD.rank == 0: |
|
119
|
|
|
if not os.path.exists(cls.path): |
|
120
|
|
|
os.mkdir(cls.path) |
|
121
|
|
|
cls._has_setup = True |
|
122
|
|
|
|
|
123
|
|
|
def get_stats(self, p_num=None, stat=None, instance=-1): |
|
124
|
|
|
"""Returns stats associated with a certain plugin, given the plugin number (its place in the process list). |
|
125
|
|
|
|
|
126
|
|
|
:param p_num: Plugin number of the plugin whose associated stats are being fetched. |
|
127
|
|
|
If p_num <= 0, it is relative to the plugin number of the current plugin being run. |
|
128
|
|
|
E.g current plugin number = 5, p_num = -2 --> will return stats of the third plugin. |
|
129
|
|
|
By default will gather stats for the current plugin. |
|
130
|
|
|
:param stat: Specify the stat parameter you want to fetch, i.e 'max', 'mean', 'median_std_dev'. |
|
131
|
|
|
If left blank will return the whole dictionary of stats: |
|
132
|
|
|
{'max': , 'min': , 'mean': , 'mean_std_dev': , 'median_std_dev': , 'NRMSD': } |
|
133
|
|
|
:param instance: In cases where there are multiple set of stats associated with a plugin |
|
134
|
|
|
due to iterative loops or multi-parameters, specify which set you want to retrieve, i.e 3 to retrieve the |
|
135
|
|
|
stats associated with the third run of a plugin. Pass 'all' to get a list of all sets. |
|
136
|
|
|
By default will retrieve the most recent set. |
|
137
|
|
|
""" |
|
138
|
|
|
if p_num is None: |
|
139
|
|
|
p_num = self.p_num |
|
140
|
|
|
if p_num <= 0: |
|
141
|
|
|
try: |
|
142
|
|
|
p_num = self.p_num + p_num |
|
143
|
|
|
except TypeError: |
|
144
|
|
|
p_num = Statistics.count + p_num |
|
145
|
|
|
if instance == "all": |
|
146
|
|
|
stats_list = [self.get_stats(p_num, stat=stat, instance=1)] |
|
147
|
|
|
n = 2 |
|
148
|
|
|
while n <= len(Statistics.global_stats[p_num]): |
|
149
|
|
|
stats_list.append(self.get_stats(p_num, stat=stat, instance=n)) |
|
150
|
|
|
n += 1 |
|
151
|
|
|
return stats_list |
|
152
|
|
|
if instance > 0: |
|
153
|
|
|
instance -= 1 |
|
154
|
|
|
stats_dict = Statistics.global_stats[p_num][instance] |
|
155
|
|
|
if stat is not None: |
|
156
|
|
|
return stats_dict[stat] |
|
157
|
|
|
else: |
|
158
|
|
|
return stats_dict |
|
159
|
|
|
|
|
160
|
|
|
def get_stats_from_name(self, plugin_name, n=None, stat=None, instance=-1): |
|
161
|
|
|
"""Returns stats associated with a certain plugin. |
|
162
|
|
|
|
|
163
|
|
|
:param plugin_name: name of the plugin whose associated stats are being fetched. |
|
164
|
|
|
:param n: In a case where there are multiple instances of **plugin_name** in the process list, |
|
165
|
|
|
specify the nth instance. Not specifying will select the first (or only) instance. |
|
166
|
|
|
:param stat: Specify the stat parameter you want to fetch, i.e 'max', 'mean', 'median_std_dev'. |
|
167
|
|
|
If left blank will return the whole dictionary of stats: |
|
168
|
|
|
{'max': , 'min': , 'mean': , 'mean_std_dev': , 'median_std_dev': , 'NRMSD': } |
|
169
|
|
|
:param instance: In cases where there are multiple set of stats associated with a plugin |
|
170
|
|
|
due to iterative loops or multi-parameters, specify which set you want to retrieve, i.e 3 to retrieve the |
|
171
|
|
|
stats associated with the third run of a plugin. Pass 'all' to get a list of all sets. |
|
172
|
|
|
By default will retrieve the most recent set. |
|
173
|
|
|
""" |
|
174
|
|
|
name = plugin_name |
|
175
|
|
|
if n not in (None, 0, 1): |
|
176
|
|
|
name = name + str(n) |
|
177
|
|
|
p_num = Statistics.plugin_numbers[name] |
|
178
|
|
|
return self.get_stats(p_num, stat, instance) |
|
179
|
|
|
|
|
180
|
|
|
def get_stats_from_dataset(self, dataset, stat=None, instance=-1): |
|
181
|
|
|
"""Returns stats associated with a dataset. |
|
182
|
|
|
|
|
183
|
|
|
:param dataset: The dataset whose associated stats are being fetched. |
|
184
|
|
|
:param stat: Specify the stat parameter you want to fetch, i.e 'max', 'mean', 'median_std_dev'. |
|
185
|
|
|
If left blank will return the whole dictionary of stats: |
|
186
|
|
|
{'max': , 'min': , 'mean': , 'mean_std_dev': , 'median_std_dev': , 'NRMSD': } |
|
187
|
|
|
:param instance: In cases where there are multiple set of stats associated with a dataset |
|
188
|
|
|
due to iterative loops or multi-parameters, specify which set you want to retrieve, i.e 3 to retrieve the |
|
189
|
|
|
stats associated with the third run of a plugin. Pass 'all' to get a list of all sets. |
|
190
|
|
|
By default will retrieve the most recent set. |
|
191
|
|
|
""" |
|
192
|
|
|
stats_list = [dataset.meta_data.get("stats")] |
|
193
|
|
|
n = 2 |
|
194
|
|
|
while ("stats" + str(n)) in list(dataset.meta_data.get_dictionary().keys()): |
|
195
|
|
|
stats_list.append(dataset.meta_data.get("stats" + str(n))) |
|
196
|
|
|
n += 1 |
|
197
|
|
|
if stat: |
|
198
|
|
|
for i in range(len(stats_list)): |
|
199
|
|
|
stats_list[i] = stats_list[i][stat] |
|
200
|
|
|
if instance in (None, 0, 1): |
|
201
|
|
|
stats = stats_list[0] |
|
202
|
|
|
elif instance == "all": |
|
203
|
|
|
stats = stats_list |
|
204
|
|
|
else: |
|
205
|
|
|
if instance >= 2: |
|
206
|
|
|
instance -= 1 |
|
207
|
|
|
stats = stats_list[instance] |
|
208
|
|
|
return stats |
|
209
|
|
|
|
|
210
|
|
|
def set_stats_key(self, stats_key): |
|
211
|
|
|
"""Changes which stats are to be calculated for the current plugin. |
|
212
|
|
|
|
|
213
|
|
|
:param stats_key: List of stats to be calculated. |
|
214
|
|
|
""" |
|
215
|
|
|
valid = Statistics._possible_stats |
|
216
|
|
|
stats_key = sorted(set(valid).intersection(stats_key), key=lambda stat: valid.index(stat)) |
|
217
|
|
|
self.stats_key = stats_key |
|
218
|
|
|
self.slice_stats_key = list(set(self._flatten(list(Statistics._volume_to_slice[stat] for stat in stats_key)))) |
|
219
|
|
|
if "data_points" not in self.slice_stats_key: |
|
220
|
|
|
self.slice_stats_key.append("data_points") # Data points is essential |
|
221
|
|
|
|
|
222
|
|
|
def set_slice_stats(self, my_slice, base_slice=None, pad=True): |
|
223
|
|
|
"""Sets slice stats for the current slice. |
|
224
|
|
|
|
|
225
|
|
|
:param my_slice: The slice whose stats are being set. |
|
226
|
|
|
:param base_slice: Provide a base slice to calculate residuals from, to calculate RMSD. |
|
227
|
|
|
:param pad: Specify whether slice is padded or not (usually can leave as True even if slice is not padded). |
|
228
|
|
|
""" |
|
229
|
|
|
my_slice = self._de_list(my_slice) |
|
230
|
|
|
if 0 not in my_slice.shape: |
|
231
|
|
|
try: |
|
232
|
|
|
slice_stats = self.calc_slice_stats(my_slice, base_slice=base_slice, pad=pad) |
|
233
|
|
|
except: |
|
234
|
|
|
pass |
|
235
|
|
|
if slice_stats is not None: |
|
236
|
|
|
for key, value in slice_stats.items(): |
|
237
|
|
|
self.stats[key].append(value) |
|
238
|
|
|
if self._4d: |
|
239
|
|
|
if sum(self.stats["data_points"]) >= self._volume_total_points: |
|
240
|
|
|
self.set_volume_stats() |
|
241
|
|
|
else: |
|
242
|
|
|
self.calc_stats = False |
|
243
|
|
|
else: |
|
244
|
|
|
self.calc_stats = False |
|
245
|
|
|
|
|
246
|
|
|
def calc_slice_stats(self, my_slice, base_slice=None, pad=True): |
|
247
|
|
|
"""Calculates and returns slice stats for the current slice. |
|
248
|
|
|
|
|
249
|
|
|
:param my_slice: The slice whose stats are being calculated. |
|
250
|
|
|
:param base_slice: Provide a base slice to calculate residuals from, to calculate RMSD. |
|
251
|
|
|
:param pad: Specify whether slice is padded or not (usually can leave as True even if slice is not padded). |
|
252
|
|
|
""" |
|
253
|
|
|
if my_slice is not None: |
|
254
|
|
|
my_slice = self._de_list(my_slice) |
|
255
|
|
|
if pad: |
|
256
|
|
|
my_slice = self._unpad_slice(my_slice) |
|
257
|
|
|
slice_stats = {} |
|
258
|
|
|
if "max" in self.slice_stats_key: |
|
259
|
|
|
slice_stats["max"] = np.amax(my_slice).astype('float64') |
|
260
|
|
|
if "min" in self.slice_stats_key: |
|
261
|
|
|
slice_stats["min"] = np.amin(my_slice).astype('float64') |
|
262
|
|
|
if "mean" in self.slice_stats_key: |
|
263
|
|
|
slice_stats["mean"] = np.mean(my_slice) |
|
264
|
|
|
if "std_dev" in self.slice_stats_key: |
|
265
|
|
|
slice_stats["std_dev"] = np.std(my_slice) |
|
266
|
|
|
if "zeros" in self.slice_stats_key: |
|
267
|
|
|
slice_stats["zeros"] = self.calc_zeros(my_slice) |
|
268
|
|
|
if "data_points" in self.slice_stats_key: |
|
269
|
|
|
slice_stats["data_points"] = my_slice.size |
|
270
|
|
|
if "RSS" in self.slice_stats_key and base_slice is not None: |
|
271
|
|
|
base_slice = self._de_list(base_slice) |
|
272
|
|
|
base_slice = self._unpad_slice(base_slice) |
|
273
|
|
|
slice_stats["RSS"] = self.calc_rss(my_slice, base_slice) |
|
274
|
|
|
if "dtype" not in self.stats: |
|
275
|
|
|
self.stats["dtype"] = my_slice.dtype |
|
276
|
|
|
return slice_stats |
|
277
|
|
|
return None |
|
278
|
|
|
|
|
279
|
|
|
@staticmethod |
|
280
|
|
|
def calc_zeros(my_slice): |
|
281
|
|
|
return my_slice.size - np.count_nonzero(my_slice) |
|
282
|
|
|
|
|
283
|
|
|
@staticmethod |
|
284
|
|
|
def calc_rss(array1, array2): # residual sum of squares |
|
285
|
|
|
if array1.shape == array2.shape: |
|
286
|
|
|
residuals = np.subtract(array1, array2) |
|
287
|
|
|
rss = np.sum(residuals.flatten() ** 2) |
|
288
|
|
|
else: |
|
289
|
|
|
logging.debug("Cannot calculate RSS, arrays different sizes.") |
|
290
|
|
|
rss = None |
|
291
|
|
|
return rss |
|
292
|
|
|
|
|
293
|
|
|
@staticmethod |
|
294
|
|
|
def rmsd_from_rss(rss, n): |
|
295
|
|
|
return np.sqrt(rss/n) |
|
296
|
|
|
|
|
297
|
|
|
def calc_rmsd(self, array1, array2): |
|
298
|
|
|
if array1.shape == array2.shape: |
|
299
|
|
|
rss = self.calc_rss(array1, array2) |
|
300
|
|
|
rmsd = self.rmsd_from_rss(rss, array1.size) |
|
301
|
|
|
else: |
|
302
|
|
|
logging.error("Cannot calculate RMSD, arrays different sizes.") |
|
303
|
|
|
rmsd = None |
|
304
|
|
|
return rmsd |
|
305
|
|
|
|
|
306
|
|
|
def calc_stats_residuals(self, stats_before, stats_after): # unused |
|
307
|
|
|
residuals = {'max': None, 'min': None, 'mean': None, 'std_dev': None} |
|
308
|
|
|
for key in list(residuals.keys()): |
|
309
|
|
|
residuals[key] = stats_after[key] - stats_before[key] |
|
310
|
|
|
return residuals |
|
311
|
|
|
|
|
312
|
|
|
def set_stats_residuals(self, residuals): # unused |
|
313
|
|
|
self.residuals['max'].append(residuals['max']) |
|
314
|
|
|
self.residuals['min'].append(residuals['min']) |
|
315
|
|
|
self.residuals['mean'].append(residuals['mean']) |
|
316
|
|
|
self.residuals['std_dev'].append(residuals['std_dev']) |
|
317
|
|
|
|
|
318
|
|
|
def calc_volume_stats(self, slice_stats): |
|
319
|
|
|
"""Calculates and returns volume-wide stats from slice-wide stats. |
|
320
|
|
|
|
|
321
|
|
|
:param slice_stats: The slice-wide stats that the volume-wide stats are calculated from. |
|
322
|
|
|
""" |
|
323
|
|
|
slice_stats = slice_stats |
|
324
|
|
|
volume_stats = {} |
|
325
|
|
|
if "max" in self.stats_key: |
|
326
|
|
|
volume_stats["max"] = max(slice_stats["max"]) |
|
327
|
|
|
if "min" in self.stats_key: |
|
328
|
|
|
volume_stats["min"] = min(slice_stats["min"]) |
|
329
|
|
|
if "mean" in self.stats_key: |
|
330
|
|
|
volume_stats["mean"] = np.mean(slice_stats["mean"]) |
|
331
|
|
|
if "mean_std_dev" in self.stats_key: |
|
332
|
|
|
volume_stats["mean_std_dev"] = np.mean(slice_stats["std_dev"]) |
|
333
|
|
|
if "median_std_dev" in self.stats_key: |
|
334
|
|
|
volume_stats["median_std_dev"] = np.median(slice_stats["std_dev"]) |
|
335
|
|
|
if "NRMSD" in self.stats_key and None not in slice_stats["RSS"]: |
|
336
|
|
|
total_rss = sum(slice_stats["RSS"]) |
|
337
|
|
|
n = sum(slice_stats["data_points"]) |
|
338
|
|
|
RMSD = self.rmsd_from_rss(total_rss, n) |
|
339
|
|
|
the_range = volume_stats["max"] - volume_stats["min"] |
|
340
|
|
|
NRMSD = RMSD / the_range # normalised RMSD (dividing by the range) |
|
341
|
|
|
volume_stats["NRMSD"] = NRMSD |
|
342
|
|
|
if "zeros" in self.stats_key: |
|
343
|
|
|
volume_stats["zeros"] = sum(slice_stats["zeros"]) |
|
344
|
|
|
if "zeros%" in self.stats_key: |
|
345
|
|
|
volume_stats["zeros%"] = (volume_stats["zeros"] / sum(slice_stats["data_points"])) * 100 |
|
346
|
|
|
if "range_used" in self.stats_key: |
|
347
|
|
|
my_range = volume_stats["max"] - volume_stats["min"] |
|
348
|
|
|
if "int" in str(self.stats["dtype"]): |
|
349
|
|
|
possible_max = np.iinfo(self.stats["dtype"]).max |
|
350
|
|
|
possible_min = np.iinfo(self.stats["dtype"]).min |
|
351
|
|
|
self.stats["possible_max"] = possible_max |
|
352
|
|
|
self.stats["possible_min"] = possible_min |
|
353
|
|
|
elif "float" in str(self.stats["dtype"]): |
|
354
|
|
|
possible_max = np.finfo(self.stats["dtype"]).max |
|
355
|
|
|
possible_min = np.finfo(self.stats["dtype"]).min |
|
356
|
|
|
self.stats["possible_max"] = possible_max |
|
357
|
|
|
self.stats["possible_min"] = possible_min |
|
358
|
|
|
possible_range = possible_max - possible_min |
|
|
|
|
|
|
359
|
|
|
volume_stats["range_used"] = (my_range / possible_range) * 100 |
|
360
|
|
|
return volume_stats |
|
361
|
|
|
|
|
362
|
|
|
def _set_loop_stats(self): |
|
363
|
|
|
# NEED TO CHANGE THIS - MUST USE SLICES (unused) |
|
364
|
|
|
data_obj1 = list(self._iterative_group._ip_data_dict["iterating"].keys())[0] |
|
365
|
|
|
data_obj2 = self._iterative_group._ip_data_dict["iterating"][data_obj1] |
|
366
|
|
|
RMSD = self.calc_rmsd(data_obj1.data, data_obj2.data) |
|
367
|
|
|
the_range = self.get_stats(self.p_num, stat="max", instance=self._iterative_group._ip_iteration) -\ |
|
368
|
|
|
self.get_stats(self.p_num, stat="min", instance=self._iterative_group._ip_iteration) |
|
369
|
|
|
NRMSD = RMSD/the_range |
|
370
|
|
|
Statistics.loop_stats[self.l_num]["NRMSD"] = np.append(Statistics.loop_stats[self.l_num]["NRMSD"], NRMSD) |
|
371
|
|
|
|
|
372
|
|
|
def set_volume_stats(self): |
|
373
|
|
|
"""Calculates volume-wide statistics from slice stats, and updates class-wide arrays with these values. |
|
374
|
|
|
Links volume stats with the output dataset and writes slice stats to file. |
|
375
|
|
|
""" |
|
376
|
|
|
stats = self.stats |
|
377
|
|
|
comm = self.plugin.get_communicator() |
|
378
|
|
|
combined_stats = self._combine_mpi_stats(stats, comm=comm) |
|
379
|
|
|
if not self.p_num: |
|
380
|
|
|
self.p_num = Statistics.count |
|
381
|
|
|
p_num = self.p_num |
|
382
|
|
|
name = self.plugin_name |
|
383
|
|
|
i = 2 |
|
384
|
|
|
if not self._iterative_group: |
|
385
|
|
|
while name in list(Statistics.plugin_numbers.keys()): |
|
386
|
|
|
name = self.plugin_name + str(i) |
|
387
|
|
|
i += 1 |
|
388
|
|
|
elif self._iterative_group._ip_iteration == 0: |
|
389
|
|
|
while name in list(Statistics.plugin_numbers.keys()): |
|
390
|
|
|
name = self.plugin_name + str(i) |
|
391
|
|
|
i += 1 |
|
392
|
|
|
|
|
393
|
|
|
if p_num not in list(Statistics.plugin_names.keys()): |
|
394
|
|
|
Statistics.plugin_names[p_num] = name |
|
395
|
|
|
Statistics.plugin_numbers[name] = p_num |
|
396
|
|
|
if len(combined_stats['max']) != 0: |
|
397
|
|
|
stats_dict = self.calc_volume_stats(combined_stats) |
|
398
|
|
|
Statistics.global_residuals[p_num] = {} |
|
399
|
|
|
#before_processing = self.calc_volume_stats(self.stats_before_processing) |
|
400
|
|
|
#for key in list(before_processing.keys()): |
|
401
|
|
|
# Statistics.global_residuals[p_num][key] = Statistics.global_stats[p_num][key] - before_processing[key] |
|
402
|
|
|
|
|
403
|
|
|
if len(Statistics.global_stats[p_num]) == 0: |
|
404
|
|
|
Statistics.global_stats[p_num] = [stats_dict] |
|
405
|
|
|
else: |
|
406
|
|
|
Statistics.global_stats[p_num].append(stats_dict) |
|
407
|
|
|
|
|
408
|
|
|
self._link_stats_to_datasets(stats_dict, self._iterative_group) |
|
409
|
|
|
self._write_stats_to_file(p_num, comm=comm) |
|
410
|
|
|
self._already_called = True |
|
411
|
|
|
self._repeat_count += 1 |
|
412
|
|
|
if self._iterative_group or self._4d: |
|
413
|
|
|
self.stats = {stat: [] for stat in self.slice_stats_key} |
|
414
|
|
|
|
|
415
|
|
|
def start_time(self): |
|
416
|
|
|
"""Called at the start of a plugin.""" |
|
417
|
|
|
self.t0 = time.time() |
|
418
|
|
|
|
|
419
|
|
|
def stop_time(self): |
|
420
|
|
|
"""Called at the ebd of a plugin.""" |
|
421
|
|
|
self.t1 = time.time() |
|
422
|
|
|
elapsed = round(self.t1 - self.t0, 1) |
|
423
|
|
|
if self._stats_flag and self.calc_stats: |
|
424
|
|
|
self.set_time(elapsed) |
|
425
|
|
|
|
|
426
|
|
|
def set_time(self, seconds): |
|
427
|
|
|
"""Sets time taken for plugin to complete.""" |
|
428
|
|
|
Statistics.global_times[self.p_num] += seconds # Gives total time for a plugin in a loop |
|
429
|
|
|
#print(f"{self.p_num}, {seconds}") |
|
430
|
|
|
comm = self.plugin.get_communicator() |
|
431
|
|
|
try: |
|
432
|
|
|
rank = comm.rank |
|
433
|
|
|
except MPI.Exception: # Sometimes get_communicator() returns an invalid communicator. |
|
434
|
|
|
comm = MPI.COMM_WORLD # So using COMM_WORLD in this case. |
|
435
|
|
|
self._write_times_to_file(comm) |
|
436
|
|
|
|
|
437
|
|
|
def _combine_mpi_stats(self, slice_stats, comm=MPI.COMM_WORLD): |
|
438
|
|
|
"""Combines slice stats from different processes, so volume stats can be calculated. |
|
439
|
|
|
|
|
440
|
|
|
:param slice_stats: slice stats (each process will have a different set). |
|
441
|
|
|
:param comm: MPI communicator being used. |
|
442
|
|
|
""" |
|
443
|
|
|
combined_stats_list = comm.allgather(slice_stats) |
|
444
|
|
|
combined_stats = {stat: [] for stat in self.slice_stats_key} |
|
445
|
|
|
for single_stats in combined_stats_list: |
|
446
|
|
|
for key in self.slice_stats_key: |
|
447
|
|
|
combined_stats[key] += single_stats[key] |
|
448
|
|
|
return combined_stats |
|
449
|
|
|
|
|
450
|
|
|
def _array_to_dict(self, stats_array, key_list=None): |
|
451
|
|
|
"""Converts an array of stats to a dictionary of stats. |
|
452
|
|
|
|
|
453
|
|
|
:param stats_array: Array of stats to be converted. |
|
454
|
|
|
:param key_list: List of keys indicating the names of the stats in the stats_array. |
|
455
|
|
|
""" |
|
456
|
|
|
if key_list is None: |
|
457
|
|
|
key_list = self.stats_key |
|
458
|
|
|
stats_dict = {} |
|
459
|
|
|
for i, value in enumerate(stats_array): |
|
460
|
|
|
stats_dict[key_list[i]] = value |
|
461
|
|
|
return stats_dict |
|
462
|
|
|
|
|
463
|
|
|
def _dict_to_array(self, stats_dict): |
|
464
|
|
|
"""Converts stats dict into a numpy array (keys will be lost). |
|
465
|
|
|
|
|
466
|
|
|
:param stats_dict: dictionary of stats. |
|
467
|
|
|
""" |
|
468
|
|
|
return np.array(list(stats_dict.values())) |
|
469
|
|
|
|
|
470
|
|
|
def _broadcast_gpu_stats(self, gpu_processes, process): |
|
471
|
|
|
"""During GPU plugins, most processes are unused, and don't have access to stats. |
|
472
|
|
|
This method shares stats between processes so all have access to stats. |
|
473
|
|
|
|
|
474
|
|
|
:param gpu_processes: List that determines whether a process is a GPU process. |
|
475
|
|
|
:param process: Process number. |
|
476
|
|
|
""" |
|
477
|
|
|
p_num = self.p_num |
|
478
|
|
|
Statistics.global_stats[p_num] = MPI.COMM_WORLD.bcast(Statistics.global_stats[p_num], root=0) |
|
479
|
|
|
if not gpu_processes[process]: |
|
480
|
|
|
if Statistics.global_stats[p_num].ndim == 1: |
|
481
|
|
|
stats_dict = self._array_to_dict(Statistics.global_stats[p_num]) |
|
482
|
|
|
self._link_stats_to_datasets(stats_dict, self._iterative_group) |
|
483
|
|
|
elif Statistics.global_stats[p_num].ndim > 1: |
|
484
|
|
|
for stats_array in Statistics.global_stats[p_num]: |
|
485
|
|
|
stats_dict = self._array_to_dict(stats_array) |
|
486
|
|
|
self._link_stats_to_datasets(stats_dict, self._iterative_group) |
|
487
|
|
|
|
|
488
|
|
|
def _set_pattern_info(self): |
|
489
|
|
|
"""Gathers information about the pattern of the data in the current plugin.""" |
|
490
|
|
|
out_datasets = self.plugin.get_out_datasets() |
|
491
|
|
|
if len(out_datasets) == 0: |
|
492
|
|
|
self.calc_stats = False |
|
493
|
|
|
try: |
|
494
|
|
|
self.pattern = self.plugin.parameters['pattern'] |
|
495
|
|
|
if self.pattern == None: |
|
496
|
|
|
raise KeyError |
|
497
|
|
|
except KeyError: |
|
498
|
|
|
if not out_datasets: |
|
499
|
|
|
self.pattern = None |
|
500
|
|
|
else: |
|
501
|
|
|
patterns = out_datasets[0].get_data_patterns() |
|
502
|
|
|
for pattern in patterns: |
|
503
|
|
|
if 1 in patterns.get(pattern)["slice_dims"]: |
|
504
|
|
|
self.pattern = pattern |
|
505
|
|
|
break |
|
506
|
|
|
self.pattern = None |
|
507
|
|
|
if self.pattern not in Statistics._pattern_list: |
|
508
|
|
|
self.calc_stats = False |
|
509
|
|
|
|
|
510
|
|
|
def _link_stats_to_datasets(self, stats_dict, iterative=False): |
|
511
|
|
|
"""Links the volume wide statistics to the output dataset(s). |
|
512
|
|
|
|
|
513
|
|
|
:param stats_dict: Dictionary of stats being linked. |
|
514
|
|
|
:param iterative: boolean indicating if the plugin is iterative or not. |
|
515
|
|
|
""" |
|
516
|
|
|
out_dataset = self.plugin.get_out_datasets()[0] |
|
517
|
|
|
my_dataset = out_dataset |
|
518
|
|
|
if iterative: |
|
519
|
|
|
if "itr_clone" in out_dataset.group_name: |
|
520
|
|
|
my_dataset = list(iterative._ip_data_dict["iterating"].keys())[0] |
|
521
|
|
|
n_datasets = self.plugin.nOutput_datasets() |
|
522
|
|
|
|
|
523
|
|
|
i = 2 |
|
524
|
|
|
group_name = "stats" |
|
525
|
|
|
while group_name in list(my_dataset.meta_data.get_dictionary().keys()): |
|
526
|
|
|
group_name = f"stats{i}" # If more than one set of stats for a plugin (such as iterative plugin) |
|
527
|
|
|
i += 1 # the groups will be named stats, stats2, stats3 etc. |
|
528
|
|
|
for key, value in stats_dict.items(): |
|
529
|
|
|
my_dataset.meta_data.set([group_name, key], value) |
|
530
|
|
|
|
|
531
|
|
|
def _write_stats_to_file(self, p_num=None, plugin_name=None, comm=MPI.COMM_WORLD): |
|
532
|
|
|
"""Writes stats to a h5 file. This file is used to create figures and tables from the stats. |
|
533
|
|
|
|
|
534
|
|
|
:param p_num: The plugin number of the plugin the stats belong to (usually left as None except |
|
535
|
|
|
for special cases). |
|
536
|
|
|
:param plugin_name: Same as above (but for the name of the plugin). |
|
537
|
|
|
:param comm: The MPI communicator the plugin is using. |
|
538
|
|
|
""" |
|
539
|
|
|
if p_num is None: |
|
540
|
|
|
p_num = self.p_num |
|
541
|
|
|
if plugin_name is None: |
|
542
|
|
|
plugin_name = self.plugin_names[p_num] |
|
543
|
|
|
path = Statistics.path |
|
544
|
|
|
filename = f"{path}/stats.h5" |
|
545
|
|
|
stats_dict = self.get_stats(p_num, instance="all") |
|
546
|
|
|
stats_array = self._dict_to_array(stats_dict[0]) |
|
547
|
|
|
stats_key = list(stats_dict[0].keys()) |
|
548
|
|
|
for i, my_dict in enumerate(stats_dict): |
|
549
|
|
|
if i != 0: |
|
550
|
|
|
stats_array = np.vstack([stats_array, self._dict_to_array(my_dict)]) |
|
551
|
|
|
self.hdf5 = Hdf5Utils(self.exp) |
|
552
|
|
|
self.exp._barrier(communicator=comm) |
|
553
|
|
|
if comm.rank == 0: |
|
554
|
|
|
with h5.File(filename, "a") as h5file: |
|
555
|
|
|
group = h5file.require_group("stats") |
|
556
|
|
|
if stats_array.shape != (0,): |
|
557
|
|
|
if str(p_num) in list(group.keys()): |
|
558
|
|
|
del group[str(p_num)] |
|
559
|
|
|
dataset = group.create_dataset(str(p_num), shape=stats_array.shape, dtype=stats_array.dtype) |
|
560
|
|
|
dataset[::] = stats_array[::] |
|
561
|
|
|
dataset.attrs.create("plugin_name", plugin_name) |
|
562
|
|
|
dataset.attrs.create("pattern", self.pattern) |
|
563
|
|
|
dataset.attrs.create("stats_key", stats_key) |
|
564
|
|
|
if self._iterative_group: |
|
565
|
|
|
l_stats = Statistics.loop_stats[self.l_num] |
|
566
|
|
|
group1 = h5file.require_group("iterative") |
|
567
|
|
|
if self._iterative_group._ip_iteration == self._iterative_group._ip_fixed_iterations - 1\ |
|
568
|
|
|
and self.p_num == self._iterative_group.end_index: |
|
569
|
|
|
dataset1 = group1.create_dataset(str(self.l_num), shape=l_stats["NRMSD"].shape, dtype=l_stats["NRMSD"].dtype) |
|
570
|
|
|
dataset1[::] = l_stats["NRMSD"][::] |
|
571
|
|
|
loop_plugins = [] |
|
572
|
|
|
for i in range(self._iterative_group.start_index, self._iterative_group.end_index + 1): |
|
573
|
|
|
if i in list(self.plugin_names.keys()): |
|
574
|
|
|
loop_plugins.append(self.plugin_names[i]) |
|
575
|
|
|
dataset1.attrs.create("loop_plugins", loop_plugins) |
|
576
|
|
|
dataset.attrs.create("n_loop_plugins", len(loop_plugins)) |
|
|
|
|
|
|
577
|
|
|
self.exp._barrier(communicator=comm) |
|
578
|
|
|
|
|
579
|
|
|
def _write_times_to_file(self, comm): |
|
580
|
|
|
"""Writes times into the file containing all the stats.""" |
|
581
|
|
|
p_num = self.p_num |
|
582
|
|
|
plugin_name = self.plugin_name |
|
583
|
|
|
path = Statistics.path |
|
584
|
|
|
filename = f"{path}/stats.h5" |
|
585
|
|
|
time = Statistics.global_times[p_num] |
|
586
|
|
|
self.hdf5 = Hdf5Utils(self.exp) |
|
587
|
|
|
if comm.rank == 0: |
|
588
|
|
|
with h5.File(filename, "a") as h5file: |
|
589
|
|
|
group = h5file.require_group("stats") |
|
590
|
|
|
dataset = group[str(p_num)] |
|
591
|
|
|
dataset.attrs.create("time", time) |
|
592
|
|
|
|
|
593
|
|
|
def write_slice_stats_to_file(self, slice_stats=None, p_num=None, comm=MPI.COMM_WORLD): |
|
594
|
|
|
"""Writes slice statistics to a h5 file. Placed in the stats folder in the output directory. Currently unused.""" |
|
595
|
|
|
if not slice_stats: |
|
596
|
|
|
slice_stats = self.stats |
|
597
|
|
|
if not p_num: |
|
598
|
|
|
p_num = self.count |
|
599
|
|
|
plugin_name = self.plugin_name |
|
600
|
|
|
else: |
|
601
|
|
|
plugin_name = self.plugin_names[p_num] |
|
602
|
|
|
combined_stats = self._combine_mpi_stats(slice_stats) |
|
603
|
|
|
slice_stats_arrays = {} |
|
604
|
|
|
datasets = {} |
|
605
|
|
|
path = Statistics.path |
|
606
|
|
|
filename = f"{path}/stats_p{p_num}_{plugin_name}.h5" |
|
607
|
|
|
self.hdf5 = Hdf5Utils(self.plugin.exp) |
|
608
|
|
|
with h5.File(filename, "a", driver="mpio", comm=comm) as h5file: |
|
609
|
|
|
i = 2 |
|
610
|
|
|
group_name = "/stats" |
|
611
|
|
|
while group_name in h5file: |
|
612
|
|
|
group_name = f"/stats{i}" |
|
613
|
|
|
i += 1 |
|
614
|
|
|
group = h5file.create_group(group_name, track_order=None) |
|
615
|
|
|
for key in list(combined_stats.keys()): |
|
616
|
|
|
slice_stats_arrays[key] = np.array(combined_stats[key]) |
|
617
|
|
|
datasets[key] = self.hdf5.create_dataset_nofill(group, key, (len(slice_stats_arrays[key]),), slice_stats_arrays[key].dtype) |
|
618
|
|
|
datasets[key][::] = slice_stats_arrays[key] |
|
619
|
|
|
|
|
620
|
|
|
def _unpad_slice(self, my_slice): |
|
621
|
|
|
"""If data is padded in the slice dimension, removes this pad.""" |
|
622
|
|
|
out_datasets = self.plugin.get_out_datasets() |
|
623
|
|
|
if len(out_datasets) == 1: |
|
624
|
|
|
out_dataset = out_datasets[0] |
|
625
|
|
|
else: |
|
626
|
|
|
for dataset in out_datasets: |
|
627
|
|
|
if self.pattern in list(dataset.data_info.get(["data_patterns"]).keys()): |
|
628
|
|
|
out_dataset = dataset |
|
629
|
|
|
break |
|
630
|
|
|
slice_dims = out_dataset.get_slice_dimensions() |
|
|
|
|
|
|
631
|
|
|
if self.plugin.pcount == 0: |
|
632
|
|
|
self._slice_list, self._pad = self._get_unpadded_slice_list(my_slice, slice_dims) |
|
633
|
|
|
if self._pad: |
|
634
|
|
|
#for slice_dim in slice_dims: |
|
635
|
|
|
slice_dim = slice_dims[0] |
|
636
|
|
|
temp_slice = np.swapaxes(my_slice, 0, slice_dim) |
|
637
|
|
|
temp_slice = temp_slice[self._slice_list[slice_dim]] |
|
638
|
|
|
my_slice = np.swapaxes(temp_slice, 0, slice_dim) |
|
639
|
|
|
return my_slice |
|
640
|
|
|
|
|
641
|
|
|
def _get_unpadded_slice_list(self, my_slice, slice_dims): |
|
642
|
|
|
"""Creates slice object(s) to un-pad slices in the slice dimension(s).""" |
|
643
|
|
|
slice_list = list(self.plugin.slice_list[0]) |
|
644
|
|
|
pad = False |
|
645
|
|
|
if len(slice_list) == len(my_slice.shape): |
|
646
|
|
|
i = slice_dims[0] |
|
647
|
|
|
slice_width = self.plugin.slice_list[0][i].stop - self.plugin.slice_list[0][i].start |
|
648
|
|
|
if slice_width < my_slice.shape[i]: |
|
649
|
|
|
pad = True |
|
650
|
|
|
pad_width = (my_slice.shape[i] - slice_width) // 2 # Assuming symmetrical padding |
|
651
|
|
|
slice_list[i] = slice(pad_width, pad_width + 1, 1) |
|
652
|
|
|
return tuple(slice_list), pad |
|
653
|
|
|
else: |
|
654
|
|
|
return self.plugin.slice_list[0], pad |
|
655
|
|
|
|
|
656
|
|
|
def _flatten(self, l): |
|
657
|
|
|
"""Function to flatten nested lists.""" |
|
658
|
|
|
out = [] |
|
659
|
|
|
for item in l: |
|
660
|
|
|
if isinstance(item, (list, tuple)): |
|
661
|
|
|
out.extend(self._flatten(item)) |
|
662
|
|
|
else: |
|
663
|
|
|
out.append(item) |
|
664
|
|
|
return out |
|
665
|
|
|
|
|
666
|
|
|
def _de_list(self, my_slice): |
|
667
|
|
|
"""If the slice is in a list, remove it from that list (takes 0th element).""" |
|
668
|
|
|
if type(my_slice) == list: |
|
669
|
|
|
if len(my_slice) != 0: |
|
670
|
|
|
my_slice = my_slice[0] |
|
671
|
|
|
my_slice = self._de_list(my_slice) |
|
672
|
|
|
return my_slice |
|
673
|
|
|
|
|
674
|
|
|
@classmethod |
|
675
|
|
|
def _count(cls): |
|
676
|
|
|
cls.count += 1 |
|
677
|
|
|
|
|
678
|
|
|
@classmethod |
|
679
|
|
|
def _post_chain(cls): |
|
680
|
|
|
"""Called after all plugins have run.""" |
|
681
|
|
|
if cls._any_stats & cls._stats_flag: |
|
682
|
|
|
stats_utils = StatsUtils() |
|
683
|
|
|
stats_utils.generate_figures(f"{cls.path}/stats.h5", cls.path) |
|
684
|
|
|
|