1
|
|
|
""" |
2
|
|
|
.. module:: statistics |
3
|
|
|
:platform: Unix |
4
|
|
|
:synopsis: Contains and processes statistics information for each plugin. |
5
|
|
|
|
6
|
|
|
.. moduleauthor::Jacob Williamson <[email protected]> |
7
|
|
|
|
8
|
|
|
""" |
9
|
|
|
|
10
|
|
|
from savu.plugins.savers.utils.hdf5_utils import Hdf5Utils |
11
|
|
|
from savu.plugins.stats.stats_utils import StatsUtils |
12
|
|
|
|
13
|
|
|
import h5py as h5 |
14
|
|
|
import numpy as np |
15
|
|
|
import os |
16
|
|
|
|
17
|
|
|
|
18
|
|
|
class Statistics(object): |
19
|
|
|
_pattern_list = ["SINOGRAM", "PROJECTION", "TANGENTOGRAM", "VOLUME_YZ", "VOLUME_XZ", "VOLUME_XY", "VOLUME_3D", "4D_SCAN", "SINOMOVIE"] |
20
|
|
|
no_stats_plugins = ["BasicOperations", "Mipmap"] |
21
|
|
|
_key_list = ["max", "min", "mean", "mean_std_dev", "median_std_dev", "RMSD"] |
22
|
|
|
|
23
|
|
|
|
24
|
|
|
def __init__(self): |
25
|
|
|
self.calc_stats = True |
26
|
|
|
self.stats = {'max': [], 'min': [], 'mean': [], 'std_dev': [], 'RSS': [], 'data_points': []} |
27
|
|
|
self.stats_before_processing = {'max': [], 'min': [], 'mean': [], 'std_dev': []} |
28
|
|
|
self.residuals = {'max': [], 'min': [], 'mean': [], 'std_dev': []} |
29
|
|
|
self._repeat_count = 0 |
30
|
|
|
|
31
|
|
|
def setup(self, plugin_self): |
32
|
|
|
if plugin_self.name in Statistics.no_stats_plugins: |
33
|
|
|
self.calc_stats = False |
34
|
|
|
if self.calc_stats: |
35
|
|
|
self.plugin = plugin_self |
36
|
|
|
self.plugin_name = plugin_self.name |
37
|
|
|
self.pad_dims = [] |
38
|
|
|
self._already_called = False |
39
|
|
|
self._set_pattern_info() |
40
|
|
|
if self.calc_stats: |
41
|
|
|
Statistics._any_stats = True |
42
|
|
|
|
43
|
|
|
|
44
|
|
|
@classmethod |
45
|
|
|
def _setup_class(cls, exp): |
46
|
|
|
"""Sets up the statistics class for the whole plugin chain (only called once)""" |
47
|
|
|
cls._any_stats = False |
48
|
|
|
cls.count = 2 |
49
|
|
|
cls.global_stats = {} |
50
|
|
|
cls.exp = exp |
51
|
|
|
cls.n_plugins = len(exp.meta_data.plugin_list.plugin_list) |
52
|
|
|
for i in range(1, cls.n_plugins + 1): |
53
|
|
|
cls.global_stats[i] = np.array([]) |
54
|
|
|
cls.global_residuals = {} |
55
|
|
|
cls.plugin_numbers = {} |
56
|
|
|
cls.plugin_names = {} |
57
|
|
|
|
58
|
|
|
cls.path = exp.meta_data['out_path'] |
59
|
|
|
if cls.path[-1] == '/': |
60
|
|
|
cls.path = cls.path[0:-1] |
61
|
|
|
cls.path = f"{cls.path}/stats" |
62
|
|
|
if not os.path.exists(cls.path): |
63
|
|
|
os.mkdir(cls.path) |
64
|
|
|
|
65
|
|
|
def set_slice_stats(self, slice, base_slice): |
66
|
|
|
slice_stats_before = self.calc_slice_stats(base_slice) |
67
|
|
|
slice_stats_after = self.calc_slice_stats(slice, base_slice) |
68
|
|
|
for key in list(self.stats_before_processing.keys()): |
69
|
|
|
self.stats_before_processing[key].append(slice_stats_before[key]) |
70
|
|
|
for key in list(self.stats.keys()): |
71
|
|
|
self.stats[key].append(slice_stats_after[key]) |
72
|
|
|
|
73
|
|
|
def calc_slice_stats(self, my_slice, base_slice=None): |
74
|
|
|
"""Calculates and returns slice stats for the current slice. |
75
|
|
|
|
76
|
|
|
:param slice1: The slice whose stats are being calculated. |
77
|
|
|
""" |
78
|
|
|
if my_slice is not None: |
79
|
|
|
slice_num = self.plugin.pcount |
80
|
|
|
my_slice = self._de_list(my_slice) |
81
|
|
|
my_slice = self._unpad_slice(my_slice) |
82
|
|
|
slice_stats = {'max': np.amax(my_slice).astype('float64'), 'min': np.amin(my_slice).astype('float64'), |
83
|
|
|
'mean': np.mean(my_slice), 'std_dev': np.std(my_slice), 'data_points': my_slice.size} |
84
|
|
|
if base_slice is not None: |
85
|
|
|
base_slice = self._de_list(base_slice) |
86
|
|
|
base_slice = self._unpad_slice(base_slice) |
87
|
|
|
rss = self._calc_rss(my_slice, base_slice) |
88
|
|
|
else: |
89
|
|
|
rss = None |
90
|
|
|
slice_stats['RSS'] = rss |
91
|
|
|
return slice_stats |
92
|
|
|
return None |
93
|
|
|
|
94
|
|
|
def _calc_rss(self, array1, array2): # residual sum of squares |
95
|
|
|
if array1.shape == array2.shape: |
96
|
|
|
residuals = np.subtract(array1, array2) |
97
|
|
|
rss = sum(value**2 for value in np.nditer(residuals)) |
98
|
|
|
else: |
99
|
|
|
#print("Warning: cannot calculate RSS, arrays different sizes.") # need to make this an actual warning |
100
|
|
|
rss = None |
101
|
|
|
return rss |
102
|
|
|
|
103
|
|
|
def _rmsd_from_rss(self, rss, n): |
104
|
|
|
return np.sqrt(rss/n) |
105
|
|
|
|
106
|
|
|
def calc_rmsd(self, array1, array2): |
107
|
|
|
if array1.shape == array2.shape: |
108
|
|
|
rss = self._calc_rss(array1, array2) |
109
|
|
|
rmsd = self._rmsd_from_rss(rss, array1.size) |
110
|
|
|
else: |
111
|
|
|
print("Warning: cannot calculate RMSD, arrays different sizes.") # need to make this an actual warning |
112
|
|
|
rmsd = None |
113
|
|
|
return rmsd |
114
|
|
|
|
115
|
|
|
def calc_stats_residuals(self, stats_before, stats_after): |
116
|
|
|
residuals = {'max': None, 'min': None, 'mean': None, 'std_dev': None} |
117
|
|
|
for key in list(residuals.keys()): |
118
|
|
|
residuals[key] = stats_after[key] - stats_before[key] |
119
|
|
|
return residuals |
120
|
|
|
|
121
|
|
|
def set_stats_residuals(self, residuals): |
122
|
|
|
self.residuals['max'].append(residuals['max']) |
123
|
|
|
self.residuals['min'].append(residuals['min']) |
124
|
|
|
self.residuals['mean'].append(residuals['mean']) |
125
|
|
|
self.residuals['std_dev'].append(residuals['std_dev']) |
126
|
|
|
|
127
|
|
|
def calc_volume_stats(self, slice_stats): |
128
|
|
|
volume_stats = np.array([max(slice_stats['max']), min(slice_stats['min']), np.mean(slice_stats['mean']), |
129
|
|
|
np.mean(slice_stats['std_dev']), np.median(slice_stats['std_dev'])]) |
130
|
|
|
if None not in slice_stats['RSS']: |
131
|
|
|
total_rss = sum(slice_stats['RSS']) |
132
|
|
|
n = sum(slice_stats['data_points']) |
133
|
|
|
RMSD = self._rmsd_from_rss(total_rss, n) |
134
|
|
|
NRMSD = RMSD / abs(volume_stats[2]) # normalised RMSD (dividing by mean) |
135
|
|
|
volume_stats = np.append(volume_stats, NRMSD) |
136
|
|
|
else: |
137
|
|
|
#volume_stats = np.append(volume_stats, None) |
138
|
|
|
pass |
139
|
|
|
return volume_stats |
140
|
|
|
|
141
|
|
|
def set_volume_stats(self): |
142
|
|
|
"""Calculates volume-wide statistics from slice stats, and updates class-wide arrays with these values. |
143
|
|
|
Links volume stats with the output dataset and writes slice stats to file. |
144
|
|
|
""" |
145
|
|
|
p_num = Statistics.count |
146
|
|
|
name = self.plugin_name |
147
|
|
|
i = 2 |
148
|
|
|
while name in list(Statistics.plugin_numbers.keys()): |
149
|
|
|
name = self.plugin_name + str(i) |
150
|
|
|
i += 1 |
151
|
|
|
if len(self.stats['max']) != 0: |
152
|
|
|
stats_array = self.calc_volume_stats(self.stats) |
153
|
|
|
Statistics.global_residuals[p_num] = {} |
154
|
|
|
#before_processing = self.calc_volume_stats(self.stats_before_processing) |
155
|
|
|
#for key in list(before_processing.keys()): |
156
|
|
|
# Statistics.global_residuals[p_num][key] = Statistics.global_stats[p_num][key] - before_processing[key] |
157
|
|
|
|
158
|
|
|
if len(Statistics.global_stats[p_num]) == 0: |
159
|
|
|
Statistics.global_stats[p_num] = stats_array |
160
|
|
|
else: |
161
|
|
|
Statistics.global_stats[p_num] = np.vstack([Statistics.global_stats[p_num], stats_array]) |
162
|
|
|
Statistics.plugin_numbers[name] = p_num |
163
|
|
|
if p_num not in list(Statistics.plugin_names.keys()): |
164
|
|
|
Statistics.plugin_names[p_num] = name |
165
|
|
|
self._link_stats_to_datasets(Statistics.global_stats[Statistics.plugin_numbers[name]]) |
166
|
|
|
|
167
|
|
|
slice_stats_array = np.array([self.stats['max'], self.stats['min'], self.stats['mean'], self.stats['std_dev']]) |
168
|
|
|
self._write_stats_to_file3(p_num) |
169
|
|
|
self._already_called = True |
170
|
|
|
self._repeat_count += 1 |
171
|
|
|
|
172
|
|
|
def get_stats(self, plugin_name, n=None, stat=None, instance=1): |
173
|
|
|
"""Returns stats associated with a certain plugin. |
174
|
|
|
|
175
|
|
|
:param plugin_name: name of the plugin whose associated stats are being fetched. |
176
|
|
|
:param n: In a case where there are multiple instances of **plugin_name** in the process list, |
177
|
|
|
specify the nth instance. Not specifying will select the first (or only) instance. |
178
|
|
|
:param stat: Specify the stat parameter you want to fetch, i.e 'max', 'mean', 'median_std_dev'. |
179
|
|
|
If left blank will return the whole dictionary of stats: |
180
|
|
|
{'max': , 'min': , 'mean': , 'mean_std_dev': , 'median_std_dev': , 'RMSD' } |
181
|
|
|
:param instance: In cases where there are multiple set of stats associated with a plugin |
182
|
|
|
due to multi-parameters, specify which set you want to retrieve, i.e 3 to retrieve the |
183
|
|
|
stats associated with the third run of a plugin. Pass 'all' to get a list of all sets. |
184
|
|
|
""" |
185
|
|
|
name = plugin_name |
186
|
|
|
if n in (None, 0, 1): |
187
|
|
|
name = name + str(n) |
188
|
|
|
p_num = Statistics.plugin_numbers[name] |
189
|
|
|
return self.get_stats_from_num(p_num, stat, instance) |
190
|
|
|
|
191
|
|
|
def get_stats_from_num(self, p_num, stat=None, instance=1): |
192
|
|
|
"""Returns stats associated with a certain plugin, given the plugin number (its place in the process list). |
193
|
|
|
|
194
|
|
|
:param p_num: Plugin number of the plugin whose associated stats are being fetched. |
195
|
|
|
If p_num <= 0, it is relative to the plugin number of the current plugin being run. |
196
|
|
|
E.g current plugin number = 5, p_num = -2 --> will return stats of the third plugin. |
197
|
|
|
:param stat: Specify the stat parameter you want to fetch, i.e 'max', 'mean', 'median_std_dev'. |
198
|
|
|
If left blank will return the whole dictionary of stats: |
199
|
|
|
{'max': , 'min': , 'mean': , 'mean_std_dev': , 'median_std_dev': , 'RMSD' } |
200
|
|
|
:param instance: In cases where there are multiple set of stats associated with a plugin |
201
|
|
|
due to multi-parameters, specify which set you want to retrieve, i.e 3 to retrieve the |
202
|
|
|
stats associated with the third run of a plugin. Pass 'all' to get a list of all sets. |
203
|
|
|
""" |
204
|
|
|
if p_num <= 0: |
205
|
|
|
p_num = Statistics.count + p_num |
206
|
|
|
if Statistics.global_stats[p_num].ndim == 1 and instance in (None, 0, 1, "all"): |
207
|
|
|
stats_array = Statistics.global_stats[p_num] |
208
|
|
|
else: |
209
|
|
|
if instance == "all": |
210
|
|
|
stats_list = [self.get_stats_from_num(p_num, stat=stat, instance=1)] |
211
|
|
|
n = 2 |
212
|
|
|
while n <= Statistics.global_stats[p_num].ndim: |
213
|
|
|
stats_list.append(self.get_stats_from_num(p_num, stat=stat, instance=n)) |
214
|
|
|
n += 1 |
215
|
|
|
return stats_list |
216
|
|
|
if instance > 0: |
217
|
|
|
instance -= 1 |
218
|
|
|
stats_array = Statistics.global_stats[p_num][instance] |
219
|
|
|
stats_dict = self._array_to_dict(stats_array) |
220
|
|
|
if stat is not None: |
221
|
|
|
return stats_dict[stat] |
222
|
|
|
else: |
223
|
|
|
return stats_dict |
224
|
|
|
|
225
|
|
|
def get_stats_from_dataset(self, dataset, stat=None, instance=None): |
226
|
|
|
"""Returns stats associated with a dataset. |
227
|
|
|
|
228
|
|
|
:param dataset: The dataset whose associated stats are being fetched. |
229
|
|
|
:param stat: Specify the stat parameter you want to fetch, i.e 'max', 'mean', 'median_std_dev'. |
230
|
|
|
If left blank will return the whole dictionary of stats: |
231
|
|
|
{'max': , 'min': , 'mean': , 'mean_std_dev': , 'median_std_dev': , 'RMSD'} |
232
|
|
|
:param instance: In cases where there are multiple set of stats associated with a dataset |
233
|
|
|
due to multi-parameters, specify which set you want to retrieve, i.e 3 to retrieve the |
234
|
|
|
stats associated with the third run of a plugin. Pass 'all' to get a list of all sets. |
235
|
|
|
|
236
|
|
|
""" |
237
|
|
|
key = "stats" |
238
|
|
|
stats = {} |
239
|
|
|
if instance not in (None, 0, 1): |
240
|
|
|
if instance == "all": |
241
|
|
|
stats = [self.get_stats_from_dataset(dataset, stat=stat, instance=1)] |
242
|
|
|
n = 2 |
243
|
|
|
while ("stats" + str(n)) in list(dataset.meta_data.get_dictionary().keys()): |
244
|
|
|
stats.append(self.get_stats_from_dataset(dataset, stat=stat, instance=n)) |
245
|
|
|
n += 1 |
246
|
|
|
return stats |
247
|
|
|
key = key + str(instance) |
248
|
|
|
stats = dataset.meta_data.get(key) |
249
|
|
|
if stat is not None: |
250
|
|
|
return stats[stat] |
251
|
|
|
else: |
252
|
|
|
return stats |
253
|
|
|
|
254
|
|
|
def get_data_stats(self): |
255
|
|
|
return Statistics.data_stats |
256
|
|
|
|
257
|
|
|
def get_volume_stats(self): |
258
|
|
|
return Statistics.volume_stats |
259
|
|
|
|
260
|
|
|
def get_global_stats(self): |
261
|
|
|
return Statistics.global_stats |
262
|
|
|
|
263
|
|
|
def _array_to_dict(self, stats_array): |
264
|
|
|
stats_dict = {} |
265
|
|
|
for i, value in enumerate(stats_array): |
266
|
|
|
stats_dict[Statistics._key_list[i]] = value |
267
|
|
|
return stats_dict |
268
|
|
|
|
269
|
|
|
def _set_pattern_info(self): |
270
|
|
|
"""Gathers information about the pattern of the data in the current plugin.""" |
271
|
|
|
in_datasets, out_datasets = self.plugin.get_datasets() |
272
|
|
|
try: |
273
|
|
|
self.pattern = self.plugin.parameters['pattern'] |
274
|
|
|
if self.pattern == None: |
275
|
|
|
raise KeyError |
276
|
|
|
except KeyError: |
277
|
|
|
if not out_datasets: |
278
|
|
|
self.pattern = None |
279
|
|
|
else: |
280
|
|
|
patterns = out_datasets[0].get_data_patterns() |
281
|
|
|
for pattern in patterns: |
282
|
|
|
if 1 in patterns.get(pattern)["slice_dims"]: |
283
|
|
|
self.pattern = pattern |
284
|
|
|
break |
285
|
|
|
self.calc_stats = False |
286
|
|
|
for dataset in out_datasets: |
287
|
|
|
if bool(set(Statistics._pattern_list) & set(dataset.data_info.get("data_patterns"))): |
288
|
|
|
self.calc_stats = True |
289
|
|
|
|
290
|
|
|
def _link_stats_to_datasets(self, stats): |
291
|
|
|
"""Links the volume wide statistics to the output dataset(s)""" |
292
|
|
|
out_dataset = self.plugin.get_out_datasets()[0] |
293
|
|
|
n_datasets = self.plugin.nOutput_datasets() |
294
|
|
|
if self._repeat_count > 0: |
295
|
|
|
stats_dict = self._array_to_dict(stats[self._repeat_count]) |
296
|
|
|
else: |
297
|
|
|
stats_dict = self._array_to_dict(stats) |
298
|
|
|
i = 2 |
299
|
|
|
group_name = "stats" |
300
|
|
|
#out_dataset.data_info.set([group_name], stats) |
301
|
|
|
if n_datasets == 1: |
302
|
|
|
while group_name in list(out_dataset.meta_data.get_dictionary().keys()): |
303
|
|
|
group_name = f"stats{i}" |
304
|
|
|
i += 1 |
305
|
|
|
for key in list(stats_dict.keys()): |
306
|
|
|
out_dataset.meta_data.set([group_name, key], stats_dict[key]) |
307
|
|
|
|
308
|
|
|
def _write_stats_to_file2(self, p_num): |
309
|
|
|
path = Statistics.path |
310
|
|
|
filename = f"{path}/stats.h5" |
311
|
|
|
stats = Statistics.global_stats[p_num] |
312
|
|
|
array_dim = stats.shape |
313
|
|
|
self.hdf5 = Hdf5Utils(self.plugin.exp) |
314
|
|
|
group_name = f"{p_num}-{self.plugin_name}-stats" |
315
|
|
|
with h5.File(filename, "a") as h5file: |
316
|
|
|
if group_name not in h5file: |
317
|
|
|
group = h5file.create_group(group_name, track_order=None) |
318
|
|
|
dataset = self.hdf5.create_dataset_nofill(group, "stats", array_dim, stats.dtype) |
319
|
|
|
dataset[::] = stats[::] |
320
|
|
|
else: |
321
|
|
|
group = h5file[group_name] |
322
|
|
|
|
323
|
|
|
|
324
|
|
|
@classmethod |
325
|
|
|
def _write_stats_to_file4(cls): |
326
|
|
|
path = cls.path |
327
|
|
|
filename = f"{path}/stats.h5" |
328
|
|
|
stats = cls.global_stats |
329
|
|
|
cls.hdf5 = Hdf5Utils(cls.exp) |
330
|
|
|
for i in range(5): |
331
|
|
|
array = np.array([]) |
332
|
|
|
stat = cls._key_list[i] |
333
|
|
|
for key in list(stats.keys()): |
334
|
|
|
if len(stats[key]) != 0: |
335
|
|
|
if stats[key].ndim == 1: |
336
|
|
|
array = np.append(array, stats[key][i]) |
337
|
|
|
else: |
338
|
|
|
array = np.append(array, stats[key][0][i]) |
339
|
|
|
array_dim = array.shape |
340
|
|
|
group_name = f"all-{stat}" |
341
|
|
|
with h5.File(filename, "a") as h5file: |
342
|
|
|
group = h5file.create_group(group_name, track_order=None) |
343
|
|
|
dataset = cls.hdf5.create_dataset_nofill(group, stat, array_dim, array.dtype) |
344
|
|
|
dataset[::] = array[::] |
345
|
|
|
|
346
|
|
|
def _write_stats_to_file3(self, p_num): |
347
|
|
|
path = Statistics.path |
348
|
|
|
filename = f"{path}/stats.h5" |
349
|
|
|
stats = self.global_stats |
350
|
|
|
self.hdf5 = Hdf5Utils(self.exp) |
351
|
|
|
with h5.File(filename, "a") as h5file: |
352
|
|
|
group = h5file.require_group("stats") |
353
|
|
|
if stats[p_num].shape != (0,): |
354
|
|
|
if str(p_num) in list(group.keys()): |
355
|
|
|
del group[str(p_num)] |
356
|
|
|
dataset = group.create_dataset(str(p_num), shape=stats[p_num].shape, dtype=stats[p_num].dtype) |
357
|
|
|
dataset[::] = stats[p_num][::] |
358
|
|
|
dataset.attrs.create("plugin_name", self.plugin_names[p_num]) |
359
|
|
|
dataset.attrs.create("pattern", self.pattern) |
360
|
|
|
|
361
|
|
|
|
362
|
|
|
def _write_stats_to_file(self, slice_stats_array, p_num): |
363
|
|
|
"""Writes slice statistics to a h5 file""" |
364
|
|
|
path = Statistics.path |
365
|
|
|
filename = f"{path}/stats_p{p_num}_{self.plugin_name}.h5" |
366
|
|
|
slice_stats_dim = (slice_stats_array.shape[1],) |
367
|
|
|
self.hdf5 = Hdf5Utils(self.plugin.exp) |
368
|
|
|
with h5.File(filename, "a") as h5file: |
369
|
|
|
i = 2 |
370
|
|
|
group_name = "/stats" |
371
|
|
|
while group_name in h5file: |
372
|
|
|
group_name = f"/stats{i}" |
373
|
|
|
i += 1 |
374
|
|
|
group = h5file.create_group(group_name, track_order=None) |
375
|
|
|
max_ds = self.hdf5.create_dataset_nofill(group, "max", slice_stats_dim, slice_stats_array.dtype) |
376
|
|
|
min_ds = self.hdf5.create_dataset_nofill(group, "min", slice_stats_dim, slice_stats_array.dtype) |
377
|
|
|
mean_ds = self.hdf5.create_dataset_nofill(group, "mean", slice_stats_dim, slice_stats_array.dtype) |
378
|
|
|
std_dev_ds = self.hdf5.create_dataset_nofill(group, "standard_deviation", |
379
|
|
|
slice_stats_dim, slice_stats_array.dtype) |
380
|
|
|
if slice_stats_array.shape[0] == 5: |
381
|
|
|
rmsd_ds = self.hdf5.create_dataset_nofill(group, "RMSD", slice_stats_dim, slice_stats_array.dtype) |
382
|
|
|
rmsd_ds[::] = slice_stats_array[4] |
383
|
|
|
max_ds[::] = slice_stats_array[0] |
384
|
|
|
min_ds[::] = slice_stats_array[1] |
385
|
|
|
mean_ds[::] = slice_stats_array[2] |
386
|
|
|
std_dev_ds[::] = slice_stats_array[3] |
387
|
|
|
|
388
|
|
|
def _unpad_slice(self, slice1): |
389
|
|
|
"""If data is padded in the slice dimension, removes this pad.""" |
390
|
|
|
out_datasets = self.plugin.get_out_datasets() |
391
|
|
|
if len(out_datasets) == 1: |
392
|
|
|
out_dataset = out_datasets[0] |
393
|
|
|
else: |
394
|
|
|
for dataset in out_datasets: |
395
|
|
|
if self.pattern in list(dataset.data_info.get(["data_patterns"]).keys()): |
396
|
|
|
out_dataset = dataset |
397
|
|
|
break |
398
|
|
|
slice_dims = out_dataset.get_slice_dimensions() |
|
|
|
|
399
|
|
|
if self.plugin.pcount == 0: |
400
|
|
|
self.slice_list, self.pad = self._get_unpadded_slice_list(slice1, slice_dims) |
401
|
|
|
if self.pad: |
402
|
|
|
#for slice_dim in slice_dims: |
403
|
|
|
slice_dim = slice_dims[0] |
404
|
|
|
temp_slice = np.swapaxes(slice1, 0, slice_dim) |
405
|
|
|
temp_slice = temp_slice[self.slice_list[slice_dim]] |
406
|
|
|
slice1 = np.swapaxes(temp_slice, 0, slice_dim) |
407
|
|
|
return slice1 |
408
|
|
|
|
409
|
|
|
def _get_unpadded_slice_list(self, slice1, slice_dims): |
410
|
|
|
"""Creates slice object(s) to un-pad slices in the slice dimension(s).""" |
411
|
|
|
slice_list = list(self.plugin.slice_list[0]) |
412
|
|
|
pad = False |
413
|
|
|
if len(slice_list) == len(slice1.shape): |
414
|
|
|
#for i in slice_dims: |
415
|
|
|
i = slice_dims[0] |
416
|
|
|
slice_width = self.plugin.slice_list[0][i].stop - self.plugin.slice_list[0][i].start |
417
|
|
|
if slice_width != slice1.shape[i]: |
418
|
|
|
pad = True |
419
|
|
|
pad_width = (slice1.shape[i] - slice_width) // 2 # Assuming symmetrical padding |
420
|
|
|
slice_list[i] = slice(pad_width, pad_width + 1, 1) |
421
|
|
|
return tuple(slice_list), pad |
422
|
|
|
else: |
423
|
|
|
return self.plugin.slice_list[0], pad |
424
|
|
|
|
425
|
|
|
def _de_list(self, slice1): |
426
|
|
|
"""If the slice is in a list, remove it from that list.""" |
427
|
|
|
if type(slice1) == list: |
428
|
|
|
if len(slice1) != 0: |
429
|
|
|
slice1 = slice1[0] |
430
|
|
|
slice1 = self._de_list(slice1) |
431
|
|
|
return slice1 |
432
|
|
|
|
433
|
|
|
|
434
|
|
|
@classmethod |
435
|
|
|
def _count(cls): |
436
|
|
|
cls.count += 1 |
437
|
|
|
|
438
|
|
|
@classmethod |
439
|
|
|
def _post_chain(cls): |
440
|
|
|
if cls._any_stats: |
441
|
|
|
stats_utils = StatsUtils() |
442
|
|
|
stats_utils.generate_figures(f"{cls.path}/stats.h5", cls.path) |
443
|
|
|
|