Passed
Branch master (9c78f3)
by Stefan
06:59
created

sciapy.regress.__main__   F

Complexity

Total Complexity 72

Size/Duplication

Total Lines 573
Duplicated Lines 90.4 %

Test Coverage

Coverage 9.09%

Importance

Changes 0
Metric Value
eloc 425
dl 518
loc 573
ccs 27
cts 297
cp 0.0909
rs 2.64
c 0
b 0
f 0
wmc 72

4 Functions

Rating   Name   Duplication   Size   Complexity  
A _r_sun_earth() 24 24 1
A _train_test_split() 29 29 3
F main() 424 424 61
B save_samples_netcdf() 41 41 7

How to fix   Duplicated Code    Complexity   

Duplicated Code

Duplicate code is one of the most pungent code smells. A rule that is often used is to re-structure code once it is duplicated in three or more places.

Common duplication problems, and corresponding solutions are:

Complexity

 Tip:   Before tackling complexity, make sure that you eliminate any duplication first. This often can reduce the size of classes significantly.

Complex classes like sciapy.regress.__main__ often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
# -*- coding: utf-8 -*-
2
# vim:fileencoding=utf-8
3
#
4
# Copyright (c) 2017-2018 Stefan Bender
5
#
6
# This module is part of sciapy.
7
# sciapy is free software: you can redistribute it or modify
8
# it under the terms of the GNU General Public License as published
9
# by the Free Software Foundation, version 2.
10
# See accompanying LICENSE file or http://www.gnu.org/licenses/gpl-2.0.html.
11 1
"""SCIAMACHY data regression command line interface
12
13
Command line main program for regression analysis of SCIAMACHY
14
daily zonal mean time series (NO for now).
15
"""
16
17 1
import ctypes
18 1
import logging
19
20 1
import numpy as np
21 1
import scipy.optimize as op
22 1
from scipy.interpolate import interp1d
23
24 1
import george
25 1
import celerite
26
27 1
import matplotlib as mpl
28
# switch off X11 rendering
29 1
mpl.use("Agg")
30
31 1
from .load_data import load_solar_gm_table, load_scia_dzm
32 1
from .models_cel import CeleriteModelSet as NOModel
33 1
from .models_cel import ConstantModel, ProxyModel
34 1
from .models_cel import HarmonicModelCosineSine, HarmonicModelAmpPhase
35 1
from .mcmc import mcmc_sample_model
36 1
from .statistics import mcmc_statistics
37
38 1
from ._gpkernels import (george_solvers,
39
		setup_george_kernel, setup_celerite_terms)
40 1
from ._plot import (plot_single_sample_and_residuals,
41
		plot_residual, plot_random_samples)
42 1
from ._options import parser
43
44
45 1 View Code Duplication
def save_samples_netcdf(filename, model, alt, lat, samples,
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
46
		scale=1e-6,
47
		lnpost=None, compressed=False):
48
	from xarray import Dataset
49
	smpl_ds = Dataset(dict([(pname, (["lat", "alt", "sample"],
50
				samples[..., i].reshape(1, 1, -1)))
51
				for i, pname in enumerate(model.get_parameter_names())]
52
				# + [("lpost", (["lat", "alt", "sample"], lnp.reshape(1, 1, -1)))]
53
			),
54
			coords={"lat": [lat], "alt": [alt]})
55
56
	for modn in model.mean.models:
57
		modl = model.mean.models[modn]
58
		if hasattr(modl, "mean"):
59
			smpl_ds.attrs[modn + ":mean"] = modl.mean
60
61
	units = {"kernel": {
62
				"log": "log(10$^{{{0:.0f}}}$ cm$^{{-3}}$)"
63
						.format(-np.log10(scale))},
64
			"mean": {
65
				"log": "log(10$^{{{0:.0f}}}$ cm$^{{-3}}$)"
66
						.format(-np.log10(scale)),
67
				"val": "10$^{{{0:.0f}}}$ cm$^{{-3}}$".format(-np.log10(scale)),
68
				"amp": "10$^{{{0:.0f}}}$ cm$^{{-3}}$".format(-np.log10(scale)),
69
				"tau": "d"}}
