Test Failed
Push — master ( 9ba79c...17f3e3 )
by Yousef
01:54 queued 19s
created

PluginList.shift_iterative_loop()   A

Complexity

Conditions 3

Size

Total Lines 15
Code Lines 9

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 9
nop 3
dl 0
loc 15
rs 9.95
c 0
b 0
f 0
1
# Copyright 2014 Diamond Light Source Ltd.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
15
"""
16
.. module:: plugin_list
17
   :platform: Unix
18
   :synopsis: Contains the PluginList class, which deals with loading and \
19
   saving the plugin list, and the CitationInformation class. An instance is \
20
   held by the MetaData class.
21
22
.. moduleauthor:: Nicola Wadeson <[email protected]>
23
24
"""
25
import ast
26
import copy
27
import inspect
28
import json
29
import logging
30
import os
31
import re
32
from collections import defaultdict
33
34
import h5py
35
import numpy as np
36
37
import savu.data.framework_citations as fc
38
import savu.plugins.docstring_parser as doc
39
import savu.plugins.loaders.utils.yaml_utils as yu
40
import savu.plugins.utils as pu
41
from savu.data.meta_data import MetaData
42
43
NX_CLASS = "NX_class"
44
45
46
class PluginList(object):
47
    """
48
    The PluginList class handles the plugin list - loading, saving and adding
49
    citation information for the plugin
50
    """
51
52
    def __init__(self):
53
        self.plugin_list = []
54
        self.n_plugins = None
55
        self.n_loaders = 0
56
        self.n_savers = 0
57
        self.loader_idx = None
58
        self.saver_idx = None
59
        self.datasets_list = []
60
        self.saver_plugin_status = True
61
        self._template = None
62
        self.version = None
63
        self.iterate_plugin_groups = []
64
65
    def add_template(self, create=False):
66
        self._template = Template(self)
67
        if create:
68
            self._template.creating = True
69
70
    def _get_plugin_entry_template(self):
71
        template = {"active": True, "name": None, "id": None, "data": None}
72
        return template
73
74
    def __get_json_keys(self):
75
        return ["data"]
76
77
    def _populate_plugin_list(
78
        self, filename, active_pass=False, template=False
79
    ):
80
        """ Populate the plugin list from a nexus file. """
81
        with h5py.File(filename, "r") as plugin_file:
82
            if "entry/savu_notes/version" in plugin_file:
83
                self.version = plugin_file["entry/savu_notes/version"][()]
84
                self._show_process_list_version()
85
86
            plugin_group = plugin_file["entry/plugin"]
87
            self.plugin_list = []
88
            single_val = ["name", "id", "pos", "active"]
89
            exclude = ["citation"]
90
            ordered_pl_keys = pu.sort_alphanum(list(plugin_group.keys()))
91
            for group in ordered_pl_keys:
92
                plugin = self._get_plugin_entry_template()
93
                entry_keys = plugin_group[group].keys()
94
                parameters = [
95
                    k
96
                    for k in entry_keys
97
                    for e in exclude
98
                    if k not in single_val and e not in k
99
                ]
100
101
                if "active" in entry_keys:
102
                    plugin["active"] = plugin_group[group]["active"][0]
103
104
                if plugin['active'] or active_pass:
105
                    plugin['name'] = plugin_group[group]['name'][0].decode("utf-8")
106
                    plugin['id'] = plugin_group[group]['id'][0].decode("utf-8")
107
                    plugin_tools = None
108
                    try:
109
                        plugin_class = pu.load_class(plugin["id"])()
110
                        # Populate the parameters (including those from it's base classes)
111
                        plugin_tools = plugin_class.get_plugin_tools()
112
                        if not plugin_tools:
113
                            raise OSError(f"Tools file not found for {plugin['name']}")
114
                        plugin_tools._populate_default_parameters()
115
                    except ImportError:
116
                        # No plugin class found
117
                        logging.error(f"No class found for {plugin['name']}")
118
119
                    plugin['doc'] = plugin_tools.docstring_info if plugin_tools else ""
120
                    plugin['tools'] = plugin_tools if plugin_tools else {}
121
                    plugin['param'] = plugin_tools.get_param_definitions() if \
122
                        plugin_tools else {}
123
                    plugin['pos'] = group.strip()
124
125
                    for param in parameters:
126
                        try:
127
                            plugin[param] = json.loads(plugin_group[group][param][0])
128
                        except ValueError as e:
129
                            raise ValueError(
130
                                f"Error: {e}\n Could not parse key '{param}' from group '{group}' as JSON"
131
                            )
