Passed
Push — master ( eb437a...a53d46 )
by Juho
03:20 queued 15s
created

annif.util.atomic_save_folder()   A

Complexity

Conditions 4

Size

Total Lines 23
Code Lines 16

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 4
eloc 16
nop 3
dl 0
loc 23
rs 9.6
c 0
b 0
f 0
1
"""Utility functions for Annif"""
2
3
import glob
4
import logging
5
import os
6
import os.path
7
import tempfile
8
from shutil import rmtree
9
10
import numpy as np
11
12
from annif import logger
13
from annif.suggestion import VectorSuggestionResult
14
15
16
class DuplicateFilter(logging.Filter):
17
    """Filter out log messages that have already been displayed."""
18
19
    def __init__(self):
20
        super().__init__()
21
        self.logged = set()
22
23
    def filter(self, record):
24
        current_log = hash((record.module, record.levelno, record.msg, record.args))
25
        if current_log not in self.logged:
26
            self.logged.add(current_log)
27
            return True
28
        return False
29
30
31
def atomic_save(obj, dirname, filename, method=None):
32
    """Save the given object (which must have a .save() method, unless the
33
    method parameter is given) into the given directory with the given
34
    filename, using a temporary file and renaming the temporary file to the
35
    final name. The .save() mehod or the function provided in the method argument
36
    will be called with the path to the temporary file."""
37
38
    prefix, suffix = os.path.splitext(filename)
39
    tempfd, tempfilename = tempfile.mkstemp(prefix=prefix, suffix=suffix, dir=dirname)
40
    os.close(tempfd)
41
    logger.debug("saving %s to temporary file %s", str(obj)[:90], tempfilename)
42
    if method is not None:
43
        method(obj, tempfilename)
44
    else:
45
        obj.save(tempfilename)
46
    for fn in glob.glob(tempfilename + "*"):
47
        newname = fn.replace(tempfilename, os.path.join(dirname, filename))
48
        logger.debug("renaming temporary file %s to %s", fn, newname)
49
        os.rename(fn, newname)
50
51
52
def atomic_save_folder(obj, dirname, method=None):
53
    """Save the given object (which must have a .save() method, unless the
54
    method parameter is given) into the given directory,
55
    using a temporary directory and renaming the temporary directory to the
56
    final name. The .save() method or the function provided in the method argument
57
    will be called with the path to the temporary directory."""
58
59
    tldir = os.path.dirname(dirname.rstrip("/"))
60
    os.makedirs(dirname, exist_ok=tldir)
61
    tempdir = tempfile.TemporaryDirectory(dir=tldir)
62
    temp_dir_name = tempdir.name
63
    target_pth = dirname
64
    logger.debug("saving %s to temporary file %s", str(obj)[:90], temp_dir_name)
65
    if method is not None:
66
        method(obj, temp_dir_name)
67
    else:
68
        obj.save(temp_dir_name)
69
    for fn in glob.glob(temp_dir_name + "*"):
70
        newname = fn.replace(temp_dir_name, target_pth)
71
        logger.debug("renaming temporary file %s to %s", fn, newname)
72
        if os.path.isdir(newname):
73
            rmtree(newname)
74
        os.replace(fn, newname)
75
76
77
def cleanup_uri(uri):
78
    """remove angle brackets from a URI, if any"""
79
    if uri.startswith("<") and uri.endswith(">"):
80
        return uri[1:-1]
81
    return uri
82
83
84
def merge_hits(weighted_hits_batches, size):
85
    """Merge hit sets from multiple sources. Input is a sequence of
86
    WeightedSuggestionsBatch objects. The size parameter determines the length of the
87
    subject vector. Returns a list of SuggestionResult objects."""
88
89
    weights = [batch.weight for batch in weighted_hits_batches]
90
    score_vectors = np.array(
91
        [
92
            [whits.as_vector(size) for whits in batch.hit_sets]
93
            for batch in weighted_hits_batches
94
        ]
95
    )
96
    results = np.average(score_vectors, axis=0, weights=weights)
97
    return [VectorSuggestionResult(res) for res in results]
98
99
100
def parse_sources(sourcedef):
101
    """parse a source definition such as 'src1:1.0,src2' into a sequence of
102
    tuples (src_id, weight)"""
103
104
    sources = []
105
    totalweight = 0.0
106
    for srcdef in sourcedef.strip().split(","):
107
        srcval = srcdef.strip().split(":")
108
        src_id = srcval[0]
109
        if len(srcval) > 1:
110
            weight = float(srcval[1])
111
        else:
112
            weight = 1.0
113
        sources.append((src_id, weight))
114
        totalweight += weight
115
    return [(srcid, weight / totalweight) for srcid, weight in sources]
116
117
118
def parse_args(param_string):
119
    """Parse a string of comma separated arguments such as '42,43,key=abc' into
120
    a list of positional args [42, 43] and a dict of keyword args {key: abc}"""
121
122
    if not param_string:
123
        return [], {}
124
    posargs = []
125
    kwargs = {}
126
    param_strings = param_string.split(",")
127
    for p_string in param_strings:
128
        parts = p_string.split("=")
129
        if len(parts) == 1:
130
            posargs.append(p_string)
131
        elif len(parts) == 2:
132
            kwargs[parts[0]] = parts[1]
133
    return posargs, kwargs
134
135
136
def apply_param_parse_config(configs, params):
137
    """Applies a parsing configuration to a parameter dict."""
138
    return {
139
        param: configs[param](val)
140
        for param, val in params.items()
141
        if param in configs and val is not None
142
    }
143
144
145
def boolean(val):
146
    """Convert the given value to a boolean True/False value, if it isn't already.
147
    True values are '1', 'yes', 'true', and 'on' (case insensitive), everything
148
    else is False."""
149
150
    return str(val).lower() in ("1", "yes", "true", "on")
151
152
153
def identity(x):
154
    """Identity function: return the given argument unchanged"""
155
    return x
156
157
158
def metric_code(metric):
159
    """Convert a human-readable metric name into an alphanumeric string"""
160
    return metric.translate(metric.maketrans(" ", "_", "()"))
161