Completed
Pull Request — master (#858)
by Eddie
02:03
created

zipline.SIDData   A

Complexity

Total Complexity 33

Size/Duplication

Total Lines 222
Duplicated Lines 0 %
Metric Value
dl 0
loc 222
rs 9.4
wmc 33
1
#
2
# Copyright 2013 Quantopian, Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
from copy import copy
16
17
import pandas as pd
18
19
from . utils.protocol_utils import Enum
20
21
from zipline.utils.serialization_utils import (
22
    VERSION_LABEL
23
)
24
25
# Datasource type should completely determine the other fields of a
26
# message with its type.
27
DATASOURCE_TYPE = Enum(
28
    'AS_TRADED_EQUITY',
29
    'MERGER',
30
    'SPLIT',
31
    'DIVIDEND',
32
    'TRADE',
33
    'TRANSACTION',
34
    'ORDER',
35
    'EMPTY',
36
    'DONE',
37
    'CUSTOM',
38
    'BENCHMARK',
39
    'COMMISSION',
40
    'CLOSE_POSITION'
41
)
42
43
# Expected fields/index values for a dividend Series.
44
DIVIDEND_FIELDS = [
45
    'declared_date',
46
    'ex_date',
47
    'gross_amount',
48
    'net_amount',
49
    'pay_date',
50
    'payment_sid',
51
    'ratio',
52
    'sid',
53
]
54
# Expected fields/index values for a dividend payment Series.
55
DIVIDEND_PAYMENT_FIELDS = [
56
    'id',
57
    'payment_sid',
58
    'cash_amount',
59
    'share_count',
60
]
61
62
63
class Event(object):
64
65
    def __init__(self, initial_values=None):
66
        if initial_values:
67
            self.__dict__ = initial_values
68
69
    def __getitem__(self, name):
70
        return getattr(self, name)
71
72
    def __setitem__(self, name, value):
73
        setattr(self, name, value)
74
75
    def __delitem__(self, name):
76
        delattr(self, name)
77
78
    def keys(self):
79
        return self.__dict__.keys()
80
81
    def __eq__(self, other):
82
        return hasattr(other, '__dict__') and self.__dict__ == other.__dict__
83
84
    def __contains__(self, name):
85
        return name in self.__dict__
86
87
    def __repr__(self):
88
        return "Event({0})".format(self.__dict__)
89
90
    def to_series(self, index=None):
91
        return pd.Series(self.__dict__, index=index)
92
93
94
class Order(Event):
95
    pass
96
97
98
class Portfolio(object):
99
100
    def __init__(self):
101
        self.capital_used = 0.0
102
        self.starting_cash = 0.0
103
        self.portfolio_value = 0.0
104
        self.pnl = 0.0
105
        self.returns = 0.0
106
        self.cash = 0.0
107
        self.positions = Positions()
108
        self.start_date = None
109
        self.positions_value = 0.0
110
111
    def __getitem__(self, key):
112
        return self.__dict__[key]
113
114
    def __repr__(self):
115
        return "Portfolio({0})".format(self.__dict__)
116
117
    def __getstate__(self):
118
119
        state_dict = copy(self.__dict__)
120
121
        # Have to convert to primitive dict
122
        state_dict['positions'] = dict(self.positions)
123
124
        STATE_VERSION = 1
125
        state_dict[VERSION_LABEL] = STATE_VERSION
126
127
        return state_dict
128
129
    def __setstate__(self, state):
130
131
        OLDEST_SUPPORTED_STATE = 1
132
        version = state.pop(VERSION_LABEL)
133
134
        if version < OLDEST_SUPPORTED_STATE:
135
            raise BaseException("Portfolio saved state is too old.")
136
137
        self.positions = Positions()
138
        self.positions.update(state.pop('positions'))
139
140
        self.__dict__.update(state)
141
142
143
class Account(object):
144
    '''
145
    The account object tracks information about the trading account. The
146
    values are updated as the algorithm runs and its keys remain unchanged.
147
    If connected to a broker, one can update these values with the trading
148
    account values as reported by the broker.
149
    '''
150
151
    def __init__(self):
152
        self.settled_cash = 0.0
153
        self.accrued_interest = 0.0
154
        self.buying_power = float('inf')
155
        self.equity_with_loan = 0.0
156
        self.total_positions_value = 0.0
157
        self.regt_equity = 0.0
158
        self.regt_margin = float('inf')
159
        self.initial_margin_requirement = 0.0
160
        self.maintenance_margin_requirement = 0.0
161
        self.available_funds = 0.0
162
        self.excess_liquidity = 0.0
163
        self.cushion = 0.0
164
        self.day_trades_remaining = float('inf')
165
        self.leverage = 0.0
166
        self.net_leverage = 0.0
167
        self.net_liquidation = 0.0
168
169
    def __getitem__(self, key):
170
        return self.__dict__[key]
171
172
    def __repr__(self):
173
        return "Account({0})".format(self.__dict__)
174
175
    def __getstate__(self):
176
177
        state_dict = copy(self.__dict__)
178
179
        STATE_VERSION = 1
180
        state_dict[VERSION_LABEL] = STATE_VERSION
181
182
        return state_dict
183
184
    def __setstate__(self, state):
185
186
        OLDEST_SUPPORTED_STATE = 1
187
        version = state.pop(VERSION_LABEL)
188
189
        if version < OLDEST_SUPPORTED_STATE:
190
            raise BaseException("Account saved state is too old.")
191
192
        self.__dict__.update(state)
193
194
195
class Position(object):
196
197
    def __init__(self, sid):
198
        self.sid = sid
199
        self.amount = 0
200
        self.cost_basis = 0.0  # per share
201
        self.last_sale_price = 0.0
202
203
    def __getitem__(self, key):
204
        return self.__dict__[key]
205
206
    def __repr__(self):
207
        return "Position({0})".format(self.__dict__)
208
209
    def __getstate__(self):
210
        state_dict = copy(self.__dict__)
211
212
        STATE_VERSION = 1
213
        state_dict[VERSION_LABEL] = STATE_VERSION
214
215
        return state_dict
216
217
    def __setstate__(self, state):
218
219
        OLDEST_SUPPORTED_STATE = 1
220
        version = state.pop(VERSION_LABEL)
221
222
        if version < OLDEST_SUPPORTED_STATE:
223
            raise BaseException("Protocol Position saved state is too old.")
224
225
        self.__dict__.update(state)
226
227
228
class Positions(dict):
229
230
    def __missing__(self, key):
231
        pos = Position(key)
232
        self[key] = pos
233
        return pos
234
235
236
class BarData(object):
237
    """
238
    Holds the event data for all sids for a given dt.
239
240
    This is what is passed as `data` to the `handle_data` function.
241
    """
242
243
    def __init__(self, data_portal=None):
244
        self.data_portal = data_portal or {}
245
246
    def __getitem__(self, name):
247
        return self.data_portal.get_equity_price_view(name)
248
249
    def __iter__(self):
250
        raise TypeError('%r object is not iterable'
251
                        % self.__class__.__name__)
252
253
    @property
254
    def fetcher_assets(self):
255
        return self.data_portal.get_fetcher_assets()
256