Passed
Push — master ( d51613...22dea2 )
by Stefan
05:53
created

sciapy.regress.statistics._log_pred_pt()   A

Complexity

Conditions 2

Size

Total Lines 10
Code Lines 7

Duplication

Lines 10
Ratio 100 %

Code Coverage

Tests 1
CRAP Score 4.5185

Importance

Changes 0
Metric Value
cc 2
eloc 7
nop 6
dl 10
loc 10
ccs 1
cts 7
cp 0.1429
crap 4.5185
rs 10
c 0
b 0
f 0
1
# -*- coding: utf-8 -*-
2
# vim:fileencoding=utf-8
3
#
4
# Copyright (c) 2017-2019 Stefan Bender
5
#
6
# This module is part of sciapy.
7
# sciapy is free software: you can redistribute it or modify
8
# it under the terms of the GNU General Public License as published
9
# by the Free Software Foundation, version 2.
10
# See accompanying LICENSE file or http://www.gnu.org/licenses/gpl-2.0.html.
11 1
"""SCIAMACHY MCMC statistic tools
12
13
Statistical functions for MCMC sampled parameters.
14
"""
15
16 1
import logging
17
18 1
import numpy as np
19 1
from scipy import linalg
20
21 1
__all__ = ["mcmc_statistics", "waic_loo"]
22
23
24 1
def _log_prob(resid, var):
25
	return -0.5 * (np.log(2 * np.pi * var) + resid**2 / var)
26
27
28 1 View Code Duplication
def _log_lh_pt(gp, times, data, errs, s):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
29
	gp.set_parameter_vector(s)
30
	resid = data - gp.mean.get_value(times)
31
	ker = gp.get_matrix(times, include_diagonal=True)
32
	ker[np.diag_indices_from(ker)] += errs**2
33
	ll, lower = linalg.cho_factor(
34
			ker, lower=True, check_finite=False, overwrite_a=True)
35
	linv_r = linalg.solve_triangular(
36
			ll, resid, lower=True, check_finite=False, overwrite_b=True)
37
	return -0.5 * (np.log(2. * np.pi) + linv_r**2) - np.log(np.diag(ll))
38
39
40 1 View Code Duplication
def _log_pred_pt(gp, train_data, times, data, errs, s):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
41
	gp.set_parameter_vector(s)
42
	gppred, pvar_gp = gp.predict(train_data, t=times, return_var=True)
43
	# the predictive covariance should include the data variance
44
	# if noisy_targets and errs is not None:
45
	if errs is not None:
46
		pvar_gp += errs**2
47
	# residuals
48
	resid_gp = gppred - data  # GP residuals
49
	return _log_prob(resid_gp, pvar_gp)
50
51
52 1 View Code Duplication
def waic_loo(model, times, data, errs,
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
53
		samples,
54
		method="likelihood",
55
		train_data=None,
56
		noisy_targets=True,
57
		nthreads=1,
58
		use_dask=False,
59
		dask_cluster=None,
60
		):
61
	"""Watanabe-Akaike information criterion (WAIC) and LOO IC of the (GP) model
62
63
	Calculates the WAIC and leave-one-out (LOO) cross validation scores and
64
	information criteria (IC) from the MCMC samples of the posterior parameter
65
	distributions. Uses the posterior point-wise (per data point)
66
	probabilities and the formulae from [1]_ and [2]_.
67
68
	.. [1] Vehtari, Gelman, and Gabry, Stat Comput (2017) 27:1413–1432,
69
		doi: 10.1007/s11222-016-9696-4
70
71
	.. [2] Vehtari and Gelman, (unpublished)
72
		http://www.stat.columbia.edu/~gelman/research/unpublished/waic_stan.pdf
73
		http://www.stat.columbia.edu/~gelman/research/unpublished/loo_stan.pdf
74
75
	Parameters
76
	----------
77
	model : `celerite.GP`, `george.GP` or `CeleriteModelSet` instance
78
		The model instance whose parameter distribution was drawn.
79
	times : (M,) array_like
80
		The test coordinates to predict or evaluate the model on.
81
	data : (M,) array_like
82
		The test data to test the model against.
83
	errs : (M,) array_like
84
		The errors (variances) of the test data.
85
	samples : (K, L) array_like
86
		The `K` MCMC samples of the `L` parameter distributions.
87
	method : str ("likelihood" or "predict"), optional
88
		The method to "predict" the data, the default uses the (log)likelihood
89
		in the same way as is done when fitting (training) the model.
90
		"predict" uses the actual GP prediction, might be useful if the IC
91
		should be estimated for actual test data that was not used to train
92
		the model.
93
	train_data : (N,) array_like, optional
94
		The data on which the model was trained, needed if method="predict" is
95
		used, otherwise None is the default and the likelihood is used.
96
	noisy_targets : bool, optional
97
		Include the given errors when calculating the predictive probability.
98
	nthreads : int, optional
99
		Number of threads to distribute the point-wise probability
100
		calculations to (default: 1).
101
	use_dask : boot, optional
102
		Use `dask.distributed` to distribute the point-wise probability
103
		calculations to `nthreads` workers. The default is to use
104
		`multiprocessing.pool.Pool()`.
105
	dask_cluster: str, or `dask.distributed.Cluster` instance, optional
106
		Will be passed to `dask.distributed.Client()`
107
		This can be the address of a Scheduler server like a string
108
		'127.0.0.1:8786' or a cluster object like `dask.distributed.LocalCluster()`.
109
110
	Returns
111
	-------
112
	waic, waic_se, p_waic, loo_ic, loo_se, p_loo : tuple
113
		The WAIC and its standard error as well as the
114
		estimated effective number of parameters, p_waic.
115
		The LOO IC, its standard error, and the estimated
116
		effective number of parameters, p_loo.
117
	"""
