Passed
Push — master ( 8653da...99a432 )
by Stefan
06:35
created

test_regress_models_theano   A

Complexity

Total Complexity 5

Size/Duplication

Total Lines 87
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 57
dl 0
loc 87
rs 10
c 0
b 0
f 0
wmc 5

3 Functions

Rating   Name   Duplication   Size   Complexity  
A test_harmonics_theano() 0 44 3
A xs() 0 4 1
A ys() 0 3 1
1
# -*- coding: utf-8 -*-
2
# vim:fileencoding=utf-8
3
#
4
# Copyright (c) 2018-2022 Stefan Bender
5
#
6
# This module is part of sciapy.
7
# sciapy is free software: you can redistribute it or modify
8
# it under the terms of the GNU General Public License as published
9
# by the Free Software Foundation, version 2.
10
# See accompanying LICENSE file or http://www.gnu.org/licenses/gpl-2.0.html.
11
"""SCIAMACHY regression module tests
12
"""
13
import numpy as np
14
15
import pytest
16
17
try:
18
	import pymc3 as pm
19
	import arviz as az
20
except (ImportError, ModuleNotFoundError):
21
	pytest.skip("Theano/PyMC3 packages not installed", allow_module_level=True)
22
23
try:
24
	from sciapy.regress.models_theano import (
25
		HarmonicModelCosineSine,
26
		HarmonicModelAmpPhase,
27
	)
28
except (ImportError, ModuleNotFoundError):
29
	pytest.skip("Theano/PyMC3 interface not installed", allow_module_level=True)
30
31
32
@pytest.fixture(scope="module")
33
def xs():
34
	_xs = np.linspace(0., 11.1, 2048)
35
	return np.ascontiguousarray(_xs, dtype=np.float64)
36
37
38
def ys(xs, c, s):
39
	_ys = c * np.cos(2 * np.pi * xs) + s * np.sin(2 * np.pi * xs)
40
	return np.ascontiguousarray(_ys, dtype=np.float64)
41
42
43
@pytest.mark.parametrize(
44
	"c, s",
45
	[
46
		(0.5, 2.0),
47
		(1.0, 0.5),
48
		(1.0, 1.0),
49
	]
50
)
51
def test_harmonics_theano(xs, c, s):
52
	# Initialize random number generator
53
	np.random.seed(93457)
54
	yp = ys(xs, c, s)
55
	yp += 0.5 * np.random.randn(xs.shape[0])
56
57
	with pm.Model() as model1:
58
		cos = pm.Normal("cos", mu=0.0, sd=4.0)
59
		sin = pm.Normal("sin", mu=0.0, sd=4.0)
60
		harm1 = HarmonicModelCosineSine(1., cos, sin)
61
		wave1 = harm1.get_value(xs)
62
		# add amplitude and phase for comparison
63
		pm.Deterministic("amp", harm1.get_amplitude())
64
		pm.Deterministic("phase", harm1.get_phase())
65
		resid1 = yp - wave1
66
		pm.Normal("obs", mu=0.0, observed=resid1)
67
		trace1 = pm.sample(tune=800, draws=800, chains=2, return_inferencedata=True)
68
69
	with pm.Model() as model2:
70
		amp2 = pm.HalfNormal("amp", sigma=4.0)
71
		phase2 = pm.Normal("phase", mu=0.0, sd=4.0)
72
		harm2 = HarmonicModelAmpPhase(1., amp2, phase2)
73
		wave2 = harm2.get_value(xs)
74
		resid2 = yp - wave2
75
		pm.Normal("obs", mu=0.0, observed=resid2)
76
		trace2 = pm.sample(tune=800, draws=800, chains=2, return_inferencedata=True)
77
78
	np.testing.assert_allclose(
79
		trace1.posterior.median(dim=("chain", "draw"))[["cos", "sin"]].to_array(),
80
		(c, s),
81
		atol=1e-2,
82
	)
83
	np.testing.assert_allclose(
84
		trace1.posterior.median(dim=("chain", "draw"))[["amp", "phase"]].to_array(),
85
		trace2.posterior.median(dim=("chain", "draw"))[["amp", "phase"]].to_array(),
86
		atol=3e-3,
87
	)
88