Completed
Pull Request — master (#858)
by Eddie
01:43
created

zipline.finance.performance.PositionTracker   C

Complexity

Total Complexity 54

Size/Duplication

Total Lines 342
Duplicated Lines 0 %
Metric Value
dl 0
loc 342
rs 6.8539
wmc 54

18 Methods

Rating   Name   Duplication   Size   Complexity  
A __init__() 0 22 1
A update_positions() 0 5 2
A _insert_auto_close_position_date() 0 14 2
A handle_commission() 0 5 2
B get_positions() 0 24 4
B _update_asset() 0 27 5
B pay_dividends() 0 42 6
B update_position() 0 17 6
A maybe_create_close_position_transaction() 0 16 2
A execute_transaction() 0 13 2
A sync_last_sale_prices() 0 5 2
A get_positions_list() 0 6 3
B earn_dividends() 0 25 5
A __getstate__() 0 12 1
A handle_splits() 0 22 3
B auto_close_position_events() 0 36 5
A stats() 0 4 1
B __setstate__() 0 27 2

How to fix   Complexity   

Complex Class

Complex classes like zipline.finance.performance.PositionTracker often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
#
2
# Copyright 2015 Quantopian, Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
from __future__ import division
17
18
import logbook
19
import numpy as np
20
from collections import namedtuple
21
from zipline.finance.performance.position import Position
22
from zipline.finance.transaction import Transaction
23
24
try:
25
    # optional cython based OrderedDict
26
    from cyordereddict import OrderedDict
27
except ImportError:
28
    from collections import OrderedDict
29
from six import iteritems, itervalues
30
31
from zipline.protocol import Event, DATASOURCE_TYPE
32
from zipline.utils.serialization_utils import (
33
    VERSION_LABEL
34
)
35
36
import zipline.protocol as zp
37
from zipline.assets import (
38
    Equity, Future
39
)
40
from zipline.errors import PositionTrackerMissingAssetFinder
41
from . position import positiondict
42
43
log = logbook.Logger('Performance')
44
45
46
PositionStats = namedtuple('PositionStats',
47
                           ['net_exposure',
48
                            'gross_value',
49
                            'gross_exposure',
50
                            'short_value',
51
                            'short_exposure',
52
                            'shorts_count',
53
                            'long_value',
54
                            'long_exposure',
55
                            'longs_count',
56
                            'net_value'])
57
58
59
def calc_position_values(amounts,
60
                         last_sale_prices,
61
                         value_multipliers):
62
    iter_amount_price_multiplier = zip(
63
        amounts,
64
        last_sale_prices,
65
        itervalues(value_multipliers),
66
    )
67
    return [
68
        price * amount * multiplier for
69
        price, amount, multiplier in iter_amount_price_multiplier
70
    ]
71
72
73
def calc_net(values):
74
    # Returns 0.0 if there are no values.
75
    return sum(values, np.float64())
76
77
78
def calc_position_exposures(amounts,
79
                            last_sale_prices,
80
                            exposure_multipliers):
81
    iter_amount_price_multiplier = zip(
82
        amounts,
83
        last_sale_prices,
84
        itervalues(exposure_multipliers),
85
    )
86
    return [
87
        price * amount * multiplier for
88
        price, amount, multiplier in iter_amount_price_multiplier
89
    ]
90
91
92
def calc_long_value(position_values):
93
    return sum(i for i in position_values if i > 0)
94
95
96
def calc_short_value(position_values):
97
    return sum(i for i in position_values if i < 0)
98
99
100
def calc_long_exposure(position_exposures):
101
    return sum(i for i in position_exposures if i > 0)
102
103
104
def calc_short_exposure(position_exposures):
105
    return sum(i for i in position_exposures if i < 0)
106
107
108
def calc_longs_count(position_exposures):
109
    return sum(1 for i in position_exposures if i > 0)
110
111
112
def calc_shorts_count(position_exposures):
113
    return sum(1 for i in position_exposures if i < 0)
114
115
116
def calc_gross_exposure(long_exposure, short_exposure):
117
    return long_exposure + abs(short_exposure)
118
119
120
def calc_gross_value(long_value, short_value):
121
    return long_value + abs(short_value)
122
123
124
def calc_position_stats(positions,
125
                        position_value_multipliers,
126
                        position_exposure_multipliers):
127
    amounts = []
128
    last_sale_prices = []
129
    for pos in itervalues(positions):
130
        amounts.append(pos.amount)
131
        last_sale_prices.append(pos.last_sale_price)
132
133
    position_values = calc_position_values(
134
        amounts,
135
        last_sale_prices,
136
        position_value_multipliers
137
    )
138
139
    position_exposures = calc_position_exposures(
140
        amounts,
141
        last_sale_prices,
142
        position_exposure_multipliers
143
    )
144
145
    long_value = calc_long_value(position_values)
146
    short_value = calc_short_value(position_values)
147
    gross_value = calc_gross_value(long_value, short_value)
148
    long_exposure = calc_long_exposure(position_exposures)
149
    short_exposure = calc_short_exposure(position_exposures)
150
    gross_exposure = calc_gross_exposure(long_exposure, short_exposure)
151
    net_exposure = calc_net(position_exposures)
152
    longs_count = calc_longs_count(position_exposures)
153
    shorts_count = calc_shorts_count(position_exposures)
154
    net_value = calc_net(position_values)
155
156
    return PositionStats(
157
        long_value=long_value,
158
        gross_value=gross_value,
159
        short_value=short_value,
160
        long_exposure=long_exposure,
161
        short_exposure=short_exposure,
162
        gross_exposure=gross_exposure,
163
        net_exposure=net_exposure,
164
        longs_count=longs_count,
165
        shorts_count=shorts_count,
166
        net_value=net_value
167
    )
168
169
170
class PositionTracker(object):
171
172
    def __init__(self, asset_finder, data_portal):
173
        self.asset_finder = asset_finder
174
175
        # FIXME really want to avoid storing a data portal here,
176
        # but the path to get to maybe_create_close_position_transaction
177
        # is long and tortuous
178
        self._data_portal = data_portal
179
180
        # sid => position object
181
        self.positions = positiondict()
182
183
        # Arrays for quick calculations of positions value
184
        self._position_value_multipliers = OrderedDict()
185
        self._position_exposure_multipliers = OrderedDict()
186
        self._position_payout_multipliers = OrderedDict()
187
        self._unpaid_dividends = {}
188
        self._unpaid_stock_dividends = {}
189
        self._positions_store = zp.Positions()
190
191
        # Dict, keyed on dates, that contains lists of close position events
192
        # for any Assets in this tracker's positions
193
        self._auto_close_position_sids = {}
194
195
    def _update_asset(self, sid):
196
        try:
197
            self._position_value_multipliers[sid]
198
            self._position_exposure_multipliers[sid]
199
            self._position_payout_multipliers[sid]
200
        except KeyError:
201
            # Check if there is an AssetFinder
202
            if self.asset_finder is None:
203
                raise PositionTrackerMissingAssetFinder()
204
205
            # Collect the value multipliers from applicable sids
206
            asset = self.asset_finder.retrieve_asset(sid)
207
            if isinstance(asset, Equity):
208
                self._position_value_multipliers[sid] = 1
209
                self._position_exposure_multipliers[sid] = 1
210
                self._position_payout_multipliers[sid] = 0
211
            if isinstance(asset, Future):
212
                self._position_value_multipliers[sid] = 0
213
                self._position_exposure_multipliers[sid] = \
214
                    asset.contract_multiplier
215
                self._position_payout_multipliers[sid] = \
216
                    asset.contract_multiplier
217
                # Futures auto-close timing is controlled by the Future's
218
                # auto_close_date property
219
                self._insert_auto_close_position_date(
220
                    dt=asset.auto_close_date,
221
                    sid=sid
222
                )
223
224
    def _insert_auto_close_position_date(self, dt, sid):
225
        """
226
        Inserts the given SID in to the list of positions to be auto-closed by
227
        the given dt.
228
229
        Parameters
230
        ----------
231
        dt : pandas.Timestamp
232
            The date before-which the given SID will be auto-closed
233
        sid : int
234
            The SID of the Asset to be auto-closed
235
        """
236
        if dt is not None:
237
            self._auto_close_position_sids.setdefault(dt, set()).add(sid)
238
239
    def auto_close_position_events(self, next_trading_day):
240
        """
241
        Generates CLOSE_POSITION events for any SIDs whose auto-close date is
242
        before or equal to the given date.
243
244
        Parameters
245
        ----------
246
        next_trading_day : pandas.Timestamp
247
            The time before-which certain Assets need to be closed
248
249
        Yields
250
        ------
251
        Event
252
            A close position event for any sids that should be closed before
253
            the next_trading_day parameter
254
        """
255
        past_asset_end_dates = set()
256
257
        # Check the auto_close_position_dates dict for SIDs to close
258
        for date, sids in self._auto_close_position_sids.items():
259
            if date > next_trading_day:
260
                continue
261
            past_asset_end_dates.add(date)
262
263
            for sid in sids:
264
                # Yield a CLOSE_POSITION event
265
                event = Event({
266
                    'dt': date,
267
                    'type': DATASOURCE_TYPE.CLOSE_POSITION,
268
                    'sid': sid,
269
                })
270
                yield event
271
272
        # Clear out past dates
273
        while past_asset_end_dates:
274
            self._auto_close_position_sids.pop(past_asset_end_dates.pop())
275
276
    def update_positions(self, positions):
277
        # update positions in batch
278
        self.positions.update(positions)
279
        for sid, pos in iteritems(positions):
280
            self._update_asset(sid)
281
282
    def update_position(self, sid, amount=None, last_sale_price=None,
283
                        last_sale_date=None, cost_basis=None):
284
        if sid not in self.positions:
285
            position = Position(sid)
286
            self.positions[sid] = position
287
        else:
288
            position = self.positions[sid]
289
290
        if amount is not None:
291
            position.amount = amount
292
            self._update_asset(sid=sid)
293
        if last_sale_price is not None:
294
            position.last_sale_price = last_sale_price
295
        if last_sale_date is not None:
296
            position.last_sale_date = last_sale_date
297
        if cost_basis is not None:
298
            position.cost_basis = cost_basis
299
300
    def execute_transaction(self, txn):
301
        # Update Position
302
        # ----------------
303
        sid = txn.sid
304
305
        if sid not in self.positions:
306
            position = Position(sid)
307
            self.positions[sid] = position
308
        else:
309
            position = self.positions[sid]
310
311
        position.update(txn)
312
        self._update_asset(sid)
313
314
    def handle_commission(self, sid, cost):
315
        # Adjust the cost basis of the stock if we own it
316
        if sid in self.positions:
317
            self.positions[sid].\
318
                adjust_commission_cost_basis(sid, cost)
319
320
    def handle_splits(self, splits):
321
        """
322
        Processes a list of splits by modifying any positions as needed.
323
324
        Parameters
325
        ----------
326
        splits: list
327
            A list of splits.  Each split is a tuple of (sid, ratio).
328
329
        Returns
330
        -------
331
        None
332
        """
333
        for split in splits:
334
            sid = split[0]
335
            if sid in self.positions:
336
                # Make the position object handle the split. It returns the
337
                # leftover cash from a fractional share, if there is any.
338
                position = self.positions[sid]
339
                leftover_cash = position.handle_split(sid, split[1])
340
                self._update_asset(split[0])
341
                return leftover_cash
342
343
    def earn_dividends(self, dividends, stock_dividends):
344
        """
345
        Given a list of dividends whose ex_dates are all the next trading day,
346
        calculate and store the cash and/or stock payments to be paid on each
347
        dividend's pay date.
348
        """
349
        for dividend in dividends:
350
            # Store the earned dividends so that they can be paid on the
351
            # dividends' pay_dates.
352
            div_owed = self.positions[dividend.sid].earn_dividend(dividend)
353
            try:
354
                self._unpaid_dividends[dividend.pay_date].append(
355
                    div_owed)
356
            except KeyError:
357
                self._unpaid_dividends[dividend.pay_date] = [div_owed]
358
359
        for stock_dividend in stock_dividends:
360
            div_owed = self.positions[stock_dividend.sid].earn_stock_dividend(
361
                stock_dividend)
362
            try:
363
                self._unpaid_stock_dividends[stock_dividend.pay_date].\
364
                    append(div_owed)
365
            except KeyError:
366
                self._unpaid_stock_dividends[stock_dividend.pay_date] = \
367
                    [div_owed]
368
369
    def pay_dividends(self, next_trading_day):
370
        """
371
        Returns a cash payment based on the dividends that should be paid out
372
        according to the accumulated bookkeeping of earned, unpaid, and stock
373
        dividends.
374
        """
375
        net_cash_payment = 0.0
376
377
        try:
378
            payments = self._unpaid_dividends[next_trading_day]
379
            # Mark these dividends as paid by dropping them from our unpaid
380
            del self._unpaid_dividends[next_trading_day]
381
        except KeyError:
382
            payments = []
383
384
        # representing the fact that we're required to reimburse the owner of
385
        # the stock for any dividends paid while borrowing.
386
        for payment in payments:
387
            net_cash_payment += payment['amount']
388
389
        # Add stock for any stock dividends paid.  Again, the values here may
390
        # be negative in the case of short positions.
391
392
        try:
393
            stock_payments = self._unpaid_stock_dividends[next_trading_day]
394
        except:
395
            stock_payments = []
396
397
        for stock_payment in stock_payments:
398
            stock = stock_payment['payment_sid']
399
            share_count = stock_payment['share_count']
400
            # note we create a Position for stock dividend if we don't
401
            # already own the asset
402
            if stock in self.positions:
403
                position = self.positions[stock]
404
            else:
405
                position = self.positions[stock] = Position(stock)
406
407
            position.amount += share_count
408
            self._update_asset(stock)
409
410
        return net_cash_payment
411
412
    def maybe_create_close_position_transaction(self, event):
413
        if not self.positions.get(event.sid):
414
            return None
415
416
        amount = self.positions.get(event.sid).amount
417
        price = self._data_portal.get_spot_value(event.sid, 'close', event.dt)
418
419
        txn = Transaction(
420
            sid=event.sid,
421
            amount=(-1 * amount),
422
            dt=event.dt,
423
            price=price,
424
            commission=0,
425
            order_id=0
426
        )
427
        return txn
428
429
    def get_positions(self):
430
        positions = self._positions_store
431
432
        for sid, pos in iteritems(self.positions):
433
            if pos.amount == 0:
434
                # Clear out the position if it has become empty since the last
435
                # time get_positions was called.  Catching the KeyError is
436
                # faster than checking `if sid in positions`, and this can be
437
                # potentially called in a tight inner loop.
438
                try:
439
                    del positions[sid]
440
                except KeyError:
441
                    pass
442
                continue
443
444
            # Note that this will create a position if we don't currently have
445
            # an entry
446
            position = positions[sid]
447
            position.amount = pos.amount
448
            position.cost_basis = pos.cost_basis
449
            position.last_sale_price = pos.last_sale_price
450
            position.last_sale_date = pos.last_sale_date
451
452
        return positions
453
454
    def get_positions_list(self):
455
        positions = []
456
        for sid, pos in iteritems(self.positions):
457
            if pos.amount != 0:
458
                positions.append(pos.to_dict())
459
        return positions
460
461
    def sync_last_sale_prices(self, dt):
462
        data_portal = self._data_portal
463
        for sid, position in iteritems(self.positions):
464
            position.last_sale_price = data_portal.get_spot_value(
465
                sid, 'close', dt)
466
467
    def stats(self):
468
        return calc_position_stats(self.positions,
469
                                   self._position_value_multipliers,
470
                                   self._position_exposure_multipliers)
471
472
    def __getstate__(self):
473
        state_dict = {}
474
475
        state_dict['asset_finder'] = self.asset_finder
476
        state_dict['positions'] = dict(self.positions)
477
        state_dict['unpaid_dividends'] = self._unpaid_dividends
478
        state_dict['unpaid_stock_dividends'] = self._unpaid_stock_dividends
479
        state_dict['auto_close_position_sids'] = self._auto_close_position_sids
480
481
        STATE_VERSION = 3
482
        state_dict[VERSION_LABEL] = STATE_VERSION
483
        return state_dict
484
485
    def __setstate__(self, state):
486
        OLDEST_SUPPORTED_STATE = 3
487
        version = state.pop(VERSION_LABEL)
488
489
        if version < OLDEST_SUPPORTED_STATE:
490
            raise BaseException("PositionTracker saved state is too old.")
491
492
        self.asset_finder = state['asset_finder']
493
        self.positions = positiondict()
494
        # note that positions_store is temporary and gets regened from
495
        # .positions
496
        self._positions_store = zp.Positions()
497
498
        self._unpaid_dividends = state['unpaid_dividends']
499
        self._unpaid_stock_dividends = state['unpaid_stock_dividends']
500
        self._auto_close_position_sids = state['auto_close_position_sids']
501
502
        # Arrays for quick calculations of positions value
503
        self._position_value_multipliers = OrderedDict()
504
        self._position_exposure_multipliers = OrderedDict()
505
        self._position_payout_multipliers = OrderedDict()
506
507
        # Update positions is called without a finder
508
        self.update_positions(state['positions'])
509
510
        # FIXME
511
        self._data_portal = None
512