GitHub Access Token became invalid

It seems like the GitHub access token used for retrieving details about this repository from GitHub became invalid. This might prevent certain types of inspections from being run (in particular, everything related to pull requests).
Please ask an admin of your repository to re-new the access token on this website.

test_mixture_runs()   F
last analyzed

Complexity

Conditions 9

Size

Total Lines 39

Duplication

Lines 0
Ratio 0 %

Importance

Changes 2
Bugs 0 Features 0
Metric Value
cc 9
dl 0
loc 39
rs 3
c 2
b 0
f 0
1
# Copyright (c) 2014, Salesforce.com, Inc.  All rights reserved.
2
#
3
# Redistribution and use in source and binary forms, with or without
4
# modification, are permitted provided that the following conditions
5
# are met:
6
#
7
# - Redistributions of source code must retain the above copyright
8
#   notice, this list of conditions and the following disclaimer.
9
# - Redistributions in binary form must reproduce the above copyright
10
#   notice, this list of conditions and the following disclaimer in the
11
#   documentation and/or other materials provided with the distribution.
12
# - Neither the name of Salesforce.com nor the names of its contributors
13
#   may be used to endorse or promote products derived from this
14
#   software without specific prior written permission.
15
#
16
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
19
# FOR A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE
20
# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
21
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
22
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS
23
# OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
24
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR
25
# TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
26
# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27
28
import math
29
import numpy
30
import numpy.random
31
import scipy.stats
32
import functools
33
from collections import defaultdict
34
from nose import SkipTest
35
from nose.tools import assert_greater
36
from nose.tools import assert_in
37
from nose.tools import assert_is_instance
38
from nose.tools import assert_not_equal
39
from nose.tools import assert_true
40
from goftests import density_goodness_of_fit
41
from goftests import discrete_goodness_of_fit
42
from goftests import vector_density_goodness_of_fit
43
from distributions.dbg.random import sample_discrete
44
from distributions.util import scores_to_probs
45
from distributions.tests.util import assert_all_close
46
from distributions.tests.util import assert_close
47
from distributions.tests.util import assert_hasattr
48
from distributions.tests.util import import_model
49
from distributions.tests.util import list_models
50
from distributions.tests.util import seed_all
51
52
try:
53
    import distributions.io.schema_pb2
54
    has_protobuf = True
55
except ImportError:
56
    has_protobuf = False
57
58
DATA_COUNT = 20
59
SAMPLE_COUNT = 1000
60
MIN_GOODNESS_OF_FIT = 1e-3
61
62
MODULES = {
63
    '{flavor}.models.{name}'.format(**spec): import_model(spec)
64
    for spec in list_models()
65
}
66
67
IS_FAST = {'dbg': False, 'hp': True, 'lp': True}
68
69
70
def model_is_fast(model):
71
    flavor = model.__name__.split('.')[1]
72
    return IS_FAST[flavor]
73
74
75
def iter_examples(module):
76
    assert_hasattr(module, 'EXAMPLES')
77
    EXAMPLES = module.EXAMPLES
78
    assert_is_instance(EXAMPLES, list)
79
    assert_true(EXAMPLES, 'no examples provided')
80
    for i, EXAMPLE in enumerate(EXAMPLES):
81
        print 'example {}/{}'.format(1 + i, len(EXAMPLES))
82
        assert_in('shared', EXAMPLE)
83
        assert_in('values', EXAMPLE)
84
        values = EXAMPLE['values']
85
        assert_is_instance(values, list)
86
        count = len(values)
87
        assert_true(
88
            count >= 7,
89
            'Add more example values (expected >= 7, found {})'.format(count))
90
        yield EXAMPLE
91
92
93
def for_each_model(*filters):
94
    '''
95
    Run one test per Model, filtering out inappropriate Models for test.
96
    '''
97
    def filtered(test_fun):
98
99
        @functools.wraps(test_fun)
100
        def test_one_model(name):
101
            module = MODULES[name]
102
            assert_hasattr(module, 'Shared')
103
            for EXAMPLE in iter_examples(module):
104
                test_fun(module, EXAMPLE)
105
106
        @functools.wraps(test_fun)
107
        def test_all_models():
108
            for name in MODULES:
109
                module = MODULES[name]
110
                if all(f(module) for f in filters):
111
                    yield test_one_model, name
112
113
        return test_all_models
114
    return filtered
115
116
117
@for_each_model()
118
def test_value(module, EXAMPLE):
119
    assert_hasattr(module, 'Value')
120
    assert_is_instance(module.Value, type)
121
122
    values = EXAMPLE['values']
123
    for value in values:
