Passed
Push — master ( c3c30b...6ed46e )
by Stefan
04:58
created

sciapy.regress.__main__.main()   F

Complexity

Conditions 56

Size

Total Lines 397
Code Lines 309

Duplication

Lines 397
Ratio 100 %

Code Coverage

Tests 176
CRAP Score 94.8567

Importance

Changes 0
Metric Value
cc 56
eloc 309
nop 0
dl 397
loc 397
ccs 176
cts 229
cp 0.7685
crap 94.8567
rs 0
c 0
b 0
f 0

How to fix   Long Method    Complexity   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Complexity

Complex classes like sciapy.regress.__main__.main() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
# -*- coding: utf-8 -*-
2
# vim:fileencoding=utf-8
3
#
4
# Copyright (c) 2017-2018 Stefan Bender
5
#
6
# This module is part of sciapy.
7
# sciapy is free software: you can redistribute it or modify
8
# it under the terms of the GNU General Public License as published
9
# by the Free Software Foundation, version 2.
10
# See accompanying LICENSE file or http://www.gnu.org/licenses/gpl-2.0.html.
11 1
"""SCIAMACHY data regression command line interface
12
13
Command line main program for regression analysis of SCIAMACHY
14
daily zonal mean time series (NO for now).
15
"""
16
17 1
import ctypes
18 1
import logging
19
20 1
import numpy as np
21 1
import scipy.optimize as op
22 1
from scipy.interpolate import interp1d
23
24 1
import george
25 1
import celerite
26
27 1
import matplotlib as mpl
28
# switch off X11 rendering
29 1
mpl.use("Agg")
30
31 1
from .load_data import load_solar_gm_table, load_scia_dzm
32 1
from .models_cel import trace_gas_model
33 1
from .mcmc import mcmc_sample_model
34 1
from .statistics import mcmc_statistics
35
36 1
from ._gpkernels import (george_solvers,
37
		setup_george_kernel, setup_celerite_terms)
38 1
from ._plot import (plot_single_sample_and_residuals,
39
		plot_residual, plot_random_samples)
40 1
from ._options import parser
41
42
43 1 View Code Duplication
def save_samples_netcdf(filename, model, alt, lat, samples,
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
44
		scale=1e-6,
45
		lnpost=None, compressed=False):
46 1
	from xarray import Dataset
47 1
	smpl_ds = Dataset(dict([(pname, (["lat", "alt", "sample"],
48
				samples[..., i].reshape(1, 1, -1)))
49
				for i, pname in enumerate(model.get_parameter_names())]
50
				# + [("lpost", (["lat", "alt", "sample"], lnp.reshape(1, 1, -1)))]
51
			),
52
			coords={"lat": [lat], "alt": [alt]})
53
54 1
	for modn in model.mean.models:
55 1
		modl = model.mean.models[modn]
56 1
		if hasattr(modl, "mean"):
57 1
			smpl_ds.attrs[modn + ":mean"] = modl.mean
58
59 1
	units = {"kernel": {
60
				"log": "log(10$^{{{0:.0f}}}$ cm$^{{-3}}$)"
61
						.format(-np.log10(scale))},
62
			"mean": {
63
				"log": "log(10$^{{{0:.0f}}}$ cm$^{{-3}}$)"
64
						.format(-np.log10(scale)),
65
				"cos": "10$^{{{0:.0f}}}$ cm$^{{-3}}$".format(-np.log10(scale)),
66
				"sin": "10$^{{{0:.0f}}}$ cm$^{{-3}}$".format(-np.log10(scale)),
67
				"val": "10$^{{{0:.0f}}}$ cm$^{{-3}}$".format(-np.log10(scale)),
68
				"amp": "10$^{{{0:.0f}}}$ cm$^{{-3}}$".format(-np.log10(scale)),
69
				"tau": "d"}}
70 1
	for pname in smpl_ds.data_vars:
71 1
		_pp = pname.split(':')
72 1
		for _n, _u in units[_pp[0]].items():
73 1
			if _pp[-1].startswith(_n):
74 1
				logging.debug("units for %s: %s", pname, _u)
