Passed
Push — master ( 7bc10c...b988a0 )
by Stefan
05:24
created

sciapy.regress.__main__.main()   F

Complexity

Conditions 56

Size

Total Lines 410
Code Lines 319

Duplication

Lines 410
Ratio 100 %

Code Coverage

Tests 3
CRAP Score 3069.4463

Importance

Changes 0
Metric Value
cc 56
eloc 319
nop 0
dl 410
loc 410
ccs 3
cts 227
cp 0.0132
crap 3069.4463
rs 0
c 0
b 0
f 0

How to fix   Long Method    Complexity   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Complexity

Complex classes like sciapy.regress.__main__.main() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
# -*- coding: utf-8 -*-
2
# vim:fileencoding=utf-8
3
#
4
# Copyright (c) 2017-2018 Stefan Bender
5
#
6
# This module is part of sciapy.
7
# sciapy is free software: you can redistribute it or modify
8
# it under the terms of the GNU General Public License as published
9
# by the Free Software Foundation, version 2.
10
# See accompanying LICENSE file or http://www.gnu.org/licenses/gpl-2.0.html.
11 1
"""SCIAMACHY data regression command line interface
12
13
Command line main program for regression analysis of SCIAMACHY
14
daily zonal mean time series (NO for now).
15
"""
16
17 1
import ctypes
18 1
import logging
19
20 1
import numpy as np
21 1
import scipy.optimize as op
22 1
from scipy.interpolate import interp1d
23
24 1
import george
25 1
import celerite
26
27 1
import matplotlib as mpl
28
# switch off X11 rendering
29 1
mpl.use("Agg")
30
31 1
from .load_data import load_solar_gm_table, load_scia_dzm
32 1
from .models_cel import CeleriteModelSet as NOModel
33 1
from .models_cel import ConstantModel, ProxyModel
34 1
from .models_cel import HarmonicModelCosineSine, HarmonicModelAmpPhase
35 1
from .mcmc import mcmc_sample_model
36 1
from .statistics import mcmc_statistics
37
38 1
from ._gpkernels import (george_solvers,
39
		setup_george_kernel, setup_celerite_terms)
40 1
from ._plot import (plot_single_sample_and_residuals,
41
		plot_residual, plot_random_samples)
42 1
from ._options import parser
43
44
45 1 View Code Duplication
def save_samples_netcdf(filename, model, alt, lat, samples,
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
46
		scale=1e-6,
47
		lnpost=None, compressed=False):
48
	from xarray import Dataset
49
	smpl_ds = Dataset(dict([(pname, (["lat", "alt", "sample"],
50
				samples[..., i].reshape(1, 1, -1)))
51
				for i, pname in enumerate(model.get_parameter_names())]
52
				# + [("lpost", (["lat", "alt", "sample"], lnp.reshape(1, 1, -1)))]
53
			),
54
			coords={"lat": [lat], "alt": [alt]})
55
56
	for modn in model.mean.models:
57
		modl = model.mean.models[modn]
58
		if hasattr(modl, "mean"):
59
			smpl_ds.attrs[modn + ":mean"] = modl.mean
60
61
	units = {"kernel": {
62
				"log": "log(10$^{{{0:.0f}}}$ cm$^{{-3}}$)"
63
						.format(-np.log10(scale))},
64
			"mean": {
65
				"log": "log(10$^{{{0:.0f}}}$ cm$^{{-3}}$)"
66
						.format(-np.log10(scale)),
67
				"cos": "10$^{{{0:.0f}}}$ cm$^{{-3}}$".format(-np.log10(scale)),
68
				"sin": "10$^{{{0:.0f}}}$ cm$^{{-3}}$".format(-np.log10(scale)),
69
				"val": "10$^{{{0:.0f}}}$ cm$^{{-3}}$".format(-np.log10(scale)),
70
				"amp": "10$^{{{0:.0f}}}$ cm$^{{-3}}$".format(-np.log10(scale)),
71
				"tau": "d"}}
