Test Failed
Pull Request — master (#878)
by
unknown
04:12
created

savu.plugins.stats.statistics   F

Complexity

Total Complexity 72

Size/Duplication

Total Lines 330
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 239
dl 0
loc 330
rs 2.64
c 0
b 0
f 0
wmc 72

26 Methods

Rating   Name   Duplication   Size   Complexity  
A Statistics.setup() 0 9 3
A Statistics.get_data_stats() 0 2 1
A Statistics._write_stats_to_file() 0 25 4
A Statistics.get_global_stats() 0 2 1
A Statistics.get_stats_from_dataset() 0 20 4
B Statistics._set_pattern_info() 0 20 8
A Statistics._get_unpadded_slice_list() 0 14 4
A Statistics.get_stats_from_num() 0 16 3
A Statistics.set_stats_residuals() 0 5 1
A Statistics.calc_rmsd() 0 8 2
A Statistics.get_stats() 0 15 3
A Statistics._calc_rss() 0 8 2
A Statistics.calc_volume_stats() 0 5 1
A Statistics._rmsd_from_rss() 0 2 1
A Statistics._count() 0 3 1
A Statistics.calc_stats_residuals() 0 5 2
B Statistics._unpad_slice() 0 19 7
A Statistics._post_chain() 0 6 1
A Statistics._link_stats_to_datasets() 0 12 4
A Statistics.calc_slice_stats() 0 20 3
A Statistics._de_list() 0 7 3
A Statistics.__init__() 0 5 1
A Statistics.set_slice_stats() 0 7 3
B Statistics.set_volume_stats() 0 30 5
A Statistics._setup_class() 0 16 3
A Statistics.get_volume_stats() 0 2 1

How to fix   Complexity   

Complexity

Complex classes like savu.plugins.stats.statistics often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
"""
2
.. module:: statistics
3
   :platform: Unix
4
   :synopsis: Contains and processes statistics information for each plugin.
5
6
.. moduleauthor::Jacob Williamson <[email protected]>
7
8
"""
9
10
from savu.plugins.savers.utils.hdf5_utils import Hdf5Utils
11
12
import h5py as h5
13
import numpy as np
14
import os
15
16
17
class Statistics(object):
18
    pattern_list = ["SINOGRAM", "PROJECTION", "VOLUME_YZ", "VOLUME_XZ", "VOLUME_XY", "VOLUME_3D", "4D_SCAN", "SINOMOVIE"]
19
    no_stats_plugins = ["BasicOperations", "Mipmap"]
20
21
    def __init__(self):
22
        self.calc_stats = True
23
        self.stats = {'max': [], 'min': [], 'mean': [], 'std_dev': [], 'RSS': [], 'data_points': []}
24
        self.stats_before_processing = {'max': [], 'min': [], 'mean': [], 'std_dev': []}
25
        self.residuals = {'max': [], 'min': [], 'mean': [], 'std_dev': []}
26
27
    def setup(self, plugin_self):
28
        if plugin_self.name in Statistics.no_stats_plugins:
29
            self.calc_stats = False
30
        if self.calc_stats:
31
            self.plugin = plugin_self
32
            self.plugin_name = plugin_self.name
33
            self.pad_dims = []
34
            self._already_called = False
35
            self._set_pattern_info()
36
37
38
    @classmethod
39
    def _setup_class(cls, exp):
40
        """Sets up the statistics class for the whole experiment (only called once)"""
41
        cls.count = 2
42
        cls.data_stats = {}
43
        cls.volume_stats = {}
44
        cls.global_stats = {}
45
        cls.global_residuals = {}
46
        cls.plugin_numbers = {}
47
        n_plugins = len(exp.meta_data.plugin_list.plugin_list)
48
        cls.path = exp.meta_data['out_path']
49
        if cls.path[-1] == '/':
50
            cls.path = cls.path[0:-1]
51
        cls.path = f"{cls.path}/stats"
52
        if not os.path.exists(cls.path):
53
            os.mkdir(cls.path)
54
55
    def set_slice_stats(self, slice, base_slice):
56
        slice_stats_before = self.calc_slice_stats(base_slice)
57
        slice_stats_after = self.calc_slice_stats(slice, base_slice)
58
        for key in list(self.stats_before_processing.keys()):
59
            self.stats_before_processing[key].append(slice_stats_before[key])
60
        for key in list(self.stats.keys()):
61
            self.stats[key].append(slice_stats_after[key])
62
63
    def calc_slice_stats(self, my_slice, base_slice=None):
64
        """Calculates and returns slice stats for the current slice.
65
66
        :param slice1: The slice whose stats are being calculated.
67
        """
68
        if my_slice is not None:
69
            slice_num = self.plugin.pcount
70
            my_slice = self._de_list(my_slice)
71
            my_slice = self._unpad_slice(my_slice)
72
            slice_stats = {'max': np.amax(my_slice).astype('float64'), 'min': np.amin(my_slice).astype('float64'),
73
                           'mean': np.mean(my_slice), 'std_dev': np.std(my_slice), 'data_points': my_slice.size}
74
            if base_slice is not None:
75
                base_slice = self._de_list(base_slice)
76
                base_slice = self._unpad_slice(base_slice)
77
                rss = self._calc_rss(my_slice, base_slice)
78
            else:
79
                rss = None
80
            slice_stats['RSS'] = rss
81
            return slice_stats
82
        return None
83
84
    def _calc_rss(self, array1, array2): # residual sum of squares
85
        if array1.shape == array2.shape:
86
            residuals = np.subtract(array1, array2)
87
            rss = sum(value**2 for value in np.nditer(residuals))
88
        else:
89
            print("Warning: cannot calculate RSS, arrays different sizes.")  # need to make this an actual warning
90
            rss = None
91
        return rss
92
93
    def _rmsd_from_rss(self, rss, n):
94
        return np.sqrt(rss/n)
95
96
    def calc_rmsd(self, array1, array2):
97
        if array1.shape == array2.shape:
98
            rss = self._calc_rss(array1, array2)
99
            rmsd = self._rmsd_from_rss(rss, array1.size)
100
        else:
101
            print("Warning: cannot calculate RMSD, arrays different sizes.")  # need to make this an actual warning
102
            rmsd = None
103
        return rmsd
104
105
    def calc_stats_residuals(self, stats_before, stats_after):
106
        residuals = {'max': None, 'min': None, 'mean': None, 'std_dev': None}
107
        for key in list(residuals.keys()):
108
            residuals[key] = stats_after[key] - stats_before[key]
109
        return residuals
110
111
    def set_stats_residuals(self, residuals):
112
        self.residuals['max'].append(residuals['max'])
113
        self.residuals['min'].append(residuals['min'])
114
        self.residuals['mean'].append(residuals['mean'])
115
        self.residuals['std_dev'].append(residuals['std_dev'])
116
117
    def calc_volume_stats(self, slice_stats):
118
        volume_stats = {'max': max(slice_stats['max']), 'min': min(slice_stats['min']),
119
                        'mean': np.mean(slice_stats['mean']), 'mean_std_dev': np.mean(slice_stats['std_dev']),
120
                        'median_std_dev': np.median(slice_stats['std_dev'])}
121
        return volume_stats
122
123
    def set_volume_stats(self):
124
        """Calculates volume-wide statistics from slice stats, and updates class-wide arrays with these values.
125
        Links volume stats with the output dataset and writes slice stats to file.
126
        """
127
        p_num = Statistics.count
128
        name = self.plugin_name
129
        i = 2
130
        while name in list(Statistics.plugin_numbers.keys()):
131
            name = self.plugin_name + str(i)
132
            i += 1
133
        Statistics.global_stats[p_num] = {}
134
        if len(self.stats['max']) != 0:
135
            Statistics.global_stats[p_num] = self.calc_volume_stats(self.stats)
136
            Statistics.global_residuals[p_num] = {}
137
            before_processing = self.calc_volume_stats(self.stats_before_processing)
138
            for key in list(before_processing.keys()):
139
                Statistics.global_residuals[p_num][key] = Statistics.global_stats[p_num][key] - before_processing[key]
140
            if None not in self.stats['RSS']:
141
                total_rss = sum(self.stats['RSS'])
142
                n = sum(self.stats['data_points'])
143
                Statistics.global_stats[p_num]['RMSD'] = self._rmsd_from_rss(total_rss, n)
144
            else:
145
                Statistics.global_stats[p_num]['RMSD'] = None
146
147
            Statistics.plugin_numbers[name] = p_num
148
            self._link_stats_to_datasets(Statistics.global_stats[Statistics.plugin_numbers[name]])
149
150
        slice_stats_array = np.array([self.stats['max'], self.stats['min'], self.stats['mean'], self.stats['std_dev']])
151
        self._write_stats_to_file(slice_stats_array, p_num)
152
        self._already_called = True
153
154
    def get_stats(self, plugin_name, n=None, stat=None):
155
        """Returns stats associated with a certain plugin.
156
157
        :param plugin_name: name of the plugin whose associated stats are being fetched.
158
        :param n: In a case where there are multiple instances of **plugin_name** in the process list,
159
            specify the nth instance. Not specifying will select the first (or only) instance.
160
        :param stat: Specify the stat parameter you want to fetch, i.e 'max', 'mean', 'median_std_dev'.
161
            If left blank will return the whole dictionary of stats:
162
            {'max': , 'min': , 'mean': , 'mean_std_dev': , 'median_std_dev': }
163
        """
164
        name = plugin_name
165
        if n is not None and n not in (0, 1):
166
            name = name + str(n)
167
        p_num = Statistics.plugin_numbers[name]
168
        return self.get_stats_from_num(p_num, stat)
169
170
    def get_stats_from_num(self, p_num, stat=None):
171
        """Returns stats associated with a certain plugin, given the plugin number (its place in the process list).
172
173
        :param p_num: Plugin  number of the plugin whose associated stats are being fetched.
174
            If p_num <= 0, it is relative to the plugin number of the current plugin being run.
175
            E.g current plugin number = 5, p_num = -2 --> will return stats of the third plugin.
176
        :param stat: Specify the stat parameter you want to fetch, i.e 'max', 'mean', 'median_std_dev'.
177
            If left blank will return the whole dictionary of stats:
178
            {'max': , 'min': , 'mean': , 'mean_std_dev': , 'median_std_dev': }
179
        """
180
        if p_num <= 0:
181
            p_num = Statistics.count + p_num
182
        if stat is not None:
183
            return Statistics.global_stats[p_num][stat]
184
        else:
185
            return Statistics.global_stats[p_num]
186
187
    def get_stats_from_dataset(self, dataset, stat=None, set_num=None):
188
        """Returns stats associated with a dataset.
189
190
        :param dataset: The dataset whose associated stats are being fetched.
191
        :param stat: Specify the stat parameter you want to fetch, i.e 'max', 'mean', 'median_std_dev'.
192
            If left blank will return the whole dictionary of stats:
193
            {'max': , 'min': , 'mean': , 'mean_std_dev': , 'median_std_dev': }
194
        :param set_num: In the (rare) case that there are multiple sets of stats associated with the dataset,
195
            specify which set to return.
196
197
        """
198
        key = "stats"
199
        stats = {}
200
        if set_num is not None and set_num not in (0, 1):
201
            key = key + str(set_num)
202
        stats = dataset.meta_data.get(key)
203
        if stat is not None:
204
            return stats[stat]
205
        else:
206
            return stats
207
208
    def get_data_stats(self):
209
        return Statistics.data_stats
210
211
    def get_volume_stats(self):
212
        return Statistics.volume_stats
213
214
    def get_global_stats(self):
215
        return Statistics.global_stats
216
217
    def _set_pattern_info(self):
218
        """Gathers information about the pattern of the data in the current plugin."""
219
        in_datasets, out_datasets = self.plugin.get_datasets()
220
        try:
221
            self.pattern = self.plugin.parameters['pattern']
222
            if self.pattern == None:
223
                raise KeyError
224
        except KeyError:
225
            if not out_datasets:
226
                self.pattern = None
227
            else:
228
                patterns = out_datasets[0].get_data_patterns()
229
                for pattern in patterns:
230
                    if 1 in patterns.get(pattern)["slice_dims"]:
231
                        self.pattern = pattern
232
                        break
233
        self.calc_stats = False
234
        for dataset in out_datasets:
235
            if bool(set(Statistics.pattern_list) & set(dataset.data_info.get("data_patterns"))):
236
                self.calc_stats = True
237
238
    def _link_stats_to_datasets(self, stats):
239
        """Links the volume wide statistics to the output dataset(s)"""
240
        out_dataset = self.plugin.get_out_datasets()[0]
241
        n_datasets = self.plugin.nOutput_datasets()
242
        i = 2
243
        group_name = "stats"
244
        if n_datasets == 1:
245
            while group_name in list(out_dataset.meta_data.get_dictionary().keys()):
246
                group_name = f"stats{i}"
247
                i += 1
248
            for key in list(stats.keys()):
249
                out_dataset.meta_data.set([group_name, key], stats[key])
250
251
    def _write_stats_to_file(self, slice_stats_array, p_num):
252
        """Writes slice statistics to a h5 file"""
253
        path = Statistics.path
254
        filename = f"{path}/stats_p{p_num}_{self.plugin_name}.h5"
255
        slice_stats_dim = (slice_stats_array.shape[1],)
256
        self.hdf5 = Hdf5Utils(self.plugin.exp)
257
        with h5.File(filename, "a") as h5file:
258
            i = 2
259
            group_name = "/stats"
260
            while group_name in h5file:
261
                group_name = f"/stats{i}"
262
                i += 1
263
            group = h5file.create_group(group_name, track_order=None)
264
            max_ds = self.hdf5.create_dataset_nofill(group, "max", slice_stats_dim, slice_stats_array.dtype)
265
            min_ds = self.hdf5.create_dataset_nofill(group, "min", slice_stats_dim, slice_stats_array.dtype)
266
            mean_ds = self.hdf5.create_dataset_nofill(group, "mean", slice_stats_dim, slice_stats_array.dtype)
267
            std_dev_ds = self.hdf5.create_dataset_nofill(group, "standard_deviation",
268
                                                         slice_stats_dim, slice_stats_array.dtype)
269
            if slice_stats_array.shape[0] == 5:
270
                rmsd_ds = self.hdf5.create_dataset_nofill(group, "RMSD", slice_stats_dim, slice_stats_array.dtype)
271
                rmsd_ds[::] = slice_stats_array[4]
272
            max_ds[::] = slice_stats_array[0]
273
            min_ds[::] = slice_stats_array[1]
274
            mean_ds[::] = slice_stats_array[2]
275
            std_dev_ds[::] = slice_stats_array[3]
276
277
    def _unpad_slice(self, slice1):
278
        """If data is padded in the slice dimension, removes this pad."""
279
        out_datasets = self.plugin.get_out_datasets()
280
        if len(out_datasets) == 1:
281
            out_dataset = out_datasets[0]
282
        else:
283
            for dataset in out_datasets:
284
                if self.pattern in list(dataset.data_info.get(["data_patterns"]).keys()):
285
                    out_dataset = dataset
286
                    break
287
        slice_dims = out_dataset.get_slice_dimensions()
0 ignored issues
show
introduced by
The variable out_dataset does not seem to be defined for all execution paths.
Loading history...
288
        if self.plugin.pcount == 0:
289
            self.slice_list, self.pad = self._get_unpadded_slice_list(slice1, slice_dims)
290
        if self.pad:
291
            for slice_dim in slice_dims:
292
                temp_slice = np.swapaxes(slice1, 0, slice_dim)
293
                temp_slice = temp_slice[self.slice_list[slice_dim]]
294
                slice1 = np.swapaxes(temp_slice, 0, slice_dim)
295
        return slice1
296
297
    def _get_unpadded_slice_list(self, slice1, slice_dims):
298
        """Creates slice object(s) to un-pad slices in the slice dimension(s)."""
299
        slice_list = list(self.plugin.slice_list[0])
300
        pad = False
301
        if len(slice_list) == len(slice1.shape):
302
            for i in slice_dims:
303
                slice_width = self.plugin.slice_list[0][i].stop - self.plugin.slice_list[0][i].start
304
                if slice_width != slice1.shape[i]:
305
                    pad = True
306
                    pad_width = (slice1.shape[i] - slice_width) // 2  # Assuming symmetrical padding
307
                    slice_list[i] = slice(pad_width, pad_width + 1, 1)
308
            return tuple(slice_list), pad
309
        else:
310
            return self.plugin.slice_list[0], pad
311
312
    def _de_list(self, slice1):
313
        """If the slice is in a list, remove it from that list."""
314
        if type(slice1) == list:
315
            if len(slice1) != 0:
316
                slice1 = slice1[0]
317
                slice1 = self._de_list(slice1)
318
        return slice1
319
320
    @classmethod
321
    def _count(cls):
322
        cls.count += 1
323
324
    @classmethod
325
    def _post_chain(cls):
326
        print(cls.data_stats)
327
        print(cls.volume_stats)
328
        print(cls.global_stats)
329
        print(cls.global_residuals)