Passed
Push — master ( e93741...f4fcef )
by Stefan
10:45
created

test_regress_models_theano.xx()   A

Complexity

Conditions 1

Size

Total Lines 5
Code Lines 4

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 4
nop 0
dl 0
loc 5
rs 10
c 0
b 0
f 0
1
# -*- coding: utf-8 -*-
2
# vim:fileencoding=utf-8
3
#
4
# Copyright (c) 2018-2022 Stefan Bender
5
#
6
# This module is part of sciapy.
7
# sciapy is free software: you can redistribute it or modify
8
# it under the terms of the GNU General Public License as published
9
# by the Free Software Foundation, version 2.
10
# See accompanying LICENSE file or http://www.gnu.org/licenses/gpl-2.0.html.
11
"""SCIAMACHY regression module tests
12
"""
13
import numpy as np
14
15
import pytest
16
17
try:
18
	import pymc3 as pm
19
except ImportError:
20
	pytest.skip("Theano/PyMC3 packages not installed", allow_module_level=True)
21
22
try:
23
	from sciapy.regress.models_theano import (
24
		HarmonicModelCosineSine,
25
		HarmonicModelAmpPhase,
26
		LifetimeModel,
27
		ProxyModel,
28
	)
29
except ImportError:
30
	pytest.skip("Theano/PyMC3 interface not installed", allow_module_level=True)
31
32
33
@pytest.fixture(scope="module")
34
def xx():
35
	# modified Julian days, 2 years from 2000-01-01
36
	_xs = 51544.5 + np.arange(0., 2 * 365. + 1, 1.)
37
	return np.ascontiguousarray(_xs, dtype=np.float64)
38
39
40
def ys(xs, c, s):
41
	_ys = c * np.cos(2 * np.pi * xs) + s * np.sin(2 * np.pi * xs)
42
	return np.ascontiguousarray(_ys, dtype=np.float64)
43
44
45
@pytest.mark.parametrize(
46
	"c, s",
47
	[
48
		(0.5, 2.0),
49
		(1.0, 0.5),
50
		(1.0, 1.0),
51
	]
52
)
53
def test_harmonics_theano(xx, c, s):
54
	# Initialize random number generator
55
	np.random.seed(93457)
56
	# convert to fractional years
57
	xs = 1859 + (xx - 44.25) / 365.25
58
	yp = ys(xs, c, s)
59
	yp += 0.5 * np.random.randn(xs.shape[0])
60
61
	with pm.Model() as model1:
62
		cos = pm.Normal("cos", mu=0.0, sigma=4.0)
63
		sin = pm.Normal("sin", mu=0.0, sigma=4.0)
64
		harm1 = HarmonicModelCosineSine(1., cos, sin)
65
		wave1 = harm1.get_value(xs)
66
		# add amplitude and phase for comparison
67
		pm.Deterministic("amp", harm1.get_amplitude())
68
		pm.Deterministic("phase", harm1.get_phase())
69
		pm.Normal("obs", mu=wave1, observed=yp)
70
		trace1 = pm.sample(tune=400, draws=400, chains=2, return_inferencedata=True)
71
72
	with pm.Model() as model2:
73
		amp2 = pm.HalfNormal("amp", sigma=4.0)
74
		phase2 = pm.Normal("phase", mu=0.0, sigma=4.0)
75
		harm2 = HarmonicModelAmpPhase(1., amp2, phase2)
76
		wave2 = harm2.get_value(xs)
77
		pm.Normal("obs", mu=wave2, observed=yp)
78
		trace2 = pm.sample(tune=400, draws=400, chains=2, return_inferencedata=True)
79
80
	np.testing.assert_allclose(
81
		trace1.posterior.median(dim=("chain", "draw"))[["cos", "sin"]].to_array(),
82
		(c, s),
83
		atol=2e-2,
84
	)
