Passed
Push — master ( 54142a...75fbae )
by Stefan
05:53
created

sciapy.regress.__main__.main()   F

Complexity

Conditions 56

Size

Total Lines 401
Code Lines 312

Duplication

Lines 401
Ratio 100 %

Code Coverage

Tests 181
CRAP Score 84.8768

Importance

Changes 0
Metric Value
cc 56
eloc 312
nop 0
dl 401
loc 401
ccs 181
cts 229
cp 0.7904
crap 84.8768
rs 0
c 0
b 0
f 0

How to fix   Long Method    Complexity   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Complexity

Complex classes like sciapy.regress.__main__.main() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
# -*- coding: utf-8 -*-
2
# vim:fileencoding=utf-8
3
#
4
# Copyright (c) 2017-2018 Stefan Bender
5
#
6
# This module is part of sciapy.
7
# sciapy is free software: you can redistribute it or modify
8
# it under the terms of the GNU General Public License as published
9
# by the Free Software Foundation, version 2.
10
# See accompanying LICENSE file or http://www.gnu.org/licenses/gpl-2.0.html.
11 1
"""SCIAMACHY data regression command line interface
12
13
Command line main program for regression analysis of SCIAMACHY
14
daily zonal mean time series (NO for now).
15
"""
16
17 1
import ctypes
18 1
import logging
19 1
from os import path
20
21 1
import numpy as np
22 1
import scipy.optimize as op
23 1
from scipy.interpolate import interp1d
24
25 1
import george
26 1
import celerite
27
28 1
import matplotlib as mpl
29
# switch off X11 rendering
30 1
mpl.use("Agg")
31
32 1
from .load_data import load_solar_gm_table, load_scia_dzm
33 1
from .models_cel import trace_gas_model
34 1
from .mcmc import mcmc_sample_model
35 1
from .statistics import mcmc_statistics
36
37 1
from ._gpkernels import (george_solvers,
38
		setup_george_kernel, setup_celerite_terms)
39 1
from ._plot import (plot_single_sample_and_residuals,
40
		plot_residual, plot_random_samples)
41 1
from ._options import parser
42
43
44 1 View Code Duplication
def save_samples_netcdf(filename, model, alt, lat, samples,
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
45
		scale=1e-6,
46
		lnpost=None, compressed=False):
47 1
	from xarray import Dataset
48 1
	smpl_ds = Dataset(dict([(pname, (["lat", "alt", "sample"],
49
				samples[..., i].reshape(1, 1, -1)))
50
				for i, pname in enumerate(model.get_parameter_names())]
51
				# + [("lpost", (["lat", "alt", "sample"], lnp.reshape(1, 1, -1)))]
52
			),
53
			coords={"lat": [lat], "alt": [alt]})
54
55 1
	for modn in model.mean.models:
56 1
		modl = model.mean.models[modn]
57 1
		if hasattr(modl, "mean"):
58 1
			smpl_ds.attrs[modn + ":mean"] = modl.mean
59
60 1
	units = {"kernel": {
61
				"log": "log(10$^{{{0:.0f}}}$ cm$^{{-3}}$)"
62
						.format(-np.log10(scale))},
63
			"mean": {
64
				"log": "log(10$^{{{0:.0f}}}$ cm$^{{-3}}$)"
65
						.format(-np.log10(scale)),
66
				"cos": "10$^{{{0:.0f}}}$ cm$^{{-3}}$".format(-np.log10(scale)),
67
				"sin": "10$^{{{0:.0f}}}$ cm$^{{-3}}$".format(-np.log10(scale)),
68
				"val": "10$^{{{0:.0f}}}$ cm$^{{-3}}$".format(-np.log10(scale)),
69
				"amp": "10$^{{{0:.0f}}}$ cm$^{{-3}}$".format(-np.log10(scale)),
70
				"tau": "d"}}
71 1
	for pname in smpl_ds.data_vars:
72 1
		_pp = pname.split(':')
73 1
		for _n, _u in units[_pp[0]].items():
74 1
			if _pp[-1].startswith(_n):
