Passed
Push — master ( 9c78f3...07cced )
by Stefan
04:19
created

sciapy.regress.__main__._prepare_proxy_model()   C

Complexity

Conditions 10

Size

Total Lines 38
Code Lines 37

Duplication

Lines 38
Ratio 100 %

Code Coverage

Tests 1
CRAP Score 76.992

Importance

Changes 0
Metric Value
cc 10
eloc 37
nop 8
dl 38
loc 38
ccs 1
cts 8
cp 0.125
crap 76.992
rs 5.9999
c 0
b 0
f 0

How to fix   Complexity    Many Parameters   

Complexity

Complex classes like sciapy.regress.__main__._prepare_proxy_model() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
# -*- coding: utf-8 -*-
2
# vim:fileencoding=utf-8
3
#
4
# Copyright (c) 2017-2018 Stefan Bender
5
#
6
# This module is part of sciapy.
7
# sciapy is free software: you can redistribute it or modify
8
# it under the terms of the GNU General Public License as published
9
# by the Free Software Foundation, version 2.
10
# See accompanying LICENSE file or http://www.gnu.org/licenses/gpl-2.0.html.
11 1
"""SCIAMACHY data regression command line interface
12
13
Command line main program for regression analysis of SCIAMACHY
14
daily zonal mean time series (NO for now).
15
"""
16
17 1
import ctypes
18 1
import logging
19
20 1
import numpy as np
21 1
import scipy.optimize as op
22 1
from scipy.interpolate import interp1d
23
24 1
import george
25 1
import celerite
26
27 1
import matplotlib as mpl
28
# switch off X11 rendering
29 1
mpl.use("Agg")
30
31 1
from .load_data import load_solar_gm_table, load_scia_dzm
32 1
from .models_cel import CeleriteModelSet as NOModel
33 1
from .models_cel import ConstantModel, ProxyModel
34 1
from .models_cel import HarmonicModelCosineSine, HarmonicModelAmpPhase
35 1
from .mcmc import mcmc_sample_model
36 1
from .statistics import mcmc_statistics
37
38 1
from ._gpkernels import (george_solvers,
39
		setup_george_kernel, setup_celerite_terms)
40 1
from ._plot import (plot_single_sample_and_residuals,
41
		plot_residual, plot_random_samples)
42 1
from ._options import parser
43
44
45 1 View Code Duplication
def save_samples_netcdf(filename, model, alt, lat, samples,
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
46
		scale=1e-6,
47
		lnpost=None, compressed=False):
48
	from xarray import Dataset
49
	smpl_ds = Dataset(dict([(pname, (["lat", "alt", "sample"],
50
				samples[..., i].reshape(1, 1, -1)))
51
				for i, pname in enumerate(model.get_parameter_names())]
52
				# + [("lpost", (["lat", "alt", "sample"], lnp.reshape(1, 1, -1)))]
53
			),
54
			coords={"lat": [lat], "alt": [alt]})
55
56
	for modn in model.mean.models:
57
		modl = model.mean.models[modn]
58
		if hasattr(modl, "mean"):
59
			smpl_ds.attrs[modn + ":mean"] = modl.mean
60
61
	units = {"kernel": {
62
				"log": "log(10$^{{{0:.0f}}}$ cm$^{{-3}}$)"
63
						.format(-np.log10(scale))},
64
			"mean": {
65
				"log": "log(10$^{{{0:.0f}}}$ cm$^{{-3}}$)"
66
						.format(-np.log10(scale)),
67
				"cos": "10$^{{{0:.0f}}}$ cm$^{{-3}}$".format(-np.log10(scale)),
68
				"sin": "10$^{{{0:.0f}}}$ cm$^{{-3}}$".format(-np.log10(scale)),
69
				"val": "10$^{{{0:.0f}}}$ cm$^{{-3}}$".format(-np.log10(scale)),
70
				"amp": "10$^{{{0:.0f}}}$ cm$^{{-3}}$".format(-np.log10(scale)),
71
				"tau": "d"}}
72
	for pname in smpl_ds.data_vars:
73
		_pp = pname.split(':')
74
		for _n, _u in units[_pp[0]].items():
75
			if _pp[-1].startswith(_n):
76
				logging.debug("units for %s: %s", pname, _u)
