Passed
Push — master ( 33eda6...e9fe96 )
by Stefan
06:32
created

sciapy.regress.models_theano.ProxyModel._lt_corr()   A

Complexity

Conditions 2

Size

Total Lines 16
Code Lines 10

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 10
nop 3
dl 0
loc 16
rs 9.9
c 0
b 0
f 0
1
# -*- coding: utf-8 -*-
2
# vim:fileencoding=utf-8
3
#
4
# Copyright (c) 2022 Stefan Bender
5
#
6
# This module is part of sciapy.
7
# sciapy is free software: you can redistribute it or modify
8
# it under the terms of the GNU General Public License as published
9
# by the Free Software Foundation, version 2.
10
# See accompanying LICENSE file or http://www.gnu.org/licenses/gpl-2.0.html.
11
"""SCIAMACHY regression models (theano/pymc3 version)
12
13
Model classes for SCIAMACHY data regression fits using
14
:mod:`theano` for :mod:`pymc3`.
15
16
This interface is still experimental.
17
"""
18
from __future__ import absolute_import, division, print_function
19
from warnings import warn
20
21
import numpy as np
22
23
try:
24
	import aesara_theano_fallback.tensor as tt
25
except ImportError as err:
26
	raise ImportError(
27
		"The `aesara_theano_fallback` package is required for the `theano` model interface."
28
	).with_traceback(err.__traceback__)
29
try:
30
	import pymc3 as pm
31
except ImportError as err:
32
	raise ImportError(
33
		"The `pymc3` package is required for the `theano` model interface."
34
	).with_traceback(err.__traceback__)
35
36
__all__ = [
37
	"HarmonicModelCosineSine", "HarmonicModelAmpPhase",
38
	"LifetimeModel",
39
	"ProxyModel",
40
	"ModelSet",
41
	"setup_proxy_model_theano",
42
	"trace_gas_modelset",
43
]
44
45
46
class HarmonicModelCosineSine:
47
	"""Model for harmonic terms
48
49
	Models harmonic terms using a cosine and sine part.
50
	The total amplitude and phase can be inferred from that.
51
52
	Parameters
53
	----------
54
	freq : float
55
		The frequency in years^-1
56
	cos : float
57
		The amplitude of the cosine part
58
	sin : float
59
		The amplitude of the sine part
60
	"""
61
	def __init__(self, freq, cos, sin):
62
		self.omega = tt.as_tensor_variable(2 * np.pi * freq).astype("float64")
63
		self.cos = tt.as_tensor_variable(cos).astype("float64")
64
		self.sin = tt.as_tensor_variable(sin).astype("float64")
65
66
	def get_value(self, t):
67
		t = tt.as_tensor_variable(t).astype("float64")
68
		return (
69
			self.cos * tt.cos(self.omega * t)
70
			+ self.sin * tt.sin(self.omega * t)
71
		)
72
73
	def get_amplitude(self):
74
		return tt.sqrt(self.cos**2 + self.sin**2)
75
76
	def get_phase(self):
77
		return tt.arctan2(self.cos, self.sin)
78
79
	def compute_gradient(self, t):
80
		t = tt.as_tensor_variable(t).astype("float64")
81
		dcos = tt.cos(self.omega * t)
82
		dsin = tt.sin(self.omega * t)
83
		df = 2 * np.pi * t * (self.sin * dcos - self.cos * dsin)
84
		return (df, dcos, dsin)
85
86
87
class HarmonicModelAmpPhase:
88
	"""Model for harmonic terms
89
90
	Models harmonic terms using amplitude and phase of a sine.
91
92
	Parameters
93
	----------
94
	freq : float
95
		The frequency in years^-1
96
	amp : float
97
		The amplitude of the harmonic term
98
	phase : float
99
		The phase of the harmonic part
100
	"""
101
	def __init__(self, freq, amp, phase):
102
		self.omega = tt.as_tensor_variable(2 * np.pi * freq).astype("float64")
103
		self.amp = tt.as_tensor_variable(amp).astype("float64")
104
		self.phase = tt.as_tensor_variable(phase).astype("float64")
105
106
	def get_value(self, t):
107
		t = tt.as_tensor_variable(t).astype("float64")
108
		return self.amp * tt.sin(self.omega * t + self.phase)
109
110
	def get_amplitude(self):
111
		return self.amp
112
113
	def get_phase(self):
114
		return self.phase
115
116
	def compute_gradient(self, t):
117
		t = tt.as_tensor_variable(t).astype("float64")
