sciapy.regress.__main__   F
last analyzed

Complexity

Total Complexity 66

Size/Duplication

Total Lines 558
Duplicated Lines 0 %

Test Coverage

Coverage 84.59%

Importance

Changes 0
Metric Value
eloc 402
dl 0
loc 558
ccs 247
cts 292
cp 0.8459
rs 3.12
c 0
b 0
f 0
wmc 66

4 Functions

Rating   Name   Duplication   Size   Complexity  
A _r_sun_earth() 0 24 1
A _train_test_split() 0 29 3
B save_samples_netcdf() 0 46 8
F main() 0 404 54

How to fix   Complexity   

Complexity

Complex classes like sciapy.regress.__main__ often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
# -*- coding: utf-8 -*-
2
# vim:fileencoding=utf-8
3
#
4
# Copyright (c) 2017-2018 Stefan Bender
5
#
6
# This module is part of sciapy.
7
# sciapy is free software: you can redistribute it or modify
8
# it under the terms of the GNU General Public License as published
9
# by the Free Software Foundation, version 2.
10
# See accompanying LICENSE file or http://www.gnu.org/licenses/gpl-2.0.html.
11 1
"""SCIAMACHY data regression command line interface
12
13
Command line main program for regression analysis of SCIAMACHY
14
daily zonal mean time series (NO for now).
15
"""
16
17 1
import ctypes
18 1
import logging
19 1
from os import environ, path
20
21 1
import numpy as np
22 1
import scipy.optimize as op
23 1
from scipy.interpolate import interp1d
24
25 1
import george
26 1
import celerite
27
28 1
import matplotlib as mpl
29
# switch off X11 rendering
30 1
mpl.use("Agg")
31
32 1
from regressproxy.models_cel import proxy_model_set as trace_gas_model
33
34 1
from .load_data import load_solar_gm_table, load_scia_dzm
35 1
from .mcmc import mcmc_sample_model
36 1
from .statistics import mcmc_statistics
37
38 1
from ._gpkernels import (george_solvers,
39
		setup_george_kernel, setup_celerite_terms)
40 1
from ._plot import (plot_single_sample_and_residuals,
41
		plot_residual, plot_random_samples)
42 1
from ._options import parser
43
44
45 1
def save_samples_netcdf(filename, model, alt, lat, samples,
46
		scale=1e-6,
47
		lnpost=None, compressed=False):
48 1
	from xarray import Dataset
49 1
	smpl_ds = Dataset(dict([(pname, (["lat", "alt", "sample"],
50
				samples[..., i].reshape(1, 1, -1)))
51
				for i, pname in enumerate(model.get_parameter_names())]
52
				# + [("lpost", (["lat", "alt", "sample"], lnp.reshape(1, 1, -1)))]
53
			),
54
			coords={"lat": [lat], "alt": [alt]})
55
56 1
	for modn in model.mean.models:
57 1
		modl = model.mean.models[modn]
58 1
		if hasattr(modl, "mean"):
59 1
			smpl_ds.attrs[modn + ":mean"] = modl.mean
60
61 1
	units = {"kernel": {
62
				"log": "log(10$^{{{0:.0f}}}$ cm$^{{-3}}$)"
63
						.format(-np.log10(scale))},
64
			"mean": {
65
				"log": "log(10$^{{{0:.0f}}}$ cm$^{{-3}}$)"
66
						.format(-np.log10(scale)),
67
				"cos": "10$^{{{0:.0f}}}$ cm$^{{-3}}$".format(-np.log10(scale)),
68
				"sin": "10$^{{{0:.0f}}}$ cm$^{{-3}}$".format(-np.log10(scale)),
69
				"val": "10$^{{{0:.0f}}}$ cm$^{{-3}}$".format(-np.log10(scale)),
70
				"amp": "10$^{{{0:.0f}}}$ cm$^{{-3}}$".format(-np.log10(scale)),
71
				"tau": "d"}}
72 1
	for pname in smpl_ds.data_vars:
73 1
		_pp = pname.split(':')
74 1
		for _n, _u in units.get(_pp[0], {}).items():
75 1
			if _pp[-1].startswith(_n):
