Completed
Pull Request — master (#920)
by Eddie
01:18
created

zipline.finance.performance.calc_position_stats()   A

Complexity

Conditions 2

Size

Total Lines 56

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 2
dl 0
loc 56
rs 9.7252

How to fix   Long Method   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

1
#
2
# Copyright 2015 Quantopian, Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
from __future__ import division
17
18
import logbook
19
import numpy as np
20
import pandas as pd
21
from pandas.lib import checknull
22
from collections import namedtuple
23
try:
24
    # optional cython based OrderedDict
25
    from cyordereddict import OrderedDict
26
except ImportError:
27
    from collections import OrderedDict
28
from six import iteritems, itervalues
29
30
from zipline.protocol import Event, DATASOURCE_TYPE
31
from zipline.finance.transaction import Transaction
32
from zipline.utils.serialization_utils import (
33
    VERSION_LABEL
34
)
35
36
import zipline.protocol as zp
37
from zipline.assets import (
38
    Equity, Future
39
)
40
from zipline.errors import PositionTrackerMissingAssetFinder
41
from . position import positiondict
42
43
log = logbook.Logger('Performance')
44
45
46
PositionStats = namedtuple('PositionStats',
47
                           ['net_exposure',
48
                            'gross_value',
49
                            'gross_exposure',
50
                            'short_value',
51
                            'short_exposure',
52
                            'shorts_count',
53
                            'long_value',
54
                            'long_exposure',
55
                            'longs_count',
56
                            'net_value'])
57
58
59
def calc_position_values(amounts,
60
                         last_sale_prices,
61
                         value_multipliers):
62
    iter_amount_price_multiplier = zip(
63
        amounts,
64
        last_sale_prices,
65
        itervalues(value_multipliers),
66
    )
67
    return [
68
        price * amount * multiplier for
69
        price, amount, multiplier in iter_amount_price_multiplier
70
    ]
71
72
73
def calc_net(values):
74
    # Returns 0.0 if there are no values.
75
    return sum(values, np.float64())
76
77
78
def calc_position_exposures(amounts,
79
                            last_sale_prices,
80
                            exposure_multipliers):
81
    iter_amount_price_multiplier = zip(
82
        amounts,
83
        last_sale_prices,
84
        itervalues(exposure_multipliers),
85
    )
86
    return [
87
        price * amount * multiplier for
88
        price, amount, multiplier in iter_amount_price_multiplier
89
    ]
90
91
92
def calc_long_value(position_values):
93
    return sum(i for i in position_values if i > 0)
94
95
96
def calc_short_value(position_values):
97
    return sum(i for i in position_values if i < 0)
98
99
100
def calc_long_exposure(position_exposures):
101
    return sum(i for i in position_exposures if i > 0)
102
103
104
def calc_short_exposure(position_exposures):
105
    return sum(i for i in position_exposures if i < 0)
106
107
108
def calc_longs_count(position_exposures):
109
    return sum(1 for i in position_exposures if i > 0)
110
111
112
def calc_shorts_count(position_exposures):
113
    return sum(1 for i in position_exposures if i < 0)
114
115
116
def calc_gross_exposure(long_exposure, short_exposure):
117
    return long_exposure + abs(short_exposure)
118
119
120
def calc_gross_value(long_value, short_value):
121
    return long_value + abs(short_value)
122
123
124
def calc_position_stats(positions,
125
                        position_value_multipliers,
126
                        position_exposure_multipliers):
127
    amounts = []
128
    last_sale_prices = []
129
    for pos in itervalues(positions):
130
        amounts.append(pos.amount)
131
        last_sale_prices.append(pos.last_sale_price)
132
133
    position_values = calc_position_values(
134
        amounts,
135
        last_sale_prices,
136
        position_value_multipliers
137
    )
138
139
    position_exposures = calc_position_exposures(
140
        amounts,
141
        last_sale_prices,
142
        position_exposure_multipliers
143
    )
144
145
    long_value = calc_long_value(position_values)
146
    short_value = calc_short_value(position_values)
147
    gross_value = calc_gross_value(long_value, short_value)
148
    long_exposure = calc_long_exposure(position_exposures)
149
    short_exposure = calc_short_exposure(position_exposures)
150
    gross_exposure = calc_gross_exposure(long_exposure, short_exposure)
151
    net_exposure = calc_net(position_exposures)
152
    longs_count = calc_longs_count(position_exposures)
153
    shorts_count = calc_shorts_count(position_exposures)
154
    net_value = calc_net(position_values)
155
156
    return PositionStats(
157
        long_value=long_value,
158
        gross_value=gross_value,
159
        short_value=short_value,
160
        long_exposure=long_exposure,
161
        short_exposure=short_exposure,
162
        gross_exposure=gross_exposure,
163
        net_exposure=net_exposure,
164
        longs_count=longs_count,
165
        shorts_count=shorts_count,
166
        net_value=net_value
167
    )
168
169
    return PositionStats(
170
        long_value=long_value,
171
        gross_value=gross_value,
172
        short_value=short_value,
173
        long_exposure=long_exposure,
174
        short_exposure=short_exposure,
175
        gross_exposure=gross_exposure,
176
        net_exposure=net_exposure,
177
        longs_count=longs_count,
178
        shorts_count=shorts_count,
179
        net_value=net_value
180
    )
181
182
183
class PositionTracker(object):
184
185
    def __init__(self, asset_finder):
186
        self.asset_finder = asset_finder
187
188
        # sid => position object
189
        self.positions = positiondict()
190
        # Arrays for quick calculations of positions value
191
        self._position_value_multipliers = OrderedDict()
192
        self._position_exposure_multipliers = OrderedDict()
193
        self._position_payout_multipliers = OrderedDict()
194
        self._unpaid_dividends = pd.DataFrame(
195
            columns=zp.DIVIDEND_PAYMENT_FIELDS,
196
        )
197
        self._positions_store = zp.Positions()
198
199
        # Dict, keyed on dates, that contains lists of close position events
200
        # for any Assets in this tracker's positions
201
        self._auto_close_position_sids = {}
202
203
    def _update_asset(self, sid):
204
        try:
205
            self._position_value_multipliers[sid]
206
            self._position_exposure_multipliers[sid]
207
            self._position_payout_multipliers[sid]
208
        except KeyError:
209
            # Check if there is an AssetFinder
210
            if self.asset_finder is None:
211
                raise PositionTrackerMissingAssetFinder()
212
213
            # Collect the value multipliers from applicable sids
214
            asset = self.asset_finder.retrieve_asset(sid)
215
            if isinstance(asset, Equity):
216
                self._position_value_multipliers[sid] = 1
217
                self._position_exposure_multipliers[sid] = 1
218
                self._position_payout_multipliers[sid] = 0
219
            if isinstance(asset, Future):
220
                self._position_value_multipliers[sid] = 0
221
                self._position_exposure_multipliers[sid] = \
222
                    asset.contract_multiplier
223
                self._position_payout_multipliers[sid] = \
224
                    asset.contract_multiplier
225
                # Futures auto-close timing is controlled by the Future's
226
                # auto_close_date property
227
                self._insert_auto_close_position_date(
228
                    dt=asset.auto_close_date,
229
                    sid=sid
230
                )
231
232
    def _insert_auto_close_position_date(self, dt, sid):
233
        """
234
        Inserts the given SID in to the list of positions to be auto-closed by
235
        the given dt.
236
237
        Parameters
238
        ----------
239
        dt : pandas.Timestamp
240
            The date before-which the given SID will be auto-closed
241
        sid : int
242
            The SID of the Asset to be auto-closed
243
        """
244
        if dt is not None:
245
            self._auto_close_position_sids.setdefault(dt, set()).add(sid)
246
247
    def auto_close_position_events(self, next_trading_day):
248
        """
249
        Generates CLOSE_POSITION events for any SIDs whose auto-close date is
250
        before or equal to the given date.
251
252
        Parameters
253
        ----------
254
        next_trading_day : pandas.Timestamp
255
            The time before-which certain Assets need to be closed
256
257
        Yields
258
        ------
259
        Event
260
            A close position event for any sids that should be closed before
261
            the next_trading_day parameter
262
        """
263
        past_asset_end_dates = set()
264
265
        # Check the auto_close_position_dates dict for SIDs to close
266
        for date, sids in self._auto_close_position_sids.items():
267
            if date > next_trading_day:
268
                continue
269
            past_asset_end_dates.add(date)
270
271
            for sid in sids:
272
                # Yield a CLOSE_POSITION event
273
                event = Event({
274
                    'dt': date,
275
                    'type': DATASOURCE_TYPE.CLOSE_POSITION,
276
                    'sid': sid,
277
                })
278
                yield event
279
280
        # Clear out past dates
281
        while past_asset_end_dates:
282
            self._auto_close_position_sids.pop(past_asset_end_dates.pop())
283
284
    def update_last_sale(self, event):
285
        # NOTE, PerformanceTracker already vetted as TRADE type
286
        sid = event.sid
287
        if sid not in self.positions:
288
            return 0
289
290
        price = event.price
291
292
        if checknull(price):
293
            return 0
294
295
        pos = self.positions[sid]
296
        old_price = pos.last_sale_price
297
        pos.last_sale_date = event.dt
298
        pos.last_sale_price = price
299
300
        # Calculate cash adjustment on assets with multipliers
301
        return ((price - old_price) * self._position_payout_multipliers[sid]
302
                * pos.amount)
303
304
    def update_positions(self, positions):
305
        # update positions in batch
306
        self.positions.update(positions)
307
        for sid, pos in iteritems(positions):
308
            self._update_asset(sid)
309
310
    def update_position(self, sid, amount=None, last_sale_price=None,
311
                        last_sale_date=None, cost_basis=None):
312
        pos = self.positions[sid]
313
314
        if amount is not None:
315
            pos.amount = amount
316
            self._update_asset(sid=sid)
317
        if last_sale_price is not None:
318
            pos.last_sale_price = last_sale_price
319
        if last_sale_date is not None:
320
            pos.last_sale_date = last_sale_date
321
        if cost_basis is not None:
322
            pos.cost_basis = cost_basis
323
324
    def execute_transaction(self, txn):
325
        # Update Position
326
        # ----------------
327
        sid = txn.sid
328
        position = self.positions[sid]
329
        position.update(txn)
330
        self._update_asset(sid)
331
332
    def handle_commission(self, sid, cost):
333
        # Adjust the cost basis of the stock if we own it
334
        if sid in self.positions:
335
            self.positions[sid].adjust_commission_cost_basis(sid, cost)
336
337
    def handle_split(self, split):
338
        if split.sid in self.positions:
339
            # Make the position object handle the split. It returns the
340
            # leftover cash from a fractional share, if there is any.
341
            position = self.positions[split.sid]
342
            leftover_cash = position.handle_split(split.sid, split.ratio)
343
            self._update_asset(split.sid)
344
            return leftover_cash
345
346
    def _maybe_earn_dividend(self, dividend):
347
        """
348
        Take a historical dividend record and return a Series with fields in
349
        zipline.protocol.DIVIDEND_FIELDS (plus an 'id' field) representing
350
        the cash/stock amount we are owed when the dividend is paid.
351
        """
352
        if dividend['sid'] in self.positions:
353
            return self.positions[dividend['sid']].earn_dividend(dividend)
354
        else:
355
            return zp.dividend_payment()
356
357
    def earn_dividends(self, dividend_frame):
358
        """
359
        Given a frame of dividends whose ex_dates are all the next trading day,
360
        calculate and store the cash and/or stock payments to be paid on each
361
        dividend's pay date.
362
        """
363
        earned = dividend_frame.apply(self._maybe_earn_dividend, axis=1)\
364
                               .dropna(how='all')
365
        if len(earned) > 0:
366
            # Store the earned dividends so that they can be paid on the
367
            # dividends' pay_dates.
368
            self._unpaid_dividends = pd.concat(
369
                [self._unpaid_dividends, earned],
370
            )
371
372
    def _maybe_pay_dividend(self, dividend):
373
        """
374
        Take a historical dividend record, look up any stored record of
375
        cash/stock we are owed for that dividend, and return a Series
376
        with fields drawn from zipline.protocol.DIVIDEND_PAYMENT_FIELDS.
377
        """
378
        try:
379
            unpaid_dividend = self._unpaid_dividends.loc[dividend['id']]
380
            return unpaid_dividend
381
        except KeyError:
382
            return zp.dividend_payment()
383
384
    def pay_dividends(self, dividend_frame):
385
        """
386
        Given a frame of dividends whose pay_dates are all the next trading
387
        day, grant the cash and/or stock payments that were calculated on the
388
        given dividends' ex dates.
389
        """
390
        payments = dividend_frame.apply(self._maybe_pay_dividend, axis=1)\
391
                                 .dropna(how='all')
392
393
        # Mark these dividends as paid by dropping them from our unpaid
394
        # table.
395
        self._unpaid_dividends.drop(payments.index)
396
397
        # Add stock for any stock dividends paid.  Again, the values here may
398
        # be negative in the case of short positions.
399
        stock_payments = payments[payments['payment_sid'].notnull()]
400
        for _, row in stock_payments.iterrows():
401
            stock = row['payment_sid']
402
            share_count = row['share_count']
403
            # note we create a Position for stock dividend if we don't
404
            # already own the asset
405
            position = self.positions[stock]
406
407
            position.amount += share_count
408
            self._update_asset(stock)
409
410
        # Add cash equal to the net cash payed from all dividends.  Note that
411
        # "negative cash" is effectively paid if we're short an asset,
412
        # representing the fact that we're required to reimburse the owner of
413
        # the stock for any dividends paid while borrowing.
414
        net_cash_payment = payments['cash_amount'].fillna(0).sum()
415
        return net_cash_payment
416
417
    def maybe_create_close_position_transaction(self, event):
418
        try:
419
            pos = self.positions[event.sid]
420
            amount = pos.amount
421
            if amount == 0:
422
                return None
423
        except KeyError:
424
            return None
425
        if 'price' in event:
426
            price = event.price
427
        else:
428
            price = pos.last_sale_price
429
        txn = Transaction(
430
            sid=event.sid,
431
            amount=(-1 * pos.amount),
432
            dt=event.dt,
433
            price=price,
434
            commission=0,
435
            order_id=0
436
        )
437
        return txn
438
439
    def get_positions(self):
440
441
        positions = self._positions_store
442
443
        for sid, pos in iteritems(self.positions):
444
445
            if pos.amount == 0:
446
                # Clear out the position if it has become empty since the last
447
                # time get_positions was called.  Catching the KeyError is
448
                # faster than checking `if sid in positions`, and this can be
449
                # potentially called in a tight inner loop.
450
                try:
451
                    del positions[sid]
452
                except KeyError:
453
                    pass
454
                continue
455
456
            # Note that this will create a position if we don't currently have
457
            # an entry
458
            position = positions[sid]
459
            position.amount = pos.amount
460
            position.cost_basis = pos.cost_basis
461
            position.last_sale_price = pos.last_sale_price
462
        return positions
463
464
    def get_positions_list(self):
465
        positions = []
466
        for sid, pos in iteritems(self.positions):
467
            if pos.amount != 0:
468
                positions.append(pos.to_dict())
469
        return positions
470
471
    def stats(self):
472
        return calc_position_stats(self.positions,
473
                                   self._position_value_multipliers,
474
                                   self._position_exposure_multipliers)
475
476
    def __getstate__(self):
477
        state_dict = {}
478
479
        state_dict['asset_finder'] = self.asset_finder
480
        state_dict['positions'] = dict(self.positions)
481
        state_dict['unpaid_dividends'] = self._unpaid_dividends
482
        state_dict['auto_close_position_sids'] = self._auto_close_position_sids
483
484
        STATE_VERSION = 3
485
        state_dict[VERSION_LABEL] = STATE_VERSION
486
        return state_dict
487
488
    def __setstate__(self, state):
489
        OLDEST_SUPPORTED_STATE = 3
490
        version = state.pop(VERSION_LABEL)
491
492
        if version < OLDEST_SUPPORTED_STATE:
493
            raise BaseException("PositionTracker saved state is too old.")
494
495
        self.asset_finder = state['asset_finder']
496
        self.positions = positiondict()
497
        # note that positions_store is temporary and gets regened from
498
        # .positions
499
        self._positions_store = zp.Positions()
500
501
        self._unpaid_dividends = state['unpaid_dividends']
502
        self._auto_close_position_sids = state['auto_close_position_sids']
503
504
        # Arrays for quick calculations of positions value
505
        self._position_value_multipliers = OrderedDict()
506
        self._position_exposure_multipliers = OrderedDict()
507
        self._position_payout_multipliers = OrderedDict()
508
509
        # Update positions is called without a finder
510
        self.update_positions(state['positions'])
511