77
				smpl_ds[pname].attrs["units"] = _u
78
79
	smpl_ds["alt"].attrs = {"long_name": "altitude", "units": "km"}
80
	smpl_ds["lat"].attrs = {"long_name": "latitude", "units": "degrees_north"}
81
82
	_encoding = None
83
	if compressed:
84
		_encoding = {var: {"zlib": True, "complevel": 1}
85
					for var in smpl_ds.data_vars}
86
	smpl_ds.to_netcdf(filename, encoding=_encoding)
87
	smpl_ds.close()
88
89
90 1 View Code Duplication
def _train_test_split(times, data, errs, train_frac,
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
91
		test_frac, randomize):
92
	# split the data into training and test subsets according to the
93
	# fraction given (default is 1, i.e. no splitting)
94
	ndata = len(times)
95
	train_size = int(ndata * train_frac)
96
	test_size = min(ndata - train_size, int(ndata * test_frac))
97
	# randomize if requested
98
	if randomize:
99
		permut_idx = np.random.permutation(np.arange(ndata))
100
	else:
101
		permut_idx = np.arange(ndata)
102
	train_idx = np.sort(permut_idx[:train_size])
103
	test_idx = np.sort(permut_idx[train_size:train_size + test_size])
104
	times_train = times[train_idx]
105
	data_train = data[train_idx]
106
	errs_train = errs[train_idx]
107
	if test_size > 0:
108
		times_test = times[test_idx]
109
		data_test = data[test_idx]
110
		errs_test = errs[test_idx]
111
	else:
112
		times_test = times
113
		data_test = data
114
		errs_test = errs
115
	logging.info("using %s of %s samples for training.", len(times_train), ndata)
116
	logging.info("using %s of %s samples for testing.", len(times_test), ndata)
117
	return (times_train, data_train, errs_train,
118
			times_test, data_test, errs_test)
119
120
121 1 View Code Duplication
def _r_sun_earth(time, tfmt="jyear"):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
122
	"""First order approximation of the Sun-Earth distance
123
124
	The Sun-to-Earth distance can be used to (un-)normalize proxies
125
	to the actual distance to the Sun instead of 1 AU.
126
127
	Parameters
128
	----------
129
	time : float
130
		Time value in the units given by 'tfmt'.
131
	tfmt : str, optional
132
		The units of 'time' as supported by the
133
		astropy.time time formats. Default: 'jyear'.
134
135
	Returns
136
	-------
137
	dist : float
138
		The Sun-Earth distance at the given day of year in AU.
139
	"""
140
	from astropy.time import Time
141
	tdoy = Time(time, format=tfmt)
142
	tdoy.format = "yday"
143
	doy = int(tdoy.value.split(':')[1])
144
	return 1 - 0.01672 * np.cos(2 * np.pi / 365.256363 * (doy - 4))
145
146
147 1 View Code Duplication
def _prepare_proxy_model(name, times, values,
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
148
		args,
149
		lag=0,
150
		sza_intp=None,
151
		max_amp=1e10, max_days=100):
152
	pv = np.log(values) if name in args.log_proxies.split(',') else values
153
	if name in args.norm_proxies_distSEsq:
154
		rad_sun_earth = np.vectorize(_r_sun_earth)(times, tfmt=args.time_format)
155
		pv /= rad_sun_earth**2
156
	if name in args.norm_proxies_SZA:
157
		pv *= np.cos(np.radians(sza_intp(times)))
158
	return (
159
		name,
160
		ProxyModel(times, pv,
161
			center=name in args.center_proxies.split(','),
162
			sza_intp=sza_intp,
163
			fit_phase=args.fit_phase,
164
			lifetime_prior=args.lifetime_prior,
165
			lifetime_metric=args.lifetime_metric,
166
			days_per_time_unit=1 if args.time_format.endswith("d") else 365.25,
167
			amp=0.,
168
			lag=lag,
169
			tau0=0,
170
			taucos1=0, tausin1=0,
171
			taucos2=0, tausin2=0,
172
			ltscan=args.lifetime_scan,
173
			bounds=dict([
174
				("amp",
175
					[0, max_amp] if name in args.positive_proxies.split(',')
176
					else [-max_amp, max_amp]),
177
				("lag", [0, max_days]),
178
				("tau0", [0, max_days]),
179
				("taucos1", [0, max_days] if args.fit_phase else [-max_days, max_days]),
180
				("tausin1", [-np.pi, np.pi] if args.fit_phase else [-max_days, max_days]),
181
				# semi-annual cycles for the life time
182
				("taucos2", [0, max_days] if args.fit_phase else [-max_days, max_days]),
183
				("tausin2", [-np.pi, np.pi] if args.fit_phase else [-max_days, max_days]),
184
				("ltscan", [0, 200])])
185
		))
