Test Failed
Pull Request — master (#708)
by Daniil
03:33
created

savu.data.plugin_list   F

Complexity

Total Complexity 108

Size/Duplication

Total Lines 530
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 377
dl 0
loc 530
rs 2
c 0
b 0
f 0
wmc 108

40 Methods

Rating   Name   Duplication   Size   Complexity  
A Template._get_plugin_data_dict() 0 5 1
A PluginList._get_n_loaders() 0 2 1
A PluginList.add_template() 0 4 2
A PluginList._add() 0 3 1
A PluginList._check_loaders() 0 12 4
A PluginList._get_savers_index() 0 2 1
A PluginList._get_docstring_info() 0 4 1
B Template.__get_template_params() 0 13 6
A Template.__init__() 0 4 1
A PluginList._get_plugin_entry_template() 0 9 1
A PluginList._save_framework_citations() 0 11 2
A PluginList._contains_gpu_processes() 0 16 5
A PluginList._overwrite_group() 0 6 2
A PluginList._update_datasets() 0 4 1
A PluginList._set_datasets_list() 0 6 1
A Template._set_param_for_template_loader_plugin() 0 6 2
A PluginList._populate_datasets_list() 0 11 2
B Template._output_template() 0 24 7
A PluginList._get_n_savers() 0 2 1
B Template.update_process_list() 0 20 6
A PluginList.__init__() 0 11 1
A PluginList._get_dataset_flow() 0 7 2
A PluginList._reset_datasets_list() 0 2 1
C PluginList._populate_plugin_list() 0 37 9
A Template._get_yaml_dict() 0 7 2
A PluginList._add_missing_savers() 0 22 3
A Template.dict_depth() 0 4 3
A PluginList._get_loaders_index() 0 2 1
C PluginList.__populate_plugins_group() 0 31 9
A PluginList._save_plugin_list() 0 20 5
A PluginList._get_n_processing_plugins() 0 2 1
A Template._set_param_for_all_instances_of_a_plugin() 0 5 3
A PluginList._get_datasets_list() 0 2 1
A CitationInformation.write() 0 23 1
A PluginList.__set_loaders_and_savers() 0 26 4
B PluginList.__dumps() 0 30 8
A PluginList.__save_savu_notes() 0 3 1
A PluginList._remove() 0 3 1
A PluginList._output_plugin_citations() 0 6 3
A PluginList.__get_json_keys() 0 2 1

How to fix   Complexity   

Complexity

Complex classes like savu.data.plugin_list often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
# Copyright 2014 Diamond Light Source Ltd.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
15
"""
16
.. module:: plugin_list
17
   :platform: Unix
18
   :synopsis: Contains the PluginList class, which deals with loading and \
19
   saving the plugin list, and the CitationInformation class. An instance is \
20
   held by the MetaData class.
21
22
.. moduleauthor:: Nicola Wadeson <[email protected]>
23
24
"""
25
import ast
26
import copy
27
import inspect
28
import json
29
import logging
30
import os
31
import re
32
from collections import defaultdict
33
34
import h5py
35
import numpy as np
36
import yaml
37
38
import savu.data.framework_citations as fc
39
import savu.plugins.loaders.utils.yaml_utils as yu
40
import savu.plugins.utils as pu
41
from savu.data.meta_data import MetaData
42
43
NX_CLASS = 'NX_class'
44
45
46
class PluginList(object):
47
    """
48
    The PluginList class handles the plugin list - loading, saving and adding
49
    citation information for the plugin
50
    """
51
52
    def __init__(self):
53
        self.plugin_list = []
54
        self.n_plugins = None
55
        self.n_loaders = 0
56
        self.n_savers = 0
57
        self.loader_idx = None
58
        self.saver_idx = None
59
        self.datasets_list = []
60
        self.saver_plugin_status = True
61
        self._template = None
62
        self.version = None
63
64
    def add_template(self, create=False):
65
        self._template = Template(self)
66
        if create:
67
            self._template.creating = True
68
69
    def _get_plugin_entry_template(self):
70
        template = {'active': True,
71
                    'name': None,
72
                    'id': None,
73
                    'desc': None,
74
                    'data': None,
75
                    'user': [],
76
                    'hide': []}
77
        return template
78
79
    def __get_json_keys(self):
80
        return ['data', 'desc', 'user', 'hide']
81
82
    def _populate_plugin_list(self, filename, active_pass=False, template=False):
83
        """ Populate the plugin list from a nexus file. """
84
        plugin_file = h5py.File(filename, 'r')
85
86
        if 'entry/savu_notes/version' in plugin_file:
87
            self.version = plugin_file['entry/savu_notes/version'][()]
88
89
        plugin_group = plugin_file['entry/plugin']
90
        self.plugin_list = []
91
        single_val = ['name', 'id', 'pos', 'active']
92
        exclude = ['citation']
93
        for group in plugin_group.keys():
94
            plugin = self._get_plugin_entry_template()
95
            entry_keys = plugin_group[group].keys()
96
            parameters = [k for k in entry_keys for e in exclude if k not in
97
                          single_val and e not in k]
98
99
            if 'active' in entry_keys:
100
                plugin['active'] = plugin_group[group]['active'][0]
101
102
            if plugin['active'] or active_pass:
103
                plugin['name'] = plugin_group[group]['name'][0].decode("utf-8")
104
                plugin['id'] = plugin_group[group]['id'][0].decode("utf-8")
105
                plugin['pos'] = group.strip()
106
107
                for param in parameters:
108
                    try:
109
                        plugin[param] = json.loads(plugin_group[group][param][0])
110
                    except ValueError as e:
111
                        raise ValueError(f"Error: {e}\n Could not parse key '{param}' from group '{group}' as JSON")
112
                self.plugin_list.append(plugin)
113
114
        if template:
115
            self.add_template()
116
            self._template.update_process_list(template)
117
118
        plugin_file.close()
119
120
    def _save_plugin_list(self, out_filename):
121
        with h5py.File(out_filename, 'a') as nxs_file:
122
123
            entry = nxs_file.require_group('entry')
124
125
            self._save_framework_citations(self._overwrite_group(
126
                entry, 'framework_citations', 'NXcollection'))
127
128
            self.__save_savu_notes(self._overwrite_group(
129
                entry, 'savu_notes', 'NXnote'))
130
131
            plugins_group = self._overwrite_group(entry, 'plugin', 'NXprocess')
132
133
            count = 1
134
            for plugin in self.plugin_list:
135
                self.__populate_plugins_group(plugins_group, plugin, count)
136
137
        if self._template and self._template.creating:
138
            fname = os.path.splitext(out_filename)[0] + '.savu'
139
            self._template._output_template(fname, out_filename)
140
141
    def _overwrite_group(self, entry, name, nxclass):
142
        if name in entry:
143
            entry.pop(name)
144
        group = entry.create_group(name.encode("ascii"))
145
        group.attrs[NX_CLASS] = nxclass.encode("ascii")
146
        return group
147
148
    def __save_savu_notes(self, notes):
149
        from savu.version import __version__
150
        notes['version'] = __version__
151
152
    def __populate_plugins_group(self, plugins_group, plugin, count):
153
        if 'pos' in plugin.keys():
154
            num = int(re.findall(r'\d+', plugin['pos'])[0])
155
            letter = re.findall('[a-z]', plugin['pos'])
156
            letter = letter[0] if letter else ""
157
            group_name = "%*i%*s" % (4, num, 1, letter)
158
        else:
159
            group_name = "%*i" % (4, count)
160
161
        plugin_group = plugins_group.create_group(group_name.encode("ascii"))
162
163
        plugin_group.attrs[NX_CLASS] = 'NXnote'.encode('ascii')
164
        required_keys = self._get_plugin_entry_template().keys()
165
        json_keys = self.__get_json_keys()
166
167
        if 'cite' in plugin.keys():
168
            if plugin['cite'] is not None:
169
                self._output_plugin_citations(plugin['cite'], plugin_group)
170
171
        for key in required_keys:
172
            # only need to apply dumps if saving in configurator
173
            data = self.__dumps(plugin[key]) if key == 'data' else plugin[key]
174
175
            # get the string value
176
            data = json.dumps(data) if key in json_keys else plugin[key]
177
            # if the data is string it has to be encoded to ascii so that
178
            # hdf5 can save out the bytes
179
            if isinstance(data, str):
180
                data = data.encode("ascii")
181
            data = np.array([data])
182
            plugin_group.create_dataset(key.encode('ascii'), data.shape, data.dtype, data)
183
184
    def __dumps(self, data_dict):
185
        """ Replace any missing quotes around variables
186
        """
187
        for key, val in data_dict.items():
188
            if isinstance(val, str):
189
                try:
190
                    data_dict[key] = ast.literal_eval(val)
191
                    continue
192
                except Exception:
193
                    pass
194
                try:
195
                    data_dict[key] = yaml.load(val, Loader=yaml.SafeLoader)
196
                    continue
197
                except Exception:
198
                    pass
199
                try:
200
                    isdict = re.findall(r"[\{\}]+", val)
201
                    if isdict:
202
                        val = val.replace("[", "'[").replace("]", "]'")
203
                        data_dict[key] = self.__dumps(yaml.load(val))
204
                    else:
205
                        data_dict[key] = pu.parse_config_string(val)
206
                    continue
207
                except Exception:
208
                    # for when parameter tuning with lists is added to the framework
209
                    if len(val.split(';')) > 1:
210
                        pass
211
                    else:
212
                        raise Exception("Invalid string %s" % val)
213
        return data_dict
214
215
    def _add(self, idx, entry):
216
        self.plugin_list.insert(idx, entry)
217
        self.__set_loaders_and_savers()
218
219
    def _remove(self, idx):
220
        del self.plugin_list[idx]
221
        self.__set_loaders_and_savers()
222
223
    def _output_plugin_citations(self, citations, group):
224
        if not isinstance(citations, list):
225
            citations = [citations]
226
        for cite in citations:
227
            citation_group = group.create_group(cite.name.encode("ascii"))
228
            cite.write(citation_group)
229
230
    def _save_framework_citations(self, group):
231
        framework_cites = fc.get_framework_citations()
232
        for cite in framework_cites:
233
            citation_group = group.require_group(cite['name'].encode("ascii"))
234
            citation = CitationInformation()
235
            citation.name = cite["name"]
236
            citation.description = cite["description"]
237
            citation.bibtex = cite["bibtex"]
238
            citation.endnote = cite["endnote"]
239
240
            citation.write(citation_group)
241
242
    def _get_docstring_info(self, plugin):
243
        plugin_inst = pu.plugins[plugin]()
244
        plugin_inst._populate_default_parameters()
245
        return plugin_inst.docstring_info
246
247
    # def _byteify(self, input):
248
    #     if isinstance(input, dict):
249
    #         return {self._byteify(key): self._byteify(value)
250
    #                 for key, value in input.items()}
251
    #     elif isinstance(input, list):
252
    #         temp = [self._byteify(element) for element in input]
253
    #         return temp
254
    #     elif isinstance(input, str):
255
    #         return input.encode('utf-8')
256
    #     else:
257
    #         return input
258
259
    def _set_datasets_list(self, plugin):
260
        in_pData, out_pData = plugin.get_plugin_datasets()
261
        in_data_list = self._populate_datasets_list(in_pData)
262
        out_data_list = self._populate_datasets_list(out_pData)
263
        self.datasets_list.append({'in_datasets': in_data_list,
264
                                   'out_datasets': out_data_list})
265
266
    def _populate_datasets_list(self, data):
267
        data_list = []
268
        for d in data:
269
            name = d.data_obj.get_name()
270
            pattern = copy.deepcopy(d.get_pattern())
271
            pattern[list(pattern.keys())[0]]['max_frames_transfer'] = \
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable list does not seem to be defined.
Loading history...
272
                d.meta_data.get('max_frames_transfer')
273
            pattern[list(pattern.keys())[0]]['transfer_shape'] = \
274
                d.meta_data.get('transfer_shape')
275
            data_list.append({'name': name, 'pattern': pattern})
276
        return data_list
277
278
    def _get_datasets_list(self):
279
        return self.datasets_list
280
281
    def _reset_datasets_list(self):
282
        self.datasets_list = []
283
284
    def _get_n_loaders(self):
285
        return self.n_loaders
286
287
    def _get_n_savers(self):
288
        return self.n_savers
289
290
    def _get_loaders_index(self):
291
        return self.loader_idx
292
293
    def _get_savers_index(self):
294
        return self.saver_idx
295
296
    def _get_n_processing_plugins(self):
297
        return len(self.plugin_list) - self._get_n_loaders()
298
299
    def __set_loaders_and_savers(self):
300
        """ Get lists of loader and saver positions within the plugin list and
301
        set the number of loaders.
302
303
        :returns: loader index list and saver index list
304
        :rtype: list(int(loader)), list(int(saver))
305
        """
306
        from savu.plugins.loaders.base_loader import BaseLoader
307
        from savu.plugins.savers.base_saver import BaseSaver
308
        loader_idx = []
309
        saver_idx = []
310
        self.n_plugins = len(self.plugin_list)
311
312
        for i in range(self.n_plugins):
313
            pid = self.plugin_list[i]['id']
314
            bases = inspect.getmro(pu.load_class(pid))
315
            loader_list = [b for b in bases if b == BaseLoader]
316
            saver_list = [b for b in bases if b == BaseSaver]
317
            if loader_list:
318
                loader_idx.append(i)
319
            if saver_list:
320
                saver_idx.append(i)
321
        self.loader_idx = loader_idx
322
        self.saver_idx = saver_idx
323
        self.n_loaders = len(loader_idx)
324
        self.n_savers = len(saver_idx)
325
326
    def _check_loaders(self):
327
        """ Check plugin list starts with a loader and ends with a saver.
328
        """
329
        self.__set_loaders_and_savers()
330
        loaders = self._get_loaders_index()
331
332
        if loaders:
333
            if loaders[0] != 0 or loaders[-1] + 1 != len(loaders):
334
                raise Exception("All loader plugins must be at the beginning "
335
                                "of the plugin list")
336
        else:
337
            raise Exception("The first plugin in the plugin list must be a "
338
                            "loader plugin.")
339
340
    def _add_missing_savers(self, exp):
341
        """ Add savers for missing datasets. """
342
        data_names = exp.index['in_data'].keys()
343
        saved_data = []
344
        
345
        for i in self._get_savers_index():
346
            saved_data.append(self.plugin_list[i]['data']['in_datasets'])
347
        saved_data = set([s for sub_list in saved_data for s in sub_list])
348
349
        for name in [data for data in data_names if data not in saved_data]:
350
            pos = exp.meta_data.get('nPlugin')+1
351
            exp.meta_data.set('nPlugin', pos)
352
            process = {}
353
            plugin = pu.load_class('savu.plugins.savers.hdf5_saver')()
354
            plugin.parameters['in_datasets'] = [name]
355
            process['name'] = plugin.name
356
            process['id'] = plugin.__module__
357
            process['pos'] = str(pos+1)
358
            process['data'] = plugin.parameters
359
            process['active'] = True
360
            process['desc'] = plugin.parameters_desc
361
            self._add(pos+1, process)
362
363
    def _update_datasets(self, plugin_no, data_dict):
364
        n_loaders = self._get_n_loaders()
365
        idx = self._get_n_loaders() + plugin_no
366
        self.plugin_list[idx]['data'].update(data_dict)
367
368
    def _get_dataset_flow(self):
369
        datasets_idx = []
370
        n_loaders = self._get_n_loaders()
371
        n_plugins = self._get_n_processing_plugins()
372
        for i in range(self.n_loaders, n_loaders + n_plugins):
373
            datasets_idx.append(self.plugin_list[i]['data']['out_datasets'])
374
        return datasets_idx
375
376
    def _contains_gpu_processes(self):
377
        """ Returns True if gpu processes exist in the process list. """
378
        try:
379
            from savu.plugins.driver.gpu_plugin import GpuPlugin
380
            for i in range(self.n_plugins):
381
                bases = inspect.getmro(pu.load_class(self.plugin_list[i]['id']))
382
                if GpuPlugin in bases:
383
                    return True
384
        except ImportError as ex:
385
            if "pynvml" in ex.message:
386
                logging.error('Error while importing GPU dependencies: %s',
387
                              ex.message)
388
            else:
389
                raise
390
391
        return False
392
393
394
class Template(object):
395
    """ A class to read and write templates for plugin lists.
396
    """
397
398
    def __init__(self, plist):
399
        super(Template, self).__init__()
400
        self.plist = plist
401
        self.creating = False
402
403
    def _output_template(self, fname, process_fname):
404
        plist = self.plist.plugin_list
405
        index = [i for i in range(len(plist)) if plist[i]['active']]
406
407
        local_dict = MetaData(ordered=True)
408
        global_dict = MetaData(ordered=True)
409
        local_dict.set(['process_list'], os.path.abspath(process_fname))
410
411
        for i in index:
412
            params = self.__get_template_params(plist[i]['data'], [])
413
            name = plist[i]['name']
414
            for p in params:
415
                ptype, isyaml, key, value = p
416
                if isyaml:
417
                    data_name = isyaml if ptype == 'local' else 'all'
418
                    local_dict.set([i + 1, name, data_name, key], value)
419
                elif ptype == 'local':
420
                    local_dict.set([i + 1, name, key], value)
421
                else:
422
                    global_dict.set(['all', name, key], value)
423
424
        with open(fname, 'w') as stream:
425
            local_dict.get_dictionary().update(global_dict.get_dictionary())
426
            yu.dump_yaml(local_dict.get_dictionary(), stream)
427
428
    def __get_template_params(self, params, tlist, yaml=False):
429
        for key, value in params.items():
430
            if key == 'yaml_file':
431
                yaml_dict = self._get_yaml_dict(value)
432
                for entry in list(yaml_dict.keys()):
433
                    self.__get_template_params(
434
                        yaml_dict[entry]['params'], tlist, yaml=entry)
435
            value = pu.is_template_param(value)
436
            if value is not False:
437
                ptype, value = value
438
                isyaml = yaml if yaml else False
439
                tlist.append([ptype, isyaml, key, value])
440
        return tlist
441
442
    def _get_yaml_dict(self, yfile):
443
        from savu.plugins.loaders.yaml_converter import YamlConverter
444
        yaml = YamlConverter()
445
        template_check = pu.is_template_param(yfile)
446
        yfile = template_check[1] if template_check is not False else yfile
447
        yaml.parameters = {'yaml_file': yfile}
448
        return yaml.setup(template=True)
449
450
    def update_process_list(self, template):
451
        tdict = yu.read_yaml(template)
452
        del tdict['process_list']
453
454
        for plugin_no, entry in tdict.items():
455
            plugin = list(entry.keys())[0]
456
            for key, value in list(entry.values())[0].iteritems():
457
                depth = self.dict_depth(value)
458
                if depth == 1:
459
                    self._set_param_for_template_loader_plugin(
460
                        plugin_no, key, value)
461
                elif depth == 0:
462
                    if plugin_no == 'all':
463
                        self._set_param_for_all_instances_of_a_plugin(
464
                            plugin, key, value)
465
                    else:
466
                        data = self._get_plugin_data_dict(str(plugin_no))
467
                        data[key] = value
468
                else:
469
                    raise Exception("Template key not recognised.")
470
471
    def dict_depth(self, d, depth=0):
472
        if not isinstance(d, dict) or not d:
473
            return depth
474
        return max(self.dict_depth(v, depth + 1) for k, v in d.items())
475
476
    def _set_param_for_all_instances_of_a_plugin(self, plugin, param, value):
477
        # find all plugins with this name and replace the param
478
        for p in self.plist.plugin_list:
479
            if p['name'] == plugin:
480
                p['data'][param] = value
481
482
    def _set_param_for_template_loader_plugin(self, plugin_no, data, value):
483
        param_key = list(value.keys())[0]
484
        param_val = list(value.values())[0]
485
        pdict = self._get_plugin_data_dict(str(plugin_no))['template_param']
486
        pdict = defaultdict(dict) if not pdict else pdict
487
        pdict[data][param_key] = param_val
488
489
    def _get_plugin_data_dict(self, plugin_no):
490
        """ input plugin_no as a string """
491
        plist = self.plist.plugin_list
492
        index = [plist[i]['pos'] for i in range(len(plist))]
493
        return plist[index.index(plugin_no)]['data']
494
495
496
class CitationInformation(object):
497
    """
498
    Descriptor of Citation Information for plugins
499
    """
500
501
    name: str = 'citation'
502
    bibtex: str = "Default Bibtex"
503
    description: str = "Default Description"
504
    doi: str = "Default DOI"
505
    endnote: str = "Default Endnote"
506
507
    def write(self, citation_group):
508
        # classes don't have to be encoded to ASCII
509
        citation_group.attrs[NX_CLASS] = 'NXcite'
510
        description_array = np.array([self.description.encode('ascii')])
511
        citation_group.create_dataset('description'.encode('ascii'),
512
                                      description_array.shape,
513
                                      description_array.dtype,
514
                                      description_array)
515
        doi_array = np.array([self.doi.encode('ascii')])
516
        citation_group.create_dataset('doi'.encode('ascii'),
517
                                      doi_array.shape,
518
                                      doi_array.dtype,
519
                                      doi_array)
520
        endnote_array = np.array([self.endnote.encode('ascii')])
521
        citation_group.create_dataset('endnote'.encode('ascii'),
522
                                      endnote_array.shape,
523
                                      endnote_array.dtype,
524
                                      endnote_array)
525
        bibtex_array = np.array([self.bibtex.encode('ascii')])
526
        citation_group.create_dataset('bibtex'.encode('ascii'),
527
                                      bibtex_array.shape,
528
                                      bibtex_array.dtype,
529
                                      bibtex_array)
530