118
		damp = tt.sin(self.omega * t + self.phase)
119
		dphi = self.amp * tt.cos(self.omega * t + self.phase)
120
		df = 2 * np.pi * t * dphi
121
		return (df, damp, dphi)
122
123
124
class LifetimeModel:
125
	"""Model for variable lifetime
126
127
	Returns the positive values of the sine/cosine.
128
129
	Parameters
130
	----------
131
	harmonics : HarmonicModelCosineSine or HarmonicModelAmpPhase or list
132
	"""
133
	def __init__(self, harmonics, lower=0.):
134
		if not hasattr(harmonics, "getitem"):
135
			harmonics = [harmonics]
136
		self.harmonics = harmonics
137
		self.lower = lower
138
139
	def get_value(self, t):
140
		tau_cs = tt.zeros(t.shape[:-1], dtype="float64")
141
		for h in self.harmonics:
142
			tau_cs += h.get_value(t)
143
		return tt.maximum(self.lower, tau_cs)
144
145
146
def _interp(x, xs, ys, fill_value=0.):
147
	idx = xs.searchsorted(x)
148
	out_of_bounds = tt.zeros(x.shape[:-1], dtype=bool)
149
	out_of_bounds |= (idx < 1) | (idx >= xs.shape[0])
150
	idx = tt.clip(idx, 1, xs.shape[0] - 1)
151
	dl = x - xs[idx - 1]
152
	dr = xs[idx] - x
153
	d = dl + dr
154
	wl = dr / d
155
	ret = tt.zeros(x.shape[:-1], dtype="float64")
156
	ret += wl * ys[idx - 1] + (1 - wl) * ys[idx]
157
	ret = tt.switch(out_of_bounds, fill_value, ret)
158
	return ret
159
160
161
class ProxyModel:
162
	"""Model for proxy terms
163
164
	Models proxy terms with a finite and (semi-)annually varying life time.
165
166
	Parameters
167
	----------
168
	proxy_times : (N,) array_like
169
		The data times of the proxy values
170
	proxy_vals : (N,) array_like
171
		The proxy values at `proxy_times`
172
	amp : float
173
		The amplitude of the proxy term
174
	lag : float, optional
175
		The lag of the proxy value in years.
176
	tau0 : float, optional
177
		The base life time of the proxy
178
	tau_harm : LifetimeModel, optional
179
		The lifetime harmonic model with a lower limit.
180
	tau_scan : float, optional
181
		The number of days to sum the previous proxy values. If it is
182
		negative, the value will be set to three times the maximal lifetime.
183
		No lifetime adjustemets are calculated when set to zero.
184
	days_per_time_unit : float, optional
185
		The number of days per time unit, used to normalize the lifetime
186
		units. Use 365.25 if the times are in fractional years, or 1 if
187
		they are in days.
188
		Default: 365.25
189
	"""
190
	def __init__(
191
		self, ptimes, pvalues, amp,
192
		lag=0.,
193
		tau0=0.,
194
		tau_harm=None,
195
		tau_scan=0,
196
		days_per_time_unit=365.25,
197
	):
198
		# data
199
		self.times = tt.as_tensor_variable(ptimes).astype("float64")
200
		self.values = tt.as_tensor_variable(pvalues).astype("float64")
201
		# parameters
202
		self.amp = tt.as_tensor_variable(amp).astype("float64")
203
		self.days_per_time_unit = tt.as_tensor_variable(days_per_time_unit).astype("float64")
204
		self.lag = tt.as_tensor_variable(lag / days_per_time_unit).astype("float64")
205
		self.tau0 = tt.as_tensor_variable(tau0).astype("float64")
206
		self.tau_harm = tau_harm
207
		self.tau_scan = tau_scan
208
		dt = 1.0
209
		bs = np.arange(dt, tau_scan + dt, dt) / days_per_time_unit
210
		self.bs = tt.as_tensor_variable(bs).astype("float64")
211
		self.dt = tt.as_tensor_variable(dt).astype("float64")
212
		# Makes "(m)jd" and "jyear" compatible for the lifetime
213
		# seasonal variation. The julian epoch (the default)
214
		# is slightly offset with respect to (modified) julian days.
215
		self.t_adj = 0.
216
		if self.days_per_time_unit == 1:
217
			# discriminate between julian days and modified julian days,
218
			# 1.8e6 is year 216 in julian days and year 6787 in
219
			# modified julian days. It should be pretty safe to judge on
220
			# that for most use cases.
221
			if self.times[0] > 1.8e6:
222
				# julian days
