Passed
Push — master ( 22dea2...7bc10c )
by Stefan
04:02
created

sciapy.regress.statistics.waic_loo()   C

Complexity

Conditions 9

Size

Total Lines 126
Code Lines 48

Duplication

Lines 126
Ratio 100 %

Code Coverage

Tests 1
CRAP Score 83.6145

Importance

Changes 0
Metric Value
cc 9
eloc 48
nop 11
dl 126
loc 126
ccs 1
cts 37
cp 0.027
crap 83.6145
rs 6.3684
c 0
b 0
f 0

How to fix   Long Method    Many Parameters   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
# -*- coding: utf-8 -*-
2
# vim:fileencoding=utf-8
3
#
4
# Copyright (c) 2017-2019 Stefan Bender
5
#
6
# This module is part of sciapy.
7
# sciapy is free software: you can redistribute it or modify
8
# it under the terms of the GNU General Public License as published
9
# by the Free Software Foundation, version 2.
10
# See accompanying LICENSE file or http://www.gnu.org/licenses/gpl-2.0.html.
11 1
"""SCIAMACHY MCMC statistic tools
12
13
Statistical functions for MCMC sampled parameters.
14
"""
15
16 1
import logging
17
18 1
import numpy as np
19 1
from scipy import linalg
20
21 1
__all__ = ["mcmc_statistics", "waic_loo"]
22
23
24 1
def _log_prob(resid, var):
25
	return -0.5 * (np.log(2 * np.pi * var) + resid**2 / var)
26
27
28 1 View Code Duplication
def _log_lh_pt(gp, times, data, errs, s):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
29
	gp.set_parameter_vector(s)
30
	resid = data - gp.mean.get_value(times)
31
	ker = gp.get_matrix(times, include_diagonal=True)
32
	ker[np.diag_indices_from(ker)] += errs**2
33
	ll, lower = linalg.cho_factor(
34
			ker, lower=True, check_finite=False, overwrite_a=True)
35
	linv_r = linalg.solve_triangular(
36
			ll, resid, lower=True, check_finite=False, overwrite_b=True)
37
	return -0.5 * (np.log(2. * np.pi) + linv_r**2) - np.log(np.diag(ll))
38
39
40 1 View Code Duplication
def _log_pred_pt(gp, train_data, times, data, errs, s):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
41
	gp.set_parameter_vector(s)
42
	gppred, pvar_gp = gp.predict(train_data, t=times, return_var=True)
43
	# the predictive covariance should include the data variance
44
	# if noisy_targets and errs is not None:
45
	if errs is not None:
46
		pvar_gp += errs**2
47
	# residuals
48
	resid_gp = gppred - data  # GP residuals
49
	return _log_prob(resid_gp, pvar_gp)
50
51
52 1 View Code Duplication
def _log_prob_pt_samples_dask(log_p_pt, samples,
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
53
		nthreads=1, cluster=None):
54
	from dask.distributed import Client, LocalCluster, progress
55
	import dask.bag as db
56
57
	# calculate the point-wise probabilities and stack them together
58
	if cluster is None:
59
		# start local dask cluster
60
		_cl = LocalCluster(n_workers=nthreads, threads_per_worker=1)
61
	else:
62
		# use provided dask cluster
63
		_cl = cluster
64
65
	with Client(_cl):
66
		_log_pred = db.from_sequence(samples).map(log_p_pt)
67
		progress(_log_pred)
68
		ret = np.stack(_log_pred.compute())
69
70
	if cluster is None:
71
		_cl.close()
72
73
	return ret
74
75
76 1 View Code Duplication
def _log_prob_pt_samples_mt(log_p_pt, samples, nthreads=1):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
77
	from multiprocessing import pool
78
	try:
79
		from tqdm.autonotebook import tqdm
80
	except ImportError:
81
		tqdm = None
82
83
	# multiprocessing.pool
84
	_p = pool.Pool(processes=nthreads)
85
86
	_mapped = _p.imap_unordered(log_p_pt, samples)
87
	if tqdm is not None:
88
		_mapped = tqdm(_mapped, total=len(samples))
89
90
	ret = np.stack(list(_mapped))
91
92
	_p.close()
93
	_p.join()
94
95
	return ret
96
97
98 1 View Code Duplication
def waic_loo(model, times, data, errs,
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
99
		samples,
100
		method="likelihood",
101
		train_data=None,
102
		noisy_targets=True,
103
		nthreads=1,
104
		use_dask=False,
105
		dask_cluster=None,
106
		):
