savu.data.plugin_list.PluginList.__init__()   A
last analyzed

Complexity

Conditions 1

Size

Total Lines 12
Code Lines 12

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 12
nop 1
dl 0
loc 12
rs 9.8
c 0
b 0
f 0
1
# Copyright 2014 Diamond Light Source Ltd.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
15
"""
16
.. module:: plugin_list
17
   :platform: Unix
18
   :synopsis: Contains the PluginList class, which deals with loading and \
19
   saving the plugin list, and the CitationInformation class. An instance is \
20
   held by the MetaData class.
21
22
.. moduleauthor:: Nicola Wadeson <[email protected]>
23
24
"""
25
import ast
26
import copy
27
import inspect
28
import json
29
import logging
30
import os
31
import re
32
from collections import defaultdict
33
34
import h5py
35
import numpy as np
36
37
import savu.data.framework_citations as fc
38
import savu.plugins.docstring_parser as doc
39
import savu.plugins.loaders.utils.yaml_utils as yu
40
import savu.plugins.utils as pu
41
from savu.data.meta_data import MetaData
42
43
NX_CLASS = "NX_class"
44
45
46
class PluginList(object):
47
    """
48
    The PluginList class handles the plugin list - loading, saving and adding
49
    citation information for the plugin
50
    """
51
52
    def __init__(self):
53
        self.plugin_list = []
54
        self.n_plugins = None
55
        self.n_loaders = 0
56
        self.n_savers = 0
57
        self.loader_idx = None
58
        self.saver_idx = None
59
        self.datasets_list = []
60
        self.saver_plugin_status = True
61
        self._template = None
62
        self.version = None
63
        self.iterate_plugin_groups = []
64
65
    def add_template(self, create=False):
66
        self._template = Template(self)
67
        if create:
68
            self._template.creating = True
69
70
    def _get_plugin_entry_template(self):
71
        template = {"active": True, "name": None, "id": None, "data": None}
72
        return template
73
74
    def __get_json_keys(self):
75
        return ["data"]
76
77
    def _populate_plugin_list(
78
        self, filename, active_pass=False, template=False
79
    ):
80
        """ Populate the plugin list from a nexus file. """
81
        with h5py.File(filename, "r") as plugin_file:
82
            if "entry/savu_notes/version" in plugin_file:
83
                self.version = plugin_file["entry/savu_notes/version"][()]
84
                self._show_process_list_version()
85
86
            plugin_group = plugin_file["entry/plugin"]
87
            self.plugin_list = []
88
            single_val = ["name", "id", "pos", "active"]
89
            exclude = ["citation"]
90
            ordered_pl_keys = pu.sort_alphanum(list(plugin_group.keys()))
91
            for group in ordered_pl_keys:
92
                plugin = self._get_plugin_entry_template()
93
                entry_keys = plugin_group[group].keys()
94
                parameters = [
95
                    k
96
                    for k in entry_keys
97
                    for e in exclude
98
                    if k not in single_val and e not in k
99
                ]
100
101
                if "active" in entry_keys:
102
                    plugin["active"] = plugin_group[group]["active"][0]
103
104
                if plugin['active'] or active_pass:
105
                    plugin['name'] = plugin_group[group]['name'][0].decode("utf-8")
106
                    plugin['id'] = plugin_group[group]['id'][0].decode("utf-8")
107
                    plugin_tools = None
108
                    try:
109
                        plugin_class = pu.load_class(plugin["id"])()
110
                        # Populate the parameters (including those from it's base classes)
111
                        plugin_tools = plugin_class.get_plugin_tools()
112
                        if not plugin_tools:
113
                            raise OSError(f"Tools file not found for {plugin['name']}")
114
                        plugin_tools._populate_default_parameters()
115
                    except ImportError:
116
                        # No plugin class found
117
                        logging.error(f"No class found for {plugin['name']}")
118
119
                    plugin['doc'] = plugin_tools.docstring_info if plugin_tools else ""
120
                    plugin['tools'] = plugin_tools if plugin_tools else {}
121
                    plugin['param'] = plugin_tools.get_param_definitions() if \
122
                        plugin_tools else {}
123
                    plugin['pos'] = group.strip()
124
125
                    for param in parameters:
126
                        try:
127
                            plugin[param] = json.loads(plugin_group[group][param][0])
128
                        except ValueError as e:
129
                            raise ValueError(
130
                                f"Error: {e}\n Could not parse key '{param}' from group '{group}' as JSON"
131
                            )
132
                    self.plugin_list.append(plugin)
