Passed
Push — master ( d1ac5b...af9f8c )
by Stefan
03:57
created

sciapy.regress.__main__._r_sun_earth()   A

Complexity

Conditions 1

Size

Total Lines 24
Code Lines 6

Duplication

Lines 0
Ratio 0 %

Code Coverage

Tests 1
CRAP Score 1.5786

Importance

Changes 0
Metric Value
cc 1
eloc 6
nop 2
dl 0
loc 24
ccs 1
cts 6
cp 0.1666
crap 1.5786
rs 10
c 0
b 0
f 0
1
# -*- coding: utf-8 -*-
2
# vim:fileencoding=utf-8
3
#
4
# Copyright (c) 2017-2018 Stefan Bender
5
#
6
# This module is part of sciapy.
7
# sciapy is free software: you can redistribute it or modify
8
# it under the terms of the GNU General Public License as published
9
# by the Free Software Foundation, version 2.
10
# See accompanying LICENSE file or http://www.gnu.org/licenses/gpl-2.0.html.
11 1
"""SCIAMACHY data regression command line interface
12
13
Command line main program for regression analysis of SCIAMACHY
14
daily zonal mean time series (NO for now).
15
"""
16
17 1
import ctypes
18 1
import logging
19 1
from os import path
20
21 1
import numpy as np
22 1
import scipy.optimize as op
23 1
from scipy.interpolate import interp1d
24
25 1
import george
26 1
import celerite
27
28 1
import matplotlib as mpl
29
# switch off X11 rendering
30 1
mpl.use("Agg")
31
32 1
from .load_data import load_solar_gm_table, load_scia_dzm
33 1
from .models_cel import trace_gas_model
34 1
from .mcmc import mcmc_sample_model
35 1
from .statistics import mcmc_statistics
36
37 1
from ._gpkernels import (george_solvers,
38
		setup_george_kernel, setup_celerite_terms)
39 1
from ._plot import (plot_single_sample_and_residuals,
40
		plot_residual, plot_random_samples)
41 1
from ._options import parser
42
43
44 1
def save_samples_netcdf(filename, model, alt, lat, samples,
45
		scale=1e-6,
46
		lnpost=None, compressed=False):
47 1
	from xarray import Dataset
48 1
	smpl_ds = Dataset(dict([(pname, (["lat", "alt", "sample"],
49
				samples[..., i].reshape(1, 1, -1)))
50
				for i, pname in enumerate(model.get_parameter_names())]
51
				# + [("lpost", (["lat", "alt", "sample"], lnp.reshape(1, 1, -1)))]
52
			),
53
			coords={"lat": [lat], "alt": [alt]})
54
55 1
	for modn in model.mean.models:
56 1
		modl = model.mean.models[modn]
57 1
		if hasattr(modl, "mean"):
58 1
			smpl_ds.attrs[modn + ":mean"] = modl.mean
59
60 1
	units = {"kernel": {
61
				"log": "log(10$^{{{0:.0f}}}$ cm$^{{-3}}$)"
62
						.format(-np.log10(scale))},
63
			"mean": {
64
				"log": "log(10$^{{{0:.0f}}}$ cm$^{{-3}}$)"
65
						.format(-np.log10(scale)),
66
				"cos": "10$^{{{0:.0f}}}$ cm$^{{-3}}$".format(-np.log10(scale)),
67
				"sin": "10$^{{{0:.0f}}}$ cm$^{{-3}}$".format(-np.log10(scale)),
68
				"val": "10$^{{{0:.0f}}}$ cm$^{{-3}}$".format(-np.log10(scale)),
69
				"amp": "10$^{{{0:.0f}}}$ cm$^{{-3}}$".format(-np.log10(scale)),
70
				"tau": "d"}}
71 1
	for pname in smpl_ds.data_vars:
72 1
		_pp = pname.split(':')
73 1
		for _n, _u in units.get(_pp[0], {}).items():
74 1
			if _pp[-1].startswith(_n):
75 1
				logging.debug("units for %s: %s", pname, _u)
76 1
				smpl_ds[pname].attrs["units"] = _u
77
78 1
	smpl_ds["alt"].attrs = {"long_name": "altitude", "units": "km"}
