Passed
Push — issue684-cli-command-completio... ( 9f99ab...c704ce )
by Juho
02:46
created

annif.cli.run_learn()   A

Complexity

Conditions 1

Size

Total Lines 20
Code Lines 12

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 12
nop 4
dl 0
loc 20
rs 9.8
c 0
b 0
f 0
1
"""Definitions for command-line (Click) commands for invoking Annif
2
operations and printing the results to console."""
3
4
5
import collections
6
import json
7
import os.path
8
import re
9
import sys
10
11
import click
12
import click_log
13
from flask.cli import FlaskGroup
14
15
import annif
16
import annif.corpus
17
import annif.parallel
18
import annif.project
19
import annif.registry
20
from annif import cli_util
21
from annif.exception import NotInitializedException, NotSupportedException
22
from annif.project import Access
23
from annif.suggestion import ListSuggestionResult, SuggestionFilter
24
from annif.util import metric_code
25
26
logger = annif.logger
27
click_log.basic_config(logger)
28
29
cli = FlaskGroup(create_app=annif.create_app, add_version_option=False)
30
cli = click.version_option(message="%(version)s")(cli)
31
32
33
@cli.command("list-projects")
34
@cli_util.common_options
35
@click_log.simple_verbosity_option(logger, default="ERROR")
36
def run_list_projects():
37
    """
38
    List available projects.
39
    \f
40
    Show a list of currently defined projects. Projects are defined in a
41
    configuration file, normally called ``projects.cfg``. See `Project
42
    configuration
43
    <https://github.com/NatLibFi/Annif/wiki/Project-configuration>`_
44
    for details.
45
    """
46
47
    template = "{0: <25}{1: <45}{2: <10}{3: <7}"
48
    header = template.format("Project ID", "Project Name", "Language", "Trained")
49
    click.echo(header)
50
    click.echo("-" * len(header))
51
    for proj in annif.registry.get_projects(min_access=Access.private).values():
52
        click.echo(
53
            template.format(
54
                proj.project_id, proj.name, proj.language, str(proj.is_trained)
55
            )
56
        )
57
58
59
@cli.command("show-project")
60
@click.argument("project_id", shell_complete=cli_util.complete_project_id)
61
@cli_util.common_options
62
def run_show_project(project_id):
63
    """
64
    Show information about a project.
65
    """
66
67
    proj = cli_util.get_project(project_id)
68
    click.echo(f"Project ID:        {proj.project_id}")
69
    click.echo(f"Project Name:      {proj.name}")
70
    click.echo(f"Language:          {proj.language}")
71
    click.echo(f"Vocabulary:        {proj.vocab.vocab_id}")
72
    click.echo(f"Vocab language:    {proj.vocab_lang}")
73
    click.echo(f"Access:            {proj.access.name}")
74
    click.echo(f"Trained:           {proj.is_trained}")
75
    click.echo(f"Modification time: {proj.modification_time}")
76
77
78
@cli.command("clear")
79
@click.argument("project_id", shell_complete=cli_util.complete_project_id)
80
@cli_util.common_options
81
def run_clear_project(project_id):
82
    """
83
    Initialize the project to its original, untrained state.
84
    """
85
    proj = cli_util.get_project(project_id)
86
    proj.remove_model_data()
87
88
89
@cli.command("list-vocabs")
90
@cli_util.common_options
91
@click_log.simple_verbosity_option(logger, default="ERROR")
92
def run_list_vocabs():
93
    """
94
    List available vocabularies.
95
    """
96
97
    template = "{0: <20}{1: <20}{2: >10}  {3: <6}"
98
    header = template.format("Vocabulary ID", "Languages", "Size", "Loaded")
99
    click.echo(header)
100
    click.echo("-" * len(header))
101
    for vocab in annif.registry.get_vocabs(min_access=Access.private).values():
102
        try:
103
            languages = ",".join(sorted(vocab.languages))
