Passed
Push — master ( e8b27e...e8dfc7 )
by Stefan
03:28
created

sciapy.regress.__main__   F

Complexity

Total Complexity 67

Size/Duplication

Total Lines 551
Duplicated Lines 90.2 %

Test Coverage

Coverage 80.87%

Importance

Changes 0
Metric Value
eloc 402
dl 497
loc 551
ccs 241
cts 298
cp 0.8087
rs 3.04
c 0
b 0
f 0
wmc 67

4 Functions

Rating   Name   Duplication   Size   Complexity  
A _r_sun_earth() 24 24 1
A _train_test_split() 29 29 3
B save_samples_netcdf() 43 43 7
F main() 401 401 56

How to fix   Duplicated Code    Complexity   

Duplicated Code

Duplicate code is one of the most pungent code smells. A rule that is often used is to re-structure code once it is duplicated in three or more places.

Common duplication problems, and corresponding solutions are:

Complexity

 Tip:   Before tackling complexity, make sure that you eliminate any duplication first. This often can reduce the size of classes significantly.

Complex classes like sciapy.regress.__main__ often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
# -*- coding: utf-8 -*-
2
# vim:fileencoding=utf-8
3
#
4
# Copyright (c) 2017-2018 Stefan Bender
5
#
6
# This module is part of sciapy.
7
# sciapy is free software: you can redistribute it or modify
8
# it under the terms of the GNU General Public License as published
9
# by the Free Software Foundation, version 2.
10
# See accompanying LICENSE file or http://www.gnu.org/licenses/gpl-2.0.html.
11 1
"""SCIAMACHY data regression command line interface
12
13
Command line main program for regression analysis of SCIAMACHY
14
daily zonal mean time series (NO for now).
15
"""
16
17 1
import ctypes
18 1
import logging
19 1
from os import path
20
21 1
import numpy as np
22 1
import scipy.optimize as op
23 1
from scipy.interpolate import interp1d
24
25 1
import george
26 1
import celerite
27
28 1
import matplotlib as mpl
29
# switch off X11 rendering
30 1
mpl.use("Agg")
31
32 1
from .load_data import load_solar_gm_table, load_scia_dzm
33 1
from .models_cel import trace_gas_model
34 1
from .mcmc import mcmc_sample_model
35 1
from .statistics import mcmc_statistics
36
37 1
from ._gpkernels import (george_solvers,
38
		setup_george_kernel, setup_celerite_terms)
39 1
from ._plot import (plot_single_sample_and_residuals,
40
		plot_residual, plot_random_samples)
41 1
from ._options import parser
42
43
44 1 View Code Duplication
def save_samples_netcdf(filename, model, alt, lat, samples,
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
45
		scale=1e-6,
46
		lnpost=None, compressed=False):
47 1
	from xarray import Dataset
48 1
	smpl_ds = Dataset(dict([(pname, (["lat", "alt", "sample"],
49
				samples[..., i].reshape(1, 1, -1)))
50
				for i, pname in enumerate(model.get_parameter_names())]
51
				# + [("lpost", (["lat", "alt", "sample"], lnp.reshape(1, 1, -1)))]
52
			),
53
			coords={"lat": [lat], "alt": [alt]})
54
55 1
	for modn in model.mean.models:
56 1
		modl = model.mean.models[modn]
57 1
		if hasattr(modl, "mean"):
58 1
			smpl_ds.attrs[modn + ":mean"] = modl.mean
59
60 1
	units = {"kernel": {
61
				"log": "log(10$^{{{0:.0f}}}$ cm$^{{-3}}$)"
62
						.format(-np.log10(scale))},
63
			"mean": {
64
				"log": "log(10$^{{{0:.0f}}}$ cm$^{{-3}}$)"
65
						.format(-np.log10(scale)),
66
				"cos": "10$^{{{0:.0f}}}$ cm$^{{-3}}$".format(-np.log10(scale)),
67
				"sin": "10$^{{{0:.0f}}}$ cm$^{{-3}}$".format(-np.log10(scale)),
68
				"val": "10$^{{{0:.0f}}}$ cm$^{{-3}}$".format(-np.log10(scale)),
69
				"amp": "10$^{{{0:.0f}}}$ cm$^{{-3}}$".format(-np.log10(scale)),
70
				"tau": "d"}}
