Passed
Push — master ( 253f52...0e3d3e )
by Stefan
05:19
created

test_regress_models_theano._test_data()   A

Complexity

Conditions 1

Size

Total Lines 17
Code Lines 15

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 15
nop 3
dl 0
loc 17
rs 9.65
c 0
b 0
f 0
1
# -*- coding: utf-8 -*-
2
# vim:fileencoding=utf-8
3
#
4
# Copyright (c) 2018-2022 Stefan Bender
5
#
6
# This module is part of sciapy.
7
# sciapy is free software: you can redistribute it or modify
8
# it under the terms of the GNU General Public License as published
9
# by the Free Software Foundation, version 2.
10
# See accompanying LICENSE file or http://www.gnu.org/licenses/gpl-2.0.html.
11
"""SCIAMACHY regression module tests
12
"""
13
import numpy as np
14
15
import pytest
16
17
try:
18
	import pymc3 as pm
19
	import arviz as az
20
except ImportError:
21
	pytest.skip("Theano/PyMC3 packages not installed", allow_module_level=True)
22
23
try:
24
	from sciapy.regress.models_theano import (
25
		HarmonicModelCosineSine,
26
		HarmonicModelAmpPhase,
27
		LifetimeModel,
28
		ProxyModel,
29
	)
30
except ImportError:
31
	pytest.skip("Theano/PyMC3 interface not installed", allow_module_level=True)
32
33
34
@pytest.fixture(scope="module")
35
def xs():
36
	_xs = np.linspace(0., 11.1, 2048)
37
	return np.ascontiguousarray(_xs, dtype=np.float64)
38
39
40
def ys(xs, c, s):
41
	_ys = c * np.cos(2 * np.pi * xs) + s * np.sin(2 * np.pi * xs)
42
	return np.ascontiguousarray(_ys, dtype=np.float64)
43
44
45
@pytest.mark.parametrize(
46
	"c, s",
47
	[
48
		(0.5, 2.0),
49
		(1.0, 0.5),
50
		(1.0, 1.0),
51
	]
52
)
53
def test_harmonics_theano(xs, c, s):
54
	# Initialize random number generator
55
	np.random.seed(93457)
56
	yp = ys(xs, c, s)
57
	yp += 0.5 * np.random.randn(xs.shape[0])
58
59
	with pm.Model() as model1:
60
		cos = pm.Normal("cos", mu=0.0, sd=4.0)
61
		sin = pm.Normal("sin", mu=0.0, sd=4.0)
62
		harm1 = HarmonicModelCosineSine(1., cos, sin)
63
		wave1 = harm1.get_value(xs)
64
		# add amplitude and phase for comparison
65
		pm.Deterministic("amp", harm1.get_amplitude())
66
		pm.Deterministic("phase", harm1.get_phase())
67
		resid1 = yp - wave1
68
		pm.Normal("obs", mu=0.0, observed=resid1)
69
		trace1 = pm.sample(tune=800, draws=800, chains=2, return_inferencedata=True)
70
71
	with pm.Model() as model2:
72
		amp2 = pm.HalfNormal("amp", sigma=4.0)
73
		phase2 = pm.Normal("phase", mu=0.0, sd=4.0)
74
		harm2 = HarmonicModelAmpPhase(1., amp2, phase2)
75
		wave2 = harm2.get_value(xs)
76
		resid2 = yp - wave2
77
		pm.Normal("obs", mu=0.0, observed=resid2)
78
		trace2 = pm.sample(tune=800, draws=800, chains=2, return_inferencedata=True)
79
80
	np.testing.assert_allclose(
81
		trace1.posterior.median(dim=("chain", "draw"))[["cos", "sin"]].to_array(),
82
		(c, s),
83
		atol=1e-2,
84
	)
85
	np.testing.assert_allclose(
86
		trace1.posterior.median(dim=("chain", "draw"))[["amp", "phase"]].to_array(),
87
		trace2.posterior.median(dim=("chain", "draw"))[["amp", "phase"]].to_array(),
88
		atol=3e-3,
89
	)
90
91
92
def _test_data(xs, c, s):
93
	# generate proxy "values"
94
	values = ys(xs, c, s)
95
	amp = 3.
96
	lag = 2.
97
	tau0 = 1.
98
	harm0 = HarmonicModelCosineSine(1., c, s)
99
	tau_lt0 = LifetimeModel(harm0, lower=0.)
100
	proxy0 = ProxyModel(
101
		xs, values,
102
		amp=amp,
103
		lag=lag,
104
		tau0=tau0,
105
		tau_harm=tau_lt0,
106
		tau_scan=10,
107
	)
108
	return proxy0.get_value(xs).eval()
109
110
111
@pytest.mark.long
112
def test_proxy_theano(xs, c=3.0, s=1.0):
113
	# Initialize random number generator
114
	np.random.seed(93457)
115
116
	# proxy "values"
117
	values = ys(xs, c, s)
118
119
	yp = _test_data(xs, c, s)
120
	yp += 0.5 * np.random.randn(xs.shape[0])
121
122
	# using "name" prefixes all variables with <name>_
123
	with pm.Model(name="proxy") as model:
124
		# amplitude
125
		plamp = pm.Normal("log_amp", mu=0.0, sd=np.log(10.0))
126
		pamp = pm.Deterministic("amp", pm.math.exp(plamp))
127
		# lag
128
		pllag = pm.Normal("log_lag", mu=0.0, sd=np.log(10.0))
129
		plag = pm.Deterministic("lag", pm.math.exp(pllag))
130
		# lifetime
131
		pltau0 = pm.Normal("log_tau0", mu=0.0, sd=np.log(10.0))
132
		ptau0 = pm.Deterministic("tau0", pm.math.exp(pltau0))
133
		cos1 = pm.Normal("tau_cos1", mu=0.0, sd=10.0)
134
		sin1 = pm.Normal("tau_sin1", mu=0.0, sd=10.0)
135
		harm1 = HarmonicModelCosineSine(1., cos1, sin1)
136
		tau1 = LifetimeModel(harm1, lower=0)
137
138
		proxy = ProxyModel(
139
			xs, values,
140
			amp=pamp,
141
			lag=plag,
142
			tau0=ptau0,
143
			tau_harm=tau1,
144
			tau_scan=10,
145
		)
146
		prox1 = proxy.get_value(xs)
147
		# Include "jitter"
148
		log_jitter = pm.Normal("log_jitter", mu=0.0, sd=4.0)
149
		pm.Normal("obs", mu=prox1, sd=pm.math.exp(log_jitter), observed=yp)
150
151
		maxlp0 = pm.find_MAP()
152
		trace = pm.sample(
153
			chains=2,
154
			draws=1000,
155
			tune=1000,
156
			init="jitter+adapt_full",
157
			random_seed=[286923464, 464329682],
158
			return_inferencedata=True,
159
			start=maxlp0,
160
			target_accept=0.9,
161
		)
162
163
	medians = trace.posterior.median(dim=("chain", "draw"))
164
	np.testing.assert_allclose(
165
		medians[[
166
			"proxy_amp", "proxy_lag", "proxy_tau0",
167
			"proxy_tau_cos1", "proxy_tau_sin1",
168
			"proxy_log_jitter",
169
		]].to_array(),
170
		(3., 2., 1., c, s, np.log(0.5)),
171
		atol=3e-2, rtol=1e-2,
172
	)
173