Total Complexity | 67 |
Total Lines | 405 |
Duplicated Lines | 0 % |
Complex classes like zipline.finance.TradingEnvironment often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
1 | # |
||
60 | class TradingEnvironment(object): |
||
61 | |||
62 | # Token used as a substitute for pickling objects that contain a |
||
63 | # reference to a TradingEnvironment |
||
64 | PERSISTENT_TOKEN = "<TradingEnvironment>" |
||
65 | |||
66 | def __init__( |
||
67 | self, |
||
68 | load=None, |
||
69 | bm_symbol='^GSPC', |
||
70 | exchange_tz="US/Eastern", |
||
71 | max_date=None, |
||
72 | env_trading_calendar=tradingcalendar, |
||
73 | asset_db_path=':memory:' |
||
74 | ): |
||
75 | """ |
||
76 | @load is function that returns benchmark_returns and treasury_curves |
||
77 | The treasury_curves are expected to be a DataFrame with an index of |
||
78 | dates and columns of the curve names, e.g. '10year', '1month', etc. |
||
79 | """ |
||
80 | self.trading_day = env_trading_calendar.trading_day.copy() |
||
81 | |||
82 | # `tc_td` is short for "trading calendar trading days" |
||
83 | tc_td = env_trading_calendar.trading_days |
||
84 | |||
85 | if max_date: |
||
86 | self.trading_days = tc_td[tc_td <= max_date].copy() |
||
87 | else: |
||
88 | self.trading_days = tc_td.copy() |
||
89 | |||
90 | self.first_trading_day = self.trading_days[0] |
||
91 | self.last_trading_day = self.trading_days[-1] |
||
92 | |||
93 | self.early_closes = env_trading_calendar.get_early_closes( |
||
94 | self.first_trading_day, self.last_trading_day) |
||
95 | |||
96 | self.open_and_closes = env_trading_calendar.open_and_closes.loc[ |
||
97 | self.trading_days] |
||
98 | |||
99 | self.bm_symbol = bm_symbol |
||
100 | if not load: |
||
101 | load = load_market_data |
||
102 | |||
103 | self.benchmark_returns, self.treasury_curves = \ |
||
104 | load(self.trading_day, self.trading_days, self.bm_symbol) |
||
105 | |||
106 | if max_date: |
||
107 | tr_c = self.treasury_curves |
||
108 | # Mask the treasury curves down to the current date. |
||
109 | # In the case of live trading, the last date in the treasury |
||
110 | # curves would be the day before the date considered to be |
||
111 | # 'today'. |
||
112 | self.treasury_curves = tr_c[tr_c.index <= max_date] |
||
113 | |||
114 | self.exchange_tz = exchange_tz |
||
115 | |||
116 | if isinstance(asset_db_path, string_types): |
||
117 | asset_db_path = 'sqlite:///%s' % asset_db_path |
||
118 | self.engine = engine = create_engine(asset_db_path) |
||
119 | AssetDBWriterFromDictionary().init_db(engine) |
||
120 | else: |
||
121 | self.engine = engine = asset_db_path |
||
122 | |||
123 | if engine is not None: |
||
124 | self.asset_finder = AssetFinder(engine) |
||
125 | else: |
||
126 | self.asset_finder = None |
||
127 | |||
128 | def write_data(self, |
||
129 | engine=None, |
||
130 | equities_data=None, |
||
131 | futures_data=None, |
||
132 | exchanges_data=None, |
||
133 | root_symbols_data=None, |
||
134 | equities_df=None, |
||
135 | futures_df=None, |
||
136 | exchanges_df=None, |
||
137 | root_symbols_df=None, |
||
138 | equities_identifiers=None, |
||
139 | futures_identifiers=None, |
||
140 | exchanges_identifiers=None, |
||
141 | root_symbols_identifiers=None, |
||
142 | allow_sid_assignment=True): |
||
143 | """ Write the supplied data to the database. |
||
144 | |||
145 | Parameters |
||
146 | ---------- |
||
147 | equities_data: dict, optional |
||
148 | A dictionary of equity metadata |
||
149 | futures_data: dict, optional |
||
150 | A dictionary of futures metadata |
||
151 | exchanges_data: dict, optional |
||
152 | A dictionary of exchanges metadata |
||
153 | root_symbols_data: dict, optional |
||
154 | A dictionary of root symbols metadata |
||
155 | equities_df: pandas.DataFrame, optional |
||
156 | A pandas.DataFrame of equity metadata |
||
157 | futures_df: pandas.DataFrame, optional |
||
158 | A pandas.DataFrame of futures metadata |
||
159 | exchanges_df: pandas.DataFrame, optional |
||
160 | A pandas.DataFrame of exchanges metadata |
||
161 | root_symbols_df: pandas.DataFrame, optional |
||
162 | A pandas.DataFrame of root symbols metadata |
||
163 | equities_identifiers: list, optional |
||
164 | A list of equities identifiers (sids, symbols, Assets) |
||
165 | futures_identifiers: list, optional |
||
166 | A list of futures identifiers (sids, symbols, Assets) |
||
167 | exchanges_identifiers: list, optional |
||
168 | A list of exchanges identifiers (ids or names) |
||
169 | root_symbols_identifiers: list, optional |
||
170 | A list of root symbols identifiers (ids or symbols) |
||
171 | """ |
||
172 | if engine: |
||
173 | self.engine = engine |
||
174 | |||
175 | # If any pandas.DataFrame data has been provided, |
||
176 | # write it to the database. |
||
177 | if (equities_df is not None or futures_df is not None or |
||
178 | exchanges_df is not None or root_symbols_df is not None): |
||
179 | self._write_data_dataframes(equities_df, futures_df, |
||
180 | exchanges_df, root_symbols_df) |
||
181 | |||
182 | if (equities_data is not None or futures_data is not None or |
||
183 | exchanges_data is not None or root_symbols_data is not None): |
||
184 | self._write_data_dicts(equities_data, futures_data, |
||
185 | exchanges_data, root_symbols_data) |
||
186 | |||
187 | # These could be lists or other iterables such as a pandas.Index. |
||
188 | # For simplicity, don't check whether data has been provided. |
||
189 | self._write_data_lists(equities_identifiers, |
||
190 | futures_identifiers, |
||
191 | exchanges_identifiers, |
||
192 | root_symbols_identifiers, |
||
193 | allow_sid_assignment=allow_sid_assignment) |
||
194 | |||
195 | def _write_data_lists(self, equities=None, futures=None, exchanges=None, |
||
196 | root_symbols=None, allow_sid_assignment=True): |
||
197 | AssetDBWriterFromList(equities, futures, exchanges, root_symbols)\ |
||
198 | .write_all(self.engine, allow_sid_assignment=allow_sid_assignment) |
||
199 | |||
200 | def _write_data_dicts(self, equities=None, futures=None, exchanges=None, |
||
201 | root_symbols=None): |
||
202 | AssetDBWriterFromDictionary(equities, futures, exchanges, root_symbols)\ |
||
203 | .write_all(self.engine) |
||
204 | |||
205 | def _write_data_dataframes(self, equities=None, futures=None, |
||
206 | exchanges=None, root_symbols=None): |
||
207 | AssetDBWriterFromDataFrame(equities, futures, exchanges, root_symbols)\ |
||
208 | .write_all(self.engine) |
||
209 | |||
210 | def normalize_date(self, test_date): |
||
211 | test_date = pd.Timestamp(test_date, tz='UTC') |
||
212 | return pd.tseries.tools.normalize_date(test_date) |
||
213 | |||
214 | def utc_dt_in_exchange(self, dt): |
||
215 | return pd.Timestamp(dt).tz_convert(self.exchange_tz) |
||
216 | |||
217 | def exchange_dt_in_utc(self, dt): |
||
218 | return pd.Timestamp(dt, tz=self.exchange_tz).tz_convert('UTC') |
||
219 | |||
220 | def is_market_hours(self, test_date): |
||
221 | if not self.is_trading_day(test_date): |
||
222 | return False |
||
223 | |||
224 | mkt_open, mkt_close = self.get_open_and_close(test_date) |
||
225 | return test_date >= mkt_open and test_date <= mkt_close |
||
226 | |||
227 | def is_trading_day(self, test_date): |
||
228 | dt = self.normalize_date(test_date) |
||
229 | return (dt in self.trading_days) |
||
230 | |||
231 | def next_trading_day(self, test_date): |
||
232 | dt = self.normalize_date(test_date) |
||
233 | delta = datetime.timedelta(days=1) |
||
234 | |||
235 | while dt <= self.last_trading_day: |
||
236 | dt += delta |
||
237 | if dt in self.trading_days: |
||
238 | return dt |
||
239 | |||
240 | return None |
||
241 | |||
242 | def previous_trading_day(self, test_date): |
||
243 | dt = self.normalize_date(test_date) |
||
244 | delta = datetime.timedelta(days=-1) |
||
245 | |||
246 | while self.first_trading_day < dt: |
||
247 | dt += delta |
||
248 | if dt in self.trading_days: |
||
249 | return dt |
||
250 | |||
251 | return None |
||
252 | |||
253 | def add_trading_days(self, n, date): |
||
254 | """ |
||
255 | Adds n trading days to date. If this would fall outside of the |
||
256 | trading calendar, a NoFurtherDataError is raised. |
||
257 | |||
258 | :Arguments: |
||
259 | n : int |
||
260 | The number of days to add to date, this can be positive or |
||
261 | negative. |
||
262 | date : datetime |
||
263 | The date to add to. |
||
264 | |||
265 | :Returns: |
||
266 | new_date : datetime |
||
267 | n trading days added to date. |
||
268 | """ |
||
269 | if n == 1: |
||
270 | return self.next_trading_day(date) |
||
271 | if n == -1: |
||
272 | return self.previous_trading_day(date) |
||
273 | |||
274 | idx = self.get_index(date) + n |
||
275 | if idx < 0 or idx >= len(self.trading_days): |
||
276 | raise NoFurtherDataError( |
||
277 | msg='Cannot add %d days to %s' % (n, date) |
||
278 | ) |
||
279 | |||
280 | return self.trading_days[idx] |
||
281 | |||
282 | def days_in_range(self, start, end): |
||
283 | mask = ((self.trading_days >= start) & |
||
284 | (self.trading_days <= end)) |
||
285 | return self.trading_days[mask] |
||
286 | |||
287 | def opens_in_range(self, start, end): |
||
288 | return self.open_and_closes.market_open.loc[start:end] |
||
289 | |||
290 | def closes_in_range(self, start, end): |
||
291 | return self.open_and_closes.market_close.loc[start:end] |
||
292 | |||
293 | def minutes_for_days_in_range(self, start, end): |
||
294 | """ |
||
295 | Get all market minutes for the days between start and end, inclusive. |
||
296 | """ |
||
297 | start_date = self.normalize_date(start) |
||
298 | end_date = self.normalize_date(end) |
||
299 | |||
300 | all_minutes = [] |
||
301 | for day in self.days_in_range(start_date, end_date): |
||
302 | day_minutes = self.market_minutes_for_day(day) |
||
303 | all_minutes.append(day_minutes) |
||
304 | |||
305 | # Concatenate all minutes and truncate minutes before start/after end. |
||
306 | return pd.DatetimeIndex( |
||
307 | np.concatenate(all_minutes), copy=False, tz='UTC', |
||
308 | ) |
||
309 | |||
310 | def next_open_and_close(self, start_date): |
||
311 | """ |
||
312 | Given the start_date, returns the next open and close of |
||
313 | the market. |
||
314 | """ |
||
315 | next_open = self.next_trading_day(start_date) |
||
316 | |||
317 | if next_open is None: |
||
318 | raise NoFurtherDataError( |
||
319 | msg=("Attempt to backtest beyond available history. " |
||
320 | "Last known date: %s" % self.last_trading_day) |
||
321 | ) |
||
322 | |||
323 | return self.get_open_and_close(next_open) |
||
324 | |||
325 | def previous_open_and_close(self, start_date): |
||
326 | """ |
||
327 | Given the start_date, returns the previous open and close of the |
||
328 | market. |
||
329 | """ |
||
330 | previous = self.previous_trading_day(start_date) |
||
331 | |||
332 | if previous is None: |
||
333 | raise NoFurtherDataError( |
||
334 | msg=("Attempt to backtest beyond available history. " |
||
335 | "First known date: %s" % self.first_trading_day) |
||
336 | ) |
||
337 | return self.get_open_and_close(previous) |
||
338 | |||
339 | def next_market_minute(self, start): |
||
340 | """ |
||
341 | Get the next market minute after @start. This is either the immediate |
||
342 | next minute, the open of the same day if @start is before the market |
||
343 | open on a trading day, or the open of the next market day after @start. |
||
344 | """ |
||
345 | if self.is_trading_day(start): |
||
346 | market_open, market_close = self.get_open_and_close(start) |
||
347 | # If start before market open on a trading day, return market open. |
||
348 | if start < market_open: |
||
349 | return market_open |
||
350 | # If start is during trading hours, then get the next minute. |
||
351 | elif start < market_close: |
||
352 | return start + datetime.timedelta(minutes=1) |
||
353 | # If start is not in a trading day, or is after the market close |
||
354 | # then return the open of the *next* trading day. |
||
355 | return self.next_open_and_close(start)[0] |
||
356 | |||
357 | def previous_market_minute(self, start): |
||
358 | """ |
||
359 | Get the next market minute before @start. This is either the immediate |
||
360 | previous minute, the close of the same day if @start is after the close |
||
361 | on a trading day, or the close of the market day before @start. |
||
362 | """ |
||
363 | if self.is_trading_day(start): |
||
364 | market_open, market_close = self.get_open_and_close(start) |
||
365 | # If start after the market close, return market close. |
||
366 | if start > market_close: |
||
367 | return market_close |
||
368 | # If start is during trading hours, then get previous minute. |
||
369 | if start > market_open: |
||
370 | return start - datetime.timedelta(minutes=1) |
||
371 | # If start is not a trading day, or is before the market open |
||
372 | # then return the close of the *previous* trading day. |
||
373 | return self.previous_open_and_close(start)[1] |
||
374 | |||
375 | def get_open_and_close(self, day): |
||
376 | index = self.open_and_closes.index.get_loc(day.date()) |
||
377 | todays_minutes = self.open_and_closes.values[index] |
||
378 | return todays_minutes[0], todays_minutes[1] |
||
379 | |||
380 | def market_minutes_for_day(self, stamp): |
||
381 | market_open, market_close = self.get_open_and_close(stamp) |
||
382 | return pd.date_range(market_open, market_close, freq='T') |
||
383 | |||
384 | def open_close_window(self, start, count, offset=0, step=1): |
||
385 | """ |
||
386 | Return a DataFrame containing `count` market opens and closes, |
||
387 | beginning with `start` + `offset` days and continuing `step` minutes at |
||
388 | a time. |
||
389 | """ |
||
390 | # TODO: Correctly handle end of data. |
||
391 | start_idx = self.get_index(start) + offset |
||
392 | stop_idx = start_idx + (count * step) |
||
393 | |||
394 | index = np.arange(start_idx, stop_idx, step) |
||
395 | |||
396 | return self.open_and_closes.iloc[index] |
||
397 | |||
398 | def market_minute_window(self, start, count, step=1): |
||
399 | """ |
||
400 | Return a DatetimeIndex containing `count` market minutes, starting with |
||
401 | `start` and continuing `step` minutes at a time. |
||
402 | """ |
||
403 | if not self.is_market_hours(start): |
||
404 | raise ValueError("market_minute_window starting at " |
||
405 | "non-market time {minute}".format(minute=start)) |
||
406 | |||
407 | all_minutes = [] |
||
408 | |||
409 | current_day_minutes = self.market_minutes_for_day(start) |
||
410 | first_minute_idx = current_day_minutes.searchsorted(start) |
||
411 | minutes_in_range = current_day_minutes[first_minute_idx::step] |
||
412 | |||
413 | # Build up list of lists of days' market minutes until we have count |
||
414 | # minutes stored altogether. |
||
415 | while True: |
||
416 | |||
417 | if len(minutes_in_range) >= count: |
||
418 | # Truncate off extra minutes |
||
419 | minutes_in_range = minutes_in_range[:count] |
||
420 | |||
421 | all_minutes.append(minutes_in_range) |
||
422 | count -= len(minutes_in_range) |
||
423 | if count <= 0: |
||
424 | break |
||
425 | |||
426 | if step > 0: |
||
427 | start, _ = self.next_open_and_close(start) |
||
428 | current_day_minutes = self.market_minutes_for_day(start) |
||
429 | else: |
||
430 | _, start = self.previous_open_and_close(start) |
||
431 | current_day_minutes = self.market_minutes_for_day(start) |
||
432 | |||
433 | minutes_in_range = current_day_minutes[::step] |
||
434 | |||
435 | # Concatenate all the accumulated minutes. |
||
436 | return pd.DatetimeIndex( |
||
437 | np.concatenate(all_minutes), copy=False, tz='UTC', |
||
438 | ) |
||
439 | |||
440 | def trading_day_distance(self, first_date, second_date): |
||
441 | first_date = self.normalize_date(first_date) |
||
442 | second_date = self.normalize_date(second_date) |
||
443 | |||
444 | # TODO: May be able to replace the following with searchsorted. |
||
445 | # Find leftmost item greater than or equal to day |
||
446 | i = bisect.bisect_left(self.trading_days, first_date) |
||
447 | if i == len(self.trading_days): # nothing found |
||
448 | return None |
||
449 | j = bisect.bisect_left(self.trading_days, second_date) |
||
450 | if j == len(self.trading_days): |
||
451 | return None |
||
452 | |||
453 | return j - i |
||
454 | |||
455 | def get_index(self, dt): |
||
456 | """ |
||
457 | Return the index of the given @dt, or the index of the preceding |
||
458 | trading day if the given dt is not in the trading calendar. |
||
459 | """ |
||
460 | ndt = self.normalize_date(dt) |
||
461 | if ndt in self.trading_days: |
||
462 | return self.trading_days.searchsorted(ndt) |
||
463 | else: |
||
464 | return self.trading_days.searchsorted(ndt) - 1 |
||
465 | |||
567 |