133
134
            # add info about groups of plugins to iterate over into
135
            # self.iterate_plugin_groups
136
            self.clear_iterate_plugin_group_dicts()
137
            try:
138
                iterate_groups = plugin_file['entry/iterate_plugin_groups']
139
                for key in list(iterate_groups.keys()):
140
                    iterate_group_dict = {
141
                        'start_index': iterate_groups[key]['start'][()],
142
                        'end_index': iterate_groups[key]['end'][()],
143
                        'iterations': iterate_groups[key]['iterations'][()]
144
                    }
145
                    self.iterate_plugin_groups.append(iterate_group_dict)
146
            except Exception as e:
147
                err_str = f"Process list file {filename} doesn't have the " \
148
                          f"iterate_plugin_groups internal hdf5 path"
149
                print(err_str)
150
151
            if template:
152
                self.add_template()
153
                self._template.update_process_list(template)
154
155
    def _show_process_list_version(self):
156
        """If the input process list was created using an older version
157
        of Savu, then alert the user"""
158
        from savu.version import __version__
159
        if __version__ != self.version:
160
            separator = "*" * 53
161
            print(separator)
162
            print(f"*** This process list was created using Savu "
163
                  f"{self.version}  ***")
164
            print(separator)
165
166
    def _save_plugin_list(self, out_filename):
167
        with h5py.File(out_filename, "a") as nxs_file:
168
169
            entry = nxs_file.require_group("entry")
170
171
            self._save_framework_citations(self._overwrite_group(
172
                entry, 'framework_citations', 'NXcollection'))
173
174
            self.__save_savu_notes(self._overwrite_group(
175
                entry, 'savu_notes', 'NXnote'))
176
177
            plugins_group = self._overwrite_group(entry, 'plugin', 'NXprocess')
178
179
            count = 1
180
            for plugin in self.plugin_list:
181
                plugin_group = self._get_plugin_group(
182
                    plugins_group, plugin, count
183
                )
184
                self.__populate_plugins_group(plugin_group, plugin)
185
186
            self.__save_iterate_plugin_groups(self._overwrite_group(
187
                entry, 'iterate_plugin_groups', 'NXnote'))
188
189
        if self._template and self._template.creating:
190
            fname = os.path.splitext(out_filename)[0] + ".savu"
191
            self._template._output_template(fname, out_filename)
192
193
    def _overwrite_group(self, entry, name, nxclass):
194
        if name in entry:
195
            entry.pop(name)
196
        group = entry.create_group(name.encode("ascii"))
197
        group.attrs[NX_CLASS] = nxclass.encode("ascii")
198
        return group
199
200
    def add_iterate_plugin_group(self, start, end, iterations):
201
        """Add an element to self.iterate_plugin_groups"""
202
        group_new = {
203
            'start_index': start,
204
            'end_index': end,
205
            'iterations': iterations
206
        }
207
208
        if iterations <= 0:
209
            print("The number of iterations should be larger than zero and nonnegative")
210
            return
211
        elif start <= 0 or start > len(self.plugin_list) or end <= 0 or end > len(self.plugin_list):
212
            print("The given plugin indices are not within the range of existing plugin indices")
213
            return
214
215
        are_crosschecks_ok = \
216
            self._crosscheck_existing_loops(start, end, iterations)
217
218
        if are_crosschecks_ok:
219
            self.iterate_plugin_groups.append(group_new)
220
            info_str = f"The following loop has been added: start plugin " \
221
                       f"index {group_new['start_index']}, end plugin index " \
222
                       f"{group_new['end_index']}, iterations " \
223
                       f"{group_new['iterations']}"
224
            print(info_str)
225
226
    def _crosscheck_existing_loops(self, start, end, iterations):
227
        """
228
        Check requested loop to be added against existing loops for potential
229
        clashes
230
        """
231
        are_crosschecks_ok = False
232
        list_new_indices = list(range(start, end+1))
233
        # crosscheck with the existing iterative loops
234
        if len(self.iterate_plugin_groups) != 0:
235
            noexactlist = True
236
            nointersection = True
237
            for count, group in enumerate(self.iterate_plugin_groups, 1):
238
                start_int = int(group['start_index'])
239
                end_int = int(group['end_index'])
240
                list_existing_indices = list(range(start_int, end_int+1))
241
                if bool(set(list_new_indices).intersection(list_existing_indices)):
242
                    # check if the intersection of lists is exact (number of iterations to change)
243
                    nointersection = False
244
                    if list_new_indices == list_existing_indices:
