1
|
|
|
# |
2
|
|
|
# Copyright 2015 Quantopian, Inc. |
3
|
|
|
# |
4
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
5
|
|
|
# you may not use this file except in compliance with the License. |
6
|
|
|
# You may obtain a copy of the License at |
7
|
|
|
# |
8
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0 |
9
|
|
|
# |
10
|
|
|
# Unless required by applicable law or agreed to in writing, software |
11
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS, |
12
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13
|
|
|
# See the License for the specific language governing permissions and |
14
|
|
|
# limitations under the License. |
15
|
|
|
from copy import copy |
16
|
|
|
|
17
|
|
|
import pytz |
18
|
|
|
import pandas as pd |
19
|
|
|
from pandas.tseries.tools import normalize_date |
20
|
|
|
import numpy as np |
21
|
|
|
|
22
|
|
|
from datetime import datetime |
23
|
|
|
from itertools import chain, repeat |
24
|
|
|
from numbers import Integral |
25
|
|
|
|
26
|
|
|
from six import ( |
27
|
|
|
exec_, |
28
|
|
|
iteritems, |
29
|
|
|
itervalues, |
30
|
|
|
string_types, |
31
|
|
|
) |
32
|
|
|
|
33
|
|
|
|
34
|
|
|
from zipline.errors import ( |
35
|
|
|
AttachPipelineAfterInitialize, |
36
|
|
|
HistoryInInitialize, |
37
|
|
|
NoSuchPipeline, |
38
|
|
|
OrderDuringInitialize, |
39
|
|
|
PipelineOutputDuringInitialize, |
40
|
|
|
RegisterAccountControlPostInit, |
41
|
|
|
RegisterTradingControlPostInit, |
42
|
|
|
SetBenchmarkOutsideInitialize, |
43
|
|
|
SetCommissionPostInit, |
44
|
|
|
SetSlippagePostInit, |
45
|
|
|
UnsupportedCommissionModel, |
46
|
|
|
UnsupportedDatetimeFormat, |
47
|
|
|
UnsupportedOrderParameters, |
48
|
|
|
UnsupportedSlippageModel, |
49
|
|
|
) |
50
|
|
|
from zipline.finance.trading import TradingEnvironment |
51
|
|
|
from zipline.finance.blotter import Blotter |
52
|
|
|
from zipline.finance.commission import PerShare, PerTrade, PerDollar |
53
|
|
|
from zipline.finance.controls import ( |
54
|
|
|
LongOnly, |
55
|
|
|
MaxOrderCount, |
56
|
|
|
MaxOrderSize, |
57
|
|
|
MaxPositionSize, |
58
|
|
|
MaxLeverage, |
59
|
|
|
RestrictedListOrder |
60
|
|
|
) |
61
|
|
|
from zipline.finance.execution import ( |
62
|
|
|
LimitOrder, |
63
|
|
|
MarketOrder, |
64
|
|
|
StopLimitOrder, |
65
|
|
|
StopOrder, |
66
|
|
|
) |
67
|
|
|
from zipline.finance.performance import PerformanceTracker |
68
|
|
|
from zipline.finance.slippage import ( |
69
|
|
|
VolumeShareSlippage, |
70
|
|
|
SlippageModel |
71
|
|
|
) |
72
|
|
|
from zipline.assets import Asset, Future |
73
|
|
|
from zipline.assets.futures import FutureChain |
74
|
|
|
from zipline.gens.tradesimulation import AlgorithmSimulator |
75
|
|
|
from zipline.pipeline.engine import ( |
76
|
|
|
NoOpPipelineEngine, |
77
|
|
|
SimplePipelineEngine, |
78
|
|
|
) |
79
|
|
|
from zipline.utils.api_support import ( |
80
|
|
|
api_method, |
81
|
|
|
require_initialized, |
82
|
|
|
require_not_initialized, |
83
|
|
|
ZiplineAPI, |
84
|
|
|
) |
85
|
|
|
from zipline.utils.input_validation import ensure_upper_case |
86
|
|
|
from zipline.utils.cache import CachedObject, Expired |
87
|
|
|
import zipline.utils.events |
88
|
|
|
from zipline.utils.events import ( |
89
|
|
|
EventManager, |
90
|
|
|
make_eventrule, |
91
|
|
|
DateRuleFactory, |
92
|
|
|
TimeRuleFactory, |
93
|
|
|
) |
94
|
|
|
from zipline.utils.factory import create_simulation_parameters |
95
|
|
|
from zipline.utils.math_utils import ( |
96
|
|
|
tolerant_equals, |
97
|
|
|
round_if_near_integer |
98
|
|
|
) |
99
|
|
|
from zipline.utils.preprocess import preprocess |
100
|
|
|
|
101
|
|
|
import zipline.protocol |
102
|
|
|
from zipline.sources.requests_csv import PandasRequestsCSV |
103
|
|
|
|
104
|
|
|
from zipline.gens.sim_engine import ( |
105
|
|
|
MinuteSimulationClock, |
106
|
|
|
DailySimulationClock, |
107
|
|
|
) |
108
|
|
|
from zipline.sources.benchmark_source import BenchmarkSource |
109
|
|
|
|
110
|
|
|
DEFAULT_CAPITAL_BASE = float("1.0e5") |
111
|
|
|
|
112
|
|
|
|
113
|
|
|
class TradingAlgorithm(object): |
114
|
|
|
""" |
115
|
|
|
Base class for trading algorithms. Inherit and overload |
116
|
|
|
initialize() and handle_data(data). |
117
|
|
|
|
118
|
|
|
A new algorithm could look like this: |
119
|
|
|
``` |
120
|
|
|
from zipline.api import order, symbol |
121
|
|
|
|
122
|
|
|
def initialize(context): |
123
|
|
|
context.sid = symbol('AAPL') |
124
|
|
|
context.amount = 100 |
125
|
|
|
|
126
|
|
|
def handle_data(context, data): |
127
|
|
|
sid = context.sid |
128
|
|
|
amount = context.amount |
129
|
|
|
order(sid, amount) |
130
|
|
|
``` |
131
|
|
|
To then to run this algorithm pass these functions to |
132
|
|
|
TradingAlgorithm: |
133
|
|
|
|
134
|
|
|
my_algo = TradingAlgorithm(initialize, handle_data) |
135
|
|
|
stats = my_algo.run(data) |
136
|
|
|
|
137
|
|
|
""" |
138
|
|
|
|
139
|
|
|
def __init__(self, *args, **kwargs): |
140
|
|
|
"""Initialize sids and other state variables. |
141
|
|
|
|
142
|
|
|
:Arguments: |
143
|
|
|
:Optional: |
144
|
|
|
initialize : function |
145
|
|
|
Function that is called with a single |
146
|
|
|
argument at the begninning of the simulation. |
147
|
|
|
handle_data : function |
148
|
|
|
Function that is called with 2 arguments |
149
|
|
|
(context and data) on every bar. |
150
|
|
|
script : str |
151
|
|
|
Algoscript that contains initialize and |
152
|
|
|
handle_data function definition. |
153
|
|
|
data_frequency : {'daily', 'minute'} |
154
|
|
|
The duration of the bars. |
155
|
|
|
capital_base : float <default: 1.0e5> |
156
|
|
|
How much capital to start with. |
157
|
|
|
asset_finder : An AssetFinder object |
158
|
|
|
A new AssetFinder object to be used in this TradingEnvironment |
159
|
|
|
equities_metadata : can be either: |
160
|
|
|
- dict |
161
|
|
|
- pandas.DataFrame |
162
|
|
|
- object with 'read' property |
163
|
|
|
If dict is provided, it must have the following structure: |
164
|
|
|
* keys are the identifiers |
165
|
|
|
* values are dicts containing the metadata, with the metadata |
166
|
|
|
field name as the key |
167
|
|
|
If pandas.DataFrame is provided, it must have the |
168
|
|
|
following structure: |
169
|
|
|
* column names must be the metadata fields |
170
|
|
|
* index must be the different asset identifiers |
171
|
|
|
* array contents should be the metadata value |
172
|
|
|
If an object with a 'read' property is provided, 'read' must |
173
|
|
|
return rows containing at least one of 'sid' or 'symbol' along |
174
|
|
|
with the other metadata fields. |
175
|
|
|
identifiers : List |
176
|
|
|
Any asset identifiers that are not provided in the |
177
|
|
|
equities_metadata, but will be traded by this TradingAlgorithm |
178
|
|
|
""" |
179
|
|
|
self.sources = [] |
180
|
|
|
|
181
|
|
|
# List of trading controls to be used to validate orders. |
182
|
|
|
self.trading_controls = [] |
183
|
|
|
|
184
|
|
|
# List of account controls to be checked on each bar. |
185
|
|
|
self.account_controls = [] |
186
|
|
|
|
187
|
|
|
self._recorded_vars = {} |
188
|
|
|
self.namespace = kwargs.pop('namespace', {}) |
189
|
|
|
|
190
|
|
|
self._platform = kwargs.pop('platform', 'zipline') |
191
|
|
|
|
192
|
|
|
self.logger = None |
193
|
|
|
|
194
|
|
|
self.data_portal = None |
195
|
|
|
|
196
|
|
|
# If an env has been provided, pop it |
197
|
|
|
self.trading_environment = kwargs.pop('env', None) |
198
|
|
|
|
199
|
|
|
if self.trading_environment is None: |
200
|
|
|
self.trading_environment = TradingEnvironment() |
201
|
|
|
|
202
|
|
|
# Update the TradingEnvironment with the provided asset metadata |
203
|
|
|
self.trading_environment.write_data( |
204
|
|
|
equities_data=kwargs.pop('equities_metadata', {}), |
205
|
|
|
equities_identifiers=kwargs.pop('identifiers', []), |
206
|
|
|
futures_data=kwargs.pop('futures_metadata', {}), |
207
|
|
|
) |
208
|
|
|
|
209
|
|
|
# set the capital base |
210
|
|
|
self.capital_base = kwargs.pop('capital_base', DEFAULT_CAPITAL_BASE) |
211
|
|
|
self.sim_params = kwargs.pop('sim_params', None) |
212
|
|
|
if self.sim_params is None: |
213
|
|
|
self.sim_params = create_simulation_parameters( |
214
|
|
|
capital_base=self.capital_base, |
215
|
|
|
start=kwargs.pop('start', None), |
216
|
|
|
end=kwargs.pop('end', None), |
217
|
|
|
env=self.trading_environment, |
218
|
|
|
) |
219
|
|
|
else: |
220
|
|
|
self.sim_params.update_internal_from_env(self.trading_environment) |
221
|
|
|
|
222
|
|
|
self.perf_tracker = None |
223
|
|
|
# Pull in the environment's new AssetFinder for quick reference |
224
|
|
|
self.asset_finder = self.trading_environment.asset_finder |
225
|
|
|
|
226
|
|
|
# Initialize Pipeline API data. |
227
|
|
|
self.init_engine(kwargs.pop('get_pipeline_loader', None)) |
228
|
|
|
self._pipelines = {} |
229
|
|
|
# Create an always-expired cache so that we compute the first time data |
230
|
|
|
# is requested. |
231
|
|
|
self._pipeline_cache = CachedObject(None, pd.Timestamp(0, tz='UTC')) |
232
|
|
|
|
233
|
|
|
self.blotter = kwargs.pop('blotter', None) |
234
|
|
|
if not self.blotter: |
235
|
|
|
self.blotter = Blotter( |
236
|
|
|
slippage_func=VolumeShareSlippage(), |
237
|
|
|
commission=PerShare(), |
238
|
|
|
data_frequency=self.data_frequency, |
239
|
|
|
) |
240
|
|
|
|
241
|
|
|
# The symbol lookup date specifies the date to use when resolving |
242
|
|
|
# symbols to sids, and can be set using set_symbol_lookup_date() |
243
|
|
|
self._symbol_lookup_date = None |
244
|
|
|
|
245
|
|
|
self.portfolio_needs_update = True |
246
|
|
|
self.account_needs_update = True |
247
|
|
|
self.performance_needs_update = True |
248
|
|
|
self._portfolio = None |
249
|
|
|
self._account = None |
250
|
|
|
|
251
|
|
|
# If string is passed in, execute and get reference to |
252
|
|
|
# functions. |
253
|
|
|
self.algoscript = kwargs.pop('script', None) |
254
|
|
|
|
255
|
|
|
self._initialize = None |
256
|
|
|
self._before_trading_start = None |
257
|
|
|
self._analyze = None |
258
|
|
|
|
259
|
|
|
self.event_manager = EventManager( |
260
|
|
|
create_context=kwargs.pop('create_event_context', None), |
261
|
|
|
) |
262
|
|
|
|
263
|
|
|
if self.algoscript is not None: |
264
|
|
|
filename = kwargs.pop('algo_filename', None) |
265
|
|
|
if filename is None: |
266
|
|
|
filename = '<string>' |
267
|
|
|
code = compile(self.algoscript, filename, 'exec') |
268
|
|
|
exec_(code, self.namespace) |
269
|
|
|
self._initialize = self.namespace.get('initialize') |
270
|
|
|
if 'handle_data' not in self.namespace: |
271
|
|
|
raise ValueError('You must define a handle_data function.') |
272
|
|
|
else: |
273
|
|
|
self._handle_data = self.namespace['handle_data'] |
274
|
|
|
|
275
|
|
|
self._before_trading_start = \ |
276
|
|
|
self.namespace.get('before_trading_start') |
277
|
|
|
# Optional analyze function, gets called after run |
278
|
|
|
self._analyze = self.namespace.get('analyze') |
279
|
|
|
|
280
|
|
|
elif kwargs.get('initialize') and kwargs.get('handle_data'): |
281
|
|
|
if self.algoscript is not None: |
282
|
|
|
raise ValueError('You can not set script and \ |
283
|
|
|
initialize/handle_data.') |
284
|
|
|
self._initialize = kwargs.pop('initialize') |
285
|
|
|
self._handle_data = kwargs.pop('handle_data') |
286
|
|
|
self._before_trading_start = kwargs.pop('before_trading_start', |
287
|
|
|
None) |
288
|
|
|
self._analyze = kwargs.pop('analyze', None) |
289
|
|
|
|
290
|
|
|
self.event_manager.add_event( |
291
|
|
|
zipline.utils.events.Event( |
292
|
|
|
zipline.utils.events.Always(), |
293
|
|
|
# We pass handle_data.__func__ to get the unbound method. |
294
|
|
|
# We will explicitly pass the algorithm to bind it again. |
295
|
|
|
self.handle_data.__func__, |
296
|
|
|
), |
297
|
|
|
prepend=True, |
298
|
|
|
) |
299
|
|
|
|
300
|
|
|
# If method not defined, NOOP |
301
|
|
|
if self._initialize is None: |
302
|
|
|
self._initialize = lambda x: None |
303
|
|
|
|
304
|
|
|
# Alternative way of setting data_frequency for backwards |
305
|
|
|
# compatibility. |
306
|
|
|
if 'data_frequency' in kwargs: |
307
|
|
|
self.data_frequency = kwargs.pop('data_frequency') |
308
|
|
|
|
309
|
|
|
# Prepare the algo for initialization |
310
|
|
|
self.initialized = False |
311
|
|
|
self.initialize_args = args |
312
|
|
|
self.initialize_kwargs = kwargs |
313
|
|
|
|
314
|
|
|
self.benchmark_sid = kwargs.pop('benchmark_sid', None) |
315
|
|
|
|
316
|
|
|
def init_engine(self, get_loader): |
317
|
|
|
""" |
318
|
|
|
Construct and store a PipelineEngine from loader. |
319
|
|
|
|
320
|
|
|
If get_loader is None, constructs a NoOpPipelineEngine. |
321
|
|
|
""" |
322
|
|
|
if get_loader is not None: |
323
|
|
|
self.engine = SimplePipelineEngine( |
324
|
|
|
get_loader, |
325
|
|
|
self.trading_environment.trading_days, |
326
|
|
|
self.asset_finder, |
327
|
|
|
) |
328
|
|
|
else: |
329
|
|
|
self.engine = NoOpPipelineEngine() |
330
|
|
|
|
331
|
|
|
def initialize(self, *args, **kwargs): |
332
|
|
|
""" |
333
|
|
|
Call self._initialize with `self` made available to Zipline API |
334
|
|
|
functions. |
335
|
|
|
""" |
336
|
|
|
with ZiplineAPI(self): |
337
|
|
|
self._initialize(self, *args, **kwargs) |
338
|
|
|
|
339
|
|
|
def before_trading_start(self, data): |
340
|
|
|
if self._before_trading_start is None: |
341
|
|
|
return |
342
|
|
|
|
343
|
|
|
self._before_trading_start(self, data) |
344
|
|
|
|
345
|
|
|
def handle_data(self, data): |
346
|
|
|
self._handle_data(self, data) |
347
|
|
|
|
348
|
|
|
# Unlike trading controls which remain constant unless placing an |
349
|
|
|
# order, account controls can change each bar. Thus, must check |
350
|
|
|
# every bar no matter if the algorithm places an order or not. |
351
|
|
|
self.validate_account_controls() |
352
|
|
|
|
353
|
|
|
def analyze(self, perf): |
354
|
|
|
if self._analyze is None: |
355
|
|
|
return |
356
|
|
|
|
357
|
|
|
with ZiplineAPI(self): |
358
|
|
|
self._analyze(self, perf) |
359
|
|
|
|
360
|
|
|
def __repr__(self): |
361
|
|
|
""" |
362
|
|
|
N.B. this does not yet represent a string that can be used |
363
|
|
|
to instantiate an exact copy of an algorithm. |
364
|
|
|
|
365
|
|
|
However, it is getting close, and provides some value as something |
366
|
|
|
that can be inspected interactively. |
367
|
|
|
""" |
368
|
|
|
return """ |
369
|
|
|
{class_name}( |
370
|
|
|
capital_base={capital_base} |
371
|
|
|
sim_params={sim_params}, |
372
|
|
|
initialized={initialized}, |
373
|
|
|
slippage={slippage}, |
374
|
|
|
commission={commission}, |
375
|
|
|
blotter={blotter}, |
376
|
|
|
recorded_vars={recorded_vars}) |
377
|
|
|
""".strip().format(class_name=self.__class__.__name__, |
378
|
|
|
capital_base=self.capital_base, |
379
|
|
|
sim_params=repr(self.sim_params), |
380
|
|
|
initialized=self.initialized, |
381
|
|
|
slippage=repr(self.blotter.slippage_func), |
382
|
|
|
commission=repr(self.blotter.commission), |
383
|
|
|
blotter=repr(self.blotter), |
384
|
|
|
recorded_vars=repr(self.recorded_vars)) |
385
|
|
|
|
386
|
|
|
def _create_clock(self): |
387
|
|
|
""" |
388
|
|
|
If the clock property is not set, then create one based on frequency. |
389
|
|
|
""" |
390
|
|
|
if self.sim_params.data_frequency == 'minute': |
391
|
|
|
env = self.trading_environment |
392
|
|
|
trading_o_and_c = env.open_and_closes.ix[ |
393
|
|
|
self.sim_params.trading_days] |
394
|
|
|
market_opens = trading_o_and_c['market_open'].values.astype( |
395
|
|
|
'datetime64[ns]').astype(np.int64) |
396
|
|
|
market_closes = trading_o_and_c['market_close'].values.astype( |
397
|
|
|
'datetime64[ns]').astype(np.int64) |
398
|
|
|
|
399
|
|
|
minutely_emission = self.sim_params.emission_rate == "minute" |
400
|
|
|
|
401
|
|
|
clock = MinuteSimulationClock( |
402
|
|
|
self.sim_params.trading_days, |
403
|
|
|
market_opens, |
404
|
|
|
market_closes, |
405
|
|
|
env.trading_days, |
406
|
|
|
minutely_emission |
407
|
|
|
) |
408
|
|
|
return clock |
409
|
|
|
else: |
410
|
|
|
return DailySimulationClock(self.sim_params.trading_days) |
411
|
|
|
|
412
|
|
|
def _create_benchmark_source(self): |
413
|
|
|
return BenchmarkSource( |
414
|
|
|
self.benchmark_sid, |
415
|
|
|
self.trading_environment, |
416
|
|
|
self.sim_params.trading_days, |
417
|
|
|
self.data_portal, |
418
|
|
|
emission_rate=self.sim_params.emission_rate, |
419
|
|
|
) |
420
|
|
|
|
421
|
|
|
def _create_generator(self, sim_params): |
422
|
|
|
if sim_params is not None: |
423
|
|
|
self.sim_params = sim_params |
424
|
|
|
|
425
|
|
|
if self.perf_tracker is None: |
426
|
|
|
# HACK: When running with the `run` method, we set perf_tracker to |
427
|
|
|
# None so that it will be overwritten here. |
428
|
|
|
self.perf_tracker = PerformanceTracker( |
429
|
|
|
sim_params=self.sim_params, |
430
|
|
|
env=self.trading_environment, |
431
|
|
|
data_portal=self.data_portal |
432
|
|
|
) |
433
|
|
|
|
434
|
|
|
# Set the dt initially to the period start by forcing it to change. |
435
|
|
|
self.on_dt_changed(self.sim_params.period_start) |
436
|
|
|
|
437
|
|
|
if not self.initialized: |
438
|
|
|
self.initialize(*self.initialize_args, **self.initialize_kwargs) |
439
|
|
|
self.initialized = True |
440
|
|
|
|
441
|
|
|
self.trading_client = self._create_trading_client(sim_params) |
442
|
|
|
|
443
|
|
|
return self.trading_client.transform() |
444
|
|
|
|
445
|
|
|
def _create_trading_client(self, sim_params): |
446
|
|
|
return AlgorithmSimulator( |
447
|
|
|
self, |
448
|
|
|
sim_params, |
449
|
|
|
self.data_portal, |
450
|
|
|
self._create_clock(), |
451
|
|
|
self._create_benchmark_source() |
452
|
|
|
) |
453
|
|
|
|
454
|
|
|
def get_generator(self): |
455
|
|
|
""" |
456
|
|
|
Override this method to add new logic to the construction |
457
|
|
|
of the generator. Overrides can use the _create_generator |
458
|
|
|
method to get a standard construction generator. |
459
|
|
|
""" |
460
|
|
|
return self._create_generator(self.sim_params) |
461
|
|
|
|
462
|
|
|
def run(self, data_portal=None): |
463
|
|
|
"""Run the algorithm. |
464
|
|
|
|
465
|
|
|
:Arguments: |
466
|
|
|
source : DataPortal |
467
|
|
|
|
468
|
|
|
:Returns: |
469
|
|
|
daily_stats : pandas.DataFrame |
470
|
|
|
Daily performance metrics such as returns, alpha etc. |
471
|
|
|
|
472
|
|
|
""" |
473
|
|
|
self.data_portal = data_portal |
474
|
|
|
|
475
|
|
|
# Force a reset of the performance tracker, in case |
476
|
|
|
# this is a repeat run of the algorithm. |
477
|
|
|
self.perf_tracker = None |
478
|
|
|
|
479
|
|
|
# Create zipline and loop through simulated_trading. |
480
|
|
|
# Each iteration returns a perf dictionary |
481
|
|
|
perfs = [] |
482
|
|
|
for perf in self.get_generator(): |
483
|
|
|
perfs.append(perf) |
484
|
|
|
|
485
|
|
|
# convert perf dict to pandas dataframe |
486
|
|
|
daily_stats = self._create_daily_stats(perfs) |
487
|
|
|
|
488
|
|
|
self.analyze(daily_stats) |
489
|
|
|
|
490
|
|
|
return daily_stats |
491
|
|
|
|
492
|
|
|
def _write_and_map_id_index_to_sids(self, identifiers, as_of_date): |
493
|
|
|
# Build new Assets for identifiers that can't be resolved as |
494
|
|
|
# sids/Assets |
495
|
|
|
identifiers_to_build = [] |
496
|
|
|
for identifier in identifiers: |
497
|
|
|
asset = None |
498
|
|
|
|
499
|
|
|
if isinstance(identifier, Asset): |
500
|
|
|
asset = self.asset_finder.retrieve_asset(sid=identifier.sid, |
501
|
|
|
default_none=True) |
502
|
|
|
elif isinstance(identifier, Integral): |
503
|
|
|
asset = self.asset_finder.retrieve_asset(sid=identifier, |
504
|
|
|
default_none=True) |
505
|
|
|
if asset is None: |
506
|
|
|
identifiers_to_build.append(identifier) |
507
|
|
|
|
508
|
|
|
self.trading_environment.write_data( |
509
|
|
|
equities_identifiers=identifiers_to_build) |
510
|
|
|
|
511
|
|
|
# We need to clear out any cache misses that were stored while trying |
512
|
|
|
# to do lookups. The real fix for this problem is to not construct an |
513
|
|
|
# AssetFinder until we `run()` when we actually have all the data we |
514
|
|
|
# need to so. |
515
|
|
|
self.asset_finder._reset_caches() |
516
|
|
|
|
517
|
|
|
return self.asset_finder.map_identifier_index_to_sids( |
518
|
|
|
identifiers, as_of_date, |
519
|
|
|
) |
520
|
|
|
|
521
|
|
|
def _create_daily_stats(self, perfs): |
522
|
|
|
# create daily and cumulative stats dataframe |
523
|
|
|
daily_perfs = [] |
524
|
|
|
# TODO: the loop here could overwrite expected properties |
525
|
|
|
# of daily_perf. Could potentially raise or log a |
526
|
|
|
# warning. |
527
|
|
|
for perf in perfs: |
528
|
|
|
if 'daily_perf' in perf: |
529
|
|
|
|
530
|
|
|
perf['daily_perf'].update( |
531
|
|
|
perf['daily_perf'].pop('recorded_vars') |
532
|
|
|
) |
533
|
|
|
perf['daily_perf'].update(perf['cumulative_risk_metrics']) |
534
|
|
|
daily_perfs.append(perf['daily_perf']) |
535
|
|
|
else: |
536
|
|
|
self.risk_report = perf |
537
|
|
|
|
538
|
|
|
daily_dts = [np.datetime64(perf['period_close'], utc=True) |
539
|
|
|
for perf in daily_perfs] |
540
|
|
|
daily_stats = pd.DataFrame(daily_perfs, index=daily_dts) |
541
|
|
|
|
542
|
|
|
return daily_stats |
543
|
|
|
|
544
|
|
|
@api_method |
545
|
|
|
def get_environment(self, field='platform'): |
546
|
|
|
env = { |
547
|
|
|
'arena': self.sim_params.arena, |
548
|
|
|
'data_frequency': self.sim_params.data_frequency, |
549
|
|
|
'start': self.sim_params.first_open, |
550
|
|
|
'end': self.sim_params.last_close, |
551
|
|
|
'capital_base': self.sim_params.capital_base, |
552
|
|
|
'platform': self._platform |
553
|
|
|
} |
554
|
|
|
if field == '*': |
555
|
|
|
return env |
556
|
|
|
else: |
557
|
|
|
return env[field] |
558
|
|
|
|
559
|
|
|
@api_method |
560
|
|
|
def fetch_csv(self, url, |
561
|
|
|
pre_func=None, |
562
|
|
|
post_func=None, |
563
|
|
|
date_column='date', |
564
|
|
|
date_format=None, |
565
|
|
|
timezone=pytz.utc.zone, |
566
|
|
|
symbol=None, |
567
|
|
|
mask=True, |
568
|
|
|
symbol_column=None, |
569
|
|
|
special_params_checker=None, |
570
|
|
|
**kwargs): |
571
|
|
|
|
572
|
|
|
# Show all the logs every time fetcher is used. |
573
|
|
|
csv_data_source = PandasRequestsCSV( |
574
|
|
|
url, |
575
|
|
|
pre_func, |
576
|
|
|
post_func, |
577
|
|
|
self.trading_environment, |
578
|
|
|
self.sim_params.period_start, |
579
|
|
|
self.sim_params.period_end, |
580
|
|
|
date_column, |
581
|
|
|
date_format, |
582
|
|
|
timezone, |
583
|
|
|
symbol, |
584
|
|
|
mask, |
585
|
|
|
symbol_column, |
586
|
|
|
data_frequency=self.data_frequency, |
587
|
|
|
special_params_checker=special_params_checker, |
588
|
|
|
**kwargs |
589
|
|
|
) |
590
|
|
|
|
591
|
|
|
# ingest this into dataportal |
592
|
|
|
self.data_portal.handle_extra_source(csv_data_source.df, |
593
|
|
|
self.sim_params) |
594
|
|
|
|
595
|
|
|
return csv_data_source |
596
|
|
|
|
597
|
|
|
def add_event(self, rule=None, callback=None): |
598
|
|
|
""" |
599
|
|
|
Adds an event to the algorithm's EventManager. |
600
|
|
|
""" |
601
|
|
|
self.event_manager.add_event( |
602
|
|
|
zipline.utils.events.Event(rule, callback), |
603
|
|
|
) |
604
|
|
|
|
605
|
|
|
@api_method |
606
|
|
|
def schedule_function(self, |
607
|
|
|
func, |
608
|
|
|
date_rule=None, |
609
|
|
|
time_rule=None, |
610
|
|
|
half_days=True): |
611
|
|
|
""" |
612
|
|
|
Schedules a function to be called with some timed rules. |
613
|
|
|
""" |
614
|
|
|
date_rule = date_rule or DateRuleFactory.every_day() |
615
|
|
|
time_rule = ((time_rule or TimeRuleFactory.market_open()) |
616
|
|
|
if self.sim_params.data_frequency == 'minute' else |
617
|
|
|
# If we are in daily mode the time_rule is ignored. |
618
|
|
|
zipline.utils.events.Always()) |
619
|
|
|
|
620
|
|
|
self.add_event( |
621
|
|
|
make_eventrule(date_rule, time_rule, half_days), |
622
|
|
|
func, |
623
|
|
|
) |
624
|
|
|
|
625
|
|
|
@api_method |
626
|
|
|
def record(self, *args, **kwargs): |
627
|
|
|
""" |
628
|
|
|
Track and record local variable (i.e. attributes) each day. |
629
|
|
|
""" |
630
|
|
|
# Make 2 objects both referencing the same iterator |
631
|
|
|
args = [iter(args)] * 2 |
632
|
|
|
|
633
|
|
|
# Zip generates list entries by calling `next` on each iterator it |
634
|
|
|
# receives. In this case the two iterators are the same object, so the |
635
|
|
|
# call to next on args[0] will also advance args[1], resulting in zip |
636
|
|
|
# returning (a,b) (c,d) (e,f) rather than (a,a) (b,b) (c,c) etc. |
637
|
|
|
positionals = zip(*args) |
638
|
|
|
for name, value in chain(positionals, iteritems(kwargs)): |
639
|
|
|
self._recorded_vars[name] = value |
640
|
|
|
|
641
|
|
|
@api_method |
642
|
|
|
def set_benchmark(self, benchmark_sid): |
643
|
|
|
if self.initialized: |
644
|
|
|
raise SetBenchmarkOutsideInitialize() |
645
|
|
|
|
646
|
|
|
self.benchmark_sid = benchmark_sid |
647
|
|
|
|
648
|
|
|
@api_method |
649
|
|
|
@preprocess(symbol_str=ensure_upper_case) |
650
|
|
|
def symbol(self, symbol_str): |
651
|
|
|
""" |
652
|
|
|
Default symbol lookup for any source that directly maps the |
653
|
|
|
symbol to the Asset (e.g. yahoo finance). |
654
|
|
|
""" |
655
|
|
|
# If the user has not set the symbol lookup date, |
656
|
|
|
# use the period_end as the date for sybmol->sid resolution. |
657
|
|
|
_lookup_date = self._symbol_lookup_date if self._symbol_lookup_date is not None \ |
658
|
|
|
else self.sim_params.period_end |
659
|
|
|
|
660
|
|
|
return self.asset_finder.lookup_symbol( |
661
|
|
|
symbol_str, |
662
|
|
|
as_of_date=_lookup_date, |
663
|
|
|
) |
664
|
|
|
|
665
|
|
|
@api_method |
666
|
|
|
def symbols(self, *args): |
667
|
|
|
""" |
668
|
|
|
Default symbols lookup for any source that directly maps the |
669
|
|
|
symbol to the Asset (e.g. yahoo finance). |
670
|
|
|
""" |
671
|
|
|
return [self.symbol(identifier) for identifier in args] |
672
|
|
|
|
673
|
|
|
@api_method |
674
|
|
|
def sid(self, a_sid): |
675
|
|
|
""" |
676
|
|
|
Default sid lookup for any source that directly maps the integer sid |
677
|
|
|
to the Asset. |
678
|
|
|
""" |
679
|
|
|
return self.asset_finder.retrieve_asset(a_sid) |
680
|
|
|
|
681
|
|
|
@api_method |
682
|
|
|
@preprocess(symbol=ensure_upper_case) |
683
|
|
|
def future_symbol(self, symbol): |
684
|
|
|
""" Lookup a futures contract with a given symbol. |
685
|
|
|
|
686
|
|
|
Parameters |
687
|
|
|
---------- |
688
|
|
|
symbol : str |
689
|
|
|
The symbol of the desired contract. |
690
|
|
|
|
691
|
|
|
Returns |
692
|
|
|
------- |
693
|
|
|
Future |
694
|
|
|
A Future object. |
695
|
|
|
|
696
|
|
|
Raises |
697
|
|
|
------ |
698
|
|
|
SymbolNotFound |
699
|
|
|
Raised when no contract named 'symbol' is found. |
700
|
|
|
|
701
|
|
|
""" |
702
|
|
|
return self.asset_finder.lookup_future_symbol(symbol) |
703
|
|
|
|
704
|
|
|
@api_method |
705
|
|
|
@preprocess(root_symbol=ensure_upper_case) |
706
|
|
|
def future_chain(self, root_symbol, as_of_date=None): |
707
|
|
|
""" Look up a future chain with the specified parameters. |
708
|
|
|
|
709
|
|
|
Parameters |
710
|
|
|
---------- |
711
|
|
|
root_symbol : str |
712
|
|
|
The root symbol of a future chain. |
713
|
|
|
as_of_date : datetime.datetime or pandas.Timestamp or str, optional |
714
|
|
|
Date at which the chain determination is rooted. I.e. the |
715
|
|
|
existing contract whose notice date is first after this date is |
716
|
|
|
the primary contract, etc. |
717
|
|
|
|
718
|
|
|
Returns |
719
|
|
|
------- |
720
|
|
|
FutureChain |
721
|
|
|
The future chain matching the specified parameters. |
722
|
|
|
|
723
|
|
|
Raises |
724
|
|
|
------ |
725
|
|
|
RootSymbolNotFound |
726
|
|
|
If a future chain could not be found for the given root symbol. |
727
|
|
|
""" |
728
|
|
|
if as_of_date: |
729
|
|
|
try: |
730
|
|
|
as_of_date = pd.Timestamp(as_of_date, tz='UTC') |
731
|
|
|
except ValueError: |
732
|
|
|
raise UnsupportedDatetimeFormat(input=as_of_date, |
733
|
|
|
method='future_chain') |
734
|
|
|
return FutureChain( |
735
|
|
|
asset_finder=self.asset_finder, |
736
|
|
|
get_datetime=self.get_datetime, |
737
|
|
|
root_symbol=root_symbol, |
738
|
|
|
as_of_date=as_of_date |
739
|
|
|
) |
740
|
|
|
|
741
|
|
|
def _calculate_order_value_amount(self, asset, value): |
742
|
|
|
""" |
743
|
|
|
Calculates how many shares/contracts to order based on the type of |
744
|
|
|
asset being ordered. |
745
|
|
|
""" |
746
|
|
|
last_price = self.trading_client.current_data[asset].price |
747
|
|
|
|
748
|
|
|
if tolerant_equals(last_price, 0): |
749
|
|
|
zero_message = "Price of 0 for {psid}; can't infer value".format( |
750
|
|
|
psid=asset |
751
|
|
|
) |
752
|
|
|
if self.logger: |
753
|
|
|
self.logger.debug(zero_message) |
754
|
|
|
# Don't place any order |
755
|
|
|
return 0 |
756
|
|
|
|
757
|
|
|
if isinstance(asset, Future): |
758
|
|
|
value_multiplier = asset.contract_multiplier |
759
|
|
|
else: |
760
|
|
|
value_multiplier = 1 |
761
|
|
|
|
762
|
|
|
return value / (last_price * value_multiplier) |
763
|
|
|
|
764
|
|
|
@api_method |
765
|
|
|
def order(self, sid, amount, |
766
|
|
|
limit_price=None, |
767
|
|
|
stop_price=None, |
768
|
|
|
style=None): |
769
|
|
|
""" |
770
|
|
|
Place an order using the specified parameters. |
771
|
|
|
""" |
772
|
|
|
# Truncate to the integer share count that's either within .0001 of |
773
|
|
|
# amount or closer to zero. |
774
|
|
|
# E.g. 3.9999 -> 4.0; 5.5 -> 5.0; -5.5 -> -5.0 |
775
|
|
|
amount = int(round_if_near_integer(amount)) |
776
|
|
|
|
777
|
|
|
# Raises a ZiplineError if invalid parameters are detected. |
778
|
|
|
self.validate_order_params(sid, |
779
|
|
|
amount, |
780
|
|
|
limit_price, |
781
|
|
|
stop_price, |
782
|
|
|
style) |
783
|
|
|
|
784
|
|
|
# Convert deprecated limit_price and stop_price parameters to use |
785
|
|
|
# ExecutionStyle objects. |
786
|
|
|
style = self.__convert_order_params_for_blotter(limit_price, |
787
|
|
|
stop_price, |
788
|
|
|
style) |
789
|
|
|
return self.blotter.order(sid, amount, style) |
790
|
|
|
|
791
|
|
|
def validate_order_params(self, |
792
|
|
|
asset, |
793
|
|
|
amount, |
794
|
|
|
limit_price, |
795
|
|
|
stop_price, |
796
|
|
|
style): |
797
|
|
|
""" |
798
|
|
|
Helper method for validating parameters to the order API function. |
799
|
|
|
|
800
|
|
|
Raises an UnsupportedOrderParameters if invalid arguments are found. |
801
|
|
|
""" |
802
|
|
|
|
803
|
|
|
if not self.initialized: |
804
|
|
|
raise OrderDuringInitialize( |
805
|
|
|
msg="order() can only be called from within handle_data()" |
806
|
|
|
) |
807
|
|
|
|
808
|
|
|
if style: |
809
|
|
|
if limit_price: |
810
|
|
|
raise UnsupportedOrderParameters( |
811
|
|
|
msg="Passing both limit_price and style is not supported." |
812
|
|
|
) |
813
|
|
|
|
814
|
|
|
if stop_price: |
815
|
|
|
raise UnsupportedOrderParameters( |
816
|
|
|
msg="Passing both stop_price and style is not supported." |
817
|
|
|
) |
818
|
|
|
|
819
|
|
|
if not isinstance(asset, Asset): |
820
|
|
|
raise UnsupportedOrderParameters( |
821
|
|
|
msg="Passing non-Asset argument to 'order()' is not supported." |
822
|
|
|
" Use 'sid()' or 'symbol()' methods to look up an Asset." |
823
|
|
|
) |
824
|
|
|
|
825
|
|
|
for control in self.trading_controls: |
826
|
|
|
control.validate(asset, |
827
|
|
|
amount, |
828
|
|
|
self.updated_portfolio(), |
829
|
|
|
self.get_datetime(), |
830
|
|
|
self.trading_client.current_data) |
831
|
|
|
|
832
|
|
|
@staticmethod |
833
|
|
|
def __convert_order_params_for_blotter(limit_price, stop_price, style): |
834
|
|
|
""" |
835
|
|
|
Helper method for converting deprecated limit_price and stop_price |
836
|
|
|
arguments into ExecutionStyle instances. |
837
|
|
|
|
838
|
|
|
This function assumes that either style == None or (limit_price, |
839
|
|
|
stop_price) == (None, None). |
840
|
|
|
""" |
841
|
|
|
# TODO_SS: DeprecationWarning for usage of limit_price and stop_price. |
842
|
|
|
if style: |
843
|
|
|
assert (limit_price, stop_price) == (None, None) |
844
|
|
|
return style |
845
|
|
|
if limit_price and stop_price: |
846
|
|
|
return StopLimitOrder(limit_price, stop_price) |
847
|
|
|
if limit_price: |
848
|
|
|
return LimitOrder(limit_price) |
849
|
|
|
if stop_price: |
850
|
|
|
return StopOrder(stop_price) |
851
|
|
|
else: |
852
|
|
|
return MarketOrder() |
853
|
|
|
|
854
|
|
|
@api_method |
855
|
|
|
def order_value(self, sid, value, |
856
|
|
|
limit_price=None, stop_price=None, style=None): |
857
|
|
|
""" |
858
|
|
|
Place an order by desired value rather than desired number of shares. |
859
|
|
|
If the requested sid exists, the requested value is |
860
|
|
|
divided by its price to imply the number of shares to transact. |
861
|
|
|
If the Asset being ordered is a Future, the 'value' calculated |
862
|
|
|
is actually the exposure, as Futures have no 'value'. |
863
|
|
|
|
864
|
|
|
value > 0 :: Buy/Cover |
865
|
|
|
value < 0 :: Sell/Short |
866
|
|
|
Market order: order(sid, value) |
867
|
|
|
Limit order: order(sid, value, limit_price) |
868
|
|
|
Stop order: order(sid, value, None, stop_price) |
869
|
|
|
StopLimit order: order(sid, value, limit_price, stop_price) |
870
|
|
|
""" |
871
|
|
|
amount = self._calculate_order_value_amount(sid, value) |
872
|
|
|
return self.order(sid, amount, |
873
|
|
|
limit_price=limit_price, |
874
|
|
|
stop_price=stop_price, |
875
|
|
|
style=style) |
876
|
|
|
|
877
|
|
|
@property |
878
|
|
|
def recorded_vars(self): |
879
|
|
|
return copy(self._recorded_vars) |
880
|
|
|
|
881
|
|
|
@property |
882
|
|
|
def portfolio(self): |
883
|
|
|
return self.updated_portfolio() |
884
|
|
|
|
885
|
|
|
def updated_portfolio(self): |
886
|
|
|
if self.portfolio_needs_update: |
887
|
|
|
self._portfolio = \ |
888
|
|
|
self.perf_tracker.get_portfolio(self.performance_needs_update, |
889
|
|
|
self.datetime) |
890
|
|
|
self.portfolio_needs_update = False |
891
|
|
|
self.performance_needs_update = False |
892
|
|
|
return self._portfolio |
893
|
|
|
|
894
|
|
|
@property |
895
|
|
|
def account(self): |
896
|
|
|
return self.updated_account() |
897
|
|
|
|
898
|
|
|
def updated_account(self): |
899
|
|
|
if self.account_needs_update: |
900
|
|
|
self._account = \ |
901
|
|
|
self.perf_tracker.get_account(self.performance_needs_update, |
902
|
|
|
self.datetime) |
903
|
|
|
self.account_needs_update = False |
904
|
|
|
self.performance_needs_update = False |
905
|
|
|
return self._account |
906
|
|
|
|
907
|
|
|
def set_logger(self, logger): |
908
|
|
|
self.logger = logger |
909
|
|
|
|
910
|
|
|
def on_dt_changed(self, dt): |
911
|
|
|
""" |
912
|
|
|
Callback triggered by the simulation loop whenever the current dt |
913
|
|
|
changes. |
914
|
|
|
|
915
|
|
|
Any logic that should happen exactly once at the start of each datetime |
916
|
|
|
group should happen here. |
917
|
|
|
""" |
918
|
|
|
assert isinstance(dt, datetime), \ |
919
|
|
|
"Attempt to set algorithm's current time with non-datetime" |
920
|
|
|
assert dt.tzinfo == pytz.utc, \ |
921
|
|
|
"Algorithm expects a utc datetime" |
922
|
|
|
|
923
|
|
|
self.datetime = dt |
924
|
|
|
self.perf_tracker.set_date(dt) |
925
|
|
|
self.blotter.set_date(dt) |
926
|
|
|
|
927
|
|
|
self.portfolio_needs_update = True |
928
|
|
|
self.account_needs_update = True |
929
|
|
|
self.performance_needs_update = True |
930
|
|
|
|
931
|
|
|
@api_method |
932
|
|
|
def get_datetime(self, tz=None): |
933
|
|
|
""" |
934
|
|
|
Returns the simulation datetime. |
935
|
|
|
""" |
936
|
|
|
dt = self.datetime |
937
|
|
|
assert dt.tzinfo == pytz.utc, "Algorithm should have a utc datetime" |
938
|
|
|
|
939
|
|
|
if tz is not None: |
940
|
|
|
# Convert to the given timezone passed as a string or tzinfo. |
941
|
|
|
if isinstance(tz, string_types): |
942
|
|
|
tz = pytz.timezone(tz) |
943
|
|
|
dt = dt.astimezone(tz) |
944
|
|
|
|
945
|
|
|
return dt # datetime.datetime objects are immutable. |
946
|
|
|
|
947
|
|
|
def update_dividends(self, dividend_frame): |
948
|
|
|
""" |
949
|
|
|
Set DataFrame used to process dividends. DataFrame columns should |
950
|
|
|
contain at least the entries in zp.DIVIDEND_FIELDS. |
951
|
|
|
""" |
952
|
|
|
self.perf_tracker.update_dividends(dividend_frame) |
953
|
|
|
|
954
|
|
|
@api_method |
955
|
|
|
def set_slippage(self, slippage): |
956
|
|
|
if not isinstance(slippage, SlippageModel): |
957
|
|
|
raise UnsupportedSlippageModel() |
958
|
|
|
if self.initialized: |
959
|
|
|
raise SetSlippagePostInit() |
960
|
|
|
self.blotter.slippage_func = slippage |
961
|
|
|
|
962
|
|
|
@api_method |
963
|
|
|
def set_commission(self, commission): |
964
|
|
|
if not isinstance(commission, (PerShare, PerTrade, PerDollar)): |
965
|
|
|
raise UnsupportedCommissionModel() |
966
|
|
|
|
967
|
|
|
if self.initialized: |
968
|
|
|
raise SetCommissionPostInit() |
969
|
|
|
self.blotter.commission = commission |
970
|
|
|
|
971
|
|
|
@api_method |
972
|
|
|
def set_symbol_lookup_date(self, dt): |
973
|
|
|
""" |
974
|
|
|
Set the date for which symbols will be resolved to their sids |
975
|
|
|
(symbols may map to different firms or underlying assets at |
976
|
|
|
different times) |
977
|
|
|
""" |
978
|
|
|
try: |
979
|
|
|
self._symbol_lookup_date = pd.Timestamp(dt, tz='UTC') |
980
|
|
|
except ValueError: |
981
|
|
|
raise UnsupportedDatetimeFormat(input=dt, |
982
|
|
|
method='set_symbol_lookup_date') |
983
|
|
|
|
984
|
|
|
# Remain backwards compatibility |
985
|
|
|
@property |
986
|
|
|
def data_frequency(self): |
987
|
|
|
return self.sim_params.data_frequency |
988
|
|
|
|
989
|
|
|
@data_frequency.setter |
990
|
|
|
def data_frequency(self, value): |
991
|
|
|
assert value in ('daily', 'minute') |
992
|
|
|
self.sim_params.data_frequency = value |
993
|
|
|
|
994
|
|
|
@api_method |
995
|
|
|
def order_percent(self, sid, percent, |
996
|
|
|
limit_price=None, stop_price=None, style=None): |
997
|
|
|
""" |
998
|
|
|
Place an order in the specified asset corresponding to the given |
999
|
|
|
percent of the current portfolio value. |
1000
|
|
|
|
1001
|
|
|
Note that percent must expressed as a decimal (0.50 means 50\%). |
1002
|
|
|
""" |
1003
|
|
|
value = self.portfolio.portfolio_value * percent |
1004
|
|
|
return self.order_value(sid, value, |
1005
|
|
|
limit_price=limit_price, |
1006
|
|
|
stop_price=stop_price, |
1007
|
|
|
style=style) |
1008
|
|
|
|
1009
|
|
|
@api_method |
1010
|
|
|
def order_target(self, sid, target, |
1011
|
|
|
limit_price=None, stop_price=None, style=None): |
1012
|
|
|
""" |
1013
|
|
|
Place an order to adjust a position to a target number of shares. If |
1014
|
|
|
the position doesn't already exist, this is equivalent to placing a new |
1015
|
|
|
order. If the position does exist, this is equivalent to placing an |
1016
|
|
|
order for the difference between the target number of shares and the |
1017
|
|
|
current number of shares. |
1018
|
|
|
""" |
1019
|
|
|
if sid in self.portfolio.positions: |
1020
|
|
|
current_position = self.portfolio.positions[sid].amount |
1021
|
|
|
req_shares = target - current_position |
1022
|
|
|
return self.order(sid, req_shares, |
1023
|
|
|
limit_price=limit_price, |
1024
|
|
|
stop_price=stop_price, |
1025
|
|
|
style=style) |
1026
|
|
|
else: |
1027
|
|
|
return self.order(sid, target, |
1028
|
|
|
limit_price=limit_price, |
1029
|
|
|
stop_price=stop_price, |
1030
|
|
|
style=style) |
1031
|
|
|
|
1032
|
|
|
@api_method |
1033
|
|
|
def order_target_value(self, sid, target, |
1034
|
|
|
limit_price=None, stop_price=None, style=None): |
1035
|
|
|
""" |
1036
|
|
|
Place an order to adjust a position to a target value. If |
1037
|
|
|
the position doesn't already exist, this is equivalent to placing a new |
1038
|
|
|
order. If the position does exist, this is equivalent to placing an |
1039
|
|
|
order for the difference between the target value and the |
1040
|
|
|
current value. |
1041
|
|
|
If the Asset being ordered is a Future, the 'target value' calculated |
1042
|
|
|
is actually the target exposure, as Futures have no 'value'. |
1043
|
|
|
""" |
1044
|
|
|
target_amount = self._calculate_order_value_amount(sid, target) |
1045
|
|
|
return self.order_target(sid, target_amount, |
1046
|
|
|
limit_price=limit_price, |
1047
|
|
|
stop_price=stop_price, |
1048
|
|
|
style=style) |
1049
|
|
|
|
1050
|
|
|
@api_method |
1051
|
|
|
def order_target_percent(self, sid, target, |
1052
|
|
|
limit_price=None, stop_price=None, style=None): |
1053
|
|
|
""" |
1054
|
|
|
Place an order to adjust a position to a target percent of the |
1055
|
|
|
current portfolio value. If the position doesn't already exist, this is |
1056
|
|
|
equivalent to placing a new order. If the position does exist, this is |
1057
|
|
|
equivalent to placing an order for the difference between the target |
1058
|
|
|
percent and the current percent. |
1059
|
|
|
|
1060
|
|
|
Note that target must expressed as a decimal (0.50 means 50\%). |
1061
|
|
|
""" |
1062
|
|
|
target_value = self.portfolio.portfolio_value * target |
1063
|
|
|
return self.order_target_value(sid, target_value, |
1064
|
|
|
limit_price=limit_price, |
1065
|
|
|
stop_price=stop_price, |
1066
|
|
|
style=style) |
1067
|
|
|
|
1068
|
|
|
@api_method |
1069
|
|
|
def get_open_orders(self, sid=None): |
1070
|
|
|
if sid is None: |
1071
|
|
|
return { |
1072
|
|
|
key: [order.to_api_obj() for order in orders] |
1073
|
|
|
for key, orders in iteritems(self.blotter.open_orders) |
1074
|
|
|
if orders |
1075
|
|
|
} |
1076
|
|
|
if sid in self.blotter.open_orders: |
1077
|
|
|
orders = self.blotter.open_orders[sid] |
1078
|
|
|
return [order.to_api_obj() for order in orders] |
1079
|
|
|
return [] |
1080
|
|
|
|
1081
|
|
|
@api_method |
1082
|
|
|
def get_order(self, order_id): |
1083
|
|
|
if order_id in self.blotter.orders: |
1084
|
|
|
return self.blotter.orders[order_id].to_api_obj() |
1085
|
|
|
|
1086
|
|
|
@api_method |
1087
|
|
|
def cancel_order(self, order_param): |
1088
|
|
|
order_id = order_param |
1089
|
|
|
if isinstance(order_param, zipline.protocol.Order): |
1090
|
|
|
order_id = order_param.id |
1091
|
|
|
|
1092
|
|
|
self.blotter.cancel(order_id) |
1093
|
|
|
|
1094
|
|
|
@api_method |
1095
|
|
|
@require_initialized(HistoryInInitialize()) |
1096
|
|
|
def history(self, sids, bar_count, frequency, field, ffill=True): |
1097
|
|
|
if self.data_portal is None: |
1098
|
|
|
raise Exception("no data portal!") |
1099
|
|
|
|
1100
|
|
|
return self.data_portal.get_history_window( |
1101
|
|
|
sids, |
1102
|
|
|
self.datetime, |
1103
|
|
|
bar_count, |
1104
|
|
|
frequency, |
1105
|
|
|
field, |
1106
|
|
|
ffill, |
1107
|
|
|
) |
1108
|
|
|
|
1109
|
|
|
#################### |
1110
|
|
|
# Account Controls # |
1111
|
|
|
#################### |
1112
|
|
|
|
1113
|
|
|
def register_account_control(self, control): |
1114
|
|
|
""" |
1115
|
|
|
Register a new AccountControl to be checked on each bar. |
1116
|
|
|
""" |
1117
|
|
|
if self.initialized: |
1118
|
|
|
raise RegisterAccountControlPostInit() |
1119
|
|
|
self.account_controls.append(control) |
1120
|
|
|
|
1121
|
|
|
def validate_account_controls(self): |
1122
|
|
|
for control in self.account_controls: |
1123
|
|
|
control.validate(self.updated_portfolio(), |
1124
|
|
|
self.updated_account(), |
1125
|
|
|
self.get_datetime(), |
1126
|
|
|
self.trading_client.current_data) |
1127
|
|
|
|
1128
|
|
|
@api_method |
1129
|
|
|
def set_max_leverage(self, max_leverage=None): |
1130
|
|
|
""" |
1131
|
|
|
Set a limit on the maximum leverage of the algorithm. |
1132
|
|
|
""" |
1133
|
|
|
control = MaxLeverage(max_leverage) |
1134
|
|
|
self.register_account_control(control) |
1135
|
|
|
|
1136
|
|
|
#################### |
1137
|
|
|
# Trading Controls # |
1138
|
|
|
#################### |
1139
|
|
|
|
1140
|
|
|
def register_trading_control(self, control): |
1141
|
|
|
""" |
1142
|
|
|
Register a new TradingControl to be checked prior to order calls. |
1143
|
|
|
""" |
1144
|
|
|
if self.initialized: |
1145
|
|
|
raise RegisterTradingControlPostInit() |
1146
|
|
|
self.trading_controls.append(control) |
1147
|
|
|
|
1148
|
|
|
@api_method |
1149
|
|
|
def set_max_position_size(self, |
1150
|
|
|
sid=None, |
1151
|
|
|
max_shares=None, |
1152
|
|
|
max_notional=None): |
1153
|
|
|
""" |
1154
|
|
|
Set a limit on the number of shares and/or dollar value held for the |
1155
|
|
|
given sid. Limits are treated as absolute values and are enforced at |
1156
|
|
|
the time that the algo attempts to place an order for sid. This means |
1157
|
|
|
that it's possible to end up with more than the max number of shares |
1158
|
|
|
due to splits/dividends, and more than the max notional due to price |
1159
|
|
|
improvement. |
1160
|
|
|
|
1161
|
|
|
If an algorithm attempts to place an order that would result in |
1162
|
|
|
increasing the absolute value of shares/dollar value exceeding one of |
1163
|
|
|
these limits, raise a TradingControlException. |
1164
|
|
|
""" |
1165
|
|
|
control = MaxPositionSize(asset=sid, |
1166
|
|
|
max_shares=max_shares, |
1167
|
|
|
max_notional=max_notional) |
1168
|
|
|
self.register_trading_control(control) |
1169
|
|
|
|
1170
|
|
|
@api_method |
1171
|
|
|
def set_max_order_size(self, sid=None, max_shares=None, max_notional=None): |
1172
|
|
|
""" |
1173
|
|
|
Set a limit on the number of shares and/or dollar value of any single |
1174
|
|
|
order placed for sid. Limits are treated as absolute values and are |
1175
|
|
|
enforced at the time that the algo attempts to place an order for sid. |
1176
|
|
|
|
1177
|
|
|
If an algorithm attempts to place an order that would result in |
1178
|
|
|
exceeding one of these limits, raise a TradingControlException. |
1179
|
|
|
""" |
1180
|
|
|
control = MaxOrderSize(asset=sid, |
1181
|
|
|
max_shares=max_shares, |
1182
|
|
|
max_notional=max_notional) |
1183
|
|
|
self.register_trading_control(control) |
1184
|
|
|
|
1185
|
|
|
@api_method |
1186
|
|
|
def set_max_order_count(self, max_count): |
1187
|
|
|
""" |
1188
|
|
|
Set a limit on the number of orders that can be placed within the given |
1189
|
|
|
time interval. |
1190
|
|
|
""" |
1191
|
|
|
control = MaxOrderCount(max_count) |
1192
|
|
|
self.register_trading_control(control) |
1193
|
|
|
|
1194
|
|
|
@api_method |
1195
|
|
|
def set_do_not_order_list(self, restricted_list): |
1196
|
|
|
""" |
1197
|
|
|
Set a restriction on which sids can be ordered. |
1198
|
|
|
""" |
1199
|
|
|
control = RestrictedListOrder(restricted_list) |
1200
|
|
|
self.register_trading_control(control) |
1201
|
|
|
|
1202
|
|
|
@api_method |
1203
|
|
|
def set_long_only(self): |
1204
|
|
|
""" |
1205
|
|
|
Set a rule specifying that this algorithm cannot take short positions. |
1206
|
|
|
""" |
1207
|
|
|
self.register_trading_control(LongOnly()) |
1208
|
|
|
|
1209
|
|
|
############## |
1210
|
|
|
# Pipeline API |
1211
|
|
|
############## |
1212
|
|
|
@api_method |
1213
|
|
|
@require_not_initialized(AttachPipelineAfterInitialize()) |
1214
|
|
|
def attach_pipeline(self, pipeline, name, chunksize=None): |
1215
|
|
|
""" |
1216
|
|
|
Register a pipeline to be computed at the start of each day. |
1217
|
|
|
""" |
1218
|
|
|
if self._pipelines: |
1219
|
|
|
raise NotImplementedError("Multiple pipelines are not supported.") |
1220
|
|
|
if chunksize is None: |
1221
|
|
|
# Make the first chunk smaller to get more immediate results: |
1222
|
|
|
# (one week, then every half year) |
1223
|
|
|
chunks = iter(chain([5], repeat(126))) |
1224
|
|
|
else: |
1225
|
|
|
chunks = iter(repeat(int(chunksize))) |
1226
|
|
|
self._pipelines[name] = pipeline, chunks |
1227
|
|
|
|
1228
|
|
|
# Return the pipeline to allow expressions like |
1229
|
|
|
# p = attach_pipeline(Pipeline(), 'name') |
1230
|
|
|
return pipeline |
1231
|
|
|
|
1232
|
|
|
@api_method |
1233
|
|
|
@require_initialized(PipelineOutputDuringInitialize()) |
1234
|
|
|
def pipeline_output(self, name): |
1235
|
|
|
""" |
1236
|
|
|
Get the results of pipeline with name `name`. |
1237
|
|
|
|
1238
|
|
|
Parameters |
1239
|
|
|
---------- |
1240
|
|
|
name : str |
1241
|
|
|
Name of the pipeline for which results are requested. |
1242
|
|
|
|
1243
|
|
|
Returns |
1244
|
|
|
------- |
1245
|
|
|
results : pd.DataFrame |
1246
|
|
|
DataFrame containing the results of the requested pipeline for |
1247
|
|
|
the current simulation date. |
1248
|
|
|
|
1249
|
|
|
Raises |
1250
|
|
|
------ |
1251
|
|
|
NoSuchPipeline |
1252
|
|
|
Raised when no pipeline with the name `name` has been registered. |
1253
|
|
|
|
1254
|
|
|
See Also |
1255
|
|
|
-------- |
1256
|
|
|
:meth:`zipline.pipeline.engine.PipelineEngine.run_pipeline` |
1257
|
|
|
""" |
1258
|
|
|
# NOTE: We don't currently support multiple pipelines, but we plan to |
1259
|
|
|
# in the future. |
1260
|
|
|
try: |
1261
|
|
|
p, chunks = self._pipelines[name] |
1262
|
|
|
except KeyError: |
1263
|
|
|
raise NoSuchPipeline( |
1264
|
|
|
name=name, |
1265
|
|
|
valid=list(self._pipelines.keys()), |
1266
|
|
|
) |
1267
|
|
|
return self._pipeline_output(p, chunks) |
1268
|
|
|
|
1269
|
|
|
def _pipeline_output(self, pipeline, chunks): |
1270
|
|
|
""" |
1271
|
|
|
Internal implementation of `pipeline_output`. |
1272
|
|
|
""" |
1273
|
|
|
today = normalize_date(self.get_datetime()) |
1274
|
|
|
try: |
1275
|
|
|
data = self._pipeline_cache.unwrap(today) |
1276
|
|
|
except Expired: |
1277
|
|
|
data, valid_until = self._run_pipeline( |
1278
|
|
|
pipeline, today, next(chunks), |
1279
|
|
|
) |
1280
|
|
|
self._pipeline_cache = CachedObject(data, valid_until) |
1281
|
|
|
|
1282
|
|
|
# Now that we have a cached result, try to return the data for today. |
1283
|
|
|
try: |
1284
|
|
|
return data.loc[today] |
1285
|
|
|
except KeyError: |
1286
|
|
|
# This happens if no assets passed the pipeline screen on a given |
1287
|
|
|
# day. |
1288
|
|
|
return pd.DataFrame(index=[], columns=data.columns) |
1289
|
|
|
|
1290
|
|
|
def _run_pipeline(self, pipeline, start_date, chunksize): |
1291
|
|
|
""" |
1292
|
|
|
Compute `pipeline`, providing values for at least `start_date`. |
1293
|
|
|
|
1294
|
|
|
Produces a DataFrame containing data for days between `start_date` and |
1295
|
|
|
`end_date`, where `end_date` is defined by: |
1296
|
|
|
|
1297
|
|
|
`end_date = min(start_date + chunksize trading days, |
1298
|
|
|
simulation_end)` |
1299
|
|
|
|
1300
|
|
|
Returns |
1301
|
|
|
------- |
1302
|
|
|
(data, valid_until) : tuple (pd.DataFrame, pd.Timestamp) |
1303
|
|
|
|
1304
|
|
|
See Also |
1305
|
|
|
-------- |
1306
|
|
|
PipelineEngine.run_pipeline |
1307
|
|
|
""" |
1308
|
|
|
days = self.trading_environment.trading_days |
1309
|
|
|
|
1310
|
|
|
# Load data starting from the previous trading day... |
1311
|
|
|
start_date_loc = days.get_loc(start_date) |
1312
|
|
|
|
1313
|
|
|
# ...continuing until either the day before the simulation end, or |
1314
|
|
|
# until chunksize days of data have been loaded. |
1315
|
|
|
sim_end = self.sim_params.last_close.normalize() |
1316
|
|
|
end_loc = min(start_date_loc + chunksize, days.get_loc(sim_end)) |
1317
|
|
|
end_date = days[end_loc] |
1318
|
|
|
|
1319
|
|
|
return \ |
1320
|
|
|
self.engine.run_pipeline(pipeline, start_date, end_date), end_date |
1321
|
|
|
|
1322
|
|
|
################## |
1323
|
|
|
# End Pipeline API |
1324
|
|
|
################## |
1325
|
|
|
|
1326
|
|
|
@classmethod |
1327
|
|
|
def all_api_methods(cls): |
1328
|
|
|
""" |
1329
|
|
|
Return a list of all the TradingAlgorithm API methods. |
1330
|
|
|
""" |
1331
|
|
|
return [ |
1332
|
|
|
fn for fn in itervalues(vars(cls)) |
1333
|
|
|
if getattr(fn, 'is_api_method', False) |
1334
|
|
|
] |
1335
|
|
|
|