186
187
188 1 View Code Duplication
def main():
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
189 1
	logging.basicConfig(level=logging.WARNING,
190
			format="[%(levelname)-8s] (%(asctime)s) "
191
				"%(filename)s:%(lineno)d %(message)s",
192
			datefmt="%Y-%m-%d %H:%M:%S %z")
193
194 1
	args = parser.parse_args()
195
196
	logging.info("command line arguments: %s", args)
197
	if args.quiet:
198
		logging.getLogger().setLevel(logging.ERROR)
199
	elif args.verbose:
200
		logging.getLogger().setLevel(logging.INFO)
201
	else:
202
		logging.getLogger().setLevel(args.loglevel)
203
204
	from numpy.distutils.system_info import get_info
205
	for oblas_path in get_info("openblas")["library_dirs"]:
206
		oblas_name = "{0}/libopenblas.so".format(oblas_path)
207
		logging.info("Trying %s", oblas_name)
208
		try:
209
			oblas_lib = ctypes.cdll.LoadLibrary(oblas_name)
210
			oblas_cores = oblas_lib.openblas_get_num_threads()
211
			oblas_lib.openblas_set_num_threads(args.openblas_threads)
212
			logging.info("Using %s/%s Openblas thread(s).",
213
					oblas_lib.openblas_get_num_threads(), oblas_cores)
214
		except:
215
			logging.info("Setting number of openblas threads failed.")
216
217
	if args.random_seed is not None:
218
		np.random.seed(args.random_seed)
219
220
	if args.proxies:
221
		proxies = args.proxies.split(',')
222
		proxy_dict = dict(_p.split(':') for _p in proxies)
223
	else:
224
		proxy_dict = {}
225
	lag_dict = {pn: 0 for pn in proxy_dict.keys()}
226
227
	# Post-processing of arguments...
228
	# List of proxy lag fits from csv
229
	fit_lags = args.fit_lags.split(',')
230
	# List of proxy lifetime fits from csv
231
	fit_lifetimes = args.fit_lifetimes.split(',')
232
	fit_annlifetimes = args.fit_annlifetimes.split(',')
233
	# List of proxy lag times from csv
234
	lag_dict.update(dict(_ls.split(':') for _ls in args.lag_times.split(',')))
235
	# List of cycles (frequencies in 1/year) from argument list (csv)
236
	try:
237
		freqs = list(map(float, args.freqs.split(',')))
238
	except ValueError:
239
		freqs = []
240
	# List of initial parameter values
241
	initial = None
242
	if args.initial is not None:
243
		try:
244
			initial = list(map(float, args.initial.split(',')))
245
		except ValueError:
246
			pass
247
	# List of GP kernels from argument list (csv)
248
	kernls = args.kernels.split(',')
249
250
	lat = args.latitude
251
	alt = args.altitude
252
	logging.info("location: %.0f°N %.0f km", lat, alt)
253
254
	no_ys, no_dens, no_errs, no_szas = load_scia_dzm(args.file, alt, lat,
255
			tfmt=args.time_format,
256
			scale=args.scale,
257
			#subsample_factor=args.random_subsample,
258
			#subsample_method="random",
259
			akd_threshold=args.akd_threshold,
260
			cnt_threshold=args.cnt_threshold,
261
			center=args.center_data,
262
			season=args.season,
263
			SPEs=args.exclude_spe)
264
265
	(no_ys_train, no_dens_train, no_errs_train,
266
		no_ys_test, no_dens_test, no_errs_test) = _train_test_split(
267
				no_ys, no_dens, no_errs, args.train_fraction,
268
				args.test_fraction, args.random_train_test)
269
270
	sza_intp = interp1d(no_ys, no_szas, fill_value="extrapolate")
271
272
	max_amp = 1e10 * args.scale
