sciapy.regress.statistics._log_pred_pt()   A
last analyzed

Complexity

Conditions 2

Size

Total Lines 10
Code Lines 7

Duplication

Lines 0
Ratio 0 %

Code Coverage

Tests 1
CRAP Score 4.5185

Importance

Changes 0
Metric Value
cc 2
eloc 7
nop 6
dl 0
loc 10
ccs 1
cts 7
cp 0.1429
crap 4.5185
rs 10
c 0
b 0
f 0
1
# -*- coding: utf-8 -*-
2
# vim:fileencoding=utf-8
3
#
4
# Copyright (c) 2017-2019 Stefan Bender
5
#
6
# This module is part of sciapy.
7
# sciapy is free software: you can redistribute it or modify
8
# it under the terms of the GNU General Public License as published
9
# by the Free Software Foundation, version 2.
10
# See accompanying LICENSE file or http://www.gnu.org/licenses/gpl-2.0.html.
11 1
"""SCIAMACHY MCMC statistic tools
12
13
Statistical functions for MCMC sampled parameters.
14
"""
15
16 1
import logging
17
18 1
import numpy as np
19 1
from scipy import linalg
20
21 1
__all__ = ["mcmc_statistics", "waic_loo"]
22
23
24 1
def _log_prob(resid, var):
25 1
	return -0.5 * (np.log(2 * np.pi * var) + resid**2 / var)
26
27
28 1
def _log_lh_pt(gp, times, data, errs, s):
29
	gp.set_parameter_vector(s)
30
	resid = data - gp.mean.get_value(times)
31
	ker = gp.get_matrix(times, include_diagonal=True)
32
	ker[np.diag_indices_from(ker)] += errs**2
33
	ll, lower = linalg.cho_factor(
34
			ker, lower=True, check_finite=False, overwrite_a=True)
35
	linv_r = linalg.solve_triangular(
36
			ll, resid, lower=True, check_finite=False, overwrite_b=True)
37
	return -0.5 * (np.log(2. * np.pi) + linv_r**2) - np.log(np.diag(ll))
38
39
40 1
def _log_pred_pt(gp, train_data, times, data, errs, s):
41
	gp.set_parameter_vector(s)
42
	gppred, pvar_gp = gp.predict(train_data, t=times, return_var=True)
43
	# the predictive covariance should include the data variance
44
	# if noisy_targets and errs is not None:
45
	if errs is not None:
46
		pvar_gp += errs**2
47
	# residuals
48
	resid_gp = gppred - data  # GP residuals
49
	return _log_prob(resid_gp, pvar_gp)
50
51
52 1
def _log_prob_pt_samples_dask(log_p_pt, samples,
53
		nthreads=1, cluster=None):
54
	from dask.distributed import Client, LocalCluster, progress
55
	import dask.bag as db
56
57
	# calculate the point-wise probabilities and stack them together
58
	if cluster is None:
59
		# start local dask cluster
60
		_cl = LocalCluster(n_workers=nthreads, threads_per_worker=1)
61
	else:
62
		# use provided dask cluster
63
		_cl = cluster
64
65
	with Client(_cl):
66
		_log_pred = db.from_sequence(samples).map(log_p_pt)
67
		progress(_log_pred)
68
		ret = np.stack(_log_pred.compute())
69
70
	if cluster is None:
71
		_cl.close()
72
73
	return ret
74
75
76 1
def _log_prob_pt_samples_mt(log_p_pt, samples, nthreads=1):
77
	from multiprocessing import pool
78
	try:
79
		from tqdm.autonotebook import tqdm
80
	except ImportError:
81
		tqdm = None
82
83
	# multiprocessing.pool
84
	_p = pool.Pool(processes=nthreads)
85
86
	_mapped = _p.imap_unordered(log_p_pt, samples)
87
	if tqdm is not None:
88
		_mapped = tqdm(_mapped, total=len(samples))
89
90
	ret = np.stack(list(_mapped))
91
92
	_p.close()
93
	_p.join()
94
95
	return ret
96
97
98 1
def waic_loo(model, times, data, errs,
99
		samples,
100
		method="likelihood",
101
		train_data=None,
102
		noisy_targets=True,
103
		nthreads=1,
104
		use_dask=False,
105
		dask_cluster=None,
106
		):
