Passed
Push — master ( 07cced...d51613 )
by Stefan
04:04
created

sciapy.regress.statistics._log_pred_pt()   A

Complexity

Conditions 2

Size

Total Lines 10
Code Lines 7

Duplication

Lines 10
Ratio 100 %

Code Coverage

Tests 1
CRAP Score 4.5185

Importance

Changes 0
Metric Value
cc 2
eloc 7
nop 6
dl 10
loc 10
ccs 1
cts 7
cp 0.1429
crap 4.5185
rs 10
c 0
b 0
f 0
1
# -*- coding: utf-8 -*-
2
# vim:fileencoding=utf-8
3
#
4
# Copyright (c) 2017-2019 Stefan Bender
5
#
6
# This module is part of sciapy.
7
# sciapy is free software: you can redistribute it or modify
8
# it under the terms of the GNU General Public License as published
9
# by the Free Software Foundation, version 2.
10
# See accompanying LICENSE file or http://www.gnu.org/licenses/gpl-2.0.html.
11 1
"""SCIAMACHY MCMC statistic tools
12
13
Statistical functions for MCMC sampled parameters.
14
"""
15
16 1
import logging
17
18 1
import numpy as np
19 1
from scipy import linalg
20
21 1
__all__ = ["mcmc_statistics", "waic_loo"]
22
23
24 1
def _log_prob(resid, var):
25
	return -0.5 * (np.log(2 * np.pi * var) + resid**2 / var)
26
27
28 1 View Code Duplication
def _log_lh_pt(gp, times, data, errs, s):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
29
	gp.set_parameter_vector(s)
30
	resid = data - gp.mean.get_value(times)
31
	ker = gp.get_matrix(times, include_diagonal=True)
32
	ker[np.diag_indices_from(ker)] += errs**2
33
	ll, lower = linalg.cho_factor(
34
			ker, lower=True, check_finite=False, overwrite_a=True)
35
	linv_r = linalg.solve_triangular(
36
			ll, resid, lower=True, check_finite=False, overwrite_b=True)
37
	return -0.5 * (np.log(2. * np.pi) + linv_r**2) - np.log(np.diag(ll))
38
39
40 1 View Code Duplication
def _log_pred_pt(gp, train_data, times, data, errs, s):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
41
	gp.set_parameter_vector(s)
42
	gppred, pvar_gp = gp.predict(train_data, t=times, return_var=True)
43
	# the predictive covariance should include the data variance
44
	# if noisy_targets and errs is not None:
45
	if errs is not None:
46
		pvar_gp += errs**2
47
	# residuals
48
	resid_gp = gppred - data  # GP residuals
49
	return _log_prob(resid_gp, pvar_gp)
50
51
52 1 View Code Duplication
def waic_loo(model, times, data, errs,
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
53
		train_data,
54
		samples,
55
		method="likelihood",
56
		noisy_targets=True,
57
		nthreads=1,
58
		use_dask=False,
59
		):
60
	"""Watanabe-Akaike information criterion (WAIC) and LOO IC of the (GP) model
61
62
	Calculates the WAIC and leave-one-out (LOO) cross validation scores and
63
	information criteria (IC) from the MCMC samples of the posterior parameter
64
	distributions. Uses the posterior point-wise (per data point)
65
	probabilities and the formulae from [1]_ and [2]_.
66
67
	.. [1] Vehtari, Gelman, and Gabry, Stat Comput (2017) 27:1413–1432,
68
		doi: 10.1007/s11222-016-9696-4
69
70
	.. [2] Vehtari and Gelman, (unpublished)
71
		http://www.stat.columbia.edu/~gelman/research/unpublished/waic_stan.pdf
72
		http://www.stat.columbia.edu/~gelman/research/unpublished/loo_stan.pdf
73
74
	Parameters
75
	----------
76
	model : `celerite.GP`, `george.GP` or `CeleriteModelSet` instance
77
		The model instance whose parameter distribution was drawn.
78
	times : (M,) array_like
79
		The test coordinates to predict or evaluate the model on.
80
	data : (M,) array_like
81
		The test data to test the model against.
82
	errs : (M,) array_like
83
		The errors (variances) of the test data.
84
	train_data : (N,) array_like
85
		The data on which the model was trained.
86
	samples : (K, L) array_like
87
		The `K` MCMC samples of the `L` parameter distributions.
88
	method : str ("likelihood" or "predict"), optional
89
		The method to "predict" the data, the default uses the (log)likelihood
90
		in the same way as is done when fitting (training) the model.
91
		"predict" uses the actual GP prediction, might be useful if the IC
92
		should be estimated for actual test data that was not used to train
93
		the model.
94
	noisy_targets : bool, optional
95
		Include the given errors when calculating the predictive probability.
96
	nthreads : int, optional
97
		Number of threads to distribute the point-wise probability
98
		calculations to (default: 1).
99
	use_dask : boot, optional
100
		Use `dask.distributed` to distribute the point-wise probability
101
		calculations to `nthreads` workers. The default is to use
102
		`multiprocessing.pool.Pool()`.
103
104
	Returns
105
	-------
106
	waic, waic_se, p_waic, loo_ic, loo_se, p_loo : tuple
107
		The WAIC and its standard error as well as the
108
		estimated effective number of parameters, p_waic.
109
		The LOO IC, its standard error, and the estimated
110
		effective number of parameters, p_loo.
111
	"""
