savu.data.stats.statistics.Statistics._count()   A
last analyzed

Complexity

Conditions 1

Size

Total Lines 3
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 3
nop 1
dl 0
loc 3
rs 10
c 0
b 0
f 0
1
"""
2
.. module:: statistics
3
   :platform: Unix
4
   :synopsis: Contains and processes statistics information for each plugin.
5
6
.. moduleauthor::Jacob Williamson <[email protected]>
7
8
"""
9
import logging
10
11
from savu.plugins.savers.utils.hdf5_utils import Hdf5Utils
12
from savu.data.stats.stats_utils import StatsUtils
13
from savu.core.iterate_plugin_group_utils import check_if_in_iterative_loop
14
import savu.core.utils as cu
15
16
import time
17
import h5py as h5
18
import numpy as np
19
import os
20
from mpi4py import MPI
21
from collections import OrderedDict
22
23
class Statistics(object):
24
    _pattern_list = ["SINOGRAM", "PROJECTION", "TANGENTOGRAM", "VOLUME_YZ", "VOLUME_XZ", "VOLUME_XY", "VOLUME_3D", "4D_SCAN", "SINOMOVIE"]
25
    _no_stats_plugins = ["BasicOperations", "Mipmap", "UnetApply"]
26
    _possible_stats = ("max", "min", "mean", "mean_std_dev", "median_std_dev", "NRMSD", "zeros", "zeros%", "range_used")  # list of possible stats
27
    _volume_to_slice = {"max": "max", "min": "min", "mean": "mean", "mean_std_dev": "std_dev",
28
                        "median_std_dev": "std_dev", "NRMSD": ("RSS", "data_points", "max", "min"),
29
                        "zeros": ("zeros", "data_points"), "zeros%": ("zeros", "data_points"),
30
                        "range_used": ("min", "max")}  # volume stat: required slice stat(s)
31
    #_savers = ["Hdf5Saver", "ImageSaver", "MrcSaver", "TiffSaver", "XrfSaver"]
32
    _has_setup = False
33
34
35
    def __init__(self):
36
37
        self.calc_stats = True
38
        self.stats_before_processing = {'max': [], 'min': [], 'mean': [], 'std_dev': []}
39
        self.residuals = {'max': [], 'min': [], 'mean': [], 'std_dev': []}
40
        self._repeat_count = 0
41
        self.plugin = None
42
        self.p_num = None
43
        self.stats_key = ["max", "min", "mean", "mean_std_dev", "median_std_dev", "RMSD"]
44
        self.slice_stats_key = None
45
        self.stats = None
46
        self.GPU = False
47
        self._iterative_group = None
48
49
    def setup(self, plugin_self, pattern=None):
50
        if not Statistics._has_setup:
51
            self._setup_class(plugin_self.exp)
52
        self.plugin_name = plugin_self.name
53
        self.p_num = Statistics.count
54
        self.plugin = plugin_self
55
        self.set_stats_key(self.stats_key)
56
        self.stats = {stat: [] for stat in self.slice_stats_key}
57
        if plugin_self.name in Statistics._no_stats_plugins:
58
            self.calc_stats = False
59
        if self.calc_stats:
60
            self._pad_dims = []
61
            self._already_called = False
62
            if pattern is not None:
63
                self.pattern = pattern
64
            else:
65
                self._set_pattern_info()
66
        if self.calc_stats:
67
            Statistics._any_stats = True
68
        self._setup_4d()
69
        self._setup_iterative()
70
71
    def _setup_iterative(self):
72
        self._iterative_group = check_if_in_iterative_loop(Statistics.exp)
73
        if self._iterative_group:
74
            if self._iterative_group.start_index == Statistics.count:
75
                Statistics._loop_counter += 1
76
                Statistics.loop_stats.append({"NRMSD": np.array([])})
77
            self.l_num = Statistics._loop_counter - 1
78
79
    def _setup_4d(self):
80
        try:
81
            in_dataset, out_dataset = self.plugin.get_datasets()
82
            if in_dataset[0].data_info["nDims"] == 4 and len(out_dataset) != 0:
83
                self._4d = True
84
                shape = out_dataset[0].data_info["shape"]
85
                self._volume_total_points = 1