245
                        print(f"The number of iterations of loop group no. {count}, {list_new_indices} has been set to: {iterations}")
246
                        self.iterate_plugin_groups[count-1]["iterations"] = iterations
247
                        noexactlist = False
248
                    else:
249
                        print(f"The plugins of group no. {count} are already set to be iterative: {set(list_new_indices).intersection(list_existing_indices)}")
250
            if noexactlist and nointersection:
251
                are_crosschecks_ok = True
252
        else:
253
            are_crosschecks_ok = True
254
255
        return are_crosschecks_ok
256
257
    def remove_iterate_plugin_groups(self, indices):
258
        """ Remove elements from self.iterate_plugin_groups """
259
        if len(indices) == 0:
260
            # remove all iterative loops in process list
261
            prompt_str = 'Are you sure you want to remove all iterative ' \
262
                'loops? [y/N]'
263
            check = input(prompt_str)
264
            should_remove_all = check.lower() == 'y'
265
            if should_remove_all:
266
                self.clear_iterate_plugin_group_dicts()
267
                print('All iterative loops have been removed')
268
            else:
269
                print('No iterative loops have been removed')
270
        else:
271
            # remove specified iterative loops in process list
272
            sorted_indices = sorted(indices)
273
274
            if sorted_indices[0] <= 0:
275
                print('The iterative loops are indexed starting from 1')
276
                return
277
278
            for i in reversed(sorted_indices):
279
                try:
280
                    # convert the one-based index to a zero-based index
281
                    iterate_group = self.iterate_plugin_groups.pop(i - 1)
282
                    info_str = f"The following loop has been removed: start " \
283
                               f"plugin index " \
284
                               f"{iterate_group['start_index']}, " \
285
                               f"end plugin index " \
286
                               f"{iterate_group['end_index']}, iterations " \
287
                               f"{iterate_group['iterations']}"
288
                    print(info_str)
289
                except IndexError as e:
290
                    info_str = f"There doesn't exist an iterative loop with " \
291
                               f"number {i}"
292
                    print(info_str)
293
294
    def clear_iterate_plugin_group_dicts(self):
295
        """
296
        Reset the list of dicts representing groups of plugins to iterate over
297
        """
298
        self.iterate_plugin_groups = []
299
300
    def get_iterate_plugin_group_dicts(self):
301
        """
302
        Return the list of dicts representing groups of plugins to iterate over
303
        """
304
        return self.iterate_plugin_groups
305
306
    def print_iterative_loops(self):
307
        if len(self.iterate_plugin_groups) == 0:
308
            print('There are no iterative loops in the current process list')
309
        else:
310
            print('Iterative loops in the current process list are:')
311
            for count, group in enumerate(self.iterate_plugin_groups, 1):
312
                number = f"({count}) "
313
                start_str = f"start plugin index: {group['start_index']}"
314
                end_str = f"end index: {group['end_index']}"
315
                iterations_str = f"iterations number: {group['iterations']}"
316
                full_str = number + start_str + ', ' + end_str + ', ' + \
317
                    iterations_str
318
                print(full_str)
319
320
    def remove_associated_iterate_group_dict(self, pos, direction):
321
        """
322
        Remove an iterative loop associated to a plugin index
323
        """
324
        operation = 'add' if direction == 1 else 'remove'
325
        for i, iterate_group in enumerate(self.iterate_plugin_groups):
326
            if operation == 'remove':
327
                if iterate_group['start_index'] <= pos and \
328
                    pos <= iterate_group['end_index']:
329
                    # remove the loop if the plugin being removed is at any
330
                    # position within an iterative loop
331
                    del self.iterate_plugin_groups[i]
332
                    break
333
            elif operation == 'add':
334
                if iterate_group['start_index'] != iterate_group['end_index']:
335
                    # remove the loop only if the plugin is being added between
336
                    # the start and end of the loop
337
                    if iterate_group['start_index'] < pos and \
338
                        pos <= iterate_group['end_index']:
339
                        del self.iterate_plugin_groups[i]
340
                        break
341
342
    def check_pos_in_iterative_loop(self, pos):
343
        """
344
        Check if the given plugin position is in an iterative loop
345
        """
346
        is_in_loop = False
347
        for iterate_group in self.iterate_plugin_groups:
348
            if iterate_group['start_index'] <= pos and \
349
                pos <= iterate_group['end_index']:
350
                is_in_loop = True
351
                break
352
353
        return is_in_loop
354
355
    def __save_iterate_plugin_groups(self, group):