75 1
				smpl_ds[pname].attrs["units"] = _u
76
77 1
	smpl_ds["alt"].attrs = {"long_name": "altitude", "units": "km"}
78 1
	smpl_ds["lat"].attrs = {"long_name": "latitude", "units": "degrees_north"}
79
80 1
	_encoding = None
81 1
	if compressed:
82 1
		_encoding = {var: {"zlib": True, "complevel": 1}
83
					for var in smpl_ds.data_vars}
84 1
	smpl_ds.to_netcdf(filename, encoding=_encoding)
85 1
	smpl_ds.close()
86
87
88 1 View Code Duplication
def _train_test_split(times, data, errs, train_frac,
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
89
		test_frac, randomize):
90
	# split the data into training and test subsets according to the
91
	# fraction given (default is 1, i.e. no splitting)
92 1
	ndata = len(times)
93 1
	train_size = int(ndata * train_frac)
94 1
	test_size = min(ndata - train_size, int(ndata * test_frac))
95
	# randomize if requested
96 1
	if randomize:
97
		permut_idx = np.random.permutation(np.arange(ndata))
98
	else:
99 1
		permut_idx = np.arange(ndata)
100 1
	train_idx = np.sort(permut_idx[:train_size])
101 1
	test_idx = np.sort(permut_idx[train_size:train_size + test_size])
102 1
	times_train = times[train_idx]
103 1
	data_train = data[train_idx]
104 1
	errs_train = errs[train_idx]
105 1
	if test_size > 0:
106
		times_test = times[test_idx]
107
		data_test = data[test_idx]
108
		errs_test = errs[test_idx]
109
	else:
110 1
		times_test = times
111 1
		data_test = data
112 1
		errs_test = errs
113 1
	logging.info("using %s of %s samples for training.", len(times_train), ndata)
114 1
	logging.info("using %s of %s samples for testing.", len(times_test), ndata)
115 1
	return (times_train, data_train, errs_train,
116
			times_test, data_test, errs_test)
117
118
119 1 View Code Duplication
def _r_sun_earth(time, tfmt="jyear"):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
120
	"""First order approximation of the Sun-Earth distance
121
122
	The Sun-to-Earth distance can be used to (un-)normalize proxies
123
	to the actual distance to the Sun instead of 1 AU.
124
125
	Parameters
126
	----------
127
	time : float
128
		Time value in the units given by 'tfmt'.
129
	tfmt : str, optional
130
		The units of 'time' as supported by the
131
		astropy.time time formats. Default: 'jyear'.
132
133
	Returns
134
	-------
135
	dist : float
136
		The Sun-Earth distance at the given day of year in AU.
137
	"""
138
	from astropy.time import Time
139
	tdoy = Time(time, format=tfmt)
140
	tdoy.format = "yday"
141
	doy = int(tdoy.value.split(':')[1])
142
	return 1 - 0.01672 * np.cos(2 * np.pi / 365.256363 * (doy - 4))
143
144
145 1 View Code Duplication
def main():
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
146 1
	logging.basicConfig(level=logging.WARNING,
147
			format="[%(levelname)-8s] (%(asctime)s) "
148
				"%(filename)s:%(lineno)d %(message)s",
149
			datefmt="%Y-%m-%d %H:%M:%S %z")
150
151 1
	args = parser.parse_args()
152
153 1
	logging.info("command line arguments: %s", args)
154 1
	if args.quiet:
155 1
		logging.getLogger().setLevel(logging.ERROR)
156
	elif args.verbose:
157
		logging.getLogger().setLevel(logging.INFO)
158
	else:
159
		logging.getLogger().setLevel(args.loglevel)
160
161 1
	from numpy.distutils.system_info import get_info
162 1
	try:
163 1
		ob_lib_dirs = get_info("openblas")["library_dirs"]
164 1
	except KeyError:
165 1
		ob_lib_dirs = []
166 1
	for oblas_path in ob_lib_dirs:
167
		oblas_name = "{0}/libopenblas.so".format(oblas_path)
168
		logging.info("Trying %s", oblas_name)
169
		try:
170
			oblas_lib = ctypes.cdll.LoadLibrary(oblas_name)
