1
|
|
|
# Copyright (c) 2014, Salesforce.com, Inc. All rights reserved. |
2
|
|
|
# |
3
|
|
|
# Redistribution and use in source and binary forms, with or without |
4
|
|
|
# modification, are permitted provided that the following conditions |
5
|
|
|
# are met: |
6
|
|
|
# |
7
|
|
|
# - Redistributions of source code must retain the above copyright |
8
|
|
|
# notice, this list of conditions and the following disclaimer. |
9
|
|
|
# - Redistributions in binary form must reproduce the above copyright |
10
|
|
|
# notice, this list of conditions and the following disclaimer in the |
11
|
|
|
# documentation and/or other materials provided with the distribution. |
12
|
|
|
# - Neither the name of Salesforce.com nor the names of its contributors |
13
|
|
|
# may be used to endorse or promote products derived from this |
14
|
|
|
# software without specific prior written permission. |
15
|
|
|
# |
16
|
|
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS |
17
|
|
|
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT |
18
|
|
|
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS |
19
|
|
|
# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE |
20
|
|
|
# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, |
21
|
|
|
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, |
22
|
|
|
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS |
23
|
|
|
# OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND |
24
|
|
|
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR |
25
|
|
|
# TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE |
26
|
|
|
# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
27
|
|
|
|
28
|
|
|
import math |
29
|
|
|
import numpy |
30
|
|
|
import numpy.random |
31
|
|
|
import scipy.stats |
32
|
|
|
import functools |
33
|
|
|
from collections import defaultdict |
34
|
|
|
from nose import SkipTest |
35
|
|
|
from nose.tools import assert_greater |
36
|
|
|
from nose.tools import assert_in |
37
|
|
|
from nose.tools import assert_is_instance |
38
|
|
|
from nose.tools import assert_not_equal |
39
|
|
|
from nose.tools import assert_true |
40
|
|
|
from goftests import density_goodness_of_fit |
41
|
|
|
from goftests import discrete_goodness_of_fit |
42
|
|
|
from goftests import vector_density_goodness_of_fit |
43
|
|
|
from distributions.dbg.random import sample_discrete |
44
|
|
|
from distributions.util import scores_to_probs |
45
|
|
|
from distributions.tests.util import assert_all_close |
46
|
|
|
from distributions.tests.util import assert_close |
47
|
|
|
from distributions.tests.util import assert_hasattr |
48
|
|
|
from distributions.tests.util import import_model |
49
|
|
|
from distributions.tests.util import list_models |
50
|
|
|
from distributions.tests.util import seed_all |
51
|
|
|
|
52
|
|
|
try: |
53
|
|
|
import distributions.io.schema_pb2 |
54
|
|
|
has_protobuf = True |
55
|
|
|
except ImportError: |
56
|
|
|
has_protobuf = False |
57
|
|
|
|
58
|
|
|
DATA_COUNT = 20 |
59
|
|
|
SAMPLE_COUNT = 1000 |
60
|
|
|
MIN_GOODNESS_OF_FIT = 1e-3 |
61
|
|
|
|
62
|
|
|
MODULES = { |
63
|
|
|
'{flavor}.models.{name}'.format(**spec): import_model(spec) |
64
|
|
|
for spec in list_models() |
65
|
|
|
} |
66
|
|
|
|
67
|
|
|
IS_FAST = {'dbg': False, 'hp': True, 'lp': True} |
68
|
|
|
|
69
|
|
|
|
70
|
|
|
def model_is_fast(model): |
71
|
|
|
flavor = model.__name__.split('.')[1] |
72
|
|
|
return IS_FAST[flavor] |
73
|
|
|
|
74
|
|
|
|
75
|
|
|
def iter_examples(module): |
76
|
|
|
assert_hasattr(module, 'EXAMPLES') |
77
|
|
|
EXAMPLES = module.EXAMPLES |
78
|
|
|
assert_is_instance(EXAMPLES, list) |
79
|
|
|
assert_true(EXAMPLES, 'no examples provided') |
80
|
|
|
for i, EXAMPLE in enumerate(EXAMPLES): |
81
|
|
|
print 'example {}/{}'.format(1 + i, len(EXAMPLES)) |
82
|
|
|
assert_in('shared', EXAMPLE) |
83
|
|
|
assert_in('values', EXAMPLE) |
84
|
|
|
values = EXAMPLE['values'] |
85
|
|
|
assert_is_instance(values, list) |
86
|
|
|
count = len(values) |
87
|
|
|
assert_true( |
88
|
|
|
count >= 7, |
89
|
|
|
'Add more example values (expected >= 7, found {})'.format(count)) |
90
|
|
|
yield EXAMPLE |
91
|
|
|
|
92
|
|
|
|
93
|
|
|
def for_each_model(*filters): |
94
|
|
|
''' |
95
|
|
|
Run one test per Model, filtering out inappropriate Models for test. |
96
|
|
|
''' |
97
|
|
|
def filtered(test_fun): |
98
|
|
|
|
99
|
|
|
@functools.wraps(test_fun) |
100
|
|
|
def test_one_model(name): |
101
|
|
|
module = MODULES[name] |
102
|
|
|
assert_hasattr(module, 'Shared') |
103
|
|
|
for EXAMPLE in iter_examples(module): |
104
|
|
|
test_fun(module, EXAMPLE) |
105
|
|
|
|
106
|
|
|
@functools.wraps(test_fun) |
107
|
|
|
def test_all_models(): |
108
|
|
|
for name in MODULES: |
109
|
|
|
module = MODULES[name] |
110
|
|
|
if all(f(module) for f in filters): |
111
|
|
|
yield test_one_model, name |
112
|
|
|
|
113
|
|
|
return test_all_models |
114
|
|
|
return filtered |
115
|
|
|
|
116
|
|
|
|
117
|
|
|
@for_each_model() |
118
|
|
|
def test_value(module, EXAMPLE): |
119
|
|
|
assert_hasattr(module, 'Value') |
120
|
|
|
assert_is_instance(module.Value, type) |
121
|
|
|
|
122
|
|
|
values = EXAMPLE['values'] |
123
|
|
|
for value in values: |
124
|
|
|
assert_is_instance(value, module.Value) |
125
|
|
|
|
126
|
|
|
|
127
|
|
|
@for_each_model() |
128
|
|
|
def test_shared(module, EXAMPLE): |
129
|
|
|
assert_hasattr(module, 'Shared') |
130
|
|
|
assert_is_instance(module.Shared, type) |
131
|
|
|
|
132
|
|
|
shared1 = module.Shared.from_dict(EXAMPLE['shared']) |
133
|
|
|
shared2 = module.Shared.from_dict(EXAMPLE['shared']) |
134
|
|
|
assert_close(shared1.dump(), EXAMPLE['shared']) |
135
|
|
|
|
136
|
|
|
values = EXAMPLE['values'] |
137
|
|
|
seed_all(0) |
138
|
|
|
for value in values: |
139
|
|
|
shared1.add_value(value) |
140
|
|
|
seed_all(0) |
141
|
|
|
for value in values: |
142
|
|
|
shared2.add_value(value) |
143
|
|
|
assert_close(shared1.dump(), shared2.dump()) |
144
|
|
|
|
145
|
|
|
for value in values: |
146
|
|
|
shared1.remove_value(value) |
147
|
|
|
assert_close(shared1.dump(), EXAMPLE['shared']) |
148
|
|
|
|
149
|
|
|
|
150
|
|
|
@for_each_model() |
151
|
|
|
def test_group(module, EXAMPLE): |
152
|
|
|
assert_hasattr(module, 'Group') |
153
|
|
|
assert_is_instance(module.Group, type) |
154
|
|
|
|
155
|
|
|
shared = module.Shared.from_dict(EXAMPLE['shared']) |
156
|
|
|
values = EXAMPLE['values'] |
157
|
|
|
for value in values: |
158
|
|
|
shared.add_value(value) |
159
|
|
|
|
160
|
|
|
group1 = module.Group() |
161
|
|
|
group1.init(shared) |
162
|
|
|
for value in values: |
163
|
|
|
group1.add_value(shared, value) |
164
|
|
|
group2 = module.Group.from_values(shared, values) |
165
|
|
|
assert_close(group1.dump(), group2.dump()) |
166
|
|
|
|
167
|
|
|
group = module.Group.from_values(shared, values) |
168
|
|
|
dumped = group.dump() |
169
|
|
|
group.init(shared) |
170
|
|
|
group.load(dumped) |
171
|
|
|
assert_close(group.dump(), dumped) |
172
|
|
|
|
173
|
|
|
for value in values: |
174
|
|
|
group2.remove_value(shared, value) |
175
|
|
|
assert_not_equal(group1, group2) |
176
|
|
|
group2.merge(shared, group1) |
177
|
|
|
|
178
|
|
|
for value in values: |
179
|
|
|
group1.score_value(shared, value) |
180
|
|
|
for _ in xrange(10): |
181
|
|
|
value = group1.sample_value(shared) |
182
|
|
|
group1.score_value(shared, value) |
183
|
|
|
module.sample_group(shared, 10) |
184
|
|
|
group1.score_data(shared) |
185
|
|
|
group2.score_data(shared) |
186
|
|
|
|
187
|
|
|
|
188
|
|
|
@for_each_model(lambda module: hasattr(module.Shared, 'protobuf_load')) |
189
|
|
|
def test_protobuf(module, EXAMPLE): |
190
|
|
|
if not has_protobuf: |
191
|
|
|
raise SkipTest('protobuf not available') |
192
|
|
|
shared = module.Shared.from_dict(EXAMPLE['shared']) |
193
|
|
|
values = EXAMPLE['values'] |
194
|
|
|
Message = getattr(distributions.io.schema_pb2, module.NAME) |
195
|
|
|
|
196
|
|
|
message = Message.Shared() |
197
|
|
|
shared.protobuf_dump(message) |
198
|
|
|
shared2 = module.Shared() |
199
|
|
|
shared2.protobuf_load(message) |
200
|
|
|
assert_close(shared2.dump(), shared.dump()) |
201
|
|
|
|
202
|
|
|
message.Clear() |
203
|
|
|
dumped = shared.dump() |
204
|
|
|
module.Shared.to_protobuf(dumped, message) |
205
|
|
|
assert_close(module.Shared.from_protobuf(message), dumped) |
206
|
|
|
|
207
|
|
|
if hasattr(module.Group, 'protobuf_load'): |
208
|
|
|
for value in values: |
209
|
|
|
shared.add_value(value) |
210
|
|
|
group = module.Group.from_values(shared, values) |
211
|
|
|
|
212
|
|
|
message = Message.Group() |
213
|
|
|
group.protobuf_dump(message) |
214
|
|
|
group2 = module.Group() |
215
|
|
|
group2.protobuf_load(message) |
216
|
|
|
assert_close(group2.dump(), group.dump()) |
217
|
|
|
|
218
|
|
|
message.Clear() |
219
|
|
|
dumped = group.dump() |
220
|
|
|
module.Group.to_protobuf(dumped, message) |
221
|
|
|
assert_close(module.Group.from_protobuf(message), dumped) |
222
|
|
|
|
223
|
|
|
|
224
|
|
|
@for_each_model() |
225
|
|
|
def test_add_remove(module, EXAMPLE): |
226
|
|
|
# Test group_add_value, group_remove_value, score_data, score_value |
227
|
|
|
|
228
|
|
|
shared = module.Shared.from_dict(EXAMPLE['shared']) |
229
|
|
|
shared.realize() |
230
|
|
|
|
231
|
|
|
values = [] |
232
|
|
|
group = module.Group.from_values(shared) |
233
|
|
|
score = 0.0 |
234
|
|
|
assert_close(group.score_data(shared), score, err_msg='p(empty) != 1') |
235
|
|
|
|
236
|
|
|
for _ in range(DATA_COUNT): |
237
|
|
|
value = group.sample_value(shared) |
238
|
|
|
values.append(value) |
239
|
|
|
score += group.score_value(shared, value) |
240
|
|
|
group.add_value(shared, value) |
241
|
|
|
|
242
|
|
|
group_all = module.Group.from_dict(group.dump()) |
243
|
|
|
assert_close( |
244
|
|
|
score, |
245
|
|
|
group.score_data(shared), |
246
|
|
|
err_msg='p(x1,...,xn) != p(x1) p(x2|x1) p(xn|...)') |
247
|
|
|
|
248
|
|
|
numpy.random.shuffle(values) |
249
|
|
|
|
250
|
|
|
for value in values: |
251
|
|
|
group.remove_value(shared, value) |
252
|
|
|
|
253
|
|
|
group_empty = module.Group.from_values(shared) |
254
|
|
|
assert_close( |
255
|
|
|
group.dump(), |
256
|
|
|
group_empty.dump(), |
257
|
|
|
err_msg='group + values - values != group') |
258
|
|
|
|
259
|
|
|
numpy.random.shuffle(values) |
260
|
|
|
for value in values: |
261
|
|
|
group.add_value(shared, value) |
262
|
|
|
assert_close( |
263
|
|
|
group.dump(), |
264
|
|
|
group_all.dump(), |
265
|
|
|
err_msg='group - values + values != group') |
266
|
|
|
|
267
|
|
|
|
268
|
|
|
@for_each_model() |
269
|
|
|
def test_add_repeated(module, EXAMPLE): |
270
|
|
|
# Test add_repeated value vs n * add |
271
|
|
|
shared = module.Shared.from_dict(EXAMPLE['shared']) |
272
|
|
|
shared.realize() |
273
|
|
|
for value in EXAMPLE['values']: |
274
|
|
|
group = module.Group.from_values(shared) |
275
|
|
|
for _ in range(DATA_COUNT): |
276
|
|
|
group.add_value(shared, value) |
277
|
|
|
|
278
|
|
|
group_repeated = module.Group.from_values(shared) |
279
|
|
|
group_repeated.add_repeated_value(shared, value, count=DATA_COUNT) |
280
|
|
|
assert_close( |
281
|
|
|
group.dump(), |
282
|
|
|
group_repeated.dump(), |
283
|
|
|
err_msg='n * add_value != add_repeated_value n') |
284
|
|
|
|
285
|
|
|
|
286
|
|
|
@for_each_model() |
287
|
|
|
def test_add_merge(module, EXAMPLE): |
288
|
|
|
# Test group_add_value, group_merge |
289
|
|
|
shared = module.Shared.from_dict(EXAMPLE['shared']) |
290
|
|
|
values = EXAMPLE['values'][:] |
291
|
|
|
for value in values: |
292
|
|
|
shared.add_value(value) |
293
|
|
|
|
294
|
|
|
numpy.random.shuffle(values) |
295
|
|
|
group = module.Group.from_values(shared, values) |
296
|
|
|
|
297
|
|
|
for i in xrange(len(values) + 1): |
298
|
|
|
numpy.random.shuffle(values) |
299
|
|
|
group1 = module.Group.from_values(shared, values[:i]) |
300
|
|
|
group2 = module.Group.from_values(shared, values[i:]) |
301
|
|
|
group1.merge(shared, group2) |
302
|
|
|
assert_close(group.dump(), group1.dump()) |
303
|
|
|
|
304
|
|
|
|
305
|
|
|
@for_each_model() |
306
|
|
|
def test_group_merge(module, EXAMPLE): |
307
|
|
|
shared = module.Shared.from_dict(EXAMPLE['shared']) |
308
|
|
|
shared.realize() |
309
|
|
|
group1 = module.Group.from_values(shared) |
310
|
|
|
group2 = module.Group.from_values(shared) |
311
|
|
|
expected = module.Group.from_values(shared) |
312
|
|
|
actual = module.Group.from_values(shared) |
313
|
|
|
for _ in xrange(100): |
314
|
|
|
value = expected.sample_value(shared) |
315
|
|
|
expected.add_value(shared, value) |
316
|
|
|
group1.add_value(shared, value) |
317
|
|
|
|
318
|
|
|
value = expected.sample_value(shared) |
319
|
|
|
expected.add_value(shared, value) |
320
|
|
|
group2.add_value(shared, value) |
321
|
|
|
|
322
|
|
|
actual.load(group1.dump()) |
323
|
|
|
actual.merge(shared, group2) |
324
|
|
|
assert_close(actual.dump(), expected.dump()) |
325
|
|
|
|
326
|
|
|
|
327
|
|
|
@for_each_model(lambda module: module.Value in [bool, int]) |
328
|
|
|
def test_group_allows_debt(module, EXAMPLE): |
329
|
|
|
# Test that group.add_value can safely go into data debt |
330
|
|
|
shared = module.Shared.from_dict(EXAMPLE['shared']) |
331
|
|
|
shared.realize() |
332
|
|
|
values = [] |
333
|
|
|
group1 = module.Group.from_values(shared, values) |
334
|
|
|
for _ in range(DATA_COUNT): |
335
|
|
|
value = group1.sample_value(shared) |
336
|
|
|
values.append(value) |
337
|
|
|
group1.add_value(shared, value) |
338
|
|
|
|
339
|
|
|
group2 = module.Group.from_values(shared) |
340
|
|
|
pos_values = [(v, +1) for v in values] |
341
|
|
|
neg_values = [(v, -1) for v in values] |
342
|
|
|
signed_values = pos_values * 3 + neg_values * 2 |
343
|
|
|
numpy.random.shuffle(signed_values) |
344
|
|
|
for value, sign in signed_values: |
345
|
|
|
if sign > 0: |
346
|
|
|
group2.add_value(shared, value) |
347
|
|
|
else: |
348
|
|
|
group2.remove_value(shared, value) |
349
|
|
|
|
350
|
|
|
assert_close(group1.dump(), group2.dump()) |
351
|
|
|
|
352
|
|
|
|
353
|
|
|
@for_each_model() |
354
|
|
|
def test_sample_seed(module, EXAMPLE): |
355
|
|
|
shared = module.Shared.from_dict(EXAMPLE['shared']) |
356
|
|
|
|
357
|
|
|
seed_all(0) |
358
|
|
|
group1 = module.Group.from_values(shared) |
359
|
|
|
values1 = [group1.sample_value(shared) for _ in xrange(DATA_COUNT)] |
360
|
|
|
|
361
|
|
|
seed_all(0) |
362
|
|
|
group2 = module.Group.from_values(shared) |
363
|
|
|
values2 = [group2.sample_value(shared) for _ in xrange(DATA_COUNT)] |
364
|
|
|
|
365
|
|
|
assert_close(values1, values2, err_msg='values') |
366
|
|
|
|
367
|
|
|
|
368
|
|
|
@for_each_model() |
369
|
|
|
def test_sample_value(module, EXAMPLE): |
370
|
|
|
seed_all(0) |
371
|
|
|
shared = module.Shared.from_dict(EXAMPLE['shared']) |
372
|
|
|
shared.realize() |
373
|
|
|
for values in [[], EXAMPLE['values']]: |
374
|
|
|
group = module.Group.from_values(shared, values) |
375
|
|
|
sample_count = SAMPLE_COUNT |
376
|
|
|
if module.Value == numpy.ndarray: |
377
|
|
|
sample_count *= 10 |
378
|
|
|
samples = [group.sample_value(shared) for _ in xrange(sample_count)] |
379
|
|
|
if module.Value in [bool, int]: |
380
|
|
|
probs_dict = { |
381
|
|
|
value: math.exp(group.score_value(shared, value)) |
382
|
|
|
for value in set(samples) |
383
|
|
|
} |
384
|
|
|
gof = discrete_goodness_of_fit(samples, probs_dict, plot=True) |
385
|
|
|
elif module.Value == float: |
386
|
|
|
probs = numpy.exp([ |
387
|
|
|
group.score_value(shared, value) |
388
|
|
|
for value in samples |
389
|
|
|
]) |
390
|
|
|
gof = density_goodness_of_fit(samples, probs, plot=True) |
391
|
|
|
elif module.Value == numpy.ndarray: |
392
|
|
|
if module.__name__ == 'distributions.lp.models.niw': |
393
|
|
|
raise SkipTest('FIXME known sampling bug') |
394
|
|
|
probs = numpy.exp([ |
395
|
|
|
group.score_value(shared, value) |
396
|
|
|
for value in samples |
397
|
|
|
]) |
398
|
|
|
gof = vector_density_goodness_of_fit(samples, probs, plot=True) |
399
|
|
|
else: |
400
|
|
|
raise SkipTest('Not implemented for {}'.format(module.Value)) |
401
|
|
|
print '{} gof = {:0.3g}'.format(module.__name__, gof) |
402
|
|
|
assert_greater(gof, MIN_GOODNESS_OF_FIT) |
403
|
|
|
|
404
|
|
|
|
405
|
|
|
@for_each_model() |
406
|
|
|
def test_sample_group(module, EXAMPLE): |
407
|
|
|
seed_all(0) |
408
|
|
|
SIZE = 2 |
409
|
|
|
shared = module.Shared.from_dict(EXAMPLE['shared']) |
410
|
|
|
shared.realize() |
411
|
|
|
for values in [[], EXAMPLE['values']]: |
412
|
|
|
if module.Value in [bool, int]: |
413
|
|
|
samples = [] |
414
|
|
|
probs_dict = {} |
415
|
|
|
for _ in xrange(SAMPLE_COUNT): |
416
|
|
|
values = module.sample_group(shared, SIZE) |
417
|
|
|
sample = tuple(values) |
418
|
|
|
samples.append(sample) |
419
|
|
|
group = module.Group.from_values(shared, values) |
420
|
|
|
probs_dict[sample] = math.exp(group.score_data(shared)) |
421
|
|
|
gof = discrete_goodness_of_fit(samples, probs_dict, plot=True) |
422
|
|
|
else: |
423
|
|
|
raise SkipTest('Not implemented for {}'.format(module.Value)) |
424
|
|
|
print '{} gof = {:0.3g}'.format(module.__name__, gof) |
425
|
|
|
assert_greater(gof, MIN_GOODNESS_OF_FIT) |
426
|
|
|
|
427
|
|
|
|
428
|
|
|
def _append_ss(group, aggregator): |
429
|
|
|
ss = group.dump() |
430
|
|
|
for key, val in ss.iteritems(): |
431
|
|
|
if isinstance(val, list): |
432
|
|
|
for i, v in enumerate(val): |
433
|
|
|
aggregator['{}_{}'.format(key, i)].append(v) |
434
|
|
|
elif isinstance(val, dict): |
435
|
|
|
for k, v in val.iteritems(): |
436
|
|
|
aggregator['{}_{}'.format(key, k)].append(v) |
437
|
|
|
else: |
438
|
|
|
aggregator[key].append(val) |
439
|
|
|
|
440
|
|
|
|
441
|
|
|
def sample_marginal_conditional(module, shared, value_count): |
442
|
|
|
values = module.sample_group(shared, value_count) |
443
|
|
|
group = module.Group.from_values(shared, values) |
444
|
|
|
return group |
445
|
|
|
|
446
|
|
|
|
447
|
|
|
def sample_successive_conditional(module, shared, group, value_count): |
448
|
|
|
sampler = module.Sampler() |
449
|
|
|
sampler.init(shared, group) |
450
|
|
|
values = [sampler.eval(shared) for _ in xrange(value_count)] |
451
|
|
|
new_group = module.Group.from_values(shared, values) |
452
|
|
|
return new_group |
453
|
|
|
|
454
|
|
|
|
455
|
|
|
@for_each_model(model_is_fast) |
456
|
|
|
def test_joint(module, EXAMPLE): |
457
|
|
|
# \cite{geweke04getting} |
458
|
|
|
seed_all(0) |
459
|
|
|
SIZE = 10 |
460
|
|
|
SKIP = 100 |
461
|
|
|
shared = module.Shared.from_dict(EXAMPLE['shared']) |
462
|
|
|
shared.realize() |
463
|
|
|
marginal_conditional_samples = defaultdict(lambda: []) |
464
|
|
|
successive_conditional_samples = defaultdict(lambda: []) |
465
|
|
|
cond_group = sample_marginal_conditional(module, shared, SIZE) |
466
|
|
|
for _ in xrange(SAMPLE_COUNT): |
467
|
|
|
marg_group = sample_marginal_conditional(module, shared, SIZE) |
468
|
|
|
_append_ss(marg_group, marginal_conditional_samples) |
469
|
|
|
|
470
|
|
|
for __ in range(SKIP): |
471
|
|
|
cond_group = sample_successive_conditional( |
472
|
|
|
module, |
473
|
|
|
shared, |
474
|
|
|
cond_group, |
475
|
|
|
SIZE) |
476
|
|
|
_append_ss(cond_group, successive_conditional_samples) |
477
|
|
|
for key in marginal_conditional_samples.keys(): |
478
|
|
|
gof = scipy.stats.ttest_ind( |
479
|
|
|
marginal_conditional_samples[key], |
480
|
|
|
successive_conditional_samples[key])[1] |
481
|
|
|
if isinstance(gof, numpy.ndarray): |
482
|
|
|
raise SkipTest('XXX: handle array case, gof = {}'.format(gof)) |
483
|
|
|
print '{}:{} gof = {:0.3g}'.format(module.__name__, key, gof) |
484
|
|
|
if not numpy.isfinite(gof): |
485
|
|
|
raise SkipTest('Test fails with gof = {}'.format(gof)) |
486
|
|
|
assert_greater(gof, MIN_GOODNESS_OF_FIT) |
487
|
|
|
|
488
|
|
|
|
489
|
|
|
@for_each_model(lambda module: hasattr(module.Shared, 'scorer_create')) |
490
|
|
|
def test_scorer(module, EXAMPLE): |
491
|
|
|
shared = module.Shared.from_dict(EXAMPLE['shared']) |
492
|
|
|
values = EXAMPLE['values'] |
493
|
|
|
|
494
|
|
|
group = module.Group.from_values(shared) |
495
|
|
|
scorer1 = shared.scorer_create() |
496
|
|
|
scorer2 = shared.scorer_create(group) |
497
|
|
|
for value in values: |
498
|
|
|
score1 = shared.scorer_eval(scorer1, value) |
499
|
|
|
score2 = shared.scorer_eval(scorer2, value) |
500
|
|
|
score3 = group.score_value(shared, value) |
501
|
|
|
assert_all_close([score1, score2, score3]) |
502
|
|
|
|
503
|
|
|
|
504
|
|
|
@for_each_model(lambda module: hasattr(module, 'Mixture')) |
505
|
|
|
def test_mixture_runs(module, EXAMPLE): |
506
|
|
|
shared = module.Shared.from_dict(EXAMPLE['shared']) |
507
|
|
|
values = EXAMPLE['values'] |
508
|
|
|
|
509
|
|
|
mixture = module.Mixture() |
510
|
|
|
for value in values: |
511
|
|
|
shared.add_value(value) |
512
|
|
|
mixture.append(module.Group.from_values(shared, [value])) |
513
|
|
|
mixture.init(shared) |
514
|
|
|
|
515
|
|
|
groupids = [] |
516
|
|
|
for value in values: |
517
|
|
|
scores = numpy.zeros(len(mixture), dtype=numpy.float32) |
518
|
|
|
mixture.score_value(shared, value, scores) |
519
|
|
|
probs = scores_to_probs(scores) |
520
|
|
|
groupid = sample_discrete(probs) |
521
|
|
|
mixture.add_value(shared, groupid, value) |
522
|
|
|
groupids.append(groupid) |
523
|
|
|
|
524
|
|
|
mixture.add_group(shared) |
525
|
|
|
assert len(mixture) == len(values) + 1 |
526
|
|
|
scores = numpy.zeros(len(mixture), dtype=numpy.float32) |
527
|
|
|
|
528
|
|
|
for value, groupid in zip(values, groupids): |
529
|
|
|
mixture.remove_value(shared, groupid, value) |
530
|
|
|
|
531
|
|
|
mixture.remove_group(shared, 0) |
532
|
|
|
if module.__name__ == 'distributions.lp.models.dpd': |
533
|
|
|
raise SkipTest('FIXME known segfault here') |
534
|
|
|
mixture.remove_group(shared, len(mixture) - 1) |
535
|
|
|
assert len(mixture) == len(values) - 1 |
536
|
|
|
|
537
|
|
|
for value in values: |
538
|
|
|
scores = numpy.zeros(len(mixture), dtype=numpy.float32) |
539
|
|
|
mixture.score_value(shared, value, scores) |
540
|
|
|
probs = scores_to_probs(scores) |
541
|
|
|
groupid = sample_discrete(probs) |
542
|
|
|
mixture.add_value(shared, groupid, value) |
543
|
|
|
|
544
|
|
|
|
545
|
|
|
@for_each_model(lambda module: hasattr(module, 'Mixture')) |
546
|
|
|
def test_mixture_score(module, EXAMPLE): |
547
|
|
|
shared = module.Shared.from_dict(EXAMPLE['shared']) |
548
|
|
|
values = EXAMPLE['values'] |
549
|
|
|
for value in values: |
550
|
|
|
shared.add_value(value) |
551
|
|
|
|
552
|
|
|
groups = [module.Group.from_values(shared, [value]) for value in values] |
553
|
|
|
mixture = module.Mixture() |
554
|
|
|
for group in groups: |
555
|
|
|
mixture.append(group) |
556
|
|
|
mixture.init(shared) |
557
|
|
|
|
558
|
|
|
def check_score_value(value): |
559
|
|
|
expected = [group.score_value(shared, value) for group in groups] |
560
|
|
|
actual = numpy.zeros(len(mixture), dtype=numpy.float32) |
561
|
|
|
noise = numpy.random.randn(len(actual)) |
562
|
|
|
actual += noise |
563
|
|
|
mixture.score_value(shared, value, actual) |
564
|
|
|
actual -= noise |
565
|
|
|
assert_close(actual, expected, err_msg='score_value {}'.format(value)) |
566
|
|
|
another = [ |
567
|
|
|
mixture.score_value_group(shared, i, value) |
568
|
|
|
for i in xrange(len(groups)) |
569
|
|
|
] |
570
|
|
|
assert_close( |
571
|
|
|
another, |
572
|
|
|
expected, |
573
|
|
|
err_msg='score_value_group {}'.format(value)) |
574
|
|
|
return actual |
575
|
|
|
|
576
|
|
|
def check_score_data(): |
577
|
|
|
expected = sum(group.score_data(shared) for group in groups) |
578
|
|
|
actual = mixture.score_data(shared) |
579
|
|
|
assert_close(actual, expected, err_msg='score_data') |
580
|
|
|
|
581
|
|
|
print 'init' |
582
|
|
|
for value in values: |
583
|
|
|
check_score_value(value) |
584
|
|
|
check_score_data() |
585
|
|
|
|
586
|
|
|
print 'adding' |
587
|
|
|
groupids = [] |
588
|
|
|
for value in values: |
589
|
|
|
scores = check_score_value(value) |
590
|
|
|
probs = scores_to_probs(scores) |
591
|
|
|
groupid = sample_discrete(probs) |
592
|
|
|
groups[groupid].add_value(shared, value) |
593
|
|
|
mixture.add_value(shared, groupid, value) |
594
|
|
|
groupids.append(groupid) |
595
|
|
|
check_score_data() |
596
|
|
|
|
597
|
|
|
print 'removing' |
598
|
|
|
for value, groupid in zip(values, groupids): |
599
|
|
|
groups[groupid].remove_value(shared, value) |
600
|
|
|
mixture.remove_value(shared, groupid, value) |
601
|
|
|
scores = check_score_value(value) |
602
|
|
|
check_score_data() |
603
|
|
|
|