107
	"""Watanabe-Akaike information criterion (WAIC) and LOO IC of the (GP) model
108
109
	Calculates the WAIC and leave-one-out (LOO) cross validation scores and
110
	information criteria (IC) from the MCMC samples of the posterior parameter
111
	distributions. Uses the posterior point-wise (per data point)
112
	probabilities and the formulae from [1]_ and [2]_.
113
114
	.. [1] Vehtari, Gelman, and Gabry, Stat Comput (2017) 27:1413–1432,
115
		doi: 10.1007/s11222-016-9696-4
116
117
	.. [2] Vehtari and Gelman, (unpublished)
118
		http://www.stat.columbia.edu/~gelman/research/unpublished/waic_stan.pdf
119
		http://www.stat.columbia.edu/~gelman/research/unpublished/loo_stan.pdf
120
121
	Parameters
122
	----------
123
	model : `celerite.GP`, `george.GP` or `CeleriteModelSet` instance
124
		The model instance whose parameter distribution was drawn.
125
	times : (M,) array_like
126
		The test coordinates to predict or evaluate the model on.
127
	data : (M,) array_like
128
		The test data to test the model against.
129
	errs : (M,) array_like
130
		The errors (variances) of the test data.
131
	samples : (K, L) array_like
132
		The `K` MCMC samples of the `L` parameter distributions.
133
	method : str ("likelihood" or "predict"), optional
134
		The method to "predict" the data, the default uses the (log)likelihood
135
		in the same way as is done when fitting (training) the model.
136
		"predict" uses the actual GP prediction, might be useful if the IC
137
		should be estimated for actual test data that was not used to train
138
		the model.
139
	train_data : (N,) array_like, optional
140
		The data on which the model was trained, needed if method="predict" is
141
		used, otherwise None is the default and the likelihood is used.
142
	noisy_targets : bool, optional
143
		Include the given errors when calculating the predictive probability.
144
	nthreads : int, optional
145
		Number of threads to distribute the point-wise probability
146
		calculations to (default: 1).
147
	use_dask : boot, optional
148
		Use `dask.distributed` to distribute the point-wise probability
149
		calculations to `nthreads` workers. The default is to use
150
		`multiprocessing.pool.Pool()`.
151
	dask_cluster: str, or `dask.distributed.Cluster` instance, optional
152
		Will be passed to `dask.distributed.Client()`
153
		This can be the address of a Scheduler server like a string
154
		'127.0.0.1:8786' or a cluster object like `dask.distributed.LocalCluster()`.
155
156
	Returns
157
	-------
158
	waic, waic_se, p_waic, loo_ic, loo_se, p_loo : tuple
159
		The WAIC and its standard error as well as the
160
		estimated effective number of parameters, p_waic.
161
		The LOO IC, its standard error, and the estimated
162
		effective number of parameters, p_loo.
163
	"""
164
	from functools import partial
165
	from scipy.special import logsumexp
166
167
	# the predictive covariance should include the data variance
168
	# set to a small value if we don't want to account for them
169
	if not noisy_targets or errs is None:
170
		errs = 1.123e-12
171
172
	# point-wise posterior/predictive probabilities
173
	_log_p_pt = partial(_log_lh_pt, model, times, data, errs)
174
	if method == "predict" and train_data is not None:
175
		_log_p_pt = partial(_log_pred_pt, model, train_data, times, data, errs)
176
177
	# calculate the point-wise probabilities and stack them together
178
	if nthreads > 1 and use_dask:
179
		log_pred = _log_prob_pt_samples_dask(_log_p_pt, samples,
180
				nthreads=nthreads, cluster=dask_cluster)
181
	else:
182
		log_pred = _log_prob_pt_samples_mt(_log_p_pt, samples,
183
				nthreads=nthreads)
184
185
	lppd_i = logsumexp(log_pred, b=1. / log_pred.shape[0], axis=0)
186
	p_waic_i = np.nanvar(log_pred, ddof=1, axis=0)
187
	if np.any(p_waic_i > 0.4):
188
		logging.warn("""For one or more samples the posterior variance of the
189
		log predictive densities exceeds 0.4. This could be indication of
190
		WAIC starting to fail see http://arxiv.org/abs/1507.04544 for details
191
		""")
192
	elpd_i = lppd_i - p_waic_i
193
	waic_i = -2. * elpd_i
194
	waic_se = np.sqrt(len(waic_i) * np.nanvar(waic_i, ddof=1))
195
	waic = np.nansum(waic_i)
196
	p_waic = np.nansum(p_waic_i)
197
	if 2. * p_waic > len(waic_i):
198
		logging.warn("""p_waic > n / 2,
199
		the WAIC approximation is unreliable.
200
		""")
201
	logging.info("WAIC: %s, waic_se: %s, p_w: %s", waic, waic_se, p_waic)
202
203
	# LOO
204
	loo_ws = 1. / np.exp(log_pred - np.nanmax(log_pred, axis=0))
205
	loo_ws_n = loo_ws / np.nanmean(loo_ws, axis=0)
206
	loo_ws_r = np.clip(loo_ws_n, None, np.sqrt(log_pred.shape[0]))
207
	elpd_loo_i = logsumexp(log_pred,
208
			b=loo_ws_r / np.nansum(loo_ws_r, axis=0),
209
			axis=0)