171
			oblas_cores = oblas_lib.openblas_get_num_threads()
172
			oblas_lib.openblas_set_num_threads(args.openblas_threads)
173
			logging.info("Using %s/%s Openblas thread(s).",
174
					oblas_lib.openblas_get_num_threads(), oblas_cores)
175
		except:
176
			logging.info("Setting number of openblas threads failed.")
177
178 1
	if args.random_seed is not None:
179 1
		np.random.seed(args.random_seed)
180
181 1
	if args.proxies:
182 1
		proxies = args.proxies.split(',')
183 1
		proxy_dict = dict(_p.split(':') for _p in proxies)
184
	else:
185
		proxy_dict = {}
186 1
	lag_dict = {pn: 0 for pn in proxy_dict.keys()}
187
188
	# Post-processing of arguments...
189
	# List of proxy lag fits from csv
190 1
	fit_lags = args.fit_lags.split(',')
191
	# List of proxy lifetime fits from csv
192 1
	fit_lifetimes = args.fit_lifetimes.split(',')
193 1
	fit_annlifetimes = args.fit_annlifetimes.split(',')
194
	# List of proxy lag times from csv
195 1
	lag_dict.update(dict(_ls.split(':') for _ls in args.lag_times.split(',')))
196
	# List of cycles (frequencies in 1/year) from argument list (csv)
197 1
	try:
198 1
		freqs = list(map(float, args.freqs.split(',')))
199 1
	except ValueError:
200 1
		freqs = []
201 1
	args.freqs = freqs
202
	# List of initial parameter values
203 1
	initial = None
204 1
	if args.initial is not None:
205
		try:
206
			initial = list(map(float, args.initial.split(',')))
207
		except ValueError:
208
			pass
209
	# List of GP kernels from argument list (csv)
210 1
	kernls = args.kernels.split(',')
211
212 1
	lat = args.latitude
213 1
	alt = args.altitude
214 1
	logging.info("location: %.0f°N %.0f km", lat, alt)
215
216 1
	no_ys, no_dens, no_errs, no_szas = load_scia_dzm(args.file, alt, lat,
217
			tfmt=args.time_format,
218
			scale=args.scale,
219
			#subsample_factor=args.random_subsample,
220
			#subsample_method="random",
221
			akd_threshold=args.akd_threshold,
222
			cnt_threshold=args.cnt_threshold,
223
			center=args.center_data,
224
			season=args.season,
225
			SPEs=args.exclude_spe)
226
227 1
	(no_ys_train, no_dens_train, no_errs_train,
228
		no_ys_test, no_dens_test, no_errs_test) = _train_test_split(
229
				no_ys, no_dens, no_errs, args.train_fraction,
230
				args.test_fraction, args.random_train_test)
231
232 1
	sza_intp = interp1d(no_ys, no_szas, bounds_error=False)
233
234 1
	max_amp = 1e10 * args.scale
235 1
	max_days = 100
236
237 1
	proxy_config = {}
238 1
	for pn, pf in proxy_dict.items():
239 1
		pt, pp = load_solar_gm_table(pf, cols=[0, 1], names=["time", pn], tfmt=args.time_format)
240
		# use log of proxy values
241 1
		pv = pp[pn].to_numpy()
242 1
		if pn in args.log_proxies.split(','):
243
			pv = np.log(pv)
244
		# normalize to sun--earth distance squared
245 1
		if pn in args.norm_proxies_distSEsq.split(','):
246
			rad_sun_earth = np.vectorize(_r_sun_earth)(pt, tfmt=args.time_format)
247
			pv /= rad_sun_earth**2
248
		# normalize by cos(SZA)
249 1
		if pn in args.norm_proxies_SZA.split(',') and sza_intp is not None:
250
			pv *= np.cos(np.radians(sza_intp(pt)))
251 1
		proxy_config.update({pn:
252
			dict(times=pt, values=pv,
253
				center=pn in args.center_proxies.split(','),
254
				positive=pn in args.positive_proxies.split(','),
255
				lag=float(lag_dict[pn]),
256
				max_amp=max_amp, max_days=max_days,
257
				sza_intp=sza_intp if args.use_sza else None,
258
			)}
259
		)