124
        assert_is_instance(value, module.Value)
125
126
127
@for_each_model()
128
def test_shared(module, EXAMPLE):
129
    assert_hasattr(module, 'Shared')
130
    assert_is_instance(module.Shared, type)
131
132
    shared1 = module.Shared.from_dict(EXAMPLE['shared'])
133
    shared2 = module.Shared.from_dict(EXAMPLE['shared'])
134
    assert_close(shared1.dump(), EXAMPLE['shared'])
135
136
    values = EXAMPLE['values']
137
    seed_all(0)
138
    for value in values:
139
        shared1.add_value(value)
140
    seed_all(0)
141
    for value in values:
142
        shared2.add_value(value)
143
    assert_close(shared1.dump(), shared2.dump())
144
145
    for value in values:
146
        shared1.remove_value(value)
147
    assert_close(shared1.dump(), EXAMPLE['shared'])
148
149
150
@for_each_model()
151
def test_group(module, EXAMPLE):
152
    assert_hasattr(module, 'Group')
153
    assert_is_instance(module.Group, type)
154
155
    shared = module.Shared.from_dict(EXAMPLE['shared'])
156
    values = EXAMPLE['values']
157
    for value in values:
158
        shared.add_value(value)
159
160
    group1 = module.Group()
161
    group1.init(shared)
162
    for value in values:
163
        group1.add_value(shared, value)
164
    group2 = module.Group.from_values(shared, values)
165
    assert_close(group1.dump(), group2.dump())
166
167
    group = module.Group.from_values(shared, values)
168
    dumped = group.dump()
169
    group.init(shared)
170
    group.load(dumped)
171
    assert_close(group.dump(), dumped)
172
173
    for value in values:
174
        group2.remove_value(shared, value)
175
    assert_not_equal(group1, group2)
176
    group2.merge(shared, group1)
177
178
    for value in values:
179
        group1.score_value(shared, value)
180
    for _ in xrange(10):
181
        value = group1.sample_value(shared)
182
        group1.score_value(shared, value)
183
        module.sample_group(shared, 10)
184
    group1.score_data(shared)
185
    group2.score_data(shared)
186
187
188
@for_each_model(lambda module: hasattr(module.Shared, 'protobuf_load'))
189
def test_protobuf(module, EXAMPLE):
190
    if not has_protobuf:
191
        raise SkipTest('protobuf not available')
192
    shared = module.Shared.from_dict(EXAMPLE['shared'])
193
    values = EXAMPLE['values']
194
    Message = getattr(distributions.io.schema_pb2, module.NAME)
195
196
    message = Message.Shared()
197
    shared.protobuf_dump(message)
198
    shared2 = module.Shared()
199
    shared2.protobuf_load(message)
200
    assert_close(shared2.dump(), shared.dump())
201
202
    message.Clear()
203
    dumped = shared.dump()
204
    module.Shared.to_protobuf(dumped, message)
205
    assert_close(module.Shared.from_protobuf(message), dumped)
206
207
    if hasattr(module.Group, 'protobuf_load'):
208
        for value in values:
209
            shared.add_value(value)
210
        group = module.Group.from_values(shared, values)
211
212
        message = Message.Group()
213
        group.protobuf_dump(message)
214
        group2 = module.Group()
215
        group2.protobuf_load(message)
216
        assert_close(group2.dump(), group.dump())
217
218
        message.Clear()
219
        dumped = group.dump()
220
        module.Group.to_protobuf(dumped, message)
221
        assert_close(module.Group.from_protobuf(message), dumped)
222
223
224
@for_each_model()
225
def test_add_remove(module, EXAMPLE):
226
    # Test group_add_value, group_remove_value, score_data, score_value
227
228
    shared = module.Shared.from_dict(EXAMPLE['shared'])
229
    shared.realize()
230
231
    values = []
232
    group = module.Group.from_values(shared)
233
    score = 0.0
234
    assert_close(group.score_data(shared), score, err_msg='p(empty) != 1')
235
236
    for _ in range(DATA_COUNT):
237
        value = group.sample_value(shared)
238
        values.append(value)
239
        score += group.score_value(shared, value)
240
        group.add_value(shared, value)
241
242
    group_all = module.Group.from_dict(group.dump())
243
    assert_close(
244
        score,
245
        group.score_data(shared),
246
        err_msg='p(x1,...,xn) != p(x1) p(x2|x1) p(xn|...)')
247
248
    numpy.random.shuffle(values)
249
250
    for value in values:
251
        group.remove_value(shared, value)
252
253
    group_empty = module.Group.from_values(shared)
