boromir674 /
topic-modeling-toolkit
| 1 | import os |
||
| 2 | import pytest |
||
| 3 | from topic_modeling_toolkit.reporting import PsiReporter |
||
| 4 | from topic_modeling_toolkit.patm import Experiment, TrainerFactory |
||
| 5 | |||
| 6 | |||
| 7 | @pytest.fixture(scope='session') |
||
| 8 | def megadata_dir(unittests_data_dir): |
||
| 9 | return os.path.join(unittests_data_dir, 'megadata') |
||
| 10 | |||
| 11 | |||
| 12 | @pytest.fixture(scope='session') |
||
| 13 | def regression_model_path(megadata_dir): |
||
| 14 | trainer = TrainerFactory().create_trainer(megadata_dir, exploit_ideology_labels=True, force_new_batches=False) |
||
| 15 | experiment = Experiment(megadata_dir) |
||
| 16 | topic_model = trainer.model_factory.create_model('candidate', os.path.join(MODULE_DIR, 'regression.cfg'), reg_cfg=os.path.join(MODULE_DIR, 'test-regularizers.cfg'), |
||
|
0 ignored issues
–
show
Comprehensibility
Best Practice
introduced
by
Loading history...
|
|||
| 17 | show_progress_bars=False) |
||
| 18 | train_specs = trainer.model_factory.create_train_specs() |
||
| 19 | trainer.register(experiment) |
||
| 20 | experiment.init_empty_trackables(topic_model) |
||
| 21 | trainer.train(topic_model, train_specs, effects=False, cache_theta=True) |
||
| 22 | experiment.save_experiment(save_phi=True) |
||
| 23 | # CHANGE THIS SO THAT YOU DO NOT OVERRIDE the previously persisted model phi matrix object and results json |
||
| 24 | return 'candidate.phi' |
||
| 25 | |||
| 26 | |||
| 27 | @pytest.fixture(scope='module') |
||
| 28 | def pr(megadata_dir): |
||
| 29 | pr = PsiReporter() |
||
| 30 | pr.dataset = megadata_dir |
||
| 31 | return pr |
||
| 32 | |||
| 33 | |||
| 34 | @pytest.fixture(scope='module', params=[3]) |
||
| 35 | def datasets(megadata_dir, pr, request): |
||
| 36 | """Dataset and models trained on it parametrized on the number of document classes defined""" |
||
| 37 | data = {3: {'dataset': megadata_dir, |
||
| 38 | 'models': [ |
||
| 39 | { |
||
| 40 | 'label': 'sgo_1.phi', |
||
| 41 | 'expected-string': 'liberal_Class 70.8 71.9\n' |
||
| 42 | 'centre_Class 70.8 65.1\n' |
||
| 43 | 'conservative_Class 71.9 65.1 \n', |
||
| 44 | 'expected-matrix': [[0, 70.8, 71.9], |
||
| 45 | [70.8, 0, 65.1], |
||
| 46 | [71.9, 65.1, 0]] |
||
| 47 | }, |
||
| 48 | { |
||
| 49 | 'label': 'candidate.phi', |
||
| 50 | 'expected-string': 'liberal_Class 34.9 35.0\n' |
||
| 51 | 'centre_Class 34.9 29.2\n' |
||
| 52 | 'conservative_Class 35.0 29.2 \n', |
||
| 53 | 'expected-matrix': [[0, 34.9, 35], |
||
| 54 | [34.9, 0, 29.2], |
||
| 55 | [35, 29.2, 0]] |
||
| 56 | } |
||
| 57 | ] |
||
| 58 | }} |
||
| 59 | subdict = data[request.param] |
||
| 60 | for model_data in subdict['models']: |
||
| 61 | model_data['reported-distances'] = [[float('{:.1f}'.format(k)) for k in z] for z in pr.values([model_data['label']], topics_set='domain')[0]] |
||
| 62 | model_data['reported-string'] = pr.pformat([model_data['label']], topics_set='domain', show_model_name=False, show_class_names=True) |
||
| 63 | return subdict |
||
| 64 | |||
| 65 | # Insert fixture 'regression_model_path', which is the path to a newly trained model, with settings in regression.cfg, to test if it places the ideology classes relatively in correct ordering on the political spectrum |
||
| 66 | def test_sanity(pr, datasets): |
||
| 67 | pr.dataset = datasets['dataset'] |
||
| 68 | for d in datasets['models']: |
||
| 69 | # array = [[float('{:.1f}'.format(k)) for k in z] for z in pr.values([d['label']], topics_set='domain')[0]] |
||
| 70 | for i, row in enumerate(d['reported-distances']): |
||
| 71 | assert row[:i] == sorted(row[:i], reverse=True) |
||
| 72 | assert row[i] == 0 |
||
| 73 | assert row[i:] == sorted(row[i:], reverse=False) |
||
| 74 | |||
| 75 | # |
||
| 76 | # @pytest.fixture(scope='module', params=[3]) |
||
| 77 | # def best_models(request): |
||
| 78 | # """Models achieving best separation between the p(c|t) distributions and their exected Psi-matrix related results""" |
||
| 79 | # return {3: [{'string': 'liberal_Class 70.8 71.9\n' |
||
| 80 | # 'centre_Class 70.8 65.1\n' |
||
| 81 | # 'conservative_Class 71.9 65.1 \n', |
||
| 82 | # 'label': 'sgo_1.phi', |
||
| 83 | # 'data': [[0, 70.8, 71.9], |
||
| 84 | # [70.8, 0, 65.1], |
||
| 85 | # [71.9, 65.1, 0]]}, |
||
| 86 | # # {'string': '', |
||
| 87 | # # 'label': 'candidate', |
||
| 88 | # # 'data': [[]]} |
||
| 89 | # ], |
||
| 90 | # }[request.param] |
||
| 91 | # # |
||
| 92 | # # @pytest.fixture(scope='module') |
||
| 93 | # # def |
||
| 94 | # |
||
| 95 | # def sane_separation(data): |
||
| 96 | # [[float('{:.1f}'.format(y)) for y in x] for x in pr.values([request.param], topics_set='domain')[0]] |
||
| 97 | @pytest.fixture(scope='module',) |
||
| 98 | def known_expected(datasets): |
||
| 99 | d = datasets.copy() |
||
| 100 | d['models'] = [x for x in d['models'] if 'expected-string' in x and 'expected-matrix' in x] |
||
| 101 | return d |
||
| 102 | |||
| 103 | |||
| 104 | def test_divergence_distances_computer(pr, known_expected): # , regression_model_path): |
||
| 105 | pr.dataset = known_expected['dataset'] |
||
| 106 | for d in known_expected['models']: |
||
| 107 | assert d['reported-string'] == d['expected-string'] |
||
| 108 | assert d['reported-distances'] == d['expected-matrix'] |
||
| 109 | |||
| 110 | |||
| 111 | # def test_loading(datasets, megadata_dir): |
||
| 112 | # new_exp_obj = Experiment(megadata_dir) |
||
| 113 | # trainer = TrainerFactory().create_trainer(megadata_dir) |
||
| 114 | # trainer.register(new_exp_obj) |
||
| 115 | # for d in datasets['models']: |
||
| 116 | # loaded_model = new_exp_obj.load_experiment(d['label'].replace('.phi', '')) |
||
| 117 | # |
||
| 118 | # assert loaded_model.regularizer_wrappers |
||
| 119 | # return loaded_model, new_exp_obj |
||
| 120 | # @pytest.mark.parametrize("model_phi_path", [ |
||
| 121 | # 'sgo_1.phi', |
||
| 122 | # 'candidate.phi' |
||
| 123 | # ]) |
||
| 124 | |||
| 125 | |||
| 126 | # # T.O.D.O. implement a function that takes a KL distances matrix (see 'best_seperation_performance' fixture) and return a quantitative measure of quality |
||
| 127 | # # def fitness(data): |
||
| 128 | # # pass |
||
| 129 |