356
        '''
357
        Save information regarding the groups of plugins to iterate over
358
        '''
359
        for count, iterate_group in enumerate(self.iterate_plugin_groups):
360
            grp_name = str(count)
361
            grp = group.create_group(grp_name.encode('ascii'))
362
            shape = () # scalar data
363
            grp.create_dataset('start'.encode('ascii'), shape, 'i',
364
                iterate_group['start_index'])
365
            grp.create_dataset('end'.encode('ascii'), shape, 'i',
366
                iterate_group['end_index'])
367
            grp.create_dataset('iterations'.encode('ascii'), shape, 'i',
368
                iterate_group['iterations'])
369
370
    def shift_subsequent_iterative_loops(self, pos, direction):
371
        """
372
        Shift all iterative loops occurring after a given plugin position
373
        """
374
        # if removing a plugin that is positioned before a loop, the loop should
375
        # be shifted down by 1; but if removing a plugin that is positioned at
376
        # the start of the loop, it will be removed instead of shifted (ie, both
377
        # < or <= work for this case)
378
        #
379
        # if adding a plugin that will be positioned before a loop, the loop
380
        # should be shifted up by 1; also, if adding a plugin to be positioned
381
        # where the start of a loop currently exists, this should shift the loop
382
        # up by 1 as well (ie, only <= works for this case, hence the use of <=)
383
        for iterate_group in self.iterate_plugin_groups:
384
            if pos <= iterate_group['start_index']:
385
                self.shift_iterative_loop(iterate_group, direction)
386
387
    def shift_range_iterative_loops(self, positions, direction):
388
        """
389
        Shift all iterative loops within a range of plugin indices
390
        """
391
        for iterate_group in self.iterate_plugin_groups:
392
            if positions[0] <= iterate_group['start_index'] and \
393
                iterate_group['end_index'] <= positions[1]:
394
                self.shift_iterative_loop(iterate_group, direction)
395
396
    def shift_iterative_loop(self, iterate_group, direction):
397
        """
398
        Shift an iterative loop up or down in the process list, based on if a
399
        plugin is added or removed
400
        """
401
        if direction == 1:
402
            iterate_group['start_index'] += 1
403
            iterate_group['end_index'] += 1
404
        elif direction == -1:
405
            iterate_group['start_index'] -= 1
406
            iterate_group['end_index'] -= 1
407
        else:
408
            err_str = f"Bad direction value given to shift iterative loop: " \
409
                      f"{direction}"
410
            raise ValueError(err_str)
411
412
    def __save_savu_notes(self, notes):
413
        """ Save the version number
414
415
        :param notes: hdf5 group to save data to
416
        """
417
        from savu.version import __version__
418
419
        notes["version"] = __version__
420
421
    def __populate_plugins_group(self, plugin_group, plugin):
422
        """Populate the plugin group information which will be saved
423
424
        :param plugin_group: Plugin group to save to
425
        :param plugin: Plugin to be saved
426
        """
427
        plugin_group.attrs[NX_CLASS] = "NXnote".encode("ascii")
428
        required_keys = self._get_plugin_entry_template().keys()
429
        json_keys = self.__get_json_keys()
430
431
        self._save_citations(plugin, plugin_group)
432
433
        for key in required_keys:
434
            # only need to apply dumps if saving in configurator
435
            if key == "data":
436
                data = {}
437
                for k, v in plugin[key].items():
438
                    #  Replace any missing quotes around variables.
439
                    data[k] = pu._dumps(v)
440
            else:
441
                data = plugin[key]
442
443
            # get the string value
444
            data = json.dumps(data) if key in json_keys else plugin[key]
445
            # if the data is string it has to be encoded to ascii so that
446
            # hdf5 can save out the bytes
447
            if isinstance(data, str):
448
                data = data.encode("ascii")
449
            data = np.array([data])
450
            plugin_group.create_dataset(
451
                key.encode("ascii"), data.shape, data.dtype, data
452
            )
453
454
    def _get_plugin_group(self, plugins_group, plugin, count):
455
        """Return the plugin_group, into which the plugin information
456
         will be saved
457
458
        :param plugins_group: Current group to save inside
459
        :param plugin: Plugin to be saved
460
        :param count: Order number of the plugin in the process list
461
        :return: plugin group
462
        """
463
        if "pos" in plugin.keys():
464
            num = int(re.findall(r"\d+", str(plugin["pos"]))[0])
465
            letter = re.findall("[a-z]", str(plugin["pos"]))
466
            letter = letter[0] if letter else ""
467
            group_name = "%i%s" % (num, letter)