72
	for pname in smpl_ds.data_vars:
73
		_pp = pname.split(':')
74
		for _n, _u in units[_pp[0]].items():
75
			if _pp[-1].startswith(_n):
76
				logging.debug("units for %s: %s", pname, _u)
77
				smpl_ds[pname].attrs["units"] = _u
78
79
	smpl_ds["alt"].attrs = {"long_name": "altitude", "units": "km"}
80
	smpl_ds["lat"].attrs = {"long_name": "latitude", "units": "degrees_north"}
81
82
	_encoding = None
83
	if compressed:
84
		_encoding = {var: {"zlib": True, "complevel": 1}
85
					for var in smpl_ds.data_vars}
86
	smpl_ds.to_netcdf(filename, encoding=_encoding)
87
	smpl_ds.close()
88
89
90 1 View Code Duplication
def _train_test_split(times, data, errs, train_frac,
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
91
		test_frac, randomize):
92
	# split the data into training and test subsets according to the
93
	# fraction given (default is 1, i.e. no splitting)
94
	ndata = len(times)
95
	train_size = int(ndata * train_frac)
96
	test_size = min(ndata - train_size, int(ndata * test_frac))
97
	# randomize if requested
98
	if randomize:
99
		permut_idx = np.random.permutation(np.arange(ndata))
100
	else:
101
		permut_idx = np.arange(ndata)
102
	train_idx = np.sort(permut_idx[:train_size])
103
	test_idx = np.sort(permut_idx[train_size:train_size + test_size])
104
	times_train = times[train_idx]
105
	data_train = data[train_idx]
106
	errs_train = errs[train_idx]
107
	if test_size > 0:
108
		times_test = times[test_idx]
109
		data_test = data[test_idx]
110
		errs_test = errs[test_idx]
111
	else:
112
		times_test = times
113
		data_test = data
114
		errs_test = errs
115
	logging.info("using %s of %s samples for training.", len(times_train), ndata)
116
	logging.info("using %s of %s samples for testing.", len(times_test), ndata)
117
	return (times_train, data_train, errs_train,
118
			times_test, data_test, errs_test)
119
120
121 1 View Code Duplication
def _r_sun_earth(time, tfmt="jyear"):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
122
	"""First order approximation of the Sun-Earth distance
123
124
	The Sun-to-Earth distance can be used to (un-)normalize proxies
125
	to the actual distance to the Sun instead of 1 AU.
126
127
	Parameters
128
	----------
129
	time : float
130
		Time value in the units given by 'tfmt'.
131
	tfmt : str, optional
132
		The units of 'time' as supported by the
133
		astropy.time time formats. Default: 'jyear'.
134
135
	Returns
136
	-------
137
	dist : float
138
		The Sun-Earth distance at the given day of year in AU.
139
	"""
140
	from astropy.time import Time
141
	tdoy = Time(time, format=tfmt)
142
	tdoy.format = "yday"
143
	doy = int(tdoy.value.split(':')[1])
144
	return 1 - 0.01672 * np.cos(2 * np.pi / 365.256363 * (doy - 4))
145
146
147 1 View Code Duplication
def main():
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
148 1
	logging.basicConfig(level=logging.WARNING,
149
			format="[%(levelname)-8s] (%(asctime)s) "
150
				"%(filename)s:%(lineno)d %(message)s",
151
			datefmt="%Y-%m-%d %H:%M:%S %z")
152
153 1
	args = parser.parse_args()
154
155
	logging.info("command line arguments: %s", args)
156
	if args.quiet:
157
		logging.getLogger().setLevel(logging.ERROR)
158
	elif args.verbose:
159
		logging.getLogger().setLevel(logging.INFO)
160
	else:
161
		logging.getLogger().setLevel(args.loglevel)
162
163
	from numpy.distutils.system_info import get_info
164
	for oblas_path in get_info("openblas")["library_dirs"]:
165
		oblas_name = "{0}/libopenblas.so".format(oblas_path)
