Completed
Pull Request — master (#326)
by
unknown
01:49
created

test_interpolate_nan_linear()   A

Complexity

Conditions 1

Size

Total Lines 8

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
cc 1
c 1
b 0
f 0
dl 0
loc 8
rs 9.4285
1
# Copyright (c) 2008-2015 MetPy Developers.
2
# Distributed under the terms of the BSD 3-Clause License.
3
# SPDX-License-Identifier: BSD-3-Clause
4
"""Tests for `calc.tools` module."""
5
6
import numpy as np
7
import pytest
8
9
from metpy.calc import (find_intersections, interpolate_nans, nearest_intersection_idx,
10
                        resample_nn_1d)
11
from metpy.testing import assert_array_almost_equal, assert_array_equal
12
13
14
def test_resample_nn():
15
    """Test 1d nearest neighbor functionality."""
16
    a = np.arange(5.)
17
    b = np.array([2, 3.8])
18
    truth = np.array([2, 4])
19
20
    assert_array_equal(truth, resample_nn_1d(a, b))
21
22
23
def test_nearest_intersection_idx():
24
    """Test nearest index to intersection functionality."""
25
    x = np.linspace(5, 30, 17)
26
    y1 = 3 * x**2
27
    y2 = 100 * x - 650
28
    truth = np.array([2, 12])
29
30
    assert_array_equal(truth, nearest_intersection_idx(y1, y2))
31
32
33
@pytest.mark.parametrize('direction, expected', [
34
    ('all', np.array([[8.88, 24.44], [238.84, 1794.53]])),
35
    ('increasing', np.array([[24.44], [1794.53]])),
36
    ('decreasing', np.array([[8.88], [238.84]]))
37
])
38
def test_find_intersections(direction, expected):
39
    """Test finding the intersection of two curves functionality."""
40
    x = np.linspace(5, 30, 17)
41
    y1 = 3 * x**2
42
    y2 = 100 * x - 650
43
    # Note: Truth is what we will get with this sampling, not the mathematical intersection
44
    assert_array_almost_equal(expected, find_intersections(x, y1, y2, direction=direction), 2)
45
46
47
def test_find_intersections_no_intersections():
48
    """Test finding the intersection of two curves with no intersections."""
49
    x = np.linspace(5, 30, 17)
50
    y1 = 3 * x + 0
51
    y2 = 5 * x + 5
52
    # Note: Truth is what we will get with this sampling, not the mathematical intersection
53
    truth = np.array([[],
54
                      []])
55
    assert_array_equal(truth, find_intersections(x, y1, y2))
56
57
58
def test_find_intersections_invalid_direction():
59
    """Test exception if an invalid direction is given."""
60
    x = np.linspace(5, 30, 17)
61
    y1 = 3 * x ** 2
62
    y2 = 100 * x - 650
63
    with pytest.raises(ValueError):
64
        find_intersections(x, y1, y2, direction='increaing')
65
66
67
def test_interpolate_nan_linear():
68
    """Test linear interpolation of arrays with NaNs in the y-coordinate."""
69
    x = np.linspace(0, 20, 15)
70
    y = 5 * x + 3
71
    nan_indexes = [1, 5, 11, 12]
72
    y_with_nan = y.copy()
73
    y_with_nan[nan_indexes] = np.nan
74
    assert_array_almost_equal(y, interpolate_nans(x, y_with_nan), 2)
75
76
77
def test_interpolate_nan_log():
78
    """Test log interpolation of arrays with NaNs in the y-coordinate."""
79
    x = np.logspace(1, 5, 15)
80
    y = 5 * np.log(x) + 3
81
    nan_indexes = [1, 5, 11, 12]
82
    y_with_nan = y.copy()
83
    y_with_nan[nan_indexes] = np.nan
84
    assert_array_almost_equal(y, interpolate_nans(x, y_with_nan, kind='log'), 2)
85
86
87
def test_interpolate_nan_invalid():
88
    """Test log interpolation with invalid parameter."""
89
    x = np.logspace(1, 5, 15)
90
    y = 5 * np.log(x) + 3
91
    with pytest.raises(ValueError):
92
        interpolate_nans(x, y, kind='loog')
93