468
        else:
469
            group_name = count
470
        return plugins_group.create_group(group_name.encode("ascii"))
471
472
    def _add(self, idx, entry):
473
        self.plugin_list.insert(idx, entry)
474
        self.__set_loaders_and_savers()
475
476
    def _remove(self, idx):
477
        del self.plugin_list[idx]
478
        self.__set_loaders_and_savers()
479
480
    def _save_citations(self, plugin, group):
481
        """Save all the citations in the plugin
482
483
        :param plugin: dictionary of plugin information
484
        :param group: Group to save to
485
        """
486
        if "tools" in plugin.keys():
487
            citation_plugin = plugin.get("tools").get_citations()
488
            if citation_plugin:
489
                count = 1
490
                for citation in citation_plugin.values():
491
                    group_label = f"citation{count}"
492
                    if (
493
                        not citation.dependency
494
                        or self._dependent_citation_used(plugin, citation)
495
                    ):
496
                        self._save_citation_group(
497
                            citation, citation.__dict__, group, group_label
498
                        )
499
                        count += 1
500
501
    def _dependent_citation_used(self, plugin, citation):
502
        """Check if any Plugin parameter values match the citation
503
        dependency requirement.
504
505
        :param plugin: dictionary of plugin information
506
        :param citation: A plugin citation
507
        :return: bool True if the citation is for a parameter value
508
            being used inside this plugin
509
        """
510
        parameters = plugin["data"]
511
        for (
512
            citation_dependent_parameter,
513
            citation_dependent_value,
514
        ) in citation.dependency.items():
515
            current_value = parameters[citation_dependent_parameter]
516
            if current_value == citation_dependent_value:
517
                return True
518
        return False
519
520
    def _exec_citations(self, cite, citation):
521
        """Execute citations to variable
522
        :param cite: citation dictionary
523
        """
524
        for key, value in cite.items():
525
            exec("citation." + key + "= value")
526
527
    def _save_framework_citations(self, group):
528
        """Save all the citations in dict
529
530
        :param group: Group for nxs file
531
        """
532
        framework_cites = fc.get_framework_citations()
533
        for cite in framework_cites.values():
534
            label = cite.short_name_article
535
            del cite.short_name_article
536
            self._save_citation_group(cite, cite.__dict__, group, label)
537
538
539
    def _save_citation_group(self, citation, cite_dict, group, group_label):
540
        """Save the citations to the provided group label
541
542
        :param citation: Citation object
543
        :param cite_dict: Citation dictionary
544
        :param group: Group
545
        :param group_label: Group label
546
        :return:
547
        """
548
        citation_group = group.require_group(group_label.encode("ascii"))
549
        self._exec_citations(cite_dict, citation)
550
        citation.write(citation_group)
551
552
    def _get_docstring_info(self, plugin):
553
        plugin_inst = pu.plugins[plugin]()
554
        tools = plugin_inst.get_plugin_tools()
555
        tools._populate_default_parameters()
556
        return plugin_inst.docstring_info
557
558
    # def _byteify(self, input):
559
    #     if isinstance(input, dict):
560
    #         return {self._byteify(key): self._byteify(value)
561
    #                 for key, value in input.items()}
562
    #     elif isinstance(input, list):
563
    #         temp = [self._byteify(element) for element in input]
564
    #         return temp
565
    #     elif isinstance(input, str):
566
    #         return input.encode('utf-8')
567
    #     else:
568
    #         return input
569
570
    def _set_datasets_list(self, plugin):
571
        in_pData, out_pData = plugin.get_plugin_datasets()
572
        in_data_list = self._populate_datasets_list(in_pData)
573
        out_data_list = self._populate_datasets_list(out_pData)
574
        self.datasets_list.append({'in_datasets': in_data_list,
575
                                   'out_datasets': out_data_list})
576
577
    def _populate_datasets_list(self, data):
578
        data_list = []
579
        for d in data:
580
            name = d.data_obj.get_name()
581
            pattern = copy.deepcopy(d.get_pattern())
582
            pattern[list(pattern.keys())[0]]['max_frames_transfer'] = \
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable list does not seem to be defined.
Loading history...
583
                d.meta_data.get('max_frames_transfer')
584
            pattern[list(pattern.keys())[0]]['transfer_shape'] = \
585
                d.meta_data.get('transfer_shape')
586
            data_list.append({'name': name, 'pattern': pattern})
587
        return data_list
588
589
    def _get_datasets_list(self):
590
        return self.datasets_list
591
592
    def _reset_datasets_list(self):