104
            size = len(vocab)
105
            loaded = True
106
        except NotInitializedException:
107
            languages = "-"
108
            size = "-"
109
            loaded = False
110
        click.echo(template.format(vocab.vocab_id, languages, size, str(loaded)))
111
112
113
@cli.command("load-vocab")
114
@click.argument("vocab_id", shell_complete=cli_util.complete_vocab_id)
115
@click.argument("subjectfile", type=click.Path(exists=True, dir_okay=False))
116
@click.option("--language", "-L", help="Language of subject file")
117
@click.option(
118
    "--force",
119
    "-f",
120
    default=False,
121
    is_flag=True,
122
    help="Replace existing vocabulary completely instead of updating it",
123
)
124
@cli_util.common_options
125
def run_load_vocab(vocab_id, language, force, subjectfile):
126
    """
127
    Load a vocabulary from a subject file.
128
    """
129
    vocab = cli_util.get_vocab(vocab_id)
130
    if annif.corpus.SubjectFileSKOS.is_rdf_file(subjectfile):
131
        # SKOS/RDF file supported by rdflib
132
        subjects = annif.corpus.SubjectFileSKOS(subjectfile)
133
        click.echo(f"Loading vocabulary from SKOS file {subjectfile}...")
134
    elif annif.corpus.SubjectFileCSV.is_csv_file(subjectfile):
135
        # CSV file
136
        subjects = annif.corpus.SubjectFileCSV(subjectfile)
137
        click.echo(f"Loading vocabulary from CSV file {subjectfile}...")
138
    else:
139
        # probably a TSV file - we need to know its language
140
        if not language:
141
            click.echo(
142
                "Please use --language option to set the language of a TSV vocabulary.",
143
                err=True,
144
            )
145
            sys.exit(1)
146
        click.echo(f"Loading vocabulary from TSV file {subjectfile}...")
147
        subjects = annif.corpus.SubjectFileTSV(subjectfile, language)
148
    vocab.load_vocabulary(subjects, force=force)
149
150
151
@cli.command("train")
152
@click.argument("project_id", shell_complete=cli_util.complete_project_id)
153
@click.argument("paths", type=click.Path(exists=True), nargs=-1)
154
@click.option(
155
    "--cached/--no-cached",
156
    "-c/-C",
157
    default=False,
158
    help="Reuse preprocessed training data from previous run",
159
)
160
@click.option(
161
    "--jobs",
162
    "-j",
163
    default=0,
164
    help="Number of parallel jobs (0 means choose automatically)",
165
)
166
@cli_util.docs_limit_option
167
@cli_util.backend_param_option
168
@cli_util.common_options
169
def run_train(project_id, paths, cached, docs_limit, jobs, backend_param):
170
    """
171
    Train a project on a collection of documents.
172
    \f
173
    This will train the project using the documents from ``PATHS`` (directories
174
    or possibly gzipped TSV files) in a single batch operation. If ``--cached``
175
    is set, preprocessed training data from the previous run is reused instead
176
    of documents input; see `Reusing preprocessed training data
177
    <https://github.com/NatLibFi/Annif/wiki/
178
    Reusing-preprocessed-training-data>`_.
179
    """
180
    proj = cli_util.get_project(project_id)
181
    backend_params = cli_util.parse_backend_params(backend_param, proj)
182
    if cached:
183
        if len(paths) > 0:
184
            raise click.UsageError(
185
                "Corpus paths cannot be given when using --cached option."
186
            )
187
        documents = "cached"
188
    else:
189
        documents = cli_util.open_documents(
190
            paths, proj.subjects, proj.vocab_lang, docs_limit
191
        )
192
    proj.train(documents, backend_params, jobs)