70
	for pname in smpl_ds.data_vars:
71
		_pp = pname.split(':')
72
		for _n, _u in units[_pp[0]].items():
73
			if _pp[-1].startswith(_n):
74
				logging.debug("units for %s: %s", pname, _u)
75
				smpl_ds[pname].attrs["units"] = _u
76
77
	smpl_ds["alt"].attrs = {"long_name": "altitude", "units": "km"}
78
	smpl_ds["lat"].attrs = {"long_name": "latitude", "units": "degrees_north"}
79
80
	_encoding = None
81
	if compressed:
82
		_encoding = {var: {"zlib": True, "complevel": 1}
83
					for var in smpl_ds.data_vars}
84
	smpl_ds.to_netcdf(filename, encoding=_encoding)
85
	smpl_ds.close()
86
87
88 1 View Code Duplication
def _train_test_split(times, data, errs, train_frac,
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
89
		test_frac, randomize):
90
	# split the data into training and test subsets according to the
91
	# fraction given (default is 1, i.e. no splitting)
92
	ndata = len(times)
93
	train_size = int(ndata * train_frac)
94
	test_size = min(ndata - train_size, int(ndata * test_frac))
95
	# randomize if requested
96
	if randomize:
97
		permut_idx = np.random.permutation(np.arange(ndata))
98
	else:
99
		permut_idx = np.arange(ndata)
100
	train_idx = np.sort(permut_idx[:train_size])
101
	test_idx = np.sort(permut_idx[train_size:train_size + test_size])
102
	times_train = times[train_idx]
103
	data_train = data[train_idx]
104
	errs_train = errs[train_idx]
105
	if test_size > 0:
106
		times_test = times[test_idx]
107
		data_test = data[test_idx]
108
		errs_test = errs[test_idx]
109
	else:
110
		times_test = times
111
		data_test = data
112
		errs_test = errs
113
	logging.info("using %s of %s samples for training.", len(times_train), ndata)
114
	logging.info("using %s of %s samples for testing.", len(times_test), ndata)
115
	return (times_train, data_train, errs_train,
116
			times_test, data_test, errs_test)
117
118
119 1 View Code Duplication
def _r_sun_earth(time, tfmt="jyear"):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
120
	"""First order approximation of the Sun-Earth distance
121
122
	The Sun-to-Earth distance can be used to (un-)normalize proxies
123
	to the actual distance to the Sun instead of 1 AU.
124
125
	Parameters
126
	----------
127
	time : float
128
		Time value in the units given by 'tfmt'.
129
	tfmt : str, optional
130
		The units of 'time' as supported by the
131
		astropy.time time formats. Default: 'jyear'.
132
133
	Returns
134
	-------
135
	dist : float
136
		The Sun-Earth distance at the given day of year in AU.
137
	"""
138
	from astropy.time import Time
139
	tdoy = Time(time, format=tfmt)
140
	tdoy.format = "yday"
141
	doy = int(tdoy.value.split(':')[1])
142
	return 1 - 0.01672 * np.cos(2 * np.pi / 365.256363 * (doy - 4))
143
144
145 1 View Code Duplication
def main():
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
146 1
	logging.basicConfig(level=logging.WARNING,
147
			format="[%(levelname)-8s] (%(asctime)s) "
148
				"%(filename)s:%(lineno)d %(message)s",
149
			datefmt="%Y-%m-%d %H:%M:%S %z")
150
151 1
	args = parser.parse_args()
152
153
	logging.info("command line arguments: %s", args)
154
	if args.quiet:
155
		logging.getLogger().setLevel(logging.ERROR)
156
	elif args.verbose:
157
		logging.getLogger().setLevel(logging.INFO)
158
	else:
159
		logging.getLogger().setLevel(args.loglevel)
160
161
	from numpy.distutils.system_info import get_info
162
	for oblas_path in get_info("openblas")["library_dirs"]:
163
		oblas_name = "{0}/libopenblas.so".format(oblas_path)
164
		logging.info("Trying %s", oblas_name)
165
		try:
166
			oblas_lib = ctypes.cdll.LoadLibrary(oblas_name)
