Passed
Push — master ( 65b079...f515c9 )
by Stefan
05:52
created

sciapy.regress.__main__   F

Complexity

Total Complexity 67

Size/Duplication

Total Lines 544
Duplicated Lines 90.26 %

Test Coverage

Coverage 79.32%

Importance

Changes 0
Metric Value
eloc 396
dl 491
loc 544
ccs 234
cts 295
cp 0.7932
rs 3.04
c 0
b 0
f 0
wmc 67

4 Functions

Rating   Name   Duplication   Size   Complexity  
A _r_sun_earth() 24 24 1
A _train_test_split() 29 29 3
F main() 395 395 56
B save_samples_netcdf() 43 43 7

How to fix   Duplicated Code    Complexity   

Duplicated Code

Duplicate code is one of the most pungent code smells. A rule that is often used is to re-structure code once it is duplicated in three or more places.

Common duplication problems, and corresponding solutions are:

Complexity

 Tip:   Before tackling complexity, make sure that you eliminate any duplication first. This often can reduce the size of classes significantly.

Complex classes like sciapy.regress.__main__ often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
# -*- coding: utf-8 -*-
2
# vim:fileencoding=utf-8
3
#
4
# Copyright (c) 2017-2018 Stefan Bender
5
#
6
# This module is part of sciapy.
7
# sciapy is free software: you can redistribute it or modify
8
# it under the terms of the GNU General Public License as published
9
# by the Free Software Foundation, version 2.
10
# See accompanying LICENSE file or http://www.gnu.org/licenses/gpl-2.0.html.
11 1
"""SCIAMACHY data regression command line interface
12
13
Command line main program for regression analysis of SCIAMACHY
14
daily zonal mean time series (NO for now).
15
"""
16
17 1
import ctypes
18 1
import logging
19
20 1
import numpy as np
21 1
import scipy.optimize as op
22 1
from scipy.interpolate import interp1d
23
24 1
import george
25 1
import celerite
26
27 1
import matplotlib as mpl
28
# switch off X11 rendering
29 1
mpl.use("Agg")
30
31 1
from .load_data import load_solar_gm_table, load_scia_dzm
32 1
from .models_cel import trace_gas_model
33 1
from .mcmc import mcmc_sample_model
34 1
from .statistics import mcmc_statistics
35
36 1
from ._gpkernels import (george_solvers,
37
		setup_george_kernel, setup_celerite_terms)
38 1
from ._plot import (plot_single_sample_and_residuals,
39
		plot_residual, plot_random_samples)
40 1
from ._options import parser
41
42
43 1 View Code Duplication
def save_samples_netcdf(filename, model, alt, lat, samples,
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
44
		scale=1e-6,
45
		lnpost=None, compressed=False):
46 1
	from xarray import Dataset
47 1
	smpl_ds = Dataset(dict([(pname, (["lat", "alt", "sample"],
48
				samples[..., i].reshape(1, 1, -1)))
49
				for i, pname in enumerate(model.get_parameter_names())]
50
				# + [("lpost", (["lat", "alt", "sample"], lnp.reshape(1, 1, -1)))]
51
			),
52
			coords={"lat": [lat], "alt": [alt]})
53
54 1
	for modn in model.mean.models:
55 1
		modl = model.mean.models[modn]
56 1
		if hasattr(modl, "mean"):
57 1
			smpl_ds.attrs[modn + ":mean"] = modl.mean
58
59 1
	units = {"kernel": {
60
				"log": "log(10$^{{{0:.0f}}}$ cm$^{{-3}}$)"
61
						.format(-np.log10(scale))},
62
			"mean": {
63
				"log": "log(10$^{{{0:.0f}}}$ cm$^{{-3}}$)"
64
						.format(-np.log10(scale)),
65
				"cos": "10$^{{{0:.0f}}}$ cm$^{{-3}}$".format(-np.log10(scale)),
66
				"sin": "10$^{{{0:.0f}}}$ cm$^{{-3}}$".format(-np.log10(scale)),
67
				"val": "10$^{{{0:.0f}}}$ cm$^{{-3}}$".format(-np.log10(scale)),
68
				"amp": "10$^{{{0:.0f}}}$ cm$^{{-3}}$".format(-np.log10(scale)),
69
				"tau": "d"}}