593
        self.datasets_list = []
594
595
    def _get_n_loaders(self):
596
        return self.n_loaders
597
598
    def _get_n_savers(self):
599
        return self.n_savers
600
601
    def _get_loaders_index(self):
602
        return self.loader_idx
603
604
    def _get_savers_index(self):
605
        return self.saver_idx
606
607
    def _get_n_processing_plugins(self):
608
        return len(self.plugin_list) - self._get_n_loaders()
609
610
    def __set_loaders_and_savers(self):
611
        """Get lists of loader and saver positions within the plugin list and
612
        set the number of loaders.
613
614
        :returns: loader index list and saver index list
615
        :rtype: list(int(loader)), list(int(saver))
616
        """
617
        from savu.plugins.loaders.base_loader import BaseLoader
618
        from savu.plugins.savers.base_saver import BaseSaver
619
620
        loader_idx = []
621
        saver_idx = []
622
        self.n_plugins = len(self.plugin_list)
623
624
        for i in range(self.n_plugins):
625
            pid = self.plugin_list[i]["id"]
626
            bases = inspect.getmro(pu.load_class(pid))
627
            loader_list = [b for b in bases if b == BaseLoader]
628
            saver_list = [b for b in bases if b == BaseSaver]
629
            if loader_list:
630
                loader_idx.append(i)
631
            if saver_list:
632
                saver_idx.append(i)
633
        self.loader_idx = loader_idx
634
        self.saver_idx = saver_idx
635
        self.n_loaders = len(loader_idx)
636
        self.n_savers = len(saver_idx)
637
638
    def _check_loaders(self):
639
        """Check plugin list starts with a loader."""
640
        self.__set_loaders_and_savers()
641
        loaders = self._get_loaders_index()
642
643
        if loaders:
644
            if loaders[0] != 0 or loaders[-1] + 1 != len(loaders):
645
                raise Exception("All loader plugins must be at the beginning "
646
                                "of the process list")
647
        else:
648
            raise Exception("The first plugin in the process list must be a "
649
                            "loader plugin.")
650
651
    def _add_missing_savers(self, exp):
652
        """ Add savers for missing datasets. """
653
        data_names = exp.index["in_data"].keys()
654
        saved_data = []
655
656
        for i in self._get_savers_index():
657
            saved_data.append(self.plugin_list[i]["data"]["in_datasets"])
658
        saved_data = set([s for sub_list in saved_data for s in sub_list])
659
660
        for name in [data for data in data_names if data not in saved_data]:
661
            pos = exp.meta_data.get("nPlugin") + 1
662
            exp.meta_data.set("nPlugin", pos)
663
            process = {}
664
            plugin = pu.load_class("savu.plugins.savers.hdf5_saver")()
665
            ptools = plugin.get_plugin_tools()
666
            plugin.parameters["in_datasets"] = [name]
667
            process["name"] = plugin.name
668
            process["id"] = plugin.__module__
669
            process["pos"] = str(pos + 1)
670
            process["data"] = plugin.parameters
671
            process["active"] = True
672
            process["param"] = ptools.get_param_definitions()
673
            process["doc"] = ptools.docstring_info
674
            process["tools"] = ptools
675
            self._add(pos + 1, process)
676
677
    def _update_datasets(self, plugin_no, data_dict):
678
        n_loaders = self._get_n_loaders()
679
        idx = self._get_n_loaders() + plugin_no
680
        self.plugin_list[idx]["data"].update(data_dict)
681
682
    def _get_dataset_flow(self):
683
        datasets_idx = []
684
        n_loaders = self._get_n_loaders()
685
        n_plugins = self._get_n_processing_plugins()
686
        for i in range(self.n_loaders, n_loaders + n_plugins):
687
            datasets_idx.append(self.plugin_list[i]["data"]["out_datasets"])
688
        return datasets_idx
689
690
    def _contains_gpu_processes(self):
691
        """ Returns True if gpu processes exist in the process list. """
692
        try:
693
            from savu.plugins.driver.gpu_plugin import GpuPlugin
694
695
            for i in range(self.n_plugins):
696
                bases = inspect.getmro(
697
                    pu.load_class(self.plugin_list[i]["id"])
698
                )
699
                if GpuPlugin in bases:
700
                    return True
701
        except ImportError as ex:
702
            if "pynvml" in ex.message:
703
                logging.error(
704
                    "Error while importing GPU dependencies: %s", ex.message
705
                )
706
            else:
707
                raise
708
709
        return False