112
	from tqdm.autonotebook import tqdm
113
	from scipy.special import logsumexp
114
	from multiprocessing import pool
115
	from functools import partial
116
	try:
117
		from dask.distributed import Client, LocalCluster, progress
118
	except ImportError:
119
		use_dask = False
120
121
	# the predictive covariance should include the data variance
122
	# set to a small value if we don't want to account for them
123
	if not noisy_targets or errs is None:
124
		errs = 1.123e-12
125
126
	# point-wise posterior/predictive probabilities
127
	if method == "likelihood":
128
		_log_p_pt = partial(_log_lh_pt, model, times, data, errs)
129
	elif method == "predict":
130
		_log_p_pt = partial(_log_pred_pt, model, train_data, times, data, errs)
131
132
	# calculate the point-wise probabilities and stack them together
133
	if nthreads > 1:
134
		if use_dask:
135
			# local dask cluster
136
			_cl = LocalCluster(n_workers=nthreads, threads_per_worker=1)
137
			_c = Client(_cl)
138
			_log_pred = _c.map(_log_p_pt, samples)
0 ignored issues
show
introduced by
The variable _log_p_pt does not seem to be defined for all execution paths.
Loading history...
139
			progress(_log_pred)
140
			log_pred = np.stack(_c.gather(_log_pred))
141
		else:
142
			# multiprocessing.pool
143
			_p = pool.Pool(processes=nthreads)
144
			log_pred = np.stack(list(tqdm(
145
				_p.imap_unordered(_log_p_pt, samples), total=len(samples))))
146
	else:
147
		samples = tqdm(samples, total=len(samples))
148
		log_pred = np.stack(list(map(_log_p_pt, samples)))
149
150
	lppd_i = logsumexp(log_pred, b=1. / log_pred.shape[0], axis=0)
151
	p_waic_i = np.nanvar(log_pred, ddof=1, axis=0)
152
	if np.any(p_waic_i > 0.4):
153
		logging.warn("""For one or more samples the posterior variance of the
154
		log predictive densities exceeds 0.4. This could be indication of
155
		WAIC starting to fail see http://arxiv.org/abs/1507.04544 for details
156
		""")
157
	elpd_i = lppd_i - p_waic_i
158
	waic_i = -2. * elpd_i
159
	waic_se = np.sqrt(len(waic_i) * np.nanvar(waic_i, ddof=1))
160
	waic = np.nansum(waic_i)
161
	p_waic = np.nansum(p_waic_i)
162
	if 2. * p_waic > len(waic_i):
163
		logging.warn("""p_waic > n / 2,
164
		the WAIC approximation is unreliable.
165
		""")
166
	logging.info("WAIC: %s, waic_se: %s, p_w: %s", waic, waic_se, p_waic)
167
168
	# LOO
169
	loo_ws = 1. / np.exp(log_pred - np.nanmax(log_pred, axis=0))
170
	loo_ws_n = loo_ws / np.nanmean(loo_ws, axis=0)
171
	loo_ws_r = np.clip(loo_ws_n, None, np.sqrt(log_pred.shape[0]))
172
	elpd_loo_i = logsumexp(log_pred,
173
			b=loo_ws_r / np.nansum(loo_ws_r, axis=0),
174
			axis=0)
175
	p_loo_i = lppd_i - elpd_loo_i
176
	loo_ic_i = -2 * elpd_loo_i
177
	loo_ic_se = np.sqrt(len(loo_ic_i) * np.nanvar(loo_ic_i))
178
	loo_ic = np.nansum(loo_ic_i)
179
	p_loo = np.nansum(p_loo_i)
180
	logging.info("loo IC: %s, se: %s, p_loo: %s", loo_ic, loo_ic_se, p_loo)
181
182
	# van der Linde, 2005, Statistica Neerlandica, 2005
