| 1 |  |  | # -*- coding: utf-8 -*- | 
            
                                                                                                            
                            
            
                                    
            
            
                | 2 |  |  | # vim:fileencoding=utf-8 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 3 |  |  | # | 
            
                                                                                                            
                            
            
                                    
            
            
                | 4 |  |  | # Copyright (c) 2017-2019 Stefan Bender | 
            
                                                                                                            
                            
            
                                    
            
            
                | 5 |  |  | # | 
            
                                                                                                            
                            
            
                                    
            
            
                | 6 |  |  | # This module is part of sciapy. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 7 |  |  | # sciapy is free software: you can redistribute it or modify | 
            
                                                                                                            
                            
            
                                    
            
            
                | 8 |  |  | # it under the terms of the GNU General Public License as published | 
            
                                                                                                            
                            
            
                                    
            
            
                | 9 |  |  | # by the Free Software Foundation, version 2. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 10 |  |  | # See accompanying LICENSE file or http://www.gnu.org/licenses/gpl-2.0.html. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 11 | 1 |  | """SCIAMACHY MCMC statistic tools | 
            
                                                                                                            
                            
            
                                    
            
            
                | 12 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 13 |  |  | Statistical functions for MCMC sampled parameters. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 14 |  |  | """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 15 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 16 | 1 |  | import logging | 
            
                                                                                                            
                            
            
                                    
            
            
                | 17 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 18 | 1 |  | import numpy as np | 
            
                                                                                                            
                            
            
                                    
            
            
                | 19 | 1 |  | from scipy import linalg | 
            
                                                                                                            
                            
            
                                    
            
            
                | 20 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 21 | 1 |  | __all__ = ["mcmc_statistics", "waic_loo"] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 22 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 23 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 24 | 1 |  | def _log_prob(resid, var): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 25 | 1 |  | 	return -0.5 * (np.log(2 * np.pi * var) + resid**2 / var) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 26 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 27 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 28 | 1 |  | def _log_lh_pt(gp, times, data, errs, s): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 29 |  |  | 	gp.set_parameter_vector(s) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 30 |  |  | 	resid = data - gp.mean.get_value(times) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 31 |  |  | 	ker = gp.get_matrix(times, include_diagonal=True) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 32 |  |  | 	ker[np.diag_indices_from(ker)] += errs**2 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 33 |  |  | 	ll, lower = linalg.cho_factor( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 34 |  |  | 			ker, lower=True, check_finite=False, overwrite_a=True) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 35 |  |  | 	linv_r = linalg.solve_triangular( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 36 |  |  | 			ll, resid, lower=True, check_finite=False, overwrite_b=True) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 37 |  |  | 	return -0.5 * (np.log(2. * np.pi) + linv_r**2) - np.log(np.diag(ll)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 38 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 39 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 40 | 1 |  | def _log_pred_pt(gp, train_data, times, data, errs, s): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 41 |  |  | 	gp.set_parameter_vector(s) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 42 |  |  | 	gppred, pvar_gp = gp.predict(train_data, t=times, return_var=True) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 43 |  |  | 	# the predictive covariance should include the data variance | 
            
                                                                                                            
                            
            
                                    
            
            
                | 44 |  |  | 	# if noisy_targets and errs is not None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 45 |  |  | 	if errs is not None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 46 |  |  | 		pvar_gp += errs**2 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 47 |  |  | 	# residuals | 
            
                                                                                                            
                            
            
                                    
            
            
                | 48 |  |  | 	resid_gp = gppred - data  # GP residuals | 
            
                                                                                                            
                            
            
                                    
            
            
                | 49 |  |  | 	return _log_prob(resid_gp, pvar_gp) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 50 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 51 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 52 | 1 |  | def _log_prob_pt_samples_dask(log_p_pt, samples, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 53 |  |  | 		nthreads=1, cluster=None): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 54 |  |  | 	from dask.distributed import Client, LocalCluster, progress | 
            
                                                                                                            
                            
            
                                    
            
            
                | 55 |  |  | 	import dask.bag as db | 
            
                                                                                                            
                            
            
                                    
            
            
                | 56 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 57 |  |  | 	# calculate the point-wise probabilities and stack them together | 
            
                                                                                                            
                            
            
                                    
            
            
                | 58 |  |  | 	if cluster is None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 59 |  |  | 		# start local dask cluster | 
            
                                                                                                            
                            
            
                                    
            
            
                | 60 |  |  | 		_cl = LocalCluster(n_workers=nthreads, threads_per_worker=1) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 61 |  |  | 	else: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 62 |  |  | 		# use provided dask cluster | 
            
                                                                                                            
                            
            
                                    
            
            
                | 63 |  |  | 		_cl = cluster | 
            
                                                                                                            
                            
            
                                    
            
            
                | 64 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 65 |  |  | 	with Client(_cl): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 66 |  |  | 		_log_pred = db.from_sequence(samples).map(log_p_pt) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 67 |  |  | 		progress(_log_pred) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 68 |  |  | 		ret = np.stack(_log_pred.compute()) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 69 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 70 |  |  | 	if cluster is None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 71 |  |  | 		_cl.close() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 72 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 73 |  |  | 	return ret | 
            
                                                                                                            
                            
            
                                    
            
            
                | 74 |  |  |  | 
            
                                                                                                            
                                                                
            
                                    
            
            
                | 75 |  |  |  | 
            
                                                                        
                            
            
                                    
            
            
                | 76 | 1 |  | def _log_prob_pt_samples_mt(log_p_pt, samples, nthreads=1): | 
            
                                                                        
                            
            
                                    
            
            
                | 77 |  |  | 	from multiprocessing import pool | 
            
                                                                        
                            
            
                                    
            
            
                | 78 |  |  | 	try: | 
            
                                                                        
                            
            
                                    
            
            
                | 79 |  |  | 		from tqdm.autonotebook import tqdm | 
            
                                                                        
                            
            
                                    
            
            
                | 80 |  |  | 	except ImportError: | 
            
                                                                        
                            
            
                                    
            
            
                | 81 |  |  | 		tqdm = None | 
            
                                                                        
                            
            
                                    
            
            
                | 82 |  |  |  | 
            
                                                                        
                            
            
                                    
            
            
                | 83 |  |  | 	# multiprocessing.pool | 
            
                                                                        
                            
            
                                    
            
            
                | 84 |  |  | 	_p = pool.Pool(processes=nthreads) | 
            
                                                                        
                            
            
                                    
            
            
                | 85 |  |  |  | 
            
                                                                        
                            
            
                                    
            
            
                | 86 |  |  | 	_mapped = _p.imap_unordered(log_p_pt, samples) | 
            
                                                                        
                            
            
                                    
            
            
                | 87 |  |  | 	if tqdm is not None: | 
            
                                                                        
                            
            
                                    
            
            
                | 88 |  |  | 		_mapped = tqdm(_mapped, total=len(samples)) | 
            
                                                                        
                            
            
                                    
            
            
                | 89 |  |  |  | 
            
                                                                        
                            
            
                                    
            
            
                | 90 |  |  | 	ret = np.stack(list(_mapped)) | 
            
                                                                        
                            
            
                                    
            
            
                | 91 |  |  |  | 
            
                                                                        
                            
            
                                    
            
            
                | 92 |  |  | 	_p.close() | 
            
                                                                        
                            
            
                                    
            
            
                | 93 |  |  | 	_p.join() | 
            
                                                                        
                            
            
                                    
            
            
                | 94 |  |  |  | 
            
                                                                        
                            
            
                                    
            
            
                | 95 |  |  | 	return ret | 
            
                                                                                                            
                            
            
                                    
            
            
                | 96 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 97 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 98 | 1 |  | def waic_loo(model, times, data, errs, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 99 |  |  | 		samples, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 100 |  |  | 		method="likelihood", | 
            
                                                                                                            
                            
            
                                    
            
            
                | 101 |  |  | 		train_data=None, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 102 |  |  | 		noisy_targets=True, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 103 |  |  | 		nthreads=1, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 104 |  |  | 		use_dask=False, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 105 |  |  | 		dask_cluster=None, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 106 |  |  | 		): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 107 |  |  | 	"""Watanabe-Akaike information criterion (WAIC) and LOO IC of the (GP) model | 
            
                                                                                                            
                            
            
                                    
            
            
                | 108 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 109 |  |  | 	Calculates the WAIC and leave-one-out (LOO) cross validation scores and | 
            
                                                                                                            
                            
            
                                    
            
            
                | 110 |  |  | 	information criteria (IC) from the MCMC samples of the posterior parameter | 
            
                                                                                                            
                            
            
                                    
            
            
                | 111 |  |  | 	distributions. Uses the posterior point-wise (per data point) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 112 |  |  | 	probabilities and the formulae from [1]_ and [2]_. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 113 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 114 |  |  | 	.. [1] Vehtari, Gelman, and Gabry, Stat Comput (2017) 27:1413–1432, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 115 |  |  | 		doi: 10.1007/s11222-016-9696-4 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 116 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 117 |  |  | 	.. [2] Vehtari and Gelman, (unpublished) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 118 |  |  | 		http://www.stat.columbia.edu/~gelman/research/unpublished/waic_stan.pdf | 
            
                                                                                                            
                            
            
                                    
            
            
                | 119 |  |  | 		http://www.stat.columbia.edu/~gelman/research/unpublished/loo_stan.pdf | 
            
                                                                                                            
                            
            
                                    
            
            
                | 120 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 121 |  |  | 	Parameters | 
            
                                                                                                            
                            
            
                                    
            
            
                | 122 |  |  | 	---------- | 
            
                                                                                                            
                            
            
                                    
            
            
                | 123 |  |  | 	model : `celerite.GP`, `george.GP` or `CeleriteModelSet` instance | 
            
                                                                                                            
                            
            
                                    
            
            
                | 124 |  |  | 		The model instance whose parameter distribution was drawn. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 125 |  |  | 	times : (M,) array_like | 
            
                                                                                                            
                            
            
                                    
            
            
                | 126 |  |  | 		The test coordinates to predict or evaluate the model on. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 127 |  |  | 	data : (M,) array_like | 
            
                                                                                                            
                            
            
                                    
            
            
                | 128 |  |  | 		The test data to test the model against. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 129 |  |  | 	errs : (M,) array_like | 
            
                                                                                                            
                            
            
                                    
            
            
                | 130 |  |  | 		The errors (variances) of the test data. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 131 |  |  | 	samples : (K, L) array_like | 
            
                                                                                                            
                            
            
                                    
            
            
                | 132 |  |  | 		The `K` MCMC samples of the `L` parameter distributions. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 133 |  |  | 	method : str ("likelihood" or "predict"), optional | 
            
                                                                                                            
                            
            
                                    
            
            
                | 134 |  |  | 		The method to "predict" the data, the default uses the (log)likelihood | 
            
                                                                                                            
                            
            
                                    
            
            
                | 135 |  |  | 		in the same way as is done when fitting (training) the model. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 136 |  |  | 		"predict" uses the actual GP prediction, might be useful if the IC | 
            
                                                                                                            
                            
            
                                    
            
            
                | 137 |  |  | 		should be estimated for actual test data that was not used to train | 
            
                                                                                                            
                            
            
                                    
            
            
                | 138 |  |  | 		the model. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 139 |  |  | 	train_data : (N,) array_like, optional | 
            
                                                                                                            
                            
            
                                    
            
            
                | 140 |  |  | 		The data on which the model was trained, needed if method="predict" is | 
            
                                                                                                            
                            
            
                                    
            
            
                | 141 |  |  | 		used, otherwise None is the default and the likelihood is used. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 142 |  |  | 	noisy_targets : bool, optional | 
            
                                                                                                            
                            
            
                                    
            
            
                | 143 |  |  | 		Include the given errors when calculating the predictive probability. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 144 |  |  | 	nthreads : int, optional | 
            
                                                                                                            
                            
            
                                    
            
            
                | 145 |  |  | 		Number of threads to distribute the point-wise probability | 
            
                                                                                                            
                            
            
                                    
            
            
                | 146 |  |  | 		calculations to (default: 1). | 
            
                                                                                                            
                            
            
                                    
            
            
                | 147 |  |  | 	use_dask : boot, optional | 
            
                                                                                                            
                            
            
                                    
            
            
                | 148 |  |  | 		Use `dask.distributed` to distribute the point-wise probability | 
            
                                                                                                            
                            
            
                                    
            
            
                | 149 |  |  | 		calculations to `nthreads` workers. The default is to use | 
            
                                                                                                            
                            
            
                                    
            
            
                | 150 |  |  | 		`multiprocessing.pool.Pool()`. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 151 |  |  | 	dask_cluster: str, or `dask.distributed.Cluster` instance, optional | 
            
                                                                                                            
                            
            
                                    
            
            
                | 152 |  |  | 		Will be passed to `dask.distributed.Client()` | 
            
                                                                                                            
                            
            
                                    
            
            
                | 153 |  |  | 		This can be the address of a Scheduler server like a string | 
            
                                                                                                            
                            
            
                                    
            
            
                | 154 |  |  | 		'127.0.0.1:8786' or a cluster object like `dask.distributed.LocalCluster()`. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 155 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 156 |  |  | 	Returns | 
            
                                                                                                            
                            
            
                                    
            
            
                | 157 |  |  | 	------- | 
            
                                                                                                            
                            
            
                                    
            
            
                | 158 |  |  | 	waic, waic_se, p_waic, loo_ic, loo_se, p_loo : tuple | 
            
                                                                                                            
                            
            
                                    
            
            
                | 159 |  |  | 		The WAIC and its standard error as well as the | 
            
                                                                                                            
                            
            
                                    
            
            
                | 160 |  |  | 		estimated effective number of parameters, p_waic. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 161 |  |  | 		The LOO IC, its standard error, and the estimated | 
            
                                                                                                            
                            
            
                                    
            
            
                | 162 |  |  | 		effective number of parameters, p_loo. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 163 |  |  | 	""" | 
            
                                                                                                            
                            
            
                                    
            
            
                | 164 |  |  | 	from functools import partial | 
            
                                                                                                            
                            
            
                                    
            
            
                | 165 |  |  | 	from scipy.special import logsumexp | 
            
                                                                                                            
                            
            
                                    
            
            
                | 166 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 167 |  |  | 	# the predictive covariance should include the data variance | 
            
                                                                                                            
                            
            
                                    
            
            
                | 168 |  |  | 	# set to a small value if we don't want to account for them | 
            
                                                                                                            
                            
            
                                    
            
            
                | 169 |  |  | 	if not noisy_targets or errs is None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 170 |  |  | 		errs = 1.123e-12 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 171 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 172 |  |  | 	# point-wise posterior/predictive probabilities | 
            
                                                                                                            
                            
            
                                    
            
            
                | 173 |  |  | 	_log_p_pt = partial(_log_lh_pt, model, times, data, errs) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 174 |  |  | 	if method == "predict" and train_data is not None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 175 |  |  | 		_log_p_pt = partial(_log_pred_pt, model, train_data, times, data, errs) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 176 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 177 |  |  | 	# calculate the point-wise probabilities and stack them together | 
            
                                                                                                            
                            
            
                                    
            
            
                | 178 |  |  | 	if nthreads > 1 and use_dask: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 179 |  |  | 		log_pred = _log_prob_pt_samples_dask(_log_p_pt, samples, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 180 |  |  | 				nthreads=nthreads, cluster=dask_cluster) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 181 |  |  | 	else: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 182 |  |  | 		log_pred = _log_prob_pt_samples_mt(_log_p_pt, samples, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 183 |  |  | 				nthreads=nthreads) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 184 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 185 |  |  | 	lppd_i = logsumexp(log_pred, b=1. / log_pred.shape[0], axis=0) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 186 |  |  | 	p_waic_i = np.nanvar(log_pred, ddof=1, axis=0) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 187 |  |  | 	if np.any(p_waic_i > 0.4): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 188 |  |  | 		logging.warn("""For one or more samples the posterior variance of the | 
            
                                                                                                            
                            
            
                                    
            
            
                | 189 |  |  | 		log predictive densities exceeds 0.4. This could be indication of | 
            
                                                                                                            
                            
            
                                    
            
            
                | 190 |  |  | 		WAIC starting to fail see http://arxiv.org/abs/1507.04544 for details | 
            
                                                                                                            
                            
            
                                    
            
            
                | 191 |  |  | 		""") | 
            
                                                                                                            
                            
            
                                    
            
            
                | 192 |  |  | 	elpd_i = lppd_i - p_waic_i | 
            
                                                                                                            
                            
            
                                    
            
            
                | 193 |  |  | 	waic_i = -2. * elpd_i | 
            
                                                                                                            
                            
            
                                    
            
            
                | 194 |  |  | 	waic_se = np.sqrt(len(waic_i) * np.nanvar(waic_i, ddof=1)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 195 |  |  | 	waic = np.nansum(waic_i) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 196 |  |  | 	p_waic = np.nansum(p_waic_i) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 197 |  |  | 	if 2. * p_waic > len(waic_i): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 198 |  |  | 		logging.warn("""p_waic > n / 2, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 199 |  |  | 		the WAIC approximation is unreliable. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 200 |  |  | 		""") | 
            
                                                                                                            
                            
            
                                    
            
            
                | 201 |  |  | 	logging.info("WAIC: %s, waic_se: %s, p_w: %s", waic, waic_se, p_waic) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 202 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 203 |  |  | 	# LOO | 
            
                                                                                                            
                            
            
                                    
            
            
                | 204 |  |  | 	loo_ws = 1. / np.exp(log_pred - np.nanmax(log_pred, axis=0)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 205 |  |  | 	loo_ws_n = loo_ws / np.nanmean(loo_ws, axis=0) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 206 |  |  | 	loo_ws_r = np.clip(loo_ws_n, None, np.sqrt(log_pred.shape[0])) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 207 |  |  | 	elpd_loo_i = logsumexp(log_pred, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 208 |  |  | 			b=loo_ws_r / np.nansum(loo_ws_r, axis=0), | 
            
                                                                                                            
                            
            
                                    
            
            
                | 209 |  |  | 			axis=0) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 210 |  |  | 	p_loo_i = lppd_i - elpd_loo_i | 
            
                                                                                                            
                            
            
                                    
            
            
                | 211 |  |  | 	loo_ic_i = -2 * elpd_loo_i | 
            
                                                                                                            
                            
            
                                    
            
            
                | 212 |  |  | 	loo_ic_se = np.sqrt(len(loo_ic_i) * np.nanvar(loo_ic_i)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 213 |  |  | 	loo_ic = np.nansum(loo_ic_i) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 214 |  |  | 	p_loo = np.nansum(p_loo_i) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 215 |  |  | 	logging.info("loo IC: %s, se: %s, p_loo: %s", loo_ic, loo_ic_se, p_loo) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 216 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 217 |  |  | 	# van der Linde, 2005, Statistica Neerlandica, 2005 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 218 |  |  | 	# https://doi.org/10.1111/j.1467-9574.2005.00278.x | 
            
                                                                                                            
                            
            
                                    
            
            
                | 219 |  |  | 	hy1 = -np.nanmean(lppd_i) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 220 |  |  | 	hy2 = -np.nanmedian(lppd_i) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 221 |  |  | 	logging.info("H(Y): mean %s, median: %s", hy1, hy2) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 222 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 223 |  |  | 	return waic, waic_se, p_waic, loo_ic, loo_ic_se, p_loo | 
            
                                                                                                            
                            
            
                                    
            
            
                | 224 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 225 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 226 | 1 |  | def mcmc_statistics(model, times, data, errs, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 227 |  |  | 		train_times, train_data, train_errs, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 228 |  |  | 		samples, lnp, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 229 |  |  | 		median=False): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 230 |  |  | 	"""Statistics for the (GP) model against the provided data | 
            
                                                                                                            
                            
            
                                    
            
            
                | 231 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 232 |  |  | 	Statistical information about the model and the sampled parameter | 
            
                                                                                                            
                            
            
                                    
            
            
                | 233 |  |  | 	distributions with respect to the provided data and its variance. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 234 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 235 |  |  | 	Sends the calculated values to the logger, includes the mean | 
            
                                                                                                            
                            
            
                                    
            
            
                | 236 |  |  | 	standardized log loss as described in R&W, 2006, section 2.5, (2.34), | 
            
                                                                                                            
                            
            
                                    
            
            
                | 237 |  |  | 	and some slightly adapted $\\chi^2_{\\text{red}}$ and $R^2$ scores. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 238 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 239 |  |  | 	Parameters | 
            
                                                                                                            
                            
            
                                    
            
            
                | 240 |  |  | 	---------- | 
            
                                                                                                            
                            
            
                                    
            
            
                | 241 |  |  | 	model : `celerite.GP`, `george.GP` or `CeleriteModelSet` instance | 
            
                                                                                                            
                            
            
                                    
            
            
                | 242 |  |  | 		The model instance whose parameter distribution was drawn. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 243 |  |  | 	times : (M,) array_like | 
            
                                                                                                            
                            
            
                                    
            
            
                | 244 |  |  | 		The test coordinates to predict or evaluate the model on. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 245 |  |  | 	data : (M,) array_like | 
            
                                                                                                            
                            
            
                                    
            
            
                | 246 |  |  | 		The test data to test the model against. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 247 |  |  | 	errs : (M,) array_like | 
            
                                                                                                            
                            
            
                                    
            
            
                | 248 |  |  | 		The errors (variances) of the test data. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 249 |  |  | 	train_times : (N,) array_like | 
            
                                                                                                            
                            
            
                                    
            
            
                | 250 |  |  | 		The coordinates on which the model was trained. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 251 |  |  | 	train_data : (N,) array_like | 
            
                                                                                                            
                            
            
                                    
            
            
                | 252 |  |  | 		The data on which the model was trained. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 253 |  |  | 	train_errs : (N,) array_like | 
            
                                                                                                            
                            
            
                                    
            
            
                | 254 |  |  | 		The errors (variances) of the training data. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 255 |  |  | 	samples : (K, L) array_like | 
            
                                                                                                            
                            
            
                                    
            
            
                | 256 |  |  | 		The `K` MCMC samples of the `L` parameter distributions. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 257 |  |  | 	lnp : (K,) array_like | 
            
                                                                                                            
                            
            
                                    
            
            
                | 258 |  |  | 		The posterior log probabilities of the `K` MCMC samples. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 259 |  |  | 	median : bool, optional | 
            
                                                                                                            
                            
            
                                    
            
            
                | 260 |  |  | 		Whether to use the median of the sampled distributions or | 
            
                                                                                                            
                            
            
                                    
            
            
                | 261 |  |  | 		the maximum posterior sample (the default) to evaluate the | 
            
                                                                                                            
                            
            
                                    
            
            
                | 262 |  |  | 		statistics. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 263 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 264 |  |  | 	Returns | 
            
                                                                                                            
                            
            
                                    
            
            
                | 265 |  |  | 	------- | 
            
                                                                                                            
                            
            
                                    
            
            
                | 266 |  |  | 	nothing | 
            
                                                                                                            
                            
            
                                    
            
            
                | 267 |  |  | 	""" | 
            
                                                                                                            
                            
            
                                    
            
            
                | 268 | 1 |  | 	ndat = len(times) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 269 | 1 |  | 	ndim = len(model.get_parameter_vector()) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 270 | 1 |  | 	mdim = len(model.mean.get_parameter_vector()) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 271 | 1 |  | 	samples_max_lp = np.max(lnp) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 272 | 1 |  | 	if median: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 273 |  |  | 		sample_pos = np.nanmedian(samples, axis=0) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 274 |  |  | 	else: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 275 | 1 |  | 		sample_pos = samples[np.argmax(lnp)] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 276 | 1 |  | 	model.set_parameter_vector(sample_pos) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 277 |  |  | 	# calculate the GP predicted values and covariance | 
            
                                                                                                            
                            
            
                                    
            
            
                | 278 | 1 |  | 	gppred, gpcov = model.predict(train_data, t=times, return_cov=True) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 279 |  |  | 	# the predictive covariance should include the data variance | 
            
                                                                                                            
                            
            
                                    
            
            
                | 280 | 1 |  | 	gpcov[np.diag_indices_from(gpcov)] += errs**2 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 281 |  |  | 	# residuals | 
            
                                                                                                            
                            
            
                                    
            
            
                | 282 | 1 |  | 	resid_mod = model.mean.get_value(times) - data  # GP mean model | 
            
                                                                                                            
                            
            
                                    
            
            
                | 283 | 1 |  | 	resid_gp = gppred - data  # GP prediction | 
            
                                                                                                            
                            
            
                                    
            
            
                | 284 | 1 |  | 	resid_triv = np.nanmean(train_data) - data  # trivial model | 
            
                                                                                                            
                            
            
                                    
            
            
                | 285 | 1 |  | 	_const = ndat * np.log(2.0 * np.pi) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 286 | 1 |  | 	test_logpred = -0.5 * (resid_gp.dot(linalg.solve(gpcov, resid_gp)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 287 |  |  | 			+ np.trace(np.log(gpcov)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 288 |  |  | 			+ _const) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 289 |  |  | 	# MSLL -- mean standardized log loss | 
            
                                                                                                            
                            
            
                                    
            
            
                | 290 |  |  | 	# as described in R&W, 2006, section 2.5, (2.34) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 291 | 1 |  | 	var_mod = np.nanvar(resid_mod, ddof=mdim)  # mean model variance | 
            
                                                                                                            
                            
            
                                    
            
            
                | 292 | 1 |  | 	var_gp = np.nanvar(resid_gp, ddof=ndim)  # gp model variance | 
            
                                                                                                            
                            
            
                                    
            
            
                | 293 | 1 |  | 	var_triv = np.nanvar(train_data, ddof=1)  # trivial model variance | 
            
                                                                                                            
                            
            
                                    
            
            
                | 294 | 1 |  | 	logpred_mod = _log_prob(resid_mod, var_mod) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 295 | 1 |  | 	logpred_gp = _log_prob(resid_gp, var_gp) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 296 | 1 |  | 	logpred_triv = _log_prob(resid_triv, var_triv) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 297 | 1 |  | 	logging.info("MSLL mean: %s", np.nanmean(-logpred_mod + logpred_triv)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 298 | 1 |  | 	logging.info("MSLL gp: %s", np.nanmean(-logpred_gp + logpred_triv)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 299 |  |  | 	# predictive variances | 
            
                                                                                                            
                            
            
                                    
            
            
                | 300 | 1 |  | 	logpred_mod = _log_prob(resid_mod, var_mod + errs**2) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 301 | 1 |  | 	logpred_gp = _log_prob(resid_gp, var_gp + errs**2) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 302 | 1 |  | 	logpred_triv = _log_prob(resid_triv, var_triv + errs**2) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 303 | 1 |  | 	logging.info("pred MSLL mean: %s", np.nanmean(-logpred_mod + logpred_triv)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 304 | 1 |  | 	logging.info("pred MSLL gp: %s", np.nanmean(-logpred_gp + logpred_triv)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 305 |  |  | 	# cost values | 
            
                                                                                                            
                            
            
                                    
            
            
                | 306 | 1 |  | 	cost_mod = np.sum(resid_mod**2) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 307 | 1 |  | 	cost_triv = np.sum(resid_triv**2) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 308 | 1 |  | 	cost_gp = np.sum(resid_gp**2) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 309 |  |  | 	# chi^2 (variance corrected costs) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 310 | 1 |  | 	chisq_mod_ye = np.sum((resid_mod / errs)**2) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 311 | 1 |  | 	chisq_triv = np.sum((resid_triv / errs)**2) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 312 | 1 |  | 	chisq_gpcov = resid_mod.dot(linalg.solve(gpcov, resid_mod)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 313 |  |  | 	# adjust for degrees of freedom | 
            
                                                                                                            
                            
            
                                    
            
            
                | 314 | 1 |  | 	cost_gp_dof = cost_gp / (ndat - ndim) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 315 | 1 |  | 	cost_mod_dof = cost_mod / (ndat - mdim) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 316 | 1 |  | 	cost_triv_dof = cost_triv / (ndat - 1) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 317 |  |  | 	# reduced chi^2 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 318 | 1 |  | 	chisq_red_mod_ye = chisq_mod_ye / (ndat - mdim) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 319 | 1 |  | 	chisq_red_triv = chisq_triv / (ndat - 1) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 320 | 1 |  | 	chisq_red_gpcov = chisq_gpcov / (ndat - ndim) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 321 |  |  | 	# "generalized" R^2 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 322 | 1 |  | 	logp_triv1 = np.sum(_log_prob(resid_triv, errs**2)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 323 | 1 |  | 	logp_triv2 = np.sum(_log_prob(resid_triv, var_triv)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 324 | 1 |  | 	logp_triv3 = np.sum(_log_prob(resid_triv, var_triv + errs**2)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 325 | 1 |  | 	log_lambda1 = test_logpred - logp_triv1 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 326 | 1 |  | 	log_lambda2 = test_logpred - logp_triv2 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 327 | 1 |  | 	log_lambda3 = test_logpred - logp_triv3 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 328 | 1 |  | 	gen_rsq1a = 1 - np.exp(-2 * log_lambda1 / ndat) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 329 | 1 |  | 	gen_rsq1b = 1 - np.exp(-2 * log_lambda1 / (ndat - ndim)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 330 | 1 |  | 	gen_rsq2a = 1 - np.exp(-2 * log_lambda2 / ndat) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 331 | 1 |  | 	gen_rsq2b = 1 - np.exp(-2 * log_lambda2 / (ndat - ndim)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 332 | 1 |  | 	gen_rsq3a = 1 - np.exp(-2 * log_lambda3 / ndat) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 333 | 1 |  | 	gen_rsq3b = 1 - np.exp(-2 * log_lambda3 / (ndat - ndim)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 334 |  |  | 	# sent to the logger | 
            
                                                                                                            
                            
            
                                    
            
            
                | 335 | 1 |  | 	logging.info("train max logpost: %s", samples_max_lp) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 336 | 1 |  | 	logging.info("test log_pred: %s", test_logpred) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 337 | 1 |  | 	logging.info("1a cost mean model: %s, dof adj: %s", cost_mod, cost_mod_dof) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 338 | 1 |  | 	logging.debug("1c cost gp predict: %s, dof adj: %s", cost_gp, cost_gp_dof) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 339 | 1 |  | 	logging.debug("1b cost triv model: %s, dof adj: %s", cost_triv, cost_triv_dof) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 340 | 1 |  | 	logging.info("1d var resid mean model: %s, gp model: %s, triv: %s", | 
            
                                                                                                            
                            
            
                                    
            
            
                | 341 |  |  | 			var_mod, var_gp, var_triv) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 342 | 1 |  | 	logging.info("2a adjR2 mean model: %s, adjR2 gp predict: %s", | 
            
                                                                                                            
                            
            
                                    
            
            
                | 343 |  |  | 			1 - cost_mod_dof / cost_triv_dof, 1 - cost_gp_dof / cost_triv_dof) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 344 | 1 |  | 	logging.info("2b red chi^2 mod: %s / triv: %s = %s", | 
            
                                                                                                            
                            
            
                                    
            
            
                | 345 |  |  | 			chisq_red_mod_ye, chisq_red_triv, chisq_red_mod_ye / chisq_red_triv) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 346 | 1 |  | 	logging.info("2c red chi^2 mod (gp cov): %s / triv: %s = %s", | 
            
                                                                                                            
                            
            
                                    
            
            
                | 347 |  |  | 			chisq_red_gpcov, chisq_red_triv, chisq_red_gpcov / chisq_red_triv) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 348 | 1 |  | 	logging.info("3a stand. red chi^2: %s", chisq_red_gpcov / chisq_red_triv) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 349 | 1 |  | 	logging.info("3b 1 - stand. red chi^2: %s", | 
            
                                                                                                            
                            
            
                                    
            
            
                | 350 |  |  | 			1 - chisq_red_gpcov / chisq_red_triv) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 351 | 1 |  | 	logging.info("5a generalized R^2: 1a: %s, 1b: %s", | 
            
                                                                                                            
                            
            
                                    
            
            
                | 352 |  |  | 			gen_rsq1a, gen_rsq1b) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 353 | 1 |  | 	logging.info("5b generalized R^2: 2a: %s, 2b: %s", | 
            
                                                                                                            
                            
            
                                    
            
            
                | 354 |  |  | 			gen_rsq2a, gen_rsq2b) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 355 | 1 |  | 	logging.info("5c generalized R^2: 3a: %s, 3b: %s", | 
            
                                                                                                            
                            
            
                                    
            
            
                | 356 |  |  | 			gen_rsq3a, gen_rsq3b) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 357 | 1 |  | 	try: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 358 |  |  | 		# celerite | 
            
                                                                                                            
                            
            
                                    
            
            
                | 359 | 1 |  | 		logdet = model.solver.log_determinant() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 360 | 1 |  | 	except TypeError: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 361 |  |  | 		# george | 
            
                                                                                                            
                            
            
                                    
            
            
                | 362 | 1 |  | 		logdet = model.solver.log_determinant | 
            
                                                                                                            
                                                                
            
                                    
            
            
                | 363 |  |  | 	logging.debug("5 logdet: %s, const 2: %s", logdet, _const) | 
            
                                                        
            
                                    
            
            
                | 364 |  |  |  |