193
194
195
@cli.command("learn")
196
@click.argument("project_id", shell_complete=cli_util.complete_project_id)
197
@click.argument("paths", type=click.Path(exists=True), nargs=-1)
198
@cli_util.docs_limit_option
199
@cli_util.backend_param_option
200
@cli_util.common_options
201
def run_learn(project_id, paths, docs_limit, backend_param):
202
    """
203
    Further train an existing project on a collection of documents.
204
    \f
205
    Similar to the ``train`` command. This will continue training an already
206
    trained project using the documents given by ``PATHS`` in a single batch
207
    operation. Not supported by all backends.
208
    """
209
    proj = cli_util.get_project(project_id)
210
    backend_params = cli_util.parse_backend_params(backend_param, proj)
211
    documents = cli_util.open_documents(
212
        paths, proj.subjects, proj.vocab_lang, docs_limit
213
    )
214
    proj.learn(documents, backend_params)
215
216
217
@cli.command("suggest")
218
@click.argument("project_id", shell_complete=cli_util.complete_project_id)
219
@click.argument(
220
    "paths", type=click.Path(dir_okay=False, exists=True, allow_dash=True), nargs=-1
221
)
222
@click.option("--limit", "-l", default=10, help="Maximum number of subjects")
223
@click.option("--threshold", "-t", default=0.0, help="Minimum score threshold")
224
@click.option("--language", "-L", help="Language of subject labels")
225
@cli_util.docs_limit_option
226
@cli_util.backend_param_option
227
@cli_util.common_options
228
def run_suggest(
229
    project_id, paths, limit, threshold, language, backend_param, docs_limit
230
):
231
    """
232
    Suggest subjects for a single document from standard input or for one or more
233
    document file(s) given its/their path(s).
234
    \f
235
    This will read a text document from standard input and suggest subjects for
236
    it, or if given path(s) to file(s), suggest subjects for it/them.
237
    """
238
    project = cli_util.get_project(project_id)
239
    lang = language or project.vocab_lang
240
    if lang not in project.vocab.languages:
241
        raise click.BadParameter(f'language "{lang}" not supported by vocabulary')
242
    backend_params = cli_util.parse_backend_params(backend_param, project)
243
    hit_filter = SuggestionFilter(project.subjects, limit, threshold)
244
245
    if paths and not (len(paths) == 1 and paths[0] == "-"):
246
        docs = cli_util.open_text_documents(paths, docs_limit)
247
        subject_sets = project.suggest_corpus(docs, backend_params)
248
        for (
249
            subjects,
250
            path,
251
        ) in zip(subject_sets, paths):
252
            click.echo(f"Suggestions for {path}")
253
            hits = hit_filter(subjects)
254
            cli_util.show_hits(hits, project, lang)
255
    else:
256
        text = sys.stdin.read()
257
        hits = hit_filter(project.suggest([text], backend_params)[0])
258
        cli_util.show_hits(hits, project, lang)
259
260
261
@cli.command("index")
262
@click.argument("project_id", shell_complete=cli_util.complete_project_id)
263
@click.argument("directory", type=click.Path(exists=True, file_okay=False))
264
@click.option(
265
    "--suffix", "-s", default=".annif", help="File name suffix for result files"
266
)
267
@click.option(
268
    "--force/--no-force",
269
    "-f/-F",
270
    default=False,
271
    help="Force overwriting of existing result files",
272
)
273
@click.option("--limit", "-l", default=10, help="Maximum number of subjects")
274
@click.option("--threshold", "-t", default=0.0, help="Minimum score threshold")
275
@click.option("--language", "-L", help="Language of subject labels")
276
@cli_util.backend_param_option
277
@cli_util.common_options
278
def run_index(
279
    project_id, directory, suffix, force, limit, threshold, language, backend_param
280
):
281
    """
282
    Index a directory with documents, suggesting subjects for each document.
283
    Write the results in TSV files with the given suffix (``.annif`` by
284
    default).
285
    """
286
    project = cli_util.get_project(project_id)
287
    lang = language or project.vocab_lang
288
    if lang not in project.vocab.languages:
