Passed
Pull Request — master (#540)
by
unknown
02:12
created

annif.util.identity()   A

Complexity

Conditions 1

Size

Total Lines 3
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 2
dl 0
loc 3
rs 10
c 0
b 0
f 0
cc 1
nop 1
1
"""Utility functions for Annif"""
2
3
import glob
4
import os
5
import os.path
6
from shutil import rmtree
7
import tempfile
8
import numpy as np
9
from annif import logger
10
from annif.suggestion import VectorSuggestionResult
11
12
13
def atomic_save(obj, dirname, filename, method=None):
14
    """Save the given object (which must have a .save() method, unless the
15
    method parameter is given) into the given directory with the given
16
    filename, using a temporary file and renaming the temporary file to the
17
    final name. To save a directory explicitly set filename=None."""
18
19
    if filename:
20
        prefix, suffix = os.path.splitext(filename)
21
        tempfd, tempfilename = tempfile.mkstemp(
22
            prefix=prefix, suffix=suffix, dir=dirname)
23
        os.close(tempfd)
24
        target_pth = os.path.join(dirname, filename)
25
    else:
26
        tldir = os.path.dirname(dirname.rstrip('/'))
27
        os.makedirs(dirname, exist_ok=tldir)
28
        tempdir = tempfile.TemporaryDirectory(dir=tldir)
29
        tempfilename = tempdir.name
30
        target_pth = dirname
31
    logger.debug('saving %s to temporary file %s', str(obj)[:90], tempfilename)
32
    if method is not None:
33
        method(obj, tempfilename)
34
    else:
35
        obj.save(tempfilename)
36
    for fn in glob.glob(tempfilename + '*'):
37
        newname = fn.replace(tempfilename, target_pth)
38
        logger.debug('renaming temporary file %s to %s', fn, newname)
39
        if os.path.isdir(newname):
40
            rmtree(newname)
41
        os.replace(fn, newname)
42
43
44
def cleanup_uri(uri):
45
    """remove angle brackets from a URI, if any"""
46
    if uri.startswith('<') and uri.endswith('>'):
47
        return uri[1:-1]
48
    return uri
49
50
51
def merge_hits(weighted_hits, subject_index):
52
    """Merge hits from multiple sources. Input is a sequence of WeightedSuggestion
53
    objects. A SubjectIndex is needed to convert between subject IDs and URIs.
54
    Returns an SuggestionResult object."""
55
56
    weights = [whit.weight for whit in weighted_hits]
57
    scores = [whit.hits.as_vector(subject_index) for whit in weighted_hits]
58
    result = np.average(scores, axis=0, weights=weights)
59
    return VectorSuggestionResult(result)
60
61
62
def parse_sources(sourcedef):
63
    """parse a source definition such as 'src1:1.0,src2' into a sequence of
64
    tuples (src_id, weight)"""
65
66
    sources = []
67
    totalweight = 0.0
68
    for srcdef in sourcedef.strip().split(','):
69
        srcval = srcdef.strip().split(':')
70
        src_id = srcval[0]
71
        if len(srcval) > 1:
72
            weight = float(srcval[1])
73
        else:
74
            weight = 1.0
75
        sources.append((src_id, weight))
76
        totalweight += weight
77
    return [(srcid, weight / totalweight) for srcid, weight in sources]
78
79
80
def parse_args(param_string):
81
    """Parse a string of comma separated arguments such as '42,43,key=abc' into
82
    a list of positional args [42, 43] and a dict of keyword args {key: abc}"""
83
84
    if not param_string:
85
        return [], {}
86
    posargs = []
87
    kwargs = {}
88
    param_strings = param_string.split(',')
89
    for p_string in param_strings:
90
        parts = p_string.split('=')
91
        if len(parts) == 1:
92
            posargs.append(p_string)
93
        elif len(parts) == 2:
94
            kwargs[parts[0]] = parts[1]
95
    return posargs, kwargs
96
97
98
def apply_param_parse_config(configs, params):
99
    """Applies a parsing configuration to a parameter dict."""
100
    return {
101
        param: configs[param](val)
102
        for param, val in params.items()
103
        if param in configs and val is not None}
104
105
106
def boolean(val):
107
    """Convert the given value to a boolean True/False value, if it isn't already.
108
    True values are '1', 'yes', 'true', and 'on' (case insensitive), everything
109
    else is False."""
110
111
    return str(val).lower() in ('1', 'yes', 'true', 'on')
112
113
114
def identity(x):
115
    """Identity function: return the given argument unchanged"""
116
    return x
117