86
                for i in shape[:-1]:
87
                    self._volume_total_points *= i
88
            else:
89
                self._4d = False
90
        except KeyError:
91
            self._4d = False
92
93
    @classmethod
94
    def _setup_class(cls, exp):
95
        """Sets up the statistics class for the whole plugin chain (only called once)"""
96
        if exp.meta_data.get("stats") == "on":
97
            cls._stats_flag = True
98
        elif exp.meta_data.get("stats") == "off":
99
            cls._stats_flag = False
100
        cls._any_stats = False
101
        cls.exp = exp
102
        cls.count = 2
103
        cls.global_stats = {}
104
        cls.global_times = {}
105
        cls.loop_stats = []
106
        cls.n_plugins = len(exp.meta_data.plugin_list.plugin_list)
107
        for i in range(1, cls.n_plugins + 1):
108
            cls.global_stats[i] = {}
109
            cls.global_times[i] = 0
110
        cls.global_residuals = {}
111
        cls.plugin_numbers = {}
112
        cls.plugin_names = {}
113
        cls._loop_counter = 0
114
        cls.path = exp.meta_data['out_path']
115
        if cls.path[-1] == '/':
116
            cls.path = cls.path[0:-1]
117
        cls.path = f"{cls.path}/stats"
118
        if MPI.COMM_WORLD.rank == 0:
119
            if not os.path.exists(cls.path):
120
                os.mkdir(cls.path)
121
        cls._has_setup = True
122
123
    def get_stats(self, p_num=None, stat=None, instance=-1):
124
        """Returns stats associated with a certain plugin, given the plugin number (its place in the process list).
125
126
        :param p_num: Plugin  number of the plugin whose associated stats are being fetched.
127
            If p_num <= 0, it is relative to the plugin number of the current plugin being run.
128
            E.g current plugin number = 5, p_num = -2 --> will return stats of the third plugin.
129
            By default will gather stats for the current plugin.
130
        :param stat: Specify the stat parameter you want to fetch, i.e 'max', 'mean', 'median_std_dev'.
131
            If left blank will return the whole dictionary of stats:
132
            {'max': , 'min': , 'mean': , 'mean_std_dev': , 'median_std_dev': , 'NRMSD': }
133
        :param instance: In cases where there are multiple set of stats associated with a plugin
134
            due to iterative loops or multi-parameters, specify which set you want to retrieve, i.e 3 to retrieve the
135
            stats associated with the third run of a plugin. Pass 'all' to get a list of all sets.
136
            By default will retrieve the most recent set.
137
        """
138
        if p_num is None:
139
            p_num = self.p_num
140
        if p_num <= 0:
141
            try:
142
                p_num = self.p_num + p_num
143
            except TypeError:
144
                p_num = Statistics.count + p_num
145
        if instance == "all":
146
            stats_list = [self.get_stats(p_num, stat=stat, instance=1)]
147
            n = 2
148
            while n <= len(Statistics.global_stats[p_num]):
149
                stats_list.append(self.get_stats(p_num, stat=stat, instance=n))
150
                n += 1
151
            return stats_list
152
        if instance > 0:
153
            instance -= 1
154
        stats_dict = Statistics.global_stats[p_num][instance]
155
        if stat is not None:
156
            return stats_dict[stat]
157
        else:
158
            return stats_dict
159
160
    def get_stats_from_name(self, plugin_name, n=None, stat=None, instance=-1):
161
        """Returns stats associated with a certain plugin.
162
163
        :param plugin_name: name of the plugin whose associated stats are being fetched.
164
        :param n: In a case where there are multiple instances of **plugin_name** in the process list,
165
            specify the nth instance. Not specifying will select the first (or only) instance.
166
        :param stat: Specify the stat parameter you want to fetch, i.e 'max', 'mean', 'median_std_dev'.
167
            If left blank will return the whole dictionary of stats:
168
            {'max': , 'min': , 'mean': , 'mean_std_dev': , 'median_std_dev': , 'NRMSD': }
169
        :param instance: In cases where there are multiple set of stats associated with a plugin
170
            due to iterative loops or multi-parameters, specify which set you want to retrieve, i.e 3 to retrieve the
171
            stats associated with the third run of a plugin. Pass 'all' to get a list of all sets.
172
            By default will retrieve the most recent set.
173
        """