79 1
	smpl_ds["lat"].attrs = {"long_name": "latitude", "units": "degrees_north"}
80
81 1
	_encoding = None
82 1
	if compressed:
83 1
		_encoding = {var: {"zlib": True, "complevel": 1}
84
					for var in smpl_ds.data_vars}
85 1
	try:
86 1
		smpl_ds.to_netcdf(filename, encoding=_encoding)
87
	except ValueError:
88
		smpl_ds.to_netcdf(filename)
89 1
	smpl_ds.close()
90
91
92 1
def _train_test_split(times, data, errs, train_frac,
93
		test_frac, randomize):
94
	# split the data into training and test subsets according to the
95
	# fraction given (default is 1, i.e. no splitting)
96 1
	ndata = len(times)
97 1
	train_size = int(ndata * train_frac)
98 1
	test_size = min(ndata - train_size, int(ndata * test_frac))
99
	# randomize if requested
100 1
	if randomize:
101
		permut_idx = np.random.permutation(np.arange(ndata))
102
	else:
103 1
		permut_idx = np.arange(ndata)
104 1
	train_idx = np.sort(permut_idx[:train_size])
105 1
	test_idx = np.sort(permut_idx[train_size:train_size + test_size])
106 1
	times_train = times[train_idx]
107 1
	data_train = data[train_idx]
108 1
	errs_train = errs[train_idx]
109 1
	if test_size > 0:
110
		times_test = times[test_idx]
111
		data_test = data[test_idx]
112
		errs_test = errs[test_idx]
113
	else:
114 1
		times_test = times
115 1
		data_test = data
116 1
		errs_test = errs
117 1
	logging.info("using %s of %s samples for training.", len(times_train), ndata)
118 1
	logging.info("using %s of %s samples for testing.", len(times_test), ndata)
119 1
	return (times_train, data_train, errs_train,
120
			times_test, data_test, errs_test)
121
122
123 1
def _r_sun_earth(time, tfmt="jyear"):
124
	"""First order approximation of the Sun-Earth distance
125
126
	The Sun-to-Earth distance can be used to (un-)normalize proxies
127
	to the actual distance to the Sun instead of 1 AU.
128
129
	Parameters
130
	----------
131
	time : float
132
		Time value in the units given by 'tfmt'.
133
	tfmt : str, optional
134
		The units of 'time' as supported by the
135
		astropy.time time formats. Default: 'jyear'.
136
137
	Returns
138
	-------
139
	dist : float
140
		The Sun-Earth distance at the given day of year in AU.
141
	"""
142
	from astropy.time import Time
143
	tdoy = Time(time, format=tfmt)
144
	tdoy.format = "yday"
145
	doy = int(tdoy.value.split(':')[1])
146
	return 1 - 0.01672 * np.cos(2 * np.pi / 365.256363 * (doy - 4))
147
148
149 1
def main():
150 1
	logging.basicConfig(level=logging.WARNING,
151
			format="[%(levelname)-8s] (%(asctime)s) "
152
				"%(filename)s:%(lineno)d %(message)s",
153
			datefmt="%Y-%m-%d %H:%M:%S %z")
154
155 1
	args = parser.parse_args()
156
157 1
	logging.info("command line arguments: %s", args)
158 1
	if args.quiet:
159 1
		logging.getLogger().setLevel(logging.ERROR)
160
	elif args.verbose:
161
		logging.getLogger().setLevel(logging.INFO)
162
	else:
163
		logging.getLogger().setLevel(args.loglevel)
164
165 1
	from numpy.distutils.system_info import get_info
166 1
	try:
167 1
		ob_lib_dirs = get_info("openblas")["library_dirs"]
168 1
	except KeyError:
169 1
		ob_lib_dirs = []
170 1
	for oblas_path in ob_lib_dirs:
171
		oblas_name = "{0}/libopenblas.so".format(oblas_path)
172
		logging.info("Trying %s", oblas_name)
173
		try:
174
			oblas_lib = ctypes.cdll.LoadLibrary(oblas_name)
175
			oblas_cores = oblas_lib.openblas_get_num_threads()
176
			oblas_lib.openblas_set_num_threads(args.openblas_threads)
