1
|
|
|
# |
2
|
|
|
# Copyright 2014 Quantopian, Inc. |
3
|
|
|
# |
4
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
5
|
|
|
# you may not use this file except in compliance with the License. |
6
|
|
|
# You may obtain a copy of the License at |
7
|
|
|
# |
8
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0 |
9
|
|
|
# |
10
|
|
|
# Unless required by applicable law or agreed to in writing, software |
11
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS, |
12
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13
|
|
|
# See the License for the specific language governing permissions and |
14
|
|
|
# limitations under the License. |
15
|
|
|
from copy import copy |
16
|
|
|
import warnings |
17
|
|
|
|
18
|
|
|
import pytz |
19
|
|
|
import pandas as pd |
20
|
|
|
from pandas.tseries.tools import normalize_date |
21
|
|
|
import numpy as np |
22
|
|
|
|
23
|
|
|
from datetime import datetime |
24
|
|
|
from itertools import groupby, chain, repeat |
25
|
|
|
from numbers import Integral |
26
|
|
|
from operator import attrgetter |
27
|
|
|
|
28
|
|
|
from six.moves import filter |
29
|
|
|
from six import ( |
30
|
|
|
exec_, |
31
|
|
|
iteritems, |
32
|
|
|
itervalues, |
33
|
|
|
string_types, |
34
|
|
|
) |
35
|
|
|
|
36
|
|
|
|
37
|
|
|
from zipline.errors import ( |
38
|
|
|
AttachPipelineAfterInitialize, |
39
|
|
|
HistoryInInitialize, |
40
|
|
|
NoSuchPipeline, |
41
|
|
|
OrderDuringInitialize, |
42
|
|
|
PipelineOutputDuringInitialize, |
43
|
|
|
RegisterAccountControlPostInit, |
44
|
|
|
RegisterTradingControlPostInit, |
45
|
|
|
SetCommissionPostInit, |
46
|
|
|
SetSlippagePostInit, |
47
|
|
|
UnsupportedCommissionModel, |
48
|
|
|
UnsupportedDatetimeFormat, |
49
|
|
|
UnsupportedOrderParameters, |
50
|
|
|
UnsupportedSlippageModel, |
51
|
|
|
) |
52
|
|
|
from zipline.finance.trading import TradingEnvironment |
53
|
|
|
from zipline.finance.blotter import Blotter |
54
|
|
|
from zipline.finance.commission import PerShare, PerTrade, PerDollar |
55
|
|
|
from zipline.finance.controls import ( |
56
|
|
|
LongOnly, |
57
|
|
|
MaxOrderCount, |
58
|
|
|
MaxOrderSize, |
59
|
|
|
MaxPositionSize, |
60
|
|
|
MaxLeverage, |
61
|
|
|
RestrictedListOrder |
62
|
|
|
) |
63
|
|
|
from zipline.finance.execution import ( |
64
|
|
|
LimitOrder, |
65
|
|
|
MarketOrder, |
66
|
|
|
StopLimitOrder, |
67
|
|
|
StopOrder, |
68
|
|
|
) |
69
|
|
|
from zipline.finance.performance import PerformanceTracker |
70
|
|
|
from zipline.finance.slippage import ( |
71
|
|
|
VolumeShareSlippage, |
72
|
|
|
SlippageModel, |
73
|
|
|
transact_partial |
74
|
|
|
) |
75
|
|
|
from zipline.assets import Asset, Future |
76
|
|
|
from zipline.assets.futures import FutureChain |
77
|
|
|
from zipline.gens.composites import date_sorted_sources |
78
|
|
|
from zipline.gens.tradesimulation import AlgorithmSimulator |
79
|
|
|
from zipline.pipeline.engine import ( |
80
|
|
|
NoOpPipelineEngine, |
81
|
|
|
SimplePipelineEngine, |
82
|
|
|
) |
83
|
|
|
from zipline.sources import DataFrameSource, DataPanelSource |
84
|
|
|
from zipline.utils.api_support import ( |
85
|
|
|
api_method, |
86
|
|
|
require_initialized, |
87
|
|
|
require_not_initialized, |
88
|
|
|
ZiplineAPI, |
89
|
|
|
) |
90
|
|
|
from zipline.utils.input_validation import ensure_upper_case |
91
|
|
|
from zipline.utils.cache import CachedObject, Expired |
92
|
|
|
import zipline.utils.events |
93
|
|
|
from zipline.utils.events import ( |
94
|
|
|
EventManager, |
95
|
|
|
make_eventrule, |
96
|
|
|
DateRuleFactory, |
97
|
|
|
TimeRuleFactory, |
98
|
|
|
) |
99
|
|
|
from zipline.utils.factory import create_simulation_parameters |
100
|
|
|
from zipline.utils.math_utils import tolerant_equals |
101
|
|
|
from zipline.utils.preprocess import preprocess |
102
|
|
|
|
103
|
|
|
import zipline.protocol |
104
|
|
|
from zipline.protocol import Event |
105
|
|
|
|
106
|
|
|
from zipline.history import HistorySpec |
107
|
|
|
from zipline.history.history_container import HistoryContainer |
108
|
|
|
|
109
|
|
|
DEFAULT_CAPITAL_BASE = float("1.0e5") |
110
|
|
|
|
111
|
|
|
|
112
|
|
|
class TradingAlgorithm(object): |
113
|
|
|
""" |
114
|
|
|
Base class for trading algorithms. Inherit and overload |
115
|
|
|
initialize() and handle_data(data). |
116
|
|
|
|
117
|
|
|
A new algorithm could look like this: |
118
|
|
|
``` |
119
|
|
|
from zipline.api import order, symbol |
120
|
|
|
|
121
|
|
|
def initialize(context): |
122
|
|
|
context.sid = symbol('AAPL') |
123
|
|
|
context.amount = 100 |
124
|
|
|
|
125
|
|
|
def handle_data(context, data): |
126
|
|
|
sid = context.sid |
127
|
|
|
amount = context.amount |
128
|
|
|
order(sid, amount) |
129
|
|
|
``` |
130
|
|
|
To then to run this algorithm pass these functions to |
131
|
|
|
TradingAlgorithm: |
132
|
|
|
|
133
|
|
|
my_algo = TradingAlgorithm(initialize, handle_data) |
134
|
|
|
stats = my_algo.run(data) |
135
|
|
|
|
136
|
|
|
""" |
137
|
|
|
|
138
|
|
|
def __init__(self, *args, **kwargs): |
139
|
|
|
"""Initialize sids and other state variables. |
140
|
|
|
|
141
|
|
|
:Arguments: |
142
|
|
|
:Optional: |
143
|
|
|
initialize : function |
144
|
|
|
Function that is called with a single |
145
|
|
|
argument at the begninning of the simulation. |
146
|
|
|
handle_data : function |
147
|
|
|
Function that is called with 2 arguments |
148
|
|
|
(context and data) on every bar. |
149
|
|
|
script : str |
150
|
|
|
Algoscript that contains initialize and |
151
|
|
|
handle_data function definition. |
152
|
|
|
data_frequency : {'daily', 'minute'} |
153
|
|
|
The duration of the bars. |
154
|
|
|
capital_base : float <default: 1.0e5> |
155
|
|
|
How much capital to start with. |
156
|
|
|
instant_fill : bool <default: False> |
157
|
|
|
Whether to fill orders immediately or on next bar. |
158
|
|
|
asset_finder : An AssetFinder object |
159
|
|
|
A new AssetFinder object to be used in this TradingEnvironment |
160
|
|
|
equities_metadata : can be either: |
161
|
|
|
- dict |
162
|
|
|
- pandas.DataFrame |
163
|
|
|
- object with 'read' property |
164
|
|
|
If dict is provided, it must have the following structure: |
165
|
|
|
* keys are the identifiers |
166
|
|
|
* values are dicts containing the metadata, with the metadata |
167
|
|
|
field name as the key |
168
|
|
|
If pandas.DataFrame is provided, it must have the |
169
|
|
|
following structure: |
170
|
|
|
* column names must be the metadata fields |
171
|
|
|
* index must be the different asset identifiers |
172
|
|
|
* array contents should be the metadata value |
173
|
|
|
If an object with a 'read' property is provided, 'read' must |
174
|
|
|
return rows containing at least one of 'sid' or 'symbol' along |
175
|
|
|
with the other metadata fields. |
176
|
|
|
identifiers : List |
177
|
|
|
Any asset identifiers that are not provided in the |
178
|
|
|
equities_metadata, but will be traded by this TradingAlgorithm |
179
|
|
|
""" |
180
|
|
|
self.sources = [] |
181
|
|
|
|
182
|
|
|
# List of trading controls to be used to validate orders. |
183
|
|
|
self.trading_controls = [] |
184
|
|
|
|
185
|
|
|
# List of account controls to be checked on each bar. |
186
|
|
|
self.account_controls = [] |
187
|
|
|
|
188
|
|
|
self._recorded_vars = {} |
189
|
|
|
self.namespace = kwargs.pop('namespace', {}) |
190
|
|
|
|
191
|
|
|
self._platform = kwargs.pop('platform', 'zipline') |
192
|
|
|
|
193
|
|
|
self.logger = None |
194
|
|
|
|
195
|
|
|
self.benchmark_return_source = None |
196
|
|
|
|
197
|
|
|
# default components for transact |
198
|
|
|
self.slippage = VolumeShareSlippage() |
199
|
|
|
self.commission = PerShare() |
200
|
|
|
|
201
|
|
|
self.instant_fill = kwargs.pop('instant_fill', False) |
202
|
|
|
|
203
|
|
|
# If an env has been provided, pop it |
204
|
|
|
self.trading_environment = kwargs.pop('env', None) |
205
|
|
|
|
206
|
|
|
if self.trading_environment is None: |
207
|
|
|
self.trading_environment = TradingEnvironment() |
208
|
|
|
|
209
|
|
|
# Update the TradingEnvironment with the provided asset metadata |
210
|
|
|
self.trading_environment.write_data( |
211
|
|
|
equities_data=kwargs.pop('equities_metadata', {}), |
212
|
|
|
equities_identifiers=kwargs.pop('identifiers', []), |
213
|
|
|
futures_data=kwargs.pop('futures_metadata', {}), |
214
|
|
|
) |
215
|
|
|
|
216
|
|
|
# set the capital base |
217
|
|
|
self.capital_base = kwargs.pop('capital_base', DEFAULT_CAPITAL_BASE) |
218
|
|
|
self.sim_params = kwargs.pop('sim_params', None) |
219
|
|
|
if self.sim_params is None: |
220
|
|
|
self.sim_params = create_simulation_parameters( |
221
|
|
|
capital_base=self.capital_base, |
222
|
|
|
start=kwargs.pop('start', None), |
223
|
|
|
end=kwargs.pop('end', None), |
224
|
|
|
env=self.trading_environment, |
225
|
|
|
) |
226
|
|
|
else: |
227
|
|
|
self.sim_params.update_internal_from_env(self.trading_environment) |
228
|
|
|
|
229
|
|
|
# Build a perf_tracker |
230
|
|
|
self.perf_tracker = PerformanceTracker(sim_params=self.sim_params, |
231
|
|
|
env=self.trading_environment) |
232
|
|
|
|
233
|
|
|
# Pull in the environment's new AssetFinder for quick reference |
234
|
|
|
self.asset_finder = self.trading_environment.asset_finder |
235
|
|
|
|
236
|
|
|
# Initialize Pipeline API data. |
237
|
|
|
self.init_engine(kwargs.pop('get_pipeline_loader', None)) |
238
|
|
|
self._pipelines = {} |
239
|
|
|
# Create an always-expired cache so that we compute the first time data |
240
|
|
|
# is requested. |
241
|
|
|
self._pipeline_cache = CachedObject(None, pd.Timestamp(0, tz='UTC')) |
242
|
|
|
|
243
|
|
|
self.blotter = kwargs.pop('blotter', None) |
244
|
|
|
if not self.blotter: |
245
|
|
|
self.blotter = Blotter() |
246
|
|
|
|
247
|
|
|
# Set the dt initally to the period start by forcing it to change |
248
|
|
|
self.on_dt_changed(self.sim_params.period_start) |
249
|
|
|
|
250
|
|
|
# The symbol lookup date specifies the date to use when resolving |
251
|
|
|
# symbols to sids, and can be set using set_symbol_lookup_date() |
252
|
|
|
self._symbol_lookup_date = None |
253
|
|
|
|
254
|
|
|
self.portfolio_needs_update = True |
255
|
|
|
self.account_needs_update = True |
256
|
|
|
self.performance_needs_update = True |
257
|
|
|
self._portfolio = None |
258
|
|
|
self._account = None |
259
|
|
|
|
260
|
|
|
self.history_container_class = kwargs.pop( |
261
|
|
|
'history_container_class', HistoryContainer, |
262
|
|
|
) |
263
|
|
|
self.history_container = None |
264
|
|
|
self.history_specs = {} |
265
|
|
|
|
266
|
|
|
# If string is passed in, execute and get reference to |
267
|
|
|
# functions. |
268
|
|
|
self.algoscript = kwargs.pop('script', None) |
269
|
|
|
|
270
|
|
|
self._initialize = None |
271
|
|
|
self._before_trading_start = None |
272
|
|
|
self._analyze = None |
273
|
|
|
|
274
|
|
|
self.event_manager = EventManager( |
275
|
|
|
create_context=kwargs.pop('create_event_context', None), |
276
|
|
|
) |
277
|
|
|
|
278
|
|
|
if self.algoscript is not None: |
279
|
|
|
filename = kwargs.pop('algo_filename', None) |
280
|
|
|
if filename is None: |
281
|
|
|
filename = '<string>' |
282
|
|
|
code = compile(self.algoscript, filename, 'exec') |
283
|
|
|
exec_(code, self.namespace) |
284
|
|
|
self._initialize = self.namespace.get('initialize') |
285
|
|
|
if 'handle_data' not in self.namespace: |
286
|
|
|
raise ValueError('You must define a handle_data function.') |
287
|
|
|
else: |
288
|
|
|
self._handle_data = self.namespace['handle_data'] |
289
|
|
|
|
290
|
|
|
self._before_trading_start = \ |
291
|
|
|
self.namespace.get('before_trading_start') |
292
|
|
|
# Optional analyze function, gets called after run |
293
|
|
|
self._analyze = self.namespace.get('analyze') |
294
|
|
|
|
295
|
|
|
elif kwargs.get('initialize') and kwargs.get('handle_data'): |
296
|
|
|
if self.algoscript is not None: |
297
|
|
|
raise ValueError('You can not set script and \ |
298
|
|
|
initialize/handle_data.') |
299
|
|
|
self._initialize = kwargs.pop('initialize') |
300
|
|
|
self._handle_data = kwargs.pop('handle_data') |
301
|
|
|
self._before_trading_start = kwargs.pop('before_trading_start', |
302
|
|
|
None) |
303
|
|
|
self._analyze = kwargs.pop('analyze', None) |
304
|
|
|
|
305
|
|
|
self.event_manager.add_event( |
306
|
|
|
zipline.utils.events.Event( |
307
|
|
|
zipline.utils.events.Always(), |
308
|
|
|
# We pass handle_data.__func__ to get the unbound method. |
309
|
|
|
# We will explicitly pass the algorithm to bind it again. |
310
|
|
|
self.handle_data.__func__, |
311
|
|
|
), |
312
|
|
|
prepend=True, |
313
|
|
|
) |
314
|
|
|
|
315
|
|
|
# If method not defined, NOOP |
316
|
|
|
if self._initialize is None: |
317
|
|
|
self._initialize = lambda x: None |
318
|
|
|
|
319
|
|
|
# Alternative way of setting data_frequency for backwards |
320
|
|
|
# compatibility. |
321
|
|
|
if 'data_frequency' in kwargs: |
322
|
|
|
self.data_frequency = kwargs.pop('data_frequency') |
323
|
|
|
|
324
|
|
|
self._most_recent_data = None |
325
|
|
|
|
326
|
|
|
# Prepare the algo for initialization |
327
|
|
|
self.initialized = False |
328
|
|
|
self.initialize_args = args |
329
|
|
|
self.initialize_kwargs = kwargs |
330
|
|
|
|
331
|
|
|
def init_engine(self, get_loader): |
332
|
|
|
""" |
333
|
|
|
Construct and store a PipelineEngine from loader. |
334
|
|
|
|
335
|
|
|
If get_loader is None, constructs a NoOpPipelineEngine. |
336
|
|
|
""" |
337
|
|
|
if get_loader is not None: |
338
|
|
|
self.engine = SimplePipelineEngine( |
339
|
|
|
get_loader, |
340
|
|
|
self.trading_environment.trading_days, |
341
|
|
|
self.asset_finder, |
342
|
|
|
) |
343
|
|
|
else: |
344
|
|
|
self.engine = NoOpPipelineEngine() |
345
|
|
|
|
346
|
|
|
def initialize(self, *args, **kwargs): |
347
|
|
|
""" |
348
|
|
|
Call self._initialize with `self` made available to Zipline API |
349
|
|
|
functions. |
350
|
|
|
""" |
351
|
|
|
with ZiplineAPI(self): |
352
|
|
|
self._initialize(self, *args, **kwargs) |
353
|
|
|
|
354
|
|
|
def before_trading_start(self, data): |
355
|
|
|
if self._before_trading_start is None: |
356
|
|
|
return |
357
|
|
|
|
358
|
|
|
self._before_trading_start(self, data) |
359
|
|
|
|
360
|
|
|
def handle_data(self, data): |
361
|
|
|
self._most_recent_data = data |
362
|
|
|
if self.history_container: |
363
|
|
|
self.history_container.update(data, self.datetime) |
364
|
|
|
|
365
|
|
|
self._handle_data(self, data) |
366
|
|
|
|
367
|
|
|
# Unlike trading controls which remain constant unless placing an |
368
|
|
|
# order, account controls can change each bar. Thus, must check |
369
|
|
|
# every bar no matter if the algorithm places an order or not. |
370
|
|
|
self.validate_account_controls() |
371
|
|
|
|
372
|
|
|
def analyze(self, perf): |
373
|
|
|
if self._analyze is None: |
374
|
|
|
return |
375
|
|
|
|
376
|
|
|
with ZiplineAPI(self): |
377
|
|
|
self._analyze(self, perf) |
378
|
|
|
|
379
|
|
|
def __repr__(self): |
380
|
|
|
""" |
381
|
|
|
N.B. this does not yet represent a string that can be used |
382
|
|
|
to instantiate an exact copy of an algorithm. |
383
|
|
|
|
384
|
|
|
However, it is getting close, and provides some value as something |
385
|
|
|
that can be inspected interactively. |
386
|
|
|
""" |
387
|
|
|
return """ |
388
|
|
|
{class_name}( |
389
|
|
|
capital_base={capital_base} |
390
|
|
|
sim_params={sim_params}, |
391
|
|
|
initialized={initialized}, |
392
|
|
|
slippage={slippage}, |
393
|
|
|
commission={commission}, |
394
|
|
|
blotter={blotter}, |
395
|
|
|
recorded_vars={recorded_vars}) |
396
|
|
|
""".strip().format(class_name=self.__class__.__name__, |
397
|
|
|
capital_base=self.capital_base, |
398
|
|
|
sim_params=repr(self.sim_params), |
399
|
|
|
initialized=self.initialized, |
400
|
|
|
slippage=repr(self.slippage), |
401
|
|
|
commission=repr(self.commission), |
402
|
|
|
blotter=repr(self.blotter), |
403
|
|
|
recorded_vars=repr(self.recorded_vars)) |
404
|
|
|
|
405
|
|
|
def _create_data_generator(self, source_filter, sim_params=None): |
406
|
|
|
""" |
407
|
|
|
Create a merged data generator using the sources attached to this |
408
|
|
|
algorithm. |
409
|
|
|
|
410
|
|
|
::source_filter:: is a method that receives events in date |
411
|
|
|
sorted order, and returns True for those events that should be |
412
|
|
|
processed by the zipline, and False for those that should be |
413
|
|
|
skipped. |
414
|
|
|
""" |
415
|
|
|
if sim_params is None: |
416
|
|
|
sim_params = self.sim_params |
417
|
|
|
|
418
|
|
|
if self.benchmark_return_source is None: |
419
|
|
|
if sim_params.data_frequency == 'minute' or \ |
420
|
|
|
sim_params.emission_rate == 'minute': |
421
|
|
|
def update_time(date): |
422
|
|
|
return self.trading_environment.get_open_and_close(date)[1] |
423
|
|
|
else: |
424
|
|
|
def update_time(date): |
425
|
|
|
return date |
426
|
|
|
benchmark_return_source = [ |
427
|
|
|
Event({'dt': update_time(dt), |
428
|
|
|
'returns': ret, |
429
|
|
|
'type': zipline.protocol.DATASOURCE_TYPE.BENCHMARK, |
430
|
|
|
'source_id': 'benchmarks'}) |
431
|
|
|
for dt, ret in |
432
|
|
|
self.trading_environment.benchmark_returns.iteritems() |
433
|
|
|
if dt.date() >= sim_params.period_start.date() and |
434
|
|
|
dt.date() <= sim_params.period_end.date() |
435
|
|
|
] |
436
|
|
|
else: |
437
|
|
|
benchmark_return_source = self.benchmark_return_source |
438
|
|
|
|
439
|
|
|
date_sorted = date_sorted_sources(*self.sources) |
440
|
|
|
|
441
|
|
|
if source_filter: |
442
|
|
|
date_sorted = filter(source_filter, date_sorted) |
443
|
|
|
|
444
|
|
|
with_benchmarks = date_sorted_sources(benchmark_return_source, |
445
|
|
|
date_sorted) |
446
|
|
|
|
447
|
|
|
# Group together events with the same dt field. This depends on the |
448
|
|
|
# events already being sorted. |
449
|
|
|
return groupby(with_benchmarks, attrgetter('dt')) |
450
|
|
|
|
451
|
|
|
def _create_generator(self, sim_params, source_filter=None): |
452
|
|
|
""" |
453
|
|
|
Create a basic generator setup using the sources to this algorithm. |
454
|
|
|
|
455
|
|
|
::source_filter:: is a method that receives events in date |
456
|
|
|
sorted order, and returns True for those events that should be |
457
|
|
|
processed by the zipline, and False for those that should be |
458
|
|
|
skipped. |
459
|
|
|
""" |
460
|
|
|
|
461
|
|
|
if not self.initialized: |
462
|
|
|
self.initialize(*self.initialize_args, **self.initialize_kwargs) |
463
|
|
|
self.initialized = True |
464
|
|
|
|
465
|
|
|
if self.perf_tracker is None: |
466
|
|
|
# HACK: When running with the `run` method, we set perf_tracker to |
467
|
|
|
# None so that it will be overwritten here. |
468
|
|
|
self.perf_tracker = PerformanceTracker( |
469
|
|
|
sim_params=sim_params, env=self.trading_environment |
470
|
|
|
) |
471
|
|
|
|
472
|
|
|
self.portfolio_needs_update = True |
473
|
|
|
self.account_needs_update = True |
474
|
|
|
self.performance_needs_update = True |
475
|
|
|
|
476
|
|
|
self.data_gen = self._create_data_generator(source_filter, sim_params) |
477
|
|
|
|
478
|
|
|
self.trading_client = AlgorithmSimulator(self, sim_params) |
479
|
|
|
|
480
|
|
|
transact_method = transact_partial(self.slippage, self.commission) |
481
|
|
|
self.set_transact(transact_method) |
482
|
|
|
|
483
|
|
|
return self.trading_client.transform(self.data_gen) |
484
|
|
|
|
485
|
|
|
def get_generator(self): |
486
|
|
|
""" |
487
|
|
|
Override this method to add new logic to the construction |
488
|
|
|
of the generator. Overrides can use the _create_generator |
489
|
|
|
method to get a standard construction generator. |
490
|
|
|
""" |
491
|
|
|
return self._create_generator(self.sim_params) |
492
|
|
|
|
493
|
|
|
# TODO: make a new subclass, e.g. BatchAlgorithm, and move |
494
|
|
|
# the run method to the subclass, and refactor to put the |
495
|
|
|
# generator creation logic into get_generator. |
496
|
|
|
def run(self, source, overwrite_sim_params=True, |
497
|
|
|
benchmark_return_source=None): |
498
|
|
|
"""Run the algorithm. |
499
|
|
|
|
500
|
|
|
:Arguments: |
501
|
|
|
source : can be either: |
502
|
|
|
- pandas.DataFrame |
503
|
|
|
- zipline source |
504
|
|
|
- list of sources |
505
|
|
|
|
506
|
|
|
If pandas.DataFrame is provided, it must have the |
507
|
|
|
following structure: |
508
|
|
|
* column names must be the different asset identifiers |
509
|
|
|
* index must be DatetimeIndex |
510
|
|
|
* array contents should be price info. |
511
|
|
|
|
512
|
|
|
:Returns: |
513
|
|
|
daily_stats : pandas.DataFrame |
514
|
|
|
Daily performance metrics such as returns, alpha etc. |
515
|
|
|
|
516
|
|
|
""" |
517
|
|
|
|
518
|
|
|
# Ensure that source is a DataSource object |
519
|
|
|
if isinstance(source, list): |
520
|
|
|
if overwrite_sim_params: |
521
|
|
|
warnings.warn("""List of sources passed, will not attempt to extract start and end |
522
|
|
|
dates. Make sure to set the correct fields in sim_params passed to |
523
|
|
|
__init__().""", UserWarning) |
524
|
|
|
overwrite_sim_params = False |
525
|
|
|
elif isinstance(source, pd.DataFrame): |
526
|
|
|
# if DataFrame provided, map columns to sids and wrap |
527
|
|
|
# in DataFrameSource |
528
|
|
|
copy_frame = source.copy() |
529
|
|
|
copy_frame.columns = self._write_and_map_id_index_to_sids( |
530
|
|
|
source.columns, source.index[0], |
531
|
|
|
) |
532
|
|
|
source = DataFrameSource(copy_frame) |
533
|
|
|
|
534
|
|
|
elif isinstance(source, pd.Panel): |
535
|
|
|
# If Panel provided, map items to sids and wrap |
536
|
|
|
# in DataPanelSource |
537
|
|
|
copy_panel = source.copy() |
538
|
|
|
copy_panel.items = self._write_and_map_id_index_to_sids( |
539
|
|
|
source.items, source.major_axis[0], |
540
|
|
|
) |
541
|
|
|
source = DataPanelSource(copy_panel) |
542
|
|
|
|
543
|
|
|
if isinstance(source, list): |
544
|
|
|
self.set_sources(source) |
545
|
|
|
else: |
546
|
|
|
self.set_sources([source]) |
547
|
|
|
|
548
|
|
|
# Override sim_params if params are provided by the source. |
549
|
|
|
if overwrite_sim_params: |
550
|
|
|
if hasattr(source, 'start'): |
551
|
|
|
self.sim_params.period_start = source.start |
552
|
|
|
if hasattr(source, 'end'): |
553
|
|
|
self.sim_params.period_end = source.end |
554
|
|
|
# Changing period_start and period_close might require updating |
555
|
|
|
# of first_open and last_close. |
556
|
|
|
self.sim_params.update_internal_from_env( |
557
|
|
|
env=self.trading_environment |
558
|
|
|
) |
559
|
|
|
|
560
|
|
|
# The sids field of the source is the reference for the universe at |
561
|
|
|
# the start of the run |
562
|
|
|
self._current_universe = set() |
563
|
|
|
for source in self.sources: |
564
|
|
|
for sid in source.sids: |
565
|
|
|
self._current_universe.add(sid) |
566
|
|
|
# Check that all sids from the source are accounted for in |
567
|
|
|
# the AssetFinder. This retrieve call will raise an exception if the |
568
|
|
|
# sid is not found. |
569
|
|
|
for sid in self._current_universe: |
570
|
|
|
self.asset_finder.retrieve_asset(sid) |
571
|
|
|
|
572
|
|
|
# force a reset of the performance tracker, in case |
573
|
|
|
# this is a repeat run of the algorithm. |
574
|
|
|
self.perf_tracker = None |
575
|
|
|
|
576
|
|
|
# create zipline |
577
|
|
|
self.gen = self._create_generator(self.sim_params) |
578
|
|
|
|
579
|
|
|
# Create history containers |
580
|
|
|
if self.history_specs: |
581
|
|
|
self.history_container = self.history_container_class( |
582
|
|
|
self.history_specs, |
583
|
|
|
self.current_universe(), |
584
|
|
|
self.sim_params.first_open, |
585
|
|
|
self.sim_params.data_frequency, |
586
|
|
|
self.trading_environment, |
587
|
|
|
) |
588
|
|
|
|
589
|
|
|
# loop through simulated_trading, each iteration returns a |
590
|
|
|
# perf dictionary |
591
|
|
|
perfs = [] |
592
|
|
|
for perf in self.gen: |
593
|
|
|
perfs.append(perf) |
594
|
|
|
|
595
|
|
|
# convert perf dict to pandas dataframe |
596
|
|
|
daily_stats = self._create_daily_stats(perfs) |
597
|
|
|
|
598
|
|
|
self.analyze(daily_stats) |
599
|
|
|
|
600
|
|
|
return daily_stats |
601
|
|
|
|
602
|
|
|
def _write_and_map_id_index_to_sids(self, identifiers, as_of_date): |
603
|
|
|
# Build new Assets for identifiers that can't be resolved as |
604
|
|
|
# sids/Assets |
605
|
|
|
identifiers_to_build = [] |
606
|
|
|
for identifier in identifiers: |
607
|
|
|
asset = None |
608
|
|
|
|
609
|
|
|
if isinstance(identifier, Asset): |
610
|
|
|
asset = self.asset_finder.retrieve_asset(sid=identifier.sid, |
611
|
|
|
default_none=True) |
612
|
|
|
elif isinstance(identifier, Integral): |
613
|
|
|
asset = self.asset_finder.retrieve_asset(sid=identifier, |
614
|
|
|
default_none=True) |
615
|
|
|
if asset is None: |
616
|
|
|
identifiers_to_build.append(identifier) |
617
|
|
|
|
618
|
|
|
self.trading_environment.write_data( |
619
|
|
|
equities_identifiers=identifiers_to_build) |
620
|
|
|
|
621
|
|
|
# We need to clear out any cache misses that were stored while trying |
622
|
|
|
# to do lookups. The real fix for this problem is to not construct an |
623
|
|
|
# AssetFinder until we `run()` when we actually have all the data we |
624
|
|
|
# need to so. |
625
|
|
|
self.asset_finder._reset_caches() |
626
|
|
|
|
627
|
|
|
return self.asset_finder.map_identifier_index_to_sids( |
628
|
|
|
identifiers, as_of_date, |
629
|
|
|
) |
630
|
|
|
|
631
|
|
|
def _create_daily_stats(self, perfs): |
632
|
|
|
# create daily and cumulative stats dataframe |
633
|
|
|
daily_perfs = [] |
634
|
|
|
# TODO: the loop here could overwrite expected properties |
635
|
|
|
# of daily_perf. Could potentially raise or log a |
636
|
|
|
# warning. |
637
|
|
|
for perf in perfs: |
638
|
|
|
if 'daily_perf' in perf: |
639
|
|
|
|
640
|
|
|
perf['daily_perf'].update( |
641
|
|
|
perf['daily_perf'].pop('recorded_vars') |
642
|
|
|
) |
643
|
|
|
perf['daily_perf'].update(perf['cumulative_risk_metrics']) |
644
|
|
|
daily_perfs.append(perf['daily_perf']) |
645
|
|
|
else: |
646
|
|
|
self.risk_report = perf |
647
|
|
|
|
648
|
|
|
daily_dts = [np.datetime64(perf['period_close'], utc=True) |
649
|
|
|
for perf in daily_perfs] |
650
|
|
|
daily_stats = pd.DataFrame(daily_perfs, index=daily_dts) |
651
|
|
|
|
652
|
|
|
return daily_stats |
653
|
|
|
|
654
|
|
|
@api_method |
655
|
|
|
def add_transform(self, transform, days=None): |
656
|
|
|
""" |
657
|
|
|
Ensures that the history container will have enough size to service |
658
|
|
|
a simple transform. |
659
|
|
|
|
660
|
|
|
:Arguments: |
661
|
|
|
transform : string |
662
|
|
|
The transform to add. must be an element of: |
663
|
|
|
{'mavg', 'stddev', 'vwap', 'returns'}. |
664
|
|
|
days : int <default=None> |
665
|
|
|
The maximum amount of days you will want for this transform. |
666
|
|
|
This is not needed for 'returns'. |
667
|
|
|
""" |
668
|
|
|
if transform not in {'mavg', 'stddev', 'vwap', 'returns'}: |
669
|
|
|
raise ValueError('Invalid transform') |
670
|
|
|
|
671
|
|
|
if transform == 'returns': |
672
|
|
|
if days is not None: |
673
|
|
|
raise ValueError('returns does use days') |
674
|
|
|
|
675
|
|
|
self.add_history(2, '1d', 'price') |
676
|
|
|
return |
677
|
|
|
elif days is None: |
678
|
|
|
raise ValueError('no number of days specified') |
679
|
|
|
|
680
|
|
|
if self.sim_params.data_frequency == 'daily': |
681
|
|
|
mult = 1 |
682
|
|
|
freq = '1d' |
683
|
|
|
else: |
684
|
|
|
mult = 390 |
685
|
|
|
freq = '1m' |
686
|
|
|
|
687
|
|
|
bars = mult * days |
688
|
|
|
self.add_history(bars, freq, 'price') |
689
|
|
|
|
690
|
|
|
if transform == 'vwap': |
691
|
|
|
self.add_history(bars, freq, 'volume') |
692
|
|
|
|
693
|
|
|
@api_method |
694
|
|
|
def get_environment(self, field='platform'): |
695
|
|
|
env = { |
696
|
|
|
'arena': self.sim_params.arena, |
697
|
|
|
'data_frequency': self.sim_params.data_frequency, |
698
|
|
|
'start': self.sim_params.first_open, |
699
|
|
|
'end': self.sim_params.last_close, |
700
|
|
|
'capital_base': self.sim_params.capital_base, |
701
|
|
|
'platform': self._platform |
702
|
|
|
} |
703
|
|
|
if field == '*': |
704
|
|
|
return env |
705
|
|
|
else: |
706
|
|
|
return env[field] |
707
|
|
|
|
708
|
|
|
def add_event(self, rule=None, callback=None): |
709
|
|
|
""" |
710
|
|
|
Adds an event to the algorithm's EventManager. |
711
|
|
|
""" |
712
|
|
|
self.event_manager.add_event( |
713
|
|
|
zipline.utils.events.Event(rule, callback), |
714
|
|
|
) |
715
|
|
|
|
716
|
|
|
@api_method |
717
|
|
|
def schedule_function(self, |
718
|
|
|
func, |
719
|
|
|
date_rule=None, |
720
|
|
|
time_rule=None, |
721
|
|
|
half_days=True): |
722
|
|
|
""" |
723
|
|
|
Schedules a function to be called with some timed rules. |
724
|
|
|
""" |
725
|
|
|
date_rule = date_rule or DateRuleFactory.every_day() |
726
|
|
|
time_rule = ((time_rule or TimeRuleFactory.market_open()) |
727
|
|
|
if self.sim_params.data_frequency == 'minute' else |
728
|
|
|
# If we are in daily mode the time_rule is ignored. |
729
|
|
|
zipline.utils.events.Always()) |
730
|
|
|
|
731
|
|
|
self.add_event( |
732
|
|
|
make_eventrule(date_rule, time_rule, half_days), |
733
|
|
|
func, |
734
|
|
|
) |
735
|
|
|
|
736
|
|
|
@api_method |
737
|
|
|
def record(self, *args, **kwargs): |
738
|
|
|
""" |
739
|
|
|
Track and record local variable (i.e. attributes) each day. |
740
|
|
|
""" |
741
|
|
|
# Make 2 objects both referencing the same iterator |
742
|
|
|
args = [iter(args)] * 2 |
743
|
|
|
|
744
|
|
|
# Zip generates list entries by calling `next` on each iterator it |
745
|
|
|
# receives. In this case the two iterators are the same object, so the |
746
|
|
|
# call to next on args[0] will also advance args[1], resulting in zip |
747
|
|
|
# returning (a,b) (c,d) (e,f) rather than (a,a) (b,b) (c,c) etc. |
748
|
|
|
positionals = zip(*args) |
749
|
|
|
for name, value in chain(positionals, iteritems(kwargs)): |
750
|
|
|
self._recorded_vars[name] = value |
751
|
|
|
|
752
|
|
|
@api_method |
753
|
|
|
@preprocess(symbol_str=ensure_upper_case) |
754
|
|
|
def symbol(self, symbol_str): |
755
|
|
|
""" |
756
|
|
|
Default symbol lookup for any source that directly maps the |
757
|
|
|
symbol to the Asset (e.g. yahoo finance). |
758
|
|
|
""" |
759
|
|
|
# If the user has not set the symbol lookup date, |
760
|
|
|
# use the period_end as the date for sybmol->sid resolution. |
761
|
|
|
_lookup_date = self._symbol_lookup_date if self._symbol_lookup_date is not None \ |
762
|
|
|
else self.sim_params.period_end |
763
|
|
|
|
764
|
|
|
return self.asset_finder.lookup_symbol( |
765
|
|
|
symbol_str, |
766
|
|
|
as_of_date=_lookup_date, |
767
|
|
|
) |
768
|
|
|
|
769
|
|
|
@api_method |
770
|
|
|
def symbols(self, *args): |
771
|
|
|
""" |
772
|
|
|
Default symbols lookup for any source that directly maps the |
773
|
|
|
symbol to the Asset (e.g. yahoo finance). |
774
|
|
|
""" |
775
|
|
|
return [self.symbol(identifier) for identifier in args] |
776
|
|
|
|
777
|
|
|
@api_method |
778
|
|
|
def sid(self, a_sid): |
779
|
|
|
""" |
780
|
|
|
Default sid lookup for any source that directly maps the integer sid |
781
|
|
|
to the Asset. |
782
|
|
|
""" |
783
|
|
|
return self.asset_finder.retrieve_asset(a_sid) |
784
|
|
|
|
785
|
|
|
@api_method |
786
|
|
|
@preprocess(symbol=ensure_upper_case) |
787
|
|
|
def future_symbol(self, symbol): |
788
|
|
|
""" Lookup a futures contract with a given symbol. |
789
|
|
|
|
790
|
|
|
Parameters |
791
|
|
|
---------- |
792
|
|
|
symbol : str |
793
|
|
|
The symbol of the desired contract. |
794
|
|
|
|
795
|
|
|
Returns |
796
|
|
|
------- |
797
|
|
|
Future |
798
|
|
|
A Future object. |
799
|
|
|
|
800
|
|
|
Raises |
801
|
|
|
------ |
802
|
|
|
SymbolNotFound |
803
|
|
|
Raised when no contract named 'symbol' is found. |
804
|
|
|
|
805
|
|
|
""" |
806
|
|
|
return self.asset_finder.lookup_future_symbol(symbol) |
807
|
|
|
|
808
|
|
|
@api_method |
809
|
|
|
@preprocess(root_symbol=ensure_upper_case) |
810
|
|
|
def future_chain(self, root_symbol, as_of_date=None): |
811
|
|
|
""" Look up a future chain with the specified parameters. |
812
|
|
|
|
813
|
|
|
Parameters |
814
|
|
|
---------- |
815
|
|
|
root_symbol : str |
816
|
|
|
The root symbol of a future chain. |
817
|
|
|
as_of_date : datetime.datetime or pandas.Timestamp or str, optional |
818
|
|
|
Date at which the chain determination is rooted. I.e. the |
819
|
|
|
existing contract whose notice date is first after this date is |
820
|
|
|
the primary contract, etc. |
821
|
|
|
|
822
|
|
|
Returns |
823
|
|
|
------- |
824
|
|
|
FutureChain |
825
|
|
|
The future chain matching the specified parameters. |
826
|
|
|
|
827
|
|
|
Raises |
828
|
|
|
------ |
829
|
|
|
RootSymbolNotFound |
830
|
|
|
If a future chain could not be found for the given root symbol. |
831
|
|
|
""" |
832
|
|
|
if as_of_date: |
833
|
|
|
try: |
834
|
|
|
as_of_date = pd.Timestamp(as_of_date, tz='UTC') |
835
|
|
|
except ValueError: |
836
|
|
|
raise UnsupportedDatetimeFormat(input=as_of_date, |
837
|
|
|
method='future_chain') |
838
|
|
|
return FutureChain( |
839
|
|
|
asset_finder=self.asset_finder, |
840
|
|
|
get_datetime=self.get_datetime, |
841
|
|
|
root_symbol=root_symbol, |
842
|
|
|
as_of_date=as_of_date |
843
|
|
|
) |
844
|
|
|
|
845
|
|
|
def _calculate_order_value_amount(self, asset, value): |
846
|
|
|
""" |
847
|
|
|
Calculates how many shares/contracts to order based on the type of |
848
|
|
|
asset being ordered. |
849
|
|
|
""" |
850
|
|
|
last_price = self.trading_client.current_data[asset].price |
851
|
|
|
|
852
|
|
|
if tolerant_equals(last_price, 0): |
853
|
|
|
zero_message = "Price of 0 for {psid}; can't infer value".format( |
854
|
|
|
psid=asset |
855
|
|
|
) |
856
|
|
|
if self.logger: |
857
|
|
|
self.logger.debug(zero_message) |
858
|
|
|
# Don't place any order |
859
|
|
|
return 0 |
860
|
|
|
|
861
|
|
|
if isinstance(asset, Future): |
862
|
|
|
value_multiplier = asset.contract_multiplier |
863
|
|
|
else: |
864
|
|
|
value_multiplier = 1 |
865
|
|
|
|
866
|
|
|
return value / (last_price * value_multiplier) |
867
|
|
|
|
868
|
|
|
@api_method |
869
|
|
|
def order(self, sid, amount, |
870
|
|
|
limit_price=None, |
871
|
|
|
stop_price=None, |
872
|
|
|
style=None): |
873
|
|
|
""" |
874
|
|
|
Place an order using the specified parameters. |
875
|
|
|
""" |
876
|
|
|
|
877
|
|
|
def round_if_near_integer(a, epsilon=1e-4): |
878
|
|
|
""" |
879
|
|
|
Round a to the nearest integer if that integer is within an epsilon |
880
|
|
|
of a. |
881
|
|
|
""" |
882
|
|
|
if abs(a - round(a)) <= epsilon: |
883
|
|
|
return round(a) |
884
|
|
|
else: |
885
|
|
|
return a |
886
|
|
|
|
887
|
|
|
# Truncate to the integer share count that's either within .0001 of |
888
|
|
|
# amount or closer to zero. |
889
|
|
|
# E.g. 3.9999 -> 4.0; 5.5 -> 5.0; -5.5 -> -5.0 |
890
|
|
|
amount = int(round_if_near_integer(amount)) |
891
|
|
|
|
892
|
|
|
# Raises a ZiplineError if invalid parameters are detected. |
893
|
|
|
self.validate_order_params(sid, |
894
|
|
|
amount, |
895
|
|
|
limit_price, |
896
|
|
|
stop_price, |
897
|
|
|
style) |
898
|
|
|
|
899
|
|
|
# Convert deprecated limit_price and stop_price parameters to use |
900
|
|
|
# ExecutionStyle objects. |
901
|
|
|
style = self.__convert_order_params_for_blotter(limit_price, |
902
|
|
|
stop_price, |
903
|
|
|
style) |
904
|
|
|
return self.blotter.order(sid, amount, style) |
905
|
|
|
|
906
|
|
|
def validate_order_params(self, |
907
|
|
|
asset, |
908
|
|
|
amount, |
909
|
|
|
limit_price, |
910
|
|
|
stop_price, |
911
|
|
|
style): |
912
|
|
|
""" |
913
|
|
|
Helper method for validating parameters to the order API function. |
914
|
|
|
|
915
|
|
|
Raises an UnsupportedOrderParameters if invalid arguments are found. |
916
|
|
|
""" |
917
|
|
|
|
918
|
|
|
if not self.initialized: |
919
|
|
|
raise OrderDuringInitialize( |
920
|
|
|
msg="order() can only be called from within handle_data()" |
921
|
|
|
) |
922
|
|
|
|
923
|
|
|
if style: |
924
|
|
|
if limit_price: |
925
|
|
|
raise UnsupportedOrderParameters( |
926
|
|
|
msg="Passing both limit_price and style is not supported." |
927
|
|
|
) |
928
|
|
|
|
929
|
|
|
if stop_price: |
930
|
|
|
raise UnsupportedOrderParameters( |
931
|
|
|
msg="Passing both stop_price and style is not supported." |
932
|
|
|
) |
933
|
|
|
|
934
|
|
|
if not isinstance(asset, Asset): |
935
|
|
|
raise UnsupportedOrderParameters( |
936
|
|
|
msg="Passing non-Asset argument to 'order()' is not supported." |
937
|
|
|
" Use 'sid()' or 'symbol()' methods to look up an Asset." |
938
|
|
|
) |
939
|
|
|
|
940
|
|
|
for control in self.trading_controls: |
941
|
|
|
control.validate(asset, |
942
|
|
|
amount, |
943
|
|
|
self.updated_portfolio(), |
944
|
|
|
self.get_datetime(), |
945
|
|
|
self.trading_client.current_data) |
946
|
|
|
|
947
|
|
|
@staticmethod |
948
|
|
|
def __convert_order_params_for_blotter(limit_price, stop_price, style): |
949
|
|
|
""" |
950
|
|
|
Helper method for converting deprecated limit_price and stop_price |
951
|
|
|
arguments into ExecutionStyle instances. |
952
|
|
|
|
953
|
|
|
This function assumes that either style == None or (limit_price, |
954
|
|
|
stop_price) == (None, None). |
955
|
|
|
""" |
956
|
|
|
# TODO_SS: DeprecationWarning for usage of limit_price and stop_price. |
957
|
|
|
if style: |
958
|
|
|
assert (limit_price, stop_price) == (None, None) |
959
|
|
|
return style |
960
|
|
|
if limit_price and stop_price: |
961
|
|
|
return StopLimitOrder(limit_price, stop_price) |
962
|
|
|
if limit_price: |
963
|
|
|
return LimitOrder(limit_price) |
964
|
|
|
if stop_price: |
965
|
|
|
return StopOrder(stop_price) |
966
|
|
|
else: |
967
|
|
|
return MarketOrder() |
968
|
|
|
|
969
|
|
|
@api_method |
970
|
|
|
def order_value(self, sid, value, |
971
|
|
|
limit_price=None, stop_price=None, style=None): |
972
|
|
|
""" |
973
|
|
|
Place an order by desired value rather than desired number of shares. |
974
|
|
|
If the requested sid is found in the universe, the requested value is |
975
|
|
|
divided by its price to imply the number of shares to transact. |
976
|
|
|
If the Asset being ordered is a Future, the 'value' calculated |
977
|
|
|
is actually the exposure, as Futures have no 'value'. |
978
|
|
|
|
979
|
|
|
value > 0 :: Buy/Cover |
980
|
|
|
value < 0 :: Sell/Short |
981
|
|
|
Market order: order(sid, value) |
982
|
|
|
Limit order: order(sid, value, limit_price) |
983
|
|
|
Stop order: order(sid, value, None, stop_price) |
984
|
|
|
StopLimit order: order(sid, value, limit_price, stop_price) |
985
|
|
|
""" |
986
|
|
|
amount = self._calculate_order_value_amount(sid, value) |
987
|
|
|
return self.order(sid, amount, |
988
|
|
|
limit_price=limit_price, |
989
|
|
|
stop_price=stop_price, |
990
|
|
|
style=style) |
991
|
|
|
|
992
|
|
|
@property |
993
|
|
|
def recorded_vars(self): |
994
|
|
|
return copy(self._recorded_vars) |
995
|
|
|
|
996
|
|
|
@property |
997
|
|
|
def portfolio(self): |
998
|
|
|
return self.updated_portfolio() |
999
|
|
|
|
1000
|
|
|
def updated_portfolio(self): |
1001
|
|
|
if self.portfolio_needs_update: |
1002
|
|
|
self._portfolio = \ |
1003
|
|
|
self.perf_tracker.get_portfolio(self.performance_needs_update) |
1004
|
|
|
self.portfolio_needs_update = False |
1005
|
|
|
self.performance_needs_update = False |
1006
|
|
|
return self._portfolio |
1007
|
|
|
|
1008
|
|
|
@property |
1009
|
|
|
def account(self): |
1010
|
|
|
return self.updated_account() |
1011
|
|
|
|
1012
|
|
|
def updated_account(self): |
1013
|
|
|
if self.account_needs_update: |
1014
|
|
|
self._account = \ |
1015
|
|
|
self.perf_tracker.get_account(self.performance_needs_update) |
1016
|
|
|
self.account_needs_update = False |
1017
|
|
|
self.performance_needs_update = False |
1018
|
|
|
return self._account |
1019
|
|
|
|
1020
|
|
|
def set_logger(self, logger): |
1021
|
|
|
self.logger = logger |
1022
|
|
|
|
1023
|
|
|
def on_dt_changed(self, dt): |
1024
|
|
|
""" |
1025
|
|
|
Callback triggered by the simulation loop whenever the current dt |
1026
|
|
|
changes. |
1027
|
|
|
|
1028
|
|
|
Any logic that should happen exactly once at the start of each datetime |
1029
|
|
|
group should happen here. |
1030
|
|
|
""" |
1031
|
|
|
assert isinstance(dt, datetime), \ |
1032
|
|
|
"Attempt to set algorithm's current time with non-datetime" |
1033
|
|
|
assert dt.tzinfo == pytz.utc, \ |
1034
|
|
|
"Algorithm expects a utc datetime" |
1035
|
|
|
|
1036
|
|
|
self.datetime = dt |
1037
|
|
|
self.perf_tracker.set_date(dt) |
1038
|
|
|
self.blotter.set_date(dt) |
1039
|
|
|
|
1040
|
|
|
@api_method |
1041
|
|
|
def get_datetime(self, tz=None): |
1042
|
|
|
""" |
1043
|
|
|
Returns the simulation datetime. |
1044
|
|
|
""" |
1045
|
|
|
dt = self.datetime |
1046
|
|
|
assert dt.tzinfo == pytz.utc, "Algorithm should have a utc datetime" |
1047
|
|
|
|
1048
|
|
|
if tz is not None: |
1049
|
|
|
# Convert to the given timezone passed as a string or tzinfo. |
1050
|
|
|
if isinstance(tz, string_types): |
1051
|
|
|
tz = pytz.timezone(tz) |
1052
|
|
|
dt = dt.astimezone(tz) |
1053
|
|
|
|
1054
|
|
|
return dt # datetime.datetime objects are immutable. |
1055
|
|
|
|
1056
|
|
|
def set_transact(self, transact): |
1057
|
|
|
""" |
1058
|
|
|
Set the method that will be called to create a |
1059
|
|
|
transaction from open orders and trade events. |
1060
|
|
|
""" |
1061
|
|
|
self.blotter.transact = transact |
1062
|
|
|
|
1063
|
|
|
def update_dividends(self, dividend_frame): |
1064
|
|
|
""" |
1065
|
|
|
Set DataFrame used to process dividends. DataFrame columns should |
1066
|
|
|
contain at least the entries in zp.DIVIDEND_FIELDS. |
1067
|
|
|
""" |
1068
|
|
|
self.perf_tracker.update_dividends(dividend_frame) |
1069
|
|
|
|
1070
|
|
|
@api_method |
1071
|
|
|
def set_slippage(self, slippage): |
1072
|
|
|
if not isinstance(slippage, SlippageModel): |
1073
|
|
|
raise UnsupportedSlippageModel() |
1074
|
|
|
if self.initialized: |
1075
|
|
|
raise SetSlippagePostInit() |
1076
|
|
|
self.slippage = slippage |
1077
|
|
|
|
1078
|
|
|
@api_method |
1079
|
|
|
def set_commission(self, commission): |
1080
|
|
|
if not isinstance(commission, (PerShare, PerTrade, PerDollar)): |
1081
|
|
|
raise UnsupportedCommissionModel() |
1082
|
|
|
|
1083
|
|
|
if self.initialized: |
1084
|
|
|
raise SetCommissionPostInit() |
1085
|
|
|
self.commission = commission |
1086
|
|
|
|
1087
|
|
|
@api_method |
1088
|
|
|
def set_symbol_lookup_date(self, dt): |
1089
|
|
|
""" |
1090
|
|
|
Set the date for which symbols will be resolved to their sids |
1091
|
|
|
(symbols may map to different firms or underlying assets at |
1092
|
|
|
different times) |
1093
|
|
|
""" |
1094
|
|
|
try: |
1095
|
|
|
self._symbol_lookup_date = pd.Timestamp(dt, tz='UTC') |
1096
|
|
|
except ValueError: |
1097
|
|
|
raise UnsupportedDatetimeFormat(input=dt, |
1098
|
|
|
method='set_symbol_lookup_date') |
1099
|
|
|
|
1100
|
|
|
def set_sources(self, sources): |
1101
|
|
|
assert isinstance(sources, list) |
1102
|
|
|
self.sources = sources |
1103
|
|
|
|
1104
|
|
|
# Remain backwards compatibility |
1105
|
|
|
@property |
1106
|
|
|
def data_frequency(self): |
1107
|
|
|
return self.sim_params.data_frequency |
1108
|
|
|
|
1109
|
|
|
@data_frequency.setter |
1110
|
|
|
def data_frequency(self, value): |
1111
|
|
|
assert value in ('daily', 'minute') |
1112
|
|
|
self.sim_params.data_frequency = value |
1113
|
|
|
|
1114
|
|
|
@api_method |
1115
|
|
|
def order_percent(self, sid, percent, |
1116
|
|
|
limit_price=None, stop_price=None, style=None): |
1117
|
|
|
""" |
1118
|
|
|
Place an order in the specified asset corresponding to the given |
1119
|
|
|
percent of the current portfolio value. |
1120
|
|
|
|
1121
|
|
|
Note that percent must expressed as a decimal (0.50 means 50\%). |
1122
|
|
|
""" |
1123
|
|
|
value = self.portfolio.portfolio_value * percent |
1124
|
|
|
return self.order_value(sid, value, |
1125
|
|
|
limit_price=limit_price, |
1126
|
|
|
stop_price=stop_price, |
1127
|
|
|
style=style) |
1128
|
|
|
|
1129
|
|
|
@api_method |
1130
|
|
|
def order_target(self, sid, target, |
1131
|
|
|
limit_price=None, stop_price=None, style=None): |
1132
|
|
|
""" |
1133
|
|
|
Place an order to adjust a position to a target number of shares. If |
1134
|
|
|
the position doesn't already exist, this is equivalent to placing a new |
1135
|
|
|
order. If the position does exist, this is equivalent to placing an |
1136
|
|
|
order for the difference between the target number of shares and the |
1137
|
|
|
current number of shares. |
1138
|
|
|
""" |
1139
|
|
|
if sid in self.portfolio.positions: |
1140
|
|
|
current_position = self.portfolio.positions[sid].amount |
1141
|
|
|
req_shares = target - current_position |
1142
|
|
|
return self.order(sid, req_shares, |
1143
|
|
|
limit_price=limit_price, |
1144
|
|
|
stop_price=stop_price, |
1145
|
|
|
style=style) |
1146
|
|
|
else: |
1147
|
|
|
return self.order(sid, target, |
1148
|
|
|
limit_price=limit_price, |
1149
|
|
|
stop_price=stop_price, |
1150
|
|
|
style=style) |
1151
|
|
|
|
1152
|
|
|
@api_method |
1153
|
|
|
def order_target_value(self, sid, target, |
1154
|
|
|
limit_price=None, stop_price=None, style=None): |
1155
|
|
|
""" |
1156
|
|
|
Place an order to adjust a position to a target value. If |
1157
|
|
|
the position doesn't already exist, this is equivalent to placing a new |
1158
|
|
|
order. If the position does exist, this is equivalent to placing an |
1159
|
|
|
order for the difference between the target value and the |
1160
|
|
|
current value. |
1161
|
|
|
If the Asset being ordered is a Future, the 'target value' calculated |
1162
|
|
|
is actually the target exposure, as Futures have no 'value'. |
1163
|
|
|
""" |
1164
|
|
|
target_amount = self._calculate_order_value_amount(sid, target) |
1165
|
|
|
return self.order_target(sid, target_amount, |
1166
|
|
|
limit_price=limit_price, |
1167
|
|
|
stop_price=stop_price, |
1168
|
|
|
style=style) |
1169
|
|
|
|
1170
|
|
|
@api_method |
1171
|
|
|
def order_target_percent(self, sid, target, |
1172
|
|
|
limit_price=None, stop_price=None, style=None): |
1173
|
|
|
""" |
1174
|
|
|
Place an order to adjust a position to a target percent of the |
1175
|
|
|
current portfolio value. If the position doesn't already exist, this is |
1176
|
|
|
equivalent to placing a new order. If the position does exist, this is |
1177
|
|
|
equivalent to placing an order for the difference between the target |
1178
|
|
|
percent and the current percent. |
1179
|
|
|
|
1180
|
|
|
Note that target must expressed as a decimal (0.50 means 50\%). |
1181
|
|
|
""" |
1182
|
|
|
target_value = self.portfolio.portfolio_value * target |
1183
|
|
|
return self.order_target_value(sid, target_value, |
1184
|
|
|
limit_price=limit_price, |
1185
|
|
|
stop_price=stop_price, |
1186
|
|
|
style=style) |
1187
|
|
|
|
1188
|
|
|
@api_method |
1189
|
|
|
def get_open_orders(self, sid=None): |
1190
|
|
|
if sid is None: |
1191
|
|
|
return { |
1192
|
|
|
key: [order.to_api_obj() for order in orders] |
1193
|
|
|
for key, orders in iteritems(self.blotter.open_orders) |
1194
|
|
|
if orders |
1195
|
|
|
} |
1196
|
|
|
if sid in self.blotter.open_orders: |
1197
|
|
|
orders = self.blotter.open_orders[sid] |
1198
|
|
|
return [order.to_api_obj() for order in orders] |
1199
|
|
|
return [] |
1200
|
|
|
|
1201
|
|
|
@api_method |
1202
|
|
|
def get_order(self, order_id): |
1203
|
|
|
if order_id in self.blotter.orders: |
1204
|
|
|
return self.blotter.orders[order_id].to_api_obj() |
1205
|
|
|
|
1206
|
|
|
@api_method |
1207
|
|
|
def cancel_order(self, order_param): |
1208
|
|
|
order_id = order_param |
1209
|
|
|
if isinstance(order_param, zipline.protocol.Order): |
1210
|
|
|
order_id = order_param.id |
1211
|
|
|
|
1212
|
|
|
self.blotter.cancel(order_id) |
1213
|
|
|
|
1214
|
|
|
@api_method |
1215
|
|
|
def add_history(self, bar_count, frequency, field, ffill=True): |
1216
|
|
|
data_frequency = self.sim_params.data_frequency |
1217
|
|
|
history_spec = HistorySpec(bar_count, frequency, field, ffill, |
1218
|
|
|
data_frequency=data_frequency, |
1219
|
|
|
env=self.trading_environment) |
1220
|
|
|
self.history_specs[history_spec.key_str] = history_spec |
1221
|
|
|
if self.initialized: |
1222
|
|
|
if self.history_container: |
1223
|
|
|
self.history_container.ensure_spec( |
1224
|
|
|
history_spec, self.datetime, self._most_recent_data, |
1225
|
|
|
) |
1226
|
|
|
else: |
1227
|
|
|
self.history_container = self.history_container_class( |
1228
|
|
|
self.history_specs, |
1229
|
|
|
self.current_universe(), |
1230
|
|
|
self.sim_params.first_open, |
1231
|
|
|
self.sim_params.data_frequency, |
1232
|
|
|
env=self.trading_environment, |
1233
|
|
|
) |
1234
|
|
|
|
1235
|
|
|
def get_history_spec(self, bar_count, frequency, field, ffill): |
1236
|
|
|
spec_key = HistorySpec.spec_key(bar_count, frequency, field, ffill) |
1237
|
|
|
if spec_key not in self.history_specs: |
1238
|
|
|
data_freq = self.sim_params.data_frequency |
1239
|
|
|
spec = HistorySpec( |
1240
|
|
|
bar_count, |
1241
|
|
|
frequency, |
1242
|
|
|
field, |
1243
|
|
|
ffill, |
1244
|
|
|
data_frequency=data_freq, |
1245
|
|
|
env=self.trading_environment, |
1246
|
|
|
) |
1247
|
|
|
self.history_specs[spec_key] = spec |
1248
|
|
|
if not self.history_container: |
1249
|
|
|
self.history_container = self.history_container_class( |
1250
|
|
|
self.history_specs, |
1251
|
|
|
self.current_universe(), |
1252
|
|
|
self.datetime, |
1253
|
|
|
self.sim_params.data_frequency, |
1254
|
|
|
bar_data=self._most_recent_data, |
1255
|
|
|
env=self.trading_environment, |
1256
|
|
|
) |
1257
|
|
|
self.history_container.ensure_spec( |
1258
|
|
|
spec, self.datetime, self._most_recent_data, |
1259
|
|
|
) |
1260
|
|
|
return self.history_specs[spec_key] |
1261
|
|
|
|
1262
|
|
|
@api_method |
1263
|
|
|
@require_initialized(HistoryInInitialize()) |
1264
|
|
|
def history(self, bar_count, frequency, field, ffill=True): |
1265
|
|
|
history_spec = self.get_history_spec( |
1266
|
|
|
bar_count, |
1267
|
|
|
frequency, |
1268
|
|
|
field, |
1269
|
|
|
ffill, |
1270
|
|
|
) |
1271
|
|
|
return self.history_container.get_history(history_spec, self.datetime) |
1272
|
|
|
|
1273
|
|
|
#################### |
1274
|
|
|
# Account Controls # |
1275
|
|
|
#################### |
1276
|
|
|
|
1277
|
|
|
def register_account_control(self, control): |
1278
|
|
|
""" |
1279
|
|
|
Register a new AccountControl to be checked on each bar. |
1280
|
|
|
""" |
1281
|
|
|
if self.initialized: |
1282
|
|
|
raise RegisterAccountControlPostInit() |
1283
|
|
|
self.account_controls.append(control) |
1284
|
|
|
|
1285
|
|
|
def validate_account_controls(self): |
1286
|
|
|
for control in self.account_controls: |
1287
|
|
|
control.validate(self.updated_portfolio(), |
1288
|
|
|
self.updated_account(), |
1289
|
|
|
self.get_datetime(), |
1290
|
|
|
self.trading_client.current_data) |
1291
|
|
|
|
1292
|
|
|
@api_method |
1293
|
|
|
def set_max_leverage(self, max_leverage=None): |
1294
|
|
|
""" |
1295
|
|
|
Set a limit on the maximum leverage of the algorithm. |
1296
|
|
|
""" |
1297
|
|
|
control = MaxLeverage(max_leverage) |
1298
|
|
|
self.register_account_control(control) |
1299
|
|
|
|
1300
|
|
|
#################### |
1301
|
|
|
# Trading Controls # |
1302
|
|
|
#################### |
1303
|
|
|
|
1304
|
|
|
def register_trading_control(self, control): |
1305
|
|
|
""" |
1306
|
|
|
Register a new TradingControl to be checked prior to order calls. |
1307
|
|
|
""" |
1308
|
|
|
if self.initialized: |
1309
|
|
|
raise RegisterTradingControlPostInit() |
1310
|
|
|
self.trading_controls.append(control) |
1311
|
|
|
|
1312
|
|
|
@api_method |
1313
|
|
|
def set_max_position_size(self, |
1314
|
|
|
sid=None, |
1315
|
|
|
max_shares=None, |
1316
|
|
|
max_notional=None): |
1317
|
|
|
""" |
1318
|
|
|
Set a limit on the number of shares and/or dollar value held for the |
1319
|
|
|
given sid. Limits are treated as absolute values and are enforced at |
1320
|
|
|
the time that the algo attempts to place an order for sid. This means |
1321
|
|
|
that it's possible to end up with more than the max number of shares |
1322
|
|
|
due to splits/dividends, and more than the max notional due to price |
1323
|
|
|
improvement. |
1324
|
|
|
|
1325
|
|
|
If an algorithm attempts to place an order that would result in |
1326
|
|
|
increasing the absolute value of shares/dollar value exceeding one of |
1327
|
|
|
these limits, raise a TradingControlException. |
1328
|
|
|
""" |
1329
|
|
|
control = MaxPositionSize(asset=sid, |
1330
|
|
|
max_shares=max_shares, |
1331
|
|
|
max_notional=max_notional) |
1332
|
|
|
self.register_trading_control(control) |
1333
|
|
|
|
1334
|
|
|
@api_method |
1335
|
|
|
def set_max_order_size(self, sid=None, max_shares=None, max_notional=None): |
1336
|
|
|
""" |
1337
|
|
|
Set a limit on the number of shares and/or dollar value of any single |
1338
|
|
|
order placed for sid. Limits are treated as absolute values and are |
1339
|
|
|
enforced at the time that the algo attempts to place an order for sid. |
1340
|
|
|
|
1341
|
|
|
If an algorithm attempts to place an order that would result in |
1342
|
|
|
exceeding one of these limits, raise a TradingControlException. |
1343
|
|
|
""" |
1344
|
|
|
control = MaxOrderSize(asset=sid, |
1345
|
|
|
max_shares=max_shares, |
1346
|
|
|
max_notional=max_notional) |
1347
|
|
|
self.register_trading_control(control) |
1348
|
|
|
|
1349
|
|
|
@api_method |
1350
|
|
|
def set_max_order_count(self, max_count): |
1351
|
|
|
""" |
1352
|
|
|
Set a limit on the number of orders that can be placed within the given |
1353
|
|
|
time interval. |
1354
|
|
|
""" |
1355
|
|
|
control = MaxOrderCount(max_count) |
1356
|
|
|
self.register_trading_control(control) |
1357
|
|
|
|
1358
|
|
|
@api_method |
1359
|
|
|
def set_do_not_order_list(self, restricted_list): |
1360
|
|
|
""" |
1361
|
|
|
Set a restriction on which sids can be ordered. |
1362
|
|
|
""" |
1363
|
|
|
control = RestrictedListOrder(restricted_list) |
1364
|
|
|
self.register_trading_control(control) |
1365
|
|
|
|
1366
|
|
|
@api_method |
1367
|
|
|
def set_long_only(self): |
1368
|
|
|
""" |
1369
|
|
|
Set a rule specifying that this algorithm cannot take short positions. |
1370
|
|
|
""" |
1371
|
|
|
self.register_trading_control(LongOnly()) |
1372
|
|
|
|
1373
|
|
|
############## |
1374
|
|
|
# Pipeline API |
1375
|
|
|
############## |
1376
|
|
|
@api_method |
1377
|
|
|
@require_not_initialized(AttachPipelineAfterInitialize()) |
1378
|
|
|
def attach_pipeline(self, pipeline, name, chunksize=None): |
1379
|
|
|
""" |
1380
|
|
|
Register a pipeline to be computed at the start of each day. |
1381
|
|
|
""" |
1382
|
|
|
if self._pipelines: |
1383
|
|
|
raise NotImplementedError("Multiple pipelines are not supported.") |
1384
|
|
|
if chunksize is None: |
1385
|
|
|
# Make the first chunk smaller to get more immediate results: |
1386
|
|
|
# (one week, then every half year) |
1387
|
|
|
chunks = iter(chain([5], repeat(126))) |
1388
|
|
|
else: |
1389
|
|
|
chunks = iter(repeat(int(chunksize))) |
1390
|
|
|
self._pipelines[name] = pipeline, chunks |
1391
|
|
|
|
1392
|
|
|
# Return the pipeline to allow expressions like |
1393
|
|
|
# p = attach_pipeline(Pipeline(), 'name') |
1394
|
|
|
return pipeline |
1395
|
|
|
|
1396
|
|
|
@api_method |
1397
|
|
|
@require_initialized(PipelineOutputDuringInitialize()) |
1398
|
|
|
def pipeline_output(self, name): |
1399
|
|
|
""" |
1400
|
|
|
Get the results of pipeline with name `name`. |
1401
|
|
|
|
1402
|
|
|
Parameters |
1403
|
|
|
---------- |
1404
|
|
|
name : str |
1405
|
|
|
Name of the pipeline for which results are requested. |
1406
|
|
|
|
1407
|
|
|
Returns |
1408
|
|
|
------- |
1409
|
|
|
results : pd.DataFrame |
1410
|
|
|
DataFrame containing the results of the requested pipeline for |
1411
|
|
|
the current simulation date. |
1412
|
|
|
|
1413
|
|
|
Raises |
1414
|
|
|
------ |
1415
|
|
|
NoSuchPipeline |
1416
|
|
|
Raised when no pipeline with the name `name` has been registered. |
1417
|
|
|
|
1418
|
|
|
See Also |
1419
|
|
|
-------- |
1420
|
|
|
:meth:`zipline.pipeline.engine.PipelineEngine.run_pipeline` |
1421
|
|
|
""" |
1422
|
|
|
# NOTE: We don't currently support multiple pipelines, but we plan to |
1423
|
|
|
# in the future. |
1424
|
|
|
try: |
1425
|
|
|
p, chunks = self._pipelines[name] |
1426
|
|
|
except KeyError: |
1427
|
|
|
raise NoSuchPipeline( |
1428
|
|
|
name=name, |
1429
|
|
|
valid=list(self._pipelines.keys()), |
1430
|
|
|
) |
1431
|
|
|
return self._pipeline_output(p, chunks) |
1432
|
|
|
|
1433
|
|
|
def _pipeline_output(self, pipeline, chunks): |
1434
|
|
|
""" |
1435
|
|
|
Internal implementation of `pipeline_output`. |
1436
|
|
|
""" |
1437
|
|
|
today = normalize_date(self.get_datetime()) |
1438
|
|
|
try: |
1439
|
|
|
data = self._pipeline_cache.unwrap(today) |
1440
|
|
|
except Expired: |
1441
|
|
|
data, valid_until = self._run_pipeline( |
1442
|
|
|
pipeline, today, next(chunks), |
1443
|
|
|
) |
1444
|
|
|
self._pipeline_cache = CachedObject(data, valid_until) |
1445
|
|
|
|
1446
|
|
|
# Now that we have a cached result, try to return the data for today. |
1447
|
|
|
try: |
1448
|
|
|
return data.loc[today] |
1449
|
|
|
except KeyError: |
1450
|
|
|
# This happens if no assets passed the pipeline screen on a given |
1451
|
|
|
# day. |
1452
|
|
|
return pd.DataFrame(index=[], columns=data.columns) |
1453
|
|
|
|
1454
|
|
|
def _run_pipeline(self, pipeline, start_date, chunksize): |
1455
|
|
|
""" |
1456
|
|
|
Compute `pipeline`, providing values for at least `start_date`. |
1457
|
|
|
|
1458
|
|
|
Produces a DataFrame containing data for days between `start_date` and |
1459
|
|
|
`end_date`, where `end_date` is defined by: |
1460
|
|
|
|
1461
|
|
|
`end_date = min(start_date + chunksize trading days, |
1462
|
|
|
simulation_end)` |
1463
|
|
|
|
1464
|
|
|
Returns |
1465
|
|
|
------- |
1466
|
|
|
(data, valid_until) : tuple (pd.DataFrame, pd.Timestamp) |
1467
|
|
|
|
1468
|
|
|
See Also |
1469
|
|
|
-------- |
1470
|
|
|
PipelineEngine.run_pipeline |
1471
|
|
|
""" |
1472
|
|
|
days = self.trading_environment.trading_days |
1473
|
|
|
|
1474
|
|
|
# Load data starting from the previous trading day... |
1475
|
|
|
start_date_loc = days.get_loc(start_date) |
1476
|
|
|
|
1477
|
|
|
# ...continuing until either the day before the simulation end, or |
1478
|
|
|
# until chunksize days of data have been loaded. |
1479
|
|
|
sim_end = self.sim_params.last_close.normalize() |
1480
|
|
|
end_loc = min(start_date_loc + chunksize, days.get_loc(sim_end)) |
1481
|
|
|
end_date = days[end_loc] |
1482
|
|
|
|
1483
|
|
|
return \ |
1484
|
|
|
self.engine.run_pipeline(pipeline, start_date, end_date), end_date |
1485
|
|
|
|
1486
|
|
|
################## |
1487
|
|
|
# End Pipeline API |
1488
|
|
|
################## |
1489
|
|
|
|
1490
|
|
|
def current_universe(self): |
1491
|
|
|
return self._current_universe |
1492
|
|
|
|
1493
|
|
|
@classmethod |
1494
|
|
|
def all_api_methods(cls): |
1495
|
|
|
""" |
1496
|
|
|
Return a list of all the TradingAlgorithm API methods. |
1497
|
|
|
""" |
1498
|
|
|
return [ |
1499
|
|
|
fn for fn in itervalues(vars(cls)) |
1500
|
|
|
if getattr(fn, 'is_api_method', False) |
1501
|
|
|
] |
1502
|
|
|
|