Passed
Pull Request — main (#798)
by
unknown
03:15
created

annif.util.DuplicateFilter.filter()   A

Complexity

Conditions 2

Size

Total Lines 6
Code Lines 6

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 6
nop 2
dl 0
loc 6
rs 10
c 0
b 0
f 0
1
"""Utility functions for Annif"""
2
3
from __future__ import annotations
4
5
import glob
6
import logging
7
import os
8
import os.path
9
import tempfile
10
from shutil import rmtree
11
from typing import Any, Callable
12
13
from annif import logger
14
15
16
class DuplicateFilter(logging.Filter):
17
    """Filter out log messages that have already been displayed."""
18
19
    def __init__(self) -> None:
20
        super().__init__()
21
        self.logged = set()
22
23
    def filter(self, record: logging.LogRecord) -> bool:
24
        current_log = hash((record.module, record.levelno, record.msg, record.args))
25
        if current_log not in self.logged:
26
            self.logged.add(current_log)
27
            return True
28
        return False
29
30
31
def atomic_save(
32
    obj: Any, dirname: str, filename: str, method: Callable | None = None
33
) -> None:
34
    """Save the given object (which must have a .save() method, unless the
35
    method parameter is given) into the given directory with the given
36
    filename, using a temporary file and renaming the temporary file to the
37
    final name. The .save() mehod or the function provided in the method argument
38
    will be called with the path to the temporary file."""
39
40
    prefix, suffix = os.path.splitext(filename)
41
    prefix = "tmp-" + prefix
42
    tempfd, tempfilename = tempfile.mkstemp(prefix=prefix, suffix=suffix, dir=dirname)
43
    os.close(tempfd)
44
    logger.debug("saving %s to temporary file %s", str(obj)[:90], tempfilename)
45
    if method is not None:
46
        method(obj, tempfilename)
47
    else:
48
        obj.save(tempfilename)
49
    for fn in glob.glob(tempfilename + "*"):
50
        newname = fn.replace(tempfilename, os.path.join(dirname, filename))
51
        logger.debug("renaming temporary file %s to %s", fn, newname)
52
        os.rename(fn, newname)
53
54
55
def atomic_save_folder(obj, dirname, method=None):
56
    """Save the given object (which must have a .save() method, unless the
57
    method parameter is given) into the given directory,
58
    using a temporary directory and renaming the temporary directory to the
59
    final name. The .save() method or the function provided in the method argument
60
    will be called with the path to the temporary directory."""
61
62
    tldir = os.path.dirname(dirname.rstrip("/"))
63
    os.makedirs(dirname, exist_ok=tldir)
64
    tempdir = tempfile.TemporaryDirectory(dir=tldir)
65
    temp_dir_name = tempdir.name
66
    target_pth = dirname
67
    logger.debug("saving %s to temporary file %s", str(obj)[:90], temp_dir_name)
68
    if method is not None:
69
        method(obj, temp_dir_name)
70
    else:
71
        obj.save(temp_dir_name)
72
    for fn in glob.glob(temp_dir_name + "*"):
73
        newname = fn.replace(temp_dir_name, target_pth)
74
        logger.debug("renaming temporary file %s to %s", fn, newname)
75
        if os.path.isdir(newname):
76
            rmtree(newname)
77
        os.replace(fn, newname)
78
79
80
def cleanup_uri(uri: str) -> str:
81
    """remove angle brackets from a URI, if any"""
82
    if uri.startswith("<") and uri.endswith(">"):
83
        return uri[1:-1]
84
    return uri
85
86
87
def parse_sources(sourcedef: str) -> list[tuple[str, float]]:
88
    """parse a source definition such as 'src1:1.0,src2' into a sequence of
89
    tuples (src_id, weight)"""
90
91
    sources = []
92
    totalweight = 0.0
93
    for srcdef in sourcedef.strip().split(","):
94
        srcval = srcdef.strip().split(":")
95
        src_id = srcval[0]
96
        if len(srcval) > 1:
97
            weight = float(srcval[1])
98
        else:
99
            weight = 1.0
100
        sources.append((src_id, weight))
101
        totalweight += weight
102
    return [(srcid, weight / totalweight) for srcid, weight in sources]
103
104
105
def parse_args(param_string: str) -> tuple[list, dict]:
106
    """Parse a string of comma separated arguments such as '42,43,key=abc' into
107
    a list of positional args [42, 43] and a dict of keyword args {key: abc}"""
108
109
    if not param_string:
110
        return [], {}
111
    posargs = []
112
    kwargs = {}
113
    param_strings = param_string.split(",")
114
    for p_string in param_strings:
115
        parts = p_string.split("=")
116
        if len(parts) == 1:
117
            posargs.append(p_string)
118
        elif len(parts) == 2:
119
            kwargs[parts[0]] = parts[1]
120
    return posargs, kwargs
121
122
123
def apply_param_parse_config(configs, params):
124
    """Applies a parsing configuration to a parameter dict."""
125
    return {
126
        param: configs[param](val)
127
        for param, val in params.items()
128
        if param in configs and val is not None
129
    }
130
131
132
133
def boolean(val: Any) -> bool:
134
    """Convert the given value to a boolean True/False value, if it isn't already.
135
    True values are '1', 'yes', 'true', and 'on' (case insensitive), everything
136
    else is False."""
137
138
    return str(val).lower() in ("1", "yes", "true", "on")
139
140
141
def identity(x: Any) -> Any:
142
    """Identity function: return the given argument unchanged"""
143
    return x
144
145
146
def metric_code(metric):
147
    """Convert a human-readable metric name into an alphanumeric string"""
148
    return metric.translate(metric.maketrans(" ", "_", "()"))
149