Completed
Pull Request — master (#858)
by Eddie
02:02
created

zipline.sources.RandomWalkSource._gen_events()   A

Complexity

Conditions 2

Size

Total Lines 15

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 2
dl 0
loc 15
rs 9.4286
1
#
2
# Copyright 2013 Quantopian, Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
"""
17
A source to be used in testing.
18
"""
19
20
import pytz
21
22
from six.moves import filter
23
from datetime import datetime, timedelta
24
import itertools
25
26
from six.moves import range
27
28
from zipline.protocol import (
29
    Event,
30
    DATASOURCE_TYPE
31
)
32
from zipline.gens.utils import hash_args
33
34
35
def create_trade(sid, price, amount, datetime, source_id="test_factory"):
36
37
    trade = Event()
38
39
    trade.source_id = source_id
40
    trade.type = DATASOURCE_TYPE.TRADE
41
    trade.sid = sid
42
    trade.dt = datetime
43
    trade.price = price
44
    trade.close_price = price
45
    trade.open_price = price
46
    trade.low = price * .95
47
    trade.high = price * 1.05
48
    trade.volume = amount
49
50
    return trade
51
52
53
def date_gen(start,
54
             end,
55
             env,
56
             delta=timedelta(minutes=1),
57
             repeats=None):
58
    """
59
    Utility to generate a stream of dates.
60
    """
61
    daily_delta = not (delta.total_seconds()
62
                       % timedelta(days=1).total_seconds())
63
    cur = start
64
    if daily_delta:
65
        # if we are producing daily timestamps, we
66
        # use midnight
67
        cur = cur.replace(hour=0, minute=0, second=0,
68
                          microsecond=0)
69
70
    def advance_current(cur):
71
        """
72
        Advances the current dt skipping non market days and minutes.
73
        """
74
        cur = cur + delta
75
76
        if not (env.is_trading_day
77
                if daily_delta
78
                else env.is_market_hours)(cur):
79
            if daily_delta:
80
                return env.next_trading_day(cur)
81
            else:
82
                return env.next_open_and_close(cur)[0]
83
        else:
84
            return cur
85
86
    # yield count trade events, all on trading days, and
87
    # during trading hours.
88
    while cur < end:
89
        if repeats:
90
            for j in range(repeats):
91
                yield cur
92
        else:
93
            yield cur
94
95
        cur = advance_current(cur)
96
97
98
class SpecificEquityTrades(object):
99
    """
100
    Yields all events in event_list that match the given sid_filter.
101
    If no event_list is specified, generates an internal stream of events
102
    to filter.  Returns all events if filter is None.
103
104
    Configuration options:
105
106
    count  : integer representing number of trades
107
    sids   : list of values representing simulated internal sids
108
    start  : start date
109
    delta  : timedelta between internal events
110
    filter : filter to remove the sids
111
    """
112
    def __init__(self, env, *args, **kwargs):
113
        # We shouldn't get any positional arguments.
114
        assert len(args) == 0
115
116
        self.env = env
117
118
        # Default to None for event_list and filter.
119
        self.event_list = kwargs.get('event_list')
120
        self.filter = kwargs.get('filter')
121
        if self.event_list is not None:
122
            # If event_list is provided, extract parameters from there
123
            # This isn't really clean and ultimately I think this
124
            # class should serve a single purpose (either take an
125
            # event_list or autocreate events).
126
            self.count = kwargs.get('count', len(self.event_list))
127
            self.start = kwargs.get('start', self.event_list[0].dt)
128
            self.end = kwargs.get('end', self.event_list[-1].dt)
129
            self.delta = delta = kwargs.get('delta')
130
            if delta is None:
131
                self.delta = self.event_list[1].dt - self.event_list[0].dt
132
            self.concurrent = kwargs.get('concurrent', False)
133
134
            self.identifiers = kwargs.get(
135
                'sids',
136
                set(event.sid for event in self.event_list)
137
            )
138
            assets_by_identifier = {}
139
            for identifier in self.identifiers:
140
                assets_by_identifier[identifier] = env.asset_finder.\
141
                    lookup_generic(identifier, datetime.now())[0]
142
            self.sids = [asset.sid for asset in assets_by_identifier.values()]
143
            for event in self.event_list:
144
                event.sid = assets_by_identifier[event.sid].sid
145
146
        else:
147
            # Unpack config dictionary with default values.
148
            self.count = kwargs.get('count', 500)
149
            self.start = kwargs.get(
150
                'start',
151
                datetime(2008, 6, 6, 15, tzinfo=pytz.utc))
152
            self.end = kwargs.get(
153
                'end',
154
                datetime(2008, 6, 6, 15, tzinfo=pytz.utc))
155
            self.delta = kwargs.get(
156
                'delta',
157
                timedelta(minutes=1))
158
            self.concurrent = kwargs.get('concurrent', False)
159
160
            self.identifiers = kwargs.get('sids', [1, 2])
161
            assets_by_identifier = {}
162
            for identifier in self.identifiers:
163
                assets_by_identifier[identifier] = env.asset_finder.\
164
                    lookup_generic(identifier, datetime.now())[0]
165
            self.sids = [asset.sid for asset in assets_by_identifier.values()]
166
167
        # Hash_value for downstream sorting.
168
        self.arg_string = hash_args(*args, **kwargs)
169
170
        self.generator = self.create_fresh_generator()
171
172
    def __iter__(self):
173
        return self
174
175
    def next(self):
176
        return self.generator.next()
177
178
    def __next__(self):
179
        return next(self.generator)
180
181
    def rewind(self):
182
        self.generator = self.create_fresh_generator()
183
184
    def get_hash(self):
185
        return self.__class__.__name__ + "-" + self.arg_string
186
187
    def update_source_id(self, gen):
188
        for event in gen:
189
            event.source_id = self.get_hash()
190
            yield event
191
192
    def create_fresh_generator(self):
193
194
        if self.event_list:
195
            event_gen = (event for event in self.event_list)
196
            unfiltered = self.update_source_id(event_gen)
197
198
        # Set up iterators for each expected field.
199
        else:
200
            if self.concurrent:
201
                # in this context the count is the number of
202
                # trades per sid, not the total.
203
                date_generator = date_gen(
204
                    start=self.start,
205
                    end=self.end,
206
                    delta=self.delta,
207
                    repeats=len(self.sids),
208
                    env=self.env,
209
                )
210
            else:
211
                date_generator = date_gen(
212
                    start=self.start,
213
                    end=self.end,
214
                    delta=self.delta,
215
                    env=self.env,
216
                )
217
218
            source_id = self.get_hash()
219
220
            unfiltered = (
221
                create_trade(
222
                    sid=sid,
223
                    price=float(i % 10) + 1.0,
224
                    amount=(i * 50) % 900 + 100,
225
                    datetime=date,
226
                    source_id=source_id,
227
                ) for (i, date), sid in itertools.product(
228
                    enumerate(date_generator), self.sids
229
                )
230
            )
231
232
        # If we specified a sid filter, filter out elements that don't
233
        # match the filter.
234
        if self.filter:
235
            filtered = filter(
236
                lambda event: event.sid in self.filter, unfiltered)
237
238
        # Otherwise just use all events.
239
        else:
240
            filtered = unfiltered
241
242
        # Return the filtered event stream.
243
        return filtered
244