Passed
Pull Request — master (#640)
by Juho
03:12
created

annif.cli.run_train()   A

Complexity

Conditions 3

Size

Total Lines 46
Code Lines 31

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 31
nop 6
dl 0
loc 46
rs 9.1359
c 0
b 0
f 0
1
"""Definitions for command-line (Click) commands for invoking Annif
2
operations and printing the results to console."""
3
4
5
import collections
6
import os.path
7
import re
8
import sys
9
import json
10
import click
11
import click_log
12
from flask import current_app
13
from flask.cli import FlaskGroup, ScriptInfo
14
import annif
15
import annif.corpus
16
import annif.parallel
17
import annif.project
18
import annif.registry
19
from annif.project import Access
20
from annif.suggestion import SuggestionFilter, ListSuggestionResult
21
from annif.exception import ConfigurationException, NotSupportedException
22
from annif.exception import NotInitializedException
23
from annif.util import metric_code
24
25
logger = annif.logger
26
click_log.basic_config(logger)
27
28
cli = FlaskGroup(create_app=annif.create_app, add_version_option=False)
29
cli = click.version_option(message="%(version)s")(cli)
30
31
32
def get_project(project_id):
33
    """
34
    Helper function to get a project by ID and bail out if it doesn't exist"""
35
    try:
36
        return annif.registry.get_project(project_id, min_access=Access.private)
37
    except ValueError:
38
        click.echo("No projects found with id '{0}'.".format(project_id), err=True)
39
        sys.exit(1)
40
41
42
def get_vocab(vocab_id):
43
    """
44
    Helper function to get a vocabulary by ID and bail out if it doesn't
45
    exist"""
46
    try:
47
        return annif.registry.get_vocab(vocab_id, min_access=Access.private)
48
    except ValueError:
49
        click.echo(f"No vocabularies found with the id '{vocab_id}'.", err=True)
50
        sys.exit(1)
51
52
53
def open_documents(paths, subject_index, vocab_lang, docs_limit):
54
    """Helper function to open a document corpus from a list of pathnames,
55
    each of which is either a TSV file or a directory of TXT files. For
56
    directories with subjects in TSV files, the given vocabulary language
57
    will be used to convert subject labels into URIs. The corpus will be
58
    returned as an instance of DocumentCorpus or LimitingDocumentCorpus."""
59
60
    def open_doc_path(path, subject_index):
61
        """open a single path and return it as a DocumentCorpus"""
62
        if os.path.isdir(path):
63
            return annif.corpus.DocumentDirectory(
64
                path, subject_index, vocab_lang, require_subjects=True
65
            )
66
        return annif.corpus.DocumentFile(path, subject_index)
67
68
    if len(paths) == 0:
69
        logger.warning("Reading empty file")
70
        docs = open_doc_path(os.path.devnull, subject_index)
71
    elif len(paths) == 1:
72
        docs = open_doc_path(paths[0], subject_index)
73
    else:
74
        corpora = [open_doc_path(path, subject_index) for path in paths]
75
        docs = annif.corpus.CombinedCorpus(corpora)
76
    if docs_limit is not None:
77
        docs = annif.corpus.LimitingDocumentCorpus(docs, docs_limit)
78
    return docs
79
80
81
def parse_backend_params(backend_param, project):
82
    """Parse a list of backend parameters given with the --backend-param
83
    option into a nested dict structure"""
84
    backend_params = collections.defaultdict(dict)
85
    for beparam in backend_param:
86
        backend, param = beparam.split(".", 1)
87
        key, val = param.split("=", 1)
88
        validate_backend_params(backend, beparam, project)
89
        backend_params[backend][key] = val
90
    return backend_params
91
92
93
def validate_backend_params(backend, beparam, project):
94
    if backend != project.config["backend"]:
95
        raise ConfigurationException(
96
            'The backend {} in CLI option "-b {}" not matching the project'
97
            " backend {}.".format(backend, beparam, project.config["backend"])
98
        )
99
100
101
BATCH_MAX_LIMIT = 15
102
103
104
def generate_filter_batches(subjects):
105
    import annif.eval
106
107
    filter_batches = collections.OrderedDict()
108
    for limit in range(1, BATCH_MAX_LIMIT + 1):
109
        for threshold in [i * 0.05 for i in range(20)]:
110
            hit_filter = SuggestionFilter(subjects, limit, threshold)
111
            batch = annif.eval.EvaluationBatch(subjects)
112
            filter_batches[(limit, threshold)] = (hit_filter, batch)
113
    return filter_batches
114
115
116
def set_project_config_file_path(ctx, param, value):
117
    """Override the default path or the path given in env by CLI option"""
118
    with ctx.ensure_object(ScriptInfo).load_app().app_context():
119
        if value:
120
            current_app.config["PROJECTS_CONFIG_PATH"] = value
121
122
123
def common_options(f):
124
    """Decorator to add common options for all CLI commands"""
125
    f = click.option(
126
        "-p",
127
        "--projects",
128
        help="Set path to project configuration file or directory",
129
        type=click.Path(dir_okay=True, exists=True),
130
        callback=set_project_config_file_path,
131
        expose_value=False,
132
        is_eager=True,
133
    )(f)
134
    return click_log.simple_verbosity_option(logger)(f)
135
136
137
def backend_param_option(f):
138
    """Decorator to add an option for CLI commands to override BE parameters"""
139
    return click.option(
140
        "--backend-param",
141
        "-b",
142
        multiple=True,
143
        help="Override backend parameter of the config file. "
144
        + "Syntax: `-b <backend>.<parameter>=<value>`.",
145
    )(f)
146
147
148
@cli.command("list-projects")
149
@common_options
150
@click_log.simple_verbosity_option(logger, default="ERROR")
151
def run_list_projects():
152
    """
153
    List available projects.
154
    \f
155
    Show a list of currently defined projects. Projects are defined in a
156
    configuration file, normally called ``projects.cfg``. See `Project
157
    configuration
158
    <https://github.com/NatLibFi/Annif/wiki/Project-configuration>`_
159
    for details.
160
    """
161
162
    template = "{0: <25}{1: <45}{2: <10}{3: <7}"
163
    header = template.format("Project ID", "Project Name", "Language", "Trained")
164
    click.echo(header)
165
    click.echo("-" * len(header))
166
    for proj in annif.registry.get_projects(min_access=Access.private).values():
167
        click.echo(
168
            template.format(
169
                proj.project_id, proj.name, proj.language, str(proj.is_trained)
170
            )
171
        )
172
173
174
@cli.command("show-project")
175
@click.argument("project_id")
176
@common_options
177
def run_show_project(project_id):
178
    """
179
    Show information about a project.
180
    """
181
182
    proj = get_project(project_id)
183
    click.echo(f"Project ID:        {proj.project_id}")
184
    click.echo(f"Project Name:      {proj.name}")
185
    click.echo(f"Language:          {proj.language}")
186
    click.echo(f"Vocabulary:        {proj.vocab.vocab_id}")
187
    click.echo(f"Vocab language:    {proj.vocab_lang}")
188
    click.echo(f"Access:            {proj.access.name}")
189
    click.echo(f"Trained:           {proj.is_trained}")
190
    click.echo(f"Modification time: {proj.modification_time}")
191
192
193
@cli.command("clear")
194
@click.argument("project_id")
195
@common_options
196
def run_clear_project(project_id):
197
    """
198
    Initialize the project to its original, untrained state.
199
    """
200
    proj = get_project(project_id)
201
    proj.remove_model_data()
202
203
204
@cli.command("list-vocabs")
205
@common_options
206
@click_log.simple_verbosity_option(logger, default="ERROR")
207
def run_list_vocabs():
208
    """
209
    List available vocabularies.
210
    """
211
212
    template = "{0: <20}{1: <20}{2: >10}  {3: <6}"
213
    header = template.format("Vocabulary ID", "Languages", "Size", "Loaded")
214
    click.echo(header)
215
    click.echo("-" * len(header))
216
    for vocab in annif.registry.get_vocabs(min_access=Access.private).values():
217
        try:
218
            languages = ",".join(sorted(vocab.languages))
219
            size = len(vocab)
220
            loaded = True
221
        except NotInitializedException:
222
            languages = "-"
223
            size = "-"
224
            loaded = False
225
        click.echo(template.format(vocab.vocab_id, languages, size, str(loaded)))
226
227
228
@cli.command("load-vocab")
229
@click.argument("vocab_id")
230
@click.argument("subjectfile", type=click.Path(exists=True, dir_okay=False))
231
@click.option("--language", "-L", help="Language of subject file")
232
@click.option(
233
    "--force",
234
    "-f",
235
    default=False,
236
    is_flag=True,
237
    help="Replace existing vocabulary completely " + "instead of updating it",
238
)
239
@common_options
240
def run_load_vocab(vocab_id, language, force, subjectfile):
241
    """
242
    Load a vocabulary from a subject file.
243
    """
244
    vocab = get_vocab(vocab_id)
245
    if annif.corpus.SubjectFileSKOS.is_rdf_file(subjectfile):
246
        # SKOS/RDF file supported by rdflib
247
        subjects = annif.corpus.SubjectFileSKOS(subjectfile)
248
        click.echo(f"Loading vocabulary from SKOS file {subjectfile}...")
249
    elif annif.corpus.SubjectFileCSV.is_csv_file(subjectfile):
250
        # CSV file
251
        subjects = annif.corpus.SubjectFileCSV(subjectfile)
252
        click.echo(f"Loading vocabulary from CSV file {subjectfile}...")
253
    else:
254
        # probably a TSV file - we need to know its language
255
        if not language:
256
            click.echo(
257
                "Please use --language option to set the language of "
258
                + "a TSV vocabulary.",
259
                err=True,
260
            )
261
            sys.exit(1)
262
        click.echo(f"Loading vocabulary from TSV file {subjectfile}...")
263
        subjects = annif.corpus.SubjectFileTSV(subjectfile, language)
264
    vocab.load_vocabulary(subjects, force=force)
265
266
267
@cli.command("train")
268
@click.argument("project_id")
269
@click.argument("paths", type=click.Path(exists=True), nargs=-1)
270
@click.option(
271
    "--cached/--no-cached",
272
    "-c/-C",
273
    default=False,
274
    help="Reuse preprocessed training data from previous run",
275
)
276
@click.option(
277
    "--docs-limit",
278
    "-d",
279
    default=None,
280
    type=click.IntRange(0, None),
281
    help="Maximum number of documents to use",
282
)
283
@click.option(
284
    "--jobs",
285
    "-j",
286
    default=0,
287
    help="Number of parallel jobs (0 means choose automatically)",
288
)
289
@backend_param_option
290
@common_options
291
def run_train(project_id, paths, cached, docs_limit, jobs, backend_param):
292
    """
293
    Train a project on a collection of documents.
294
    \f
295
    This will train the project using the documents from ``PATHS`` (directories
296
    or possibly gzipped TSV files) in a single batch operation. If ``--cached``
297
    is set, preprocessed training data from the previous run is reused instead
298
    of documents input; see `Reusing preprocessed training data
299
    <https://github.com/NatLibFi/Annif/wiki/
300
    Reusing-preprocessed-training-data>`_.
301
    """
302
    proj = get_project(project_id)
303
    backend_params = parse_backend_params(backend_param, proj)
304
    if cached:
305
        if len(paths) > 0:
306
            raise click.UsageError(
307
                "Corpus paths cannot be given when using --cached option."
308
            )
309
        documents = "cached"
310
    else:
311
        documents = open_documents(paths, proj.subjects, proj.vocab_lang, docs_limit)
312
    proj.train(documents, backend_params, jobs)
313
314
315
@cli.command("learn")
316
@click.argument("project_id")
317
@click.argument("paths", type=click.Path(exists=True), nargs=-1)
318
@click.option(
319
    "--docs-limit",
320
    "-d",
321
    default=None,
322
    type=click.IntRange(0, None),
323
    help="Maximum number of documents to use",
324
)
325
@backend_param_option
326
@common_options
327
def run_learn(project_id, paths, docs_limit, backend_param):
328
    """
329
    Further train an existing project on a collection of documents.
330
    \f
331
    Similar to the ``train`` command. This will continue training an already
332
    trained project using the documents given by ``PATHS`` in a single batch
333
    operation. Not supported by all backends.
334
    """
335
    proj = get_project(project_id)
336
    backend_params = parse_backend_params(backend_param, proj)
337
    documents = open_documents(paths, proj.subjects, proj.vocab_lang, docs_limit)
338
    proj.learn(documents, backend_params)
339
340
341
@cli.command("suggest")
342
@click.argument("project_id")
343
@click.option("--limit", "-l", default=10, help="Maximum number of subjects")
344
@click.option("--threshold", "-t", default=0.0, help="Minimum score threshold")
345
@click.option("--language", "-L", help="Language of subject labels")
346
@backend_param_option
347
@common_options
348
def run_suggest(project_id, limit, threshold, language, backend_param):
349
    """
350
    Suggest subjects for a single document from standard input.
351
    \f
352
    This will read a text document from standard input and suggest subjects for
353
    it.
354
    """
355
    project = get_project(project_id)
356
    text = sys.stdin.read()
357
    lang = language or project.vocab_lang
358
    if lang not in project.vocab.languages:
359
        raise click.BadParameter(f'language "{lang}" not supported by vocabulary')
360
    backend_params = parse_backend_params(backend_param, project)
361
    hit_filter = SuggestionFilter(project.subjects, limit, threshold)
362
    hits = hit_filter(project.suggest(text, backend_params))
363
    for hit in hits.as_list():
364
        subj = project.subjects[hit.subject_id]
365
        click.echo(
366
            "<{}>\t{}\t{}".format(
367
                subj.uri,
368
                "\t".join(filter(None, (subj.labels[lang], subj.notation))),
369
                hit.score,
370
            )
371
        )
372
373
374
@cli.command("index")
375
@click.argument("project_id")
376
@click.argument("directory", type=click.Path(exists=True, file_okay=False))
377
@click.option(
378
    "--suffix", "-s", default=".annif", help="File name suffix for result files"
379
)
380
@click.option(
381
    "--force/--no-force",
382
    "-f/-F",
383
    default=False,
384
    help="Force overwriting of existing result files",
385
)
386
@click.option("--limit", "-l", default=10, help="Maximum number of subjects")
387
@click.option("--threshold", "-t", default=0.0, help="Minimum score threshold")
388
@click.option("--language", "-L", help="Language of subject labels")
389
@backend_param_option
390
@common_options
391
def run_index(
392
    project_id, directory, suffix, force, limit, threshold, language, backend_param
393
):
394
    """
395
    Index a directory with documents, suggesting subjects for each document.
396
    Write the results in TSV files with the given suffix (``.annif`` by
397
    default).
398
    """
399
    project = get_project(project_id)
400
    lang = language or project.vocab_lang
401
    if lang not in project.vocab.languages:
402
        raise click.BadParameter(f'language "{lang}" not supported by vocabulary')
403
    backend_params = parse_backend_params(backend_param, project)
404
    hit_filter = SuggestionFilter(project.subjects, limit, threshold)
405
406
    for docfilename, dummy_subjectfn in annif.corpus.DocumentDirectory(
407
        directory, project.subjects, project.vocab_lang, require_subjects=False
408
    ):
409
        with open(docfilename, encoding="utf-8-sig") as docfile:
410
            text = docfile.read()
411
        subjectfilename = re.sub(r"\.txt$", suffix, docfilename)
412
        if os.path.exists(subjectfilename) and not force:
413
            click.echo(
414
                "Not overwriting {} (use --force to override)".format(subjectfilename)
415
            )
416
            continue
417
        with open(subjectfilename, "w", encoding="utf-8") as subjfile:
418
            results = project.suggest(text, backend_params)
419
            for hit in hit_filter(results).as_list():
420
                subj = project.subjects[hit.subject_id]
421
                line = "<{}>\t{}\t{}".format(
422
                    subj.uri,
423
                    "\t".join(filter(None, (subj.labels[lang], subj.notation))),
424
                    hit.score,
425
                )
426
                click.echo(line, file=subjfile)
427
428
429
@cli.command("eval")
430
@click.argument("project_id")
431
@click.argument("paths", type=click.Path(exists=True), nargs=-1)
432
@click.option("--limit", "-l", default=10, help="Maximum number of subjects")
433
@click.option("--threshold", "-t", default=0.0, help="Minimum score threshold")
434
@click.option(
435
    "--docs-limit",
436
    "-d",
437
    default=None,
438
    type=click.IntRange(0, None),
439
    help="Maximum number of documents to use",
440
)
441
@click.option(
442
    "--metric",
443
    "-m",
444
    default=[],
445
    multiple=True,
446
    help="Metric to calculate (default: all)",
447
)
448
@click.option(
449
    "--metrics-file",
450
    "-M",
451
    type=click.File("w", encoding="utf-8", errors="ignore", lazy=True),
452
    help="""Specify file in order to write evaluation metrics in JSON format.
453
    File directory must exist, existing file will be overwritten.""",
454
)
455
@click.option(
456
    "--results-file",
457
    "-r",
458
    type=click.File("w", encoding="utf-8", errors="ignore", lazy=True),
459
    help="""Specify file in order to write non-aggregated results per subject.
460
    File directory must exist, existing file will be overwritten.""",
461
)
462
@click.option(
463
    "--jobs", "-j", default=1, help="Number of parallel jobs (0 means all CPUs)"
464
)
465
@backend_param_option
466
@common_options
467
def run_eval(
468
    project_id,
469
    paths,
470
    limit,
471
    threshold,
472
    docs_limit,
473
    metric,
474
    metrics_file,
475
    results_file,
476
    jobs,
477
    backend_param,
478
):
479
    """
480
    Suggest subjects for documents and evaluate the results by comparing
481
    against a gold standard.
482
    \f
483
    With this command the documents from ``PATHS`` (directories or possibly
484
    gzipped TSV files) will be assigned subject suggestions and then
485
    statistical measures are calculated that quantify how well the suggested
486
    subjects match the gold-standard subjects in the documents.
487
488
    Normally the output is the list of the metrics calculated across documents.
489
    If ``--results-file <FILENAME>`` option is given, the metrics are
490
    calculated separately for each subject, and written to the given file.
491
    """
492
493
    project = get_project(project_id)
494
    backend_params = parse_backend_params(backend_param, project)
495
496
    import annif.eval
497
498
    eval_batch = annif.eval.EvaluationBatch(project.subjects)
499
500
    if results_file:
501
        try:
502
            print("", end="", file=results_file)
503
            click.echo(
504
                "Writing per subject evaluation results to {!s}".format(
505
                    results_file.name
506
                )
507
            )
508
        except Exception as e:
509
            raise NotSupportedException(
510
                "cannot open results-file for writing: " + str(e)
511
            )
512
    docs = open_documents(paths, project.subjects, project.vocab_lang, docs_limit)
513
514
    jobs, pool_class = annif.parallel.get_pool(jobs)
515
516
    project.initialize(parallel=True)
517
    psmap = annif.parallel.ProjectSuggestMap(
518
        project.registry, [project_id], backend_params, limit, threshold
519
    )
520
521
    with pool_class(jobs) as pool:
522
        for hits, subject_set in pool.imap_unordered(psmap.suggest, docs.documents):
523
            eval_batch.evaluate(hits[project_id], subject_set)
524
525
    template = "{0:<30}\t{1}"
526
    metrics = eval_batch.results(
527
        metrics=metric, results_file=results_file, language=project.vocab_lang
528
    )
529
    for metric, score in metrics.items():
530
        click.echo(template.format(metric + ":", score))
531
    if metrics_file:
532
        json.dump(
533
            {metric_code(mname): val for mname, val in metrics.items()},
534
            metrics_file,
535
            indent=2,
536
        )
537
538
539
@cli.command("optimize")
540
@click.argument("project_id")
541
@click.argument("paths", type=click.Path(exists=True), nargs=-1)
542
@click.option(
543
    "--docs-limit",
544
    "-d",
545
    default=None,
546
    type=click.IntRange(0, None),
547
    help="Maximum number of documents to use",
548
)
549
@backend_param_option
550
@common_options
551
def run_optimize(project_id, paths, docs_limit, backend_param):
552
    """
553
    Suggest subjects for documents, testing multiple limits and thresholds.
554
    \f
555
    This command will use different limit (maximum number of subjects) and
556
    score threshold values when assigning subjects to each document given by
557
    ``PATHS`` and compare the results against the gold standard subjects in the
558
    documents. The output is a list of parameter combinations and their scores.
559
    From the output, you can determine the optimum limit and threshold
560
    parameters depending on which measure you want to target.
561
    """
562
    project = get_project(project_id)
563
    backend_params = parse_backend_params(backend_param, project)
564
565
    filter_batches = generate_filter_batches(project.subjects)
566
567
    ndocs = 0
568
    docs = open_documents(paths, project.subjects, project.vocab_lang, docs_limit)
569
    for doc in docs.documents:
570
        raw_hits = project.suggest(doc.text, backend_params)
571
        hits = raw_hits.filter(project.subjects, limit=BATCH_MAX_LIMIT)
572
        assert isinstance(hits, ListSuggestionResult), (
573
            "Optimize should only be done with ListSuggestionResult "
574
            + "as it would be very slow with VectorSuggestionResult."
575
        )
576
        for hit_filter, batch in filter_batches.values():
577
            batch.evaluate(hit_filter(hits), doc.subject_set)
578
        ndocs += 1
579
580
    click.echo("\t".join(("Limit", "Thresh.", "Prec.", "Rec.", "F1")))
581
582
    best_scores = collections.defaultdict(float)
583
    best_params = {}
584
585
    template = "{:d}\t{:.02f}\t{:.04f}\t{:.04f}\t{:.04f}"
586
    # Store the batches in a list that gets consumed along the way
587
    # This way GC will have a chance to reclaim the memory
588
    filter_batches = list(filter_batches.items())
589
    while filter_batches:
590
        params, filter_batch = filter_batches.pop(0)
591
        metrics = ["Precision (doc avg)", "Recall (doc avg)", "F1 score (doc avg)"]
592
        results = filter_batch[1].results(metrics=metrics)
593
        for metric, score in results.items():
594
            if score >= best_scores[metric]:
595
                best_scores[metric] = score
596
                best_params[metric] = params
597
        click.echo(
598
            template.format(
599
                params[0],
600
                params[1],
601
                results["Precision (doc avg)"],
602
                results["Recall (doc avg)"],
603
                results["F1 score (doc avg)"],
604
            )
605
        )
606
607
    click.echo()
608
    template2 = "Best {:>19}: {:.04f}\tLimit: {:d}\tThreshold: {:.02f}"
609
    for metric in metrics:
610
        click.echo(
611
            template2.format(
612
                metric,
613
                best_scores[metric],
614
                best_params[metric][0],
615
                best_params[metric][1],
616
            )
617
        )
618
    click.echo("Documents evaluated:\t{}".format(ndocs))
619
620
621
@cli.command("hyperopt")
622
@click.argument("project_id")
623
@click.argument("paths", type=click.Path(exists=True), nargs=-1)
624
@click.option(
625
    "--docs-limit",
626
    "-d",
627
    default=None,
628
    type=click.IntRange(0, None),
629
    help="Maximum number of documents to use",
630
)
631
@click.option("--trials", "-T", default=10, help="Number of trials")
632
@click.option(
633
    "--jobs", "-j", default=1, help="Number of parallel runs (0 means all CPUs)"
634
)
635
@click.option(
636
    "--metric", "-m", default="NDCG", help="Metric to optimize (default: NDCG)"
637
)
638
@click.option(
639
    "--results-file",
640
    "-r",
641
    type=click.File("w", encoding="utf-8", errors="ignore", lazy=True),
642
    help="""Specify file path to write trial results as CSV.
643
    File directory must exist, existing file will be overwritten.""",
644
)
645
@common_options
646
def run_hyperopt(project_id, paths, docs_limit, trials, jobs, metric, results_file):
647
    """
648
    Optimize the hyperparameters of a project using validation documents from
649
    ``PATHS``. Not supported by all backends. Output is a list of trial results
650
    and a report of the best performing parameters.
651
    """
652
    proj = get_project(project_id)
653
    documents = open_documents(paths, proj.subjects, proj.vocab_lang, docs_limit)
654
    click.echo(f"Looking for optimal hyperparameters using {trials} trials")
655
    rec = proj.hyperopt(documents, trials, jobs, metric, results_file)
656
    click.echo(f"Got best {metric} score {rec.score:.4f} with:")
657
    click.echo("---")
658
    for line in rec.lines:
659
        click.echo(line)
660
    click.echo("---")
661
662
663
if __name__ == "__main__":
664
    cli()
665