273
	max_days = 100
274
275
	harmonic_models = []
276
	for freq in freqs:
277
		if not args.fit_phase:
278
			harmonic_models.append(("f{0:.0f}".format(freq),
279
				HarmonicModelCosineSine(freq=freq,
280
					cos=0, sin=0,
281
					bounds=dict([
282
						("cos", [-max_amp, max_amp]),
283
						("sin", [-max_amp, max_amp])])
284
			)))
285
		else:
286
			harmonic_models.append(("f{0:.0f}".format(freq),
287
				HarmonicModelAmpPhase(freq=freq,
288
					amp=0, phase=0,
289
					bounds=dict([
290
						# ("amp", [-max_amp, max_amp]),
291
						("amp", [0, max_amp]),
292
						("phase", [-np.pi, np.pi])])
293
			)))
294
	proxy_models = []
295
	for pn, pf in proxy_dict.items():
296
		pt, pp = load_solar_gm_table(pf, cols=[0, 1], names=["time", pn], tfmt=args.time_format)
297
		proxy_models.append(
298
			_prepare_proxy_model(pn, pt, pp[pn], args,
299
				lag=float(lag_dict[pn]),
300
				sza_intp=sza_intp if args.use_sza else None,
301
				max_amp=max_amp, max_days=max_days,
302
			)
303
		)
304
		logging.info("%s mean: %s", pn, proxy_models[-1][1].mean)
305
	offset_model = [("offset",
306
			ConstantModel(value=0.,
307
					bounds={"value": [-max_amp, max_amp]}))]
308
309
	model = NOModel(offset_model + harmonic_models + proxy_models)
310
311
	logging.debug("model dict: %s", model.get_parameter_dict())
312
	model.freeze_all_parameters()
313
	# thaw parameters according to requested fits
314
	for pn in proxy_dict.keys():
315
		model.thaw_parameter("{0}:amp".format(pn))
316
		if pn in fit_lags:
317
			model.thaw_parameter("{0}:lag".format(pn))
318
		if pn in fit_lifetimes:
319
			model.set_parameter("{0}:tau0".format(pn), 1e-3)
320
			model.thaw_parameter("{0}:tau0".format(pn))
321
			if pn in fit_annlifetimes:
322
				model.thaw_parameter("{0}:taucos1".format(pn))
323
				model.thaw_parameter("{0}:tausin1".format(pn))
324
	for freq in freqs:
325
		if not args.fit_phase:
326
			model.thaw_parameter("f{0:.0f}:cos".format(freq))
327
			model.thaw_parameter("f{0:.0f}:sin".format(freq))
328
		else:
329
			model.thaw_parameter("f{0:.0f}:amp".format(freq))
330
			model.thaw_parameter("f{0:.0f}:phase".format(freq))
331
	if args.fit_offset:
332
		#model.set_parameter("offset:value", -100.)
333
		#model.set_parameter("offset:value", 0)
334
		model.thaw_parameter("offset:value")
335
336
	if initial is not None:
337
		model.set_parameter_vector(initial)
338
	# model.thaw_parameter("GM:ltscan")
339
	logging.debug("params: %s", model.get_parameter_dict())
340
	logging.debug("param names: %s", model.get_parameter_names())
341
	logging.debug("param vector: %s", model.get_parameter_vector())
342
	logging.debug("param bounds: %s", model.get_parameter_bounds())
343
	#logging.debug("model value: %s", model.get_value(no_ys))
344
	#logging.debug("default log likelihood: %s", model.log_likelihood(model.vector))
345
346
	# setup the Gaussian Process kernel
347
	kernel_base = (1e7 * args.scale)**2
348
	ksub = args.name_suffix
349
350
	solver = "basic"
351
	skwargs = {}
352
	if args.HODLR_Solver:
353
		solver = "HODLR"
354
		#skwargs = {"tol": 1e-3}
355
356
	if args.george:
357
		gpname, kernel = setup_george_kernel(kernls,
358
				kernel_base=kernel_base, fit_bias=args.fit_bias)
359
		gpmodel = george.GP(kernel, mean=model,
360
			white_noise=1.e-25, fit_white_noise=args.fit_white,
361
			solver=george_solvers[solver], **skwargs)
362
		# the george interface does not allow setting the bounds in