174
        name = plugin_name
175
        if n not in (None, 0, 1):
176
            name = name + str(n)
177
        p_num = Statistics.plugin_numbers[name]
178
        return self.get_stats(p_num, stat, instance)
179
180
    def get_stats_from_dataset(self, dataset, stat=None, instance=-1):
181
        """Returns stats associated with a dataset.
182
183
        :param dataset: The dataset whose associated stats are being fetched.
184
        :param stat: Specify the stat parameter you want to fetch, i.e 'max', 'mean', 'median_std_dev'.
185
            If left blank will return the whole dictionary of stats:
186
            {'max': , 'min': , 'mean': , 'mean_std_dev': , 'median_std_dev': , 'NRMSD': }
187
        :param instance: In cases where there are multiple set of stats associated with a dataset
188
            due to iterative loops or multi-parameters, specify which set you want to retrieve, i.e 3 to retrieve the
189
            stats associated with the third run of a plugin. Pass 'all' to get a list of all sets.
190
            By default will retrieve the most recent set.
191
        """
192
        stats_list = [dataset.meta_data.get("stats")]
193
        n = 2
194
        while ("stats" + str(n)) in list(dataset.meta_data.get_dictionary().keys()):
195
            stats_list.append(dataset.meta_data.get("stats" + str(n)))
196
            n += 1
197
        if stat:
198
            for i in range(len(stats_list)):
199
                stats_list[i] = stats_list[i][stat]
200
        if instance in (None, 0, 1):
201
            stats = stats_list[0]
202
        elif instance == "all":
203
            stats = stats_list
204
        else:
205
            if instance >= 2:
206
                instance -= 1
207
            stats = stats_list[instance]
208
        return stats
209
210
    def set_stats_key(self, stats_key):
211
        """Changes which stats are to be calculated for the current plugin.
212
213
        :param stats_key: List of stats to be calculated.
214
        """
215
        valid = Statistics._possible_stats
216
        stats_key = sorted(set(valid).intersection(stats_key), key=lambda stat: valid.index(stat))
217
        self.stats_key = stats_key
218
        self.slice_stats_key = list(set(self._flatten(list(Statistics._volume_to_slice[stat] for stat in stats_key))))
219
        if "data_points" not in self.slice_stats_key:
220
            self.slice_stats_key.append("data_points")  # Data points is essential
221
222
    def set_slice_stats(self, my_slice, base_slice=None, pad=True):
223
        """Sets slice stats for the current slice.
224
225
        :param my_slice: The slice whose stats are being set.
226
        :param base_slice: Provide a base slice to calculate residuals from, to calculate RMSD.
227
        :param pad: Specify whether slice is padded or not (usually can leave as True even if slice is not padded).
228
        """
229
        my_slice = self._de_list(my_slice)
230
        if 0 not in my_slice.shape:
231
            try:
232
                slice_stats = self.calc_slice_stats(my_slice, base_slice=base_slice, pad=pad)
233
            except:
234
                pass
235
            if slice_stats is not None:
236
                for key, value in slice_stats.items():
237
                    self.stats[key].append(value)
238
                if self._4d:
239
                    if sum(self.stats["data_points"]) >= self._volume_total_points:
240
                        self.set_volume_stats()
241
            else:
242
                self.calc_stats = False
243
        else:
244
            self.calc_stats = False
245
246
    def calc_slice_stats(self, my_slice, base_slice=None, pad=True):
247
        """Calculates and returns slice stats for the current slice.
248
249
        :param my_slice: The slice whose stats are being calculated.
250
        :param base_slice: Provide a base slice to calculate residuals from, to calculate RMSD.
251
        :param pad: Specify whether slice is padded or not (usually can leave as True even if slice is not padded).
252
        """
253
        if my_slice is not None:
254
            my_slice = self._de_list(my_slice)
255
            if pad:
256
                my_slice = self._unpad_slice(my_slice)
257
            slice_stats = {}
258
            if "max" in self.slice_stats_key:
259
                slice_stats["max"] = np.amax(my_slice).astype('float64')
260
            if "min" in self.slice_stats_key:
