| Total Complexity | 44 |
| Total Lines | 342 |
| Duplicated Lines | 0 % |
Complex classes like zipline.finance.performance.PerformancePeriod often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
| 1 | # |
||
| 151 | class PerformancePeriod(object): |
||
| 152 | |||
| 153 | def __init__( |
||
| 154 | self, |
||
| 155 | starting_cash, |
||
| 156 | asset_finder, |
||
| 157 | data_frequency, |
||
| 158 | period_open=None, |
||
| 159 | period_close=None, |
||
| 160 | keep_transactions=True, |
||
| 161 | keep_orders=False, |
||
| 162 | serialize_positions=True, |
||
| 163 | name=None): |
||
| 164 | |||
| 165 | self.asset_finder = asset_finder |
||
| 166 | self.data_frequency = data_frequency |
||
| 167 | |||
| 168 | self.period_open = period_open |
||
| 169 | self.period_close = period_close |
||
| 170 | |||
| 171 | self.period_cash_flow = 0.0 |
||
| 172 | |||
| 173 | self.starting_cash = starting_cash |
||
| 174 | self.starting_value = 0.0 |
||
| 175 | self.starting_exposure = 0.0 |
||
| 176 | |||
| 177 | self.keep_transactions = keep_transactions |
||
| 178 | self.keep_orders = keep_orders |
||
| 179 | |||
| 180 | self.processed_transactions = {} |
||
| 181 | self.orders_by_modified = {} |
||
| 182 | self.orders_by_id = OrderedDict() |
||
| 183 | |||
| 184 | self.name = name |
||
| 185 | |||
| 186 | # An object to recycle via assigning new values |
||
| 187 | # when returning portfolio information. |
||
| 188 | # So as not to avoid creating a new object for each event |
||
| 189 | self._portfolio_store = zp.Portfolio() |
||
| 190 | self._account_store = zp.Account() |
||
| 191 | self.serialize_positions = serialize_positions |
||
| 192 | |||
| 193 | # This dict contains the known cash flow multipliers for sids and is |
||
| 194 | # keyed on sid |
||
| 195 | self._execution_cash_flow_multipliers = {} |
||
| 196 | |||
| 197 | def rollover(self, pos_stats, prev_period_stats): |
||
| 198 | self.starting_value = pos_stats.net_value |
||
| 199 | self.starting_exposure = pos_stats.net_exposure |
||
| 200 | self.starting_cash = prev_period_stats.ending_cash |
||
| 201 | self.period_cash_flow = 0.0 |
||
| 202 | self.processed_transactions = {} |
||
| 203 | self.orders_by_modified = {} |
||
| 204 | self.orders_by_id = OrderedDict() |
||
| 205 | |||
| 206 | def handle_dividends_paid(self, net_cash_payment): |
||
| 207 | if net_cash_payment: |
||
| 208 | self.handle_cash_payment(net_cash_payment) |
||
| 209 | |||
| 210 | def handle_cash_payment(self, payment_amount): |
||
| 211 | self.adjust_cash(payment_amount) |
||
| 212 | |||
| 213 | def handle_commission(self, cost): |
||
| 214 | # Deduct from our total cash pool. |
||
| 215 | self.adjust_cash(-cost) |
||
| 216 | |||
| 217 | def adjust_cash(self, amount): |
||
| 218 | self.period_cash_flow += amount |
||
| 219 | |||
| 220 | def adjust_field(self, field, value): |
||
| 221 | setattr(self, field, value) |
||
| 222 | |||
| 223 | def record_order(self, order): |
||
| 224 | if self.keep_orders: |
||
| 225 | try: |
||
| 226 | dt_orders = self.orders_by_modified[order.dt] |
||
| 227 | if order.id in dt_orders: |
||
| 228 | del dt_orders[order.id] |
||
| 229 | except KeyError: |
||
| 230 | self.orders_by_modified[order.dt] = dt_orders = OrderedDict() |
||
| 231 | dt_orders[order.id] = order |
||
| 232 | # to preserve the order of the orders by modified date |
||
| 233 | # we delete and add back. (ordered dictionary is sorted by |
||
| 234 | # first insertion date). |
||
| 235 | if order.id in self.orders_by_id: |
||
| 236 | del self.orders_by_id[order.id] |
||
| 237 | self.orders_by_id[order.id] = order |
||
| 238 | |||
| 239 | def handle_execution(self, txn): |
||
| 240 | self.period_cash_flow += self._calculate_execution_cash_flow(txn) |
||
| 241 | |||
| 242 | if self.keep_transactions: |
||
| 243 | try: |
||
| 244 | self.processed_transactions[txn.dt].append(txn) |
||
| 245 | except KeyError: |
||
| 246 | self.processed_transactions[txn.dt] = [txn] |
||
| 247 | |||
| 248 | def _calculate_execution_cash_flow(self, txn): |
||
| 249 | """ |
||
| 250 | Calculates the cash flow from executing the given transaction |
||
| 251 | """ |
||
| 252 | # Check if the multiplier is cached. If it is not, look up the asset |
||
| 253 | # and cache the multiplier. |
||
| 254 | try: |
||
| 255 | multiplier = self._execution_cash_flow_multipliers[txn.sid] |
||
| 256 | except KeyError: |
||
| 257 | asset = self.asset_finder.retrieve_asset(txn.sid) |
||
| 258 | # Futures experience no cash flow on transactions |
||
| 259 | if isinstance(asset, Future): |
||
| 260 | multiplier = 0 |
||
| 261 | else: |
||
| 262 | multiplier = 1 |
||
| 263 | self._execution_cash_flow_multipliers[txn.sid] = multiplier |
||
| 264 | |||
| 265 | # Calculate and return the cash flow given the multiplier |
||
| 266 | return -1 * txn.price * txn.amount * multiplier |
||
| 267 | |||
| 268 | def stats(self, positions, pos_stats, data_portal): |
||
| 269 | # TODO: passing positions here seems off, since we have already |
||
| 270 | # calculated pos_stats. |
||
| 271 | futures_payouts = [] |
||
| 272 | for sid, pos in iteritems(positions): |
||
| 273 | asset = self.asset_finder.retrieve_asset(sid) |
||
| 274 | if isinstance(asset, Future): |
||
| 275 | old_price_dt = max(pos.last_sale_date, self.period_open) |
||
| 276 | |||
| 277 | if old_price_dt == pos.last_sale_date: |
||
| 278 | continue |
||
| 279 | |||
| 280 | old_price = data_portal.get_previous_value( |
||
| 281 | sid, 'close', old_price_dt, self.data_frequency |
||
| 282 | ) |
||
| 283 | |||
| 284 | price = data_portal.get_spot_value( |
||
| 285 | sid, 'close', self.period_close, self.data_frequency, |
||
| 286 | ) |
||
| 287 | |||
| 288 | payout = ( |
||
| 289 | (price - old_price) |
||
| 290 | * |
||
| 291 | asset.contract_multiplier |
||
| 292 | * |
||
| 293 | pos.amount |
||
| 294 | ) |
||
| 295 | futures_payouts.append(payout) |
||
| 296 | |||
| 297 | futures_payout = sum(futures_payouts) |
||
| 298 | |||
| 299 | return calc_period_stats( |
||
| 300 | pos_stats, |
||
| 301 | self.starting_cash, |
||
| 302 | self.starting_value, |
||
| 303 | self.period_cash_flow, |
||
| 304 | futures_payout |
||
| 305 | ) |
||
| 306 | |||
| 307 | def __core_dict(self, pos_stats, period_stats): |
||
| 308 | rval = { |
||
| 309 | 'ending_value': pos_stats.net_value, |
||
| 310 | 'ending_exposure': pos_stats.net_exposure, |
||
| 311 | # this field is renamed to capital_used for backward |
||
| 312 | # compatibility. |
||
| 313 | 'capital_used': self.period_cash_flow, |
||
| 314 | 'starting_value': self.starting_value, |
||
| 315 | 'starting_exposure': self.starting_exposure, |
||
| 316 | 'starting_cash': self.starting_cash, |
||
| 317 | 'ending_cash': period_stats.ending_cash, |
||
| 318 | 'portfolio_value': period_stats.portfolio_value, |
||
| 319 | 'pnl': period_stats.pnl, |
||
| 320 | 'returns': period_stats.returns, |
||
| 321 | 'period_open': self.period_open, |
||
| 322 | 'period_close': self.period_close, |
||
| 323 | 'gross_leverage': period_stats.gross_leverage, |
||
| 324 | 'net_leverage': period_stats.net_leverage, |
||
| 325 | 'short_exposure': pos_stats.short_exposure, |
||
| 326 | 'long_exposure': pos_stats.long_exposure, |
||
| 327 | 'short_value': pos_stats.short_value, |
||
| 328 | 'long_value': pos_stats.long_value, |
||
| 329 | 'longs_count': pos_stats.longs_count, |
||
| 330 | 'shorts_count': pos_stats.shorts_count, |
||
| 331 | } |
||
| 332 | |||
| 333 | return rval |
||
| 334 | |||
| 335 | def to_dict(self, pos_stats, period_stats, position_tracker, dt=None): |
||
| 336 | """ |
||
| 337 | Creates a dictionary representing the state of this performance |
||
| 338 | period. See header comments for a detailed description. |
||
| 339 | |||
| 340 | Kwargs: |
||
| 341 | dt (datetime): If present, only return transactions for the dt. |
||
| 342 | """ |
||
| 343 | rval = self.__core_dict(pos_stats, period_stats) |
||
| 344 | |||
| 345 | if self.serialize_positions: |
||
| 346 | positions = position_tracker.get_positions_list() |
||
| 347 | rval['positions'] = positions |
||
| 348 | |||
| 349 | # we want the key to be absent, not just empty |
||
| 350 | if self.keep_transactions: |
||
| 351 | if dt: |
||
| 352 | # Only include transactions for given dt |
||
| 353 | try: |
||
| 354 | transactions = [x.to_dict() |
||
| 355 | for x in self.processed_transactions[dt]] |
||
| 356 | except KeyError: |
||
| 357 | transactions = [] |
||
| 358 | else: |
||
| 359 | transactions = \ |
||
| 360 | [y.to_dict() |
||
| 361 | for x in itervalues(self.processed_transactions) |
||
| 362 | for y in x] |
||
| 363 | rval['transactions'] = transactions |
||
| 364 | |||
| 365 | if self.keep_orders: |
||
| 366 | if dt: |
||
| 367 | # only include orders modified as of the given dt. |
||
| 368 | try: |
||
| 369 | orders = [x.to_dict() |
||
| 370 | for x in itervalues(self.orders_by_modified[dt])] |
||
| 371 | except KeyError: |
||
| 372 | orders = [] |
||
| 373 | else: |
||
| 374 | orders = [x.to_dict() for x in itervalues(self.orders_by_id)] |
||
| 375 | rval['orders'] = orders |
||
| 376 | |||
| 377 | return rval |
||
| 378 | |||
| 379 | def as_portfolio(self, pos_stats, period_stats, position_tracker, dt): |
||
| 380 | """ |
||
| 381 | The purpose of this method is to provide a portfolio |
||
| 382 | object to algorithms running inside the same trading |
||
| 383 | client. The data needed is captured raw in a |
||
| 384 | PerformancePeriod, and in this method we rename some |
||
| 385 | fields for usability and remove extraneous fields. |
||
| 386 | """ |
||
| 387 | # Recycles containing objects' Portfolio object |
||
| 388 | # which is used for returning values. |
||
| 389 | # as_portfolio is called in an inner loop, |
||
| 390 | # so repeated object creation becomes too expensive |
||
| 391 | portfolio = self._portfolio_store |
||
| 392 | # maintaining the old name for the portfolio field for |
||
| 393 | # backward compatibility |
||
| 394 | portfolio.capital_used = self.period_cash_flow |
||
| 395 | portfolio.starting_cash = self.starting_cash |
||
| 396 | portfolio.portfolio_value = period_stats.portfolio_value |
||
| 397 | portfolio.pnl = period_stats.pnl |
||
| 398 | portfolio.returns = period_stats.returns |
||
| 399 | portfolio.cash = period_stats.ending_cash |
||
| 400 | portfolio.start_date = self.period_open |
||
| 401 | portfolio.positions = position_tracker.get_positions() |
||
| 402 | portfolio.positions_value = pos_stats.net_value |
||
| 403 | portfolio.positions_exposure = pos_stats.net_exposure |
||
| 404 | return portfolio |
||
| 405 | |||
| 406 | def as_account(self, pos_stats, period_stats): |
||
| 407 | account = self._account_store |
||
| 408 | |||
| 409 | # If no attribute is found on the PerformancePeriod resort to the |
||
| 410 | # following default values. If an attribute is found use the existing |
||
| 411 | # value. For instance, a broker may provide updates to these |
||
| 412 | # attributes. In this case we do not want to over write the broker |
||
| 413 | # values with the default values. |
||
| 414 | account.settled_cash = \ |
||
| 415 | getattr(self, 'settled_cash', period_stats.ending_cash) |
||
| 416 | account.accrued_interest = \ |
||
| 417 | getattr(self, 'accrued_interest', 0.0) |
||
| 418 | account.buying_power = \ |
||
| 419 | getattr(self, 'buying_power', float('inf')) |
||
| 420 | account.equity_with_loan = \ |
||
| 421 | getattr(self, 'equity_with_loan', period_stats.portfolio_value) |
||
| 422 | account.total_positions_value = \ |
||
| 423 | getattr(self, 'total_positions_value', pos_stats.net_value) |
||
| 424 | account.total_positions_value = \ |
||
| 425 | getattr(self, 'total_positions_exposure', pos_stats.net_exposure) |
||
| 426 | account.regt_equity = \ |
||
| 427 | getattr(self, 'regt_equity', period_stats.ending_cash) |
||
| 428 | account.regt_margin = \ |
||
| 429 | getattr(self, 'regt_margin', float('inf')) |
||
| 430 | account.initial_margin_requirement = \ |
||
| 431 | getattr(self, 'initial_margin_requirement', 0.0) |
||
| 432 | account.maintenance_margin_requirement = \ |
||
| 433 | getattr(self, 'maintenance_margin_requirement', 0.0) |
||
| 434 | account.available_funds = \ |
||
| 435 | getattr(self, 'available_funds', period_stats.ending_cash) |
||
| 436 | account.excess_liquidity = \ |
||
| 437 | getattr(self, 'excess_liquidity', period_stats.ending_cash) |
||
| 438 | account.cushion = \ |
||
| 439 | getattr(self, 'cushion', |
||
| 440 | period_stats.ending_cash / period_stats.portfolio_value) |
||
| 441 | account.day_trades_remaining = \ |
||
| 442 | getattr(self, 'day_trades_remaining', float('inf')) |
||
| 443 | account.leverage = getattr(self, 'leverage', |
||
| 444 | period_stats.gross_leverage) |
||
| 445 | account.net_leverage = period_stats.net_leverage |
||
| 446 | |||
| 447 | account.net_liquidation = getattr(self, 'net_liquidation', |
||
| 448 | period_stats.net_liquidation) |
||
| 449 | return account |
||
| 450 | |||
| 451 | def __getstate__(self): |
||
| 452 | state_dict = {k: v for k, v in iteritems(self.__dict__) |
||
| 453 | if not k.startswith('_')} |
||
| 454 | |||
| 455 | state_dict['_portfolio_store'] = self._portfolio_store |
||
| 456 | state_dict['_account_store'] = self._account_store |
||
| 457 | state_dict['data_frequency'] = self.data_frequency |
||
| 458 | |||
| 459 | state_dict['processed_transactions'] = \ |
||
| 460 | dict(self.processed_transactions) |
||
| 461 | state_dict['orders_by_id'] = \ |
||
| 462 | dict(self.orders_by_id) |
||
| 463 | state_dict['orders_by_modified'] = \ |
||
| 464 | dict(self.orders_by_modified) |
||
| 465 | |||
| 466 | STATE_VERSION = 3 |
||
| 467 | state_dict[VERSION_LABEL] = STATE_VERSION |
||
| 468 | return state_dict |
||
| 469 | |||
| 470 | def __setstate__(self, state): |
||
| 471 | |||
| 472 | OLDEST_SUPPORTED_STATE = 3 |
||
| 473 | version = state.pop(VERSION_LABEL) |
||
| 474 | |||
| 475 | if version < OLDEST_SUPPORTED_STATE: |
||
| 476 | raise BaseException("PerformancePeriod saved state is too old.") |
||
| 477 | |||
| 478 | processed_transactions = {} |
||
| 479 | processed_transactions.update(state.pop('processed_transactions')) |
||
| 480 | |||
| 481 | orders_by_id = OrderedDict() |
||
| 482 | orders_by_id.update(state.pop('orders_by_id')) |
||
| 483 | |||
| 484 | orders_by_modified = {} |
||
| 485 | orders_by_modified.update(state.pop('orders_by_modified')) |
||
| 486 | self.processed_transactions = processed_transactions |
||
| 487 | self.orders_by_id = orders_by_id |
||
| 488 | self.orders_by_modified = orders_by_modified |
||
| 489 | |||
| 490 | self._execution_cash_flow_multipliers = {} |
||
| 491 | |||
| 492 | self.__dict__.update(state) |
||
| 493 |