Completed
Pull Request — master (#858)
by Eddie
10:07 queued 01:13
created

earn_dividend()   A

Complexity

Conditions 1

Size

Total Lines 9

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 1
dl 0
loc 9
rs 9.6667
1
#
2
# Copyright 2014 Quantopian, Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
"""
17
Position Tracking
18
=================
19
20
    +-----------------+----------------------------------------------------+
21
    | key             | value                                              |
22
    +=================+====================================================+
23
    | sid             | the sid for the asset held in this position        |
24
    +-----------------+----------------------------------------------------+
25
    | amount          | whole number of shares in the position             |
26
    +-----------------+----------------------------------------------------+
27
    | last_sale_price | price at last sale of the asset on the exchange    |
28
    +-----------------+----------------------------------------------------+
29
    | cost_basis      | the volume weighted average price paid per share   |
30
    +-----------------+----------------------------------------------------+
31
32
"""
33
34
from __future__ import division
35
from math import (
36
    copysign,
37
    floor,
38
)
39
40
from copy import copy
41
42
import logbook
43
44
from zipline.utils.serialization_utils import (
45
    VERSION_LABEL
46
)
47
48
log = logbook.Logger('Performance')
49
50
51
class Position(object):
52
53
    def __init__(self, sid, amount=0, cost_basis=0.0,
54
                 last_sale_price=0.0, last_sale_date=None):
55
56
        self.sid = sid
57
        self.amount = amount
58
        self.cost_basis = cost_basis  # per share
59
        self.last_sale_price = last_sale_price
60
        self.last_sale_date = last_sale_date
61
62
    def earn_dividend(self, dividend):
63
        """
64
        Register the number of shares we held at this dividend's ex date so
65
        that we can pay out the correct amount on the dividend's pay date.
66
        """
67
        out = {}
68
69
        out['amount'] = self.amount * dividend.amount
70
        return out
71
72
    def earn_stock_dividend(self, stock_dividend):
73
        """
74
        Register the number of shares we held at this dividend's ex date so
75
        that we can pay out the correct amount on the dividend's pay date.
76
        """
77
        out = {}
78
79
        # stock dividend
80
        out['payment_sid'] = stock_dividend.payment_sid
81
        out['share_count'] = floor(
82
            self.amount * float(stock_dividend.ratio))
83
84
        return out
85
86
    def handle_split(self, sid, ratio):
87
        """
88
        Update the position by the split ratio, and return the resulting
89
        fractional share that will be converted into cash.
90
91
        Returns the unused cash.
92
        """
93
        if self.sid != sid:
94
            raise Exception("updating split with the wrong sid!")
95
96
        log.info("handling split for sid = " + str(sid) +
97
                 ", ratio = " + str(ratio))
98
        log.info("before split: " + str(self))
99
100
        # adjust the # of shares by the ratio
101
        # (if we had 100 shares, and the ratio is 3,
102
        #  we now have 33 shares)
103
        # (old_share_count / ratio = new_share_count)
104
        # (old_price * ratio = new_price)
105
106
        # e.g., 33.333
107
        raw_share_count = self.amount / float(ratio)
108
109
        # e.g., 33
110
        full_share_count = floor(raw_share_count)
111
112
        # e.g., 0.333
113
        fractional_share_count = raw_share_count - full_share_count
114
115
        # adjust the cost basis to the nearest cent, e.g., 60.0
116
        new_cost_basis = round(self.cost_basis * ratio, 2)
117
118
        self.cost_basis = new_cost_basis
119
        self.amount = full_share_count
120
121
        return_cash = round(float(fractional_share_count * new_cost_basis), 2)
122
123
        log.info("after split: " + str(self))
124
        log.info("returning cash: " + str(return_cash))
125
126
        # return the leftover cash, which will be converted into cash
127
        # (rounded to the nearest cent)
128
        return return_cash
129
130
    def update(self, txn):
131
        if self.sid != txn.sid:
132
            raise Exception('updating position with txn for a '
133
                            'different sid')
134
135
        total_shares = self.amount + txn.amount
136
137
        if total_shares == 0:
138
            self.cost_basis = 0.0
139
        else:
140
            prev_direction = copysign(1, self.amount)
141
            txn_direction = copysign(1, txn.amount)
142
143
            if prev_direction != txn_direction:
144
                # we're covering a short or closing a position
145
                if abs(txn.amount) > abs(self.amount):
146
                    # we've closed the position and gone short
147
                    # or covered the short position and gone long
148
                    self.cost_basis = txn.price
149
            else:
150
                prev_cost = self.cost_basis * self.amount
151
                txn_cost = txn.amount * txn.price
152
                total_cost = prev_cost + txn_cost
153
                self.cost_basis = total_cost / total_shares
154
155
            # Update the last sale price if txn is
156
            # best data we have so far
157
            if self.last_sale_date is None or txn.dt > self.last_sale_date:
158
                self.last_sale_price = txn.price
159
                self.last_sale_date = txn.dt
160
161
        self.amount = total_shares
162
163
    def adjust_commission_cost_basis(self, sid, cost):
164
        """
165
        A note about cost-basis in zipline: all positions are considered
166
        to share a cost basis, even if they were executed in different
167
        transactions with different commission costs, different prices, etc.
168
169
        Due to limitations about how zipline handles positions, zipline will
170
        currently spread an externally-delivered commission charge across
171
        all shares in a position.
172
        """
173
174
        if sid != self.sid:
175
            raise Exception('Updating a commission for a different sid?')
176
        if cost == 0.0:
177
            return
178
179
        # If we no longer hold this position, there is no cost basis to
180
        # adjust.
181
        if self.amount == 0:
182
            return
183
184
        prev_cost = self.cost_basis * self.amount
185
        new_cost = prev_cost + cost
186
        self.cost_basis = new_cost / self.amount
187
188
    def __repr__(self):
189
        template = "sid: {sid}, amount: {amount}, cost_basis: {cost_basis}, \
190
last_sale_price: {last_sale_price}"
191
        return template.format(
192
            sid=self.sid,
193
            amount=self.amount,
194
            cost_basis=self.cost_basis,
195
            last_sale_price=self.last_sale_price
196
        )
197
198
    def to_dict(self):
199
        """
200
        Creates a dictionary representing the state of this position.
201
        Returns a dict object of the form:
202
        """
203
        return {
204
            'sid': self.sid,
205
            'amount': self.amount,
206
            'cost_basis': self.cost_basis,
207
            'last_sale_price': self.last_sale_price
208
        }
209
210
    def __getstate__(self):
211
        state_dict = copy(self.__dict__)
212
213
        STATE_VERSION = 1
214
        state_dict[VERSION_LABEL] = STATE_VERSION
215
216
        return state_dict
217
218
    def __setstate__(self, state):
219
220
        OLDEST_SUPPORTED_STATE = 1
221
        version = state.pop(VERSION_LABEL)
222
223
        if version < OLDEST_SUPPORTED_STATE:
224
            raise BaseException("Position saved state is too old.")
225
226
        self.__dict__.update(state)
227
228
229
class positiondict(dict):
230
    def __missing__(self, key):
231
        return None
232