132
                    self.plugin_list.append(plugin)
133
134
            # add info about groups of plugins to iterate over into
135
            # self.iterate_plugin_groups
136
            self.clear_iterate_plugin_group_dicts()
137
            try:
138
                iterate_groups = plugin_file['entry/iterate_plugin_groups']
139
                for key in list(iterate_groups.keys()):
140
                    iterate_group_dict = {
141
                        'start_index': iterate_groups[key]['start'][()],
142
                        'end_index': iterate_groups[key]['end'][()],
143
                        'iterations': iterate_groups[key]['iterations'][()]
144
                    }
145
                    self.iterate_plugin_groups.append(iterate_group_dict)
146
            except Exception as e:
147
                err_str = f"Process list file {filename} doesn't have the " \
148
                          f"iterate_plugin_groups internal hdf5 path"
149
                print(err_str)
150
151
            if template:
152
                self.add_template()
153
                self._template.update_process_list(template)
154
155
    def _show_process_list_version(self):
156
        """If the input process list was created using an older version
157
        of Savu, then alert the user"""
158
        from savu.version import __version__
159
        pl_version = float(self.version)
160
        if float(__version__) > pl_version:
161
            separator = "*" * 53
162
            print(separator)
163
            print(f"*** This process list was created using Savu "
164
                  f"{pl_version}  ***")
165
            print(separator)
166
167
    def _save_plugin_list(self, out_filename):
168
        with h5py.File(out_filename, "a") as nxs_file:
169
170
            entry = nxs_file.require_group("entry")
171
172
            self._save_framework_citations(self._overwrite_group(
173
                entry, 'framework_citations', 'NXcollection'))
174
175
            self.__save_savu_notes(self._overwrite_group(
176
                entry, 'savu_notes', 'NXnote'))
177
178
            plugins_group = self._overwrite_group(entry, 'plugin', 'NXprocess')
179
180
            count = 1
181
            for plugin in self.plugin_list:
182
                plugin_group = self._get_plugin_group(
183
                    plugins_group, plugin, count
184
                )
185
                self.__populate_plugins_group(plugin_group, plugin)
186
187
            self.__save_iterate_plugin_groups(self._overwrite_group(
188
                entry, 'iterate_plugin_groups', 'NXnote'))
189
190
        if self._template and self._template.creating:
191
            fname = os.path.splitext(out_filename)[0] + ".savu"
192
            self._template._output_template(fname, out_filename)
193
194
    def _overwrite_group(self, entry, name, nxclass):
195
        if name in entry:
196
            entry.pop(name)
197
        group = entry.create_group(name.encode("ascii"))
198
        group.attrs[NX_CLASS] = nxclass.encode("ascii")
199
        return group
200
201
    def add_iterate_plugin_group(self, start, end, iterations):
202
        """Add an element to self.iterate_plugin_groups"""
203
        group_new = {
204
            'start_index': start,
205
            'end_index': end,
206
            'iterations': iterations
207
        }
208
209
        if iterations <= 0:
210
            print("The number of iterations should be larger than zero and nonnegative")
211
            return
212
        elif start <= 0 or start > len(self.plugin_list) or end <= 0 or end > len(self.plugin_list):
213
            print("The given plugin indices are not within the range of existing plugin indices")
214
            return
215
216
        are_crosschecks_ok = \
217
            self._crosscheck_existing_loops(start, end, iterations)
218
219
        if are_crosschecks_ok:
220
            self.iterate_plugin_groups.append(group_new)
221
            info_str = f"The following loop has been added: start plugin " \
222
                       f"index {group_new['start_index']}, end plugin index " \
223
                       f"{group_new['end_index']}, iterations " \
224
                       f"{group_new['iterations']}"
225
            print(info_str)
226
227
    def _crosscheck_existing_loops(self, start, end, iterations):
228
        """
229
        Check requested loop to be added against existing loops for potential
230
        clashes
231
        """
232
        are_crosschecks_ok = False
233
        list_new_indices = list(range(start, end+1))
234
        # crosscheck with the existing iterative loops
235
        if len(self.iterate_plugin_groups) != 0:
236
            noexactlist = True
237
            nointersection = True
238
            for count, group in enumerate(self.iterate_plugin_groups, 1):
239
                start_int = int(group['start_index'])
240
                end_int = int(group['end_index'])
241
                list_existing_indices = list(range(start_int, end_int+1))
242
                if bool(set(list_new_indices).intersection(list_existing_indices)):
243
                    # check if the intersection of lists is exact (number of iterations to change)
