Total Complexity | 5 |
Total Lines | 54 |
Duplicated Lines | 0 % |
Changes | 0 |
1 | #!/usr/bin/env python |
||
2 | |||
3 | import argparse |
||
4 | from topic_modeling_toolkit.reporting import ModelReporter |
||
5 | from topic_modeling_toolkit.reporting.reporter import InvalidMetricException |
||
6 | |||
7 | |||
8 | def get_cli_arguments(): |
||
9 | parser = argparse.ArgumentParser(prog='report_models.py', description='Reports on the trained models for the specified collection (dataset)', formatter_class=argparse.ArgumentDefaultsHelpFormatter) |
||
10 | parser.add_argument('dataset', metavar='collection_name', help='the collection to report models trained on') |
||
11 | # parser.add_argument('--details', '-d', default=False, action='store_true', help='Switch to show details about the models') |
||
12 | parser.add_argument('--sort', '-s', default='perplexity', help='Whether to sort the found experiments by checking the desired metric against the corresponding models') |
||
13 | return parser.parse_args() |
||
14 | |||
15 | |||
16 | def main(): |
||
17 | COLUMNS = [ |
||
18 | 'nb-topics', |
||
19 | 'collection-passes', |
||
20 | 'document-passes', |
||
21 | # 'total-phi-updates', |
||
22 | 'perplexity', |
||
23 | 'kernel-size', |
||
24 | 'kernel-coherence', |
||
25 | 'kernel-contrast', |
||
26 | 'kernel-purity', |
||
27 | 'top-tokens-coherence', |
||
28 | 'sparsity-phi', |
||
29 | 'sparsity-theta', |
||
30 | 'background-tokens-ratio', |
||
31 | 'regularizers' |
||
32 | ] |
||
33 | |||
34 | cli_args = get_cli_arguments() |
||
35 | sort_metric = cli_args.sort |
||
36 | |||
37 | collections_dir = os.getenv('COLLECTIONS_DIR') |
||
|
|||
38 | if not collections_dir: |
||
39 | raise RuntimeError( |
||
40 | "Please set the COLLECTIONS_DIR environment variable with the path to a directory containing collections/datasets") |
||
41 | model_reporter = ModelReporter(collections_dir) |
||
42 | while 1: |
||
43 | try: |
||
44 | s = model_reporter.get_formatted_string(cli_args.dataset, columns=COLUMNS, metric=sort_metric, verbose=True) |
||
45 | print('\n{}'.format(s)) |
||
46 | break |
||
47 | except InvalidMetricException as e: |
||
48 | print(e) |
||
49 | sort_metric = input("Please input another metric to sort (blank for 'perplexity'): ") |
||
50 | |||
51 | |||
52 | if __name__ == '__main__': |
||
53 | main() |
||
54 |