363
		# the kernel initialization so we prepare simple default bounds
364
		kernel_bounds = [(-0.3 * max_amp, 0.3 * max_amp)
365
				for _ in gpmodel.kernel.get_parameter_names()]
366
		bounds = gpmodel.mean.get_parameter_bounds() + kernel_bounds
367
	else:
368
		gpname, cel_terms = setup_celerite_terms(kernls,
369
				fit_bias=args.fit_bias, fit_white=args.fit_white)
370
		gpmodel = celerite.GP(cel_terms, mean=model,
371
			fit_white_noise=args.fit_white,
372
			fit_mean=True)
373
		bounds = gpmodel.get_parameter_bounds()
374
	gpmodel.compute(no_ys_train, no_errs_train)
375
	logging.debug("gpmodel params: %s", gpmodel.get_parameter_dict())
376
	logging.debug("gpmodel bounds: %s", bounds)
377
	logging.debug("initial log likelihood: %s", gpmodel.log_likelihood(no_dens_train))
378
	if isinstance(gpmodel, celerite.GP):
379
		logging.info("(GP) jitter: %s", gpmodel.kernel.jitter)
380
	model_name = "_".join(gpmodel.mean.get_parameter_names()).replace(':', '')
381
	gpmodel_name = model_name + gpname
382
	logging.info("GP model name: %s", gpmodel_name)
383
384
	pre_opt = False
385
	if args.optimize > 0:
386
		def gpmodel_mean(x, *p):
387
			gpmodel.set_parameter_vector(p)
388
			return gpmodel.mean.get_value(x)
389
390
		def gpmodel_res(x, *p):
391
			gpmodel.set_parameter_vector(p)
392
			return (gpmodel.mean.get_value(x) - no_dens_train) / no_errs_train
393
394
		def lpost(p, y, gp):
395
			gp.set_parameter_vector(p)
396
			return gp.log_likelihood(y, quiet=True) + gp.log_prior()
397
398
		def nlpost(p, y, gp):
399
			lp = lpost(p, y, gp)
0 ignored issues
show
introduced by
The variable lpost does not seem to be defined in case args.optimize > 0 on line 385 is False. Are you sure this can never be the case?
Loading history...
400
			return -lp if np.isfinite(lp) else 1e25
401
402
		def grad_nlpost(p, y, gp):
403
			gp.set_parameter_vector(p)
404
			grad_ll = gp.grad_log_likelihood(y)
405
			if isinstance(grad_ll, tuple):
406
				# celerite
407
				return -grad_ll[1]
408
			# george
409
			return -grad_ll
410
411
		if args.optimize == 1:
412
			resop_gp = op.minimize(
413
				nlpost,
414
				gpmodel.get_parameter_vector(),
415
				args=(no_dens_train, gpmodel),
416
				bounds=bounds,
417
				# method="l-bfgs-b", options=dict(disp=True, maxcor=100, eps=1e-9, ftol=2e-15, gtol=1e-8))
418
				method="l-bfgs-b", jac=grad_nlpost)
419
				# method="tnc", options=dict(disp=True, maxiter=500, xtol=1e-12))
420
				# method="nelder-mead", options=dict(disp=True, maxfev=100000, fatol=1.49012e-8, xatol=1.49012e-8))
421
				# method="Powell", options=dict(ftol=1.49012e-08, xtol=1.49012e-08))
422
		if args.optimize == 2:
423
			resop_gp = op.differential_evolution(
424
				nlpost,
425
				bounds=bounds,
426
				args=(no_dens_train, gpmodel),
427
				popsize=2 * args.walkers, tol=0.01)
428
		if args.optimize == 3:
429
			resop_bh = op.basinhopping(
430
				nlpost,
431
				gpmodel.get_parameter_vector(),
432
				niter=200,
433
				minimizer_kwargs=dict(
434
					args=(no_dens_train, gpmodel),
435
					bounds=bounds,
436
					# method="tnc"))
437
					# method="l-bfgs-b", options=dict(maxcor=100)))
438
					method="l-bfgs-b", jac=grad_nlpost))
439
					# method="Nelder-Mead"))
440
					# method="BFGS"))
441
					# method="Powell", options=dict(ftol=1.49012e-08, xtol=1.49012e-08)))
442
			logging.debug("optimization result: %s", resop_bh)