76 1
				logging.debug("units for %s: %s", pname, _u)
77 1
				smpl_ds[pname].attrs["units"] = _u
78
79 1
	smpl_ds["alt"].attrs = {"long_name": "altitude", "units": "km"}
80 1
	smpl_ds["lat"].attrs = {"long_name": "latitude", "units": "degrees_north"}
81
82 1
	_encoding = None
83 1
	if compressed:
84 1
		_encoding = {var: {"zlib": True, "complevel": 1}
85
					for var in smpl_ds.data_vars}
86 1
	try:
87 1
		smpl_ds.to_netcdf(filename, encoding=_encoding)
88
	except ValueError:
89
		smpl_ds.to_netcdf(filename)
90 1
	smpl_ds.close()
91
92
93 1
def _train_test_split(times, data, errs, train_frac,
94
		test_frac, randomize):
95
	# split the data into training and test subsets according to the
96
	# fraction given (default is 1, i.e. no splitting)
97 1
	ndata = len(times)
98 1
	train_size = int(ndata * train_frac)
99 1
	test_size = min(ndata - train_size, int(ndata * test_frac))
100
	# randomize if requested
101 1
	if randomize:
102
		permut_idx = np.random.permutation(np.arange(ndata))
103
	else:
104 1
		permut_idx = np.arange(ndata)
105 1
	train_idx = np.sort(permut_idx[:train_size])
106 1
	test_idx = np.sort(permut_idx[train_size:train_size + test_size])
107 1
	times_train = times[train_idx]
108 1
	data_train = data[train_idx]
109 1
	errs_train = errs[train_idx]
110 1
	if test_size > 0:
111
		times_test = times[test_idx]
112
		data_test = data[test_idx]
113
		errs_test = errs[test_idx]
114
	else:
115 1
		times_test = times
116 1
		data_test = data
117 1
		errs_test = errs
118 1
	logging.info("using %s of %s samples for training.", len(times_train), ndata)
119 1
	logging.info("using %s of %s samples for testing.", len(times_test), ndata)
120 1
	return (times_train, data_train, errs_train,
121
			times_test, data_test, errs_test)
122
123
124 1
def _r_sun_earth(time, tfmt="jyear"):
125
	"""First order approximation of the Sun-Earth distance
126
127
	The Sun-to-Earth distance can be used to (un-)normalize proxies
128
	to the actual distance to the Sun instead of 1 AU.
129
130
	Parameters
131
	----------
132
	time : float
133
		Time value in the units given by 'tfmt'.
134
	tfmt : str, optional
135
		The units of 'time' as supported by the
136
		astropy.time time formats. Default: 'jyear'.
137
138
	Returns
139
	-------
140
	dist : float
141
		The Sun-Earth distance at the given day of year in AU.
142
	"""
143
	from astropy.time import Time
144
	tdoy = Time(time, format=tfmt)
145
	tdoy.format = "yday"
146
	doy = int(tdoy.value.split(':')[1])
147
	return 1 - 0.01672 * np.cos(2 * np.pi / 365.256363 * (doy - 4))
148
149
150 1
def main():
151 1
	logging.basicConfig(level=logging.WARNING,
152
			format="[%(levelname)-8s] (%(asctime)s) "
153
				"%(filename)s:%(lineno)d %(message)s",
154
			datefmt="%Y-%m-%d %H:%M:%S %z")
155
156 1
	args = parser.parse_args()
157
158 1
	logging.info("command line arguments: %s", args)
159 1
	if args.quiet:
160 1
		logging.getLogger().setLevel(logging.ERROR)
161
	elif args.verbose:
162
		logging.getLogger().setLevel(logging.INFO)
163
	else:
164
		logging.getLogger().setLevel(args.loglevel)
165
166 1
	if args.openblas_threads is not None:
167
		# Sets OpenMP/Openblas number of threads if set,
168
		# else uses the environment's and library's defaults
169
		environ["OMP_NUM_THREADS"] = str(args.openblas_threads)
170
		environ["OPENBLAS_NUM_THREADS"] = str(args.openblas_threads)
171
172 1
	if args.random_seed is not None:
173 1
		np.random.seed(args.random_seed)
