Conditions | 9 |
Total Lines | 126 |
Code Lines | 48 |
Lines | 0 |
Ratio | 0 % |
Tests | 1 |
CRAP Score | 83.6145 |
Changes | 0 |
Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.
For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.
Commonly applied refactorings include:
If many parameters/temporary variables are present:
Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.
There are several approaches to avoid long parameter lists:
1 | # -*- coding: utf-8 -*- |
||
98 | 1 | def waic_loo(model, times, data, errs, |
|
99 | samples, |
||
100 | method="likelihood", |
||
101 | train_data=None, |
||
102 | noisy_targets=True, |
||
103 | nthreads=1, |
||
104 | use_dask=False, |
||
105 | dask_cluster=None, |
||
106 | ): |
||
107 | """Watanabe-Akaike information criterion (WAIC) and LOO IC of the (GP) model |
||
108 | |||
109 | Calculates the WAIC and leave-one-out (LOO) cross validation scores and |
||
110 | information criteria (IC) from the MCMC samples of the posterior parameter |
||
111 | distributions. Uses the posterior point-wise (per data point) |
||
112 | probabilities and the formulae from [1]_ and [2]_. |
||
113 | |||
114 | .. [1] Vehtari, Gelman, and Gabry, Stat Comput (2017) 27:1413–1432, |
||
115 | doi: 10.1007/s11222-016-9696-4 |
||
116 | |||
117 | .. [2] Vehtari and Gelman, (unpublished) |
||
118 | http://www.stat.columbia.edu/~gelman/research/unpublished/waic_stan.pdf |
||
119 | http://www.stat.columbia.edu/~gelman/research/unpublished/loo_stan.pdf |
||
120 | |||
121 | Parameters |
||
122 | ---------- |
||
123 | model : `celerite.GP`, `george.GP` or `CeleriteModelSet` instance |
||
124 | The model instance whose parameter distribution was drawn. |
||
125 | times : (M,) array_like |
||
126 | The test coordinates to predict or evaluate the model on. |
||
127 | data : (M,) array_like |
||
128 | The test data to test the model against. |
||
129 | errs : (M,) array_like |
||
130 | The errors (variances) of the test data. |
||
131 | samples : (K, L) array_like |
||
132 | The `K` MCMC samples of the `L` parameter distributions. |
||
133 | method : str ("likelihood" or "predict"), optional |
||
134 | The method to "predict" the data, the default uses the (log)likelihood |
||
135 | in the same way as is done when fitting (training) the model. |
||
136 | "predict" uses the actual GP prediction, might be useful if the IC |
||
137 | should be estimated for actual test data that was not used to train |
||
138 | the model. |
||
139 | train_data : (N,) array_like, optional |
||
140 | The data on which the model was trained, needed if method="predict" is |
||
141 | used, otherwise None is the default and the likelihood is used. |
||
142 | noisy_targets : bool, optional |
||
143 | Include the given errors when calculating the predictive probability. |
||
144 | nthreads : int, optional |
||
145 | Number of threads to distribute the point-wise probability |
||
146 | calculations to (default: 1). |
||
147 | use_dask : boot, optional |
||
148 | Use `dask.distributed` to distribute the point-wise probability |
||
149 | calculations to `nthreads` workers. The default is to use |
||
150 | `multiprocessing.pool.Pool()`. |
||
151 | dask_cluster: str, or `dask.distributed.Cluster` instance, optional |
||
152 | Will be passed to `dask.distributed.Client()` |
||
153 | This can be the address of a Scheduler server like a string |
||
154 | '127.0.0.1:8786' or a cluster object like `dask.distributed.LocalCluster()`. |
||
155 | |||
156 | Returns |
||
157 | ------- |
||
158 | waic, waic_se, p_waic, loo_ic, loo_se, p_loo : tuple |
||
159 | The WAIC and its standard error as well as the |
||
160 | estimated effective number of parameters, p_waic. |
||
161 | The LOO IC, its standard error, and the estimated |
||
162 | effective number of parameters, p_loo. |
||
163 | """ |
||
164 | from functools import partial |
||
165 | from scipy.special import logsumexp |
||
166 | |||
167 | # the predictive covariance should include the data variance |
||
168 | # set to a small value if we don't want to account for them |
||
169 | if not noisy_targets or errs is None: |
||
170 | errs = 1.123e-12 |
||
171 | |||
172 | # point-wise posterior/predictive probabilities |
||
173 | _log_p_pt = partial(_log_lh_pt, model, times, data, errs) |
||
174 | if method == "predict" and train_data is not None: |
||
175 | _log_p_pt = partial(_log_pred_pt, model, train_data, times, data, errs) |
||
176 | |||
177 | # calculate the point-wise probabilities and stack them together |
||
178 | if nthreads > 1 and use_dask: |
||
179 | log_pred = _log_prob_pt_samples_dask(_log_p_pt, samples, |
||
180 | nthreads=nthreads, cluster=dask_cluster) |
||
181 | else: |
||
182 | log_pred = _log_prob_pt_samples_mt(_log_p_pt, samples, |
||
183 | nthreads=nthreads) |
||
184 | |||
185 | lppd_i = logsumexp(log_pred, b=1. / log_pred.shape[0], axis=0) |
||
186 | p_waic_i = np.nanvar(log_pred, ddof=1, axis=0) |
||
187 | if np.any(p_waic_i > 0.4): |
||
188 | logging.warn("""For one or more samples the posterior variance of the |
||
189 | log predictive densities exceeds 0.4. This could be indication of |
||
190 | WAIC starting to fail see http://arxiv.org/abs/1507.04544 for details |
||
191 | """) |
||
192 | elpd_i = lppd_i - p_waic_i |
||
193 | waic_i = -2. * elpd_i |
||
194 | waic_se = np.sqrt(len(waic_i) * np.nanvar(waic_i, ddof=1)) |
||
195 | waic = np.nansum(waic_i) |
||
196 | p_waic = np.nansum(p_waic_i) |
||
197 | if 2. * p_waic > len(waic_i): |
||
198 | logging.warn("""p_waic > n / 2, |
||
199 | the WAIC approximation is unreliable. |
||
200 | """) |
||
201 | logging.info("WAIC: %s, waic_se: %s, p_w: %s", waic, waic_se, p_waic) |
||
202 | |||
203 | # LOO |
||
204 | loo_ws = 1. / np.exp(log_pred - np.nanmax(log_pred, axis=0)) |
||
205 | loo_ws_n = loo_ws / np.nanmean(loo_ws, axis=0) |
||
206 | loo_ws_r = np.clip(loo_ws_n, None, np.sqrt(log_pred.shape[0])) |
||
207 | elpd_loo_i = logsumexp(log_pred, |
||
208 | b=loo_ws_r / np.nansum(loo_ws_r, axis=0), |
||
209 | axis=0) |
||
210 | p_loo_i = lppd_i - elpd_loo_i |
||
211 | loo_ic_i = -2 * elpd_loo_i |
||
212 | loo_ic_se = np.sqrt(len(loo_ic_i) * np.nanvar(loo_ic_i)) |
||
213 | loo_ic = np.nansum(loo_ic_i) |
||
214 | p_loo = np.nansum(p_loo_i) |
||
215 | logging.info("loo IC: %s, se: %s, p_loo: %s", loo_ic, loo_ic_se, p_loo) |
||
216 | |||
217 | # van der Linde, 2005, Statistica Neerlandica, 2005 |
||
218 | # https://doi.org/10.1111/j.1467-9574.2005.00278.x |
||
219 | hy1 = -np.nanmean(lppd_i) |
||
220 | hy2 = -np.nanmedian(lppd_i) |
||
221 | logging.info("H(Y): mean %s, median: %s", hy1, hy2) |
||
222 | |||
223 | return waic, waic_se, p_waic, loo_ic, loo_ic_se, p_loo |
||
224 | |||
364 |