260
261 1
	model = trace_gas_model(constant=args.fit_offset,
262
			proxy_config=proxy_config, **vars(args))
263
264 1
	logging.debug("model dict: %s", model.get_parameter_dict())
265 1
	model.freeze_all_parameters()
266
	# thaw parameters according to requested fits
267 1
	for pn in proxy_dict.keys():
268 1
		model.thaw_parameter("{0}:amp".format(pn))
269 1
		if pn in fit_lags:
270
			model.thaw_parameter("{0}:lag".format(pn))
271 1
		if pn in fit_lifetimes:
272 1
			model.set_parameter("{0}:tau0".format(pn), 1e-3)
273 1
			model.thaw_parameter("{0}:tau0".format(pn))
274 1
			if pn in fit_annlifetimes:
275 1
				model.thaw_parameter("{0}:taucos1".format(pn))
276 1
				model.thaw_parameter("{0}:tausin1".format(pn))
277
		else:
278 1
			model.set_parameter("{0}:ltscan".format(pn), 0)
279 1
	for freq in freqs:
280 1
		if not args.fit_phase:
281 1
			model.thaw_parameter("f{0:.0f}:cos".format(freq))
282 1
			model.thaw_parameter("f{0:.0f}:sin".format(freq))
283
		else:
284
			model.thaw_parameter("f{0:.0f}:amp".format(freq))
285
			model.thaw_parameter("f{0:.0f}:phase".format(freq))
286 1
	if args.fit_offset:
287
		#model.set_parameter("offset:value", -100.)
288
		#model.set_parameter("offset:value", 0)
289 1
		model.thaw_parameter("offset:value")
290
291 1
	if initial is not None:
292
		model.set_parameter_vector(initial)
293
	# model.thaw_parameter("GM:ltscan")
294 1
	logging.debug("params: %s", model.get_parameter_dict())
295 1
	logging.debug("param names: %s", model.get_parameter_names())
296 1
	logging.debug("param vector: %s", model.get_parameter_vector())
297 1
	logging.debug("param bounds: %s", model.get_parameter_bounds())
298
	#logging.debug("model value: %s", model.get_value(no_ys))
299
	#logging.debug("default log likelihood: %s", model.log_likelihood(model.vector))
300
301
	# setup the Gaussian Process kernel
302 1
	kernel_base = (1e7 * args.scale)**2
303 1
	ksub = args.name_suffix
304
305 1
	solver = "basic"
306 1
	skwargs = {}
307 1
	if args.HODLR_Solver:
308
		solver = "HODLR"
309
		#skwargs = {"tol": 1e-3}
310
311 1
	if args.george:
312
		gpname, kernel = setup_george_kernel(kernls,
313
				kernel_base=kernel_base, fit_bias=args.fit_bias)
314
		gpmodel = george.GP(kernel, mean=model,
315
			white_noise=1.e-25, fit_white_noise=args.fit_white,
316
			solver=george_solvers[solver], **skwargs)
317
		# the george interface does not allow setting the bounds in
318
		# the kernel initialization so we prepare simple default bounds
319
		kernel_bounds = [(-0.3 * max_amp, 0.3 * max_amp)
320
				for _ in gpmodel.kernel.get_parameter_names()]
321
		bounds = gpmodel.mean.get_parameter_bounds() + kernel_bounds
322
	else:
323 1
		gpname, cel_terms = setup_celerite_terms(kernls,
324
				fit_bias=args.fit_bias, fit_white=args.fit_white)
325 1
		gpmodel = celerite.GP(cel_terms, mean=model,
326
			fit_white_noise=args.fit_white,
327
			fit_mean=True)
328 1
		bounds = gpmodel.get_parameter_bounds()
329 1
	gpmodel.compute(no_ys_train, no_errs_train)
330 1
	logging.debug("gpmodel params: %s", gpmodel.get_parameter_dict())
331 1
	logging.debug("gpmodel bounds: %s", bounds)
332 1
	logging.debug("initial log likelihood: %s", gpmodel.log_likelihood(no_dens_train))