183
	# https://doi.org/10.1111/j.1467-9574.2005.00278.x
184
	hy1 = -np.nanmean(lppd_i)
185
	hy2 = -np.nanmedian(lppd_i)
186
	logging.info("H(Y): mean %s, median: %s", hy1, hy2)
187
188
	# clean up
189
	if nthreads > 1:
190
		if use_dask:
191
			_c.close()
0 ignored issues
show
introduced by
The variable _c does not seem to be defined for all execution paths.
Loading history...
192
			_cl.close()
0 ignored issues
show
introduced by
The variable _cl does not seem to be defined for all execution paths.
Loading history...
193
		else:
194
			_p.close()
0 ignored issues
show
introduced by
The variable _p does not seem to be defined for all execution paths.
Loading history...
195
			_p.join()
196
	return waic, waic_se, p_waic, loo_ic, loo_ic_se, p_loo
197
198
199 1 View Code Duplication
def mcmc_statistics(model, times, data, errs,
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
200
		train_times, train_data, train_errs,
201
		samples, lnp,
202
		median=False):
203
	"""Statistics for the (GP) model against the provided data
204
205
	Statistical information about the model and the sampled parameter
206
	distributions with respect to the provided data and its variance.
207
208
	Sends the calculated values to the logger, includes the mean
209
	standardized log loss as described in R&W, 2006, section 2.5, (2.34),
210
	and some slightly adapted $\\chi^2_{\\text{red}}$ and $R^2$ scores.
211
212
	Parameters
213
	----------
214
	model : `celerite.GP`, `george.GP` or `CeleriteModelSet` instance
215
		The model instance whose parameter distribution was drawn.
216
	times : (M,) array_like
217
		The test coordinates to predict or evaluate the model on.
218
	data : (M,) array_like
219
		The test data to test the model against.
220
	errs : (M,) array_like
221
		The errors (variances) of the test data.
222
	train_times : (N,) array_like
223
		The coordinates on which the model was trained.
224
	train_data : (N,) array_like
225
		The data on which the model was trained.
226
	train_errs : (N,) array_like
227
		The errors (variances) of the training data.
228
	samples : (K, L) array_like
229
		The `K` MCMC samples of the `L` parameter distributions.
230
	lnp : (K,) array_like
231
		The posterior log probabilities of the `K` MCMC samples.
232
	median : bool, optional
233
		Whether to use the median of the sampled distributions or
234
		the maximum posterior sample (the default) to evaluate the
235
		statistics.
236
237
	Returns
238
	-------
239
	nothing
240
	"""
241
	ndat = len(times)
242
	ndim = len(model.get_parameter_vector())
243
	mdim = len(model.mean.get_parameter_vector())
244
	samples_max_lp = np.max(lnp)
245
	if median:
246
		sample_pos = np.nanmedian(samples, axis=0)
247
	else:
248
		sample_pos = samples[np.argmax(lnp)]
249
	model.set_parameter_vector(sample_pos)
250
	# calculate the GP predicted values and covariance
251
	gppred, gpcov = model.predict(train_data, t=times, return_cov=True)
252
	# the predictive covariance should include the data variance
253
	gpcov[np.diag_indices_from(gpcov)] += errs**2
254
	# residuals
255
	resid_mod = model.mean.get_value(times) - data  # GP mean model
256
	resid_gp = gppred - data  # GP prediction
257
	resid_triv = np.nanmean(train_data) - data  # trivial model
258
	_const = ndat * np.log(2.0 * np.pi)
259
	test_logpred = -0.5 * (resid_gp.dot(linalg.solve(gpcov, resid_gp))
260
			+ np.trace(np.log(gpcov))
261
			+ _const)
262
	# MSLL -- mean standardized log loss
263
	# as described in R&W, 2006, section 2.5, (2.34)
264
	var_mod = np.nanvar(resid_mod, ddof=mdim)  # mean model variance
265
	var_gp = np.nanvar(resid_gp, ddof=ndim)  # gp model variance
266
	var_triv = np.nanvar(train_data, ddof=1)  # trivial model variance
267
	logpred_mod = _log_prob(resid_mod, var_mod)
268
	logpred_gp = _log_prob(resid_gp, var_gp)
269
	logpred_triv = _log_prob(resid_triv, var_triv)
270
	logging.info("MSLL mean: %s", np.nanmean(-logpred_mod + logpred_triv))
271
	logging.info("MSLL gp: %s", np.nanmean(-logpred_gp + logpred_triv))
272
	# predictive variances
273
	logpred_mod = _log_prob(resid_mod, var_mod + errs**2)
