1
|
|
|
#!/usr/bin/env python |
2
|
|
|
|
3
|
|
|
import os |
4
|
|
|
import re |
5
|
|
|
import sys |
6
|
|
|
import click |
7
|
|
|
from .reporting import GraphMaker |
8
|
|
|
# from .patm.definitions import COLLECTIONS_DIR_PATH |
9
|
|
|
|
10
|
|
|
c = ['perplexity', 'kernel-size', 'kernel-coherence', 'kernel-contrast', 'kernel-purity', 'top-tokens-coherence', 'sparsity-phi', |
11
|
|
|
'sparsity-theta', 'background-tokens-ratio'] |
12
|
|
|
|
13
|
|
|
# class PythonLiteralOption(click.Option): |
14
|
|
|
# def type_cast_value(self, ctx, value): |
15
|
|
|
# try: |
16
|
|
|
# return ast.literal_eval(value) |
17
|
|
|
# except: |
18
|
|
|
# raise click.BadParameter(value) |
19
|
|
|
class ModelLabelsParser(click.Option): |
20
|
|
|
|
21
|
|
|
def type_cast_value(self, ctx, value): |
22
|
|
|
try: |
23
|
|
|
return re.compile(r'(\w+)').findall(value) |
24
|
|
|
except: |
25
|
|
|
raise click.BadParameter(value) |
26
|
|
|
|
27
|
|
|
class ScoreMetricsDefinitionParser(click.Option): |
28
|
|
|
def type_cast_value(self, ctx, value): |
29
|
|
|
try: |
30
|
|
|
return re.compile(r'([\w\-]+)').findall(value) |
31
|
|
|
except: |
32
|
|
|
raise click.BadParameter(value) |
33
|
|
|
|
34
|
|
|
@click.command(context_settings=dict(help_option_names=['-h', '--help'])) |
35
|
|
|
@click.argument('dataset') # help="the name of the dataset/collection in which to look for trained models" |
36
|
|
|
@click.option('--models-labels', '-ml', cls=ModelLabelsParser, default='', help='Selects model(s) solely based on the names provided and ignores all other selection and sorting options below') |
37
|
|
|
@click.option('--sort', '-s', default='perplexity', show_default=True, help="How to order the results' list, from which to select; supports 'alphabetical' and rest of metrics eg: {}".format(', '.join("'{}'".format(x) for x in c))) |
38
|
|
|
@click.option('--model-indices', '-mi', cls=ModelLabelsParser, default='', type=int, help='Selects model(s) based on linear indices corresponding to the list of results per model (see --sort option), residing in collection directory and ignores all othe selection options below') |
39
|
|
|
@click.option('--top', '-t', type=int, help='Selects the first n model results from the list and ignores all othe selection options below') |
40
|
|
|
@click.option('--range-tuple', '-r', cls=ModelLabelsParser, default='', help='Selects model results defined by the range') |
41
|
|
|
@click.option('--allmetrics/--no-allmetrics', default=False, show_default=True, help='Whether to create plots for all possible metrics (maximal) discovered in the model results') |
42
|
|
|
@click.option('--metrics', cls=ScoreMetricsDefinitionParser, default='', help='Whether to limit plotting to the selected graph types (roughly score definition) i.e. "perplexity", "sparsity-phi-0.80", "top-tokens-coherence-10", If --allmetrics flag is enabled it has no effect.') |
43
|
|
|
@click.option('--tautrajectories/--no-tautrajectories', default=False, show_default=True, help="Legacy flag: turn this on explicitly in case you have used dynamic regularization coefficient trajectory feature") |
44
|
|
|
@click.option('--iterations', '-i', type=int, help='Whether to limit the datapoints plotted to the specified number. Defaults to plotting information for all training iterations') |
45
|
|
|
@click.option('--legend/--no-legend', default=True, show_default=True, help="Whether to include legend information in the resulting graph images") |
46
|
|
|
def main(dataset, models_labels, sort, model_indices, top, range_tuple, allmetrics, metrics, tautrajectories, iterations, legend): |
47
|
|
|
if models_labels: |
48
|
|
|
selection = models_labels |
49
|
|
|
print(selection) |
50
|
|
|
elif model_indices: |
51
|
|
|
selection = model_indices |
52
|
|
|
elif top: |
53
|
|
|
selection = top |
54
|
|
|
elif range_tuple: |
55
|
|
|
selection = range(*range_tuple) |
56
|
|
|
else: |
57
|
|
|
selection = 8 |
58
|
|
|
if allmetrics: |
59
|
|
|
metrics = 'all' |
60
|
|
|
if tautrajectories: |
61
|
|
|
tautrajectories = 'all' |
62
|
|
|
if sys.version_info[1] == 3: |
63
|
|
|
print("The graph building process is probably going to result in an Attribute error ('Legend' object has no attribute 'draggable') " |
64
|
|
|
"due to a bug of the easyplot module when requesting legend information to be inculded in the graph image built." |
65
|
|
|
" For now please either invoke program with python2 interpreter or use the '--no-legend' flag.") |
66
|
|
|
collections_dir = os.getenv('COLLECTIONS_DIR') |
67
|
|
|
if not collections_dir: |
68
|
|
|
raise RuntimeError( |
69
|
|
|
"Please set the COLLECTIONS_DIR environment variable with the path to a directory containing collections/datasets") |
70
|
|
|
graph_maker = GraphMaker(collections_dir) |
71
|
|
|
graph_maker.build_graphs_from_collection(dataset, selection, |
72
|
|
|
metric=sort, |
73
|
|
|
score_definitions=metrics, |
74
|
|
|
tau_trajectories=tautrajectories, |
75
|
|
|
showlegend=legend) |
76
|
|
|
|
77
|
|
|
print("\nFigures' paths:") |
78
|
|
|
for _ in graph_maker.saved_figures: |
79
|
|
|
print(_) |
80
|
|
|
|
81
|
|
|
if __name__ == '__main__': |
82
|
|
|
main() |
83
|
|
|
|