107
	"""Watanabe-Akaike information criterion (WAIC) and LOO IC of the (GP) model
108
109
	Calculates the WAIC and leave-one-out (LOO) cross validation scores and
110
	information criteria (IC) from the MCMC samples of the posterior parameter
111
	distributions. Uses the posterior point-wise (per data point)
112
	probabilities and the formulae from [1]_ and [2]_.
113
114
	.. [1] Vehtari, Gelman, and Gabry, Stat Comput (2017) 27:1413–1432,
115
		doi: 10.1007/s11222-016-9696-4
116
117
	.. [2] Vehtari and Gelman, (unpublished)
118
		http://www.stat.columbia.edu/~gelman/research/unpublished/waic_stan.pdf
119
		http://www.stat.columbia.edu/~gelman/research/unpublished/loo_stan.pdf
120
121
	Parameters
122
	----------
123
	model : `celerite.GP`, `george.GP` or `CeleriteModelSet` instance
124
		The model instance whose parameter distribution was drawn.
125
	times : (M,) array_like
126
		The test coordinates to predict or evaluate the model on.
127
	data : (M,) array_like
128
		The test data to test the model against.
129
	errs : (M,) array_like
130
		The errors (variances) of the test data.
131
	samples : (K, L) array_like
132
		The `K` MCMC samples of the `L` parameter distributions.
133
	method : str ("likelihood" or "predict"), optional
134
		The method to "predict" the data, the default uses the (log)likelihood
135
		in the same way as is done when fitting (training) the model.
136
		"predict" uses the actual GP prediction, might be useful if the IC
137
		should be estimated for actual test data that was not used to train
138
		the model.
139
	train_data : (N,) array_like, optional
140
		The data on which the model was trained, needed if method="predict" is
141
		used, otherwise None is the default and the likelihood is used.
142
	noisy_targets : bool, optional
143
		Include the given errors when calculating the predictive probability.
144
	nthreads : int, optional
145
		Number of threads to distribute the point-wise probability
146
		calculations to (default: 1).
147
	use_dask : boot, optional
148
		Use `dask.distributed` to distribute the point-wise probability
149
		calculations to `nthreads` workers. The default is to use
150
		`multiprocessing.pool.Pool()`.
151
	dask_cluster: str, or `dask.distributed.Cluster` instance, optional
152
		Will be passed to `dask.distributed.Client()`
153
		This can be the address of a Scheduler server like a string
154
		'127.0.0.1:8786' or a cluster object like `dask.distributed.LocalCluster()`.
155
156
	Returns
157
	-------
158
	waic, waic_se, p_waic, loo_ic, loo_se, p_loo : tuple
159
		The WAIC and its standard error as well as the
160
		estimated effective number of parameters, p_waic.
161
		The LOO IC, its standard error, and the estimated
162
		effective number of parameters, p_loo.
163
	"""
164
	from functools import partial
165
	from scipy.special import logsumexp
166
167
	# the predictive covariance should include the data variance
168
	# set to a small value if we don't want to account for them
169
	if not noisy_targets or errs is None:
170
		errs = 1.123e-12
171
172
	# point-wise posterior/predictive probabilities
173
	_log_p_pt = partial(_log_lh_pt, model, times, data, errs)
174
	if method == "predict" and train_data is not None:
175
		_log_p_pt = partial(_log_pred_pt, model, train_data, times, data, errs)
176
177
	# calculate the point-wise probabilities and stack them together
178
	if nthreads > 1 and use_dask:
179
		log_pred = _log_prob_pt_samples_dask(_log_p_pt, samples,
180
				nthreads=nthreads, cluster=dask_cluster)
181
	else:
182
		log_pred = _log_prob_pt_samples_mt(_log_p_pt, samples,
183
				nthreads=nthreads)
184
185
	lppd_i = logsumexp(log_pred, b=1. / log_pred.shape[0], axis=0)
186
	p_waic_i = np.nanvar(log_pred, ddof=1, axis=0)
187
	if np.any(p_waic_i > 0.4):
188
		logging.warn("""For one or more samples the posterior variance of the
189
		log predictive densities exceeds 0.4. This could be indication of
190
		WAIC starting to fail see http://arxiv.org/abs/1507.04544 for details
191
		""")
192
	elpd_i = lppd_i - p_waic_i
193
	waic_i = -2. * elpd_i
194
	waic_se = np.sqrt(len(waic_i) * np.nanvar(waic_i, ddof=1))
195
	waic = np.nansum(waic_i)
196
	p_waic = np.nansum(p_waic_i)
197
	if 2. * p_waic > len(waic_i):
198
		logging.warn("""p_waic > n / 2,
199
		the WAIC approximation is unreliable.
200
		""")
201
	logging.info("WAIC: %s, waic_se: %s, p_w: %s", waic, waic_se, p_waic)
202
203
	# LOO
204
	loo_ws = 1. / np.exp(log_pred - np.nanmax(log_pred, axis=0))
205
	loo_ws_n = loo_ws / np.nanmean(loo_ws, axis=0)