289
        raise click.BadParameter(f'language "{lang}" not supported by vocabulary')
290
    backend_params = cli_util.parse_backend_params(backend_param, project)
291
    hit_filter = SuggestionFilter(project.subjects, limit, threshold)
292
293
    documents = annif.corpus.DocumentDirectory(
294
        directory, None, None, require_subjects=False
295
    )
296
    subject_sets = project.suggest_corpus(documents, backend_params)
297
298
    for (docfilename, _), subjects in zip(documents, subject_sets):
299
        subjectfilename = re.sub(r"\.txt$", suffix, docfilename)
300
        if os.path.exists(subjectfilename) and not force:
301
            click.echo(
302
                "Not overwriting {} (use --force to override)".format(subjectfilename)
303
            )
304
            continue
305
        hits = hit_filter(subjects)
306
        with open(subjectfilename, "w", encoding="utf-8") as subjfile:
307
            cli_util.show_hits(hits, project, lang, file=subjfile)
308
309
310
@cli.command("eval")
311
@click.argument("project_id", shell_complete=cli_util.complete_project_id)
312
@click.argument("paths", type=click.Path(exists=True), nargs=-1)
313
@click.option("--limit", "-l", default=10, help="Maximum number of subjects")
314
@click.option("--threshold", "-t", default=0.0, help="Minimum score threshold")
315
@click.option(
316
    "--metric",
317
    "-m",
318
    default=[],
319
    multiple=True,
320
    help="Metric to calculate (default: all)",
321
)
322
@click.option(
323
    "--metrics-file",
324
    "-M",
325
    type=click.File("w", encoding="utf-8", errors="ignore", lazy=True),
326
    help="""Specify file in order to write evaluation metrics in JSON format.
327
    File directory must exist, existing file will be overwritten.""",
328
)
329
@click.option(
330
    "--results-file",
331
    "-r",
332
    type=click.File("w", encoding="utf-8", errors="ignore", lazy=True),
333
    help="""Specify file in order to write non-aggregated results per subject.
334
    File directory must exist, existing file will be overwritten.""",
335
)
336
@click.option(
337
    "--jobs", "-j", default=1, help="Number of parallel jobs (0 means all CPUs)"
338
)
339
@cli_util.docs_limit_option
340
@cli_util.backend_param_option
341
@cli_util.common_options
342
def run_eval(
343
    project_id,
344
    paths,
345
    limit,
346
    threshold,
347
    docs_limit,
348
    metric,
349
    metrics_file,
350
    results_file,
351
    jobs,
352
    backend_param,
353
):
354
    """
355
    Suggest subjects for documents and evaluate the results by comparing
356
    against a gold standard.
357
    \f
358
    With this command the documents from ``PATHS`` (directories or possibly
359
    gzipped TSV files) will be assigned subject suggestions and then
360
    statistical measures are calculated that quantify how well the suggested
361
    subjects match the gold-standard subjects in the documents.
362
363
    Normally the output is the list of the metrics calculated across documents.
364
    If ``--results-file <FILENAME>`` option is given, the metrics are
365
    calculated separately for each subject, and written to the given file.
366
    """
367
368
    project = cli_util.get_project(project_id)
369
    backend_params = cli_util.parse_backend_params(backend_param, project)
370
371
    import annif.eval
372
373
    eval_batch = annif.eval.EvaluationBatch(project.subjects)
374
375
    if results_file:
376
        try:
377
            print("", end="", file=results_file)
378
            click.echo(
379
                "Writing per subject evaluation results to {!s}".format(
380
                    results_file.name
381
                )
382
            )
383
        except Exception as e:
384
            raise NotSupportedException(
385
                "cannot open results-file for writing: " + str(e)
386
            )
387
    corpus = cli_util.open_documents(
388
        paths, project.subjects, project.vocab_lang, docs_limit
389
    )