261
                slice_stats["min"] = np.amin(my_slice).astype('float64')
262
            if "mean" in self.slice_stats_key:
263
                slice_stats["mean"] = np.mean(my_slice)
264
            if "std_dev" in self.slice_stats_key:
265
                slice_stats["std_dev"] = np.std(my_slice)
266
            if "zeros" in self.slice_stats_key:
267
                slice_stats["zeros"] = self.calc_zeros(my_slice)
268
            if "data_points" in self.slice_stats_key:
269
                slice_stats["data_points"] = my_slice.size
270
            if "RSS" in self.slice_stats_key and base_slice is not None:
271
                base_slice = self._de_list(base_slice)
272
                base_slice = self._unpad_slice(base_slice)
273
                slice_stats["RSS"] = self.calc_rss(my_slice, base_slice)
274
            if "dtype" not in self.stats:
275
                self.stats["dtype"] = my_slice.dtype
276
            return slice_stats
277
        return None
278
279
    @staticmethod
280
    def calc_zeros(my_slice):
281
        return my_slice.size - np.count_nonzero(my_slice)
282
283
    @staticmethod
284
    def calc_rss(array1, array2):  # residual sum of squares
285
        if array1.shape == array2.shape:
286
            residuals = np.subtract(array1, array2)
287
            rss = np.sum(residuals.flatten() ** 2)
288
        else:
289
            logging.debug("Cannot calculate RSS, arrays different sizes.")
290
            rss = None
291
        return rss
292
293
    @staticmethod
294
    def rmsd_from_rss(rss, n):
295
        return np.sqrt(rss/n)
296
297
    def calc_rmsd(self, array1, array2):
298
        if array1.shape == array2.shape:
299
            rss = self.calc_rss(array1, array2)
300
            rmsd = self.rmsd_from_rss(rss, array1.size)
301
        else:
302
            logging.error("Cannot calculate RMSD, arrays different sizes.")
303
            rmsd = None
304
        return rmsd
305
306
    def calc_stats_residuals(self, stats_before, stats_after):  # unused
307
        residuals = {'max': None, 'min': None, 'mean': None, 'std_dev': None}
308
        for key in list(residuals.keys()):
309
            residuals[key] = stats_after[key] - stats_before[key]
310
        return residuals
311
312
    def set_stats_residuals(self, residuals):  # unused
313
        self.residuals['max'].append(residuals['max'])
314
        self.residuals['min'].append(residuals['min'])
315
        self.residuals['mean'].append(residuals['mean'])
316
        self.residuals['std_dev'].append(residuals['std_dev'])
317
318
    def calc_volume_stats(self, slice_stats):
319
        """Calculates and returns volume-wide stats from slice-wide stats.
320
321
        :param slice_stats: The slice-wide stats that the volume-wide stats are calculated from.
322
        """
323
        slice_stats = slice_stats
324
        volume_stats = {}
325
        if "max" in self.stats_key:
326
            volume_stats["max"] = max(slice_stats["max"])
327
        if "min" in self.stats_key:
328
            volume_stats["min"] = min(slice_stats["min"])
329
        if "mean" in self.stats_key:
330
            volume_stats["mean"] = np.mean(slice_stats["mean"])
331
        if "mean_std_dev" in self.stats_key:
332
            volume_stats["mean_std_dev"] = np.mean(slice_stats["std_dev"])
333
        if "median_std_dev" in self.stats_key:
334
            volume_stats["median_std_dev"] = np.median(slice_stats["std_dev"])
335
        if "NRMSD" in self.stats_key and None not in slice_stats["RSS"]:
336
            total_rss = sum(slice_stats["RSS"])
337
            n = sum(slice_stats["data_points"])
338
            RMSD = self.rmsd_from_rss(total_rss, n)
339
            the_range = volume_stats["max"] - volume_stats["min"]
340
            NRMSD = RMSD / the_range  # normalised RMSD (dividing by the range)
341
            volume_stats["NRMSD"] = NRMSD
342
        if "zeros" in self.stats_key:
343
            volume_stats["zeros"] = sum(slice_stats["zeros"])
344
        if "zeros%" in self.stats_key:
345
            volume_stats["zeros%"] = (volume_stats["zeros"] / sum(slice_stats["data_points"])) * 100