333 1
	if isinstance(gpmodel, celerite.GP):
334 1
		logging.info("(GP) jitter: %s", gpmodel.kernel.jitter)
335 1
	model_name = "_".join(gpmodel.mean.get_parameter_names()).replace(':', '')
336 1
	gpmodel_name = model_name + gpname
337 1
	logging.info("GP model name: %s", gpmodel_name)
338
339 1
	pre_opt = False
340 1
	if args.optimize > 0:
341 1
		def gpmodel_mean(x, *p):
342
			gpmodel.set_parameter_vector(p)
343
			return gpmodel.mean.get_value(x)
344
345 1
		def gpmodel_res(x, *p):
346
			gpmodel.set_parameter_vector(p)
347
			return (gpmodel.mean.get_value(x) - no_dens_train) / no_errs_train
348
349 1
		def lpost(p, y, gp):
350 1
			gp.set_parameter_vector(p)
351 1
			return gp.log_likelihood(y, quiet=True) + gp.log_prior()
352
353 1
		def nlpost(p, y, gp):
354 1
			lp = lpost(p, y, gp)
0 ignored issues
show
introduced by
The variable lpost does not seem to be defined in case args.optimize > 0 on line 340 is False. Are you sure this can never be the case?
Loading history...
355 1
			return -lp if np.isfinite(lp) else 1e25
356
357 1
		def grad_nlpost(p, y, gp):
358 1
			gp.set_parameter_vector(p)
359 1
			grad_ll = gp.grad_log_likelihood(y)
360 1
			if isinstance(grad_ll, tuple):
361
				# celerite
362 1
				return -grad_ll[1]
363
			# george
364
			return -grad_ll
365
366 1
		jacobian = grad_nlpost if gpmodel.kernel.vector_size else None
367 1
		if args.optimize == 1:
368 1
			resop_gp = op.minimize(
369
				nlpost,
370
				gpmodel.get_parameter_vector(),
371
				args=(no_dens_train, gpmodel),
372
				bounds=bounds,
373
				# method="l-bfgs-b", options=dict(disp=True, maxcor=100, eps=1e-9, ftol=2e-15, gtol=1e-8))
374
				method="l-bfgs-b", jac=jacobian)
375
				# method="tnc", options=dict(disp=True, maxiter=500, xtol=1e-12))
376
				# method="nelder-mead", options=dict(disp=True, maxfev=100000, fatol=1.49012e-8, xatol=1.49012e-8))
377
				# method="Powell", options=dict(ftol=1.49012e-08, xtol=1.49012e-08))
378 1
		if args.optimize == 2:
379
			resop_gp = op.differential_evolution(
380
				nlpost,
381
				bounds=bounds,
382
				args=(no_dens_train, gpmodel),
383
				popsize=2 * args.walkers, tol=0.01)
384 1
		if args.optimize == 3:
385
			resop_bh = op.basinhopping(
386
				nlpost,
387
				gpmodel.get_parameter_vector(),
388
				niter=200,
389
				minimizer_kwargs=dict(
390
					args=(no_dens_train, gpmodel),
391
					bounds=bounds,
392
					# method="tnc"))
393
					# method="l-bfgs-b", options=dict(maxcor=100)))
394
					method="l-bfgs-b", jac=jacobian))
395
					# method="Nelder-Mead"))
396
					# method="BFGS"))
397
					# method="Powell", options=dict(ftol=1.49012e-08, xtol=1.49012e-08)))
398
			logging.debug("optimization result: %s", resop_bh)
399
			resop_gp = resop_bh.lowest_optimization_result
400 1
		if args.optimize == 4:
401
			resop_gp, cov_gp = op.curve_fit(
402
				gpmodel_mean,
403
				no_ys_train, no_dens_train, gpmodel.get_parameter_vector(),
404
				bounds=tuple(np.array(bounds).T),
405
				# method='lm',
406
				# absolute_sigma=True,
407
				sigma=no_errs_train)
408
			print(resop_gp, np.sqrt(np.diag(cov_gp)))
409 1
		logging.info("%s", resop_gp.message)