75 1
				logging.debug("units for %s: %s", pname, _u)
76 1
				smpl_ds[pname].attrs["units"] = _u
77
78 1
	smpl_ds["alt"].attrs = {"long_name": "altitude", "units": "km"}
79 1
	smpl_ds["lat"].attrs = {"long_name": "latitude", "units": "degrees_north"}
80
81 1
	_encoding = None
82 1
	if compressed:
83 1
		_encoding = {var: {"zlib": True, "complevel": 1}
84
					for var in smpl_ds.data_vars}
85 1
	try:
86 1
		smpl_ds.to_netcdf(filename, encoding=_encoding)
87
	except ValueError:
88
		smpl_ds.to_netcdf(filename)
89 1
	smpl_ds.close()
90
91
92 1 View Code Duplication
def _train_test_split(times, data, errs, train_frac,
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
93
		test_frac, randomize):
94
	# split the data into training and test subsets according to the
95
	# fraction given (default is 1, i.e. no splitting)
96 1
	ndata = len(times)
97 1
	train_size = int(ndata * train_frac)
98 1
	test_size = min(ndata - train_size, int(ndata * test_frac))
99
	# randomize if requested
100 1
	if randomize:
101
		permut_idx = np.random.permutation(np.arange(ndata))
102
	else:
103 1
		permut_idx = np.arange(ndata)
104 1
	train_idx = np.sort(permut_idx[:train_size])
105 1
	test_idx = np.sort(permut_idx[train_size:train_size + test_size])
106 1
	times_train = times[train_idx]
107 1
	data_train = data[train_idx]
108 1
	errs_train = errs[train_idx]
109 1
	if test_size > 0:
110
		times_test = times[test_idx]
111
		data_test = data[test_idx]
112
		errs_test = errs[test_idx]
113
	else:
114 1
		times_test = times
115 1
		data_test = data
116 1
		errs_test = errs
117 1
	logging.info("using %s of %s samples for training.", len(times_train), ndata)
118 1
	logging.info("using %s of %s samples for testing.", len(times_test), ndata)
119 1
	return (times_train, data_train, errs_train,
120
			times_test, data_test, errs_test)
121
122
123 1 View Code Duplication
def _r_sun_earth(time, tfmt="jyear"):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
124
	"""First order approximation of the Sun-Earth distance
125
126
	The Sun-to-Earth distance can be used to (un-)normalize proxies
127
	to the actual distance to the Sun instead of 1 AU.
128
129
	Parameters
130
	----------
131
	time : float
132
		Time value in the units given by 'tfmt'.
133
	tfmt : str, optional
134
		The units of 'time' as supported by the
135
		astropy.time time formats. Default: 'jyear'.
136
137
	Returns
138
	-------
139
	dist : float
140
		The Sun-Earth distance at the given day of year in AU.
141
	"""
142
	from astropy.time import Time
143
	tdoy = Time(time, format=tfmt)
144
	tdoy.format = "yday"
145
	doy = int(tdoy.value.split(':')[1])
146
	return 1 - 0.01672 * np.cos(2 * np.pi / 365.256363 * (doy - 4))
147
148
149 1 View Code Duplication
def main():
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
150 1
	logging.basicConfig(level=logging.WARNING,
151
			format="[%(levelname)-8s] (%(asctime)s) "
152
				"%(filename)s:%(lineno)d %(message)s",
153
			datefmt="%Y-%m-%d %H:%M:%S %z")
154
155 1
	args = parser.parse_args()
156
157 1
	logging.info("command line arguments: %s", args)
158 1
	if args.quiet:
159 1
		logging.getLogger().setLevel(logging.ERROR)
160
	elif args.verbose:
161
		logging.getLogger().setLevel(logging.INFO)
162
	else:
163
		logging.getLogger().setLevel(args.loglevel)
164
165 1
	from numpy.distutils.system_info import get_info
166 1
	try:
167 1
		ob_lib_dirs = get_info("openblas")["library_dirs"]
168 1
	except KeyError:
169 1
		ob_lib_dirs = []