210
	p_loo_i = lppd_i - elpd_loo_i
211
	loo_ic_i = -2 * elpd_loo_i
212
	loo_ic_se = np.sqrt(len(loo_ic_i) * np.nanvar(loo_ic_i))
213
	loo_ic = np.nansum(loo_ic_i)
214
	p_loo = np.nansum(p_loo_i)
215
	logging.info("loo IC: %s, se: %s, p_loo: %s", loo_ic, loo_ic_se, p_loo)
216
217
	# van der Linde, 2005, Statistica Neerlandica, 2005
218
	# https://doi.org/10.1111/j.1467-9574.2005.00278.x
219
	hy1 = -np.nanmean(lppd_i)
220
	hy2 = -np.nanmedian(lppd_i)
221
	logging.info("H(Y): mean %s, median: %s", hy1, hy2)
222
223
	return waic, waic_se, p_waic, loo_ic, loo_ic_se, p_loo
224
225
226 1
def mcmc_statistics(model, times, data, errs,
227
		train_times, train_data, train_errs,
228
		samples, lnp,
229
		median=False):
230
	"""Statistics for the (GP) model against the provided data
231
232
	Statistical information about the model and the sampled parameter
233
	distributions with respect to the provided data and its variance.
234
235
	Sends the calculated values to the logger, includes the mean
236
	standardized log loss as described in R&W, 2006, section 2.5, (2.34),
237
	and some slightly adapted $\\chi^2_{\\text{red}}$ and $R^2$ scores.
238
239
	Parameters
240
	----------
241
	model : `celerite.GP`, `george.GP` or `CeleriteModelSet` instance
242
		The model instance whose parameter distribution was drawn.
243
	times : (M,) array_like
244
		The test coordinates to predict or evaluate the model on.
245
	data : (M,) array_like
246
		The test data to test the model against.
247
	errs : (M,) array_like
248
		The errors (variances) of the test data.
249
	train_times : (N,) array_like
250
		The coordinates on which the model was trained.
251
	train_data : (N,) array_like
252
		The data on which the model was trained.
253
	train_errs : (N,) array_like
254
		The errors (variances) of the training data.
255
	samples : (K, L) array_like
256
		The `K` MCMC samples of the `L` parameter distributions.
257
	lnp : (K,) array_like
258
		The posterior log probabilities of the `K` MCMC samples.
259
	median : bool, optional
260
		Whether to use the median of the sampled distributions or
261
		the maximum posterior sample (the default) to evaluate the
262
		statistics.
263
264
	Returns
265
	-------
266
	nothing
267
	"""
268 1
	ndat = len(times)
269 1
	ndim = len(model.get_parameter_vector())
270 1
	mdim = len(model.mean.get_parameter_vector())
271 1
	samples_max_lp = np.max(lnp)
272 1
	if median:
273
		sample_pos = np.nanmedian(samples, axis=0)
274
	else:
275 1
		sample_pos = samples[np.argmax(lnp)]
276 1
	model.set_parameter_vector(sample_pos)
277
	# calculate the GP predicted values and covariance
278 1
	gppred, gpcov = model.predict(train_data, t=times, return_cov=True)
279
	# the predictive covariance should include the data variance
280 1
	gpcov[np.diag_indices_from(gpcov)] += errs**2
281
	# residuals
282 1
	resid_mod = model.mean.get_value(times) - data  # GP mean model
283 1
	resid_gp = gppred - data  # GP prediction
284 1
	resid_triv = np.nanmean(train_data) - data  # trivial model
285 1
	_const = ndat * np.log(2.0 * np.pi)
286 1
	test_logpred = -0.5 * (resid_gp.dot(linalg.solve(gpcov, resid_gp))
287
			+ np.trace(np.log(gpcov))
288
			+ _const)
289
	# MSLL -- mean standardized log loss
290
	# as described in R&W, 2006, section 2.5, (2.34)
291 1
	var_mod = np.nanvar(resid_mod, ddof=mdim)  # mean model variance
292 1
	var_gp = np.nanvar(resid_gp, ddof=ndim)  # gp model variance
293 1
	var_triv = np.nanvar(train_data, ddof=1)  # trivial model variance
294 1
	logpred_mod = _log_prob(resid_mod, var_mod)
295 1
	logpred_gp = _log_prob(resid_gp, var_gp)
296 1
	logpred_triv = _log_prob(resid_triv, var_triv)
297 1
	logging.info("MSLL mean: %s", np.nanmean(-logpred_mod + logpred_triv))
298 1
	logging.info("MSLL gp: %s", np.nanmean(-logpred_gp + logpred_triv))
299
	# predictive variances
300 1
	logpred_mod = _log_prob(resid_mod, var_mod + errs**2)
301 1
	logpred_gp = _log_prob(resid_gp, var_gp + errs**2)
