Test Failed
Pull Request — master (#878)
by
unknown
04:01
created

Statistics._write_stats_to_file()   A

Complexity

Conditions 4

Size

Total Lines 25
Code Lines 24

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 4
eloc 24
nop 3
dl 0
loc 25
rs 9.304
c 0
b 0
f 0
1
"""
2
.. module:: statistics
3
   :platform: Unix
4
   :synopsis: Contains and processes statistics information for each plugin.
5
6
.. moduleauthor::Jacob Williamson <[email protected]>
7
8
"""
9
10
from savu.plugins.savers.utils.hdf5_utils import Hdf5Utils
11
12
import h5py as h5
13
import numpy as np
14
import os
15
16
17
class Statistics(object):
18
    pattern_list = ["SINOGRAM", "PROJECTION", "VOLUME_YZ", "VOLUME_XZ", "VOLUME_XY", "VOLUME_3D", "4D_SCAN", "SINOMOVIE"]
19
    no_stats_plugins = ["BasicOperations", "Mipmap"]
20
21
    def __init__(self):
22
        self.calc_stats = True
23
        self.stats = {'max': [], 'min': [], 'mean': [], 'std_dev': [], 'RSS': [], 'data_points': []}
24
        self.residuals = {'max': [], 'min': [], 'mean': [], 'std_dev': []}
25
26
    def setup(self, plugin_self):
27
        if plugin_self.name in Statistics.no_stats_plugins:
28
            self.calc_stats = False
29
        if self.calc_stats:
30
            self.plugin = plugin_self
31
            self.plugin_name = plugin_self.name
32
            self.pad_dims = []
33
            self._already_called = False
34
            self._set_pattern_info()
35
36
37
    @classmethod
38
    def _setup_class(cls, exp):
39
        """Sets up the statistics class for the whole experiment (only called once)"""
40
        cls.count = 2
41
        cls.data_stats = {}
42
        cls.volume_stats = {}
43
        cls.global_stats = {}
44
        cls.global_residuals = {}
45
        n_plugins = len(exp.meta_data.plugin_list.plugin_list)
46
        # for n in range(n_plugins):
47
        #    cls.data_stats[n + 1] = [None, None, None, None, None]
48
        #    cls.volume_stats[n + 1] = [None, None, None, None, None]
49
        cls.path = exp.meta_data['out_path']
50
        if cls.path[-1] == '/':
51
            cls.path = cls.path[0:-1]
52
        cls.path = f"{cls.path}/stats"
53
        if not os.path.exists(cls.path):
54
            os.mkdir(cls.path)
55
56
    def calc_slice_stats(self, my_slice, base_slice=None):
57
        """Calculates and returns slice stats for the current slice.
58
59
        :param slice1: The slice whose stats are being calculated.
60
        """
61
        if my_slice is not None:
62
            slice_num = self.plugin.pcount
63
            my_slice = self._de_list(my_slice)
64
            my_slice = self._unpad_slice(my_slice)
65
            slice_stats = {'max': np.amax(my_slice).astype('float64'), 'min': np.amin(my_slice).astype('float64'),
66
                           'mean': np.mean(my_slice), 'std_dev': np.std(my_slice), 'data_points': my_slice.size}
67
            if base_slice is not None:
68
                base_slice = self._de_list(base_slice)
69
                base_slice = self._unpad_slice(base_slice)
70
                rss = self._calc_rss(my_slice, base_slice)
71
            else:
72
                rss = None
73
            slice_stats['RSS'] = rss
74
            return slice_stats
75
        return None
76
77
    def _calc_rss(self, array1, array2): # residual sum of squares
78
        if array1.shape == array2.shape:
79
            residuals = np.subtract(array1, array2)
80
            rss = sum(value**2 for value in np.nditer(residuals))
81
        else:
82
            print("Warning: cannot calculate RSS, arrays different sizes.")  # need to make this an actual warning
83
            rss = None
84
        return rss
85
86
    def _rmsd_from_rss(self, rss, n):
87
        return np.sqrt(rss/n)
88
89
    def calc_rmsd(self, array1, array2):
90
        if array1.shape == array2.shape:
91
            rss = self._calc_rss(array1, array2)
92
            rmsd = self._rmsd_from_rss(rss, array1.size)
93
        else:
94
            print("Warning: cannot calculate RMSD, arrays different sizes.")  # need to make this an actual warning
95
            rmsd = None
96
        return rmsd
97
98
    def set_slice_stats(self, slice_stats):
99
        """Appends slice stats arrays with the stats parameters of the current slice.
100
101
        :param slice_stats: The stats for the current slice.
102
        """
103
        self.stats['max'].append(slice_stats['max'])
104
        self.stats['min'].append(slice_stats['min'])
105
        self.stats['mean'].append(slice_stats['mean'])
106
        self.stats['std_dev'].append(slice_stats['std_dev'])
107
        self.stats['RSS'].append(slice_stats['RSS'])
108
        self.stats['data_points'].append(slice_stats['data_points'])
109
110
    def calc_stats_residuals(self, stats_before, stats_after):
111
        residuals = {'max': None, 'min': None, 'mean': None, 'std_dev': None}
112
        for key in list(residuals.keys()):
113
            residuals[key] = stats_after[key] - stats_before[key]
114
        return residuals
115
116
    def set_stats_residuals(self, residuals):
117
        self.residuals['max'].append(residuals['max'])
118
        self.residuals['min'].append(residuals['min'])
119
        self.residuals['mean'].append(residuals['mean'])
120
        self.residuals['std_dev'].append(residuals['std_dev'])
121
122
    def get_slice_stats(self, stat, slice_num):
123
        """Returns array of stats associated with the processed slices of the current plugin."""
124
        return self.stats[stat][slice_num]
125
126
    def set_volume_stats(self):
127
        """Calculates volume-wide statistics from slice stats, and updates class-wide arrays with these values.
128
        Links volume stats with the output dataset and writes slice stats to file.
129
        """
130
        p_num = Statistics.count
131
        name = self.plugin_name
132
        i = 2
133
        while name in list(Statistics.global_stats.keys()):
134
            name = self.plugin_name + str(i)
135
            i += 1
136
        Statistics.global_stats[p_num] = {}
137
        if len(self.stats['max']) != 0:
138
            Statistics.global_stats[p_num]['max'] = max(self.stats['max'])
139
            Statistics.global_stats[p_num]['min'] = min(self.stats['min'])
140
            Statistics.global_stats[p_num]['mean'] = np.mean(self.stats['mean'])
141
            Statistics.global_stats[p_num]['mean_std_dev'] = np.mean(self.stats['std_dev'])
142
            Statistics.global_stats[p_num]['median_std_dev'] = np.median(self.stats['std_dev'])
143
            if None not in self.stats['RSS']:
144
                total_rss = sum(self.stats['RSS'])
145
                n = sum(self.stats['data_points'])
146
                Statistics.global_stats[p_num]['RMSD'] = self._rmsd_from_rss(total_rss, n)
147
            else:
148
                Statistics.global_stats[p_num]['RMSD'] = None
149
150
            Statistics.global_stats[name] = Statistics.global_stats[p_num]
151
            self._link_stats_to_datasets(Statistics.global_stats[name])
152
153
        slice_stats_array = np.array([self.stats['max'], self.stats['min'], self.stats['mean'], self.stats['std_dev']])
154
155
        #if None not in self.stats['RMSD']:
156
        #    slice_stats_array = np.append(slice_stats_array, self.stats['RMSD'], 0)
157
        self._write_stats_to_file(slice_stats_array, p_num)
158
        self.set_volume_residuals()
159
        self._already_called = True
160
161
    def set_volume_residuals(self):
162
        p_num = Statistics.count
163
        Statistics.global_residuals[p_num] = {}
164
        Statistics.global_residuals[p_num]['max'] = np.mean(self.residuals['max'])
165
        Statistics.global_residuals[p_num]['min'] = np.mean(self.residuals['min'])
166
        Statistics.global_residuals[p_num]['mean'] = np.mean(self.residuals['mean'])
167
        Statistics.global_residuals[p_num]['std_dev'] = np.mean(self.residuals['std_dev'])
168
169
    def get_stats(self, plugin_name, n=None, stat=None):
170
        """Returns stats associated with a certain plugin.
171
172
        :param plugin_name: name of the plugin whose associated stats are being fetched.
173
        :param n: In a case where there are multiple instances of **plugin_name** in the process list,
174
            specify the nth instance. Not specifying will select the first (or only) instance.
175
        :param stat: Specify the stat parameter you want to fetch, i.e 'max', 'mean', 'median_std_dev'.
176
            If left blank will return the whole dictionary of stats:
177
            {'max': , 'min': , 'mean': , 'mean_std_dev': , 'median_std_dev': }
178
        """
179
        name = plugin_name
180
        if n is not None and n not in (0, 1):
181
            name = name + str(n)
182
        if stat is not None:
183
            return Statistics.global_stats[name][stat]
184
        else:
185
            return Statistics.global_stats[name]
186
187
    def get_stats_from_num(self, p_num, stat=None):
188
        """Returns stats associated with a certain plugin, given the plugin number (its place in the process list).
189
190
        :param p_num: Plugin  number of the plugin whose associated stats are being fetched.
191
            If p_num <= 0, it is relative to the plugin number of the current plugin being run.
192
            E.g current plugin number = 5, p_num = -2 --> will return stats of the third plugin.
193
        :param stat: Specify the stat parameter you want to fetch, i.e 'max', 'mean', 'median_std_dev'.
194
            If left blank will return the whole dictionary of stats:
195
            {'max': , 'min': , 'mean': , 'mean_std_dev': , 'median_std_dev': }
196
        """
197
        if p_num <= 0:
198
            p_num = Statistics.count + p_num
199
        if stat is not None:
200
            return Statistics.global_stats[p_num][stat]
201
        else:
202
            return Statistics.global_stats[p_num]
203
204
    def get_stats_from_dataset(self, dataset, stat=None, set_num=None):
205
        """Returns stats associated with a dataset.
206
207
        :param dataset: The dataset whose associated stats are being fetched.
208
        :param stat: Specify the stat parameter you want to fetch, i.e 'max', 'mean', 'median_std_dev'.
209
            If left blank will return the whole dictionary of stats:
210
            {'max': , 'min': , 'mean': , 'mean_std_dev': , 'median_std_dev': }
211
        :param set_num: In the (rare) case that there are multiple sets of stats associated with the dataset,
212
            specify which set to return.
213
214
        """
215
        key = "stats"
216
        stats = {}
217
        if set_num is not None and set_num not in (0, 1):
218
            key = key + str(set_num)
219
        stats = dataset.meta_data.get(key)
220
        if stat is not None:
221
            return stats[stat]
222
        else:
223
            return stats
224
225
    def get_data_stats(self):
226
        return Statistics.data_stats
227
228
    def get_volume_stats(self):
229
        return Statistics.volume_stats
230
231
    def get_global_stats(self):
232
        return Statistics.global_stats
233
234
    def _set_pattern_info(self):
235
        """Gathers information about the pattern of the data in the current plugin."""
236
        in_datasets, out_datasets = self.plugin.get_datasets()
237
        try:
238
            self.pattern = self.plugin.parameters['pattern']
239
            if self.pattern == None:
240
                raise KeyError
241
        except KeyError:
242
            if not out_datasets:
243
                self.pattern = None
244
            else:
245
                patterns = out_datasets[0].get_data_patterns()
246
                for pattern in patterns:
247
                    if 1 in patterns.get(pattern)["slice_dims"]:
248
                        self.pattern = pattern
249
                        break
250
        self.calc_stats = False
251
        for dataset in out_datasets:
252
            if bool(set(Statistics.pattern_list) & set(dataset.data_info.get("data_patterns"))):
253
                self.calc_stats = True
254
255
    def _link_stats_to_datasets(self, stats):
256
        """Links the volume wide statistics to the output dataset(s)"""
257
        out_dataset = self.plugin.get_out_datasets()[0]
258
        n_datasets = self.plugin.nOutput_datasets()
259
        i = 2
260
        group_name = "stats"
261
        if n_datasets == 1:
262
            while group_name in list(out_dataset.meta_data.get_dictionary().keys()):
263
                group_name = f"stats{i}"
264
                i += 1
265
            out_dataset.meta_data.set([group_name, "max"], stats["max"])
266
            out_dataset.meta_data.set([group_name, "min"], stats["min"])
267
            out_dataset.meta_data.set([group_name, "mean"], stats["mean"])
268
            out_dataset.meta_data.set([group_name, "mean_std_dev"], stats["mean_std_dev"])
269
            out_dataset.meta_data.set([group_name, "median_std_dev"], stats["median_std_dev"])
270
            if stats["RMSD"] is not None:
271
                out_dataset.meta_data.set([group_name, "RMSD"], stats["RMSD"])
272
273
    def _write_stats_to_file(self, slice_stats_array, p_num):
274
        """Writes slice statistics to a h5 file"""
275
        path = Statistics.path
276
        filename = f"{path}/stats_p{p_num}_{self.plugin_name}.h5"
277
        slice_stats_dim = (slice_stats_array.shape[1],)
278
        self.hdf5 = Hdf5Utils(self.plugin.exp)
279
        with h5.File(filename, "a") as h5file:
280
            i = 2
281
            group_name = "/stats"
282
            while group_name in h5file:
283
                group_name = f"/stats{i}"
284
                i += 1
285
            group = h5file.create_group(group_name, track_order=None)
286
            max_ds = self.hdf5.create_dataset_nofill(group, "max", slice_stats_dim, slice_stats_array.dtype)
287
            min_ds = self.hdf5.create_dataset_nofill(group, "min", slice_stats_dim, slice_stats_array.dtype)
288
            mean_ds = self.hdf5.create_dataset_nofill(group, "mean", slice_stats_dim, slice_stats_array.dtype)
289
            std_dev_ds = self.hdf5.create_dataset_nofill(group, "standard_deviation",
290
                                                         slice_stats_dim, slice_stats_array.dtype)
291
            if slice_stats_array.shape[0] == 5:
292
                rmsd_ds = self.hdf5.create_dataset_nofill(group, "RMSD", slice_stats_dim, slice_stats_array.dtype)
293
                rmsd_ds[::] = slice_stats_array[4]
294
            max_ds[::] = slice_stats_array[0]
295
            min_ds[::] = slice_stats_array[1]
296
            mean_ds[::] = slice_stats_array[2]
297
            std_dev_ds[::] = slice_stats_array[3]
298
299
    def _unpad_slice(self, slice1):
300
        """If data is padded in the slice dimension, removes this pad."""
301
        out_datasets = self.plugin.get_out_datasets()
302
        if len(out_datasets) == 1:
303
            out_dataset = out_datasets[0]
304
        else:
305
            for dataset in out_datasets:
306
                if self.pattern in list(dataset.data_info.get(["data_patterns"]).keys()):
307
                    out_dataset = dataset
308
                    break
309
        slice_dims = out_dataset.get_slice_dimensions()
0 ignored issues
show
introduced by
The variable out_dataset does not seem to be defined for all execution paths.
Loading history...
310
        if self.plugin.pcount == 0:
311
            self.slice_list, self.pad = self._get_unpadded_slice_list(slice1, slice_dims)
312
        if self.pad:
313
            for slice_dim in slice_dims:
314
                temp_slice = np.swapaxes(slice1, 0, slice_dim)
315
                temp_slice = temp_slice[self.slice_list[slice_dim]]
316
                slice1 = np.swapaxes(temp_slice, 0, slice_dim)
317
        return slice1
318
319
    def _get_unpadded_slice_list(self, slice1, slice_dims):
320
        """Creates slice object(s) to un-pad slices in the slice dimension(s)."""
321
        slice_list = list(self.plugin.slice_list[0])
322
        pad = False
323
        if len(slice_list) == len(slice1.shape):
324
            for i in slice_dims:
325
                slice_width = self.plugin.slice_list[0][i].stop - self.plugin.slice_list[0][i].start
326
                if slice_width != slice1.shape[i]:
327
                    pad = True
328
                    pad_width = (slice1.shape[i] - slice_width) // 2  # Assuming symmetrical padding
329
                    slice_list[i] = slice(pad_width, pad_width + 1, 1)
330
            return tuple(slice_list), pad
331
        else:
332
            return self.plugin.slice_list[0], pad
333
334
    def _de_list(self, slice1):
335
        """If the slice is in a list, remove it from that list."""
336
        if type(slice1) == list:
337
            if len(slice1) != 0:
338
                slice1 = slice1[0]
339
                slice1 = self._de_list(slice1)
340
        return slice1
341
342
    @classmethod
343
    def _count(cls):
344
        cls.count += 1
345
346
    @classmethod
347
    def _post_chain(cls):
348
        print(cls.data_stats)
349
        print(cls.volume_stats)
350
        print(cls.global_stats)
351
        print(cls.global_residuals)