118
	from functools import partial
119
	from multiprocessing import pool
120
	from scipy.special import logsumexp
121
	try:
122
		from tqdm.autonotebook import tqdm
123
	except ImportError:
124
		tqdm = None
125
	try:
126
		from dask.distributed import Client, LocalCluster, progress
127
	except ImportError:
128
		use_dask = False
129
130
	# the predictive covariance should include the data variance
131
	# set to a small value if we don't want to account for them
132
	if not noisy_targets or errs is None:
133
		errs = 1.123e-12
134
135
	# point-wise posterior/predictive probabilities
136
	_log_p_pt = partial(_log_lh_pt, model, times, data, errs)
137
	if method == "predict" and train_data is not None:
138
		_log_p_pt = partial(_log_pred_pt, model, train_data, times, data, errs)
139
140
	# calculate the point-wise probabilities and stack them together
141
	if nthreads > 1:
142
		if use_dask:
143
			if dask_cluster is None:
144
				# start local dask cluster
145
				_cl = LocalCluster(n_workers=nthreads, threads_per_worker=1)
146
			else:
147
				# use provided dask cluster
148
				_cl = dask_cluster
149
			_c = Client(_cl)
150
			_log_pred = _c.map(_log_p_pt, samples)
151
			progress(_log_pred)
152
			log_pred = np.stack(_c.gather(_log_pred))
153
			_c.close()
154
			if dask_cluster is None:
155
				_cl.close()
156
		else:
157
			# multiprocessing.pool
158
			_p = pool.Pool(processes=nthreads)
159
			_mapped = _p.imap_unordered(_log_p_pt, samples)
160
			if tqdm is not None:
161
				_mapped = tqdm(_mapped, total=len(samples))
162
			log_pred = np.stack(list(_mapped))
163
			_p.close()
164
			_p.join()
165
	else:
166
		if tqdm is not None:
167
			samples = tqdm(samples, total=len(samples))
168
		log_pred = np.stack(list(map(_log_p_pt, samples)))
169
170
	lppd_i = logsumexp(log_pred, b=1. / log_pred.shape[0], axis=0)
171
	p_waic_i = np.nanvar(log_pred, ddof=1, axis=0)
172
	if np.any(p_waic_i > 0.4):
173
		logging.warn("""For one or more samples the posterior variance of the
174
		log predictive densities exceeds 0.4. This could be indication of
175
		WAIC starting to fail see http://arxiv.org/abs/1507.04544 for details
176
		""")
177
	elpd_i = lppd_i - p_waic_i
178
	waic_i = -2. * elpd_i
179
	waic_se = np.sqrt(len(waic_i) * np.nanvar(waic_i, ddof=1))
180
	waic = np.nansum(waic_i)
181
	p_waic = np.nansum(p_waic_i)
182
	if 2. * p_waic > len(waic_i):
183
		logging.warn("""p_waic > n / 2,
184
		the WAIC approximation is unreliable.
185
		""")
186
	logging.info("WAIC: %s, waic_se: %s, p_w: %s", waic, waic_se, p_waic)
187
188
	# LOO
189
	loo_ws = 1. / np.exp(log_pred - np.nanmax(log_pred, axis=0))
190
	loo_ws_n = loo_ws / np.nanmean(loo_ws, axis=0)
191
	loo_ws_r = np.clip(loo_ws_n, None, np.sqrt(log_pred.shape[0]))
192
	elpd_loo_i = logsumexp(log_pred,
193
			b=loo_ws_r / np.nansum(loo_ws_r, axis=0),
194
			axis=0)