71 1
	for pname in smpl_ds.data_vars:
72 1
		_pp = pname.split(':')
73 1
		for _n, _u in units[_pp[0]].items():
74 1
			if _pp[-1].startswith(_n):
75 1
				logging.debug("units for %s: %s", pname, _u)
76 1
				smpl_ds[pname].attrs["units"] = _u
77
78 1
	smpl_ds["alt"].attrs = {"long_name": "altitude", "units": "km"}
79 1
	smpl_ds["lat"].attrs = {"long_name": "latitude", "units": "degrees_north"}
80
81 1
	_encoding = None
82 1
	if compressed:
83 1
		_encoding = {var: {"zlib": True, "complevel": 1}
84
					for var in smpl_ds.data_vars}
85 1
	smpl_ds.to_netcdf(filename, encoding=_encoding)
86 1
	smpl_ds.close()
87
88
89 1 View Code Duplication
def _train_test_split(times, data, errs, train_frac,
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
90
		test_frac, randomize):
91
	# split the data into training and test subsets according to the
92
	# fraction given (default is 1, i.e. no splitting)
93 1
	ndata = len(times)
94 1
	train_size = int(ndata * train_frac)
95 1
	test_size = min(ndata - train_size, int(ndata * test_frac))
96
	# randomize if requested
97 1
	if randomize:
98
		permut_idx = np.random.permutation(np.arange(ndata))
99
	else:
100 1
		permut_idx = np.arange(ndata)
101 1
	train_idx = np.sort(permut_idx[:train_size])
102 1
	test_idx = np.sort(permut_idx[train_size:train_size + test_size])
103 1
	times_train = times[train_idx]
104 1
	data_train = data[train_idx]
105 1
	errs_train = errs[train_idx]
106 1
	if test_size > 0:
107
		times_test = times[test_idx]
108
		data_test = data[test_idx]
109
		errs_test = errs[test_idx]
110
	else:
111 1
		times_test = times
112 1
		data_test = data
113 1
		errs_test = errs
114 1
	logging.info("using %s of %s samples for training.", len(times_train), ndata)
115 1
	logging.info("using %s of %s samples for testing.", len(times_test), ndata)
116 1
	return (times_train, data_train, errs_train,
117
			times_test, data_test, errs_test)
118
119
120 1 View Code Duplication
def _r_sun_earth(time, tfmt="jyear"):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
121
	"""First order approximation of the Sun-Earth distance
122
123
	The Sun-to-Earth distance can be used to (un-)normalize proxies
124
	to the actual distance to the Sun instead of 1 AU.
125
126
	Parameters
127
	----------
128
	time : float
129
		Time value in the units given by 'tfmt'.
130
	tfmt : str, optional
131
		The units of 'time' as supported by the
132
		astropy.time time formats. Default: 'jyear'.
133
134
	Returns
135
	-------
136
	dist : float
137
		The Sun-Earth distance at the given day of year in AU.
138
	"""
139
	from astropy.time import Time
140
	tdoy = Time(time, format=tfmt)
141
	tdoy.format = "yday"
142
	doy = int(tdoy.value.split(':')[1])
143
	return 1 - 0.01672 * np.cos(2 * np.pi / 365.256363 * (doy - 4))
144
145
146 1 View Code Duplication
def main():
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
147 1
	logging.basicConfig(level=logging.WARNING,
148
			format="[%(levelname)-8s] (%(asctime)s) "
149
				"%(filename)s:%(lineno)d %(message)s",
150
			datefmt="%Y-%m-%d %H:%M:%S %z")
151
152 1
	args = parser.parse_args()
153
154 1
	logging.info("command line arguments: %s", args)
155 1
	if args.quiet:
156 1
		logging.getLogger().setLevel(logging.ERROR)
157
	elif args.verbose:
158
		logging.getLogger().setLevel(logging.INFO)
159
	else:
160
		logging.getLogger().setLevel(args.loglevel)
161
162 1
	from numpy.distutils.system_info import get_info
163 1
	try:
164 1
		ob_lib_dirs = get_info("openblas")["library_dirs"]
165 1
	except KeyError:
