Completed
Pull Request — master (#905)
by
unknown
01:31
created

zipline.pipeline.loaders.blaze.BlazeEarningsCalendarLoader   A

Complexity

Total Complexity 6

Size/Duplication

Total Lines 107
Duplicated Lines 0 %
Metric Value
dl 0
loc 107
rs 10
wmc 6

3 Methods

Rating   Name   Duplication   Size   Complexity  
B load_adjusted_array() 0 46 3
A mkseries() 0 7 1
A __init__() 0 18 3
1
import blaze as bz
2
from datashape import istabular
3
from odo import odo
4
import pandas as pd
5
from six import iteritems
6
from toolz import valmap
7
8
from .core import TS_FIELD_NAME, SID_FIELD_NAME
9
from zipline.pipeline.loaders.base import PipelineLoader
10
from zipline.pipeline.loaders.earnings import EarningsCalendarLoader
11
12
13
ANNOUNCEMENT_FIELD_NAME = 'announcement_date'
14
15
16
def bind_expression_to_resources(expr, resources):
17
    """
18
    Bind a Blaze expression to resources.
19
20
    Parameters
21
    ----------
22
    expr : bz.Expr
23
        The expression to which we want to bind resources.
24
    resources : dict[bz.Symbol -> any]
25
        Mapping from the atomic terms of ``expr`` to actual data resources.
26
27
    Returns
28
    -------
29
    bound_expr : bz.Expr
30
        ``expr`` with bound resources.
31
    """
32
    # bind the resources into the expression
33
    if resources is None:
34
        resources = {}
35
36
    # _subs stands for substitute.  It's not actually private, blaze just
37
    # prefixes symbol-manipulation methods with underscores to prevent
38
    # collisions with data column names.
39
    return expr._subs({
40
        k: bz.Data(v, dshape=k.dshape) for k, v in iteritems(resources)
41
    })
42
43
44
class BlazeEarningsCalendarLoader(PipelineLoader):
45
    """A pipeline loader for the ``EarningsCalendar`` dataset that loads
46
    data from a blaze expression.
47
48
    Parameters
49
    ----------
50
    expr : Expr
51
        The expression representing the data to load.
52
    resources : dict, optional
53
        Mapping from the atomic terms of ``expr`` to actual data resources.
54
    odo_kwargs : dict, optional
55
        Extra keyword arguments to pass to odo when executing the expression.
56
57
    Notes
58
    -----
59
    The expression should have a tabular dshape of::
60
61
       Dim * {{
62
           {SID_FIELD_NAME}: int64,
63
           {TS_FIELD_NAME}: datetime64,
64
           {ANNOUNCEMENT_FIELD_NAME}: datetime64,
65
       }}
66
67
    Where each row of the table is a record including the sid to identify the
68
    company, the timestamp where we learned about the announcement, and the
69
    date when the earnings will be announced.
70
71
    If the '{TS_FIELD_NAME}' field is not included it is assumed that we
72
    start the backtest with knowledge of all announcements.
73
    """
74
    __doc__ = __doc__.format(
75
        TS_FIELD_NAME=TS_FIELD_NAME,
76
        SID_FIELD_NAME=SID_FIELD_NAME,
77
        ANNOUNCEMENT_FIELD_NAME=ANNOUNCEMENT_FIELD_NAME,
78
    )
79
80
    _expected_fields = frozenset({
81
        TS_FIELD_NAME,
82
        SID_FIELD_NAME,
83
        ANNOUNCEMENT_FIELD_NAME,
84
    })
85
86
    def __init__(self,
87
                 expr,
88
                 resources=None,
89
                 compute_kwargs=None,
90
                 odo_kwargs=None):
91
        dshape = expr.dshape
92
93
        if not istabular(dshape):
94
            raise ValueError(
95
                'expression dshape must be tabular, got: %s' % dshape,
96
            )
97
98
        expected_fields = self._expected_fields
99
        self._expr = bind_expression_to_resources(
100
            expr[list(expected_fields)],
101
            resources,
102
        )
103
        self._odo_kwargs = odo_kwargs if odo_kwargs is not None else {}
104
105
    def load_adjusted_array(self, columns, dates, assets, mask):
106
        expr = self._expr
107
        filtered = expr[expr[TS_FIELD_NAME] <= dates[0]]
108
        lower = odo(
109
            bz.by(
110
                filtered[SID_FIELD_NAME],
111
                timestamp=filtered[TS_FIELD_NAME].max(),
112
            ).timestamp.min(),
113
            pd.Timestamp,
114
            **self._odo_kwargs
115
        )
116
        if pd.isnull(lower):
117
            # If there is no lower date, just query for data in the date
118
            # range. It must all be null anyways.
119
            lower = dates[0]
120
121
        raw = odo(
122
            expr[
123
                (expr[TS_FIELD_NAME] >= lower) &
124
                (expr[TS_FIELD_NAME] <= dates[-1])
125
            ],
126
            pd.DataFrame,
127
            **self._odo_kwargs
128
        )
129
130
        sids = raw.loc[:, SID_FIELD_NAME]
131
        raw.drop(
132
            sids[~(sids.isin(assets) | sids.notnull())].index,
133
            inplace=True
134
        )
135
136
        gb = raw.groupby(SID_FIELD_NAME)
137
138
        def mkseries(idx, raw_loc=raw.loc):
139
            vs = raw_loc[
140
                idx, [TS_FIELD_NAME, ANNOUNCEMENT_FIELD_NAME]
141
            ].values
142
            return pd.Series(
143
                index=pd.DatetimeIndex(vs[:, 0]),
144
                data=vs[:, 1],
145
            )
146
147
        return EarningsCalendarLoader(
148
            dates,
149
            valmap(mkseries, gb.groups),
150
        ).load_adjusted_array(columns, dates, assets, mask)
151