206
	loo_ws_r = np.clip(loo_ws_n, None, np.sqrt(log_pred.shape[0]))
207
	elpd_loo_i = logsumexp(log_pred,
208
			b=loo_ws_r / np.nansum(loo_ws_r, axis=0),
209
			axis=0)
210
	p_loo_i = lppd_i - elpd_loo_i
211
	loo_ic_i = -2 * elpd_loo_i
212
	loo_ic_se = np.sqrt(len(loo_ic_i) * np.nanvar(loo_ic_i))
213
	loo_ic = np.nansum(loo_ic_i)
214
	p_loo = np.nansum(p_loo_i)
215
	logging.info("loo IC: %s, se: %s, p_loo: %s", loo_ic, loo_ic_se, p_loo)
216
217
	# van der Linde, 2005, Statistica Neerlandica, 2005
218
	# https://doi.org/10.1111/j.1467-9574.2005.00278.x
219
	hy1 = -np.nanmean(lppd_i)
220
	hy2 = -np.nanmedian(lppd_i)
221
	logging.info("H(Y): mean %s, median: %s", hy1, hy2)
222
223
	return waic, waic_se, p_waic, loo_ic, loo_ic_se, p_loo
224
225
226 1 View Code Duplication
def mcmc_statistics(model, times, data, errs,
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
227
		train_times, train_data, train_errs,
228
		samples, lnp,
229
		median=False):
230
	"""Statistics for the (GP) model against the provided data
231
232
	Statistical information about the model and the sampled parameter
233
	distributions with respect to the provided data and its variance.
234
235
	Sends the calculated values to the logger, includes the mean
236
	standardized log loss as described in R&W, 2006, section 2.5, (2.34),
237
	and some slightly adapted $\\chi^2_{\\text{red}}$ and $R^2$ scores.
238
239
	Parameters
240
	----------
241
	model : `celerite.GP`, `george.GP` or `CeleriteModelSet` instance
242
		The model instance whose parameter distribution was drawn.
243
	times : (M,) array_like
244
		The test coordinates to predict or evaluate the model on.
245
	data : (M,) array_like
246
		The test data to test the model against.
247
	errs : (M,) array_like
248
		The errors (variances) of the test data.
249
	train_times : (N,) array_like
250
		The coordinates on which the model was trained.
251
	train_data : (N,) array_like
252
		The data on which the model was trained.
253
	train_errs : (N,) array_like
254
		The errors (variances) of the training data.
255
	samples : (K, L) array_like
256
		The `K` MCMC samples of the `L` parameter distributions.
257
	lnp : (K,) array_like
258
		The posterior log probabilities of the `K` MCMC samples.
259
	median : bool, optional
260
		Whether to use the median of the sampled distributions or
261
		the maximum posterior sample (the default) to evaluate the
262
		statistics.
263
264
	Returns
265
	-------
266
	nothing
267
	"""
268
	ndat = len(times)
269
	ndim = len(model.get_parameter_vector())
270
	mdim = len(model.mean.get_parameter_vector())
271
	samples_max_lp = np.max(lnp)
272
	if median:
273
		sample_pos = np.nanmedian(samples, axis=0)
274
	else:
275
		sample_pos = samples[np.argmax(lnp)]
276
	model.set_parameter_vector(sample_pos)
277
	# calculate the GP predicted values and covariance
278
	gppred, gpcov = model.predict(train_data, t=times, return_cov=True)
279
	# the predictive covariance should include the data variance
280
	gpcov[np.diag_indices_from(gpcov)] += errs**2
281
	# residuals
282
	resid_mod = model.mean.get_value(times) - data  # GP mean model
283
	resid_gp = gppred - data  # GP prediction
284
	resid_triv = np.nanmean(train_data) - data  # trivial model
285
	_const = ndat * np.log(2.0 * np.pi)
286
	test_logpred = -0.5 * (resid_gp.dot(linalg.solve(gpcov, resid_gp))
287
			+ np.trace(np.log(gpcov))
288
			+ _const)
289
	# MSLL -- mean standardized log loss
290
	# as described in R&W, 2006, section 2.5, (2.34)
291
	var_mod = np.nanvar(resid_mod, ddof=mdim)  # mean model variance
292
	var_gp = np.nanvar(resid_gp, ddof=ndim)  # gp model variance
293
	var_triv = np.nanvar(train_data, ddof=1)  # trivial model variance
294
	logpred_mod = _log_prob(resid_mod, var_mod)
295
	logpred_gp = _log_prob(resid_gp, var_gp)
296
	logpred_triv = _log_prob(resid_triv, var_triv)
297
	logging.info("MSLL mean: %s", np.nanmean(-logpred_mod + logpred_triv))
298
	logging.info("MSLL gp: %s", np.nanmean(-logpred_gp + logpred_triv))
