| Total Complexity | 5 |
| Total Lines | 54 |
| Duplicated Lines | 0 % |
| Changes | 0 | ||
| 1 | #!/usr/bin/env python |
||
| 2 | |||
| 3 | import argparse |
||
| 4 | from topic_modeling_toolkit.reporting import ModelReporter |
||
| 5 | from topic_modeling_toolkit.reporting.reporter import InvalidMetricException |
||
| 6 | |||
| 7 | |||
| 8 | def get_cli_arguments(): |
||
| 9 | parser = argparse.ArgumentParser(prog='report_models.py', description='Reports on the trained models for the specified collection (dataset)', formatter_class=argparse.ArgumentDefaultsHelpFormatter) |
||
| 10 | parser.add_argument('dataset', metavar='collection_name', help='the collection to report models trained on') |
||
| 11 | # parser.add_argument('--details', '-d', default=False, action='store_true', help='Switch to show details about the models') |
||
| 12 | parser.add_argument('--sort', '-s', default='perplexity', help='Whether to sort the found experiments by checking the desired metric against the corresponding models') |
||
| 13 | return parser.parse_args() |
||
| 14 | |||
| 15 | |||
| 16 | def main(): |
||
| 17 | COLUMNS = [ |
||
| 18 | 'nb-topics', |
||
| 19 | 'collection-passes', |
||
| 20 | 'document-passes', |
||
| 21 | # 'total-phi-updates', |
||
| 22 | 'perplexity', |
||
| 23 | 'kernel-size', |
||
| 24 | 'kernel-coherence', |
||
| 25 | 'kernel-contrast', |
||
| 26 | 'kernel-purity', |
||
| 27 | 'top-tokens-coherence', |
||
| 28 | 'sparsity-phi', |
||
| 29 | 'sparsity-theta', |
||
| 30 | 'background-tokens-ratio', |
||
| 31 | 'regularizers' |
||
| 32 | ] |
||
| 33 | |||
| 34 | cli_args = get_cli_arguments() |
||
| 35 | sort_metric = cli_args.sort |
||
| 36 | |||
| 37 | collections_dir = os.getenv('COLLECTIONS_DIR') |
||
|
|
|||
| 38 | if not collections_dir: |
||
| 39 | raise RuntimeError( |
||
| 40 | "Please set the COLLECTIONS_DIR environment variable with the path to a directory containing collections/datasets") |
||
| 41 | model_reporter = ModelReporter(collections_dir) |
||
| 42 | while 1: |
||
| 43 | try: |
||
| 44 | s = model_reporter.get_formatted_string(cli_args.dataset, columns=COLUMNS, metric=sort_metric, verbose=True) |
||
| 45 | print('\n{}'.format(s)) |
||
| 46 | break |
||
| 47 | except InvalidMetricException as e: |
||
| 48 | print(e) |
||
| 49 | sort_metric = input("Please input another metric to sort (blank for 'perplexity'): ") |
||
| 50 | |||
| 51 | |||
| 52 | if __name__ == '__main__': |
||
| 53 | main() |
||
| 54 |