Completed
Push — master ( 1f137d...7a6ba4 )
by Joe
01:31
created

__init__()   A

Complexity

Conditions 3

Size

Total Lines 20

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 3
dl 0
loc 20
rs 9.4286
1
import blaze as bz
2
from datashape import istabular
3
from odo import odo
4
import pandas as pd
5
from six import iteritems
6
from toolz import valmap
7
8
from .core import TS_FIELD_NAME, SID_FIELD_NAME
9
from zipline.pipeline.data import EarningsCalendar
10
from zipline.pipeline.loaders.base import PipelineLoader
11
from zipline.pipeline.loaders.earnings import EarningsCalendarLoader
12
13
14
ANNOUNCEMENT_FIELD_NAME = 'announcement_date'
15
16
17
def bind_expression_to_resources(expr, resources):
18
    """
19
    Bind a Blaze expression to resources.
20
21
    Parameters
22
    ----------
23
    expr : bz.Expr
24
        The expression to which we want to bind resources.
25
    resources : dict[bz.Symbol -> any]
26
        Mapping from the atomic terms of ``expr`` to actual data resources.
27
28
    Returns
29
    -------
30
    bound_expr : bz.Expr
31
        ``expr`` with bound resources.
32
    """
33
    # bind the resources into the expression
34
    if resources is None:
35
        resources = {}
36
37
    # _subs stands for substitute.  It's not actually private, blaze just
38
    # prefixes symbol-manipulation methods with underscores to prevent
39
    # collisions with data column names.
40
    return expr._subs({
41
        k: bz.Data(v, dshape=k.dshape) for k, v in iteritems(resources)
42
    })
43
44
45
class BlazeEarningsCalendarLoader(PipelineLoader):
46
    """A pipeline loader for the ``EarningsCalendar`` dataset that loads
47
    data from a blaze expression.
48
49
    Parameters
50
    ----------
51
    expr : Expr
52
        The expression representing the data to load.
53
    resources : dict, optional
54
        Mapping from the atomic terms of ``expr`` to actual data resources.
55
    odo_kwargs : dict, optional
56
        Extra keyword arguments to pass to odo when executing the expression.
57
58
    Notes
59
    -----
60
    The expression should have a tabular dshape of::
61
62
       Dim * {{
63
           {SID_FIELD_NAME}: int64,
64
           {TS_FIELD_NAME}: datetime64,
65
           {ANNOUNCEMENT_FIELD_NAME}: datetime64,
66
       }}
67
68
    Where each row of the table is a record including the sid to identify the
69
    company, the timestamp where we learned about the announcement, and the
70
    date when the earnings will be announced.
71
72
    If the '{TS_FIELD_NAME}' field is not included it is assumed that we
73
    start the backtest with knowledge of all announcements.
74
    """
75
    __doc__ = __doc__.format(
76
        TS_FIELD_NAME=TS_FIELD_NAME,
77
        SID_FIELD_NAME=SID_FIELD_NAME,
78
        ANNOUNCEMENT_FIELD_NAME=ANNOUNCEMENT_FIELD_NAME,
79
    )
80
81
    _expected_fields = frozenset({
82
        TS_FIELD_NAME,
83
        SID_FIELD_NAME,
84
        ANNOUNCEMENT_FIELD_NAME,
85
    })
86
87
    def __init__(self,
88
                 expr,
89
                 resources=None,
90
                 compute_kwargs=None,
91
                 odo_kwargs=None,
92
                 dataset=EarningsCalendar):
93
        dshape = expr.dshape
94
95
        if not istabular(dshape):
96
            raise ValueError(
97
                'expression dshape must be tabular, got: %s' % dshape,
98
            )
99
100
        expected_fields = self._expected_fields
101
        self._expr = bind_expression_to_resources(
102
            expr[list(expected_fields)],
103
            resources,
104
        )
105
        self._odo_kwargs = odo_kwargs if odo_kwargs is not None else {}
106
        self._dataset = dataset
107
108
    def load_adjusted_array(self, columns, dates, assets, mask):
109
        expr = self._expr
110
        filtered = expr[expr[TS_FIELD_NAME] <= dates[0]]
111
        lower = odo(
112
            bz.by(
113
                filtered[SID_FIELD_NAME],
114
                timestamp=filtered[TS_FIELD_NAME].max(),
115
            ).timestamp.min(),
116
            pd.Timestamp,
117
            **self._odo_kwargs
118
        )
119
        if pd.isnull(lower):
120
            # If there is no lower date, just query for data in the date
121
            # range. It must all be null anyways.
122
            lower = dates[0]
123
124
        raw = odo(
125
            expr[
126
                (expr[TS_FIELD_NAME] >= lower) &
127
                (expr[TS_FIELD_NAME] <= dates[-1])
128
            ],
129
            pd.DataFrame,
130
            **self._odo_kwargs
131
        )
132
133
        sids = raw.loc[:, SID_FIELD_NAME]
134
        raw.drop(
135
            sids[~(sids.isin(assets) | sids.notnull())].index,
136
            inplace=True
137
        )
138
139
        gb = raw.groupby(SID_FIELD_NAME)
140
141
        def mkseries(idx, raw_loc=raw.loc):
142
            vs = raw_loc[
143
                idx, [TS_FIELD_NAME, ANNOUNCEMENT_FIELD_NAME]
144
            ].values
145
            return pd.Series(
146
                index=pd.DatetimeIndex(vs[:, 0]),
147
                data=vs[:, 1],
148
            )
149
150
        return EarningsCalendarLoader(
151
            dates,
152
            valmap(mkseries, gb.groups),
153
            dataset=self._dataset,
154
        ).load_adjusted_array(columns, dates, assets, mask)
155