174
175 1
	if args.proxies:
176 1
		proxies = args.proxies.split(',')
177 1
		proxy_dict = dict(_p.split(':') for _p in proxies)
178
	else:
179
		proxy_dict = {}
180 1
	lag_dict = {pn: 0 for pn in proxy_dict.keys()}
181
182
	# Post-processing of arguments...
183
	# List of proxy lag fits from csv
184 1
	fit_lags = args.fit_lags.split(',')
185
	# List of proxy lifetime fits from csv
186 1
	fit_lifetimes = args.fit_lifetimes.split(',')
187 1
	fit_annlifetimes = args.fit_annlifetimes.split(',')
188
	# List of proxy lag times from csv
189 1
	lag_dict.update(dict(_ls.split(':') for _ls in args.lag_times.split(',')))
190
	# List of cycles (frequencies in 1/year) from argument list (csv)
191 1
	try:
192 1
		freqs = list(map(float, args.freqs.split(',')))
193 1
	except ValueError:
194 1
		freqs = []
195 1
	args.freqs = freqs
196
	# List of initial parameter values
197 1
	initial = None
198 1
	if args.initial is not None:
199
		try:
200
			initial = list(map(float, args.initial.split(',')))
201
		except ValueError:
202
			pass
203
	# List of GP kernels from argument list (csv)
204 1
	kernls = args.kernels.split(',')
205
206 1
	lat = args.latitude
207 1
	alt = args.altitude
208 1
	logging.info("location: %.0f°N %.0f km", lat, alt)
209
210 1
	no_ys, no_dens, no_errs, no_szas = load_scia_dzm(args.file, alt, lat,
211
			tfmt=args.time_format,
212
			scale=args.scale,
213
			#subsample_factor=args.random_subsample,
214
			#subsample_method="random",
215
			akd_threshold=args.akd_threshold,
216
			cnt_threshold=args.cnt_threshold,
217
			center=args.center_data,
218
			season=args.season,
219
			SPEs=args.exclude_spe)
220
221 1
	(no_ys_train, no_dens_train, no_errs_train,
222
		no_ys_test, no_dens_test, no_errs_test) = _train_test_split(
223
				no_ys, no_dens, no_errs, args.train_fraction,
224
				args.test_fraction, args.random_train_test)
225
226 1
	sza_intp = interp1d(no_ys, no_szas, bounds_error=False)
227
228 1
	max_amp = 1e10 * args.scale
229 1
	max_days = 100
230
231 1
	proxy_config = {}
232 1
	for pn, pf in proxy_dict.items():
233 1
		pt, pp = load_solar_gm_table(path.expanduser(pf),
234
				cols=[0, 1], names=["time", pn], tfmt=args.time_format)
235 1
		pv = pp[pn]
236
		# use log of proxy values if desired
237 1
		if pn in args.log_proxies.split(','):
238
			pv = np.log(pv)
239
		# normalize to sun--earth distance squared
240 1
		if pn in args.norm_proxies_distSEsq.split(','):
241
			rad_sun_earth = np.vectorize(_r_sun_earth)(pt, tfmt=args.time_format)
242
			pv /= rad_sun_earth**2
243
		# normalize by cos(SZA)
244 1
		if pn in args.norm_proxies_SZA.split(',') and sza_intp is not None:
245
			pv *= np.cos(np.radians(sza_intp(pt)))
246 1
		proxy_config.update({pn:
247
			dict(times=pt, values=pv,
248
				center=pn in args.center_proxies.split(','),
249
				positive=pn in args.positive_proxies.split(','),
250
				lag=float(lag_dict[pn]),
251
				max_amp=max_amp, max_days=max_days,
252
				sza_intp=sza_intp if args.use_sza else None,
253
			)}
254
		)
255
256 1
	model = trace_gas_model(constant=args.fit_offset,
257
			proxy_config=proxy_config, **vars(args))
258
259 1
	logging.debug("model dict: %s", model.get_parameter_dict())
260 1
	model.freeze_all_parameters()
261
	# thaw parameters according to requested fits
262 1
	for pn in proxy_dict.keys():
