Total Complexity | 44 |
Total Lines | 342 |
Duplicated Lines | 0 % |
Complex classes like zipline.finance.performance.PerformancePeriod often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
1 | # |
||
151 | class PerformancePeriod(object): |
||
152 | |||
153 | def __init__( |
||
154 | self, |
||
155 | starting_cash, |
||
156 | asset_finder, |
||
157 | data_frequency, |
||
158 | period_open=None, |
||
159 | period_close=None, |
||
160 | keep_transactions=True, |
||
161 | keep_orders=False, |
||
162 | serialize_positions=True, |
||
163 | name=None): |
||
164 | |||
165 | self.asset_finder = asset_finder |
||
166 | self.data_frequency = data_frequency |
||
167 | |||
168 | self.period_open = period_open |
||
169 | self.period_close = period_close |
||
170 | |||
171 | self.period_cash_flow = 0.0 |
||
172 | |||
173 | self.starting_cash = starting_cash |
||
174 | self.starting_value = 0.0 |
||
175 | self.starting_exposure = 0.0 |
||
176 | |||
177 | self.keep_transactions = keep_transactions |
||
178 | self.keep_orders = keep_orders |
||
179 | |||
180 | self.processed_transactions = {} |
||
181 | self.orders_by_modified = {} |
||
182 | self.orders_by_id = OrderedDict() |
||
183 | |||
184 | self.name = name |
||
185 | |||
186 | # An object to recycle via assigning new values |
||
187 | # when returning portfolio information. |
||
188 | # So as not to avoid creating a new object for each event |
||
189 | self._portfolio_store = zp.Portfolio() |
||
190 | self._account_store = zp.Account() |
||
191 | self.serialize_positions = serialize_positions |
||
192 | |||
193 | # This dict contains the known cash flow multipliers for sids and is |
||
194 | # keyed on sid |
||
195 | self._execution_cash_flow_multipliers = {} |
||
196 | |||
197 | def rollover(self, pos_stats, prev_period_stats): |
||
198 | self.starting_value = pos_stats.net_value |
||
199 | self.starting_exposure = pos_stats.net_exposure |
||
200 | self.starting_cash = prev_period_stats.ending_cash |
||
201 | self.period_cash_flow = 0.0 |
||
202 | self.processed_transactions = {} |
||
203 | self.orders_by_modified = {} |
||
204 | self.orders_by_id = OrderedDict() |
||
205 | |||
206 | def handle_dividends_paid(self, net_cash_payment): |
||
207 | if net_cash_payment: |
||
208 | self.handle_cash_payment(net_cash_payment) |
||
209 | |||
210 | def handle_cash_payment(self, payment_amount): |
||
211 | self.adjust_cash(payment_amount) |
||
212 | |||
213 | def handle_commission(self, cost): |
||
214 | # Deduct from our total cash pool. |
||
215 | self.adjust_cash(-cost) |
||
216 | |||
217 | def adjust_cash(self, amount): |
||
218 | self.period_cash_flow += amount |
||
219 | |||
220 | def adjust_field(self, field, value): |
||
221 | setattr(self, field, value) |
||
222 | |||
223 | def record_order(self, order): |
||
224 | if self.keep_orders: |
||
225 | try: |
||
226 | dt_orders = self.orders_by_modified[order.dt] |
||
227 | if order.id in dt_orders: |
||
228 | del dt_orders[order.id] |
||
229 | except KeyError: |
||
230 | self.orders_by_modified[order.dt] = dt_orders = OrderedDict() |
||
231 | dt_orders[order.id] = order |
||
232 | # to preserve the order of the orders by modified date |
||
233 | # we delete and add back. (ordered dictionary is sorted by |
||
234 | # first insertion date). |
||
235 | if order.id in self.orders_by_id: |
||
236 | del self.orders_by_id[order.id] |
||
237 | self.orders_by_id[order.id] = order |
||
238 | |||
239 | def handle_execution(self, txn): |
||
240 | self.period_cash_flow += self._calculate_execution_cash_flow(txn) |
||
241 | |||
242 | if self.keep_transactions: |
||
243 | try: |
||
244 | self.processed_transactions[txn.dt].append(txn) |
||
245 | except KeyError: |
||
246 | self.processed_transactions[txn.dt] = [txn] |
||
247 | |||
248 | def _calculate_execution_cash_flow(self, txn): |
||
249 | """ |
||
250 | Calculates the cash flow from executing the given transaction |
||
251 | """ |
||
252 | # Check if the multiplier is cached. If it is not, look up the asset |
||
253 | # and cache the multiplier. |
||
254 | try: |
||
255 | multiplier = self._execution_cash_flow_multipliers[txn.sid] |
||
256 | except KeyError: |
||
257 | asset = self.asset_finder.retrieve_asset(txn.sid) |
||
258 | # Futures experience no cash flow on transactions |
||
259 | if isinstance(asset, Future): |
||
260 | multiplier = 0 |
||
261 | else: |
||
262 | multiplier = 1 |
||
263 | self._execution_cash_flow_multipliers[txn.sid] = multiplier |
||
264 | |||
265 | # Calculate and return the cash flow given the multiplier |
||
266 | return -1 * txn.price * txn.amount * multiplier |
||
267 | |||
268 | def stats(self, positions, pos_stats, data_portal): |
||
269 | # TODO: passing positions here seems off, since we have already |
||
270 | # calculated pos_stats. |
||
271 | futures_payouts = [] |
||
272 | for sid, pos in iteritems(positions): |
||
273 | asset = self.asset_finder.retrieve_asset(sid) |
||
274 | if isinstance(asset, Future): |
||
275 | old_price_dt = max(pos.last_sale_date, self.period_open) |
||
276 | |||
277 | if old_price_dt == pos.last_sale_date: |
||
278 | continue |
||
279 | |||
280 | old_price = data_portal.get_previous_value( |
||
281 | sid, 'close', old_price_dt, self.data_frequency |
||
282 | ) |
||
283 | |||
284 | price = data_portal.get_spot_value( |
||
285 | sid, 'close', self.period_close, self.data_frequency, |
||
286 | ) |
||
287 | |||
288 | payout = ( |
||
289 | (price - old_price) |
||
290 | * |
||
291 | asset.contract_multiplier |
||
292 | * |
||
293 | pos.amount |
||
294 | ) |
||
295 | futures_payouts.append(payout) |
||
296 | |||
297 | futures_payout = sum(futures_payouts) |
||
298 | |||
299 | return calc_period_stats( |
||
300 | pos_stats, |
||
301 | self.starting_cash, |
||
302 | self.starting_value, |
||
303 | self.period_cash_flow, |
||
304 | futures_payout |
||
305 | ) |
||
306 | |||
307 | def __core_dict(self, pos_stats, period_stats): |
||
308 | rval = { |
||
309 | 'ending_value': pos_stats.net_value, |
||
310 | 'ending_exposure': pos_stats.net_exposure, |
||
311 | # this field is renamed to capital_used for backward |
||
312 | # compatibility. |
||
313 | 'capital_used': self.period_cash_flow, |
||
314 | 'starting_value': self.starting_value, |
||
315 | 'starting_exposure': self.starting_exposure, |
||
316 | 'starting_cash': self.starting_cash, |
||
317 | 'ending_cash': period_stats.ending_cash, |
||
318 | 'portfolio_value': period_stats.portfolio_value, |
||
319 | 'pnl': period_stats.pnl, |
||
320 | 'returns': period_stats.returns, |
||
321 | 'period_open': self.period_open, |
||
322 | 'period_close': self.period_close, |
||
323 | 'gross_leverage': period_stats.gross_leverage, |
||
324 | 'net_leverage': period_stats.net_leverage, |
||
325 | 'short_exposure': pos_stats.short_exposure, |
||
326 | 'long_exposure': pos_stats.long_exposure, |
||
327 | 'short_value': pos_stats.short_value, |
||
328 | 'long_value': pos_stats.long_value, |
||
329 | 'longs_count': pos_stats.longs_count, |
||
330 | 'shorts_count': pos_stats.shorts_count, |
||
331 | } |
||
332 | |||
333 | return rval |
||
334 | |||
335 | def to_dict(self, pos_stats, period_stats, position_tracker, dt=None): |
||
336 | """ |
||
337 | Creates a dictionary representing the state of this performance |
||
338 | period. See header comments for a detailed description. |
||
339 | |||
340 | Kwargs: |
||
341 | dt (datetime): If present, only return transactions for the dt. |
||
342 | """ |
||
343 | rval = self.__core_dict(pos_stats, period_stats) |
||
344 | |||
345 | if self.serialize_positions: |
||
346 | positions = position_tracker.get_positions_list() |
||
347 | rval['positions'] = positions |
||
348 | |||
349 | # we want the key to be absent, not just empty |
||
350 | if self.keep_transactions: |
||
351 | if dt: |
||
352 | # Only include transactions for given dt |
||
353 | try: |
||
354 | transactions = [x.to_dict() |
||
355 | for x in self.processed_transactions[dt]] |
||
356 | except KeyError: |
||
357 | transactions = [] |
||
358 | else: |
||
359 | transactions = \ |
||
360 | [y.to_dict() |
||
361 | for x in itervalues(self.processed_transactions) |
||
362 | for y in x] |
||
363 | rval['transactions'] = transactions |
||
364 | |||
365 | if self.keep_orders: |
||
366 | if dt: |
||
367 | # only include orders modified as of the given dt. |
||
368 | try: |
||
369 | orders = [x.to_dict() |
||
370 | for x in itervalues(self.orders_by_modified[dt])] |
||
371 | except KeyError: |
||
372 | orders = [] |
||
373 | else: |
||
374 | orders = [x.to_dict() for x in itervalues(self.orders_by_id)] |
||
375 | rval['orders'] = orders |
||
376 | |||
377 | return rval |
||
378 | |||
379 | def as_portfolio(self, pos_stats, period_stats, position_tracker, dt): |
||
380 | """ |
||
381 | The purpose of this method is to provide a portfolio |
||
382 | object to algorithms running inside the same trading |
||
383 | client. The data needed is captured raw in a |
||
384 | PerformancePeriod, and in this method we rename some |
||
385 | fields for usability and remove extraneous fields. |
||
386 | """ |
||
387 | # Recycles containing objects' Portfolio object |
||
388 | # which is used for returning values. |
||
389 | # as_portfolio is called in an inner loop, |
||
390 | # so repeated object creation becomes too expensive |
||
391 | portfolio = self._portfolio_store |
||
392 | # maintaining the old name for the portfolio field for |
||
393 | # backward compatibility |
||
394 | portfolio.capital_used = self.period_cash_flow |
||
395 | portfolio.starting_cash = self.starting_cash |
||
396 | portfolio.portfolio_value = period_stats.portfolio_value |
||
397 | portfolio.pnl = period_stats.pnl |
||
398 | portfolio.returns = period_stats.returns |
||
399 | portfolio.cash = period_stats.ending_cash |
||
400 | portfolio.start_date = self.period_open |
||
401 | portfolio.positions = position_tracker.get_positions() |
||
402 | portfolio.positions_value = pos_stats.net_value |
||
403 | portfolio.positions_exposure = pos_stats.net_exposure |
||
404 | return portfolio |
||
405 | |||
406 | def as_account(self, pos_stats, period_stats): |
||
407 | account = self._account_store |
||
408 | |||
409 | # If no attribute is found on the PerformancePeriod resort to the |
||
410 | # following default values. If an attribute is found use the existing |
||
411 | # value. For instance, a broker may provide updates to these |
||
412 | # attributes. In this case we do not want to over write the broker |
||
413 | # values with the default values. |
||
414 | account.settled_cash = \ |
||
415 | getattr(self, 'settled_cash', period_stats.ending_cash) |
||
416 | account.accrued_interest = \ |
||
417 | getattr(self, 'accrued_interest', 0.0) |
||
418 | account.buying_power = \ |
||
419 | getattr(self, 'buying_power', float('inf')) |
||
420 | account.equity_with_loan = \ |
||
421 | getattr(self, 'equity_with_loan', period_stats.portfolio_value) |
||
422 | account.total_positions_value = \ |
||
423 | getattr(self, 'total_positions_value', pos_stats.net_value) |
||
424 | account.total_positions_value = \ |
||
425 | getattr(self, 'total_positions_exposure', pos_stats.net_exposure) |
||
426 | account.regt_equity = \ |
||
427 | getattr(self, 'regt_equity', period_stats.ending_cash) |
||
428 | account.regt_margin = \ |
||
429 | getattr(self, 'regt_margin', float('inf')) |
||
430 | account.initial_margin_requirement = \ |
||
431 | getattr(self, 'initial_margin_requirement', 0.0) |
||
432 | account.maintenance_margin_requirement = \ |
||
433 | getattr(self, 'maintenance_margin_requirement', 0.0) |
||
434 | account.available_funds = \ |
||
435 | getattr(self, 'available_funds', period_stats.ending_cash) |
||
436 | account.excess_liquidity = \ |
||
437 | getattr(self, 'excess_liquidity', period_stats.ending_cash) |
||
438 | account.cushion = \ |
||
439 | getattr(self, 'cushion', |
||
440 | period_stats.ending_cash / period_stats.portfolio_value) |
||
441 | account.day_trades_remaining = \ |
||
442 | getattr(self, 'day_trades_remaining', float('inf')) |
||
443 | account.leverage = getattr(self, 'leverage', |
||
444 | period_stats.gross_leverage) |
||
445 | account.net_leverage = period_stats.net_leverage |
||
446 | |||
447 | account.net_liquidation = getattr(self, 'net_liquidation', |
||
448 | period_stats.net_liquidation) |
||
449 | return account |
||
450 | |||
451 | def __getstate__(self): |
||
452 | state_dict = {k: v for k, v in iteritems(self.__dict__) |
||
453 | if not k.startswith('_')} |
||
454 | |||
455 | state_dict['_portfolio_store'] = self._portfolio_store |
||
456 | state_dict['_account_store'] = self._account_store |
||
457 | state_dict['data_frequency'] = self.data_frequency |
||
458 | |||
459 | state_dict['processed_transactions'] = \ |
||
460 | dict(self.processed_transactions) |
||
461 | state_dict['orders_by_id'] = \ |
||
462 | dict(self.orders_by_id) |
||
463 | state_dict['orders_by_modified'] = \ |
||
464 | dict(self.orders_by_modified) |
||
465 | |||
466 | STATE_VERSION = 3 |
||
467 | state_dict[VERSION_LABEL] = STATE_VERSION |
||
468 | return state_dict |
||
469 | |||
470 | def __setstate__(self, state): |
||
471 | |||
472 | OLDEST_SUPPORTED_STATE = 3 |
||
473 | version = state.pop(VERSION_LABEL) |
||
474 | |||
475 | if version < OLDEST_SUPPORTED_STATE: |
||
476 | raise BaseException("PerformancePeriod saved state is too old.") |
||
477 | |||
478 | processed_transactions = {} |
||
479 | processed_transactions.update(state.pop('processed_transactions')) |
||
480 | |||
481 | orders_by_id = OrderedDict() |
||
482 | orders_by_id.update(state.pop('orders_by_id')) |
||
483 | |||
484 | orders_by_modified = {} |
||
485 | orders_by_modified.update(state.pop('orders_by_modified')) |
||
486 | self.processed_transactions = processed_transactions |
||
487 | self.orders_by_id = orders_by_id |
||
488 | self.orders_by_modified = orders_by_modified |
||
489 | |||
490 | self._execution_cash_flow_multipliers = {} |
||
491 | |||
492 | self.__dict__.update(state) |
||
493 |