70 1
	for pname in smpl_ds.data_vars:
71 1
		_pp = pname.split(':')
72 1
		for _n, _u in units[_pp[0]].items():
73 1
			if _pp[-1].startswith(_n):
74 1
				logging.debug("units for %s: %s", pname, _u)
75 1
				smpl_ds[pname].attrs["units"] = _u
76
77 1
	smpl_ds["alt"].attrs = {"long_name": "altitude", "units": "km"}
78 1
	smpl_ds["lat"].attrs = {"long_name": "latitude", "units": "degrees_north"}
79
80 1
	_encoding = None
81 1
	if compressed:
82 1
		_encoding = {var: {"zlib": True, "complevel": 1}
83
					for var in smpl_ds.data_vars}
84 1
	smpl_ds.to_netcdf(filename, encoding=_encoding)
85 1
	smpl_ds.close()
86
87
88 1 View Code Duplication
def _train_test_split(times, data, errs, train_frac,
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
89
		test_frac, randomize):
90
	# split the data into training and test subsets according to the
91
	# fraction given (default is 1, i.e. no splitting)
92 1
	ndata = len(times)
93 1
	train_size = int(ndata * train_frac)
94 1
	test_size = min(ndata - train_size, int(ndata * test_frac))
95
	# randomize if requested
96 1
	if randomize:
97
		permut_idx = np.random.permutation(np.arange(ndata))
98
	else:
99 1
		permut_idx = np.arange(ndata)
100 1
	train_idx = np.sort(permut_idx[:train_size])
101 1
	test_idx = np.sort(permut_idx[train_size:train_size + test_size])
102 1
	times_train = times[train_idx]
103 1
	data_train = data[train_idx]
104 1
	errs_train = errs[train_idx]
105 1
	if test_size > 0:
106
		times_test = times[test_idx]
107
		data_test = data[test_idx]
108
		errs_test = errs[test_idx]
109
	else:
110 1
		times_test = times
111 1
		data_test = data
112 1
		errs_test = errs
113 1
	logging.info("using %s of %s samples for training.", len(times_train), ndata)
114 1
	logging.info("using %s of %s samples for testing.", len(times_test), ndata)
115 1
	return (times_train, data_train, errs_train,
116
			times_test, data_test, errs_test)
117
118
119 1 View Code Duplication
def _r_sun_earth(time, tfmt="jyear"):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
120
	"""First order approximation of the Sun-Earth distance
121
122
	The Sun-to-Earth distance can be used to (un-)normalize proxies
123
	to the actual distance to the Sun instead of 1 AU.
124
125
	Parameters
126
	----------
127
	time : float
128
		Time value in the units given by 'tfmt'.
129
	tfmt : str, optional
130
		The units of 'time' as supported by the
131
		astropy.time time formats. Default: 'jyear'.
132
133
	Returns
134
	-------
135
	dist : float
136
		The Sun-Earth distance at the given day of year in AU.
137
	"""
138
	from astropy.time import Time
139
	tdoy = Time(time, format=tfmt)
140
	tdoy.format = "yday"
141
	doy = int(tdoy.value.split(':')[1])
142
	return 1 - 0.01672 * np.cos(2 * np.pi / 365.256363 * (doy - 4))
143
144
145 1 View Code Duplication
def main():
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
146 1
	logging.basicConfig(level=logging.WARNING,
147
			format="[%(levelname)-8s] (%(asctime)s) "
148
				"%(filename)s:%(lineno)d %(message)s",
149
			datefmt="%Y-%m-%d %H:%M:%S %z")
150
151 1
	args = parser.parse_args()
152
153 1
	logging.info("command line arguments: %s", args)
154 1
	if args.quiet:
155 1
		logging.getLogger().setLevel(logging.ERROR)
156
	elif args.verbose:
157
		logging.getLogger().setLevel(logging.INFO)
158
	else:
159
		logging.getLogger().setLevel(args.loglevel)
160
161 1
	from numpy.distutils.system_info import get_info
162 1
	try:
163 1
		ob_lib_dirs = get_info("openblas")["library_dirs"]