166 1
		ob_lib_dirs = []
167 1
	for oblas_path in ob_lib_dirs:
168
		oblas_name = "{0}/libopenblas.so".format(oblas_path)
169
		logging.info("Trying %s", oblas_name)
170
		try:
171
			oblas_lib = ctypes.cdll.LoadLibrary(oblas_name)
172
			oblas_cores = oblas_lib.openblas_get_num_threads()
173
			oblas_lib.openblas_set_num_threads(args.openblas_threads)
174
			logging.info("Using %s/%s Openblas thread(s).",
175
					oblas_lib.openblas_get_num_threads(), oblas_cores)
176
		except:
177
			logging.info("Setting number of openblas threads failed.")
178
179 1
	if args.random_seed is not None:
180 1
		np.random.seed(args.random_seed)
181
182 1
	if args.proxies:
183 1
		proxies = args.proxies.split(',')
184 1
		proxy_dict = dict(_p.split(':') for _p in proxies)
185
	else:
186
		proxy_dict = {}
187 1
	lag_dict = {pn: 0 for pn in proxy_dict.keys()}
188
189
	# Post-processing of arguments...
190
	# List of proxy lag fits from csv
191 1
	fit_lags = args.fit_lags.split(',')
192
	# List of proxy lifetime fits from csv
193 1
	fit_lifetimes = args.fit_lifetimes.split(',')
194 1
	fit_annlifetimes = args.fit_annlifetimes.split(',')
195
	# List of proxy lag times from csv
196 1
	lag_dict.update(dict(_ls.split(':') for _ls in args.lag_times.split(',')))
197
	# List of cycles (frequencies in 1/year) from argument list (csv)
198 1
	try:
199 1
		freqs = list(map(float, args.freqs.split(',')))
200 1
	except ValueError:
201 1
		freqs = []
202 1
	args.freqs = freqs
203
	# List of initial parameter values
204 1
	initial = None
205 1
	if args.initial is not None:
206
		try:
207
			initial = list(map(float, args.initial.split(',')))
208
		except ValueError:
209
			pass
210
	# List of GP kernels from argument list (csv)
211 1
	kernls = args.kernels.split(',')
212
213 1
	lat = args.latitude
214 1
	alt = args.altitude
215 1
	logging.info("location: %.0f°N %.0f km", lat, alt)
216
217 1
	no_ys, no_dens, no_errs, no_szas = load_scia_dzm(args.file, alt, lat,
218
			tfmt=args.time_format,
219
			scale=args.scale,
220
			#subsample_factor=args.random_subsample,
221
			#subsample_method="random",
222
			akd_threshold=args.akd_threshold,
223
			cnt_threshold=args.cnt_threshold,
224
			center=args.center_data,
225
			season=args.season,
226
			SPEs=args.exclude_spe)
227
228 1
	(no_ys_train, no_dens_train, no_errs_train,
229
		no_ys_test, no_dens_test, no_errs_test) = _train_test_split(
230
				no_ys, no_dens, no_errs, args.train_fraction,
231
				args.test_fraction, args.random_train_test)
232
233 1
	sza_intp = interp1d(no_ys, no_szas, bounds_error=False)
234
235 1
	max_amp = 1e10 * args.scale
236 1
	max_days = 100
237
238 1
	proxy_config = {}
239 1
	for pn, pf in proxy_dict.items():
240 1
		pt, pp = load_solar_gm_table(path.expanduser(pf),
241
				cols=[0, 1], names=["time", pn], tfmt=args.time_format)
242 1
		pv = pp[pn]
243
		# use log of proxy values if desired
244 1
		if pn in args.log_proxies.split(','):
245
			pv = np.log(pv)
246
		# normalize to sun--earth distance squared
247 1
		if pn in args.norm_proxies_distSEsq.split(','):
248
			rad_sun_earth = np.vectorize(_r_sun_earth)(pt, tfmt=args.time_format)
249
			pv /= rad_sun_earth**2
250
		# normalize by cos(SZA)
251 1
		if pn in args.norm_proxies_SZA.split(',') and sza_intp is not None:
252
			pv *= np.cos(np.radians(sza_intp(pt)))