443
			resop_gp = resop_bh.lowest_optimization_result
444
		if args.optimize == 4:
445
			resop_gp, cov_gp = op.curve_fit(
446
				gpmodel_mean,
447
				no_ys_train, no_dens_train, gpmodel.get_parameter_vector(),
448
				bounds=tuple(np.array(bounds).T),
449
				# method='lm',
450
				# absolute_sigma=True,
451
				sigma=no_errs_train)
452
			print(resop_gp, np.sqrt(np.diag(cov_gp)))
453
		logging.info("%s", resop_gp.message)
0 ignored issues
show
introduced by
The variable resop_gp does not seem to be defined in case args.optimize == 1 on line 411 is False. Are you sure this can never be the case?
Loading history...
454
		logging.debug("optimization result: %s", resop_gp)
455
		logging.info("gpmodel dict: %s", gpmodel.get_parameter_dict())
456
		logging.info("log posterior trained: %s", lpost(gpmodel.get_parameter_vector(), no_dens_train, gpmodel))
457
		gpmodel.compute(no_ys_test, no_errs_test)
458
		logging.info("log posterior test: %s", lpost(gpmodel.get_parameter_vector(), no_dens_test, gpmodel))
459
		gpmodel.compute(no_ys, no_errs)
460
		logging.info("log posterior all: %s", lpost(gpmodel.get_parameter_vector(), no_dens, gpmodel))
461
		# cross check to make sure that the gpmodel parameter vector is really
462
		# set to the fitted parameters
463
		logging.info("opt. model vector: %s", resop_gp.x)
464
		gpmodel.compute(no_ys_train, no_errs_train)
465
		logging.debug("opt. log posterior trained 1: %s", lpost(resop_gp.x, no_dens_train, gpmodel))
466
		gpmodel.compute(no_ys_test, no_errs_test)
467
		logging.debug("opt. log posterior test 1: %s", lpost(resop_gp.x, no_dens_test, gpmodel))
468
		gpmodel.compute(no_ys, no_errs)
469
		logging.debug("opt. log posterior all 1: %s", lpost(resop_gp.x, no_dens, gpmodel))
470
		logging.debug("opt. model vector: %s", gpmodel.get_parameter_vector())
471
		gpmodel.compute(no_ys_train, no_errs_train)
472
		logging.debug("opt. log posterior trained 2: %s", lpost(gpmodel.get_parameter_vector(), no_dens_train, gpmodel))
473
		gpmodel.compute(no_ys_test, no_errs_test)
474
		logging.debug("opt. log posterior test 2: %s", lpost(gpmodel.get_parameter_vector(), no_dens_test, gpmodel))
475
		gpmodel.compute(no_ys, no_errs)
476
		logging.debug("opt. log posterior all 2: %s", lpost(gpmodel.get_parameter_vector(), no_dens, gpmodel))
477
		pre_opt = resop_gp.success
478
	try:
479
		logging.info("GM lt: %s", gpmodel.get_parameter("mean:GM:tau0"))
480
	except ValueError:
481
		pass
482
	logging.info("(GP) model: %s", gpmodel.kernel)
483
	if isinstance(gpmodel, celerite.GP):
484
		logging.info("(GP) jitter: %s", gpmodel.kernel.jitter)
485
486
	bestfit = gpmodel.get_parameter_vector()
487
	filename_base = ("NO_regress_fit_{0}_{1:.0f}_{2:.0f}_{{0}}_{3}"
488
					.format(gpmodel_name, lat * 10, alt, ksub))
489
490
	if args.mcmc:
491
		gpmodel.compute(no_ys_train, no_errs_train)
492
		samples, lnp = mcmc_sample_model(gpmodel,
493
				no_dens_train,
494
				beta=1.0,
495
				nwalkers=args.walkers, nburnin=args.burn_in,
496
				nprod=args.production, nthreads=args.threads,
497
				show_progress=args.progress,
498
				optimized=pre_opt, bounds=bounds, return_logpost=True)
499
500
		if args.train_fraction < 1. or args.test_fraction < 1.:
501
			logging.info("Statistics for the test samples")
502
			mcmc_statistics(gpmodel,
503
					no_ys_test, no_dens_test, no_errs_test,
504
					no_ys_train, no_dens_train, no_errs_train,
505
					samples, lnp,
506
			)
