Completed
Push — master ( 3e0e7c...40f314 )
by Ryan
01:16
created

patch_round()   A

Complexity

Conditions 1

Size

Total Lines 9

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
cc 1
c 1
b 0
f 0
dl 0
loc 9
rs 9.6666
1
# Copyright (c) 2008-2015 MetPy Developers.
2
# Distributed under the terms of the BSD 3-Clause License.
3
# SPDX-License-Identifier: BSD-3-Clause
4
r"""Collection of utilities for testing.
5
6
This includes:
7
* unit-aware test functions
8
* code for testing matplotlib figures
9
"""
10
11
import numpy as np
12
import numpy.testing
13
from pint import DimensionalityError
14
import pytest
15
16
from .units import units
17
18
19
def check_and_drop_units(actual, desired):
20
    r"""Check that the units on the passed in arrays are compatible; return the magnitudes.
21
22
    Parameters
23
    ----------
24
    actual : `pint.Quantity` or array-like
25
26
    desired : `pint.Quantity` or array-like
27
28
    Returns
29
    -------
30
    actual, desired
31
        array-like versions of `actual` and `desired` once they have been
32
        coerced to compatible units.
33
34
    Raises
35
    ------
36
    AssertionError
37
        If the units on the passed in objects are not compatible.
38
    """
39
    try:
40
        # If the desired result has units, add dimensionless units if necessary, then
41
        # ensure that this is compatible to the desired result.
42
        if hasattr(desired, 'units'):
43
            if not hasattr(actual, 'units'):
44
                actual = units.Quantity(actual, 'dimensionless')
45
            actual = actual.to(desired.units)
46
        # Otherwise, the desired result has no units. Convert the actual result to
47
        # dimensionless units if it is a united quantity.
48
        else:
49
            if hasattr(actual, 'units'):
50
                actual = actual.to('dimensionless')
51
    except DimensionalityError:
52
        raise AssertionError('Units are not compatible: {} should be {}'.format(actual.units,
53
                                                                                desired.units))
54
    except AttributeError:
55
        pass
56
57
    if hasattr(actual, 'magnitude'):
58
        actual = actual.magnitude
59
    if hasattr(desired, 'magnitude'):
60
        desired = desired.magnitude
61
62
    return actual, desired
63
64
65
def assert_almost_equal(actual, desired, decimal=7):
66
    """Check that values are almost equal, including units.
67
68
    Wrapper around :func:`numpy.testing.assert_almost_equal`
69
    """
70
    actual, desired = check_and_drop_units(actual, desired)
71
    numpy.testing.assert_almost_equal(actual, desired, decimal)
72
73
74
def assert_array_almost_equal(actual, desired, decimal=7):
75
    """Check that arrays are almost equal, including units.
76
77
    Wrapper around :func:`numpy.testing.assert_array_almost_equal`
78
    """
79
    actual, desired = check_and_drop_units(actual, desired)
80
    numpy.testing.assert_array_almost_equal(actual, desired, decimal)
81
82
83
def assert_array_equal(actual, desired):
84
    """Check that arrays are equal, including units.
85
86
    Wrapper around :func:`numpy.testing.assert_array_equal`
87
    """
88
    actual, desired = check_and_drop_units(actual, desired)
89
    numpy.testing.assert_array_equal(actual, desired)
90
91
92
@pytest.fixture(scope='module', autouse=True)
93
def set_agg_backend():
94
    """Fixture to ensure the Agg backend is active."""
95
    import matplotlib.pyplot as plt
96
    prev_backend = plt.get_backend()
97
    try:
98
        plt.switch_backend('agg')
99
        yield
100
    finally:
101
        plt.switch_backend(prev_backend)
102
103
104
@pytest.fixture(autouse=True)
105
def patch_round(monkeypatch):
106
    """Fixture to patch builtin round using numpy's.
107
108
    This works around the fact that built-in round changed between Python 2 and 3. This
109
    is probably not needed once we're testing on matplotlib 2.0, which has been updated
110
    to use numpy's throughout.
111
    """
112
    monkeypatch.setitem(__builtins__, 'round', np.round)
113