Test Failed
Pull Request — master (#888)
by Daniil
03:57
created

PluginList.add_iterate_plugin_group()   B

Complexity

Conditions 7

Size

Total Lines 25
Code Lines 17

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 7
eloc 17
nop 4
dl 0
loc 25
rs 8
c 0
b 0
f 0
1
# Copyright 2014 Diamond Light Source Ltd.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
15
"""
16
.. module:: plugin_list
17
   :platform: Unix
18
   :synopsis: Contains the PluginList class, which deals with loading and \
19
   saving the plugin list, and the CitationInformation class. An instance is \
20
   held by the MetaData class.
21
22
.. moduleauthor:: Nicola Wadeson <[email protected]>
23
24
"""
25
import ast
26
import copy
27
import inspect
28
import json
29
import logging
30
import os
31
import re
32
from collections import defaultdict
33
34
import h5py
35
import numpy as np
36
37
import savu.data.framework_citations as fc
38
import savu.plugins.docstring_parser as doc
39
import savu.plugins.loaders.utils.yaml_utils as yu
40
import savu.plugins.utils as pu
41
from savu.data.meta_data import MetaData
42
43
NX_CLASS = "NX_class"
44
45
46
class PluginList(object):
47
    """
48
    The PluginList class handles the plugin list - loading, saving and adding
49
    citation information for the plugin
50
    """
51
52
    def __init__(self):
53
        self.plugin_list = []
54
        self.n_plugins = None
55
        self.n_loaders = 0
56
        self.n_savers = 0
57
        self.loader_idx = None
58
        self.saver_idx = None
59
        self.datasets_list = []
60
        self.saver_plugin_status = True
61
        self._template = None
62
        self.version = None
63
        self.iterate_plugin_groups = []
64
65
    def add_template(self, create=False):
66
        self._template = Template(self)
67
        if create:
68
            self._template.creating = True
69
70
    def _get_plugin_entry_template(self):
71
        template = {"active": True, "name": None, "id": None, "data": None}
72
        return template
73
74
    def __get_json_keys(self):
75
        return ["data"]
76
77
    def _populate_plugin_list(
78
        self, filename, active_pass=False, template=False
79
    ):
80
        """ Populate the plugin list from a nexus file. """
81
        with h5py.File(filename, "r") as plugin_file:
82
            if "entry/savu_notes/version" in plugin_file:
83
                self.version = plugin_file["entry/savu_notes/version"][()]
84
                self._show_process_list_version()
85
86
            plugin_group = plugin_file["entry/plugin"]
87
            self.plugin_list = []
88
            single_val = ["name", "id", "pos", "active"]
89
            exclude = ["citation"]
90
            ordered_pl_keys = pu.sort_alphanum(list(plugin_group.keys()))
91
            for group in ordered_pl_keys:
92
                plugin = self._get_plugin_entry_template()
93
                entry_keys = plugin_group[group].keys()
94
                parameters = [
95
                    k
96
                    for k in entry_keys
97
                    for e in exclude
98
                    if k not in single_val and e not in k
99
                ]
100
101
                if "active" in entry_keys:
102
                    plugin["active"] = plugin_group[group]["active"][0]
103
104
                if plugin['active'] or active_pass:
105
                    plugin['name'] = plugin_group[group]['name'][0].decode("utf-8")
106
                    plugin['id'] = plugin_group[group]['id'][0].decode("utf-8")
107
                    plugin_tools = None
108
                    try:
109
                        plugin_class = pu.load_class(plugin["id"])()
110
                        # Populate the parameters (including those from it's base classes)
111
                        plugin_tools = plugin_class.get_plugin_tools()
112
                        if not plugin_tools:
113
                            raise OSError(f"Tools file not found for {plugin['name']}")
114
                        plugin_tools._populate_default_parameters()
115
                    except ImportError:
116
                        # No plugin class found
117
                        logging.error(f"No class found for {plugin['name']}")
118
119
                    plugin['doc'] = plugin_tools.docstring_info if plugin_tools else ""
120
                    plugin['tools'] = plugin_tools if plugin_tools else {}
121
                    plugin['param'] = plugin_tools.get_param_definitions() if \
122
                        plugin_tools else {}
123
                    plugin['pos'] = group.strip()
124
125
                    for param in parameters:
126
                        try:
127
                            plugin[param] = json.loads(plugin_group[group][param][0])
128
                        except ValueError as e:
129
                            raise ValueError(
130
                                f"Error: {e}\n Could not parse key '{param}' from group '{group}' as JSON"
131
                            )
132
                    self.plugin_list.append(plugin)
133
134
            # add info about groups of plugins to iterate over into
135
            # self.iterate_plugin_groups
136
            self.clear_iterate_plugin_group_dicts()
137
            try:
138
                iterate_groups = plugin_file['entry/iterate_plugin_groups']
139
                for key in list(iterate_groups.keys()):
140
                    iterate_group_dict = {
141
                        'start_index': iterate_groups[key]['start'][()],
142
                        'end_index': iterate_groups[key]['end'][()],
143
                        'iterations': iterate_groups[key]['iterations'][()]
144
                    }
145
                    self.iterate_plugin_groups.append(iterate_group_dict)
146
            except Exception as e:
147
                err_str = f"Process list file {filename} doesn't have the " \
148
                          f"iterate_plugin_groups internal hdf5 path"
149
                print(err_str)
150
151
            if template:
152
                self.add_template()
153
                self._template.update_process_list(template)
154
155
    def _show_process_list_version(self):
156
        """If the input process list was created using an older version
157
        of Savu, then alert the user"""
158
        from savu.version import __version__
159
        pl_version = float(self.version)
160
        if float(__version__) > pl_version:
161
            separator = "*" * 53
162
            print(separator)
163
            print(f"*** This process list was created using Savu "
164
                  f"{pl_version}  ***")
165
            print(separator)
166
            print(f"The process list has been updated, the incorrect \n"
167
                  f"parameter values have been reverted to default. \n"
168
                  f"Any warnings below point to the problematic parameters.\n"
169
                  f"Save this process list to save the updated values.")
170
            print(separator)
171
172
    def _save_plugin_list(self, out_filename):
173
        with h5py.File(out_filename, "a") as nxs_file:
174
175
            entry = nxs_file.require_group("entry")
176
177
            self._save_framework_citations(self._overwrite_group(
178
                entry, 'framework_citations', 'NXcollection'))
179
180
            self.__save_savu_notes(self._overwrite_group(
181
                entry, 'savu_notes', 'NXnote'))
182
183
            plugins_group = self._overwrite_group(entry, 'plugin', 'NXprocess')
184
185
            count = 1
186
            for plugin in self.plugin_list:
187
                plugin_group = self._get_plugin_group(
188
                    plugins_group, plugin, count
189
                )
190
                self.__populate_plugins_group(plugin_group, plugin)
191
192
            self.__save_iterate_plugin_groups(self._overwrite_group(
193
                entry, 'iterate_plugin_groups', 'NXnote'))
194
195
        if self._template and self._template.creating:
196
            fname = os.path.splitext(out_filename)[0] + ".savu"
197
            self._template._output_template(fname, out_filename)
198
199
    def _overwrite_group(self, entry, name, nxclass):
200
        if name in entry:
201
            entry.pop(name)
202
        group = entry.create_group(name.encode("ascii"))
203
        group.attrs[NX_CLASS] = nxclass.encode("ascii")
204
        return group
205
206
    def add_iterate_plugin_group(self, start, end, iterations):
207
        """Add an element to self.iterate_plugin_groups"""
208
        group_new = {
209
            'start_index': start,
210
            'end_index': end,
211
            'iterations': iterations
212
        }
213
214
        if iterations <= 0:
215
            print("The number of iterations should be larger than zero and nonnegative")
216
            return
217
        elif start <= 0 or start > len(self.plugin_list) or end <= 0 or end > len(self.plugin_list):
218
            print("The given plugin indices are not within the range of existing plugin indices")
219
            return
220
221
        are_crosschecks_ok = \
222
            self._crosscheck_existing_loops(start, end, iterations)
223
224
        if are_crosschecks_ok:
225
            self.iterate_plugin_groups.append(group_new)
226
            info_str = f"The following loop has been added: start plugin " \
227
                       f"index {group_new['start_index']}, end plugin index " \
228
                       f"{group_new['end_index']}, iterations " \
229
                       f"{group_new['iterations']}"
230
            print(info_str)
231
232
    def _crosscheck_existing_loops(self, start, end, iterations):
233
        """
234
        Check requested loop to be added against existing loops for potential
235
        clashes
236
        """
237
        are_crosschecks_ok = False
238
        list_new_indices = list(range(start, end+1))
239
        # crosscheck with the existing iterative loops
240
        if len(self.iterate_plugin_groups) != 0:
241
            noexactlist = True
242
            nointersection = True
243
            for count, group in enumerate(self.iterate_plugin_groups, 1):
244
                start_int = int(group['start_index'])
245
                end_int = int(group['end_index'])
246
                list_existing_indices = list(range(start_int, end_int+1))
247
                if bool(set(list_new_indices).intersection(list_existing_indices)):
248
                    # check if the intersection of lists is exact (number of iterations to change)
249
                    nointersection = False
250
                    if list_new_indices == list_existing_indices:
251
                        print(f"The number of iterations of loop group no. {count}, {list_new_indices} has been set to: {iterations}")
252
                        self.iterate_plugin_groups[count-1]["iterations"] = iterations
253
                        noexactlist = False
254
                    else:
255
                        print(f"The plugins of group no. {count} are already set to be iterative: {set(list_new_indices).intersection(list_existing_indices)}")
256
            if noexactlist and nointersection:
257
                are_crosschecks_ok = True
258
        else:
259
            are_crosschecks_ok = True
260
261
        return are_crosschecks_ok
262
263
    def remove_iterate_plugin_groups(self, indices):
264
        """ Remove elements from self.iterate_plugin_groups """
265
        if len(indices) == 0:
266
            # remove all iterative loops in process list
267
            prompt_str = 'Are you sure you want to remove all iterative ' \
268
                'loops? [y/N]'
269
            check = input(prompt_str)
270
            should_remove_all = check.lower() == 'y'
271
            if should_remove_all:
272
                self.clear_iterate_plugin_group_dicts()
273
                print('All iterative loops have been removed')
274
            else:
275
                print('No iterative loops have been removed')
276
        else:
277
            # remove specified iterative loops in process list
278
            sorted_indices = sorted(indices)
279
280
            if sorted_indices[0] <= 0:
281
                print('The iterative loops are indexed starting from 1')
282
                return
283
284
            for i in reversed(sorted_indices):
285
                try:
286
                    # convert the one-based index to a zero-based index
287
                    iterate_group = self.iterate_plugin_groups.pop(i - 1)
288
                    info_str = f"The following loop has been removed: start " \
289
                               f"plugin index " \
290
                               f"{iterate_group['start_index']}, " \
291
                               f"end plugin index " \
292
                               f"{iterate_group['end_index']}, iterations " \
293
                               f"{iterate_group['iterations']}"
294
                    print(info_str)
295
                except IndexError as e:
296
                    info_str = f"There doesn't exist an iterative loop with " \
297
                               f"number {i}"
298
                    print(info_str)
299
300
    def clear_iterate_plugin_group_dicts(self):
301
        """
302
        Reset the list of dicts representing groups of plugins to iterate over
303
        """
304
        self.iterate_plugin_groups = []
305
306
    def get_iterate_plugin_group_dicts(self):
307
        """
308
        Return the list of dicts representing groups of plugins to iterate over
309
        """
310
        return self.iterate_plugin_groups
311
312
    def print_iterative_loops(self):
313
        if len(self.iterate_plugin_groups) == 0:
314
            print('There are no iterative loops in the current process list')
315
        else:
316
            print('Iterative loops in the current process list are:')
317
            for count, group in enumerate(self.iterate_plugin_groups, 1):
318
                number = f"({count}) "
319
                start_str = f"start plugin index: {group['start_index']}"
320
                end_str = f"end index: {group['end_index']}"
321
                iterations_str = f"iterations number: {group['iterations']}"
322
                full_str = number + start_str + ', ' + end_str + ', ' + \
323
                    iterations_str
324
                print(full_str)
325
326
    def remove_associated_iterate_group_dict(self, pos, direction):
327
        """
328
        Remove an iterative loop associated to a plugin index
329
        """
330
        operation = 'add' if direction == 1 else 'remove'
331
        for i, iterate_group in enumerate(self.iterate_plugin_groups):
332
            if operation == 'remove':
333
                if iterate_group['start_index'] <= pos and \
334
                    pos <= iterate_group['end_index']:
335
                    # remove the loop if the plugin being removed is at any
336
                    # position within an iterative loop
337
                    del self.iterate_plugin_groups[i]
338
                    break
339
            elif operation == 'add':
340
                if iterate_group['start_index'] != iterate_group['end_index']:
341
                    # remove the loop only if the plugin is being added between
342
                    # the start and end of the loop
343
                    if iterate_group['start_index'] < pos and \
344
                        pos <= iterate_group['end_index']:
345
                        del self.iterate_plugin_groups[i]
346
                        break
347
348
    def check_pos_in_iterative_loop(self, pos):
349
        """
350
        Check if the given plugin position is in an iterative loop
351
        """
352
        is_in_loop = False
353
        for iterate_group in self.iterate_plugin_groups:
354
            if iterate_group['start_index'] <= pos and \
355
                pos <= iterate_group['end_index']:
356
                is_in_loop = True
357
                break
358
359
        return is_in_loop
360
361
    def __save_iterate_plugin_groups(self, group):
362
        '''
363
        Save information regarding the groups of plugins to iterate over
364
        '''
365
        for count, iterate_group in enumerate(self.iterate_plugin_groups):
366
            grp_name = str(count)
367
            grp = group.create_group(grp_name.encode('ascii'))
368
            shape = () # scalar data
369
            grp.create_dataset('start'.encode('ascii'), shape, 'i',
370
                iterate_group['start_index'])
371
            grp.create_dataset('end'.encode('ascii'), shape, 'i',
372
                iterate_group['end_index'])
373
            grp.create_dataset('iterations'.encode('ascii'), shape, 'i',
374
                iterate_group['iterations'])
375
376
    def shift_subsequent_iterative_loops(self, pos, direction):
377
        """
378
        Shift all iterative loops occurring after a given plugin position
379
        """
380
        # if removing a plugin that is positioned before a loop, the loop should
381
        # be shifted down by 1; but if removing a plugin that is positioned at
382
        # the start of the loop, it will be removed instead of shifted (ie, both
383
        # < or <= work for this case)
384
        #
385
        # if adding a plugin that will be positioned before a loop, the loop
386
        # should be shifted up by 1; also, if adding a plugin to be positioned
387
        # where the start of a loop currently exists, this should shift the loop
388
        # up by 1 as well (ie, only <= works for this case, hence the use of <=)
389
        for iterate_group in self.iterate_plugin_groups:
390
            if pos <= iterate_group['start_index']:
391
                self.shift_iterative_loop(iterate_group, direction)
392
393
    def shift_range_iterative_loops(self, positions, direction):
394
        """
395
        Shift all iterative loops within a range of plugin indices
396
        """
397
        for iterate_group in self.iterate_plugin_groups:
398
            if positions[0] <= iterate_group['start_index'] and \
399
                iterate_group['end_index'] <= positions[1]:
400
                self.shift_iterative_loop(iterate_group, direction)
401
402
    def shift_iterative_loop(self, iterate_group, direction):
403
        """
404
        Shift an iterative loop up or down in the process list, based on if a
405
        plugin is added or removed
406
        """
407
        if direction == 1:
408
            iterate_group['start_index'] += 1
409
            iterate_group['end_index'] += 1
410
        elif direction == -1:
411
            iterate_group['start_index'] -= 1
412
            iterate_group['end_index'] -= 1
413
        else:
414
            err_str = f"Bad direction value given to shift iterative loop: " \
415
                      f"{direction}"
416
            raise ValueError(err_str)
417
418
    def __save_savu_notes(self, notes):
419
        """ Save the version number
420
421
        :param notes: hdf5 group to save data to
422
        """
423
        from savu.version import __version__
424
425
        notes["version"] = __version__
426
427
    def __populate_plugins_group(self, plugin_group, plugin):
428
        """Populate the plugin group information which will be saved
429
430
        :param plugin_group: Plugin group to save to
431
        :param plugin: Plugin to be saved
432
        """
433
        plugin_group.attrs[NX_CLASS] = "NXnote".encode("ascii")
434
        required_keys = self._get_plugin_entry_template().keys()
435
        json_keys = self.__get_json_keys()
436
437
        self._save_citations(plugin, plugin_group)
438
439
        for key in required_keys:
440
            # only need to apply dumps if saving in configurator
441
            if key == "data":
442
                data = {}
443
                for k, v in plugin[key].items():
444
                    #  Replace any missing quotes around variables.
445
                    data[k] = pu._dumps(v)
446
            else:
447
                data = plugin[key]
448
449
            # get the string value
450
            data = json.dumps(data) if key in json_keys else plugin[key]
451
            # if the data is string it has to be encoded to ascii so that
452
            # hdf5 can save out the bytes
453
            if isinstance(data, str):
454
                data = data.encode("ascii")
455
            data = np.array([data])
456
            plugin_group.create_dataset(
457
                key.encode("ascii"), data.shape, data.dtype, data
458
            )
459
460
    def _get_plugin_group(self, plugins_group, plugin, count):
461
        """Return the plugin_group, into which the plugin information
462
         will be saved
463
464
        :param plugins_group: Current group to save inside
465
        :param plugin: Plugin to be saved
466
        :param count: Order number of the plugin in the process list
467
        :return: plugin group
468
        """
469
        if "pos" in plugin.keys():
470
            num = int(re.findall(r"\d+", str(plugin["pos"]))[0])
471
            letter = re.findall("[a-z]", str(plugin["pos"]))
472
            letter = letter[0] if letter else ""
473
            group_name = "%i%s" % (num, letter)
474
        else:
475
            group_name = count
476
        return plugins_group.create_group(group_name.encode("ascii"))
477
478
    def _add(self, idx, entry):
479
        self.plugin_list.insert(idx, entry)
480
        self.__set_loaders_and_savers()
481
482
    def _remove(self, idx):
483
        del self.plugin_list[idx]
484
        self.__set_loaders_and_savers()
485
486
    def _save_citations(self, plugin, group):
487
        """Save all the citations in the plugin
488
489
        :param plugin: dictionary of plugin information
490
        :param group: Group to save to
491
        """
492
        if "tools" in plugin.keys():
493
            citation_plugin = plugin.get("tools").get_citations()
494
            if citation_plugin:
495
                count = 1
496
                for citation in citation_plugin.values():
497
                    group_label = f"citation{count}"
498
                    if (
499
                        not citation.dependency
500
                        or self._dependent_citation_used(plugin, citation)
501
                    ):
502
                        self._save_citation_group(
503
                            citation, citation.__dict__, group, group_label
504
                        )
505
                        count += 1
506
507
    def _dependent_citation_used(self, plugin, citation):
508
        """Check if any Plugin parameter values match the citation
509
        dependency requirement.
510
511
        :param plugin: dictionary of plugin information
512
        :param citation: A plugin citation
513
        :return: bool True if the citation is for a parameter value
514
            being used inside this plugin
515
        """
516
        parameters = plugin["data"]
517
        for (
518
            citation_dependent_parameter,
519
            citation_dependent_value,
520
        ) in citation.dependency.items():
521
            current_value = parameters[citation_dependent_parameter]
522
            if current_value == citation_dependent_value:
523
                return True
524
        return False
525
526
    def _exec_citations(self, cite, citation):
527
        """Execute citations to variable
528
        :param cite: citation dictionary
529
        """
530
        for key, value in cite.items():
531
            exec("citation." + key + "= value")
532
533
    def _save_framework_citations(self, group):
534
        """Save all the citations in dict
535
536
        :param group: Group for nxs file
537
        """
538
        framework_cites = fc.get_framework_citations()
539
        for cite in framework_cites.values():
540
            label = cite.short_name_article
541
            del cite.short_name_article
542
            self._save_citation_group(cite, cite.__dict__, group, label)
543
544
545
    def _save_citation_group(self, citation, cite_dict, group, group_label):
546
        """Save the citations to the provided group label
547
548
        :param citation: Citation object
549
        :param cite_dict: Citation dictionary
550
        :param group: Group
551
        :param group_label: Group label
552
        :return:
553
        """
554
        citation_group = group.require_group(group_label.encode("ascii"))
555
        self._exec_citations(cite_dict, citation)
556
        citation.write(citation_group)
557
558
    def _get_docstring_info(self, plugin):
559
        plugin_inst = pu.plugins[plugin]()
560
        tools = plugin_inst.get_plugin_tools()
561
        tools._populate_default_parameters()
562
        return plugin_inst.docstring_info
563
564
    # def _byteify(self, input):
565
    #     if isinstance(input, dict):
566
    #         return {self._byteify(key): self._byteify(value)
567
    #                 for key, value in input.items()}
568
    #     elif isinstance(input, list):
569
    #         temp = [self._byteify(element) for element in input]
570
    #         return temp
571
    #     elif isinstance(input, str):
572
    #         return input.encode('utf-8')
573
    #     else:
574
    #         return input
575
576
    def _set_datasets_list(self, plugin):
577
        in_pData, out_pData = plugin.get_plugin_datasets()
578
        in_data_list = self._populate_datasets_list(in_pData)
579
        out_data_list = self._populate_datasets_list(out_pData)
580
        self.datasets_list.append({'in_datasets': in_data_list,
581
                                   'out_datasets': out_data_list})
582
583
    def _populate_datasets_list(self, data):
584
        data_list = []
585
        for d in data:
586
            name = d.data_obj.get_name()
587
            pattern = copy.deepcopy(d.get_pattern())
588
            pattern[list(pattern.keys())[0]]['max_frames_transfer'] = \
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable list does not seem to be defined.
Loading history...
589
                d.meta_data.get('max_frames_transfer')
590
            pattern[list(pattern.keys())[0]]['transfer_shape'] = \
591
                d.meta_data.get('transfer_shape')
592
            data_list.append({'name': name, 'pattern': pattern})
593
        return data_list
594
595
    def _get_datasets_list(self):
596
        return self.datasets_list
597
598
    def _reset_datasets_list(self):
599
        self.datasets_list = []
600
601
    def _get_n_loaders(self):
602
        return self.n_loaders
603
604
    def _get_n_savers(self):
605
        return self.n_savers
606
607
    def _get_loaders_index(self):
608
        return self.loader_idx
609
610
    def _get_savers_index(self):
611
        return self.saver_idx
612
613
    def _get_n_processing_plugins(self):
614
        return len(self.plugin_list) - self._get_n_loaders()
615
616
    def __set_loaders_and_savers(self):
617
        """Get lists of loader and saver positions within the plugin list and
618
        set the number of loaders.
619
620
        :returns: loader index list and saver index list
621
        :rtype: list(int(loader)), list(int(saver))
622
        """
623
        from savu.plugins.loaders.base_loader import BaseLoader
624
        from savu.plugins.savers.base_saver import BaseSaver
625
626
        loader_idx = []
627
        saver_idx = []
628
        self.n_plugins = len(self.plugin_list)
629
630
        for i in range(self.n_plugins):
631
            pid = self.plugin_list[i]["id"]
632
            bases = inspect.getmro(pu.load_class(pid))
633
            loader_list = [b for b in bases if b == BaseLoader]
634
            saver_list = [b for b in bases if b == BaseSaver]
635
            if loader_list:
636
                loader_idx.append(i)
637
            if saver_list:
638
                saver_idx.append(i)
639
        self.loader_idx = loader_idx
640
        self.saver_idx = saver_idx
641
        self.n_loaders = len(loader_idx)
642
        self.n_savers = len(saver_idx)
643
644
    def _check_loaders(self):
645
        """Check plugin list starts with a loader."""
646
        self.__set_loaders_and_savers()
647
        loaders = self._get_loaders_index()
648
649
        if loaders:
650
            if loaders[0] != 0 or loaders[-1] + 1 != len(loaders):
651
                raise Exception("All loader plugins must be at the beginning "
652
                                "of the process list")
653
        else:
654
            raise Exception("The first plugin in the process list must be a "
655
                            "loader plugin.")
656
657
    def _add_missing_savers(self, exp):
658
        """ Add savers for missing datasets. """
659
        data_names = exp.index["in_data"].keys()
660
        saved_data = []
661
662
        for i in self._get_savers_index():
663
            saved_data.append(self.plugin_list[i]["data"]["in_datasets"])
664
        saved_data = set([s for sub_list in saved_data for s in sub_list])
665
666
        for name in [data for data in data_names if data not in saved_data]:
667
            pos = exp.meta_data.get("nPlugin") + 1
668
            exp.meta_data.set("nPlugin", pos)
669
            process = {}
670
            plugin = pu.load_class("savu.plugins.savers.hdf5_saver")()
671
            ptools = plugin.get_plugin_tools()
672
            plugin.parameters["in_datasets"] = [name]
673
            process["name"] = plugin.name
674
            process["id"] = plugin.__module__
675
            process["pos"] = str(pos + 1)
676
            process["data"] = plugin.parameters
677
            process["active"] = True
678
            process["param"] = ptools.get_param_definitions()
679
            process["doc"] = ptools.docstring_info
680
            process["tools"] = ptools
681
            self._add(pos + 1, process)
682
683
    def _update_datasets(self, plugin_no, data_dict):
684
        n_loaders = self._get_n_loaders()
685
        idx = self._get_n_loaders() + plugin_no
686
        self.plugin_list[idx]["data"].update(data_dict)
687
688
    def _get_dataset_flow(self):
689
        datasets_idx = []
690
        n_loaders = self._get_n_loaders()
691
        n_plugins = self._get_n_processing_plugins()
692
        for i in range(self.n_loaders, n_loaders + n_plugins):
693
            datasets_idx.append(self.plugin_list[i]["data"]["out_datasets"])
694
        return datasets_idx
695
696
    def _contains_gpu_processes(self):
697
        """ Returns True if gpu processes exist in the process list. """
698
        try:
699
            from savu.plugins.driver.gpu_plugin import GpuPlugin
700
701
            for i in range(self.n_plugins):
702
                bases = inspect.getmro(
703
                    pu.load_class(self.plugin_list[i]["id"])
704
                )
705
                if GpuPlugin in bases:
706
                    return True
707
        except ImportError as ex:
708
            if "pynvml" in ex.message:
709
                logging.error(
710
                    "Error while importing GPU dependencies: %s", ex.message
711
                )
712
            else:
713
                raise
714
715
        return False
716
717
718
class Template(object):
719
    """A class to read and write templates for plugin lists."""
720
721
    def __init__(self, plist):
722
        super(Template, self).__init__()
723
        self.plist = plist
724
        self.creating = False
725
726
    def _output_template(self, fname, process_fname):
727
        plist = self.plist.plugin_list
728
        index = [i for i in range(len(plist)) if plist[i]["active"]]
729
730
        local_dict = MetaData(ordered=True)
731
        global_dict = MetaData(ordered=True)
732
        local_dict.set(["process_list"], os.path.abspath(process_fname))
733
734
        for i in index:
735
            params = self.__get_template_params(plist[i]["data"], [])
736
            name = plist[i]["name"]
737
            for p in params:
738
                ptype, isyaml, key, value = p
739
                if isyaml:
740
                    data_name = isyaml if ptype == "local" else "all"
741
                    local_dict.set([i + 1, name, data_name, key], value)
742
                elif ptype == "local":
743
                    local_dict.set([i + 1, name, key], value)
744
                else:
745
                    global_dict.set(["all", name, key], value)
746
747
        with open(fname, "w") as stream:
748
            local_dict.get_dictionary().update(global_dict.get_dictionary())
749
            yu.dump_yaml(local_dict.get_dictionary(), stream)
750
751
    def __get_template_params(self, params, tlist, yaml=False):
752
        for key, value in params.items():
753
            if key == "yaml_file":
754
                yaml_dict = self._get_yaml_dict(value)
755
                for entry in list(yaml_dict.keys()):
756
                    self.__get_template_params(
757
                        yaml_dict[entry]["params"], tlist, yaml=entry
758
                    )
759
            value = pu.is_template_param(value)
760
            if value is not False:
761
                ptype, value = value
762
                isyaml = yaml if yaml else False
763
                tlist.append([ptype, isyaml, key, value])
764
        return tlist
765
766
    def _get_yaml_dict(self, yfile):
767
        from savu.plugins.loaders.yaml_converter import YamlConverter
768
769
        yaml_c = YamlConverter()
770
        template_check = pu.is_template_param(yfile)
771
        yfile = template_check[1] if template_check is not False else yfile
772
        yaml_c.parameters = {"yaml_file": yfile}
773
        return yaml_c.setup(template=True)
774
775
    def update_process_list(self, template):
776
        tdict = yu.read_yaml(template)
777
        del tdict["process_list"]
778
779
        for plugin_no, entry in tdict.items():
780
            plugin = list(entry.keys())[0]
781
            for key, value in list(entry.values())[0].items():
782
                depth = self.dict_depth(value)
783
                if depth == 1:
784
                    self._set_param_for_template_loader_plugin(
785
                        plugin_no, key, value
786
                    )
787
                elif depth == 0:
788
                    if plugin_no == "all":
789
                        self._set_param_for_all_instances_of_a_plugin(
790
                            plugin, key, value
791
                        )
792
                    else:
793
                        data = self._get_plugin_data_dict(str(plugin_no))
794
                        data[key] = value
795
                else:
796
                    raise Exception("Template key not recognised.")
797
798
    def dict_depth(self, d, depth=0):
799
        if not isinstance(d, dict) or not d:
800
            return depth
801
        return max(self.dict_depth(v, depth + 1) for k, v in d.items())
802
803
    def _set_param_for_all_instances_of_a_plugin(self, plugin, param, value):
804
        # find all plugins with this name and replace the param
805
        for p in self.plist.plugin_list:
806
            if p["name"] == plugin:
807
                p["data"][param] = value
808
809
    def _set_param_for_template_loader_plugin(self, plugin_no, data, value):
810
        param_key = list(value.keys())[0]
811
        param_val = list(value.values())[0]
812
        pdict = self._get_plugin_data_dict(str(plugin_no))["template_param"]
813
        pdict = defaultdict(dict) if not pdict else pdict
814
        pdict[data][param_key] = param_val
815
816
    def _get_plugin_data_dict(self, plugin_no):
817
        """ input plugin_no as a string """
818
        plist = self.plist.plugin_list
819
        index = [plist[i]["pos"] for i in range(len(plist))]
820
        return plist[index.index(plugin_no)]["data"]
821
822
823
class CitationInformation(object):
824
    """
825
    Descriptor of Citation Information for plugins
826
    """
827
828
    def __init__(
829
        self,
830
        description,
831
        bibtex="",
832
        endnote="",
833
        doi="",
834
        short_name_article="",
835
        dependency="",
836
    ):
837
        self.description = description
838
        self.short_name_article = short_name_article
839
        self.bibtex = bibtex
840
        self.endnote = endnote
841
        self.doi = doi
842
        self.dependency = dependency
843
        self.name = self._set_citation_name()
844
        self.id = self._set_id()
845
846
    def _set_citation_name(self):
847
        """Create a short identifier using the short name of the article
848
        and the first author
849
        """
850
        cite_info = True if self.endnote or self.bibtex else False
851
852
        if cite_info and self.short_name_article:
853
            cite_name = (
854
                self.short_name_article.title()
855
                + " by "
856
                + self._get_first_author()
857
                + " et al."
858
            )
859
        elif cite_info:
860
            cite_name = (
861
                self._get_title()
862
                + " by "
863
                + self._get_first_author()
864
                + " et al."
865
            )
866
        else:
867
            # Set the tools class name as the citation name causes overwriting
868
            # module_name = tool_class.__module__.split('.')[-1].replace('_', ' ')
869
            # cite_name = module_name.split('tools')[0].title()
870
            cite_name = self.description
871
        return cite_name
872
873
    def _set_citation_id(self, tool_class):
874
        """Create a short identifier using the bibtex identification"""
875
        # Remove blank space
876
        if self.bibtex:
877
            cite_id = self._get_id()
878
        else:
879
            # Set the tools class name as the citation name
880
            module_name = tool_class.__module__.split(".")[-1]
881
            cite_id = module_name
882
        return cite_id
883
884
    def _set_id(self):
885
        """ Retrieve the id from the bibtex """
886
        cite_id = self.seperate_bibtex("@article{", ",")
887
        return cite_id
888
889
    def _get_first_author(self):
890
        """ Retrieve the first author name """
891
        if self.endnote:
892
            first_author = self.seperate_endnote("%A ")
893
        elif self.bibtex:
894
            first_author = self.seperate_bibtex("author={", "}", author=True)
895
        return first_author
0 ignored issues
show
introduced by
The variable first_author does not seem to be defined for all execution paths.
Loading history...
896
897
    def _get_title(self):
898
        """ Retrieve the title """
899
        if self.endnote:
900
            title = self.seperate_endnote("%T ")
901
        elif self.bibtex:
902
            title = self.seperate_bibtex("title={", "}")
903
        return title
0 ignored issues
show
introduced by
The variable title does not seem to be defined for all execution paths.
Loading history...
904
905
    def seperate_endnote(self, seperation_char):
906
        """Return the string contained between start characters
907
        and a new line
908
909
        :param seperation_char: Character to split the string at
910
        :return: The string contained between start characters and a new line
911
        """
912
        item = self.endnote.partition(seperation_char)[2].split("\n")[0]
913
        return item
914
915
    def seperate_bibtex(self, start_char, end_char, author=False):
916
        """Return the string contained between provided characters
917
918
        :param start_char: Character to split the string at
919
        :param end_char: Character to end the split at
920
        :return: The string contained between both characters
921
        """
922
        plain_text = doc.remove_new_lines(self.bibtex)
923
        item = plain_text.partition(start_char)[2].split(end_char)[0]
924
        if author:
925
            # Return one author only
926
            if " and " in item:
927
                item = item.split(" and ")[0]
928
929
        return item
930
931
    def write(self, citation_group):
932
        # classes don't have to be encoded to ASCII
933
        citation_group.attrs[NX_CLASS] = "NXcite"
934
        # Valid ascii sequences will be encoded, invalid ones will be
935
        # preserved as escape sequences
936
        description_array = np.array([self.description.encode('ascii','backslashreplace')])
937
        citation_group.create_dataset('description'.encode('ascii'),
938
                                      description_array.shape,
939
                                      description_array.dtype,
940
                                      description_array)
941
        doi_array = np.array([self.doi.encode('ascii')])
942
        citation_group.create_dataset('doi'.encode('ascii'),
943
                                      doi_array.shape,
944
                                      doi_array.dtype,
945
                                      doi_array)
946
        endnote_array = np.array([self.endnote.encode('ascii','backslashreplace')])
947
        citation_group.create_dataset('endnote'.encode('ascii'),
948
                                      endnote_array.shape,
949
                                      endnote_array.dtype,
950
                                      endnote_array)
951
        bibtex_array = np.array([self.bibtex.encode('ascii','backslashreplace')])
952
        citation_group.create_dataset('bibtex'.encode('ascii'),
953
                                      bibtex_array.shape,
954
                                      bibtex_array.dtype,
955
                                      bibtex_array)
956