274
	logpred_gp = _log_prob(resid_gp, var_gp + errs**2)
275
	logpred_triv = _log_prob(resid_triv, var_triv + errs**2)
276
	logging.info("pred MSLL mean: %s", np.nanmean(-logpred_mod + logpred_triv))
277
	logging.info("pred MSLL gp: %s", np.nanmean(-logpred_gp + logpred_triv))
278
	# cost values
279
	cost_mod = np.sum(resid_mod**2)
280
	cost_triv = np.sum(resid_triv**2)
281
	cost_gp = np.sum(resid_gp**2)
282
	# chi^2 (variance corrected costs)
283
	chisq_mod_ye = np.sum((resid_mod / errs)**2)
284
	chisq_triv = np.sum((resid_triv / errs)**2)
285
	chisq_gpcov = resid_mod.dot(linalg.solve(gpcov, resid_mod))
286
	# adjust for degrees of freedom
287
	cost_gp_dof = cost_gp / (ndat - ndim)
288
	cost_mod_dof = cost_mod / (ndat - mdim)
289
	cost_triv_dof = cost_triv / (ndat - 1)
290
	# reduced chi^2
291
	chisq_red_mod_ye = chisq_mod_ye / (ndat - mdim)
292
	chisq_red_triv = chisq_triv / (ndat - 1)
293
	chisq_red_gpcov = chisq_gpcov / (ndat - ndim)
294
	# "generalized" R^2
295
	logp_triv1 = np.sum(_log_prob(resid_triv, errs**2))
296
	logp_triv2 = np.sum(_log_prob(resid_triv, var_triv))
297
	logp_triv3 = np.sum(_log_prob(resid_triv, var_triv + errs**2))
298
	log_lambda1 = test_logpred - logp_triv1
299
	log_lambda2 = test_logpred - logp_triv2
300
	log_lambda3 = test_logpred - logp_triv3
301
	gen_rsq1a = 1 - np.exp(-2 * log_lambda1 / ndat)
302
	gen_rsq1b = 1 - np.exp(-2 * log_lambda1 / (ndat - ndim))
303
	gen_rsq2a = 1 - np.exp(-2 * log_lambda2 / ndat)
304
	gen_rsq2b = 1 - np.exp(-2 * log_lambda2 / (ndat - ndim))
305
	gen_rsq3a = 1 - np.exp(-2 * log_lambda3 / ndat)
306
	gen_rsq3b = 1 - np.exp(-2 * log_lambda3 / (ndat - ndim))
307
	# sent to the logger
308
	logging.info("train max logpost: %s", samples_max_lp)
309
	logging.info("test log_pred: %s", test_logpred)
310
	logging.info("1a cost mean model: %s, dof adj: %s", cost_mod, cost_mod_dof)
311
	logging.debug("1c cost gp predict: %s, dof adj: %s", cost_gp, cost_gp_dof)
312
	logging.debug("1b cost triv model: %s, dof adj: %s", cost_triv, cost_triv_dof)
313
	logging.info("1d var resid mean model: %s, gp model: %s, triv: %s",
314
			var_mod, var_gp, var_triv)
315
	logging.info("2a adjR2 mean model: %s, adjR2 gp predict: %s",
316
			1 - cost_mod_dof / cost_triv_dof, 1 - cost_gp_dof / cost_triv_dof)
317
	logging.info("2b red chi^2 mod: %s / triv: %s = %s",
318
			chisq_red_mod_ye, chisq_red_triv, chisq_red_mod_ye / chisq_red_triv)
319
	logging.info("2c red chi^2 mod (gp cov): %s / triv: %s = %s",
320
			chisq_red_gpcov, chisq_red_triv, chisq_red_gpcov / chisq_red_triv)
321
	logging.info("3a stand. red chi^2: %s", chisq_red_gpcov / chisq_red_triv)
322
	logging.info("3b 1 - stand. red chi^2: %s",
323
			1 - chisq_red_gpcov / chisq_red_triv)
324
	logging.info("5a generalized R^2: 1a: %s, 1b: %s",
325
			gen_rsq1a, gen_rsq1b)
326
	logging.info("5b generalized R^2: 2a: %s, 2b: %s",
327
			gen_rsq2a, gen_rsq2b)
328
	logging.info("5c generalized R^2: 3a: %s, 3b: %s",
329
			gen_rsq3a, gen_rsq3b)
330
	try:
331
		# celerite
332
		logdet = model.solver.log_determinant()
333
	except TypeError:
334
		# george
335
		logdet = model.solver.log_determinant
336
	logging.debug("5 logdet: %s, const 2: %s", logdet, _const)
337