253 1
		proxy_config.update({pn:
254
			dict(times=pt, values=pv,
255
				center=pn in args.center_proxies.split(','),
256
				positive=pn in args.positive_proxies.split(','),
257
				lag=float(lag_dict[pn]),
258
				max_amp=max_amp, max_days=max_days,
259
				sza_intp=sza_intp if args.use_sza else None,
260
			)}
261
		)
262
263 1
	model = trace_gas_model(constant=args.fit_offset,
264
			proxy_config=proxy_config, **vars(args))
265
266 1
	logging.debug("model dict: %s", model.get_parameter_dict())
267 1
	model.freeze_all_parameters()
268
	# thaw parameters according to requested fits
269 1
	for pn in proxy_dict.keys():
270 1
		model.thaw_parameter("{0}:amp".format(pn))
271 1
		if pn in fit_lags:
272
			model.thaw_parameter("{0}:lag".format(pn))
273 1
		if pn in fit_lifetimes:
274 1
			model.set_parameter("{0}:tau0".format(pn), 1e-3)
275 1
			model.thaw_parameter("{0}:tau0".format(pn))
276 1
			if pn in fit_annlifetimes:
277 1
				model.thaw_parameter("{0}:taucos1".format(pn))
278 1
				model.thaw_parameter("{0}:tausin1".format(pn))
279
		else:
280 1
			model.set_parameter("{0}:ltscan".format(pn), 0)
281 1
	for freq in freqs:
282 1
		if not args.fit_phase:
283 1
			model.thaw_parameter("f{0:.0f}:cos".format(freq))
284 1
			model.thaw_parameter("f{0:.0f}:sin".format(freq))
285
		else:
286
			model.thaw_parameter("f{0:.0f}:amp".format(freq))
287
			model.thaw_parameter("f{0:.0f}:phase".format(freq))
288 1
	if args.fit_offset:
289
		#model.set_parameter("offset:value", -100.)
290
		#model.set_parameter("offset:value", 0)
291 1
		model.thaw_parameter("offset:value")
292
293 1
	if initial is not None:
294
		model.set_parameter_vector(initial)
295
	# model.thaw_parameter("GM:ltscan")
296 1
	logging.debug("params: %s", model.get_parameter_dict())
297 1
	logging.debug("param names: %s", model.get_parameter_names())
298 1
	logging.debug("param vector: %s", model.get_parameter_vector())
299 1
	logging.debug("param bounds: %s", model.get_parameter_bounds())
300
	#logging.debug("model value: %s", model.get_value(no_ys))
301
	#logging.debug("default log likelihood: %s", model.log_likelihood(model.vector))
302
303
	# setup the Gaussian Process kernel
304 1
	kernel_base = (1e7 * args.scale)**2
305 1
	ksub = args.name_suffix
306
307 1
	solver = "basic"
308 1
	skwargs = {}
309 1
	if args.HODLR_Solver:
310
		solver = "HODLR"
311
		#skwargs = {"tol": 1e-3}
312
313 1
	if args.george:
314 1
		gpname, kernel = setup_george_kernel(kernls,
315
				kernel_base=kernel_base, fit_bias=args.fit_bias)
316 1
		gpmodel = george.GP(kernel, mean=model,
317
			white_noise=1.e-25, fit_white_noise=args.fit_white,
318
			solver=george_solvers[solver], **skwargs)
319
		# the george interface does not allow setting the bounds in
320
		# the kernel initialization so we prepare simple default bounds
321 1
		kernel_bounds = [(-0.3 * max_amp, 0.3 * max_amp)
322
				for _ in gpmodel.kernel.get_parameter_names()]
323 1
		bounds = gpmodel.mean.get_parameter_bounds() + kernel_bounds
324
	else:
325 1
		gpname, cel_terms = setup_celerite_terms(kernls,
326
				fit_bias=args.fit_bias, fit_white=args.fit_white)
327 1
		gpmodel = celerite.GP(cel_terms, mean=model,
328
			fit_white_noise=args.fit_white,
329
			fit_mean=True)
330 1
		bounds = gpmodel.get_parameter_bounds()
331 1
	gpmodel.compute(no_ys_train, no_errs_train)
332 1
	logging.debug("gpmodel params: %s", gpmodel.get_parameter_dict())
