Completed
Pull Request — master (#940)
by Joe
01:28
created

load_adjusted_array()   B

Complexity

Conditions 2

Size

Total Lines 29

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 2
dl 0
loc 29
rs 8.8571

1 Method

Rating   Name   Duplication   Size   Complexity  
A zipline.pipeline.loaders.blaze.BlazeEarningsCalendarLoader.mkseries() 0 7 1
1
from datashape import istabular
2
import pandas as pd
3
from toolz import valmap
4
5
from .core import (
6
    TS_FIELD_NAME,
7
    SID_FIELD_NAME,
8
    bind_expression_to_resources,
9
    ffill_query_in_range,
10
)
11
from zipline.pipeline.data import EarningsCalendar
12
from zipline.pipeline.loaders.base import PipelineLoader
13
from zipline.pipeline.loaders.earnings import EarningsCalendarLoader
14
15
16
ANNOUNCEMENT_FIELD_NAME = 'announcement_date'
17
18
19
class BlazeEarningsCalendarLoader(PipelineLoader):
20
    """A pipeline loader for the ``EarningsCalendar`` dataset that loads
21
    data from a blaze expression.
22
23
    Parameters
24
    ----------
25
    expr : Expr
26
        The expression representing the data to load.
27
    resources : dict, optional
28
        Mapping from the atomic terms of ``expr`` to actual data resources.
29
    odo_kwargs : dict, optional
30
        Extra keyword arguments to pass to odo when executing the expression.
31
32
    Notes
33
    -----
34
    The expression should have a tabular dshape of::
35
36
       Dim * {{
37
           {SID_FIELD_NAME}: int64,
38
           {TS_FIELD_NAME}: datetime,
39
           {ANNOUNCEMENT_FIELD_NAME}: ?datetime,
40
       }}
41
42
    Where each row of the table is a record including the sid to identify the
43
    company, the timestamp where we learned about the announcement, and the
44
    date when the earnings will be announced.
45
46
    If the '{TS_FIELD_NAME}' field is not included it is assumed that we
47
    start the backtest with knowledge of all announcements.
48
    """
49
    __doc__ = __doc__.format(
50
        TS_FIELD_NAME=TS_FIELD_NAME,
51
        SID_FIELD_NAME=SID_FIELD_NAME,
52
        ANNOUNCEMENT_FIELD_NAME=ANNOUNCEMENT_FIELD_NAME,
53
    )
54
55
    _expected_fields = frozenset({
56
        TS_FIELD_NAME,
57
        SID_FIELD_NAME,
58
        ANNOUNCEMENT_FIELD_NAME,
59
    })
60
61
    def __init__(self,
62
                 expr,
63
                 resources=None,
64
                 odo_kwargs=None,
65
                 dataset=EarningsCalendar):
66
        dshape = expr.dshape
67
68
        if not istabular(dshape):
69
            raise ValueError(
70
                'expression dshape must be tabular, got: %s' % dshape,
71
            )
72
73
        expected_fields = self._expected_fields
74
        self._expr = bind_expression_to_resources(
75
            expr[list(expected_fields)],
76
            resources,
77
        )
78
        self._odo_kwargs = odo_kwargs if odo_kwargs is not None else {}
79
        self._dataset = dataset
80
81
    def load_adjusted_array(self, columns, dates, assets, mask):
82
        raw = ffill_query_in_range(
83
            self._expr,
84
            dates[0],
85
            dates[-1],
86
            self._odo_kwargs,
87
        )
88
        sids = raw.loc[:, SID_FIELD_NAME]
89
        raw.drop(
90
            sids[~sids.isin(assets)].index,
91
            inplace=True
92
        )
93
94
        gb = raw.groupby(SID_FIELD_NAME)
95
96
        def mkseries(idx, raw_loc=raw.loc):
97
            vs = raw_loc[
98
                idx, [TS_FIELD_NAME, ANNOUNCEMENT_FIELD_NAME]
99
            ].values
100
            return pd.Series(
101
                index=pd.DatetimeIndex(vs[:, 0]),
102
                data=vs[:, 1],
103
            )
104
105
        return EarningsCalendarLoader(
106
            dates,
107
            valmap(mkseries, gb.groups),
108
            dataset=self._dataset,
109
        ).load_adjusted_array(columns, dates, assets, mask)
110