Passed
Pull Request — main (#798)
by
unknown
03:25
created

annif.util.atomic_save()   A

Complexity

Conditions 3

Size

Total Lines 25
Code Lines 17

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 17
nop 4
dl 0
loc 25
rs 9.55
c 0
b 0
f 0
1
"""Utility functions for Annif"""
2
3
from __future__ import annotations
4
5
import glob
6
import logging
7
import os
8
import os.path
9
import tempfile
10
from shutil import rmtree
11
from typing import Any, Callable
12
from typing import TYPE_CHECKING, Any, Callable
13
14
if TYPE_CHECKING:
15
    from annif.corpus.subject import SubjectIndex
16
    from annif.suggestion import SubjectSuggestion, SuggestionResults
17
18
from annif import logger
19
20
21
class DuplicateFilter(logging.Filter):
22
    """Filter out log messages that have already been displayed."""
23
24
    def __init__(self) -> None:
25
        super().__init__()
26
        self.logged = set()
27
28
    def filter(self, record: logging.LogRecord) -> bool:
29
        current_log = hash((record.module, record.levelno, record.msg, record.args))
30
        if current_log not in self.logged:
31
            self.logged.add(current_log)
32
            return True
33
        return False
34
35
36
def atomic_save(
37
    obj: Any, dirname: str, filename: str, method: Callable | None = None
38
) -> None:
39
    """Save the given object (which must have a .save() method, unless the
40
    method parameter is given) into the given directory with the given
41
    filename, using a temporary file and renaming the temporary file to the
42
    final name. The .save() mehod or the function provided in the method argument
43
    will be called with the path to the temporary file."""
44
45
    prefix, suffix = os.path.splitext(filename)
46
    prefix = "tmp-" + prefix
47
    tempfd, tempfilename = tempfile.mkstemp(prefix=prefix, suffix=suffix, dir=dirname)
48
    os.close(tempfd)
49
    logger.debug("saving %s to temporary file %s", str(obj)[:90], tempfilename)
50
    if method is not None:
51
        method(obj, tempfilename)
52
    else:
53
        obj.save(tempfilename)
54
    for fn in glob.glob(tempfilename + "*"):
55
        newname = fn.replace(tempfilename, os.path.join(dirname, filename))
56
        logger.debug("renaming temporary file %s to %s", fn, newname)
57
        os.rename(fn, newname)
58
        umask = os.umask(0o777)
59
        os.umask(umask)
60
        os.chmod(newname, 0o666 & ~umask)
61
62
63
def atomic_save_folder(obj, dirname, method=None):
64
    """Save the given object (which must have a .save() method, unless the
65
    method parameter is given) into the given directory,
66
    using a temporary directory and renaming the temporary directory to the
67
    final name. The .save() method or the function provided in the method argument
68
    will be called with the path to the temporary directory."""
69
70
    tldir = os.path.dirname(dirname.rstrip("/"))
71
    os.makedirs(dirname, exist_ok=tldir)
72
    tempdir = tempfile.TemporaryDirectory(dir=tldir)
73
    temp_dir_name = tempdir.name
74
    target_pth = dirname
75
    logger.debug("saving %s to temporary file %s", str(obj)[:90], temp_dir_name)
76
    if method is not None:
77
        method(obj, temp_dir_name)
78
    else:
79
        obj.save(temp_dir_name)
80
    for fn in glob.glob(temp_dir_name + "*"):
81
        newname = fn.replace(temp_dir_name, target_pth)
82
        logger.debug("renaming temporary file %s to %s", fn, newname)
83
        if os.path.isdir(newname):
84
            rmtree(newname)
85
        os.replace(fn, newname)
86
87
88
def cleanup_uri(uri: str) -> str:
89
    """remove angle brackets from a URI, if any"""
90
    if uri.startswith("<") and uri.endswith(">"):
91
        return uri[1:-1]
92
    return uri
93
94
95
def parse_sources(sourcedef: str) -> list[tuple[str, float]]:
96
    """parse a source definition such as 'src1:1.0,src2' into a sequence of
97
    tuples (src_id, weight)"""
98
99
    sources = []
100
    totalweight = 0.0
101
    for srcdef in sourcedef.strip().split(","):
102
        srcval = srcdef.strip().split(":")
103
        src_id = srcval[0]
104
        if len(srcval) > 1:
105
            weight = float(srcval[1])
106
        else:
107
            weight = 1.0
108
        sources.append((src_id, weight))
109
        totalweight += weight
110
    return [(srcid, weight / totalweight) for srcid, weight in sources]
111
112
113
def parse_args(param_string: str) -> tuple[list, dict]:
114
    """Parse a string of comma separated arguments such as '42,43,key=abc' into
115
    a list of positional args [42, 43] and a dict of keyword args {key: abc}"""
116
117
    if not param_string:
118
        return [], {}
119
    posargs = []
120
    kwargs = {}
121
    param_strings = param_string.split(",")
122
    for p_string in param_strings:
123
        parts = p_string.split("=")
124
        if len(parts) == 1:
125
            posargs.append(p_string)
126
        elif len(parts) == 2:
127
            kwargs[parts[0]] = parts[1]
128
    return posargs, kwargs
129
130
131
def apply_param_parse_config(configs, params):
132
    """Applies a parsing configuration to a parameter dict."""
133
    return {
134
        param: configs[param](val)
135
        for param, val in params.items()
136
        if param in configs and val is not None
137
    }
138
139
140
def boolean(val: Any) -> bool:
141
    """Convert the given value to a boolean True/False value, if it isn't already.
142
    True values are '1', 'yes', 'true', and 'on' (case insensitive), everything
143
    else is False."""
144
145
    return str(val).lower() in ("1", "yes", "true", "on")
146
147
148
def identity(x: Any) -> Any:
149
    """Identity function: return the given argument unchanged"""
150
    return x
151
152
153
def metric_code(metric):
154
    """Convert a human-readable metric name into an alphanumeric string"""
155
    return metric.translate(metric.maketrans(" ", "_", "()"))
156
157
158
def suggestion_to_dict(
159
    suggestion: SubjectSuggestion, subject_index: SubjectIndex, language: str
160
) -> dict[str, str | float | None]:
161
    subject = subject_index[suggestion.subject_id]
162
    return {
163
        "uri": subject.uri,
164
        "label": subject.labels[language],
165
        "notation": subject.notation,
166
        "score": suggestion.score,
167
    }
168
169
170
def suggestion_results_to_list(
171
    suggestion_results: SuggestionResults, subjects: SubjectIndex, lang: str
172
) -> list[dict[str, list]]:
173
    return [
174
        {
175
            "results": [
176
                suggestion_to_dict(suggestion, subjects, lang)
177
                for suggestion in suggestions
178
            ]
179
        }
180
        for suggestions in suggestion_results
181
    ]
182