333 1
	logging.debug("gpmodel bounds: %s", bounds)
334 1
	logging.debug("initial log likelihood: %s", gpmodel.log_likelihood(no_dens_train))
335 1
	if isinstance(gpmodel, celerite.GP):
336 1
		logging.info("(GP) jitter: %s", gpmodel.kernel.jitter)
337 1
	model_name = "_".join(gpmodel.mean.get_parameter_names()).replace(':', '')
338 1
	gpmodel_name = model_name + gpname
339 1
	logging.info("GP model name: %s", gpmodel_name)
340
341 1
	pre_opt = False
342 1
	if args.optimize > 0:
343 1
		def gpmodel_mean(x, *p):
344
			gpmodel.set_parameter_vector(p)
345
			return gpmodel.mean.get_value(x)
346
347 1
		def gpmodel_res(x, *p):
348
			gpmodel.set_parameter_vector(p)
349
			return (gpmodel.mean.get_value(x) - no_dens_train) / no_errs_train
350
351 1
		def lpost(p, y, gp):
352 1
			gp.set_parameter_vector(p)
353 1
			return gp.log_likelihood(y, quiet=True) + gp.log_prior()
354
355 1
		def nlpost(p, y, gp):
356 1
			lp = lpost(p, y, gp)
0 ignored issues
show
introduced by
The variable lpost does not seem to be defined in case args.optimize > 0 on line 342 is False. Are you sure this can never be the case?
Loading history...
357 1
			return -lp if np.isfinite(lp) else 1e25
358
359 1
		def grad_nlpost(p, y, gp):
360 1
			gp.set_parameter_vector(p)
361 1
			grad_ll = gp.grad_log_likelihood(y)
362 1
			if isinstance(grad_ll, tuple):
363
				# celerite
364 1
				return -grad_ll[1]
365
			# george
366 1
			return -grad_ll
367
368 1
		jacobian = grad_nlpost if gpmodel.kernel.vector_size else None
369 1
		if args.optimize == 1:
370 1
			resop_gp = op.minimize(
371
				nlpost,
372
				gpmodel.get_parameter_vector(),
373
				args=(no_dens_train, gpmodel),
374
				bounds=bounds,
375
				# method="l-bfgs-b", options=dict(disp=True, maxcor=100, eps=1e-9, ftol=2e-15, gtol=1e-8))
376
				method="l-bfgs-b", jac=jacobian)
377
				# method="tnc", options=dict(disp=True, maxiter=500, xtol=1e-12))
378
				# method="nelder-mead", options=dict(disp=True, maxfev=100000, fatol=1.49012e-8, xatol=1.49012e-8))
379
				# method="Powell", options=dict(ftol=1.49012e-08, xtol=1.49012e-08))
380 1
		if args.optimize == 2:
381
			resop_gp = op.differential_evolution(
382
				nlpost,
383
				bounds=bounds,
384
				args=(no_dens_train, gpmodel),
385
				popsize=2 * args.walkers, tol=0.01)
386 1
		if args.optimize == 3:
387
			resop_bh = op.basinhopping(
388
				nlpost,
389
				gpmodel.get_parameter_vector(),
390
				niter=200,
391
				minimizer_kwargs=dict(
392
					args=(no_dens_train, gpmodel),
393
					bounds=bounds,
394
					# method="tnc"))
395
					# method="l-bfgs-b", options=dict(maxcor=100)))
396
					method="l-bfgs-b", jac=jacobian))
397
					# method="Nelder-Mead"))
398
					# method="BFGS"))
399
					# method="Powell", options=dict(ftol=1.49012e-08, xtol=1.49012e-08)))
400
			logging.debug("optimization result: %s", resop_bh)
401
			resop_gp = resop_bh.lowest_optimization_result
402 1
		if args.optimize == 4:
403
			resop_gp, cov_gp = op.curve_fit(
404
				gpmodel_mean,
405
				no_ys_train, no_dens_train, gpmodel.get_parameter_vector(),
406
				bounds=tuple(np.array(bounds).T),
407
				# method='lm',
408
				# absolute_sigma=True,
409
				sigma=no_errs_train)
410
			print(resop_gp, np.sqrt(np.diag(cov_gp)))