166
		logging.info("Trying %s", oblas_name)
167
		try:
168
			oblas_lib = ctypes.cdll.LoadLibrary(oblas_name)
169
			oblas_cores = oblas_lib.openblas_get_num_threads()
170
			oblas_lib.openblas_set_num_threads(args.openblas_threads)
171
			logging.info("Using %s/%s Openblas thread(s).",
172
					oblas_lib.openblas_get_num_threads(), oblas_cores)
173
		except:
174
			logging.info("Setting number of openblas threads failed.")
175
176
	if args.random_seed is not None:
177
		np.random.seed(args.random_seed)
178
179
	if args.proxies:
180
		proxies = args.proxies.split(',')
181
		proxy_dict = dict(_p.split(':') for _p in proxies)
182
	else:
183
		proxy_dict = {}
184
	lag_dict = {pn: 0 for pn in proxy_dict.keys()}
185
186
	# Post-processing of arguments...
187
	# List of proxy lag fits from csv
188
	fit_lags = args.fit_lags.split(',')
189
	# List of proxy lifetime fits from csv
190
	fit_lifetimes = args.fit_lifetimes.split(',')
191
	fit_annlifetimes = args.fit_annlifetimes.split(',')
192
	# List of proxy lag times from csv
193
	lag_dict.update(dict(_ls.split(':') for _ls in args.lag_times.split(',')))
194
	# List of cycles (frequencies in 1/year) from argument list (csv)
195
	try:
196
		freqs = list(map(float, args.freqs.split(',')))
197
	except ValueError:
198
		freqs = []
199
	# List of initial parameter values
200
	initial = None
201
	if args.initial is not None:
202
		try:
203
			initial = list(map(float, args.initial.split(',')))
204
		except ValueError:
205
			pass
206
	# List of GP kernels from argument list (csv)
207
	kernls = args.kernels.split(',')
208
209
	lat = args.latitude
210
	alt = args.altitude
211
	logging.info("location: %.0f°N %.0f km", lat, alt)
212
213
	no_ys, no_dens, no_errs, no_szas = load_scia_dzm(args.file, alt, lat,
214
			tfmt=args.time_format,
215
			scale=args.scale,
216
			#subsample_factor=args.random_subsample,
217
			#subsample_method="random",
218
			akd_threshold=args.akd_threshold,
219
			cnt_threshold=args.cnt_threshold,
220
			center=args.center_data,
221
			season=args.season,
222
			SPEs=args.exclude_spe)
223
224
	(no_ys_train, no_dens_train, no_errs_train,
225
		no_ys_test, no_dens_test, no_errs_test) = _train_test_split(
226
				no_ys, no_dens, no_errs, args.train_fraction,
227
				args.test_fraction, args.random_train_test)
228
229
	sza_intp = interp1d(no_ys, no_szas, bounds_error=False)
230
231
	max_amp = 1e10 * args.scale
232
	max_days = 100
233
234
	harmonic_models = []
235
	for freq in freqs:
236
		if not args.fit_phase:
237
			harmonic_models.append(("f{0:.0f}".format(freq),
238
				HarmonicModelCosineSine(freq=freq,
239
					cos=0, sin=0,
240
					bounds=dict([
241
						("cos", [-max_amp, max_amp]),
242
						("sin", [-max_amp, max_amp])])
243
			)))
244
		else:
245
			harmonic_models.append(("f{0:.0f}".format(freq),
246
				HarmonicModelAmpPhase(freq=freq,
247
					amp=0, phase=0,
248
					bounds=dict([
249
						# ("amp", [-max_amp, max_amp]),
250
						("amp", [0, max_amp]),
251
						("phase", [-np.pi, np.pi])])
252
			)))
253
	proxy_models = []
254
	for pn, pf in proxy_dict.items():
255
		pt, pp = load_solar_gm_table(pf, cols=[0, 1], names=["time", pn], tfmt=args.time_format)
256
		# use log of proxy values
257
		pv = np.log(pp[pn]) if pn in args.log_proxies.split(',') else pp[pn]
