Test Failed
Pull Request — master (#772)
by
unknown
03:45
created

savu.plugins.utils._get_cls_name()   A

Complexity

Conditions 1

Size

Total Lines 2
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 2
nop 1
dl 0
loc 2
rs 10
c 0
b 0
f 0
1
# Copyright 2014 Diamond Light Source Ltd.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
"""
15
.. module:: utils
16
   :platform: Unix
17
   :synopsis: Utilities for plugin management
18
19
.. moduleauthor:: Mark Basham <[email protected]>
20
21
"""
22
23
import os
24
import re
25
import sys
26
import ast
27
import logging
28
import savu
29
import importlib
30
import inspect
31
import itertools
32
33
from collections import OrderedDict
34
import numpy as np
35
36
# can I remove these from here?
37
38
load_tools = {}
39
plugins = {}
40
plugins_path = {}
41
dawn_plugins = {}
42
count = 0
43
44
OUTPUT_TYPE_DATA_ONLY = 0
45
OUTPUT_TYPE_METADATA_ONLY = 1
46
OUTPUT_TYPE_METADATA_AND_DATA = 2
47
48
49
def register_plugin(clazz):
50
    """decorator to add plugins to a central register"""
51
    plugins[clazz.__name__] = clazz
52
    if clazz.__module__.split(".")[0] != "savu":
53
        plugins_path[clazz.__name__] = clazz.__module__
54
    return clazz
55
56
57
def dawn_compatible(plugin_output_type=OUTPUT_TYPE_METADATA_AND_DATA):
58
    def _dawn_compatible(clazz):
59
        """
60
        decorator to add dawn compatible plugins and details to a central
61
        register
62
        """
63
        dawn_plugins[clazz.__name__] = {}
64
        try:
65
            plugin_path = sys.modules[clazz.__module__].__file__
66
            # looks out for .pyc files
67
            dawn_plugins[clazz.__name__]['path2plugin'] = plugin_path.split('.py')[0] + '.py'
68
            dawn_plugins[clazz.__name__]['plugin_output_type'] = _plugin_output_type
69
        except Exception as e:
70
            print(e)
71
        return clazz
72
73
    # for backwards compatibility, if decorator is invoked without brackets...
74
    if inspect.isclass(plugin_output_type):
75
        _plugin_output_type = OUTPUT_TYPE_METADATA_AND_DATA
76
        return _dawn_compatible(plugin_output_type)
77
    else:
78
        _plugin_output_type = plugin_output_type
79
        return _dawn_compatible
80
81
82
def get_plugin(plugin_name, params, exp, check=False):
83
    """Get an instance of the plugin class and populate default parameters.
84
85
    :param plugin_name: Name of the plugin to import
86
    :type plugin_name: str.
87
    :returns:  An instance of the class described by the named plugin.
88
    """
89
    logging.debug("Importing the module %s", plugin_name)
90
    instance = load_class(plugin_name)()
91
    instance.initialise(params, exp, check=check)
92
    return instance
93
94
95
def _get_cls_name(name):
96
    return "".join(x.capitalize() for x in name.split(".")[-1].split("_"))
97
98
99
def load_class(name, cls_name=None):
100
    """Returns an instance of the class associated with the module name.
101
102
    :param name: Module name or path to a module file
103
    :returns: An instance of the class associated with module.
104
    """
105
    path = name if os.path.dirname(name) else None
106
    name = os.path.basename(os.path.splitext(name)[0]) if path else name
107
    cls_name = _get_cls_name(name) if not cls_name else cls_name
108
    if cls_name in plugins.keys():
109
        return plugins[cls_name]
110
    if path:
111
        mod = importlib.machinery.SourceFileLoader(name, path).load_module()
112
    else:
113
        mod = importlib.import_module(name)
114
    return getattr(mod, cls_name)
115
116
117
def plugin_loader(exp, plugin_dict, check=False):
118
    logging.debug("Running plugin loader")
119
    try:
120
        plugin = get_plugin(plugin_dict['id'],
121
                            plugin_dict['data'],
122
                            exp,
123
                            check=check)
124
    except Exception as e:
125
        logging.error("failed to load the plugin")
126
        logging.error(e)
127
        # re-raise the original error
128
        raise
129
130
    if check:
131
        exp.meta_data.plugin_list._set_datasets_list(plugin)
132
133
    logging.debug("finished plugin loader")
134
    return plugin
135
136
137
def get_tools_class(plugin_tools_id, cls=None):
138
    if plugin_tools_id == "savu.plugins.plugin_tools":
139
        plugin_tools_id = "savu.plugins.base_tools"
140
141
    # determine Savu base path
142
    path_name = plugin_tools_id.replace(".", "/")
143
    file_path = savu.__path__[0] + "/../" + path_name + ".py"
144
    if os.path.isfile(file_path):
145
        if cls:
146
            return load_class(plugin_tools_id)(cls)
147
        else:
148
            return load_class(plugin_tools_id)
149
    else:
150
        Exception("Tools file %s not found." % path_name)
151
152
153
def get_plugins_paths(examples=True):
154
    """
155
    This gets the plugin paths, but also adds any that are not on the
156
    pythonpath to it.
157
    """
158
    plugins_paths = OrderedDict()
159
160
    # Add the savu plugins paths first so it is overridden by user folders
161
    savu_plugins_path = os.path.join(savu.__path__[0], 'plugins')
162
    savu_plugins_subpaths = [d for d in next(os.walk(savu_plugins_path))[1] \
163
                             if d != "__pycache__"]
164
    for path in savu_plugins_subpaths:
165
        plugins_paths[os.path.join(savu_plugins_path, path)] = \
166
            ''.join(['savu.plugins.', path, '.'])
167
168
    # get user, environment and example plugin paths
169
    user_path = [os.path.join(os.path.expanduser("~"), "savu_plugins")]
170
    env_paths = os.getenv("SAVU_PLUGINS_PATH", "").replace(" ", "").split(":")
171
    templates = "../plugin_examples/plugin_templates"
172
    eg_path = [os.path.join(savu.__path__[0], templates)] if examples else []
173
174
    for ppath in env_paths + user_path + eg_path:
175
        if os.path.exists(ppath):
176
            plugins_paths[ppath] = os.path.basename(ppath) + "."
177
            if ppath not in sys.path:
178
                sys.path.append(os.path.dirname(ppath))
179
180
    return plugins_paths
181
182
183
def is_template_param(param):
184
    """Identifies if the parameter should be included in an input template
185
    and returns the default value of the parameter if it exists.
186
    """
187
    start = 0
188
    ptype = "local"
189
    if isinstance(param, str):
190
        param = param.strip()
191
        if not param.split("global")[0]:
192
            ptype = "global"
193
            start = 6
194
        first, last = param[start], param[-1]
195
        if first == "<" and last == ">":
196
            param = param[start + 1 : -1]
197
            param = None if not param else param
198
            try:
199
                param = eval(param)
200
            except:
201
                pass
202
            return [ptype, param]
203
    return False
204
205
206
def blockPrint():
207
    """ Disable printing to stdout """
208
    import tempfile
209
210
    fname = tempfile.mkdtemp() + "/unwanted_prints.txt"
211
    sys.stdout = open(fname, "w")
212
213
214
def enablePrint():
215
    """ Enable printing to stdout """
216
    sys.stdout = sys.__stdout__
217
218
219
def parse_config_string(string):
220
    regex = r"[\[\]\, ]+"
221
    split_vals = [_f for _f in re.split(regex, string) if _f]
222
    delimitors = re.findall(regex, string)
223
    split_vals = [repr(a.strip()) for a in split_vals]
224
    zipped = itertools.zip_longest(delimitors, split_vals)
225
    string = "".join([i for l in zipped for i in l if i is not None])
226
    try:
227
        return ast.literal_eval(string)
228
    except ValueError:
229
        return ast.literal_eval(parse_array_index_as_string(string))
230
231
232
def parse_array_index_as_string(string):
233
    p = re.compile(r"'\['")
234
    for m in p.finditer(string):
235
        offset = m.start() - count + 3
236
        end = string[offset:].index("']") + offset
237
        string = string[:end] + "]'" + string[end + 2 :]
238
    string = string.replace("'['", "[")
239
    return string
240
241
242
def param_to_str(param_name, keys):
243
    """Check the parameter is within the provided list and
244
    return the string name.
245
    """
246
    if param_name.isdigit():
247
        param_name = int(param_name)
248
        if param_name <= len(keys):
249
            param_name = keys[param_name - 1]
250
        else:
251
            raise Exception(
252
                "This parameter number is not valid for this plugin"
253
            )
254
    elif param_name not in keys:
255
        raise Exception("This parameter is not present in this plug in.")
256
257
    return param_name
258
259
260
def set_order_by_visibility(parameters, level=False):
261
    """Return an ordered list of parameters depending on the
262
    visibility level
263
264
    :param parameters: The dictionary of parameters
265
    :param level: The visibility level
266
    :return: An ordered list of parameters
267
    """
268
    data_keys = []
269
    basic_keys = []
270
    interm_keys = []
271
    adv_keys = []
272
    for k, v in parameters.items():
273
        if v["display"] == "on":
274
            if v["visibility"] == "datasets":
275
                data_keys.append(k)
276
            if v["visibility"] == "basic":
277
                basic_keys.append(k)
278
            if v["visibility"] == "intermediate":
279
                interm_keys.append(k)
280
            if v["visibility"] == "advanced":
281
                adv_keys.append(k)
282
    if level:
283
        if level == "datasets":
284
            keys = data_keys
285
        elif level == "basic":
286
            keys = basic_keys
287
        elif level == "intermediate":
288
            keys = basic_keys + interm_keys + data_keys
289
        elif level == "advanced":
290
            keys = basic_keys + interm_keys + adv_keys + data_keys
291
        else:
292
            keys = basic_keys + interm_keys + adv_keys + data_keys
293
    else:
294
        keys = basic_keys + interm_keys + adv_keys + data_keys
295
296
    return keys
297
298
299
def convert_multi_params(param_name, value):
300
    """Check if value is a multi parameter and check if each item is valid.
301
    Change from the input multi parameter string to a list
302
303
    :param param_name: Name of the parameter
304
    :param value: Parameter value
305
    :return: List or unchanged value
306
    """
307
    error_str = ""
308
    multi_parameters = (
309
        isinstance(value, str) and (";" in value) and param_name != "preview"
310
    )
311
    if multi_parameters:
312
        value = value.split(";")
313
        isdict = re.findall(r"[\{\}]+", value[0])
314
        if ":" in value[0] and not isdict:
315
            seq = value[0].split(":")
316
            try:
317
                seq = [ast.literal_eval(s) for s in seq]
318
                if len(value) == 0:
319
                    error_str = (
320
                        f"No values for tuned parameter "
321
                        f"'{param_name}' ensure start:stop:step; values "
322
                        f"are valid"
323
                    )
324
                elif len(seq) == 2:
325
                    value = list(np.arange(seq[0], seq[1]))
326
                elif len(seq) > 2:
327
                    value = list(np.arange(seq[0], seq[1], seq[2]))
328
                else:
329
                    error_str = "Ensure start:stop:step; values are valid."
330
            except:
331
                error_str = "Ensure start:stop:step; values are valid."
332
        val_list = (
333
            parse_config_string(value) if isinstance(value, str) else value
334
        )
335
        # Remove blank list entries
336
        # Change type to int, float or str
337
        val_list = [_dumps(val) for val in value if val]
338
        value = val_list
339
    return value, error_str
340
341
342
def _dumps(val):
343
    """Replace any missing quotes around variables
344
    Change the string to an integer, float, tuple, list, str, dict
345
    """
346
    import yaml
347
348
    if isinstance(val, str):
349
        try:
350
            # Safely evaluate an expression node or a string containing
351
            # a Python literal or container display
352
            value = ast.literal_eval(val)
353
            return value
354
        except Exception:
355
            pass
356
        try:
357
            isdict = re.findall(r"[\{\}]+", val)
358
            val = _sexagesimal_check(val, isdict, remove=False)
359
            value = yaml.load(val, Loader=yaml.SafeLoader)
360
            return _sexagesimal_check(value, isdict)
361
        except Exception:
362
            val = _sexagesimal_check(val, isdict)
363
            pass
364
        try:
365
            isdict = re.findall(r"[\{\}]+", val)
366
            # Matches { } between one and unlimited number of times
367
            if isdict:
368
                if isinstance(val, dict):
369
                    value_dict = {}
370
                    for k, v in val.items():
371
                        v = v.replace("[", "'[").replace("]", "]'")
372
                        value_dict[k] = _dumps(
373
                            yaml.load(v, Loader=yaml.SafeLoader)
374
                        )
375
                    return value_dict
376
                else:
377
                    value = val.replace("[", "'[").replace("]", "]'")
378
                    return _dumps(yaml.load(value, Loader=yaml.SafeLoader))
379
            else:
380
                value = parse_config_string(val)
381
                return value
382
        except Exception:
383
            if len(val.split(";")) > 1:
384
                value = val
385
                return value
386
            else:
387
                raise Exception("Invalid string %s" % val)
388
    else:
389
        value = val
390
    return value
391
392
393
def _sexagesimal_check(val, isdict, remove=True):
394
    """To avoid sexagesimal values being evaluated, replace colon
395
    values temporarily
396
397
    :param val:
398
    :param isdict: True if braces {} found
399
    :return: value
400
    """
401
    if isinstance(val, str) and not isdict:
402
        if remove:
403
            val = val.replace(":?", ":")
404
        else:
405
            val = val.replace(":", ":?")
406
    return val
407
408
409
def check_valid_dimension(dim, prev_list):
410
    """Check the dimension is within the correct range"""
411
    if not 0 < dim < 21:
412
        raise Exception("Please use a dimension between 1 and 20.")
413
    if prev_list and (dim > len(prev_list)):
414
        raise Exception(
415
            "You have not specified enough dimensions "
416
            "inside the preview parameter."
417
        )
418
    return True
419
420
421
def is_slice_notation(value):
422
    """Return True if the value is made up of multiple"""
423
    return isinstance(value, str) and (":" in value)
424
425
426
def create_dir(file_path):
427
    """Check if directories provided exist at this file path. If they don't
428
    create the directories.
429
    """
430
    directory = os.path.dirname(file_path)
431
    if not os.path.exists(directory):
432
        os.makedirs(directory)
433
434
435
def indent_multi_line_str(text, indent_level=1, justify=False):
436
    text = text.split("\n")
437
    # Remove additional spacing on the left side so that text aligns
438
    if justify is False:
439
        text = [(" " * 4 * indent_level) + line for line in text]
440
    else:
441
        text = [(" " * 4 * indent_level) + line.lstrip() for line in text]
442
    text = "\n".join(text)
443
    return text
444
445
446
def indent(text, indent_level=1):
447
    text = (" " * 4 * indent_level) + text
448
    return text
449
450
451
def sort_alphanum(_list):
452
    """Sort list numerically and alphabetically
453
    *While maintaining original list value types*
454
455
    :param _list: Input list to be sorted
456
    :return: List sorted by number and letter alphabetically
457
    """
458
    return sorted(_list, key=_alphanum)
459
460
461
def _str_to_int(_str):
462
    """Convert the input str to an int if possible
463
464
    :param _str: input string
465
    :return: integer if text is a digit, else string
466
    """
467
    return int(_str) if _str.isdigit() else _str
468
469
470
def _alphanum(_str):
471
    """Split string into numbers and letters
472
473
    :param _str:
474
    :return: list of numbers and letters
475
    """
476
    char_list = re.split("([0-9]+)", _str)
477
    return [_str_to_int(c) for c in char_list]
478