346
        if "range_used" in self.stats_key:
347
            my_range = volume_stats["max"] - volume_stats["min"]
348
            if "int" in str(self.stats["dtype"]):
349
                possible_max = np.iinfo(self.stats["dtype"]).max
350
                possible_min = np.iinfo(self.stats["dtype"]).min
351
                self.stats["possible_max"] = possible_max
352
                self.stats["possible_min"] = possible_min
353
            elif "float" in str(self.stats["dtype"]):
354
                possible_max = np.finfo(self.stats["dtype"]).max
355
                possible_min = np.finfo(self.stats["dtype"]).min
356
                self.stats["possible_max"] = possible_max
357
                self.stats["possible_min"] = possible_min
358
            possible_range = possible_max - possible_min
0 ignored issues
show
introduced by
The variable possible_min does not seem to be defined for all execution paths.
Loading history...
introduced by
The variable possible_max does not seem to be defined for all execution paths.
Loading history...
359
            volume_stats["range_used"] = (my_range / possible_range) * 100
360
        return volume_stats
361
362
    def _set_loop_stats(self):
363
        # NEED TO CHANGE THIS - MUST USE SLICES (unused)
364
        data_obj1 = list(self._iterative_group._ip_data_dict["iterating"].keys())[0]
365
        data_obj2 = self._iterative_group._ip_data_dict["iterating"][data_obj1]
366
        RMSD = self.calc_rmsd(data_obj1.data, data_obj2.data)
367
        the_range = self.get_stats(self.p_num, stat="max", instance=self._iterative_group._ip_iteration) -\
368
                self.get_stats(self.p_num, stat="min", instance=self._iterative_group._ip_iteration)
369
        NRMSD = RMSD/the_range
370
        Statistics.loop_stats[self.l_num]["NRMSD"] = np.append(Statistics.loop_stats[self.l_num]["NRMSD"], NRMSD)
371
372
    def set_volume_stats(self):
373
        """Calculates volume-wide statistics from slice stats, and updates class-wide arrays with these values.
374
        Links volume stats with the output dataset and writes slice stats to file.
375
        """
376
        stats = self.stats
377
        comm = self.plugin.get_communicator()
378
        combined_stats = self._combine_mpi_stats(stats, comm=comm)
379
        if not self.p_num:
380
            self.p_num = Statistics.count
381
        p_num = self.p_num
382
        name = self.plugin_name
383
        i = 2
384
        if not self._iterative_group:
385
            while name in list(Statistics.plugin_numbers.keys()):
386
                name = self.plugin_name + str(i)
387
                i += 1
388
        elif self._iterative_group._ip_iteration == 0:
389
            while name in list(Statistics.plugin_numbers.keys()):
390
                name = self.plugin_name + str(i)
391
                i += 1
392
393
        if p_num not in list(Statistics.plugin_names.keys()):
394
            Statistics.plugin_names[p_num] = name
395
        Statistics.plugin_numbers[name] = p_num
396
        if len(combined_stats['max']) != 0:
397
            stats_dict = self.calc_volume_stats(combined_stats)
398
            Statistics.global_residuals[p_num] = {}
399
            #before_processing = self.calc_volume_stats(self.stats_before_processing)
400
            #for key in list(before_processing.keys()):
401
            #    Statistics.global_residuals[p_num][key] = Statistics.global_stats[p_num][key] - before_processing[key]
402
403
            if len(Statistics.global_stats[p_num]) == 0:
404
                Statistics.global_stats[p_num] = [stats_dict]
405
            else:
406
                Statistics.global_stats[p_num].append(stats_dict)
407
408
            self._link_stats_to_datasets(stats_dict, self._iterative_group)
409
            self._write_stats_to_file(p_num, comm=comm)
410
        self._already_called = True
411
        self._repeat_count += 1
412
        if self._iterative_group or self._4d:
413
            self.stats = {stat: [] for stat in self.slice_stats_key}
414
415
    def start_time(self):
416
        """Called at the start of a plugin."""
417
        self.t0 = time.time()
418
419
    def stop_time(self):
420
        """Called at the ebd of a plugin."""
421
        self.t1 = time.time()
