Passed
Push — issue760-hugging-face-hub-inte... ( 313511...7666de )
by Juho
07:21
created

annif.cli_util.download_from_hf_hub()   A

Complexity

Conditions 2

Size

Total Lines 10
Code Lines 9

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 9
nop 4
dl 0
loc 10
rs 9.95
c 0
b 0
f 0
1
"""Utility functions for Annif CLI commands"""
2
3
from __future__ import annotations
4
5
import binascii
6
import collections
7
import configparser
8
import io
9
import itertools
10
import os
11
import pathlib
12
import shutil
13
import sys
14
import tempfile
15
import zipfile
16
from fnmatch import fnmatch
17
from typing import TYPE_CHECKING
18
19
import click
20
import click_log
21
from flask import current_app
22
23
import annif
24
from annif.exception import ConfigurationException, OperationFailedException
25
from annif.project import Access
26
27
if TYPE_CHECKING:
28
    from datetime import datetime
29
    from io import TextIOWrapper
30
31
    from click.core import Argument, Context, Option
32
33
    from annif.corpus.document import DocumentCorpus, DocumentList
34
    from annif.corpus.subject import SubjectIndex
35
    from annif.project import AnnifProject
36
    from annif.suggestion import SuggestionResult
37
    from annif.vocab import AnnifVocabulary
38
39
logger = annif.logger
40
41
42
def _set_project_config_file_path(
43
    ctx: Context, param: Option, value: str | None
44
) -> None:
45
    """Override the default path or the path given in env by CLI option"""
46
    with ctx.obj.load_app().app_context():
47
        if value:
48
            current_app.config["PROJECTS_CONFIG_PATH"] = value
49
50
51
def common_options(f):
52
    """Decorator to add common options for all CLI commands"""
53
    f = click.option(
54
        "-p",
55
        "--projects",
56
        help="Set path to project configuration file or directory",
57
        type=click.Path(dir_okay=True, exists=True),
58
        callback=_set_project_config_file_path,
59
        expose_value=False,
60
        is_eager=True,
61
    )(f)
62
    return click_log.simple_verbosity_option(logger)(f)
63
64
65
def project_id(f):
66
    """Decorator to add a project ID parameter to a CLI command"""
67
    return click.argument("project_id", shell_complete=complete_param)(f)
68
69
70
def backend_param_option(f):
71
    """Decorator to add an option for CLI commands to override BE parameters"""
72
    return click.option(
73
        "--backend-param",
74
        "-b",
75
        multiple=True,
76
        help="Override backend parameter of the config file. "
77
        + "Syntax: `-b <backend>.<parameter>=<value>`.",
78
    )(f)
79
80
81
def docs_limit_option(f):
82
    """Decorator to add an option for CLI commands to limit the number of documents to
83
    use"""
84
    return click.option(
85
        "--docs-limit",
86
        "-d",
87
        default=None,
88
        type=click.IntRange(0, None),
89
        help="Maximum number of documents to use",
90
    )(f)
91
92
93
def get_project(project_id: str) -> AnnifProject:
94
    """
95
    Helper function to get a project by ID and bail out if it doesn't exist"""
96
    try:
97
        return annif.registry.get_project(project_id, min_access=Access.private)
98
    except ValueError:
99
        click.echo("No projects found with id '{0}'.".format(project_id), err=True)
100
        sys.exit(1)
101
102
103
def get_vocab(vocab_id: str) -> AnnifVocabulary:
104
    """
105
    Helper function to get a vocabulary by ID and bail out if it doesn't
106
    exist"""
107
    try:
108
        return annif.registry.get_vocab(vocab_id, min_access=Access.private)
109
    except ValueError:
110
        click.echo(f"No vocabularies found with the id '{vocab_id}'.", err=True)
111
        sys.exit(1)
112
113
114
def make_list_template(*rows) -> str:
115
    """Helper function to create a template for a list of entries with fields of
116
    variable width. The width of each field is determined by the longest item in the
117
    field in the given rows."""
118
119
    max_field_widths = collections.defaultdict(int)
120
    for row in rows:
121
        for field_ind, item in enumerate(row):
122
            max_field_widths[field_ind] = max(max_field_widths[field_ind], len(item))
123
124
    return "  ".join(
125
        [
126
            f"{{{field_ind}: <{field_width}}}"
127
            for field_ind, field_width in max_field_widths.items()
128
        ]
129
    )
130
131
132
def format_datetime(dt: datetime | None) -> str:
133
    """Helper function to format a datetime object as a string in the local time."""
134
    if dt is None:
135
        return "-"
136
    return dt.astimezone().strftime("%Y-%m-%d %H:%M:%S")
