GitHub Access Token became invalid

It seems like the GitHub access token used for retrieving details about this repository from GitHub became invalid. This might prevent certain types of inspections from being run (in particular, everything related to pull requests).
Please ask an admin of your repository to re-new the access token on this website.

assert_close()   F
last analyzed

Complexity

Conditions 11

Size

Total Lines 46

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 11
dl 0
loc 46
rs 3.1764
c 0
b 0
f 0

How to fix   Complexity   

Complexity

Complex classes like assert_close() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
# Copyright (c) 2014, Salesforce.com, Inc.  All rights reserved.
2
#
3
# Redistribution and use in source and binary forms, with or without
4
# modification, are permitted provided that the following conditions
5
# are met:
6
#
7
# - Redistributions of source code must retain the above copyright
8
#   notice, this list of conditions and the following disclaimer.
9
# - Redistributions in binary form must reproduce the above copyright
10
#   notice, this list of conditions and the following disclaimer in the
11
#   documentation and/or other materials provided with the distribution.
12
# - Neither the name of Salesforce.com nor the names of its contributors
13
#   may be used to endorse or promote products derived from this
14
#   software without specific prior written permission.
15
#
16
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
19
# FOR A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE
20
# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
21
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
22
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS
23
# OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
24
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR
25
# TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
26
# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27
28
import os
29
import glob
30
from collections import defaultdict
31
from itertools import izip
32
import math
33
import numpy
34
from numpy.testing import assert_array_almost_equal
35
from nose import SkipTest
36
from nose.tools import assert_true, assert_less, assert_equal
37
import importlib
38
import distributions
39
from goftests import multinomial_goodness_of_fit
40
41
ROOT = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
42
TOL = 1e-3
43
44
45
def require_cython():
46
    if not distributions.has_cython:
47
        raise SkipTest('no cython support')
48
49
50
def seed_all(s):
51
    import distributions.dbg.random
52
    distributions.dbg.random.seed(s)
53
    try:
54
        import distributions.hp.random
55
        distributions.hp.random.seed(s)
56
    except ImportError:
57
        pass
58
59
60
def list_models():
61
    result = set()
62
    for path in glob.glob(os.path.join(ROOT, '*', 'models', '*.p*')):
63
        dirname, filename = os.path.split(path)
64
        flavor = os.path.split(os.path.dirname(dirname))[-1]
65
        name = os.path.splitext(filename)[0]
66
        if not name.startswith('__'):
67
            result.add((name, flavor))
68
    for name, flavor in sorted(result):
69
        spec = {'flavor': flavor, 'name': name}
70
        if name.startswith('_'):
71
            continue
72
        try:
73
            import_model(spec)
74
            yield spec
75
        except ImportError:
76
            module_name = 'distributions.{flavor}.models.{name}'.format(**spec)
77
            print 'failed to import {}'.format(module_name)
78
            import traceback
79
            print traceback.format_exc()
80
81
82
def import_model(spec):
83
    module_name = 'distributions.{flavor}.models.{name}'.format(**spec)
84
    return importlib.import_module(module_name)
85
86
87
def assert_hasattr(thing, attr):
88
    assert_true(
89
        hasattr(thing, attr),
90
        "{} is missing attribute '{}'".format(thing.__name__, attr))
91
92
93
def print_short(x, size=64):
94
    string = str(x)
95
    if len(string) > size:
96
        string = string[:size - 3] + '...'
97
    return string
98
99
100
def assert_close(lhs, rhs, tol=TOL, err_msg=None):
101
    try:
102
        if isinstance(lhs, dict):
103
            assert_true(
104
                isinstance(rhs, dict),
105
                'type mismatch: {} vs {}'.format(type(lhs), type(rhs)))
106
            assert_equal(set(lhs.keys()), set(rhs.keys()))
107
            for key, val in lhs.iteritems():
108
                msg = '{}[{}]'.format(err_msg or '', key)
109
                assert_close(val, rhs[key], tol, msg)
110
        elif isinstance(lhs, float) or isinstance(lhs, numpy.float64):
111
            assert_true(
112
                isinstance(rhs, float) or isinstance(rhs, numpy.float64),
113
                'type mismatch: {} vs {}'.format(type(lhs), type(rhs)))
114
            diff = abs(lhs - rhs)
115
            norm = 1 + abs(lhs) + abs(rhs)
