Passed
Push — issue760-hugging-face-hub-inte... ( f6d2b7...9d030c )
by Juho
06:28
created

annif.cli_util.upload_to_hf_hub()   A

Complexity

Conditions 2

Size

Total Lines 12
Code Lines 11

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 11
nop 5
dl 0
loc 12
rs 9.85
c 0
b 0
f 0
1
"""Utility functions for Annif CLI commands"""
2
from __future__ import annotations
3
4
import collections
5
import configparser
6
import io
7
import itertools
8
import os
9
import pathlib
10
import sys
11
import tempfile
12
import zipfile
13
from typing import TYPE_CHECKING
14
15
import click
16
import click_log
17
from flask import current_app
18
from huggingface_hub import HfApi
19
from huggingface_hub.utils import HfHubHTTPError, HFValidationError
20
21
import annif
22
from annif.exception import ConfigurationException, OperationFailedException
23
from annif.project import Access
24
25
if TYPE_CHECKING:
26
    from datetime import datetime
27
    from io import TextIOWrapper
28
29
    from click.core import Argument, Context, Option
30
31
    from annif.corpus.document import DocumentCorpus, DocumentList
32
    from annif.corpus.subject import SubjectIndex
33
    from annif.project import AnnifProject
34
    from annif.suggestion import SuggestionResult
35
    from annif.vocab import AnnifVocabulary
36
37
logger = annif.logger
38
39
40
def _set_project_config_file_path(
41
    ctx: Context, param: Option, value: str | None
42
) -> None:
43
    """Override the default path or the path given in env by CLI option"""
44
    with ctx.obj.load_app().app_context():
45
        if value:
46
            current_app.config["PROJECTS_CONFIG_PATH"] = value
47
48
49
def common_options(f):
50
    """Decorator to add common options for all CLI commands"""
51
    f = click.option(
52
        "-p",
53
        "--projects",
54
        help="Set path to project configuration file or directory",
55
        type=click.Path(dir_okay=True, exists=True),
56
        callback=_set_project_config_file_path,
57
        expose_value=False,
58
        is_eager=True,
59
    )(f)
60
    return click_log.simple_verbosity_option(logger)(f)
61
62
63
def project_id(f):
64
    """Decorator to add a project ID parameter to a CLI command"""
65
    return click.argument("project_id", shell_complete=complete_param)(f)
66
67
68
def backend_param_option(f):
69
    """Decorator to add an option for CLI commands to override BE parameters"""
70
    return click.option(
71
        "--backend-param",
72
        "-b",
73
        multiple=True,
74
        help="Override backend parameter of the config file. "
75
        + "Syntax: `-b <backend>.<parameter>=<value>`.",
76
    )(f)
77
78
79
def docs_limit_option(f):
80
    """Decorator to add an option for CLI commands to limit the number of documents to
81
    use"""
82
    return click.option(
83
        "--docs-limit",
84
        "-d",
85
        default=None,
86
        type=click.IntRange(0, None),
87
        help="Maximum number of documents to use",
88
    )(f)
89
90
91
def get_project(project_id: str) -> AnnifProject:
92
    """
93
    Helper function to get a project by ID and bail out if it doesn't exist"""
94
    try:
95
        return annif.registry.get_project(project_id, min_access=Access.private)
96
    except ValueError:
97
        click.echo("No projects found with id '{0}'.".format(project_id), err=True)
98
        sys.exit(1)
99
100
101
def get_vocab(vocab_id: str) -> AnnifVocabulary:
102
    """
103
    Helper function to get a vocabulary by ID and bail out if it doesn't
104
    exist"""
105
    try:
106
        return annif.registry.get_vocab(vocab_id, min_access=Access.private)
107
    except ValueError:
108
        click.echo(f"No vocabularies found with the id '{vocab_id}'.", err=True)
109
        sys.exit(1)
110
111
112
def make_list_template(*rows) -> str:
113
    """Helper function to create a template for a list of entries with fields of
114
    variable width. The width of each field is determined by the longest item in the
115
    field in the given rows."""
116
117
    max_field_widths = collections.defaultdict(int)
118
    for row in rows:
119
        for field_ind, item in enumerate(row):
120
            max_field_widths[field_ind] = max(max_field_widths[field_ind], len(item))
121
122
    return "  ".join(
123
        [
124
            f"{{{field_ind}: <{field_width}}}"
125
            for field_ind, field_width in max_field_widths.items()
126
        ]
127
    )
128
129
130
def format_datetime(dt: datetime | None) -> str:
131
    """Helper function to format a datetime object as a string in the local time."""
132
    if dt is None:
133
        return "-"
134
    return dt.astimezone().strftime("%Y-%m-%d %H:%M:%S")
135
136
137
def open_documents(
138
    paths: tuple[str, ...],
139
    subject_index: SubjectIndex,
140
    vocab_lang: str,
141
    docs_limit: int | None,
142
) -> DocumentCorpus:
143
    """Helper function to open a document corpus from a list of pathnames,
144
    each of which is either a TSV file or a directory of TXT files. For
145
    directories with subjects in TSV files, the given vocabulary language
146
    will be used to convert subject labels into URIs. The corpus will be
147
    returned as an instance of DocumentCorpus or LimitingDocumentCorpus."""
148
149
    def open_doc_path(path, subject_index):
150
        """open a single path and return it as a DocumentCorpus"""
151
        if os.path.isdir(path):
152
            return annif.corpus.DocumentDirectory(
153
                path, subject_index, vocab_lang, require_subjects=True
154
            )
155
        return annif.corpus.DocumentFile(path, subject_index)