177
			logging.info("Using %s/%s Openblas thread(s).",
178
					oblas_lib.openblas_get_num_threads(), oblas_cores)
179
		except:
180
			logging.info("Setting number of openblas threads failed.")
181
182 1
	if args.random_seed is not None:
183 1
		np.random.seed(args.random_seed)
184
185 1
	if args.proxies:
186 1
		proxies = args.proxies.split(',')
187 1
		proxy_dict = dict(_p.split(':') for _p in proxies)
188
	else:
189
		proxy_dict = {}
190 1
	lag_dict = {pn: 0 for pn in proxy_dict.keys()}
191
192
	# Post-processing of arguments...
193
	# List of proxy lag fits from csv
194 1
	fit_lags = args.fit_lags.split(',')
195
	# List of proxy lifetime fits from csv
196 1
	fit_lifetimes = args.fit_lifetimes.split(',')
197 1
	fit_annlifetimes = args.fit_annlifetimes.split(',')
198
	# List of proxy lag times from csv
199 1
	lag_dict.update(dict(_ls.split(':') for _ls in args.lag_times.split(',')))
200
	# List of cycles (frequencies in 1/year) from argument list (csv)
201 1
	try:
202 1
		freqs = list(map(float, args.freqs.split(',')))
203 1
	except ValueError:
204 1
		freqs = []
205 1
	args.freqs = freqs
206
	# List of initial parameter values
207 1
	initial = None
208 1
	if args.initial is not None:
209
		try:
210
			initial = list(map(float, args.initial.split(',')))
211
		except ValueError:
212
			pass
213
	# List of GP kernels from argument list (csv)
214 1
	kernls = args.kernels.split(',')
215
216 1
	lat = args.latitude
217 1
	alt = args.altitude
218 1
	logging.info("location: %.0f°N %.0f km", lat, alt)
219
220 1
	no_ys, no_dens, no_errs, no_szas = load_scia_dzm(args.file, alt, lat,
221
			tfmt=args.time_format,
222
			scale=args.scale,
223
			#subsample_factor=args.random_subsample,
224
			#subsample_method="random",
225
			akd_threshold=args.akd_threshold,
226
			cnt_threshold=args.cnt_threshold,
227
			center=args.center_data,
228
			season=args.season,
229
			SPEs=args.exclude_spe)
230
231 1
	(no_ys_train, no_dens_train, no_errs_train,
232
		no_ys_test, no_dens_test, no_errs_test) = _train_test_split(
233
				no_ys, no_dens, no_errs, args.train_fraction,
234
				args.test_fraction, args.random_train_test)
235
236 1
	sza_intp = interp1d(no_ys, no_szas, bounds_error=False)
237
238 1
	max_amp = 1e10 * args.scale
239 1
	max_days = 100
240
241 1
	proxy_config = {}
242 1
	for pn, pf in proxy_dict.items():
243 1
		pt, pp = load_solar_gm_table(path.expanduser(pf),
244
				cols=[0, 1], names=["time", pn], tfmt=args.time_format)
245 1
		pv = pp[pn]
246
		# use log of proxy values if desired
247 1
		if pn in args.log_proxies.split(','):
248
			pv = np.log(pv)
249
		# normalize to sun--earth distance squared
250 1
		if pn in args.norm_proxies_distSEsq.split(','):
251
			rad_sun_earth = np.vectorize(_r_sun_earth)(pt, tfmt=args.time_format)
252
			pv /= rad_sun_earth**2
253
		# normalize by cos(SZA)
254 1
		if pn in args.norm_proxies_SZA.split(',') and sza_intp is not None:
255
			pv *= np.cos(np.radians(sza_intp(pt)))
256 1
		proxy_config.update({pn:
257
			dict(times=pt, values=pv,
258
				center=pn in args.center_proxies.split(','),
259
				positive=pn in args.positive_proxies.split(','),
260
				lag=float(lag_dict[pn]),
261
				max_amp=max_amp, max_days=max_days,
262
				sza_intp=sza_intp if args.use_sza else None,
263
			)}
264
		)
265
266 1
	model = trace_gas_model(constant=args.fit_offset,
267
			proxy_config=proxy_config, **vars(args))