390
    jobs, pool_class = annif.parallel.get_pool(jobs)
391
392
    project.initialize(parallel=True)
393
    psmap = annif.parallel.ProjectSuggestMap(
394
        project.registry, [project_id], backend_params, limit, threshold
395
    )
396
397
    with pool_class(jobs) as pool:
398
        for hit_sets, subject_sets in pool.imap_unordered(
399
            psmap.suggest_batch, corpus.doc_batches
400
        ):
401
            eval_batch.evaluate_many(hit_sets[project_id], subject_sets)
402
403
    template = "{0:<30}\t{1}"
404
    metrics = eval_batch.results(
405
        metrics=metric, results_file=results_file, language=project.vocab_lang
406
    )
407
    for metric, score in metrics.items():
408
        click.echo(template.format(metric + ":", score))
409
    if metrics_file:
410
        json.dump(
411
            {metric_code(mname): val for mname, val in metrics.items()},
412
            metrics_file,
413
            indent=2,
414
        )
415
416
417
FILTER_BATCH_MAX_LIMIT = 15
418
419
420
@cli.command("optimize")
421
@click.argument("project_id", shell_complete=cli_util.complete_project_id)
422
@click.argument("paths", type=click.Path(exists=True), nargs=-1)
423
@cli_util.docs_limit_option
424
@cli_util.backend_param_option
425
@cli_util.common_options
426
def run_optimize(project_id, paths, docs_limit, backend_param):
427
    """
428
    Suggest subjects for documents, testing multiple limits and thresholds.
429
    \f
430
    This command will use different limit (maximum number of subjects) and
431
    score threshold values when assigning subjects to each document given by
432
    ``PATHS`` and compare the results against the gold standard subjects in the
433
    documents. The output is a list of parameter combinations and their scores.
434
    From the output, you can determine the optimum limit and threshold
435
    parameters depending on which measure you want to target.
436
    """
437
    project = cli_util.get_project(project_id)
438
    backend_params = cli_util.parse_backend_params(backend_param, project)
439
440
    filter_batches = cli_util.generate_filter_batches(
441
        project.subjects, FILTER_BATCH_MAX_LIMIT
442
    )
443
444
    ndocs = 0
445
    corpus = cli_util.open_documents(
446
        paths, project.subjects, project.vocab_lang, docs_limit
447
    )
448
    for docs_batch in corpus.doc_batches:
449
        texts, subject_sets = zip(*[(doc.text, doc.subject_set) for doc in docs_batch])
450
        raw_hit_sets = project.suggest(texts, backend_params)
451
        hit_sets = [
452
            raw_hits.filter(project.subjects, limit=FILTER_BATCH_MAX_LIMIT)
453
            for raw_hits in raw_hit_sets
454
        ]
455
        assert isinstance(hit_sets[0], ListSuggestionResult), (
456
            "Optimize should only be done with ListSuggestionResult "
457
            + "as it would be very slow with VectorSuggestionResult."
458
        )
459
        for hit_filter, filter_batch in filter_batches.values():
460
            filtered_hits = [hit_filter(hits) for hits in hit_sets]
461
            filter_batch.evaluate_many(filtered_hits, subject_sets)
462
        ndocs += len(texts)
463
464
    click.echo("\t".join(("Limit", "Thresh.", "Prec.", "Rec.", "F1")))
465
466
    best_scores = collections.defaultdict(float)
467
    best_params = {}
468
469
    template = "{:d}\t{:.02f}\t{:.04f}\t{:.04f}\t{:.04f}"
470
    # Store the batches in a list that gets consumed along the way
471
    # This way GC will have a chance to reclaim the memory
472
    filter_batches = list(filter_batches.items())
473
    while filter_batches:
474
        params, filter_batch = filter_batches.pop(0)
475
        metrics = ["Precision (doc avg)", "Recall (doc avg)", "F1 score (doc avg)"]
476
        results = filter_batch[1].results(metrics=metrics)