710
711
712
class Template(object):
713
    """A class to read and write templates for plugin lists."""
714
715
    def __init__(self, plist):
716
        super(Template, self).__init__()
717
        self.plist = plist
718
        self.creating = False
719
720
    def _output_template(self, fname, process_fname):
721
        plist = self.plist.plugin_list
722
        index = [i for i in range(len(plist)) if plist[i]["active"]]
723
724
        local_dict = MetaData(ordered=True)
725
        global_dict = MetaData(ordered=True)
726
        local_dict.set(["process_list"], os.path.abspath(process_fname))
727
728
        for i in index:
729
            params = self.__get_template_params(plist[i]["data"], [])
730
            name = plist[i]["name"]
731
            for p in params:
732
                ptype, isyaml, key, value = p
733
                if isyaml:
734
                    data_name = isyaml if ptype == "local" else "all"
735
                    local_dict.set([i + 1, name, data_name, key], value)
736
                elif ptype == "local":
737
                    local_dict.set([i + 1, name, key], value)
738
                else:
739
                    global_dict.set(["all", name, key], value)
740
741
        with open(fname, "w") as stream:
742
            local_dict.get_dictionary().update(global_dict.get_dictionary())
743
            yu.dump_yaml(local_dict.get_dictionary(), stream)
744
745
    def __get_template_params(self, params, tlist, yaml=False):
746
        for key, value in params.items():
747
            if key == "yaml_file":
748
                yaml_dict = self._get_yaml_dict(value)
749
                for entry in list(yaml_dict.keys()):
750
                    self.__get_template_params(
751
                        yaml_dict[entry]["params"], tlist, yaml=entry
752
                    )
753
            value = pu.is_template_param(value)
754
            if value is not False:
755
                ptype, value = value
756
                isyaml = yaml if yaml else False
757
                tlist.append([ptype, isyaml, key, value])
758
        return tlist
759
760
    def _get_yaml_dict(self, yfile):
761
        from savu.plugins.loaders.yaml_converter import YamlConverter
762
763
        yaml_c = YamlConverter()
764
        template_check = pu.is_template_param(yfile)
765
        yfile = template_check[1] if template_check is not False else yfile
766
        yaml_c.parameters = {"yaml_file": yfile}
767
        return yaml_c.setup(template=True)
768
769
    def update_process_list(self, template):
770
        tdict = yu.read_yaml(template)
771
        del tdict["process_list"]
772
773
        for plugin_no, entry in tdict.items():
774
            plugin = list(entry.keys())[0]
775
            for key, value in list(entry.values())[0].items():
776
                depth = self.dict_depth(value)
777
                if depth == 1:
778
                    self._set_param_for_template_loader_plugin(
779
                        plugin_no, key, value
780
                    )
781
                elif depth == 0:
782
                    if plugin_no == "all":
783
                        self._set_param_for_all_instances_of_a_plugin(
784
                            plugin, key, value
785
                        )
786
                    else:
787
                        data = self._get_plugin_data_dict(str(plugin_no))
788
                        data[key] = value
789
                else:
790
                    raise Exception("Template key not recognised.")
791
792
    def dict_depth(self, d, depth=0):
793
        if not isinstance(d, dict) or not d:
794
            return depth
795
        return max(self.dict_depth(v, depth + 1) for k, v in d.items())
796
797
    def _set_param_for_all_instances_of_a_plugin(self, plugin, param, value):
798
        # find all plugins with this name and replace the param
799
        for p in self.plist.plugin_list:
800
            if p["name"] == plugin:
801
                p["data"][param] = value
802
803
    def _set_param_for_template_loader_plugin(self, plugin_no, data, value):
804
        param_key = list(value.keys())[0]
805
        param_val = list(value.values())[0]
806
        pdict = self._get_plugin_data_dict(str(plugin_no))["template_param"]
807
        pdict = defaultdict(dict) if not pdict else pdict
808
        pdict[data][param_key] = param_val
809
810
    def _get_plugin_data_dict(self, plugin_no):
811
        """ input plugin_no as a string """
812
        plist = self.plist.plugin_list
813
        index = [plist[i]["pos"] for i in range(len(plist))]
814
        return plist[index.index(plugin_no)]["data"]
815
816
817
class CitationInformation(object):
818
    """
819
    Descriptor of Citation Information for plugins
820
    """
821
822
    def __init__(
823
        self,
824
        description,
825
        bibtex="",
826
        endnote="",
827
        doi="",
828
        short_name_article="",
829
        dependency="",
830
    ):
831
        self.description = description