167
			oblas_cores = oblas_lib.openblas_get_num_threads()
168
			oblas_lib.openblas_set_num_threads(args.openblas_threads)
169
			logging.info("Using %s/%s Openblas thread(s).",
170
					oblas_lib.openblas_get_num_threads(), oblas_cores)
171
		except:
172
			logging.info("Setting number of openblas threads failed.")
173
174
	if args.random_seed is not None:
175
		np.random.seed(args.random_seed)
176
177
	if args.proxies:
178
		proxies = args.proxies.split(',')
179
		proxy_dict = dict(_p.split(':') for _p in proxies)
180
	else:
181
		proxy_dict = {}
182
	lag_dict = {pn: 0 for pn in proxy_dict.keys()}
183
184
	# Post-processing of arguments...
185
	# List of proxy lag fits from csv
186
	fit_lags = args.fit_lags.split(',')
187
	# List of proxy lifetime fits from csv
188
	fit_lifetimes = args.fit_lifetimes.split(',')
189
	fit_annlifetimes = args.fit_annlifetimes.split(',')
190
	# List of proxy lag times from csv
191
	lag_dict.update(dict(_ls.split(':') for _ls in args.lag_times.split(',')))
192
	# List of cycles (frequencies in 1/year) from argument list (csv)
193
	try:
194
		freqs = list(map(float, args.freqs.split(',')))
195
	except ValueError:
196
		freqs = []
197
	# List of initial parameter values
198
	initial = None
199
	if args.initial is not None:
200
		try:
201
			initial = list(map(float, args.initial.split(',')))
202
		except ValueError:
203
			pass
204
	# List of GP kernels from argument list (csv)
205
	kernls = args.kernels.split(',')
206
207
	lat = args.latitude
208
	alt = args.altitude
209
	logging.info("location: %.0f°N %.0f km", lat, alt)
210
211
	no_ys, no_dens, no_errs, no_szas = load_scia_dzm(args.file, alt, lat,
212
			tfmt=args.time_format,
213
			scale=args.scale,
214
			#subsample_factor=args.random_subsample,
215
			#subsample_method="random",
216
			akd_threshold=args.akd_threshold,
217
			cnt_threshold=args.cnt_threshold,
218
			center=args.center_data,
219
			season=args.season,
220
			SPEs=args.exclude_spe)
221
222
	(no_ys_train, no_dens_train, no_errs_train,
223
		no_ys_test, no_dens_test, no_errs_test) = _train_test_split(
224
				no_ys, no_dens, no_errs, args.train_fraction,
225
				args.test_fraction, args.random_train_test)
226
227
	sza_intp = interp1d(no_ys, no_szas, fill_value="extrapolate")
228
229
	max_amp = 1e10 * args.scale
230
	max_days = 100
231
232
	harmonic_models = []
233
	for freq in freqs:
234
		if not args.fit_phase:
235
			harmonic_models.append(("f{0:.0f}".format(freq),
236
				HarmonicModelCosineSine(freq=freq,
237
					cos=0, sin=0,
238
					bounds=dict([
239
						("cos", [-max_amp, max_amp]),
240
						("sin", [-max_amp, max_amp])])
241
			)))
242
		else:
243
			harmonic_models.append(("f{0:.0f}".format(freq),
244
				HarmonicModelAmpPhase(freq=freq,
245
					amp=0, phase=0,
246
					bounds=dict([
247
						# ("amp", [-max_amp, max_amp]),
248
						("amp", [0, max_amp]),
249
						("phase", [-np.pi, np.pi])])
250
			)))
251
	proxy_models = []
252
	for pn, pf in proxy_dict.items():
253
		pt, pp = load_solar_gm_table(pf, cols=[0, 1], names=["time", pn], tfmt=args.time_format)
254
		pv = np.log(pp[pn]) if pn in args.log_proxies.split(',') else pp[pn]
255
		if pn in args.norm_proxies_distSEsq:
256
			rad_sun_earth = np.vectorize(_r_sun_earth)(pt, tfmt=args.time_format)
257
			pv /= rad_sun_earth**2