477
        for metric, score in results.items():
478
            if score >= best_scores[metric]:
479
                best_scores[metric] = score
480
                best_params[metric] = params
481
        click.echo(
482
            template.format(
483
                params[0],
484
                params[1],
485
                results["Precision (doc avg)"],
486
                results["Recall (doc avg)"],
487
                results["F1 score (doc avg)"],
488
            )
489
        )
490
491
    click.echo()
492
    template2 = "Best {:>19}: {:.04f}\tLimit: {:d}\tThreshold: {:.02f}"
493
    for metric in metrics:
0 ignored issues
show
introduced by
The variable metrics does not seem to be defined in case the while loop on line 473 is not entered. Are you sure this can never be the case?
Loading history...
494
        click.echo(
495
            template2.format(
496
                metric,
497
                best_scores[metric],
498
                best_params[metric][0],
499
                best_params[metric][1],
500
            )
501
        )
502
    click.echo("Documents evaluated:\t{}".format(ndocs))
503
504
505
@cli.command("hyperopt")
506
@click.argument("project_id", shell_complete=cli_util.complete_project_id)
507
@click.argument("paths", type=click.Path(exists=True), nargs=-1)
508
@click.option("--trials", "-T", default=10, help="Number of trials")
509
@click.option(
510
    "--jobs", "-j", default=1, help="Number of parallel runs (0 means all CPUs)"
511
)
512
@click.option(
513
    "--metric", "-m", default="NDCG", help="Metric to optimize (default: NDCG)"
514
)
515
@click.option(
516
    "--results-file",
517
    "-r",
518
    type=click.File("w", encoding="utf-8", errors="ignore", lazy=True),
519
    help="""Specify file path to write trial results as CSV.
520
    File directory must exist, existing file will be overwritten.""",
521
)
522
@cli_util.docs_limit_option
523
@cli_util.common_options
524
def run_hyperopt(project_id, paths, docs_limit, trials, jobs, metric, results_file):
525
    """
526
    Optimize the hyperparameters of a project using validation documents from
527
    ``PATHS``. Not supported by all backends. Output is a list of trial results
528
    and a report of the best performing parameters.
529
    """
530
    proj = cli_util.get_project(project_id)
531
    documents = cli_util.open_documents(
532
        paths, proj.subjects, proj.vocab_lang, docs_limit
533
    )
534
    click.echo(f"Looking for optimal hyperparameters using {trials} trials")
535
    rec = proj.hyperopt(documents, trials, jobs, metric, results_file)
536
    click.echo(f"Got best {metric} score {rec.score:.4f} with:")
537
    click.echo("---")
538
    for line in rec.lines:
539
        click.echo(line)
540
    click.echo("---")
541
542
543
@cli.command("completion")
544
@click.option("--bash", "shell", flag_value="bash")
545
@click.option("--zsh", "shell", flag_value="zsh")
546
@click.option("--fish", "shell", flag_value="fish")
547
def completion(shell):
548
    """Generate autocompletion script for the given shell."""
549
    import datetime
550
551
    if shell is None:
552
        click.echo("Shell not specified, try --bash, --zsh or --fish")
553
        return
554
555
    now = datetime.datetime.now()
556
    click.echo(f"# Generated by Annif on {now.strftime('%Y-%m-%d %H:%M:%S')}")
557
558
    if shell == "bash":
559
        script = os.popen("_ANNIF_COMPLETE=bash_source annif").read()
560
    elif shell == "zsh":
561
        script = os.popen("_ANNIF_COMPLETE=zsh_source annif").read()
562
    elif shell == "fish":
563
        script = os.popen("_ANNIF_COMPLETE=fish_source annif").read()
564
    click.echo(script)
0 ignored issues
show
introduced by
The variable script does not seem to be defined for all execution paths.
Loading history...
565
566
567
if __name__ == "__main__":
568
    cli()
569