|
1
|
|
|
#!/usr/bin/env python |
|
2
|
|
|
|
|
3
|
|
|
import os |
|
4
|
|
|
import sys |
|
5
|
|
|
import argparse |
|
6
|
|
|
|
|
7
|
|
|
from topic_modeling_toolkit.patm import TrainerFactory, Experiment |
|
8
|
|
|
|
|
9
|
|
|
|
|
10
|
|
|
def get_cl_arguments(): |
|
11
|
|
|
parser = argparse.ArgumentParser(prog='train.py', description='Trains an artm topic model and stores \'evaluation\' scores', formatter_class=argparse.ArgumentDefaultsHelpFormatter) |
|
12
|
|
|
parser.add_argument('collection', help='the name for the collection to train on') |
|
13
|
|
|
parser.add_argument('config', help='the .cfg file to use for constructing and training the topic_model') |
|
14
|
|
|
parser.add_argument('label', metavar='id', default='def', help='a unique identifier used for a newly created model') |
|
15
|
|
|
parser.add_argument('--reg-config', '--r-c', dest='reg_config', help='the .cfg file containing initialization parameters for the active regularizers') |
|
16
|
|
|
parser.add_argument('--save', default=True, action='store_true', help='saves the state of the model and experimental results after the training iterations finish') |
|
17
|
|
|
# parser.add_argument('--load', default=False, action='store_true', help='restores the model state and progress of tracked entities from disk') |
|
18
|
|
|
parser.add_argument('--new-batches', '--n-b', default=False, dest='new_batches', action='store_true', help='whether to force the creation of new batches, regardless of finding batches already existing') |
|
19
|
|
|
if len(sys.argv) == 1: |
|
20
|
|
|
parser.print_help() |
|
21
|
|
|
sys.exit(1) |
|
22
|
|
|
return parser.parse_args() |
|
23
|
|
|
|
|
24
|
|
|
def main(): |
|
25
|
|
|
args = get_cl_arguments() |
|
26
|
|
|
collections_dir = os.getenv('COLLECTIONS_DIR') |
|
27
|
|
|
if not collections_dir: |
|
28
|
|
|
raise RuntimeError( |
|
29
|
|
|
"Please set the COLLECTIONS_DIR environment variable to the directory containing collections/datasets") |
|
30
|
|
|
root_dir = os.path.join(collections_dir, args.collection) |
|
31
|
|
|
model_trainer = TrainerFactory().create_trainer(root_dir, exploit_ideology_labels=True, |
|
32
|
|
|
force_new_batches=args.new_batches) |
|
33
|
|
|
experiment = Experiment(root_dir) |
|
34
|
|
|
model_trainer.register( |
|
35
|
|
|
experiment) # when the model_trainer trains, the experiment object keeps track of evaluation metrics |
|
36
|
|
|
|
|
37
|
|
|
# if args.load: |
|
38
|
|
|
# topic_model = experiment.load_experiment(args.label) |
|
39
|
|
|
# print '\nLoaded experiment and model state' |
|
40
|
|
|
# settings = cfg2model_settings(args.config) |
|
41
|
|
|
# train_specs = TrainSpecs(15, [], []) |
|
42
|
|
|
# else: |
|
43
|
|
|
topic_model = model_trainer.model_factory.create_model(args.label, args.config, reg_cfg=args.reg_config, |
|
44
|
|
|
show_progress_bars=False) |
|
45
|
|
|
train_specs = model_trainer.model_factory.create_train_specs() |
|
46
|
|
|
experiment.init_empty_trackables(topic_model) |
|
47
|
|
|
# static_reg_specs = {} # regularizers' parameters that should be kept constant during data fitting (model training) |
|
48
|
|
|
# import pprint |
|
49
|
|
|
# pprint.pprint({k: dict(v, **{setting_name: setting_value for setting_name, setting_value in {'target topics': (lambda x: 'all' if len(x) == 0 else '[{}]'.format(', '.join(x)))(topic_model.get_reg_obj(topic_model.get_reg_name(k)).topic_names), 'mods': getattr(topic_model.get_reg_obj(topic_model.get_reg_name(k)), 'class_ids', None)}.items()}) for k, v in self.static_regularization_specs.items()}) |
|
50
|
|
|
# pprint.pprint(tm.modalities_dictionary) |
|
51
|
|
|
print("Initialized Model:") |
|
52
|
|
|
print(topic_model.pformat_regularizers) |
|
53
|
|
|
print(topic_model.pformat_modalities) |
|
54
|
|
|
model_trainer.train(topic_model, train_specs, effects=True, cache_theta=True) |
|
55
|
|
|
print('Iterated {} times through the collection and {} times over each document: total phi updates = {}'. |
|
56
|
|
|
format(train_specs.collection_passes, topic_model.document_passes, |
|
57
|
|
|
train_specs.collection_passes * topic_model.document_passes)) |
|
58
|
|
|
|
|
59
|
|
|
if args.save: |
|
60
|
|
|
experiment.save_experiment(save_phi=True) |
|
61
|
|
|
print("Saved results and model '{}'".format(args.label)) |
|
62
|
|
|
|
|
63
|
|
|
|
|
64
|
|
|
if __name__ == '__main__': |
|
65
|
|
|
main() |