|
1
|
|
|
# Copyright (c) 2014, Salesforce.com, Inc. All rights reserved. |
|
2
|
|
|
# |
|
3
|
|
|
# Redistribution and use in source and binary forms, with or without |
|
4
|
|
|
# modification, are permitted provided that the following conditions |
|
5
|
|
|
# are met: |
|
6
|
|
|
# |
|
7
|
|
|
# - Redistributions of source code must retain the above copyright |
|
8
|
|
|
# notice, this list of conditions and the following disclaimer. |
|
9
|
|
|
# - Redistributions in binary form must reproduce the above copyright |
|
10
|
|
|
# notice, this list of conditions and the following disclaimer in the |
|
11
|
|
|
# documentation and/or other materials provided with the distribution. |
|
12
|
|
|
# - Neither the name of Salesforce.com nor the names of its contributors |
|
13
|
|
|
# may be used to endorse or promote products derived from this |
|
14
|
|
|
# software without specific prior written permission. |
|
15
|
|
|
# |
|
16
|
|
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS |
|
17
|
|
|
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT |
|
18
|
|
|
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS |
|
19
|
|
|
# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE |
|
20
|
|
|
# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, |
|
21
|
|
|
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, |
|
22
|
|
|
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS |
|
23
|
|
|
# OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND |
|
24
|
|
|
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR |
|
25
|
|
|
# TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE |
|
26
|
|
|
# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
|
27
|
|
|
|
|
28
|
|
|
import math |
|
29
|
|
|
import numpy |
|
30
|
|
|
import numpy.random |
|
31
|
|
|
import scipy.stats |
|
32
|
|
|
import functools |
|
33
|
|
|
from collections import defaultdict |
|
34
|
|
|
from nose import SkipTest |
|
35
|
|
|
from nose.tools import assert_greater |
|
36
|
|
|
from nose.tools import assert_in |
|
37
|
|
|
from nose.tools import assert_is_instance |
|
38
|
|
|
from nose.tools import assert_not_equal |
|
39
|
|
|
from nose.tools import assert_true |
|
40
|
|
|
from goftests import density_goodness_of_fit |
|
41
|
|
|
from goftests import discrete_goodness_of_fit |
|
42
|
|
|
from goftests import vector_density_goodness_of_fit |
|
43
|
|
|
from distributions.dbg.random import sample_discrete |
|
44
|
|
|
from distributions.util import scores_to_probs |
|
45
|
|
|
from distributions.tests.util import assert_all_close |
|
46
|
|
|
from distributions.tests.util import assert_close |
|
47
|
|
|
from distributions.tests.util import assert_hasattr |
|
48
|
|
|
from distributions.tests.util import import_model |
|
49
|
|
|
from distributions.tests.util import list_models |
|
50
|
|
|
from distributions.tests.util import seed_all |
|
51
|
|
|
|
|
52
|
|
|
try: |
|
53
|
|
|
import distributions.io.schema_pb2 |
|
54
|
|
|
has_protobuf = True |
|
55
|
|
|
except ImportError: |
|
56
|
|
|
has_protobuf = False |
|
57
|
|
|
|
|
58
|
|
|
DATA_COUNT = 20 |
|
59
|
|
|
SAMPLE_COUNT = 1000 |
|
60
|
|
|
MIN_GOODNESS_OF_FIT = 1e-3 |
|
61
|
|
|
|
|
62
|
|
|
MODULES = { |
|
63
|
|
|
'{flavor}.models.{name}'.format(**spec): import_model(spec) |
|
64
|
|
|
for spec in list_models() |
|
65
|
|
|
} |
|
66
|
|
|
|
|
67
|
|
|
IS_FAST = {'dbg': False, 'hp': True, 'lp': True} |
|
68
|
|
|
|
|
69
|
|
|
|
|
70
|
|
|
def model_is_fast(model): |
|
71
|
|
|
flavor = model.__name__.split('.')[1] |
|
72
|
|
|
return IS_FAST[flavor] |
|
73
|
|
|
|
|
74
|
|
|
|
|
75
|
|
|
def iter_examples(module): |
|
76
|
|
|
assert_hasattr(module, 'EXAMPLES') |
|
77
|
|
|
EXAMPLES = module.EXAMPLES |
|
78
|
|
|
assert_is_instance(EXAMPLES, list) |
|
79
|
|
|
assert_true(EXAMPLES, 'no examples provided') |
|
80
|
|
|
for i, EXAMPLE in enumerate(EXAMPLES): |
|
81
|
|
|
print 'example {}/{}'.format(1 + i, len(EXAMPLES)) |
|
82
|
|
|
assert_in('shared', EXAMPLE) |
|
83
|
|
|
assert_in('values', EXAMPLE) |
|
84
|
|
|
values = EXAMPLE['values'] |
|
85
|
|
|
assert_is_instance(values, list) |
|
86
|
|
|
count = len(values) |
|
87
|
|
|
assert_true( |
|
88
|
|
|
count >= 7, |
|
89
|
|
|
'Add more example values (expected >= 7, found {})'.format(count)) |
|
90
|
|
|
yield EXAMPLE |
|
91
|
|
|
|
|
92
|
|
|
|
|
93
|
|
|
def for_each_model(*filters): |
|
94
|
|
|
''' |
|
95
|
|
|
Run one test per Model, filtering out inappropriate Models for test. |
|
96
|
|
|
''' |
|
97
|
|
|
def filtered(test_fun): |
|
98
|
|
|
|
|
99
|
|
|
@functools.wraps(test_fun) |
|
100
|
|
|
def test_one_model(name): |
|
101
|
|
|
module = MODULES[name] |
|
102
|
|
|
assert_hasattr(module, 'Shared') |
|
103
|
|
|
for EXAMPLE in iter_examples(module): |
|
104
|
|
|
test_fun(module, EXAMPLE) |
|
105
|
|
|
|
|
106
|
|
|
@functools.wraps(test_fun) |
|
107
|
|
|
def test_all_models(): |
|
108
|
|
|
for name in MODULES: |
|
109
|
|
|
module = MODULES[name] |
|
110
|
|
|
if all(f(module) for f in filters): |
|
111
|
|
|
yield test_one_model, name |
|
112
|
|
|
|
|
113
|
|
|
return test_all_models |
|
114
|
|
|
return filtered |
|
115
|
|
|
|
|
116
|
|
|
|
|
117
|
|
|
@for_each_model() |
|
118
|
|
|
def test_value(module, EXAMPLE): |
|
119
|
|
|
assert_hasattr(module, 'Value') |
|
120
|
|
|
assert_is_instance(module.Value, type) |
|
121
|
|
|
|
|
122
|
|
|
values = EXAMPLE['values'] |
|
123
|
|
|
for value in values: |
|
124
|
|
|
assert_is_instance(value, module.Value) |
|
125
|
|
|
|
|
126
|
|
|
|
|
127
|
|
|
@for_each_model() |
|
128
|
|
|
def test_shared(module, EXAMPLE): |
|
129
|
|
|
assert_hasattr(module, 'Shared') |
|
130
|
|
|
assert_is_instance(module.Shared, type) |
|
131
|
|
|
|
|
132
|
|
|
shared1 = module.Shared.from_dict(EXAMPLE['shared']) |
|
133
|
|
|
shared2 = module.Shared.from_dict(EXAMPLE['shared']) |
|
134
|
|
|
assert_close(shared1.dump(), EXAMPLE['shared']) |
|
135
|
|
|
|
|
136
|
|
|
values = EXAMPLE['values'] |
|
137
|
|
|
seed_all(0) |
|
138
|
|
|
for value in values: |
|
139
|
|
|
shared1.add_value(value) |
|
140
|
|
|
seed_all(0) |
|
141
|
|
|
for value in values: |
|
142
|
|
|
shared2.add_value(value) |
|
143
|
|
|
assert_close(shared1.dump(), shared2.dump()) |
|
144
|
|
|
|
|
145
|
|
|
for value in values: |
|
146
|
|
|
shared1.remove_value(value) |
|
147
|
|
|
assert_close(shared1.dump(), EXAMPLE['shared']) |
|
148
|
|
|
|
|
149
|
|
|
|
|
150
|
|
|
@for_each_model() |
|
151
|
|
|
def test_group(module, EXAMPLE): |
|
152
|
|
|
assert_hasattr(module, 'Group') |
|
153
|
|
|
assert_is_instance(module.Group, type) |
|
154
|
|
|
|
|
155
|
|
|
shared = module.Shared.from_dict(EXAMPLE['shared']) |
|
156
|
|
|
values = EXAMPLE['values'] |
|
157
|
|
|
for value in values: |
|
158
|
|
|
shared.add_value(value) |
|
159
|
|
|
|
|
160
|
|
|
group1 = module.Group() |
|
161
|
|
|
group1.init(shared) |
|
162
|
|
|
for value in values: |
|
163
|
|
|
group1.add_value(shared, value) |
|
164
|
|
|
group2 = module.Group.from_values(shared, values) |
|
165
|
|
|
assert_close(group1.dump(), group2.dump()) |
|
166
|
|
|
|
|
167
|
|
|
group = module.Group.from_values(shared, values) |
|
168
|
|
|
dumped = group.dump() |
|
169
|
|
|
group.init(shared) |
|
170
|
|
|
group.load(dumped) |
|
171
|
|
|
assert_close(group.dump(), dumped) |
|
172
|
|
|
|
|
173
|
|
|
for value in values: |
|
174
|
|
|
group2.remove_value(shared, value) |
|
175
|
|
|
assert_not_equal(group1, group2) |
|
176
|
|
|
group2.merge(shared, group1) |
|
177
|
|
|
|
|
178
|
|
|
for value in values: |
|
179
|
|
|
group1.score_value(shared, value) |
|
180
|
|
|
for _ in xrange(10): |
|
181
|
|
|
value = group1.sample_value(shared) |
|
182
|
|
|
group1.score_value(shared, value) |
|
183
|
|
|
module.sample_group(shared, 10) |
|
184
|
|
|
group1.score_data(shared) |
|
185
|
|
|
group2.score_data(shared) |
|
186
|
|
|
|
|
187
|
|
|
|
|
188
|
|
|
@for_each_model(lambda module: hasattr(module.Shared, 'protobuf_load')) |
|
189
|
|
|
def test_protobuf(module, EXAMPLE): |
|
190
|
|
|
if not has_protobuf: |
|
191
|
|
|
raise SkipTest('protobuf not available') |
|
192
|
|
|
shared = module.Shared.from_dict(EXAMPLE['shared']) |
|
193
|
|
|
values = EXAMPLE['values'] |
|
194
|
|
|
Message = getattr(distributions.io.schema_pb2, module.NAME) |
|
195
|
|
|
|
|
196
|
|
|
message = Message.Shared() |
|
197
|
|
|
shared.protobuf_dump(message) |
|
198
|
|
|
shared2 = module.Shared() |
|
199
|
|
|
shared2.protobuf_load(message) |
|
200
|
|
|
assert_close(shared2.dump(), shared.dump()) |
|
201
|
|
|
|
|
202
|
|
|
message.Clear() |
|
203
|
|
|
dumped = shared.dump() |
|
204
|
|
|
module.Shared.to_protobuf(dumped, message) |
|
205
|
|
|
assert_close(module.Shared.from_protobuf(message), dumped) |
|
206
|
|
|
|
|
207
|
|
|
if hasattr(module.Group, 'protobuf_load'): |
|
208
|
|
|
for value in values: |
|
209
|
|
|
shared.add_value(value) |
|
210
|
|
|
group = module.Group.from_values(shared, values) |
|
211
|
|
|
|
|
212
|
|
|
message = Message.Group() |
|
213
|
|
|
group.protobuf_dump(message) |
|
214
|
|
|
group2 = module.Group() |
|
215
|
|
|
group2.protobuf_load(message) |
|
216
|
|
|
assert_close(group2.dump(), group.dump()) |
|
217
|
|
|
|
|
218
|
|
|
message.Clear() |
|
219
|
|
|
dumped = group.dump() |
|
220
|
|
|
module.Group.to_protobuf(dumped, message) |
|
221
|
|
|
assert_close(module.Group.from_protobuf(message), dumped) |
|
222
|
|
|
|
|
223
|
|
|
|
|
224
|
|
|
@for_each_model() |
|
225
|
|
|
def test_add_remove(module, EXAMPLE): |
|
226
|
|
|
# Test group_add_value, group_remove_value, score_data, score_value |
|
227
|
|
|
|
|
228
|
|
|
shared = module.Shared.from_dict(EXAMPLE['shared']) |
|
229
|
|
|
shared.realize() |
|
230
|
|
|
|
|
231
|
|
|
values = [] |
|
232
|
|
|
group = module.Group.from_values(shared) |
|
233
|
|
|
score = 0.0 |
|
234
|
|
|
assert_close(group.score_data(shared), score, err_msg='p(empty) != 1') |
|
235
|
|
|
|
|
236
|
|
|
for _ in range(DATA_COUNT): |
|
237
|
|
|
value = group.sample_value(shared) |
|
238
|
|
|
values.append(value) |
|
239
|
|
|
score += group.score_value(shared, value) |
|
240
|
|
|
group.add_value(shared, value) |
|
241
|
|
|
|
|
242
|
|
|
group_all = module.Group.from_dict(group.dump()) |
|
243
|
|
|
assert_close( |
|
244
|
|
|
score, |
|
245
|
|
|
group.score_data(shared), |
|
246
|
|
|
err_msg='p(x1,...,xn) != p(x1) p(x2|x1) p(xn|...)') |
|
247
|
|
|
|
|
248
|
|
|
numpy.random.shuffle(values) |
|
249
|
|
|
|
|
250
|
|
|
for value in values: |
|
251
|
|
|
group.remove_value(shared, value) |
|
252
|
|
|
|
|
253
|
|
|
group_empty = module.Group.from_values(shared) |
|
254
|
|
|
assert_close( |
|
255
|
|
|
group.dump(), |
|
256
|
|
|
group_empty.dump(), |
|
257
|
|
|
err_msg='group + values - values != group') |
|
258
|
|
|
|
|
259
|
|
|
numpy.random.shuffle(values) |
|
260
|
|
|
for value in values: |
|
261
|
|
|
group.add_value(shared, value) |
|
262
|
|
|
assert_close( |
|
263
|
|
|
group.dump(), |
|
264
|
|
|
group_all.dump(), |
|
265
|
|
|
err_msg='group - values + values != group') |
|
266
|
|
|
|
|
267
|
|
|
|
|
268
|
|
|
@for_each_model() |
|
269
|
|
|
def test_add_repeated(module, EXAMPLE): |
|
270
|
|
|
# Test add_repeated value vs n * add |
|
271
|
|
|
shared = module.Shared.from_dict(EXAMPLE['shared']) |
|
272
|
|
|
shared.realize() |
|
273
|
|
|
for value in EXAMPLE['values']: |
|
274
|
|
|
group = module.Group.from_values(shared) |
|
275
|
|
|
for _ in range(DATA_COUNT): |
|
276
|
|
|
group.add_value(shared, value) |
|
277
|
|
|
|
|
278
|
|
|
group_repeated = module.Group.from_values(shared) |
|
279
|
|
|
group_repeated.add_repeated_value(shared, value, count=DATA_COUNT) |
|
280
|
|
|
assert_close( |
|
281
|
|
|
group.dump(), |
|
282
|
|
|
group_repeated.dump(), |
|
283
|
|
|
err_msg='n * add_value != add_repeated_value n') |
|
284
|
|
|
|
|
285
|
|
|
|
|
286
|
|
|
@for_each_model() |
|
287
|
|
|
def test_add_merge(module, EXAMPLE): |
|
288
|
|
|
# Test group_add_value, group_merge |
|
289
|
|
|
shared = module.Shared.from_dict(EXAMPLE['shared']) |
|
290
|
|
|
values = EXAMPLE['values'][:] |
|
291
|
|
|
for value in values: |
|
292
|
|
|
shared.add_value(value) |
|
293
|
|
|
|
|
294
|
|
|
numpy.random.shuffle(values) |
|
295
|
|
|
group = module.Group.from_values(shared, values) |
|
296
|
|
|
|
|
297
|
|
|
for i in xrange(len(values) + 1): |
|
298
|
|
|
numpy.random.shuffle(values) |
|
299
|
|
|
group1 = module.Group.from_values(shared, values[:i]) |
|
300
|
|
|
group2 = module.Group.from_values(shared, values[i:]) |
|
301
|
|
|
group1.merge(shared, group2) |
|
302
|
|
|
assert_close(group.dump(), group1.dump()) |
|
303
|
|
|
|
|
304
|
|
|
|
|
305
|
|
|
@for_each_model() |
|
306
|
|
|
def test_group_merge(module, EXAMPLE): |
|
307
|
|
|
shared = module.Shared.from_dict(EXAMPLE['shared']) |
|
308
|
|
|
shared.realize() |
|
309
|
|
|
group1 = module.Group.from_values(shared) |
|
310
|
|
|
group2 = module.Group.from_values(shared) |
|
311
|
|
|
expected = module.Group.from_values(shared) |
|
312
|
|
|
actual = module.Group.from_values(shared) |
|
313
|
|
|
for _ in xrange(100): |
|
314
|
|
|
value = expected.sample_value(shared) |
|
315
|
|
|
expected.add_value(shared, value) |
|
316
|
|
|
group1.add_value(shared, value) |
|
317
|
|
|
|
|
318
|
|
|
value = expected.sample_value(shared) |
|
319
|
|
|
expected.add_value(shared, value) |
|
320
|
|
|
group2.add_value(shared, value) |
|
321
|
|
|
|
|
322
|
|
|
actual.load(group1.dump()) |
|
323
|
|
|
actual.merge(shared, group2) |
|
324
|
|
|
assert_close(actual.dump(), expected.dump()) |
|
325
|
|
|
|
|
326
|
|
|
|
|
327
|
|
|
@for_each_model(lambda module: module.Value in [bool, int]) |
|
328
|
|
|
def test_group_allows_debt(module, EXAMPLE): |
|
329
|
|
|
# Test that group.add_value can safely go into data debt |
|
330
|
|
|
shared = module.Shared.from_dict(EXAMPLE['shared']) |
|
331
|
|
|
shared.realize() |
|
332
|
|
|
values = [] |
|
333
|
|
|
group1 = module.Group.from_values(shared, values) |
|
334
|
|
|
for _ in range(DATA_COUNT): |
|
335
|
|
|
value = group1.sample_value(shared) |
|
336
|
|
|
values.append(value) |
|
337
|
|
|
group1.add_value(shared, value) |
|
338
|
|
|
|
|
339
|
|
|
group2 = module.Group.from_values(shared) |
|
340
|
|
|
pos_values = [(v, +1) for v in values] |
|
341
|
|
|
neg_values = [(v, -1) for v in values] |
|
342
|
|
|
signed_values = pos_values * 3 + neg_values * 2 |
|
343
|
|
|
numpy.random.shuffle(signed_values) |
|
344
|
|
|
for value, sign in signed_values: |
|
345
|
|
|
if sign > 0: |
|
346
|
|
|
group2.add_value(shared, value) |
|
347
|
|
|
else: |
|
348
|
|
|
group2.remove_value(shared, value) |
|
349
|
|
|
|
|
350
|
|
|
assert_close(group1.dump(), group2.dump()) |
|
351
|
|
|
|
|
352
|
|
|
|
|
353
|
|
|
@for_each_model() |
|
354
|
|
|
def test_sample_seed(module, EXAMPLE): |
|
355
|
|
|
shared = module.Shared.from_dict(EXAMPLE['shared']) |
|
356
|
|
|
|
|
357
|
|
|
seed_all(0) |
|
358
|
|
|
group1 = module.Group.from_values(shared) |
|
359
|
|
|
values1 = [group1.sample_value(shared) for _ in xrange(DATA_COUNT)] |
|
360
|
|
|
|
|
361
|
|
|
seed_all(0) |
|
362
|
|
|
group2 = module.Group.from_values(shared) |
|
363
|
|
|
values2 = [group2.sample_value(shared) for _ in xrange(DATA_COUNT)] |
|
364
|
|
|
|
|
365
|
|
|
assert_close(values1, values2, err_msg='values') |
|
366
|
|
|
|
|
367
|
|
|
|
|
368
|
|
|
@for_each_model() |
|
369
|
|
|
def test_sample_value(module, EXAMPLE): |
|
370
|
|
|
seed_all(0) |
|
371
|
|
|
shared = module.Shared.from_dict(EXAMPLE['shared']) |
|
372
|
|
|
shared.realize() |
|
373
|
|
|
for values in [[], EXAMPLE['values']]: |
|
374
|
|
|
group = module.Group.from_values(shared, values) |
|
375
|
|
|
sample_count = SAMPLE_COUNT |
|
376
|
|
|
if module.Value == numpy.ndarray: |
|
377
|
|
|
sample_count *= 10 |
|
378
|
|
|
samples = [group.sample_value(shared) for _ in xrange(sample_count)] |
|
379
|
|
|
if module.Value in [bool, int]: |
|
380
|
|
|
probs_dict = { |
|
381
|
|
|
value: math.exp(group.score_value(shared, value)) |
|
382
|
|
|
for value in set(samples) |
|
383
|
|
|
} |
|
384
|
|
|
gof = discrete_goodness_of_fit(samples, probs_dict, plot=True) |
|
385
|
|
|
elif module.Value == float: |
|
386
|
|
|
probs = numpy.exp([ |
|
387
|
|
|
group.score_value(shared, value) |
|
388
|
|
|
for value in samples |
|
389
|
|
|
]) |
|
390
|
|
|
gof = density_goodness_of_fit(samples, probs, plot=True) |
|
391
|
|
|
elif module.Value == numpy.ndarray: |
|
392
|
|
|
if module.__name__ == 'distributions.lp.models.niw': |
|
393
|
|
|
raise SkipTest('FIXME known sampling bug') |
|
394
|
|
|
probs = numpy.exp([ |
|
395
|
|
|
group.score_value(shared, value) |
|
396
|
|
|
for value in samples |
|
397
|
|
|
]) |
|
398
|
|
|
gof = vector_density_goodness_of_fit(samples, probs, plot=True) |
|
399
|
|
|
else: |
|
400
|
|
|
raise SkipTest('Not implemented for {}'.format(module.Value)) |
|
401
|
|
|
print '{} gof = {:0.3g}'.format(module.__name__, gof) |
|
402
|
|
|
assert_greater(gof, MIN_GOODNESS_OF_FIT) |
|
403
|
|
|
|
|
404
|
|
|
|
|
405
|
|
|
@for_each_model() |
|
406
|
|
|
def test_sample_group(module, EXAMPLE): |
|
407
|
|
|
seed_all(0) |
|
408
|
|
|
SIZE = 2 |
|
409
|
|
|
shared = module.Shared.from_dict(EXAMPLE['shared']) |
|
410
|
|
|
shared.realize() |
|
411
|
|
|
for values in [[], EXAMPLE['values']]: |
|
412
|
|
|
if module.Value in [bool, int]: |
|
413
|
|
|
samples = [] |
|
414
|
|
|
probs_dict = {} |
|
415
|
|
|
for _ in xrange(SAMPLE_COUNT): |
|
416
|
|
|
values = module.sample_group(shared, SIZE) |
|
417
|
|
|
sample = tuple(values) |
|
418
|
|
|
samples.append(sample) |
|
419
|
|
|
group = module.Group.from_values(shared, values) |
|
420
|
|
|
probs_dict[sample] = math.exp(group.score_data(shared)) |
|
421
|
|
|
gof = discrete_goodness_of_fit(samples, probs_dict, plot=True) |
|
422
|
|
|
else: |
|
423
|
|
|
raise SkipTest('Not implemented for {}'.format(module.Value)) |
|
424
|
|
|
print '{} gof = {:0.3g}'.format(module.__name__, gof) |
|
425
|
|
|
assert_greater(gof, MIN_GOODNESS_OF_FIT) |
|
426
|
|
|
|
|
427
|
|
|
|
|
428
|
|
|
def _append_ss(group, aggregator): |
|
429
|
|
|
ss = group.dump() |
|
430
|
|
|
for key, val in ss.iteritems(): |
|
431
|
|
|
if isinstance(val, list): |
|
432
|
|
|
for i, v in enumerate(val): |
|
433
|
|
|
aggregator['{}_{}'.format(key, i)].append(v) |
|
434
|
|
|
elif isinstance(val, dict): |
|
435
|
|
|
for k, v in val.iteritems(): |
|
436
|
|
|
aggregator['{}_{}'.format(key, k)].append(v) |
|
437
|
|
|
else: |
|
438
|
|
|
aggregator[key].append(val) |
|
439
|
|
|
|
|
440
|
|
|
|
|
441
|
|
|
def sample_marginal_conditional(module, shared, value_count): |
|
442
|
|
|
values = module.sample_group(shared, value_count) |
|
443
|
|
|
group = module.Group.from_values(shared, values) |
|
444
|
|
|
return group |
|
445
|
|
|
|
|
446
|
|
|
|
|
447
|
|
|
def sample_successive_conditional(module, shared, group, value_count): |
|
448
|
|
|
sampler = module.Sampler() |
|
449
|
|
|
sampler.init(shared, group) |
|
450
|
|
|
values = [sampler.eval(shared) for _ in xrange(value_count)] |
|
451
|
|
|
new_group = module.Group.from_values(shared, values) |
|
452
|
|
|
return new_group |
|
453
|
|
|
|
|
454
|
|
|
|
|
455
|
|
|
@for_each_model(model_is_fast) |
|
456
|
|
|
def test_joint(module, EXAMPLE): |
|
457
|
|
|
# \cite{geweke04getting} |
|
458
|
|
|
seed_all(0) |
|
459
|
|
|
SIZE = 10 |
|
460
|
|
|
SKIP = 100 |
|
461
|
|
|
shared = module.Shared.from_dict(EXAMPLE['shared']) |
|
462
|
|
|
shared.realize() |
|
463
|
|
|
marginal_conditional_samples = defaultdict(lambda: []) |
|
464
|
|
|
successive_conditional_samples = defaultdict(lambda: []) |
|
465
|
|
|
cond_group = sample_marginal_conditional(module, shared, SIZE) |
|
466
|
|
|
for _ in xrange(SAMPLE_COUNT): |
|
467
|
|
|
marg_group = sample_marginal_conditional(module, shared, SIZE) |
|
468
|
|
|
_append_ss(marg_group, marginal_conditional_samples) |
|
469
|
|
|
|
|
470
|
|
|
for __ in range(SKIP): |
|
471
|
|
|
cond_group = sample_successive_conditional( |
|
472
|
|
|
module, |
|
473
|
|
|
shared, |
|
474
|
|
|
cond_group, |
|
475
|
|
|
SIZE) |
|
476
|
|
|
_append_ss(cond_group, successive_conditional_samples) |
|
477
|
|
|
for key in marginal_conditional_samples.keys(): |
|
478
|
|
|
gof = scipy.stats.ttest_ind( |
|
479
|
|
|
marginal_conditional_samples[key], |
|
480
|
|
|
successive_conditional_samples[key])[1] |
|
481
|
|
|
if isinstance(gof, numpy.ndarray): |
|
482
|
|
|
raise SkipTest('XXX: handle array case, gof = {}'.format(gof)) |
|
483
|
|
|
print '{}:{} gof = {:0.3g}'.format(module.__name__, key, gof) |
|
484
|
|
|
if not numpy.isfinite(gof): |
|
485
|
|
|
raise SkipTest('Test fails with gof = {}'.format(gof)) |
|
486
|
|
|
assert_greater(gof, MIN_GOODNESS_OF_FIT) |
|
487
|
|
|
|
|
488
|
|
|
|
|
489
|
|
|
@for_each_model(lambda module: hasattr(module.Shared, 'scorer_create')) |
|
490
|
|
|
def test_scorer(module, EXAMPLE): |
|
491
|
|
|
shared = module.Shared.from_dict(EXAMPLE['shared']) |
|
492
|
|
|
values = EXAMPLE['values'] |
|
493
|
|
|
|
|
494
|
|
|
group = module.Group.from_values(shared) |
|
495
|
|
|
scorer1 = shared.scorer_create() |
|
496
|
|
|
scorer2 = shared.scorer_create(group) |
|
497
|
|
|
for value in values: |
|
498
|
|
|
score1 = shared.scorer_eval(scorer1, value) |
|
499
|
|
|
score2 = shared.scorer_eval(scorer2, value) |
|
500
|
|
|
score3 = group.score_value(shared, value) |
|
501
|
|
|
assert_all_close([score1, score2, score3]) |
|
502
|
|
|
|
|
503
|
|
|
|
|
504
|
|
|
@for_each_model(lambda module: hasattr(module, 'Mixture')) |
|
505
|
|
|
def test_mixture_runs(module, EXAMPLE): |
|
506
|
|
|
shared = module.Shared.from_dict(EXAMPLE['shared']) |
|
507
|
|
|
values = EXAMPLE['values'] |
|
508
|
|
|
|
|
509
|
|
|
mixture = module.Mixture() |
|
510
|
|
|
for value in values: |
|
511
|
|
|
shared.add_value(value) |
|
512
|
|
|
mixture.append(module.Group.from_values(shared, [value])) |
|
513
|
|
|
mixture.init(shared) |
|
514
|
|
|
|
|
515
|
|
|
groupids = [] |
|
516
|
|
|
for value in values: |
|
517
|
|
|
scores = numpy.zeros(len(mixture), dtype=numpy.float32) |
|
518
|
|
|
mixture.score_value(shared, value, scores) |
|
519
|
|
|
probs = scores_to_probs(scores) |
|
520
|
|
|
groupid = sample_discrete(probs) |
|
521
|
|
|
mixture.add_value(shared, groupid, value) |
|
522
|
|
|
groupids.append(groupid) |
|
523
|
|
|
|
|
524
|
|
|
mixture.add_group(shared) |
|
525
|
|
|
assert len(mixture) == len(values) + 1 |
|
526
|
|
|
scores = numpy.zeros(len(mixture), dtype=numpy.float32) |
|
527
|
|
|
|
|
528
|
|
|
for value, groupid in zip(values, groupids): |
|
529
|
|
|
mixture.remove_value(shared, groupid, value) |
|
530
|
|
|
|
|
531
|
|
|
mixture.remove_group(shared, 0) |
|
532
|
|
|
if module.__name__ == 'distributions.lp.models.dpd': |
|
533
|
|
|
raise SkipTest('FIXME known segfault here') |
|
534
|
|
|
mixture.remove_group(shared, len(mixture) - 1) |
|
535
|
|
|
assert len(mixture) == len(values) - 1 |
|
536
|
|
|
|
|
537
|
|
|
for value in values: |
|
538
|
|
|
scores = numpy.zeros(len(mixture), dtype=numpy.float32) |
|
539
|
|
|
mixture.score_value(shared, value, scores) |
|
540
|
|
|
probs = scores_to_probs(scores) |
|
541
|
|
|
groupid = sample_discrete(probs) |
|
542
|
|
|
mixture.add_value(shared, groupid, value) |
|
543
|
|
|
|
|
544
|
|
|
|
|
545
|
|
|
@for_each_model(lambda module: hasattr(module, 'Mixture')) |
|
546
|
|
|
def test_mixture_score(module, EXAMPLE): |
|
547
|
|
|
shared = module.Shared.from_dict(EXAMPLE['shared']) |
|
548
|
|
|
values = EXAMPLE['values'] |
|
549
|
|
|
for value in values: |
|
550
|
|
|
shared.add_value(value) |
|
551
|
|
|
|
|
552
|
|
|
groups = [module.Group.from_values(shared, [value]) for value in values] |
|
553
|
|
|
mixture = module.Mixture() |
|
554
|
|
|
for group in groups: |
|
555
|
|
|
mixture.append(group) |
|
556
|
|
|
mixture.init(shared) |
|
557
|
|
|
|
|
558
|
|
|
def check_score_value(value): |
|
559
|
|
|
expected = [group.score_value(shared, value) for group in groups] |
|
560
|
|
|
actual = numpy.zeros(len(mixture), dtype=numpy.float32) |
|
561
|
|
|
noise = numpy.random.randn(len(actual)) |
|
562
|
|
|
actual += noise |
|
563
|
|
|
mixture.score_value(shared, value, actual) |
|
564
|
|
|
actual -= noise |
|
565
|
|
|
assert_close(actual, expected, err_msg='score_value {}'.format(value)) |
|
566
|
|
|
another = [ |
|
567
|
|
|
mixture.score_value_group(shared, i, value) |
|
568
|
|
|
for i in xrange(len(groups)) |
|
569
|
|
|
] |
|
570
|
|
|
assert_close( |
|
571
|
|
|
another, |
|
572
|
|
|
expected, |
|
573
|
|
|
err_msg='score_value_group {}'.format(value)) |
|
574
|
|
|
return actual |
|
575
|
|
|
|
|
576
|
|
|
def check_score_data(): |
|
577
|
|
|
expected = sum(group.score_data(shared) for group in groups) |
|
578
|
|
|
actual = mixture.score_data(shared) |
|
579
|
|
|
assert_close(actual, expected, err_msg='score_data') |
|
580
|
|
|
|
|
581
|
|
|
print 'init' |
|
582
|
|
|
for value in values: |
|
583
|
|
|
check_score_value(value) |
|
584
|
|
|
check_score_data() |
|
585
|
|
|
|
|
586
|
|
|
print 'adding' |
|
587
|
|
|
groupids = [] |
|
588
|
|
|
for value in values: |
|
589
|
|
|
scores = check_score_value(value) |
|
590
|
|
|
probs = scores_to_probs(scores) |
|
591
|
|
|
groupid = sample_discrete(probs) |
|
592
|
|
|
groups[groupid].add_value(shared, value) |
|
593
|
|
|
mixture.add_value(shared, groupid, value) |
|
594
|
|
|
groupids.append(groupid) |
|
595
|
|
|
check_score_data() |
|
596
|
|
|
|
|
597
|
|
|
print 'removing' |
|
598
|
|
|
for value, groupid in zip(values, groupids): |
|
599
|
|
|
groups[groupid].remove_value(shared, value) |
|
600
|
|
|
mixture.remove_value(shared, groupid, value) |
|
601
|
|
|
scores = check_score_value(value) |
|
602
|
|
|
check_score_data() |
|
603
|
|
|
|