263 1
		model.thaw_parameter("{0}:amp".format(pn))
264 1
		if pn in fit_lags:
265
			model.thaw_parameter("{0}:lag".format(pn))
266 1
		if pn in fit_lifetimes:
267 1
			model.set_parameter("{0}:tau0".format(pn), 1e-3)
268 1
			model.thaw_parameter("{0}:tau0".format(pn))
269 1
			if pn in fit_annlifetimes:
270 1
				model.thaw_parameter("{0}:taucos1".format(pn))
271 1
				model.thaw_parameter("{0}:tausin1".format(pn))
272
		else:
273 1
			model.set_parameter("{0}:ltscan".format(pn), 0)
274 1
	for freq in freqs:
275 1
		if not args.fit_phase:
276 1
			model.thaw_parameter("f{0:.0f}:cos".format(freq))
277 1
			model.thaw_parameter("f{0:.0f}:sin".format(freq))
278
		else:
279
			model.thaw_parameter("f{0:.0f}:amp".format(freq))
280
			model.thaw_parameter("f{0:.0f}:phase".format(freq))
281 1
	if args.fit_offset:
282
		#model.set_parameter("offset:value", -100.)
283
		#model.set_parameter("offset:value", 0)
284 1
		model.thaw_parameter("offset:value")
285
286 1
	if initial is not None:
287
		model.set_parameter_vector(initial)
288
	# model.thaw_parameter("GM:ltscan")
289 1
	logging.debug("params: %s", model.get_parameter_dict())
290 1
	logging.debug("param names: %s", model.get_parameter_names())
291 1
	logging.debug("param vector: %s", model.get_parameter_vector())
292 1
	logging.debug("param bounds: %s", model.get_parameter_bounds())
293
	#logging.debug("model value: %s", model.get_value(no_ys))
294
	#logging.debug("default log likelihood: %s", model.log_likelihood(model.vector))
295
296
	# setup the Gaussian Process kernel
297 1
	kernel_base = (1e7 * args.scale)**2
298 1
	ksub = args.name_suffix
299
300 1
	solver = "basic"
301 1
	skwargs = {}
302 1
	if args.HODLR_Solver:
303 1
		solver = "HODLR"
304
		#skwargs = {"tol": 1e-3}
305
306 1
	if args.george:
307 1
		gpname, kernel = setup_george_kernel(kernls,
308
				kernel_base=kernel_base, fit_bias=args.fit_bias)
309 1
		gpmodel = george.GP(kernel, mean=model,
310
			white_noise=1.e-25, fit_white_noise=args.fit_white,
311
			solver=george_solvers[solver], **skwargs)
312
		# the george interface does not allow setting the bounds in
313
		# the kernel initialization so we prepare simple default bounds
314 1
		kernel_bounds = [(-0.3 * max_amp, 0.3 * max_amp)] * args.fit_white + [
315
			(-0.3 * max_amp, 0.3 * max_amp)
316
			for _ in gpmodel.kernel.get_parameter_names()
317
		]
318 1
		bounds = gpmodel.mean.get_parameter_bounds() + kernel_bounds
319
	else:
320 1
		gpname, cel_terms = setup_celerite_terms(kernls,
321
				fit_bias=args.fit_bias, fit_white=args.fit_white)
322 1
		gpmodel = celerite.GP(cel_terms, mean=model,
323
			fit_white_noise=args.fit_white,
324
			fit_mean=True)
325 1
		bounds = gpmodel.get_parameter_bounds()
326 1
	gpmodel.compute(no_ys_train, no_errs_train)
327 1
	logging.debug("gpmodel params: %s", gpmodel.get_parameter_dict())
328 1
	logging.debug("gpmodel bounds: %s", bounds)
329 1
	logging.debug("initial log likelihood: %s", gpmodel.log_likelihood(no_dens_train))
330 1
	if isinstance(gpmodel, celerite.GP):
331 1
		logging.info("(GP) jitter: %s", gpmodel.kernel.jitter)
332 1
	model_name = "_".join(gpmodel.mean.get_parameter_names()).replace(':', '')
333 1
	gpmodel_name = model_name + gpname
334 1
	logging.info("GP model name: %s", gpmodel_name)
