Total Complexity | 60 |
Total Lines | 493 |
Duplicated Lines | 0 % |
Complex classes like zipline.finance.performance.PerformanceTracker often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
1 | # |
||
81 | class PerformanceTracker(object): |
||
82 | """ |
||
83 | Tracks the performance of the algorithm. |
||
84 | """ |
||
85 | def __init__(self, sim_params, env): |
||
86 | |||
87 | self.sim_params = sim_params |
||
88 | self.env = env |
||
89 | |||
90 | self.period_start = self.sim_params.period_start |
||
91 | self.period_end = self.sim_params.period_end |
||
92 | self.last_close = self.sim_params.last_close |
||
93 | first_open = self.sim_params.first_open.tz_convert( |
||
94 | self.env.exchange_tz |
||
95 | ) |
||
96 | self.day = pd.Timestamp(datetime(first_open.year, first_open.month, |
||
97 | first_open.day), tz='UTC') |
||
98 | self.market_open, self.market_close = env.get_open_and_close(self.day) |
||
99 | self.total_days = self.sim_params.days_in_period |
||
100 | self.capital_base = self.sim_params.capital_base |
||
101 | self.emission_rate = sim_params.emission_rate |
||
102 | |||
103 | all_trading_days = env.trading_days |
||
104 | mask = ((all_trading_days >= normalize_date(self.period_start)) & |
||
105 | (all_trading_days <= normalize_date(self.period_end))) |
||
106 | |||
107 | self.trading_days = all_trading_days[mask] |
||
108 | |||
109 | self.dividend_frame = pd.DataFrame() |
||
110 | self._dividend_count = 0 |
||
111 | |||
112 | self.position_tracker = PositionTracker(asset_finder=env.asset_finder) |
||
113 | |||
114 | if self.emission_rate == 'daily': |
||
115 | self.all_benchmark_returns = pd.Series( |
||
116 | index=self.trading_days) |
||
117 | self.cumulative_risk_metrics = \ |
||
118 | risk.RiskMetricsCumulative(self.sim_params, self.env) |
||
119 | |||
120 | elif self.emission_rate == 'minute': |
||
121 | self.all_benchmark_returns = pd.Series(index=pd.date_range( |
||
122 | self.sim_params.first_open, self.sim_params.last_close, |
||
123 | freq='Min')) |
||
124 | |||
125 | self.cumulative_risk_metrics = \ |
||
126 | risk.RiskMetricsCumulative(self.sim_params, self.env, |
||
127 | create_first_day_stats=True) |
||
128 | |||
129 | # this performance period will span the entire simulation from |
||
130 | # inception. |
||
131 | self.cumulative_performance = PerformancePeriod( |
||
132 | # initial cash is your capital base. |
||
133 | starting_cash=self.capital_base, |
||
134 | # the cumulative period will be calculated over the entire test. |
||
135 | period_open=self.period_start, |
||
136 | period_close=self.period_end, |
||
137 | # don't save the transactions for the cumulative |
||
138 | # period |
||
139 | keep_transactions=False, |
||
140 | keep_orders=False, |
||
141 | # don't serialize positions for cumulative period |
||
142 | serialize_positions=False, |
||
143 | asset_finder=self.env.asset_finder, |
||
144 | ) |
||
145 | self.cumulative_performance.position_tracker = self.position_tracker |
||
146 | |||
147 | # this performance period will span just the current market day |
||
148 | self.todays_performance = PerformancePeriod( |
||
149 | # initial cash is your capital base. |
||
150 | starting_cash=self.capital_base, |
||
151 | # the daily period will be calculated for the market day |
||
152 | period_open=self.market_open, |
||
153 | period_close=self.market_close, |
||
154 | keep_transactions=True, |
||
155 | keep_orders=True, |
||
156 | serialize_positions=True, |
||
157 | asset_finder=self.env.asset_finder, |
||
158 | ) |
||
159 | self.todays_performance.position_tracker = self.position_tracker |
||
160 | |||
161 | self.saved_dt = self.period_start |
||
162 | # one indexed so that we reach 100% |
||
163 | self.day_count = 0.0 |
||
164 | self.txn_count = 0 |
||
165 | |||
166 | self.account_needs_update = True |
||
167 | self._account = None |
||
168 | |||
169 | def __repr__(self): |
||
170 | return "%s(%r)" % ( |
||
171 | self.__class__.__name__, |
||
172 | {'simulation parameters': self.sim_params}) |
||
173 | |||
174 | @property |
||
175 | def progress(self): |
||
176 | if self.emission_rate == 'minute': |
||
177 | # Fake a value |
||
178 | return 1.0 |
||
179 | elif self.emission_rate == 'daily': |
||
180 | return self.day_count / self.total_days |
||
181 | |||
182 | def set_date(self, date): |
||
183 | if self.emission_rate == 'minute': |
||
184 | self.saved_dt = date |
||
185 | self.todays_performance.period_close = self.saved_dt |
||
186 | |||
187 | def update_dividends(self, new_dividends): |
||
188 | """ |
||
189 | Update our dividend frame with new dividends. @new_dividends should be |
||
190 | a DataFrame with columns containing at least the entries in |
||
191 | zipline.protocol.DIVIDEND_FIELDS. |
||
192 | """ |
||
193 | |||
194 | # Mark each new dividend with a unique integer id. This ensures that |
||
195 | # we can differentiate dividends whose date/sid fields are otherwise |
||
196 | # identical. |
||
197 | new_dividends['id'] = np.arange( |
||
198 | self._dividend_count, |
||
199 | self._dividend_count + len(new_dividends), |
||
200 | ) |
||
201 | self._dividend_count += len(new_dividends) |
||
202 | |||
203 | self.dividend_frame = pd.concat( |
||
204 | [self.dividend_frame, new_dividends] |
||
205 | ).sort(['pay_date', 'ex_date']).set_index('id', drop=False) |
||
206 | |||
207 | def initialize_dividends_from_other(self, other): |
||
208 | """ |
||
209 | Helper for copying dividends to a new PerformanceTracker while |
||
210 | preserving dividend count. Useful if a simulation needs to create a |
||
211 | new PerformanceTracker mid-stream and wants to preserve stored dividend |
||
212 | info. |
||
213 | |||
214 | Note that this does not copy unpaid dividends. |
||
215 | """ |
||
216 | self.dividend_frame = other.dividend_frame |
||
217 | self._dividend_count = other._dividend_count |
||
218 | |||
219 | def handle_sid_removed_from_universe(self, sid): |
||
220 | """ |
||
221 | This method handles any behaviors that must occur when a SID leaves the |
||
222 | universe of the TradingAlgorithm. |
||
223 | |||
224 | Parameters |
||
225 | __________ |
||
226 | sid : int |
||
227 | The sid of the Asset being removed from the universe. |
||
228 | """ |
||
229 | |||
230 | # Drop any dividends for the sid from the dividends frame |
||
231 | self.dividend_frame = self.dividend_frame[ |
||
232 | self.dividend_frame.sid != sid |
||
233 | ] |
||
234 | |||
235 | def update_performance(self): |
||
236 | # calculate performance as of last trade |
||
237 | self.cumulative_performance.calculate_performance() |
||
238 | self.todays_performance.calculate_performance() |
||
239 | |||
240 | def get_portfolio(self, performance_needs_update): |
||
241 | if performance_needs_update: |
||
242 | self.update_performance() |
||
243 | self.account_needs_update = True |
||
244 | return self.cumulative_performance.as_portfolio() |
||
245 | |||
246 | def get_account(self, performance_needs_update): |
||
247 | if performance_needs_update: |
||
248 | self.update_performance() |
||
249 | self.account_needs_update = True |
||
250 | if self.account_needs_update: |
||
251 | self._update_account() |
||
252 | return self._account |
||
253 | |||
254 | def _update_account(self): |
||
255 | self._account = self.cumulative_performance.as_account() |
||
256 | self.account_needs_update = False |
||
257 | |||
258 | def to_dict(self, emission_type=None): |
||
259 | """ |
||
260 | Creates a dictionary representing the state of this tracker. |
||
261 | Returns a dict object of the form described in header comments. |
||
262 | """ |
||
263 | |||
264 | # Default to the emission rate of this tracker if no type is provided |
||
265 | if emission_type is None: |
||
266 | emission_type = self.emission_rate |
||
267 | |||
268 | _dict = { |
||
269 | 'period_start': self.period_start, |
||
270 | 'period_end': self.period_end, |
||
271 | 'capital_base': self.capital_base, |
||
272 | 'cumulative_perf': self.cumulative_performance.to_dict(), |
||
273 | 'progress': self.progress, |
||
274 | 'cumulative_risk_metrics': self.cumulative_risk_metrics.to_dict() |
||
275 | } |
||
276 | if emission_type == 'daily': |
||
277 | _dict['daily_perf'] = self.todays_performance.to_dict() |
||
278 | elif emission_type == 'minute': |
||
279 | _dict['minute_perf'] = self.todays_performance.to_dict( |
||
280 | self.saved_dt) |
||
281 | else: |
||
282 | raise ValueError("Invalid emission type: %s" % emission_type) |
||
283 | |||
284 | return _dict |
||
285 | |||
286 | def _handle_event_price(self, event): |
||
287 | self.position_tracker.update_last_sale(event) |
||
288 | |||
289 | def process_trade(self, event): |
||
290 | self._handle_event_price(event) |
||
291 | |||
292 | def process_transaction(self, event): |
||
293 | self._handle_event_price(event) |
||
294 | self.txn_count += 1 |
||
295 | self.cumulative_performance.handle_execution(event) |
||
296 | self.todays_performance.handle_execution(event) |
||
297 | self.position_tracker.execute_transaction(event) |
||
298 | |||
299 | def process_dividend(self, dividend): |
||
300 | |||
301 | log.info("Ignoring DIVIDEND event.") |
||
302 | |||
303 | def process_split(self, event): |
||
304 | leftover_cash = self.position_tracker.handle_split(event) |
||
305 | if leftover_cash > 0: |
||
306 | self.cumulative_performance.handle_cash_payment(leftover_cash) |
||
307 | self.todays_performance.handle_cash_payment(leftover_cash) |
||
308 | |||
309 | def process_order(self, event): |
||
310 | self.cumulative_performance.record_order(event) |
||
311 | self.todays_performance.record_order(event) |
||
312 | |||
313 | def process_commission(self, commission): |
||
314 | sid = commission.sid |
||
315 | cost = commission.cost |
||
316 | |||
317 | self.position_tracker.handle_commission(sid, cost) |
||
318 | self.cumulative_performance.handle_commission(cost) |
||
319 | self.todays_performance.handle_commission(cost) |
||
320 | |||
321 | def process_benchmark(self, event): |
||
322 | if self.sim_params.data_frequency == 'minute' and \ |
||
323 | self.sim_params.emission_rate == 'daily': |
||
324 | # Minute data benchmarks should have a timestamp of market |
||
325 | # close, so that calculations are triggered at the right time. |
||
326 | # However, risk module uses midnight as the 'day' |
||
327 | # marker for returns, so adjust back to midnight. |
||
328 | midnight = pd.tseries.tools.normalize_date(event.dt) |
||
329 | else: |
||
330 | midnight = event.dt |
||
331 | |||
332 | if midnight not in self.all_benchmark_returns.index: |
||
333 | raise AssertionError( |
||
334 | ("Date %s not allocated in all_benchmark_returns. " |
||
335 | "Calendar seems to mismatch with benchmark. " |
||
336 | "Benchmark container is=%s" % |
||
337 | (midnight, |
||
338 | self.all_benchmark_returns.index))) |
||
339 | |||
340 | self.all_benchmark_returns[midnight] = event.returns |
||
341 | |||
342 | def process_close_position(self, event): |
||
343 | |||
344 | # CLOSE_POSITION events that contain prices that must be handled as |
||
345 | # a final trade event |
||
346 | if 'price' in event: |
||
347 | self.process_trade(event) |
||
348 | |||
349 | txn = self.position_tracker.\ |
||
350 | maybe_create_close_position_transaction(event) |
||
351 | if txn: |
||
352 | self.process_transaction(txn) |
||
353 | |||
354 | def check_upcoming_dividends(self, next_trading_day): |
||
355 | """ |
||
356 | Check if we currently own any stocks with dividends whose ex_date is |
||
357 | the next trading day. Track how much we should be payed on those |
||
358 | dividends' pay dates. |
||
359 | |||
360 | Then check if we are owed cash/stock for any dividends whose pay date |
||
361 | is the next trading day. Apply all such benefits, then recalculate |
||
362 | performance. |
||
363 | """ |
||
364 | if len(self.dividend_frame) == 0: |
||
365 | # We don't currently know about any dividends for this simulation |
||
366 | # period, so bail. |
||
367 | return |
||
368 | |||
369 | # Dividends whose ex_date is the next trading day. We need to check if |
||
370 | # we own any of these stocks so we know to pay them out when the pay |
||
371 | # date comes. |
||
372 | ex_date_mask = (self.dividend_frame['ex_date'] == next_trading_day) |
||
373 | dividends_earnable = self.dividend_frame[ex_date_mask] |
||
374 | |||
375 | # Dividends whose pay date is the next trading day. If we held any of |
||
376 | # these stocks on midnight before the ex_date, we need to pay these out |
||
377 | # now. |
||
378 | pay_date_mask = (self.dividend_frame['pay_date'] == next_trading_day) |
||
379 | dividends_payable = self.dividend_frame[pay_date_mask] |
||
380 | |||
381 | position_tracker = self.position_tracker |
||
382 | if len(dividends_earnable): |
||
383 | position_tracker.earn_dividends(dividends_earnable) |
||
384 | |||
385 | if not len(dividends_payable): |
||
386 | return |
||
387 | |||
388 | net_cash_payment = position_tracker.pay_dividends(dividends_payable) |
||
389 | |||
390 | self.cumulative_performance.handle_dividends_paid(net_cash_payment) |
||
391 | self.todays_performance.handle_dividends_paid(net_cash_payment) |
||
392 | |||
393 | def check_asset_auto_closes(self, next_trading_day): |
||
394 | """ |
||
395 | Check if the position tracker currently owns any Assets with an |
||
396 | auto-close date that is the next trading day. Close those positions. |
||
397 | |||
398 | Parameters |
||
399 | ---------- |
||
400 | next_trading_day : pandas.Timestamp |
||
401 | The next trading day of the simulation |
||
402 | """ |
||
403 | auto_close_events = self.position_tracker.auto_close_position_events( |
||
404 | next_trading_day=next_trading_day |
||
405 | ) |
||
406 | for event in auto_close_events: |
||
407 | self.process_close_position(event) |
||
408 | |||
409 | def handle_minute_close(self, dt): |
||
410 | """ |
||
411 | Handles the close of the given minute. This includes handling |
||
412 | market-close functions if the given minute is the end of the market |
||
413 | day. |
||
414 | |||
415 | Parameters |
||
416 | __________ |
||
417 | dt : Timestamp |
||
418 | The minute that is ending |
||
419 | |||
420 | Returns |
||
421 | _______ |
||
422 | (dict, dict/None) |
||
423 | A tuple of the minute perf packet and daily perf packet. |
||
424 | If the market day has not ended, the daily perf packet is None. |
||
425 | """ |
||
426 | self.update_performance() |
||
427 | todays_date = normalize_date(dt) |
||
428 | account = self.get_account(False) |
||
429 | |||
430 | bench_returns = self.all_benchmark_returns.loc[todays_date:dt] |
||
431 | # cumulative returns |
||
432 | bench_since_open = (1. + bench_returns).prod() - 1 |
||
433 | |||
434 | self.cumulative_risk_metrics.update(todays_date, |
||
435 | self.todays_performance.returns, |
||
436 | bench_since_open, |
||
437 | account.leverage) |
||
438 | |||
439 | minute_packet = self.to_dict(emission_type='minute') |
||
440 | |||
441 | # if this is the close, update dividends for the next day. |
||
442 | # Return the performance tuple |
||
443 | if dt == self.market_close: |
||
444 | return (minute_packet, self._handle_market_close(todays_date)) |
||
445 | else: |
||
446 | return (minute_packet, None) |
||
447 | |||
448 | def handle_market_close_daily(self): |
||
449 | """ |
||
450 | Function called after handle_data when running with daily emission |
||
451 | rate. |
||
452 | """ |
||
453 | self.update_performance() |
||
454 | completed_date = self.day |
||
455 | account = self.get_account(False) |
||
456 | |||
457 | # update risk metrics for cumulative performance |
||
458 | self.cumulative_risk_metrics.update( |
||
459 | completed_date, |
||
460 | self.todays_performance.returns, |
||
461 | self.all_benchmark_returns[completed_date], |
||
462 | account.leverage) |
||
463 | |||
464 | return self._handle_market_close(completed_date) |
||
465 | |||
466 | def _handle_market_close(self, completed_date): |
||
467 | |||
468 | # increment the day counter before we move markers forward. |
||
469 | self.day_count += 1.0 |
||
470 | |||
471 | # Get the next trading day and, if it is past the bounds of this |
||
472 | # simulation, return the daily perf packet |
||
473 | next_trading_day = self.env.next_trading_day(completed_date) |
||
474 | |||
475 | # Check if any assets need to be auto-closed before generating today's |
||
476 | # perf period |
||
477 | if next_trading_day: |
||
478 | self.check_asset_auto_closes(next_trading_day=next_trading_day) |
||
479 | |||
480 | # Take a snapshot of our current performance to return to the |
||
481 | # browser. |
||
482 | daily_update = self.to_dict(emission_type='daily') |
||
483 | |||
484 | # On the last day of the test, don't create tomorrow's performance |
||
485 | # period. We may not be able to find the next trading day if we're at |
||
486 | # the end of our historical data |
||
487 | if self.market_close >= self.last_close: |
||
488 | return daily_update |
||
489 | |||
490 | # move the market day markers forward |
||
491 | self.market_open, self.market_close = \ |
||
492 | self.env.next_open_and_close(self.day) |
||
493 | self.day = self.env.next_trading_day(self.day) |
||
494 | |||
495 | # Roll over positions to current day. |
||
496 | self.todays_performance.rollover() |
||
497 | self.todays_performance.period_open = self.market_open |
||
498 | self.todays_performance.period_close = self.market_close |
||
499 | |||
500 | # If the next trading day is irrelevant, then return the daily packet |
||
501 | if (next_trading_day is None) or (next_trading_day >= self.last_close): |
||
502 | return daily_update |
||
503 | |||
504 | # Check for any dividends and auto-closes, then return the daily perf |
||
505 | # packet |
||
506 | self.check_upcoming_dividends(next_trading_day=next_trading_day) |
||
507 | return daily_update |
||
508 | |||
509 | def handle_simulation_end(self): |
||
510 | """ |
||
511 | When the simulation is complete, run the full period risk report |
||
512 | and send it out on the results socket. |
||
513 | """ |
||
514 | |||
515 | log_msg = "Simulated {n} trading days out of {m}." |
||
516 | log.info(log_msg.format(n=int(self.day_count), m=self.total_days)) |
||
517 | log.info("first open: {d}".format( |
||
518 | d=self.sim_params.first_open)) |
||
519 | log.info("last close: {d}".format( |
||
520 | d=self.sim_params.last_close)) |
||
521 | |||
522 | bms = pd.Series( |
||
523 | index=self.cumulative_risk_metrics.cont_index, |
||
524 | data=self.cumulative_risk_metrics.benchmark_returns_cont) |
||
525 | ars = pd.Series( |
||
526 | index=self.cumulative_risk_metrics.cont_index, |
||
527 | data=self.cumulative_risk_metrics.algorithm_returns_cont) |
||
528 | acl = self.cumulative_risk_metrics.algorithm_cumulative_leverages |
||
529 | self.risk_report = risk.RiskReport( |
||
530 | ars, |
||
531 | self.sim_params, |
||
532 | benchmark_returns=bms, |
||
533 | algorithm_leverages=acl, |
||
534 | env=self.env) |
||
535 | |||
536 | risk_dict = self.risk_report.to_dict() |
||
537 | return risk_dict |
||
538 | |||
539 | def __getstate__(self): |
||
540 | state_dict = \ |
||
541 | {k: v for k, v in iteritems(self.__dict__) |
||
542 | if not k.startswith('_')} |
||
543 | |||
544 | state_dict['dividend_frame'] = pickle.dumps(self.dividend_frame) |
||
545 | |||
546 | state_dict['_dividend_count'] = self._dividend_count |
||
547 | |||
548 | STATE_VERSION = 4 |
||
549 | state_dict[VERSION_LABEL] = STATE_VERSION |
||
550 | |||
551 | return state_dict |
||
552 | |||
553 | def __setstate__(self, state): |
||
554 | |||
555 | OLDEST_SUPPORTED_STATE = 4 |
||
556 | version = state.pop(VERSION_LABEL) |
||
557 | |||
558 | if version < OLDEST_SUPPORTED_STATE: |
||
559 | raise BaseException("PerformanceTracker saved state is too old.") |
||
560 | |||
561 | self.__dict__.update(state) |
||
562 | |||
563 | # Handle the dividend frame specially |
||
564 | self.dividend_frame = pickle.loads(state['dividend_frame']) |
||
565 | |||
566 | # properly setup the perf periods |
||
567 | p_types = ['cumulative', 'todays'] |
||
568 | for p_type in p_types: |
||
569 | name = p_type + '_performance' |
||
570 | period = getattr(self, name, None) |
||
571 | if period is None: |
||
572 | continue |
||
573 | period._position_tracker = self.position_tracker |
||
574 |