Passed
Pull Request — main (#798)
by
unknown
03:19
created

annif.util.atomic_save()   A

Complexity

Conditions 3

Size

Total Lines 22
Code Lines 14

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 14
nop 4
dl 0
loc 22
rs 9.7
c 0
b 0
f 0
1
"""Utility functions for Annif"""
2
3
from __future__ import annotations
4
5
import glob
6
import logging
7
import os
8
import os.path
9
import tempfile
10
from shutil import rmtree
11
from typing import Any, Callable
12
import numpy as np
13
14
from annif import logger
15
16
17
class DuplicateFilter(logging.Filter):
18
    """Filter out log messages that have already been displayed."""
19
20
    def __init__(self) -> None:
21
        super().__init__()
22
        self.logged = set()
23
24
    def filter(self, record: logging.LogRecord) -> bool:
25
        current_log = hash((record.module, record.levelno, record.msg, record.args))
26
        if current_log not in self.logged:
27
            self.logged.add(current_log)
28
            return True
29
        return False
30
31
32
def atomic_save(
33
    obj: Any, dirname: str, filename: str, method: Callable | None = None
34
) -> None:
35
    """Save the given object (which must have a .save() method, unless the
36
    method parameter is given) into the given directory with the given
37
    filename, using a temporary file and renaming the temporary file to the
38
    final name. The .save() mehod or the function provided in the method argument
39
    will be called with the path to the temporary file."""
40
41
    prefix, suffix = os.path.splitext(filename)
42
    prefix = "tmp-" + prefix
43
    tempfd, tempfilename = tempfile.mkstemp(prefix=prefix, suffix=suffix, dir=dirname)
44
    os.close(tempfd)
45
    logger.debug("saving %s to temporary file %s", str(obj)[:90], tempfilename)
46
    if method is not None:
47
        method(obj, tempfilename)
48
    else:
49
        obj.save(tempfilename)
50
    for fn in glob.glob(tempfilename + "*"):
51
        newname = fn.replace(tempfilename, os.path.join(dirname, filename))
52
        logger.debug("renaming temporary file %s to %s", fn, newname)
53
        os.rename(fn, newname)
54
55
56
def atomic_save_folder(obj, dirname, method=None):
57
    """Save the given object (which must have a .save() method, unless the
58
    method parameter is given) into the given directory,
59
    using a temporary directory and renaming the temporary directory to the
60
    final name. The .save() method or the function provided in the method argument
61
    will be called with the path to the temporary directory."""
62
63
    tldir = os.path.dirname(dirname.rstrip("/"))
64
    os.makedirs(dirname, exist_ok=tldir)
65
    tempdir = tempfile.TemporaryDirectory(dir=tldir)
66
    temp_dir_name = tempdir.name
67
    target_pth = dirname
68
    logger.debug("saving %s to temporary file %s", str(obj)[:90], temp_dir_name)
69
    if method is not None:
70
        method(obj, temp_dir_name)
71
    else:
72
        obj.save(temp_dir_name)
73
    for fn in glob.glob(temp_dir_name + "*"):
74
        newname = fn.replace(temp_dir_name, target_pth)
75
        logger.debug("renaming temporary file %s to %s", fn, newname)
76
        if os.path.isdir(newname):
77
            rmtree(newname)
78
        os.replace(fn, newname)
79
80
81
def cleanup_uri(uri: str) -> str:
82
    """remove angle brackets from a URI, if any"""
83
    if uri.startswith("<") and uri.endswith(">"):
84
        return uri[1:-1]
85
    return uri
86
87
88
def parse_sources(sourcedef: str) -> list[tuple[str, float]]:
89
    """parse a source definition such as 'src1:1.0,src2' into a sequence of
90
    tuples (src_id, weight)"""
91
92
    sources = []
93
    totalweight = 0.0
94
    for srcdef in sourcedef.strip().split(","):
95
        srcval = srcdef.strip().split(":")
96
        src_id = srcval[0]
97
        if len(srcval) > 1:
98
            weight = float(srcval[1])
99
        else:
100
            weight = 1.0
101
        sources.append((src_id, weight))
102
        totalweight += weight
103
    return [(srcid, weight / totalweight) for srcid, weight in sources]
104
105
106
def parse_args(param_string: str) -> tuple[list, dict]:
107
    """Parse a string of comma separated arguments such as '42,43,key=abc' into
108
    a list of positional args [42, 43] and a dict of keyword args {key: abc}"""
109
110
    if not param_string:
111
        return [], {}
112
    posargs = []
113
    kwargs = {}
114
    param_strings = param_string.split(",")
115
    for p_string in param_strings:
116
        parts = p_string.split("=")
117
        if len(parts) == 1:
118
            posargs.append(p_string)
119
        elif len(parts) == 2:
120
            kwargs[parts[0]] = parts[1]
121
    return posargs, kwargs
122
123
124
def apply_param_parse_config(configs, params):
125
    """Applies a parsing configuration to a parameter dict."""
126
    return {
127
        param: configs[param](val)
128
        for param, val in params.items()
129
        if param in configs and val is not None
130
    }
131
132
133
134
def boolean(val: Any) -> bool:
135
    """Convert the given value to a boolean True/False value, if it isn't already.
136
    True values are '1', 'yes', 'true', and 'on' (case insensitive), everything
137
    else is False."""
138
139
    return str(val).lower() in ("1", "yes", "true", "on")
140
141
142
def identity(x: Any) -> Any:
143
    """Identity function: return the given argument unchanged"""
144
    return x
145
146
147
def metric_code(metric):
148
    """Convert a human-readable metric name into an alphanumeric string"""
149
    return metric.translate(metric.maketrans(" ", "_", "()"))
150