335
336 1
	pre_opt = False
337 1
	if args.optimize > 0:
338 1
		def gpmodel_mean(x, *p):
339 1
			gpmodel.set_parameter_vector(p)
340 1
			return gpmodel.mean.get_value(x)
341
342 1
		def gpmodel_res(x, *p):
343
			gpmodel.set_parameter_vector(p)
344
			return (gpmodel.mean.get_value(x) - no_dens_train) / no_errs_train
345
346 1
		def lpost(p, y, gp):
347 1
			gp.set_parameter_vector(p)
348 1
			return gp.log_likelihood(y, quiet=True) + gp.log_prior()
349
350 1
		def nlpost(p, y, gp):
351 1
			lp = lpost(p, y, gp)
0 ignored issues
show
introduced by
The variable lpost does not seem to be defined in case args.optimize > 0 on line 337 is False. Are you sure this can never be the case?
Loading history...
352 1
			return -lp if np.isfinite(lp) else 1e25
353
354 1
		def grad_nlpost(p, y, gp):
355 1
			gp.set_parameter_vector(p)
356 1
			grad_ll = gp.grad_log_likelihood(y)
357 1
			if isinstance(grad_ll, tuple):
358
				# celerite
359 1
				return -grad_ll[1]
360
			# george
361 1
			return -grad_ll
362
363 1
		jacobian = grad_nlpost if gpmodel.kernel.vector_size else None
364 1
		if args.optimize == 1:
365 1
			resop_gp = op.minimize(
366
				nlpost,
367
				gpmodel.get_parameter_vector(),
368
				args=(no_dens_train, gpmodel),
369
				bounds=bounds,
370
				# method="l-bfgs-b", options=dict(disp=True, maxcor=100, eps=1e-9, ftol=2e-15, gtol=1e-8))
371
				method="l-bfgs-b", jac=jacobian)
372
				# method="tnc", options=dict(disp=True, maxiter=500, xtol=1e-12))
373
				# method="nelder-mead", options=dict(disp=True, maxfev=100000, fatol=1.49012e-8, xatol=1.49012e-8))
374
				# method="Powell", options=dict(ftol=1.49012e-08, xtol=1.49012e-08))
375 1
		elif args.optimize == 2:
376 1
			resop_gp = op.differential_evolution(
377
				nlpost,
378
				bounds=bounds,
379
				args=(no_dens_train, gpmodel),
380
				popsize=2 * args.walkers, tol=0.01)
381 1
		elif args.optimize == 3:
382 1
			resop_bh = op.basinhopping(
383
				nlpost,
384
				gpmodel.get_parameter_vector(),
385
				niter=200,
386
				minimizer_kwargs=dict(
387
					args=(no_dens_train, gpmodel),
388
					bounds=bounds,
389
					# method="tnc"))
390
					# method="l-bfgs-b", options=dict(maxcor=100)))
391
					method="l-bfgs-b", jac=jacobian))
392
					# method="Nelder-Mead"))
393
					# method="BFGS"))
394
					# method="Powell", options=dict(ftol=1.49012e-08, xtol=1.49012e-08)))
395 1
			logging.debug("optimization result: %s", resop_bh)
396 1
			resop_gp = resop_bh.lowest_optimization_result
397 1
		elif args.optimize == 4:
398 1
			resop, cov_gp = op.curve_fit(
399
				gpmodel_mean,
400
				no_ys_train, no_dens_train, gpmodel.get_parameter_vector(),
401
				bounds=tuple(np.array(bounds).T),
402
				# method='lm',
403
				# absolute_sigma=True,
404
				sigma=no_errs_train)
405 1
			resop_gp = op.OptimizeResult(dict(
406
				x=resop,
407
				success=True,
408
				message="Curve fit successful.",
409
			))
410 1
			logging.debug("curve fit %s, std %s:", resop, np.sqrt(np.diag(cov_gp)))
411
		else:
412
			logging.warn("unsupported optimization method: %s", args.optimize)
413
			resop_gp = op.OptimizeResult(dict(
414
				x=gpmodel.get_parameter_vector(),
415
				success=False,
416
				message="unsupported optimization method: {0}".format(args.optimize),
417
			))