258
		# normalize to sun--earth distance squared
259
		if pn in args.norm_proxies_distSEsq.split(','):
260
			rad_sun_earth = np.vectorize(_r_sun_earth)(pt, tfmt=args.time_format)
261
			pv /= rad_sun_earth**2
262
		# normalize by cos(SZA)
263
		if pn in args.norm_proxies_SZA.split(',') and sza_intp is not None:
264
			pv *= np.cos(np.radians(sza_intp(pt)))
265
		proxy_models.append((pn,
266
			_setup_proxy_model_with_bounds(pt, pv,
267
				center=pn in args.center_proxies.split(','),
268
				positive=pn in args.positive_proxies.split(','),
269
				lag=float(lag_dict[pn]),
270
				max_amp=max_amp, max_days=max_days,
271
				sza_intp=sza_intp if args.use_sza else None,
272
				**vars(args),
273
			))
274
		)
275
		logging.info("%s mean: %s", pn, proxy_models[-1][1].mean)
276
	offset_model = [("offset",
277
			ConstantModel(value=0.,
278
					bounds={"value": [-max_amp, max_amp]}))]
279
280
	model = NOModel(offset_model + harmonic_models + proxy_models)
281
282
	logging.debug("model dict: %s", model.get_parameter_dict())
283
	model.freeze_all_parameters()
284
	# thaw parameters according to requested fits
285
	for pn in proxy_dict.keys():
286
		model.thaw_parameter("{0}:amp".format(pn))
287
		if pn in fit_lags:
288
			model.thaw_parameter("{0}:lag".format(pn))
289
		if pn in fit_lifetimes:
290
			model.set_parameter("{0}:tau0".format(pn), 1e-3)
291
			model.thaw_parameter("{0}:tau0".format(pn))
292
			if pn in fit_annlifetimes:
293
				model.thaw_parameter("{0}:taucos1".format(pn))
294
				model.thaw_parameter("{0}:tausin1".format(pn))
295
	for freq in freqs:
296
		if not args.fit_phase:
297
			model.thaw_parameter("f{0:.0f}:cos".format(freq))
298
			model.thaw_parameter("f{0:.0f}:sin".format(freq))
299
		else:
300
			model.thaw_parameter("f{0:.0f}:amp".format(freq))
301
			model.thaw_parameter("f{0:.0f}:phase".format(freq))
302
	if args.fit_offset:
303
		#model.set_parameter("offset:value", -100.)
304
		#model.set_parameter("offset:value", 0)
305
		model.thaw_parameter("offset:value")
306
307
	if initial is not None:
308
		model.set_parameter_vector(initial)
309
	# model.thaw_parameter("GM:ltscan")
310
	logging.debug("params: %s", model.get_parameter_dict())
311
	logging.debug("param names: %s", model.get_parameter_names())
312
	logging.debug("param vector: %s", model.get_parameter_vector())
313
	logging.debug("param bounds: %s", model.get_parameter_bounds())
314
	#logging.debug("model value: %s", model.get_value(no_ys))
315
	#logging.debug("default log likelihood: %s", model.log_likelihood(model.vector))
316
317
	# setup the Gaussian Process kernel
318
	kernel_base = (1e7 * args.scale)**2
319
	ksub = args.name_suffix
320
321
	solver = "basic"
322
	skwargs = {}
323
	if args.HODLR_Solver:
324
		solver = "HODLR"
325
		#skwargs = {"tol": 1e-3}
326
327
	if args.george:
328
		gpname, kernel = setup_george_kernel(kernls,
329
				kernel_base=kernel_base, fit_bias=args.fit_bias)
330
		gpmodel = george.GP(kernel, mean=model,
331
			white_noise=1.e-25, fit_white_noise=args.fit_white,
332
			solver=george_solvers[solver], **skwargs)
333
		# the george interface does not allow setting the bounds in
334
		# the kernel initialization so we prepare simple default bounds