422
        elapsed = round(self.t1 - self.t0, 1)
423
        if self._stats_flag and self.calc_stats:
424
            self.set_time(elapsed)
425
426
    def set_time(self, seconds):
427
        """Sets time taken for plugin to complete."""
428
        Statistics.global_times[self.p_num] += seconds  # Gives total time for a plugin in a loop
429
        #print(f"{self.p_num}, {seconds}")
430
        comm = self.plugin.get_communicator()
431
        try:
432
            rank = comm.rank
433
        except (MPI.Exception, AttributeError):        # Sometimes get_communicator() returns an invalid communicator.
434
            comm = MPI.COMM_WORLD    # So using COMM_WORLD in this case.
435
        self._write_times_to_file(comm)
436
437
    def _combine_mpi_stats(self, slice_stats, comm=MPI.COMM_WORLD):
438
        """Combines slice stats from different processes, so volume stats can be calculated.
439
440
        :param slice_stats: slice stats (each process will have a different set).
441
        :param comm: MPI communicator being used.
442
        """
443
        combined_stats_list = comm.allgather(slice_stats)
444
        combined_stats = {stat: [] for stat in self.slice_stats_key}
445
        for single_stats in combined_stats_list:
446
            for key in self.slice_stats_key:
447
                combined_stats[key] += single_stats[key]
448
        return combined_stats
449
450
    def _array_to_dict(self, stats_array, key_list=None):
451
        """Converts an array of stats to a dictionary of stats.
452
453
        :param stats_array: Array of stats to be converted.
454
        :param key_list: List of keys indicating the names of the stats in the stats_array.
455
        """
456
        if key_list is None:
457
            key_list = self.stats_key
458
        stats_dict = {}
459
        for i, value in enumerate(stats_array):
460
            stats_dict[key_list[i]] = value
461
        return stats_dict
462
463
    def _dict_to_array(self, stats_dict):
464
        """Converts stats dict into a numpy array (keys will be lost).
465
466
        :param stats_dict: dictionary of stats.
467
        """
468
        return np.array(list(stats_dict.values()))
469
470
    def _broadcast_gpu_stats(self, gpu_processes, process):
471
        """During GPU plugins, most processes are unused, and don't have access to stats.
472
        This method shares stats between processes so all have access to stats.
473
474
        :param gpu_processes: List that determines whether a process is a GPU process.
475
        :param process: Process number.
476
        """
477
        p_num = self.p_num
478
        Statistics.global_stats[p_num] = MPI.COMM_WORLD.bcast(Statistics.global_stats[p_num], root=0)
479
        if not gpu_processes[process]:
480
            if len(Statistics.global_stats[p_num]) != 0:
481
                for stats_dict in Statistics.global_stats[p_num]:
482
                    self._link_stats_to_datasets(stats_dict, self._iterative_group)
483
484
    def _set_pattern_info(self):
485
        """Gathers information about the pattern of the data in the current plugin."""
486
        out_datasets = self.plugin.get_out_datasets()
487
        if len(out_datasets) == 0:
488
            self.calc_stats = False
489
        try:
490
            self.pattern = self.plugin.parameters['pattern']
491
            if self.pattern == None:
492
                raise KeyError
493
        except KeyError:
494
            if not out_datasets:
495
                self.pattern = None
496
            else:
497
                patterns = out_datasets[0].get_data_patterns()
498
                for pattern in patterns:
499
                    if 1 in patterns.get(pattern)["slice_dims"]:
500
                        self.pattern = pattern
501
                        break
502
                    self.pattern = None
503
        if self.pattern not in Statistics._pattern_list:
504
            self.calc_stats = False
505
506
    def _link_stats_to_datasets(self, stats_dict, iterative=False):
507
        """Links the volume wide statistics to the output dataset(s).
508
509
        :param stats_dict: Dictionary of stats being linked.
510
        :param iterative: boolean indicating if the plugin is iterative or not.
511
        """
512
        out_dataset = self.plugin.get_out_datasets()[0]
513
        my_dataset = out_dataset
514
        if iterative:
515
            if "itr_clone" in out_dataset.group_name:
516
                my_dataset = list(iterative._ip_data_dict["iterating"].keys())[0]
