| Total Complexity | 47 |
| Total Lines | 462 |
| Duplicated Lines | 0 % |
Complex classes like zipline.finance.performance.PerformanceTracker often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
| 1 | # |
||
| 79 | class PerformanceTracker(object): |
||
| 80 | """ |
||
| 81 | Tracks the performance of the algorithm. |
||
| 82 | """ |
||
| 83 | def __init__(self, sim_params, env, data_portal): |
||
| 84 | self.sim_params = sim_params |
||
| 85 | self.env = env |
||
| 86 | |||
| 87 | self.period_start = self.sim_params.period_start |
||
| 88 | self.period_end = self.sim_params.period_end |
||
| 89 | self.last_close = self.sim_params.last_close |
||
| 90 | first_open = self.sim_params.first_open.tz_convert( |
||
| 91 | self.env.exchange_tz |
||
| 92 | ) |
||
| 93 | self.day = pd.Timestamp(datetime(first_open.year, first_open.month, |
||
| 94 | first_open.day), tz='UTC') |
||
| 95 | self.market_open, self.market_close = env.get_open_and_close(self.day) |
||
| 96 | self.total_days = self.sim_params.days_in_period |
||
| 97 | self.capital_base = self.sim_params.capital_base |
||
| 98 | self.emission_rate = sim_params.emission_rate |
||
| 99 | |||
| 100 | all_trading_days = env.trading_days |
||
| 101 | mask = ((all_trading_days >= normalize_date(self.period_start)) & |
||
| 102 | (all_trading_days <= normalize_date(self.period_end))) |
||
| 103 | |||
| 104 | self.trading_days = all_trading_days[mask] |
||
| 105 | |||
| 106 | self._data_portal = data_portal |
||
| 107 | if data_portal is not None: |
||
| 108 | self._adjustment_reader = data_portal._adjustment_reader |
||
| 109 | else: |
||
| 110 | self._adjustment_reader = None |
||
| 111 | |||
| 112 | self.position_tracker = PositionTracker( |
||
| 113 | asset_finder=env.asset_finder, |
||
| 114 | data_portal=data_portal, |
||
| 115 | data_frequency=self.sim_params.data_frequency) |
||
| 116 | |||
| 117 | if self.emission_rate == 'daily': |
||
| 118 | self.all_benchmark_returns = pd.Series( |
||
| 119 | index=self.trading_days) |
||
| 120 | self.cumulative_risk_metrics = \ |
||
| 121 | risk.RiskMetricsCumulative(self.sim_params, self.env) |
||
| 122 | elif self.emission_rate == 'minute': |
||
| 123 | self.all_benchmark_returns = pd.Series(index=pd.date_range( |
||
| 124 | self.sim_params.first_open, self.sim_params.last_close, |
||
| 125 | freq='Min')) |
||
| 126 | |||
| 127 | self.cumulative_risk_metrics = \ |
||
| 128 | risk.RiskMetricsCumulative(self.sim_params, self.env, |
||
| 129 | create_first_day_stats=True) |
||
| 130 | |||
| 131 | # this performance period will span the entire simulation from |
||
| 132 | # inception. |
||
| 133 | self.cumulative_performance = PerformancePeriod( |
||
| 134 | # initial cash is your capital base. |
||
| 135 | starting_cash=self.capital_base, |
||
| 136 | data_frequency=self.sim_params.data_frequency, |
||
| 137 | # the cumulative period will be calculated over the entire test. |
||
| 138 | period_open=self.period_start, |
||
| 139 | period_close=self.period_end, |
||
| 140 | # don't save the transactions for the cumulative |
||
| 141 | # period |
||
| 142 | keep_transactions=False, |
||
| 143 | keep_orders=False, |
||
| 144 | # don't serialize positions for cumulative period |
||
| 145 | serialize_positions=False, |
||
| 146 | asset_finder=self.env.asset_finder, |
||
| 147 | name="Cumulative" |
||
| 148 | ) |
||
| 149 | |||
| 150 | # this performance period will span just the current market day |
||
| 151 | self.todays_performance = PerformancePeriod( |
||
| 152 | # initial cash is your capital base. |
||
| 153 | starting_cash=self.capital_base, |
||
| 154 | data_frequency=self.sim_params.data_frequency, |
||
| 155 | # the daily period will be calculated for the market day |
||
| 156 | period_open=self.market_open, |
||
| 157 | period_close=self.market_close, |
||
| 158 | keep_transactions=True, |
||
| 159 | keep_orders=True, |
||
| 160 | serialize_positions=True, |
||
| 161 | asset_finder=self.env.asset_finder, |
||
| 162 | name="Daily" |
||
| 163 | ) |
||
| 164 | |||
| 165 | self.saved_dt = self.period_start |
||
| 166 | # one indexed so that we reach 100% |
||
| 167 | self.day_count = 0.0 |
||
| 168 | self.txn_count = 0 |
||
| 169 | |||
| 170 | self.account_needs_update = True |
||
| 171 | self._account = None |
||
| 172 | |||
| 173 | self._perf_periods = [self.cumulative_performance, |
||
| 174 | self.todays_performance] |
||
| 175 | |||
| 176 | @property |
||
| 177 | def perf_periods(self): |
||
| 178 | return self._perf_periods |
||
| 179 | |||
| 180 | def __repr__(self): |
||
| 181 | return "%s(%r)" % ( |
||
| 182 | self.__class__.__name__, |
||
| 183 | {'simulation parameters': self.sim_params}) |
||
| 184 | |||
| 185 | @property |
||
| 186 | def progress(self): |
||
| 187 | if self.emission_rate == 'minute': |
||
| 188 | # Fake a value |
||
| 189 | return 1.0 |
||
| 190 | elif self.emission_rate == 'daily': |
||
| 191 | return self.day_count / self.total_days |
||
| 192 | |||
| 193 | def set_date(self, date): |
||
| 194 | if self.emission_rate == 'minute': |
||
| 195 | self.saved_dt = date |
||
| 196 | self.todays_performance.period_close = self.saved_dt |
||
| 197 | |||
| 198 | def get_portfolio(self, dt): |
||
| 199 | position_tracker = self.position_tracker |
||
| 200 | position_tracker.sync_last_sale_prices(dt) |
||
| 201 | pos_stats = position_tracker.stats() |
||
| 202 | period_stats = self.cumulative_performance.stats( |
||
| 203 | position_tracker.positions, pos_stats, self._data_portal) |
||
| 204 | return self.cumulative_performance.as_portfolio( |
||
| 205 | pos_stats, |
||
| 206 | period_stats, |
||
| 207 | position_tracker, |
||
| 208 | dt) |
||
| 209 | |||
| 210 | def get_account(self, dt): |
||
| 211 | self.position_tracker.sync_last_sale_prices(dt) |
||
| 212 | pos_stats = self.position_tracker.stats() |
||
| 213 | period_stats = self.cumulative_performance.stats( |
||
| 214 | self.position_tracker.positions, pos_stats, self._data_portal) |
||
| 215 | self._account = self.cumulative_performance.as_account( |
||
| 216 | pos_stats, period_stats) |
||
| 217 | return self._account |
||
| 218 | |||
| 219 | def to_dict(self, emission_type=None): |
||
| 220 | """ |
||
| 221 | Wrapper for serialization compatibility. |
||
| 222 | """ |
||
| 223 | pos_stats = self.position_tracker.stats() |
||
| 224 | cumulative_stats = self.cumulative_performance.stats( |
||
| 225 | self.position_tracker.positions, pos_stats, self._data_portal) |
||
| 226 | todays_stats = self.todays_performance.stats( |
||
| 227 | self.position_tracker.positions, pos_stats, self._data_portal) |
||
| 228 | |||
| 229 | return self._to_dict(pos_stats, |
||
| 230 | cumulative_stats, |
||
| 231 | todays_stats, |
||
| 232 | emission_type) |
||
| 233 | |||
| 234 | def _to_dict(self, pos_stats, cumulative_stats, todays_stats, |
||
| 235 | emission_type=None): |
||
| 236 | """ |
||
| 237 | Creates a dictionary representing the state of this tracker. |
||
| 238 | Returns a dict object of the form described in header comments. |
||
| 239 | |||
| 240 | Use this method internally, when stats are available. |
||
| 241 | """ |
||
| 242 | # Default to the emission rate of this tracker if no type is provided |
||
| 243 | if emission_type is None: |
||
| 244 | emission_type = self.emission_rate |
||
| 245 | |||
| 246 | position_tracker = self.position_tracker |
||
| 247 | |||
| 248 | _dict = { |
||
| 249 | 'period_start': self.period_start, |
||
| 250 | 'period_end': self.period_end, |
||
| 251 | 'capital_base': self.capital_base, |
||
| 252 | 'cumulative_perf': self.cumulative_performance.to_dict( |
||
| 253 | pos_stats, cumulative_stats, position_tracker, |
||
| 254 | ), |
||
| 255 | 'progress': self.progress, |
||
| 256 | 'cumulative_risk_metrics': self.cumulative_risk_metrics.to_dict() |
||
| 257 | } |
||
| 258 | if emission_type == 'daily': |
||
| 259 | _dict['daily_perf'] = self.todays_performance.to_dict( |
||
| 260 | pos_stats, |
||
| 261 | todays_stats, |
||
| 262 | position_tracker) |
||
| 263 | elif emission_type == 'minute': |
||
| 264 | _dict['minute_perf'] = self.todays_performance.to_dict( |
||
| 265 | pos_stats, |
||
| 266 | todays_stats, |
||
| 267 | position_tracker, |
||
| 268 | self.saved_dt) |
||
| 269 | else: |
||
| 270 | raise ValueError("Invalid emission type: %s" % emission_type) |
||
| 271 | |||
| 272 | return _dict |
||
| 273 | |||
| 274 | def copy_state_from(self, other_perf_tracker): |
||
| 275 | self.all_benchmark_returns = other_perf_tracker.all_benchmark_returns |
||
| 276 | |||
| 277 | if other_perf_tracker.position_tracker: |
||
| 278 | self.position_tracker._unpaid_dividends = \ |
||
| 279 | other_perf_tracker.position_tracker._unpaid_dividends |
||
| 280 | |||
| 281 | self.position_tracker._unpaid_stock_dividends = \ |
||
| 282 | other_perf_tracker.position_tracker._unpaid_stock_dividends |
||
| 283 | |||
| 284 | def process_transaction(self, transaction): |
||
| 285 | self.txn_count += 1 |
||
| 286 | self.position_tracker.execute_transaction(transaction) |
||
| 287 | self.cumulative_performance.handle_execution(transaction) |
||
| 288 | self.todays_performance.handle_execution(transaction) |
||
| 289 | |||
| 290 | def handle_splits(self, splits): |
||
| 291 | leftover_cash = self.position_tracker.handle_splits(splits) |
||
| 292 | if leftover_cash > 0: |
||
| 293 | self.cumulative_performance.handle_cash_payment(leftover_cash) |
||
| 294 | self.todays_performance.handle_cash_payment(leftover_cash) |
||
| 295 | |||
| 296 | def process_order(self, event): |
||
| 297 | self.cumulative_performance.record_order(event) |
||
| 298 | self.todays_performance.record_order(event) |
||
| 299 | |||
| 300 | def process_commission(self, commission): |
||
| 301 | sid = commission["sid"] |
||
| 302 | cost = commission["cost"] |
||
| 303 | |||
| 304 | self.position_tracker.handle_commission(sid, cost) |
||
| 305 | self.cumulative_performance.handle_commission(cost) |
||
| 306 | self.todays_performance.handle_commission(cost) |
||
| 307 | |||
| 308 | def process_close_position(self, event): |
||
| 309 | txn = self.position_tracker.\ |
||
| 310 | maybe_create_close_position_transaction(event) |
||
| 311 | if txn: |
||
| 312 | self.process_transaction(txn) |
||
| 313 | |||
| 314 | def check_upcoming_dividends(self, next_trading_day): |
||
| 315 | """ |
||
| 316 | Check if we currently own any stocks with dividends whose ex_date is |
||
| 317 | the next trading day. Track how much we should be payed on those |
||
| 318 | dividends' pay dates. |
||
| 319 | |||
| 320 | Then check if we are owed cash/stock for any dividends whose pay date |
||
| 321 | is the next trading day. Apply all such benefits, then recalculate |
||
| 322 | performance. |
||
| 323 | """ |
||
| 324 | if self._adjustment_reader is None: |
||
| 325 | return |
||
| 326 | position_tracker = self.position_tracker |
||
| 327 | held_sids = set(position_tracker.positions) |
||
| 328 | # Dividends whose ex_date is the next trading day. We need to check if |
||
| 329 | # we own any of these stocks so we know to pay them out when the pay |
||
| 330 | # date comes. |
||
| 331 | if held_sids: |
||
| 332 | dividends_earnable = self._adjustment_reader.\ |
||
| 333 | get_dividends_with_ex_date(held_sids, next_trading_day) |
||
| 334 | stock_dividends = self._adjustment_reader.\ |
||
| 335 | get_stock_dividends_with_ex_date(held_sids, next_trading_day) |
||
| 336 | position_tracker.earn_dividends(dividends_earnable, |
||
| 337 | stock_dividends) |
||
| 338 | |||
| 339 | net_cash_payment = position_tracker.pay_dividends(next_trading_day) |
||
| 340 | if not net_cash_payment: |
||
| 341 | return |
||
| 342 | |||
| 343 | self.cumulative_performance.handle_dividends_paid(net_cash_payment) |
||
| 344 | self.todays_performance.handle_dividends_paid(net_cash_payment) |
||
| 345 | |||
| 346 | def check_asset_auto_closes(self, next_trading_day): |
||
| 347 | """ |
||
| 348 | Check if the position tracker currently owns any Assets with an |
||
| 349 | auto-close date that is the next trading day. Close those positions. |
||
| 350 | |||
| 351 | Parameters |
||
| 352 | ---------- |
||
| 353 | next_trading_day : pandas.Timestamp |
||
| 354 | The next trading day of the simulation |
||
| 355 | """ |
||
| 356 | auto_close_events = self.position_tracker.auto_close_position_events( |
||
| 357 | next_trading_day=next_trading_day |
||
| 358 | ) |
||
| 359 | for event in auto_close_events: |
||
| 360 | self.process_close_position(event) |
||
| 361 | |||
| 362 | def handle_minute_close(self, dt): |
||
| 363 | """ |
||
| 364 | Handles the close of the given minute. This includes handling |
||
| 365 | market-close functions if the given minute is the end of the market |
||
| 366 | day. |
||
| 367 | |||
| 368 | Parameters |
||
| 369 | __________ |
||
| 370 | dt : Timestamp |
||
| 371 | The minute that is ending |
||
| 372 | |||
| 373 | Returns |
||
| 374 | _______ |
||
| 375 | (dict, dict/None) |
||
| 376 | A tuple of the minute perf packet and daily perf packet. |
||
| 377 | If the market day has not ended, the daily perf packet is None. |
||
| 378 | """ |
||
| 379 | todays_date = normalize_date(dt) |
||
| 380 | account = self.get_account(dt) |
||
| 381 | |||
| 382 | bench_returns = self.all_benchmark_returns.loc[todays_date:dt] |
||
| 383 | # cumulative returns |
||
| 384 | bench_since_open = (1. + bench_returns).prod() - 1 |
||
| 385 | |||
| 386 | self.position_tracker.sync_last_sale_prices(dt) |
||
| 387 | pos_stats = self.position_tracker.stats() |
||
| 388 | cumulative_stats = self.cumulative_performance.stats( |
||
| 389 | self.position_tracker.positions, pos_stats, self._data_portal |
||
| 390 | ) |
||
| 391 | todays_stats = self.todays_performance.stats( |
||
| 392 | self.position_tracker.positions, pos_stats, self._data_portal |
||
| 393 | ) |
||
| 394 | self.cumulative_risk_metrics.update(todays_date, |
||
| 395 | todays_stats.returns, |
||
| 396 | bench_since_open, |
||
| 397 | account) |
||
| 398 | |||
| 399 | minute_packet = self._to_dict(pos_stats, |
||
| 400 | cumulative_stats, |
||
| 401 | todays_stats, |
||
| 402 | emission_type='minute') |
||
| 403 | |||
| 404 | if dt == self.market_close: |
||
| 405 | # if this is the last minute of the day, we also want to |
||
| 406 | # emit a daily packet. |
||
| 407 | return minute_packet, self._handle_market_close(todays_date, |
||
| 408 | pos_stats, |
||
| 409 | todays_stats) |
||
| 410 | else: |
||
| 411 | return minute_packet, None |
||
| 412 | |||
| 413 | def handle_market_close_daily(self, dt): |
||
| 414 | """ |
||
| 415 | Function called after handle_data when running with daily emission |
||
| 416 | rate. |
||
| 417 | """ |
||
| 418 | completed_date = normalize_date(dt) |
||
| 419 | |||
| 420 | self.position_tracker.sync_last_sale_prices(dt) |
||
| 421 | |||
| 422 | pos_stats = self.position_tracker.stats() |
||
| 423 | todays_stats = self.todays_performance.stats( |
||
| 424 | self.position_tracker.positions, pos_stats, self._data_portal |
||
| 425 | ) |
||
| 426 | account = self.get_account(completed_date) |
||
| 427 | |||
| 428 | # update risk metrics for cumulative performance |
||
| 429 | benchmark_value = self.all_benchmark_returns[completed_date] |
||
| 430 | |||
| 431 | self.cumulative_risk_metrics.update( |
||
| 432 | completed_date, |
||
| 433 | todays_stats.returns, |
||
| 434 | benchmark_value, |
||
| 435 | account) |
||
| 436 | |||
| 437 | daily_packet = self._handle_market_close(completed_date, |
||
| 438 | pos_stats, |
||
| 439 | todays_stats) |
||
| 440 | |||
| 441 | return daily_packet |
||
| 442 | |||
| 443 | def _handle_market_close(self, completed_date, pos_stats, todays_stats): |
||
| 444 | |||
| 445 | # increment the day counter before we move markers forward. |
||
| 446 | self.day_count += 1.0 |
||
| 447 | |||
| 448 | # Get the next trading day and, if it is past the bounds of this |
||
| 449 | # simulation, return the daily perf packet |
||
| 450 | next_trading_day = self.env.next_trading_day(completed_date) |
||
| 451 | |||
| 452 | # Check if any assets need to be auto-closed before generating today's |
||
| 453 | # perf period |
||
| 454 | if next_trading_day: |
||
| 455 | self.check_asset_auto_closes(next_trading_day=next_trading_day) |
||
| 456 | |||
| 457 | # Take a snapshot of our current performance to return to the |
||
| 458 | # browser. |
||
| 459 | cumulative_stats = self.cumulative_performance.stats( |
||
| 460 | self.position_tracker.positions, |
||
| 461 | pos_stats, self._data_portal) |
||
| 462 | daily_update = self._to_dict(pos_stats, |
||
| 463 | cumulative_stats, |
||
| 464 | todays_stats, |
||
| 465 | emission_type='daily') |
||
| 466 | |||
| 467 | # On the last day of the test, don't create tomorrow's performance |
||
| 468 | # period. We may not be able to find the next trading day if we're at |
||
| 469 | # the end of our historical data |
||
| 470 | if self.market_close >= self.last_close: |
||
| 471 | return daily_update |
||
| 472 | |||
| 473 | # move the market day markers forward |
||
| 474 | self.market_open, self.market_close = \ |
||
| 475 | self.env.next_open_and_close(self.day) |
||
| 476 | self.day = self.env.next_trading_day(self.day) |
||
| 477 | |||
| 478 | # Roll over positions to current day. |
||
| 479 | self.todays_performance.rollover(pos_stats, todays_stats) |
||
| 480 | self.todays_performance.period_open = self.market_open |
||
| 481 | self.todays_performance.period_close = self.market_close |
||
| 482 | |||
| 483 | # If the next trading day is irrelevant, then return the daily packet |
||
| 484 | if (next_trading_day is None) or (next_trading_day >= self.last_close): |
||
| 485 | return daily_update |
||
| 486 | |||
| 487 | # Check for any dividends and auto-closes, then return the daily perf |
||
| 488 | # packet |
||
| 489 | self.check_upcoming_dividends(next_trading_day=next_trading_day) |
||
| 490 | return daily_update |
||
| 491 | |||
| 492 | def handle_simulation_end(self): |
||
| 493 | """ |
||
| 494 | When the simulation is complete, run the full period risk report |
||
| 495 | and send it out on the results socket. |
||
| 496 | """ |
||
| 497 | |||
| 498 | log_msg = "Simulated {n} trading days out of {m}." |
||
| 499 | log.info(log_msg.format(n=int(self.day_count), m=self.total_days)) |
||
| 500 | log.info("first open: {d}".format( |
||
| 501 | d=self.sim_params.first_open)) |
||
| 502 | log.info("last close: {d}".format( |
||
| 503 | d=self.sim_params.last_close)) |
||
| 504 | |||
| 505 | bms = pd.Series( |
||
| 506 | index=self.cumulative_risk_metrics.cont_index, |
||
| 507 | data=self.cumulative_risk_metrics.benchmark_returns_cont) |
||
| 508 | ars = pd.Series( |
||
| 509 | index=self.cumulative_risk_metrics.cont_index, |
||
| 510 | data=self.cumulative_risk_metrics.algorithm_returns_cont) |
||
| 511 | acl = self.cumulative_risk_metrics.algorithm_cumulative_leverages |
||
| 512 | self.risk_report = risk.RiskReport( |
||
| 513 | ars, |
||
| 514 | self.sim_params, |
||
| 515 | benchmark_returns=bms, |
||
| 516 | algorithm_leverages=acl, |
||
| 517 | env=self.env) |
||
| 518 | |||
| 519 | risk_dict = self.risk_report.to_dict() |
||
| 520 | return risk_dict |
||
| 521 | |||
| 522 | def __getstate__(self): |
||
| 523 | state_dict = \ |
||
| 524 | {k: v for k, v in iteritems(self.__dict__) |
||
| 525 | if not k.startswith('_')} |
||
| 526 | |||
| 527 | STATE_VERSION = 4 |
||
| 528 | state_dict[VERSION_LABEL] = STATE_VERSION |
||
| 529 | |||
| 530 | return state_dict |
||
| 531 | |||
| 532 | def __setstate__(self, state): |
||
| 533 | |||
| 534 | OLDEST_SUPPORTED_STATE = 4 |
||
| 535 | version = state.pop(VERSION_LABEL) |
||
| 536 | |||
| 537 | if version < OLDEST_SUPPORTED_STATE: |
||
| 538 | raise BaseException("PerformanceTracker saved state is too old.") |
||
| 539 | |||
| 540 | self.__dict__.update(state) |
||
| 541 |