244
                    nointersection = False
245
                    if list_new_indices == list_existing_indices:
246
                        print(f"The number of iterations of loop group no. {count}, {list_new_indices} has been set to: {iterations}")
247
                        self.iterate_plugin_groups[count-1]["iterations"] = iterations
248
                        noexactlist = False
249
                    else:
250
                        print(f"The plugins of group no. {count} are already set to be iterative: {set(list_new_indices).intersection(list_existing_indices)}")
251
            if noexactlist and nointersection:
252
                are_crosschecks_ok = True
253
        else:
254
            are_crosschecks_ok = True
255
256
        return are_crosschecks_ok
257
258
    def remove_iterate_plugin_groups(self, indices):
259
        """ Remove elements from self.iterate_plugin_groups """
260
        if len(indices) == 0:
261
            # remove all iterative loops in process list
262
            prompt_str = 'Are you sure you want to remove all iterative ' \
263
                'loops? [y/N]'
264
            check = input(prompt_str)
265
            should_remove_all = check.lower() == 'y'
266
            if should_remove_all:
267
                self.clear_iterate_plugin_group_dicts()
268
                print('All iterative loops have been removed')
269
            else:
270
                print('No iterative loops have been removed')
271
        else:
272
            # remove specified iterative loops in process list
273
            sorted_indices = sorted(indices)
274
275
            if sorted_indices[0] <= 0:
276
                print('The iterative loops are indexed starting from 1')
277
                return
278
279
            for i in reversed(sorted_indices):
280
                try:
281
                    # convert the one-based index to a zero-based index
282
                    iterate_group = self.iterate_plugin_groups.pop(i - 1)
283
                    info_str = f"The following loop has been removed: start " \
284
                               f"plugin index " \
285
                               f"{iterate_group['start_index']}, " \
286
                               f"end plugin index " \
287
                               f"{iterate_group['end_index']}, iterations " \
288
                               f"{iterate_group['iterations']}"
289
                    print(info_str)
290
                except IndexError as e:
291
                    info_str = f"There doesn't exist an iterative loop with " \
292
                               f"number {i}"
293
                    print(info_str)
294
295
    def clear_iterate_plugin_group_dicts(self):
296
        """
297
        Reset the list of dicts representing groups of plugins to iterate over
298
        """
299
        self.iterate_plugin_groups = []
300
301
    def get_iterate_plugin_group_dicts(self):
302
        """
303
        Return the list of dicts representing groups of plugins to iterate over
304
        """
305
        return self.iterate_plugin_groups
306
307
    def print_iterative_loops(self):
308
        if len(self.iterate_plugin_groups) == 0:
309
            print('There are no iterative loops in the current process list')
310
        else:
311
            print('Iterative loops in the current process list are:')
312
            for count, group in enumerate(self.iterate_plugin_groups, 1):
313
                number = f"({count}) "
314
                start_str = f"start plugin index: {group['start_index']}"
315
                end_str = f"end index: {group['end_index']}"
316
                iterations_str = f"iterations number: {group['iterations']}"
317
                full_str = number + start_str + ', ' + end_str + ', ' + \
318
                    iterations_str
319
                print(full_str)
320
321
    def remove_associated_iterate_group_dict(self, pos, direction):
322
        """
323
        Remove an iterative loop associated to a plugin index
324
        """
325
        operation = 'add' if direction == 1 else 'remove'
326
        for i, iterate_group in enumerate(self.iterate_plugin_groups):
327
            if operation == 'remove':
328
                if iterate_group['start_index'] <= pos and \
329
                    pos <= iterate_group['end_index']:
330
                    # remove the loop if the plugin being removed is at any
331
                    # position within an iterative loop
332
                    del self.iterate_plugin_groups[i]
333
                    break
334
            elif operation == 'add':
335
                if iterate_group['start_index'] != iterate_group['end_index']:
336
                    # remove the loop only if the plugin is being added between
337
                    # the start and end of the loop
338
                    if iterate_group['start_index'] < pos and \
339
                        pos <= iterate_group['end_index']:
340
                        del self.iterate_plugin_groups[i]
341
                        break
342
343
    def check_pos_in_iterative_loop(self, pos):
344
        """
345
        Check if the given plugin position is in an iterative loop
346
        """
347
        is_in_loop = False
348
        for iterate_group in self.iterate_plugin_groups:
349
            if iterate_group['start_index'] <= pos and \
350
                pos <= iterate_group['end_index']:
351
                is_in_loop = True
352
                break
353
354
        return is_in_loop
