Passed
Branch dev-release (a75e90)
by Konstantinos
02:13
created

train.main()   A

Complexity

Conditions 3

Size

Total Lines 38
Code Lines 26

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 26
nop 0
dl 0
loc 38
rs 9.256
c 0
b 0
f 0
1
#!/usr/bin/env python
2
3
import os
4
import sys
5
import argparse
6
7
from topic_modeling_toolkit.patm import TrainerFactory, Experiment
8
9
10
def get_cl_arguments():
11
    parser = argparse.ArgumentParser(prog='train.py', description='Trains an artm topic model and stores \'evaluation\' scores', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
12
    parser.add_argument('collection', help='the name for the collection to train on')
13
    parser.add_argument('config', help='the .cfg file to use for constructing and training the topic_model')
14
    parser.add_argument('label', metavar='id', default='def', help='a unique identifier used for a newly created model')
15
    parser.add_argument('--reg-config', '--r-c', dest='reg_config', help='the .cfg file containing initialization parameters for the active regularizers')
16
    parser.add_argument('--save', default=True, action='store_true', help='saves the state of the model and experimental results after the training iterations finish')
17
    # parser.add_argument('--load', default=False, action='store_true', help='restores the model state and progress of tracked entities from disk')
18
    parser.add_argument('--new-batches', '--n-b', default=False, dest='new_batches', action='store_true', help='whether to force the creation of new batches, regardless of finding batches already existing')
19
    if len(sys.argv) == 1:
20
        parser.print_help()
21
        sys.exit(1)
22
    return parser.parse_args()
23
24
def main():
25
    args = get_cl_arguments()
26
    collections_dir = os.getenv('COLLECTIONS_DIR')
27
    if not collections_dir:
28
        raise RuntimeError(
29
            "Please set the COLLECTIONS_DIR environment variable to the directory containing collections/datasets")
30
    root_dir = os.path.join(collections_dir, args.collection)
31
    model_trainer = TrainerFactory().create_trainer(root_dir, exploit_ideology_labels=True,
32
                                                    force_new_batches=args.new_batches)
33
    experiment = Experiment(root_dir)
34
    model_trainer.register(
35
        experiment)  # when the model_trainer trains, the experiment object keeps track of evaluation metrics
36
37
    # if args.load:
38
    #     topic_model = experiment.load_experiment(args.label)
39
    #     print '\nLoaded experiment and model state'
40
    #     settings = cfg2model_settings(args.config)
41
    #     train_specs = TrainSpecs(15, [], [])
42
    # else:
43
    topic_model = model_trainer.model_factory.create_model(args.label, args.config, reg_cfg=args.reg_config,
44
                                                           show_progress_bars=False)
45
    train_specs = model_trainer.model_factory.create_train_specs()
46
    experiment.init_empty_trackables(topic_model)
47
    # static_reg_specs = {}  # regularizers' parameters that should be kept constant during data fitting (model training)
48
    # import pprint
49
    # pprint.pprint({k: dict(v, **{setting_name: setting_value for setting_name, setting_value in {'target topics': (lambda x: 'all' if len(x) == 0 else '[{}]'.format(', '.join(x)))(topic_model.get_reg_obj(topic_model.get_reg_name(k)).topic_names), 'mods': getattr(topic_model.get_reg_obj(topic_model.get_reg_name(k)), 'class_ids', None)}.items()}) for k, v in self.static_regularization_specs.items()})
50
    # pprint.pprint(tm.modalities_dictionary)
51
    print("Initialized Model:")
52
    print(topic_model.pformat_regularizers)
53
    print(topic_model.pformat_modalities)
54
    model_trainer.train(topic_model, train_specs, effects=True, cache_theta=True)
55
    print('Iterated {} times through the collection and {} times over each document: total phi updates = {}'.
56
          format(train_specs.collection_passes, topic_model.document_passes,
57
                 train_specs.collection_passes * topic_model.document_passes))
58
59
    if args.save:
60
        experiment.save_experiment(save_phi=True)
61
        print("Saved results and model '{}'".format(args.label))
62
63
64
if __name__ == '__main__':
65
    main()