195
	p_loo_i = lppd_i - elpd_loo_i
196
	loo_ic_i = -2 * elpd_loo_i
197
	loo_ic_se = np.sqrt(len(loo_ic_i) * np.nanvar(loo_ic_i))
198
	loo_ic = np.nansum(loo_ic_i)
199
	p_loo = np.nansum(p_loo_i)
200
	logging.info("loo IC: %s, se: %s, p_loo: %s", loo_ic, loo_ic_se, p_loo)
201
202
	# van der Linde, 2005, Statistica Neerlandica, 2005
203
	# https://doi.org/10.1111/j.1467-9574.2005.00278.x
204
	hy1 = -np.nanmean(lppd_i)
205
	hy2 = -np.nanmedian(lppd_i)
206
	logging.info("H(Y): mean %s, median: %s", hy1, hy2)
207
208
	return waic, waic_se, p_waic, loo_ic, loo_ic_se, p_loo
209
210
211 1 View Code Duplication
def mcmc_statistics(model, times, data, errs,
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
212
		train_times, train_data, train_errs,
213
		samples, lnp,
214
		median=False):
215
	"""Statistics for the (GP) model against the provided data
216
217
	Statistical information about the model and the sampled parameter
218
	distributions with respect to the provided data and its variance.
219
220
	Sends the calculated values to the logger, includes the mean
221
	standardized log loss as described in R&W, 2006, section 2.5, (2.34),
222
	and some slightly adapted $\\chi^2_{\\text{red}}$ and $R^2$ scores.
223
224
	Parameters
225
	----------
226
	model : `celerite.GP`, `george.GP` or `CeleriteModelSet` instance
227
		The model instance whose parameter distribution was drawn.
228
	times : (M,) array_like
229
		The test coordinates to predict or evaluate the model on.
230
	data : (M,) array_like
231
		The test data to test the model against.
232
	errs : (M,) array_like
233
		The errors (variances) of the test data.
234
	train_times : (N,) array_like
235
		The coordinates on which the model was trained.
236
	train_data : (N,) array_like
237
		The data on which the model was trained.
238
	train_errs : (N,) array_like
239
		The errors (variances) of the training data.
240
	samples : (K, L) array_like
241
		The `K` MCMC samples of the `L` parameter distributions.
242
	lnp : (K,) array_like
243
		The posterior log probabilities of the `K` MCMC samples.
244
	median : bool, optional
245
		Whether to use the median of the sampled distributions or
246
		the maximum posterior sample (the default) to evaluate the
247
		statistics.
248
249
	Returns
250
	-------
251
	nothing
252
	"""
253
	ndat = len(times)
254
	ndim = len(model.get_parameter_vector())
255
	mdim = len(model.mean.get_parameter_vector())
256
	samples_max_lp = np.max(lnp)
257
	if median:
258
		sample_pos = np.nanmedian(samples, axis=0)
259
	else:
260
		sample_pos = samples[np.argmax(lnp)]
261
	model.set_parameter_vector(sample_pos)
262
	# calculate the GP predicted values and covariance
263
	gppred, gpcov = model.predict(train_data, t=times, return_cov=True)
264
	# the predictive covariance should include the data variance
265
	gpcov[np.diag_indices_from(gpcov)] += errs**2
266
	# residuals
267
	resid_mod = model.mean.get_value(times) - data  # GP mean model
268
	resid_gp = gppred - data  # GP prediction
269
	resid_triv = np.nanmean(train_data) - data  # trivial model
270
	_const = ndat * np.log(2.0 * np.pi)
271
	test_logpred = -0.5 * (resid_gp.dot(linalg.solve(gpcov, resid_gp))
272
			+ np.trace(np.log(gpcov))
273
			+ _const)
274
	# MSLL -- mean standardized log loss
275
	# as described in R&W, 2006, section 2.5, (2.34)
276
	var_mod = np.nanvar(resid_mod, ddof=mdim)  # mean model variance
277
	var_gp = np.nanvar(resid_gp, ddof=ndim)  # gp model variance
278
	var_triv = np.nanvar(train_data, ddof=1)  # trivial model variance
279
	logpred_mod = _log_prob(resid_mod, var_mod)
280
	logpred_gp = _log_prob(resid_gp, var_gp)
281
	logpred_triv = _log_prob(resid_triv, var_triv)
282
	logging.info("MSLL mean: %s", np.nanmean(-logpred_mod + logpred_triv))
283
	logging.info("MSLL gp: %s", np.nanmean(-logpred_gp + logpred_triv))
284
	# predictive variances
285
	logpred_mod = _log_prob(resid_mod, var_mod + errs**2)