299
	# predictive variances
300
	logpred_mod = _log_prob(resid_mod, var_mod + errs**2)
301
	logpred_gp = _log_prob(resid_gp, var_gp + errs**2)
302
	logpred_triv = _log_prob(resid_triv, var_triv + errs**2)
303
	logging.info("pred MSLL mean: %s", np.nanmean(-logpred_mod + logpred_triv))
304
	logging.info("pred MSLL gp: %s", np.nanmean(-logpred_gp + logpred_triv))
305
	# cost values
306
	cost_mod = np.sum(resid_mod**2)
307
	cost_triv = np.sum(resid_triv**2)
308
	cost_gp = np.sum(resid_gp**2)
309
	# chi^2 (variance corrected costs)
310
	chisq_mod_ye = np.sum((resid_mod / errs)**2)
311
	chisq_triv = np.sum((resid_triv / errs)**2)
312
	chisq_gpcov = resid_mod.dot(linalg.solve(gpcov, resid_mod))
313
	# adjust for degrees of freedom
314
	cost_gp_dof = cost_gp / (ndat - ndim)
315
	cost_mod_dof = cost_mod / (ndat - mdim)
316
	cost_triv_dof = cost_triv / (ndat - 1)
317
	# reduced chi^2
318
	chisq_red_mod_ye = chisq_mod_ye / (ndat - mdim)
319
	chisq_red_triv = chisq_triv / (ndat - 1)
320
	chisq_red_gpcov = chisq_gpcov / (ndat - ndim)
321
	# "generalized" R^2
322
	logp_triv1 = np.sum(_log_prob(resid_triv, errs**2))
323
	logp_triv2 = np.sum(_log_prob(resid_triv, var_triv))
324
	logp_triv3 = np.sum(_log_prob(resid_triv, var_triv + errs**2))
325
	log_lambda1 = test_logpred - logp_triv1
326
	log_lambda2 = test_logpred - logp_triv2
327
	log_lambda3 = test_logpred - logp_triv3
328
	gen_rsq1a = 1 - np.exp(-2 * log_lambda1 / ndat)
329
	gen_rsq1b = 1 - np.exp(-2 * log_lambda1 / (ndat - ndim))
330
	gen_rsq2a = 1 - np.exp(-2 * log_lambda2 / ndat)
331
	gen_rsq2b = 1 - np.exp(-2 * log_lambda2 / (ndat - ndim))
332
	gen_rsq3a = 1 - np.exp(-2 * log_lambda3 / ndat)
333
	gen_rsq3b = 1 - np.exp(-2 * log_lambda3 / (ndat - ndim))
334
	# sent to the logger
335
	logging.info("train max logpost: %s", samples_max_lp)
336
	logging.info("test log_pred: %s", test_logpred)
337
	logging.info("1a cost mean model: %s, dof adj: %s", cost_mod, cost_mod_dof)
338
	logging.debug("1c cost gp predict: %s, dof adj: %s", cost_gp, cost_gp_dof)
339
	logging.debug("1b cost triv model: %s, dof adj: %s", cost_triv, cost_triv_dof)
340
	logging.info("1d var resid mean model: %s, gp model: %s, triv: %s",
341
			var_mod, var_gp, var_triv)
342
	logging.info("2a adjR2 mean model: %s, adjR2 gp predict: %s",
343
			1 - cost_mod_dof / cost_triv_dof, 1 - cost_gp_dof / cost_triv_dof)
344
	logging.info("2b red chi^2 mod: %s / triv: %s = %s",
345
			chisq_red_mod_ye, chisq_red_triv, chisq_red_mod_ye / chisq_red_triv)
346
	logging.info("2c red chi^2 mod (gp cov): %s / triv: %s = %s",
347
			chisq_red_gpcov, chisq_red_triv, chisq_red_gpcov / chisq_red_triv)
348
	logging.info("3a stand. red chi^2: %s", chisq_red_gpcov / chisq_red_triv)
349
	logging.info("3b 1 - stand. red chi^2: %s",
350
			1 - chisq_red_gpcov / chisq_red_triv)
351
	logging.info("5a generalized R^2: 1a: %s, 1b: %s",
352
			gen_rsq1a, gen_rsq1b)
353
	logging.info("5b generalized R^2: 2a: %s, 2b: %s",
354
			gen_rsq2a, gen_rsq2b)
355
	logging.info("5c generalized R^2: 3a: %s, 3b: %s",
356
			gen_rsq3a, gen_rsq3b)
357
	try:
358
		# celerite
359
		logdet = model.solver.log_determinant()
360
	except TypeError:
361
		# george
362
		logdet = model.solver.log_determinant
363
	logging.debug("5 logdet: %s, const 2: %s", logdet, _const)
364