|
1
|
|
|
""" |
|
2
|
|
|
.. module:: statistics |
|
3
|
|
|
:platform: Unix |
|
4
|
|
|
:synopsis: Contains and processes statistics information for each plugin. |
|
5
|
|
|
|
|
6
|
|
|
.. moduleauthor::Jacob Williamson <[email protected]> |
|
7
|
|
|
|
|
8
|
|
|
""" |
|
9
|
|
|
|
|
10
|
|
|
from savu.plugins.savers.utils.hdf5_utils import Hdf5Utils |
|
11
|
|
|
|
|
12
|
|
|
import h5py as h5 |
|
13
|
|
|
import numpy as np |
|
14
|
|
|
import os |
|
15
|
|
|
|
|
16
|
|
|
|
|
17
|
|
|
class Statistics(object): |
|
18
|
|
|
index_dict = {"max": 0, "min": 1, "mean": 2, "mean_std_dev": 3, "median_std_dev": 4} |
|
19
|
|
|
key_list = ["max", "min", "mean", "mean_std_dev", "median_std_dev"] |
|
20
|
|
|
pattern_list = ["SINOGRAM", "PROJECTION", "VOLUME_YZ", "VOLUME_XZ", "VOLUME_XY", "VOLUME_3D", "4D_SCAN", "SINOMOVIE"] |
|
21
|
|
|
no_stats_plugins = ["BasicOperations", "Mipmap"] |
|
22
|
|
|
|
|
23
|
|
|
def __init__(self, plugin_self): |
|
24
|
|
|
self.plugin = plugin_self |
|
25
|
|
|
self.plugin_name = plugin_self.name |
|
26
|
|
|
self.pad_dims = [] |
|
27
|
|
|
self.stats = {'max': [], 'min': [], 'mean': [], 'standard_deviation': []} |
|
28
|
|
|
self.calc_stats = False |
|
29
|
|
|
self._set_pattern_info() |
|
30
|
|
|
if self.plugin_name in Statistics.no_stats_plugins: |
|
31
|
|
|
self.calc_stats = False |
|
32
|
|
|
|
|
33
|
|
|
@classmethod |
|
34
|
|
|
def _setup(cls, exp): |
|
35
|
|
|
"""Sets up the statistics class for the whole experiment (only called once)""" |
|
36
|
|
|
cls.count = 2 |
|
37
|
|
|
cls.data_stats = {} |
|
38
|
|
|
cls.volume_stats = {} |
|
39
|
|
|
cls.global_stats = {} |
|
40
|
|
|
n_plugins = len(exp.meta_data.plugin_list.plugin_list) |
|
41
|
|
|
# for n in range(n_plugins): |
|
42
|
|
|
# cls.data_stats[n + 1] = [None, None, None, None, None] |
|
43
|
|
|
# cls.volume_stats[n + 1] = [None, None, None, None, None] |
|
44
|
|
|
cls.path = exp.meta_data['out_path'] |
|
45
|
|
|
if cls.path[-1] == '/': |
|
46
|
|
|
cls.path = cls.path[0:-1] |
|
47
|
|
|
cls.path = f"{cls.path}/stats" |
|
48
|
|
|
if not os.path.exists(cls.path): |
|
49
|
|
|
os.mkdir(cls.path) |
|
50
|
|
|
|
|
51
|
|
|
def set_slice_stats(self, slice1): |
|
52
|
|
|
"""Appends slice stats arrays with the stats parameters of the current slice. |
|
53
|
|
|
|
|
54
|
|
|
:param slice1: The slice whose stats are being calculated. |
|
55
|
|
|
""" |
|
56
|
|
|
if slice1 is not None: |
|
57
|
|
|
slice_num = self.plugin.pcount |
|
58
|
|
|
slice1 = self._de_list(slice1) |
|
59
|
|
|
slice1 = self._unpad_slice(slice1) |
|
60
|
|
|
self.stats['max'].append(slice1.max()) |
|
61
|
|
|
self.stats['min'].append(slice1.min()) |
|
62
|
|
|
self.stats['mean'].append(np.mean(slice1)) |
|
63
|
|
|
self.stats['standard_deviation'].append(np.std(slice1)) |
|
64
|
|
|
|
|
65
|
|
|
def get_slice_stats(self, stat, slice_num): |
|
66
|
|
|
"""Returns array of stats associated with the processed slices of the current plugin.""" |
|
67
|
|
|
return self.stats[stat][slice_num] |
|
68
|
|
|
|
|
69
|
|
|
def set_volume_stats(self): |
|
70
|
|
|
"""Calculates volume-wide statistics from slice stats, and updates class-wide arrays with these values. |
|
71
|
|
|
Links volume stats with the output dataset and writes slice stats to file. |
|
72
|
|
|
""" |
|
73
|
|
|
p_num = Statistics.count |
|
74
|
|
|
name = self.plugin_name |
|
75
|
|
|
i = 2 |
|
76
|
|
|
while name in list(Statistics.global_stats.keys()): |
|
77
|
|
|
name = self.plugin_name + str(i) |
|
78
|
|
|
i += 1 |
|
79
|
|
|
Statistics.data_stats[p_num] = [None, None, None, None, None] |
|
80
|
|
|
Statistics.volume_stats[p_num] = [None, None, None, None, None] |
|
81
|
|
|
if len(self.stats['max']) != 0: |
|
82
|
|
|
if self.pattern in ['PROJECTION', 'SINOGRAM', 'TANGENTOGRAM', 'SINOMOVIE', '4D_SCAN']: |
|
83
|
|
|
Statistics.data_stats[p_num][0] = max(self.stats['max']) |
|
84
|
|
|
Statistics.data_stats[p_num][1] = min(self.stats['min']) |
|
85
|
|
|
Statistics.data_stats[p_num][2] = np.mean(self.stats['mean']) |
|
86
|
|
|
Statistics.data_stats[p_num][3] = np.mean(self.stats['standard_deviation']) |
|
87
|
|
|
Statistics.data_stats[p_num][4] = np.median(self.stats['standard_deviation']) |
|
88
|
|
|
Statistics.global_stats[p_num] = Statistics.data_stats[p_num] |
|
89
|
|
|
Statistics.global_stats[name] = Statistics.global_stats[p_num] |
|
90
|
|
|
self._link_stats_to_datasets(Statistics.global_stats[name]) |
|
91
|
|
|
elif self.pattern in ['VOLUME_XZ', 'VOLUME_XY', 'VOLUME_YZ', 'VOLUME_3D']: |
|
92
|
|
|
Statistics.volume_stats[p_num][0] = max(self.stats['max']) |
|
93
|
|
|
Statistics.volume_stats[p_num][1] = min(self.stats['min']) |
|
94
|
|
|
Statistics.volume_stats[p_num][2] = np.mean(self.stats['mean']) |
|
95
|
|
|
Statistics.volume_stats[p_num][3] = np.mean(self.stats['standard_deviation']) |
|
96
|
|
|
Statistics.volume_stats[p_num][4] = np.median(self.stats['standard_deviation']) |
|
97
|
|
|
Statistics.global_stats[p_num] = Statistics.volume_stats[p_num] |
|
98
|
|
|
Statistics.global_stats[name] = Statistics.global_stats[p_num] |
|
99
|
|
|
self._link_stats_to_datasets(Statistics.global_stats[name]) |
|
100
|
|
|
slice_stats = np.array([self.stats['max'], self.stats['min'], self.stats['mean'], |
|
101
|
|
|
self.stats['standard_deviation']]) |
|
102
|
|
|
self._write_stats_to_file(slice_stats, p_num) |
|
103
|
|
|
|
|
104
|
|
|
def get_stats(self, plugin_name, n=None, stat=None): |
|
105
|
|
|
"""Returns stats associated with a certain plugin. |
|
106
|
|
|
|
|
107
|
|
|
:param plugin_name: name of the plugin whose associated stats are being fetched. |
|
108
|
|
|
:param n: In a case where there are multiple instances of <plugin_name> in the process list, |
|
109
|
|
|
specify the nth instance. Not specifying will select the first (or only) instance. |
|
110
|
|
|
:param stat: Specify the stat parameter you want to fetch, i.e 'max', 'mean', 'median_std_dev'. |
|
111
|
|
|
If left blank will return the whole dictionary of stats: |
|
112
|
|
|
{'max': ,'min': ,'mean': ,'mean_std_dev': ,'median_std_dev': } |
|
113
|
|
|
""" |
|
114
|
|
|
name = plugin_name |
|
115
|
|
|
if n is not None and n not in (0, 1): |
|
116
|
|
|
name = name + str(n) |
|
117
|
|
|
if stat is not None: |
|
118
|
|
|
i = Statistics.index_dict[stat] |
|
119
|
|
|
return Statistics.global_stats[name][i] |
|
120
|
|
|
else: |
|
121
|
|
|
stats = dict(zip(Statistics.key_list, Statistics.global_stats[name])) |
|
122
|
|
|
return stats |
|
123
|
|
|
|
|
124
|
|
|
def get_stats_from_num(self, p_num, stat=None): |
|
125
|
|
|
"""Returns stats associated with a certain plugin, given the plugin number (its place in the process list). |
|
126
|
|
|
|
|
127
|
|
|
:param p_num: Plugin number of the plugin whose associated stats are being fetched. |
|
128
|
|
|
If p_num <= 0, it is relative to the plugin number of the current plugin being run. |
|
129
|
|
|
E.g current plugin number = 5, p_num = -2 --> will return stats of the third plugin. |
|
130
|
|
|
:param stat: Specify the stat parameter you want to fetch, i.e 'max', 'mean', 'median_std_dev'. |
|
131
|
|
|
If left blank will return the whole dictionary of stats: |
|
132
|
|
|
{'max': ,'min': ,'mean': ,'mean_std_dev': ,'median_std_dev': } |
|
133
|
|
|
""" |
|
134
|
|
|
if p_num <= 0: |
|
135
|
|
|
p_num = Statistics.count + p_num |
|
136
|
|
|
if stat is not None: |
|
137
|
|
|
i = Statistics.index_dict[stat] |
|
138
|
|
|
return Statistics.global_stats[p_num][i] |
|
139
|
|
|
else: |
|
140
|
|
|
stats = dict(zip(Statistics.key_list, Statistics.global_stats[p_num])) |
|
141
|
|
|
return stats |
|
142
|
|
|
|
|
143
|
|
|
def get_stats_from_dataset(self, dataset, stat=None, set_num=None): |
|
144
|
|
|
"""Returns stats associated with a dataset. |
|
145
|
|
|
|
|
146
|
|
|
:param dataset: The dataset whose associated stats are being fetched. |
|
147
|
|
|
:param stat: Specify the stat parameter you want to fetch, i.e 'max', 'mean', 'median_std_dev'. |
|
148
|
|
|
If left blank will return the whole dictionary of stats: |
|
149
|
|
|
{'max': ,'min': ,'mean': ,'mean_std_dev': ,'median_std_dev': } |
|
150
|
|
|
:param set_num: In the (rare) case that there are multiple sets of stats associated with the dataset, |
|
151
|
|
|
specify which set to return. |
|
152
|
|
|
""" |
|
153
|
|
|
key = "stats" |
|
154
|
|
|
stats = {} |
|
155
|
|
|
if set_num is not None: |
|
156
|
|
|
key = key + str(set_num) |
|
157
|
|
|
if key in list(dataset.meta_data.dict.keys()): |
|
158
|
|
|
stats = dataset.meta_data.get(key) |
|
159
|
|
|
if stat is not None: |
|
160
|
|
|
return stats[stat] |
|
161
|
|
|
else: |
|
162
|
|
|
return stats |
|
163
|
|
|
|
|
164
|
|
|
def get_data_stats(self): |
|
165
|
|
|
return Statistics.data_stats |
|
166
|
|
|
|
|
167
|
|
|
def get_volume_stats(self): |
|
168
|
|
|
return Statistics.volume_stats |
|
169
|
|
|
|
|
170
|
|
|
def get_global_stats(self): |
|
171
|
|
|
return Statistics.global_stats |
|
172
|
|
|
|
|
173
|
|
|
def _set_pattern_info(self): |
|
174
|
|
|
"""Gathers information about the pattern of the data in the current plugin.""" |
|
175
|
|
|
in_datasets, out_datasets = self.plugin.get_datasets() |
|
176
|
|
|
try: |
|
177
|
|
|
self.pattern = self.plugin.parameters['pattern'] |
|
178
|
|
|
if self.pattern == None: |
|
179
|
|
|
raise KeyError |
|
180
|
|
|
except KeyError: |
|
181
|
|
|
if not out_datasets: |
|
182
|
|
|
self.pattern = None |
|
183
|
|
|
else: |
|
184
|
|
|
patterns = out_datasets[0].get_data_patterns() |
|
185
|
|
|
for pattern in patterns: |
|
186
|
|
|
if 1 in patterns.get(pattern)["slice_dims"]: |
|
187
|
|
|
self.pattern = pattern |
|
188
|
|
|
break |
|
189
|
|
|
for dataset in out_datasets: |
|
190
|
|
|
if bool(set(Statistics.pattern_list) & set(dataset.data_info.get("data_patterns"))): |
|
191
|
|
|
self.calc_stats = True |
|
192
|
|
|
|
|
193
|
|
|
def _link_stats_to_datasets(self, stats): |
|
194
|
|
|
"""Links the volume wide statistics to the output dataset(s)""" |
|
195
|
|
|
out_datasets = self.plugin.get_out_datasets() |
|
196
|
|
|
n_datasets = self.plugin.nOutput_datasets() |
|
197
|
|
|
i = 1 |
|
198
|
|
|
group_name = "stats" |
|
199
|
|
|
if n_datasets == 1: |
|
200
|
|
|
while group_name in list(out_datasets[0].meta_data.get_dictionary().keys()): |
|
201
|
|
|
group_name = f"stats{i}" |
|
202
|
|
|
i += 1 |
|
203
|
|
|
out_datasets[0].data_info.set([group_name, "max"], stats[0]) |
|
204
|
|
|
out_datasets[0].data_info.set([group_name, "min"], stats[1]) |
|
205
|
|
|
out_datasets[0].data_info.set([group_name, "mean"], stats[2]) |
|
206
|
|
|
out_datasets[0].data_info.set([group_name, "mean_std_dev"], stats[3]) |
|
207
|
|
|
out_datasets[0].data_info.set([group_name, "median_std_dev"], stats[4]) |
|
208
|
|
|
|
|
209
|
|
|
def _write_stats_to_file(self, slice_stats, p_num): |
|
210
|
|
|
"""Writes slice statistics to a h5 file""" |
|
211
|
|
|
path = Statistics.path |
|
212
|
|
|
filename = f"{path}/stats_p{p_num}_{self.plugin_name}.h5" |
|
213
|
|
|
slice_stats_dim = (slice_stats.shape[1],) |
|
214
|
|
|
self.hdf5 = Hdf5Utils(self.plugin.exp) |
|
215
|
|
|
with h5.File(filename, "a") as h5file: |
|
216
|
|
|
i = 1 |
|
217
|
|
|
group_name = "/stats" |
|
218
|
|
|
while group_name in h5file: |
|
219
|
|
|
group_name = f"/stats{i}" |
|
220
|
|
|
i += 1 |
|
221
|
|
|
group = h5file.create_group(group_name, track_order=None) |
|
222
|
|
|
max_ds = self.hdf5.create_dataset_nofill(group, "max", slice_stats_dim, slice_stats.dtype) |
|
223
|
|
|
min_ds = self.hdf5.create_dataset_nofill(group, "min", slice_stats_dim, slice_stats.dtype) |
|
224
|
|
|
mean_ds = self.hdf5.create_dataset_nofill(group, "mean", slice_stats_dim, slice_stats.dtype) |
|
225
|
|
|
standard_deviation_ds = self.hdf5.create_dataset_nofill(group, "standard_deviation", |
|
226
|
|
|
slice_stats_dim, slice_stats.dtype) |
|
227
|
|
|
max_ds[::] = slice_stats[0] |
|
228
|
|
|
min_ds[::] = slice_stats[1] |
|
229
|
|
|
mean_ds[::] = slice_stats[2] |
|
230
|
|
|
standard_deviation_ds[::] = slice_stats[3] |
|
231
|
|
|
|
|
232
|
|
|
def _unpad_slice(self, slice1): |
|
233
|
|
|
"""If data is padded in the slice dimension, removes this pad.""" |
|
234
|
|
|
out_datasets = self.plugin.get_out_datasets() |
|
235
|
|
|
if len(out_datasets) == 1: |
|
236
|
|
|
out_dataset = out_datasets[0] |
|
237
|
|
|
else: |
|
238
|
|
|
for dataset in out_datasets: |
|
239
|
|
|
if self.pattern in list(dataset.data_info.get(["data_patterns"]).keys()): |
|
240
|
|
|
out_dataset = dataset |
|
241
|
|
|
break |
|
242
|
|
|
slice_dims = out_dataset.get_slice_dimensions() |
|
|
|
|
|
|
243
|
|
|
if self.plugin.pcount == 0: |
|
244
|
|
|
self.slice_list, self.pad = self._get_unpadded_slice_list(slice1, slice_dims) |
|
245
|
|
|
if self.pad: |
|
246
|
|
|
for slice_dim in slice_dims: |
|
247
|
|
|
temp_slice = np.swapaxes(slice1, 0, slice_dim) |
|
248
|
|
|
temp_slice = temp_slice[self.slice_list[slice_dim]] |
|
249
|
|
|
slice1 = np.swapaxes(temp_slice, 0, slice_dim) |
|
250
|
|
|
return slice1 |
|
251
|
|
|
|
|
252
|
|
|
def _get_unpadded_slice_list(self, slice1, slice_dims): |
|
253
|
|
|
"""Creates slice object(s) to un-pad slices in the slice dimension(s).""" |
|
254
|
|
|
slice_list = list(self.plugin.slice_list[0]) |
|
255
|
|
|
pad = False |
|
256
|
|
|
if len(slice_list) == len(slice1.shape): |
|
257
|
|
|
for i in slice_dims: |
|
258
|
|
|
slice_width = self.plugin.slice_list[0][i].stop - self.plugin.slice_list[0][i].start |
|
259
|
|
|
if slice_width != slice1.shape[i]: |
|
260
|
|
|
pad = True |
|
261
|
|
|
pad_width = (slice1.shape[i] - slice_width) // 2 # Assuming symmetrical padding |
|
262
|
|
|
slice_list[i] = slice(pad_width, pad_width + 1, 1) |
|
263
|
|
|
return tuple(slice_list), pad |
|
264
|
|
|
else: |
|
265
|
|
|
return self.plugin.slice_list[0], pad |
|
266
|
|
|
|
|
267
|
|
|
def _de_list(self, slice1): |
|
268
|
|
|
"""If the slice is in a list, remove it from that list.""" |
|
269
|
|
|
if type(slice1) == list: |
|
270
|
|
|
if len(slice1) != 0: |
|
271
|
|
|
slice1 = slice1[0] |
|
272
|
|
|
slice1 = self._de_list(slice1) |
|
273
|
|
|
return slice1 |
|
274
|
|
|
|
|
275
|
|
|
@classmethod |
|
276
|
|
|
def _count(cls): |
|
277
|
|
|
cls.count += 1 |
|
278
|
|
|
|
|
279
|
|
|
@classmethod |
|
280
|
|
|
def _post_chain(cls): |
|
281
|
|
|
print(cls.data_stats) |
|
282
|
|
|
print(cls.volume_stats) |
|
283
|
|
|
print(cls.global_stats) |