268
269 1
	logging.debug("model dict: %s", model.get_parameter_dict())
270 1
	model.freeze_all_parameters()
271
	# thaw parameters according to requested fits
272 1
	for pn in proxy_dict.keys():
273 1
		model.thaw_parameter("{0}:amp".format(pn))
274 1
		if pn in fit_lags:
275
			model.thaw_parameter("{0}:lag".format(pn))
276 1
		if pn in fit_lifetimes:
277 1
			model.set_parameter("{0}:tau0".format(pn), 1e-3)
278 1
			model.thaw_parameter("{0}:tau0".format(pn))
279 1
			if pn in fit_annlifetimes:
280 1
				model.thaw_parameter("{0}:taucos1".format(pn))
281 1
				model.thaw_parameter("{0}:tausin1".format(pn))
282
		else:
283 1
			model.set_parameter("{0}:ltscan".format(pn), 0)
284 1
	for freq in freqs:
285 1
		if not args.fit_phase:
286 1
			model.thaw_parameter("f{0:.0f}:cos".format(freq))
287 1
			model.thaw_parameter("f{0:.0f}:sin".format(freq))
288
		else:
289
			model.thaw_parameter("f{0:.0f}:amp".format(freq))
290
			model.thaw_parameter("f{0:.0f}:phase".format(freq))
291 1
	if args.fit_offset:
292
		#model.set_parameter("offset:value", -100.)
293
		#model.set_parameter("offset:value", 0)
294 1
		model.thaw_parameter("offset:value")
295
296 1
	if initial is not None:
297
		model.set_parameter_vector(initial)
298
	# model.thaw_parameter("GM:ltscan")
299 1
	logging.debug("params: %s", model.get_parameter_dict())
300 1
	logging.debug("param names: %s", model.get_parameter_names())
301 1
	logging.debug("param vector: %s", model.get_parameter_vector())
302 1
	logging.debug("param bounds: %s", model.get_parameter_bounds())
303
	#logging.debug("model value: %s", model.get_value(no_ys))
304
	#logging.debug("default log likelihood: %s", model.log_likelihood(model.vector))
305
306
	# setup the Gaussian Process kernel
307 1
	kernel_base = (1e7 * args.scale)**2
308 1
	ksub = args.name_suffix
309
310 1
	solver = "basic"
311 1
	skwargs = {}
312 1
	if args.HODLR_Solver:
313 1
		solver = "HODLR"
314
		#skwargs = {"tol": 1e-3}
315
316 1
	if args.george:
317 1
		gpname, kernel = setup_george_kernel(kernls,
318
				kernel_base=kernel_base, fit_bias=args.fit_bias)
319 1
		gpmodel = george.GP(kernel, mean=model,
320
			white_noise=1.e-25, fit_white_noise=args.fit_white,
321
			solver=george_solvers[solver], **skwargs)
322
		# the george interface does not allow setting the bounds in
323
		# the kernel initialization so we prepare simple default bounds
324 1
		kernel_bounds = [(-0.3 * max_amp, 0.3 * max_amp)] * args.fit_white + [
325
			(-0.3 * max_amp, 0.3 * max_amp)
326
			for _ in gpmodel.kernel.get_parameter_names()
327
		]
328 1
		bounds = gpmodel.mean.get_parameter_bounds() + kernel_bounds
329
	else:
330 1
		gpname, cel_terms = setup_celerite_terms(kernls,
331
				fit_bias=args.fit_bias, fit_white=args.fit_white)
332 1
		gpmodel = celerite.GP(cel_terms, mean=model,
333
			fit_white_noise=args.fit_white,
334
			fit_mean=True)
335 1
		bounds = gpmodel.get_parameter_bounds()
336 1
	gpmodel.compute(no_ys_train, no_errs_train)
337 1
	logging.debug("gpmodel params: %s", gpmodel.get_parameter_dict())
338 1
	logging.debug("gpmodel bounds: %s", bounds)
339 1
	logging.debug("initial log likelihood: %s", gpmodel.log_likelihood(no_dens_train))
340 1
	if isinstance(gpmodel, celerite.GP):
341 1
		logging.info("(GP) jitter: %s", gpmodel.kernel.jitter)