164 1
	except KeyError:
165 1
		ob_lib_dirs = []
166 1
	for oblas_path in ob_lib_dirs:
167
		oblas_name = "{0}/libopenblas.so".format(oblas_path)
168
		logging.info("Trying %s", oblas_name)
169
		try:
170
			oblas_lib = ctypes.cdll.LoadLibrary(oblas_name)
171
			oblas_cores = oblas_lib.openblas_get_num_threads()
172
			oblas_lib.openblas_set_num_threads(args.openblas_threads)
173
			logging.info("Using %s/%s Openblas thread(s).",
174
					oblas_lib.openblas_get_num_threads(), oblas_cores)
175
		except:
176
			logging.info("Setting number of openblas threads failed.")
177
178 1
	if args.random_seed is not None:
179 1
		np.random.seed(args.random_seed)
180
181 1
	if args.proxies:
182 1
		proxies = args.proxies.split(',')
183 1
		proxy_dict = dict(_p.split(':') for _p in proxies)
184
	else:
185
		proxy_dict = {}
186 1
	lag_dict = {pn: 0 for pn in proxy_dict.keys()}
187
188
	# Post-processing of arguments...
189
	# List of proxy lag fits from csv
190 1
	fit_lags = args.fit_lags.split(',')
191
	# List of proxy lifetime fits from csv
192 1
	fit_lifetimes = args.fit_lifetimes.split(',')
193 1
	fit_annlifetimes = args.fit_annlifetimes.split(',')
194
	# List of proxy lag times from csv
195 1
	lag_dict.update(dict(_ls.split(':') for _ls in args.lag_times.split(',')))
196
	# List of cycles (frequencies in 1/year) from argument list (csv)
197 1
	try:
198 1
		freqs = list(map(float, args.freqs.split(',')))
199 1
	except ValueError:
200 1
		freqs = []
201 1
	args.freqs = freqs
202
	# List of initial parameter values
203 1
	initial = None
204 1
	if args.initial is not None:
205
		try:
206
			initial = list(map(float, args.initial.split(',')))
207
		except ValueError:
208
			pass
209
	# List of GP kernels from argument list (csv)
210 1
	kernls = args.kernels.split(',')
211
212 1
	lat = args.latitude
213 1
	alt = args.altitude
214 1
	logging.info("location: %.0f°N %.0f km", lat, alt)
215
216 1
	no_ys, no_dens, no_errs, no_szas = load_scia_dzm(args.file, alt, lat,
217
			tfmt=args.time_format,
218
			scale=args.scale,
219
			#subsample_factor=args.random_subsample,
220
			#subsample_method="random",
221
			akd_threshold=args.akd_threshold,
222
			cnt_threshold=args.cnt_threshold,
223
			center=args.center_data,
224
			season=args.season,
225
			SPEs=args.exclude_spe)
226
227 1
	(no_ys_train, no_dens_train, no_errs_train,
228
		no_ys_test, no_dens_test, no_errs_test) = _train_test_split(
229
				no_ys, no_dens, no_errs, args.train_fraction,
230
				args.test_fraction, args.random_train_test)
231
232 1
	sza_intp = interp1d(no_ys, no_szas, bounds_error=False)
233
234 1
	max_amp = 1e10 * args.scale
235 1
	max_days = 100
236
237 1
	proxy_config = {}
238 1
	for pn, pf in proxy_dict.items():
239 1
		pt, pp = load_solar_gm_table(pf, cols=[0, 1], names=["time", pn], tfmt=args.time_format)
240
		# use log of proxy values
241 1
		pv = np.log(pp[pn]) if pn in args.log_proxies.split(',') else pp[pn]
242
		# normalize to sun--earth distance squared
243 1
		if pn in args.norm_proxies_distSEsq.split(','):
244
			rad_sun_earth = np.vectorize(_r_sun_earth)(pt, tfmt=args.time_format)
245
			pv /= rad_sun_earth**2
246
		# normalize by cos(SZA)
247 1
		if pn in args.norm_proxies_SZA.split(',') and sza_intp is not None:
248
			pv *= np.cos(np.radians(sza_intp(pt)))