258
		if pn in args.norm_proxies_SZA:
259
			pv *= np.cos(np.radians(sza_intp(pt)))
260
		proxy_models.append((pn,
261
			ProxyModel(pt, pv,
262
				center=pn in args.center_proxies.split(','),
263
				sza_intp=sza_intp if args.use_sza else None,
264
				fit_phase=args.fit_phase,
265
				lifetime_prior=args.lifetime_prior,
266
				lifetime_metric=args.lifetime_metric,
267
				days_per_time_unit=1 if args.time_format.endswith("d") else 365.25,
268
				amp=0.,
269
				lag=float(lag_dict[pn]),
270
				tau0=0,
271
				taucos1=0, tausin1=0,
272
				taucos2=0, tausin2=0,
273
				ltscan=args.lifetime_scan,
274
				bounds=dict([
275
					("amp",
276
						[0, max_amp] if pn in args.positive_proxies.split(',')
277
						else [-max_amp, max_amp]),
278
					("lag", [0, max_days]),
279
					("tau0", [0, max_days]),
280
					("taucos1", [0, max_days] if args.fit_phase else [-max_days, max_days]),
281
					("tausin1", [-np.pi, np.pi] if args.fit_phase else [-max_days, max_days]),
282
					# semi-annual cycles for the life time
283
					("taucos2", [0, max_days] if args.fit_phase else [-max_days, max_days]),
284
					("tausin2", [-np.pi, np.pi] if args.fit_phase else [-max_days, max_days]),
285
					("ltscan", [0, 200])])
286
			)))
287
		logging.info("%s mean: %s", pn, proxy_models[-1][1].mean)
288
	offset_model = [("offset",
289
			ConstantModel(value=0.,
290
					bounds={"value": [-max_amp, max_amp]}))]
291
292
	model = NOModel(offset_model + harmonic_models + proxy_models)
293
294
	logging.debug("model dict: %s", model.get_parameter_dict())
295
	model.freeze_all_parameters()
296
	# thaw parameters according to requested fits
297
	for pn in proxy_dict.keys():
298
		model.thaw_parameter("{0}:amp".format(pn))
299
		if pn in fit_lags:
300
			model.thaw_parameter("{0}:lag".format(pn))
301
		if pn in fit_lifetimes:
302
			model.set_parameter("{0}:tau0".format(pn), 1e-3)
303
			model.thaw_parameter("{0}:tau0".format(pn))
304
			if pn in fit_annlifetimes:
305
				model.thaw_parameter("{0}:taucos1".format(pn))
306
				model.thaw_parameter("{0}:tausin1".format(pn))
307
	for freq in freqs:
308
		if not args.fit_phase:
309
			model.thaw_parameter("f{0:.0f}:cos".format(freq))
310
			model.thaw_parameter("f{0:.0f}:sin".format(freq))
311
		else:
312
			model.thaw_parameter("f{0:.0f}:amp".format(freq))
313
			model.thaw_parameter("f{0:.0f}:phase".format(freq))
314
	if args.fit_offset:
315
		#model.set_parameter("offset:value", -100.)
316
		#model.set_parameter("offset:value", 0)
317
		model.thaw_parameter("offset:value")
318
319
	if initial is not None:
320
		model.set_parameter_vector(initial)
321
	# model.thaw_parameter("GM:ltscan")
322
	logging.debug("params: %s", model.get_parameter_dict())
323
	logging.debug("param names: %s", model.get_parameter_names())
324
	logging.debug("param vector: %s", model.get_parameter_vector())
325
	logging.debug("param bounds: %s", model.get_parameter_bounds())
326
	#logging.debug("model value: %s", model.get_value(no_ys))
327
	#logging.debug("default log likelihood: %s", model.log_likelihood(model.vector))
328
329
	# setup the Gaussian Process kernel
330
	kernel_base = (1e7 * args.scale)**2
331
	ksub = args.name_suffix
332
333
	solver = "basic"
334
	skwargs = {}
335
	if args.HODLR_Solver:
336
		solver = "HODLR"
337
		#skwargs = {"tol": 1e-3}
338
339
	if args.george:
