analytic_solution()   A
last analyzed

Complexity

Conditions 1

Size

Total Lines 8

Duplication

Lines 0
Ratio 0 %

Importance

Changes 4
Bugs 0 Features 0
Metric Value
cc 1
c 4
b 0
f 0
dl 0
loc 8
rs 9.4285
1
from nose import tools
2
3
import numpy as np
4
from scipy import stats
5
6
from .. import basis_functions
7
from .. import problems
8
from .. import solvers
9
10
11
def analytic_solution(y, nL, alpha):
12
    """
13
    Analytic solution to the differential equation describing the signaling
14
    equilbrium of the Spence (1974) model.
15
16
    """
17
    D = ((1 + alpha) / 2) * (nL / yL(nL, alpha)**-alpha)**2 - yL(nL, alpha)**(1 + alpha)
18
    return y**(-alpha) * (2 * (y**(1 + alpha) + D) / (1 + alpha))**0.5
19
20
21
def spence_model(y, n, alpha, **params):
22
    return [(n**-1 - alpha * n * y**(alpha - 1)) / y**alpha]
23
24
25
def initial_condition(y, n, nL, alpha, **params):
26
    return [n - nL]
27
28
29
def yL(nL, alpha):
30
    return (nL**2 * alpha)**(1 / (1 - alpha))
31
32
33
def initial_mesh(yL, yH, num, problem):
34
    ys = np.linspace(yL, yH, num=num)
35
    ns = problem.params['nL'] + np.sqrt(ys)
36
    return ys, ns
37
38
39
random_seed = np.random.randint(2147483647)
40
params = {'nL': 1.0, 'alpha': 0.15}
41
test_problem = problems.IVP(initial_condition, 1, 1, params, spence_model)
42
43
44
def test_bspline_collocation():
45
    """Tests B-spline collocation."""
46
    bspline_basis = basis_functions.BSplineBasis()
47
    solver = solvers.Solver(bspline_basis)
48
49
    boundary_points = (yL(**params), 10)
50
    ys, ns = initial_mesh(*boundary_points, num=250, problem=test_problem)
51
52
    tck, u = bspline_basis.fit([ns], u=ys, k=5, s=0)
53
    knots, coefs, k = tck
54
    initial_coefs = np.hstack(coefs)
55
56
    basis_kwargs = {'knots': knots, 'degree': k, 'ext': 2}
57
    nodes = np.linspace(*boundary_points, num=249)
58
    solution = solver.solve(basis_kwargs, boundary_points, initial_coefs,
59
                            nodes, test_problem)
60
61
    # check that solver terminated successfully
62
    msg = "Solver failed!\nSeed: {}\nModel params: {}\n"
63
    tools.assert_true(solution.result.success,
64
                      msg=msg.format(random_seed, test_problem.params))
65
66
    # compute the residuals
67
    normed_residuals = solution.normalize_residuals(ys)
68
69
    # check that residuals are close to zero on average
70
    tools.assert_true(np.mean(normed_residuals) < 1e-6,
71
                      msg=msg.format(random_seed, test_problem.params))
72
73
    # check that the numerical and analytic solutions are close
74
    numeric_soln = solution.evaluate_solution(ys)
75
    analytic_soln = analytic_solution(ys, **test_problem.params)
76
    tools.assert_true(np.mean(numeric_soln - analytic_soln) < 1e-6)
77