335
		kernel_bounds = [(-0.3 * max_amp, 0.3 * max_amp)
336
				for _ in gpmodel.kernel.get_parameter_names()]
337
		bounds = gpmodel.mean.get_parameter_bounds() + kernel_bounds
338
	else:
339
		gpname, cel_terms = setup_celerite_terms(kernls,
340
				fit_bias=args.fit_bias, fit_white=args.fit_white)
341
		gpmodel = celerite.GP(cel_terms, mean=model,
342
			fit_white_noise=args.fit_white,
343
			fit_mean=True)
344
		bounds = gpmodel.get_parameter_bounds()
345
	gpmodel.compute(no_ys_train, no_errs_train)
346
	logging.debug("gpmodel params: %s", gpmodel.get_parameter_dict())
347
	logging.debug("gpmodel bounds: %s", bounds)
348
	logging.debug("initial log likelihood: %s", gpmodel.log_likelihood(no_dens_train))
349
	if isinstance(gpmodel, celerite.GP):
350
		logging.info("(GP) jitter: %s", gpmodel.kernel.jitter)
351
	model_name = "_".join(gpmodel.mean.get_parameter_names()).replace(':', '')
352
	gpmodel_name = model_name + gpname
353
	logging.info("GP model name: %s", gpmodel_name)
354
355
	pre_opt = False
356
	if args.optimize > 0:
357
		def gpmodel_mean(x, *p):
358
			gpmodel.set_parameter_vector(p)
359
			return gpmodel.mean.get_value(x)
360
361
		def gpmodel_res(x, *p):
362
			gpmodel.set_parameter_vector(p)
363
			return (gpmodel.mean.get_value(x) - no_dens_train) / no_errs_train
364
365
		def lpost(p, y, gp):
366
			gp.set_parameter_vector(p)
367
			return gp.log_likelihood(y, quiet=True) + gp.log_prior()
368
369
		def nlpost(p, y, gp):
370
			lp = lpost(p, y, gp)
371
			return -lp if np.isfinite(lp) else 1e25
372
373
		def grad_nlpost(p, y, gp):
374
			gp.set_parameter_vector(p)
375
			grad_ll = gp.grad_log_likelihood(y)
376
			if isinstance(grad_ll, tuple):
377
				# celerite
378
				return -grad_ll[1]
379
			# george
380
			return -grad_ll
381
382
		if args.optimize == 1:
383
			resop_gp = op.minimize(
384
				nlpost,
385
				gpmodel.get_parameter_vector(),
386
				args=(no_dens_train, gpmodel),
387
				bounds=bounds,
388
				# method="l-bfgs-b", options=dict(disp=True, maxcor=100, eps=1e-9, ftol=2e-15, gtol=1e-8))
389
				method="l-bfgs-b", jac=grad_nlpost)
390
				# method="tnc", options=dict(disp=True, maxiter=500, xtol=1e-12))
391
				# method="nelder-mead", options=dict(disp=True, maxfev=100000, fatol=1.49012e-8, xatol=1.49012e-8))
392
				# method="Powell", options=dict(ftol=1.49012e-08, xtol=1.49012e-08))
393
		if args.optimize == 2:
394
			resop_gp = op.differential_evolution(
395
				nlpost,
396
				bounds=bounds,
397
				args=(no_dens_train, gpmodel),
398
				popsize=2 * args.walkers, tol=0.01)
399
		if args.optimize == 3:
400
			resop_bh = op.basinhopping(
401
				nlpost,
402
				gpmodel.get_parameter_vector(),
403
				niter=200,
404
				minimizer_kwargs=dict(
405
					args=(no_dens_train, gpmodel),
406
					bounds=bounds,
407
					# method="tnc"))
408
					# method="l-bfgs-b", options=dict(maxcor=100)))
409
					method="l-bfgs-b", jac=grad_nlpost))
410
					# method="Nelder-Mead"))
411
					# method="BFGS"))
412
					# method="Powell", options=dict(ftol=1.49012e-08, xtol=1.49012e-08)))