517
        n_datasets = self.plugin.nOutput_datasets()
518
519
        i = 2
520
        group_name = "stats"
521
        while group_name in list(my_dataset.meta_data.get_dictionary().keys()):
522
            group_name = f"stats{i}"  # If more than one set of stats for a plugin (such as iterative plugin)
523
            i += 1                    # the groups will be named stats, stats2, stats3 etc.
524
        for key, value in stats_dict.items():
525
            my_dataset.meta_data.set([group_name, key], value)
526
527
    def _write_stats_to_file(self, p_num=None, plugin_name=None, comm=MPI.COMM_WORLD):
528
        """Writes stats to a h5 file. This file is used to create figures and tables from the stats.
529
530
        :param p_num: The plugin number of the plugin the stats belong to (usually left as None except
531
            for special cases).
532
        :param plugin_name: Same as above (but for the name of the plugin).
533
        :param comm: The MPI communicator the plugin is using.
534
        """
535
        if p_num is None:
536
            p_num = self.p_num
537
        if plugin_name is None:
538
            plugin_name = self.plugin_names[p_num]
539
        path = Statistics.path
540
        filename = f"{path}/stats.h5"
541
        stats_dict = self.get_stats(p_num, instance="all")
542
        stats_array = self._dict_to_array(stats_dict[0])
543
        stats_key = list(stats_dict[0].keys())
544
        for i, my_dict in enumerate(stats_dict):
545
            if i != 0:
546
                stats_array = np.vstack([stats_array, self._dict_to_array(my_dict)])
547
        self.hdf5 = Hdf5Utils(self.exp)
548
        self.exp._barrier(communicator=comm)
549
        if comm.rank == 0:
550
            with h5.File(filename, "a") as h5file:
551
                group = h5file.require_group("stats")
552
                if stats_array.shape != (0,):
553
                    if str(p_num) in list(group.keys()):
554
                        del group[str(p_num)]
555
                    dataset = group.create_dataset(str(p_num), shape=stats_array.shape, dtype=stats_array.dtype)
556
                    dataset[::] = stats_array[::]
557
                    dataset.attrs.create("plugin_name", plugin_name)
558
                    dataset.attrs.create("pattern", self.pattern)
559
                    dataset.attrs.create("stats_key", stats_key)
560
                if self._iterative_group:
561
                    l_stats = Statistics.loop_stats[self.l_num]
562
                    group1 = h5file.require_group("iterative")
563
                    if self._iterative_group._ip_iteration == self._iterative_group._ip_fixed_iterations - 1\
564
                            and self.p_num == self._iterative_group.end_index:
565
                        dataset1 = group1.create_dataset(str(self.l_num), shape=l_stats["NRMSD"].shape, dtype=l_stats["NRMSD"].dtype)
566
                        dataset1[::] = l_stats["NRMSD"][::]
567
                        loop_plugins = []
568
                        for i in range(self._iterative_group.start_index, self._iterative_group.end_index + 1):
569
                            if i in list(self.plugin_names.keys()):
570
                                loop_plugins.append(self.plugin_names[i])
571
                        dataset1.attrs.create("loop_plugins", loop_plugins)
572
                        dataset.attrs.create("n_loop_plugins", len(loop_plugins))
0 ignored issues
show
introduced by
The variable dataset does not seem to be defined in case stats_array.shape != TupleNode on line 552 is False. Are you sure this can never be the case?
Loading history...
573
        self.exp._barrier(communicator=comm)
574
575
    def _write_times_to_file(self, comm):
576
        """Writes times into the file containing all the stats."""
577
        p_num = self.p_num
578
        plugin_name = self.plugin_name
579
        path = Statistics.path
580
        filename = f"{path}/stats.h5"
581
        time = Statistics.global_times[p_num]
582
        self.hdf5 = Hdf5Utils(self.exp)
583
        if comm.rank == 0:
584
            with h5.File(filename, "a") as h5file:
585
                group = h5file.require_group("stats")
586
                dataset = group[str(p_num)]
587
                dataset.attrs.create("time", time)
588
589
    def write_slice_stats_to_file(self, slice_stats=None, p_num=None, comm=MPI.COMM_WORLD):