832
        self.short_name_article = short_name_article
833
        self.bibtex = bibtex
834
        self.endnote = endnote
835
        self.doi = doi
836
        self.dependency = dependency
837
        self.name = self._set_citation_name()
838
        self.id = self._set_id()
839
840
    def _set_citation_name(self):
841
        """Create a short identifier using the short name of the article
842
        and the first author
843
        """
844
        cite_info = True if self.endnote or self.bibtex else False
845
846
        if cite_info and self.short_name_article:
847
            cite_name = (
848
                self.short_name_article.title()
849
                + " by "
850
                + self._get_first_author()
851
                + " et al."
852
            )
853
        elif cite_info:
854
            cite_name = (
855
                self._get_title()
856
                + " by "
857
                + self._get_first_author()
858
                + " et al."
859
            )
860
        else:
861
            # Set the tools class name as the citation name causes overwriting
862
            # module_name = tool_class.__module__.split('.')[-1].replace('_', ' ')
863
            # cite_name = module_name.split('tools')[0].title()
864
            cite_name = self.description
865
        return cite_name
866
867
    def _set_citation_id(self, tool_class):
868
        """Create a short identifier using the bibtex identification"""
869
        # Remove blank space
870
        if self.bibtex:
871
            cite_id = self._get_id()
872
        else:
873
            # Set the tools class name as the citation name
874
            module_name = tool_class.__module__.split(".")[-1]
875
            cite_id = module_name
876
        return cite_id
877
878
    def _set_id(self):
879
        """ Retrieve the id from the bibtex """
880
        cite_id = self.seperate_bibtex("@article{", ",")
881
        return cite_id
882
883
    def _get_first_author(self):
884
        """ Retrieve the first author name """
885
        if self.endnote:
886
            first_author = self.seperate_endnote("%A ")
887
        elif self.bibtex:
888
            first_author = self.seperate_bibtex("author={", "}", author=True)
889
        return first_author
0 ignored issues
show
introduced by
The variable first_author does not seem to be defined for all execution paths.
Loading history...
890
891
    def _get_title(self):
892
        """ Retrieve the title """
893
        if self.endnote:
894
            title = self.seperate_endnote("%T ")
895
        elif self.bibtex:
896
            title = self.seperate_bibtex("title={", "}")
897
        return title
0 ignored issues
show
introduced by
The variable title does not seem to be defined for all execution paths.
Loading history...
898
899
    def seperate_endnote(self, seperation_char):
900
        """Return the string contained between start characters
901
        and a new line
902
903
        :param seperation_char: Character to split the string at
904
        :return: The string contained between start characters and a new line
905
        """
906
        item = self.endnote.partition(seperation_char)[2].split("\n")[0]
907
        return item
908
909
    def seperate_bibtex(self, start_char, end_char, author=False):
910
        """Return the string contained between provided characters
911
912
        :param start_char: Character to split the string at
913
        :param end_char: Character to end the split at
914
        :return: The string contained between both characters
915
        """
916
        plain_text = doc.remove_new_lines(self.bibtex)
917
        item = plain_text.partition(start_char)[2].split(end_char)[0]
918
        if author:
919
            # Return one author only
920
            if " and " in item:
921
                item = item.split(" and ")[0]
922
923
        return item
924
925
    def write(self, citation_group):
926
        # classes don't have to be encoded to ASCII
927
        citation_group.attrs[NX_CLASS] = "NXcite"
928
        # Valid ascii sequences will be encoded, invalid ones will be
929
        # preserved as escape sequences
930
        description_array = np.array([self.description.encode('ascii','backslashreplace')])
931
        citation_group.create_dataset('description'.encode('ascii'),
932
                                      description_array.shape,
933
                                      description_array.dtype,
934
                                      description_array)
935
        doi_array = np.array([self.doi.encode('ascii')])
936
        citation_group.create_dataset('doi'.encode('ascii'),
937
                                      doi_array.shape,
938
                                      doi_array.dtype,
939
                                      doi_array)
940
        endnote_array = np.array([self.endnote.encode('ascii','backslashreplace')])
941
        citation_group.create_dataset('endnote'.encode('ascii'),
942
                                      endnote_array.shape,
943
                                      endnote_array.dtype,
944
                                      endnote_array)
945
        bibtex_array = np.array([self.bibtex.encode('ascii','backslashreplace')])
946
        citation_group.create_dataset('bibtex'.encode('ascii'),
947
                                      bibtex_array.shape,
948
                                      bibtex_array.dtype,
949
                                      bibtex_array)
950