355
356
    def __save_iterate_plugin_groups(self, group):
357
        '''
358
        Save information regarding the groups of plugins to iterate over
359
        '''
360
        for count, iterate_group in enumerate(self.iterate_plugin_groups):
361
            grp_name = str(count)
362
            grp = group.create_group(grp_name.encode('ascii'))
363
            shape = () # scalar data
364
            grp.create_dataset('start'.encode('ascii'), shape, 'i',
365
                iterate_group['start_index'])
366
            grp.create_dataset('end'.encode('ascii'), shape, 'i',
367
                iterate_group['end_index'])
368
            grp.create_dataset('iterations'.encode('ascii'), shape, 'i',
369
                iterate_group['iterations'])
370
371
    def shift_subsequent_iterative_loops(self, pos, direction):
372
        """
373
        Shift all iterative loops occurring after a given plugin position
374
        """
375
        # if removing a plugin that is positioned before a loop, the loop should
376
        # be shifted down by 1; but if removing a plugin that is positioned at
377
        # the start of the loop, it will be removed instead of shifted (ie, both
378
        # < or <= work for this case)
379
        #
380
        # if adding a plugin that will be positioned before a loop, the loop
381
        # should be shifted up by 1; also, if adding a plugin to be positioned
382
        # where the start of a loop currently exists, this should shift the loop
383
        # up by 1 as well (ie, only <= works for this case, hence the use of <=)
384
        for iterate_group in self.iterate_plugin_groups:
385
            if pos <= iterate_group['start_index']:
386
                self.shift_iterative_loop(iterate_group, direction)
387
388
    def shift_range_iterative_loops(self, positions, direction):
389
        """
390
        Shift all iterative loops within a range of plugin indices
391
        """
392
        for iterate_group in self.iterate_plugin_groups:
393
            if positions[0] <= iterate_group['start_index'] and \
394
                iterate_group['end_index'] <= positions[1]:
395
                self.shift_iterative_loop(iterate_group, direction)
396
397
    def shift_iterative_loop(self, iterate_group, direction):
398
        """
399
        Shift an iterative loop up or down in the process list, based on if a
400
        plugin is added or removed
401
        """
402
        if direction == 1:
403
            iterate_group['start_index'] += 1
404
            iterate_group['end_index'] += 1
405
        elif direction == -1:
406
            iterate_group['start_index'] -= 1
407
            iterate_group['end_index'] -= 1
408
        else:
409
            err_str = f"Bad direction value given to shift iterative loop: " \
410
                      f"{direction}"
411
            raise ValueError(err_str)
412
413
    def __save_savu_notes(self, notes):
414
        """ Save the version number
415
416
        :param notes: hdf5 group to save data to
417
        """
418
        from savu.version import __version__
419
420
        notes["version"] = __version__
421
422
    def __populate_plugins_group(self, plugin_group, plugin):
423
        """Populate the plugin group information which will be saved
424
425
        :param plugin_group: Plugin group to save to
426
        :param plugin: Plugin to be saved
427
        """
428
        plugin_group.attrs[NX_CLASS] = "NXnote".encode("ascii")
429
        required_keys = self._get_plugin_entry_template().keys()
430
        json_keys = self.__get_json_keys()
431
432
        self._save_citations(plugin, plugin_group)
433
434
        for key in required_keys:
435
            # only need to apply dumps if saving in configurator
436
            if key == "data":
437
                data = {}
438
                for k, v in plugin[key].items():
439
                    #  Replace any missing quotes around variables.
440
                    data[k] = pu._dumps(v)
441
            else:
442
                data = plugin[key]
443
444
            # get the string value
445
            data = json.dumps(data) if key in json_keys else plugin[key]
446
            # if the data is string it has to be encoded to ascii so that
447
            # hdf5 can save out the bytes
448
            if isinstance(data, str):
449
                data = data.encode("ascii")
450
            data = np.array([data])
451
            plugin_group.create_dataset(
452
                key.encode("ascii"), data.shape, data.dtype, data
453
            )
454
455
    def _get_plugin_group(self, plugins_group, plugin, count):
456
        """Return the plugin_group, into which the plugin information
457
         will be saved
458
459
        :param plugins_group: Current group to save inside
460
        :param plugin: Plugin to be saved
461
        :param count: Order number of the plugin in the process list
462
        :return: plugin group
463
        """
464
        if "pos" in plugin.keys():
465
            num = int(re.findall(r"\d+", str(plugin["pos"]))[0])
466
            letter = re.findall("[a-z]", str(plugin["pos"]))