116
            msg = '{} off by {}% = {}'.format(
117
                err_msg or '',
118
                100 * diff / norm,
119
                diff)
120
            assert_less(diff, tol * norm, msg)
121
        elif isinstance(lhs, numpy.ndarray) or isinstance(rhs, numpy.ndarray):
122
            assert_true(
123
                (isinstance(lhs, numpy.ndarray) or isinstance(lhs, list)) and
124
                (isinstance(rhs, numpy.ndarray) or isinstance(rhs, list)),
125
                'type mismatch: {} vs {}'.format(type(lhs), type(rhs)))
126
            decimal = int(round(-math.log10(tol)))
127
            assert_array_almost_equal(
128
                lhs,
129
                rhs,
130
                decimal=decimal,
131
                err_msg=(err_msg or ''))
132
        elif isinstance(lhs, list) or isinstance(lhs, tuple):
133
            assert_true(
134
                isinstance(rhs, list) or isinstance(rhs, tuple),
135
                'type mismatch: {} vs {}'.format(type(lhs), type(rhs)))
136
            for pos, (x, y) in enumerate(izip(lhs, rhs)):
137
                msg = '{}[{}]'.format(err_msg or '', pos)
138
                assert_close(x, y, tol, msg)
139
        else:
140
            assert_equal(lhs, rhs, err_msg)
141
    except Exception:
142
        print err_msg or ''
143
        print 'actual = {}'.format(print_short(lhs))
144
        print 'expected = {}'.format(print_short(rhs))
145
        raise
146
147
148
def assert_all_close(collection, **kwargs):
149
    for i1, item1 in enumerate(collection[:-1]):
150
        for item2 in collection[i1 + 1:]:
151
            assert_close(item1, item2, **kwargs)
152
153
154
def collect_samples_and_scores(sampler, total_count=10000):
155
    '''
156
    Collect samples and MC estimates of sample probabilities.
157
158
    Inputs:
159
        - sampler generates (sample, prob) pairs.  samples must be hashable.
160
          probs may be randomized, but must be unbiased and low-variance.
161
        - total_count samples are drawn in total.
162
        - tol is the minimum goodness of fit allowed to pass the test.
163
    Returns:
164
        - counts : key -> int
165
        - probs : key -> float
166
    '''
167
    counts = defaultdict(lambda: 0)
168
    probs = defaultdict(lambda: 0.0)
169
    for _ in xrange(total_count):
170
        sample, prob = sampler()
171
        counts[sample] += 1
172
        probs[sample] += prob
173
174
    for key, count in counts.iteritems():
175
        probs[key] /= count
176
    total_prob = sum(probs.itervalues())
177
    assert_close(total_prob, 1.0, tol=1e-2, err_msg='total_prob is biased')
178
179
    return counts, probs
180
181
182
def assert_counts_match_probs(counts, probs, tol=1e-3):
183
    '''
184
    Check goodness of fit of observed counts to predicted probabilities
185
    using Pearson's chi-squared test.
186
187
    Inputs:
188
        - counts : key -> int
189
        - probs : key -> float
190
    '''
191
    keys = counts.keys()
192
    probs = [probs[key] for key in keys]
193
    counts = [counts[key] for key in keys]
194
    total_count = sum(counts)
195
196
    print 'EXPECT\tACTUAL\tVALUE'
197
    for prob, count, key in sorted(izip(probs, counts, keys), reverse=True):
198
        expect = prob * total_count
199
        print '{:0.1f}\t{}\t{}'.format(expect, count, key)
200
201
    gof = multinomial_goodness_of_fit(probs, counts, total_count)
202
    print 'goodness of fit = {}'.format(gof)
203
    assert gof > tol, 'failed with goodness of fit {}'.format(gof)
204
205
206
def assert_samples_match_scores(sampler, total_count=10000, tol=1e-3):
207
    '''
208
    Test that a discrete sampler is distributed according to its scores.
209
210
    Inputs:
211
        - sampler generates (sample, prob) pairs.  samples must be hashable.
212
          probs may be randomized, but must be unbiased and low-variance.
213
        - total_count samples are drawn in total.
214
        - tol is the minimum goodness of fit allowed to pass the test.
215
    '''
216
    counts, probs = collect_samples_and_scores(sampler, total_count)
217
    assert_counts_match_probs(counts, probs)
218