Test Failed
Pull Request — master (#878)
by
unknown
03:49
created

Statistics.get_stats_from_dataset()   A

Complexity

Conditions 4

Size

Total Lines 20
Code Lines 9

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 4
eloc 9
nop 4
dl 0
loc 20
rs 9.95
c 0
b 0
f 0
1
"""
2
.. module:: statistics
3
   :platform: Unix
4
   :synopsis: Contains and processes statistics information for each plugin.
5
6
.. moduleauthor::Jacob Williamson <[email protected]>
7
8
"""
9
10
from savu.plugins.savers.utils.hdf5_utils import Hdf5Utils
11
from savu.plugins.stats.stats_utils import StatsUtils
12
13
import h5py as h5
14
import numpy as np
15
import os
16
17
18
class Statistics(object):
19
    _pattern_list = ["SINOGRAM", "PROJECTION", "TANGENTOGRAM", "VOLUME_YZ", "VOLUME_XZ", "VOLUME_XY", "VOLUME_3D", "4D_SCAN", "SINOMOVIE"]
20
    no_stats_plugins = ["BasicOperations", "Mipmap"]
21
    _key_list = ["max", "min", "mean", "mean_std_dev", "median_std_dev", "RMSD"]
22
23
24
    def __init__(self):
25
        self.calc_stats = True
26
        self.stats = {'max': [], 'min': [], 'mean': [], 'std_dev': [], 'RSS': [], 'data_points': []}
27
        self.stats_before_processing = {'max': [], 'min': [], 'mean': [], 'std_dev': []}
28
        self.residuals = {'max': [], 'min': [], 'mean': [], 'std_dev': []}
29
30
    def setup(self, plugin_self):
31
        if plugin_self.name in Statistics.no_stats_plugins:
32
            self.calc_stats = False
33
        if self.calc_stats:
34
            self.plugin = plugin_self
35
            self.plugin_name = plugin_self.name
36
            self.pad_dims = []
37
            self._already_called = False
38
            self._set_pattern_info()
39
40
41
    @classmethod
42
    def _setup_class(cls, exp):
43
        """Sets up the statistics class for the whole experiment (only called once)"""
44
        cls.count = 2
45
        cls.data_stats = {}
46
        cls.volume_stats = {}
47
        cls.global_stats = {}
48
        cls.exp = exp
49
        n_plugins = len(exp.meta_data.plugin_list.plugin_list)
50
        for i in range(1, n_plugins + 1):
51
            cls.global_stats[i] = np.array([])
52
        cls.global_residuals = {}
53
        cls.plugin_numbers = {}
54
        cls.plugin_names = {}
55
56
        cls.path = exp.meta_data['out_path']
57
        if cls.path[-1] == '/':
58
            cls.path = cls.path[0:-1]
59
        cls.path = f"{cls.path}/stats"
60
        if not os.path.exists(cls.path):
61
            os.mkdir(cls.path)
62
63
    def set_slice_stats(self, slice, base_slice):
64
        slice_stats_before = self.calc_slice_stats(base_slice)
65
        slice_stats_after = self.calc_slice_stats(slice, base_slice)
66
        for key in list(self.stats_before_processing.keys()):
67
            self.stats_before_processing[key].append(slice_stats_before[key])
68
        for key in list(self.stats.keys()):
69
            self.stats[key].append(slice_stats_after[key])
70
71
    def calc_slice_stats(self, my_slice, base_slice=None):
72
        """Calculates and returns slice stats for the current slice.
73
74
        :param slice1: The slice whose stats are being calculated.
75
        """
76
        if my_slice is not None:
77
            slice_num = self.plugin.pcount
78
            my_slice = self._de_list(my_slice)
79
            my_slice = self._unpad_slice(my_slice)
80
            slice_stats = {'max': np.amax(my_slice).astype('float64'), 'min': np.amin(my_slice).astype('float64'),
81
                           'mean': np.mean(my_slice), 'std_dev': np.std(my_slice), 'data_points': my_slice.size}
82
            if base_slice is not None:
83
                base_slice = self._de_list(base_slice)
84
                base_slice = self._unpad_slice(base_slice)
85
                rss = self._calc_rss(my_slice, base_slice)
86
            else:
87
                rss = None
88
            slice_stats['RSS'] = rss
89
            return slice_stats
90
        return None
91
92
    def _calc_rss(self, array1, array2):  # residual sum of squares
93
        if array1.shape == array2.shape:
94
            residuals = np.subtract(array1, array2)
95
            rss = sum(value**2 for value in np.nditer(residuals))
96
        else:
97
            print("Warning: cannot calculate RSS, arrays different sizes.")  # need to make this an actual warning
98
            rss = None
99
        return rss
100
101
    def _rmsd_from_rss(self, rss, n):
102
        return np.sqrt(rss/n)
103
104
    def calc_rmsd(self, array1, array2):
105
        if array1.shape == array2.shape:
106
            rss = self._calc_rss(array1, array2)
107
            rmsd = self._rmsd_from_rss(rss, array1.size)
108
        else:
109
            print("Warning: cannot calculate RMSD, arrays different sizes.")  # need to make this an actual warning
110
            rmsd = None
111
        return rmsd
112
113
    def calc_stats_residuals(self, stats_before, stats_after):
114
        residuals = {'max': None, 'min': None, 'mean': None, 'std_dev': None}
115
        for key in list(residuals.keys()):
116
            residuals[key] = stats_after[key] - stats_before[key]
117
        return residuals
118
119
    def set_stats_residuals(self, residuals):
120
        self.residuals['max'].append(residuals['max'])
121
        self.residuals['min'].append(residuals['min'])
122
        self.residuals['mean'].append(residuals['mean'])
123
        self.residuals['std_dev'].append(residuals['std_dev'])
124
125
    def calc_volume_stats(self, slice_stats):
126
        volume_stats = np.array([max(slice_stats['max']), min(slice_stats['min']), np.mean(slice_stats['mean']),
127
                        np.mean(slice_stats['std_dev']), np.median(slice_stats['std_dev'])])
128
        return volume_stats
129
130
    def set_volume_stats(self):
131
        """Calculates volume-wide statistics from slice stats, and updates class-wide arrays with these values.
132
        Links volume stats with the output dataset and writes slice stats to file.
133
        """
134
        p_num = Statistics.count
135
        name = self.plugin_name
136
        i = 2
137
        while name in list(Statistics.plugin_numbers.keys()):
138
            name = self.plugin_name + str(i)
139
            i += 1
140
141
        if len(self.stats['max']) != 0:
142
            stats_array = self.calc_volume_stats(self.stats)
143
            Statistics.global_residuals[p_num] = {}
144
            before_processing = self.calc_volume_stats(self.stats_before_processing)
145
            #for key in list(before_processing.keys()):
146
            #    Statistics.global_residuals[p_num][key] = Statistics.global_stats[p_num][key] - before_processing[key]
147
            if None not in self.stats['RSS']:
148
                total_rss = sum(self.stats['RSS'])
149
                n = sum(self.stats['data_points'])
150
                RMSD = self._rmsd_from_rss(total_rss, n)
151
                stats_array = np.append(stats_array, RMSD)
152
            #else:
153
            #    stats_array = np.append(stats_array[p_num], None)
154
            if len(Statistics.global_stats[p_num]) == 0:
155
                Statistics.global_stats[p_num] = stats_array
156
            else:
157
                Statistics.global_stats[p_num] = np.vstack([Statistics.global_stats[p_num], stats_array])
158
            Statistics.plugin_numbers[name] = p_num
159
            if p_num not in list(Statistics.plugin_names.keys()):
160
                Statistics.plugin_names[p_num] = name
161
            self._link_stats_to_datasets(Statistics.global_stats[Statistics.plugin_numbers[name]])
162
163
        slice_stats_array = np.array([self.stats['max'], self.stats['min'], self.stats['mean'], self.stats['std_dev']])
164
        self._write_stats_to_file3(p_num)
165
        self._already_called = True
166
167
    def get_stats(self, plugin_name, n=None, stat=None):
168
        """Returns stats associated with a certain plugin.
169
170
        :param plugin_name: name of the plugin whose associated stats are being fetched.
171
        :param n: In a case where there are multiple instances of **plugin_name** in the process list,
172
            specify the nth instance. Not specifying will select the first (or only) instance.
173
        :param stat: Specify the stat parameter you want to fetch, i.e 'max', 'mean', 'median_std_dev'.
174
            If left blank will return the whole dictionary of stats:
175
            {'max': , 'min': , 'mean': , 'mean_std_dev': , 'median_std_dev': }
176
        """
177
        name = plugin_name
178
        if n is not None and n not in (0, 1):
179
            name = name + str(n)
180
        p_num = Statistics.plugin_numbers[name]
181
        return self.get_stats_from_num(p_num, stat)
182
183
    def get_stats_from_num(self, p_num, stat=None, instance=0):
184
        """Returns stats associated with a certain plugin, given the plugin number (its place in the process list).
185
186
        :param p_num: Plugin  number of the plugin whose associated stats are being fetched.
187
            If p_num <= 0, it is relative to the plugin number of the current plugin being run.
188
            E.g current plugin number = 5, p_num = -2 --> will return stats of the third plugin.
189
        :param stat: Specify the stat parameter you want to fetch, i.e 'max', 'mean', 'median_std_dev'.
190
            If left blank will return the whole dictionary of stats:
191
            {'max': , 'min': , 'mean': , 'mean_std_dev': , 'median_std_dev': }
192
        """
193
        if p_num <= 0:
194
            p_num = Statistics.count + p_num
195
        if Statistics.global_stats[p_num].ndim == 1:
196
            stats_array = Statistics.global_stats[p_num]
197
        else:
198
            stats_array = Statistics.global_stats[p_num][instance]
199
        stats_dict = self._array_to_dict(stats_array)
200
        if stat is not None:
201
            return stats_dict[stat]
202
        else:
203
            return stats_dict
204
205
    def get_stats_from_dataset(self, dataset, stat=None, instance=None):
206
        """Returns stats associated with a dataset.
207
208
        :param dataset: The dataset whose associated stats are being fetched.
209
        :param stat: Specify the stat parameter you want to fetch, i.e 'max', 'mean', 'median_std_dev'.
210
            If left blank will return the whole dictionary of stats:
211
            {'max': , 'min': , 'mean': , 'mean_std_dev': , 'median_std_dev': }
212
        :param instance: In the (rare) case that there are multiple sets of stats associated with the dataset,
213
            specify which set to return.
214
215
        """
216
        key = "stats"
217
        stats = {}
218
        if instance is not None and instance not in (0, 1):
219
            key = key + str(instance)
220
        stats = dataset.meta_data.get(key)
221
        if stat is not None:
222
            return stats[stat]
223
        else:
224
            return stats
225
226
    def get_data_stats(self):
227
        return Statistics.data_stats
228
229
    def get_volume_stats(self):
230
        return Statistics.volume_stats
231
232
    def get_global_stats(self):
233
        return Statistics.global_stats
234
235
    def _array_to_dict(self, stats_array):
236
        stats_dict = {}
237
        for i, value in enumerate(stats_array):
238
            stats_dict[Statistics._key_list[i]] = value
239
        return stats_dict
240
241
    def _set_pattern_info(self):
242
        """Gathers information about the pattern of the data in the current plugin."""
243
        in_datasets, out_datasets = self.plugin.get_datasets()
244
        try:
245
            self.pattern = self.plugin.parameters['pattern']
246
            if self.pattern == None:
247
                raise KeyError
248
        except KeyError:
249
            if not out_datasets:
250
                self.pattern = None
251
            else:
252
                patterns = out_datasets[0].get_data_patterns()
253
                for pattern in patterns:
254
                    if 1 in patterns.get(pattern)["slice_dims"]:
255
                        self.pattern = pattern
256
                        break
257
        self.calc_stats = False
258
        for dataset in out_datasets:
259
            if bool(set(Statistics._pattern_list) & set(dataset.data_info.get("data_patterns"))):
260
                self.calc_stats = True
261
262
    def _link_stats_to_datasets(self, stats):
263
        """Links the volume wide statistics to the output dataset(s)"""
264
        out_dataset = self.plugin.get_out_datasets()[0]
265
        n_datasets = self.plugin.nOutput_datasets()
266
        stats_dict = self._array_to_dict(stats)
267
        i = 2
268
        group_name = "stats"
269
        #out_dataset.data_info.set([group_name], stats)
270
        if n_datasets == 1:
271
            while group_name in list(out_dataset.meta_data.get_dictionary().keys()):
272
                group_name = f"stats{i}"
273
                i += 1
274
            for key in list(stats_dict.keys()):
275
                out_dataset.meta_data.set([group_name, key], stats_dict[key])
276
277
    def _write_stats_to_file2(self, p_num):
278
        path = Statistics.path
279
        filename = f"{path}/stats.h5"
280
        stats = Statistics.global_stats[p_num]
281
        array_dim = stats.shape
282
        self.hdf5 = Hdf5Utils(self.plugin.exp)
283
        group_name = f"{p_num}-{self.plugin_name}-stats"
284
        with h5.File(filename, "a") as h5file:
285
            if group_name not in h5file:
286
                group = h5file.create_group(group_name, track_order=None)
287
                dataset = self.hdf5.create_dataset_nofill(group, "stats", array_dim, stats.dtype)
288
                dataset[::] = stats[::]
289
            else:
290
                group = h5file[group_name]
291
292
293
    @classmethod
294
    def _write_stats_to_file4(cls):
295
        path = cls.path
296
        filename = f"{path}/stats.h5"
297
        stats = cls.global_stats
298
        cls.hdf5 = Hdf5Utils(cls.exp)
299
        for i in range(5):
300
            array = np.array([])
301
            stat = cls._key_list[i]
302
            for key in list(stats.keys()):
303
                if len(stats[key]) != 0:
304
                    if stats[key].ndim == 1:
305
                        array = np.append(array, stats[key][i])
306
                    else:
307
                        array = np.append(array, stats[key][0][i])
308
            array_dim = array.shape
309
            group_name = f"all-{stat}"
310
            with h5.File(filename, "a") as h5file:
311
                group = h5file.create_group(group_name, track_order=None)
312
                dataset = cls.hdf5.create_dataset_nofill(group, stat, array_dim, array.dtype)
313
                dataset[::] = array[::]
314
315
    def _write_stats_to_file3(self, p_num):
316
        path = Statistics.path
317
        filename = f"{path}/stats.h5"
318
        stats = self.global_stats
319
        self.hdf5 = Hdf5Utils(self.exp)
320
        with h5.File(filename, "a") as h5file:
321
            group = h5file.require_group("stats")
322
            if stats[p_num].shape != (0,):
323
                if str(p_num) in list(group.keys()):
324
                    del group[str(p_num)]
325
                dataset = group.create_dataset(str(p_num), shape=stats[p_num].shape, dtype=stats[p_num].dtype)
326
                dataset[::] = stats[p_num][::]
327
                dataset.attrs.create("plugin_name", self.plugin_names[p_num])
328
                dataset.attrs.create("pattern", self.pattern)
329
330
331
    def _write_stats_to_file(self, slice_stats_array, p_num):
332
        """Writes slice statistics to a h5 file"""
333
        path = Statistics.path
334
        filename = f"{path}/stats_p{p_num}_{self.plugin_name}.h5"
335
        slice_stats_dim = (slice_stats_array.shape[1],)
336
        self.hdf5 = Hdf5Utils(self.plugin.exp)
337
        with h5.File(filename, "a") as h5file:
338
            i = 2
339
            group_name = "/stats"
340
            while group_name in h5file:
341
                group_name = f"/stats{i}"
342
                i += 1
343
            group = h5file.create_group(group_name, track_order=None)
344
            max_ds = self.hdf5.create_dataset_nofill(group, "max", slice_stats_dim, slice_stats_array.dtype)
345
            min_ds = self.hdf5.create_dataset_nofill(group, "min", slice_stats_dim, slice_stats_array.dtype)
346
            mean_ds = self.hdf5.create_dataset_nofill(group, "mean", slice_stats_dim, slice_stats_array.dtype)
347
            std_dev_ds = self.hdf5.create_dataset_nofill(group, "standard_deviation",
348
                                                         slice_stats_dim, slice_stats_array.dtype)
349
            if slice_stats_array.shape[0] == 5:
350
                rmsd_ds = self.hdf5.create_dataset_nofill(group, "RMSD", slice_stats_dim, slice_stats_array.dtype)
351
                rmsd_ds[::] = slice_stats_array[4]
352
            max_ds[::] = slice_stats_array[0]
353
            min_ds[::] = slice_stats_array[1]
354
            mean_ds[::] = slice_stats_array[2]
355
            std_dev_ds[::] = slice_stats_array[3]
356
357
    def _unpad_slice(self, slice1):
358
        """If data is padded in the slice dimension, removes this pad."""
359
        out_datasets = self.plugin.get_out_datasets()
360
        if len(out_datasets) == 1:
361
            out_dataset = out_datasets[0]
362
        else:
363
            for dataset in out_datasets:
364
                if self.pattern in list(dataset.data_info.get(["data_patterns"]).keys()):
365
                    out_dataset = dataset
366
                    break
367
        slice_dims = out_dataset.get_slice_dimensions()
0 ignored issues
show
introduced by
The variable out_dataset does not seem to be defined for all execution paths.
Loading history...
368
        if self.plugin.pcount == 0:
369
            self.slice_list, self.pad = self._get_unpadded_slice_list(slice1, slice_dims)
370
        if self.pad:
371
            #for slice_dim in slice_dims:
372
            slice_dim = slice_dims[0]
373
            temp_slice = np.swapaxes(slice1, 0, slice_dim)
374
            temp_slice = temp_slice[self.slice_list[slice_dim]]
375
            slice1 = np.swapaxes(temp_slice, 0, slice_dim)
376
        return slice1
377
378
    def _get_unpadded_slice_list(self, slice1, slice_dims):
379
        """Creates slice object(s) to un-pad slices in the slice dimension(s)."""
380
        slice_list = list(self.plugin.slice_list[0])
381
        pad = False
382
        if len(slice_list) == len(slice1.shape):
383
            #for i in slice_dims:
384
            i = slice_dims[0]
385
            slice_width = self.plugin.slice_list[0][i].stop - self.plugin.slice_list[0][i].start
386
            if slice_width != slice1.shape[i]:
387
                pad = True
388
                pad_width = (slice1.shape[i] - slice_width) // 2  # Assuming symmetrical padding
389
                slice_list[i] = slice(pad_width, pad_width + 1, 1)
390
            return tuple(slice_list), pad
391
        else:
392
            return self.plugin.slice_list[0], pad
393
394
    def _de_list(self, slice1):
395
        """If the slice is in a list, remove it from that list."""
396
        if type(slice1) == list:
397
            if len(slice1) != 0:
398
                slice1 = slice1[0]
399
                slice1 = self._de_list(slice1)
400
        return slice1
401
402
403
    @classmethod
404
    def _count(cls):
405
        cls.count += 1
406
407
    @classmethod
408
    def _post_chain(cls):
409
        print(cls.data_stats)
410
        print(cls.volume_stats)
411
        print(cls.global_stats)
412
        print(cls.global_residuals)
413
        stats_utils = StatsUtils()
414
        stats_utils.generate_figures(f"{cls.path}/stats.h5", cls.path)
415