467
            letter = letter[0] if letter else ""
468
            group_name = "%i%s" % (num, letter)
469
        else:
470
            group_name = count
471
        return plugins_group.create_group(group_name.encode("ascii"))
472
473
    def _add(self, idx, entry):
474
        self.plugin_list.insert(idx, entry)
475
        self.__set_loaders_and_savers()
476
477
    def _remove(self, idx):
478
        del self.plugin_list[idx]
479
        self.__set_loaders_and_savers()
480
481
    def _save_citations(self, plugin, group):
482
        """Save all the citations in the plugin
483
484
        :param plugin: dictionary of plugin information
485
        :param group: Group to save to
486
        """
487
        if "tools" in plugin.keys():
488
            citation_plugin = plugin.get("tools").get_citations()
489
            if citation_plugin:
490
                count = 1
491
                for citation in citation_plugin.values():
492
                    group_label = f"citation{count}"
493
                    if (
494
                        not citation.dependency
495
                        or self._dependent_citation_used(plugin, citation)
496
                    ):
497
                        self._save_citation_group(
498
                            citation, citation.__dict__, group, group_label
499
                        )
500
                        count += 1
501
502
    def _dependent_citation_used(self, plugin, citation):
503
        """Check if any Plugin parameter values match the citation
504
        dependency requirement.
505
506
        :param plugin: dictionary of plugin information
507
        :param citation: A plugin citation
508
        :return: bool True if the citation is for a parameter value
509
            being used inside this plugin
510
        """
511
        parameters = plugin["data"]
512
        for (
513
            citation_dependent_parameter,
514
            citation_dependent_value,
515
        ) in citation.dependency.items():
516
            current_value = parameters[citation_dependent_parameter]
517
            if current_value == citation_dependent_value:
518
                return True
519
        return False
520
521
    def _exec_citations(self, cite, citation):
522
        """Execute citations to variable
523
        :param cite: citation dictionary
524
        """
525
        for key, value in cite.items():
526
            exec("citation." + key + "= value")
527
528
    def _save_framework_citations(self, group):
529
        """Save all the citations in dict
530
531
        :param group: Group for nxs file
532
        """
533
        framework_cites = fc.get_framework_citations()
534
        for cite in framework_cites.values():
535
            label = cite.short_name_article
536
            del cite.short_name_article
537
            self._save_citation_group(cite, cite.__dict__, group, label)
538
539
540
    def _save_citation_group(self, citation, cite_dict, group, group_label):
541
        """Save the citations to the provided group label
542
543
        :param citation: Citation object
544
        :param cite_dict: Citation dictionary
545
        :param group: Group
546
        :param group_label: Group label
547
        :return:
548
        """
549
        citation_group = group.require_group(group_label.encode("ascii"))
550
        self._exec_citations(cite_dict, citation)
551
        citation.write(citation_group)
552
553
    def _get_docstring_info(self, plugin):
554
        plugin_inst = pu.plugins[plugin]()
555
        tools = plugin_inst.get_plugin_tools()
556
        tools._populate_default_parameters()
557
        return plugin_inst.docstring_info
558
559
    # def _byteify(self, input):
560
    #     if isinstance(input, dict):
561
    #         return {self._byteify(key): self._byteify(value)
562
    #                 for key, value in input.items()}
563
    #     elif isinstance(input, list):
564
    #         temp = [self._byteify(element) for element in input]
565
    #         return temp
566
    #     elif isinstance(input, str):
567
    #         return input.encode('utf-8')
568
    #     else:
569
    #         return input
570
571
    def _set_datasets_list(self, plugin):
572
        in_pData, out_pData = plugin.get_plugin_datasets()
573
        in_data_list = self._populate_datasets_list(in_pData)
574
        out_data_list = self._populate_datasets_list(out_pData)
575
        self.datasets_list.append({'in_datasets': in_data_list,
576
                                   'out_datasets': out_data_list})
577
578
    def _populate_datasets_list(self, data):
579
        data_list = []
580
        for d in data:
581
            name = d.data_obj.get_name()
582
            pattern = copy.deepcopy(d.get_pattern())
583
            pattern[list(pattern.keys())[0]]['max_frames_transfer'] = \
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable list does not seem to be defined.
Loading history...
584
                d.meta_data.get('max_frames_transfer')
585
            pattern[list(pattern.keys())[0]]['transfer_shape'] = \
586
                d.meta_data.get('transfer_shape')
587
            data_list.append({'name': name, 'pattern': pattern})
588
        return data_list
589
590
    def _get_datasets_list(self):
