Completed
Pull Request — master (#858)
by Eddie
10:07 queued 01:13
created

execute_transaction()   A

Complexity

Conditions 2

Size

Total Lines 13

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 2
dl 0
loc 13
rs 9.4286
1
#
2
# Copyright 2015 Quantopian, Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
from __future__ import division
17
18
import logbook
19
import numpy as np
20
from collections import namedtuple
21
from zipline.finance.performance.position import Position
22
from zipline.finance.transaction import Transaction
23
24
try:
25
    # optional cython based OrderedDict
26
    from cyordereddict import OrderedDict
27
except ImportError:
28
    from collections import OrderedDict
29
from six import iteritems, itervalues
30
31
from zipline.protocol import Event, DATASOURCE_TYPE
32
from zipline.utils.serialization_utils import (
33
    VERSION_LABEL
34
)
35
36
import zipline.protocol as zp
37
from zipline.assets import (
38
    Equity, Future
39
)
40
from zipline.errors import PositionTrackerMissingAssetFinder
41
from . position import positiondict
42
43
log = logbook.Logger('Performance')
44
45
46
PositionStats = namedtuple('PositionStats',
47
                           ['net_exposure',
48
                            'gross_value',
49
                            'gross_exposure',
50
                            'short_value',
51
                            'short_exposure',
52
                            'shorts_count',
53
                            'long_value',
54
                            'long_exposure',
55
                            'longs_count',
56
                            'net_value'])
57
58
59
def calc_position_values(amounts,
60
                         last_sale_prices,
61
                         value_multipliers):
62
    iter_amount_price_multiplier = zip(
63
        amounts,
64
        last_sale_prices,
65
        itervalues(value_multipliers),
66
    )
67
    return [
68
        price * amount * multiplier for
69
        price, amount, multiplier in iter_amount_price_multiplier
70
    ]
71
72
73
def calc_net(values):
74
    # Returns 0.0 if there are no values.
75
    return sum(values, np.float64())
76
77
78
def calc_position_exposures(amounts,
79
                            last_sale_prices,
80
                            exposure_multipliers):
81
    iter_amount_price_multiplier = zip(
82
        amounts,
83
        last_sale_prices,
84
        itervalues(exposure_multipliers),
85
    )
86
    return [
87
        price * amount * multiplier for
88
        price, amount, multiplier in iter_amount_price_multiplier
89
    ]
90
91
92
def calc_long_value(position_values):
93
    return sum(i for i in position_values if i > 0)
94
95
96
def calc_short_value(position_values):
97
    return sum(i for i in position_values if i < 0)
98
99
100
def calc_long_exposure(position_exposures):
101
    return sum(i for i in position_exposures if i > 0)
102
103
104
def calc_short_exposure(position_exposures):
105
    return sum(i for i in position_exposures if i < 0)
106
107
108
def calc_longs_count(position_exposures):
109
    return sum(1 for i in position_exposures if i > 0)
110
111
112
def calc_shorts_count(position_exposures):
113
    return sum(1 for i in position_exposures if i < 0)
114
115
116
def calc_gross_exposure(long_exposure, short_exposure):
117
    return long_exposure + abs(short_exposure)
118
119
120
def calc_gross_value(long_value, short_value):
121
    return long_value + abs(short_value)
122
123
124
class PositionTracker(object):
125
126
    def __init__(self, asset_finder, data_portal, data_frequency):
127
        self.asset_finder = asset_finder
128
129
        # FIXME really want to avoid storing a data portal here,
130
        # but the path to get to maybe_create_close_position_transaction
131
        # is long and tortuous
132
        self._data_portal = data_portal
133
134
        # sid => position object
135
        self.positions = positiondict()
136
        # Arrays for quick calculations of positions value
137
        self._position_value_multipliers = OrderedDict()
138
        self._position_exposure_multipliers = OrderedDict()
139
        self._unpaid_dividends = {}
140
        self._unpaid_stock_dividends = {}
141
        self._positions_store = zp.Positions()
142
143
        # Dict, keyed on dates, that contains lists of close position events
144
        # for any Assets in this tracker's positions
145
        self._auto_close_position_sids = {}
146
147
        self.data_frequency = data_frequency
148
149
    def _update_asset(self, sid):
150
        try:
151
            self._position_value_multipliers[sid]
152
            self._position_exposure_multipliers[sid]
153
        except KeyError:
154
            # Check if there is an AssetFinder
155
            if self.asset_finder is None:
156
                raise PositionTrackerMissingAssetFinder()
157
158
            # Collect the value multipliers from applicable sids
159
            asset = self.asset_finder.retrieve_asset(sid)
160
            if isinstance(asset, Equity):
161
                self._position_value_multipliers[sid] = 1
162
                self._position_exposure_multipliers[sid] = 1
163
            if isinstance(asset, Future):
164
                self._position_value_multipliers[sid] = 0
165
                self._position_exposure_multipliers[sid] = \
166
                    asset.contract_multiplier
167
                # Futures auto-close timing is controlled by the Future's
168
                # auto_close_date property
169
                self._insert_auto_close_position_date(
170
                    dt=asset.auto_close_date,
171
                    sid=sid
172
                )
173
174
    def _insert_auto_close_position_date(self, dt, sid):
175
        """
176
        Inserts the given SID in to the list of positions to be auto-closed by
177
        the given dt.
178
179
        Parameters
180
        ----------
181
        dt : pandas.Timestamp
182
            The date before-which the given SID will be auto-closed
183
        sid : int
184
            The SID of the Asset to be auto-closed
185
        """
186
        if dt is not None:
187
            self._auto_close_position_sids.setdefault(dt, set()).add(sid)
188
189
    def auto_close_position_events(self, next_trading_day):
190
        """
191
        Generates CLOSE_POSITION events for any SIDs whose auto-close date is
192
        before or equal to the given date.
193
194
        Parameters
195
        ----------
196
        next_trading_day : pandas.Timestamp
197
            The time before-which certain Assets need to be closed
198
199
        Yields
200
        ------
201
        Event
202
            A close position event for any sids that should be closed before
203
            the next_trading_day parameter
204
        """
205
        past_asset_end_dates = set()
206
207
        # Check the auto_close_position_dates dict for SIDs to close
208
        for date, sids in self._auto_close_position_sids.items():
209
            if date > next_trading_day:
210
                continue
211
            past_asset_end_dates.add(date)
212
213
            for sid in sids:
214
                # Yield a CLOSE_POSITION event
215
                event = Event({
216
                    'dt': date,
217
                    'type': DATASOURCE_TYPE.CLOSE_POSITION,
218
                    'sid': sid,
219
                })
220
                yield event
221
222
        # Clear out past dates
223
        while past_asset_end_dates:
224
            self._auto_close_position_sids.pop(past_asset_end_dates.pop())
225
226
    def update_positions(self, positions):
227
        # update positions in batch
228
        self.positions.update(positions)
229
        for sid, pos in iteritems(positions):
230
            self._update_asset(sid)
231
232
    def update_position(self, sid, amount=None, last_sale_price=None,
233
                        last_sale_date=None, cost_basis=None):
234
        if sid not in self.positions:
235
            position = Position(sid)
236
            self.positions[sid] = position
237
        else:
238
            position = self.positions[sid]
239
240
        if amount is not None:
241
            position.amount = amount
242
            self._update_asset(sid=sid)
243
        if last_sale_price is not None:
244
            position.last_sale_price = last_sale_price
245
        if last_sale_date is not None:
246
            position.last_sale_date = last_sale_date
247
        if cost_basis is not None:
248
            position.cost_basis = cost_basis
249
250
    def execute_transaction(self, txn):
251
        # Update Position
252
        # ----------------
253
        sid = txn.sid
254
255
        if sid not in self.positions:
256
            position = Position(sid)
257
            self.positions[sid] = position
258
        else:
259
            position = self.positions[sid]
260
261
        position.update(txn)
262
        self._update_asset(sid)
263
264
    def handle_commission(self, sid, cost):
265
        # Adjust the cost basis of the stock if we own it
266
        if sid in self.positions:
267
            self.positions[sid].adjust_commission_cost_basis(sid, cost)
268
269
    def handle_splits(self, splits):
270
        """
271
        Processes a list of splits by modifying any positions as needed.
272
273
        Parameters
274
        ----------
275
        splits: list
276
            A list of splits.  Each split is a tuple of (sid, ratio).
277
278
        Returns
279
        -------
280
        None
281
        """
282
        for split in splits:
283
            sid = split[0]
284
            if sid in self.positions:
285
                # Make the position object handle the split. It returns the
286
                # leftover cash from a fractional share, if there is any.
287
                position = self.positions[sid]
288
                leftover_cash = position.handle_split(sid, split[1])
289
                self._update_asset(split[0])
290
                return leftover_cash
291
292
    def earn_dividends(self, dividends, stock_dividends):
293
        """
294
        Given a list of dividends whose ex_dates are all the next trading day,
295
        calculate and store the cash and/or stock payments to be paid on each
296
        dividend's pay date.
297
        """
298
        for dividend in dividends:
299
            # Store the earned dividends so that they can be paid on the
300
            # dividends' pay_dates.
301
            div_owed = self.positions[dividend.sid].earn_dividend(dividend)
302
            try:
303
                self._unpaid_dividends[dividend.pay_date].append(
304
                    div_owed)
305
            except KeyError:
306
                self._unpaid_dividends[dividend.pay_date] = [div_owed]
307
308
        for stock_dividend in stock_dividends:
309
            div_owed = self.positions[stock_dividend.sid].earn_stock_dividend(
310
                stock_dividend)
311
            try:
312
                self._unpaid_stock_dividends[stock_dividend.pay_date].\
313
                    append(div_owed)
314
            except KeyError:
315
                self._unpaid_stock_dividends[stock_dividend.pay_date] = \
316
                    [div_owed]
317
318
    def pay_dividends(self, next_trading_day):
319
        """
320
        Returns a cash payment based on the dividends that should be paid out
321
        according to the accumulated bookkeeping of earned, unpaid, and stock
322
        dividends.
323
        """
324
        net_cash_payment = 0.0
325
326
        try:
327
            payments = self._unpaid_dividends[next_trading_day]
328
            # Mark these dividends as paid by dropping them from our unpaid
329
            del self._unpaid_dividends[next_trading_day]
330
        except KeyError:
331
            payments = []
332
333
        # representing the fact that we're required to reimburse the owner of
334
        # the stock for any dividends paid while borrowing.
335
        for payment in payments:
336
            net_cash_payment += payment['amount']
337
338
        # Add stock for any stock dividends paid.  Again, the values here may
339
        # be negative in the case of short positions.
340
341
        try:
342
            stock_payments = self._unpaid_stock_dividends[next_trading_day]
343
        except:
344
            stock_payments = []
345
346
        for stock_payment in stock_payments:
347
            stock = stock_payment['payment_sid']
348
            share_count = stock_payment['share_count']
349
            # note we create a Position for stock dividend if we don't
350
            # already own the asset
351
            if stock in self.positions:
352
                position = self.positions[stock]
353
            else:
354
                position = self.positions[stock] = Position(stock)
355
356
            position.amount += share_count
357
            self._update_asset(stock)
358
359
        return net_cash_payment
360
361
    def maybe_create_close_position_transaction(self, event):
362
        if not self.positions.get(event.sid):
363
            return None
364
365
        amount = self.positions.get(event.sid).amount
366
        price = self._data_portal.get_spot_value(
367
            event.sid, 'close', event.dt, self.data_frequency)
368
369
        txn = Transaction(
370
            sid=event.sid,
371
            amount=(-1 * amount),
372
            dt=event.dt,
373
            price=price,
374
            commission=0,
375
            order_id=0
376
        )
377
        return txn
378
379
    def get_positions(self):
380
381
        positions = self._positions_store
382
383
        for sid, pos in iteritems(self.positions):
384
385
            if pos.amount == 0:
386
                # Clear out the position if it has become empty since the last
387
                # time get_positions was called.  Catching the KeyError is
388
                # faster than checking `if sid in positions`, and this can be
389
                # potentially called in a tight inner loop.
390
                try:
391
                    del positions[sid]
392
                except KeyError:
393
                    pass
394
                continue
395
396
            # Note that this will create a position if we don't currently have
397
            # an entry
398
            position = positions[sid]
399
            position.amount = pos.amount
400
            position.cost_basis = pos.cost_basis
401
            position.last_sale_price = pos.last_sale_price
402
            position.last_sale_date = pos.last_sale_date
403
404
        return positions
405
406
    def get_positions_list(self):
407
        positions = []
408
        for sid, pos in iteritems(self.positions):
409
            if pos.amount != 0:
410
                positions.append(pos.to_dict())
411
        return positions
412
413
    def sync_last_sale_prices(self, dt):
414
        data_portal = self._data_portal
415
        for sid, position in iteritems(self.positions):
416
            position.last_sale_price = data_portal.get_spot_value(
417
                sid, 'close', dt, self.data_frequency)
418
419
    def stats(self):
420
        amounts = []
421
        last_sale_prices = []
422
        for pos in itervalues(self.positions):
423
            amounts.append(pos.amount)
424
            last_sale_prices.append(pos.last_sale_price)
425
426
        position_values = calc_position_values(
427
            amounts,
428
            last_sale_prices,
429
            self._position_value_multipliers
430
        )
431
432
        position_exposures = calc_position_exposures(
433
            amounts,
434
            last_sale_prices,
435
            self._position_exposure_multipliers
436
        )
437
438
        long_value = calc_long_value(position_values)
439
        short_value = calc_short_value(position_values)
440
        gross_value = calc_gross_value(long_value, short_value)
441
        long_exposure = calc_long_exposure(position_exposures)
442
        short_exposure = calc_short_exposure(position_exposures)
443
        gross_exposure = calc_gross_exposure(long_exposure, short_exposure)
444
        net_exposure = calc_net(position_exposures)
445
        longs_count = calc_longs_count(position_exposures)
446
        shorts_count = calc_shorts_count(position_exposures)
447
        net_value = calc_net(position_values)
448
449
        return PositionStats(
450
            long_value=long_value,
451
            gross_value=gross_value,
452
            short_value=short_value,
453
            long_exposure=long_exposure,
454
            short_exposure=short_exposure,
455
            gross_exposure=gross_exposure,
456
            net_exposure=net_exposure,
457
            longs_count=longs_count,
458
            shorts_count=shorts_count,
459
            net_value=net_value
460
        )
461
462
    def __getstate__(self):
463
        state_dict = {}
464
465
        state_dict['asset_finder'] = self.asset_finder
466
        state_dict['positions'] = dict(self.positions)
467
        state_dict['unpaid_dividends'] = self._unpaid_dividends
468
        state_dict['unpaid_stock_dividends'] = self._unpaid_stock_dividends
469
        state_dict['auto_close_position_sids'] = self._auto_close_position_sids
470
        state_dict['data_frequency'] = self.data_frequency
471
472
        STATE_VERSION = 3
473
        state_dict[VERSION_LABEL] = STATE_VERSION
474
        return state_dict
475
476
    def __setstate__(self, state):
477
        OLDEST_SUPPORTED_STATE = 3
478
        version = state.pop(VERSION_LABEL)
479
480
        if version < OLDEST_SUPPORTED_STATE:
481
            raise BaseException("PositionTracker saved state is too old.")
482
483
        self.asset_finder = state['asset_finder']
484
        self.positions = positiondict()
485
        self.data_frequency = state['data_frequency']
486
        # note that positions_store is temporary and gets regened from
487
        # .positions
488
        self._positions_store = zp.Positions()
489
490
        self._unpaid_dividends = state['unpaid_dividends']
491
        self._unpaid_stock_dividends = state['unpaid_stock_dividends']
492
        self._auto_close_position_sids = state['auto_close_position_sids']
493
494
        # Arrays for quick calculations of positions value
495
        self._position_value_multipliers = OrderedDict()
496
        self._position_exposure_multipliers = OrderedDict()
497
498
        # Update positions is called without a finder
499
        self.update_positions(state['positions'])
500
501
        # FIXME
502
        self._data_portal = None
503