249 1
		proxy_config.update({pn:
250
			dict(times=pt, values=pv,
251
				center=pn in args.center_proxies.split(','),
252
				positive=pn in args.positive_proxies.split(','),
253
				lag=float(lag_dict[pn]),
254
				max_amp=max_amp, max_days=max_days,
255
				sza_intp=sza_intp if args.use_sza else None,
256
			)}
257
		)
258
259 1
	model = trace_gas_model(constant=args.fit_offset,
260
			proxy_config=proxy_config, **vars(args))
261
262 1
	logging.debug("model dict: %s", model.get_parameter_dict())
263 1
	model.freeze_all_parameters()
264
	# thaw parameters according to requested fits
265 1
	for pn in proxy_dict.keys():
266 1
		model.thaw_parameter("{0}:amp".format(pn))
267 1
		if pn in fit_lags:
268
			model.thaw_parameter("{0}:lag".format(pn))
269 1
		if pn in fit_lifetimes:
270 1
			model.set_parameter("{0}:tau0".format(pn), 1e-3)
271 1
			model.thaw_parameter("{0}:tau0".format(pn))
272 1
			if pn in fit_annlifetimes:
273 1
				model.thaw_parameter("{0}:taucos1".format(pn))
274 1
				model.thaw_parameter("{0}:tausin1".format(pn))
275
		else:
276 1
			model.set_parameter("{0}:ltscan".format(pn), 0)
277 1
	for freq in freqs:
278 1
		if not args.fit_phase:
279 1
			model.thaw_parameter("f{0:.0f}:cos".format(freq))
280 1
			model.thaw_parameter("f{0:.0f}:sin".format(freq))
281
		else:
282
			model.thaw_parameter("f{0:.0f}:amp".format(freq))
283
			model.thaw_parameter("f{0:.0f}:phase".format(freq))
284 1
	if args.fit_offset:
285
		#model.set_parameter("offset:value", -100.)
286
		#model.set_parameter("offset:value", 0)
287 1
		model.thaw_parameter("offset:value")
288
289 1
	if initial is not None:
290
		model.set_parameter_vector(initial)
291
	# model.thaw_parameter("GM:ltscan")
292 1
	logging.debug("params: %s", model.get_parameter_dict())
293 1
	logging.debug("param names: %s", model.get_parameter_names())
294 1
	logging.debug("param vector: %s", model.get_parameter_vector())
295 1
	logging.debug("param bounds: %s", model.get_parameter_bounds())
296
	#logging.debug("model value: %s", model.get_value(no_ys))
297
	#logging.debug("default log likelihood: %s", model.log_likelihood(model.vector))
298
299
	# setup the Gaussian Process kernel
300 1
	kernel_base = (1e7 * args.scale)**2
301 1
	ksub = args.name_suffix
302
303 1
	solver = "basic"
304 1
	skwargs = {}
305 1
	if args.HODLR_Solver:
306
		solver = "HODLR"
307
		#skwargs = {"tol": 1e-3}
308
309 1
	if args.george:
310
		gpname, kernel = setup_george_kernel(kernls,
311
				kernel_base=kernel_base, fit_bias=args.fit_bias)
312
		gpmodel = george.GP(kernel, mean=model,
313
			white_noise=1.e-25, fit_white_noise=args.fit_white,
314
			solver=george_solvers[solver], **skwargs)
315
		# the george interface does not allow setting the bounds in
316
		# the kernel initialization so we prepare simple default bounds
317
		kernel_bounds = [(-0.3 * max_amp, 0.3 * max_amp)
318
				for _ in gpmodel.kernel.get_parameter_names()]
319
		bounds = gpmodel.mean.get_parameter_bounds() + kernel_bounds
320
	else:
321 1
		gpname, cel_terms = setup_celerite_terms(kernls,
322
				fit_bias=args.fit_bias, fit_white=args.fit_white)
323 1
		gpmodel = celerite.GP(cel_terms, mean=model,
324
			fit_white_noise=args.fit_white,
325
			fit_mean=True)
326 1
		bounds = gpmodel.get_parameter_bounds()
327 1
	gpmodel.compute(no_ys_train, no_errs_train)
328 1
	logging.debug("gpmodel params: %s", gpmodel.get_parameter_dict())