418 1
		logging.info("%s", resop_gp.message)
419 1
		logging.debug("optimization result: %s", resop_gp)
420 1
		logging.info("gpmodel dict: %s", gpmodel.get_parameter_dict())
421 1
		logging.info("log posterior trained: %s", lpost(gpmodel.get_parameter_vector(), no_dens_train, gpmodel))
422 1
		gpmodel.compute(no_ys_test, no_errs_test)
423 1
		logging.info("log posterior test: %s", lpost(gpmodel.get_parameter_vector(), no_dens_test, gpmodel))
424 1
		gpmodel.compute(no_ys, no_errs)
425 1
		logging.info("log posterior all: %s", lpost(gpmodel.get_parameter_vector(), no_dens, gpmodel))
426
		# cross check to make sure that the gpmodel parameter vector is really
427
		# set to the fitted parameters
428 1
		logging.info("opt. model vector: %s", resop_gp.x)
429 1
		gpmodel.compute(no_ys_train, no_errs_train)
430 1
		logging.debug("opt. log posterior trained 1: %s", lpost(resop_gp.x, no_dens_train, gpmodel))
431 1
		gpmodel.compute(no_ys_test, no_errs_test)
432 1
		logging.debug("opt. log posterior test 1: %s", lpost(resop_gp.x, no_dens_test, gpmodel))
433 1
		gpmodel.compute(no_ys, no_errs)
434 1
		logging.debug("opt. log posterior all 1: %s", lpost(resop_gp.x, no_dens, gpmodel))
435 1
		logging.debug("opt. model vector: %s", gpmodel.get_parameter_vector())
436 1
		gpmodel.compute(no_ys_train, no_errs_train)
437 1
		logging.debug("opt. log posterior trained 2: %s", lpost(gpmodel.get_parameter_vector(), no_dens_train, gpmodel))
438 1
		gpmodel.compute(no_ys_test, no_errs_test)
439 1
		logging.debug("opt. log posterior test 2: %s", lpost(gpmodel.get_parameter_vector(), no_dens_test, gpmodel))
440 1
		gpmodel.compute(no_ys, no_errs)
441 1
		logging.debug("opt. log posterior all 2: %s", lpost(gpmodel.get_parameter_vector(), no_dens, gpmodel))
442 1
		pre_opt = resop_gp.success
443 1
	try:
444 1
		logging.info("GM lt: %s", gpmodel.get_parameter("mean:GM:tau0"))
445
	except ValueError:
446
		pass
447 1
	logging.info("(GP) model: %s", gpmodel.kernel)
448 1
	if isinstance(gpmodel, celerite.GP):
449 1
		logging.info("(GP) jitter: %s", gpmodel.kernel.jitter)
450
451 1
	bestfit = gpmodel.get_parameter_vector()
452 1
	filename_base = path.join(
453
		args.output_path,
454
		"NO_regress_fit_{0}_{1:.0f}_{2:.0f}_{{0}}_{3}"
455
		.format(gpmodel_name, lat * 10, alt, ksub),
456
	)
457
458 1
	if args.mcmc:
459 1
		gpmodel.compute(no_ys_train, no_errs_train)
460 1
		samples, lnp = mcmc_sample_model(gpmodel,
461
				no_dens_train,
462
				beta=1.0,
463
				nwalkers=args.walkers, nburnin=args.burn_in,
464
				nprod=args.production, nthreads=args.threads,
465
				show_progress=args.progress,
466
				optimized=pre_opt, bounds=bounds, return_logpost=True)
467
468 1
		if args.train_fraction < 1. or args.test_fraction < 1.:
469
			logging.info("Statistics for the test samples")
470
			mcmc_statistics(gpmodel,
471
					no_ys_test, no_dens_test, no_errs_test,
472
					no_ys_train, no_dens_train, no_errs_train,
473
					samples, lnp,
474
			)
475 1
		logging.info("Statistics for all samples")
476 1
		mcmc_statistics(gpmodel,
477
				no_ys, no_dens, no_errs,
478
				no_ys_train, no_dens_train, no_errs_train,
479
				samples, lnp,
480
		)