413
			logging.debug("optimization result: %s", resop_bh)
414
			resop_gp = resop_bh.lowest_optimization_result
415
		if args.optimize == 4:
416
			resop_gp, cov_gp = op.curve_fit(
417
				gpmodel_mean,
418
				no_ys_train, no_dens_train, gpmodel.get_parameter_vector(),
419
				bounds=tuple(np.array(bounds).T),
420
				# method='lm',
421
				# absolute_sigma=True,
422
				sigma=no_errs_train)
423
			print(resop_gp, np.sqrt(np.diag(cov_gp)))
424
		logging.info("%s", resop_gp.message)
425
		logging.debug("optimization result: %s", resop_gp)
426
		logging.info("gpmodel dict: %s", gpmodel.get_parameter_dict())
427
		logging.info("log posterior trained: %s", lpost(gpmodel.get_parameter_vector(), no_dens_train, gpmodel))
428
		gpmodel.compute(no_ys_test, no_errs_test)
429
		logging.info("log posterior test: %s", lpost(gpmodel.get_parameter_vector(), no_dens_test, gpmodel))
430
		gpmodel.compute(no_ys, no_errs)
431
		logging.info("log posterior all: %s", lpost(gpmodel.get_parameter_vector(), no_dens, gpmodel))
432
		# cross check to make sure that the gpmodel parameter vector is really
433
		# set to the fitted parameters
434
		logging.info("opt. model vector: %s", resop_gp.x)
435
		gpmodel.compute(no_ys_train, no_errs_train)
436
		logging.debug("opt. log posterior trained 1: %s", lpost(resop_gp.x, no_dens_train, gpmodel))
437
		gpmodel.compute(no_ys_test, no_errs_test)
438
		logging.debug("opt. log posterior test 1: %s", lpost(resop_gp.x, no_dens_test, gpmodel))
439
		gpmodel.compute(no_ys, no_errs)
440
		logging.debug("opt. log posterior all 1: %s", lpost(resop_gp.x, no_dens, gpmodel))
441
		logging.debug("opt. model vector: %s", gpmodel.get_parameter_vector())
442
		gpmodel.compute(no_ys_train, no_errs_train)
443
		logging.debug("opt. log posterior trained 2: %s", lpost(gpmodel.get_parameter_vector(), no_dens_train, gpmodel))
444
		gpmodel.compute(no_ys_test, no_errs_test)
445
		logging.debug("opt. log posterior test 2: %s", lpost(gpmodel.get_parameter_vector(), no_dens_test, gpmodel))
446
		gpmodel.compute(no_ys, no_errs)
447
		logging.debug("opt. log posterior all 2: %s", lpost(gpmodel.get_parameter_vector(), no_dens, gpmodel))
448
		pre_opt = resop_gp.success
449
	try:
450
		logging.info("GM lt: %s", gpmodel.get_parameter("mean:GM:tau0"))
451
	except ValueError:
452
		pass
453
	logging.info("(GP) model: %s", gpmodel.kernel)
454
	if isinstance(gpmodel, celerite.GP):
455
		logging.info("(GP) jitter: %s", gpmodel.kernel.jitter)
456
457
	bestfit = gpmodel.get_parameter_vector()
458
	filename_base = ("NO_regress_fit_{0}_{1:.0f}_{2:.0f}_{{0}}_{3}"
459
					.format(gpmodel_name, lat * 10, alt, ksub))
460
461
	if args.mcmc:
462
		gpmodel.compute(no_ys_train, no_errs_train)
463
		samples, lnp = mcmc_sample_model(gpmodel,
464
				no_dens_train,
465
				beta=1.0,
466
				nwalkers=args.walkers, nburnin=args.burn_in,
467
				nprod=args.production, nthreads=args.threads,
468
				show_progress=args.progress,
469
				optimized=pre_opt, bounds=bounds, return_logpost=True)
470
471
		if args.train_fraction < 1. or args.test_fraction < 1.:
