Completed
Push — master ( ebb4fb...323695 )
by
unknown
01:25
created

zipline.lib._normalize_array()   B

Complexity

Conditions 5

Size

Total Lines 42

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 5
dl 0
loc 42
rs 8.0896
1
from textwrap import dedent
2
3
from numpy import (
4
    bool_,
5
    dtype,
6
    float32,
7
    float64,
8
    int32,
9
    int64,
10
    ndarray,
11
    uint32,
12
    uint8,
13
)
14
from zipline.errors import (
15
    WindowLengthNotPositive,
16
    WindowLengthTooLong,
17
)
18
from zipline.utils.numpy_utils import (
19
    datetime64ns_dtype,
20
    default_fillvalue_for_dtype,
21
    float64_dtype,
22
    int64_dtype,
23
    uint8_dtype,
24
)
25
from zipline.utils.memoize import lazyval
26
from zipline.utils.sentinel import sentinel
27
28
# These class names are all the same because of our bootleg templating system.
29
from ._float64window import AdjustedArrayWindow as Float64Window
30
from ._int64window import AdjustedArrayWindow as Int64Window
31
from ._uint8window import AdjustedArrayWindow as UInt8Window
32
33
Infer = sentinel(
34
    'Infer',
35
    "Sentinel used to say 'infer missing_value from data type.'"
36
)
37
NOMASK = None
38
SUPPORTED_NUMERIC_DTYPES = frozenset(
39
    map(dtype, [float32, float64, int32, int64, uint32])
40
)
41
CONCRETE_WINDOW_TYPES = {
42
    float64_dtype: Float64Window,
43
    int64_dtype: Int64Window,
44
    uint8_dtype: UInt8Window,
45
}
46
47
48
def _normalize_array(data):
49
    """
50
    Coerce buffer data for an AdjustedArray into a standard scalar
51
    representation, returning the coerced array and a numpy dtype object to use
52
    as a view type when providing public view into the data.
53
54
    Semantically numerical data (float*, int*, uint*) is coerced to float64 and
55
    viewed as float64.  We coerce integral data to float so that we can use NaN
56
    as a missing value.
57
58
    datetime[*] data is coerced to int64 with a viewtype of ``datetime64[ns]``.
59
60
    ``bool_`` data is coerced to uint8 with a viewtype of ``bool_``
61
62
    Parameters
63
    ----------
64
    data : np.ndarray
65
66
    Returns
67
    -------
68
    coerced, viewtype : (np.ndarray, np.dtype)
69
    """
70
    data_dtype = data.dtype
71
    if data_dtype == bool_:
72
        return data.astype(uint8), dtype(bool_)
73
    elif data_dtype in SUPPORTED_NUMERIC_DTYPES:
74
        return data.astype(float64), dtype(float64)
75
    elif data_dtype.name.startswith('datetime'):
76
        try:
77
            outarray = data.astype('datetime64[ns]').view('int64')
78
            return outarray, datetime64ns_dtype
79
        except OverflowError:
80
            raise ValueError(
81
                "AdjustedArray received a datetime array "
82
                "not representable as datetime64[ns].\n"
83
                "Min Date: %s\n"
84
                "Max Date: %s\n"
85
            ) % (data.min(), data.max())
86
    else:
87
        raise TypeError(
88
            "Don't know how to construct AdjustedArray "
89
            "on data of type %s." % dtype
90
        )
91
92
93
class AdjustedArray(object):
94
    """
95
    An array that can be iterated with a variable-length window, and which can
96
    provide different views on data from different perspectives.
97
98
    Parameters
99
    ----------
100
    data : np.ndarray
101
        The baseline data values.
102
    mask : np.ndarray[bool]
103
        A mask indicating the locations of missing data.
104
    adjustments : dict[int -> list[Adjustment]]
105
        A dict mapping row indices to lists of adjustments to apply when we
106
        reach that row.
107
    fillvalue : object, optional
108
        A value to use to fill missing data in yielded windows.
109
        Default behavior is to infer a value based on the dtype of `data`.
110
        `NaN` is used for numeric data, and `NaT` is used for datetime data.
111
    """
