Test Failed
Pull Request — master (#917)
by
unknown
04:53
created

savu.plugins.utils._user_directory_warning()   A

Complexity

Conditions 1

Size

Total Lines 14
Code Lines 11

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 11
nop 1
dl 0
loc 14
rs 9.85
c 0
b 0
f 0
1
# Copyright 2014 Diamond Light Source Ltd.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
"""
15
.. module:: utils
16
   :platform: Unix
17
   :synopsis: Utilities for plugin management
18
19
.. moduleauthor:: Mark Basham <[email protected]>
20
21
"""
22
23
import os
24
import re
25
import sys
26
import ast
27
import logging
28
import savu
29
import importlib
30
import inspect
31
import itertools
32
33
from collections import OrderedDict
34
from colorama import Fore, Style
35
import numpy as np
36
37
from savu.plugins.loaders.utils.my_safe_constructor import MySafeConstructor
38
39
# can I remove these from here?
40
41
load_tools = {}
42
plugins = {}
43
plugins_path = {}
44
dawn_plugins = {}
45
count = 0
46
47
OUTPUT_TYPE_DATA_ONLY = 0
48
OUTPUT_TYPE_METADATA_ONLY = 1
49
OUTPUT_TYPE_METADATA_AND_DATA = 2
50
51
52
def register_plugin(clazz):
53
    """decorator to add plugins to a central register"""
54
    plugins[clazz.__name__] = clazz
55
    if clazz.__module__.split(".")[0] != "savu":
56
        plugins_path[clazz.__name__] = clazz.__module__
57
    return clazz
58
59
60
def dawn_compatible(plugin_output_type=OUTPUT_TYPE_METADATA_AND_DATA):
61
    def _dawn_compatible(clazz):
62
        """
63
        decorator to add dawn compatible plugins and details to a central
64
        register
65
        """
66
        dawn_plugins[clazz.__name__] = {}
67
        try:
68
            plugin_path = sys.modules[clazz.__module__].__file__
69
            # looks out for .pyc files
70
            dawn_plugins[clazz.__name__]['path2plugin'] = plugin_path.split('.py')[0] + '.py'
71
            dawn_plugins[clazz.__name__]['plugin_output_type'] = _plugin_output_type
72
        except Exception as e:
73
            print(e)
74
        return clazz
75
76
    # for backwards compatibility, if decorator is invoked without brackets...
77
    if inspect.isclass(plugin_output_type):
78
        _plugin_output_type = OUTPUT_TYPE_METADATA_AND_DATA
79
        return _dawn_compatible(plugin_output_type)
80
    else:
81
        _plugin_output_type = plugin_output_type
82
        return _dawn_compatible
83
84
85
def get_plugin(plugin_name, params, exp, check=False):
86
    """Get an instance of the plugin class and populate default parameters.
87
88
    :param plugin_name: Name of the plugin to import
89
    :type plugin_name: str.
90
    :returns:  An instance of the class described by the named plugin.
91
    """
92
    logging.debug("Importing the module %s", plugin_name)
93
    instance = load_class(plugin_name)()
94
    instance.initialise(params, exp, check=check)
95
    return instance
96
97
98
def _get_cls_name(name):
99
    return "".join(x.capitalize() for x in name.split(".")[-1].split("_"))
100
101
102
def load_class(name, cls_name=None):
103
    """Returns an instance of the class associated with the module name.
104
105
    :param name: Module name or path to a module file
106
    :returns: An instance of the class associated with module.
107
    """
108
    path = name if os.path.dirname(name) else None
109
    name = os.path.basename(os.path.splitext(name)[0]) if path else name
110
    if "savu_plugins" in name and "tools" not in name:
111
        _user_directory_warning(_get_cls_name(name))
112
    cls_name = _get_cls_name(name) if not cls_name else cls_name
113
    if cls_name in plugins.keys():
114
        return plugins[cls_name]
115
    if path:
116
        mod = importlib.machinery.SourceFileLoader(name, path).load_module()
117
    else:
118
        mod = importlib.import_module(name)
119
    return getattr(mod, cls_name)
120
121
122
def plugin_loader(exp, plugin_dict, check=False):
123
    logging.debug("Running plugin loader")
124
    try:
125
        plugin = get_plugin(plugin_dict['id'],
126
                            plugin_dict['data'],
127
                            exp,
128
                            check=check)
129
    except Exception as e:
130
        logging.error("failed to load the plugin")
131
        logging.error(e)
132
        # re-raise the original error
133
        raise
134
135
    if check:
136
        exp.meta_data.plugin_list._set_datasets_list(plugin)
137
138
    logging.debug("finished plugin loader")
139
    return plugin
140
141
142
def get_tools_class(plugin_tools_id, cls=None):
143
    """Load the plugin tools class
144
145
    :param plugin_tools_id: plugin tools module name
146
    :param cls: Class to initialise
147
    :return:
148
    """
149
    if plugin_tools_id == "savu.plugins.plugin_tools":
150
        plugin_tools_id = "savu.plugins.base_tools"
151
    if cls:
152
        return load_class(plugin_tools_id)(cls)
153
    else:
154
        return load_class(plugin_tools_id)
155
156
157
def get_plugins_paths(examples=True):
158
    """
159
    This gets the plugin paths, but also adds any that are not on the
160
    pythonpath to it.
161
    """
162
    plugins_paths = OrderedDict()
163
164
    # Add the savu plugins paths first so it is overridden by user folders
165
    savu_plugins_path = os.path.join(savu.__path__[0], 'plugins')
166
    savu_plugins_subpaths = [d for d in next(os.walk(savu_plugins_path))[1] \
167
                             if d != "__pycache__"]
168
    for path in savu_plugins_subpaths:
169
        plugins_paths[os.path.join(savu_plugins_path, path)] = \
170
            ''.join(['savu.plugins.', path, '.'])
171
172
    # get user, environment and example plugin paths
173
    user_path = [os.path.join(os.path.expanduser("~"), "savu_plugins")]
174
    env_paths = os.getenv("SAVU_PLUGINS_PATH", "").replace(" ", "").split(":")
175
    templates = "../examples/plugin_examples/plugin_templates"
176
    eg_path = [os.path.join(savu.__path__[0], templates)] if examples else []
177
178
    for ppath in env_paths + user_path + eg_path:
179
        if os.path.exists(ppath):
180
            plugins_paths[ppath] = os.path.basename(ppath) + "."
181
            if ppath not in sys.path:
182
                sys.path.append(os.path.dirname(ppath))
183
184
    return plugins_paths
185
186
187
def _user_directory_warning(plugin_name):
188
    """Warn the user that their plugin may override official plugins
189
    :plugin_name plugin name
190
    """
191
    user_path = os.path.join(os.path.expanduser("~"), "savu_plugins")
192
    print("-"*58)
193
    print(f" You are loading the plugin {plugin_name} from the following")
194
    print(f" location: {user_path}")
195
    warn_c = Style.RESET_ALL + Fore.RED
196
    warn_str1 = f" WARNING The plugin {plugin_name} will be prioritised over a "
197
    warn_str2 = f" {plugin_name} plugin if present in the released version."
198
    print(warn_c + warn_str1 )
199
    print(warn_str2 + Fore.BLACK)
200
    print("-"*58)
201
202
203
def is_template_param(param):
204
    """Identifies if the parameter should be included in an input template
205
    and returns the default value of the parameter if it exists.
206
    """
207
    start = 0
208
    ptype = "local"
209
    if isinstance(param, str):
210
        param = param.strip()
211
        if not param.split("global")[0]:
212
            ptype = "global"
213
            start = 6
214
        first, last = param[start], param[-1]
215
        if first == "<" and last == ">":
216
            param = param[start + 1 : -1]
217
            param = None if not param else param
218
            try:
219
                param = eval(param)
220
            except:
221
                pass
222
            return [ptype, param]
223
    return False
224
225
226
def blockPrint():
227
    """ Disable printing to stdout """
228
    import tempfile
229
230
    fname = tempfile.mkdtemp() + "/unwanted_prints.txt"
231
    sys.stdout = open(fname, "w")
232
233
234
def enablePrint():
235
    """ Enable printing to stdout """
236
    sys.stdout = sys.__stdout__
237
238
239
def parse_config_string(string):
240
    regex = r"[\[\]\, ]+"
241
    split_vals = [_f for _f in re.split(regex, string) if _f]
242
    delimitors = re.findall(regex, string)
243
    split_vals = [repr(a.strip()) for a in split_vals]
244
    zipped = itertools.zip_longest(delimitors, split_vals)
245
    string = "".join([i for l in zipped for i in l if i is not None])
246
    try:
247
        return ast.literal_eval(string)
248
    except ValueError:
249
        return ast.literal_eval(parse_array_index_as_string(string))
250
251
252
def parse_array_index_as_string(string):
253
    p = re.compile(r"'\['")
254
    for m in p.finditer(string):
255
        offset = m.start() - count + 3
256
        end = string[offset:].index("']") + offset
257
        string = string[:end] + "]'" + string[end + 2 :]
258
    string = string.replace("'['", "[")
259
    return string
260
261
262
def param_to_str(param_name, keys):
263
    """Check the parameter is within the provided list and
264
    return the string name.
265
    """
266
    if param_name.isdigit():
267
        param_name = int(param_name)
268
        if param_name <= len(keys):
269
            param_name = keys[param_name - 1]
270
        else:
271
            raise ValueError(
272
                "This parameter number is not valid for this plugin"
273
            )
274
    elif param_name not in keys:
275
        raise Exception("This parameter is not present in this plug in.")
276
277
    return param_name
278
279
280
def set_order_by_visibility(parameters, level=False):
281
    """Return an ordered list of parameters depending on the
282
    visibility level
283
284
    :param parameters: The dictionary of parameters
285
    :param level: The visibility level
286
    :return: An ordered list of parameters
287
    """
288
    data_keys = []
289
    basic_keys = []
290
    interm_keys = []
291
    adv_keys = []
292
    for k, v in parameters.items():
293
        if v["display"] == "on":
294
            if v["visibility"] == "datasets":
295
                data_keys.append(k)
296
            if v["visibility"] == "basic":
297
                basic_keys.append(k)
298
            if v["visibility"] == "intermediate":
299
                interm_keys.append(k)
300
            if v["visibility"] == "advanced":
301
                adv_keys.append(k)
302
    if level:
303
        if level == "datasets":
304
            keys = data_keys
305
        elif level == "basic":
306
            keys = basic_keys
307
        elif level == "intermediate":
308
            keys = basic_keys + interm_keys + data_keys
309
        elif level == "advanced":
310
            keys = basic_keys + interm_keys + adv_keys + data_keys
311
        else:
312
            keys = basic_keys + interm_keys + adv_keys + data_keys
313
    else:
314
        keys = basic_keys + interm_keys + adv_keys + data_keys
315
316
    return keys
317
318
319
def convert_multi_params(param_name, value):
320
    """Check if value is a multi parameter and check if each item is valid.
321
    Change from the input multi parameter string to a list
322
323
    :param param_name: Name of the parameter
324
    :param value: Parameter value
325
    :return: List or unchanged value
326
    """
327
    error_str = ""
328
    multi_parameters = (
329
        isinstance(value, str) and (";" in value) and param_name != "preview"
330
    )
331
    if multi_parameters:
332
        value = value.split(";")
333
        isdict = re.findall(r"[\{\}]+", value[0])
334
        if ":" in value[0] and not isdict:
335
            seq = value[0].split(":")
336
            try:
337
                seq = [ast.literal_eval(s) for s in seq]
338
                if len(value) == 0:
339
                    error_str = (
340
                        f"No values for tuned parameter "
341
                        f"'{param_name}' ensure start:stop:step; values "
342
                        f"are valid"
343
                    )
344
                elif len(seq) == 2:
345
                    value = list(np.arange(seq[0], seq[1]))
346
                elif len(seq) > 2:
347
                    value = list(np.arange(seq[0], seq[1], seq[2]))
348
                else:
349
                    error_str = "Ensure start:stop:step; values are valid."
350
                if not value:
351
                    # Don't allow an empty list
352
                    raise ValueError
353
            except:
354
                error_str = "Ensure start:stop:step; values are valid."
355
        val_list = (
356
            parse_config_string(value) if isinstance(value, str) else value
357
        )
358
        # Remove blank list entries
359
        # Change type to int, float or str
360
        val_list = [_dumps(val) for val in value if val]
361
        value = val_list
362
    return value, error_str
363
364
365
def _dumps(val):
366
    """Replace any missing quotes around variables
367
    Change the string to an integer, float, tuple, list, str, dict
368
    """
369
    import yaml
370
    # Prevent conversion from on/off to boolean
371
    yaml.SafeLoader.add_constructor(
372
        "tag:yaml.org,2002:bool", MySafeConstructor.add_bool
373
    )
374
    if isinstance(val, str):
375
        try:
376
            # Safely evaluate an expression node or a string containing
377
            # a Python literal or container display
378
            value = ast.literal_eval(val)
379
            return value
380
        except Exception:
381
            pass
382
        try:
383
            isdict = re.findall(r"[\{\}]+", val)
384
            val = _sexagesimal_check(val, isdict, remove=False)
385
            value = yaml.safe_load(val)
386
            return _sexagesimal_check(value, isdict)
387
        except Exception:
388
            val = _sexagesimal_check(val, isdict)
389
            pass
390
        try:
391
            isdict = re.findall(r"[\{\}]+", val)
392
            # Matches { } between one and unlimited number of times
393
            if isdict:
394
                if isinstance(val, dict):
395
                    value_dict = {}
396
                    for k, v in val.items():
397
                        v = v.replace("[", "'[").replace("]", "]'")
398
                        value_dict[k] = _dumps(
399
                            yaml.safe_load(v)
400
                        )
401
                    return value_dict
402
                else:
403
                    value = val.replace("[", "'[").replace("]", "]'")
404
                    return _dumps(yaml.safe_load(value))
405
            else:
406
                value = parse_config_string(val)
407
                return value
408
        except Exception:
409
            if len(val.split(";")) > 1:
410
                value = val
411
                return value
412
            else:
413
                raise Exception("Invalid string %s" % val)
414
    else:
415
        value = val
416
    return value
417
418
419
def _sexagesimal_check(val, isdict, remove=True):
420
    """To avoid sexagesimal values being evaluated, replace colon
421
    values temporarily
422
423
    :param val:
424
    :param isdict: True if braces {} found
425
    :return: value
426
    """
427
    if isinstance(val, str) and not isdict:
428
        if remove:
429
            val = val.replace(":?", ":")
430
        else:
431
            val = val.replace(":", ":?")
432
    return val
433
434
435
def check_valid_dimension(dim, prev_list):
436
    """Check the dimension is within the correct range"""
437
    if not 0 < dim < 21:
438
        raise Exception("Please use a dimension between 1 and 20.")
439
    if prev_list and (dim > len(prev_list)):
440
        raise Exception(
441
            "You have not specified enough dimensions "
442
            "inside the preview parameter."
443
        )
444
    return True
445
446
447
def is_slice_notation(value):
448
    """Return True if the value is made up of multiple"""
449
    return isinstance(value, str) and (":" in value)
450
451
452
def create_dir(file_path):
453
    """Check if directories provided exist at this file path. If they don't
454
    create the directories.
455
    """
456
    directory = os.path.dirname(file_path)
457
    if not os.path.exists(directory):
458
        os.makedirs(directory)
459
460
461
def indent_multi_line_str(text, indent_level=1, justify=False):
462
    text = text.split("\n")
463
    # Remove additional spacing on the left side so that text aligns
464
    if justify is False:
465
        text = [(" " * 4 * indent_level) + line for line in text]
466
    else:
467
        text = [(" " * 4 * indent_level) + line.lstrip() for line in text]
468
    text = "\n".join(text)
469
    return text
470
471
472
def indent(text, indent_level=1):
473
    text = (" " * 4 * indent_level) + text
474
    return text
475
476
477
def sort_alphanum(_list):
478
    """Sort list numerically and alphabetically
479
    *While maintaining original list value types*
480
481
    :param _list: Input list to be sorted
482
    :return: List sorted by number and letter alphabetically
483
    """
484
    return sorted(_list, key=_alphanum)
485
486
487
def _str_to_int(_str):
488
    """Convert the input str to an int if possible
489
490
    :param _str: input string
491
    :return: integer if text is a digit, else string
492
    """
493
    return int(_str) if _str.isdigit() else _str
494
495
496
def _alphanum(_str):
497
    """Split string into numbers and letters
498
499
    :param _str:
500
    :return: list of numbers and letters
501
    """
502
    _str = _str.strip() # Remove spacing
503
    char_list = re.split("([0-9]+)", _str)
504
    char_list = list(filter(None, char_list))
505
    return [_str_to_int(c) for c in char_list]
506