156
157
    if len(paths) == 0:
158
        logger.warning("Reading empty file")
159
        docs = open_doc_path(os.path.devnull, subject_index)
160
    elif len(paths) == 1:
161
        docs = open_doc_path(paths[0], subject_index)
162
    else:
163
        corpora = [open_doc_path(path, subject_index) for path in paths]
164
        docs = annif.corpus.CombinedCorpus(corpora)
165
    if docs_limit is not None:
166
        docs = annif.corpus.LimitingDocumentCorpus(docs, docs_limit)
167
    return docs
168
169
170
def open_text_documents(paths: tuple[str, ...], docs_limit: int | None) -> DocumentList:
171
    """
172
    Helper function to read text documents from the given file paths. Returns a
173
    DocumentList object with Documents having no subjects. If a path is "-", the
174
    document text is read from standard input. The maximum number of documents to read
175
    is set by docs_limit parameter.
176
    """
177
178
    def _docs(paths):
179
        for path in paths:
180
            if path == "-":
181
                doc = annif.corpus.Document(text=sys.stdin.read(), subject_set=None)
182
            else:
183
                with open(path, errors="replace", encoding="utf-8-sig") as docfile:
184
                    doc = annif.corpus.Document(text=docfile.read(), subject_set=None)
185
            yield doc
186
187
    return annif.corpus.DocumentList(_docs(paths[:docs_limit]))
188
189
190
def show_hits(
191
    hits: SuggestionResult,
192
    project: AnnifProject,
193
    lang: str,
194
    file: TextIOWrapper | None = None,
195
) -> None:
196
    """
197
    Print subject suggestions to the console or a file. The suggestions are displayed as
198
    a table, with one row per hit. Each row contains the URI, label, possible notation,
199
    and score of the suggestion. The label is given in the specified language.
200
    """
201
    template = "<{}>\t{}\t{:.04f}"
202
    for hit in hits:
203
        subj = project.subjects[hit.subject_id]
204
        line = template.format(
205
            subj.uri,
206
            "\t".join(filter(None, (subj.labels[lang], subj.notation))),
207
            hit.score,
208
        )
209
        click.echo(line, file=file)
210
211
212
def parse_backend_params(
213
    backend_param: tuple[str, ...] | tuple[()], project: AnnifProject
214
) -> collections.defaultdict[str, dict[str, str]]:
215
    """Parse a list of backend parameters given with the --backend-param
216
    option into a nested dict structure"""
217
    backend_params = collections.defaultdict(dict)
218
    for beparam in backend_param:
219
        backend, param = beparam.split(".", 1)
220
        key, val = param.split("=", 1)
221
        _validate_backend_params(backend, beparam, project)
222
        backend_params[backend][key] = val
223
    return backend_params
224
225
226
def _validate_backend_params(backend: str, beparam: str, project: AnnifProject) -> None:
227
    if backend != project.config["backend"]:
228
        raise ConfigurationException(
229
            'The backend {} in CLI option "-b {}" not matching the project'
230
            " backend {}.".format(backend, beparam, project.config["backend"])
231
        )
232
233
234
def generate_filter_params(filter_batch_max_limit: int) -> list[tuple[int, float]]:
235
    limits = range(1, filter_batch_max_limit + 1)
236
    thresholds = [i * 0.05 for i in range(20)]
237
    return list(itertools.product(limits, thresholds))
238
239
240
def _is_train_file(fname):
241
    train_file_patterns = ("-train", "tmp-")
242
    for pat in train_file_patterns:
243
        if pat in fname:
244
            return True
245
    return False
246
247
248
def archive_dir(data_dir):
249
    fp = tempfile.TemporaryFile()
250
    path = pathlib.Path(data_dir)
251
    fpaths = [fpath for fpath in path.glob("**/*") if not _is_train_file(fpath.name)]
252
    with zipfile.ZipFile(fp, mode="w") as zfile:
253
        for fpath in fpaths:
254
            logger.debug(f"Adding {fpath}")
255
            zfile.write(fpath)
256
    fp.seek(0)
257
    return fp
258
259
260
def write_config(project):
261
    fp = tempfile.TemporaryFile(mode="w+t")
262
    config = configparser.ConfigParser()
263
    config[project.project_id] = project.config
264
    config.write(fp)  # This needs tempfile in text mode
265
    fp.seek(0)
266
    # But for upload fobj needs to be in binary mode
267
    return io.BytesIO(fp.read().encode("utf8"))
268
269
270
def upload_to_hf_hub(fileobj, filename, repo_id, token, commit_message):
271
    api = HfApi()
272
    try:
273
        api.upload_file(
274
            path_or_fileobj=fileobj,
275
            path_in_repo=filename,
276
            repo_id=repo_id,
277
            token=token,
278
            commit_message=commit_message,
279
        )
280
    except (HfHubHTTPError, HFValidationError) as err:
281
        raise OperationFailedException(str(err))
282
283
284
def _get_completion_choices(
285
    param: Argument,
286
) -> dict[str, AnnifVocabulary] | dict[str, AnnifProject] | list:
287
    if param.name == "project_id":
288
        return annif.registry.get_projects()
289
    elif param.name == "vocab_id":
290
        return annif.registry.get_vocabs()
291
    else:
292
        return []
293
294
295
def complete_param(ctx: Context, param: Argument, incomplete: str) -> list[str]:
296
    with ctx.obj.load_app().app_context():
297
        return [
298
            choice
299
            for choice in _get_completion_choices(param)
300
            if choice.startswith(incomplete)
301
        ]
302