170 1
	for oblas_path in ob_lib_dirs:
171
		oblas_name = "{0}/libopenblas.so".format(oblas_path)
172
		logging.info("Trying %s", oblas_name)
173
		try:
174
			oblas_lib = ctypes.cdll.LoadLibrary(oblas_name)
175
			oblas_cores = oblas_lib.openblas_get_num_threads()
176
			oblas_lib.openblas_set_num_threads(args.openblas_threads)
177
			logging.info("Using %s/%s Openblas thread(s).",
178
					oblas_lib.openblas_get_num_threads(), oblas_cores)
179
		except:
180
			logging.info("Setting number of openblas threads failed.")
181
182 1
	if args.random_seed is not None:
183 1
		np.random.seed(args.random_seed)
184
185 1
	if args.proxies:
186 1
		proxies = args.proxies.split(',')
187 1
		proxy_dict = dict(_p.split(':') for _p in proxies)
188
	else:
189
		proxy_dict = {}
190 1
	lag_dict = {pn: 0 for pn in proxy_dict.keys()}
191
192
	# Post-processing of arguments...
193
	# List of proxy lag fits from csv
194 1
	fit_lags = args.fit_lags.split(',')
195
	# List of proxy lifetime fits from csv
196 1
	fit_lifetimes = args.fit_lifetimes.split(',')
197 1
	fit_annlifetimes = args.fit_annlifetimes.split(',')
198
	# List of proxy lag times from csv
199 1
	lag_dict.update(dict(_ls.split(':') for _ls in args.lag_times.split(',')))
200
	# List of cycles (frequencies in 1/year) from argument list (csv)
201 1
	try:
202 1
		freqs = list(map(float, args.freqs.split(',')))
203 1
	except ValueError:
204 1
		freqs = []
205 1
	args.freqs = freqs
206
	# List of initial parameter values
207 1
	initial = None
208 1
	if args.initial is not None:
209
		try:
210
			initial = list(map(float, args.initial.split(',')))
211
		except ValueError:
212
			pass
213
	# List of GP kernels from argument list (csv)
214 1
	kernls = args.kernels.split(',')
215
216 1
	lat = args.latitude
217 1
	alt = args.altitude
218 1
	logging.info("location: %.0f°N %.0f km", lat, alt)
219
220 1
	no_ys, no_dens, no_errs, no_szas = load_scia_dzm(args.file, alt, lat,
221
			tfmt=args.time_format,
222
			scale=args.scale,
223
			#subsample_factor=args.random_subsample,
224
			#subsample_method="random",
225
			akd_threshold=args.akd_threshold,
226
			cnt_threshold=args.cnt_threshold,
227
			center=args.center_data,
228
			season=args.season,
229
			SPEs=args.exclude_spe)
230
231 1
	(no_ys_train, no_dens_train, no_errs_train,
232
		no_ys_test, no_dens_test, no_errs_test) = _train_test_split(
233
				no_ys, no_dens, no_errs, args.train_fraction,
234
				args.test_fraction, args.random_train_test)
235
236 1
	sza_intp = interp1d(no_ys, no_szas, bounds_error=False)
237
238 1
	max_amp = 1e10 * args.scale
239 1
	max_days = 100
240
241 1
	proxy_config = {}
242 1
	for pn, pf in proxy_dict.items():
243 1
		pt, pp = load_solar_gm_table(path.expanduser(pf),
244
				cols=[0, 1], names=["time", pn], tfmt=args.time_format)
245 1
		pv = pp[pn]
246
		# use log of proxy values if desired
247 1
		if pn in args.log_proxies.split(','):
248
			pv = np.log(pv)
249
		# normalize to sun--earth distance squared
250 1
		if pn in args.norm_proxies_distSEsq.split(','):
251
			rad_sun_earth = np.vectorize(_r_sun_earth)(pt, tfmt=args.time_format)
252
			pv /= rad_sun_earth**2
253
		# normalize by cos(SZA)
254 1
		if pn in args.norm_proxies_SZA.split(',') and sza_intp is not None:
255
			pv *= np.cos(np.radians(sza_intp(pt)))
256 1
		proxy_config.update({pn:
257
			dict(times=pt, values=pv,
258
				center=pn in args.center_proxies.split(','),
259
				positive=pn in args.positive_proxies.split(','),
260
				lag=float(lag_dict[pn]),
261
				max_amp=max_amp, max_days=max_days,
262
				sza_intp=sza_intp if args.use_sza else None,
263
			)}
264
		)
265
266 1
	model = trace_gas_model(constant=args.fit_offset,
267
			proxy_config=proxy_config, **vars(args))
268
269 1
	logging.debug("model dict: %s", model.get_parameter_dict())
270 1
	model.freeze_all_parameters()
271
	# thaw parameters according to requested fits
272 1
	for pn in proxy_dict.keys():
273 1
		model.thaw_parameter("{0}:amp".format(pn))
274 1
		if pn in fit_lags:
275
			model.thaw_parameter("{0}:lag".format(pn))
276 1
		if pn in fit_lifetimes:
277 1
			model.set_parameter("{0}:tau0".format(pn), 1e-3)
278 1
			model.thaw_parameter("{0}:tau0".format(pn))
279 1
			if pn in fit_annlifetimes:
280 1
				model.thaw_parameter("{0}:taucos1".format(pn))
281 1
				model.thaw_parameter("{0}:tausin1".format(pn))
282
		else:
283 1
			model.set_parameter("{0}:ltscan".format(pn), 0)
284 1
	for freq in freqs:
285 1
		if not args.fit_phase:
286 1
			model.thaw_parameter("f{0:.0f}:cos".format(freq))
287 1
			model.thaw_parameter("f{0:.0f}:sin".format(freq))
288
		else:
289
			model.thaw_parameter("f{0:.0f}:amp".format(freq))
290
			model.thaw_parameter("f{0:.0f}:phase".format(freq))
291 1
	if args.fit_offset:
292
		#model.set_parameter("offset:value", -100.)
293
		#model.set_parameter("offset:value", 0)
294 1
		model.thaw_parameter("offset:value")
295
296 1
	if initial is not None:
297
		model.set_parameter_vector(initial)
298
	# model.thaw_parameter("GM:ltscan")
299 1
	logging.debug("params: %s", model.get_parameter_dict())
300 1
	logging.debug("param names: %s", model.get_parameter_names())
301 1
	logging.debug("param vector: %s", model.get_parameter_vector())
302 1
	logging.debug("param bounds: %s", model.get_parameter_bounds())
303
	#logging.debug("model value: %s", model.get_value(no_ys))
304
	#logging.debug("default log likelihood: %s", model.log_likelihood(model.vector))
305
306
	# setup the Gaussian Process kernel
307 1
	kernel_base = (1e7 * args.scale)**2
308 1
	ksub = args.name_suffix
309
310 1
	solver = "basic"
311 1
	skwargs = {}
312 1
	if args.HODLR_Solver:
313
		solver = "HODLR"
314
		#skwargs = {"tol": 1e-3}
315
316 1
	if args.george:
317 1
		gpname, kernel = setup_george_kernel(kernls,
318
				kernel_base=kernel_base, fit_bias=args.fit_bias)
319 1
		gpmodel = george.GP(kernel, mean=model,
320
			white_noise=1.e-25, fit_white_noise=args.fit_white,
321
			solver=george_solvers[solver], **skwargs)
322
		# the george interface does not allow setting the bounds in
323
		# the kernel initialization so we prepare simple default bounds
324 1
		kernel_bounds = [(-0.3 * max_amp, 0.3 * max_amp)
325
				for _ in gpmodel.kernel.get_parameter_names()]
326 1
		bounds = gpmodel.mean.get_parameter_bounds() + kernel_bounds
327
	else:
328 1
		gpname, cel_terms = setup_celerite_terms(kernls,
329
				fit_bias=args.fit_bias, fit_white=args.fit_white)
