Passed
Pull Request — main (#762)
by Juho
06:07 queued 03:11
created

annif.cli_util.make_list_template()   A

Complexity

Conditions 3

Size

Total Lines 14
Code Lines 9

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 9
nop 1
dl 0
loc 14
rs 9.95
c 0
b 0
f 0
1
"""Utility functions for Annif CLI commands"""
2
3
from __future__ import annotations
4
5
import binascii
6
import collections
7
import configparser
8
import importlib
9
import io
10
import itertools
11
import os
12
import pathlib
13
import shutil
14
import sys
15
import tempfile
16
import time
17
import zipfile
18
from fnmatch import fnmatch
19
from typing import TYPE_CHECKING
20
21
import click
22
import click_log
23
from flask import current_app
24
25
import annif
26
from annif.exception import ConfigurationException, OperationFailedException
27
from annif.project import Access
28
29
if TYPE_CHECKING:
30
    from datetime import datetime
31
32
    from click.core import Argument, Context, Option
33
34
    from annif.corpus.document import DocumentCorpus, DocumentList
35
    from annif.corpus.subject import SubjectIndex
36
    from annif.project import AnnifProject
37
    from annif.suggestion import SuggestionResult
38
    from annif.vocab import AnnifVocabulary
39
40
logger = annif.logger
41
42
43
def _set_project_config_file_path(
44
    ctx: Context, param: Option, value: str | None
45
) -> None:
46
    """Override the default path or the path given in env by CLI option"""
47
    with ctx.obj.load_app().app_context():
48
        if value:
49
            current_app.config["PROJECTS_CONFIG_PATH"] = value
50
51
52
def common_options(f):
53
    """Decorator to add common options for all CLI commands"""
54
    f = click.option(
55
        "-p",
56
        "--projects",
57
        help="Set path to project configuration file or directory",
58
        type=click.Path(dir_okay=True, exists=True),
59
        callback=_set_project_config_file_path,
60
        expose_value=False,
61
        is_eager=True,
62
    )(f)
63
    return click_log.simple_verbosity_option(logger)(f)
64
65
66
def project_id(f):
67
    """Decorator to add a project ID parameter to a CLI command"""
68
    return click.argument("project_id", shell_complete=complete_param)(f)
69
70
71
def backend_param_option(f):
72
    """Decorator to add an option for CLI commands to override BE parameters"""
73
    return click.option(
74
        "--backend-param",
75
        "-b",
76
        multiple=True,
77
        help="Override backend parameter of the config file. "
78
        + "Syntax: `-b <backend>.<parameter>=<value>`.",
79
    )(f)
80
81
82
def docs_limit_option(f):
83
    """Decorator to add an option for CLI commands to limit the number of documents to
84
    use"""
85
    return click.option(
86
        "--docs-limit",
87
        "-d",
88
        default=None,
89
        type=click.IntRange(0, None),
90
        help="Maximum number of documents to use",
91
    )(f)
92
93
94
def get_project(project_id: str) -> AnnifProject:
95
    """
96
    Helper function to get a project by ID and bail out if it doesn't exist"""
97
    try:
98
        return annif.registry.get_project(project_id, min_access=Access.private)
99
    except ValueError:
100
        click.echo("No projects found with id '{0}'.".format(project_id), err=True)
101
        sys.exit(1)
102
103
104
def get_vocab(vocab_id: str) -> AnnifVocabulary:
105
    """
106
    Helper function to get a vocabulary by ID and bail out if it doesn't
107
    exist"""
108
    try:
109
        return annif.registry.get_vocab(vocab_id, min_access=Access.private)
110
    except ValueError:
111
        click.echo(f"No vocabularies found with the id '{vocab_id}'.", err=True)
112
        sys.exit(1)
113
114
115
def make_list_template(*rows) -> str:
116
    """Helper function to create a template for a list of entries with fields of
117
    variable width. The width of each field is determined by the longest item in the
118
    field in the given rows."""
119
120
    max_field_widths = collections.defaultdict(int)
121
    for row in rows:
122
        for field_ind, item in enumerate(row):
123
            max_field_widths[field_ind] = max(max_field_widths[field_ind], len(item))
124
125
    return "  ".join(
126
        [
127
            f"{{{field_ind}: <{field_width}}}"
128
            for field_ind, field_width in max_field_widths.items()
129
        ]
130
    )
131
132
133
def format_datetime(dt: datetime | None) -> str:
134
    """Helper function to format a datetime object as a string in the local time."""
135
    if dt is None:
136
        return "-"
137
    return dt.astimezone().strftime("%Y-%m-%d %H:%M:%S")
138
139
140
def open_documents(
141
    paths: tuple[str, ...],
142
    subject_index: SubjectIndex,
143
    vocab_lang: str,
144
    docs_limit: int | None,
145
) -> DocumentCorpus:
146
    """Helper function to open a document corpus from a list of pathnames,
147
    each of which is either a TSV file or a directory of TXT files. For
148
    directories with subjects in TSV files, the given vocabulary language
149
    will be used to convert subject labels into URIs. The corpus will be
150
    returned as an instance of DocumentCorpus or LimitingDocumentCorpus."""
151
152
    def open_doc_path(path, subject_index):
153
        """open a single path and return it as a DocumentCorpus"""
154
        if os.path.isdir(path):
155
            return annif.corpus.DocumentDirectory(
156
                path, subject_index, vocab_lang, require_subjects=True
157
            )
158
        return annif.corpus.DocumentFile(path, subject_index)
159
160
    if len(paths) == 0:
161
        logger.warning("Reading empty file")
162
        docs = open_doc_path(os.path.devnull, subject_index)
163
    elif len(paths) == 1:
164
        docs = open_doc_path(paths[0], subject_index)
165
    else:
166
        corpora = [open_doc_path(path, subject_index) for path in paths]
167
        docs = annif.corpus.CombinedCorpus(corpora)
168
    if docs_limit is not None:
169
        docs = annif.corpus.LimitingDocumentCorpus(docs, docs_limit)
170
    return docs
171
172
173
def open_text_documents(paths: tuple[str, ...], docs_limit: int | None) -> DocumentList:
174
    """
175
    Helper function to read text documents from the given file paths. Returns a
176
    DocumentList object with Documents having no subjects. If a path is "-", the
177
    document text is read from standard input. The maximum number of documents to read
178
    is set by docs_limit parameter.
179
    """
180
181
    def _docs(paths):
182
        for path in paths:
183
            if path == "-":
184
                doc = annif.corpus.Document(text=sys.stdin.read(), subject_set=None)
185
            else:
186
                with open(path, errors="replace", encoding="utf-8-sig") as docfile:
187
                    doc = annif.corpus.Document(text=docfile.read(), subject_set=None)
188
            yield doc
189
190
    return annif.corpus.DocumentList(_docs(paths[:docs_limit]))
191
192
193
def show_hits(
194
    hits: SuggestionResult,
195
    project: AnnifProject,
196
    lang: str,
197
    file: io.TextIOWrapper | None = None,
198
) -> None:
199
    """
200
    Print subject suggestions to the console or a file. The suggestions are displayed as
201
    a table, with one row per hit. Each row contains the URI, label, possible notation,
202
    and score of the suggestion. The label is given in the specified language.
203
    """
204
    template = "<{}>\t{}\t{:.04f}"
205
    for hit in hits:
206
        subj = project.subjects[hit.subject_id]
207
        line = template.format(
208
            subj.uri,
209
            "\t".join(filter(None, (subj.labels[lang], subj.notation))),
210
            hit.score,
211
        )
212
        click.echo(line, file=file)
213
214
215
def parse_backend_params(
216
    backend_param: tuple[str, ...] | tuple[()], project: AnnifProject
217
) -> collections.defaultdict[str, dict[str, str]]:
218
    """Parse a list of backend parameters given with the --backend-param
219
    option into a nested dict structure"""
220
    backend_params = collections.defaultdict(dict)
221
    for beparam in backend_param:
222
        backend, param = beparam.split(".", 1)
223
        key, val = param.split("=", 1)
224
        _validate_backend_params(backend, beparam, project)
225
        backend_params[backend][key] = val
226
    return backend_params
227
228
229
def _validate_backend_params(backend: str, beparam: str, project: AnnifProject) -> None:
230
    if backend != project.config["backend"]:
231
        raise ConfigurationException(
232
            'The backend {} in CLI option "-b {}" not matching the project'
233
            " backend {}.".format(backend, beparam, project.config["backend"])
234
        )
235
236
237
def generate_filter_params(filter_batch_max_limit: int) -> list[tuple[int, float]]:
238
    limits = range(1, filter_batch_max_limit + 1)
239
    thresholds = [i * 0.05 for i in range(20)]
240
    return list(itertools.product(limits, thresholds))
241
242
243
def get_matching_projects(pattern: str) -> list[AnnifProject]:
244
    """
245
    Get projects that match the given pattern.
246
    """
247
    return [
248
        proj
249
        for proj in annif.registry.get_projects(min_access=Access.private).values()
250
        if fnmatch(proj.project_id, pattern)
251
    ]
252
253
254
def upload_datadir(
255
    data_dir: str, repo_id: str, token: str, commit_message: str
256
) -> None:
257
    """
258
    Upload a data directory to a Hugging Face Hub repository.
259
    """
260
    zip_repo_path = data_dir.split(os.path.sep, 1)[1] + ".zip"
261
    with _archive_dir(data_dir) as fobj:
262
        _upload_to_hf_hub(fobj, zip_repo_path, repo_id, token, commit_message)
263
264
265
def upload_config(
266
    project: AnnifProject, repo_id: str, token: str, commit_message: str
267
) -> None:
268
    """
269
    Upload a project configuration to a Hugging Face Hub repository.
270
    """
271
    config_repo_path = project.project_id + ".cfg"
272
    with _get_project_config(project) as fobj:
273
        _upload_to_hf_hub(fobj, config_repo_path, repo_id, token, commit_message)
274
275
276
def _is_train_file(fname: str) -> bool:
277
    train_file_patterns = ("-train", "tmp-")
278
    for pat in train_file_patterns:
279
        if pat in fname:
280
            return True
281
    return False
282
283
284
def _archive_dir(data_dir: str) -> io.BufferedRandom:
285
    fp = tempfile.TemporaryFile()
286
    path = pathlib.Path(data_dir)
287
    fpaths = [fpath for fpath in path.glob("**/*") if not _is_train_file(fpath.name)]
288
    with zipfile.ZipFile(fp, mode="w") as zfile:
289
        zfile.comment = bytes(
290
            f"Archived by Annif {importlib.metadata.version('annif')}",
291
            encoding="utf-8",
292
        )
293
        for fpath in fpaths:
294
            logger.debug(f"Adding {fpath}")
295
            arcname = os.path.join(*fpath.parts[1:])
296
            zfile.write(fpath, arcname=arcname)
297
    fp.seek(0)
298
    return fp
299
300
301
def _get_project_config(project: AnnifProject) -> io.BytesIO:
302
    fp = tempfile.TemporaryFile(mode="w+t")
303
    config = configparser.ConfigParser()
304
    config[project.project_id] = project.config
305
    config.write(fp)  # This needs tempfile in text mode
306
    fp.seek(0)
307
    # But for upload fobj needs to be in binary mode
308
    return io.BytesIO(fp.read().encode("utf8"))
309
310
311
def _upload_to_hf_hub(
312
    fileobj: io.BytesIO | io.BufferedRandom,
313
    path_in_repo: str,
314
    repo_id: str,
315
    token: str,
316
    commit_message: str,
317
) -> None:
318
    from huggingface_hub import HfApi
319
    from huggingface_hub.utils import HfHubHTTPError, HFValidationError
320
321
    api = HfApi()
322
    try:
323
        api.upload_file(
324
            path_or_fileobj=fileobj,
325
            path_in_repo=path_in_repo,
326
            repo_id=repo_id,
327
            token=token,
328
            commit_message=commit_message,
329
        )
330
    except (HfHubHTTPError, HFValidationError) as err:
331
        raise OperationFailedException(str(err))
332
333
334
def get_matching_project_ids_from_hf_hub(
335
    project_ids_pattern: str, repo_id: str, token, revision: str
336
) -> list[str]:
337
    """Get project IDs of the projects in a Hugging Face Model Hub repository that match
338
    the given pattern."""
339
    all_repo_file_paths = _list_files_in_hf_hub(repo_id, token, revision)
340
    return [
341
        path.rsplit(".zip")[0].split("projects/")[1]
342
        for path in all_repo_file_paths
343
        if fnmatch(path, f"projects/{project_ids_pattern}.zip")
344
    ]
345
346
347
def _list_files_in_hf_hub(repo_id: str, token: str, revision: str) -> list[str]:
348
    from huggingface_hub import list_repo_files
349
    from huggingface_hub.utils import HfHubHTTPError, HFValidationError
350
351
    try:
352
        return [
353
            repofile
354
            for repofile in list_repo_files(
355
                repo_id=repo_id, token=token, revision=revision
356
            )
357
        ]
358
    except (HfHubHTTPError, HFValidationError) as err:
359
        raise OperationFailedException(str(err))
360
361
362
def download_from_hf_hub(
363
    filename: str, repo_id: str, token: str, revision: str
364
) -> list[str]:
365
    from huggingface_hub import hf_hub_download
366
    from huggingface_hub.utils import HfHubHTTPError, HFValidationError
367
368
    try:
369
        return hf_hub_download(
370
            repo_id=repo_id,
371
            filename=filename,
372
            token=token,
373
            revision=revision,
374
        )
375
    except (HfHubHTTPError, HFValidationError) as err:
376
        raise OperationFailedException(str(err))
377
378
379
def unzip_archive(src_path: str, force: bool) -> None:
380
    """Unzip a zip archive of projects and vocabularies to a directory, by
381
    default data/ under current directory."""
382
    datadir = current_app.config["DATADIR"]
383
    with zipfile.ZipFile(src_path, "r") as zfile:
384
        archive_comment = str(zfile.comment, encoding="utf-8")
385
        logger.debug(
386
            f'Extracting archive {src_path}; archive comment: "{archive_comment}"'
387
        )
388
        for member in zfile.infolist():
389
            _unzip_member(zfile, member, datadir, force)
390
391
392
def _unzip_member(
393
    zfile: zipfile.ZipFile, member: zipfile.ZipInfo, datadir: str, force: bool
394
) -> None:
395
    dest_path = os.path.join(datadir, member.filename)
396
    if os.path.exists(dest_path) and not force:
397
        if _are_identical_member_and_file(member, dest_path):
398
            logger.debug(f"Skipping unzip to {dest_path}; already in place")
399
        else:
400
            click.echo(f"Not overwriting {dest_path} (use --force to override)")
401
    else:
402
        logger.debug(f"Unzipping to {dest_path}")
403
        zfile.extract(member, path=datadir)
404
        _restore_timestamps(member, dest_path)
405
406
407
def _are_identical_member_and_file(member: zipfile.ZipInfo, dest_path: str) -> bool:
408
    path_crc = _compute_crc32(dest_path)
409
    return path_crc == member.CRC
410
411
412
def _restore_timestamps(member: zipfile.ZipInfo, dest_path: str) -> None:
413
    date_time = time.mktime(member.date_time + (0, 0, -1))
414
    os.utime(dest_path, (date_time, date_time))
415
416
417
def copy_project_config(src_path: str, force: bool) -> None:
418
    """Copy a given project configuration file to projects.d/ directory."""
419
    project_configs_dest_dir = "projects.d"
420
    if not os.path.isdir(project_configs_dest_dir):
421
        os.mkdir(project_configs_dest_dir)
422
423
    dest_path = os.path.join(project_configs_dest_dir, os.path.basename(src_path))
424
    if os.path.exists(dest_path) and not force:
425
        if _are_identical_files(src_path, dest_path):
426
            logger.debug(f"Skipping copy to {dest_path}; already in place")
427
        else:
428
            click.echo(f"Not overwriting {dest_path} (use --force to override)")
429
    else:
430
        logger.debug(f"Copying to {dest_path}")
431
        shutil.copy(src_path, dest_path)
432
433
434
def _are_identical_files(src_path: str, dest_path: str) -> bool:
435
    src_crc32 = _compute_crc32(src_path)
436
    dest_crc32 = _compute_crc32(dest_path)
437
    return src_crc32 == dest_crc32
438
439
440
def _compute_crc32(path: str) -> int:
441
    if os.path.isdir(path):
442
        return 0
443
444
    size = 1024 * 1024 * 10  # 10 MiB chunks
445
    with open(path, "rb") as fp:
446
        crcval = 0
447
        while chunk := fp.read(size):
448
            crcval = binascii.crc32(chunk, crcval)
449
    return crcval
450
451
452
def get_vocab_id_from_config(config_path: str) -> str:
453
    """Get the vocabulary ID from a configuration file."""
454
    config = configparser.ConfigParser()
455
    config.read(config_path)
456
    section = config.sections()[0]
457
    return config[section]["vocab"]
458
459
460
def _get_completion_choices(
461
    param: Argument,
462
) -> dict[str, AnnifVocabulary] | dict[str, AnnifProject] | list:
463
    if param.name == "project_id":
464
        return annif.registry.get_projects()
465
    elif param.name == "vocab_id":
466
        return annif.registry.get_vocabs()
467
    else:
468
        return []
469
470
471
def complete_param(ctx: Context, param: Argument, incomplete: str) -> list[str]:
472
    with ctx.obj.load_app().app_context():
473
        return [
474
            choice
475
            for choice in _get_completion_choices(param)
476
            if choice.startswith(incomplete)
477
        ]
478