Completed
Pull Request — master (#858)
by Eddie
02:03
created

zipline.TradingAlgorithm.run()   B

Complexity

Conditions 3

Size

Total Lines 35

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 3
dl 0
loc 35
rs 8.8571
1
#
2
# Copyright 2014 Quantopian, Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
from copy import copy
16
17
import pytz
18
import pandas as pd
19
from pandas.tseries.tools import normalize_date
20
import numpy as np
21
22
from datetime import datetime
23
from itertools import chain, repeat
24
from numbers import Integral
25
26
from six import (
27
    exec_,
28
    iteritems,
29
    itervalues,
30
    string_types,
31
)
32
33
from zipline.errors import (
34
    AttachPipelineAfterInitialize,
35
    HistoryInInitialize,
36
    NoSuchPipeline,
37
    OrderDuringInitialize,
38
    OverrideCommissionPostInit,
39
    OverrideSlippagePostInit,
40
    PipelineOutputDuringInitialize,
41
    RegisterAccountControlPostInit,
42
    RegisterTradingControlPostInit,
43
    SetBenchmarkOutsideInitialize,
44
    UnsupportedCommissionModel,
45
    UnsupportedDatetimeFormat,
46
    UnsupportedOrderParameters,
47
    UnsupportedSlippageModel,
48
)
49
from zipline.finance.trading import TradingEnvironment
50
from zipline.finance.blotter import Blotter
51
from zipline.finance.commission import PerShare, PerTrade, PerDollar
52
from zipline.finance.controls import (
53
    LongOnly,
54
    MaxOrderCount,
55
    MaxOrderSize,
56
    MaxPositionSize,
57
    MaxLeverage,
58
    RestrictedListOrder
59
)
60
from zipline.finance.execution import (
61
    LimitOrder,
62
    MarketOrder,
63
    StopLimitOrder,
64
    StopOrder,
65
)
66
from zipline.finance.performance import PerformanceTracker
67
from zipline.finance.slippage import (
68
    VolumeShareSlippage,
69
    SlippageModel
70
)
71
from zipline.assets import Asset, Future
72
from zipline.assets.futures import FutureChain
73
from zipline.gens.tradesimulation import AlgorithmSimulator
74
from zipline.pipeline.engine import (
75
    NoOpPipelineEngine,
76
    SimplePipelineEngine,
77
)
78
from zipline.utils.api_support import (
79
    api_method,
80
    require_initialized,
81
    require_not_initialized,
82
    ZiplineAPI,
83
)
84
from zipline.utils.input_validation import ensure_upper_case
85
from zipline.utils.cache import CachedObject, Expired
86
import zipline.utils.events
87
from zipline.utils.events import (
88
    EventManager,
89
    make_eventrule,
90
    DateRuleFactory,
91
    TimeRuleFactory,
92
)
93
from zipline.utils.factory import create_simulation_parameters
94
from zipline.utils.math_utils import tolerant_equals, round_if_near_integer
95
from zipline.utils.preprocess import preprocess
96
97
import zipline.protocol
98
from zipline.sources.requests_csv import PandasRequestsCSV
99
100
from zipline.gens.sim_engine import (
101
    MinuteSimulationClock,
102
    DailySimulationClock,
103
    MinuteEmissionClock)
104
from zipline.sources.benchmark_source import BenchmarkSource
105
106
DEFAULT_CAPITAL_BASE = float("1.0e5")
107
108
109
class TradingAlgorithm(object):
110
    """
111
    Base class for trading algorithms. Inherit and overload
112
    initialize() and handle_data(data).
113
114
    A new algorithm could look like this:
115
    ```
116
    from zipline.api import order, symbol
117
118
    def initialize(context):
119
        context.sid = symbol('AAPL')
120
        context.amount = 100
121
122
    def handle_data(context, data):
123
        sid = context.sid
124
        amount = context.amount
125
        order(sid, amount)
126
    ```
127
    To then to run this algorithm pass these functions to
128
    TradingAlgorithm:
129
130
    my_algo = TradingAlgorithm(initialize, handle_data)
131
    stats = my_algo.run(data)
132
133
    """
134
135
    def __init__(self, *args, **kwargs):
136
        """Initialize sids and other state variables.
137
138
        :Arguments:
139
        :Optional:
140
            initialize : function
141
                Function that is called with a single
142
                argument at the begninning of the simulation.
143
            handle_data : function
144
                Function that is called with 2 arguments
145
                (context and data) on every bar.
146
            script : str
147
                Algoscript that contains initialize and
148
                handle_data function definition.
149
            data_frequency : {'daily', 'minute'}
150
               The duration of the bars.
151
            capital_base : float <default: 1.0e5>
152
               How much capital to start with.
153
            instant_fill : bool <default: False>
154
               Whether to fill orders immediately or on next bar.
155
            asset_finder : An AssetFinder object
156
                A new AssetFinder object to be used in this TradingEnvironment
157
            equities_metadata : can be either:
158
                            - dict
159
                            - pandas.DataFrame
160
                            - object with 'read' property
161
                If dict is provided, it must have the following structure:
162
                * keys are the identifiers
163
                * values are dicts containing the metadata, with the metadata
164
                  field name as the key
165
                If pandas.DataFrame is provided, it must have the
166
                following structure:
167
                * column names must be the metadata fields
168
                * index must be the different asset identifiers
169
                * array contents should be the metadata value
170
                If an object with a 'read' property is provided, 'read' must
171
                return rows containing at least one of 'sid' or 'symbol' along
172
                with the other metadata fields.
173
            identifiers : List
174
                Any asset identifiers that are not provided in the
175
                equities_metadata, but will be traded by this TradingAlgorithm
176
        """
177
        self.sources = []
178
        self.clock = None
179
180
        # List of trading controls to be used to validate orders.
181
        self.trading_controls = []
182
183
        # List of account controls to be checked on each bar.
184
        self.account_controls = []
185
186
        self._recorded_vars = {}
187
        self.namespace = kwargs.pop('namespace', {})
188
189
        self._platform = kwargs.pop('platform', 'zipline')
190
191
        self.logger = None
192
193
        self.benchmark_source = None
194
195
        self.instant_fill = kwargs.pop('instant_fill', False)
196
197
        # If an env has been provided, pop it
198
        self.trading_environment = kwargs.pop('env', None)
199
200
        if self.trading_environment is None:
201
            self.trading_environment = TradingEnvironment()
202
203
        self.data_portal = None
204
205
        # Update the TradingEnvironment with the provided asset metadata
206
        self.trading_environment.write_data(
207
            equities_data=kwargs.pop('equities_metadata', {}),
208
            equities_identifiers=kwargs.pop('identifiers', []),
209
            futures_data=kwargs.pop('futures_metadata', {}),
210
        )
211
212
        # set the capital base
213
        self.capital_base = kwargs.pop('capital_base', DEFAULT_CAPITAL_BASE)
214
        self.sim_params = kwargs.pop('sim_params', None)
215
        if self.sim_params is None:
216
            self.sim_params = create_simulation_parameters(
217
                capital_base=self.capital_base,
218
                start=kwargs.pop('start', None),
219
                end=kwargs.pop('end', None),
220
                env=self.trading_environment,
221
            )
222
        else:
223
            self.sim_params.update_internal_from_env(self.trading_environment)
224
225
        self.perf_tracker = None
226
        # Pull in the environment's new AssetFinder for quick reference
227
        self.asset_finder = self.trading_environment.asset_finder
228
229
        # Initialize Pipeline API data.
230
        self.init_engine(kwargs.pop('get_pipeline_loader', None))
231
        self._pipelines = {}
232
        # Create an always-expired cache so that we compute the first time data
233
        # is requested.
234
        self._pipeline_cache = CachedObject(None, pd.Timestamp(0, tz='UTC'))
235
236
        self.blotter = kwargs.pop('blotter', None)
237
        if not self.blotter:
238
            self.blotter = Blotter(
239
                slippage_func=VolumeShareSlippage(),
240
                commission=PerShare()
241
            )
242
243
        # The symbol lookup date specifies the date to use when resolving
244
        # symbols to sids, and can be set using set_symbol_lookup_date()
245
        self._symbol_lookup_date = None
246
247
        self._portfolio = None
248
        self._account = None
249
250
        # If string is passed in, execute and get reference to
251
        # functions.
252
        self.algoscript = kwargs.pop('script', None)
253
254
        self._initialize = None
255
        self._before_trading_start = None
256
        self._analyze = None
257
258
        self.event_manager = EventManager(
259
            create_context=kwargs.pop('create_event_context', None),
260
        )
261
262
        if self.algoscript is not None:
263
            filename = kwargs.pop('algo_filename', None)
264
            if filename is None:
265
                filename = '<string>'
266
            code = compile(self.algoscript, filename, 'exec')
267
            exec_(code, self.namespace)
268
            self._initialize = self.namespace.get('initialize')
269
            if 'handle_data' not in self.namespace:
270
                raise ValueError('You must define a handle_data function.')
271
            else:
272
                self._handle_data = self.namespace['handle_data']
273
274
            self._before_trading_start = \
275
                self.namespace.get('before_trading_start')
276
            # Optional analyze function, gets called after run
277
            self._analyze = self.namespace.get('analyze')
278
279
        elif kwargs.get('initialize') and kwargs.get('handle_data'):
280
            if self.algoscript is not None:
281
                raise ValueError('You can not set script and \
282
                initialize/handle_data.')
283
            self._initialize = kwargs.pop('initialize')
284
            self._handle_data = kwargs.pop('handle_data')
285
            self._before_trading_start = kwargs.pop('before_trading_start',
286
                                                    None)
287
            self._analyze = kwargs.pop('analyze', None)
288
289
        self.event_manager.add_event(
290
            zipline.utils.events.Event(
291
                zipline.utils.events.Always(),
292
                # We pass handle_data.__func__ to get the unbound method.
293
                # We will explicitly pass the algorithm to bind it again.
294
                self.handle_data.__func__,
295
            ),
296
            prepend=True,
297
        )
298
299
        # If method not defined, NOOP
300
        if self._initialize is None:
301
            self._initialize = lambda x: None
302
303
        # Alternative way of setting data_frequency for backwards
304
        # compatibility.
305
        if 'data_frequency' in kwargs:
306
            self.data_frequency = kwargs.pop('data_frequency')
307
308
        # Prepare the algo for initialization
309
        self.initialized = False
310
        self.initialize_args = args
311
        self.initialize_kwargs = kwargs
312
313
        self.benchmark_sid = kwargs.pop('benchmark_sid', None)
314
315
    def init_engine(self, get_loader):
316
        """
317
        Construct and store a PipelineEngine from loader.
318
319
        If get_loader is None, constructs a NoOpPipelineEngine.
320
        """
321
        if get_loader is not None:
322
            self.engine = SimplePipelineEngine(
323
                get_loader,
324
                self.trading_environment.trading_days,
325
                self.asset_finder,
326
            )
327
        else:
328
            self.engine = NoOpPipelineEngine()
329
330
    def initialize(self, *args, **kwargs):
331
        """
332
        Call self._initialize with `self` made available to Zipline API
333
        functions.
334
        """
335
        with ZiplineAPI(self):
336
            self._initialize(self, *args, **kwargs)
337
338
    def before_trading_start(self, data):
339
        if self._before_trading_start is None:
340
            return
341
342
        self._before_trading_start(self, data)
343
344
    def handle_data(self, data):
345
        self._handle_data(self, data)
346
347
        # Unlike trading controls which remain constant unless placing an
348
        # order, account controls can change each bar. Thus, must check
349
        # every bar no matter if the algorithm places an order or not.
350
        self.validate_account_controls()
351
352
    def analyze(self, perf):
353
        if self._analyze is None:
354
            return
355
356
        with ZiplineAPI(self):
357
            self._analyze(self, perf)
358
359
    def __repr__(self):
360
        """
361
        N.B. this does not yet represent a string that can be used
362
        to instantiate an exact copy of an algorithm.
363
364
        However, it is getting close, and provides some value as something
365
        that can be inspected interactively.
366
        """
367
        return """
368
{class_name}(
369
    capital_base={capital_base}
370
    sim_params={sim_params},
371
    initialized={initialized},
372
    slippage={slippage},
373
    commission={commission},
374
    blotter={blotter},
375
    recorded_vars={recorded_vars})
376
""".strip().format(class_name=self.__class__.__name__,
377
                   capital_base=self.capital_base,
378
                   sim_params=repr(self.sim_params),
379
                   initialized=self.initialized,
380
                   slippage=repr(self.blotter.slippage_func),
381
                   commission=repr(self.blotter.commission),
382
                   blotter=repr(self.blotter),
383
                   recorded_vars=repr(self.recorded_vars))
384
385
    def ensure_clock(self):
386
        """
387
        If the clock property is not set, then create one based on frequency.
388
        """
389
        if self.clock is None:
390
            if self.sim_params.data_frequency == 'minute':
391
                env = self.trading_environment
392
                trading_o_and_c = env.open_and_closes.ix[
393
                    self.sim_params.trading_days]
394
                market_opens = trading_o_and_c['market_open'].values.astype(
395
                    'datetime64[ns]').astype(np.int64)
396
                market_closes = trading_o_and_c['market_close'].values.astype(
397
                    'datetime64[ns]').astype(np.int64)
398
                if self.sim_params.emission_rate == "daily":
399
                    self.clock = MinuteSimulationClock(
400
                        self.sim_params.trading_days,
401
                        market_opens,
402
                        market_closes,
403
                        self.data_portal,
404
                        env.trading_days
405
                    )
406
                else:
407
                    self.clock = MinuteEmissionClock(
408
                        self.sim_params.trading_days,
409
                        market_opens,
410
                        market_closes,
411
                        self.data_portal,
412
                        env.trading_days
413
                    )
414
415
            elif self.sim_params.data_frequency == 'daily':
416
                self.clock = DailySimulationClock(self.sim_params.trading_days)
417
418
    def create_benchmark_source(self):
419
        return BenchmarkSource(
420
            self.benchmark_sid,
421
            self.trading_environment,
422
            self.sim_params.trading_days,
423
            self.data_portal,
424
            emission_rate=self.sim_params.emission_rate,
425
        )
426
427
    def _create_generator(self, sim_params):
428
        """
429
        Create a basic generator setup using the sources to this algorithm.
430
431
        ::source_filter:: is a method that receives events in date
432
        sorted order, and returns True for those events that should be
433
        processed by the zipline, and False for those that should be
434
        skipped.
435
        """
436
        if sim_params is not None:
437
            self.sim_params = sim_params
438
439
        if self.perf_tracker is None:
440
            # Build a perf_tracker
441
            self.perf_tracker = PerformanceTracker(
442
                sim_params=self.sim_params,
443
                env=self.trading_environment,
444
                data_portal=self.data_portal)
445
            # Set the dt initially to the period start by forcing it to change
446
            self.on_dt_changed(self.sim_params.period_start)
447
448
            # HACK: When running with the `run` method, we set perf_tracker to
449
            # None so that it will be overwritten here.
450
            self.perf_tracker = PerformanceTracker(
451
                sim_params=sim_params, env=self.trading_environment,
452
                data_portal=self.data_portal
453
            )
454
455
        if not self.initialized:
456
            self.initialize(*self.initialize_args, **self.initialize_kwargs)
457
            self.initialized = True
458
459
        self.ensure_clock()
460
461
        self.trading_client = AlgorithmSimulator(
462
            self,
463
            sim_params,
464
            self.data_portal,
465
            self.clock,
466
            self.create_benchmark_source()
467
        )
468
469
        return self.trading_client.transform()
470
471
    def get_generator(self):
472
        """
473
        Override this method to add new logic to the construction
474
        of the generator. Overrides can use the _create_generator
475
        method to get a standard construction generator.
476
        """
477
        return self._create_generator(self.sim_params)
478
479
    def run(self, data_portal=None):
480
        """Run the algorithm.
481
482
        :Arguments:
483
            source : DataPortal
484
485
        :Returns:
486
            daily_stats : pandas.DataFrame
487
              Daily performance metrics such as returns, alpha etc.
488
489
        """
490
        if self.data_portal is None:
491
            self.data_portal = data_portal
492
493
        self.ensure_clock()
494
495
        # force a reset of the performance tracker, in case
496
        # this is a repeat run of the algorithm.
497
        self.perf_tracker = None
498
499
        # create zipline
500
        self.gen = self.get_generator()
501
502
        # loop through simulated_trading, each iteration returns a
503
        # perf dictionary
504
        perfs = []
505
        for perf in self.gen:
506
            perfs.append(perf)
507
508
        # convert perf dict to pandas dataframe
509
        daily_stats = self._create_daily_stats(perfs)
510
511
        self.analyze(daily_stats)
512
513
        return daily_stats
514
515
    def _write_and_map_id_index_to_sids(self, identifiers, as_of_date):
516
        # Build new Assets for identifiers that can't be resolved as
517
        # sids/Assets
518
        identifiers_to_build = []
519
        for identifier in identifiers:
520
            asset = None
521
522
            if isinstance(identifier, Asset):
523
                asset = self.asset_finder.retrieve_asset(sid=identifier.sid,
524
                                                         default_none=True)
525
            elif isinstance(identifier, Integral):
526
                asset = self.asset_finder.retrieve_asset(sid=identifier,
527
                                                         default_none=True)
528
            if asset is None:
529
                identifiers_to_build.append(identifier)
530
531
        self.trading_environment.write_data(
532
            equities_identifiers=identifiers_to_build)
533
534
        # We need to clear out any cache misses that were stored while trying
535
        # to do lookups.  The real fix for this problem is to not construct an
536
        # AssetFinder until we `run()` when we actually have all the data we
537
        # need to so.
538
        self.asset_finder._reset_caches()
539
540
        return self.asset_finder.map_identifier_index_to_sids(
541
            identifiers, as_of_date,
542
        )
543
544
    def _create_daily_stats(self, perfs):
545
        # create daily and cumulative stats dataframe
546
        daily_perfs = []
547
        # TODO: the loop here could overwrite expected properties
548
        # of daily_perf. Could potentially raise or log a
549
        # warning.
550
        for perf in perfs:
551
            if 'daily_perf' in perf:
552
553
                perf['daily_perf'].update(
554
                    perf['daily_perf'].pop('recorded_vars')
555
                )
556
                perf['daily_perf'].update(perf['cumulative_risk_metrics'])
557
                daily_perfs.append(perf['daily_perf'])
558
            else:
559
                self.risk_report = perf
560
561
        daily_dts = [np.datetime64(perf['period_close'], utc=True)
562
                     for perf in daily_perfs]
563
        daily_stats = pd.DataFrame(daily_perfs, index=daily_dts)
564
565
        return daily_stats
566
567
    @api_method
568
    def get_environment(self, field='platform'):
569
        env = {
570
            'arena': self.sim_params.arena,
571
            'data_frequency': self.sim_params.data_frequency,
572
            'start': self.sim_params.first_open,
573
            'end': self.sim_params.last_close,
574
            'capital_base': self.sim_params.capital_base,
575
            'platform': self._platform
576
        }
577
        if field == '*':
578
            return env
579
        else:
580
            return env[field]
581
582
    @api_method
583
    def fetch_csv(self, url,
584
                  pre_func=None,
585
                  post_func=None,
586
                  date_column='date',
587
                  date_format=None,
588
                  timezone=pytz.utc.zone,
589
                  symbol=None,
590
                  mask=True,
591
                  symbol_column=None,
592
                  special_params_checker=None,
593
                  **kwargs):
594
595
        # Show all the logs every time fetcher is used.
596
        csv_data_source = PandasRequestsCSV(
597
            url,
598
            pre_func,
599
            post_func,
600
            self.trading_environment,
601
            self.sim_params.period_start,
602
            self.sim_params.period_end,
603
            date_column,
604
            date_format,
605
            timezone,
606
            symbol,
607
            mask,
608
            symbol_column,
609
            data_frequency=self.data_frequency,
610
            special_params_checker=special_params_checker,
611
            **kwargs
612
        )
613
614
        # ingest this into dataportal
615
        self.data_portal.handle_extra_source(csv_data_source.df)
616
617
        return csv_data_source
618
619
    def add_event(self, rule=None, callback=None):
620
        """
621
        Adds an event to the algorithm's EventManager.
622
        """
623
        self.event_manager.add_event(
624
            zipline.utils.events.Event(rule, callback),
625
        )
626
627
    @api_method
628
    def schedule_function(self,
629
                          func,
630
                          date_rule=None,
631
                          time_rule=None,
632
                          half_days=True):
633
        """
634
        Schedules a function to be called with some timed rules.
635
        """
636
        date_rule = date_rule or DateRuleFactory.every_day()
637
        time_rule = ((time_rule or TimeRuleFactory.market_open())
638
                     if self.sim_params.data_frequency == 'minute' else
639
                     # If we are in daily mode the time_rule is ignored.
640
                     zipline.utils.events.Always())
641
642
        self.add_event(
643
            make_eventrule(date_rule, time_rule, half_days),
644
            func,
645
        )
646
647
    @api_method
648
    def record(self, *args, **kwargs):
649
        """
650
        Track and record local variable (i.e. attributes) each day.
651
        """
652
        # Make 2 objects both referencing the same iterator
653
        args = [iter(args)] * 2
654
655
        # Zip generates list entries by calling `next` on each iterator it
656
        # receives.  In this case the two iterators are the same object, so the
657
        # call to next on args[0] will also advance args[1], resulting in zip
658
        # returning (a,b) (c,d) (e,f) rather than (a,a) (b,b) (c,c) etc.
659
        positionals = zip(*args)
660
        for name, value in chain(positionals, iteritems(kwargs)):
661
            self._recorded_vars[name] = value
662
663
    @api_method
664
    def set_benchmark(self, benchmark_sid):
665
        if self.initialized:
666
            raise SetBenchmarkOutsideInitialize()
667
668
        self.benchmark_sid = benchmark_sid
669
670
    @api_method
671
    @preprocess(symbol_str=ensure_upper_case)
672
    def symbol(self, symbol_str):
673
        """
674
        Default symbol lookup for any source that directly maps the
675
        symbol to the Asset (e.g. yahoo finance).
676
        """
677
        # If the user has not set the symbol lookup date,
678
        # use the period_end as the date for sybmol->sid resolution.
679
        _lookup_date = self._symbol_lookup_date if self._symbol_lookup_date is not None \
680
            else self.sim_params.period_end
681
682
        return self.asset_finder.lookup_symbol(
683
            symbol_str,
684
            as_of_date=_lookup_date,
685
        )
686
687
    @api_method
688
    def symbols(self, *args):
689
        """
690
        Default symbols lookup for any source that directly maps the
691
        symbol to the Asset (e.g. yahoo finance).
692
        """
693
        return [self.symbol(identifier) for identifier in args]
694
695
    @api_method
696
    def sid(self, a_sid):
697
        """
698
        Default sid lookup for any source that directly maps the integer sid
699
        to the Asset.
700
        """
701
        return self.asset_finder.retrieve_asset(a_sid)
702
703
    @api_method
704
    @preprocess(symbol=ensure_upper_case)
705
    def future_symbol(self, symbol):
706
        """ Lookup a futures contract with a given symbol.
707
708
        Parameters
709
        ----------
710
        symbol : str
711
            The symbol of the desired contract.
712
713
        Returns
714
        -------
715
        Future
716
            A Future object.
717
718
        Raises
719
        ------
720
        SymbolNotFound
721
            Raised when no contract named 'symbol' is found.
722
723
        """
724
        return self.asset_finder.lookup_future_symbol(symbol)
725
726
    @api_method
727
    @preprocess(root_symbol=ensure_upper_case)
728
    def future_chain(self, root_symbol, as_of_date=None):
729
        """ Look up a future chain with the specified parameters.
730
731
        Parameters
732
        ----------
733
        root_symbol : str
734
            The root symbol of a future chain.
735
        as_of_date : datetime.datetime or pandas.Timestamp or str, optional
736
            Date at which the chain determination is rooted. I.e. the
737
            existing contract whose notice date is first after this date is
738
            the primary contract, etc.
739
740
        Returns
741
        -------
742
        FutureChain
743
            The future chain matching the specified parameters.
744
745
        Raises
746
        ------
747
        RootSymbolNotFound
748
            If a future chain could not be found for the given root symbol.
749
        """
750
        if as_of_date:
751
            try:
752
                as_of_date = pd.Timestamp(as_of_date, tz='UTC')
753
            except ValueError:
754
                raise UnsupportedDatetimeFormat(input=as_of_date,
755
                                                method='future_chain')
756
        return FutureChain(
757
            asset_finder=self.asset_finder,
758
            get_datetime=self.get_datetime,
759
            root_symbol=root_symbol,
760
            as_of_date=as_of_date
761
        )
762
763
    def _calculate_order_value_amount(self, asset, value):
764
        """
765
        Calculates how many shares/contracts to order based on the type of
766
        asset being ordered.
767
        """
768
        last_price = self.trading_client.current_data[asset].price
769
770
        if tolerant_equals(last_price, 0):
771
            zero_message = "Price of 0 for {psid}; can't infer value".format(
772
                psid=asset
773
            )
774
            if self.logger:
775
                self.logger.debug(zero_message)
776
            # Don't place any order
777
            return 0
778
779
        if isinstance(asset, Future):
780
            value_multiplier = asset.contract_multiplier
781
        else:
782
            value_multiplier = 1
783
784
        return value / (last_price * value_multiplier)
785
786
    @api_method
787
    def order(self, sid, amount,
788
              limit_price=None,
789
              stop_price=None,
790
              style=None):
791
        """
792
        Place an order using the specified parameters.
793
        """
794
        # Truncate to the integer share count that's either within .0001 of
795
        # amount or closer to zero.
796
        # E.g. 3.9999 -> 4.0; 5.5 -> 5.0; -5.5 -> -5.0
797
        amount = int(round_if_near_integer(amount))
798
799
        # Raises a ZiplineError if invalid parameters are detected.
800
        self.validate_order_params(sid,
801
                                   amount,
802
                                   limit_price,
803
                                   stop_price,
804
                                   style)
805
806
        # Convert deprecated limit_price and stop_price parameters to use
807
        # ExecutionStyle objects.
808
        style = self.__convert_order_params_for_blotter(limit_price,
809
                                                        stop_price,
810
                                                        style)
811
        return self.blotter.order(sid, amount, style)
812
813
    def validate_order_params(self,
814
                              asset,
815
                              amount,
816
                              limit_price,
817
                              stop_price,
818
                              style):
819
        """
820
        Helper method for validating parameters to the order API function.
821
822
        Raises an UnsupportedOrderParameters if invalid arguments are found.
823
        """
824
825
        if not self.initialized:
826
            raise OrderDuringInitialize(
827
                msg="order() can only be called from within handle_data()"
828
            )
829
830
        if style:
831
            if limit_price:
832
                raise UnsupportedOrderParameters(
833
                    msg="Passing both limit_price and style is not supported."
834
                )
835
836
            if stop_price:
837
                raise UnsupportedOrderParameters(
838
                    msg="Passing both stop_price and style is not supported."
839
                )
840
841
        if not isinstance(asset, Asset):
842
            raise UnsupportedOrderParameters(
843
                msg="Passing non-Asset argument to 'order()' is not supported."
844
                    " Use 'sid()' or 'symbol()' methods to look up an Asset."
845
            )
846
847
        for control in self.trading_controls:
848
            control.validate(asset,
849
                             amount,
850
                             self.portfolio,
851
                             self.get_datetime(),
852
                             self.trading_client.current_data)
853
854
    @staticmethod
855
    def __convert_order_params_for_blotter(limit_price, stop_price, style):
856
        """
857
        Helper method for converting deprecated limit_price and stop_price
858
        arguments into ExecutionStyle instances.
859
860
        This function assumes that either style == None or (limit_price,
861
        stop_price) == (None, None).
862
        """
863
        # TODO_SS: DeprecationWarning for usage of limit_price and stop_price.
864
        if style:
865
            assert (limit_price, stop_price) == (None, None)
866
            return style
867
        if limit_price and stop_price:
868
            return StopLimitOrder(limit_price, stop_price)
869
        if limit_price:
870
            return LimitOrder(limit_price)
871
        if stop_price:
872
            return StopOrder(stop_price)
873
        else:
874
            return MarketOrder()
875
876
    @api_method
877
    def order_value(self, sid, value,
878
                    limit_price=None, stop_price=None, style=None):
879
        """
880
        Place an order by desired value rather than desired number of shares.
881
        If the requested sid exists, the requested value is
882
        divided by its price to imply the number of shares to transact.
883
        If the Asset being ordered is a Future, the 'value' calculated
884
        is actually the exposure, as Futures have no 'value'.
885
886
        value > 0 :: Buy/Cover
887
        value < 0 :: Sell/Short
888
        Market order:    order(sid, value)
889
        Limit order:     order(sid, value, limit_price)
890
        Stop order:      order(sid, value, None, stop_price)
891
        StopLimit order: order(sid, value, limit_price, stop_price)
892
        """
893
        amount = self._calculate_order_value_amount(sid, value)
894
        return self.order(sid, amount,
895
                          limit_price=limit_price,
896
                          stop_price=stop_price,
897
                          style=style)
898
899
    @property
900
    def recorded_vars(self):
901
        return copy(self._recorded_vars)
902
903
    @property
904
    def portfolio(self):
905
        return self.updated_portfolio()
906
907
    def updated_portfolio(self):
908
        if self._portfolio is None and self.perf_tracker is not None:
909
            self._portfolio = \
910
                self.perf_tracker.get_portfolio(self.datetime)
911
        return self._portfolio
912
913
    @property
914
    def account(self):
915
        return self.updated_account()
916
917
    def updated_account(self):
918
        if self._account is None and self.perf_tracker is not None:
919
            self._account = \
920
                self.perf_tracker.get_account(self.datetime)
921
        return self._account
922
923
    def set_logger(self, logger):
924
        self.logger = logger
925
926
    def on_dt_changed(self, dt):
927
        """
928
        Callback triggered by the simulation loop whenever the current dt
929
        changes.
930
931
        Any logic that should happen exactly once at the start of each datetime
932
        group should happen here.
933
        """
934
        assert isinstance(dt, datetime), \
935
            "Attempt to set algorithm's current time with non-datetime"
936
        assert dt.tzinfo == pytz.utc, \
937
            "Algorithm expects a utc datetime"
938
939
        self.datetime = dt
940
        self.perf_tracker.set_date(dt)
941
        self.blotter.set_date(dt)
942
943
        self._portfolio = None
944
        self._account = None
945
946
    @api_method
947
    def get_datetime(self, tz=None):
948
        """
949
        Returns the simulation datetime.
950
        """
951
        dt = self.datetime
952
        assert dt.tzinfo == pytz.utc, "Algorithm should have a utc datetime"
953
954
        if tz is not None:
955
            # Convert to the given timezone passed as a string or tzinfo.
956
            if isinstance(tz, string_types):
957
                tz = pytz.timezone(tz)
958
            dt = dt.astimezone(tz)
959
960
        return dt  # datetime.datetime objects are immutable.
961
962
    def update_dividends(self, dividend_frame):
963
        """
964
        Set DataFrame used to process dividends.  DataFrame columns should
965
        contain at least the entries in zp.DIVIDEND_FIELDS.
966
        """
967
        self.perf_tracker.update_dividends(dividend_frame)
968
969
    @api_method
970
    def set_slippage(self, slippage):
971
        if not isinstance(slippage, SlippageModel):
972
            raise UnsupportedSlippageModel()
973
        if self.initialized:
974
            raise OverrideSlippagePostInit()
975
        self.blotter.slippage_func = slippage
976
977
    @api_method
978
    def set_commission(self, commission):
979
        if not isinstance(commission, (PerShare, PerTrade, PerDollar)):
980
            raise UnsupportedCommissionModel()
981
982
        if self.initialized:
983
            raise OverrideCommissionPostInit()
984
        self.blotter.commission = commission
985
986
    @api_method
987
    def set_symbol_lookup_date(self, dt):
988
        """
989
        Set the date for which symbols will be resolved to their sids
990
        (symbols may map to different firms or underlying assets at
991
        different times)
992
        """
993
        try:
994
            self._symbol_lookup_date = pd.Timestamp(dt, tz='UTC')
995
        except ValueError:
996
            raise UnsupportedDatetimeFormat(input=dt,
997
                                            method='set_symbol_lookup_date')
998
999
    # Remain backwards compatibility
1000
    @property
1001
    def data_frequency(self):
1002
        return self.sim_params.data_frequency
1003
1004
    @data_frequency.setter
1005
    def data_frequency(self, value):
1006
        assert value in ('daily', 'minute')
1007
        self.sim_params.data_frequency = value
1008
1009
    @api_method
1010
    def order_percent(self, sid, percent,
1011
                      limit_price=None, stop_price=None, style=None):
1012
        """
1013
        Place an order in the specified asset corresponding to the given
1014
        percent of the current portfolio value.
1015
1016
        Note that percent must expressed as a decimal (0.50 means 50\%).
1017
        """
1018
        value = self.portfolio.portfolio_value * percent
1019
        return self.order_value(sid, value,
1020
                                limit_price=limit_price,
1021
                                stop_price=stop_price,
1022
                                style=style)
1023
1024
    @api_method
1025
    def order_target(self, sid, target,
1026
                     limit_price=None, stop_price=None, style=None):
1027
        """
1028
        Place an order to adjust a position to a target number of shares. If
1029
        the position doesn't already exist, this is equivalent to placing a new
1030
        order. If the position does exist, this is equivalent to placing an
1031
        order for the difference between the target number of shares and the
1032
        current number of shares.
1033
        """
1034
        if sid in self.portfolio.positions:
1035
            current_position = self.portfolio.positions[sid].amount
1036
            req_shares = target - current_position
1037
            return self.order(sid, req_shares,
1038
                              limit_price=limit_price,
1039
                              stop_price=stop_price,
1040
                              style=style)
1041
        else:
1042
            return self.order(sid, target,
1043
                              limit_price=limit_price,
1044
                              stop_price=stop_price,
1045
                              style=style)
1046
1047
    @api_method
1048
    def order_target_value(self, sid, target,
1049
                           limit_price=None, stop_price=None, style=None):
1050
        """
1051
        Place an order to adjust a position to a target value. If
1052
        the position doesn't already exist, this is equivalent to placing a new
1053
        order. If the position does exist, this is equivalent to placing an
1054
        order for the difference between the target value and the
1055
        current value.
1056
        If the Asset being ordered is a Future, the 'target value' calculated
1057
        is actually the target exposure, as Futures have no 'value'.
1058
        """
1059
        target_amount = self._calculate_order_value_amount(sid, target)
1060
        return self.order_target(sid, target_amount,
1061
                                 limit_price=limit_price,
1062
                                 stop_price=stop_price,
1063
                                 style=style)
1064
1065
    @api_method
1066
    def order_target_percent(self, sid, target,
1067
                             limit_price=None, stop_price=None, style=None):
1068
        """
1069
        Place an order to adjust a position to a target percent of the
1070
        current portfolio value. If the position doesn't already exist, this is
1071
        equivalent to placing a new order. If the position does exist, this is
1072
        equivalent to placing an order for the difference between the target
1073
        percent and the current percent.
1074
1075
        Note that target must expressed as a decimal (0.50 means 50\%).
1076
        """
1077
        target_value = self.portfolio.portfolio_value * target
1078
        return self.order_target_value(sid, target_value,
1079
                                       limit_price=limit_price,
1080
                                       stop_price=stop_price,
1081
                                       style=style)
1082
1083
    @api_method
1084
    def get_open_orders(self, sid=None):
1085
        if sid is None:
1086
            return {
1087
                key: [order.to_api_obj() for order in orders]
1088
                for key, orders in iteritems(self.blotter.open_orders)
1089
                if orders
1090
            }
1091
        if sid in self.blotter.open_orders:
1092
            orders = self.blotter.open_orders[sid]
1093
            return [order.to_api_obj() for order in orders]
1094
        return []
1095
1096
    @api_method
1097
    def get_order(self, order_id):
1098
        if order_id in self.blotter.orders:
1099
            return self.blotter.orders[order_id].to_api_obj()
1100
1101
    @api_method
1102
    def cancel_order(self, order_param):
1103
        order_id = order_param
1104
        if isinstance(order_param, zipline.protocol.Order):
1105
            order_id = order_param.id
1106
1107
        self.blotter.cancel(order_id)
1108
1109
    @api_method
1110
    @require_initialized(HistoryInInitialize())
1111
    def history(self, sids, bar_count, frequency, field, ffill=True):
1112
        if self.data_portal is None:
1113
            raise Exception("no data portal!")
1114
1115
        return self.data_portal.get_history_window(
1116
            sids,
1117
            self.get_datetime(),
1118
            bar_count,
1119
            frequency,
1120
            field,
1121
            ffill
1122
        )
1123
    ####################
1124
    # Account Controls #
1125
    ####################
1126
1127
    def register_account_control(self, control):
1128
        """
1129
        Register a new AccountControl to be checked on each bar.
1130
        """
1131
        if self.initialized:
1132
            raise RegisterAccountControlPostInit()
1133
        self.account_controls.append(control)
1134
1135
    def validate_account_controls(self):
1136
        for control in self.account_controls:
1137
            control.validate(self.portfolio,
1138
                             self.account,
1139
                             self.get_datetime(),
1140
                             self.trading_client.current_data)
1141
1142
    @api_method
1143
    def set_max_leverage(self, max_leverage=None):
1144
        """
1145
        Set a limit on the maximum leverage of the algorithm.
1146
        """
1147
        control = MaxLeverage(max_leverage)
1148
        self.register_account_control(control)
1149
1150
    ####################
1151
    # Trading Controls #
1152
    ####################
1153
1154
    def register_trading_control(self, control):
1155
        """
1156
        Register a new TradingControl to be checked prior to order calls.
1157
        """
1158
        if self.initialized:
1159
            raise RegisterTradingControlPostInit()
1160
        self.trading_controls.append(control)
1161
1162
    @api_method
1163
    def set_max_position_size(self,
1164
                              sid=None,
1165
                              max_shares=None,
1166
                              max_notional=None):
1167
        """
1168
        Set a limit on the number of shares and/or dollar value held for the
1169
        given sid. Limits are treated as absolute values and are enforced at
1170
        the time that the algo attempts to place an order for sid. This means
1171
        that it's possible to end up with more than the max number of shares
1172
        due to splits/dividends, and more than the max notional due to price
1173
        improvement.
1174
1175
        If an algorithm attempts to place an order that would result in
1176
        increasing the absolute value of shares/dollar value exceeding one of
1177
        these limits, raise a TradingControlException.
1178
        """
1179
        control = MaxPositionSize(asset=sid,
1180
                                  max_shares=max_shares,
1181
                                  max_notional=max_notional)
1182
        self.register_trading_control(control)
1183
1184
    @api_method
1185
    def set_max_order_size(self, sid=None, max_shares=None, max_notional=None):
1186
        """
1187
        Set a limit on the number of shares and/or dollar value of any single
1188
        order placed for sid.  Limits are treated as absolute values and are
1189
        enforced at the time that the algo attempts to place an order for sid.
1190
1191
        If an algorithm attempts to place an order that would result in
1192
        exceeding one of these limits, raise a TradingControlException.
1193
        """
1194
        control = MaxOrderSize(asset=sid,
1195
                               max_shares=max_shares,
1196
                               max_notional=max_notional)
1197
        self.register_trading_control(control)
1198
1199
    @api_method
1200
    def set_max_order_count(self, max_count):
1201
        """
1202
        Set a limit on the number of orders that can be placed within the given
1203
        time interval.
1204
        """
1205
        control = MaxOrderCount(max_count)
1206
        self.register_trading_control(control)
1207
1208
    @api_method
1209
    def set_do_not_order_list(self, restricted_list):
1210
        """
1211
        Set a restriction on which sids can be ordered.
1212
        """
1213
        control = RestrictedListOrder(restricted_list)
1214
        self.register_trading_control(control)
1215
1216
    @api_method
1217
    def set_long_only(self):
1218
        """
1219
        Set a rule specifying that this algorithm cannot take short positions.
1220
        """
1221
        self.register_trading_control(LongOnly())
1222
1223
    ##############
1224
    # Pipeline API
1225
    ##############
1226
    @api_method
1227
    @require_not_initialized(AttachPipelineAfterInitialize())
1228
    def attach_pipeline(self, pipeline, name, chunksize=None):
1229
        """
1230
        Register a pipeline to be computed at the start of each day.
1231
        """
1232
        if self._pipelines:
1233
            raise NotImplementedError("Multiple pipelines are not supported.")
1234
        if chunksize is None:
1235
            # Make the first chunk smaller to get more immediate results:
1236
            # (one week, then every half year)
1237
            chunks = iter(chain([5], repeat(126)))
1238
        else:
1239
            chunks = iter(repeat(int(chunksize)))
1240
        self._pipelines[name] = pipeline, chunks
1241
1242
        # Return the pipeline to allow expressions like
1243
        # p = attach_pipeline(Pipeline(), 'name')
1244
        return pipeline
1245
1246
    @api_method
1247
    @require_initialized(PipelineOutputDuringInitialize())
1248
    def pipeline_output(self, name):
1249
        """
1250
        Get the results of pipeline with name `name`.
1251
1252
        Parameters
1253
        ----------
1254
        name : str
1255
            Name of the pipeline for which results are requested.
1256
1257
        Returns
1258
        -------
1259
        results : pd.DataFrame
1260
            DataFrame containing the results of the requested pipeline for
1261
            the current simulation date.
1262
1263
        Raises
1264
        ------
1265
        NoSuchPipeline
1266
            Raised when no pipeline with the name `name` has been registered.
1267
1268
        See Also
1269
        --------
1270
        :meth:`zipline.pipeline.engine.PipelineEngine.run_pipeline`
1271
        """
1272
        # NOTE: We don't currently support multiple pipelines, but we plan to
1273
        # in the future.
1274
        try:
1275
            p, chunks = self._pipelines[name]
1276
        except KeyError:
1277
            raise NoSuchPipeline(
1278
                name=name,
1279
                valid=list(self._pipelines.keys()),
1280
            )
1281
        return self._pipeline_output(p, chunks)
1282
1283
    def _pipeline_output(self, pipeline, chunks):
1284
        """
1285
        Internal implementation of `pipeline_output`.
1286
        """
1287
        today = normalize_date(self.get_datetime())
1288
        try:
1289
            data = self._pipeline_cache.unwrap(today)
1290
        except Expired:
1291
            data, valid_until = self._run_pipeline(
1292
                pipeline, today, next(chunks),
1293
            )
1294
            self._pipeline_cache = CachedObject(data, valid_until)
1295
1296
        # Now that we have a cached result, try to return the data for today.
1297
        try:
1298
            return data.loc[today]
1299
        except KeyError:
1300
            # This happens if no assets passed the pipeline screen on a given
1301
            # day.
1302
            return pd.DataFrame(index=[], columns=data.columns)
1303
1304
    def _run_pipeline(self, pipeline, start_date, chunksize):
1305
        """
1306
        Compute `pipeline`, providing values for at least `start_date`.
1307
1308
        Produces a DataFrame containing data for days between `start_date` and
1309
        `end_date`, where `end_date` is defined by:
1310
1311
            `end_date = min(start_date + chunksize trading days,
1312
                            simulation_end)`
1313
1314
        Returns
1315
        -------
1316
        (data, valid_until) : tuple (pd.DataFrame, pd.Timestamp)
1317
1318
        See Also
1319
        --------
1320
        PipelineEngine.run_pipeline
1321
        """
1322
        days = self.trading_environment.trading_days
1323
1324
        # Load data starting from the previous trading day...
1325
        start_date_loc = days.get_loc(start_date)
1326
1327
        # ...continuing until either the day before the simulation end, or
1328
        # until chunksize days of data have been loaded.
1329
        sim_end = self.sim_params.last_close.normalize()
1330
        end_loc = min(start_date_loc + chunksize, days.get_loc(sim_end))
1331
        end_date = days[end_loc]
1332
1333
        return \
1334
            self.engine.run_pipeline(pipeline, start_date, end_date), end_date
1335
1336
    ##################
1337
    # End Pipeline API
1338
    ##################
1339
1340
    @classmethod
1341
    def all_api_methods(cls):
1342
        """
1343
        Return a list of all the TradingAlgorithm API methods.
1344
        """
1345
        return [
1346
            fn for fn in itervalues(vars(cls))
1347
            if getattr(fn, 'is_api_method', False)
1348
        ]
1349