330 1
		gpmodel = celerite.GP(cel_terms, mean=model,
331
			fit_white_noise=args.fit_white,
332
			fit_mean=True)
333 1
		bounds = gpmodel.get_parameter_bounds()
334 1
	gpmodel.compute(no_ys_train, no_errs_train)
335 1
	logging.debug("gpmodel params: %s", gpmodel.get_parameter_dict())
336 1
	logging.debug("gpmodel bounds: %s", bounds)
337 1
	logging.debug("initial log likelihood: %s", gpmodel.log_likelihood(no_dens_train))
338 1
	if isinstance(gpmodel, celerite.GP):
339 1
		logging.info("(GP) jitter: %s", gpmodel.kernel.jitter)
340 1
	model_name = "_".join(gpmodel.mean.get_parameter_names()).replace(':', '')
341 1
	gpmodel_name = model_name + gpname
342 1
	logging.info("GP model name: %s", gpmodel_name)
343
344 1
	pre_opt = False
345 1
	if args.optimize > 0:
346 1
		def gpmodel_mean(x, *p):
347
			gpmodel.set_parameter_vector(p)
348
			return gpmodel.mean.get_value(x)
349
350 1
		def gpmodel_res(x, *p):
351
			gpmodel.set_parameter_vector(p)
352
			return (gpmodel.mean.get_value(x) - no_dens_train) / no_errs_train
353
354 1
		def lpost(p, y, gp):
355 1
			gp.set_parameter_vector(p)
356 1
			return gp.log_likelihood(y, quiet=True) + gp.log_prior()
357
358 1
		def nlpost(p, y, gp):
359 1
			lp = lpost(p, y, gp)
360 1
			return -lp if np.isfinite(lp) else 1e25
361
362 1
		def grad_nlpost(p, y, gp):
363 1
			gp.set_parameter_vector(p)
364 1
			grad_ll = gp.grad_log_likelihood(y)
365 1
			if isinstance(grad_ll, tuple):
366
				# celerite
367 1
				return -grad_ll[1]
368
			# george
369 1
			return -grad_ll
370
371 1
		jacobian = grad_nlpost if gpmodel.kernel.vector_size else None
372 1
		if args.optimize == 1:
373 1
			resop_gp = op.minimize(
374
				nlpost,
375
				gpmodel.get_parameter_vector(),
376
				args=(no_dens_train, gpmodel),
377
				bounds=bounds,
378
				# method="l-bfgs-b", options=dict(disp=True, maxcor=100, eps=1e-9, ftol=2e-15, gtol=1e-8))
379
				method="l-bfgs-b", jac=jacobian)
380
				# method="tnc", options=dict(disp=True, maxiter=500, xtol=1e-12))
381
				# method="nelder-mead", options=dict(disp=True, maxfev=100000, fatol=1.49012e-8, xatol=1.49012e-8))
382
				# method="Powell", options=dict(ftol=1.49012e-08, xtol=1.49012e-08))
383 1
		if args.optimize == 2:
384
			resop_gp = op.differential_evolution(
385
				nlpost,
386
				bounds=bounds,
387
				args=(no_dens_train, gpmodel),
388
				popsize=2 * args.walkers, tol=0.01)
389 1
		if args.optimize == 3:
390
			resop_bh = op.basinhopping(
391
				nlpost,
392
				gpmodel.get_parameter_vector(),
393
				niter=200,
394
				minimizer_kwargs=dict(
395
					args=(no_dens_train, gpmodel),
396
					bounds=bounds,
397
					# method="tnc"))
398
					# method="l-bfgs-b", options=dict(maxcor=100)))
399
					method="l-bfgs-b", jac=jacobian))
400
					# method="Nelder-Mead"))
401
					# method="BFGS"))
402
					# method="Powell", options=dict(ftol=1.49012e-08, xtol=1.49012e-08)))
403
			logging.debug("optimization result: %s", resop_bh)
404
			resop_gp = resop_bh.lowest_optimization_result
405 1
		if args.optimize == 4:
