Completed
Pull Request — master (#905)
by
unknown
01:34
created

load_adjusted_array()   B

Complexity

Conditions 5

Size

Total Lines 47

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 5
dl 0
loc 47
rs 8.1672

1 Method

Rating   Name   Duplication   Size   Complexity  
A zipline.pipeline.loaders.blaze.BlazeEarningsCalendarLoader.mkseries() 0 2 1
1
import blaze as bz
2
from datashape import istabular
3
from odo import odo
4
import pandas as pd
5
from six import iteritems
6
from toolz import valmap
7
8
from .core import TS_FIELD_NAME, SID_FIELD_NAME
9
from zipline.pipeline.loaders.base import PipelineLoader
10
from zipline.pipeline.loaders.earnings import EarningsCalendarLoader
11
12
13
ANCMT_FIELD_NAME = 'announcement_date'
14
15
16
class BlazeEarningsCalendarLoader(PipelineLoader):
17
    """A pipeline loader for the ``EarningsCalendar`` dataset that loads
18
    data from a blaze expression.
19
20
    Parameters
21
    ----------
22
    expr : Expr
23
        The expression representing the data to load.
24
    resources : any, optional
25
        The resources to use when computing ``expr``. If expr is already
26
        bound to resources this can be omitted.
27
    odo_kwargs : dict, optional
28
        Extra keyword arguments to pass to odo when executing the expression.
29
30
    Notes
31
    -----
32
    The expression should have a tabular dshape of::
33
34
       Dim * {{
35
           {SID_FIELD_NAME}: int64,
36
           {TS_FIELD_NAME}: datetime64,
37
           {ANCMT_FIELD_NAME}: datetime64,
38
       }}
39
40
    Where each row of the table is a record including the sid to identify the
41
    company, the timestamp where we learned about the announcement, and the
42
    date when the earnings will be announced.
43
44
    If the '{TS_FIELD_NAME}' field is not included it is assumed that we
45
    start the backtest with knowledge of all announcements.
46
    """
47
    __doc__ = __doc__.format(
48
        TS_FIELD_NAME=TS_FIELD_NAME,
49
        SID_FIELD_NAME=SID_FIELD_NAME,
50
        ANCMT_FIELD_NAME=ANCMT_FIELD_NAME,
51
    )
52
53
    _expected_fields = frozenset({
54
        TS_FIELD_NAME,
55
        SID_FIELD_NAME,
56
        ANCMT_FIELD_NAME,
57
    })
58
59
    def __init__(self,
60
                 expr,
61
                 resources=None,
62
                 compute_kwargs=None,
63
                 odo_kwargs=None):
64
        dshape = expr.dshape
65
66
        if not istabular(dshape):
67
            raise ValueError(
68
                'expression dshape must be tabular, got: %s' % dshape,
69
            )
70
71
        expected_fields = self._expected_fields
72
        self._has_ts = has_ts = TS_FIELD_NAME in dshape.measure.dict
73
        if not has_ts:
74
            # This field is optional.
75
            expected_fields - {TS_FIELD_NAME}
76
77
        # bind the resources into the expression
78
        if resources is None:
79
            resources = {}
80
        elif not isinstance(resources, dict):
81
            leaves = expr._leaves()
82
            if len(leaves) != 1:
83
                raise ValueError('no data resources found')
84
85
            resources = {leaves[0]: resources}
86
87
        self._expr = expr[list(expected_fields)]._subs({
88
            k: bz.Data(v, dshape=k.dshape) for k, v in iteritems(resources)
89
        })
90
        self._odo_kwargs = odo_kwargs if odo_kwargs is not None else {}
91
92
    def load_adjusted_array(self, columns, dates, assets, mask):
93
        expr = self._expr
94
        filtered = expr[expr[TS_FIELD_NAME] <= dates[0]]
95
        lower = odo(
96
            bz.by(
97
                filtered[SID_FIELD_NAME],
98
                timestamp=filtered[TS_FIELD_NAME].max(),
99
            ).timestamp.min(),
100
            pd.Timestamp,
101
            **self._odo_kwargs or {}
102
        )
103
        if lower is pd.NaT:
104
            # If there is no lower date, just query for data in the date
105
            # range. It must all be null anyways.
106
            lower = dates[0]
107
108
        raw = odo(
109
            expr[
110
                (expr[TS_FIELD_NAME] >= lower) &
111
                (expr[TS_FIELD_NAME] <= dates[-1])
112
            ],
113
            pd.DataFrame,
114
            **self._odo_kwargs or {}
115
        )
116
117
        sids = raw.loc[:, SID_FIELD_NAME]
118
        raw.drop(
119
            sids[~(sids.isin(assets) | sids.notnull())].index,
120
            inplace=True
121
        )
122
123
        gb = raw.groupby(SID_FIELD_NAME)
124
        if self._has_ts:
125
            def mkseries(idx, raw_loc=raw.loc):
126
                vs = raw_loc[idx, [TS_FIELD_NAME, ANCMT_FIELD_NAME]].values
127
                return pd.Series(
128
                    index=pd.DatetimeIndex(vs[:, 0]),
129
                    data=vs[:, 1],
130
                )
131
        else:
132
            def mkseries(idx, raw_loc=raw.loc):
133
                return pd.DatetimeIndex(raw_loc[idx, ANCMT_FIELD_NAME])
134
135
        return EarningsCalendarLoader(
136
            dates,
137
            valmap(mkseries, gb.groups),
138
        ).load_adjusted_array(columns, dates, assets, mask)
139