340
		gpname, kernel = setup_george_kernel(kernls,
341
				kernel_base=kernel_base, fit_bias=args.fit_bias)
342
		gpmodel = george.GP(kernel, mean=model,
343
			white_noise=1.e-25, fit_white_noise=args.fit_white,
344
			solver=george_solvers[solver], **skwargs)
345
		# the george interface does not allow setting the bounds in
346
		# the kernel initialization so we prepare simple default bounds
347
		kernel_bounds = [(-0.3 * max_amp, 0.3 * max_amp)
348
				for _ in gpmodel.kernel.get_parameter_names()]
349
		bounds = gpmodel.mean.get_parameter_bounds() + kernel_bounds
350
	else:
351
		gpname, cel_terms = setup_celerite_terms(kernls,
352
				fit_bias=args.fit_bias, fit_white=args.fit_white)
353
		gpmodel = celerite.GP(cel_terms, mean=model,
354
			fit_white_noise=args.fit_white,
355
			fit_mean=True)
356
		bounds = gpmodel.get_parameter_bounds()
357
	gpmodel.compute(no_ys_train, no_errs_train)
358
	logging.debug("gpmodel params: %s", gpmodel.get_parameter_dict())
359
	logging.debug("gpmodel bounds: %s", bounds)
360
	logging.debug("initial log likelihood: %s", gpmodel.log_likelihood(no_dens_train))
361
	if isinstance(gpmodel, celerite.GP):
362
		logging.info("(GP) jitter: %s", gpmodel.kernel.jitter)
363
	model_name = "_".join(gpmodel.mean.get_parameter_names()).replace(':', '')
364
	gpmodel_name = model_name + gpname
365
	logging.info("GP model name: %s", gpmodel_name)
366
367
	pre_opt = False
368
	if args.optimize > 0:
369
		def gpmodel_mean(x, *p):
370
			gpmodel.set_parameter_vector(p)
371
			return gpmodel.mean.get_value(x)
372
373
		def gpmodel_res(x, *p):
374
			gpmodel.set_parameter_vector(p)
375
			return (gpmodel.mean.get_value(x) - no_dens_train) / no_errs_train
376
377
		def lpost(p, y, gp):
378
			gp.set_parameter_vector(p)
379
			return gp.log_likelihood(y, quiet=True) + gp.log_prior()
380
381
		def nlpost(p, y, gp):
382
			lp = lpost(p, y, gp)
383
			return -lp if np.isfinite(lp) else 1e25
384
385
		def grad_nlpost(p, y, gp):
386
			gp.set_parameter_vector(p)
387
			grad_ll = gp.grad_log_likelihood(y)
388
			if isinstance(grad_ll, tuple):
389
				# celerite
390
				return -grad_ll[1]
391
			# george
392
			return -grad_ll
393
394
		if args.optimize == 1:
395
			resop_gp = op.minimize(
396
				nlpost,
397
				gpmodel.get_parameter_vector(),
398
				args=(no_dens_train, gpmodel),
399
				bounds=bounds,
400
				# method="l-bfgs-b", options=dict(disp=True, maxcor=100, eps=1e-9, ftol=2e-15, gtol=1e-8))
401
				method="l-bfgs-b", jac=grad_nlpost)
402
				# method="tnc", options=dict(disp=True, maxiter=500, xtol=1e-12))
403
				# method="nelder-mead", options=dict(disp=True, maxfev=100000, fatol=1.49012e-8, xatol=1.49012e-8))
404
				# method="Powell", options=dict(ftol=1.49012e-08, xtol=1.49012e-08))
405
		if args.optimize == 2:
406
			resop_gp = op.differential_evolution(
407
				nlpost,
408
				bounds=bounds,
409
				args=(no_dens_train, gpmodel),
410
				popsize=2 * args.walkers, tol=0.01)
411
		if args.optimize == 3:
412
			resop_bh = op.basinhopping(
413
				nlpost,
414
				gpmodel.get_parameter_vector(),
415
				niter=200,
416
				minimizer_kwargs=dict(
417
					args=(no_dens_train, gpmodel),
418
					bounds=bounds,
419
					# method="tnc"))
420
					# method="l-bfgs-b", options=dict(maxcor=100)))
421
					method="l-bfgs-b", jac=grad_nlpost))
