| Total Complexity | 67 |
| Total Lines | 405 |
| Duplicated Lines | 0 % |
Complex classes like zipline.finance.TradingEnvironment often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
| 1 | # |
||
| 60 | class TradingEnvironment(object): |
||
| 61 | |||
| 62 | # Token used as a substitute for pickling objects that contain a |
||
| 63 | # reference to a TradingEnvironment |
||
| 64 | PERSISTENT_TOKEN = "<TradingEnvironment>" |
||
| 65 | |||
| 66 | def __init__( |
||
| 67 | self, |
||
| 68 | load=None, |
||
| 69 | bm_symbol='^GSPC', |
||
| 70 | exchange_tz="US/Eastern", |
||
| 71 | max_date=None, |
||
| 72 | env_trading_calendar=tradingcalendar, |
||
| 73 | asset_db_path=':memory:' |
||
| 74 | ): |
||
| 75 | """ |
||
| 76 | @load is function that returns benchmark_returns and treasury_curves |
||
| 77 | The treasury_curves are expected to be a DataFrame with an index of |
||
| 78 | dates and columns of the curve names, e.g. '10year', '1month', etc. |
||
| 79 | """ |
||
| 80 | self.trading_day = env_trading_calendar.trading_day.copy() |
||
| 81 | |||
| 82 | # `tc_td` is short for "trading calendar trading days" |
||
| 83 | tc_td = env_trading_calendar.trading_days |
||
| 84 | |||
| 85 | if max_date: |
||
| 86 | self.trading_days = tc_td[tc_td <= max_date].copy() |
||
| 87 | else: |
||
| 88 | self.trading_days = tc_td.copy() |
||
| 89 | |||
| 90 | self.first_trading_day = self.trading_days[0] |
||
| 91 | self.last_trading_day = self.trading_days[-1] |
||
| 92 | |||
| 93 | self.early_closes = env_trading_calendar.get_early_closes( |
||
| 94 | self.first_trading_day, self.last_trading_day) |
||
| 95 | |||
| 96 | self.open_and_closes = env_trading_calendar.open_and_closes.loc[ |
||
| 97 | self.trading_days] |
||
| 98 | |||
| 99 | self.bm_symbol = bm_symbol |
||
| 100 | if not load: |
||
| 101 | load = load_market_data |
||
| 102 | |||
| 103 | self.benchmark_returns, self.treasury_curves = \ |
||
| 104 | load(self.trading_day, self.trading_days, self.bm_symbol) |
||
| 105 | |||
| 106 | if max_date: |
||
| 107 | tr_c = self.treasury_curves |
||
| 108 | # Mask the treasury curves down to the current date. |
||
| 109 | # In the case of live trading, the last date in the treasury |
||
| 110 | # curves would be the day before the date considered to be |
||
| 111 | # 'today'. |
||
| 112 | self.treasury_curves = tr_c[tr_c.index <= max_date] |
||
| 113 | |||
| 114 | self.exchange_tz = exchange_tz |
||
| 115 | |||
| 116 | if isinstance(asset_db_path, string_types): |
||
| 117 | asset_db_path = 'sqlite:///%s' % asset_db_path |
||
| 118 | self.engine = engine = create_engine(asset_db_path) |
||
| 119 | AssetDBWriterFromDictionary().init_db(engine) |
||
| 120 | else: |
||
| 121 | self.engine = engine = asset_db_path |
||
| 122 | |||
| 123 | if engine is not None: |
||
| 124 | self.asset_finder = AssetFinder(engine) |
||
| 125 | else: |
||
| 126 | self.asset_finder = None |
||
| 127 | |||
| 128 | def write_data(self, |
||
| 129 | engine=None, |
||
| 130 | equities_data=None, |
||
| 131 | futures_data=None, |
||
| 132 | exchanges_data=None, |
||
| 133 | root_symbols_data=None, |
||
| 134 | equities_df=None, |
||
| 135 | futures_df=None, |
||
| 136 | exchanges_df=None, |
||
| 137 | root_symbols_df=None, |
||
| 138 | equities_identifiers=None, |
||
| 139 | futures_identifiers=None, |
||
| 140 | exchanges_identifiers=None, |
||
| 141 | root_symbols_identifiers=None, |
||
| 142 | allow_sid_assignment=True): |
||
| 143 | """ Write the supplied data to the database. |
||
| 144 | |||
| 145 | Parameters |
||
| 146 | ---------- |
||
| 147 | equities_data: dict, optional |
||
| 148 | A dictionary of equity metadata |
||
| 149 | futures_data: dict, optional |
||
| 150 | A dictionary of futures metadata |
||
| 151 | exchanges_data: dict, optional |
||
| 152 | A dictionary of exchanges metadata |
||
| 153 | root_symbols_data: dict, optional |
||
| 154 | A dictionary of root symbols metadata |
||
| 155 | equities_df: pandas.DataFrame, optional |
||
| 156 | A pandas.DataFrame of equity metadata |
||
| 157 | futures_df: pandas.DataFrame, optional |
||
| 158 | A pandas.DataFrame of futures metadata |
||
| 159 | exchanges_df: pandas.DataFrame, optional |
||
| 160 | A pandas.DataFrame of exchanges metadata |
||
| 161 | root_symbols_df: pandas.DataFrame, optional |
||
| 162 | A pandas.DataFrame of root symbols metadata |
||
| 163 | equities_identifiers: list, optional |
||
| 164 | A list of equities identifiers (sids, symbols, Assets) |
||
| 165 | futures_identifiers: list, optional |
||
| 166 | A list of futures identifiers (sids, symbols, Assets) |
||
| 167 | exchanges_identifiers: list, optional |
||
| 168 | A list of exchanges identifiers (ids or names) |
||
| 169 | root_symbols_identifiers: list, optional |
||
| 170 | A list of root symbols identifiers (ids or symbols) |
||
| 171 | """ |
||
| 172 | if engine: |
||
| 173 | self.engine = engine |
||
| 174 | |||
| 175 | # If any pandas.DataFrame data has been provided, |
||
| 176 | # write it to the database. |
||
| 177 | if (equities_df is not None or futures_df is not None or |
||
| 178 | exchanges_df is not None or root_symbols_df is not None): |
||
| 179 | self._write_data_dataframes(equities_df, futures_df, |
||
| 180 | exchanges_df, root_symbols_df) |
||
| 181 | |||
| 182 | if (equities_data is not None or futures_data is not None or |
||
| 183 | exchanges_data is not None or root_symbols_data is not None): |
||
| 184 | self._write_data_dicts(equities_data, futures_data, |
||
| 185 | exchanges_data, root_symbols_data) |
||
| 186 | |||
| 187 | # These could be lists or other iterables such as a pandas.Index. |
||
| 188 | # For simplicity, don't check whether data has been provided. |
||
| 189 | self._write_data_lists(equities_identifiers, |
||
| 190 | futures_identifiers, |
||
| 191 | exchanges_identifiers, |
||
| 192 | root_symbols_identifiers, |
||
| 193 | allow_sid_assignment=allow_sid_assignment) |
||
| 194 | |||
| 195 | def _write_data_lists(self, equities=None, futures=None, exchanges=None, |
||
| 196 | root_symbols=None, allow_sid_assignment=True): |
||
| 197 | AssetDBWriterFromList(equities, futures, exchanges, root_symbols)\ |
||
| 198 | .write_all(self.engine, allow_sid_assignment=allow_sid_assignment) |
||
| 199 | |||
| 200 | def _write_data_dicts(self, equities=None, futures=None, exchanges=None, |
||
| 201 | root_symbols=None): |
||
| 202 | AssetDBWriterFromDictionary(equities, futures, exchanges, root_symbols)\ |
||
| 203 | .write_all(self.engine) |
||
| 204 | |||
| 205 | def _write_data_dataframes(self, equities=None, futures=None, |
||
| 206 | exchanges=None, root_symbols=None): |
||
| 207 | AssetDBWriterFromDataFrame(equities, futures, exchanges, root_symbols)\ |
||
| 208 | .write_all(self.engine) |
||
| 209 | |||
| 210 | def normalize_date(self, test_date): |
||
| 211 | test_date = pd.Timestamp(test_date, tz='UTC') |
||
| 212 | return pd.tseries.tools.normalize_date(test_date) |
||
| 213 | |||
| 214 | def utc_dt_in_exchange(self, dt): |
||
| 215 | return pd.Timestamp(dt).tz_convert(self.exchange_tz) |
||
| 216 | |||
| 217 | def exchange_dt_in_utc(self, dt): |
||
| 218 | return pd.Timestamp(dt, tz=self.exchange_tz).tz_convert('UTC') |
||
| 219 | |||
| 220 | def is_market_hours(self, test_date): |
||
| 221 | if not self.is_trading_day(test_date): |
||
| 222 | return False |
||
| 223 | |||
| 224 | mkt_open, mkt_close = self.get_open_and_close(test_date) |
||
| 225 | return test_date >= mkt_open and test_date <= mkt_close |
||
| 226 | |||
| 227 | def is_trading_day(self, test_date): |
||
| 228 | dt = self.normalize_date(test_date) |
||
| 229 | return (dt in self.trading_days) |
||
| 230 | |||
| 231 | def next_trading_day(self, test_date): |
||
| 232 | dt = self.normalize_date(test_date) |
||
| 233 | delta = datetime.timedelta(days=1) |
||
| 234 | |||
| 235 | while dt <= self.last_trading_day: |
||
| 236 | dt += delta |
||
| 237 | if dt in self.trading_days: |
||
| 238 | return dt |
||
| 239 | |||
| 240 | return None |
||
| 241 | |||
| 242 | def previous_trading_day(self, test_date): |
||
| 243 | dt = self.normalize_date(test_date) |
||
| 244 | delta = datetime.timedelta(days=-1) |
||
| 245 | |||
| 246 | while self.first_trading_day < dt: |
||
| 247 | dt += delta |
||
| 248 | if dt in self.trading_days: |
||
| 249 | return dt |
||
| 250 | |||
| 251 | return None |
||
| 252 | |||
| 253 | def add_trading_days(self, n, date): |
||
| 254 | """ |
||
| 255 | Adds n trading days to date. If this would fall outside of the |
||
| 256 | trading calendar, a NoFurtherDataError is raised. |
||
| 257 | |||
| 258 | :Arguments: |
||
| 259 | n : int |
||
| 260 | The number of days to add to date, this can be positive or |
||
| 261 | negative. |
||
| 262 | date : datetime |
||
| 263 | The date to add to. |
||
| 264 | |||
| 265 | :Returns: |
||
| 266 | new_date : datetime |
||
| 267 | n trading days added to date. |
||
| 268 | """ |
||
| 269 | if n == 1: |
||
| 270 | return self.next_trading_day(date) |
||
| 271 | if n == -1: |
||
| 272 | return self.previous_trading_day(date) |
||
| 273 | |||
| 274 | idx = self.get_index(date) + n |
||
| 275 | if idx < 0 or idx >= len(self.trading_days): |
||
| 276 | raise NoFurtherDataError( |
||
| 277 | msg='Cannot add %d days to %s' % (n, date) |
||
| 278 | ) |
||
| 279 | |||
| 280 | return self.trading_days[idx] |
||
| 281 | |||
| 282 | def days_in_range(self, start, end): |
||
| 283 | mask = ((self.trading_days >= start) & |
||
| 284 | (self.trading_days <= end)) |
||
| 285 | return self.trading_days[mask] |
||
| 286 | |||
| 287 | def opens_in_range(self, start, end): |
||
| 288 | return self.open_and_closes.market_open.loc[start:end] |
||
| 289 | |||
| 290 | def closes_in_range(self, start, end): |
||
| 291 | return self.open_and_closes.market_close.loc[start:end] |
||
| 292 | |||
| 293 | def minutes_for_days_in_range(self, start, end): |
||
| 294 | """ |
||
| 295 | Get all market minutes for the days between start and end, inclusive. |
||
| 296 | """ |
||
| 297 | start_date = self.normalize_date(start) |
||
| 298 | end_date = self.normalize_date(end) |
||
| 299 | |||
| 300 | all_minutes = [] |
||
| 301 | for day in self.days_in_range(start_date, end_date): |
||
| 302 | day_minutes = self.market_minutes_for_day(day) |
||
| 303 | all_minutes.append(day_minutes) |
||
| 304 | |||
| 305 | # Concatenate all minutes and truncate minutes before start/after end. |
||
| 306 | return pd.DatetimeIndex( |
||
| 307 | np.concatenate(all_minutes), copy=False, tz='UTC', |
||
| 308 | ) |
||
| 309 | |||
| 310 | def next_open_and_close(self, start_date): |
||
| 311 | """ |
||
| 312 | Given the start_date, returns the next open and close of |
||
| 313 | the market. |
||
| 314 | """ |
||
| 315 | next_open = self.next_trading_day(start_date) |
||
| 316 | |||
| 317 | if next_open is None: |
||
| 318 | raise NoFurtherDataError( |
||
| 319 | msg=("Attempt to backtest beyond available history. " |
||
| 320 | "Last known date: %s" % self.last_trading_day) |
||
| 321 | ) |
||
| 322 | |||
| 323 | return self.get_open_and_close(next_open) |
||
| 324 | |||
| 325 | def previous_open_and_close(self, start_date): |
||
| 326 | """ |
||
| 327 | Given the start_date, returns the previous open and close of the |
||
| 328 | market. |
||
| 329 | """ |
||
| 330 | previous = self.previous_trading_day(start_date) |
||
| 331 | |||
| 332 | if previous is None: |
||
| 333 | raise NoFurtherDataError( |
||
| 334 | msg=("Attempt to backtest beyond available history. " |
||
| 335 | "First known date: %s" % self.first_trading_day) |
||
| 336 | ) |
||
| 337 | return self.get_open_and_close(previous) |
||
| 338 | |||
| 339 | def next_market_minute(self, start): |
||
| 340 | """ |
||
| 341 | Get the next market minute after @start. This is either the immediate |
||
| 342 | next minute, the open of the same day if @start is before the market |
||
| 343 | open on a trading day, or the open of the next market day after @start. |
||
| 344 | """ |
||
| 345 | if self.is_trading_day(start): |
||
| 346 | market_open, market_close = self.get_open_and_close(start) |
||
| 347 | # If start before market open on a trading day, return market open. |
||
| 348 | if start < market_open: |
||
| 349 | return market_open |
||
| 350 | # If start is during trading hours, then get the next minute. |
||
| 351 | elif start < market_close: |
||
| 352 | return start + datetime.timedelta(minutes=1) |
||
| 353 | # If start is not in a trading day, or is after the market close |
||
| 354 | # then return the open of the *next* trading day. |
||
| 355 | return self.next_open_and_close(start)[0] |
||
| 356 | |||
| 357 | def previous_market_minute(self, start): |
||
| 358 | """ |
||
| 359 | Get the next market minute before @start. This is either the immediate |
||
| 360 | previous minute, the close of the same day if @start is after the close |
||
| 361 | on a trading day, or the close of the market day before @start. |
||
| 362 | """ |
||
| 363 | if self.is_trading_day(start): |
||
| 364 | market_open, market_close = self.get_open_and_close(start) |
||
| 365 | # If start after the market close, return market close. |
||
| 366 | if start > market_close: |
||
| 367 | return market_close |
||
| 368 | # If start is during trading hours, then get previous minute. |
||
| 369 | if start > market_open: |
||
| 370 | return start - datetime.timedelta(minutes=1) |
||
| 371 | # If start is not a trading day, or is before the market open |
||
| 372 | # then return the close of the *previous* trading day. |
||
| 373 | return self.previous_open_and_close(start)[1] |
||
| 374 | |||
| 375 | def get_open_and_close(self, day): |
||
| 376 | index = self.open_and_closes.index.get_loc(day.date()) |
||
| 377 | todays_minutes = self.open_and_closes.values[index] |
||
| 378 | return todays_minutes[0], todays_minutes[1] |
||
| 379 | |||
| 380 | def market_minutes_for_day(self, stamp): |
||
| 381 | market_open, market_close = self.get_open_and_close(stamp) |
||
| 382 | return pd.date_range(market_open, market_close, freq='T') |
||
| 383 | |||
| 384 | def open_close_window(self, start, count, offset=0, step=1): |
||
| 385 | """ |
||
| 386 | Return a DataFrame containing `count` market opens and closes, |
||
| 387 | beginning with `start` + `offset` days and continuing `step` minutes at |
||
| 388 | a time. |
||
| 389 | """ |
||
| 390 | # TODO: Correctly handle end of data. |
||
| 391 | start_idx = self.get_index(start) + offset |
||
| 392 | stop_idx = start_idx + (count * step) |
||
| 393 | |||
| 394 | index = np.arange(start_idx, stop_idx, step) |
||
| 395 | |||
| 396 | return self.open_and_closes.iloc[index] |
||
| 397 | |||
| 398 | def market_minute_window(self, start, count, step=1): |
||
| 399 | """ |
||
| 400 | Return a DatetimeIndex containing `count` market minutes, starting with |
||
| 401 | `start` and continuing `step` minutes at a time. |
||
| 402 | """ |
||
| 403 | if not self.is_market_hours(start): |
||
| 404 | raise ValueError("market_minute_window starting at " |
||
| 405 | "non-market time {minute}".format(minute=start)) |
||
| 406 | |||
| 407 | all_minutes = [] |
||
| 408 | |||
| 409 | current_day_minutes = self.market_minutes_for_day(start) |
||
| 410 | first_minute_idx = current_day_minutes.searchsorted(start) |
||
| 411 | minutes_in_range = current_day_minutes[first_minute_idx::step] |
||
| 412 | |||
| 413 | # Build up list of lists of days' market minutes until we have count |
||
| 414 | # minutes stored altogether. |
||
| 415 | while True: |
||
| 416 | |||
| 417 | if len(minutes_in_range) >= count: |
||
| 418 | # Truncate off extra minutes |
||
| 419 | minutes_in_range = minutes_in_range[:count] |
||
| 420 | |||
| 421 | all_minutes.append(minutes_in_range) |
||
| 422 | count -= len(minutes_in_range) |
||
| 423 | if count <= 0: |
||
| 424 | break |
||
| 425 | |||
| 426 | if step > 0: |
||
| 427 | start, _ = self.next_open_and_close(start) |
||
| 428 | current_day_minutes = self.market_minutes_for_day(start) |
||
| 429 | else: |
||
| 430 | _, start = self.previous_open_and_close(start) |
||
| 431 | current_day_minutes = self.market_minutes_for_day(start) |
||
| 432 | |||
| 433 | minutes_in_range = current_day_minutes[::step] |
||
| 434 | |||
| 435 | # Concatenate all the accumulated minutes. |
||
| 436 | return pd.DatetimeIndex( |
||
| 437 | np.concatenate(all_minutes), copy=False, tz='UTC', |
||
| 438 | ) |
||
| 439 | |||
| 440 | def trading_day_distance(self, first_date, second_date): |
||
| 441 | first_date = self.normalize_date(first_date) |
||
| 442 | second_date = self.normalize_date(second_date) |
||
| 443 | |||
| 444 | # TODO: May be able to replace the following with searchsorted. |
||
| 445 | # Find leftmost item greater than or equal to day |
||
| 446 | i = bisect.bisect_left(self.trading_days, first_date) |
||
| 447 | if i == len(self.trading_days): # nothing found |
||
| 448 | return None |
||
| 449 | j = bisect.bisect_left(self.trading_days, second_date) |
||
| 450 | if j == len(self.trading_days): |
||
| 451 | return None |
||
| 452 | |||
| 453 | return j - i |
||
| 454 | |||
| 455 | def get_index(self, dt): |
||
| 456 | """ |
||
| 457 | Return the index of the given @dt, or the index of the preceding |
||
| 458 | trading day if the given dt is not in the trading calendar. |
||
| 459 | """ |
||
| 460 | ndt = self.normalize_date(dt) |
||
| 461 | if ndt in self.trading_days: |
||
| 462 | return self.trading_days.searchsorted(ndt) |
||
| 463 | else: |
||
| 464 | return self.trading_days.searchsorted(ndt) - 1 |
||
| 465 | |||
| 567 |