Completed
Pull Request — master (#947)
by Joe
01:25
created

  A

Complexity

Conditions 1

Size

Total Lines 7

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 1
dl 0
loc 7
rs 9.4286
1
import datetime
2
3
import blaze as bz
4
from datashape import istabular
5
from odo import odo
6
import pandas as pd
7
from six import iteritems
8
from toolz import valmap
9
10
from .core import TS_FIELD_NAME, SID_FIELD_NAME, overwrite_novel_deltas
11
from zipline.pipeline.data import EarningsCalendar
12
from zipline.pipeline.loaders.base import PipelineLoader
13
from zipline.pipeline.loaders.earnings import EarningsCalendarLoader
14
from zipline.pipeline.loaders.utils import (
15
    normalize_data_query_time,
16
    normalize_timestamp_to_query_time,
17
)
18
from zipline.utils.input_validation import ensure_timezone
19
from zipline.utils.preprocess import preprocess
20
21
22
ANNOUNCEMENT_FIELD_NAME = 'announcement_date'
23
24
25
def bind_expression_to_resources(expr, resources):
26
    """
27
    Bind a Blaze expression to resources.
28
29
    Parameters
30
    ----------
31
    expr : bz.Expr
32
        The expression to which we want to bind resources.
33
    resources : dict[bz.Symbol -> any]
34
        Mapping from the atomic terms of ``expr`` to actual data resources.
35
36
    Returns
37
    -------
38
    bound_expr : bz.Expr
39
        ``expr`` with bound resources.
40
    """
41
    # bind the resources into the expression
42
    if resources is None:
43
        resources = {}
44
45
    # _subs stands for substitute.  It's not actually private, blaze just
46
    # prefixes symbol-manipulation methods with underscores to prevent
47
    # collisions with data column names.
48
    return expr._subs({
49
        k: bz.Data(v, dshape=k.dshape) for k, v in iteritems(resources)
50
    })
51
52
53
class BlazeEarningsCalendarLoader(PipelineLoader):
54
    """A pipeline loader for the ``EarningsCalendar`` dataset that loads
55
    data from a blaze expression.
56
57
    Parameters
58
    ----------
59
    expr : Expr
60
        The expression representing the data to load.
61
    resources : dict, optional
62
        Mapping from the atomic terms of ``expr`` to actual data resources.
63
    odo_kwargs : dict, optional
64
        Extra keyword arguments to pass to odo when executing the expression.
65
    data_query_time : time, optional
66
        The time to use for the data query cutoff.
67
    data_query_tz : tzinfo or str
68
        The timezeone to use for the data query cutoff.
69
70
    Notes
71
    -----
72
    The expression should have a tabular dshape of::
73
74
       Dim * {{
75
           {SID_FIELD_NAME}: int64,
76
           {TS_FIELD_NAME}: datetime64,
77
           {ANNOUNCEMENT_FIELD_NAME}: datetime64,
78
       }}
79
80
    Where each row of the table is a record including the sid to identify the
81
    company, the timestamp where we learned about the announcement, and the
82
    date when the earnings will be announced.
83
84
    If the '{TS_FIELD_NAME}' field is not included it is assumed that we
85
    start the backtest with knowledge of all announcements.
86
    """
87
    __doc__ = __doc__.format(
88
        TS_FIELD_NAME=TS_FIELD_NAME,
89
        SID_FIELD_NAME=SID_FIELD_NAME,
90
        ANNOUNCEMENT_FIELD_NAME=ANNOUNCEMENT_FIELD_NAME,
91
    )
92
93
    _expected_fields = frozenset({
94
        TS_FIELD_NAME,
95
        SID_FIELD_NAME,
96
        ANNOUNCEMENT_FIELD_NAME,
97
    })
98
99
    @preprocess(data_query_tz=ensure_timezone)
100
    def __init__(self,
101
                 expr,
102
                 resources=None,
103
                 compute_kwargs=None,
104
                 odo_kwargs=None,
105
                 data_query_time=datetime.time(0),
106
                 data_query_tz='utc',
107
                 dataset=EarningsCalendar):
108
        dshape = expr.dshape
109
110
        if not istabular(dshape):
111
            raise ValueError(
112
                'expression dshape must be tabular, got: %s' % dshape,
113
            )
114
115
        expected_fields = self._expected_fields
116
        self._expr = bind_expression_to_resources(
117
            expr[list(expected_fields)],
118
            resources,
119
        )
120
        self._odo_kwargs = odo_kwargs if odo_kwargs is not None else {}
121
        self._dataset = dataset
122
        self._data_query_time = data_query_time
123
        self._data_query_tz = data_query_tz
124
125
    def load_adjusted_array(self, columns, dates, assets, mask):
126
        data_query_time = self._data_query_time
127
        data_query_tz = self._data_query_tz
128
        expr = self._expr
129
130
        filtered = expr[
131
            expr[TS_FIELD_NAME] <=
132
            normalize_data_query_time(
133
                dates[0],
134
                data_query_time,
135
                data_query_tz,
136
            )
137
        ]
138
        lower = odo(
139
            bz.by(
140
                filtered[SID_FIELD_NAME],
141
                timestamp=filtered[TS_FIELD_NAME].max(),
142
            ).timestamp.min(),
143
            pd.Timestamp,
144
            **self._odo_kwargs
145
        )
146
        if pd.isnull(lower):
147
            # If there is no lower date, just query for data in the date
148
            # range. It must all be null anyways.
149
            lower = dates[0]
150
151
        upper = normalize_data_query_time(
152
            dates[-1],
153
            data_query_time,
154
            data_query_tz,
155
        )
156
        raw = odo(
157
            expr[
158
                (expr[TS_FIELD_NAME] >= lower) &
159
                (expr[TS_FIELD_NAME] <= upper)
160
            ],
161
            pd.DataFrame,
162
            **self._odo_kwargs
163
        )
164
        raw[TS_FIELD_NAME] = raw[TS_FIELD_NAME].astype('datetime64[ns]')
165
        sids = raw.loc[:, SID_FIELD_NAME]
166
        raw.drop(
167
            sids[~(sids.isin(assets) | sids.notnull())].index,
168
            inplace=True
169
        )
170
        normalize_timestamp_to_query_time(
171
            raw,
172
            data_query_time,
173
            data_query_tz,
174
            inplace=True,
175
            ts_field=TS_FIELD_NAME,
176
        )
177
178
        gb = raw.groupby(SID_FIELD_NAME)
179
180
        def mkseries(idx, raw_loc=raw.loc):
181
            vs = raw_loc[
182
                idx, [TS_FIELD_NAME, ANNOUNCEMENT_FIELD_NAME]
183
            ].values
184
            return pd.Series(
185
                index=pd.DatetimeIndex(vs[:, 0]),
186
                data=vs[:, 1],
187
            )
188
189
        return EarningsCalendarLoader(
190
            dates,
191
            valmap(mkseries, gb.groups),
192
            dataset=self._dataset,
193
        ).load_adjusted_array(columns, dates, assets, mask)
194