422
					# method="Nelder-Mead"))
423
					# method="BFGS"))
424
					# method="Powell", options=dict(ftol=1.49012e-08, xtol=1.49012e-08)))
425
			logging.debug("optimization result: %s", resop_bh)
426
			resop_gp = resop_bh.lowest_optimization_result
427
		if args.optimize == 4:
428
			resop_gp, cov_gp = op.curve_fit(
429
				gpmodel_mean,
430
				no_ys_train, no_dens_train, gpmodel.get_parameter_vector(),
431
				bounds=tuple(np.array(bounds).T),
432
				# method='lm',
433
				# absolute_sigma=True,
434
				sigma=no_errs_train)
435
			print(resop_gp, np.sqrt(np.diag(cov_gp)))
436
		logging.info("%s", resop_gp.message)
437
		logging.debug("optimization result: %s", resop_gp)
438
		logging.info("gpmodel dict: %s", gpmodel.get_parameter_dict())
439
		logging.info("log posterior trained: %s", lpost(gpmodel.get_parameter_vector(), no_dens_train, gpmodel))
440
		gpmodel.compute(no_ys_test, no_errs_test)
441
		logging.info("log posterior test: %s", lpost(gpmodel.get_parameter_vector(), no_dens_test, gpmodel))
442
		gpmodel.compute(no_ys, no_errs)
443
		logging.info("log posterior all: %s", lpost(gpmodel.get_parameter_vector(), no_dens, gpmodel))
444
		# cross check to make sure that the gpmodel parameter vector is really
445
		# set to the fitted parameters
446
		logging.info("opt. model vector: %s", resop_gp.x)
447
		gpmodel.compute(no_ys_train, no_errs_train)
448
		logging.debug("opt. log posterior trained 1: %s", lpost(resop_gp.x, no_dens_train, gpmodel))
449
		gpmodel.compute(no_ys_test, no_errs_test)
450
		logging.debug("opt. log posterior test 1: %s", lpost(resop_gp.x, no_dens_test, gpmodel))
451
		gpmodel.compute(no_ys, no_errs)
452
		logging.debug("opt. log posterior all 1: %s", lpost(resop_gp.x, no_dens, gpmodel))
453
		logging.debug("opt. model vector: %s", gpmodel.get_parameter_vector())
454
		gpmodel.compute(no_ys_train, no_errs_train)
455
		logging.debug("opt. log posterior trained 2: %s", lpost(gpmodel.get_parameter_vector(), no_dens_train, gpmodel))
456
		gpmodel.compute(no_ys_test, no_errs_test)
457
		logging.debug("opt. log posterior test 2: %s", lpost(gpmodel.get_parameter_vector(), no_dens_test, gpmodel))
458
		gpmodel.compute(no_ys, no_errs)
459
		logging.debug("opt. log posterior all 2: %s", lpost(gpmodel.get_parameter_vector(), no_dens, gpmodel))
460
		pre_opt = resop_gp.success
461
	try:
462
		logging.info("GM lt: %s", gpmodel.get_parameter("mean:GM:tau0"))
463
	except ValueError:
464
		pass
465
	logging.info("(GP) model: %s", gpmodel.kernel)
466
	if isinstance(gpmodel, celerite.GP):
467
		logging.info("(GP) jitter: %s", gpmodel.kernel.jitter)
468
469
	bestfit = gpmodel.get_parameter_vector()
470
	filename_base = ("NO_regress_fit_{0}_{1:.0f}_{2:.0f}_{{0}}_{3}"
471
					.format(gpmodel_name, lat * 10, alt, ksub))
472
473
	if args.mcmc:
474
		gpmodel.compute(no_ys_train, no_errs_train)
475
		samples, lnp = mcmc_sample_model(gpmodel,
476
				no_dens_train,
477
				beta=1.0,
478
				nwalkers=args.walkers, nburnin=args.burn_in,
479
				nprod=args.production, nthreads=args.threads,
480
				show_progress=args.progress,
481
				optimized=pre_opt, bounds=bounds, return_logpost=True)