406
			resop_gp, cov_gp = op.curve_fit(
407
				gpmodel_mean,
408
				no_ys_train, no_dens_train, gpmodel.get_parameter_vector(),
409
				bounds=tuple(np.array(bounds).T),
410
				# method='lm',
411
				# absolute_sigma=True,
412
				sigma=no_errs_train)
413
			print(resop_gp, np.sqrt(np.diag(cov_gp)))
414 1
		logging.info("%s", resop_gp.message)
415 1
		logging.debug("optimization result: %s", resop_gp)
416 1
		logging.info("gpmodel dict: %s", gpmodel.get_parameter_dict())
417 1
		logging.info("log posterior trained: %s", lpost(gpmodel.get_parameter_vector(), no_dens_train, gpmodel))
418 1
		gpmodel.compute(no_ys_test, no_errs_test)
419 1
		logging.info("log posterior test: %s", lpost(gpmodel.get_parameter_vector(), no_dens_test, gpmodel))
420 1
		gpmodel.compute(no_ys, no_errs)
421 1
		logging.info("log posterior all: %s", lpost(gpmodel.get_parameter_vector(), no_dens, gpmodel))
422
		# cross check to make sure that the gpmodel parameter vector is really
423
		# set to the fitted parameters
424 1
		logging.info("opt. model vector: %s", resop_gp.x)
425 1
		gpmodel.compute(no_ys_train, no_errs_train)
426 1
		logging.debug("opt. log posterior trained 1: %s", lpost(resop_gp.x, no_dens_train, gpmodel))
427 1
		gpmodel.compute(no_ys_test, no_errs_test)
428 1
		logging.debug("opt. log posterior test 1: %s", lpost(resop_gp.x, no_dens_test, gpmodel))
429 1
		gpmodel.compute(no_ys, no_errs)
430 1
		logging.debug("opt. log posterior all 1: %s", lpost(resop_gp.x, no_dens, gpmodel))
431 1
		logging.debug("opt. model vector: %s", gpmodel.get_parameter_vector())
432 1
		gpmodel.compute(no_ys_train, no_errs_train)
433 1
		logging.debug("opt. log posterior trained 2: %s", lpost(gpmodel.get_parameter_vector(), no_dens_train, gpmodel))
434 1
		gpmodel.compute(no_ys_test, no_errs_test)
435 1
		logging.debug("opt. log posterior test 2: %s", lpost(gpmodel.get_parameter_vector(), no_dens_test, gpmodel))
436 1
		gpmodel.compute(no_ys, no_errs)
437 1
		logging.debug("opt. log posterior all 2: %s", lpost(gpmodel.get_parameter_vector(), no_dens, gpmodel))
438 1
		pre_opt = resop_gp.success
439 1
	try:
440 1
		logging.info("GM lt: %s", gpmodel.get_parameter("mean:GM:tau0"))
441
	except ValueError:
442
		pass
443 1
	logging.info("(GP) model: %s", gpmodel.kernel)
444 1
	if isinstance(gpmodel, celerite.GP):
445 1
		logging.info("(GP) jitter: %s", gpmodel.kernel.jitter)
446
447 1
	bestfit = gpmodel.get_parameter_vector()
448 1
	filename_base = path.join(
449
		args.output_path,
450
		"NO_regress_fit_{0}_{1:.0f}_{2:.0f}_{{0}}_{3}"
451
		.format(gpmodel_name, lat * 10, alt, ksub),
452
	)
453
454 1
	if args.mcmc:
455 1
		gpmodel.compute(no_ys_train, no_errs_train)
456 1
		samples, lnp = mcmc_sample_model(gpmodel,
457
				no_dens_train,
458
				beta=1.0,
459
				nwalkers=args.walkers, nburnin=args.burn_in,
460
				nprod=args.production, nthreads=args.threads,
461
				show_progress=args.progress,
462
				optimized=pre_opt, bounds=bounds, return_logpost=True)
463
464 1
		if args.train_fraction < 1. or args.test_fraction < 1.:
465
			logging.info("Statistics for the test samples")
