Test Failed
Pull Request — master (#878)
by
unknown
03:29
created

Statistics._write_stats_to_file3()   C

Complexity

Conditions 9

Size

Total Lines 28
Code Lines 28

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 9
eloc 28
nop 2
dl 0
loc 28
rs 6.6666
c 0
b 0
f 0
1
"""
2
.. module:: statistics
3
   :platform: Unix
4
   :synopsis: Contains and processes statistics information for each plugin.
5
6
.. moduleauthor::Jacob Williamson <[email protected]>
7
8
"""
9
10
from savu.plugins.savers.utils.hdf5_utils import Hdf5Utils
11
from savu.plugins.stats.stats_utils import StatsUtils
12
from savu.core.iterate_plugin_group_utils import check_if_in_iterative_loop
13
14
import h5py as h5
15
import numpy as np
16
import os
17
from mpi4py import MPI
18
19
20
class Statistics(object):
21
    _pattern_list = ["SINOGRAM", "PROJECTION", "TANGENTOGRAM", "VOLUME_YZ", "VOLUME_XZ", "VOLUME_XY", "VOLUME_3D", "4D_SCAN", "SINOMOVIE"]
22
    no_stats_plugins = ["BasicOperations", "Mipmap"]
23
    _key_list = ["max", "min", "mean", "mean_std_dev", "median_std_dev", "NRMSD"]
24
25
26
    def __init__(self):
27
        self.calc_stats = True
28
        self.stats = {'max': [], 'min': [], 'mean': [], 'std_dev': [], 'RSS': [], 'data_points': []}
29
        self.stats_before_processing = {'max': [], 'min': [], 'mean': [], 'std_dev': []}
30
        self.residuals = {'max': [], 'min': [], 'mean': [], 'std_dev': []}
31
        self._repeat_count = 0
32
        self.p_num = None
33
34
    def setup(self, plugin_self):
35
        if plugin_self.name in Statistics.no_stats_plugins:
36
            self.calc_stats = False
37
        if self.calc_stats:
38
            self.plugin = plugin_self
39
            self.plugin_name = plugin_self.name
40
            self._pad_dims = []
41
            self._already_called = False
42
            self._set_pattern_info()
43
        if self.calc_stats:
44
            Statistics._any_stats = True
45
        self._setup_iterative()
46
47
    def _setup_iterative(self):
48
        self._iterative_group = check_if_in_iterative_loop(Statistics.exp)
49
        if self._iterative_group:
50
            if self._iterative_group.start_index == Statistics.count:
51
                Statistics._loop_counter += 1
52
                Statistics.loop_stats.append({"RMSD": np.array([])})
53
            self.l_num = Statistics._loop_counter - 1
54
55
    @classmethod
56
    def _setup_class(cls, exp):
57
        """Sets up the statistics class for the whole plugin chain (only called once)"""
58
        cls.stats_flag = True
59
        cls._any_stats = False
60
        cls.count = 2
61
        cls.global_stats = {}
62
        cls.loop_stats = []
63
        cls.exp = exp
64
        cls.n_plugins = len(exp.meta_data.plugin_list.plugin_list)
65
        for i in range(1, cls.n_plugins + 1):
66
            cls.global_stats[i] = np.array([])
67
        cls.global_residuals = {}
68
        cls.plugin_numbers = {}
69
        cls.plugin_names = {}
70
        cls._loop_counter = 0
71
        cls.path = exp.meta_data['out_path']
72
        if cls.path[-1] == '/':
73
            cls.path = cls.path[0:-1]
74
        cls.path = f"{cls.path}/stats"
75
        if MPI.COMM_WORLD.rank == 0:
76
            if not os.path.exists(cls.path):
77
                os.mkdir(cls.path)
78
79
    def set_slice_stats(self, slice, base_slice):
80
        slice_stats_before = self.calc_slice_stats(base_slice)
81
        slice_stats_after = self.calc_slice_stats(slice, base_slice)
82
        for key in list(self.stats_before_processing.keys()):
83
            self.stats_before_processing[key].append(slice_stats_before[key])
84
        for key in list(self.stats.keys()):
85
            self.stats[key].append(slice_stats_after[key])
86
87
    def calc_slice_stats(self, my_slice, base_slice=None):
88
        """Calculates and returns slice stats for the current slice.
89
90
        :param slice1: The slice whose stats are being calculated.
91
        """
92
        if my_slice is not None:
93
            slice_num = self.plugin.pcount
94
            my_slice = self._de_list(my_slice)
95
            my_slice = self._unpad_slice(my_slice)
96
            slice_stats = {'max': np.amax(my_slice).astype('float64'), 'min': np.amin(my_slice).astype('float64'),
97
                           'mean': np.mean(my_slice), 'std_dev': np.std(my_slice), 'data_points': my_slice.size}
98
            if base_slice is not None:
99
                base_slice = self._de_list(base_slice)
100
                base_slice = self._unpad_slice(base_slice)
101
                rss = self._calc_rss(my_slice, base_slice)
102
            else:
103
                rss = None
104
            slice_stats['RSS'] = rss
105
            return slice_stats
106
        return None
107
108
    def _calc_rss(self, array1, array2):  # residual sum of squares
109
        if array1.shape == array2.shape:
110
            residuals = np.subtract(array1, array2)
111
            rss = sum(value**2 for value in np.nditer(residuals))
112
        else:
113
            #print("Warning: cannot calculate RSS, arrays different sizes.")  # need to make this an actual warning
114
            rss = None
115
        return rss
116
117
    def _rmsd_from_rss(self, rss, n):
118
        return np.sqrt(rss/n)
119
120
    def calc_rmsd(self, array1, array2):
121
        if array1.shape == array2.shape:
122
            rss = self._calc_rss(array1, array2)
123
            rmsd = self._rmsd_from_rss(rss, array1.size)
124
        else:
125
            print("Warning: cannot calculate RMSD, arrays different sizes.")  # need to make this an actual warning
126
            rmsd = None
127
        return rmsd
128
129
    def calc_stats_residuals(self, stats_before, stats_after):
130
        residuals = {'max': None, 'min': None, 'mean': None, 'std_dev': None}
131
        for key in list(residuals.keys()):
132
            residuals[key] = stats_after[key] - stats_before[key]
133
        return residuals
134
135
    def set_stats_residuals(self, residuals):
136
        self.residuals['max'].append(residuals['max'])
137
        self.residuals['min'].append(residuals['min'])
138
        self.residuals['mean'].append(residuals['mean'])
139
        self.residuals['std_dev'].append(residuals['std_dev'])
140
141
    def calc_volume_stats(self, slice_stats):
142
        volume_stats = np.array([max(slice_stats['max']), min(slice_stats['min']), np.mean(slice_stats['mean']),
143
                                np.mean(slice_stats['std_dev']), np.median(slice_stats['std_dev'])])
144
        if None not in slice_stats['RSS']:
145
            total_rss = sum(slice_stats['RSS'])
146
            n = sum(slice_stats['data_points'])
147
            RMSD = self._rmsd_from_rss(total_rss, n)
148
            NRMSD = RMSD / abs(volume_stats[2])  # normalised RMSD (dividing by mean)
149
            volume_stats = np.append(volume_stats, NRMSD)
150
        else:
151
            #volume_stats = np.append(volume_stats, None)
152
            pass
153
        return volume_stats
154
155
    def _set_loop_stats(self):
156
        data_obj1 = list(self._iterative_group._ip_data_dict["iterating"].keys())[0]
157
        data_obj2 = self._iterative_group._ip_data_dict["iterating"][data_obj1]
158
        RMSD = self.calc_rmsd(data_obj1.data, data_obj2.data)
159
        Statistics.loop_stats[self.l_num]["RMSD"] = np.append(Statistics.loop_stats[self.l_num]["RMSD"], RMSD)
160
161
    def set_volume_stats(self):
162
        """Calculates volume-wide statistics from slice stats, and updates class-wide arrays with these values.
163
        Links volume stats with the output dataset and writes slice stats to file.
164
        """
165
        stats = self.stats
166
        combined_stats = self._combine_mpi_stats(stats)
167
        if not self.p_num:
168
            self.p_num = Statistics.count
169
        p_num = self.p_num
170
        name = self.plugin_name
171
        i = 2
172
        if not self._iterative_group:
173
            while name in list(Statistics.plugin_numbers.keys()):
174
                name = self.plugin_name + str(i)
175
                i += 1
176
        elif self._iterative_group._ip_iteration == 0:
177
            while name in list(Statistics.plugin_numbers.keys()):
178
                name = self.plugin_name + str(i)
179
                i += 1
180
        elif self._iterative_group.end_index == p_num:
181
            self._set_loop_stats()
182
        if p_num not in list(Statistics.plugin_names.keys()):
183
            Statistics.plugin_names[p_num] = name
184
        if len(self.stats['max']) != 0:
185
            stats_array = self.calc_volume_stats(combined_stats)
186
            Statistics.global_residuals[p_num] = {}
187
            #before_processing = self.calc_volume_stats(self.stats_before_processing)
188
            #for key in list(before_processing.keys()):
189
            #    Statistics.global_residuals[p_num][key] = Statistics.global_stats[p_num][key] - before_processing[key]
190
191
            if len(Statistics.global_stats[p_num]) == 0:
192
                Statistics.global_stats[p_num] = stats_array
193
            else:
194
                Statistics.global_stats[p_num] = np.vstack([Statistics.global_stats[p_num], stats_array])
195
            Statistics.plugin_numbers[name] = p_num
196
            self._link_stats_to_datasets(Statistics.global_stats[Statistics.plugin_numbers[name]])
197
198
        self._write_stats_to_file3(p_num)
199
        self._already_called = True
200
        self._repeat_count += 1
201
202
    def get_stats(self, p_num, stat=None, instance=1):
203
        """Returns stats associated with a certain plugin, given the plugin number (its place in the process list).
204
205
        :param p_num: Plugin  number of the plugin whose associated stats are being fetched.
206
            If p_num <= 0, it is relative to the plugin number of the current plugin being run.
207
            E.g current plugin number = 5, p_num = -2 --> will return stats of the third plugin.
208
        :param stat: Specify the stat parameter you want to fetch, i.e 'max', 'mean', 'median_std_dev'.
209
            If left blank will return the whole dictionary of stats:
210
            {'max': , 'min': , 'mean': , 'mean_std_dev': , 'median_std_dev': , 'NRMSD' }
211
        :param instance: In cases where there are multiple set of stats associated with a plugin
212
            due to multi-parameters, specify which set you want to retrieve, i.e 3 to retrieve the
213
            stats associated with the third run of a plugin. Pass 'all' to get a list of all sets.
214
        """
215
        if p_num <= 0:
216
            try:
217
                p_num = self.p_num + p_num
218
            except TypeError:
219
                p_num = Statistics.count + p_num
220
        if Statistics.global_stats[p_num].ndim == 1 and instance in (None, 0, 1, "all"):
221
            stats_array = Statistics.global_stats[p_num]
222
        else:
223
            if instance == "all":
224
                stats_list = [self.get_stats(p_num, stat=stat, instance=1)]
225
                n = 2
226
                while n <= Statistics.global_stats[p_num].ndim:
227
                    stats_list.append(self.get_stats(p_num, stat=stat, instance=n))
228
                    n += 1
229
                return stats_list
230
            if instance > 0:
231
                instance -= 1
232
            stats_array = Statistics.global_stats[p_num][instance]
233
        stats_dict = self._array_to_dict(stats_array)
234
        if stat is not None:
235
            return stats_dict[stat]
236
        else:
237
            return stats_dict
238
239
    def get_stats_from_name(self, plugin_name, n=None, stat=None, instance=1):
240
        """Returns stats associated with a certain plugin.
241
242
        :param plugin_name: name of the plugin whose associated stats are being fetched.
243
        :param n: In a case where there are multiple instances of **plugin_name** in the process list,
244
            specify the nth instance. Not specifying will select the first (or only) instance.
245
        :param stat: Specify the stat parameter you want to fetch, i.e 'max', 'mean', 'median_std_dev'.
246
            If left blank will return the whole dictionary of stats:
247
            {'max': , 'min': , 'mean': , 'mean_std_dev': , 'median_std_dev': , 'NRMSD' }
248
        :param instance: In cases where there are multiple set of stats associated with a plugin
249
            due to multi-parameters, specify which set you want to retrieve, i.e 3 to retrieve the
250
            stats associated with the third run of a plugin. Pass 'all' to get a list of all sets.
251
        """
252
        name = plugin_name
253
        if n in (None, 0, 1):
254
            name = name + str(n)
255
        p_num = Statistics.plugin_numbers[name]
256
        return self.get_stats(p_num, stat, instance)
257
258
    def get_stats_from_dataset(self, dataset, stat=None, instance=None):
259
        """Returns stats associated with a dataset.
260
261
        :param dataset: The dataset whose associated stats are being fetched.
262
        :param stat: Specify the stat parameter you want to fetch, i.e 'max', 'mean', 'median_std_dev'.
263
            If left blank will return the whole dictionary of stats:
264
            {'max': , 'min': , 'mean': , 'mean_std_dev': , 'median_std_dev': , 'NRMSD'}
265
        :param instance: In cases where there are multiple set of stats associated with a dataset
266
            due to multi-parameters, specify which set you want to retrieve, i.e 3 to retrieve the
267
            stats associated with the third run of a plugin. Pass 'all' to get a list of all sets.
268
269
        """
270
        key = "stats"
271
        stats = {}
272
        if instance not in (None, 0, 1):
273
            if instance == "all":
274
                stats = [self.get_stats_from_dataset(dataset, stat=stat, instance=1)]
275
                n = 2
276
                while ("stats" + str(n)) in list(dataset.meta_data.get_dictionary().keys()):
277
                    stats.append(self.get_stats_from_dataset(dataset, stat=stat, instance=n))
278
                    n += 1
279
                return stats
280
            key = key + str(instance)
281
        stats = dataset.meta_data.get(key)
282
        if stat is not None:
283
            return stats[stat]
284
        else:
285
            return stats
286
287
    def _combine_mpi_stats(self, slice_stats):
288
        comm = MPI.COMM_WORLD
289
        combined_stats_list = comm.allgather(slice_stats)
290
        combined_stats = {'max': [], 'min': [], 'mean': [], 'std_dev': [], 'RSS': [], 'data_points': []}
291
        for single_stats in combined_stats_list:
292
            for key in list(single_stats.keys()):
293
                combined_stats[key] += single_stats[key]
294
        return combined_stats
295
296
    def _array_to_dict(self, stats_array):
297
        stats_dict = {}
298
        for i, value in enumerate(stats_array):
299
            stats_dict[Statistics._key_list[i]] = value
300
        return stats_dict
301
302
    def _set_pattern_info(self):
303
        """Gathers information about the pattern of the data in the current plugin."""
304
        in_datasets, out_datasets = self.plugin.get_datasets()
305
        try:
306
            self.pattern = self.plugin.parameters['pattern']
307
            if self.pattern == None:
308
                raise KeyError
309
        except KeyError:
310
            if not out_datasets:
311
                self.pattern = None
312
            else:
313
                patterns = out_datasets[0].get_data_patterns()
314
                for pattern in patterns:
315
                    if 1 in patterns.get(pattern)["slice_dims"]:
316
                        self.pattern = pattern
317
                        break
318
        self.calc_stats = False
319
        for dataset in out_datasets:
320
            if bool(set(Statistics._pattern_list) & set(dataset.data_info.get("data_patterns"))):
321
                self.calc_stats = True
322
323
    def _link_stats_to_datasets(self, stats):
324
        """Links the volume wide statistics to the output dataset(s)"""
325
        out_dataset = self.plugin.get_out_datasets()[0]
326
        n_datasets = self.plugin.nOutput_datasets()
327
        if self._repeat_count > 0:
328
            stats_dict = self._array_to_dict(stats[self._repeat_count])
329
        else:
330
            stats_dict = self._array_to_dict(stats)
331
        i = 2
332
        group_name = "stats"
333
        #out_dataset.data_info.set([group_name], stats)
334
        if n_datasets == 1:
335
            while group_name in list(out_dataset.meta_data.get_dictionary().keys()):
336
                group_name = f"stats{i}"
337
                i += 1
338
            for key in list(stats_dict.keys()):
339
                out_dataset.meta_data.set([group_name, key], stats_dict[key])
340
341
    def _write_stats_to_file2(self, p_num):
342
        path = Statistics.path
343
        filename = f"{path}/stats.h5"
344
        stats = Statistics.global_stats[p_num]
345
        array_dim = stats.shape
346
        self.hdf5 = Hdf5Utils(self.plugin.exp)
347
        group_name = f"{p_num}-{self.plugin_name}-stats"
348
        with h5.File(filename, "a") as h5file:
349
            if group_name not in h5file:
350
                group = h5file.create_group(group_name, track_order=None)
351
                dataset = self.hdf5.create_dataset_nofill(group, "stats", array_dim, stats.dtype)
352
                dataset[::] = stats[::]
353
            else:
354
                group = h5file[group_name]
355
356
357
    @classmethod
358
    def _write_stats_to_file4(cls):
359
        path = cls.path
360
        filename = f"{path}/stats.h5"
361
        stats = cls.global_stats
362
        cls.hdf5 = Hdf5Utils(cls.exp)
363
        for i in range(5):
364
            array = np.array([])
365
            stat = cls._key_list[i]
366
            for key in list(stats.keys()):
367
                if len(stats[key]) != 0:
368
                    if stats[key].ndim == 1:
369
                        array = np.append(array, stats[key][i])
370
                    else:
371
                        array = np.append(array, stats[key][0][i])
372
            array_dim = array.shape
373
            group_name = f"all-{stat}"
374
            with h5.File(filename, "a") as h5file:
375
                group = h5file.create_group(group_name, track_order=None)
376
                dataset = cls.hdf5.create_dataset_nofill(group, stat, array_dim, array.dtype)
377
                dataset[::] = array[::]
378
379
    def _write_stats_to_file3(self, p_num=None):
380
        if not p_num:
381
            p_num = self.p_num
382
        path = Statistics.path
383
        filename = f"{path}/stats.h5"
384
        stats = self.global_stats[p_num]
385
        self.hdf5 = Hdf5Utils(self.exp)
386
        with h5.File(filename, "a", driver="mpio", comm=MPI.COMM_WORLD) as h5file:
387
            group = h5file.require_group("stats")
388
            if stats.shape != (0,):
389
                if str(p_num) in list(group.keys()):
390
                    del group[str(p_num)]
391
                dataset = group.create_dataset(str(p_num), shape=stats.shape, dtype=stats.dtype)
392
                dataset[::] = stats[::]
393
                dataset.attrs.create("plugin_name", self.plugin_names[p_num])
394
                dataset.attrs.create("pattern", self.pattern)
395
            if self._iterative_group:
396
                l_stats = Statistics.loop_stats[self.l_num]
397
                group1 = h5file.require_group("iterative")
398
                if self._iterative_group._ip_iteration == self._iterative_group._ip_fixed_iterations - 1\
399
                        and self.p_num == self._iterative_group.end_index:
400
                    dataset1 = group1.create_dataset(str(self.l_num), shape=l_stats["RMSD"].shape, dtype=l_stats["RMSD"].dtype)
401
                    dataset1[::] = l_stats["RMSD"][::]
402
                    loop_plugins = []
403
                    for i in range(self._iterative_group.start_index + 1, self._iterative_group.end_index + 1):
404
                        loop_plugins.append(self.plugin_names[i])
405
                    dataset1.attrs.create("loop_plugins", loop_plugins)
406
                    dataset.attrs.create("n_loop_plugins", len(loop_plugins))
0 ignored issues
show
introduced by
The variable dataset does not seem to be defined in case stats.shape != TupleNode on line 388 is False. Are you sure this can never be the case?
Loading history...
407
408
    def write_slice_stats_to_file(self, slice_stats=None, p_num=None):
409
        """Writes slice statistics to a h5 file"""
410
        if not slice_stats:
411
            slice_stats = self.stats
412
        if not p_num:
413
            p_num = self.count
414
            plugin_name = self.plugin_name
415
        else:
416
            plugin_name = self.plugin_names[p_num]
417
        combined_stats = self._combine_mpi_stats(slice_stats)
418
        slice_stats_arrays = {}
419
        datasets = {}
420
        path = Statistics.path
421
        filename = f"{path}/stats_p{p_num}_{plugin_name}.h5"
422
        self.hdf5 = Hdf5Utils(self.plugin.exp)
423
        with h5.File(filename, "a", driver="mpio", comm=MPI.COMM_WORLD) as h5file:
424
            i = 2
425
            group_name = "/stats"
426
            while group_name in h5file:
427
                group_name = f"/stats{i}"
428
                i += 1
429
            group = h5file.create_group(group_name, track_order=None)
430
            for key in list(combined_stats.keys()):
431
                slice_stats_arrays[key] = np.array(combined_stats[key])
432
                datasets[key] = self.hdf5.create_dataset_nofill(group, key, (len(slice_stats_arrays[key]),), slice_stats_arrays[key].dtype)
433
                datasets[key][::] = slice_stats_arrays[key]
434
435
    def _unpad_slice(self, slice1):
436
        """If data is padded in the slice dimension, removes this pad."""
437
        out_datasets = self.plugin.get_out_datasets()
438
        if len(out_datasets) == 1:
439
            out_dataset = out_datasets[0]
440
        else:
441
            for dataset in out_datasets:
442
                if self.pattern in list(dataset.data_info.get(["data_patterns"]).keys()):
443
                    out_dataset = dataset
444
                    break
445
        slice_dims = out_dataset.get_slice_dimensions()
0 ignored issues
show
introduced by
The variable out_dataset does not seem to be defined for all execution paths.
Loading history...
446
        if self.plugin.pcount == 0:
447
            self._slice_list, self._pad = self._get_unpadded_slice_list(slice1, slice_dims)
448
        if self._pad:
449
            #for slice_dim in slice_dims:
450
            slice_dim = slice_dims[0]
451
            temp_slice = np.swapaxes(slice1, 0, slice_dim)
452
            temp_slice = temp_slice[self._slice_list[slice_dim]]
453
            slice1 = np.swapaxes(temp_slice, 0, slice_dim)
454
        return slice1
455
456
    def _get_unpadded_slice_list(self, slice1, slice_dims):
457
        """Creates slice object(s) to un-pad slices in the slice dimension(s)."""
458
        slice_list = list(self.plugin.slice_list[0])
459
        pad = False
460
        if len(slice_list) == len(slice1.shape):
461
            #for i in slice_dims:
462
            i = slice_dims[0]
463
            slice_width = self.plugin.slice_list[0][i].stop - self.plugin.slice_list[0][i].start
464
            if slice_width != slice1.shape[i]:
465
                pad = True
466
                pad_width = (slice1.shape[i] - slice_width) // 2  # Assuming symmetrical padding
467
                slice_list[i] = slice(pad_width, pad_width + 1, 1)
468
            return tuple(slice_list), pad
469
        else:
470
            return self.plugin.slice_list[0], pad
471
472
    def _de_list(self, slice1):
473
        """If the slice is in a list, remove it from that list."""
474
        if type(slice1) == list:
475
            if len(slice1) != 0:
476
                slice1 = slice1[0]
477
                slice1 = self._de_list(slice1)
478
        return slice1
479
480
481
    @classmethod
482
    def _count(cls):
483
        cls.count += 1
484
485
    @classmethod
486
    def _post_chain(cls):
487
        print(Statistics.loop_stats)
488
        if cls._any_stats & cls.stats_flag:
489
            stats_utils = StatsUtils()
490
            stats_utils.generate_figures(f"{cls.path}/stats.h5", cls.path)
491