223
				self.t_adj = 13.
224
			else:
225
				# modified julian days
226
				self.t_adj = -44.25
227
		self.t_adj = tt.as_tensor_variable(self.t_adj).astype("float64")
228
229
	def _lt_corr(self, t, tau):
230
		"""Lifetime corrected values
231
232
		Corrects for a finite lifetime by summing over the last `tmax`
233
		days with an exponential decay given of lifetime(s) `tau`.
234
		"""
235
		yp = tt.zeros(t.shape[:-1], dtype="float64")
236
		tauexp = tt.exp(-self.dt / tau)
237
		taufac = tt.ones(tau.shape[:-1], dtype="float64")
238
		for b in self.bs:
239
			taufac *= tauexp
240
			yp += taufac * _interp(
241
				t - self.lag - b,
242
				self.times, self.values,
243
			)
244
		return yp * self.dt
245
246
	def get_value(self, t):
247
		t = tt.as_tensor_variable(t)
248
		proxy_val = _interp(
249
			t - self.lag,
250
			self.times, self.values,
251
		)
252
		if self.tau_scan == 0:
253
			# no lifetime, nothing else to do
254
			return self.amp * proxy_val
255
		tau = self.tau0
256
		if self.tau_harm is not None:
257
			tau_cs = self.tau_harm.get_value(t + self.t_adj)
258
			tau += tau_cs
259
		proxy_val += self._lt_corr(t, tau)
260
		return self.amp * proxy_val
261
262
263
class ModelSet:
264
	def __init__(self, models):
265
		self.models = models
266
267
	def get_value(self, t):
268
		v = tt.zeros(t.shape[:-1], dtype="float64")
269
		for m in self.models:
270
			v += m.get_value(t)
271
		return v
272
273
274
def setup_proxy_model_theano(
275
	model, name,
276
	times, values,
277
	max_amp=1e10, max_days=100,
278
	**kwargs
279
):
280
	warn(
281
		"This method to set up the `theano`/`pymc3` interface is experimental, "
282
		"and the interface will most likely change in future versions. "
283
		"It is recommended to use the `ProxyModel` class instead."
284
	)
285
	# extract setup from `kwargs`
286
	fit_lag = kwargs.get("fit_lag", False)
287
	lag = kwargs.get("lag", 0.)
288
	lifetime_scan = kwargs.get("lifetime_scan", 60)
289
	positive = kwargs.get("positive", False)
290
	time_format = kwargs.get("time_format", "jyear")
291
292
	with model:
293
		if positive:
294
			log_amp = pm.Normal("log_{0}_amp".format(name), mu=0.0, sd=np.log(max_amp))
295
			amp = pm.Deterministic("{0}_amp".format(name), pm.math.exp(log_amp))
296
		else:
297
			amp = pm.Normal("{0}_amp".format(name), mu=0.0, sd=max_amp)
298
		if fit_lag:
299
			log_lag = pm.Normal("log_{0}_lag".format(name), mu=-5.0, sd=np.log(max_days))
300
			lag = pm.Deterministic("{0}_lag".format(name), pm.math.exp(log_lag))
301
		if lifetime_scan > 0:
302
			log_tau0 = pm.Normal("log_{0}_tau0".format(name), mu=-5.0, sd=np.log(max_days))
303
			tau0 = pm.Deterministic("{0}_tau0".format(name), pm.math.exp(log_tau0))
304
			cos1 = pm.Normal("{0}_tau_cos1".format(name), mu=0.0, sd=max_amp)
305
			sin1 = pm.Normal("{0}_tau_sin1".format(name), mu=0.0, sd=max_amp)
306
			harm1 = HarmonicModelCosineSine(1., cos1, sin1)
307
			tau1 = LifetimeModel(harm1, lower=0)
308
		else:
309
			tau0 = 0.
310
			tau1 = None
311
		proxy = ProxyModel(
312
			times, values,
313
			amp,
314
			lag=lag,
315
			tau0=tau0,
316
			tau_harm=tau1,
317
			tau_scan=lifetime_scan,
318
			days_per_time_unit=1 if time_format.endswith("d") else 365.25,
319
		)
320
	return proxy
321
322
323 View Code Duplication
def _default_proxy_config(tfmt="jyear"):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
324
	from .load_data import load_dailymeanLya, load_dailymeanAE
325
	proxy_config = {}
326
	# Lyman-alpha
327
	plyat, plyadf = load_dailymeanLya(tfmt=tfmt)