254
    assert_close(
255
        group.dump(),
256
        group_empty.dump(),
257
        err_msg='group + values - values != group')
258
259
    numpy.random.shuffle(values)
260
    for value in values:
261
        group.add_value(shared, value)
262
    assert_close(
263
        group.dump(),
264
        group_all.dump(),
265
        err_msg='group - values + values != group')
266
267
268
@for_each_model()
269
def test_add_repeated(module, EXAMPLE):
270
    # Test add_repeated value vs n * add
271
    shared = module.Shared.from_dict(EXAMPLE['shared'])
272
    shared.realize()
273
    for value in EXAMPLE['values']:
274
        group = module.Group.from_values(shared)
275
        for _ in range(DATA_COUNT):
276
            group.add_value(shared, value)
277
278
        group_repeated = module.Group.from_values(shared)
279
        group_repeated.add_repeated_value(shared, value, count=DATA_COUNT)
280
        assert_close(
281
            group.dump(),
282
            group_repeated.dump(),
283
            err_msg='n * add_value != add_repeated_value n')
284
285
286
@for_each_model()
287
def test_add_merge(module, EXAMPLE):
288
    # Test group_add_value, group_merge
289
    shared = module.Shared.from_dict(EXAMPLE['shared'])
290
    values = EXAMPLE['values'][:]
291
    for value in values:
292
        shared.add_value(value)
293
294
    numpy.random.shuffle(values)
295
    group = module.Group.from_values(shared, values)
296
297
    for i in xrange(len(values) + 1):
298
        numpy.random.shuffle(values)
299
        group1 = module.Group.from_values(shared, values[:i])
300
        group2 = module.Group.from_values(shared, values[i:])
301
        group1.merge(shared, group2)
302
        assert_close(group.dump(), group1.dump())
303
304
305
@for_each_model()
306
def test_group_merge(module, EXAMPLE):
307
    shared = module.Shared.from_dict(EXAMPLE['shared'])
308
    shared.realize()
309
    group1 = module.Group.from_values(shared)
310
    group2 = module.Group.from_values(shared)
311
    expected = module.Group.from_values(shared)
312
    actual = module.Group.from_values(shared)
313
    for _ in xrange(100):
314
        value = expected.sample_value(shared)
315
        expected.add_value(shared, value)
316
        group1.add_value(shared, value)
317
318
        value = expected.sample_value(shared)
319
        expected.add_value(shared, value)
320
        group2.add_value(shared, value)
321
322
        actual.load(group1.dump())
323
        actual.merge(shared, group2)
324
        assert_close(actual.dump(), expected.dump())
325
326
327
@for_each_model(lambda module: module.Value in [bool, int])
328
def test_group_allows_debt(module, EXAMPLE):
329
    # Test that group.add_value can safely go into data debt
330
    shared = module.Shared.from_dict(EXAMPLE['shared'])
331
    shared.realize()
332
    values = []
333
    group1 = module.Group.from_values(shared, values)
334
    for _ in range(DATA_COUNT):
335
        value = group1.sample_value(shared)
336
        values.append(value)
337
        group1.add_value(shared, value)
338
339
    group2 = module.Group.from_values(shared)
340
    pos_values = [(v, +1) for v in values]
341
    neg_values = [(v, -1) for v in values]
342
    signed_values = pos_values * 3 + neg_values * 2
343
    numpy.random.shuffle(signed_values)
344
    for value, sign in signed_values:
345
        if sign > 0:
346
            group2.add_value(shared, value)
347
        else:
348
            group2.remove_value(shared, value)
349
350
    assert_close(group1.dump(), group2.dump())
351
352
353
@for_each_model()
354
def test_sample_seed(module, EXAMPLE):
355
    shared = module.Shared.from_dict(EXAMPLE['shared'])
356
357
    seed_all(0)
358
    group1 = module.Group.from_values(shared)
359
    values1 = [group1.sample_value(shared) for _ in xrange(DATA_COUNT)]
360
361
    seed_all(0)
362
    group2 = module.Group.from_values(shared)
363
    values2 = [group2.sample_value(shared) for _ in xrange(DATA_COUNT)]
364
365
    assert_close(values1, values2, err_msg='values')
366
367
368
@for_each_model()
369
def test_sample_value(module, EXAMPLE):
370
    seed_all(0)
371
    shared = module.Shared.from_dict(EXAMPLE['shared'])
372
    shared.realize()
373
    for values in [[], EXAMPLE['values']]:
374
        group = module.Group.from_values(shared, values)
375
        sample_count = SAMPLE_COUNT
376
        if module.Value == numpy.ndarray:
377
            sample_count *= 10
