Completed
Pull Request — master (#881)
by
unknown
01:24
created

zipline.lib.AdjustedArray.data()   A

Complexity

Conditions 1

Size

Total Lines 6

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 1
dl 0
loc 6
rs 9.4286
1
from textwrap import dedent
2
3
from numpy import (
4
    bool_,
5
    dtype,
6
    float32,
7
    float64,
8
    int32,
9
    int64,
10
    nan,
11
    ndarray,
12
    uint32,
13
    uint8,
14
)
15
from zipline.errors import (
16
    WindowLengthNotPositive,
17
    WindowLengthTooLong,
18
)
19
from zipline.utils.numpy_utils import (
20
    datetime64ns_dtype,
21
    make_datetime64ns,
22
)
23
from zipline.utils.memoize import lazyval
24
from zipline.utils.sentinel import sentinel
25
26
# These class names are all the same because of our bootleg templating system.
27
from ._float64window import AdjustedArrayWindow as Float64Window
28
from ._int64window import AdjustedArrayWindow as Int64Window
29
from ._uint8window import AdjustedArrayWindow as UInt8Window
30
31
Infer = sentinel(
32
    'Infer',
33
    "Sentinel used to say 'infer missing_value from data type.'"
34
)
35
NOMASK = None
36
SUPPORTED_NUMERIC_DTYPES = frozenset(
37
    map(dtype, [float32, float64, int32, int64, uint32])
38
)
39
CONCRETE_WINDOW_TYPES = {
40
    dtype(float64): Float64Window,
41
    dtype(int64): Int64Window,
42
    dtype(uint8): UInt8Window,
43
}
44
_FILLVALUE_DEFAULTS = {
45
    dtype(float64): nan,
46
    dtype('datetime64[ns]'): make_datetime64ns('NaT'),
47
}
48
49
50
def default_fillvalue_for_dtype(dtype):
51
    """
52
    Get the default fill value for dtype `type_`.
53
    """
54
    return _FILLVALUE_DEFAULTS[dtype]
55
56
57
def _normalize_array(data):
58
    """
59
    Coerce buffer data for an AdjustedArray into a standard scalar
60
    representation, returning the coerced array and a numpy dtype object to use
61
    as a view type when providing public view into the data.
62
63
    Semantically numerical data (float*, int*, uint*) is coerced to float64 and
64
    viewed as float64.  We coerce integral data to float so that we can use NaN
65
    as a missing value.
66
67
    datetime[*] data is coerced to int64 with a viewtype of ``datetime64[ns]``.
68
69
    ``bool_`` data is coerced to uint8 with a viewtype of ``bool_``
70
71
    Parameters
72
    ----------
73
    data : np.ndarray
74
75
    Returns
76
    -------
77
    coerced, viewtype : (np.ndarray, np.dtype)
78
    """
79
    data_dtype = data.dtype
80
    if data_dtype == bool_:
81
        return data.astype(uint8), dtype(bool_)
82
    elif data_dtype in SUPPORTED_NUMERIC_DTYPES:
83
        return data.astype(float64), dtype(float64)
84
    elif data_dtype.name.startswith('datetime'):
85
        try:
86
            outarray = data.astype('datetime64[ns]').view('int64')
87
            return outarray, datetime64ns_dtype
88
        except OverflowError:
89
            raise ValueError(
90
                "AdjustedArray received a datetime array "
91
                "not representable as datetime64[ns].\n"
92
                "Min Date: %s\n"
93
                "Max Date: %s\n"
94
            ) % (data.min(), data.max())
95
    else:
96
        raise TypeError(
97
            "Don't know how to construct AdjustedArray "
98
            "on data of type %s." % dtype
99
        )
100
101
102
class AdjustedArray(object):
103
    """
104
    An array that can be iterated with a variable-length window, and which can
105
    provide different views on data from different perspectives.
106
107
    Parameters
108
    ----------
109
    data : np.ndarray
110
        The baseline data values.
111
    mask : np.ndarray[bool]
112
        A mask indicating the locations of missing data.
113
    adjustments : dict[int -> list[Adjustment]]
114
        A dict mapping row indices to lists of adjustments to apply when we
115
        reach that row.
116
    fillvalue : object, optional
117
        A value to use to fill missing data in yielded windows.
118
        Default behavior is to infer a value based on the dtype of `data`.
119
        `NaN` is used for numeric data, and `NaT` is used for datetime data.
120
    """
121
    __slots__ = ('_data', '_viewtype', '_mask', 'adjustments', '__weakref__')
122
123
    def __init__(self, data, mask, adjustments, fillvalue=Infer):
124
        self._data, self._viewtype = _normalize_array(data)
125
        self.adjustments = adjustments
126
        if fillvalue is Infer:
127
            fillvalue = default_fillvalue_for_dtype(self.data.dtype)
128
129
        if mask is not NOMASK:
130
            if mask.dtype != bool_:
131
                raise ValueError("Mask must be a bool array.")
132
            if data.shape != mask.shape:
133
                raise ValueError(
134
                    "Mask shape %s != data shape %s." %
135
                    (mask.shape, data.shape),
136
                )
137
            self._mask = mask
138
139
    @lazyval
140
    def data(self):
141
        """
142
        The data stored in this array.
143
        """
144
        return self._data.view(self._viewtype)
145
146
    @lazyval
147
    def dtype(self):
148
        """
149
        The dtype of the data stored in this array.
150
        """
151
        return self._viewtype
152
153
    @lazyval
154
    def _iterator_type(self):
155
        """
156
        The iterator produced when `traverse` is called on this Array.
157
        """
158
        return CONCRETE_WINDOW_TYPES[self._data.dtype]
159
160
    def traverse(self, window_length, offset=0):
161
        """
162
        Produce an iterator rolling windows rows over our data.
163
        Each emitted window will have `window_length` rows.
164
165
        Parameters
166
        ----------
167
        window_length : int
168
            The number of rows in each emitted window.
169
        offset : int, optional
170
            Number of rows to skip before the first window.
171
        """
172
        data = self._data.copy()
173
        _check_window_params(data, window_length)
174
        return self._iterator_type(
175
            data,
176
            self._viewtype,
177
            self.adjustments,
178
            offset,
179
            window_length,
180
        )
181
182
    def inspect(self):
183
        """
184
        Return a string representation of the data stored in this array.
185
        """
186
        return dedent(
187
            """\
188
            Adjusted Array ({dtype}):
189
190
            Data:
191
            {data!r}
192
193
            Adjustments:
194
            {adjustments}
195
            """
196
        ).format(
197
            dtype=self.dtype.name,
198
            data=self.data,
199
            adjustments=self.adjustments,
200
        )
201
202
203
def ensure_ndarray(ndarray_or_adjusted_array):
204
    """
205
    Return the input as a numpy ndarray.
206
207
    This is a no-op if the input is already an ndarray.  If the input is an
208
    adjusted_array, this extracts a read-only view of its internal data buffer.
209
210
    Parameters
211
    ----------
212
    ndarray_or_adjusted_array : numpy.ndarray | zipline.data.adjusted_array
213
214
    Returns
215
    -------
216
    out : The input, converted to an ndarray.
217
    """
218
    if isinstance(ndarray_or_adjusted_array, ndarray):
219
        return ndarray_or_adjusted_array
220
    elif isinstance(ndarray_or_adjusted_array, AdjustedArray):
221
        return ndarray_or_adjusted_array.data
222
    else:
223
        raise TypeError(
224
            "Can't convert %s to ndarray" %
225
            type(ndarray_or_adjusted_array).__name__
226
        )
227
228
229
def _check_window_params(data, window_length):
230
    """
231
    Check that a window of length `window_length` is well-defined on `data`.
232
233
    Parameters
234
    ----------
235
    data : np.ndarray[ndim=2]
236
        The array of data to check.
237
    window_length : int
238
        Length of the desired window.
239
240
    Returns
241
    -------
242
    None
243
244
    Raises
245
    ------
246
    WindowLengthNotPositive
247
        If window_length < 1.
248
    WindowLengthTooLong
249
        If window_length is greater than the number of rows in `data`.
250
    """
251
    if window_length < 1:
252
        raise WindowLengthNotPositive(window_length=window_length)
253
254
    if window_length > data.shape[0]:
255
        raise WindowLengthTooLong(
256
            nrows=data.shape[0],
257
            window_length=window_length,
258
        )
259