0 ignored issues
show
introduced by
The variable resop_gp does not seem to be defined in case args.optimize == 1 on line 367 is False. Are you sure this can never be the case?
Loading history...
410 1
		logging.debug("optimization result: %s", resop_gp)
411 1
		logging.info("gpmodel dict: %s", gpmodel.get_parameter_dict())
412 1
		logging.info("log posterior trained: %s", lpost(gpmodel.get_parameter_vector(), no_dens_train, gpmodel))
413 1
		gpmodel.compute(no_ys_test, no_errs_test)
414 1
		logging.info("log posterior test: %s", lpost(gpmodel.get_parameter_vector(), no_dens_test, gpmodel))
415 1
		gpmodel.compute(no_ys, no_errs)
416 1
		logging.info("log posterior all: %s", lpost(gpmodel.get_parameter_vector(), no_dens, gpmodel))
417
		# cross check to make sure that the gpmodel parameter vector is really
418
		# set to the fitted parameters
419 1
		logging.info("opt. model vector: %s", resop_gp.x)
420 1
		gpmodel.compute(no_ys_train, no_errs_train)
421 1
		logging.debug("opt. log posterior trained 1: %s", lpost(resop_gp.x, no_dens_train, gpmodel))
422 1
		gpmodel.compute(no_ys_test, no_errs_test)
423 1
		logging.debug("opt. log posterior test 1: %s", lpost(resop_gp.x, no_dens_test, gpmodel))
424 1
		gpmodel.compute(no_ys, no_errs)
425 1
		logging.debug("opt. log posterior all 1: %s", lpost(resop_gp.x, no_dens, gpmodel))
426 1
		logging.debug("opt. model vector: %s", gpmodel.get_parameter_vector())
427 1
		gpmodel.compute(no_ys_train, no_errs_train)
428 1
		logging.debug("opt. log posterior trained 2: %s", lpost(gpmodel.get_parameter_vector(), no_dens_train, gpmodel))
429 1
		gpmodel.compute(no_ys_test, no_errs_test)
430 1
		logging.debug("opt. log posterior test 2: %s", lpost(gpmodel.get_parameter_vector(), no_dens_test, gpmodel))
431 1
		gpmodel.compute(no_ys, no_errs)
432 1
		logging.debug("opt. log posterior all 2: %s", lpost(gpmodel.get_parameter_vector(), no_dens, gpmodel))
433 1
		pre_opt = resop_gp.success
434 1
	try:
435 1
		logging.info("GM lt: %s", gpmodel.get_parameter("mean:GM:tau0"))
436
	except ValueError:
437
		pass
438 1
	logging.info("(GP) model: %s", gpmodel.kernel)
439 1
	if isinstance(gpmodel, celerite.GP):
440 1
		logging.info("(GP) jitter: %s", gpmodel.kernel.jitter)
441
442 1
	bestfit = gpmodel.get_parameter_vector()
443 1
	filename_base = ("NO_regress_fit_{0}_{1:.0f}_{2:.0f}_{{0}}_{3}"
444
					.format(gpmodel_name, lat * 10, alt, ksub))
445
446 1
	if args.mcmc:
447 1
		gpmodel.compute(no_ys_train, no_errs_train)
448 1
		samples, lnp = mcmc_sample_model(gpmodel,
449
				no_dens_train,
450
				beta=1.0,
451
				nwalkers=args.walkers, nburnin=args.burn_in,
452
				nprod=args.production, nthreads=args.threads,
453
				show_progress=args.progress,
454
				optimized=pre_opt, bounds=bounds, return_logpost=True)
455
456 1
		if args.train_fraction < 1. or args.test_fraction < 1.:
457
			logging.info("Statistics for the test samples")
458
			mcmc_statistics(gpmodel,
459
					no_ys_test, no_dens_test, no_errs_test,
460
					no_ys_train, no_dens_train, no_errs_train,
461
					samples, lnp,
462
			)
463 1
		logging.info("Statistics for all samples")
464 1
		mcmc_statistics(gpmodel,
465
				no_ys, no_dens, no_errs,
466
				no_ys_train, no_dens_train, no_errs_train,
467
				samples, lnp,
468
		)