466
			mcmc_statistics(gpmodel,
467
					no_ys_test, no_dens_test, no_errs_test,
468
					no_ys_train, no_dens_train, no_errs_train,
469
					samples, lnp,
470
			)
471 1
		logging.info("Statistics for all samples")
472 1
		mcmc_statistics(gpmodel,
473
				no_ys, no_dens, no_errs,
474
				no_ys_train, no_dens_train, no_errs_train,
475
				samples, lnp,
476
		)
477
478 1
		sampl_percs = np.percentile(samples, [2.5, 50, 97.5], axis=0)
479 1
		if args.plot_corner:
480 1
			import corner
481
			# Corner plot of the sampled parameters
482 1
			fig = corner.corner(samples,
483
					quantiles=[0.025, 0.5, 0.975],
484
					show_titles=True,
485
					labels=gpmodel.get_parameter_names(),
486
					truths=bestfit,
487
					hist_args=dict(normed=True))
488 1
			fig.savefig(filename_base.format("corner") + ".pdf", transparent=True)
489
490 1
		if args.save_samples:
491 1
			if args.samples_format in ["npz"]:
492
				# save the samples compressed to save space.
493
				np.savez_compressed(filename_base.format("sampls") + ".npz",
494
					samples=samples)
495 1
			if args.samples_format in ["nc", "netcdf4"]:
496 1
				save_samples_netcdf(filename_base.format("sampls") + ".nc",
497
					gpmodel, alt, lat, samples, scale=args.scale, compressed=True)
498 1
			if args.samples_format in ["h5", "hdf5"]:
499
				save_samples_netcdf(filename_base.format("sampls") + ".h5",
500
					gpmodel, alt, lat, samples, scale=args.scale, compressed=True)
501
		# MCMC finished here
502
503
	# set the model times and errors to use the full data set for plotting
504 1
	gpmodel.compute(no_ys, no_errs)
505 1
	if args.save_model:
506
		try:
507
			# python 2
508
			import cPickle as pickle
509
		except ImportError:
510
			# python 3
511
			import pickle
512
		# pickle and save the model
513
		with open(filename_base.format("model") + ".pkl", "wb") as f:
514
			pickle.dump((gpmodel), f, -1)
515
516 1
	if args.plot_samples and args.mcmc:
517 1
		plot_random_samples(gpmodel, no_ys, no_dens, no_errs,
518
				samples, args.scale,
519
				filename_base.format("sampls") + ".pdf",
520
				size=4, extra_years=[4, 2])
521
522 1
	if args.plot_median:
523 1
		plot_single_sample_and_residuals(gpmodel, no_ys, no_dens, no_errs,
524
				sampl_percs[1],
525
				filename_base.format("median") + ".pdf")
526 1
	if args.plot_residuals:
527 1
		plot_residual(gpmodel, no_ys, no_dens, no_errs,
528
				sampl_percs[1], args.scale,
529
				filename_base.format("medres") + ".pdf")
530 1
	if args.plot_maxlnp:
531 1
		plot_single_sample_and_residuals(gpmodel, no_ys, no_dens, no_errs,
532
				samples[np.argmax(lnp)],
533
				filename_base.format("maxlnp") + ".pdf")
534 1
	if args.plot_maxlnpres:
535 1
		plot_residual(gpmodel, no_ys, no_dens, no_errs,
536
				samples[np.argmax(lnp)], args.scale,
537
				filename_base.format("mlpres") + ".pdf")
538
539 1
	labels = gpmodel.get_parameter_names()
540 1
	logging.info("param percentiles [2.5, 50, 97.5]:")
541 1
	for pc, label in zip(sampl_percs.T, labels):
542 1
		median = pc[1]
543 1
		pc_minus = median - pc[0]
544 1
		pc_plus = pc[2] - median
545 1
		logging.debug("%s: %s", label, pc)
546 1
		logging.info("%s: %.6f (- %.6f) (+ %.6f)", label,
547
				median, pc_minus, pc_plus)
548
549 1
	logging.info("Finished successfully.")
550
551
552 1
if __name__ == "__main__":
553
	main()
554