137
138
139
def open_documents(
140
    paths: tuple[str, ...],
141
    subject_index: SubjectIndex,
142
    vocab_lang: str,
143
    docs_limit: int | None,
144
) -> DocumentCorpus:
145
    """Helper function to open a document corpus from a list of pathnames,
146
    each of which is either a TSV file or a directory of TXT files. For
147
    directories with subjects in TSV files, the given vocabulary language
148
    will be used to convert subject labels into URIs. The corpus will be
149
    returned as an instance of DocumentCorpus or LimitingDocumentCorpus."""
150
151
    def open_doc_path(path, subject_index):
152
        """open a single path and return it as a DocumentCorpus"""
153
        if os.path.isdir(path):
154
            return annif.corpus.DocumentDirectory(
155
                path, subject_index, vocab_lang, require_subjects=True
156
            )
157
        return annif.corpus.DocumentFile(path, subject_index)
158
159
    if len(paths) == 0:
160
        logger.warning("Reading empty file")
161
        docs = open_doc_path(os.path.devnull, subject_index)
162
    elif len(paths) == 1:
163
        docs = open_doc_path(paths[0], subject_index)
164
    else:
165
        corpora = [open_doc_path(path, subject_index) for path in paths]
166
        docs = annif.corpus.CombinedCorpus(corpora)
167
    if docs_limit is not None:
168
        docs = annif.corpus.LimitingDocumentCorpus(docs, docs_limit)
169
    return docs
170
171
172
def open_text_documents(paths: tuple[str, ...], docs_limit: int | None) -> DocumentList:
173
    """
174
    Helper function to read text documents from the given file paths. Returns a
175
    DocumentList object with Documents having no subjects. If a path is "-", the
176
    document text is read from standard input. The maximum number of documents to read
177
    is set by docs_limit parameter.
178
    """
179
180
    def _docs(paths):
181
        for path in paths:
182
            if path == "-":
183
                doc = annif.corpus.Document(text=sys.stdin.read(), subject_set=None)
184
            else:
185
                with open(path, errors="replace", encoding="utf-8-sig") as docfile:
186
                    doc = annif.corpus.Document(text=docfile.read(), subject_set=None)
187
            yield doc
188
189
    return annif.corpus.DocumentList(_docs(paths[:docs_limit]))
190
191
192
def show_hits(
193
    hits: SuggestionResult,
194
    project: AnnifProject,
195
    lang: str,
196
    file: TextIOWrapper | None = None,
197
) -> None:
198
    """
199
    Print subject suggestions to the console or a file. The suggestions are displayed as
200
    a table, with one row per hit. Each row contains the URI, label, possible notation,
201
    and score of the suggestion. The label is given in the specified language.
202
    """
203
    template = "<{}>\t{}\t{:.04f}"
204
    for hit in hits:
205
        subj = project.subjects[hit.subject_id]
206
        line = template.format(
207
            subj.uri,
208
            "\t".join(filter(None, (subj.labels[lang], subj.notation))),
209
            hit.score,
210
        )
211
        click.echo(line, file=file)
212
213
214
def parse_backend_params(
215
    backend_param: tuple[str, ...] | tuple[()], project: AnnifProject
216
) -> collections.defaultdict[str, dict[str, str]]:
217
    """Parse a list of backend parameters given with the --backend-param
218
    option into a nested dict structure"""
219
    backend_params = collections.defaultdict(dict)
220
    for beparam in backend_param:
221
        backend, param = beparam.split(".", 1)
222
        key, val = param.split("=", 1)
223
        _validate_backend_params(backend, beparam, project)
224
        backend_params[backend][key] = val
225
    return backend_params
226
227
228
def _validate_backend_params(backend: str, beparam: str, project: AnnifProject) -> None:
229
    if backend != project.config["backend"]:
230
        raise ConfigurationException(
231
            'The backend {} in CLI option "-b {}" not matching the project'
232
            " backend {}.".format(backend, beparam, project.config["backend"])
233
        )
234
235
236
def generate_filter_params(filter_batch_max_limit: int) -> list[tuple[int, float]]:
237
    limits = range(1, filter_batch_max_limit + 1)
238
    thresholds = [i * 0.05 for i in range(20)]
239
    return list(itertools.product(limits, thresholds))
240
241
242
def _is_train_file(fname):
243
    train_file_patterns = ("-train", "tmp-")
244
    for pat in train_file_patterns:
245
        if pat in fname:
246
            return True
247
    return False
248
249
250
def archive_dir(data_dir):
251
    fp = tempfile.TemporaryFile()
252
    path = pathlib.Path(data_dir)
253
    fpaths = [fpath for fpath in path.glob("**/*") if not _is_train_file(fpath.name)]
254
    with zipfile.ZipFile(fp, mode="w") as zfile:
255
        for fpath in fpaths:
256
            logger.debug(f"Adding {fpath}")