411 1
		logging.info("%s", resop_gp.message)
0 ignored issues
show
introduced by
The variable resop_gp does not seem to be defined in case args.optimize == 1 on line 369 is False. Are you sure this can never be the case?
Loading history...
412 1
		logging.debug("optimization result: %s", resop_gp)
413 1
		logging.info("gpmodel dict: %s", gpmodel.get_parameter_dict())
414 1
		logging.info("log posterior trained: %s", lpost(gpmodel.get_parameter_vector(), no_dens_train, gpmodel))
415 1
		gpmodel.compute(no_ys_test, no_errs_test)
416 1
		logging.info("log posterior test: %s", lpost(gpmodel.get_parameter_vector(), no_dens_test, gpmodel))
417 1
		gpmodel.compute(no_ys, no_errs)
418 1
		logging.info("log posterior all: %s", lpost(gpmodel.get_parameter_vector(), no_dens, gpmodel))
419
		# cross check to make sure that the gpmodel parameter vector is really
420
		# set to the fitted parameters
421 1
		logging.info("opt. model vector: %s", resop_gp.x)
422 1
		gpmodel.compute(no_ys_train, no_errs_train)
423 1
		logging.debug("opt. log posterior trained 1: %s", lpost(resop_gp.x, no_dens_train, gpmodel))
424 1
		gpmodel.compute(no_ys_test, no_errs_test)
425 1
		logging.debug("opt. log posterior test 1: %s", lpost(resop_gp.x, no_dens_test, gpmodel))
426 1
		gpmodel.compute(no_ys, no_errs)
427 1
		logging.debug("opt. log posterior all 1: %s", lpost(resop_gp.x, no_dens, gpmodel))
428 1
		logging.debug("opt. model vector: %s", gpmodel.get_parameter_vector())
429 1
		gpmodel.compute(no_ys_train, no_errs_train)
430 1
		logging.debug("opt. log posterior trained 2: %s", lpost(gpmodel.get_parameter_vector(), no_dens_train, gpmodel))
431 1
		gpmodel.compute(no_ys_test, no_errs_test)
432 1
		logging.debug("opt. log posterior test 2: %s", lpost(gpmodel.get_parameter_vector(), no_dens_test, gpmodel))
433 1
		gpmodel.compute(no_ys, no_errs)
434 1
		logging.debug("opt. log posterior all 2: %s", lpost(gpmodel.get_parameter_vector(), no_dens, gpmodel))
435 1
		pre_opt = resop_gp.success
436 1
	try:
437 1
		logging.info("GM lt: %s", gpmodel.get_parameter("mean:GM:tau0"))
438
	except ValueError:
439
		pass
440 1
	logging.info("(GP) model: %s", gpmodel.kernel)
441 1
	if isinstance(gpmodel, celerite.GP):
442 1
		logging.info("(GP) jitter: %s", gpmodel.kernel.jitter)
443
444 1
	bestfit = gpmodel.get_parameter_vector()
445 1
	filename_base = path.join(
446
		args.output_path,
447
		"NO_regress_fit_{0}_{1:.0f}_{2:.0f}_{{0}}_{3}"
448
		.format(gpmodel_name, lat * 10, alt, ksub),
449
	)
450
451 1
	if args.mcmc:
452 1
		gpmodel.compute(no_ys_train, no_errs_train)
453 1
		samples, lnp = mcmc_sample_model(gpmodel,
454
				no_dens_train,
455
				beta=1.0,
456
				nwalkers=args.walkers, nburnin=args.burn_in,
457
				nprod=args.production, nthreads=args.threads,
458
				show_progress=args.progress,
459
				optimized=pre_opt, bounds=bounds, return_logpost=True)
460
461 1
		if args.train_fraction < 1. or args.test_fraction < 1.:
462
			logging.info("Statistics for the test samples")
463
			mcmc_statistics(gpmodel,
464
					no_ys_test, no_dens_test, no_errs_test,
465
					no_ys_train, no_dens_train, no_errs_train,
466
					samples, lnp,
467
			)
468 1
		logging.info("Statistics for all samples")
469 1
		mcmc_statistics(gpmodel,
470
				no_ys, no_dens, no_errs,
471
				no_ys_train, no_dens_train, no_errs_train,
472
				samples, lnp,
473
		)
