1
|
|
|
#!/usr/bin/env python |
2
|
|
|
|
3
|
|
|
import argparse |
4
|
|
|
import os |
5
|
|
|
import sys |
6
|
|
|
import re |
7
|
|
|
from PyInquirer import prompt |
8
|
|
|
|
9
|
|
|
from topic_modeling_toolkit.patm import PipeHandler, CoherenceFilesBuilder, political_spectrum |
10
|
|
|
|
11
|
|
|
|
12
|
|
|
def get_cl_arguments(): |
13
|
|
|
parser = argparse.ArgumentParser(prog='transform.py', description='Extracts from pickled data, preprocesses text and saves all necessary files to train an artm model', formatter_class=argparse.ArgumentDefaultsHelpFormatter) |
14
|
|
|
parser.add_argument('category', help='the category of data to use') |
15
|
|
|
parser.add_argument('config', help='the .cfg file to use for constructing a pipeline') |
16
|
|
|
parser.add_argument('collection', help='a given name for the collection') |
17
|
|
|
parser.add_argument('--sample', metavar='nb_docs', default='all', help='the number of documents to consider. Defaults to all documents') |
18
|
|
|
parser.add_argument('--window', '-w', default=10, type=int, help='number of tokens around specific token, which are used in calculation of cooccurrences') |
19
|
|
|
parser.add_argument('--min_tf', default=0, type=int, |
20
|
|
|
help='Minimal value of cooccurrences of a pair of tokens that are saved in dictionary of cooccurrences.') |
21
|
|
|
# 'For each int value a file is built to be used for coherence computation. By default builds one with min_tf=0') |
22
|
|
|
parser.add_argument('--min_df', default=0, type=int, |
23
|
|
|
help='Minimal value of documents in which a specific pair of tokens occurred together closely.') |
24
|
|
|
# 'For each int value a file is built to be used for coherence computation. By default builds one with min_df=0') |
25
|
|
|
parser.add_argument('--exclude_class_labels_from_vocab', '--ecliv', action='store_true', default=False, help='Whether to ommit adding document labels in the vocabulary file generated. If set to False document labels are treated as valid registered vocabulary tokens.') |
26
|
|
|
if len(sys.argv) == 1: |
27
|
|
|
parser.print_help() |
28
|
|
|
sys.exit(1) |
29
|
|
|
return parser.parse_args() |
30
|
|
|
|
31
|
|
|
|
32
|
|
|
def ask_discreetization(spectrum, pipe_handler, pool_size=100, prob=0.3, max_generation=100): |
33
|
|
|
|
34
|
|
|
namings = [['liberal', 'centre', 'conservative'], |
35
|
|
|
['liberal', 'centre_liberal', 'centre_conservative', 'conservative'], |
36
|
|
|
['liberal', 'centre_liberal', 'centre', 'centre_conservative', 'conservative'], |
37
|
|
|
['more_liberal', 'liberal', 'centre_liberal', 'centre', 'centre_conservative', 'conservative']] |
38
|
|
|
|
39
|
|
|
questions = [ |
40
|
|
|
{ |
41
|
|
|
'type': 'list', # navigate with arrows through choices |
42
|
|
|
'name': 'discreetization-scheme', |
43
|
|
|
'message': 'Use a registered discreetization scheme or create a new one.', |
44
|
|
|
'choices': ['Classes: [{}] with distribution [{}]'.format(' '.join(scheme.class_names), ' '.join('{:.2f}'.format(x) for x in spectrum.distribution(scheme))) for name, scheme in spectrum] |
45
|
|
|
+ ['Create new'] |
46
|
|
|
|
47
|
|
|
}, |
48
|
|
|
{ |
49
|
|
|
'type': 'list', # navigate with arrows through choices |
50
|
|
|
'name': 'naming-scheme', |
51
|
|
|
'message': 'You can pick one of the pre made class names or define your own custom', |
52
|
|
|
'choices': ['{}: {}'.format(len(names), ' '.join(names)) for names in namings] + ['Create custom names'], |
53
|
|
|
'when': lambda x: x['discreetization-scheme'] == 'Create new', |
54
|
|
|
}, |
55
|
|
|
{ |
56
|
|
|
'type': 'input', |
57
|
|
|
'name': 'custom-class-names', |
58
|
|
|
'message': 'Give space separated class names', |
59
|
|
|
'when': lambda x: x.get('naming-scheme', None) == 'Create custom names', |
60
|
|
|
} |
61
|
|
|
] |
62
|
|
|
answers = prompt(questions) |
63
|
|
|
if not all(x in answers for x in ['discreetization-scheme']): |
64
|
|
|
raise KeyboardInterrupt |
65
|
|
|
|
66
|
|
|
if answers['discreetization-scheme'] == 'Create new': |
67
|
|
|
evolution_specs = ask_evolution_specs() |
68
|
|
|
print("Evolving discreetization scheme ..") |
69
|
|
|
class_names = _class_names(answers.get('custom-class-names', answers['naming-scheme'])) |
70
|
|
|
spectrum.init_population(class_names, pipe_handler.outlet_ids, pool_size) |
71
|
|
|
return spectrum.evolve(int(evolution_specs['nb-generations']), prob=float(evolution_specs['probability'])) |
72
|
|
|
else: |
73
|
|
|
for scheme_name, scheme in spectrum: |
74
|
|
|
if ' '.join(scheme.class_names) in answers['discreetization-scheme']: |
75
|
|
|
return scheme |
76
|
|
|
# return spectrum[answers['discreetization-scheme']] |
77
|
|
|
|
78
|
|
|
|
79
|
|
|
def _class_names(string): |
80
|
|
|
return ['{}_Class'.format(x) for x in re.sub(r'\d+:\s+', '', string).split(' ')] |
81
|
|
|
|
82
|
|
|
def ask_persist(pol_spctrum): |
83
|
|
|
return prompt([{'type': 'confirm', |
84
|
|
|
'name': 'create-dataset', |
85
|
|
|
'message': "Use scheme [{}] with resulting distribution [{}]?".format(' '.join(pol_spctrum.class_names), ', '.join('{:.2f}'.format(x) for x in political_spectrum.class_distribution)), |
86
|
|
|
'default': True}])['create-dataset'] |
87
|
|
|
|
88
|
|
|
|
89
|
|
|
def ask_evolution_specs(): |
90
|
|
|
return prompt([{'type': 'input', |
91
|
|
|
'name': 'nb-generations', |
92
|
|
|
'message': 'Give the number of generations to evolve solution (optimize)', |
93
|
|
|
'default': '50'}, |
94
|
|
|
{'type': 'input', |
95
|
|
|
'name': 'probability', |
96
|
|
|
'message': 'Give mutation probability', |
97
|
|
|
'default': '0.35'}]) |
98
|
|
|
|
99
|
|
|
def what_to_do(): |
100
|
|
|
return prompt([{ |
101
|
|
|
'type': 'list', # navigate with arrows through choices |
102
|
|
|
'name': 'to-do', |
103
|
|
|
'message': 'How to proceed?', |
104
|
|
|
'choices': ['Evolve more', 'Use scheme to persist dataset', 'back'], |
105
|
|
|
}])['to-do'] |
106
|
|
|
|
107
|
|
|
|
108
|
|
|
def main(): |
109
|
|
|
args = get_cl_arguments() |
110
|
|
|
nb_docs = args.sample |
111
|
|
|
if nb_docs != 'all': |
112
|
|
|
nb_docs = int(nb_docs) |
113
|
|
|
collections_dir = os.getenv('COLLECTIONS_DIR') |
114
|
|
|
if not collections_dir: |
115
|
|
|
raise RuntimeError( |
116
|
|
|
"Please set the COLLECTIONS_DIR environment variable with the path to a directory containing collections/datasets") |
117
|
|
|
|
118
|
|
|
ph = PipeHandler() |
119
|
|
|
ph.process(args.config, args.category, sample=nb_docs, verbose=True) |
120
|
|
|
political_spectrum.datapoint_ids = ph.outlet_ids |
121
|
|
|
|
122
|
|
|
while 1: |
123
|
|
|
try: |
124
|
|
|
scheme = ask_discreetization(political_spectrum, ph, pool_size=100, prob=0.3, max_generation=100) |
125
|
|
|
except KeyboardInterrupt: |
126
|
|
|
print("Exiting ..") |
127
|
|
|
sys.exit(0) |
128
|
|
|
print("Scheme with classes: [{}]".format(' '.join(x for x, _ in scheme))) |
129
|
|
|
try: |
130
|
|
|
political_spectrum.discreetization_scheme = scheme |
131
|
|
|
except ValueError as e: |
132
|
|
|
raise ValueError("{}. {}".format(e, type(scheme).__name__)) |
133
|
|
|
|
134
|
|
|
print("Scheme [{}] with resulting distribution [{}]".format(' '.join(political_spectrum.class_names), ', '.join( |
135
|
|
|
'{:.2f}'.format(x) for x in political_spectrum.class_distribution))) |
136
|
|
|
print("Bins: {}".format(' '.join('[{}]'.format(', '.join(class_bin) for _, class_bin in scheme)))) |
137
|
|
|
while 1: |
138
|
|
|
answer = what_to_do() |
139
|
|
|
if answer == 'back': |
140
|
|
|
break |
141
|
|
|
if answer == 'Evolve more': |
142
|
|
|
evolution_specs = ask_evolution_specs() |
143
|
|
|
print("Evolving discreetization scheme ..") |
144
|
|
|
scheme = political_spectrum.evolve(int(evolution_specs['nb-generations']), |
145
|
|
|
prob=float(evolution_specs['probability'])) |
146
|
|
|
political_spectrum.discreetization_scheme = scheme |
147
|
|
|
print("Scheme [{}] with resulting distribution [{}]".format(' '.join(scheme.class_names), |
148
|
|
|
', '.join('{:.2f}'.format(x) for x in |
149
|
|
|
political_spectrum.class_distribution))) |
150
|
|
|
print("Bins: {}".format( |
151
|
|
|
' '.join('[{}]'.format(', '.join(outlet for outlet in class_bin) for _, class_bin in scheme)))) |
152
|
|
|
else: |
153
|
|
|
uci_dt = ph.persist(os.path.join(collections_dir, args.collection), |
154
|
|
|
political_spectrum.poster_id2ideology_label, political_spectrum.class_names, |
155
|
|
|
add_class_labels_to_vocab=not args.exclude_class_labels_from_vocab) |
156
|
|
|
print(uci_dt) |
157
|
|
|
print("Discreetization scheme\n{}".format(political_spectrum.discreetization_scheme)) |
158
|
|
|
|
159
|
|
|
# print("Add the below to the DISCREETIZATION_SCHEMES_HASH") |
160
|
|
|
# print("[{}]".format()) |
161
|
|
|
print('\nBuilding coocurences information') |
162
|
|
|
coherence_builder = CoherenceFilesBuilder(os.path.join(collections_dir, args.collection)) |
163
|
|
|
coherence_builder.create_files(cooc_window=args.window, |
164
|
|
|
min_tf=args.min_tf, |
165
|
|
|
min_df=args.min_df, |
166
|
|
|
apply_zero_index=False) |
167
|
|
|
sys.exit(0) |
168
|
|
|
|
169
|
|
|
|
170
|
|
|
if __name__ == '__main__': |
171
|
|
|
main() |
172
|
|
|
|
173
|
|
|
# if ask_persist(political_spectrum): |
174
|
|
|
# |
175
|
|
|
# break |
176
|
|
|
# |
177
|
|
|
# print(uci_dt) |
178
|
|
|
# |
179
|
|
|
# print('\nBuilding coocurences information') |
180
|
|
|
# coherence_builder = CoherenceFilesBuilder(os.path.join(collections_dir, args.collection)) |
181
|
|
|
# coherence_builder.create_files(cooc_window=args.window, |
182
|
|
|
# min_tf=args.min_tf, |
183
|
|
|
# min_df=args.min_df, |
184
|
|
|
# apply_zero_index=False) |
185
|
|
|
|