469
470 1
		sampl_percs = np.percentile(samples, [2.5, 50, 97.5], axis=0)
471 1
		if args.plot_corner:
472 1
			import corner
473
			# Corner plot of the sampled parameters
474 1
			fig = corner.corner(samples,
475
					quantiles=[0.025, 0.5, 0.975],
476
					show_titles=True,
477
					labels=gpmodel.get_parameter_names(),
478
					truths=bestfit,
479
					hist_args=dict(normed=True))
480 1
			fig.savefig(filename_base.format("corner") + ".pdf", transparent=True)
481
482 1
		if args.save_samples:
483 1
			if args.samples_format in ["npz"]:
484
				# save the samples compressed to save space.
485
				np.savez_compressed(filename_base.format("sampls") + ".npz",
486
					samples=samples)
487 1
			if args.samples_format in ["nc", "netcdf4"]:
488 1
				save_samples_netcdf(filename_base.format("sampls") + ".nc",
489
					gpmodel, alt, lat, samples, scale=args.scale, compressed=True)
490 1
			if args.samples_format in ["h5", "hdf5"]:
491
				save_samples_netcdf(filename_base.format("sampls") + ".h5",
492
					gpmodel, alt, lat, samples, scale=args.scale, compressed=True)
493
		# MCMC finished here
494
495
	# set the model times and errors to use the full data set for plotting
496 1
	gpmodel.compute(no_ys, no_errs)
497 1
	if args.save_model:
498
		try:
499
			# python 2
500
			import cPickle as pickle
501
		except ImportError:
502
			# python 3
503
			import pickle
504
		# pickle and save the model
505
		with open(filename_base.format("model") + ".pkl", "wb") as f:
506
			pickle.dump((gpmodel), f, -1)
507
508 1
	if args.plot_samples and args.mcmc:
509 1
		plot_random_samples(gpmodel, no_ys, no_dens, no_errs,
510
				samples, args.scale,
0 ignored issues
show
introduced by
The variable samples does not seem to be defined in case args.mcmc on line 446 is False. Are you sure this can never be the case?
Loading history...
511
				filename_base.format("sampls") + ".pdf",
512
				size=4, extra_years=[4, 2])
513
514 1
	if args.plot_median:
515 1
		plot_single_sample_and_residuals(gpmodel, no_ys, no_dens, no_errs,
516
				sampl_percs[1],
0 ignored issues
show
introduced by
The variable sampl_percs does not seem to be defined in case args.mcmc on line 446 is False. Are you sure this can never be the case?
Loading history...
517
				filename_base.format("median") + ".pdf")
518 1
	if args.plot_residuals:
519 1
		plot_residual(gpmodel, no_ys, no_dens, no_errs,
520
				sampl_percs[1], args.scale,
521
				filename_base.format("medres") + ".pdf")
522 1
	if args.plot_maxlnp:
523 1
		plot_single_sample_and_residuals(gpmodel, no_ys, no_dens, no_errs,
524
				samples[np.argmax(lnp)],
0 ignored issues
show
introduced by
The variable lnp does not seem to be defined in case args.mcmc on line 446 is False. Are you sure this can never be the case?
Loading history...
525
				filename_base.format("maxlnp") + ".pdf")
526 1
	if args.plot_maxlnpres:
527 1
		plot_residual(gpmodel, no_ys, no_dens, no_errs,
528
				samples[np.argmax(lnp)], args.scale,
529
				filename_base.format("mlpres") + ".pdf")
530
531 1
	labels = gpmodel.get_parameter_names()
532 1
	logging.info("param percentiles [2.5, 50, 97.5]:")
533 1
	for pc, label in zip(sampl_percs.T, labels):
534 1
		median = pc[1]
535 1
		pc_minus = median - pc[0]
536 1
		pc_plus = pc[2] - median
537 1
		logging.debug("%s: %s", label, pc)
538 1
		logging.info("%s: %.6f (- %.6f) (+ %.6f)", label,
539
				median, pc_minus, pc_plus)
540
541 1
	logging.info("Finished successfully.")
542
543
544 1
if __name__ == "__main__":
545
	main()
546