Passed
Pull Request — master (#414)
by Osma
02:11
created

annif.cli.run_hyperopt()   A

Complexity

Conditions 2

Size

Total Lines 22
Code Lines 19

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 19
dl 0
loc 22
rs 9.45
c 0
b 0
f 0
cc 2
nop 5
1
"""Definitions for command-line (Click) commands for invoking Annif
2
operations and printing the results to console."""
3
4
5
import collections
6
import multiprocessing
7
import multiprocessing.dummy
8
import os.path
9
import re
10
import sys
11
import click
12
import click_log
13
from flask import current_app
14
from flask.cli import FlaskGroup, ScriptInfo
15
import annif
16
import annif.corpus
17
import annif.eval
18
import annif.project
19
import annif.registry
20
from annif.project import Access
21
from annif.suggestion import SuggestionFilter
22
from annif.exception import ConfigurationException, NotSupportedException
23
24
logger = annif.logger
25
click_log.basic_config(logger)
26
27
cli = FlaskGroup(create_app=annif.create_app)
28
29
30
def get_project(project_id):
31
    """
32
    Helper function to get a project by ID and bail out if it doesn't exist"""
33
    try:
34
        return annif.registry.get_project(project_id, min_access=Access.hidden)
35
    except ValueError:
36
        click.echo(
37
            "No projects found with id \'{0}\'.".format(project_id),
38
            err=True)
39
        sys.exit(1)
40
41
42
def open_documents(paths):
43
    """Helper function to open a document corpus from a list of pathnames,
44
    each of which is either a TSV file or a directory of TXT files. The
45
    corpus will be returned as an instance of DocumentCorpus."""
46
47
    def open_doc_path(path):
48
        """open a single path and return it as a DocumentCorpus"""
49
        if os.path.isdir(path):
50
            return annif.corpus.DocumentDirectory(path, require_subjects=True)
51
        return annif.corpus.DocumentFile(path)
52
53
    if len(paths) == 0:
54
        logger.warning('Reading empty file')
55
        docs = open_doc_path(os.path.devnull)
56
    elif len(paths) == 1:
57
        docs = open_doc_path(paths[0])
58
    else:
59
        corpora = [open_doc_path(path) for path in paths]
60
        docs = annif.corpus.CombinedCorpus(corpora)
61
    return docs
62
63
64
def parse_backend_params(backend_param, project):
65
    """Parse a list of backend parameters given with the --backend-param
66
    option into a nested dict structure"""
67
    backend_params = collections.defaultdict(dict)
68
    for beparam in backend_param:
69
        backend, param = beparam.split('.', 1)
70
        key, val = param.split('=', 1)
71
        validate_backend_params(backend, beparam, project)
72
        backend_params[backend][key] = val
73
    return backend_params
74
75
76
def validate_backend_params(backend, beparam, project):
77
    if 'algorithm' in beparam:
78
        raise NotSupportedException('Algorithm overriding not supported.')
79
    if backend != project.config['backend']:
80
        raise ConfigurationException(
81
            'The backend {} in CLI option "-b {}" not matching the project'
82
            ' backend {}.'
83
            .format(backend, beparam, project.config['backend']))
84
85
86
def generate_filter_batches(subjects):
87
    filter_batches = collections.OrderedDict()
88
    for limit in range(1, 16):
89
        for threshold in [i * 0.05 for i in range(20)]:
90
            hit_filter = SuggestionFilter(subjects, limit, threshold)
91
            batch = annif.eval.EvaluationBatch(subjects)
92
            filter_batches[(limit, threshold)] = (hit_filter, batch)
93
    return filter_batches
94
95
96
def set_project_config_file_path(ctx, param, value):
97
    """Override the default path or the path given in env by CLI option"""
98
    with ctx.ensure_object(ScriptInfo).load_app().app_context():
99
        if value:
100
            current_app.config['PROJECTS_FILE'] = value
101
102
103
def common_options(f):
104
    """Decorator to add common options for all CLI commands"""
105
    f = click.option(
106
        '-p', '--projects', help='Set path to projects.cfg',
107
        type=click.Path(dir_okay=False, exists=True),
108
        callback=set_project_config_file_path, expose_value=False,
109
        is_eager=True)(f)
110
    return click_log.simple_verbosity_option(logger)(f)
111
112
113
def backend_param_option(f):
114
    """Decorator to add an option for CLI commands to override BE parameters"""
115
    return click.option(
116
        '--backend-param', '-b', multiple=True,
117
        help='Override backend parameter of the config file. ' +
118
        'Syntax: "-b <backend>.<parameter>=<value>".')(f)
119
120
121
@cli.command('list-projects')
122
@common_options
123
@click_log.simple_verbosity_option(logger, default='ERROR')
124
def run_list_projects():
125
    """
126
    List available projects.
127
    """
128
129
    template = "{0: <25}{1: <45}{2: <10}{3: <7}"
130
    header = template.format(
131
        "Project ID", "Project Name", "Language", "Trained")
132
    click.echo(header)
133
    click.echo("-" * len(header))
134
    for proj in annif.registry.get_projects(
135
            min_access=Access.private).values():
136
        click.echo(template.format(
137
            proj.project_id, proj.name, proj.language, str(proj.is_trained)))
138
139
140
@cli.command('show-project')
141
@click.argument('project_id')
142
@common_options
143
def run_show_project(project_id):
144
    """
145
    Show information about a project.
146
    """
147
148
    proj = get_project(project_id)
149
    click.echo(f'Project ID:        {proj.project_id}')
150
    click.echo(f'Project Name:      {proj.name}')
151
    click.echo(f'Language:          {proj.language}')
152
    click.echo(f'Access:            {proj.access.name}')
153
    click.echo(f'Trained:           {proj.is_trained}')
154
    click.echo(f'Modification time: {proj.modification_time}')
155
156
157
@cli.command('clear')
158
@click.argument('project_id')
159
@common_options
160
def run_clear_project(project_id):
161
    """
162
    Initialize the project to its original, untrained state.
163
    """
164
    proj = get_project(project_id)
165
    proj.remove_model_data()
166
167
168
@cli.command('loadvoc')
169
@click.argument('project_id')
170
@click.argument('subjectfile', type=click.Path(exists=True, dir_okay=False))
171
@common_options
172
def run_loadvoc(project_id, subjectfile):
173
    """
174
    Load a vocabulary for a project.
175
    """
176
    proj = get_project(project_id)
177
    if annif.corpus.SubjectFileSKOS.is_rdf_file(subjectfile):
178
        # SKOS/RDF file supported by rdflib
179
        subjects = annif.corpus.SubjectFileSKOS(subjectfile, proj.language)
180
    else:
181
        # probably a TSV file
182
        subjects = annif.corpus.SubjectFileTSV(subjectfile)
183
    proj.vocab.load_vocabulary(subjects, proj.language)
184
185
186
@cli.command('train')
187
@click.argument('project_id')
188
@click.argument('paths', type=click.Path(exists=True), nargs=-1)
189
@click.option('--cached/--no-cached', default=False,
190
              help='Reuse preprocessed training data from previous run')
191
@backend_param_option
192
@common_options
193
def run_train(project_id, paths, cached, backend_param):
194
    """
195
    Train a project on a collection of documents.
196
    """
197
    proj = get_project(project_id)
198
    backend_params = parse_backend_params(backend_param, proj)
199
    if cached:
200
        if len(paths) > 0:
201
            raise click.UsageError(
202
                "Corpus paths cannot be given when using --cached option.")
203
        documents = 'cached'
204
    else:
205
        documents = open_documents(paths)
206
    proj.train(documents, backend_params)
207
208
209
@cli.command('learn')
210
@click.argument('project_id')
211
@click.argument('paths', type=click.Path(exists=True), nargs=-1)
212
@backend_param_option
213
@common_options
214
def run_learn(project_id, paths, backend_param):
215
    """
216
    Further train an existing project on a collection of documents.
217
    """
218
    proj = get_project(project_id)
219
    backend_params = parse_backend_params(backend_param, proj)
220
    documents = open_documents(paths)
221
    proj.learn(documents, backend_params)
222
223
224
@cli.command('suggest')
225
@click.argument('project_id')
226
@click.option('--limit', default=10, help='Maximum number of subjects')
227
@click.option('--threshold', default=0.0, help='Minimum score threshold')
228
@backend_param_option
229
@common_options
230
def run_suggest(project_id, limit, threshold, backend_param):
231
    """
232
    Suggest subjects for a single document from standard input.
233
    """
234
    project = get_project(project_id)
235
    text = sys.stdin.read()
236
    backend_params = parse_backend_params(backend_param, project)
237
    hit_filter = SuggestionFilter(project.subjects, limit, threshold)
238
    hits = hit_filter(project.suggest(text, backend_params))
239
    for hit in hits.as_list(project.subjects):
240
        click.echo(
241
            "<{}>\t{}\t{}".format(
242
                hit.uri,
243
                '\t'.join(filter(None, (hit.label, hit.notation))),
244
                hit.score))
245
246
247
@cli.command('index')
248
@click.argument('project_id')
249
@click.argument('directory', type=click.Path(exists=True, file_okay=False))
250
@click.option(
251
    '--suffix',
252
    default='.annif',
253
    help='File name suffix for result files')
254
@click.option('--force/--no-force', default=False,
255
              help='Force overwriting of existing result files')
256
@click.option('--limit', default=10, help='Maximum number of subjects')
257
@click.option('--threshold', default=0.0, help='Minimum score threshold')
258
@backend_param_option
259
@common_options
260
def run_index(project_id, directory, suffix, force,
261
              limit, threshold, backend_param):
262
    """
263
    Index a directory with documents, suggesting subjects for each document.
264
    Write the results in TSV files with the given suffix.
265
    """
266
    project = get_project(project_id)
267
    backend_params = parse_backend_params(backend_param, project)
268
    hit_filter = SuggestionFilter(project.subjects, limit, threshold)
269
270
    for docfilename, dummy_subjectfn in annif.corpus.DocumentDirectory(
271
            directory, require_subjects=False):
272
        with open(docfilename, encoding='utf-8-sig') as docfile:
273
            text = docfile.read()
274
        subjectfilename = re.sub(r'\.txt$', suffix, docfilename)
275
        if os.path.exists(subjectfilename) and not force:
276
            click.echo(
277
                "Not overwriting {} (use --force to override)".format(
278
                    subjectfilename))
279
            continue
280
        with open(subjectfilename, 'w', encoding='utf-8') as subjfile:
281
            results = project.suggest(text, backend_params)
282
            for hit in hit_filter(results).as_list(project.subjects):
283
                line = "<{}>\t{}\t{}".format(
284
                    hit.uri,
285
                    '\t'.join(filter(None, (hit.label, hit.notation))),
286
                    hit.score)
287
                click.echo(line, file=subjfile)
288
289
290
@cli.command('eval')
291
@click.argument('project_id')
292
@click.argument('paths', type=click.Path(exists=True), nargs=-1)
293
@click.option('--limit', default=10, help='Maximum number of subjects')
294
@click.option('--threshold', default=0.0, help='Minimum score threshold')
295
@click.option(
296
    '--results-file',
297
    type=click.File(
298
        'w',
299
        encoding='utf-8',
300
        errors='ignore',
301
        lazy=True),
302
    help="""Specify file in order to write non-aggregated results per subject.
303
    File directory must exist, existing file will be overwritten.""")
304
@click.option('--jobs',
305
              default=1,
306
              help='Number of parallel jobs (0 means all CPUs)')
307
@backend_param_option
308
@common_options
309
def run_eval(
310
        project_id,
311
        paths,
312
        limit,
313
        threshold,
314
        results_file,
315
        jobs,
316
        backend_param):
317
    """
318
    Analyze documents and evaluate the result.
319
320
    Compare the results of automated indexing against a gold standard. The
321
    path may be either a TSV file with short documents or a directory with
322
    documents in separate files.
323
    """
324
325
    project = get_project(project_id)
326
    backend_params = parse_backend_params(backend_param, project)
327
328
    eval_batch = annif.eval.EvaluationBatch(project.subjects)
329
330
    if results_file:
331
        try:
332
            print('', end='', file=results_file)
333
            click.echo('Writing per subject evaluation results to {!s}'.format(
334
                results_file.name))
335
        except Exception as e:
336
            raise NotSupportedException(
337
                "cannot open results-file for writing: " + str(e))
338
    docs = open_documents(paths)
339
340
    if jobs < 1:
341
        jobs = None
342
        pool_class = multiprocessing.Pool
343
    elif jobs == 1:
344
        # use the dummy wrapper around threading to avoid subprocess overhead
345
        pool_class = multiprocessing.dummy.Pool
346
    else:
347
        pool_class = multiprocessing.Pool
348
349
    project.initialize()
350
    psmap = annif.project.ProjectSuggestMap(
351
        project, backend_params, limit, threshold)
352
353
    with pool_class(jobs) as pool:
354
        for hits, uris, labels in pool.imap_unordered(
355
                psmap.suggest, docs.documents):
356
            eval_batch.evaluate(hits,
357
                                annif.corpus.SubjectSet((uris, labels)))
358
359
    template = "{0:<30}\t{1}"
360
    for metric, score in eval_batch.results(results_file=results_file).items():
361
        click.echo(template.format(metric + ":", score))
362
363
364
@cli.command('optimize')
365
@click.argument('project_id')
366
@click.argument('paths', type=click.Path(exists=True), nargs=-1)
367
@backend_param_option
368
@common_options
369
def run_optimize(project_id, paths, backend_param):
370
    """
371
    Analyze documents, testing multiple limits and thresholds.
372
373
    Evaluate the analysis results for a directory with documents against a
374
    gold standard given in subject files. Test different limit/threshold
375
    values and report the precision, recall and F-measure of each combination
376
    of settings.
377
    """
378
    project = get_project(project_id)
379
    backend_params = parse_backend_params(backend_param, project)
380
381
    filter_batches = generate_filter_batches(project.subjects)
382
383
    ndocs = 0
384
    docs = open_documents(paths)
385
    for doc in docs.documents:
386
        hits = project.suggest(doc.text, backend_params)
387
        gold_subjects = annif.corpus.SubjectSet((doc.uris, doc.labels))
388
        for hit_filter, batch in filter_batches.values():
389
            batch.evaluate(hit_filter(hits), gold_subjects)
390
        ndocs += 1
391
392
    click.echo("\t".join(('Limit', 'Thresh.', 'Prec.', 'Rec.', 'F1')))
393
394
    best_scores = collections.defaultdict(float)
395
    best_params = {}
396
397
    template = "{:d}\t{:.02f}\t{:.04f}\t{:.04f}\t{:.04f}"
398
    # Store the batches in a list that gets consumed along the way
399
    # This way GC will have a chance to reclaim the memory
400
    filter_batches = list(filter_batches.items())
401
    while filter_batches:
402
        params, filter_batch = filter_batches.pop(0)
403
        metrics = ['Precision (doc avg)',
404
                   'Recall (doc avg)',
405
                   'F1 score (doc avg)',
406
                   'NDCG@5',
407
                   'NDCG@10']
408
        results = filter_batch[1].results(metrics=metrics)
409
        for metric, score in results.items():
410
            if score >= best_scores[metric]:
411
                best_scores[metric] = score
412
                best_params[metric] = params
413
        click.echo(
414
            template.format(
415
                params[0],
416
                params[1],
417
                results['Precision (doc avg)'],
418
                results['Recall (doc avg)'],
419
                results['F1 score (doc avg)']))
420
421
    click.echo()
422
    template2 = "Best {:>19}: {:.04f}\tLimit: {:d}\tThreshold: {:.02f}"
423
    for metric in metrics:
424
        click.echo(
425
            template2.format(
426
                metric,
427
                best_scores[metric],
428
                best_params[metric][0],
429
                best_params[metric][1]))
430
    click.echo("Documents evaluated:\t{}".format(ndocs))
431
432
433
@cli.command('hyperopt')
434
@click.argument('project_id')
435
@click.argument('paths', type=click.Path(exists=True), nargs=-1)
436
@click.option('--trials', default=10, help='Number of trials')
437
@click.option('--jobs',
438
              default=1,
439
              help='Number of parallel runs (-1 means all CPUs)')
440
@click.option('--metric', default='NDCG', help='Metric to optimize')
441
@common_options
442
def run_hyperopt(project_id, paths, trials, jobs, metric):
443
    """
444
    Optimize the hyperparameters of a project using a validation corpus.
445
    """
446
    proj = get_project(project_id)
447
    documents = open_documents(paths)
448
    click.echo(f"Looking for optimal hyperparameters using {trials} trials")
449
    rec = proj.hyperopt(documents, trials, jobs, metric)
450
    click.echo(f"Got best {metric} score {rec.score:.4f} with:")
451
    click.echo("---")
452
    for line in rec.lines:
453
        click.echo(line)
454
    click.echo("---")
455
456
457
if __name__ == '__main__':
458
    cli()
459