1
|
|
|
#!/usr/bin/env python |
2
|
|
|
|
3
|
|
|
import os |
4
|
|
|
import sys |
5
|
|
|
import argparse |
6
|
|
|
|
7
|
|
|
from topic_modeling_toolkit.patm import TrainerFactory, Experiment |
8
|
|
|
|
9
|
|
|
|
10
|
|
|
def get_cl_arguments(): |
11
|
|
|
parser = argparse.ArgumentParser(prog='train.py', description='Trains an artm topic model and stores \'evaluation\' scores', formatter_class=argparse.ArgumentDefaultsHelpFormatter) |
12
|
|
|
parser.add_argument('collection', help='the name for the collection to train on') |
13
|
|
|
parser.add_argument('config', help='the .cfg file to use for constructing and training the topic_model') |
14
|
|
|
parser.add_argument('label', metavar='id', default='def', help='a unique identifier used for a newly created model') |
15
|
|
|
parser.add_argument('--reg-config', '--r-c', dest='reg_config', help='the .cfg file containing initialization parameters for the active regularizers') |
16
|
|
|
parser.add_argument('--save', default=True, action='store_true', help='saves the state of the model and experimental results after the training iterations finish') |
17
|
|
|
# parser.add_argument('--load', default=False, action='store_true', help='restores the model state and progress of tracked entities from disk') |
18
|
|
|
parser.add_argument('--new-batches', '--n-b', default=False, dest='new_batches', action='store_true', help='whether to force the creation of new batches, regardless of finding batches already existing') |
19
|
|
|
if len(sys.argv) == 1: |
20
|
|
|
parser.print_help() |
21
|
|
|
sys.exit(1) |
22
|
|
|
return parser.parse_args() |
23
|
|
|
|
24
|
|
|
def main(): |
25
|
|
|
args = get_cl_arguments() |
26
|
|
|
collections_dir = os.getenv('COLLECTIONS_DIR') |
27
|
|
|
if not collections_dir: |
28
|
|
|
raise RuntimeError( |
29
|
|
|
"Please set the COLLECTIONS_DIR environment variable to the directory containing collections/datasets") |
30
|
|
|
root_dir = os.path.join(collections_dir, args.collection) |
31
|
|
|
model_trainer = TrainerFactory().create_trainer(root_dir, exploit_ideology_labels=True, |
32
|
|
|
force_new_batches=args.new_batches) |
33
|
|
|
experiment = Experiment(root_dir) |
34
|
|
|
model_trainer.register( |
35
|
|
|
experiment) # when the model_trainer trains, the experiment object keeps track of evaluation metrics |
36
|
|
|
|
37
|
|
|
# if args.load: |
38
|
|
|
# topic_model = experiment.load_experiment(args.label) |
39
|
|
|
# print '\nLoaded experiment and model state' |
40
|
|
|
# settings = cfg2model_settings(args.config) |
41
|
|
|
# train_specs = TrainSpecs(15, [], []) |
42
|
|
|
# else: |
43
|
|
|
topic_model = model_trainer.model_factory.create_model(args.label, args.config, reg_cfg=args.reg_config, |
44
|
|
|
show_progress_bars=False) |
45
|
|
|
train_specs = model_trainer.model_factory.create_train_specs() |
46
|
|
|
experiment.init_empty_trackables(topic_model) |
47
|
|
|
# static_reg_specs = {} # regularizers' parameters that should be kept constant during data fitting (model training) |
48
|
|
|
# import pprint |
49
|
|
|
# pprint.pprint({k: dict(v, **{setting_name: setting_value for setting_name, setting_value in {'target topics': (lambda x: 'all' if len(x) == 0 else '[{}]'.format(', '.join(x)))(topic_model.get_reg_obj(topic_model.get_reg_name(k)).topic_names), 'mods': getattr(topic_model.get_reg_obj(topic_model.get_reg_name(k)), 'class_ids', None)}.items()}) for k, v in self.static_regularization_specs.items()}) |
50
|
|
|
# pprint.pprint(tm.modalities_dictionary) |
51
|
|
|
print("Initialized Model:") |
52
|
|
|
print(topic_model.pformat_regularizers) |
53
|
|
|
print(topic_model.pformat_modalities) |
54
|
|
|
model_trainer.train(topic_model, train_specs, effects=True, cache_theta=True) |
55
|
|
|
print('Iterated {} times through the collection and {} times over each document: total phi updates = {}'. |
56
|
|
|
format(train_specs.collection_passes, topic_model.document_passes, |
57
|
|
|
train_specs.collection_passes * topic_model.document_passes)) |
58
|
|
|
|
59
|
|
|
if args.save: |
60
|
|
|
experiment.save_experiment(save_phi=True) |
61
|
|
|
print("Saved results and model '{}'".format(args.label)) |
62
|
|
|
|
63
|
|
|
|
64
|
|
|
if __name__ == '__main__': |
65
|
|
|
main() |