Passed
Pull Request — main (#762)
by Juho
06:47 queued 03:53
created

annif.cli_util   A

Complexity

Total Complexity 37

Size/Duplication

Total Lines 251
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 156
dl 0
loc 251
rs 9.44
c 0
b 0
f 0
wmc 37
1
"""Utility functions for Annif CLI commands"""
2
3
from __future__ import annotations
4
5
import binascii
6
import collections
7
import configparser
8
import importlib
9
import io
10
import itertools
11
import os
12
import pathlib
13
import shutil
14
import sys
15
import tempfile
16
import time
17
import zipfile
18
from fnmatch import fnmatch
19
from typing import TYPE_CHECKING
20
21
import click
22
import click_log
23
from flask import current_app
24
25
import annif
26
from annif.exception import ConfigurationException, OperationFailedException
27
from annif.project import Access
28
29
if TYPE_CHECKING:
30
    from datetime import datetime
31
    from typing import Any
32
33
    from click.core import Argument, Context, Option
34
35
    from annif.corpus.document import DocumentCorpus, DocumentList
36
    from annif.corpus.subject import SubjectIndex
37
    from annif.project import AnnifProject
38
    from annif.suggestion import SuggestionResult
39
    from annif.vocab import AnnifVocabulary
40
41
logger = annif.logger
42
43
44
def _set_project_config_file_path(
45
    ctx: Context, param: Option, value: str | None
46
) -> None:
47
    """Override the default path or the path given in env by CLI option"""
48
    with ctx.obj.load_app().app_context():
49
        if value:
50
            current_app.config["PROJECTS_CONFIG_PATH"] = value
51
52
53
def common_options(f):
54
    """Decorator to add common options for all CLI commands"""
55
    f = click.option(
56
        "-p",
57
        "--projects",
58
        help="Set path to project configuration file or directory",
59
        type=click.Path(dir_okay=True, exists=True),
60
        callback=_set_project_config_file_path,
61
        expose_value=False,
62
        is_eager=True,
63
    )(f)
64
    return click_log.simple_verbosity_option(logger)(f)
65
66
67
def project_id(f):
68
    """Decorator to add a project ID parameter to a CLI command"""
69
    return click.argument("project_id", shell_complete=complete_param)(f)
70
71
72
def backend_param_option(f):
73
    """Decorator to add an option for CLI commands to override BE parameters"""
74
    return click.option(
75
        "--backend-param",
76
        "-b",
77
        multiple=True,
78
        help="Override backend parameter of the config file. "
79
        + "Syntax: `-b <backend>.<parameter>=<value>`.",
80
    )(f)
81
82
83
def docs_limit_option(f):
84
    """Decorator to add an option for CLI commands to limit the number of documents to
85
    use"""
86
    return click.option(
87
        "--docs-limit",
88
        "-d",
89
        default=None,
90
        type=click.IntRange(0, None),
91
        help="Maximum number of documents to use",
92
    )(f)
93
94
95
def get_project(project_id: str) -> AnnifProject:
96
    """
97
    Helper function to get a project by ID and bail out if it doesn't exist"""
98
    try:
99
        return annif.registry.get_project(project_id, min_access=Access.private)
100
    except ValueError:
101
        click.echo("No projects found with id '{0}'.".format(project_id), err=True)
102
        sys.exit(1)
103
104
105
def get_vocab(vocab_id: str) -> AnnifVocabulary:
106
    """
107
    Helper function to get a vocabulary by ID and bail out if it doesn't
108
    exist"""
109
    try:
110
        return annif.registry.get_vocab(vocab_id, min_access=Access.private)
111
    except ValueError:
112
        click.echo(f"No vocabularies found with the id '{vocab_id}'.", err=True)
113
        sys.exit(1)
114
115
116
def make_list_template(*rows) -> str:
117
    """Helper function to create a template for a list of entries with fields of
118
    variable width. The width of each field is determined by the longest item in the
119
    field in the given rows."""
120
121
    max_field_widths = collections.defaultdict(int)
122
    for row in rows:
123
        for field_ind, item in enumerate(row):
124
            max_field_widths[field_ind] = max(max_field_widths[field_ind], len(item))
125
126
    return "  ".join(
127
        [
128
            f"{{{field_ind}: <{field_width}}}"
129
            for field_ind, field_width in max_field_widths.items()
130
        ]
131
    )
132
133
134
def format_datetime(dt: datetime | None) -> str:
135
    """Helper function to format a datetime object as a string in the local time."""
136
    if dt is None:
137
        return "-"
138
    return dt.astimezone().strftime("%Y-%m-%d %H:%M:%S")
139
140
141
def open_documents(
142
    paths: tuple[str, ...],
143
    subject_index: SubjectIndex,
144
    vocab_lang: str,
145
    docs_limit: int | None,
146
) -> DocumentCorpus:
147
    """Helper function to open a document corpus from a list of pathnames,
148
    each of which is either a TSV file or a directory of TXT files. For
149
    directories with subjects in TSV files, the given vocabulary language
150
    will be used to convert subject labels into URIs. The corpus will be
151
    returned as an instance of DocumentCorpus or LimitingDocumentCorpus."""
152
153
    def open_doc_path(path, subject_index):
154
        """open a single path and return it as a DocumentCorpus"""
155
        if os.path.isdir(path):
156
            return annif.corpus.DocumentDirectory(
157
                path, subject_index, vocab_lang, require_subjects=True
158
            )
159
        return annif.corpus.DocumentFile(path, subject_index)
160
161
    if len(paths) == 0:
162
        logger.warning("Reading empty file")
163
        docs = open_doc_path(os.path.devnull, subject_index)
164
    elif len(paths) == 1:
165
        docs = open_doc_path(paths[0], subject_index)
166
    else:
167
        corpora = [open_doc_path(path, subject_index) for path in paths]
168
        docs = annif.corpus.CombinedCorpus(corpora)
169
    if docs_limit is not None:
170
        docs = annif.corpus.LimitingDocumentCorpus(docs, docs_limit)
171
    return docs
172
173
174
def open_text_documents(paths: tuple[str, ...], docs_limit: int | None) -> DocumentList:
175
    """
176
    Helper function to read text documents from the given file paths. Returns a
177
    DocumentList object with Documents having no subjects. If a path is "-", the
178
    document text is read from standard input. The maximum number of documents to read
179
    is set by docs_limit parameter.
180
    """
181
182
    def _docs(paths):
183
        for path in paths:
184
            if path == "-":
185
                doc = annif.corpus.Document(text=sys.stdin.read(), subject_set=None)
186
            else:
187
                with open(path, errors="replace", encoding="utf-8-sig") as docfile:
188
                    doc = annif.corpus.Document(text=docfile.read(), subject_set=None)
189
            yield doc
190
191
    return annif.corpus.DocumentList(_docs(paths[:docs_limit]))
192
193
194
def show_hits(
195
    hits: SuggestionResult,
196
    project: AnnifProject,
197
    lang: str,
198
    file: io.TextIOWrapper | None = None,
199
) -> None:
200
    """
201
    Print subject suggestions to the console or a file. The suggestions are displayed as
202
    a table, with one row per hit. Each row contains the URI, label, possible notation,
203
    and score of the suggestion. The label is given in the specified language.
204
    """
205
    template = "<{}>\t{}\t{:.04f}"
206
    for hit in hits:
207
        subj = project.subjects[hit.subject_id]
208
        line = template.format(
209
            subj.uri,
210
            "\t".join(filter(None, (subj.labels[lang], subj.notation))),
211
            hit.score,
212
        )
213
        click.echo(line, file=file)
214
215
216
def parse_backend_params(
217
    backend_param: tuple[str, ...] | tuple[()], project: AnnifProject
218
) -> collections.defaultdict[str, dict[str, str]]:
219
    """Parse a list of backend parameters given with the --backend-param
220
    option into a nested dict structure"""
221
    backend_params = collections.defaultdict(dict)
222
    for beparam in backend_param:
223
        backend, param = beparam.split(".", 1)
224
        key, val = param.split("=", 1)
225
        _validate_backend_params(backend, beparam, project)
226
        backend_params[backend][key] = val
227
    return backend_params
228
229
230
def _validate_backend_params(backend: str, beparam: str, project: AnnifProject) -> None:
231
    if backend != project.config["backend"]:
232
        raise ConfigurationException(
233
            'The backend {} in CLI option "-b {}" not matching the project'
234
            " backend {}.".format(backend, beparam, project.config["backend"])
235
        )
236
237
238
def generate_filter_params(filter_batch_max_limit: int) -> list[tuple[int, float]]:
239
    limits = range(1, filter_batch_max_limit + 1)
240
    thresholds = [i * 0.05 for i in range(20)]
241
    return list(itertools.product(limits, thresholds))
242
243
244
def get_matching_projects(pattern: str) -> list[AnnifProject]:
245
    """
246
    Get projects that match the given pattern.
247
    """
248
    return [
249
        proj
250
        for proj in annif.registry.get_projects(min_access=Access.private).values()
251
        if fnmatch(proj.project_id, pattern)
252
    ]
253
254
255
def prepare_commits(projects: list[AnnifProject], repo_id: str) -> tuple[list, list]:
256
    """Prepare and pre-upload data and config commit operations for projects to a
257
    Hugging Face Hub repository."""
258
    from huggingface_hub import preupload_lfs_files
259
260
    fobjs, operations = [], []
261
    data_dirs = {p.datadir for p in projects}
262
    vocab_dirs = {p.vocab.datadir for p in projects}
263
    all_dirs = data_dirs.union(vocab_dirs)
264
265
    for data_dir in all_dirs:
266
        fobj, operation = _prepare_datadir_commit(data_dir)
267
        preupload_lfs_files(repo_id, additions=[operation])
268
        fobjs.append(fobj)
269
        operations.append(operation)
270
271
    for project in projects:
272
        fobj, operation = _prepare_config_commit(project)
273
        fobjs.append(fobj)
274
        operations.append(operation)
275
276
    return fobjs, operations
277
278
279
def _prepare_datadir_commit(data_dir: str) -> tuple[io.BufferedRandom, Any]:
280
    from huggingface_hub import CommitOperationAdd
281
282
    zip_repo_path = data_dir.split(os.path.sep, 1)[1] + ".zip"
283
    fobj = _archive_dir(data_dir)
284
    operation = CommitOperationAdd(path_in_repo=zip_repo_path, path_or_fileobj=fobj)
285
    return fobj, operation
286
287
288
def _prepare_config_commit(project: AnnifProject) -> tuple[io.BytesIO, Any]:
289
    from huggingface_hub import CommitOperationAdd
290
291
    config_repo_path = project.project_id + ".cfg"
292
    fobj = _get_project_config(project)
293
    operation = CommitOperationAdd(path_in_repo=config_repo_path, path_or_fileobj=fobj)
294
    return fobj, operation
295
296
297
def _is_train_file(fname: str) -> bool:
298
    train_file_patterns = ("-train", "tmp-")
299
    for pat in train_file_patterns:
300
        if pat in fname:
301
            return True
302
    return False
303
304
305
def _archive_dir(data_dir: str) -> io.BufferedRandom:
306
    fp = tempfile.TemporaryFile()
307
    path = pathlib.Path(data_dir)
308
    fpaths = [fpath for fpath in path.glob("**/*") if not _is_train_file(fpath.name)]
309
    with zipfile.ZipFile(fp, mode="w") as zfile:
310
        zfile.comment = bytes(
311
            f"Archived by Annif {importlib.metadata.version('annif')}",
312
            encoding="utf-8",
313
        )
314
        for fpath in fpaths:
315
            logger.debug(f"Adding {fpath}")
316
            arcname = os.path.join(*fpath.parts[1:])
317
            zfile.write(fpath, arcname=arcname)
318
    fp.seek(0)
319
    return fp
320
321
322
def _get_project_config(project: AnnifProject) -> io.BytesIO:
323
    fp = tempfile.TemporaryFile(mode="w+t")
324
    config = configparser.ConfigParser()
325
    config[project.project_id] = project.config
326
    config.write(fp)  # This needs tempfile in text mode
327
    fp.seek(0)
328
    # But for upload fobj needs to be in binary mode
329
    return io.BytesIO(fp.read().encode("utf8"))
330
331
332
def get_matching_project_ids_from_hf_hub(
333
    project_ids_pattern: str, repo_id: str, token, revision: str
334
) -> list[str]:
335
    """Get project IDs of the projects in a Hugging Face Model Hub repository that match
336
    the given pattern."""
337
    all_repo_file_paths = _list_files_in_hf_hub(repo_id, token, revision)
338
    return [
339
        path.rsplit(".cfg")[0]
340
        for path in all_repo_file_paths
341
        if fnmatch(path, f"{project_ids_pattern}.cfg")
342
    ]
343
344
345
def _list_files_in_hf_hub(repo_id: str, token: str, revision: str) -> list[str]:
346
    from huggingface_hub import list_repo_files
347
    from huggingface_hub.utils import HfHubHTTPError, HFValidationError
348
349
    try:
350
        return [
351
            repofile
352
            for repofile in list_repo_files(
353
                repo_id=repo_id, token=token, revision=revision
354
            )
355
        ]
356
    except (HfHubHTTPError, HFValidationError) as err:
357
        raise OperationFailedException(str(err))
358
359
360
def download_from_hf_hub(
361
    filename: str, repo_id: str, token: str, revision: str
362
) -> list[str]:
363
    from huggingface_hub import hf_hub_download
364
    from huggingface_hub.utils import HfHubHTTPError, HFValidationError
365
366
    try:
367
        return hf_hub_download(
368
            repo_id=repo_id,
369
            filename=filename,
370
            token=token,
371
            revision=revision,
372
        )
373
    except (HfHubHTTPError, HFValidationError) as err:
374
        raise OperationFailedException(str(err))
375
376
377
def unzip_archive(src_path: str, force: bool) -> None:
378
    """Unzip a zip archive of projects and vocabularies to a directory, by
379
    default data/ under current directory."""
380
    datadir = current_app.config["DATADIR"]
381
    with zipfile.ZipFile(src_path, "r") as zfile:
382
        archive_comment = str(zfile.comment, encoding="utf-8")
383
        logger.debug(
384
            f'Extracting archive {src_path}; archive comment: "{archive_comment}"'
385
        )
386
        for member in zfile.infolist():
387
            _unzip_member(zfile, member, datadir, force)
388
389
390
def _unzip_member(
391
    zfile: zipfile.ZipFile, member: zipfile.ZipInfo, datadir: str, force: bool
392
) -> None:
393
    dest_path = os.path.join(datadir, member.filename)
394
    if os.path.exists(dest_path) and not force:
395
        _handle_existing_file(member, dest_path)
396
        return
397
    logger.debug(f"Unzipping to {dest_path}")
398
    zfile.extract(member, path=datadir)
399
    _restore_timestamps(member, dest_path)
400
401
402
def _handle_existing_file(member: zipfile.ZipInfo, dest_path: str) -> None:
403
    if _are_identical_member_and_file(member, dest_path):
404
        logger.debug(f"Skipping unzip to {dest_path}; already in place")
405
    else:
406
        click.echo(f"Not overwriting {dest_path} (use --force to override)")
407
408
409
def _are_identical_member_and_file(member: zipfile.ZipInfo, dest_path: str) -> bool:
410
    path_crc = _compute_crc32(dest_path)
411
    return path_crc == member.CRC
412
413
414
def _restore_timestamps(member: zipfile.ZipInfo, dest_path: str) -> None:
415
    date_time = time.mktime(member.date_time + (0, 0, -1))
416
    os.utime(dest_path, (date_time, date_time))
417
418
419
def copy_project_config(src_path: str, force: bool) -> None:
420
    """Copy a given project configuration file to projects.d/ directory."""
421
    project_configs_dest_dir = "projects.d"
422
    os.makedirs(project_configs_dest_dir, exist_ok=True)
423
424
    dest_path = os.path.join(project_configs_dest_dir, os.path.basename(src_path))
425
    if os.path.exists(dest_path) and not force:
426
        if _are_identical_files(src_path, dest_path):
427
            logger.debug(f"Skipping copy to {dest_path}; already in place")
428
        else:
429
            click.echo(f"Not overwriting {dest_path} (use --force to override)")
430
    else:
431
        logger.debug(f"Copying to {dest_path}")
432
        shutil.copy(src_path, dest_path)
433
434
435
def _are_identical_files(src_path: str, dest_path: str) -> bool:
436
    src_crc32 = _compute_crc32(src_path)
437
    dest_crc32 = _compute_crc32(dest_path)
438
    return src_crc32 == dest_crc32
439
440
441
def _compute_crc32(path: str) -> int:
442
    if os.path.isdir(path):
443
        return 0
444
445
    size = 1024 * 1024 * 10  # 10 MiB chunks
446
    with open(path, "rb") as fp:
447
        crcval = 0
448
        while chunk := fp.read(size):
449
            crcval = binascii.crc32(chunk, crcval)
450
    return crcval
451
452
453
def get_vocab_id_from_config(config_path: str) -> str:
454
    """Get the vocabulary ID from a configuration file."""
455
    config = configparser.ConfigParser()
456
    config.read(config_path)
457
    section = config.sections()[0]
458
    return config[section]["vocab"]
459
460
461
def _get_completion_choices(
462
    param: Argument,
463
) -> dict[str, AnnifVocabulary] | dict[str, AnnifProject] | list:
464
    if param.name in ("project_id", "project_ids_pattern"):
465
        return annif.registry.get_projects()
466
    elif param.name == "vocab_id":
467
        return annif.registry.get_vocabs()
468
    else:
469
        return []
470
471
472
def complete_param(ctx: Context, param: Argument, incomplete: str) -> list[str]:
473
    with ctx.obj.load_app().app_context():
474
        return [
475
            choice
476
            for choice in _get_completion_choices(param)
477
            if choice.startswith(incomplete)
478
        ]
479