Passed
Pull Request — main (#798)
by
unknown
02:57
created

annif.util.atomic_save_folder()   A

Complexity

Conditions 4

Size

Total Lines 23
Code Lines 16

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 4
eloc 16
nop 3
dl 0
loc 23
rs 9.6
c 0
b 0
f 0
1
"""Utility functions for Annif"""
2
3
from __future__ import annotations
4
5
import glob
6
import logging
7
import os
8
import os.path
9
import tempfile
10
from shutil import rmtree
11
from typing import Any, Callable
12
13
from annif import logger
14
15
16
class DuplicateFilter(logging.Filter):
17
    """Filter out log messages that have already been displayed."""
18
19
    def __init__(self) -> None:
20
        super().__init__()
21
        self.logged = set()
22
23
    def filter(self, record: logging.LogRecord) -> bool:
24
        current_log = hash((record.module, record.levelno, record.msg, record.args))
25
        if current_log not in self.logged:
26
            self.logged.add(current_log)
27
            return True
28
        return False
29
30
31
def atomic_save(
32
    obj: Any, dirname: str, filename: str, method: Callable | None = None
33
) -> None:
34
    """Save the given object (which must have a .save() method, unless the
35
    method parameter is given) into the given directory with the given
36
    filename, using a temporary file and renaming the temporary file to the
37
    final name. The .save() mehod or the function provided in the method argument
38
    will be called with the path to the temporary file."""
39
40
    prefix, suffix = os.path.splitext(filename)
41
    prefix = "tmp-" + prefix
42
    tempfd, tempfilename = tempfile.mkstemp(prefix=prefix, suffix=suffix, dir=dirname)
43
    os.close(tempfd)
44
    logger.debug("saving %s to temporary file %s", str(obj)[:90], tempfilename)
45
    if method is not None:
46
        method(obj, tempfilename)
47
    else:
48
        obj.save(tempfilename)
49
    for fn in glob.glob(tempfilename + "*"):
50
        newname = fn.replace(tempfilename, os.path.join(dirname, filename))
51
        logger.debug("renaming temporary file %s to %s", fn, newname)
52
        os.rename(fn, newname)
53
        umask = os.umask(0o777)
54
        os.umask(umask)
55
        os.chmod(newname, 0o666 & ~umask)
56
57
58
def atomic_save_folder(obj, dirname, method=None):
59
    """Save the given object (which must have a .save() method, unless the
60
    method parameter is given) into the given directory,
61
    using a temporary directory and renaming the temporary directory to the
62
    final name. The .save() method or the function provided in the method argument
63
    will be called with the path to the temporary directory."""
64
65
    tldir = os.path.dirname(dirname.rstrip("/"))
66
    os.makedirs(dirname, exist_ok=tldir)
67
    tempdir = tempfile.TemporaryDirectory(dir=tldir)
68
    temp_dir_name = tempdir.name
69
    target_pth = dirname
70
    logger.debug("saving %s to temporary file %s", str(obj)[:90], temp_dir_name)
71
    if method is not None:
72
        method(obj, temp_dir_name)
73
    else:
74
        obj.save(temp_dir_name)
75
    for fn in glob.glob(temp_dir_name + "*"):
76
        newname = fn.replace(temp_dir_name, target_pth)
77
        logger.debug("renaming temporary file %s to %s", fn, newname)
78
        if os.path.isdir(newname):
79
            rmtree(newname)
80
        os.replace(fn, newname)
81
82
83
def cleanup_uri(uri: str) -> str:
84
    """remove angle brackets from a URI, if any"""
85
    if uri.startswith("<") and uri.endswith(">"):
86
        return uri[1:-1]
87
    return uri
88
89
90
def parse_sources(sourcedef: str) -> list[tuple[str, float]]:
91
    """parse a source definition such as 'src1:1.0,src2' into a sequence of
92
    tuples (src_id, weight)"""
93
94
    sources = []
95
    totalweight = 0.0
96
    for srcdef in sourcedef.strip().split(","):
97
        srcval = srcdef.strip().split(":")
98
        src_id = srcval[0]
99
        if len(srcval) > 1:
100
            weight = float(srcval[1])
101
        else:
102
            weight = 1.0
103
        sources.append((src_id, weight))
104
        totalweight += weight
105
    return [(srcid, weight / totalweight) for srcid, weight in sources]
106
107
108
def parse_args(param_string: str) -> tuple[list, dict]:
109
    """Parse a string of comma separated arguments such as '42,43,key=abc' into
110
    a list of positional args [42, 43] and a dict of keyword args {key: abc}"""
111
112
    if not param_string:
113
        return [], {}
114
    posargs = []
115
    kwargs = {}
116
    param_strings = param_string.split(",")
117
    for p_string in param_strings:
118
        parts = p_string.split("=")
119
        if len(parts) == 1:
120
            posargs.append(p_string)
121
        elif len(parts) == 2:
122
            kwargs[parts[0]] = parts[1]
123
    return posargs, kwargs
124
125
126
def apply_param_parse_config(configs, params):
127
    """Applies a parsing configuration to a parameter dict."""
128
    return {
129
        param: configs[param](val)
130
        for param, val in params.items()
131
        if param in configs and val is not None
132
    }
133
134
135
def boolean(val: Any) -> bool:
136
    """Convert the given value to a boolean True/False value, if it isn't already.
137
    True values are '1', 'yes', 'true', and 'on' (case insensitive), everything
138
    else is False."""
139
140
    return str(val).lower() in ("1", "yes", "true", "on")
141
142
143
def identity(x: Any) -> Any:
144
    """Identity function: return the given argument unchanged"""
145
    return x
146
147
148
def metric_code(metric):
149
    """Convert a human-readable metric name into an alphanumeric string"""
150
    return metric.translate(metric.maketrans(" ", "_", "()"))
151