342 1
	model_name = "_".join(gpmodel.mean.get_parameter_names()).replace(':', '')
343 1
	gpmodel_name = model_name + gpname
344 1
	logging.info("GP model name: %s", gpmodel_name)
345
346 1
	pre_opt = False
347 1
	if args.optimize > 0:
348 1
		def gpmodel_mean(x, *p):
349 1
			gpmodel.set_parameter_vector(p)
350 1
			return gpmodel.mean.get_value(x)
351
352 1
		def gpmodel_res(x, *p):
353
			gpmodel.set_parameter_vector(p)
354
			return (gpmodel.mean.get_value(x) - no_dens_train) / no_errs_train
355
356 1
		def lpost(p, y, gp):
357 1
			gp.set_parameter_vector(p)
358 1
			return gp.log_likelihood(y, quiet=True) + gp.log_prior()
359
360 1
		def nlpost(p, y, gp):
361 1
			lp = lpost(p, y, gp)
0 ignored issues
show
introduced by
The variable lpost does not seem to be defined in case args.optimize > 0 on line 347 is False. Are you sure this can never be the case?
Loading history...
362 1
			return -lp if np.isfinite(lp) else 1e25
363
364 1
		def grad_nlpost(p, y, gp):
365 1
			gp.set_parameter_vector(p)
366 1
			grad_ll = gp.grad_log_likelihood(y)
367 1
			if isinstance(grad_ll, tuple):
368
				# celerite
369 1
				return -grad_ll[1]
370
			# george
371 1
			return -grad_ll
372
373 1
		jacobian = grad_nlpost if gpmodel.kernel.vector_size else None
374 1
		if args.optimize == 1:
375 1
			resop_gp = op.minimize(
376
				nlpost,
377
				gpmodel.get_parameter_vector(),
378
				args=(no_dens_train, gpmodel),
379
				bounds=bounds,
380
				# method="l-bfgs-b", options=dict(disp=True, maxcor=100, eps=1e-9, ftol=2e-15, gtol=1e-8))
381
				method="l-bfgs-b", jac=jacobian)
382
				# method="tnc", options=dict(disp=True, maxiter=500, xtol=1e-12))
383
				# method="nelder-mead", options=dict(disp=True, maxfev=100000, fatol=1.49012e-8, xatol=1.49012e-8))
384
				# method="Powell", options=dict(ftol=1.49012e-08, xtol=1.49012e-08))
385 1
		elif args.optimize == 2:
386 1
			resop_gp = op.differential_evolution(
387
				nlpost,
388
				bounds=bounds,
389
				args=(no_dens_train, gpmodel),
390
				popsize=2 * args.walkers, tol=0.01)
391 1
		elif args.optimize == 3:
392 1
			resop_bh = op.basinhopping(
393
				nlpost,
394
				gpmodel.get_parameter_vector(),
395
				niter=200,
396
				minimizer_kwargs=dict(
397
					args=(no_dens_train, gpmodel),
398
					bounds=bounds,
399
					# method="tnc"))
400
					# method="l-bfgs-b", options=dict(maxcor=100)))
401
					method="l-bfgs-b", jac=jacobian))
402
					# method="Nelder-Mead"))
403
					# method="BFGS"))
404
					# method="Powell", options=dict(ftol=1.49012e-08, xtol=1.49012e-08)))
405 1
			logging.debug("optimization result: %s", resop_bh)
406 1
			resop_gp = resop_bh.lowest_optimization_result
407 1
		elif args.optimize == 4:
408 1
			resop, cov_gp = op.curve_fit(
409
				gpmodel_mean,
410
				no_ys_train, no_dens_train, gpmodel.get_parameter_vector(),
411
				bounds=tuple(np.array(bounds).T),
412
				# method='lm',
413
				# absolute_sigma=True,
414
				sigma=no_errs_train)
415 1
			resop_gp = op.OptimizeResult(dict(
416
				x=resop,
417
				success=True,
418
				message="Curve fit successful.",
419
			))
420 1
			logging.debug("curve fit %s, std %s:", resop, np.sqrt(np.diag(cov_gp)))
421
		else:
422
			logging.warn("unsupported optimization method: %s", args.optimize)
