Passed
Pull Request — master (#418)
by Osma
01:38
created

annif.cli.run_train()   A

Complexity

Conditions 3

Size

Total Lines 21
Code Lines 17

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 17
dl 0
loc 21
rs 9.55
c 0
b 0
f 0
cc 3
nop 4
1
"""Definitions for command-line (Click) commands for invoking Annif
2
operations and printing the results to console."""
3
4
5
import collections
6
import multiprocessing
7
import multiprocessing.dummy
8
import os.path
9
import re
10
import sys
11
import click
12
import click_log
13
from flask import current_app
14
from flask.cli import FlaskGroup, ScriptInfo
15
import annif
16
import annif.corpus
17
import annif.eval
18
import annif.project
19
from annif.project import Access
20
from annif.suggestion import SuggestionFilter
21
from annif.exception import ConfigurationException, NotSupportedException
22
23
logger = annif.logger
24
click_log.basic_config(logger)
25
26
cli = FlaskGroup(create_app=annif.create_app)
27
28
29
def get_project(project_id):
30
    """
31
    Helper function to get a project by ID and bail out if it doesn't exist"""
32
    try:
33
        return annif.project.get_project(project_id, min_access=Access.hidden)
34
    except ValueError:
35
        click.echo(
36
            "No projects found with id \'{0}\'.".format(project_id),
37
            err=True)
38
        sys.exit(1)
39
40
41
def open_documents(paths):
42
    """Helper function to open a document corpus from a list of pathnames,
43
    each of which is either a TSV file or a directory of TXT files. The
44
    corpus will be returned as an instance of DocumentCorpus."""
45
46
    def open_doc_path(path):
47
        """open a single path and return it as a DocumentCorpus"""
48
        if os.path.isdir(path):
49
            return annif.corpus.DocumentDirectory(path, require_subjects=True)
50
        return annif.corpus.DocumentFile(path)
51
52
    if len(paths) == 0:
53
        logger.warning('Reading empty file')
54
        docs = open_doc_path(os.path.devnull)
55
    elif len(paths) == 1:
56
        docs = open_doc_path(paths[0])
57
    else:
58
        corpora = [open_doc_path(path) for path in paths]
59
        docs = annif.corpus.CombinedCorpus(corpora)
60
    return docs
61
62
63
def parse_backend_params(backend_param, project):
64
    """Parse a list of backend parameters given with the --backend-param
65
    option into a nested dict structure"""
66
    backend_params = collections.defaultdict(dict)
67
    for beparam in backend_param:
68
        backend, param = beparam.split('.', 1)
69
        key, val = param.split('=', 1)
70
        validate_backend_params(backend, beparam, project)
71
        backend_params[backend][key] = val
72
    return backend_params
73
74
75
def validate_backend_params(backend, beparam, project):
76
    if 'algorithm' in beparam:
77
        raise NotSupportedException('Algorithm overriding not supported.')
78
    if backend != project.config['backend']:
79
        raise ConfigurationException(
80
            'The backend {} in CLI option "-b {}" not matching the project'
81
            ' backend {}.'
82
            .format(backend, beparam, project.config['backend']))
83
84
85
def generate_filter_batches(subjects):
86
    filter_batches = collections.OrderedDict()
87
    for limit in range(1, 16):
88
        for threshold in [i * 0.05 for i in range(20)]:
89
            hit_filter = SuggestionFilter(subjects, limit, threshold)
90
            batch = annif.eval.EvaluationBatch(subjects)
91
            filter_batches[(limit, threshold)] = (hit_filter, batch)
92
    return filter_batches
93
94
95
def set_project_config_file_path(ctx, param, value):
96
    """Override the default path or the path given in env by CLI option"""
97
    with ctx.ensure_object(ScriptInfo).load_app().app_context():
98
        if value:
99
            current_app.config['PROJECTS_FILE'] = value
100
101
102
def common_options(f):
103
    """Decorator to add common options for all CLI commands"""
104
    f = click.option(
105
        '-p', '--projects', help='Set path to projects.cfg',
106
        type=click.Path(dir_okay=False, exists=True),
107
        callback=set_project_config_file_path, expose_value=False,
108
        is_eager=True)(f)
109
    return click_log.simple_verbosity_option(logger)(f)
110
111
112
def backend_param_option(f):
113
    """Decorator to add an option for CLI commands to override BE parameters"""
114
    return click.option(
115
        '--backend-param', '-b', multiple=True,
116
        help='Override backend parameter of the config file. ' +
117
        'Syntax: "-b <backend>.<parameter>=<value>".')(f)
118
119
120
@cli.command('list-projects')
121
@common_options
122
@click_log.simple_verbosity_option(logger, default='ERROR')
123
def run_list_projects():
124
    """
125
    List available projects.
126
    """
127
128
    template = "{0: <25}{1: <45}{2: <10}{3: <7}"
129
    header = template.format(
130
        "Project ID", "Project Name", "Language", "Trained")
131
    click.echo(header)
132
    click.echo("-" * len(header))
133
    for proj in annif.project.get_projects(min_access=Access.private).values():
134
        click.echo(template.format(
135
            proj.project_id, proj.name, proj.language, str(proj.is_trained)))
136
137
138
@cli.command('show-project')
139
@click.argument('project_id')
140
@common_options
141
def run_show_project(project_id):
142
    """
143
    Show information about a project.
144
    """
145
146
    proj = get_project(project_id)
147
    click.echo(f'Project ID:        {proj.project_id}')
148
    click.echo(f'Project Name:      {proj.name}')
149
    click.echo(f'Language:          {proj.language}')
150
    click.echo(f'Access:            {proj.access.name}')
151
    click.echo(f'Trained:           {proj.is_trained}')
152
    click.echo(f'Modification time: {proj.modification_time}')
153
154
155
@cli.command('clear')
156
@click.argument('project_id')
157
@common_options
158
def run_clear_project(project_id):
159
    """
160
    Initialize the project to its original, untrained state.
161
    """
162
    proj = get_project(project_id)
163
    proj.remove_model_data()
164
165
166
@cli.command('loadvoc')
167
@click.argument('project_id')
168
@click.argument('subjectfile', type=click.Path(exists=True, dir_okay=False))
169
@common_options
170
def run_loadvoc(project_id, subjectfile):
171
    """
172
    Load a vocabulary for a project.
173
    """
174
    proj = get_project(project_id)
175
    if annif.corpus.SubjectFileSKOS.is_rdf_file(subjectfile):
176
        # SKOS/RDF file supported by rdflib
177
        subjects = annif.corpus.SubjectFileSKOS(subjectfile, proj.language)
178
    else:
179
        # probably a TSV file
180
        subjects = annif.corpus.SubjectFileTSV(subjectfile)
181
    proj.vocab.load_vocabulary(subjects, proj.language)
182
183
184
@cli.command('train')
185
@click.argument('project_id')
186
@click.argument('paths', type=click.Path(exists=True), nargs=-1)
187
@click.option('--cached/--no-cached', default=False,
188
              help='Reuse preprocessed training data from previous run')
189
@backend_param_option
190
@common_options
191
def run_train(project_id, paths, cached, backend_param):
192
    """
193
    Train a project on a collection of documents.
194
    """
195
    proj = get_project(project_id)
196
    backend_params = parse_backend_params(backend_param, proj)
197
    if cached:
198
        if len(paths) > 0:
199
            raise click.UsageError(
200
                "Corpus paths cannot be given when using --cached option.")
201
        documents = 'cached'
202
    else:
203
        documents = open_documents(paths)
204
    proj.train(documents, backend_params)
205
206
207
@cli.command('learn')
208
@click.argument('project_id')
209
@click.argument('paths', type=click.Path(exists=True), nargs=-1)
210
@backend_param_option
211
@common_options
212
def run_learn(project_id, paths, backend_param):
213
    """
214
    Further train an existing project on a collection of documents.
215
    """
216
    proj = get_project(project_id)
217
    backend_params = parse_backend_params(backend_param, proj)
218
    documents = open_documents(paths)
219
    proj.learn(documents, backend_params)
220
221
222
@cli.command('suggest')
223
@click.argument('project_id')
224
@click.option('--limit', default=10, help='Maximum number of subjects')
225
@click.option('--threshold', default=0.0, help='Minimum score threshold')
226
@backend_param_option
227
@common_options
228
def run_suggest(project_id, limit, threshold, backend_param):
229
    """
230
    Suggest subjects for a single document from standard input.
231
    """
232
    project = get_project(project_id)
233
    text = sys.stdin.read()
234
    backend_params = parse_backend_params(backend_param, project)
235
    hit_filter = SuggestionFilter(project.subjects, limit, threshold)
236
    hits = hit_filter(project.suggest(text, backend_params))
237
    for hit in hits.as_list(project.subjects):
238
        click.echo(
239
            "<{}>\t{}\t{}".format(
240
                hit.uri,
241
                '\t'.join(filter(None, (hit.label, hit.notation))),
242
                hit.score))
243
244
245
@cli.command('index')
246
@click.argument('project_id')
247
@click.argument('directory', type=click.Path(exists=True, file_okay=False))
248
@click.option(
249
    '--suffix',
250
    default='.annif',
251
    help='File name suffix for result files')
252
@click.option('--force/--no-force', default=False,
253
              help='Force overwriting of existing result files')
254
@click.option('--limit', default=10, help='Maximum number of subjects')
255
@click.option('--threshold', default=0.0, help='Minimum score threshold')
256
@backend_param_option
257
@common_options
258
def run_index(project_id, directory, suffix, force,
259
              limit, threshold, backend_param):
260
    """
261
    Index a directory with documents, suggesting subjects for each document.
262
    Write the results in TSV files with the given suffix.
263
    """
264
    project = get_project(project_id)
265
    backend_params = parse_backend_params(backend_param, project)
266
    hit_filter = SuggestionFilter(project.subjects, limit, threshold)
267
268
    for docfilename, dummy_subjectfn in annif.corpus.DocumentDirectory(
269
            directory, require_subjects=False):
270
        with open(docfilename, encoding='utf-8-sig') as docfile:
271
            text = docfile.read()
272
        subjectfilename = re.sub(r'\.txt$', suffix, docfilename)
273
        if os.path.exists(subjectfilename) and not force:
274
            click.echo(
275
                "Not overwriting {} (use --force to override)".format(
276
                    subjectfilename))
277
            continue
278
        with open(subjectfilename, 'w', encoding='utf-8') as subjfile:
279
            results = project.suggest(text, backend_params)
280
            for hit in hit_filter(results).as_list(project.subjects):
281
                line = "<{}>\t{}\t{}".format(
282
                    hit.uri,
283
                    '\t'.join(filter(None, (hit.label, hit.notation))),
284
                    hit.score)
285
                click.echo(line, file=subjfile)
286
287
288
@cli.command('eval')
289
@click.argument('project_id')
290
@click.argument('paths', type=click.Path(exists=True), nargs=-1)
291
@click.option('--limit', default=10, help='Maximum number of subjects')
292
@click.option('--threshold', default=0.0, help='Minimum score threshold')
293
@click.option(
294
    '--results-file',
295
    type=click.File(
296
        'w',
297
        encoding='utf-8',
298
        errors='ignore',
299
        lazy=True),
300
    help="""Specify file in order to write non-aggregated results per subject.
301
    File directory must exist, existing file will be overwritten.""")
302
@click.option('--jobs',
303
              default=1,
304
              help='Number of parallel jobs (0 means all CPUs)')
305
@backend_param_option
306
@common_options
307
def run_eval(
308
        project_id,
309
        paths,
310
        limit,
311
        threshold,
312
        results_file,
313
        jobs,
314
        backend_param):
315
    """
316
    Analyze documents and evaluate the result.
317
318
    Compare the results of automated indexing against a gold standard. The
319
    path may be either a TSV file with short documents or a directory with
320
    documents in separate files.
321
    """
322
323
    project = get_project(project_id)
324
    backend_params = parse_backend_params(backend_param, project)
325
326
    eval_batch = annif.eval.EvaluationBatch(project.subjects)
327
328
    if results_file:
329
        try:
330
            print('', end='', file=results_file)
331
            click.echo('Writing per subject evaluation results to {!s}'.format(
332
                results_file.name))
333
        except Exception as e:
334
            raise NotSupportedException(
335
                "cannot open results-file for writing: " + str(e))
336
    docs = open_documents(paths)
337
338
    if jobs < 1:
339
        jobs = None
340
        pool_class = multiprocessing.Pool
341
    elif jobs == 1:
342
        # use the dummy wrapper around threading to avoid subprocess overhead
343
        pool_class = multiprocessing.dummy.Pool
344
    else:
345
        pool_class = multiprocessing.Pool
346
347
    project.initialize()
348
    map = annif.project.ProjectSuggestMap(
349
        project, backend_params, limit, threshold)
350
351
    with pool_class(jobs) as pool:
352
        for hits, uris, labels in pool.imap_unordered(
353
                map.suggest, docs.documents):
354
            eval_batch.evaluate(hits,
355
                                annif.corpus.SubjectSet((uris, labels)))
356
357
    template = "{0:<30}\t{1}"
358
    for metric, score in eval_batch.results(results_file=results_file).items():
359
        click.echo(template.format(metric + ":", score))
360
361
362
@cli.command('optimize')
363
@click.argument('project_id')
364
@click.argument('paths', type=click.Path(exists=True), nargs=-1)
365
@backend_param_option
366
@common_options
367
def run_optimize(project_id, paths, backend_param):
368
    """
369
    Analyze documents, testing multiple limits and thresholds.
370
371
    Evaluate the analysis results for a directory with documents against a
372
    gold standard given in subject files. Test different limit/threshold
373
    values and report the precision, recall and F-measure of each combination
374
    of settings.
375
    """
376
    project = get_project(project_id)
377
    backend_params = parse_backend_params(backend_param, project)
378
379
    filter_batches = generate_filter_batches(project.subjects)
380
381
    ndocs = 0
382
    docs = open_documents(paths)
383
    for doc in docs.documents:
384
        hits = project.suggest(doc.text, backend_params)
385
        gold_subjects = annif.corpus.SubjectSet((doc.uris, doc.labels))
386
        for hit_filter, batch in filter_batches.values():
387
            batch.evaluate(hit_filter(hits), gold_subjects)
388
        ndocs += 1
389
390
    click.echo("\t".join(('Limit', 'Thresh.', 'Prec.', 'Rec.', 'F1')))
391
392
    best_scores = collections.defaultdict(float)
393
    best_params = {}
394
395
    template = "{:d}\t{:.02f}\t{:.04f}\t{:.04f}\t{:.04f}"
396
    # Store the batches in a list that gets consumed along the way
397
    # This way GC will have a chance to reclaim the memory
398
    filter_batches = list(filter_batches.items())
399
    while filter_batches:
400
        params, filter_batch = filter_batches.pop(0)
401
        results = filter_batch[1].results(metrics='simple')
402
        for metric, score in results.items():
403
            if score >= best_scores[metric]:
404
                best_scores[metric] = score
405
                best_params[metric] = params
406
        click.echo(
407
            template.format(
408
                params[0],
409
                params[1],
410
                results['Precision (doc avg)'],
411
                results['Recall (doc avg)'],
412
                results['F1 score (doc avg)']))
413
414
    click.echo()
415
    template2 = "Best {:>19}: {:.04f}\tLimit: {:d}\tThreshold: {:.02f}"
416
    for metric in ('Precision (doc avg)',
417
                   'Recall (doc avg)',
418
                   'F1 score (doc avg)',
419
                   'NDCG@5',
420
                   'NDCG@10'):
421
        click.echo(
422
            template2.format(
423
                metric,
424
                best_scores[metric],
425
                best_params[metric][0],
426
                best_params[metric][1]))
427
    click.echo("Documents evaluated:\t{}".format(ndocs))
428
429
430
if __name__ == '__main__':
431
    cli()
432