85
	np.testing.assert_allclose(
86
		trace1.posterior.median(dim=("chain", "draw"))[["amp", "phase"]].to_array(),
87
		trace2.posterior.median(dim=("chain", "draw"))[["amp", "phase"]].to_array(),
88
		atol=6e-3,
89
	)
90
91
92
def _test_data(xs, values, f, c, s):
93
	amp = 3.
94
	lag = 2.
95
	tau0 = 1.
96
	harm0 = HarmonicModelCosineSine(f, c, s)
97
	tau_lt0 = LifetimeModel(harm0, lower=0.)
98
	proxy0 = ProxyModel(
99
		xs, values,
100
		amp=amp,
101
		lag=lag,
102
		tau0=tau0,
103
		tau_harm=tau_lt0,
104
		tau_scan=10,
105
		days_per_time_unit=f * 365.25,
106
	)
107
	return proxy0.get_value(xs).eval()
108
109
110
def _yy(x, c, s):
111
	_ys = np.zeros_like(x)
112
	_ys[10::20] = 10.
113
	return np.ascontiguousarray(_ys, dtype=np.float64)
114
115
116
@pytest.mark.long
117
@pytest.mark.parametrize(
118
	"f",
119
	[1., 1. / 365.25]
120
)
121
def test_proxy_theano(xx, f, c=3.0, s=1.0):
122
	# Initialize random number generator
123
	np.random.seed(93457)
124
125
	dx = 1. / (f * 365.25)
126
	if f < 1.:
127
		xs = xx * dx
128
	else:
129
		# convert to fractional years
130
		xs = 1859 + (xx - 44.25) * dx
131
	# proxy "values"
132
	values = _yy(xs, c, s)
133
134
	yp = _test_data(xs, values, f, c, s)
135
	yp += 0.5 * np.random.randn(xs.shape[0])
136
137
	# using "name" prefixes all variables with <name>_
138
	with pm.Model(name="proxy") as model:
139
		# amplitude
140
		pamp = pm.Normal("amp", mu=0.0, sigma=4.0)
141
		# lag
142
		plag = pm.Lognormal("lag", mu=0.0, sigma=4.0, testval=1.0)
143
		# lifetime
144
		ptau0 = pm.Lognormal("tau0", mu=0.0, sigma=4.0, testval=1.0)
145
		cos1 = pm.Normal("tau_cos1", mu=0.0, sigma=10.0)
146
		sin1 = pm.Normal("tau_sin1", mu=0.0, sigma=10.0)
147
		harm1 = HarmonicModelCosineSine(f, cos1, sin1)
148
		tau1 = LifetimeModel(harm1, lower=0)
149
150
		proxy = ProxyModel(
151
			xs, values,
152
			amp=pamp,
153
			lag=plag,
154
			tau0=ptau0,
155
			tau_harm=tau1,
156
			tau_scan=10,
157
			days_per_time_unit=f * 365.25,
158
		)
159
		prox1 = proxy.get_value(xs)
160
		# Include "jitter"
161
		log_jitter = pm.Normal("log_jitter", mu=0.0, sigma=4.0)
162
		pm.Normal("obs", mu=prox1, sigma=pm.math.exp(log_jitter), observed=yp)
163
164
		maxlp0 = pm.find_MAP()
165
		trace = pm.sample(
166
			chains=2,
167
			draws=400,
168
			tune=400,
169
			random_seed=[286923464, 464329682],
170
			return_inferencedata=True,
171
		)
172
173
	medians = trace.posterior.median(dim=("chain", "draw"))
174
	var_names = [
175
		model.name_for(n)
176
		for n in [
177
			"amp", "lag", "tau0", "tau_cos1", "tau_sin1", "log_jitter",
178
		]
179
	]
180
	np.testing.assert_allclose(
181
		medians[var_names].to_array(),
182
		(3., 2., 1., c, s, np.log(0.5)),
183
		atol=3e-2, rtol=1e-2,
184
	)
185