286
	logpred_gp = _log_prob(resid_gp, var_gp + errs**2)
287
	logpred_triv = _log_prob(resid_triv, var_triv + errs**2)
288
	logging.info("pred MSLL mean: %s", np.nanmean(-logpred_mod + logpred_triv))
289
	logging.info("pred MSLL gp: %s", np.nanmean(-logpred_gp + logpred_triv))
290
	# cost values
291
	cost_mod = np.sum(resid_mod**2)
292
	cost_triv = np.sum(resid_triv**2)
293
	cost_gp = np.sum(resid_gp**2)
294
	# chi^2 (variance corrected costs)
295
	chisq_mod_ye = np.sum((resid_mod / errs)**2)
296
	chisq_triv = np.sum((resid_triv / errs)**2)
297
	chisq_gpcov = resid_mod.dot(linalg.solve(gpcov, resid_mod))
298
	# adjust for degrees of freedom
299
	cost_gp_dof = cost_gp / (ndat - ndim)
300
	cost_mod_dof = cost_mod / (ndat - mdim)
301
	cost_triv_dof = cost_triv / (ndat - 1)
302
	# reduced chi^2
303
	chisq_red_mod_ye = chisq_mod_ye / (ndat - mdim)
304
	chisq_red_triv = chisq_triv / (ndat - 1)
305
	chisq_red_gpcov = chisq_gpcov / (ndat - ndim)
306
	# "generalized" R^2
307
	logp_triv1 = np.sum(_log_prob(resid_triv, errs**2))
308
	logp_triv2 = np.sum(_log_prob(resid_triv, var_triv))
309
	logp_triv3 = np.sum(_log_prob(resid_triv, var_triv + errs**2))
310
	log_lambda1 = test_logpred - logp_triv1
311
	log_lambda2 = test_logpred - logp_triv2
312
	log_lambda3 = test_logpred - logp_triv3
313
	gen_rsq1a = 1 - np.exp(-2 * log_lambda1 / ndat)
314
	gen_rsq1b = 1 - np.exp(-2 * log_lambda1 / (ndat - ndim))
315
	gen_rsq2a = 1 - np.exp(-2 * log_lambda2 / ndat)
316
	gen_rsq2b = 1 - np.exp(-2 * log_lambda2 / (ndat - ndim))
317
	gen_rsq3a = 1 - np.exp(-2 * log_lambda3 / ndat)
318
	gen_rsq3b = 1 - np.exp(-2 * log_lambda3 / (ndat - ndim))
319
	# sent to the logger
320
	logging.info("train max logpost: %s", samples_max_lp)
321
	logging.info("test log_pred: %s", test_logpred)
322
	logging.info("1a cost mean model: %s, dof adj: %s", cost_mod, cost_mod_dof)
323
	logging.debug("1c cost gp predict: %s, dof adj: %s", cost_gp, cost_gp_dof)
324
	logging.debug("1b cost triv model: %s, dof adj: %s", cost_triv, cost_triv_dof)
325
	logging.info("1d var resid mean model: %s, gp model: %s, triv: %s",
326
			var_mod, var_gp, var_triv)
327
	logging.info("2a adjR2 mean model: %s, adjR2 gp predict: %s",
328
			1 - cost_mod_dof / cost_triv_dof, 1 - cost_gp_dof / cost_triv_dof)
329
	logging.info("2b red chi^2 mod: %s / triv: %s = %s",
330
			chisq_red_mod_ye, chisq_red_triv, chisq_red_mod_ye / chisq_red_triv)
331
	logging.info("2c red chi^2 mod (gp cov): %s / triv: %s = %s",
332
			chisq_red_gpcov, chisq_red_triv, chisq_red_gpcov / chisq_red_triv)
333
	logging.info("3a stand. red chi^2: %s", chisq_red_gpcov / chisq_red_triv)
334
	logging.info("3b 1 - stand. red chi^2: %s",
335
			1 - chisq_red_gpcov / chisq_red_triv)
336
	logging.info("5a generalized R^2: 1a: %s, 1b: %s",
337
			gen_rsq1a, gen_rsq1b)
338
	logging.info("5b generalized R^2: 2a: %s, 2b: %s",
339
			gen_rsq2a, gen_rsq2b)
340
	logging.info("5c generalized R^2: 3a: %s, 3b: %s",
341
			gen_rsq3a, gen_rsq3b)
342
	try:
343
		# celerite
344
		logdet = model.solver.log_determinant()
345
	except TypeError:
346
		# george
347
		logdet = model.solver.log_determinant
348
	logging.debug("5 logdet: %s, const 2: %s", logdet, _const)
349