378
        samples = [group.sample_value(shared) for _ in xrange(sample_count)]
379
        if module.Value in [bool, int]:
380
            probs_dict = {
381
                value: math.exp(group.score_value(shared, value))
382
                for value in set(samples)
383
            }
384
            gof = discrete_goodness_of_fit(samples, probs_dict, plot=True)
385
        elif module.Value == float:
386
            probs = numpy.exp([
387
                group.score_value(shared, value)
388
                for value in samples
389
            ])
390
            gof = density_goodness_of_fit(samples, probs, plot=True)
391
        elif module.Value == numpy.ndarray:
392
            if module.__name__ == 'distributions.lp.models.niw':
393
                raise SkipTest('FIXME known sampling bug')
394
            probs = numpy.exp([
395
                group.score_value(shared, value)
396
                for value in samples
397
            ])
398
            gof = vector_density_goodness_of_fit(samples, probs, plot=True)
399
        else:
400
            raise SkipTest('Not implemented for {}'.format(module.Value))
401
        print '{} gof = {:0.3g}'.format(module.__name__, gof)
402
        assert_greater(gof, MIN_GOODNESS_OF_FIT)
403
404
405
@for_each_model()
406
def test_sample_group(module, EXAMPLE):
407
    seed_all(0)
408
    SIZE = 2
409
    shared = module.Shared.from_dict(EXAMPLE['shared'])
410
    shared.realize()
411
    for values in [[], EXAMPLE['values']]:
412
        if module.Value in [bool, int]:
413
            samples = []
414
            probs_dict = {}
415
            for _ in xrange(SAMPLE_COUNT):
416
                values = module.sample_group(shared, SIZE)
417
                sample = tuple(values)
418
                samples.append(sample)
419
                group = module.Group.from_values(shared, values)
420
                probs_dict[sample] = math.exp(group.score_data(shared))
421
            gof = discrete_goodness_of_fit(samples, probs_dict, plot=True)
422
        else:
423
            raise SkipTest('Not implemented for {}'.format(module.Value))
424
        print '{} gof = {:0.3g}'.format(module.__name__, gof)
425
        assert_greater(gof, MIN_GOODNESS_OF_FIT)
426
427
428
def _append_ss(group, aggregator):
429
    ss = group.dump()
430
    for key, val in ss.iteritems():
431
        if isinstance(val, list):
432
            for i, v in enumerate(val):
433
                aggregator['{}_{}'.format(key, i)].append(v)
434
        elif isinstance(val, dict):
435
            for k, v in val.iteritems():
436
                aggregator['{}_{}'.format(key, k)].append(v)
437
        else:
438
            aggregator[key].append(val)
439
440
441
def sample_marginal_conditional(module, shared, value_count):
442
    values = module.sample_group(shared, value_count)
443
    group = module.Group.from_values(shared, values)
444
    return group
445
446
447
def sample_successive_conditional(module, shared, group, value_count):
448
    sampler = module.Sampler()
449
    sampler.init(shared, group)
450
    values = [sampler.eval(shared) for _ in xrange(value_count)]
451
    new_group = module.Group.from_values(shared, values)
452
    return new_group
453
454
455
@for_each_model(model_is_fast)
456
def test_joint(module, EXAMPLE):
457
    # \cite{geweke04getting}
458
    seed_all(0)
459
    SIZE = 10
460
    SKIP = 100
461
    shared = module.Shared.from_dict(EXAMPLE['shared'])
462
    shared.realize()
463
    marginal_conditional_samples = defaultdict(lambda: [])
464
    successive_conditional_samples = defaultdict(lambda: [])
465
    cond_group = sample_marginal_conditional(module, shared, SIZE)
466
    for _ in xrange(SAMPLE_COUNT):
467
        marg_group = sample_marginal_conditional(module, shared, SIZE)
468
        _append_ss(marg_group, marginal_conditional_samples)
469
470
        for __ in range(SKIP):
471
            cond_group = sample_successive_conditional(
472
                module,
473
                shared,
474
                cond_group,
475
                SIZE)
476
        _append_ss(cond_group, successive_conditional_samples)
477
    for key in marginal_conditional_samples.keys():
478
        gof = scipy.stats.ttest_ind(
479
            marginal_conditional_samples[key],
480
            successive_conditional_samples[key])[1]
481
        if isinstance(gof, numpy.ndarray):
482
            raise SkipTest('XXX: handle array case, gof = {}'.format(gof))
483
        print '{}:{} gof = {:0.3g}'.format(module.__name__, key, gof)
484
        if not numpy.isfinite(gof):
