Passed
Pull Request — master (#2093)
by Axel
02:43
created

InterpolationScale.inverse()   A

Complexity

Conditions 2

Size

Total Lines 6
Code Lines 5

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 5
nop 2
dl 0
loc 6
rs 10
c 0
b 0
f 0
1
# Licensed under a 3-clause BSD style license - see LICENSE.rst
2
"""Interpolation utilities"""
3
import numpy as np
4
from scipy.interpolate import RegularGridInterpolator
5
from astropy import units as u
6
7
8
__all__ = ["ScaledRegularGridInterpolator", "interpolation_scale"]
9
10
11
class ScaledRegularGridInterpolator:
12
    """Thin wrapper around `scipy.interpolate.RegularGridInterpolator`.
13
14
    The values are scaled before the interpolation and back-scaled after the
15
    interpolation.
16
17
    Parameters
18
    ----------
19
    points : tuple of `~numpy.ndarray` or `~astropy.units.Quantity`
20
        Tuple of points passed to `RegularGridInterpolator`.
21
    values : `~numpy.ndarray`
22
        Values passed to `RegularGridInterpolator`.
23
    points_scale : tuple of str
24
        Interpolation scale used for the points.
25
    values_scale : {'lin', 'log', 'sqrt'}
26
        Interpolation scaling applied to values. If the values vary over many magnitudes
27
        a 'log' scaling is recommended.
28
    **kwargs : dict
29
        Keyword arguments passed to `RegularGridInterpolator`.
30
    """
31
32
    def __init__(self, points, values, points_scale=None, values_scale="lin", extrapolate=True, **kwargs):
33
34
        if points_scale is None:
35
            points_scale = ["lin"] * len(points)
36
37
        self.scale_points = [interpolation_scale(scale) for scale in points_scale]
38
        self.scale = interpolation_scale(values_scale)
39
40
        points_scaled = tuple([scale(p) for p, scale in zip(points, self.scale_points)])
41
        values_scaled = self.scale(values)
42
43
        if extrapolate:
44
            kwargs.setdefault("bounds_error", False)
45
            kwargs.setdefault("fill_value", None)
46
47
        self._interpolate = RegularGridInterpolator(
48
            points=points_scaled, values=values_scaled, **kwargs
49
        )
50
51
    def __call__(self, points, method="linear", clip=True, **kwargs):
52
        """Interpolate data points.
53
54
        Parameters
55
        ----------
56
        points : tuple of `np.ndarray` or `~astropy.units.Quantity`
57
            Tuple of coordinate arrays of the form (x_1, x_2, x_3, ...). Arrays are
58
            broadcasted internally.
59
        method : {"linear", "nearest"}
60
            Linear or nearest neighbour interpolation.
61
        clip : bool
62
            Clip values at zero after interpolation.
63
        """
64
        points = tuple([scale(p) for scale, p in zip(self.scale_points, points)])
65
66
        points = np.broadcast_arrays(*points)
67
        points_interp = np.stack([_.flat for _ in points]).T
68
69
        values = self._interpolate(points_interp, method, **kwargs)
70
        values = self.scale.inverse(values.reshape(points[0].shape))
71
72
        if clip:
73
            values = np.clip(values, 0, np.inf)
74
75
        return values
76
77
78
def interpolation_scale(scale="lin"):
79
    """Interpolation scaling.
80
81
    Parameters
82
    ----------
83
    scale : {"lin", "log", "sqrt"}
84
        Choose interpolation scaling.
85
    """
86
    if scale in ["lin", "linear"]:
87
        return LinearScale()
88
    elif scale == "log":
89
        return LogScale()
90
    elif scale == "sqrt":
91
        return SqrtScale()
92
    else:
93
        raise ValueError("Not a valid value scaling mode: '{}'.".format(scale))
94
95
96
class InterpolationScale:
97
    """Interpolation scale base class."""
98
    def __call__(self, values):
99
        if hasattr(self, "_unit"):
100
            values = values.to_value(self._unit)
101
        else:
102
            if isinstance(values, u.Quantity):
103
                self._unit = values.unit
104
                values = values.value
105
        return self._scale(values)
106
107
    def inverse(self, values):
108
        values = self._inverse(values)
109
        if hasattr(self, "_unit"):
110
            return u.Quantity(values, self._unit, copy=False)
111
        else:
112
            return values
113
114
115
class LogScale(InterpolationScale):
116
    """Logarithmic scaling"""
117
118
    tiny = np.finfo(np.float32).tiny
119
120
    def _scale(self, values):
121
        values = np.clip(values, self.tiny, np.inf)
122
        return np.log(values)
123
124
    def _inverse(self, values):
125
        return np.exp(values)
126
127
128
class SqrtScale(InterpolationScale):
129
    """Sqrt scaling"""
130
131
    def _scale(self, values):
132
        sign = np.sign(values)
133
        return sign * np.sqrt(sign * values)
134
135
    def _inverse(self, values):
136
        return np.power(values, 2)
137
138
139
class LinearScale(InterpolationScale):
140
    """Linear scaling"""
141
142
    def _scale(self, values):
143
        return values
144
145
    def _inverse(self, values):
146
        return values
147