Test Failed
Pull Request — master (#878)
by
unknown
03:59
created

savu.plugins.stats.statistics   F

Complexity

Total Complexity 97

Size/Duplication

Total Lines 453
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 332
dl 0
loc 453
rs 2
c 0
b 0
f 0
wmc 97

30 Methods

Rating   Name   Duplication   Size   Complexity  
A Statistics.setup() 0 11 4
A Statistics.get_data_stats() 0 2 1
A Statistics._write_stats_to_file() 0 25 4
A Statistics.get_global_stats() 0 2 1
A Statistics.get_stats_from_dataset() 0 28 5
B Statistics._write_stats_to_file4() 0 21 6
B Statistics._set_pattern_info() 0 20 8
A Statistics._get_unpadded_slice_list() 0 15 3
B Statistics.get_stats_from_num() 0 33 8
A Statistics.set_stats_residuals() 0 5 1
A Statistics.calc_rmsd() 0 8 2
A Statistics._write_stats_to_file2() 0 14 3
A Statistics._calc_rss() 0 8 2
A Statistics.get_stats() 0 18 2
A Statistics._write_stats_to_file3() 0 14 4
A Statistics._array_to_dict() 0 5 2
A Statistics.calc_volume_stats() 0 13 2
A Statistics._rmsd_from_rss() 0 2 1
A Statistics._count() 0 3 1
A Statistics._post_chain() 0 5 2
A Statistics.calc_stats_residuals() 0 5 2
B Statistics._unpad_slice() 0 20 6
A Statistics._link_stats_to_datasets() 0 17 5
A Statistics.calc_slice_stats() 0 20 3
A Statistics._de_list() 0 7 3
A Statistics.__init__() 0 6 1
A Statistics.set_slice_stats() 0 7 3
B Statistics.set_volume_stats() 0 39 7
A Statistics._setup_class() 0 20 4
A Statistics.get_volume_stats() 0 2 1

How to fix   Complexity   

Complexity

Complex classes like savu.plugins.stats.statistics often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
"""
2
.. module:: statistics
3
   :platform: Unix
4
   :synopsis: Contains and processes statistics information for each plugin.
5
6
.. moduleauthor::Jacob Williamson <[email protected]>
7
8
"""
9
10
from savu.plugins.savers.utils.hdf5_utils import Hdf5Utils
11
from savu.plugins.stats.stats_utils import StatsUtils
12
13
import h5py as h5
14
import numpy as np
15
import os
16
from mpi4py import MPI
17
18
19
class Statistics(object):
20
    _pattern_list = ["SINOGRAM", "PROJECTION", "TANGENTOGRAM", "VOLUME_YZ", "VOLUME_XZ", "VOLUME_XY", "VOLUME_3D", "4D_SCAN", "SINOMOVIE"]
21
    no_stats_plugins = ["BasicOperations", "Mipmap"]
22
    _key_list = ["max", "min", "mean", "mean_std_dev", "median_std_dev", "NRMSD"]
23
24
25
    def __init__(self):
26
        self.calc_stats = True
27
        self.stats = {'max': [], 'min': [], 'mean': [], 'std_dev': [], 'RSS': [], 'data_points': []}
28
        self.stats_before_processing = {'max': [], 'min': [], 'mean': [], 'std_dev': []}
29
        self.residuals = {'max': [], 'min': [], 'mean': [], 'std_dev': []}
30
        self._repeat_count = 0
31
32
    def setup(self, plugin_self):
33
        if plugin_self.name in Statistics.no_stats_plugins:
34
            self.calc_stats = False
35
        if self.calc_stats:
36
            self.plugin = plugin_self
37
            self.plugin_name = plugin_self.name
38
            self.pad_dims = []
39
            self._already_called = False
40
            self._set_pattern_info()
41
        if self.calc_stats:
42
            Statistics._any_stats = True
43
44
45
    @classmethod
46
    def _setup_class(cls, exp):
47
        """Sets up the statistics class for the whole plugin chain (only called once)"""
48
        cls._any_stats = False
49
        cls.count = 2
50
        cls.global_stats = {}
51
        cls.exp = exp
52
        cls.n_plugins = len(exp.meta_data.plugin_list.plugin_list)
53
        for i in range(1, cls.n_plugins + 1):
54
            cls.global_stats[i] = np.array([])
55
        cls.global_residuals = {}
56
        cls.plugin_numbers = {}
57
        cls.plugin_names = {}
58
59
        cls.path = exp.meta_data['out_path']
60
        if cls.path[-1] == '/':
61
            cls.path = cls.path[0:-1]
62
        cls.path = f"{cls.path}/stats"
63
        if not os.path.exists(cls.path):
64
            os.makedirs(cls.path, exist_ok=True)
65
66
    def set_slice_stats(self, slice, base_slice):
67
        slice_stats_before = self.calc_slice_stats(base_slice)
68
        slice_stats_after = self.calc_slice_stats(slice, base_slice)
69
        for key in list(self.stats_before_processing.keys()):
70
            self.stats_before_processing[key].append(slice_stats_before[key])
71
        for key in list(self.stats.keys()):
72
            self.stats[key].append(slice_stats_after[key])
73
74
    def calc_slice_stats(self, my_slice, base_slice=None):
75
        """Calculates and returns slice stats for the current slice.
76
77
        :param slice1: The slice whose stats are being calculated.
78
        """
79
        if my_slice is not None:
80
            slice_num = self.plugin.pcount
81
            my_slice = self._de_list(my_slice)
82
            my_slice = self._unpad_slice(my_slice)
83
            slice_stats = {'max': np.amax(my_slice).astype('float64'), 'min': np.amin(my_slice).astype('float64'),
84
                           'mean': np.mean(my_slice), 'std_dev': np.std(my_slice), 'data_points': my_slice.size}
85
            if base_slice is not None:
86
                base_slice = self._de_list(base_slice)
87
                base_slice = self._unpad_slice(base_slice)
88
                rss = self._calc_rss(my_slice, base_slice)
89
            else:
90
                rss = None
91
            slice_stats['RSS'] = rss
92
            return slice_stats
93
        return None
94
95
    def _calc_rss(self, array1, array2):  # residual sum of squares
96
        if array1.shape == array2.shape:
97
            residuals = np.subtract(array1, array2)
98
            rss = sum(value**2 for value in np.nditer(residuals))
99
        else:
100
            #print("Warning: cannot calculate RSS, arrays different sizes.")  # need to make this an actual warning
101
            rss = None
102
        return rss
103
104
    def _rmsd_from_rss(self, rss, n):
105
        return np.sqrt(rss/n)
106
107
    def calc_rmsd(self, array1, array2):
108
        if array1.shape == array2.shape:
109
            rss = self._calc_rss(array1, array2)
110
            rmsd = self._rmsd_from_rss(rss, array1.size)
111
        else:
112
            print("Warning: cannot calculate RMSD, arrays different sizes.")  # need to make this an actual warning
113
            rmsd = None
114
        return rmsd
115
116
    def calc_stats_residuals(self, stats_before, stats_after):
117
        residuals = {'max': None, 'min': None, 'mean': None, 'std_dev': None}
118
        for key in list(residuals.keys()):
119
            residuals[key] = stats_after[key] - stats_before[key]
120
        return residuals
121
122
    def set_stats_residuals(self, residuals):
123
        self.residuals['max'].append(residuals['max'])
124
        self.residuals['min'].append(residuals['min'])
125
        self.residuals['mean'].append(residuals['mean'])
126
        self.residuals['std_dev'].append(residuals['std_dev'])
127
128
    def calc_volume_stats(self, slice_stats):
129
        volume_stats = np.array([max(slice_stats['max']), min(slice_stats['min']), np.mean(slice_stats['mean']),
130
                                np.mean(slice_stats['std_dev']), np.median(slice_stats['std_dev'])])
131
        if None not in slice_stats['RSS']:
132
            total_rss = sum(slice_stats['RSS'])
133
            n = sum(slice_stats['data_points'])
134
            RMSD = self._rmsd_from_rss(total_rss, n)
135
            NRMSD = RMSD / abs(volume_stats[2])  # normalised RMSD (dividing by mean)
136
            volume_stats = np.append(volume_stats, NRMSD)
137
        else:
138
            #volume_stats = np.append(volume_stats, None)
139
            pass
140
        return volume_stats
141
142
    def set_volume_stats(self):
143
        """Calculates volume-wide statistics from slice stats, and updates class-wide arrays with these values.
144
        Links volume stats with the output dataset and writes slice stats to file.
145
        """
146
        comm = MPI.COMM_WORLD
147
        rank = comm.rank
148
        size = comm.size
149
        stats = self.stats
150
        combined_stats_list = comm.allgather(stats)
151
        combined_stats = {'max': [], 'min': [], 'mean': [], 'std_dev': [], 'RSS': [], 'data_points': []}
152
        for single_stats in combined_stats_list:
153
            for key in list(single_stats.keys()):
154
                combined_stats[key] += single_stats[key]
155
        p_num = Statistics.count
156
        name = self.plugin_name
157
        i = 2
158
        while name in list(Statistics.plugin_numbers.keys()):
159
            name = self.plugin_name + str(i)
160
            i += 1
161
        if len(self.stats['max']) != 0:
162
            stats_array = self.calc_volume_stats(combined_stats)
163
            Statistics.global_residuals[p_num] = {}
164
            #before_processing = self.calc_volume_stats(self.stats_before_processing)
165
            #for key in list(before_processing.keys()):
166
            #    Statistics.global_residuals[p_num][key] = Statistics.global_stats[p_num][key] - before_processing[key]
167
168
            if len(Statistics.global_stats[p_num]) == 0:
169
                Statistics.global_stats[p_num] = stats_array
170
            else:
171
                Statistics.global_stats[p_num] = np.vstack([Statistics.global_stats[p_num], stats_array])
172
            Statistics.plugin_numbers[name] = p_num
173
            if p_num not in list(Statistics.plugin_names.keys()):
174
                Statistics.plugin_names[p_num] = name
175
            self._link_stats_to_datasets(Statistics.global_stats[Statistics.plugin_numbers[name]])
176
177
        #slice_stats_array = np.array([self.stats['max'], self.stats['min'], self.stats['mean'], self.stats['std_dev']])
178
        self._write_stats_to_file3(p_num)
179
        self._already_called = True
180
        self._repeat_count += 1
181
182
    def get_stats(self, plugin_name, n=None, stat=None, instance=1):
183
        """Returns stats associated with a certain plugin.
184
185
        :param plugin_name: name of the plugin whose associated stats are being fetched.
186
        :param n: In a case where there are multiple instances of **plugin_name** in the process list,
187
            specify the nth instance. Not specifying will select the first (or only) instance.
188
        :param stat: Specify the stat parameter you want to fetch, i.e 'max', 'mean', 'median_std_dev'.
189
            If left blank will return the whole dictionary of stats:
190
            {'max': , 'min': , 'mean': , 'mean_std_dev': , 'median_std_dev': , 'NRMSD' }
191
        :param instance: In cases where there are multiple set of stats associated with a plugin
192
            due to multi-parameters, specify which set you want to retrieve, i.e 3 to retrieve the
193
            stats associated with the third run of a plugin. Pass 'all' to get a list of all sets.
194
        """
195
        name = plugin_name
196
        if n in (None, 0, 1):
197
            name = name + str(n)
198
        p_num = Statistics.plugin_numbers[name]
199
        return self.get_stats_from_num(p_num, stat, instance)
200
201
    def get_stats_from_num(self, p_num, stat=None, instance=1):
202
        """Returns stats associated with a certain plugin, given the plugin number (its place in the process list).
203
204
        :param p_num: Plugin  number of the plugin whose associated stats are being fetched.
205
            If p_num <= 0, it is relative to the plugin number of the current plugin being run.
206
            E.g current plugin number = 5, p_num = -2 --> will return stats of the third plugin.
207
        :param stat: Specify the stat parameter you want to fetch, i.e 'max', 'mean', 'median_std_dev'.
208
            If left blank will return the whole dictionary of stats:
209
            {'max': , 'min': , 'mean': , 'mean_std_dev': , 'median_std_dev': , 'NRMSD' }
210
        :param instance: In cases where there are multiple set of stats associated with a plugin
211
            due to multi-parameters, specify which set you want to retrieve, i.e 3 to retrieve the
212
            stats associated with the third run of a plugin. Pass 'all' to get a list of all sets.
213
        """
214
        if p_num <= 0:
215
            p_num = Statistics.count + p_num
216
        if Statistics.global_stats[p_num].ndim == 1 and instance in (None, 0, 1, "all"):
217
            stats_array = Statistics.global_stats[p_num]
218
        else:
219
            if instance == "all":
220
                stats_list = [self.get_stats_from_num(p_num, stat=stat, instance=1)]
221
                n = 2
222
                while n <= Statistics.global_stats[p_num].ndim:
223
                    stats_list.append(self.get_stats_from_num(p_num, stat=stat, instance=n))
224
                    n += 1
225
                return stats_list
226
            if instance > 0:
227
                instance -= 1
228
            stats_array = Statistics.global_stats[p_num][instance]
229
        stats_dict = self._array_to_dict(stats_array)
230
        if stat is not None:
231
            return stats_dict[stat]
232
        else:
233
            return stats_dict
234
235
    def get_stats_from_dataset(self, dataset, stat=None, instance=None):
236
        """Returns stats associated with a dataset.
237
238
        :param dataset: The dataset whose associated stats are being fetched.
239
        :param stat: Specify the stat parameter you want to fetch, i.e 'max', 'mean', 'median_std_dev'.
240
            If left blank will return the whole dictionary of stats:
241
            {'max': , 'min': , 'mean': , 'mean_std_dev': , 'median_std_dev': , 'NRMSD'}
242
        :param instance: In cases where there are multiple set of stats associated with a dataset
243
            due to multi-parameters, specify which set you want to retrieve, i.e 3 to retrieve the
244
            stats associated with the third run of a plugin. Pass 'all' to get a list of all sets.
245
246
        """
247
        key = "stats"
248
        stats = {}
249
        if instance not in (None, 0, 1):
250
            if instance == "all":
251
                stats = [self.get_stats_from_dataset(dataset, stat=stat, instance=1)]
252
                n = 2
253
                while ("stats" + str(n)) in list(dataset.meta_data.get_dictionary().keys()):
254
                    stats.append(self.get_stats_from_dataset(dataset, stat=stat, instance=n))
255
                    n += 1
256
                return stats
257
            key = key + str(instance)
258
        stats = dataset.meta_data.get(key)
259
        if stat is not None:
260
            return stats[stat]
261
        else:
262
            return stats
263
264
    def get_data_stats(self):
265
        return Statistics.data_stats
266
267
    def get_volume_stats(self):
268
        return Statistics.volume_stats
269
270
    def get_global_stats(self):
271
        return Statistics.global_stats
272
273
    def _array_to_dict(self, stats_array):
274
        stats_dict = {}
275
        for i, value in enumerate(stats_array):
276
            stats_dict[Statistics._key_list[i]] = value
277
        return stats_dict
278
279
    def _set_pattern_info(self):
280
        """Gathers information about the pattern of the data in the current plugin."""
281
        in_datasets, out_datasets = self.plugin.get_datasets()
282
        try:
283
            self.pattern = self.plugin.parameters['pattern']
284
            if self.pattern == None:
285
                raise KeyError
286
        except KeyError:
287
            if not out_datasets:
288
                self.pattern = None
289
            else:
290
                patterns = out_datasets[0].get_data_patterns()
291
                for pattern in patterns:
292
                    if 1 in patterns.get(pattern)["slice_dims"]:
293
                        self.pattern = pattern
294
                        break
295
        self.calc_stats = False
296
        for dataset in out_datasets:
297
            if bool(set(Statistics._pattern_list) & set(dataset.data_info.get("data_patterns"))):
298
                self.calc_stats = True
299
300
    def _link_stats_to_datasets(self, stats):
301
        """Links the volume wide statistics to the output dataset(s)"""
302
        out_dataset = self.plugin.get_out_datasets()[0]
303
        n_datasets = self.plugin.nOutput_datasets()
304
        if self._repeat_count > 0:
305
            stats_dict = self._array_to_dict(stats[self._repeat_count])
306
        else:
307
            stats_dict = self._array_to_dict(stats)
308
        i = 2
309
        group_name = "stats"
310
        #out_dataset.data_info.set([group_name], stats)
311
        if n_datasets == 1:
312
            while group_name in list(out_dataset.meta_data.get_dictionary().keys()):
313
                group_name = f"stats{i}"
314
                i += 1
315
            for key in list(stats_dict.keys()):
316
                out_dataset.meta_data.set([group_name, key], stats_dict[key])
317
318
    def _write_stats_to_file2(self, p_num):
319
        path = Statistics.path
320
        filename = f"{path}/stats.h5"
321
        stats = Statistics.global_stats[p_num]
322
        array_dim = stats.shape
323
        self.hdf5 = Hdf5Utils(self.plugin.exp)
324
        group_name = f"{p_num}-{self.plugin_name}-stats"
325
        with h5.File(filename, "a") as h5file:
326
            if group_name not in h5file:
327
                group = h5file.create_group(group_name, track_order=None)
328
                dataset = self.hdf5.create_dataset_nofill(group, "stats", array_dim, stats.dtype)
329
                dataset[::] = stats[::]
330
            else:
331
                group = h5file[group_name]
332
333
334
    @classmethod
335
    def _write_stats_to_file4(cls):
336
        path = cls.path
337
        filename = f"{path}/stats.h5"
338
        stats = cls.global_stats
339
        cls.hdf5 = Hdf5Utils(cls.exp)
340
        for i in range(5):
341
            array = np.array([])
342
            stat = cls._key_list[i]
343
            for key in list(stats.keys()):
344
                if len(stats[key]) != 0:
345
                    if stats[key].ndim == 1:
346
                        array = np.append(array, stats[key][i])
347
                    else:
348
                        array = np.append(array, stats[key][0][i])
349
            array_dim = array.shape
350
            group_name = f"all-{stat}"
351
            with h5.File(filename, "a") as h5file:
352
                group = h5file.create_group(group_name, track_order=None)
353
                dataset = cls.hdf5.create_dataset_nofill(group, stat, array_dim, array.dtype)
354
                dataset[::] = array[::]
355
356
    def _write_stats_to_file3(self, p_num):
357
        path = Statistics.path
358
        filename = f"{path}/stats.h5"
359
        stats = self.global_stats
360
        self.hdf5 = Hdf5Utils(self.exp)
361
        with h5.File(filename, "a", driver="mpio", comm=MPI.COMM_WORLD) as h5file:
362
            group = h5file.require_group("stats")
363
            if stats[p_num].shape != (0,):
364
                if str(p_num) in list(group.keys()):
365
                    del group[str(p_num)]
366
                dataset = group.create_dataset(str(p_num), shape=stats[p_num].shape, dtype=stats[p_num].dtype)
367
                dataset[::] = stats[p_num][::]
368
                dataset.attrs.create("plugin_name", self.plugin_names[p_num])
369
                dataset.attrs.create("pattern", self.pattern)
370
371
372
    def _write_stats_to_file(self, slice_stats_array, p_num):
373
        """Writes slice statistics to a h5 file"""
374
        path = Statistics.path
375
        filename = f"{path}/stats_p{p_num}_{self.plugin_name}.h5"
376
        slice_stats_dim = (slice_stats_array.shape[1],)
377
        self.hdf5 = Hdf5Utils(self.plugin.exp)
378
        with h5.File(filename, "a") as h5file:
379
            i = 2
380
            group_name = "/stats"
381
            while group_name in h5file:
382
                group_name = f"/stats{i}"
383
                i += 1
384
            group = h5file.create_group(group_name, track_order=None)
385
            max_ds = self.hdf5.create_dataset_nofill(group, "max", slice_stats_dim, slice_stats_array.dtype)
386
            min_ds = self.hdf5.create_dataset_nofill(group, "min", slice_stats_dim, slice_stats_array.dtype)
387
            mean_ds = self.hdf5.create_dataset_nofill(group, "mean", slice_stats_dim, slice_stats_array.dtype)
388
            std_dev_ds = self.hdf5.create_dataset_nofill(group, "standard_deviation",
389
                                                         slice_stats_dim, slice_stats_array.dtype)
390
            if slice_stats_array.shape[0] == 5:
391
                nrmsd_ds = self.hdf5.create_dataset_nofill(group, "NRMSD", slice_stats_dim, slice_stats_array.dtype)
392
                nrmsd_ds[::] = slice_stats_array[4]
393
            max_ds[::] = slice_stats_array[0]
394
            min_ds[::] = slice_stats_array[1]
395
            mean_ds[::] = slice_stats_array[2]
396
            std_dev_ds[::] = slice_stats_array[3]
397
398
    def _unpad_slice(self, slice1):
399
        """If data is padded in the slice dimension, removes this pad."""
400
        out_datasets = self.plugin.get_out_datasets()
401
        if len(out_datasets) == 1:
402
            out_dataset = out_datasets[0]
403
        else:
404
            for dataset in out_datasets:
405
                if self.pattern in list(dataset.data_info.get(["data_patterns"]).keys()):
406
                    out_dataset = dataset
407
                    break
408
        slice_dims = out_dataset.get_slice_dimensions()
0 ignored issues
show
introduced by
The variable out_dataset does not seem to be defined for all execution paths.
Loading history...
409
        if self.plugin.pcount == 0:
410
            self.slice_list, self.pad = self._get_unpadded_slice_list(slice1, slice_dims)
411
        if self.pad:
412
            #for slice_dim in slice_dims:
413
            slice_dim = slice_dims[0]
414
            temp_slice = np.swapaxes(slice1, 0, slice_dim)
415
            temp_slice = temp_slice[self.slice_list[slice_dim]]
416
            slice1 = np.swapaxes(temp_slice, 0, slice_dim)
417
        return slice1
418
419
    def _get_unpadded_slice_list(self, slice1, slice_dims):
420
        """Creates slice object(s) to un-pad slices in the slice dimension(s)."""
421
        slice_list = list(self.plugin.slice_list[0])
422
        pad = False
423
        if len(slice_list) == len(slice1.shape):
424
            #for i in slice_dims:
425
            i = slice_dims[0]
426
            slice_width = self.plugin.slice_list[0][i].stop - self.plugin.slice_list[0][i].start
427
            if slice_width != slice1.shape[i]:
428
                pad = True
429
                pad_width = (slice1.shape[i] - slice_width) // 2  # Assuming symmetrical padding
430
                slice_list[i] = slice(pad_width, pad_width + 1, 1)
431
            return tuple(slice_list), pad
432
        else:
433
            return self.plugin.slice_list[0], pad
434
435
    def _de_list(self, slice1):
436
        """If the slice is in a list, remove it from that list."""
437
        if type(slice1) == list:
438
            if len(slice1) != 0:
439
                slice1 = slice1[0]
440
                slice1 = self._de_list(slice1)
441
        return slice1
442
443
444
    @classmethod
445
    def _count(cls):
446
        cls.count += 1
447
448
    @classmethod
449
    def _post_chain(cls):
450
        if cls._any_stats:
451
            stats_utils = StatsUtils()
452
            stats_utils.generate_figures(f"{cls.path}/stats.h5", cls.path)
453