329 1
	logging.debug("gpmodel bounds: %s", bounds)
330 1
	logging.debug("initial log likelihood: %s", gpmodel.log_likelihood(no_dens_train))
331 1
	if isinstance(gpmodel, celerite.GP):
332 1
		logging.info("(GP) jitter: %s", gpmodel.kernel.jitter)
333 1
	model_name = "_".join(gpmodel.mean.get_parameter_names()).replace(':', '')
334 1
	gpmodel_name = model_name + gpname
335 1
	logging.info("GP model name: %s", gpmodel_name)
336
337 1
	pre_opt = False
338 1
	if args.optimize > 0:
339 1
		def gpmodel_mean(x, *p):
340
			gpmodel.set_parameter_vector(p)
341
			return gpmodel.mean.get_value(x)
342
343 1
		def gpmodel_res(x, *p):
344
			gpmodel.set_parameter_vector(p)
345
			return (gpmodel.mean.get_value(x) - no_dens_train) / no_errs_train
346
347 1
		def lpost(p, y, gp):
348 1
			gp.set_parameter_vector(p)
349 1
			return gp.log_likelihood(y, quiet=True) + gp.log_prior()
350
351 1
		def nlpost(p, y, gp):
352 1
			lp = lpost(p, y, gp)
0 ignored issues
show
introduced by
The variable lpost does not seem to be defined in case args.optimize > 0 on line 338 is False. Are you sure this can never be the case?
Loading history...
353 1
			return -lp if np.isfinite(lp) else 1e25
354
355 1
		def grad_nlpost(p, y, gp):
356 1
			gp.set_parameter_vector(p)
357 1
			grad_ll = gp.grad_log_likelihood(y)
358 1
			if isinstance(grad_ll, tuple):
359
				# celerite
360 1
				return -grad_ll[1]
361
			# george
362
			return -grad_ll
363
364 1
		jacobian = grad_nlpost if gpmodel.kernel.vector_size else None
365 1
		if args.optimize == 1:
366 1
			resop_gp = op.minimize(
367
				nlpost,
368
				gpmodel.get_parameter_vector(),
369
				args=(no_dens_train, gpmodel),
370
				bounds=bounds,
371
				# method="l-bfgs-b", options=dict(disp=True, maxcor=100, eps=1e-9, ftol=2e-15, gtol=1e-8))
372
				method="l-bfgs-b", jac=jacobian)
373
				# method="tnc", options=dict(disp=True, maxiter=500, xtol=1e-12))
374
				# method="nelder-mead", options=dict(disp=True, maxfev=100000, fatol=1.49012e-8, xatol=1.49012e-8))
375
				# method="Powell", options=dict(ftol=1.49012e-08, xtol=1.49012e-08))
376 1
		if args.optimize == 2:
377
			resop_gp = op.differential_evolution(
378
				nlpost,
379
				bounds=bounds,
380
				args=(no_dens_train, gpmodel),
381
				popsize=2 * args.walkers, tol=0.01)
382 1
		if args.optimize == 3:
383
			resop_bh = op.basinhopping(
384
				nlpost,
385
				gpmodel.get_parameter_vector(),
386
				niter=200,
387
				minimizer_kwargs=dict(
388
					args=(no_dens_train, gpmodel),
389
					bounds=bounds,
390
					# method="tnc"))
391
					# method="l-bfgs-b", options=dict(maxcor=100)))
392
					method="l-bfgs-b", jac=jacobian))
393
					# method="Nelder-Mead"))
394
					# method="BFGS"))
395
					# method="Powell", options=dict(ftol=1.49012e-08, xtol=1.49012e-08)))
396
			logging.debug("optimization result: %s", resop_bh)
397
			resop_gp = resop_bh.lowest_optimization_result
398 1
		if args.optimize == 4:
399
			resop_gp, cov_gp = op.curve_fit(
400
				gpmodel_mean,
401
				no_ys_train, no_dens_train, gpmodel.get_parameter_vector(),
402
				bounds=tuple(np.array(bounds).T),
403
				# method='lm',
404
				# absolute_sigma=True,
405
				sigma=no_errs_train)
406
			print(resop_gp, np.sqrt(np.diag(cov_gp)))