481
482 1
		sampl_percs = np.percentile(samples, [2.5, 50, 97.5], axis=0)
483 1
		if args.plot_corner:
484 1
			import corner
485
			# Corner plot of the sampled parameters
486 1
			fig = corner.corner(samples,
487
					quantiles=[0.025, 0.5, 0.975],
488
					show_titles=True,
489
					labels=gpmodel.get_parameter_names(),
490
					truths=bestfit,
491
					hist_args=dict(normed=True))
492 1
			fig.savefig(filename_base.format("corner") + ".pdf", transparent=True)
493
494 1
		if args.save_samples:
495 1
			if args.samples_format in ["npz"]:
496
				# save the samples compressed to save space.
497
				np.savez_compressed(filename_base.format("sampls") + ".npz",
498
					samples=samples)
499 1
			if args.samples_format in ["nc", "netcdf4"]:
500 1
				save_samples_netcdf(filename_base.format("sampls") + ".nc",
501
					gpmodel, alt, lat, samples, scale=args.scale, compressed=True)
502 1
			if args.samples_format in ["h5", "hdf5"]:
503
				save_samples_netcdf(filename_base.format("sampls") + ".h5",
504
					gpmodel, alt, lat, samples, scale=args.scale, compressed=True)
505
		# MCMC finished here
506
507
	# set the model times and errors to use the full data set for plotting
508 1
	gpmodel.compute(no_ys, no_errs)
509 1
	if args.save_model:
510
		try:
511
			# python 2
512
			import cPickle as pickle
513
		except ImportError:
514
			# python 3
515
			import pickle
516
		# pickle and save the model
517
		with open(filename_base.format("model") + ".pkl", "wb") as f:
518
			pickle.dump((gpmodel), f, -1)
519
520 1
	if args.plot_samples and args.mcmc:
521 1
		plot_random_samples(gpmodel, no_ys, no_dens, no_errs,
522
				samples, args.scale,
0 ignored issues
show
introduced by
The variable samples does not seem to be defined in case args.mcmc on line 458 is False. Are you sure this can never be the case?
Loading history...
523
				filename_base.format("sampls") + ".pdf",
524
				size=4, extra_years=[4, 2])
525
526 1
	if args.plot_median:
527 1
		plot_single_sample_and_residuals(gpmodel, no_ys, no_dens, no_errs,
528
				sampl_percs[1],
0 ignored issues
show
introduced by
The variable sampl_percs does not seem to be defined in case args.mcmc on line 458 is False. Are you sure this can never be the case?
Loading history...
529
				filename_base.format("median") + ".pdf")
530 1
	if args.plot_residuals:
531 1
		plot_residual(gpmodel, no_ys, no_dens, no_errs,
532
				sampl_percs[1], args.scale,
533
				filename_base.format("medres") + ".pdf")
534 1
	if args.plot_maxlnp:
535 1
		plot_single_sample_and_residuals(gpmodel, no_ys, no_dens, no_errs,
536
				samples[np.argmax(lnp)],
0 ignored issues
show
introduced by
The variable lnp does not seem to be defined in case args.mcmc on line 458 is False. Are you sure this can never be the case?
Loading history...
537
				filename_base.format("maxlnp") + ".pdf")
538 1
	if args.plot_maxlnpres:
539 1
		plot_residual(gpmodel, no_ys, no_dens, no_errs,
540
				samples[np.argmax(lnp)], args.scale,
541
				filename_base.format("mlpres") + ".pdf")
542
543 1
	labels = gpmodel.get_parameter_names()
544 1
	logging.info("param percentiles [2.5, 50, 97.5]:")
545 1
	for pc, label in zip(sampl_percs.T, labels):
546 1
		median = pc[1]
547 1
		pc_minus = median - pc[0]
548 1
		pc_plus = pc[2] - median
549 1
		logging.debug("%s: %s", label, pc)
550 1
		logging.info("%s: %.6f (- %.6f) (+ %.6f)", label,
551
				median, pc_minus, pc_plus)
552
553 1
	logging.info("Finished successfully.")
554
555
556 1
if __name__ == "__main__":
557
	main()
558