482
483
		if args.train_fraction < 1. or args.test_fraction < 1.:
484
			logging.info("Statistics for the test samples")
485
			mcmc_statistics(gpmodel,
486
					no_ys_test, no_dens_test, no_errs_test,
487
					no_ys_train, no_dens_train, no_errs_train,
488
					samples, lnp,
489
			)
490
		logging.info("Statistics for all samples")
491
		mcmc_statistics(gpmodel,
492
				no_ys, no_dens, no_errs,
493
				no_ys_train, no_dens_train, no_errs_train,
494
				samples, lnp,
495
		)
496
497
		sampl_percs = np.percentile(samples, [2.5, 50, 97.5], axis=0)
498
		if args.plot_corner:
499
			import corner
500
			# Corner plot of the sampled parameters
501
			fig = corner.corner(samples,
502
					quantiles=[0.025, 0.5, 0.975],
503
					show_titles=True,
504
					labels=gpmodel.get_parameter_names(),
505
					truths=bestfit,
506
					hist_args=dict(normed=True))
507
			fig.savefig(filename_base.format("corner") + ".pdf", transparent=True)
508
509
		if args.save_samples:
510
			if args.samples_format in ["npz"]:
511
				# save the samples compressed to save space.
512
				np.savez_compressed(filename_base.format("sampls") + ".npz",
513
					samples=samples)
514
			if args.samples_format in ["nc", "netcdf4"]:
515
				save_samples_netcdf(filename_base.format("sampls") + ".nc",
516
					gpmodel, alt, lat, samples, scale=args.scale, compressed=True)
517
			if args.samples_format in ["h5", "hdf5"]:
518
				save_samples_netcdf(filename_base.format("sampls") + ".h5",
519
					gpmodel, alt, lat, samples, scale=args.scale, compressed=True)
520
		# MCMC finished here
521
522
	# set the model times and errors to use the full data set for plotting
523
	gpmodel.compute(no_ys, no_errs)
524
	if args.save_model:
525
		try:
526
			# python 2
527
			import cPickle as pickle
528
		except ImportError:
529
			# python 3
530
			import pickle
531
		# pickle and save the model
532
		with open(filename_base.format("model") + ".pkl", "wb") as f:
533
			pickle.dump((gpmodel), f, -1)
534
535
	if args.plot_samples and args.mcmc:
536
		plot_random_samples(gpmodel, no_ys, no_dens, no_errs,
537
				samples, args.scale,
538
				filename_base.format("sampls") + ".pdf",
539
				size=4, extra_years=[4, 2])
540
541
	if args.plot_median:
542
		plot_single_sample_and_residuals(gpmodel, no_ys, no_dens, no_errs,
543
				sampl_percs[1],
544
				filename_base.format("median") + ".pdf")
545
	if args.plot_residuals:
546
		plot_residual(gpmodel, no_ys, no_dens, no_errs,
547
				sampl_percs[1], args.scale,
548
				filename_base.format("medres") + ".pdf")
549
	if args.plot_maxlnp:
550
		plot_single_sample_and_residuals(gpmodel, no_ys, no_dens, no_errs,
551
				samples[np.argmax(lnp)],
552
				filename_base.format("maxlnp") + ".pdf")
553
	if args.plot_maxlnpres:
554
		plot_residual(gpmodel, no_ys, no_dens, no_errs,
555
				samples[np.argmax(lnp)], args.scale,
556
				filename_base.format("mlpres") + ".pdf")
557
558
	labels = gpmodel.get_parameter_names()
559
	logging.info("param percentiles [2.5, 50, 97.5]:")
560
	for pc, label in zip(sampl_percs.T, labels):
561
		median = pc[1]
562
		pc_minus = median - pc[0]
563
		pc_plus = pc[2] - median
564
		logging.debug("%s: %s", label, pc)
565
		logging.info("%s: %.6f (- %.6f) (+ %.6f)", label,
566
				median, pc_minus, pc_plus)
567
568
	logging.info("Finished successfully.")
569
570
571 1
if __name__ == "__main__":
572
	main()
573