Test Failed
Pull Request — master (#934)
by
unknown
04:09
created

Statistics.set_slice_stats()   B

Complexity

Conditions 7

Size

Total Lines 23
Code Lines 15

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 7
eloc 15
nop 4
dl 0
loc 23
rs 8
c 0
b 0
f 0
1
"""
2
.. module:: statistics
3
   :platform: Unix
4
   :synopsis: Contains and processes statistics information for each plugin.
5
6
.. moduleauthor::Jacob Williamson <[email protected]>
7
8
"""
9
import logging
10
11
from savu.plugins.savers.utils.hdf5_utils import Hdf5Utils
12
from savu.data.stats.stats_utils import StatsUtils
13
from savu.core.iterate_plugin_group_utils import check_if_in_iterative_loop
14
import savu.core.utils as cu
15
16
import time
17
import h5py as h5
18
import numpy as np
19
import os
20
from mpi4py import MPI
21
from collections import OrderedDict
22
23
class Statistics(object):
24
    _pattern_list = ["SINOGRAM", "PROJECTION", "TANGENTOGRAM", "VOLUME_YZ", "VOLUME_XZ", "VOLUME_XY", "VOLUME_3D", "4D_SCAN", "SINOMOVIE"]
25
    _no_stats_plugins = ["BasicOperations", "Mipmap", "UnetApply"]
26
    _possible_stats = ("max", "min", "mean", "mean_std_dev", "median_std_dev", "NRMSD", "zeros", "zeros%", "range_used")  # list of possible stats
27
    _volume_to_slice = {"max": "max", "min": "min", "mean": "mean", "mean_std_dev": "std_dev",
28
                        "median_std_dev": "std_dev", "NRMSD": ("RSS", "data_points", "max", "min"),
29
                        "zeros": ("zeros", "data_points"), "zeros%": ("zeros", "data_points"),
30
                        "range_used": ("min", "max")}  # volume stat: required slice stat(s)
31
    #_savers = ["Hdf5Saver", "ImageSaver", "MrcSaver", "TiffSaver", "XrfSaver"]
32
    _has_setup = False
33
34
35
    def __init__(self):
36
37
        self.calc_stats = True
38
        self.stats_before_processing = {'max': [], 'min': [], 'mean': [], 'std_dev': []}
39
        self.residuals = {'max': [], 'min': [], 'mean': [], 'std_dev': []}
40
        self._repeat_count = 0
41
        self.plugin = None
42
        self.p_num = None
43
        self.stats_key = ["max", "min", "mean", "mean_std_dev", "median_std_dev", "RMSD"]
44
        self.slice_stats_key = None
45
        self.stats = None
46
        self.GPU = False
47
        self._iterative_group = None
48
49
    def setup(self, plugin_self, pattern=None):
50
        if not Statistics._has_setup:
51
            self._setup_class(plugin_self.exp)
52
        self.plugin_name = plugin_self.name
53
        self.p_num = Statistics.count
54
        self.plugin = plugin_self
55
        self.set_stats_key(self.stats_key)
56
        self.stats = {stat: [] for stat in self.slice_stats_key}
57
        if plugin_self.name in Statistics._no_stats_plugins:
58
            self.calc_stats = False
59
        if self.calc_stats:
60
            self._pad_dims = []
61
            self._already_called = False
62
            if pattern is not None:
63
                self.pattern = pattern
64
            else:
65
                self._set_pattern_info()
66
        if self.calc_stats:
67
            Statistics._any_stats = True
68
        self._setup_4d()
69
        self._setup_iterative()
70
71
    def _setup_iterative(self):
72
        self._iterative_group = check_if_in_iterative_loop(Statistics.exp)
73
        if self._iterative_group:
74
            if self._iterative_group.start_index == Statistics.count:
75
                Statistics._loop_counter += 1
76
                Statistics.loop_stats.append({"NRMSD": np.array([])})
77
            self.l_num = Statistics._loop_counter - 1
78
79
    def _setup_4d(self):
80
        try:
81
            in_dataset, out_dataset = self.plugin.get_datasets()
82
            if in_dataset[0].data_info["nDims"] == 4:
83
                self._4d = True
84
                shape = out_dataset[0].data_info["shape"]
85
                self._volume_total_points = 1
86
                for i in shape[:-1]:
87
                    self._volume_total_points *= i
88
            else:
89
                self._4d = False
90
        except KeyError:
91
            self._4d = False
92
93
    @classmethod
94
    def _setup_class(cls, exp):
95
        """Sets up the statistics class for the whole plugin chain (only called once)"""
96
        if exp.meta_data.get("stats") == "on":
97
            cls._stats_flag = True
98
        elif exp.meta_data.get("stats") == "off":
99
            cls._stats_flag = False
100
        cls._any_stats = False
101
        cls.exp = exp
102
        cls.count = 2
103
        cls.global_stats = {}
104
        cls.global_times = {}
105
        cls.loop_stats = []
106
        cls.n_plugins = len(exp.meta_data.plugin_list.plugin_list)
107
        for i in range(1, cls.n_plugins + 1):
108
            cls.global_stats[i] = {}
109
            cls.global_times[i] = 0
110
        cls.global_residuals = {}
111
        cls.plugin_numbers = {}
112
        cls.plugin_names = {}
113
        cls._loop_counter = 0
114
        cls.path = exp.meta_data['out_path']
115
        if cls.path[-1] == '/':
116
            cls.path = cls.path[0:-1]
117
        cls.path = f"{cls.path}/stats"
118
        if MPI.COMM_WORLD.rank == 0:
119
            if not os.path.exists(cls.path):
120
                os.mkdir(cls.path)
121
        cls._has_setup = True
122
123
    def get_stats(self, p_num=None, stat=None, instance=-1):
124
        """Returns stats associated with a certain plugin, given the plugin number (its place in the process list).
125
126
        :param p_num: Plugin  number of the plugin whose associated stats are being fetched.
127
            If p_num <= 0, it is relative to the plugin number of the current plugin being run.
128
            E.g current plugin number = 5, p_num = -2 --> will return stats of the third plugin.
129
            By default will gather stats for the current plugin.
130
        :param stat: Specify the stat parameter you want to fetch, i.e 'max', 'mean', 'median_std_dev'.
131
            If left blank will return the whole dictionary of stats:
132
            {'max': , 'min': , 'mean': , 'mean_std_dev': , 'median_std_dev': , 'NRMSD': }
133
        :param instance: In cases where there are multiple set of stats associated with a plugin
134
            due to iterative loops or multi-parameters, specify which set you want to retrieve, i.e 3 to retrieve the
135
            stats associated with the third run of a plugin. Pass 'all' to get a list of all sets.
136
            By default will retrieve the most recent set.
137
        """
138
        if p_num is None:
139
            p_num = self.p_num
140
        if p_num <= 0:
141
            try:
142
                p_num = self.p_num + p_num
143
            except TypeError:
144
                p_num = Statistics.count + p_num
145
        if instance == "all":
146
            stats_list = [self.get_stats(p_num, stat=stat, instance=1)]
147
            n = 2
148
            while n <= len(Statistics.global_stats[p_num]):
149
                stats_list.append(self.get_stats(p_num, stat=stat, instance=n))
150
                n += 1
151
            return stats_list
152
        if instance > 0:
153
            instance -= 1
154
        stats_dict = Statistics.global_stats[p_num][instance]
155
        if stat is not None:
156
            return stats_dict[stat]
157
        else:
158
            return stats_dict
159
160
    def get_stats_from_name(self, plugin_name, n=None, stat=None, instance=-1):
161
        """Returns stats associated with a certain plugin.
162
163
        :param plugin_name: name of the plugin whose associated stats are being fetched.
164
        :param n: In a case where there are multiple instances of **plugin_name** in the process list,
165
            specify the nth instance. Not specifying will select the first (or only) instance.
166
        :param stat: Specify the stat parameter you want to fetch, i.e 'max', 'mean', 'median_std_dev'.
167
            If left blank will return the whole dictionary of stats:
168
            {'max': , 'min': , 'mean': , 'mean_std_dev': , 'median_std_dev': , 'NRMSD': }
169
        :param instance: In cases where there are multiple set of stats associated with a plugin
170
            due to iterative loops or multi-parameters, specify which set you want to retrieve, i.e 3 to retrieve the
171
            stats associated with the third run of a plugin. Pass 'all' to get a list of all sets.
172
            By default will retrieve the most recent set.
173
        """
174
        name = plugin_name
175
        if n not in (None, 0, 1):
176
            name = name + str(n)
177
        p_num = Statistics.plugin_numbers[name]
178
        return self.get_stats(p_num, stat, instance)
179
180
    def get_stats_from_dataset(self, dataset, stat=None, instance=-1):
181
        """Returns stats associated with a dataset.
182
183
        :param dataset: The dataset whose associated stats are being fetched.
184
        :param stat: Specify the stat parameter you want to fetch, i.e 'max', 'mean', 'median_std_dev'.
185
            If left blank will return the whole dictionary of stats:
186
            {'max': , 'min': , 'mean': , 'mean_std_dev': , 'median_std_dev': , 'NRMSD': }
187
        :param instance: In cases where there are multiple set of stats associated with a dataset
188
            due to iterative loops or multi-parameters, specify which set you want to retrieve, i.e 3 to retrieve the
189
            stats associated with the third run of a plugin. Pass 'all' to get a list of all sets.
190
            By default will retrieve the most recent set.
191
        """
192
        stats_list = [dataset.meta_data.get("stats")]
193
        n = 2
194
        while ("stats" + str(n)) in list(dataset.meta_data.get_dictionary().keys()):
195
            stats_list.append(dataset.meta_data.get("stats" + str(n)))
196
            n += 1
197
        if stat:
198
            for i in range(len(stats_list)):
199
                stats_list[i] = stats_list[i][stat]
200
        if instance in (None, 0, 1):
201
            stats = stats_list[0]
202
        elif instance == "all":
203
            stats = stats_list
204
        else:
205
            if instance >= 2:
206
                instance -= 1
207
            stats = stats_list[instance]
208
        return stats
209
210
    def set_stats_key(self, stats_key):
211
        """Changes which stats are to be calculated for the current plugin.
212
213
        :param stats_key: List of stats to be calculated.
214
        """
215
        valid = Statistics._possible_stats
216
        stats_key = sorted(set(valid).intersection(stats_key), key=lambda stat: valid.index(stat))
217
        self.stats_key = stats_key
218
        self.slice_stats_key = list(set(self._flatten(list(Statistics._volume_to_slice[stat] for stat in stats_key))))
219
        if "data_points" not in self.slice_stats_key:
220
            self.slice_stats_key.append("data_points")  # Data points is essential
221
222
    def set_slice_stats(self, my_slice, base_slice=None, pad=True):
223
        """Sets slice stats for the current slice.
224
225
        :param my_slice: The slice whose stats are being set.
226
        :param base_slice: Provide a base slice to calculate residuals from, to calculate RMSD.
227
        :param pad: Specify whether slice is padded or not (usually can leave as True even if slice is not padded).
228
        """
229
        my_slice = self._de_list(my_slice)
230
        if 0 not in my_slice.shape:
231
            try:
232
                slice_stats = self.calc_slice_stats(my_slice, base_slice=base_slice, pad=pad)
233
            except:
234
                pass
235
            if slice_stats is not None:
236
                for key, value in slice_stats.items():
237
                    self.stats[key].append(value)
238
                if self._4d:
239
                    if sum(self.stats["data_points"]) >= self._volume_total_points:
240
                        self.set_volume_stats()
241
            else:
242
                self.calc_stats = False
243
        else:
244
            self.calc_stats = False
245
246
    def calc_slice_stats(self, my_slice, base_slice=None, pad=True):
247
        """Calculates and returns slice stats for the current slice.
248
249
        :param my_slice: The slice whose stats are being calculated.
250
        :param base_slice: Provide a base slice to calculate residuals from, to calculate RMSD.
251
        :param pad: Specify whether slice is padded or not (usually can leave as True even if slice is not padded).
252
        """
253
        if my_slice is not None:
254
            my_slice = self._de_list(my_slice)
255
            if pad:
256
                my_slice = self._unpad_slice(my_slice)
257
            slice_stats = {}
258
            if "max" in self.slice_stats_key:
259
                slice_stats["max"] = np.amax(my_slice).astype('float64')
260
            if "min" in self.slice_stats_key:
261
                slice_stats["min"] = np.amin(my_slice).astype('float64')
262
            if "mean" in self.slice_stats_key:
263
                slice_stats["mean"] = np.mean(my_slice)
264
            if "std_dev" in self.slice_stats_key:
265
                slice_stats["std_dev"] = np.std(my_slice)
266
            if "zeros" in self.slice_stats_key:
267
                slice_stats["zeros"] = self.calc_zeros(my_slice)
268
            if "data_points" in self.slice_stats_key:
269
                slice_stats["data_points"] = my_slice.size
270
            if "RSS" in self.slice_stats_key and base_slice is not None:
271
                base_slice = self._de_list(base_slice)
272
                base_slice = self._unpad_slice(base_slice)
273
                slice_stats["RSS"] = self.calc_rss(my_slice, base_slice)
274
            if "dtype" not in self.stats:
275
                self.stats["dtype"] = my_slice.dtype
276
            return slice_stats
277
        return None
278
279
    @staticmethod
280
    def calc_zeros(my_slice):
281
        return my_slice.size - np.count_nonzero(my_slice)
282
283
    @staticmethod
284
    def calc_rss(array1, array2):  # residual sum of squares
285
        if array1.shape == array2.shape:
286
            residuals = np.subtract(array1, array2)
287
            rss = np.sum(residuals.flatten() ** 2)
288
        else:
289
            logging.debug("Cannot calculate RSS, arrays different sizes.")
290
            rss = None
291
        return rss
292
293
    @staticmethod
294
    def rmsd_from_rss(rss, n):
295
        return np.sqrt(rss/n)
296
297
    def calc_rmsd(self, array1, array2):
298
        if array1.shape == array2.shape:
299
            rss = self.calc_rss(array1, array2)
300
            rmsd = self.rmsd_from_rss(rss, array1.size)
301
        else:
302
            logging.error("Cannot calculate RMSD, arrays different sizes.")
303
            rmsd = None
304
        return rmsd
305
306
    def calc_stats_residuals(self, stats_before, stats_after):  # unused
307
        residuals = {'max': None, 'min': None, 'mean': None, 'std_dev': None}
308
        for key in list(residuals.keys()):
309
            residuals[key] = stats_after[key] - stats_before[key]
310
        return residuals
311
312
    def set_stats_residuals(self, residuals):  # unused
313
        self.residuals['max'].append(residuals['max'])
314
        self.residuals['min'].append(residuals['min'])
315
        self.residuals['mean'].append(residuals['mean'])
316
        self.residuals['std_dev'].append(residuals['std_dev'])
317
318
    def calc_volume_stats(self, slice_stats):
319
        """Calculates and returns volume-wide stats from slice-wide stats.
320
321
        :param slice_stats: The slice-wide stats that the volume-wide stats are calculated from.
322
        """
323
        slice_stats = slice_stats
324
        volume_stats = {}
325
        if "max" in self.stats_key:
326
            volume_stats["max"] = max(slice_stats["max"])
327
        if "min" in self.stats_key:
328
            volume_stats["min"] = min(slice_stats["min"])
329
        if "mean" in self.stats_key:
330
            volume_stats["mean"] = np.mean(slice_stats["mean"])
331
        if "mean_std_dev" in self.stats_key:
332
            volume_stats["mean_std_dev"] = np.mean(slice_stats["std_dev"])
333
        if "median_std_dev" in self.stats_key:
334
            volume_stats["median_std_dev"] = np.median(slice_stats["std_dev"])
335
        if "NRMSD" in self.stats_key and None not in slice_stats["RSS"]:
336
            total_rss = sum(slice_stats["RSS"])
337
            n = sum(slice_stats["data_points"])
338
            RMSD = self.rmsd_from_rss(total_rss, n)
339
            the_range = volume_stats["max"] - volume_stats["min"]
340
            NRMSD = RMSD / the_range  # normalised RMSD (dividing by the range)
341
            volume_stats["NRMSD"] = NRMSD
342
        if "zeros" in self.stats_key:
343
            volume_stats["zeros"] = sum(slice_stats["zeros"])
344
        if "zeros%" in self.stats_key:
345
            volume_stats["zeros%"] = (volume_stats["zeros"] / sum(slice_stats["data_points"])) * 100
346
        if "range_used" in self.stats_key:
347
            my_range = volume_stats["max"] - volume_stats["min"]
348
            if "int" in str(self.stats["dtype"]):
349
                possible_max = np.iinfo(self.stats["dtype"]).max
350
                possible_min = np.iinfo(self.stats["dtype"]).min
351
                self.stats["possible_max"] = possible_max
352
                self.stats["possible_min"] = possible_min
353
            elif "float" in str(self.stats["dtype"]):
354
                possible_max = np.finfo(self.stats["dtype"]).max
355
                possible_min = np.finfo(self.stats["dtype"]).min
356
                self.stats["possible_max"] = possible_max
357
                self.stats["possible_min"] = possible_min
358
            possible_range = possible_max - possible_min
0 ignored issues
show
introduced by
The variable possible_min does not seem to be defined for all execution paths.
Loading history...
introduced by
The variable possible_max does not seem to be defined for all execution paths.
Loading history...
359
            volume_stats["range_used"] = (my_range / possible_range) * 100
360
        return volume_stats
361
362
    def _set_loop_stats(self):
363
        # NEED TO CHANGE THIS - MUST USE SLICES (unused)
364
        data_obj1 = list(self._iterative_group._ip_data_dict["iterating"].keys())[0]
365
        data_obj2 = self._iterative_group._ip_data_dict["iterating"][data_obj1]
366
        RMSD = self.calc_rmsd(data_obj1.data, data_obj2.data)
367
        the_range = self.get_stats(self.p_num, stat="max", instance=self._iterative_group._ip_iteration) -\
368
                self.get_stats(self.p_num, stat="min", instance=self._iterative_group._ip_iteration)
369
        NRMSD = RMSD/the_range
370
        Statistics.loop_stats[self.l_num]["NRMSD"] = np.append(Statistics.loop_stats[self.l_num]["NRMSD"], NRMSD)
371
372
    def set_volume_stats(self):
373
        """Calculates volume-wide statistics from slice stats, and updates class-wide arrays with these values.
374
        Links volume stats with the output dataset and writes slice stats to file.
375
        """
376
        stats = self.stats
377
        comm = self.plugin.get_communicator()
378
        combined_stats = self._combine_mpi_stats(stats, comm=comm)
379
        if not self.p_num:
380
            self.p_num = Statistics.count
381
        p_num = self.p_num
382
        name = self.plugin_name
383
        i = 2
384
        if not self._iterative_group:
385
            while name in list(Statistics.plugin_numbers.keys()):
386
                name = self.plugin_name + str(i)
387
                i += 1
388
        elif self._iterative_group._ip_iteration == 0:
389
            while name in list(Statistics.plugin_numbers.keys()):
390
                name = self.plugin_name + str(i)
391
                i += 1
392
393
        if p_num not in list(Statistics.plugin_names.keys()):
394
            Statistics.plugin_names[p_num] = name
395
        Statistics.plugin_numbers[name] = p_num
396
        if len(combined_stats['max']) != 0:
397
            stats_dict = self.calc_volume_stats(combined_stats)
398
            Statistics.global_residuals[p_num] = {}
399
            #before_processing = self.calc_volume_stats(self.stats_before_processing)
400
            #for key in list(before_processing.keys()):
401
            #    Statistics.global_residuals[p_num][key] = Statistics.global_stats[p_num][key] - before_processing[key]
402
403
            if len(Statistics.global_stats[p_num]) == 0:
404
                Statistics.global_stats[p_num] = [stats_dict]
405
            else:
406
                Statistics.global_stats[p_num].append(stats_dict)
407
408
            self._link_stats_to_datasets(stats_dict, self._iterative_group)
409
            self._write_stats_to_file(p_num, comm=comm)
410
        self._already_called = True
411
        self._repeat_count += 1
412
        if self._iterative_group or self._4d:
413
            self.stats = {stat: [] for stat in self.slice_stats_key}
414
415
    def start_time(self):
416
        """Called at the start of a plugin."""
417
        self.t0 = time.time()
418
419
    def stop_time(self):
420
        """Called at the ebd of a plugin."""
421
        self.t1 = time.time()
422
        elapsed = round(self.t1 - self.t0, 1)
423
        if self._stats_flag and self.calc_stats:
424
            self.set_time(elapsed)
425
426
    def set_time(self, seconds):
427
        """Sets time taken for plugin to complete."""
428
        Statistics.global_times[self.p_num] += seconds  # Gives total time for a plugin in a loop
429
        #print(f"{self.p_num}, {seconds}")
430
        comm = self.plugin.get_communicator()
431
        try:
432
            rank = comm.rank
433
        except MPI.Exception:        # Sometimes get_communicator() returns an invalid communicator.
434
            comm = MPI.COMM_WORLD    # So using COMM_WORLD in this case.
435
        self._write_times_to_file(comm)
436
437
    def _combine_mpi_stats(self, slice_stats, comm=MPI.COMM_WORLD):
438
        """Combines slice stats from different processes, so volume stats can be calculated.
439
440
        :param slice_stats: slice stats (each process will have a different set).
441
        :param comm: MPI communicator being used.
442
        """
443
        combined_stats_list = comm.allgather(slice_stats)
444
        combined_stats = {stat: [] for stat in self.slice_stats_key}
445
        for single_stats in combined_stats_list:
446
            for key in self.slice_stats_key:
447
                combined_stats[key] += single_stats[key]
448
        return combined_stats
449
450
    def _array_to_dict(self, stats_array, key_list=None):
451
        """Converts an array of stats to a dictionary of stats.
452
453
        :param stats_array: Array of stats to be converted.
454
        :param key_list: List of keys indicating the names of the stats in the stats_array.
455
        """
456
        if key_list is None:
457
            key_list = self.stats_key
458
        stats_dict = {}
459
        for i, value in enumerate(stats_array):
460
            stats_dict[key_list[i]] = value
461
        return stats_dict
462
463
    def _dict_to_array(self, stats_dict):
464
        """Converts stats dict into a numpy array (keys will be lost).
465
466
        :param stats_dict: dictionary of stats.
467
        """
468
        return np.array(list(stats_dict.values()))
469
470
    def _broadcast_gpu_stats(self, gpu_processes, process):
471
        """During GPU plugins, most processes are unused, and don't have access to stats.
472
        This method shares stats between processes so all have access to stats.
473
474
        :param gpu_processes: List that determines whether a process is a GPU process.
475
        :param process: Process number.
476
        """
477
        p_num = self.p_num
478
        Statistics.global_stats[p_num] = MPI.COMM_WORLD.bcast(Statistics.global_stats[p_num], root=0)
479
        if not gpu_processes[process]:
480
            if Statistics.global_stats[p_num].ndim == 1:
481
                stats_dict = self._array_to_dict(Statistics.global_stats[p_num])
482
                self._link_stats_to_datasets(stats_dict, self._iterative_group)
483
            elif Statistics.global_stats[p_num].ndim > 1:
484
                for stats_array in Statistics.global_stats[p_num]:
485
                    stats_dict = self._array_to_dict(stats_array)
486
                    self._link_stats_to_datasets(stats_dict, self._iterative_group)
487
488
    def _set_pattern_info(self):
489
        """Gathers information about the pattern of the data in the current plugin."""
490
        out_datasets = self.plugin.get_out_datasets()
491
        if len(out_datasets) == 0:
492
            self.calc_stats = False
493
        try:
494
            self.pattern = self.plugin.parameters['pattern']
495
            if self.pattern == None:
496
                raise KeyError
497
        except KeyError:
498
            if not out_datasets:
499
                self.pattern = None
500
            else:
501
                patterns = out_datasets[0].get_data_patterns()
502
                for pattern in patterns:
503
                    if 1 in patterns.get(pattern)["slice_dims"]:
504
                        self.pattern = pattern
505
                        break
506
                    self.pattern = None
507
        if self.pattern not in Statistics._pattern_list:
508
            self.calc_stats = False
509
510
    def _link_stats_to_datasets(self, stats_dict, iterative=False):
511
        """Links the volume wide statistics to the output dataset(s).
512
513
        :param stats_dict: Dictionary of stats being linked.
514
        :param iterative: boolean indicating if the plugin is iterative or not.
515
        """
516
        out_dataset = self.plugin.get_out_datasets()[0]
517
        my_dataset = out_dataset
518
        if iterative:
519
            if "itr_clone" in out_dataset.group_name:
520
                my_dataset = list(iterative._ip_data_dict["iterating"].keys())[0]
521
        n_datasets = self.plugin.nOutput_datasets()
522
523
        i = 2
524
        group_name = "stats"
525
        while group_name in list(my_dataset.meta_data.get_dictionary().keys()):
526
            group_name = f"stats{i}"  # If more than one set of stats for a plugin (such as iterative plugin)
527
            i += 1                    # the groups will be named stats, stats2, stats3 etc.
528
        for key, value in stats_dict.items():
529
            my_dataset.meta_data.set([group_name, key], value)
530
531
    def _write_stats_to_file(self, p_num=None, plugin_name=None, comm=MPI.COMM_WORLD):
532
        """Writes stats to a h5 file. This file is used to create figures and tables from the stats.
533
534
        :param p_num: The plugin number of the plugin the stats belong to (usually left as None except
535
            for special cases).
536
        :param plugin_name: Same as above (but for the name of the plugin).
537
        :param comm: The MPI communicator the plugin is using.
538
        """
539
        if p_num is None:
540
            p_num = self.p_num
541
        if plugin_name is None:
542
            plugin_name = self.plugin_names[p_num]
543
        path = Statistics.path
544
        filename = f"{path}/stats.h5"
545
        stats_dict = self.get_stats(p_num, instance="all")
546
        stats_array = self._dict_to_array(stats_dict[0])
547
        stats_key = list(stats_dict[0].keys())
548
        for i, my_dict in enumerate(stats_dict):
549
            if i != 0:
550
                stats_array = np.vstack([stats_array, self._dict_to_array(my_dict)])
551
        self.hdf5 = Hdf5Utils(self.exp)
552
        self.exp._barrier(communicator=comm)
553
        if comm.rank == 0:
554
            with h5.File(filename, "a") as h5file:
555
                group = h5file.require_group("stats")
556
                if stats_array.shape != (0,):
557
                    if str(p_num) in list(group.keys()):
558
                        del group[str(p_num)]
559
                    dataset = group.create_dataset(str(p_num), shape=stats_array.shape, dtype=stats_array.dtype)
560
                    dataset[::] = stats_array[::]
561
                    dataset.attrs.create("plugin_name", plugin_name)
562
                    dataset.attrs.create("pattern", self.pattern)
563
                    dataset.attrs.create("stats_key", stats_key)
564
                if self._iterative_group:
565
                    l_stats = Statistics.loop_stats[self.l_num]
566
                    group1 = h5file.require_group("iterative")
567
                    if self._iterative_group._ip_iteration == self._iterative_group._ip_fixed_iterations - 1\
568
                            and self.p_num == self._iterative_group.end_index:
569
                        dataset1 = group1.create_dataset(str(self.l_num), shape=l_stats["NRMSD"].shape, dtype=l_stats["NRMSD"].dtype)
570
                        dataset1[::] = l_stats["NRMSD"][::]
571
                        loop_plugins = []
572
                        for i in range(self._iterative_group.start_index, self._iterative_group.end_index + 1):
573
                            if i in list(self.plugin_names.keys()):
574
                                loop_plugins.append(self.plugin_names[i])
575
                        dataset1.attrs.create("loop_plugins", loop_plugins)
576
                        dataset.attrs.create("n_loop_plugins", len(loop_plugins))
0 ignored issues
show
introduced by
The variable dataset does not seem to be defined in case stats_array.shape != TupleNode on line 556 is False. Are you sure this can never be the case?
Loading history...
577
        self.exp._barrier(communicator=comm)
578
579
    def _write_times_to_file(self, comm):
580
        """Writes times into the file containing all the stats."""
581
        p_num = self.p_num
582
        plugin_name = self.plugin_name
583
        path = Statistics.path
584
        filename = f"{path}/stats.h5"
585
        time = Statistics.global_times[p_num]
586
        self.hdf5 = Hdf5Utils(self.exp)
587
        if comm.rank == 0:
588
            with h5.File(filename, "a") as h5file:
589
                group = h5file.require_group("stats")
590
                dataset = group[str(p_num)]
591
                dataset.attrs.create("time", time)
592
593
    def write_slice_stats_to_file(self, slice_stats=None, p_num=None, comm=MPI.COMM_WORLD):
594
        """Writes slice statistics to a h5 file. Placed in the stats folder in the output directory. Currently unused."""
595
        if not slice_stats:
596
            slice_stats = self.stats
597
        if not p_num:
598
            p_num = self.count
599
            plugin_name = self.plugin_name
600
        else:
601
            plugin_name = self.plugin_names[p_num]
602
        combined_stats = self._combine_mpi_stats(slice_stats)
603
        slice_stats_arrays = {}
604
        datasets = {}
605
        path = Statistics.path
606
        filename = f"{path}/stats_p{p_num}_{plugin_name}.h5"
607
        self.hdf5 = Hdf5Utils(self.plugin.exp)
608
        with h5.File(filename, "a", driver="mpio", comm=comm) as h5file:
609
            i = 2
610
            group_name = "/stats"
611
            while group_name in h5file:
612
                group_name = f"/stats{i}"
613
                i += 1
614
            group = h5file.create_group(group_name, track_order=None)
615
            for key in list(combined_stats.keys()):
616
                slice_stats_arrays[key] = np.array(combined_stats[key])
617
                datasets[key] = self.hdf5.create_dataset_nofill(group, key, (len(slice_stats_arrays[key]),), slice_stats_arrays[key].dtype)
618
                datasets[key][::] = slice_stats_arrays[key]
619
620
    def _unpad_slice(self, my_slice):
621
        """If data is padded in the slice dimension, removes this pad."""
622
        out_datasets = self.plugin.get_out_datasets()
623
        if len(out_datasets) == 1:
624
            out_dataset = out_datasets[0]
625
        else:
626
            for dataset in out_datasets:
627
                if self.pattern in list(dataset.data_info.get(["data_patterns"]).keys()):
628
                    out_dataset = dataset
629
                    break
630
        slice_dims = out_dataset.get_slice_dimensions()
0 ignored issues
show
introduced by
The variable out_dataset does not seem to be defined for all execution paths.
Loading history...
631
        if self.plugin.pcount == 0:
632
            self._slice_list, self._pad = self._get_unpadded_slice_list(my_slice, slice_dims)
633
        if self._pad:
634
            #for slice_dim in slice_dims:
635
            slice_dim = slice_dims[0]
636
            temp_slice = np.swapaxes(my_slice, 0, slice_dim)
637
            temp_slice = temp_slice[self._slice_list[slice_dim]]
638
            my_slice = np.swapaxes(temp_slice, 0, slice_dim)
639
        return my_slice
640
641
    def _get_unpadded_slice_list(self, my_slice, slice_dims):
642
        """Creates slice object(s) to un-pad slices in the slice dimension(s)."""
643
        slice_list = list(self.plugin.slice_list[0])
644
        pad = False
645
        if len(slice_list) == len(my_slice.shape):
646
            i = slice_dims[0]
647
            slice_width = self.plugin.slice_list[0][i].stop - self.plugin.slice_list[0][i].start
648
            if slice_width < my_slice.shape[i]:
649
                pad = True
650
                pad_width = (my_slice.shape[i] - slice_width) // 2  # Assuming symmetrical padding
651
                slice_list[i] = slice(pad_width, pad_width + 1, 1)
652
            return tuple(slice_list), pad
653
        else:
654
            return self.plugin.slice_list[0], pad
655
656
    def _flatten(self, l):
657
        """Function to flatten nested lists."""
658
        out = []
659
        for item in l:
660
            if isinstance(item, (list, tuple)):
661
                out.extend(self._flatten(item))
662
            else:
663
                out.append(item)
664
        return out
665
666
    def _de_list(self, my_slice):
667
        """If the slice is in a list, remove it from that list (takes 0th element)."""
668
        if type(my_slice) == list:
669
            if len(my_slice) != 0:
670
                my_slice = my_slice[0]
671
                my_slice = self._de_list(my_slice)
672
        return my_slice
673
674
    @classmethod
675
    def _count(cls):
676
        cls.count += 1
677
678
    @classmethod
679
    def _post_chain(cls):
680
        """Called after all plugins have run."""
681
        if cls._any_stats & cls._stats_flag:
682
            stats_utils = StatsUtils()
683
            stats_utils.generate_figures(f"{cls.path}/stats.h5", cls.path)
684