474
475 1
		sampl_percs = np.percentile(samples, [2.5, 50, 97.5], axis=0)
476 1
		if args.plot_corner:
477 1
			import corner
478
			# Corner plot of the sampled parameters
479 1
			fig = corner.corner(samples,
480
					quantiles=[0.025, 0.5, 0.975],
481
					show_titles=True,
482
					labels=gpmodel.get_parameter_names(),
483
					truths=bestfit,
484
					hist_args=dict(normed=True))
485 1
			fig.savefig(filename_base.format("corner") + ".pdf", transparent=True)
486
487 1
		if args.save_samples:
488 1
			if args.samples_format in ["npz"]:
489
				# save the samples compressed to save space.
490
				np.savez_compressed(filename_base.format("sampls") + ".npz",
491
					samples=samples)
492 1
			if args.samples_format in ["nc", "netcdf4"]:
493 1
				save_samples_netcdf(filename_base.format("sampls") + ".nc",
494
					gpmodel, alt, lat, samples, scale=args.scale, compressed=True)
495 1
			if args.samples_format in ["h5", "hdf5"]:
496
				save_samples_netcdf(filename_base.format("sampls") + ".h5",
497
					gpmodel, alt, lat, samples, scale=args.scale, compressed=True)
498
		# MCMC finished here
499
500
	# set the model times and errors to use the full data set for plotting
501 1
	gpmodel.compute(no_ys, no_errs)
502 1
	if args.save_model:
503
		try:
504
			# python 2
505
			import cPickle as pickle
506
		except ImportError:
507
			# python 3
508
			import pickle
509
		# pickle and save the model
510
		with open(filename_base.format("model") + ".pkl", "wb") as f:
511
			pickle.dump((gpmodel), f, -1)
512
513 1
	if args.plot_samples and args.mcmc:
514 1
		plot_random_samples(gpmodel, no_ys, no_dens, no_errs,
515
				samples, args.scale,
0 ignored issues
show
introduced by
The variable samples does not seem to be defined in case args.mcmc on line 451 is False. Are you sure this can never be the case?
Loading history...
516
				filename_base.format("sampls") + ".pdf",
517
				size=4, extra_years=[4, 2])
518
519 1
	if args.plot_median:
520 1
		plot_single_sample_and_residuals(gpmodel, no_ys, no_dens, no_errs,
521
				sampl_percs[1],
0 ignored issues
show
introduced by
The variable sampl_percs does not seem to be defined in case args.mcmc on line 451 is False. Are you sure this can never be the case?
Loading history...
522
				filename_base.format("median") + ".pdf")
523 1
	if args.plot_residuals:
524 1
		plot_residual(gpmodel, no_ys, no_dens, no_errs,
525
				sampl_percs[1], args.scale,
526
				filename_base.format("medres") + ".pdf")
527 1
	if args.plot_maxlnp:
528 1
		plot_single_sample_and_residuals(gpmodel, no_ys, no_dens, no_errs,
529
				samples[np.argmax(lnp)],
0 ignored issues
show
introduced by
The variable lnp does not seem to be defined in case args.mcmc on line 451 is False. Are you sure this can never be the case?
Loading history...
530
				filename_base.format("maxlnp") + ".pdf")
531 1
	if args.plot_maxlnpres:
532 1
		plot_residual(gpmodel, no_ys, no_dens, no_errs,
533
				samples[np.argmax(lnp)], args.scale,
534
				filename_base.format("mlpres") + ".pdf")
535
536 1
	labels = gpmodel.get_parameter_names()
537 1
	logging.info("param percentiles [2.5, 50, 97.5]:")
538 1
	for pc, label in zip(sampl_percs.T, labels):
539 1
		median = pc[1]
540 1
		pc_minus = median - pc[0]
541 1
		pc_plus = pc[2] - median
542 1
		logging.debug("%s: %s", label, pc)
543 1
		logging.info("%s: %.6f (- %.6f) (+ %.6f)", label,
544
				median, pc_minus, pc_plus)
545
546 1
	logging.info("Finished successfully.")
547
548
549 1
if __name__ == "__main__":
550
	main()
551