423
			resop_gp = op.OptimizeResult(dict(
424
				x=gpmodel.get_parameter_vector(),
425
				success=False,
426
				message="unsupported optimization method: {0}".format(args.optimize),
427
			))
428 1
		logging.info("%s", resop_gp.message)
429 1
		logging.debug("optimization result: %s", resop_gp)
430 1
		logging.info("gpmodel dict: %s", gpmodel.get_parameter_dict())
431 1
		logging.info("log posterior trained: %s", lpost(gpmodel.get_parameter_vector(), no_dens_train, gpmodel))
432 1
		gpmodel.compute(no_ys_test, no_errs_test)
433 1
		logging.info("log posterior test: %s", lpost(gpmodel.get_parameter_vector(), no_dens_test, gpmodel))
434 1
		gpmodel.compute(no_ys, no_errs)
435 1
		logging.info("log posterior all: %s", lpost(gpmodel.get_parameter_vector(), no_dens, gpmodel))
436
		# cross check to make sure that the gpmodel parameter vector is really
437
		# set to the fitted parameters
438 1
		logging.info("opt. model vector: %s", resop_gp.x)
439 1
		gpmodel.compute(no_ys_train, no_errs_train)
440 1
		logging.debug("opt. log posterior trained 1: %s", lpost(resop_gp.x, no_dens_train, gpmodel))
441 1
		gpmodel.compute(no_ys_test, no_errs_test)
442 1
		logging.debug("opt. log posterior test 1: %s", lpost(resop_gp.x, no_dens_test, gpmodel))
443 1
		gpmodel.compute(no_ys, no_errs)
444 1
		logging.debug("opt. log posterior all 1: %s", lpost(resop_gp.x, no_dens, gpmodel))
445 1
		logging.debug("opt. model vector: %s", gpmodel.get_parameter_vector())
446 1
		gpmodel.compute(no_ys_train, no_errs_train)
447 1
		logging.debug("opt. log posterior trained 2: %s", lpost(gpmodel.get_parameter_vector(), no_dens_train, gpmodel))
448 1
		gpmodel.compute(no_ys_test, no_errs_test)
449 1
		logging.debug("opt. log posterior test 2: %s", lpost(gpmodel.get_parameter_vector(), no_dens_test, gpmodel))
450 1
		gpmodel.compute(no_ys, no_errs)
451 1
		logging.debug("opt. log posterior all 2: %s", lpost(gpmodel.get_parameter_vector(), no_dens, gpmodel))
452 1
		pre_opt = resop_gp.success
453 1
	try:
454 1
		logging.info("GM lt: %s", gpmodel.get_parameter("mean:GM:tau0"))
455
	except ValueError:
456
		pass
457 1
	logging.info("(GP) model: %s", gpmodel.kernel)
458 1
	if isinstance(gpmodel, celerite.GP):
459 1
		logging.info("(GP) jitter: %s", gpmodel.kernel.jitter)
460
461 1
	bestfit = gpmodel.get_parameter_vector()
462 1
	filename_base = path.join(
463
		args.output_path,
464
		"NO_regress_fit_{0}_{1:.0f}_{2:.0f}_{{0}}_{3}"
465
		.format(gpmodel_name, lat * 10, alt, ksub),
466
	)
467
468 1
	if args.mcmc:
469 1
		gpmodel.compute(no_ys_train, no_errs_train)
470 1
		samples, lnp = mcmc_sample_model(gpmodel,
471
				no_dens_train,
472
				beta=1.0,
473
				nwalkers=args.walkers, nburnin=args.burn_in,
474
				nprod=args.production, nthreads=args.threads,
475
				show_progress=args.progress,
476
				optimized=pre_opt, bounds=bounds, return_logpost=True)
477
478 1
		if args.train_fraction < 1. or args.test_fraction < 1.:
479
			logging.info("Statistics for the test samples")
480
			mcmc_statistics(gpmodel,
481
					no_ys_test, no_dens_test, no_errs_test,
482
					no_ys_train, no_dens_train, no_errs_train,
483
					samples, lnp,
484
			)
485 1
		logging.info("Statistics for all samples")
486 1
		mcmc_statistics(gpmodel,
487
				no_ys, no_dens, no_errs,
488
				no_ys_train, no_dens_train, no_errs_train,
489
				samples, lnp,
490
		)
