Passed
Branch master (9c78f3)
by Stefan
03:38
created

sciapy.regress.statistics.mcmc_statistics()   B

Complexity

Conditions 3

Size

Total Lines 138
Code Lines 85

Duplication

Lines 138
Ratio 100 %

Code Coverage

Tests 1
CRAP Score 11.6298

Importance

Changes 0
Metric Value
cc 3
eloc 85
nop 10
dl 138
loc 138
ccs 1
cts 72
cp 0.0139
crap 11.6298
rs 7.4909
c 0
b 0
f 0

How to fix   Long Method    Many Parameters   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
# -*- coding: utf-8 -*-
2
# vim:fileencoding=utf-8
3
#
4
# Copyright (c) 2017-2019 Stefan Bender
5
#
6
# This module is part of sciapy.
7
# sciapy is free software: you can redistribute it or modify
8
# it under the terms of the GNU General Public License as published
9
# by the Free Software Foundation, version 2.
10
# See accompanying LICENSE file or http://www.gnu.org/licenses/gpl-2.0.html.
11 1
"""SCIAMACHY MCMC statistic tools
12
13
Statistical functions for MCMC sampled parameters.
14
"""
15
16 1
import logging
17
18 1
import numpy as np
19
20 1
__all__ = ["mcmc_statistics"]
21
22
23 1
def _log_prob(resid, var):
24
	return -0.5 * (np.log(2 * np.pi * var) + resid**2 / var)
25
26
27 1 View Code Duplication
def mcmc_statistics(model, times, data, errs,
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
28
		train_times, train_data, train_errs,
29
		samples, lnp,
30
		median=False):
31
	"""Statistics for the (GP) model against the provided data
32
33
	Statistical information about the model and the sampled parameter
34
	distributions with respect to the provided data and its variance.
35
36
	Sends the calculated values to the logger, includes the mean
37
	standardized log loss as described in R&W, 2006, section 2.5, (2.34),
38
	and some slightly adapted $\\chi^2_{\\text{red}}$ and $R^2$ scores.
39
40
	Parameters
41
	----------
42
	model : `celerite.GP`, `george.GP` or `CeleriteModelSet` instance
43
		The model instance whose parameter distribution was drawn.
44
	times : (M,) array_like
45
		The test coordinates to predict or evaluate the model on.
46
	data : (M,) array_like
47
		The test data to test the model against.
48
	errs : (M,) array_like
49
		The errors (variances) of the test data.
50
	train_times : (N,) array_like
51
		The coordinates on which the model was trained.
52
	train_data : (N,) array_like
53
		The data on which the model was trained.
54
	train_errs : (N,) array_like
55
		The errors (variances) of the training data.
56
	samples : (K, L) array_like
57
		The `K` MCMC samples of the `L` parameter distributions.
58
	lnp : (K,) array_like
59
		The posterior log probabilities of the `K` MCMC samples.
60
	median : bool, optional
61
		Whether to use the median of the sampled distributions or
62
		the maximum posterior sample (the default) to evaluate the
63
		statistics.
64
65
	Returns
66
	-------
67
	nothing
68
	"""
69
	ndat = len(times)
70
	ndim = len(model.get_parameter_vector())
71
	mdim = len(model.mean.get_parameter_vector())
72
	samples_max_lp = np.max(lnp)
73
	if median:
74
		sample_pos = np.nanmedian(samples, axis=0)
75
	else:
76
		sample_pos = samples[np.argmax(lnp)]
77
	model.set_parameter_vector(sample_pos)
78
	# calculate the GP predicted values and covariance
79
	gppred, gpcov = model.predict(train_data, t=times, return_cov=True)
80
	# the predictive covariance should include the data variance
81
	gpcov[np.diag_indices_from(gpcov)] += errs**2
82
	# residuals
83
	resid_mod = model.mean.get_value(times) - data  # GP mean model
84
	resid_gp = gppred - data  # GP prediction
85
	resid_triv = np.nanmean(train_data) - data  # trivial model
86
	_const = ndat * np.log(2.0 * np.pi)
87
	test_logpred = -0.5 * (resid_gp.dot(np.linalg.solve(gpcov, resid_gp))
88
			+ np.trace(np.log(gpcov))
89
			+ _const)
90
	# MSLL -- mean standardized log loss
91
	# as described in R&W, 2006, section 2.5, (2.34)
92
	var_mod = np.nanvar(resid_mod, ddof=mdim)  # mean model variance
93
	var_gp = np.nanvar(resid_gp, ddof=ndim)  # gp model variance
94
	var_triv = np.nanvar(train_data, ddof=1)  # trivial model variance
95
	logpred_mod = _log_prob(resid_mod, var_mod)
96
	logpred_gp = _log_prob(resid_gp, var_gp)
97
	logpred_triv = _log_prob(resid_triv, var_triv)
98
	logging.info("MSLL mean: %s", np.nanmean(-logpred_mod + logpred_triv))