507
		logging.info("Statistics for all samples")
508
		mcmc_statistics(gpmodel,
509
				no_ys, no_dens, no_errs,
510
				no_ys_train, no_dens_train, no_errs_train,
511
				samples, lnp,
512
		)
513
514
		sampl_percs = np.percentile(samples, [2.5, 50, 97.5], axis=0)
515
		if args.plot_corner:
516
			import corner
517
			# Corner plot of the sampled parameters
518
			fig = corner.corner(samples,
519
					quantiles=[0.025, 0.5, 0.975],
520
					show_titles=True,
521
					labels=gpmodel.get_parameter_names(),
522
					truths=bestfit,
523
					hist_args=dict(normed=True))
524
			fig.savefig(filename_base.format("corner") + ".pdf", transparent=True)
525
526
		if args.save_samples:
527
			if args.samples_format in ["npz"]:
528
				# save the samples compressed to save space.
529
				np.savez_compressed(filename_base.format("sampls") + ".npz",
530
					samples=samples)
531
			if args.samples_format in ["nc", "netcdf4"]:
532
				save_samples_netcdf(filename_base.format("sampls") + ".nc",
533
					gpmodel, alt, lat, samples, scale=args.scale, compressed=True)
534
			if args.samples_format in ["h5", "hdf5"]:
535
				save_samples_netcdf(filename_base.format("sampls") + ".h5",
536
					gpmodel, alt, lat, samples, scale=args.scale, compressed=True)
537
		# MCMC finished here
538
539
	# set the model times and errors to use the full data set for plotting
540
	gpmodel.compute(no_ys, no_errs)
541
	if args.save_model:
542
		try:
543
			# python 2
544
			import cPickle as pickle
545
		except ImportError:
546
			# python 3
547
			import pickle
548
		# pickle and save the model
549
		with open(filename_base.format("model") + ".pkl", "wb") as f:
550
			pickle.dump((gpmodel), f, -1)
551
552
	if args.plot_samples and args.mcmc:
553
		plot_random_samples(gpmodel, no_ys, no_dens, no_errs,
554
				samples, args.scale,
0 ignored issues
show
introduced by
The variable samples does not seem to be defined in case args.mcmc on line 490 is False. Are you sure this can never be the case?
Loading history...
555
				filename_base.format("sampls") + ".pdf",
556
				size=4, extra_years=[4, 2])
557
558
	if args.plot_median:
559
		plot_single_sample_and_residuals(gpmodel, no_ys, no_dens, no_errs,
560
				sampl_percs[1],
0 ignored issues
show
introduced by
The variable sampl_percs does not seem to be defined in case args.mcmc on line 490 is False. Are you sure this can never be the case?
Loading history...
561
				filename_base.format("median") + ".pdf")
562
	if args.plot_residuals:
563
		plot_residual(gpmodel, no_ys, no_dens, no_errs,
564
				sampl_percs[1], args.scale,
565
				filename_base.format("medres") + ".pdf")
566
	if args.plot_maxlnp:
567
		plot_single_sample_and_residuals(gpmodel, no_ys, no_dens, no_errs,
568
				samples[np.argmax(lnp)],
0 ignored issues
show
introduced by
The variable lnp does not seem to be defined in case args.mcmc on line 490 is False. Are you sure this can never be the case?
Loading history...
569
				filename_base.format("maxlnp") + ".pdf")
570
	if args.plot_maxlnpres:
571
		plot_residual(gpmodel, no_ys, no_dens, no_errs,
572
				samples[np.argmax(lnp)], args.scale,
573
				filename_base.format("mlpres") + ".pdf")
574
575
	labels = gpmodel.get_parameter_names()
576
	logging.info("param percentiles [2.5, 50, 97.5]:")
577
	for pc, label in zip(sampl_percs.T, labels):
578
		median = pc[1]
579
		pc_minus = median - pc[0]
580
		pc_plus = pc[2] - median
581
		logging.debug("%s: %s", label, pc)
582
		logging.info("%s: %.6f (- %.6f) (+ %.6f)", label,
583
				median, pc_minus, pc_plus)
584
585
	logging.info("Finished successfully.")
586
587
588 1
if __name__ == "__main__":
589
	main()
590