472
			logging.info("Statistics for the test samples")
473
			mcmc_statistics(gpmodel,
474
					no_ys_test, no_dens_test, no_errs_test,
475
					no_ys_train, no_dens_train, no_errs_train,
476
					samples, lnp,
477
			)
478
		logging.info("Statistics for all samples")
479
		mcmc_statistics(gpmodel,
480
				no_ys, no_dens, no_errs,
481
				no_ys_train, no_dens_train, no_errs_train,
482
				samples, lnp,
483
		)
484
485
		sampl_percs = np.percentile(samples, [2.5, 50, 97.5], axis=0)
486
		if args.plot_corner:
487
			import corner
488
			# Corner plot of the sampled parameters
489
			fig = corner.corner(samples,
490
					quantiles=[0.025, 0.5, 0.975],
491
					show_titles=True,
492
					labels=gpmodel.get_parameter_names(),
493
					truths=bestfit,
494
					hist_args=dict(normed=True))
495
			fig.savefig(filename_base.format("corner") + ".pdf", transparent=True)
496
497
		if args.save_samples:
498
			if args.samples_format in ["npz"]:
499
				# save the samples compressed to save space.
500
				np.savez_compressed(filename_base.format("sampls") + ".npz",
501
					samples=samples)
502
			if args.samples_format in ["nc", "netcdf4"]:
503
				save_samples_netcdf(filename_base.format("sampls") + ".nc",
504
					gpmodel, alt, lat, samples, scale=args.scale, compressed=True)
505
			if args.samples_format in ["h5", "hdf5"]:
506
				save_samples_netcdf(filename_base.format("sampls") + ".h5",
507
					gpmodel, alt, lat, samples, scale=args.scale, compressed=True)
508
		# MCMC finished here
509
510
	# set the model times and errors to use the full data set for plotting
511
	gpmodel.compute(no_ys, no_errs)
512
	if args.save_model:
513
		try:
514
			# python 2
515
			import cPickle as pickle
516
		except ImportError:
517
			# python 3
518
			import pickle
519
		# pickle and save the model
520
		with open(filename_base.format("model") + ".pkl", "wb") as f:
521
			pickle.dump((gpmodel), f, -1)
522
523
	if args.plot_samples and args.mcmc:
524
		plot_random_samples(gpmodel, no_ys, no_dens, no_errs,
525
				samples, args.scale,
526
				filename_base.format("sampls") + ".pdf",
527
				size=4, extra_years=[4, 2])
528
529
	if args.plot_median:
530
		plot_single_sample_and_residuals(gpmodel, no_ys, no_dens, no_errs,
531
				sampl_percs[1],
532
				filename_base.format("median") + ".pdf")
533
	if args.plot_residuals:
534
		plot_residual(gpmodel, no_ys, no_dens, no_errs,
535
				sampl_percs[1], args.scale,
536
				filename_base.format("medres") + ".pdf")
537
	if args.plot_maxlnp:
538
		plot_single_sample_and_residuals(gpmodel, no_ys, no_dens, no_errs,
539
				samples[np.argmax(lnp)],
540
				filename_base.format("maxlnp") + ".pdf")
541
	if args.plot_maxlnpres:
542
		plot_residual(gpmodel, no_ys, no_dens, no_errs,
543
				samples[np.argmax(lnp)], args.scale,
544
				filename_base.format("mlpres") + ".pdf")
545
546
	labels = gpmodel.get_parameter_names()
547
	logging.info("param percentiles [2.5, 50, 97.5]:")
548
	for pc, label in zip(sampl_percs.T, labels):
549
		median = pc[1]
550
		pc_minus = median - pc[0]
551
		pc_plus = pc[2] - median
552
		logging.debug("%s: %s", label, pc)
553
		logging.info("%s: %.6f (- %.6f) (+ %.6f)", label,
554
				median, pc_minus, pc_plus)
555
556
	logging.info("Finished successfully.")
557
558
559 1
if __name__ == "__main__":
560
	main()
561