407 1
		logging.info("%s", resop_gp.message)
0 ignored issues
show
introduced by
The variable resop_gp does not seem to be defined in case args.optimize == 1 on line 365 is False. Are you sure this can never be the case?
Loading history...
408 1
		logging.debug("optimization result: %s", resop_gp)
409 1
		logging.info("gpmodel dict: %s", gpmodel.get_parameter_dict())
410 1
		logging.info("log posterior trained: %s", lpost(gpmodel.get_parameter_vector(), no_dens_train, gpmodel))
411 1
		gpmodel.compute(no_ys_test, no_errs_test)
412 1
		logging.info("log posterior test: %s", lpost(gpmodel.get_parameter_vector(), no_dens_test, gpmodel))
413 1
		gpmodel.compute(no_ys, no_errs)
414 1
		logging.info("log posterior all: %s", lpost(gpmodel.get_parameter_vector(), no_dens, gpmodel))
415
		# cross check to make sure that the gpmodel parameter vector is really
416
		# set to the fitted parameters
417 1
		logging.info("opt. model vector: %s", resop_gp.x)
418 1
		gpmodel.compute(no_ys_train, no_errs_train)
419 1
		logging.debug("opt. log posterior trained 1: %s", lpost(resop_gp.x, no_dens_train, gpmodel))
420 1
		gpmodel.compute(no_ys_test, no_errs_test)
421 1
		logging.debug("opt. log posterior test 1: %s", lpost(resop_gp.x, no_dens_test, gpmodel))
422 1
		gpmodel.compute(no_ys, no_errs)
423 1
		logging.debug("opt. log posterior all 1: %s", lpost(resop_gp.x, no_dens, gpmodel))
424 1
		logging.debug("opt. model vector: %s", gpmodel.get_parameter_vector())
425 1
		gpmodel.compute(no_ys_train, no_errs_train)
426 1
		logging.debug("opt. log posterior trained 2: %s", lpost(gpmodel.get_parameter_vector(), no_dens_train, gpmodel))
427 1
		gpmodel.compute(no_ys_test, no_errs_test)
428 1
		logging.debug("opt. log posterior test 2: %s", lpost(gpmodel.get_parameter_vector(), no_dens_test, gpmodel))
429 1
		gpmodel.compute(no_ys, no_errs)
430 1
		logging.debug("opt. log posterior all 2: %s", lpost(gpmodel.get_parameter_vector(), no_dens, gpmodel))
431 1
		pre_opt = resop_gp.success
432 1
	try:
433 1
		logging.info("GM lt: %s", gpmodel.get_parameter("mean:GM:tau0"))
434
	except ValueError:
435
		pass
436 1
	logging.info("(GP) model: %s", gpmodel.kernel)
437 1
	if isinstance(gpmodel, celerite.GP):
438 1
		logging.info("(GP) jitter: %s", gpmodel.kernel.jitter)
439
440 1
	bestfit = gpmodel.get_parameter_vector()
441 1
	filename_base = ("NO_regress_fit_{0}_{1:.0f}_{2:.0f}_{{0}}_{3}"
442
					.format(gpmodel_name, lat * 10, alt, ksub))
443
444 1
	if args.mcmc:
445 1
		gpmodel.compute(no_ys_train, no_errs_train)
446 1
		samples, lnp = mcmc_sample_model(gpmodel,
447
				no_dens_train,
448
				beta=1.0,
449
				nwalkers=args.walkers, nburnin=args.burn_in,
450
				nprod=args.production, nthreads=args.threads,
451
				show_progress=args.progress,
452
				optimized=pre_opt, bounds=bounds, return_logpost=True)
453
454 1
		if args.train_fraction < 1. or args.test_fraction < 1.:
455
			logging.info("Statistics for the test samples")
456
			mcmc_statistics(gpmodel,
457
					no_ys_test, no_dens_test, no_errs_test,
458
					no_ys_train, no_dens_train, no_errs_train,
459
					samples, lnp,
460
			)
461 1
		logging.info("Statistics for all samples")
462 1
		mcmc_statistics(gpmodel,
463
				no_ys, no_dens, no_errs,
464
				no_ys_train, no_dens_train, no_errs_train,
465
				samples, lnp,
466
		)