257
            zfile.write(fpath)
258
    fp.seek(0)
259
    return fp
260
261
262
def write_config(project):
263
    fp = tempfile.TemporaryFile(mode="w+t")
264
    config = configparser.ConfigParser()
265
    config[project.project_id] = project.config
266
    config.write(fp)  # This needs tempfile in text mode
267
    fp.seek(0)
268
    # But for upload fobj needs to be in binary mode
269
    return io.BytesIO(fp.read().encode("utf8"))
270
271
272
def upload_to_hf_hub(fileobj, filename, repo_id, token, commit_message):
273
    from huggingface_hub import HfApi
274
    from huggingface_hub.utils import HfHubHTTPError, HFValidationError
275
276
    api = HfApi()
277
    try:
278
        api.upload_file(
279
            path_or_fileobj=fileobj,
280
            path_in_repo=filename,
281
            repo_id=repo_id,
282
            token=token,
283
            commit_message=commit_message,
284
        )
285
    except (HfHubHTTPError, HFValidationError) as err:
286
        raise OperationFailedException(str(err))
287
288
289
def get_selected_project_ids_from_hf_hub(project_ids_pattern, repo_id, token, revision):
290
    all_repo_file_paths = _list_files_in_hf_hub(repo_id, token, revision)
291
    return [
292
        path.rsplit(".zip")[0].split("projects/")[1]  # TODO Try-catch this
293
        for path in all_repo_file_paths
294
        if fnmatch(path, f"projects/{project_ids_pattern}.zip")
295
    ]
296
297
298
def _list_files_in_hf_hub(repo_id, token, revision):
299
    from huggingface_hub import list_repo_files
300
301
    return [
302
        repofile
303
        for repofile in list_repo_files(repo_id=repo_id, token=token, revision=revision)
304
    ]
305
306
307
def download_from_hf_hub(filename, repo_id, token, revision):
308
    from huggingface_hub import hf_hub_download
309
    from huggingface_hub.utils import HfHubHTTPError, HFValidationError
310
311
    try:
312
        return hf_hub_download(
313
            repo_id=repo_id,
314
            filename=filename,
315
            token=token,
316
            revision=revision,
317
        )
318
    except (HfHubHTTPError, HFValidationError) as err:
319
        raise OperationFailedException(str(err))
320
321
322
def unzip(src_path, force):
323
    with zipfile.ZipFile(src_path, "r") as zfile:
324
        for member in zfile.infolist():
325
            if os.path.exists(member.filename) and not force:
326
                if _is_existing_identical(member):
327
                    logger.debug(
328
                        f"Skipping unzip of {member.filename}; already in place"
329
                    )
330
                else:
331
                    click.echo(
332
                        f"Not overwriting {member.filename} (use --force to override)"
333
                    )
334
            else:
335
                logger.debug(f"Unzipping {member.filename}")
336
                zfile.extract(member)
337
338
339
def move_project_config(src_path, force):
340
    dst_path = os.path.join("projects.d", os.path.basename(src_path))
341
    if os.path.exists(dst_path) and not force:
342
        if _compute_crc32(dst_path) == _compute_crc32(src_path):
343
            logger.debug(
344
                f"Skipping move of {os.path.basename(src_path)}; already in place"
345
            )
346
        else:
347
            click.echo(f"Not overwriting {dst_path} (use --force to override)")
348
    else:
349
        shutil.copy(src_path, dst_path)
350
351
352
def _is_existing_identical(member):
353
    file_crc = _compute_crc32(member.filename)
354
    return file_crc == member.CRC
355
356
357
def _compute_crc32(path):
358
    if os.path.isdir(path):
359
        return 0
360
361
    size = 1024 * 1024 * 10  # 10 MiB chunks
362
    with open(path, "rb") as fp:
363
        crcval = 0
364
        while chunk := fp.read(size):
365
            crcval = binascii.crc32(chunk, crcval)
366
    return crcval
367
368
369
def get_vocab_id(config_path):
370
    config = configparser.ConfigParser()
371
    config.read(config_path)
372
    section = config.sections()[0]
373
    return config[section]["vocab"]
374
375
376
def _get_completion_choices(
377
    param: Argument,
378
) -> dict[str, AnnifVocabulary] | dict[str, AnnifProject] | list:
379
    if param.name == "project_id":
380
        return annif.registry.get_projects()
381
    elif param.name == "vocab_id":
382
        return annif.registry.get_vocabs()
383
    else:
384
        return []
385
386
387
def complete_param(ctx: Context, param: Argument, incomplete: str) -> list[str]:
388
    with ctx.obj.load_app().app_context():
389
        return [
390
            choice
391
            for choice in _get_completion_choices(param)
392
            if choice.startswith(incomplete)
393
        ]
394