Completed
Pull Request — master (#947)
by Joe
01:19
created

__init__()   B

Complexity

Conditions 3

Size

Total Lines 24

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 3
dl 0
loc 24
rs 8.9714
1
import datetime
2
3
from datashape import istabular
4
import pandas as pd
5
from toolz import valmap
6
7
from .core import (
8
    TS_FIELD_NAME,
9
    SID_FIELD_NAME,
10
    bind_expression_to_resources,
11
    ffill_query_in_range,
12
)
13
from zipline.pipeline.data import EarningsCalendar
14
from zipline.pipeline.loaders.base import PipelineLoader
15
from zipline.pipeline.loaders.earnings import EarningsCalendarLoader
16
from zipline.pipeline.loaders.utils import (
17
    normalize_data_query_time,
18
    normalize_timestamp_to_query_time,
19
)
20
from zipline.utils.input_validation import ensure_timezone
21
from zipline.utils.preprocess import preprocess
22
23
24
ANNOUNCEMENT_FIELD_NAME = 'announcement_date'
25
26
27
class BlazeEarningsCalendarLoader(PipelineLoader):
28
    """A pipeline loader for the ``EarningsCalendar`` dataset that loads
29
    data from a blaze expression.
30
31
    Parameters
32
    ----------
33
    expr : Expr
34
        The expression representing the data to load.
35
    resources : dict, optional
36
        Mapping from the atomic terms of ``expr`` to actual data resources.
37
    odo_kwargs : dict, optional
38
        Extra keyword arguments to pass to odo when executing the expression.
39
    data_query_time : time, optional
40
        The time to use for the data query cutoff.
41
    data_query_tz : tzinfo or str
42
        The timezeone to use for the data query cutoff.
43
44
    Notes
45
    -----
46
    The expression should have a tabular dshape of::
47
48
       Dim * {{
49
           {SID_FIELD_NAME}: int64,
50
           {TS_FIELD_NAME}: datetime,
51
           {ANNOUNCEMENT_FIELD_NAME}: ?datetime,
52
       }}
53
54
    Where each row of the table is a record including the sid to identify the
55
    company, the timestamp where we learned about the announcement, and the
56
    date when the earnings will be announced.
57
58
    If the '{TS_FIELD_NAME}' field is not included it is assumed that we
59
    start the backtest with knowledge of all announcements.
60
    """
61
    __doc__ = __doc__.format(
62
        TS_FIELD_NAME=TS_FIELD_NAME,
63
        SID_FIELD_NAME=SID_FIELD_NAME,
64
        ANNOUNCEMENT_FIELD_NAME=ANNOUNCEMENT_FIELD_NAME,
65
    )
66
67
    _expected_fields = frozenset({
68
        TS_FIELD_NAME,
69
        SID_FIELD_NAME,
70
        ANNOUNCEMENT_FIELD_NAME,
71
    })
72
73
    @preprocess(data_query_tz=ensure_timezone)
74
    def __init__(self,
75
                 expr,
76
                 resources=None,
77
                 odo_kwargs=None,
78
                 data_query_time=datetime.time(0),
79
                 data_query_tz='utc',
80
                 dataset=EarningsCalendar):
81
        dshape = expr.dshape
82
83
        if not istabular(dshape):
84
            raise ValueError(
85
                'expression dshape must be tabular, got: %s' % dshape,
86
            )
87
88
        expected_fields = self._expected_fields
89
        self._expr = bind_expression_to_resources(
90
            expr[list(expected_fields)],
91
            resources,
92
        )
93
        self._odo_kwargs = odo_kwargs if odo_kwargs is not None else {}
94
        self._dataset = dataset
95
        self._data_query_time = data_query_time
96
        self._data_query_tz = data_query_tz
97
98
    def load_adjusted_array(self, columns, dates, assets, mask):
99
        data_query_time = self._data_query_time
100
        data_query_tz = self._data_query_tz
101
        raw = ffill_query_in_range(
102
            self._expr,
103
            normalize_data_query_time(
104
                dates[0],
105
                data_query_time,
106
                data_query_tz,
107
            ),
108
            normalize_data_query_time(
109
                dates[-1],
110
                data_query_time,
111
                data_query_tz,
112
            ),
113
            self._odo_kwargs,
114
        )
115
        sids = raw.loc[:, SID_FIELD_NAME]
116
        raw.drop(
117
            sids[~sids.isin(assets)].index,
118
            inplace=True
119
        )
120
        normalize_timestamp_to_query_time(
121
            raw,
122
            data_query_time,
123
            data_query_tz,
124
            inplace=True,
125
            ts_field=TS_FIELD_NAME,
126
        )
127
128
        gb = raw.groupby(SID_FIELD_NAME)
129
130
        def mkseries(idx, raw_loc=raw.loc):
131
            vs = raw_loc[
132
                idx, [TS_FIELD_NAME, ANNOUNCEMENT_FIELD_NAME]
133
            ].values
134
            return pd.Series(
135
                index=pd.DatetimeIndex(vs[:, 0]),
136
                data=vs[:, 1],
137
            )
138
139
        return EarningsCalendarLoader(
140
            dates,
141
            valmap(mkseries, gb.groups),
142
            dataset=self._dataset,
143
        ).load_adjusted_array(columns, dates, assets, mask)
144