591
        return self.datasets_list
592
593
    def _reset_datasets_list(self):
594
        self.datasets_list = []
595
596
    def _get_n_loaders(self):
597
        return self.n_loaders
598
599
    def _get_n_savers(self):
600
        return self.n_savers
601
602
    def _get_loaders_index(self):
603
        return self.loader_idx
604
605
    def _get_savers_index(self):
606
        return self.saver_idx
607
608
    def _get_n_processing_plugins(self):
609
        return len(self.plugin_list) - self._get_n_loaders()
610
611
    def __set_loaders_and_savers(self):
612
        """Get lists of loader and saver positions within the plugin list and
613
        set the number of loaders.
614
615
        :returns: loader index list and saver index list
616
        :rtype: list(int(loader)), list(int(saver))
617
        """
618
        from savu.plugins.loaders.base_loader import BaseLoader
619
        from savu.plugins.savers.base_saver import BaseSaver
620
621
        loader_idx = []
622
        saver_idx = []
623
        self.n_plugins = len(self.plugin_list)
624
625
        for i in range(self.n_plugins):
626
            pid = self.plugin_list[i]["id"]
627
            bases = inspect.getmro(pu.load_class(pid))
628
            loader_list = [b for b in bases if b == BaseLoader]
629
            saver_list = [b for b in bases if b == BaseSaver]
630
            if loader_list:
631
                loader_idx.append(i)
632
            if saver_list:
633
                saver_idx.append(i)
634
        self.loader_idx = loader_idx
635
        self.saver_idx = saver_idx
636
        self.n_loaders = len(loader_idx)
637
        self.n_savers = len(saver_idx)
638
639
    def _check_loaders(self):
640
        """Check plugin list starts with a loader."""
641
        self.__set_loaders_and_savers()
642
        loaders = self._get_loaders_index()
643
644
        if loaders:
645
            if loaders[0] != 0 or loaders[-1] + 1 != len(loaders):
646
                raise Exception("All loader plugins must be at the beginning "
647
                                "of the process list")
648
        else:
649
            raise Exception("The first plugin in the process list must be a "
650
                            "loader plugin.")
651
652
    def _add_missing_savers(self, exp):
653
        """ Add savers for missing datasets. """
654
        data_names = exp.index["in_data"].keys()
655
        saved_data = []
656
657
        for i in self._get_savers_index():
658
            saved_data.append(self.plugin_list[i]["data"]["in_datasets"])
659
        saved_data = set([s for sub_list in saved_data for s in sub_list])
660
661
        for name in [data for data in data_names if data not in saved_data]:
662
            pos = exp.meta_data.get("nPlugin") + 1
663
            exp.meta_data.set("nPlugin", pos)
664
            process = {}
665
            plugin = pu.load_class("savu.plugins.savers.hdf5_saver")()
666
            ptools = plugin.get_plugin_tools()
667
            plugin.parameters["in_datasets"] = [name]
668
            process["name"] = plugin.name
669
            process["id"] = plugin.__module__
670
            process["pos"] = str(pos + 1)
671
            process["data"] = plugin.parameters
672
            process["active"] = True
673
            process["param"] = ptools.get_param_definitions()
674
            process["doc"] = ptools.docstring_info
675
            process["tools"] = ptools
676
            self._add(pos + 1, process)
677
678
    def _update_datasets(self, plugin_no, data_dict):
679
        n_loaders = self._get_n_loaders()
680
        idx = self._get_n_loaders() + plugin_no
681
        self.plugin_list[idx]["data"].update(data_dict)
682
683
    def _get_dataset_flow(self):
684
        datasets_idx = []
685
        n_loaders = self._get_n_loaders()
686
        n_plugins = self._get_n_processing_plugins()
687
        for i in range(self.n_loaders, n_loaders + n_plugins):
688
            datasets_idx.append(self.plugin_list[i]["data"]["out_datasets"])
689
        return datasets_idx
690
691
    def _contains_gpu_processes(self):
692
        """ Returns True if gpu processes exist in the process list. """
693
        try:
694
            from savu.plugins.driver.gpu_plugin import GpuPlugin
695
696
            for i in range(self.n_plugins):
697
                bases = inspect.getmro(
698
                    pu.load_class(self.plugin_list[i]["id"])
699
                )
700
                if GpuPlugin in bases:
701
                    return True
702
        except ImportError as ex:
703
            if "pynvml" in ex.message:
704
                logging.error(
705
                    "Error while importing GPU dependencies: %s", ex.message
706
                )
707
            else:
708
                raise
709
710
        return False