467
468 1
		sampl_percs = np.percentile(samples, [2.5, 50, 97.5], axis=0)
469 1
		if args.plot_corner:
470 1
			import corner
471
			# Corner plot of the sampled parameters
472 1
			fig = corner.corner(samples,
473
					quantiles=[0.025, 0.5, 0.975],
474
					show_titles=True,
475
					labels=gpmodel.get_parameter_names(),
476
					truths=bestfit,
477
					hist_args=dict(normed=True))
478 1
			fig.savefig(filename_base.format("corner") + ".pdf", transparent=True)
479
480 1
		if args.save_samples:
481 1
			if args.samples_format in ["npz"]:
482
				# save the samples compressed to save space.
483
				np.savez_compressed(filename_base.format("sampls") + ".npz",
484
					samples=samples)
485 1
			if args.samples_format in ["nc", "netcdf4"]:
486 1
				save_samples_netcdf(filename_base.format("sampls") + ".nc",
487
					gpmodel, alt, lat, samples, scale=args.scale, compressed=True)
488 1
			if args.samples_format in ["h5", "hdf5"]:
489
				save_samples_netcdf(filename_base.format("sampls") + ".h5",
490
					gpmodel, alt, lat, samples, scale=args.scale, compressed=True)
491
		# MCMC finished here
492
493
	# set the model times and errors to use the full data set for plotting
494 1
	gpmodel.compute(no_ys, no_errs)
495 1
	if args.save_model:
496
		try:
497
			# python 2
498
			import cPickle as pickle
499
		except ImportError:
500
			# python 3
501
			import pickle
502
		# pickle and save the model
503
		with open(filename_base.format("model") + ".pkl", "wb") as f:
504
			pickle.dump((gpmodel), f, -1)
505
506 1
	if args.plot_samples and args.mcmc:
507 1
		plot_random_samples(gpmodel, no_ys, no_dens, no_errs,
508
				samples, args.scale,
0 ignored issues
show
introduced by
The variable samples does not seem to be defined in case args.mcmc on line 444 is False. Are you sure this can never be the case?
Loading history...
509
				filename_base.format("sampls") + ".pdf",
510
				size=4, extra_years=[4, 2])
511
512 1
	if args.plot_median:
513 1
		plot_single_sample_and_residuals(gpmodel, no_ys, no_dens, no_errs,
514
				sampl_percs[1],
0 ignored issues
show
introduced by
The variable sampl_percs does not seem to be defined in case args.mcmc on line 444 is False. Are you sure this can never be the case?
Loading history...
515
				filename_base.format("median") + ".pdf")
516 1
	if args.plot_residuals:
517 1
		plot_residual(gpmodel, no_ys, no_dens, no_errs,
518
				sampl_percs[1], args.scale,
519
				filename_base.format("medres") + ".pdf")
520 1
	if args.plot_maxlnp:
521 1
		plot_single_sample_and_residuals(gpmodel, no_ys, no_dens, no_errs,
522
				samples[np.argmax(lnp)],
0 ignored issues
show
introduced by
The variable lnp does not seem to be defined in case args.mcmc on line 444 is False. Are you sure this can never be the case?
Loading history...
523
				filename_base.format("maxlnp") + ".pdf")
524 1
	if args.plot_maxlnpres:
525 1
		plot_residual(gpmodel, no_ys, no_dens, no_errs,
526
				samples[np.argmax(lnp)], args.scale,
527
				filename_base.format("mlpres") + ".pdf")
528
529 1
	labels = gpmodel.get_parameter_names()
530 1
	logging.info("param percentiles [2.5, 50, 97.5]:")
531 1
	for pc, label in zip(sampl_percs.T, labels):
532 1
		median = pc[1]
533 1
		pc_minus = median - pc[0]
534 1
		pc_plus = pc[2] - median
535 1
		logging.debug("%s: %s", label, pc)
536 1
		logging.info("%s: %.6f (- %.6f) (+ %.6f)", label,
537
				median, pc_minus, pc_plus)
538
539 1
	logging.info("Finished successfully.")
540
541
542 1
if __name__ == "__main__":
543
	main()
544