Total Complexity | 44 |
Total Lines | 339 |
Duplicated Lines | 0 % |
Complex classes like zipline.finance.performance.PerformancePeriod often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
1 | # |
||
151 | class PerformancePeriod(object): |
||
152 | |||
153 | def __init__( |
||
154 | self, |
||
155 | starting_cash, |
||
156 | asset_finder, |
||
157 | period_open=None, |
||
158 | period_close=None, |
||
159 | keep_transactions=True, |
||
160 | keep_orders=False, |
||
161 | serialize_positions=True, |
||
162 | name=None): |
||
163 | |||
164 | self.asset_finder = asset_finder |
||
165 | |||
166 | self.period_open = period_open |
||
167 | self.period_close = period_close |
||
168 | |||
169 | self.period_cash_flow = 0.0 |
||
170 | |||
171 | self.starting_cash = starting_cash |
||
172 | self.starting_value = 0.0 |
||
173 | self.starting_exposure = 0.0 |
||
174 | |||
175 | self.keep_transactions = keep_transactions |
||
176 | self.keep_orders = keep_orders |
||
177 | |||
178 | self.processed_transactions = {} |
||
179 | self.orders_by_modified = {} |
||
180 | self.orders_by_id = OrderedDict() |
||
181 | |||
182 | self.name = name |
||
183 | |||
184 | # An object to recycle via assigning new values |
||
185 | # when returning portfolio information. |
||
186 | # So as not to avoid creating a new object for each event |
||
187 | self._portfolio_store = zp.Portfolio() |
||
188 | self._account_store = zp.Account() |
||
189 | self.serialize_positions = serialize_positions |
||
190 | |||
191 | # This dict contains the known cash flow multipliers for sids and is |
||
192 | # keyed on sid |
||
193 | self._execution_cash_flow_multipliers = {} |
||
194 | |||
195 | def rollover(self, pos_stats, prev_period_stats): |
||
196 | self.starting_value = pos_stats.net_value |
||
197 | self.starting_exposure = pos_stats.net_exposure |
||
198 | self.starting_cash = prev_period_stats.ending_cash |
||
199 | self.period_cash_flow = 0.0 |
||
200 | self.processed_transactions = {} |
||
201 | self.orders_by_modified = {} |
||
202 | self.orders_by_id = OrderedDict() |
||
203 | |||
204 | def handle_dividends_paid(self, net_cash_payment): |
||
205 | if net_cash_payment: |
||
206 | self.handle_cash_payment(net_cash_payment) |
||
207 | |||
208 | def handle_cash_payment(self, payment_amount): |
||
209 | self.adjust_cash(payment_amount) |
||
210 | |||
211 | def handle_commission(self, cost): |
||
212 | # Deduct from our total cash pool. |
||
213 | self.adjust_cash(-cost) |
||
214 | |||
215 | def adjust_cash(self, amount): |
||
216 | self.period_cash_flow += amount |
||
217 | |||
218 | def adjust_field(self, field, value): |
||
219 | setattr(self, field, value) |
||
220 | |||
221 | def record_order(self, order): |
||
222 | if self.keep_orders: |
||
223 | try: |
||
224 | dt_orders = self.orders_by_modified[order.dt] |
||
225 | if order.id in dt_orders: |
||
226 | del dt_orders[order.id] |
||
227 | except KeyError: |
||
228 | self.orders_by_modified[order.dt] = dt_orders = OrderedDict() |
||
229 | dt_orders[order.id] = order |
||
230 | # to preserve the order of the orders by modified date |
||
231 | # we delete and add back. (ordered dictionary is sorted by |
||
232 | # first insertion date). |
||
233 | if order.id in self.orders_by_id: |
||
234 | del self.orders_by_id[order.id] |
||
235 | self.orders_by_id[order.id] = order |
||
236 | |||
237 | def handle_execution(self, txn): |
||
238 | self.period_cash_flow += self._calculate_execution_cash_flow(txn) |
||
239 | |||
240 | if self.keep_transactions: |
||
241 | try: |
||
242 | self.processed_transactions[txn.dt].append(txn) |
||
243 | except KeyError: |
||
244 | self.processed_transactions[txn.dt] = [txn] |
||
245 | |||
246 | def _calculate_execution_cash_flow(self, txn): |
||
247 | """ |
||
248 | Calculates the cash flow from executing the given transaction |
||
249 | """ |
||
250 | # Check if the multiplier is cached. If it is not, look up the asset |
||
251 | # and cache the multiplier. |
||
252 | try: |
||
253 | multiplier = self._execution_cash_flow_multipliers[txn.sid] |
||
254 | except KeyError: |
||
255 | asset = self.asset_finder.retrieve_asset(txn.sid) |
||
256 | # Futures experience no cash flow on transactions |
||
257 | if isinstance(asset, Future): |
||
258 | multiplier = 0 |
||
259 | else: |
||
260 | multiplier = 1 |
||
261 | self._execution_cash_flow_multipliers[txn.sid] = multiplier |
||
262 | |||
263 | # Calculate and return the cash flow given the multiplier |
||
264 | return -1 * txn.price * txn.amount * multiplier |
||
265 | |||
266 | def stats(self, positions, pos_stats, data_portal): |
||
267 | # TODO: passing positions here seems off, since we have already |
||
268 | # calculated pos_stats. |
||
269 | futures_payouts = [] |
||
270 | for sid, pos in iteritems(positions): |
||
271 | asset = self.asset_finder.retrieve_asset(sid) |
||
272 | if isinstance(asset, Future): |
||
273 | old_price_dt = max(pos.last_sale_date, self.period_open) |
||
274 | |||
275 | if old_price_dt == pos.last_sale_date: |
||
276 | continue |
||
277 | |||
278 | old_price = data_portal.get_previous_value( |
||
279 | sid, 'close', dt=old_price_dt |
||
280 | ) |
||
281 | |||
282 | price = data_portal.get_spot_value( |
||
283 | sid, 'close', dt=self.period_close |
||
284 | ) |
||
285 | |||
286 | payout = ( |
||
287 | (price - old_price) |
||
288 | * |
||
289 | asset.contract_multiplier |
||
290 | * |
||
291 | pos.amount |
||
292 | ) |
||
293 | futures_payouts.append(payout) |
||
294 | |||
295 | futures_payout = sum(futures_payouts) |
||
296 | |||
297 | return calc_period_stats( |
||
298 | pos_stats, |
||
299 | self.starting_cash, |
||
300 | self.starting_value, |
||
301 | self.period_cash_flow, |
||
302 | futures_payout |
||
303 | ) |
||
304 | |||
305 | def __core_dict(self, pos_stats, period_stats): |
||
306 | rval = { |
||
307 | 'ending_value': pos_stats.net_value, |
||
308 | 'ending_exposure': pos_stats.net_exposure, |
||
309 | # this field is renamed to capital_used for backward |
||
310 | # compatibility. |
||
311 | 'capital_used': self.period_cash_flow, |
||
312 | 'starting_value': self.starting_value, |
||
313 | 'starting_exposure': self.starting_exposure, |
||
314 | 'starting_cash': self.starting_cash, |
||
315 | 'ending_cash': period_stats.ending_cash, |
||
316 | 'portfolio_value': period_stats.portfolio_value, |
||
317 | 'pnl': period_stats.pnl, |
||
318 | 'returns': period_stats.returns, |
||
319 | 'period_open': self.period_open, |
||
320 | 'period_close': self.period_close, |
||
321 | 'gross_leverage': period_stats.gross_leverage, |
||
322 | 'net_leverage': period_stats.net_leverage, |
||
323 | 'short_exposure': pos_stats.short_exposure, |
||
324 | 'long_exposure': pos_stats.long_exposure, |
||
325 | 'short_value': pos_stats.short_value, |
||
326 | 'long_value': pos_stats.long_value, |
||
327 | 'longs_count': pos_stats.longs_count, |
||
328 | 'shorts_count': pos_stats.shorts_count, |
||
329 | } |
||
330 | |||
331 | return rval |
||
332 | |||
333 | def to_dict(self, pos_stats, period_stats, position_tracker, dt=None): |
||
334 | """ |
||
335 | Creates a dictionary representing the state of this performance |
||
336 | period. See header comments for a detailed description. |
||
337 | |||
338 | Kwargs: |
||
339 | dt (datetime): If present, only return transactions for the dt. |
||
340 | """ |
||
341 | rval = self.__core_dict(pos_stats, period_stats) |
||
342 | |||
343 | if self.serialize_positions: |
||
344 | positions = position_tracker.get_positions_list() |
||
345 | rval['positions'] = positions |
||
346 | |||
347 | # we want the key to be absent, not just empty |
||
348 | if self.keep_transactions: |
||
349 | if dt: |
||
350 | # Only include transactions for given dt |
||
351 | try: |
||
352 | transactions = [x.to_dict() |
||
353 | for x in self.processed_transactions[dt]] |
||
354 | except KeyError: |
||
355 | transactions = [] |
||
356 | else: |
||
357 | transactions = \ |
||
358 | [y.to_dict() |
||
359 | for x in itervalues(self.processed_transactions) |
||
360 | for y in x] |
||
361 | rval['transactions'] = transactions |
||
362 | |||
363 | if self.keep_orders: |
||
364 | if dt: |
||
365 | # only include orders modified as of the given dt. |
||
366 | try: |
||
367 | orders = [x.to_dict() |
||
368 | for x in itervalues(self.orders_by_modified[dt])] |
||
369 | except KeyError: |
||
370 | orders = [] |
||
371 | else: |
||
372 | orders = [x.to_dict() for x in itervalues(self.orders_by_id)] |
||
373 | rval['orders'] = orders |
||
374 | |||
375 | return rval |
||
376 | |||
377 | def as_portfolio(self, pos_stats, period_stats, position_tracker, dt): |
||
378 | """ |
||
379 | The purpose of this method is to provide a portfolio |
||
380 | object to algorithms running inside the same trading |
||
381 | client. The data needed is captured raw in a |
||
382 | PerformancePeriod, and in this method we rename some |
||
383 | fields for usability and remove extraneous fields. |
||
384 | """ |
||
385 | # Recycles containing objects' Portfolio object |
||
386 | # which is used for returning values. |
||
387 | # as_portfolio is called in an inner loop, |
||
388 | # so repeated object creation becomes too expensive |
||
389 | portfolio = self._portfolio_store |
||
390 | # maintaining the old name for the portfolio field for |
||
391 | # backward compatibility |
||
392 | portfolio.capital_used = self.period_cash_flow |
||
393 | portfolio.starting_cash = self.starting_cash |
||
394 | portfolio.portfolio_value = period_stats.portfolio_value |
||
395 | portfolio.pnl = period_stats.pnl |
||
396 | portfolio.returns = period_stats.returns |
||
397 | portfolio.cash = period_stats.ending_cash |
||
398 | portfolio.start_date = self.period_open |
||
399 | portfolio.positions = position_tracker.get_positions() |
||
400 | portfolio.positions_value = pos_stats.net_value |
||
401 | portfolio.positions_exposure = pos_stats.net_exposure |
||
402 | return portfolio |
||
403 | |||
404 | def as_account(self, pos_stats, period_stats): |
||
405 | account = self._account_store |
||
406 | |||
407 | # If no attribute is found on the PerformancePeriod resort to the |
||
408 | # following default values. If an attribute is found use the existing |
||
409 | # value. For instance, a broker may provide updates to these |
||
410 | # attributes. In this case we do not want to over write the broker |
||
411 | # values with the default values. |
||
412 | account.settled_cash = \ |
||
413 | getattr(self, 'settled_cash', period_stats.ending_cash) |
||
414 | account.accrued_interest = \ |
||
415 | getattr(self, 'accrued_interest', 0.0) |
||
416 | account.buying_power = \ |
||
417 | getattr(self, 'buying_power', float('inf')) |
||
418 | account.equity_with_loan = \ |
||
419 | getattr(self, 'equity_with_loan', period_stats.portfolio_value) |
||
420 | account.total_positions_value = \ |
||
421 | getattr(self, 'total_positions_value', pos_stats.net_value) |
||
422 | account.total_positions_value = \ |
||
423 | getattr(self, 'total_positions_exposure', pos_stats.net_exposure) |
||
424 | account.regt_equity = \ |
||
425 | getattr(self, 'regt_equity', period_stats.ending_cash) |
||
426 | account.regt_margin = \ |
||
427 | getattr(self, 'regt_margin', float('inf')) |
||
428 | account.initial_margin_requirement = \ |
||
429 | getattr(self, 'initial_margin_requirement', 0.0) |
||
430 | account.maintenance_margin_requirement = \ |
||
431 | getattr(self, 'maintenance_margin_requirement', 0.0) |
||
432 | account.available_funds = \ |
||
433 | getattr(self, 'available_funds', period_stats.ending_cash) |
||
434 | account.excess_liquidity = \ |
||
435 | getattr(self, 'excess_liquidity', period_stats.ending_cash) |
||
436 | account.cushion = \ |
||
437 | getattr(self, 'cushion', |
||
438 | period_stats.ending_cash / period_stats.portfolio_value) |
||
439 | account.day_trades_remaining = \ |
||
440 | getattr(self, 'day_trades_remaining', float('inf')) |
||
441 | account.leverage = getattr(self, 'leverage', |
||
442 | period_stats.gross_leverage) |
||
443 | account.net_leverage = period_stats.net_leverage |
||
444 | |||
445 | account.net_liquidation = getattr(self, 'net_liquidation', |
||
446 | period_stats.net_liquidation) |
||
447 | return account |
||
448 | |||
449 | def __getstate__(self): |
||
450 | state_dict = {k: v for k, v in iteritems(self.__dict__) |
||
451 | if not k.startswith('_')} |
||
452 | |||
453 | state_dict['_portfolio_store'] = self._portfolio_store |
||
454 | state_dict['_account_store'] = self._account_store |
||
455 | |||
456 | state_dict['processed_transactions'] = \ |
||
457 | dict(self.processed_transactions) |
||
458 | state_dict['orders_by_id'] = \ |
||
459 | dict(self.orders_by_id) |
||
460 | state_dict['orders_by_modified'] = \ |
||
461 | dict(self.orders_by_modified) |
||
462 | |||
463 | STATE_VERSION = 3 |
||
464 | state_dict[VERSION_LABEL] = STATE_VERSION |
||
465 | return state_dict |
||
466 | |||
467 | def __setstate__(self, state): |
||
468 | |||
469 | OLDEST_SUPPORTED_STATE = 3 |
||
470 | version = state.pop(VERSION_LABEL) |
||
471 | |||
472 | if version < OLDEST_SUPPORTED_STATE: |
||
473 | raise BaseException("PerformancePeriod saved state is too old.") |
||
474 | |||
475 | processed_transactions = {} |
||
476 | processed_transactions.update(state.pop('processed_transactions')) |
||
477 | |||
478 | orders_by_id = OrderedDict() |
||
479 | orders_by_id.update(state.pop('orders_by_id')) |
||
480 | |||
481 | orders_by_modified = {} |
||
482 | orders_by_modified.update(state.pop('orders_by_modified')) |
||
483 | self.processed_transactions = processed_transactions |
||
484 | self.orders_by_id = orders_by_id |
||
485 | self.orders_by_modified = orders_by_modified |
||
486 | |||
487 | self._execution_cash_flow_multipliers = {} |
||
488 | |||
489 | self.__dict__.update(state) |
||
490 |