711
712
713
class Template(object):
714
    """A class to read and write templates for plugin lists."""
715
716
    def __init__(self, plist):
717
        super(Template, self).__init__()
718
        self.plist = plist
719
        self.creating = False
720
721
    def _output_template(self, fname, process_fname):
722
        plist = self.plist.plugin_list
723
        index = [i for i in range(len(plist)) if plist[i]["active"]]
724
725
        local_dict = MetaData(ordered=True)
726
        global_dict = MetaData(ordered=True)
727
        local_dict.set(["process_list"], os.path.abspath(process_fname))
728
729
        for i in index:
730
            params = self.__get_template_params(plist[i]["data"], [])
731
            name = plist[i]["name"]
732
            for p in params:
733
                ptype, isyaml, key, value = p
734
                if isyaml:
735
                    data_name = isyaml if ptype == "local" else "all"
736
                    local_dict.set([i + 1, name, data_name, key], value)
737
                elif ptype == "local":
738
                    local_dict.set([i + 1, name, key], value)
739
                else:
740
                    global_dict.set(["all", name, key], value)
741
742
        with open(fname, "w") as stream:
743
            local_dict.get_dictionary().update(global_dict.get_dictionary())
744
            yu.dump_yaml(local_dict.get_dictionary(), stream)
745
746
    def __get_template_params(self, params, tlist, yaml=False):
747
        for key, value in params.items():
748
            if key == "yaml_file":
749
                yaml_dict = self._get_yaml_dict(value)
750
                for entry in list(yaml_dict.keys()):
751
                    self.__get_template_params(
752
                        yaml_dict[entry]["params"], tlist, yaml=entry
753
                    )
754
            value = pu.is_template_param(value)
755
            if value is not False:
756
                ptype, value = value
757
                isyaml = yaml if yaml else False
758
                tlist.append([ptype, isyaml, key, value])
759
        return tlist
760
761
    def _get_yaml_dict(self, yfile):
762
        from savu.plugins.loaders.yaml_converter import YamlConverter
763
764
        yaml_c = YamlConverter()
765
        template_check = pu.is_template_param(yfile)
766
        yfile = template_check[1] if template_check is not False else yfile
767
        yaml_c.parameters = {"yaml_file": yfile}
768
        return yaml_c.setup(template=True)
769
770
    def update_process_list(self, template):
771
        tdict = yu.read_yaml(template)
772
        del tdict["process_list"]
773
774
        for plugin_no, entry in tdict.items():
775
            plugin = list(entry.keys())[0]
776
            for key, value in list(entry.values())[0].items():
777
                depth = self.dict_depth(value)
778
                if depth == 1:
779
                    self._set_param_for_template_loader_plugin(
780
                        plugin_no, key, value
781
                    )
782
                elif depth == 0:
783
                    if plugin_no == "all":
784
                        self._set_param_for_all_instances_of_a_plugin(
785
                            plugin, key, value
786
                        )
787
                    else:
788
                        data = self._get_plugin_data_dict(str(plugin_no))
789
                        data[key] = value
790
                else:
791
                    raise Exception("Template key not recognised.")
792
793
    def dict_depth(self, d, depth=0):
794
        if not isinstance(d, dict) or not d:
795
            return depth
796
        return max(self.dict_depth(v, depth + 1) for k, v in d.items())
797
798
    def _set_param_for_all_instances_of_a_plugin(self, plugin, param, value):
799
        # find all plugins with this name and replace the param
800
        for p in self.plist.plugin_list:
801
            if p["name"] == plugin:
802
                p["data"][param] = value
803
804
    def _set_param_for_template_loader_plugin(self, plugin_no, data, value):
805
        param_key = list(value.keys())[0]
806
        param_val = list(value.values())[0]
807
        pdict = self._get_plugin_data_dict(str(plugin_no))["template_param"]
808
        pdict = defaultdict(dict) if not pdict else pdict
809
        pdict[data][param_key] = param_val
810
811
    def _get_plugin_data_dict(self, plugin_no):
812
        """ input plugin_no as a string """
813
        plist = self.plist.plugin_list
814
        index = [plist[i]["pos"] for i in range(len(plist))]
815
        return plist[index.index(plugin_no)]["data"]
816
817
818
class CitationInformation(object):
819
    """
820
    Descriptor of Citation Information for plugins
821
    """
822
823
    def __init__(
824
        self,
825
        description,
826
        bibtex="",
827
        endnote="",
828
        doi="",
829
        short_name_article="",
830
        dependency="",
831
    ):
832
        self.description = description
