Test Failed
Pull Request — master (#888)
by Daniil
03:51
created

PluginList._get_plugin_group()   A

Complexity

Conditions 3

Size

Total Lines 17
Code Lines 8

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 8
nop 4
dl 0
loc 17
rs 10
c 0
b 0
f 0
1
# Copyright 2014 Diamond Light Source Ltd.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
15
"""
16
.. module:: plugin_list
17
   :platform: Unix
18
   :synopsis: Contains the PluginList class, which deals with loading and \
19
   saving the plugin list, and the CitationInformation class. An instance is \
20
   held by the MetaData class.
21
22
.. moduleauthor:: Nicola Wadeson <[email protected]>
23
24
"""
25
import ast
26
import copy
27
import inspect
28
import json
29
import logging
30
import os
31
import re
32
from collections import defaultdict
33
34
import h5py
35
import numpy as np
36
37
import savu.data.framework_citations as fc
38
import savu.plugins.docstring_parser as doc
39
import savu.plugins.loaders.utils.yaml_utils as yu
40
import savu.plugins.utils as pu
41
from savu.data.meta_data import MetaData
42
43
NX_CLASS = "NX_class"
44
45
46
class PluginList(object):
47
    """
48
    The PluginList class handles the plugin list - loading, saving and adding
49
    citation information for the plugin
50
    """
51
52
    def __init__(self):
53
        self.plugin_list = []
54
        self.n_plugins = None
55
        self.n_loaders = 0
56
        self.n_savers = 0
57
        self.loader_idx = None
58
        self.saver_idx = None
59
        self.datasets_list = []
60
        self.saver_plugin_status = True
61
        self._template = None
62
        self.version = None
63
        self.iterate_plugin_groups = []
64
65
    def add_template(self, create=False):
66
        self._template = Template(self)
67
        if create:
68
            self._template.creating = True
69
70
    def _get_plugin_entry_template(self):
71
        template = {"active": True, "name": None, "id": None, "data": None}
72
        return template
73
74
    def __get_json_keys(self):
75
        return ["data"]
76
77
    def _populate_plugin_list(
78
        self, filename, active_pass=False, template=False
79
    ):
80
        """ Populate the plugin list from a nexus file. """
81
        with h5py.File(filename, "r") as plugin_file:
82
            if "entry/savu_notes/version" in plugin_file:
83
                self.version = plugin_file["entry/savu_notes/version"][()]
84
                self._show_process_list_version()
85
86
            plugin_group = plugin_file["entry/plugin"]
87
            self.plugin_list = []
88
            single_val = ["name", "id", "pos", "active"]
89
            exclude = ["citation"]
90
            ordered_pl_keys = pu.sort_alphanum(list(plugin_group.keys()))
91
            for group in ordered_pl_keys:
92
                plugin = self._get_plugin_entry_template()
93
                entry_keys = plugin_group[group].keys()
94
                parameters = [
95
                    k
96
                    for k in entry_keys
97
                    for e in exclude
98
                    if k not in single_val and e not in k
99
                ]
100
101
                if "active" in entry_keys:
102
                    plugin["active"] = plugin_group[group]["active"][0]
103
104
                if plugin['active'] or active_pass:
105
                    plugin['name'] = plugin_group[group]['name'][0].decode("utf-8")
106
                    plugin['id'] = plugin_group[group]['id'][0].decode("utf-8")
107
                    plugin_tools = None
108
                    try:
109
                        plugin_class = pu.load_class(plugin["id"])()
110
                        # Populate the parameters (including those from it's base classes)
111
                        plugin_tools = plugin_class.get_plugin_tools()
112
                        if not plugin_tools:
113
                            raise OSError(f"Tools file not found for {plugin['name']}")
114
                        plugin_tools._populate_default_parameters()
115
                    except ImportError:
116
                        # No plugin class found
117
                        logging.error(f"No class found for {plugin['name']}")
118
119
                    plugin['doc'] = plugin_tools.docstring_info if plugin_tools else ""
120
                    plugin['tools'] = plugin_tools if plugin_tools else {}
121
                    plugin['param'] = plugin_tools.get_param_definitions() if \
122
                        plugin_tools else {}
123
                    plugin['pos'] = group.strip()
124
125
                    for param in parameters:
126
                        try:
127
                            plugin[param] = json.loads(plugin_group[group][param][0])
128
                        except ValueError as e:
129
                            raise ValueError(
130
                                f"Error: {e}\n Could not parse key '{param}' from group '{group}' as JSON"
131
                            )
132
                    self.plugin_list.append(plugin)
133
134
            # add info about groups of plugins to iterate over into
135
            # self.iterate_plugin_groups
136
            self.clear_iterate_plugin_group_dicts()
137
            try:
138
                iterate_groups = plugin_file['entry/iterate_plugin_groups']
139
                for key in list(iterate_groups.keys()):
140
                    iterate_group_dict = {
141
                        'start_index': iterate_groups[key]['start'][()],
142
                        'end_index': iterate_groups[key]['end'][()],
143
                        'iterations': iterate_groups[key]['iterations'][()]
144
                    }
145
                    self.iterate_plugin_groups.append(iterate_group_dict)
146
            except Exception as e:
147
                err_str = f"Process list file {filename} doesn't have the " \
148
                          f"iterate_plugin_groups internal hdf5 path"
149
                print(err_str)
150
151
            if template:
152
                self.add_template()
153
                self._template.update_process_list(template)
154
155
    def _show_process_list_version(self):
156
        """If the input process list was created using an older version
157
        of Savu, then alert the user"""
158
        from savu.version import __version__
159
        pl_version = float(self.version)
160
        if float(__version__) > pl_version:
161
            separator = "*" * 53
162
            print(separator)
163
            print(f"*** This process list was created using Savu "
164
                  f"{pl_version}  ***")
165
            print(separator)
166
            print(f"The process list has been updated, the incorrect \n"
167
                  f"parameter values have been reverted to default. \n"
168
                  f"Any warnings below point to the problematic parameters.\n"
169
                  f"Save this process list to save the updated values.")
170
            print(separator)
171
172
    def _save_plugin_list(self, out_filename):
173
        with h5py.File(out_filename, "a") as nxs_file:
174
175
            entry = nxs_file.require_group("entry")
176
177
            self._save_framework_citations(self._overwrite_group(
178
                entry, 'framework_citations', 'NXcollection'))
179
180
            self.__save_savu_notes(self._overwrite_group(
181
                entry, 'savu_notes', 'NXnote'))
182
183
            plugins_group = self._overwrite_group(entry, 'plugin', 'NXprocess')
184
185
            count = 1
186
            for plugin in self.plugin_list:
187
                plugin_group = self._get_plugin_group(
188
                    plugins_group, plugin, count
189
                )
190
                self.__populate_plugins_group(plugin_group, plugin)
191
192
            self.__save_iterate_plugin_groups(self._overwrite_group(
193
                entry, 'iterate_plugin_groups', 'NXnote'))
194
195
        if self._template and self._template.creating:
196
            fname = os.path.splitext(out_filename)[0] + ".savu"
197
            self._template._output_template(fname, out_filename)
198
199
    def _overwrite_group(self, entry, name, nxclass):
200
        if name in entry:
201
            entry.pop(name)
202
        group = entry.create_group(name.encode("ascii"))
203
        group.attrs[NX_CLASS] = nxclass.encode("ascii")
204
        return group
205
206
    def add_iterate_plugin_group(self, start, end, iterations):
207
        """Add an element to self.iterate_plugin_groups"""
208
        group_new = {
209
            'start_index': start,
210
            'end_index': end,
211
            'iterations': iterations
212
        }
213
        list_new_indices = list(range(start, end+1))
214
215
        if iterations <= 0:
216
            print("The number of iterations should be larger than zero and nonnegative")
217
            return
218
        elif start <= 0 or start > len(self.plugin_list) or end <= 0 or end > len(self.plugin_list):
219
            print("The given plugin indices are not within the range of existing plugin indices")
220
            return
221
222
        # crosscheck with the existing iterative loops
223
        if len(self.iterate_plugin_groups) != 0:
224
            noexactlist = True
225
            nointersection = True
226
            for count, group in enumerate(self.iterate_plugin_groups, 1):
227
                start_int = int(group['start_index'])
228
                end_int = int(group['end_index'])
229
                list_existing_indices = list(range(start_int, end_int+1))
230
                if bool(set(list_new_indices).intersection(list_existing_indices)):
231
                    # check if the intersection of lists is exact (number of iterations to change)
232
                    nointersection = False
233
                    if list_new_indices == list_existing_indices:
234
                        print(f"The number of iterations of loop group no. {count}, {list_new_indices} has been set to: {iterations}")
235
                        self.iterate_plugin_groups[count-1]["iterations"] = iterations
236
                        noexactlist = False
237
                    else:
238
                        print(f"The plugins of group no. {count} are already set to be iterative: {set(list_new_indices).intersection(list_existing_indices)}")
239
            if noexactlist and nointersection:
240
                self.iterate_plugin_groups.append(group_new)
241
        else:
242
            self.iterate_plugin_groups.append(group_new)
243
        self.print_iterative_loops()
244
245
    def remove_iterate_plugin_groups(self, indices):
246
        """ Remove elements from self.iterate_plugin_groups """
247
        if len(indices) == 0:
248
            # remove all iterative loops in process list
249
            prompt_str = 'Are you sure you want to remove all iterative ' \
250
                'loops? [y/N]'
251
            check = input(prompt_str)
252
            should_remove_all = check.lower() == 'y'
253
            if should_remove_all:
254
                self.clear_iterate_plugin_group_dicts()
255
        else:
256
            # remove specified iterative loops in process list
257
            sorted_indices = sorted(indices)
258
259
            if sorted_indices[0] <= 0:
260
                print('The iterative loops are indexed starting from 1')
261
                self.print_iterative_loops()
262
                return
263
264
            for i in reversed(sorted_indices):
265
                try:
266
                    # convert the one-based index to a zero-based index
267
                    del self.iterate_plugin_groups[i - 1]
268
                except IndexError as e:
269
                    info_str = f"There doesn't exist an iterative loop with " \
270
                               f"number {i}"
271
                    print(info_str)
272
273
        self.print_iterative_loops()
274
275
    def clear_iterate_plugin_group_dicts(self):
276
        """
277
        Reset the list of dicts representing groups of plugins to iterate over
278
        """
279
        self.iterate_plugin_groups = []
280
281
    def get_iterate_plugin_group_dicts(self):
282
        """
283
        Return the list of dicts representing groups of plugins to iterate over
284
        """
285
        return self.iterate_plugin_groups
286
287
    def print_iterative_loops(self):
288
        if len(self.iterate_plugin_groups) == 0:
289
            print('There are no iterative loops in the current process list')
290
        else:
291
            print('Iterative loops in the current process list are:')
292
            for count, group in enumerate(self.iterate_plugin_groups, 1):
293
                number = f"({count}) "
294
                start_str = f"start plugin index: {group['start_index']}"
295
                end_str = f"end index: {group['end_index']}"
296
                iterations_str = f"iterations number: {group['iterations']}"
297
                full_str = number + start_str + ', ' + end_str + ', ' + \
298
                    iterations_str
299
                print(full_str)
300
301
    def remove_associated_iterate_group_dict(self, pos, direction):
302
        """
303
        Remove an iterative loop associated to a plugin index
304
        """
305
        operation = 'add' if direction == 1 else 'remove'
306
        for i, iterate_group in enumerate(self.iterate_plugin_groups):
307
            if operation == 'remove':
308
                if iterate_group['start_index'] <= pos and \
309
                    pos <= iterate_group['end_index']:
310
                    # remove the loop if the plugin being removed is at any
311
                    # position within an iterative loop
312
                    del self.iterate_plugin_groups[i]
313
                    break
314
            elif operation == 'add':
315
                if iterate_group['start_index'] != iterate_group['end_index']:
316
                    # remove the loop only if the plugin is being added between
317
                    # the start and end of the loop
318
                    if iterate_group['start_index'] < pos and \
319
                        pos <= iterate_group['end_index']:
320
                        del self.iterate_plugin_groups[i]
321
                        break
322
323
    def check_pos_in_iterative_loop(self, pos):
324
        """
325
        Check if the given plugin position is in an iterative loop
326
        """
327
        is_in_loop = False
328
        for iterate_group in self.iterate_plugin_groups:
329
            if iterate_group['start_index'] <= pos and \
330
                pos <= iterate_group['end_index']:
331
                is_in_loop = True
332
                break
333
334
        return is_in_loop
335
336
    def __save_iterate_plugin_groups(self, group):
337
        '''
338
        Save information regarding the groups of plugins to iterate over
339
        '''
340
        for count, iterate_group in enumerate(self.iterate_plugin_groups):
341
            grp_name = str(count)
342
            grp = group.create_group(grp_name.encode('ascii'))
343
            shape = () # scalar data
344
            grp.create_dataset('start'.encode('ascii'), shape, 'i',
345
                iterate_group['start_index'])
346
            grp.create_dataset('end'.encode('ascii'), shape, 'i',
347
                iterate_group['end_index'])
348
            grp.create_dataset('iterations'.encode('ascii'), shape, 'i',
349
                iterate_group['iterations'])
350
351
    def shift_subsequent_iterative_loops(self, pos, direction):
352
        """
353
        Shift all iterative loops occurring after a given plugin position
354
        """
355
        # if removing a plugin that is positioned before a loop, the loop should
356
        # be shifted down by 1; but if removing a plugin that is positioned at
357
        # the start of the loop, it will be removed instead of shifted (ie, both
358
        # < or <= work for this case)
359
        #
360
        # if adding a plugin that will be positioned before a loop, the loop
361
        # should be shifted up by 1; also, if adding a plugin to be positioned
362
        # where the start of a loop currently exists, this should shift the loop
363
        # up by 1 as well (ie, only <= works for this case, hence the use of <=)
364
        for iterate_group in self.iterate_plugin_groups:
365
            if pos <= iterate_group['start_index']:
366
                self.shift_iterative_loop(iterate_group, direction)
367
368
    def shift_range_iterative_loops(self, positions, direction):
369
        """
370
        Shift all iterative loops within a range of plugin indices
371
        """
372
        for iterate_group in self.iterate_plugin_groups:
373
            if positions[0] <= iterate_group['start_index'] and \
374
                iterate_group['end_index'] <= positions[1]:
375
                self.shift_iterative_loop(iterate_group, direction)
376
377
    def shift_iterative_loop(self, iterate_group, direction):
378
        """
379
        Shift an iterative loop up or down in the process list, based on if a
380
        plugin is added or removed
381
        """
382
        if direction == 1:
383
            iterate_group['start_index'] += 1
384
            iterate_group['end_index'] += 1
385
        elif direction == -1:
386
            iterate_group['start_index'] -= 1
387
            iterate_group['end_index'] -= 1
388
        else:
389
            err_str = f"Bad direction value given to shift iterative loop: " \
390
                      f"{direction}"
391
            raise ValueError(err_str)
392
393
    def __save_savu_notes(self, notes):
394
        """ Save the version number
395
396
        :param notes: hdf5 group to save data to
397
        """
398
        from savu.version import __version__
399
400
        notes["version"] = __version__
401
402
    def __populate_plugins_group(self, plugin_group, plugin):
403
        """Populate the plugin group information which will be saved
404
405
        :param plugin_group: Plugin group to save to
406
        :param plugin: Plugin to be saved
407
        """
408
        plugin_group.attrs[NX_CLASS] = "NXnote".encode("ascii")
409
        required_keys = self._get_plugin_entry_template().keys()
410
        json_keys = self.__get_json_keys()
411
412
        self._save_citations(plugin, plugin_group)
413
414
        for key in required_keys:
415
            # only need to apply dumps if saving in configurator
416
            if key == "data":
417
                data = {}
418
                for k, v in plugin[key].items():
419
                    #  Replace any missing quotes around variables.
420
                    data[k] = pu._dumps(v)
421
            else:
422
                data = plugin[key]
423
424
            # get the string value
425
            data = json.dumps(data) if key in json_keys else plugin[key]
426
            # if the data is string it has to be encoded to ascii so that
427
            # hdf5 can save out the bytes
428
            if isinstance(data, str):
429
                data = data.encode("ascii")
430
            data = np.array([data])
431
            plugin_group.create_dataset(
432
                key.encode("ascii"), data.shape, data.dtype, data
433
            )
434
435
    def _get_plugin_group(self, plugins_group, plugin, count):
436
        """Return the plugin_group, into which the plugin information
437
         will be saved
438
439
        :param plugins_group: Current group to save inside
440
        :param plugin: Plugin to be saved
441
        :param count: Order number of the plugin in the process list
442
        :return: plugin group
443
        """
444
        if "pos" in plugin.keys():
445
            num = int(re.findall(r"\d+", str(plugin["pos"]))[0])
446
            letter = re.findall("[a-z]", str(plugin["pos"]))
447
            letter = letter[0] if letter else ""
448
            group_name = "%i%s" % (num, letter)
449
        else:
450
            group_name = count
451
        return plugins_group.create_group(group_name.encode("ascii"))
452
453
    def _add(self, idx, entry):
454
        self.plugin_list.insert(idx, entry)
455
        self.__set_loaders_and_savers()
456
457
    def _remove(self, idx):
458
        del self.plugin_list[idx]
459
        self.__set_loaders_and_savers()
460
461
    def _save_citations(self, plugin, group):
462
        """Save all the citations in the plugin
463
464
        :param plugin: dictionary of plugin information
465
        :param group: Group to save to
466
        """
467
        if "tools" in plugin.keys():
468
            citation_plugin = plugin.get("tools").get_citations()
469
            if citation_plugin:
470
                count = 1
471
                for citation in citation_plugin.values():
472
                    group_label = f"citation{count}"
473
                    if (
474
                        not citation.dependency
475
                        or self._dependent_citation_used(plugin, citation)
476
                    ):
477
                        self._save_citation_group(
478
                            citation, citation.__dict__, group, group_label
479
                        )
480
                        count += 1
481
482
    def _dependent_citation_used(self, plugin, citation):
483
        """Check if any Plugin parameter values match the citation
484
        dependency requirement.
485
486
        :param plugin: dictionary of plugin information
487
        :param citation: A plugin citation
488
        :return: bool True if the citation is for a parameter value
489
            being used inside this plugin
490
        """
491
        parameters = plugin["data"]
492
        for (
493
            citation_dependent_parameter,
494
            citation_dependent_value,
495
        ) in citation.dependency.items():
496
            current_value = parameters[citation_dependent_parameter]
497
            if current_value == citation_dependent_value:
498
                return True
499
        return False
500
501
    def _exec_citations(self, cite, citation):
502
        """Execute citations to variable
503
        :param cite: citation dictionary
504
        """
505
        for key, value in cite.items():
506
            exec("citation." + key + "= value")
507
508
    def _save_framework_citations(self, group):
509
        """Save all the citations in dict
510
511
        :param group: Group for nxs file
512
        """
513
        framework_cites = fc.get_framework_citations()
514
        for cite in framework_cites.values():
515
            label = cite.short_name_article
516
            del cite.short_name_article
517
            self._save_citation_group(cite, cite.__dict__, group, label)
518
519
520
    def _save_citation_group(self, citation, cite_dict, group, group_label):
521
        """Save the citations to the provided group label
522
523
        :param citation: Citation object
524
        :param cite_dict: Citation dictionary
525
        :param group: Group
526
        :param group_label: Group label
527
        :return:
528
        """
529
        citation_group = group.require_group(group_label.encode("ascii"))
530
        self._exec_citations(cite_dict, citation)
531
        citation.write(citation_group)
532
533
    def _get_docstring_info(self, plugin):
534
        plugin_inst = pu.plugins[plugin]()
535
        tools = plugin_inst.get_plugin_tools()
536
        tools._populate_default_parameters()
537
        return plugin_inst.docstring_info
538
539
    # def _byteify(self, input):
540
    #     if isinstance(input, dict):
541
    #         return {self._byteify(key): self._byteify(value)
542
    #                 for key, value in input.items()}
543
    #     elif isinstance(input, list):
544
    #         temp = [self._byteify(element) for element in input]
545
    #         return temp
546
    #     elif isinstance(input, str):
547
    #         return input.encode('utf-8')
548
    #     else:
549
    #         return input
550
551
    def _set_datasets_list(self, plugin):
552
        in_pData, out_pData = plugin.get_plugin_datasets()
553
        in_data_list = self._populate_datasets_list(in_pData)
554
        out_data_list = self._populate_datasets_list(out_pData)
555
        self.datasets_list.append({'in_datasets': in_data_list,
556
                                   'out_datasets': out_data_list})
557
558
    def _populate_datasets_list(self, data):
559
        data_list = []
560
        for d in data:
561
            name = d.data_obj.get_name()
562
            pattern = copy.deepcopy(d.get_pattern())
563
            pattern[list(pattern.keys())[0]]['max_frames_transfer'] = \
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable list does not seem to be defined.
Loading history...
564
                d.meta_data.get('max_frames_transfer')
565
            pattern[list(pattern.keys())[0]]['transfer_shape'] = \
566
                d.meta_data.get('transfer_shape')
567
            data_list.append({'name': name, 'pattern': pattern})
568
        return data_list
569
570
    def _get_datasets_list(self):
571
        return self.datasets_list
572
573
    def _reset_datasets_list(self):
574
        self.datasets_list = []
575
576
    def _get_n_loaders(self):
577
        return self.n_loaders
578
579
    def _get_n_savers(self):
580
        return self.n_savers
581
582
    def _get_loaders_index(self):
583
        return self.loader_idx
584
585
    def _get_savers_index(self):
586
        return self.saver_idx
587
588
    def _get_n_processing_plugins(self):
589
        return len(self.plugin_list) - self._get_n_loaders()
590
591
    def __set_loaders_and_savers(self):
592
        """Get lists of loader and saver positions within the plugin list and
593
        set the number of loaders.
594
595
        :returns: loader index list and saver index list
596
        :rtype: list(int(loader)), list(int(saver))
597
        """
598
        from savu.plugins.loaders.base_loader import BaseLoader
599
        from savu.plugins.savers.base_saver import BaseSaver
600
601
        loader_idx = []
602
        saver_idx = []
603
        self.n_plugins = len(self.plugin_list)
604
605
        for i in range(self.n_plugins):
606
            pid = self.plugin_list[i]["id"]
607
            bases = inspect.getmro(pu.load_class(pid))
608
            loader_list = [b for b in bases if b == BaseLoader]
609
            saver_list = [b for b in bases if b == BaseSaver]
610
            if loader_list:
611
                loader_idx.append(i)
612
            if saver_list:
613
                saver_idx.append(i)
614
        self.loader_idx = loader_idx
615
        self.saver_idx = saver_idx
616
        self.n_loaders = len(loader_idx)
617
        self.n_savers = len(saver_idx)
618
619
    def _check_loaders(self):
620
        """Check plugin list starts with a loader."""
621
        self.__set_loaders_and_savers()
622
        loaders = self._get_loaders_index()
623
624
        if loaders:
625
            if loaders[0] != 0 or loaders[-1] + 1 != len(loaders):
626
                raise Exception("All loader plugins must be at the beginning "
627
                                "of the process list")
628
        else:
629
            raise Exception("The first plugin in the process list must be a "
630
                            "loader plugin.")
631
632
    def _add_missing_savers(self, exp):
633
        """ Add savers for missing datasets. """
634
        data_names = exp.index["in_data"].keys()
635
        saved_data = []
636
637
        for i in self._get_savers_index():
638
            saved_data.append(self.plugin_list[i]["data"]["in_datasets"])
639
        saved_data = set([s for sub_list in saved_data for s in sub_list])
640
641
        for name in [data for data in data_names if data not in saved_data]:
642
            pos = exp.meta_data.get("nPlugin") + 1
643
            exp.meta_data.set("nPlugin", pos)
644
            process = {}
645
            plugin = pu.load_class("savu.plugins.savers.hdf5_saver")()
646
            ptools = plugin.get_plugin_tools()
647
            plugin.parameters["in_datasets"] = [name]
648
            process["name"] = plugin.name
649
            process["id"] = plugin.__module__
650
            process["pos"] = str(pos + 1)
651
            process["data"] = plugin.parameters
652
            process["active"] = True
653
            process["param"] = ptools.get_param_definitions()
654
            process["doc"] = ptools.docstring_info
655
            process["tools"] = ptools
656
            self._add(pos + 1, process)
657
658
    def _update_datasets(self, plugin_no, data_dict):
659
        n_loaders = self._get_n_loaders()
660
        idx = self._get_n_loaders() + plugin_no
661
        self.plugin_list[idx]["data"].update(data_dict)
662
663
    def _get_dataset_flow(self):
664
        datasets_idx = []
665
        n_loaders = self._get_n_loaders()
666
        n_plugins = self._get_n_processing_plugins()
667
        for i in range(self.n_loaders, n_loaders + n_plugins):
668
            datasets_idx.append(self.plugin_list[i]["data"]["out_datasets"])
669
        return datasets_idx
670
671
    def _contains_gpu_processes(self):
672
        """ Returns True if gpu processes exist in the process list. """
673
        try:
674
            from savu.plugins.driver.gpu_plugin import GpuPlugin
675
676
            for i in range(self.n_plugins):
677
                bases = inspect.getmro(
678
                    pu.load_class(self.plugin_list[i]["id"])
679
                )
680
                if GpuPlugin in bases:
681
                    return True
682
        except ImportError as ex:
683
            if "pynvml" in ex.message:
684
                logging.error(
685
                    "Error while importing GPU dependencies: %s", ex.message
686
                )
687
            else:
688
                raise
689
690
        return False
691
692
693
class Template(object):
694
    """A class to read and write templates for plugin lists."""
695
696
    def __init__(self, plist):
697
        super(Template, self).__init__()
698
        self.plist = plist
699
        self.creating = False
700
701
    def _output_template(self, fname, process_fname):
702
        plist = self.plist.plugin_list
703
        index = [i for i in range(len(plist)) if plist[i]["active"]]
704
705
        local_dict = MetaData(ordered=True)
706
        global_dict = MetaData(ordered=True)
707
        local_dict.set(["process_list"], os.path.abspath(process_fname))
708
709
        for i in index:
710
            params = self.__get_template_params(plist[i]["data"], [])
711
            name = plist[i]["name"]
712
            for p in params:
713
                ptype, isyaml, key, value = p
714
                if isyaml:
715
                    data_name = isyaml if ptype == "local" else "all"
716
                    local_dict.set([i + 1, name, data_name, key], value)
717
                elif ptype == "local":
718
                    local_dict.set([i + 1, name, key], value)
719
                else:
720
                    global_dict.set(["all", name, key], value)
721
722
        with open(fname, "w") as stream:
723
            local_dict.get_dictionary().update(global_dict.get_dictionary())
724
            yu.dump_yaml(local_dict.get_dictionary(), stream)
725
726
    def __get_template_params(self, params, tlist, yaml=False):
727
        for key, value in params.items():
728
            if key == "yaml_file":
729
                yaml_dict = self._get_yaml_dict(value)
730
                for entry in list(yaml_dict.keys()):
731
                    self.__get_template_params(
732
                        yaml_dict[entry]["params"], tlist, yaml=entry
733
                    )
734
            value = pu.is_template_param(value)
735
            if value is not False:
736
                ptype, value = value
737
                isyaml = yaml if yaml else False
738
                tlist.append([ptype, isyaml, key, value])
739
        return tlist
740
741
    def _get_yaml_dict(self, yfile):
742
        from savu.plugins.loaders.yaml_converter import YamlConverter
743
744
        yaml_c = YamlConverter()
745
        template_check = pu.is_template_param(yfile)
746
        yfile = template_check[1] if template_check is not False else yfile
747
        yaml_c.parameters = {"yaml_file": yfile}
748
        return yaml_c.setup(template=True)
749
750
    def update_process_list(self, template):
751
        tdict = yu.read_yaml(template)
752
        del tdict["process_list"]
753
754
        for plugin_no, entry in tdict.items():
755
            plugin = list(entry.keys())[0]
756
            for key, value in list(entry.values())[0].items():
757
                depth = self.dict_depth(value)
758
                if depth == 1:
759
                    self._set_param_for_template_loader_plugin(
760
                        plugin_no, key, value
761
                    )
762
                elif depth == 0:
763
                    if plugin_no == "all":
764
                        self._set_param_for_all_instances_of_a_plugin(
765
                            plugin, key, value
766
                        )
767
                    else:
768
                        data = self._get_plugin_data_dict(str(plugin_no))
769
                        data[key] = value
770
                else:
771
                    raise Exception("Template key not recognised.")
772
773
    def dict_depth(self, d, depth=0):
774
        if not isinstance(d, dict) or not d:
775
            return depth
776
        return max(self.dict_depth(v, depth + 1) for k, v in d.items())
777
778
    def _set_param_for_all_instances_of_a_plugin(self, plugin, param, value):
779
        # find all plugins with this name and replace the param
780
        for p in self.plist.plugin_list:
781
            if p["name"] == plugin:
782
                p["data"][param] = value
783
784
    def _set_param_for_template_loader_plugin(self, plugin_no, data, value):
785
        param_key = list(value.keys())[0]
786
        param_val = list(value.values())[0]
787
        pdict = self._get_plugin_data_dict(str(plugin_no))["template_param"]
788
        pdict = defaultdict(dict) if not pdict else pdict
789
        pdict[data][param_key] = param_val
790
791
    def _get_plugin_data_dict(self, plugin_no):
792
        """ input plugin_no as a string """
793
        plist = self.plist.plugin_list
794
        index = [plist[i]["pos"] for i in range(len(plist))]
795
        return plist[index.index(plugin_no)]["data"]
796
797
798
class CitationInformation(object):
799
    """
800
    Descriptor of Citation Information for plugins
801
    """
802
803
    def __init__(
804
        self,
805
        description,
806
        bibtex="",
807
        endnote="",
808
        doi="",
809
        short_name_article="",
810
        dependency="",
811
    ):
812
        self.description = description
813
        self.short_name_article = short_name_article
814
        self.bibtex = bibtex
815
        self.endnote = endnote
816
        self.doi = doi
817
        self.dependency = dependency
818
        self.name = self._set_citation_name()
819
        self.id = self._set_id()
820
821
    def _set_citation_name(self):
822
        """Create a short identifier using the short name of the article
823
        and the first author
824
        """
825
        cite_info = True if self.endnote or self.bibtex else False
826
827
        if cite_info and self.short_name_article:
828
            cite_name = (
829
                self.short_name_article.title()
830
                + " by "
831
                + self._get_first_author()
832
                + " et al."
833
            )
834
        elif cite_info:
835
            cite_name = (
836
                self._get_title()
837
                + " by "
838
                + self._get_first_author()
839
                + " et al."
840
            )
841
        else:
842
            # Set the tools class name as the citation name causes overwriting
843
            # module_name = tool_class.__module__.split('.')[-1].replace('_', ' ')
844
            # cite_name = module_name.split('tools')[0].title()
845
            cite_name = self.description
846
        return cite_name
847
848
    def _set_citation_id(self, tool_class):
849
        """Create a short identifier using the bibtex identification"""
850
        # Remove blank space
851
        if self.bibtex:
852
            cite_id = self._get_id()
853
        else:
854
            # Set the tools class name as the citation name
855
            module_name = tool_class.__module__.split(".")[-1]
856
            cite_id = module_name
857
        return cite_id
858
859
    def _set_id(self):
860
        """ Retrieve the id from the bibtex """
861
        cite_id = self.seperate_bibtex("@article{", ",")
862
        return cite_id
863
864
    def _get_first_author(self):
865
        """ Retrieve the first author name """
866
        if self.endnote:
867
            first_author = self.seperate_endnote("%A ")
868
        elif self.bibtex:
869
            first_author = self.seperate_bibtex("author={", "}", author=True)
870
        return first_author
0 ignored issues
show
introduced by
The variable first_author does not seem to be defined for all execution paths.
Loading history...
871
872
    def _get_title(self):
873
        """ Retrieve the title """
874
        if self.endnote:
875
            title = self.seperate_endnote("%T ")
876
        elif self.bibtex:
877
            title = self.seperate_bibtex("title={", "}")
878
        return title
0 ignored issues
show
introduced by
The variable title does not seem to be defined for all execution paths.
Loading history...
879
880
    def seperate_endnote(self, seperation_char):
881
        """Return the string contained between start characters
882
        and a new line
883
884
        :param seperation_char: Character to split the string at
885
        :return: The string contained between start characters and a new line
886
        """
887
        item = self.endnote.partition(seperation_char)[2].split("\n")[0]
888
        return item
889
890
    def seperate_bibtex(self, start_char, end_char, author=False):
891
        """Return the string contained between provided characters
892
893
        :param start_char: Character to split the string at
894
        :param end_char: Character to end the split at
895
        :return: The string contained between both characters
896
        """
897
        plain_text = doc.remove_new_lines(self.bibtex)
898
        item = plain_text.partition(start_char)[2].split(end_char)[0]
899
        if author:
900
            # Return one author only
901
            if " and " in item:
902
                item = item.split(" and ")[0]
903
904
        return item
905
906
    def write(self, citation_group):
907
        # classes don't have to be encoded to ASCII
908
        citation_group.attrs[NX_CLASS] = "NXcite"
909
        # Valid ascii sequences will be encoded, invalid ones will be
910
        # preserved as escape sequences
911
        description_array = np.array([self.description.encode('ascii','backslashreplace')])
912
        citation_group.create_dataset('description'.encode('ascii'),
913
                                      description_array.shape,
914
                                      description_array.dtype,
915
                                      description_array)
916
        doi_array = np.array([self.doi.encode('ascii')])
917
        citation_group.create_dataset('doi'.encode('ascii'),
918
                                      doi_array.shape,
919
                                      doi_array.dtype,
920
                                      doi_array)
921
        endnote_array = np.array([self.endnote.encode('ascii','backslashreplace')])
922
        citation_group.create_dataset('endnote'.encode('ascii'),
923
                                      endnote_array.shape,
924
                                      endnote_array.dtype,
925
                                      endnote_array)
926
        bibtex_array = np.array([self.bibtex.encode('ascii','backslashreplace')])
927
        citation_group.create_dataset('bibtex'.encode('ascii'),
928
                                      bibtex_array.shape,
929
                                      bibtex_array.dtype,
930
                                      bibtex_array)
931