Passed
Push — issue784-fix-nn-ensemle-model-... ( 7a1cfe )
by Juho
03:50
created

annif.util.get_keras_model_metadata()   A

Complexity

Conditions 4

Size

Total Lines 10
Code Lines 9

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 4
eloc 9
nop 1
dl 0
loc 10
rs 9.95
c 0
b 0
f 0
1
"""Utility functions for Annif"""
2
3
from __future__ import annotations
4
5
import glob
6
import json
7
import logging
8
import os
9
import os.path
10
import tempfile
11
import zipfile
12
from typing import Any, Callable
13
14
from annif import logger
15
16
17
class DuplicateFilter(logging.Filter):
18
    """Filter out log messages that have already been displayed."""
19
20
    def __init__(self) -> None:
21
        super().__init__()
22
        self.logged = set()
23
24
    def filter(self, record: logging.LogRecord) -> bool:
25
        current_log = hash((record.module, record.levelno, record.msg, record.args))
26
        if current_log not in self.logged:
27
            self.logged.add(current_log)
28
            return True
29
        return False
30
31
32
def atomic_save(
33
    obj: Any, dirname: str, filename: str, method: Callable | None = None
34
) -> None:
35
    """Save the given object (which must have a .save() method, unless the
36
    method parameter is given) into the given directory with the given
37
    filename, using a temporary file and renaming the temporary file to the
38
    final name."""
39
40
    prefix, suffix = os.path.splitext(filename)
41
    prefix = "tmp-" + prefix
42
    tempfd, tempfilename = tempfile.mkstemp(prefix=prefix, suffix=suffix, dir=dirname)
43
    os.close(tempfd)
44
    logger.debug("saving %s to temporary file %s", str(obj)[:90], tempfilename)
45
    if method is not None:
46
        method(obj, tempfilename)
47
    else:
48
        obj.save(tempfilename)
49
    for fn in glob.glob(tempfilename + "*"):
50
        newname = fn.replace(tempfilename, os.path.join(dirname, filename))
51
        logger.debug("renaming temporary file %s to %s", fn, newname)
52
        os.rename(fn, newname)
53
54
55
def get_keras_model_metadata(model_file_path: str) -> dict:
56
    """Read metadata from Keras model files."""
57
    try:
58
        with zipfile.ZipFile(model_file_path, "r") as zip:
59
            with zip.open("metadata.json") as metadata_file:
60
                metadata_str = metadata_file.read().decode("utf-8")
61
                metadata = json.loads(metadata_str)
62
                return metadata
63
    except Exception:
64
        return dict()
65
66
67
def cleanup_uri(uri: str) -> str:
68
    """remove angle brackets from a URI, if any"""
69
    if uri.startswith("<") and uri.endswith(">"):
70
        return uri[1:-1]
71
    return uri
72
73
74
def parse_sources(sourcedef: str) -> list[tuple[str, float]]:
75
    """parse a source definition such as 'src1:1.0,src2' into a sequence of
76
    tuples (src_id, weight)"""
77
78
    sources = []
79
    totalweight = 0.0
80
    for srcdef in sourcedef.strip().split(","):
81
        srcval = srcdef.strip().split(":")
82
        src_id = srcval[0]
83
        if len(srcval) > 1:
84
            weight = float(srcval[1])
85
        else:
86
            weight = 1.0
87
        sources.append((src_id, weight))
88
        totalweight += weight
89
    return [(srcid, weight / totalweight) for srcid, weight in sources]
90
91
92
def parse_args(param_string: str) -> tuple[list, dict]:
93
    """Parse a string of comma separated arguments such as '42,43,key=abc' into
94
    a list of positional args [42, 43] and a dict of keyword args {key: abc}"""
95
96
    if not param_string:
97
        return [], {}
98
    posargs = []
99
    kwargs = {}
100
    param_strings = param_string.split(",")
101
    for p_string in param_strings:
102
        parts = p_string.split("=")
103
        if len(parts) == 1:
104
            posargs.append(p_string)
105
        elif len(parts) == 2:
106
            kwargs[parts[0]] = parts[1]
107
    return posargs, kwargs
108
109
110
def boolean(val: Any) -> bool:
111
    """Convert the given value to a boolean True/False value, if it isn't already.
112
    True values are '1', 'yes', 'true', and 'on' (case insensitive), everything
113
    else is False."""
114
115
    return str(val).lower() in ("1", "yes", "true", "on")
116
117
118
def identity(x: Any) -> Any:
119
    """Identity function: return the given argument unchanged"""
120
    return x
121
122
123
def metric_code(metric):
124
    """Convert a human-readable metric name into an alphanumeric string"""
125
    return metric.translate(metric.maketrans(" ", "_", "()"))
126