99
	logging.info("MSLL gp: %s", np.nanmean(-logpred_gp + logpred_triv))
100
	# predictive variances
101
	logpred_mod = _log_prob(resid_mod, var_mod + errs**2)
102
	logpred_gp = _log_prob(resid_gp, var_gp + errs**2)
103
	logpred_triv = _log_prob(resid_triv, var_triv + errs**2)
104
	logging.info("pred MSLL mean: %s", np.nanmean(-logpred_mod + logpred_triv))
105
	logging.info("pred MSLL gp: %s", np.nanmean(-logpred_gp + logpred_triv))
106
	# cost values
107
	cost_mod = np.sum(resid_mod**2)
108
	cost_triv = np.sum(resid_triv**2)
109
	cost_gp = np.sum(resid_gp**2)
110
	# chi^2 (variance corrected costs)
111
	chisq_mod_ye = np.sum((resid_mod / errs)**2)
112
	chisq_triv = np.sum((resid_triv / errs)**2)
113
	chisq_gpcov = resid_mod.dot(np.linalg.solve(gpcov, resid_mod))
114
	# adjust for degrees of freedom
115
	cost_gp_dof = cost_gp / (ndat - ndim)
116
	cost_mod_dof = cost_mod / (ndat - mdim)
117
	cost_triv_dof = cost_triv / (ndat - 1)
118
	# reduced chi^2
119
	chisq_red_mod_ye = chisq_mod_ye / (ndat - mdim)
120
	chisq_red_triv = chisq_triv / (ndat - 1)
121
	chisq_red_gpcov = chisq_gpcov / (ndat - ndim)
122
	# "generalized" R^2
123
	logp_triv1 = np.sum(_log_prob(resid_triv, errs**2))
124
	logp_triv2 = np.sum(_log_prob(resid_triv, var_triv))
125
	logp_triv3 = np.sum(_log_prob(resid_triv, var_triv + errs**2))
126
	log_lambda1 = test_logpred - logp_triv1
127
	log_lambda2 = test_logpred - logp_triv2
128
	log_lambda3 = test_logpred - logp_triv3
129
	gen_rsq1a = 1 - np.exp(-2 * log_lambda1 / ndat)
130
	gen_rsq1b = 1 - np.exp(-2 * log_lambda1 / (ndat - ndim))
131
	gen_rsq2a = 1 - np.exp(-2 * log_lambda2 / ndat)
132
	gen_rsq2b = 1 - np.exp(-2 * log_lambda2 / (ndat - ndim))
133
	gen_rsq3a = 1 - np.exp(-2 * log_lambda3 / ndat)
134
	gen_rsq3b = 1 - np.exp(-2 * log_lambda3 / (ndat - ndim))
135
	# sent to the logger
136
	logging.info("train max logpost: %s", samples_max_lp)
137
	logging.info("test log_pred: %s", test_logpred)
138
	logging.info("1a cost mean model: %s, dof adj: %s", cost_mod, cost_mod_dof)
139
	logging.debug("1c cost gp predict: %s, dof adj: %s", cost_gp, cost_gp_dof)
140
	logging.debug("1b cost triv model: %s, dof adj: %s", cost_triv, cost_triv_dof)
141
	logging.info("1d var resid mean model: %s, gp model: %s, triv: %s",
142
			var_mod, var_gp, var_triv)
143
	logging.info("2a adjR2 mean model: %s, adjR2 gp predict: %s",
144
			1 - cost_mod_dof / cost_triv_dof, 1 - cost_gp_dof / cost_triv_dof)
145
	logging.info("2b red chi^2 mod: %s / triv: %s = %s",
146
			chisq_red_mod_ye, chisq_red_triv, chisq_red_mod_ye / chisq_red_triv)
147
	logging.info("2c red chi^2 mod (gp cov): %s / triv: %s = %s",
148
			chisq_red_gpcov, chisq_red_triv, chisq_red_gpcov / chisq_red_triv)
149
	logging.info("3a stand. red chi^2: %s", chisq_red_gpcov / chisq_red_triv)
150
	logging.info("3b 1 - stand. red chi^2: %s",
151
			1 - chisq_red_gpcov / chisq_red_triv)
152
	logging.info("5a generalized R^2: 1a: %s, 1b: %s",
153
			gen_rsq1a, gen_rsq1b)
154
	logging.info("5b generalized R^2: 2a: %s, 2b: %s",
155
			gen_rsq2a, gen_rsq2b)
156
	logging.info("5c generalized R^2: 3a: %s, 3b: %s",
157
			gen_rsq3a, gen_rsq3b)
158
	try:
159
		# celerite
160
		logdet = model.solver.log_determinant()
161
	except TypeError:
162
		# george
163
		logdet = model.solver.log_determinant
164
	logging.debug("5 logdet: %s, const 2: %s", logdet, _const)
165