Total Complexity | 64 |
Total Lines | 656 |
Duplicated Lines | 70.73 % |
Changes | 0 |
Duplicate code is one of the most pungent code smells. A rule that is often used is to re-structure code once it is duplicated in three or more places.
Common duplication problems, and corresponding solutions are:
Complex classes like annif.cli often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
1 | """Definitions for command-line (Click) commands for invoking Annif |
||
2 | operations and printing the results to console.""" |
||
3 | |||
4 | |||
5 | import collections |
||
6 | import os.path |
||
7 | import re |
||
8 | import sys |
||
9 | import json |
||
10 | import click |
||
11 | import click_log |
||
12 | from flask import current_app |
||
13 | from flask.cli import FlaskGroup, ScriptInfo |
||
14 | import annif |
||
15 | import annif.corpus |
||
16 | import annif.parallel |
||
17 | import annif.project |
||
18 | import annif.registry |
||
19 | from annif.project import Access |
||
20 | from annif.suggestion import SuggestionFilter, ListSuggestionResult |
||
21 | from annif.exception import ConfigurationException, NotSupportedException |
||
22 | from annif.exception import NotInitializedException |
||
23 | from annif.util import metric_code |
||
24 | |||
25 | logger = annif.logger |
||
26 | click_log.basic_config(logger) |
||
27 | |||
28 | cli = FlaskGroup(create_app=annif.create_app, add_version_option=False) |
||
29 | cli = click.version_option(message='%(version)s')(cli) |
||
30 | |||
31 | |||
32 | def get_project(project_id): |
||
33 | """ |
||
34 | Helper function to get a project by ID and bail out if it doesn't exist""" |
||
35 | try: |
||
36 | return annif.registry.get_project(project_id, |
||
37 | min_access=Access.private) |
||
38 | except ValueError: |
||
39 | click.echo( |
||
40 | "No projects found with id \'{0}\'.".format(project_id), |
||
41 | err=True) |
||
42 | sys.exit(1) |
||
43 | |||
44 | |||
45 | def get_vocab(vocab_id): |
||
46 | """ |
||
47 | Helper function to get a vocabulary by ID and bail out if it doesn't |
||
48 | exist""" |
||
49 | try: |
||
50 | return annif.registry.get_vocab(vocab_id, |
||
51 | min_access=Access.private) |
||
52 | except ValueError: |
||
53 | click.echo( |
||
54 | f"No vocabularies found with the id '{vocab_id}'.", |
||
55 | err=True) |
||
56 | sys.exit(1) |
||
57 | |||
58 | |||
59 | View Code Duplication | def open_documents(paths, subject_index, vocab_lang, docs_limit): |
|
|
|||
60 | """Helper function to open a document corpus from a list of pathnames, |
||
61 | each of which is either a TSV file or a directory of TXT files. For |
||
62 | directories with subjects in TSV files, the given vocabulary language |
||
63 | will be used to convert subject labels into URIs. The corpus will be |
||
64 | returned as an instance of DocumentCorpus or LimitingDocumentCorpus.""" |
||
65 | |||
66 | def open_doc_path(path, subject_index): |
||
67 | """open a single path and return it as a DocumentCorpus""" |
||
68 | if os.path.isdir(path): |
||
69 | return annif.corpus.DocumentDirectory(path, subject_index, |
||
70 | vocab_lang, |
||
71 | require_subjects=True) |
||
72 | return annif.corpus.DocumentFile(path, subject_index) |
||
73 | |||
74 | if len(paths) == 0: |
||
75 | logger.warning('Reading empty file') |
||
76 | docs = open_doc_path(os.path.devnull, subject_index) |
||
77 | elif len(paths) == 1: |
||
78 | docs = open_doc_path(paths[0], subject_index) |
||
79 | else: |
||
80 | corpora = [open_doc_path(path, subject_index) for path in paths] |
||
81 | docs = annif.corpus.CombinedCorpus(corpora) |
||
82 | if docs_limit is not None: |
||
83 | docs = annif.corpus.LimitingDocumentCorpus(docs, docs_limit) |
||
84 | return docs |
||
85 | |||
86 | |||
87 | View Code Duplication | def parse_backend_params(backend_param, project): |
|
88 | """Parse a list of backend parameters given with the --backend-param |
||
89 | option into a nested dict structure""" |
||
90 | backend_params = collections.defaultdict(dict) |
||
91 | for beparam in backend_param: |
||
92 | backend, param = beparam.split('.', 1) |
||
93 | key, val = param.split('=', 1) |
||
94 | validate_backend_params(backend, beparam, project) |
||
95 | backend_params[backend][key] = val |
||
96 | return backend_params |
||
97 | |||
98 | |||
99 | def validate_backend_params(backend, beparam, project): |
||
100 | if backend != project.config['backend']: |
||
101 | raise ConfigurationException( |
||
102 | 'The backend {} in CLI option "-b {}" not matching the project' |
||
103 | ' backend {}.' |
||
104 | .format(backend, beparam, project.config['backend'])) |
||
105 | |||
106 | |||
107 | BATCH_MAX_LIMIT = 15 |
||
108 | |||
109 | |||
110 | def generate_filter_batches(subjects): |
||
111 | import annif.eval |
||
112 | filter_batches = collections.OrderedDict() |
||
113 | for limit in range(1, BATCH_MAX_LIMIT + 1): |
||
114 | for threshold in [i * 0.05 for i in range(20)]: |
||
115 | hit_filter = SuggestionFilter(subjects, limit, threshold) |
||
116 | batch = annif.eval.EvaluationBatch(subjects) |
||
117 | filter_batches[(limit, threshold)] = (hit_filter, batch) |
||
118 | return filter_batches |
||
119 | |||
120 | |||
121 | def set_project_config_file_path(ctx, param, value): |
||
122 | """Override the default path or the path given in env by CLI option""" |
||
123 | with ctx.ensure_object(ScriptInfo).load_app().app_context(): |
||
124 | if value: |
||
125 | current_app.config['PROJECTS_CONFIG_PATH'] = value |
||
126 | |||
127 | |||
128 | def common_options(f): |
||
129 | """Decorator to add common options for all CLI commands""" |
||
130 | f = click.option( |
||
131 | '-p', '--projects', |
||
132 | help='Set path to project configuration file or directory', |
||
133 | type=click.Path(dir_okay=True, exists=True), |
||
134 | callback=set_project_config_file_path, expose_value=False, |
||
135 | is_eager=True)(f) |
||
136 | return click_log.simple_verbosity_option(logger)(f) |
||
137 | |||
138 | |||
139 | def backend_param_option(f): |
||
140 | """Decorator to add an option for CLI commands to override BE parameters""" |
||
141 | return click.option( |
||
142 | '--backend-param', '-b', multiple=True, |
||
143 | help='Override backend parameter of the config file. ' + |
||
144 | 'Syntax: `-b <backend>.<parameter>=<value>`.')(f) |
||
145 | |||
146 | |||
147 | View Code Duplication | @cli.command('list-projects') |
|
148 | @common_options |
||
149 | @click_log.simple_verbosity_option(logger, default='ERROR') |
||
150 | def run_list_projects(): |
||
151 | """ |
||
152 | List available projects. |
||
153 | \f |
||
154 | Show a list of currently defined projects. Projects are defined in a |
||
155 | configuration file, normally called ``projects.cfg``. See `Project |
||
156 | configuration |
||
157 | <https://github.com/NatLibFi/Annif/wiki/Project-configuration>`_ |
||
158 | for details. |
||
159 | """ |
||
160 | |||
161 | template = "{0: <25}{1: <45}{2: <10}{3: <7}" |
||
162 | header = template.format( |
||
163 | "Project ID", "Project Name", "Language", "Trained") |
||
164 | click.echo(header) |
||
165 | click.echo("-" * len(header)) |
||
166 | for proj in annif.registry.get_projects( |
||
167 | min_access=Access.private).values(): |
||
168 | click.echo(template.format( |
||
169 | proj.project_id, proj.name, proj.language, str(proj.is_trained))) |
||
170 | |||
171 | |||
172 | View Code Duplication | @cli.command('show-project') |
|
173 | @click.argument('project_id') |
||
174 | @common_options |
||
175 | def run_show_project(project_id): |
||
176 | """ |
||
177 | Show information about a project. |
||
178 | """ |
||
179 | |||
180 | proj = get_project(project_id) |
||
181 | click.echo(f'Project ID: {proj.project_id}') |
||
182 | click.echo(f'Project Name: {proj.name}') |
||
183 | click.echo(f'Language: {proj.language}') |
||
184 | click.echo(f'Vocabulary: {proj.vocab.vocab_id}') |
||
185 | click.echo(f'Vocab language: {proj.vocab_lang}') |
||
186 | click.echo(f'Access: {proj.access.name}') |
||
187 | click.echo(f'Trained: {proj.is_trained}') |
||
188 | click.echo(f'Modification time: {proj.modification_time}') |
||
189 | |||
190 | |||
191 | @cli.command('clear') |
||
192 | @click.argument('project_id') |
||
193 | @common_options |
||
194 | def run_clear_project(project_id): |
||
195 | """ |
||
196 | Initialize the project to its original, untrained state. |
||
197 | """ |
||
198 | proj = get_project(project_id) |
||
199 | proj.remove_model_data() |
||
200 | |||
201 | |||
202 | View Code Duplication | @cli.command('list-vocabs') |
|
203 | @common_options |
||
204 | @click_log.simple_verbosity_option(logger, default='ERROR') |
||
205 | def run_list_vocabs(): |
||
206 | """ |
||
207 | List available vocabularies. |
||
208 | """ |
||
209 | |||
210 | template = "{0: <20}{1: <20}{2: >10} {3: <6}" |
||
211 | header = template.format( |
||
212 | "Vocabulary ID", "Languages", "Size", "Loaded") |
||
213 | click.echo(header) |
||
214 | click.echo("-" * len(header)) |
||
215 | for vocab in annif.registry.get_vocabs( |
||
216 | min_access=Access.private).values(): |
||
217 | try: |
||
218 | languages = ','.join(sorted(vocab.languages)) |
||
219 | size = len(vocab) |
||
220 | loaded = True |
||
221 | except NotInitializedException: |
||
222 | languages = '-' |
||
223 | size = '-' |
||
224 | loaded = False |
||
225 | click.echo(template.format( |
||
226 | vocab.vocab_id, languages, size, str(loaded))) |
||
227 | |||
228 | |||
229 | View Code Duplication | @cli.command('loadvoc', deprecated=True) |
|
230 | @click.argument('project_id') |
||
231 | @click.argument('subjectfile', type=click.Path(exists=True, dir_okay=False)) |
||
232 | @click.option('--force', '-f', default=False, is_flag=True, |
||
233 | help='Replace existing vocabulary completely ' + |
||
234 | 'instead of updating it') |
||
235 | @common_options |
||
236 | def run_loadvoc(project_id, force, subjectfile): |
||
237 | """ |
||
238 | Load a vocabulary for a project. |
||
239 | \f |
||
240 | This will load the vocabulary to be used in subject indexing. Note that |
||
241 | although ``PROJECT_ID`` is a parameter of the command, the vocabulary is |
||
242 | shared by all the projects with the same vocab identifier in the project |
||
243 | configuration, and the vocabulary only needs to be loaded for one of those |
||
244 | projects. |
||
245 | |||
246 | If a vocabulary has already been loaded, reinvoking loadvoc with a new |
||
247 | subject file will update the Annif’s internal vocabulary: label names are |
||
248 | updated and any subject not appearing in the new subject file is removed. |
||
249 | Note that new subjects will not be suggested before the project is |
||
250 | retrained with the updated vocabulary. The update behavior can be |
||
251 | overridden with the ``--force`` option. |
||
252 | """ |
||
253 | proj = get_project(project_id) |
||
254 | if annif.corpus.SubjectFileSKOS.is_rdf_file(subjectfile): |
||
255 | # SKOS/RDF file supported by rdflib |
||
256 | subjects = annif.corpus.SubjectFileSKOS(subjectfile) |
||
257 | elif annif.corpus.SubjectFileCSV.is_csv_file(subjectfile): |
||
258 | # CSV file |
||
259 | subjects = annif.corpus.SubjectFileCSV(subjectfile) |
||
260 | else: |
||
261 | # probably a TSV file |
||
262 | subjects = annif.corpus.SubjectFileTSV(subjectfile, proj.vocab_lang) |
||
263 | proj.vocab.load_vocabulary(subjects, force=force) |
||
264 | |||
265 | |||
266 | View Code Duplication | @cli.command('load-vocab') |
|
267 | @click.argument('vocab_id') |
||
268 | @click.argument('subjectfile', type=click.Path(exists=True, dir_okay=False)) |
||
269 | @click.option('--language', '-L', help='Language of subject file') |
||
270 | @click.option('--force', '-f', default=False, is_flag=True, |
||
271 | help='Replace existing vocabulary completely ' + |
||
272 | 'instead of updating it') |
||
273 | @common_options |
||
274 | def run_load_vocab(vocab_id, language, force, subjectfile): |
||
275 | """ |
||
276 | Load a vocabulary from a subject file. |
||
277 | """ |
||
278 | vocab = get_vocab(vocab_id) |
||
279 | if annif.corpus.SubjectFileSKOS.is_rdf_file(subjectfile): |
||
280 | # SKOS/RDF file supported by rdflib |
||
281 | subjects = annif.corpus.SubjectFileSKOS(subjectfile) |
||
282 | elif annif.corpus.SubjectFileCSV.is_csv_file(subjectfile): |
||
283 | # CSV file |
||
284 | subjects = annif.corpus.SubjectFileCSV(subjectfile) |
||
285 | else: |
||
286 | # probably a TSV file - we need to know its language |
||
287 | if not language: |
||
288 | click.echo("Please use --language option to set the language of " + |
||
289 | "a TSV vocabulary.", err=True) |
||
290 | sys.exit(1) |
||
291 | subjects = annif.corpus.SubjectFileTSV(subjectfile, language) |
||
292 | vocab.load_vocabulary(subjects, force=force) |
||
293 | |||
294 | |||
295 | View Code Duplication | @cli.command('train') |
|
296 | @click.argument('project_id') |
||
297 | @click.argument('paths', type=click.Path(exists=True), nargs=-1) |
||
298 | @click.option('--cached/--no-cached', '-c/-C', default=False, |
||
299 | help='Reuse preprocessed training data from previous run') |
||
300 | @click.option('--docs-limit', '-d', default=None, |
||
301 | type=click.IntRange(0, None), |
||
302 | help='Maximum number of documents to use') |
||
303 | @click.option('--jobs', |
||
304 | '-j', |
||
305 | default=0, |
||
306 | help='Number of parallel jobs (0 means choose automatically)') |
||
307 | @backend_param_option |
||
308 | @common_options |
||
309 | def run_train(project_id, paths, cached, docs_limit, jobs, backend_param): |
||
310 | """ |
||
311 | Train a project on a collection of documents. |
||
312 | \f |
||
313 | This will train the project using the documents from ``PATHS`` (directories |
||
314 | or possibly gzipped TSV files) in a single batch operation. If ``--cached`` |
||
315 | is set, preprocessed training data from the previous run is reused instead |
||
316 | of documents input; see `Reusing preprocessed training data |
||
317 | <https://github.com/NatLibFi/Annif/wiki/ |
||
318 | Reusing-preprocessed-training-data>`_. |
||
319 | """ |
||
320 | proj = get_project(project_id) |
||
321 | backend_params = parse_backend_params(backend_param, proj) |
||
322 | if cached: |
||
323 | if len(paths) > 0: |
||
324 | raise click.UsageError( |
||
325 | "Corpus paths cannot be given when using --cached option.") |
||
326 | documents = 'cached' |
||
327 | else: |
||
328 | documents = open_documents(paths, proj.subjects, |
||
329 | proj.vocab_lang, docs_limit) |
||
330 | proj.train(documents, backend_params, jobs) |
||
331 | |||
332 | |||
333 | View Code Duplication | @cli.command('learn') |
|
334 | @click.argument('project_id') |
||
335 | @click.argument('paths', type=click.Path(exists=True), nargs=-1) |
||
336 | @click.option('--docs-limit', '-d', default=None, |
||
337 | type=click.IntRange(0, None), |
||
338 | help='Maximum number of documents to use') |
||
339 | @backend_param_option |
||
340 | @common_options |
||
341 | def run_learn(project_id, paths, docs_limit, backend_param): |
||
342 | """ |
||
343 | Further train an existing project on a collection of documents. |
||
344 | \f |
||
345 | Similar to the ``train`` command. This will continue training an already |
||
346 | trained project using the documents given by ``PATHS`` in a single batch |
||
347 | operation. Not supported by all backends. |
||
348 | """ |
||
349 | proj = get_project(project_id) |
||
350 | backend_params = parse_backend_params(backend_param, proj) |
||
351 | documents = open_documents(paths, proj.subjects, |
||
352 | proj.vocab_lang, docs_limit) |
||
353 | proj.learn(documents, backend_params) |
||
354 | |||
355 | |||
356 | View Code Duplication | @cli.command('suggest') |
|
357 | @click.argument('project_id') |
||
358 | @click.option('--limit', '-l', default=10, help='Maximum number of subjects') |
||
359 | @click.option('--threshold', '-t', default=0.0, help='Minimum score threshold') |
||
360 | @backend_param_option |
||
361 | @common_options |
||
362 | def run_suggest(project_id, limit, threshold, backend_param): |
||
363 | """ |
||
364 | Suggest subjects for a single document from standard input. |
||
365 | \f |
||
366 | This will read a text document from standard input and suggest subjects for |
||
367 | it. |
||
368 | """ |
||
369 | project = get_project(project_id) |
||
370 | text = sys.stdin.read() |
||
371 | backend_params = parse_backend_params(backend_param, project) |
||
372 | hit_filter = SuggestionFilter(project.subjects, limit, threshold) |
||
373 | hits = hit_filter(project.suggest(text, backend_params)) |
||
374 | for hit in hits.as_list(): |
||
375 | subj = project.subjects[hit.subject_id] |
||
376 | click.echo( |
||
377 | "<{}>\t{}\t{}".format( |
||
378 | subj.uri, |
||
379 | '\t'.join(filter(None, |
||
380 | (subj.labels[project.vocab_lang], |
||
381 | subj.notation))), |
||
382 | hit.score)) |
||
383 | |||
384 | |||
385 | View Code Duplication | @cli.command('index') |
|
386 | @click.argument('project_id') |
||
387 | @click.argument('directory', type=click.Path(exists=True, file_okay=False)) |
||
388 | @click.option( |
||
389 | '--suffix', |
||
390 | '-s', |
||
391 | default='.annif', |
||
392 | help='File name suffix for result files') |
||
393 | @click.option('--force/--no-force', '-f/-F', default=False, |
||
394 | help='Force overwriting of existing result files') |
||
395 | @click.option('--limit', '-l', default=10, help='Maximum number of subjects') |
||
396 | @click.option('--threshold', '-t', default=0.0, help='Minimum score threshold') |
||
397 | @backend_param_option |
||
398 | @common_options |
||
399 | def run_index(project_id, directory, suffix, force, |
||
400 | limit, threshold, backend_param): |
||
401 | """ |
||
402 | Index a directory with documents, suggesting subjects for each document. |
||
403 | Write the results in TSV files with the given suffix (``.annif`` by |
||
404 | default). |
||
405 | """ |
||
406 | project = get_project(project_id) |
||
407 | backend_params = parse_backend_params(backend_param, project) |
||
408 | hit_filter = SuggestionFilter(project.subjects, limit, threshold) |
||
409 | |||
410 | for docfilename, dummy_subjectfn in annif.corpus.DocumentDirectory( |
||
411 | directory, project.subjects, project.vocab_lang, |
||
412 | require_subjects=False): |
||
413 | with open(docfilename, encoding='utf-8-sig') as docfile: |
||
414 | text = docfile.read() |
||
415 | subjectfilename = re.sub(r'\.txt$', suffix, docfilename) |
||
416 | if os.path.exists(subjectfilename) and not force: |
||
417 | click.echo( |
||
418 | "Not overwriting {} (use --force to override)".format( |
||
419 | subjectfilename)) |
||
420 | continue |
||
421 | with open(subjectfilename, 'w', encoding='utf-8') as subjfile: |
||
422 | results = project.suggest(text, backend_params) |
||
423 | for hit in hit_filter(results).as_list(): |
||
424 | subj = project.subjects[hit.subject_id] |
||
425 | line = "<{}>\t{}\t{}".format( |
||
426 | subj.uri, |
||
427 | '\t'.join(filter(None, (subj.labels[project.vocab_lang], |
||
428 | subj.notation))), |
||
429 | hit.score) |
||
430 | click.echo(line, file=subjfile) |
||
431 | |||
432 | |||
433 | View Code Duplication | @cli.command('eval') |
|
434 | @click.argument('project_id') |
||
435 | @click.argument('paths', type=click.Path(exists=True), nargs=-1) |
||
436 | @click.option('--limit', '-l', default=10, help='Maximum number of subjects') |
||
437 | @click.option('--threshold', '-t', default=0.0, help='Minimum score threshold') |
||
438 | @click.option('--docs-limit', '-d', default=None, |
||
439 | type=click.IntRange(0, None), |
||
440 | help='Maximum number of documents to use') |
||
441 | @click.option('--metric', '-m', default=[], multiple=True, |
||
442 | help='Metric to calculate (default: all)') |
||
443 | @click.option( |
||
444 | '--metrics-file', |
||
445 | '-M', |
||
446 | type=click.File( |
||
447 | 'w', |
||
448 | encoding='utf-8', |
||
449 | errors='ignore', |
||
450 | lazy=True), |
||
451 | help="""Specify file in order to write evaluation metrics in JSON format. |
||
452 | File directory must exist, existing file will be overwritten.""") |
||
453 | @click.option( |
||
454 | '--results-file', |
||
455 | '-r', |
||
456 | type=click.File( |
||
457 | 'w', |
||
458 | encoding='utf-8', |
||
459 | errors='ignore', |
||
460 | lazy=True), |
||
461 | help="""Specify file in order to write non-aggregated results per subject. |
||
462 | File directory must exist, existing file will be overwritten.""") |
||
463 | @click.option('--jobs', |
||
464 | '-j', |
||
465 | default=1, |
||
466 | help='Number of parallel jobs (0 means all CPUs)') |
||
467 | @backend_param_option |
||
468 | @common_options |
||
469 | def run_eval( |
||
470 | project_id, |
||
471 | paths, |
||
472 | limit, |
||
473 | threshold, |
||
474 | docs_limit, |
||
475 | metric, |
||
476 | metrics_file, |
||
477 | results_file, |
||
478 | jobs, |
||
479 | backend_param): |
||
480 | """ |
||
481 | Suggest subjects for documents and evaluate the results by comparing |
||
482 | against a gold standard. |
||
483 | \f |
||
484 | With this command the documents from ``PATHS`` (directories or possibly |
||
485 | gzipped TSV files) will be assigned subject suggestions and then |
||
486 | statistical measures are calculated that quantify how well the suggested |
||
487 | subjects match the gold-standard subjects in the documents. |
||
488 | |||
489 | Normally the output is the list of the metrics calculated across documents. |
||
490 | If ``--results-file <FILENAME>`` option is given, the metrics are |
||
491 | calculated separately for each subject, and written to the given file. |
||
492 | """ |
||
493 | |||
494 | project = get_project(project_id) |
||
495 | backend_params = parse_backend_params(backend_param, project) |
||
496 | |||
497 | import annif.eval |
||
498 | eval_batch = annif.eval.EvaluationBatch(project.subjects) |
||
499 | |||
500 | if results_file: |
||
501 | try: |
||
502 | print('', end='', file=results_file) |
||
503 | click.echo('Writing per subject evaluation results to {!s}'.format( |
||
504 | results_file.name)) |
||
505 | except Exception as e: |
||
506 | raise NotSupportedException( |
||
507 | "cannot open results-file for writing: " + str(e)) |
||
508 | docs = open_documents(paths, project.subjects, |
||
509 | project.vocab_lang, docs_limit) |
||
510 | |||
511 | jobs, pool_class = annif.parallel.get_pool(jobs) |
||
512 | |||
513 | project.initialize(parallel=True) |
||
514 | psmap = annif.parallel.ProjectSuggestMap( |
||
515 | project.registry, [project_id], backend_params, limit, threshold) |
||
516 | |||
517 | with pool_class(jobs) as pool: |
||
518 | for hits, subject_set in pool.imap_unordered( |
||
519 | psmap.suggest, docs.documents): |
||
520 | eval_batch.evaluate(hits[project_id], |
||
521 | subject_set) |
||
522 | |||
523 | template = "{0:<30}\t{1}" |
||
524 | metrics = eval_batch.results(metrics=metric, |
||
525 | results_file=results_file, |
||
526 | language=project.vocab_lang) |
||
527 | for metric, score in metrics.items(): |
||
528 | click.echo(template.format(metric + ":", score)) |
||
529 | if metrics_file: |
||
530 | json.dump( |
||
531 | {metric_code(mname): val for mname, val in metrics.items()}, |
||
532 | metrics_file, indent=2) |
||
533 | |||
534 | |||
535 | View Code Duplication | @cli.command('optimize') |
|
536 | @click.argument('project_id') |
||
537 | @click.argument('paths', type=click.Path(exists=True), nargs=-1) |
||
538 | @click.option('--docs-limit', '-d', default=None, |
||
539 | type=click.IntRange(0, None), |
||
540 | help='Maximum number of documents to use') |
||
541 | @backend_param_option |
||
542 | @common_options |
||
543 | def run_optimize(project_id, paths, docs_limit, backend_param): |
||
544 | """ |
||
545 | Suggest subjects for documents, testing multiple limits and thresholds. |
||
546 | \f |
||
547 | This command will use different limit (maximum number of subjects) and |
||
548 | score threshold values when assigning subjects to each document given by |
||
549 | ``PATHS`` and compare the results against the gold standard subjects in the |
||
550 | documents. The output is a list of parameter combinations and their scores. |
||
551 | From the output, you can determine the optimum limit and threshold |
||
552 | parameters depending on which measure you want to target. |
||
553 | """ |
||
554 | project = get_project(project_id) |
||
555 | backend_params = parse_backend_params(backend_param, project) |
||
556 | |||
557 | filter_batches = generate_filter_batches(project.subjects) |
||
558 | |||
559 | ndocs = 0 |
||
560 | docs = open_documents(paths, project.subjects, |
||
561 | project.vocab_lang, docs_limit) |
||
562 | for doc in docs.documents: |
||
563 | raw_hits = project.suggest(doc.text, backend_params) |
||
564 | hits = raw_hits.filter(project.subjects, limit=BATCH_MAX_LIMIT) |
||
565 | assert isinstance(hits, ListSuggestionResult), \ |
||
566 | "Optimize should only be done with ListSuggestionResult " + \ |
||
567 | "as it would be very slow with VectorSuggestionResult." |
||
568 | for hit_filter, batch in filter_batches.values(): |
||
569 | batch.evaluate(hit_filter(hits), doc.subject_set) |
||
570 | ndocs += 1 |
||
571 | |||
572 | click.echo("\t".join(('Limit', 'Thresh.', 'Prec.', 'Rec.', 'F1'))) |
||
573 | |||
574 | best_scores = collections.defaultdict(float) |
||
575 | best_params = {} |
||
576 | |||
577 | template = "{:d}\t{:.02f}\t{:.04f}\t{:.04f}\t{:.04f}" |
||
578 | # Store the batches in a list that gets consumed along the way |
||
579 | # This way GC will have a chance to reclaim the memory |
||
580 | filter_batches = list(filter_batches.items()) |
||
581 | while filter_batches: |
||
582 | params, filter_batch = filter_batches.pop(0) |
||
583 | metrics = ['Precision (doc avg)', |
||
584 | 'Recall (doc avg)', |
||
585 | 'F1 score (doc avg)'] |
||
586 | results = filter_batch[1].results(metrics=metrics) |
||
587 | for metric, score in results.items(): |
||
588 | if score >= best_scores[metric]: |
||
589 | best_scores[metric] = score |
||
590 | best_params[metric] = params |
||
591 | click.echo( |
||
592 | template.format( |
||
593 | params[0], |
||
594 | params[1], |
||
595 | results['Precision (doc avg)'], |
||
596 | results['Recall (doc avg)'], |
||
597 | results['F1 score (doc avg)'])) |
||
598 | |||
599 | click.echo() |
||
600 | template2 = "Best {:>19}: {:.04f}\tLimit: {:d}\tThreshold: {:.02f}" |
||
601 | for metric in metrics: |
||
602 | click.echo( |
||
603 | template2.format( |
||
604 | metric, |
||
605 | best_scores[metric], |
||
606 | best_params[metric][0], |
||
607 | best_params[metric][1])) |
||
608 | click.echo("Documents evaluated:\t{}".format(ndocs)) |
||
609 | |||
610 | |||
611 | View Code Duplication | @cli.command('hyperopt') |
|
612 | @click.argument('project_id') |
||
613 | @click.argument('paths', type=click.Path(exists=True), nargs=-1) |
||
614 | @click.option('--docs-limit', '-d', default=None, |
||
615 | type=click.IntRange(0, None), |
||
616 | help='Maximum number of documents to use') |
||
617 | @click.option('--trials', '-T', default=10, help='Number of trials') |
||
618 | @click.option('--jobs', |
||
619 | '-j', |
||
620 | default=1, |
||
621 | help='Number of parallel runs (0 means all CPUs)') |
||
622 | @click.option('--metric', '-m', default='NDCG', |
||
623 | help='Metric to optimize (default: NDCG)') |
||
624 | @click.option( |
||
625 | '--results-file', |
||
626 | '-r', |
||
627 | type=click.File( |
||
628 | 'w', |
||
629 | encoding='utf-8', |
||
630 | errors='ignore', |
||
631 | lazy=True), |
||
632 | help="""Specify file path to write trial results as CSV. |
||
633 | File directory must exist, existing file will be overwritten.""") |
||
634 | @common_options |
||
635 | def run_hyperopt(project_id, paths, docs_limit, trials, jobs, metric, |
||
636 | results_file): |
||
637 | """ |
||
638 | Optimize the hyperparameters of a project using validation documents from |
||
639 | ``PATHS``. Not supported by all backends. Output is a list of trial results |
||
640 | and a report of the best performing parameters. |
||
641 | """ |
||
642 | proj = get_project(project_id) |
||
643 | documents = open_documents(paths, proj.subjects, |
||
644 | proj.vocab_lang, docs_limit) |
||
645 | click.echo(f"Looking for optimal hyperparameters using {trials} trials") |
||
646 | rec = proj.hyperopt(documents, trials, jobs, metric, results_file) |
||
647 | click.echo(f"Got best {metric} score {rec.score:.4f} with:") |
||
648 | click.echo("---") |
||
649 | for line in rec.lines: |
||
650 | click.echo(line) |
||
651 | click.echo("---") |
||
652 | |||
653 | |||
654 | if __name__ == '__main__': |
||
655 | cli() |
||
656 |