590
        """Writes slice statistics to a h5 file. Placed in the stats folder in the output directory. Currently unused."""
591
        if not slice_stats:
592
            slice_stats = self.stats
593
        if not p_num:
594
            p_num = self.count
595
            plugin_name = self.plugin_name
596
        else:
597
            plugin_name = self.plugin_names[p_num]
598
        combined_stats = self._combine_mpi_stats(slice_stats)
599
        slice_stats_arrays = {}
600
        datasets = {}
601
        path = Statistics.path
602
        filename = f"{path}/stats_p{p_num}_{plugin_name}.h5"
603
        self.hdf5 = Hdf5Utils(self.plugin.exp)
604
        with h5.File(filename, "a", driver="mpio", comm=comm) as h5file:
605
            i = 2
606
            group_name = "/stats"
607
            while group_name in h5file:
608
                group_name = f"/stats{i}"
609
                i += 1
610
            group = h5file.create_group(group_name, track_order=None)
611
            for key in list(combined_stats.keys()):
612
                slice_stats_arrays[key] = np.array(combined_stats[key])
613
                datasets[key] = self.hdf5.create_dataset_nofill(group, key, (len(slice_stats_arrays[key]),), slice_stats_arrays[key].dtype)
614
                datasets[key][::] = slice_stats_arrays[key]
615
616
    def _unpad_slice(self, my_slice):
617
        """If data is padded in the slice dimension, removes this pad."""
618
        out_datasets = self.plugin.get_out_datasets()
619
        if len(out_datasets) == 1:
620
            out_dataset = out_datasets[0]
621
        else:
622
            for dataset in out_datasets:
623
                if self.pattern in list(dataset.data_info.get(["data_patterns"]).keys()):
624
                    out_dataset = dataset
625
                    break
626
        slice_dims = out_dataset.get_slice_dimensions()
0 ignored issues
show
introduced by
The variable out_dataset does not seem to be defined for all execution paths.
Loading history...
627
        if self.plugin.pcount == 0:
628
            self._slice_list, self._pad = self._get_unpadded_slice_list(my_slice, slice_dims)
629
        if self._pad:
630
            #for slice_dim in slice_dims:
631
            slice_dim = slice_dims[0]
632
            temp_slice = np.swapaxes(my_slice, 0, slice_dim)
633
            temp_slice = temp_slice[self._slice_list[slice_dim]]
634
            my_slice = np.swapaxes(temp_slice, 0, slice_dim)
635
        return my_slice
636
637
    def _get_unpadded_slice_list(self, my_slice, slice_dims):
638
        """Creates slice object(s) to un-pad slices in the slice dimension(s)."""
639
        slice_list = list(self.plugin.slice_list[0])
640
        pad = False
641
        if len(slice_list) == len(my_slice.shape):
642
            i = slice_dims[0]
643
            slice_width = self.plugin.slice_list[0][i].stop - self.plugin.slice_list[0][i].start
644
            if slice_width < my_slice.shape[i]:
645
                pad = True
646
                pad_width = (my_slice.shape[i] - slice_width) // 2  # Assuming symmetrical padding
647
                slice_list[i] = slice(pad_width, pad_width + 1, 1)
648
            return tuple(slice_list), pad
649
        else:
650
            return self.plugin.slice_list[0], pad
651
652
    def _flatten(self, l):
653
        """Function to flatten nested lists."""
654
        out = []
655
        for item in l:
656
            if isinstance(item, (list, tuple)):
657
                out.extend(self._flatten(item))
658
            else:
659
                out.append(item)
660
        return out
661
662
    def _de_list(self, my_slice):
663
        """If the slice is in a list, remove it from that list (takes 0th element)."""
664
        if type(my_slice) == list:
665
            if len(my_slice) != 0:
666
                my_slice = my_slice[0]
667
                my_slice = self._de_list(my_slice)
668
        return my_slice
669
670
    @classmethod
671
    def _count(cls):
672
        cls.count += 1
673
674
    @classmethod
675
    def _post_chain(cls):
676
        """Called after all plugins have run."""
677
        if cls._any_stats & cls._stats_flag:
678
            stats_utils = StatsUtils()
679
            stats_utils.generate_figures(f"{cls.path}/stats.h5", cls.path)
680