328
	proxy_config.update({
329
		"Lya": {
330
			"times": plyat,
331
			"values": plyadf["Lya"],
332
			"lifetime_scan": 0,
333
			"positive": False,
334
		}
335
	})
336
	# AE index
337
	paet, paedf = load_dailymeanAE(name="GM", tfmt=tfmt)
338
	proxy_config.update({
339
		"GM": {
340
			"times": paet,
341
			"values": paedf["GM"],
342
			"lifetime_scan": 60,
343
			"positive": True,
344
		}
345
	})
346
	return proxy_config
347
348
349
def trace_gas_modelset(constant=True, freqs=None, proxy_config=None, **kwargs):
350
	"""Trace gas model set
351
352
	Sets up the trace gas model for easy access. All parameters are optional,
353
	defaults to an offset, no harmonics, proxies are uncentered and unscaled
354
	Lyman-alpha and AE. AE with only positive amplitude and a seasonally
355
	varying lifetime.
356
357
	Parameters
358
	----------
359
	constant : bool, optional
360
		Whether or not to include a constant (offset) term, default is True.
361
	freqs : list, optional
362
		Frequencies of the harmonic terms in 1 / a^-1 (inverse years).
363
	proxy_config : dict, optional
364
		Proxy configuration if different from the standard setup.
365
	**kwargs : optional
366
		Additional keyword arguments, all of them are also passed on to
367
		the proxy setup. For now, supported are the following which are
368
		also passed along to the proxy setup with
369
		`setup_proxy_model_with_bounds()`:
370
371
		* fit_phase : bool
372
			fit amplitude and phase instead of sine and cosine
373
		* scale : float
374
			the factor by which the data is scaled, used to constrain
375
			the maximum and minimum amplitudes to be fitted.
376
		* time_format : string
377
			The `astropy.time.Time` format string to setup the time axis.
378
		* days_per_time_unit : float
379
			The number of days per time unit, used to normalize the frequencies
380
			for the harmonic terms. Use 365.25 if the times are in fractional years,
381
			1 if they are in days. Default: 365.25
382
		* max_amp : float
383
			Maximum magnitude of the coefficients, used to constrain the
384
			parameter search.
385
		* max_days : float
386
			Maximum magnitude of the lifetimes, used to constrain the
387
			parameter search.
388
389
	Returns
390
	-------
391
	model : :class:`TraceGasModelSet` (extends :class:`celerite.ModelSet`)
392
	"""
393
	warn(
394
		"This method to set up the `theano`/`pymc3` interface is experimental, "
395
		"and the interface will most likely change in future versions. "
396
		"It is recommended to use the `ProxyModel` class instead."
397
	)
398
	fit_phase = kwargs.get("fit_phase", False)
399
	scale = kwargs.get("scale", 1e-6)
400
	tfmt = kwargs.get("time_format", "jyear")
401
	delta_t = kwargs.get("days_per_time_unit", 365.25)
402
403
	max_amp = kwargs.pop("max_amp", 1e10 * scale)
404
	max_days = kwargs.pop("max_days", 100)
405
406
	proxy_config = proxy_config or _default_proxy_config(tfmt=tfmt)
407
408
	with pm.Model() as model:
409
		offset = 0.
410
		if constant:
411
			offset = pm.Normal("offset", mu=0.0, sd=max_amp)
412
413
		modelset = []
414
		for freq in freqs:
415
			if not fit_phase:
416
				cos = pm.Normal("cos{0}".format(freq), mu=0., sd=max_amp)
417
				sin = pm.Normal("sin{0}".format(freq), mu=0., sd=max_amp)
418
				harm = HarmonicModelCosineSine(
419
					freq * delta_t / 365.25,
420
					cos, sin,
421
				)
422
			else:
423
				amp = pm.Normal("amp{0}".format(freq), mu=0., sd=max_amp)
424
				phase = pm.Normal("phase{0}".format(freq), mu=0., sd=max_amp)
425
				harm = HarmonicModelAmpPhase(
426
					freq * delta_t / 365.25,
427
					amp, phase,
428
				)
429
			modelset.append(harm)
430
431
		for pn, conf in proxy_config.items():
432
			if "max_amp" not in conf:
433
				conf.update(dict(max_amp=max_amp))
434
			if "max_days" not in conf:
435
				conf.update(dict(max_days=max_days))
436
			kw = kwargs.copy()  # don't mess with the passed arguments
437
			kw.update(conf)
438
			modelset.append(
439
				setup_proxy_model_theano(model, pn, **kw)
440
			)
441
442
	return model, ModelSet(modelset), offset
443