833
        self.short_name_article = short_name_article
834
        self.bibtex = bibtex
835
        self.endnote = endnote
836
        self.doi = doi
837
        self.dependency = dependency
838
        self.name = self._set_citation_name()
839
        self.id = self._set_id()
840
841
    def _set_citation_name(self):
842
        """Create a short identifier using the short name of the article
843
        and the first author
844
        """
845
        cite_info = True if self.endnote or self.bibtex else False
846
847
        if cite_info and self.short_name_article:
848
            cite_name = (
849
                self.short_name_article.title()
850
                + " by "
851
                + self._get_first_author()
852
                + " et al."
853
            )
854
        elif cite_info:
855
            cite_name = (
856
                self._get_title()
857
                + " by "
858
                + self._get_first_author()
859
                + " et al."
860
            )
861
        else:
862
            # Set the tools class name as the citation name causes overwriting
863
            # module_name = tool_class.__module__.split('.')[-1].replace('_', ' ')
864
            # cite_name = module_name.split('tools')[0].title()
865
            cite_name = self.description
866
        return cite_name
867
868
    def _set_citation_id(self, tool_class):
869
        """Create a short identifier using the bibtex identification"""
870
        # Remove blank space
871
        if self.bibtex:
872
            cite_id = self._get_id()
873
        else:
874
            # Set the tools class name as the citation name
875
            module_name = tool_class.__module__.split(".")[-1]
876
            cite_id = module_name
877
        return cite_id
878
879
    def _set_id(self):
880
        """ Retrieve the id from the bibtex """
881
        cite_id = self.seperate_bibtex("@article{", ",")
882
        return cite_id
883
884
    def _get_first_author(self):
885
        """ Retrieve the first author name """
886
        if self.endnote:
887
            first_author = self.seperate_endnote("%A ")
888
        elif self.bibtex:
889
            first_author = self.seperate_bibtex("author={", "}", author=True)
890
        return first_author
0 ignored issues
show
introduced by
The variable first_author does not seem to be defined for all execution paths.
Loading history...
891
892
    def _get_title(self):
893
        """ Retrieve the title """
894
        if self.endnote:
895
            title = self.seperate_endnote("%T ")
896
        elif self.bibtex:
897
            title = self.seperate_bibtex("title={", "}")
898
        return title
0 ignored issues
show
introduced by
The variable title does not seem to be defined for all execution paths.
Loading history...
899
900
    def seperate_endnote(self, seperation_char):
901
        """Return the string contained between start characters
902
        and a new line
903
904
        :param seperation_char: Character to split the string at
905
        :return: The string contained between start characters and a new line
906
        """
907
        item = self.endnote.partition(seperation_char)[2].split("\n")[0]
908
        return item
909
910
    def seperate_bibtex(self, start_char, end_char, author=False):
911
        """Return the string contained between provided characters
912
913
        :param start_char: Character to split the string at
914
        :param end_char: Character to end the split at
915
        :return: The string contained between both characters
916
        """
917
        plain_text = doc.remove_new_lines(self.bibtex)
918
        item = plain_text.partition(start_char)[2].split(end_char)[0]
919
        if author:
920
            # Return one author only
921
            if " and " in item:
922
                item = item.split(" and ")[0]
923
924
        return item
925
926
    def write(self, citation_group):
927
        # classes don't have to be encoded to ASCII
928
        citation_group.attrs[NX_CLASS] = "NXcite"
929
        # Valid ascii sequences will be encoded, invalid ones will be
930
        # preserved as escape sequences
931
        description_array = np.array([self.description.encode('ascii','backslashreplace')])
932
        citation_group.create_dataset('description'.encode('ascii'),
933
                                      description_array.shape,
934
                                      description_array.dtype,
935
                                      description_array)
936
        doi_array = np.array([self.doi.encode('ascii')])
937
        citation_group.create_dataset('doi'.encode('ascii'),
938
                                      doi_array.shape,
939
                                      doi_array.dtype,
940
                                      doi_array)
941
        endnote_array = np.array([self.endnote.encode('ascii','backslashreplace')])
942
        citation_group.create_dataset('endnote'.encode('ascii'),
943
                                      endnote_array.shape,
944
                                      endnote_array.dtype,
945
                                      endnote_array)
946
        bibtex_array = np.array([self.bibtex.encode('ascii','backslashreplace')])
947
        citation_group.create_dataset('bibtex'.encode('ascii'),
948
                                      bibtex_array.shape,
949
                                      bibtex_array.dtype,
950
                                      bibtex_array)
951