485
            raise SkipTest('Test fails with gof = {}'.format(gof))
486
        assert_greater(gof, MIN_GOODNESS_OF_FIT)
487
488
489
@for_each_model(lambda module: hasattr(module.Shared, 'scorer_create'))
490
def test_scorer(module, EXAMPLE):
491
    shared = module.Shared.from_dict(EXAMPLE['shared'])
492
    values = EXAMPLE['values']
493
494
    group = module.Group.from_values(shared)
495
    scorer1 = shared.scorer_create()
496
    scorer2 = shared.scorer_create(group)
497
    for value in values:
498
        score1 = shared.scorer_eval(scorer1, value)
499
        score2 = shared.scorer_eval(scorer2, value)
500
        score3 = group.score_value(shared, value)
501
        assert_all_close([score1, score2, score3])
502
503
504
@for_each_model(lambda module: hasattr(module, 'Mixture'))
505
def test_mixture_runs(module, EXAMPLE):
506
    shared = module.Shared.from_dict(EXAMPLE['shared'])
507
    values = EXAMPLE['values']
508
509
    mixture = module.Mixture()
510
    for value in values:
511
        shared.add_value(value)
512
        mixture.append(module.Group.from_values(shared, [value]))
513
    mixture.init(shared)
514
515
    groupids = []
516
    for value in values:
517
        scores = numpy.zeros(len(mixture), dtype=numpy.float32)
518
        mixture.score_value(shared, value, scores)
519
        probs = scores_to_probs(scores)
520
        groupid = sample_discrete(probs)
521
        mixture.add_value(shared, groupid, value)
522
        groupids.append(groupid)
523
524
    mixture.add_group(shared)
525
    assert len(mixture) == len(values) + 1
526
    scores = numpy.zeros(len(mixture), dtype=numpy.float32)
527
528
    for value, groupid in zip(values, groupids):
529
        mixture.remove_value(shared, groupid, value)
530
531
    mixture.remove_group(shared, 0)
532
    if module.__name__ == 'distributions.lp.models.dpd':
533
        raise SkipTest('FIXME known segfault here')
534
    mixture.remove_group(shared, len(mixture) - 1)
535
    assert len(mixture) == len(values) - 1
536
537
    for value in values:
538
        scores = numpy.zeros(len(mixture), dtype=numpy.float32)
539
        mixture.score_value(shared, value, scores)
540
        probs = scores_to_probs(scores)
541
        groupid = sample_discrete(probs)
542
        mixture.add_value(shared, groupid, value)
543
544
545
@for_each_model(lambda module: hasattr(module, 'Mixture'))
546
def test_mixture_score(module, EXAMPLE):
547
    shared = module.Shared.from_dict(EXAMPLE['shared'])
548
    values = EXAMPLE['values']
549
    for value in values:
550
        shared.add_value(value)
551
552
    groups = [module.Group.from_values(shared, [value]) for value in values]
553
    mixture = module.Mixture()
554
    for group in groups:
555
        mixture.append(group)
556
    mixture.init(shared)
557
558
    def check_score_value(value):
559
        expected = [group.score_value(shared, value) for group in groups]
560
        actual = numpy.zeros(len(mixture), dtype=numpy.float32)
561
        noise = numpy.random.randn(len(actual))
562
        actual += noise
563
        mixture.score_value(shared, value, actual)
564
        actual -= noise
565
        assert_close(actual, expected, err_msg='score_value {}'.format(value))
566
        another = [
567
            mixture.score_value_group(shared, i, value)
568
            for i in xrange(len(groups))
569
        ]
570
        assert_close(
571
            another,
572
            expected,
573
            err_msg='score_value_group {}'.format(value))
574
        return actual
575
576
    def check_score_data():
577
        expected = sum(group.score_data(shared) for group in groups)
578
        actual = mixture.score_data(shared)
579
        assert_close(actual, expected, err_msg='score_data')
580
581
    print 'init'
582
    for value in values:
583
        check_score_value(value)
584
    check_score_data()
585
586
    print 'adding'
587
    groupids = []
588
    for value in values:
589
        scores = check_score_value(value)
590
        probs = scores_to_probs(scores)
591
        groupid = sample_discrete(probs)
592
        groups[groupid].add_value(shared, value)
593
        mixture.add_value(shared, groupid, value)
594
        groupids.append(groupid)
595
        check_score_data()
596
597
    print 'removing'
598
    for value, groupid in zip(values, groupids):
599
        groups[groupid].remove_value(shared, value)
600
        mixture.remove_value(shared, groupid, value)
601
        scores = check_score_value(value)
602
        check_score_data()
603