112
    __slots__ = ('_data', '_viewtype', 'adjustments', '__weakref__')
113
114
    def __init__(self, data, mask, adjustments, fillvalue=Infer):
115
        self._data, self._viewtype = _normalize_array(data)
116
        self.adjustments = adjustments
117
        if fillvalue is Infer:
118
            fillvalue = default_fillvalue_for_dtype(self.data.dtype)
119
120
        if mask is not NOMASK:
121
            if mask.dtype != bool_:
122
                raise ValueError("Mask must be a bool array.")
123
            if data.shape != mask.shape:
124
                raise ValueError(
125
                    "Mask shape %s != data shape %s." %
126
                    (mask.shape, data.shape),
127
                )
128
            self._data[~mask] = fillvalue
129
130
    @lazyval
131
    def data(self):
132
        """
133
        The data stored in this array.
134
        """
135
        return self._data.view(self._viewtype)
136
137
    @lazyval
138
    def dtype(self):
139
        """
140
        The dtype of the data stored in this array.
141
        """
142
        return self._viewtype
143
144
    @lazyval
145
    def _iterator_type(self):
146
        """
147
        The iterator produced when `traverse` is called on this Array.
148
        """
149
        return CONCRETE_WINDOW_TYPES[self._data.dtype]
150
151
    def traverse(self, window_length, offset=0):
152
        """
153
        Produce an iterator rolling windows rows over our data.
154
        Each emitted window will have `window_length` rows.
155
156
        Parameters
157
        ----------
158
        window_length : int
159
            The number of rows in each emitted window.
160
        offset : int, optional
161
            Number of rows to skip before the first window.
162
        """
163
        data = self._data.copy()
164
        _check_window_params(data, window_length)
165
        return self._iterator_type(
166
            data,
167
            self._viewtype,
168
            self.adjustments,
169
            offset,
170
            window_length,
171
        )
172
173
    def inspect(self):
174
        """
175
        Return a string representation of the data stored in this array.
176
        """
177
        return dedent(
178
            """\
179
            Adjusted Array ({dtype}):
180
181
            Data:
182
            {data!r}
183
184
            Adjustments:
185
            {adjustments}
186
            """
187
        ).format(
188
            dtype=self.dtype.name,
189
            data=self.data,
190
            adjustments=self.adjustments,
191
        )
192
193
194
def ensure_ndarray(ndarray_or_adjusted_array):
195
    """
196
    Return the input as a numpy ndarray.
197
198
    This is a no-op if the input is already an ndarray.  If the input is an
199
    adjusted_array, this extracts a read-only view of its internal data buffer.
200
201
    Parameters
202
    ----------
203
    ndarray_or_adjusted_array : numpy.ndarray | zipline.data.adjusted_array
204
205
    Returns
206
    -------
207
    out : The input, converted to an ndarray.
208
    """
209
    if isinstance(ndarray_or_adjusted_array, ndarray):
210
        return ndarray_or_adjusted_array
211
    elif isinstance(ndarray_or_adjusted_array, AdjustedArray):
212
        return ndarray_or_adjusted_array.data
213
    else:
214
        raise TypeError(
215
            "Can't convert %s to ndarray" %
216
            type(ndarray_or_adjusted_array).__name__
217
        )
218
219
220
def _check_window_params(data, window_length):
221
    """
222
    Check that a window of length `window_length` is well-defined on `data`.
223
224
    Parameters
225
    ----------
226
    data : np.ndarray[ndim=2]
227
        The array of data to check.
228
    window_length : int
229
        Length of the desired window.
230
231
    Returns
232
    -------
233
    None
234
235
    Raises
236
    ------
237
    WindowLengthNotPositive
238
        If window_length < 1.
239
    WindowLengthTooLong
240
        If window_length is greater than the number of rows in `data`.
241
    """
242
    if window_length < 1:
243
        raise WindowLengthNotPositive(window_length=window_length)
244
245
    if window_length > data.shape[0]:
246
        raise WindowLengthTooLong(
247
            nrows=data.shape[0],
248
            window_length=window_length,
249
        )
250