Test Failed
Pull Request — master (#878)
by Daniil
03:48
created

savu.plugins.stats.statistics   C

Complexity

Total Complexity 57

Size/Duplication

Total Lines 284
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 203
dl 0
loc 284
rs 5.04
c 0
b 0
f 0
wmc 57

19 Methods

Rating   Name   Duplication   Size   Complexity  
A Statistics._write_stats_to_file() 0 22 3
A Statistics.get_data_stats() 0 2 1
A Statistics.get_global_stats() 0 2 1
A Statistics.get_stats_from_dataset() 0 20 4
B Statistics._set_pattern_info() 0 19 8
A Statistics._get_unpadded_slice_list() 0 14 4
A Statistics._setup() 0 17 3
A Statistics.get_stats_from_num() 0 18 3
A Statistics.get_slice_stats() 0 3 1
A Statistics.get_stats() 0 19 4
A Statistics._count() 0 3 1
A Statistics._post_chain() 0 5 1
B Statistics._unpad_slice() 0 19 7
A Statistics._link_stats_to_datasets() 0 15 3
A Statistics._de_list() 0 7 3
A Statistics.__init__() 0 9 2
A Statistics.set_slice_stats() 0 13 2
B Statistics.set_volume_stats() 0 34 5
A Statistics.get_volume_stats() 0 2 1

How to fix   Complexity   

Complexity

Complex classes like savu.plugins.stats.statistics often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
"""
2
.. module:: statistics
3
   :platform: Unix
4
   :synopsis: Contains and processes statistics information for each plugin.
5
6
.. moduleauthor::Jacob Williamson <[email protected]>
7
8
"""
9
10
from savu.plugins.savers.utils.hdf5_utils import Hdf5Utils
11
12
import h5py as h5
13
import numpy as np
14
import os
15
16
17
class Statistics(object):
18
    index_dict = {"max": 0, "min": 1, "mean": 2, "mean_std_dev": 3, "median_std_dev": 4}
19
    key_list = ["max", "min", "mean", "mean_std_dev", "median_std_dev"]
20
    pattern_list = ["SINOGRAM", "PROJECTION", "VOLUME_YZ", "VOLUME_XZ", "VOLUME_XY", "VOLUME_3D", "4D_SCAN", "SINOMOVIE"]
21
    no_stats_plugins = ["BasicOperations", "Mipmap"]
22
23
    def __init__(self, plugin_self):
24
        self.plugin = plugin_self
25
        self.plugin_name = plugin_self.name
26
        self.pad_dims = []
27
        self.stats = {'max': [], 'min': [], 'mean': [], 'standard_deviation': []}
28
        self.calc_stats = False
29
        self._set_pattern_info()
30
        if self.plugin_name in Statistics.no_stats_plugins:
31
            self.calc_stats = False
32
33
    @classmethod
34
    def _setup(cls, exp):
35
        """Sets up the statistics class for the whole experiment (only called once)"""
36
        cls.count = 2
37
        cls.data_stats = {}
38
        cls.volume_stats = {}
39
        cls.global_stats = {}
40
        n_plugins = len(exp.meta_data.plugin_list.plugin_list)
41
        # for n in range(n_plugins):
42
        #    cls.data_stats[n + 1] = [None, None, None, None, None]
43
        #    cls.volume_stats[n + 1] = [None, None, None, None, None]
44
        cls.path = exp.meta_data['out_path']
45
        if cls.path[-1] == '/':
46
            cls.path = cls.path[0:-1]
47
        cls.path = f"{cls.path}/stats"
48
        if not os.path.exists(cls.path):
49
            os.mkdir(cls.path)
50
51
    def set_slice_stats(self, slice1):
52
        """Appends slice stats arrays with the stats parameters of the current slice.
53
54
        :param slice1: The slice whose stats are being calculated.
55
        """
56
        if slice1 is not None:
57
            slice_num = self.plugin.pcount
58
            slice1 = self._de_list(slice1)
59
            slice1 = self._unpad_slice(slice1)
60
            self.stats['max'].append(slice1.max())
61
            self.stats['min'].append(slice1.min())
62
            self.stats['mean'].append(np.mean(slice1))
63
            self.stats['standard_deviation'].append(np.std(slice1))
64
65
    def get_slice_stats(self, stat, slice_num):
66
        """Returns array of stats associated with the processed slices of the current plugin."""
67
        return self.stats[stat][slice_num]
68
69
    def set_volume_stats(self):
70
        """Calculates volume-wide statistics from slice stats, and updates class-wide arrays with these values.
71
        Links volume stats with the output dataset and writes slice stats to file.
72
        """
73
        p_num = Statistics.count
74
        name = self.plugin_name
75
        i = 2
76
        while name in list(Statistics.global_stats.keys()):
77
            name = self.plugin_name + str(i)
78
            i += 1
79
        Statistics.data_stats[p_num] = [None, None, None, None, None]
80
        Statistics.volume_stats[p_num] = [None, None, None, None, None]
81
        if len(self.stats['max']) != 0:
82
            if self.pattern in ['PROJECTION', 'SINOGRAM', 'TANGENTOGRAM', 'SINOMOVIE', '4D_SCAN']:
83
                Statistics.data_stats[p_num][0] = max(self.stats['max'])
84
                Statistics.data_stats[p_num][1] = min(self.stats['min'])
85
                Statistics.data_stats[p_num][2] = np.mean(self.stats['mean'])
86
                Statistics.data_stats[p_num][3] = np.mean(self.stats['standard_deviation'])
87
                Statistics.data_stats[p_num][4] = np.median(self.stats['standard_deviation'])
88
                Statistics.global_stats[p_num] = Statistics.data_stats[p_num]
89
                Statistics.global_stats[name] = Statistics.global_stats[p_num]
90
                self._link_stats_to_datasets(Statistics.global_stats[name])
91
            elif self.pattern in ['VOLUME_XZ', 'VOLUME_XY', 'VOLUME_YZ', 'VOLUME_3D']:
92
                Statistics.volume_stats[p_num][0] = max(self.stats['max'])
93
                Statistics.volume_stats[p_num][1] = min(self.stats['min'])
94
                Statistics.volume_stats[p_num][2] = np.mean(self.stats['mean'])
95
                Statistics.volume_stats[p_num][3] = np.mean(self.stats['standard_deviation'])
96
                Statistics.volume_stats[p_num][4] = np.median(self.stats['standard_deviation'])
97
                Statistics.global_stats[p_num] = Statistics.volume_stats[p_num]
98
                Statistics.global_stats[name] = Statistics.global_stats[p_num]
99
                self._link_stats_to_datasets(Statistics.global_stats[name])
100
        slice_stats = np.array([self.stats['max'], self.stats['min'], self.stats['mean'],
101
                                self.stats['standard_deviation']])
102
        self._write_stats_to_file(slice_stats, p_num)
103
104
    def get_stats(self, plugin_name, n=None, stat=None):
105
        """Returns stats associated with a certain plugin.
106
107
        :param plugin_name: name of the plugin whose associated stats are being fetched.
108
        :param n: In a case where there are multiple instances of <plugin_name> in the process list,
109
            specify the nth instance. Not specifying will select the first (or only) instance.
110
        :param stat: Specify the stat parameter you want to fetch, i.e 'max', 'mean', 'median_std_dev'.
111
            If left blank will return the whole dictionary of stats:
112
            {'max': ,'min': ,'mean': ,'mean_std_dev': ,'median_std_dev': }
113
        """
114
        name = plugin_name
115
        if n is not None and n not in (0, 1):
116
            name = name + str(n)
117
        if stat is not None:
118
            i = Statistics.index_dict[stat]
119
            return Statistics.global_stats[name][i]
120
        else:
121
            stats = dict(zip(Statistics.key_list, Statistics.global_stats[name]))
122
            return stats
123
124
    def get_stats_from_num(self, p_num, stat=None):
125
        """Returns stats associated with a certain plugin, given the plugin number (its place in the process list).
126
127
        :param p_num: Plugin  number of the plugin whose associated stats are being fetched.
128
            If p_num <= 0, it is relative to the plugin number of the current plugin being run.
129
            E.g current plugin number = 5, p_num = -2 --> will return stats of the third plugin.
130
        :param stat: Specify the stat parameter you want to fetch, i.e 'max', 'mean', 'median_std_dev'.
131
            If left blank will return the whole dictionary of stats:
132
            {'max': ,'min': ,'mean': ,'mean_std_dev': ,'median_std_dev': }
133
        """
134
        if p_num <= 0:
135
            p_num = Statistics.count + p_num
136
        if stat is not None:
137
            i = Statistics.index_dict[stat]
138
            return Statistics.global_stats[p_num][i]
139
        else:
140
            stats = dict(zip(Statistics.key_list, Statistics.global_stats[p_num]))
141
            return stats
142
143
    def get_stats_from_dataset(self, dataset, stat=None, set_num=None):
144
        """Returns stats associated with a dataset.
145
146
        :param dataset: The dataset whose associated stats are being fetched.
147
        :param stat: Specify the stat parameter you want to fetch, i.e 'max', 'mean', 'median_std_dev'.
148
            If left blank will return the whole dictionary of stats:
149
            {'max': ,'min': ,'mean': ,'mean_std_dev': ,'median_std_dev': }
150
        :param set_num: In the (rare) case that there are multiple sets of stats associated with the dataset,
151
            specify which set to return.
152
        """
153
        key = "stats"
154
        stats = {}
155
        if set_num is not None:
156
            key = key + str(set_num)
157
        if key in list(dataset.meta_data.dict.keys()):
158
            stats = dataset.meta_data.get(key)
159
        if stat is not None:
160
            return stats[stat]
161
        else:
162
            return stats
163
164
    def get_data_stats(self):
165
        return Statistics.data_stats
166
167
    def get_volume_stats(self):
168
        return Statistics.volume_stats
169
170
    def get_global_stats(self):
171
        return Statistics.global_stats
172
173
    def _set_pattern_info(self):
174
        """Gathers information about the pattern of the data in the current plugin."""
175
        in_datasets, out_datasets = self.plugin.get_datasets()
176
        try:
177
            self.pattern = self.plugin.parameters['pattern']
178
            if self.pattern == None:
179
                raise KeyError
180
        except KeyError:
181
            if not out_datasets:
182
                self.pattern = None
183
            else:
184
                patterns = out_datasets[0].get_data_patterns()
185
                for pattern in patterns:
186
                    if 1 in patterns.get(pattern)["slice_dims"]:
187
                        self.pattern = pattern
188
                        break
189
        for dataset in out_datasets:
190
            if bool(set(Statistics.pattern_list) & set(dataset.data_info.get("data_patterns"))):
191
                self.calc_stats = True
192
193
    def _link_stats_to_datasets(self, stats):
194
        """Links the volume wide statistics to the output dataset(s)"""
195
        out_datasets = self.plugin.get_out_datasets()
196
        n_datasets = self.plugin.nOutput_datasets()
197
        i = 1
198
        group_name = "stats"
199
        if n_datasets == 1:
200
            while group_name in list(out_datasets[0].meta_data.get_dictionary().keys()):
201
                group_name = f"stats{i}"
202
                i += 1
203
            out_datasets[0].data_info.set([group_name, "max"], stats[0])
204
            out_datasets[0].data_info.set([group_name, "min"], stats[1])
205
            out_datasets[0].data_info.set([group_name, "mean"], stats[2])
206
            out_datasets[0].data_info.set([group_name, "mean_std_dev"], stats[3])
207
            out_datasets[0].data_info.set([group_name, "median_std_dev"], stats[4])
208
209
    def _write_stats_to_file(self, slice_stats, p_num):
210
        """Writes slice statistics to a h5 file"""
211
        path = Statistics.path
212
        filename = f"{path}/stats_p{p_num}_{self.plugin_name}.h5"
213
        slice_stats_dim = (slice_stats.shape[1],)
214
        self.hdf5 = Hdf5Utils(self.plugin.exp)
215
        with h5.File(filename, "a") as h5file:
216
            i = 1
217
            group_name = "/stats"
218
            while group_name in h5file:
219
                group_name = f"/stats{i}"
220
                i += 1
221
            group = h5file.create_group(group_name, track_order=None)
222
            max_ds = self.hdf5.create_dataset_nofill(group, "max", slice_stats_dim, slice_stats.dtype)
223
            min_ds = self.hdf5.create_dataset_nofill(group, "min", slice_stats_dim, slice_stats.dtype)
224
            mean_ds = self.hdf5.create_dataset_nofill(group, "mean", slice_stats_dim, slice_stats.dtype)
225
            standard_deviation_ds = self.hdf5.create_dataset_nofill(group, "standard_deviation",
226
                                                                    slice_stats_dim, slice_stats.dtype)
227
            max_ds[::] = slice_stats[0]
228
            min_ds[::] = slice_stats[1]
229
            mean_ds[::] = slice_stats[2]
230
            standard_deviation_ds[::] = slice_stats[3]
231
232
    def _unpad_slice(self, slice1):
233
        """If data is padded in the slice dimension, removes this pad."""
234
        out_datasets = self.plugin.get_out_datasets()
235
        if len(out_datasets) == 1:
236
            out_dataset = out_datasets[0]
237
        else:
238
            for dataset in out_datasets:
239
                if self.pattern in list(dataset.data_info.get(["data_patterns"]).keys()):
240
                    out_dataset = dataset
241
                    break
242
        slice_dims = out_dataset.get_slice_dimensions()
0 ignored issues
show
introduced by
The variable out_dataset does not seem to be defined for all execution paths.
Loading history...
243
        if self.plugin.pcount == 0:
244
            self.slice_list, self.pad = self._get_unpadded_slice_list(slice1, slice_dims)
245
        if self.pad:
246
            for slice_dim in slice_dims:
247
                temp_slice = np.swapaxes(slice1, 0, slice_dim)
248
                temp_slice = temp_slice[self.slice_list[slice_dim]]
249
                slice1 = np.swapaxes(temp_slice, 0, slice_dim)
250
        return slice1
251
252
    def _get_unpadded_slice_list(self, slice1, slice_dims):
253
        """Creates slice object(s) to un-pad slices in the slice dimension(s)."""
254
        slice_list = list(self.plugin.slice_list[0])
255
        pad = False
256
        if len(slice_list) == len(slice1.shape):
257
            for i in slice_dims:
258
                slice_width = self.plugin.slice_list[0][i].stop - self.plugin.slice_list[0][i].start
259
                if slice_width != slice1.shape[i]:
260
                    pad = True
261
                    pad_width = (slice1.shape[i] - slice_width) // 2  # Assuming symmetrical padding
262
                    slice_list[i] = slice(pad_width, pad_width + 1, 1)
263
            return tuple(slice_list), pad
264
        else:
265
            return self.plugin.slice_list[0], pad
266
267
    def _de_list(self, slice1):
268
        """If the slice is in a list, remove it from that list."""
269
        if type(slice1) == list:
270
            if len(slice1) != 0:
271
                slice1 = slice1[0]
272
                slice1 = self._de_list(slice1)
273
        return slice1
274
275
    @classmethod
276
    def _count(cls):
277
        cls.count += 1
278
279
    @classmethod
280
    def _post_chain(cls):
281
        print(cls.data_stats)
282
        print(cls.volume_stats)
283
        print(cls.global_stats)