302 1
	logpred_triv = _log_prob(resid_triv, var_triv + errs**2)
303 1
	logging.info("pred MSLL mean: %s", np.nanmean(-logpred_mod + logpred_triv))
304 1
	logging.info("pred MSLL gp: %s", np.nanmean(-logpred_gp + logpred_triv))
305
	# cost values
306 1
	cost_mod = np.sum(resid_mod**2)
307 1
	cost_triv = np.sum(resid_triv**2)
308 1
	cost_gp = np.sum(resid_gp**2)
309
	# chi^2 (variance corrected costs)
310 1
	chisq_mod_ye = np.sum((resid_mod / errs)**2)
311 1
	chisq_triv = np.sum((resid_triv / errs)**2)
312 1
	chisq_gpcov = resid_mod.dot(linalg.solve(gpcov, resid_mod))
313
	# adjust for degrees of freedom
314 1
	cost_gp_dof = cost_gp / (ndat - ndim)
315 1
	cost_mod_dof = cost_mod / (ndat - mdim)
316 1
	cost_triv_dof = cost_triv / (ndat - 1)
317
	# reduced chi^2
318 1
	chisq_red_mod_ye = chisq_mod_ye / (ndat - mdim)
319 1
	chisq_red_triv = chisq_triv / (ndat - 1)
320 1
	chisq_red_gpcov = chisq_gpcov / (ndat - ndim)
321
	# "generalized" R^2
322 1
	logp_triv1 = np.sum(_log_prob(resid_triv, errs**2))
323 1
	logp_triv2 = np.sum(_log_prob(resid_triv, var_triv))
324 1
	logp_triv3 = np.sum(_log_prob(resid_triv, var_triv + errs**2))
325 1
	log_lambda1 = test_logpred - logp_triv1
326 1
	log_lambda2 = test_logpred - logp_triv2
327 1
	log_lambda3 = test_logpred - logp_triv3
328 1
	gen_rsq1a = 1 - np.exp(-2 * log_lambda1 / ndat)
329 1
	gen_rsq1b = 1 - np.exp(-2 * log_lambda1 / (ndat - ndim))
330 1
	gen_rsq2a = 1 - np.exp(-2 * log_lambda2 / ndat)
331 1
	gen_rsq2b = 1 - np.exp(-2 * log_lambda2 / (ndat - ndim))
332 1
	gen_rsq3a = 1 - np.exp(-2 * log_lambda3 / ndat)
333 1
	gen_rsq3b = 1 - np.exp(-2 * log_lambda3 / (ndat - ndim))
334
	# sent to the logger
335 1
	logging.info("train max logpost: %s", samples_max_lp)
336 1
	logging.info("test log_pred: %s", test_logpred)
337 1
	logging.info("1a cost mean model: %s, dof adj: %s", cost_mod, cost_mod_dof)
338 1
	logging.debug("1c cost gp predict: %s, dof adj: %s", cost_gp, cost_gp_dof)
339 1
	logging.debug("1b cost triv model: %s, dof adj: %s", cost_triv, cost_triv_dof)
340 1
	logging.info("1d var resid mean model: %s, gp model: %s, triv: %s",
341
			var_mod, var_gp, var_triv)
342 1
	logging.info("2a adjR2 mean model: %s, adjR2 gp predict: %s",
343
			1 - cost_mod_dof / cost_triv_dof, 1 - cost_gp_dof / cost_triv_dof)
344 1
	logging.info("2b red chi^2 mod: %s / triv: %s = %s",
345
			chisq_red_mod_ye, chisq_red_triv, chisq_red_mod_ye / chisq_red_triv)
346 1
	logging.info("2c red chi^2 mod (gp cov): %s / triv: %s = %s",
347
			chisq_red_gpcov, chisq_red_triv, chisq_red_gpcov / chisq_red_triv)
348 1
	logging.info("3a stand. red chi^2: %s", chisq_red_gpcov / chisq_red_triv)
349 1
	logging.info("3b 1 - stand. red chi^2: %s",
350
			1 - chisq_red_gpcov / chisq_red_triv)
351 1
	logging.info("5a generalized R^2: 1a: %s, 1b: %s",
352
			gen_rsq1a, gen_rsq1b)
353 1
	logging.info("5b generalized R^2: 2a: %s, 2b: %s",
354
			gen_rsq2a, gen_rsq2b)
355 1
	logging.info("5c generalized R^2: 3a: %s, 3b: %s",
356
			gen_rsq3a, gen_rsq3b)
357 1
	try:
358
		# celerite
359 1
		logdet = model.solver.log_determinant()
360 1
	except TypeError:
361
		# george
362 1
		logdet = model.solver.log_determinant
363
	logging.debug("5 logdet: %s, const 2: %s", logdet, _const)
364