|
1
|
|
|
import os |
|
2
|
|
|
import pytest |
|
3
|
|
|
from topic_modeling_toolkit.reporting import PsiReporter |
|
4
|
|
|
from topic_modeling_toolkit.patm import Experiment, TrainerFactory |
|
5
|
|
|
|
|
6
|
|
|
|
|
7
|
|
|
@pytest.fixture(scope='session') |
|
8
|
|
|
def megadata_dir(unittests_data_dir): |
|
9
|
|
|
return os.path.join(unittests_data_dir, 'megadata') |
|
10
|
|
|
|
|
11
|
|
|
|
|
12
|
|
|
@pytest.fixture(scope='session') |
|
13
|
|
|
def regression_model_path(megadata_dir): |
|
14
|
|
|
trainer = TrainerFactory().create_trainer(megadata_dir, exploit_ideology_labels=True, force_new_batches=False) |
|
15
|
|
|
experiment = Experiment(megadata_dir) |
|
16
|
|
|
topic_model = trainer.model_factory.create_model('candidate', os.path.join(MODULE_DIR, 'regression.cfg'), reg_cfg=os.path.join(MODULE_DIR, 'test-regularizers.cfg'), |
|
|
|
|
|
|
17
|
|
|
show_progress_bars=False) |
|
18
|
|
|
train_specs = trainer.model_factory.create_train_specs() |
|
19
|
|
|
trainer.register(experiment) |
|
20
|
|
|
experiment.init_empty_trackables(topic_model) |
|
21
|
|
|
trainer.train(topic_model, train_specs, effects=False, cache_theta=True) |
|
22
|
|
|
experiment.save_experiment(save_phi=True) |
|
23
|
|
|
# CHANGE THIS SO THAT YOU DO NOT OVERRIDE the previously persisted model phi matrix object and results json |
|
24
|
|
|
return 'candidate.phi' |
|
25
|
|
|
|
|
26
|
|
|
|
|
27
|
|
|
@pytest.fixture(scope='module') |
|
28
|
|
|
def pr(megadata_dir): |
|
29
|
|
|
pr = PsiReporter() |
|
30
|
|
|
pr.dataset = megadata_dir |
|
31
|
|
|
return pr |
|
32
|
|
|
|
|
33
|
|
|
|
|
34
|
|
|
@pytest.fixture(scope='module', params=[3]) |
|
35
|
|
|
def datasets(megadata_dir, pr, request): |
|
36
|
|
|
"""Dataset and models trained on it parametrized on the number of document classes defined""" |
|
37
|
|
|
data = {3: {'dataset': megadata_dir, |
|
38
|
|
|
'models': [ |
|
39
|
|
|
{ |
|
40
|
|
|
'label': 'sgo_1.phi', |
|
41
|
|
|
'expected-string': 'liberal_Class 70.8 71.9\n' |
|
42
|
|
|
'centre_Class 70.8 65.1\n' |
|
43
|
|
|
'conservative_Class 71.9 65.1 \n', |
|
44
|
|
|
'expected-matrix': [[0, 70.8, 71.9], |
|
45
|
|
|
[70.8, 0, 65.1], |
|
46
|
|
|
[71.9, 65.1, 0]] |
|
47
|
|
|
}, |
|
48
|
|
|
{ |
|
49
|
|
|
'label': 'candidate.phi', |
|
50
|
|
|
'expected-string': 'liberal_Class 34.9 35.0\n' |
|
51
|
|
|
'centre_Class 34.9 29.2\n' |
|
52
|
|
|
'conservative_Class 35.0 29.2 \n', |
|
53
|
|
|
'expected-matrix': [[0, 34.9, 35], |
|
54
|
|
|
[34.9, 0, 29.2], |
|
55
|
|
|
[35, 29.2, 0]] |
|
56
|
|
|
} |
|
57
|
|
|
] |
|
58
|
|
|
}} |
|
59
|
|
|
subdict = data[request.param] |
|
60
|
|
|
for model_data in subdict['models']: |
|
61
|
|
|
model_data['reported-distances'] = [[float('{:.1f}'.format(k)) for k in z] for z in pr.values([model_data['label']], topics_set='domain')[0]] |
|
62
|
|
|
model_data['reported-string'] = pr.pformat([model_data['label']], topics_set='domain', show_model_name=False, show_class_names=True) |
|
63
|
|
|
return subdict |
|
64
|
|
|
|
|
65
|
|
|
# Insert fixture 'regression_model_path', which is the path to a newly trained model, with settings in regression.cfg, to test if it places the ideology classes relatively in correct ordering on the political spectrum |
|
66
|
|
|
def test_sanity(pr, datasets): |
|
67
|
|
|
pr.dataset = datasets['dataset'] |
|
68
|
|
|
for d in datasets['models']: |
|
69
|
|
|
# array = [[float('{:.1f}'.format(k)) for k in z] for z in pr.values([d['label']], topics_set='domain')[0]] |
|
70
|
|
|
for i, row in enumerate(d['reported-distances']): |
|
71
|
|
|
assert row[:i] == sorted(row[:i], reverse=True) |
|
72
|
|
|
assert row[i] == 0 |
|
73
|
|
|
assert row[i:] == sorted(row[i:], reverse=False) |
|
74
|
|
|
|
|
75
|
|
|
# |
|
76
|
|
|
# @pytest.fixture(scope='module', params=[3]) |
|
77
|
|
|
# def best_models(request): |
|
78
|
|
|
# """Models achieving best separation between the p(c|t) distributions and their exected Psi-matrix related results""" |
|
79
|
|
|
# return {3: [{'string': 'liberal_Class 70.8 71.9\n' |
|
80
|
|
|
# 'centre_Class 70.8 65.1\n' |
|
81
|
|
|
# 'conservative_Class 71.9 65.1 \n', |
|
82
|
|
|
# 'label': 'sgo_1.phi', |
|
83
|
|
|
# 'data': [[0, 70.8, 71.9], |
|
84
|
|
|
# [70.8, 0, 65.1], |
|
85
|
|
|
# [71.9, 65.1, 0]]}, |
|
86
|
|
|
# # {'string': '', |
|
87
|
|
|
# # 'label': 'candidate', |
|
88
|
|
|
# # 'data': [[]]} |
|
89
|
|
|
# ], |
|
90
|
|
|
# }[request.param] |
|
91
|
|
|
# # |
|
92
|
|
|
# # @pytest.fixture(scope='module') |
|
93
|
|
|
# # def |
|
94
|
|
|
# |
|
95
|
|
|
# def sane_separation(data): |
|
96
|
|
|
# [[float('{:.1f}'.format(y)) for y in x] for x in pr.values([request.param], topics_set='domain')[0]] |
|
97
|
|
|
@pytest.fixture(scope='module',) |
|
98
|
|
|
def known_expected(datasets): |
|
99
|
|
|
d = datasets.copy() |
|
100
|
|
|
d['models'] = [x for x in d['models'] if 'expected-string' in x and 'expected-matrix' in x] |
|
101
|
|
|
return d |
|
102
|
|
|
|
|
103
|
|
|
|
|
104
|
|
|
def test_divergence_distances_computer(pr, known_expected): # , regression_model_path): |
|
105
|
|
|
pr.dataset = known_expected['dataset'] |
|
106
|
|
|
for d in known_expected['models']: |
|
107
|
|
|
assert d['reported-string'] == d['expected-string'] |
|
108
|
|
|
assert d['reported-distances'] == d['expected-matrix'] |
|
109
|
|
|
|
|
110
|
|
|
|
|
111
|
|
|
# def test_loading(datasets, megadata_dir): |
|
112
|
|
|
# new_exp_obj = Experiment(megadata_dir) |
|
113
|
|
|
# trainer = TrainerFactory().create_trainer(megadata_dir) |
|
114
|
|
|
# trainer.register(new_exp_obj) |
|
115
|
|
|
# for d in datasets['models']: |
|
116
|
|
|
# loaded_model = new_exp_obj.load_experiment(d['label'].replace('.phi', '')) |
|
117
|
|
|
# |
|
118
|
|
|
# assert loaded_model.regularizer_wrappers |
|
119
|
|
|
# return loaded_model, new_exp_obj |
|
120
|
|
|
# @pytest.mark.parametrize("model_phi_path", [ |
|
121
|
|
|
# 'sgo_1.phi', |
|
122
|
|
|
# 'candidate.phi' |
|
123
|
|
|
# ]) |
|
124
|
|
|
|
|
125
|
|
|
|
|
126
|
|
|
# # T.O.D.O. implement a function that takes a KL distances matrix (see 'best_seperation_performance' fixture) and return a quantitative measure of quality |
|
127
|
|
|
# # def fitness(data): |
|
128
|
|
|
# # pass |
|
129
|
|
|
|