Passed
Push — master ( a53d46...b5a25f )
by Juho
03:01
created

annif.util.metric_code()   A

Complexity

Conditions 1

Size

Total Lines 3
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 2
nop 1
dl 0
loc 3
rs 10
c 0
b 0
f 0
1
"""Utility functions for Annif"""
2
3
import glob
4
import logging
5
import os
6
import os.path
7
import tempfile
8
9
import numpy as np
10
11
from annif import logger
12
from annif.suggestion import VectorSuggestionResult
13
14
15
class DuplicateFilter(logging.Filter):
16
    """Filter out log messages that have already been displayed."""
17
18
    def __init__(self):
19
        super().__init__()
20
        self.logged = set()
21
22
    def filter(self, record):
23
        current_log = hash((record.module, record.levelno, record.msg, record.args))
24
        if current_log not in self.logged:
25
            self.logged.add(current_log)
26
            return True
27
        return False
28
29
30
def atomic_save(obj, dirname, filename, method=None):
31
    """Save the given object (which must have a .save() method, unless the
32
    method parameter is given) into the given directory with the given
33
    filename, using a temporary file and renaming the temporary file to the
34
    final name."""
35
36
    prefix, suffix = os.path.splitext(filename)
37
    tempfd, tempfilename = tempfile.mkstemp(prefix=prefix, suffix=suffix, dir=dirname)
38
    os.close(tempfd)
39
    logger.debug("saving %s to temporary file %s", str(obj)[:90], tempfilename)
40
    if method is not None:
41
        method(obj, tempfilename)
42
    else:
43
        obj.save(tempfilename)
44
    for fn in glob.glob(tempfilename + "*"):
45
        newname = fn.replace(tempfilename, os.path.join(dirname, filename))
46
        logger.debug("renaming temporary file %s to %s", fn, newname)
47
        os.rename(fn, newname)
48
49
50
def cleanup_uri(uri):
51
    """remove angle brackets from a URI, if any"""
52
    if uri.startswith("<") and uri.endswith(">"):
53
        return uri[1:-1]
54
    return uri
55
56
57
def merge_hits(weighted_hits_batches, size):
58
    """Merge hit sets from multiple sources. Input is a sequence of
59
    WeightedSuggestionsBatch objects. The size parameter determines the length of the
60
    subject vector. Returns a list of SuggestionResult objects."""
61
62
    weights = [batch.weight for batch in weighted_hits_batches]
63
    score_vectors = np.array(
64
        [
65
            [whits.as_vector(size) for whits in batch.hit_sets]
66
            for batch in weighted_hits_batches
67
        ]
68
    )
69
    results = np.average(score_vectors, axis=0, weights=weights)
70
    return [VectorSuggestionResult(res) for res in results]
71
72
73
def parse_sources(sourcedef):
74
    """parse a source definition such as 'src1:1.0,src2' into a sequence of
75
    tuples (src_id, weight)"""
76
77
    sources = []
78
    totalweight = 0.0
79
    for srcdef in sourcedef.strip().split(","):
80
        srcval = srcdef.strip().split(":")
81
        src_id = srcval[0]
82
        if len(srcval) > 1:
83
            weight = float(srcval[1])
84
        else:
85
            weight = 1.0
86
        sources.append((src_id, weight))
87
        totalweight += weight
88
    return [(srcid, weight / totalweight) for srcid, weight in sources]
89
90
91
def parse_args(param_string):
92
    """Parse a string of comma separated arguments such as '42,43,key=abc' into
93
    a list of positional args [42, 43] and a dict of keyword args {key: abc}"""
94
95
    if not param_string:
96
        return [], {}
97
    posargs = []
98
    kwargs = {}
99
    param_strings = param_string.split(",")
100
    for p_string in param_strings:
101
        parts = p_string.split("=")
102
        if len(parts) == 1:
103
            posargs.append(p_string)
104
        elif len(parts) == 2:
105
            kwargs[parts[0]] = parts[1]
106
    return posargs, kwargs
107
108
109
def boolean(val):
110
    """Convert the given value to a boolean True/False value, if it isn't already.
111
    True values are '1', 'yes', 'true', and 'on' (case insensitive), everything
112
    else is False."""
113
114
    return str(val).lower() in ("1", "yes", "true", "on")
115
116
117
def identity(x):
118
    """Identity function: return the given argument unchanged"""
119
    return x
120
121
122
def metric_code(metric):
123
    """Convert a human-readable metric name into an alphanumeric string"""
124
    return metric.translate(metric.maketrans(" ", "_", "()"))
125