1 | # Copyright (c) 2014, Salesforce.com, Inc. All rights reserved. |
||
2 | # |
||
3 | # Redistribution and use in source and binary forms, with or without |
||
4 | # modification, are permitted provided that the following conditions |
||
5 | # are met: |
||
6 | # |
||
7 | # - Redistributions of source code must retain the above copyright |
||
8 | # notice, this list of conditions and the following disclaimer. |
||
9 | # - Redistributions in binary form must reproduce the above copyright |
||
10 | # notice, this list of conditions and the following disclaimer in the |
||
11 | # documentation and/or other materials provided with the distribution. |
||
12 | # - Neither the name of Salesforce.com nor the names of its contributors |
||
13 | # may be used to endorse or promote products derived from this |
||
14 | # software without specific prior written permission. |
||
15 | # |
||
16 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS |
||
17 | # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT |
||
18 | # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS |
||
19 | # FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE |
||
20 | # COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, |
||
21 | # INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, |
||
22 | # BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS |
||
23 | # OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND |
||
24 | # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR |
||
25 | # TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE |
||
26 | # USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
||
27 | |||
28 | from math import log, pi, sqrt, factorial |
||
29 | import numpy as np |
||
30 | import numpy.random |
||
31 | from numpy.random.mtrand import dirichlet as sample_dirichlet |
||
32 | from numpy import dot, inner |
||
33 | from numpy.linalg import cholesky, det, inv |
||
34 | from numpy.random import multivariate_normal |
||
35 | from numpy.random import beta as sample_beta |
||
36 | from numpy.random import poisson as sample_poisson |
||
37 | from numpy.random import gamma as sample_gamma |
||
38 | from scipy.stats import norm, chi2, bernoulli, nbinom |
||
39 | from scipy.special import gammaln |
||
40 | from distributions.util import scores_to_probs |
||
41 | import logging |
||
42 | |||
43 | from distributions.vendor.stats import ( |
||
44 | sample_invwishart as _sample_inverse_wishart |
||
45 | ) |
||
46 | |||
47 | LOG = logging.getLogger(__name__) |
||
48 | |||
49 | |||
50 | # pacify pyflakes |
||
51 | assert sample_dirichlet and factorial and sample_poisson and sample_gamma |
||
52 | |||
53 | |||
54 | def seed(x): |
||
55 | numpy.random.seed(x) |
||
56 | try: |
||
57 | import distributions.cRandom |
||
58 | distributions.cRandom.seed(x) |
||
59 | except ImportError: |
||
60 | pass |
||
61 | |||
62 | |||
63 | def sample_discrete_log(scores): |
||
64 | probs = scores_to_probs(scores) |
||
65 | return sample_discrete(probs, total=1.0) |
||
66 | |||
67 | |||
68 | def sample_bernoulli(prob): |
||
69 | return bool(bernoulli.rvs(prob)) |
||
70 | |||
71 | |||
72 | def sample_discrete(probs, total=None): |
||
73 | """ |
||
74 | Draws from a discrete distribution with the given (possibly unnormalized) |
||
75 | probabilities for each outcome. |
||
76 | |||
77 | Returns an int between 0 and len(probs)-1, inclusive |
||
78 | """ |
||
79 | if total is None: |
||
80 | total = float(sum(probs)) |
||
81 | for attempt in xrange(10): |
||
82 | dart = numpy.random.rand() * total |
||
83 | for i, prob in enumerate(probs): |
||
84 | dart -= prob |
||
85 | if dart <= 0: |
||
86 | return i |
||
87 | LOG.error( |
||
88 | 'imprecision in sample_discrete', |
||
89 | dict(total=total, dart=dart, probs=probs)) |
||
90 | raise ValueError('\n '.join([ |
||
91 | 'imprecision in sample_discrete:', |
||
92 | 'total = {}'.format(total), |
||
93 | 'dart = {}'.format(dart), |
||
94 | 'probs = {}'.format(probs), |
||
95 | ])) |
||
96 | |||
97 | |||
98 | def sample_normal(mu, sigmasq): |
||
99 | return norm.rvs(mu, sigmasq) |
||
100 | |||
101 | |||
102 | def sample_chi2(nu): |
||
103 | return chi2.rvs(nu) |
||
104 | |||
105 | |||
106 | def sample_student_t(dof, mu, Sigma): |
||
107 | p = len(mu) |
||
108 | x = numpy.random.chisquare(dof, 1) |
||
109 | z = numpy.random.multivariate_normal(numpy.zeros(p), Sigma, (1,)) |
||
110 | return (mu + z / numpy.sqrt(x))[0] |
||
111 | |||
112 | |||
113 | def score_student_t(x, nu, mu, sigma): |
||
114 | """ |
||
115 | multivariate score_student_t |
||
116 | |||
117 | \cite{murphy2007conjugate}, Eq. 313 |
||
118 | """ |
||
119 | p = len(mu) |
||
120 | z = x - mu |
||
121 | S = inner(inner(z, inv(sigma)), z) |
||
122 | score = ( |
||
123 | gammaln(0.5 * (nu + p)) |
||
124 | - gammaln(0.5 * nu) |
||
125 | - 0.5 * ( |
||
126 | p * log(nu * pi) |
||
127 | + log(det(sigma)) |
||
128 | + (nu + p) * log(1 + S / nu) |
||
129 | ) |
||
130 | ) |
||
131 | return score |
||
132 | |||
133 | |||
134 | def sample_wishart_naive(nu, Lambda): |
||
135 | """ |
||
136 | From the definition of the Wishart |
||
137 | Runs in linear time |
||
138 | """ |
||
139 | d = Lambda.shape[0] |
||
140 | X = multivariate_normal(mean=numpy.zeros(d), cov=Lambda, size=nu) |
||
141 | S = numpy.dot(X.T, X) |
||
142 | return S |
||
143 | |||
144 | |||
145 | View Code Duplication | def sample_wishart(nu, Lambda): |
|
0 ignored issues
–
show
Duplication
introduced
by
Loading history...
|
|||
146 | ch = cholesky(Lambda) |
||
147 | d = Lambda.shape[0] |
||
148 | z = numpy.zeros((d, d)) |
||
149 | for i in xrange(d): |
||
150 | if i != 0: |
||
151 | z[i, :i] = numpy.random.normal(size=(i,)) |
||
152 | z[i, i] = sqrt(numpy.random.gamma(0.5 * nu - d + 1, 2.0)) |
||
153 | return dot(dot(dot(ch, z), z.T), ch.T) |
||
154 | |||
155 | |||
156 | View Code Duplication | def sample_wishart_v2(nu, Lambda): |
|
0 ignored issues
–
show
|
|||
157 | """ |
||
158 | From Sawyer, et. al. 'Wishart Distributions and Inverse-Wishart Sampling' |
||
159 | Runs in constant time |
||
160 | Untested |
||
161 | """ |
||
162 | d = Lambda.shape[0] |
||
163 | ch = cholesky(Lambda) |
||
164 | T = numpy.zeros((d, d)) |
||
165 | for i in xrange(d): |
||
166 | if i != 0: |
||
167 | T[i, :i] = numpy.random.normal(size=(i,)) |
||
168 | T[i, i] = sqrt(chi2.rvs(nu - i + 1)) |
||
169 | return dot(dot(dot(ch, T), T.T), ch.T) |
||
170 | |||
171 | |||
172 | def sample_inverse_wishart(nu, S): |
||
173 | # matt's parameters are reversed |
||
174 | return _sample_inverse_wishart(S, nu) |
||
175 | |||
176 | |||
177 | def sample_normal_inverse_wishart(mu0, lambda0, psi0, nu0): |
||
178 | D, = mu0.shape |
||
179 | assert psi0.shape == (D, D) |
||
180 | assert lambda0 > 0.0 |
||
181 | assert nu0 > D - 1 |
||
182 | cov = sample_inverse_wishart(nu0, psi0) |
||
183 | mu = np.random.multivariate_normal(mean=mu0, cov=(1. / lambda0) * cov) |
||
184 | return mu, cov |
||
185 | |||
186 | |||
187 | def sample_partition_from_counts(items, counts): |
||
188 | """ |
||
189 | Sample a partition of a list of items, as a lists of lists that satisfies |
||
190 | the group sizes in counts. |
||
191 | """ |
||
192 | assert sum(counts) == len(items), 'counts do not sum to item count' |
||
193 | order = numpy.random.permutation(len(items)) |
||
194 | i = 0 |
||
195 | partition = [] |
||
196 | for k in range(len(counts)): |
||
197 | partition.append([]) |
||
198 | for j in range(counts[k]): |
||
199 | partition[-1].append(items[order[i]]) |
||
200 | i += 1 |
||
201 | return partition |
||
202 | |||
203 | |||
204 | def sample_stick(gamma, tol=1e-3): |
||
205 | """ |
||
206 | Truncated sample from a dirichlet process using stick breaking |
||
207 | """ |
||
208 | betas = [] |
||
209 | Z = 0. |
||
210 | while 1 - Z > tol: |
||
211 | new_beta = (1 - Z) * sample_beta(1., gamma) |
||
212 | betas.append(new_beta) |
||
213 | Z += new_beta |
||
214 | return {i: b / Z for i, b in enumerate(betas)} |
||
215 | |||
216 | |||
217 | def sample_negative_binomial(p, r): |
||
218 | return int(nbinom.rvs(r, p)) |
||
219 |