491
492 1
		sampl_percs = np.percentile(samples, [2.5, 50, 97.5], axis=0)
493 1
		if args.plot_corner:
494 1
			import corner
495
			# Corner plot of the sampled parameters
496 1
			fig = corner.corner(samples,
497
					quantiles=[0.025, 0.5, 0.975],
498
					show_titles=True,
499
					labels=gpmodel.get_parameter_names(),
500
					truths=bestfit,
501
					hist_args=dict(normed=True))
502 1
			fig.savefig(filename_base.format("corner") + ".pdf", transparent=True)
503
504 1
		if args.save_samples:
505 1
			if args.samples_format in ["npz"]:
506
				# save the samples compressed to save space.
507
				np.savez_compressed(filename_base.format("sampls") + ".npz",
508
					samples=samples)
509 1
			if args.samples_format in ["nc", "netcdf4"]:
510 1
				save_samples_netcdf(filename_base.format("sampls") + ".nc",
511
					gpmodel, alt, lat, samples, scale=args.scale, compressed=True)
512 1
			if args.samples_format in ["h5", "hdf5"]:
513
				save_samples_netcdf(filename_base.format("sampls") + ".h5",
514
					gpmodel, alt, lat, samples, scale=args.scale, compressed=True)
515
		# MCMC finished here
516
517
	# set the model times and errors to use the full data set for plotting
518 1
	gpmodel.compute(no_ys, no_errs)
519 1
	if args.save_model:
520
		try:
521
			# python 2
522
			import cPickle as pickle
523
		except ImportError:
524
			# python 3
525
			import pickle
526
		# pickle and save the model
527
		with open(filename_base.format("model") + ".pkl", "wb") as f:
528
			pickle.dump((gpmodel), f, -1)
529
530 1
	if args.plot_samples and args.mcmc:
531 1
		plot_random_samples(gpmodel, no_ys, no_dens, no_errs,
532
				samples, args.scale,
0 ignored issues
show
introduced by
The variable samples does not seem to be defined in case args.mcmc on line 468 is False. Are you sure this can never be the case?
Loading history...
533
				filename_base.format("sampls") + ".pdf",
534
				size=4, extra_years=[4, 2])
535
536 1
	if args.plot_median:
537 1
		plot_single_sample_and_residuals(gpmodel, no_ys, no_dens, no_errs,
538
				sampl_percs[1],
0 ignored issues
show
introduced by
The variable sampl_percs does not seem to be defined in case args.mcmc on line 468 is False. Are you sure this can never be the case?
Loading history...
539
				filename_base.format("median") + ".pdf")
540 1
	if args.plot_residuals:
541 1
		plot_residual(gpmodel, no_ys, no_dens, no_errs,
542
				sampl_percs[1], args.scale,
543
				filename_base.format("medres") + ".pdf")
544 1
	if args.plot_maxlnp:
545 1
		plot_single_sample_and_residuals(gpmodel, no_ys, no_dens, no_errs,
546
				samples[np.argmax(lnp)],
0 ignored issues
show
introduced by
The variable lnp does not seem to be defined in case args.mcmc on line 468 is False. Are you sure this can never be the case?
Loading history...
547
				filename_base.format("maxlnp") + ".pdf")
548 1
	if args.plot_maxlnpres:
549 1
		plot_residual(gpmodel, no_ys, no_dens, no_errs,
550
				samples[np.argmax(lnp)], args.scale,
551
				filename_base.format("mlpres") + ".pdf")
552
553 1
	labels = gpmodel.get_parameter_names()
554 1
	logging.info("param percentiles [2.5, 50, 97.5]:")
555 1
	for pc, label in zip(sampl_percs.T, labels):
556 1
		median = pc[1]
557 1
		pc_minus = median - pc[0]
558 1
		pc_plus = pc[2] - median
559 1
		logging.debug("%s: %s", label, pc)
560 1
		logging.info("%s: %.6f (- %.6f) (+ %.6f)", label,
561
				median, pc_minus, pc_plus)
562
563 1
	logging.info("Finished successfully.")
564
565
566 1
if __name__ == "__main__":
567
	main()
568