Completed
Pull Request — master (#858)
by Eddie
01:43
created

zipline.TradingAlgorithm.run()   B

Complexity

Conditions 2

Size

Total Lines 29

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 2
dl 0
loc 29
rs 8.8571
1
#
2
# Copyright 2015 Quantopian, Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
from copy import copy
16
17
import pytz
18
import pandas as pd
19
from pandas.tseries.tools import normalize_date
20
import numpy as np
21
22
from datetime import datetime
23
from itertools import chain, repeat
24
from numbers import Integral
25
26
from six import (
27
    exec_,
28
    iteritems,
29
    itervalues,
30
    string_types,
31
)
32
33
from zipline.errors import (
34
    AttachPipelineAfterInitialize,
35
    HistoryInInitialize,
36
    NoSuchPipeline,
37
    OrderDuringInitialize,
38
    OverrideCommissionPostInit,
39
    OverrideSlippagePostInit,
40
    PipelineOutputDuringInitialize,
41
    RegisterAccountControlPostInit,
42
    RegisterTradingControlPostInit,
43
    SetBenchmarkOutsideInitialize,
44
    UnsupportedCommissionModel,
45
    UnsupportedDatetimeFormat,
46
    UnsupportedOrderParameters,
47
    UnsupportedSlippageModel,
48
)
49
from zipline.finance.trading import TradingEnvironment
50
from zipline.finance.blotter import Blotter
51
from zipline.finance.commission import PerShare, PerTrade, PerDollar
52
from zipline.finance.controls import (
53
    LongOnly,
54
    MaxOrderCount,
55
    MaxOrderSize,
56
    MaxPositionSize,
57
    MaxLeverage,
58
    RestrictedListOrder
59
)
60
from zipline.finance.execution import (
61
    LimitOrder,
62
    MarketOrder,
63
    StopLimitOrder,
64
    StopOrder,
65
)
66
from zipline.finance.performance import PerformanceTracker
67
from zipline.finance.slippage import (
68
    VolumeShareSlippage,
69
    SlippageModel
70
)
71
from zipline.assets import Asset, Future
72
from zipline.assets.futures import FutureChain
73
from zipline.gens.tradesimulation import AlgorithmSimulator
74
from zipline.pipeline.engine import (
75
    NoOpPipelineEngine,
76
    SimplePipelineEngine,
77
)
78
from zipline.utils.api_support import (
79
    api_method,
80
    require_initialized,
81
    require_not_initialized,
82
    ZiplineAPI,
83
)
84
from zipline.utils.input_validation import ensure_upper_case
85
from zipline.utils.cache import (
86
    CachedObject,
87
    Expired
88
)
89
import zipline.utils.events
90
from zipline.utils.events import (
91
    EventManager,
92
    make_eventrule,
93
    DateRuleFactory,
94
    TimeRuleFactory,
95
)
96
from zipline.utils.factory import create_simulation_parameters
97
from zipline.utils.math_utils import (
98
    tolerant_equals,
99
    round_if_near_integer
100
)
101
from zipline.utils.preprocess import preprocess
102
103
import zipline.protocol
104
from zipline.sources.requests_csv import PandasRequestsCSV
105
106
from zipline.gens.sim_engine import (
107
    MinuteSimulationClock,
108
    DailySimulationClock,
109
)
110
from zipline.sources.benchmark_source import BenchmarkSource
111
112
DEFAULT_CAPITAL_BASE = float("1.0e5")
113
114
115
class TradingAlgorithm(object):
116
    """
117
    Base class for trading algorithms. Inherit and overload
118
    initialize() and handle_data(data).
119
120
    A new algorithm could look like this:
121
    ```
122
    from zipline.api import order, symbol
123
124
    def initialize(context):
125
        context.sid = symbol('AAPL')
126
        context.amount = 100
127
128
    def handle_data(context, data):
129
        sid = context.sid
130
        amount = context.amount
131
        order(sid, amount)
132
    ```
133
    To then to run this algorithm pass these functions to
134
    TradingAlgorithm:
135
136
    my_algo = TradingAlgorithm(initialize, handle_data)
137
    stats = my_algo.run(data)
138
139
    """
140
141
    def __init__(self, *args, **kwargs):
142
        """Initialize sids and other state variables.
143
144
        :Arguments:
145
        :Optional:
146
            initialize : function
147
                Function that is called with a single
148
                argument at the begninning of the simulation.
149
            handle_data : function
150
                Function that is called with 2 arguments
151
                (context and data) on every bar.
152
            script : str
153
                Algoscript that contains initialize and
154
                handle_data function definition.
155
            data_frequency : {'daily', 'minute'}
156
               The duration of the bars.
157
            capital_base : float <default: 1.0e5>
158
               How much capital to start with.
159
            asset_finder : An AssetFinder object
160
                A new AssetFinder object to be used in this TradingEnvironment
161
            equities_metadata : can be either:
162
                            - dict
163
                            - pandas.DataFrame
164
                            - object with 'read' property
165
                If dict is provided, it must have the following structure:
166
                * keys are the identifiers
167
                * values are dicts containing the metadata, with the metadata
168
                  field name as the key
169
                If pandas.DataFrame is provided, it must have the
170
                following structure:
171
                * column names must be the metadata fields
172
                * index must be the different asset identifiers
173
                * array contents should be the metadata value
174
                If an object with a 'read' property is provided, 'read' must
175
                return rows containing at least one of 'sid' or 'symbol' along
176
                with the other metadata fields.
177
            identifiers : List
178
                Any asset identifiers that are not provided in the
179
                equities_metadata, but will be traded by this TradingAlgorithm
180
        """
181
        self.sources = []
182
183
        # List of trading controls to be used to validate orders.
184
        self.trading_controls = []
185
186
        # List of account controls to be checked on each bar.
187
        self.account_controls = []
188
189
        self._recorded_vars = {}
190
        self.namespace = kwargs.pop('namespace', {})
191
192
        self._platform = kwargs.pop('platform', 'zipline')
193
194
        self.logger = None
195
196
        self.data_portal = None
197
198
        # If an env has been provided, pop it
199
        self.trading_environment = kwargs.pop('env', None)
200
201
        if self.trading_environment is None:
202
            self.trading_environment = TradingEnvironment()
203
204
        # Update the TradingEnvironment with the provided asset metadata
205
        self.trading_environment.write_data(
206
            equities_data=kwargs.pop('equities_metadata', {}),
207
            equities_identifiers=kwargs.pop('identifiers', []),
208
            futures_data=kwargs.pop('futures_metadata', {}),
209
        )
210
211
        # set the capital base
212
        self.capital_base = kwargs.pop('capital_base', DEFAULT_CAPITAL_BASE)
213
        self.sim_params = kwargs.pop('sim_params', None)
214
        if self.sim_params is None:
215
            self.sim_params = create_simulation_parameters(
216
                capital_base=self.capital_base,
217
                start=kwargs.pop('start', None),
218
                end=kwargs.pop('end', None),
219
                env=self.trading_environment,
220
            )
221
        else:
222
            self.sim_params.update_internal_from_env(self.trading_environment)
223
224
        self.perf_tracker = None
225
        # Pull in the environment's new AssetFinder for quick reference
226
        self.asset_finder = self.trading_environment.asset_finder
227
228
        # Initialize Pipeline API data.
229
        self.init_engine(kwargs.pop('get_pipeline_loader', None))
230
        self._pipelines = {}
231
        # Create an always-expired cache so that we compute the first time data
232
        # is requested.
233
        self._pipeline_cache = CachedObject(None, pd.Timestamp(0, tz='UTC'))
234
235
        self.blotter = kwargs.pop('blotter', None)
236
        if not self.blotter:
237
            self.blotter = Blotter(
238
                slippage_func=VolumeShareSlippage(),
239
                commission=PerShare()
240
            )
241
242
        # The symbol lookup date specifies the date to use when resolving
243
        # symbols to sids, and can be set using set_symbol_lookup_date()
244
        self._symbol_lookup_date = None
245
246
        self._portfolio = None
247
        self._account = None
248
249
        # If string is passed in, execute and get reference to
250
        # functions.
251
        self.algoscript = kwargs.pop('script', None)
252
253
        self._initialize = None
254
        self._before_trading_start = None
255
        self._analyze = None
256
257
        self.event_manager = EventManager(
258
            create_context=kwargs.pop('create_event_context', None),
259
        )
260
261
        if self.algoscript is not None:
262
            filename = kwargs.pop('algo_filename', None)
263
            if filename is None:
264
                filename = '<string>'
265
            code = compile(self.algoscript, filename, 'exec')
266
            exec_(code, self.namespace)
267
            self._initialize = self.namespace.get('initialize')
268
            if 'handle_data' not in self.namespace:
269
                raise ValueError('You must define a handle_data function.')
270
            else:
271
                self._handle_data = self.namespace['handle_data']
272
273
            self._before_trading_start = \
274
                self.namespace.get('before_trading_start')
275
            # Optional analyze function, gets called after run
276
            self._analyze = self.namespace.get('analyze')
277
278
        elif kwargs.get('initialize') and kwargs.get('handle_data'):
279
            if self.algoscript is not None:
280
                raise ValueError('You can not set script and \
281
                initialize/handle_data.')
282
            self._initialize = kwargs.pop('initialize')
283
            self._handle_data = kwargs.pop('handle_data')
284
            self._before_trading_start = kwargs.pop('before_trading_start',
285
                                                    None)
286
            self._analyze = kwargs.pop('analyze', None)
287
288
        self.event_manager.add_event(
289
            zipline.utils.events.Event(
290
                zipline.utils.events.Always(),
291
                # We pass handle_data.__func__ to get the unbound method.
292
                # We will explicitly pass the algorithm to bind it again.
293
                self.handle_data.__func__,
294
            ),
295
            prepend=True,
296
        )
297
298
        # If method not defined, NOOP
299
        if self._initialize is None:
300
            self._initialize = lambda x: None
301
302
        # Alternative way of setting data_frequency for backwards
303
        # compatibility.
304
        if 'data_frequency' in kwargs:
305
            self.data_frequency = kwargs.pop('data_frequency')
306
307
        # Prepare the algo for initialization
308
        self.initialized = False
309
        self.initialize_args = args
310
        self.initialize_kwargs = kwargs
311
312
        self.benchmark_sid = kwargs.pop('benchmark_sid', None)
313
314
    def init_engine(self, get_loader):
315
        """
316
        Construct and store a PipelineEngine from loader.
317
318
        If get_loader is None, constructs a NoOpPipelineEngine.
319
        """
320
        if get_loader is not None:
321
            self.engine = SimplePipelineEngine(
322
                get_loader,
323
                self.trading_environment.trading_days,
324
                self.asset_finder,
325
            )
326
        else:
327
            self.engine = NoOpPipelineEngine()
328
329
    def initialize(self, *args, **kwargs):
330
        """
331
        Call self._initialize with `self` made available to Zipline API
332
        functions.
333
        """
334
        with ZiplineAPI(self):
335
            self._initialize(self, *args, **kwargs)
336
337
    def before_trading_start(self, data):
338
        if self._before_trading_start is None:
339
            return
340
341
        self._before_trading_start(self, data)
342
343
    def handle_data(self, data):
344
        self._handle_data(self, data)
345
346
        # Unlike trading controls which remain constant unless placing an
347
        # order, account controls can change each bar. Thus, must check
348
        # every bar no matter if the algorithm places an order or not.
349
        self.validate_account_controls()
350
351
    def analyze(self, perf):
352
        if self._analyze is None:
353
            return
354
355
        with ZiplineAPI(self):
356
            self._analyze(self, perf)
357
358
    def __repr__(self):
359
        """
360
        N.B. this does not yet represent a string that can be used
361
        to instantiate an exact copy of an algorithm.
362
363
        However, it is getting close, and provides some value as something
364
        that can be inspected interactively.
365
        """
366
        return """
367
{class_name}(
368
    capital_base={capital_base}
369
    sim_params={sim_params},
370
    initialized={initialized},
371
    slippage={slippage},
372
    commission={commission},
373
    blotter={blotter},
374
    recorded_vars={recorded_vars})
375
""".strip().format(class_name=self.__class__.__name__,
376
                   capital_base=self.capital_base,
377
                   sim_params=repr(self.sim_params),
378
                   initialized=self.initialized,
379
                   slippage=repr(self.blotter.slippage_func),
380
                   commission=repr(self.blotter.commission),
381
                   blotter=repr(self.blotter),
382
                   recorded_vars=repr(self.recorded_vars))
383
384
    def _create_clock(self):
385
        """
386
        If the clock property is not set, then create one based on frequency.
387
        """
388
        if self.sim_params.data_frequency == 'minute':
389
            env = self.trading_environment
390
            trading_o_and_c = env.open_and_closes.ix[
391
                self.sim_params.trading_days]
392
            market_opens = trading_o_and_c['market_open'].values.astype(
393
                'datetime64[ns]').astype(np.int64)
394
            market_closes = trading_o_and_c['market_close'].values.astype(
395
                'datetime64[ns]').astype(np.int64)
396
397
            minutely_emission = self.sim_params.emission_rate == "minute"
398
399
            clock = MinuteSimulationClock(
400
                self.sim_params.trading_days,
401
                market_opens,
402
                market_closes,
403
                env.trading_days,
404
                minutely_emission
405
            )
406
            self.data_portal.setup_offset_cache(
407
                clock.minutes_by_day,
408
                clock.minutes_to_day)
409
            return clock
410
        else:
411
            return DailySimulationClock(self.sim_params.trading_days)
412
413
    def _create_benchmark_source(self):
414
        return BenchmarkSource(
415
            self.benchmark_sid,
416
            self.trading_environment,
417
            self.sim_params.trading_days,
418
            self.data_portal,
419
            emission_rate=self.sim_params.emission_rate,
420
        )
421
422
    def _create_generator(self, sim_params):
423
        if sim_params is not None:
424
            self.sim_params = sim_params
425
426
        if self.perf_tracker is None:
427
            # HACK: When running with the `run` method, we set perf_tracker to
428
            # None so that it will be overwritten here.
429
            self.perf_tracker = PerformanceTracker(
430
                sim_params=self.sim_params,
431
                env=self.trading_environment,
432
                data_portal=self.data_portal
433
            )
434
435
            # Set the dt initially to the period start by forcing it to change.
436
            self.on_dt_changed(self.sim_params.period_start)
437
438
        if not self.initialized:
439
            self.initialize(*self.initialize_args, **self.initialize_kwargs)
440
            self.initialized = True
441
442
        self.trading_client = AlgorithmSimulator(
443
            self,
444
            sim_params,
445
            self.data_portal,
446
            self._create_clock(),
447
            self._create_benchmark_source()
448
        )
449
450
        return self.trading_client.transform()
451
452
    def get_generator(self):
453
        """
454
        Override this method to add new logic to the construction
455
        of the generator. Overrides can use the _create_generator
456
        method to get a standard construction generator.
457
        """
458
        return self._create_generator(self.sim_params)
459
460
    def run(self, data_portal=None):
461
        """Run the algorithm.
462
463
        :Arguments:
464
            source : DataPortal
465
466
        :Returns:
467
            daily_stats : pandas.DataFrame
468
              Daily performance metrics such as returns, alpha etc.
469
470
        """
471
        self.data_portal = data_portal
472
473
        # Force a reset of the performance tracker, in case
474
        # this is a repeat run of the algorithm.
475
        self.perf_tracker = None
476
477
        # Create zipline and loop through simulated_trading.
478
        # Each iteration returns a perf dictionary
479
        perfs = []
480
        for perf in self.get_generator():
481
            perfs.append(perf)
482
483
        # convert perf dict to pandas dataframe
484
        daily_stats = self._create_daily_stats(perfs)
485
486
        self.analyze(daily_stats)
487
488
        return daily_stats
489
490
    def _write_and_map_id_index_to_sids(self, identifiers, as_of_date):
491
        # Build new Assets for identifiers that can't be resolved as
492
        # sids/Assets
493
        identifiers_to_build = []
494
        for identifier in identifiers:
495
            asset = None
496
497
            if isinstance(identifier, Asset):
498
                asset = self.asset_finder.retrieve_asset(sid=identifier.sid,
499
                                                         default_none=True)
500
            elif isinstance(identifier, Integral):
501
                asset = self.asset_finder.retrieve_asset(sid=identifier,
502
                                                         default_none=True)
503
            if asset is None:
504
                identifiers_to_build.append(identifier)
505
506
        self.trading_environment.write_data(
507
            equities_identifiers=identifiers_to_build)
508
509
        # We need to clear out any cache misses that were stored while trying
510
        # to do lookups.  The real fix for this problem is to not construct an
511
        # AssetFinder until we `run()` when we actually have all the data we
512
        # need to so.
513
        self.asset_finder._reset_caches()
514
515
        return self.asset_finder.map_identifier_index_to_sids(
516
            identifiers, as_of_date,
517
        )
518
519
    def _create_daily_stats(self, perfs):
520
        # create daily and cumulative stats dataframe
521
        daily_perfs = []
522
        # TODO: the loop here could overwrite expected properties
523
        # of daily_perf. Could potentially raise or log a
524
        # warning.
525
        for perf in perfs:
526
            if 'daily_perf' in perf:
527
528
                perf['daily_perf'].update(
529
                    perf['daily_perf'].pop('recorded_vars')
530
                )
531
                perf['daily_perf'].update(perf['cumulative_risk_metrics'])
532
                daily_perfs.append(perf['daily_perf'])
533
            else:
534
                self.risk_report = perf
535
536
        daily_dts = [np.datetime64(perf['period_close'], utc=True)
537
                     for perf in daily_perfs]
538
        daily_stats = pd.DataFrame(daily_perfs, index=daily_dts)
539
540
        return daily_stats
541
542
    @api_method
543
    def get_environment(self, field='platform'):
544
        env = {
545
            'arena': self.sim_params.arena,
546
            'data_frequency': self.sim_params.data_frequency,
547
            'start': self.sim_params.first_open,
548
            'end': self.sim_params.last_close,
549
            'capital_base': self.sim_params.capital_base,
550
            'platform': self._platform
551
        }
552
        if field == '*':
553
            return env
554
        else:
555
            return env[field]
556
557
    @api_method
558
    def fetch_csv(self, url,
559
                  pre_func=None,
560
                  post_func=None,
561
                  date_column='date',
562
                  date_format=None,
563
                  timezone=pytz.utc.zone,
564
                  symbol=None,
565
                  mask=True,
566
                  symbol_column=None,
567
                  special_params_checker=None,
568
                  **kwargs):
569
570
        # Show all the logs every time fetcher is used.
571
        csv_data_source = PandasRequestsCSV(
572
            url,
573
            pre_func,
574
            post_func,
575
            self.trading_environment,
576
            self.sim_params.period_start,
577
            self.sim_params.period_end,
578
            date_column,
579
            date_format,
580
            timezone,
581
            symbol,
582
            mask,
583
            symbol_column,
584
            data_frequency=self.data_frequency,
585
            special_params_checker=special_params_checker,
586
            **kwargs
587
        )
588
589
        # ingest this into dataportal
590
        self.data_portal.handle_extra_source(csv_data_source.df)
591
592
        return csv_data_source
593
594
    def add_event(self, rule=None, callback=None):
595
        """
596
        Adds an event to the algorithm's EventManager.
597
        """
598
        self.event_manager.add_event(
599
            zipline.utils.events.Event(rule, callback),
600
        )
601
602
    @api_method
603
    def schedule_function(self,
604
                          func,
605
                          date_rule=None,
606
                          time_rule=None,
607
                          half_days=True):
608
        """
609
        Schedules a function to be called with some timed rules.
610
        """
611
        date_rule = date_rule or DateRuleFactory.every_day()
612
        time_rule = ((time_rule or TimeRuleFactory.market_open())
613
                     if self.sim_params.data_frequency == 'minute' else
614
                     # If we are in daily mode the time_rule is ignored.
615
                     zipline.utils.events.Always())
616
617
        self.add_event(
618
            make_eventrule(date_rule, time_rule, half_days),
619
            func,
620
        )
621
622
    @api_method
623
    def record(self, *args, **kwargs):
624
        """
625
        Track and record local variable (i.e. attributes) each day.
626
        """
627
        # Make 2 objects both referencing the same iterator
628
        args = [iter(args)] * 2
629
630
        # Zip generates list entries by calling `next` on each iterator it
631
        # receives.  In this case the two iterators are the same object, so the
632
        # call to next on args[0] will also advance args[1], resulting in zip
633
        # returning (a,b) (c,d) (e,f) rather than (a,a) (b,b) (c,c) etc.
634
        positionals = zip(*args)
635
        for name, value in chain(positionals, iteritems(kwargs)):
636
            self._recorded_vars[name] = value
637
638
    @api_method
639
    def set_benchmark(self, benchmark_sid):
640
        if self.initialized:
641
            raise SetBenchmarkOutsideInitialize()
642
643
        self.benchmark_sid = benchmark_sid
644
645
    @api_method
646
    @preprocess(symbol_str=ensure_upper_case)
647
    def symbol(self, symbol_str):
648
        """
649
        Default symbol lookup for any source that directly maps the
650
        symbol to the Asset (e.g. yahoo finance).
651
        """
652
        # If the user has not set the symbol lookup date,
653
        # use the period_end as the date for sybmol->sid resolution.
654
        _lookup_date = self._symbol_lookup_date if self._symbol_lookup_date is not None \
655
            else self.sim_params.period_end
656
657
        return self.asset_finder.lookup_symbol(
658
            symbol_str,
659
            as_of_date=_lookup_date,
660
        )
661
662
    @api_method
663
    def symbols(self, *args):
664
        """
665
        Default symbols lookup for any source that directly maps the
666
        symbol to the Asset (e.g. yahoo finance).
667
        """
668
        return [self.symbol(identifier) for identifier in args]
669
670
    @api_method
671
    def sid(self, a_sid):
672
        """
673
        Default sid lookup for any source that directly maps the integer sid
674
        to the Asset.
675
        """
676
        return self.asset_finder.retrieve_asset(a_sid)
677
678
    @api_method
679
    @preprocess(symbol=ensure_upper_case)
680
    def future_symbol(self, symbol):
681
        """ Lookup a futures contract with a given symbol.
682
683
        Parameters
684
        ----------
685
        symbol : str
686
            The symbol of the desired contract.
687
688
        Returns
689
        -------
690
        Future
691
            A Future object.
692
693
        Raises
694
        ------
695
        SymbolNotFound
696
            Raised when no contract named 'symbol' is found.
697
698
        """
699
        return self.asset_finder.lookup_future_symbol(symbol)
700
701
    @api_method
702
    @preprocess(root_symbol=ensure_upper_case)
703
    def future_chain(self, root_symbol, as_of_date=None):
704
        """ Look up a future chain with the specified parameters.
705
706
        Parameters
707
        ----------
708
        root_symbol : str
709
            The root symbol of a future chain.
710
        as_of_date : datetime.datetime or pandas.Timestamp or str, optional
711
            Date at which the chain determination is rooted. I.e. the
712
            existing contract whose notice date is first after this date is
713
            the primary contract, etc.
714
715
        Returns
716
        -------
717
        FutureChain
718
            The future chain matching the specified parameters.
719
720
        Raises
721
        ------
722
        RootSymbolNotFound
723
            If a future chain could not be found for the given root symbol.
724
        """
725
        if as_of_date:
726
            try:
727
                as_of_date = pd.Timestamp(as_of_date, tz='UTC')
728
            except ValueError:
729
                raise UnsupportedDatetimeFormat(input=as_of_date,
730
                                                method='future_chain')
731
        return FutureChain(
732
            asset_finder=self.asset_finder,
733
            get_datetime=self.get_datetime,
734
            root_symbol=root_symbol,
735
            as_of_date=as_of_date
736
        )
737
738
    def _calculate_order_value_amount(self, asset, value):
739
        """
740
        Calculates how many shares/contracts to order based on the type of
741
        asset being ordered.
742
        """
743
        last_price = self.trading_client.current_data[asset].price
744
745
        if tolerant_equals(last_price, 0):
746
            zero_message = "Price of 0 for {psid}; can't infer value".format(
747
                psid=asset
748
            )
749
            if self.logger:
750
                self.logger.debug(zero_message)
751
            # Don't place any order
752
            return 0
753
754
        if isinstance(asset, Future):
755
            value_multiplier = asset.contract_multiplier
756
        else:
757
            value_multiplier = 1
758
759
        return value / (last_price * value_multiplier)
760
761
    @api_method
762
    def order(self, sid, amount,
763
              limit_price=None,
764
              stop_price=None,
765
              style=None):
766
        """
767
        Place an order using the specified parameters.
768
        """
769
        # Truncate to the integer share count that's either within .0001 of
770
        # amount or closer to zero.
771
        # E.g. 3.9999 -> 4.0; 5.5 -> 5.0; -5.5 -> -5.0
772
        amount = int(round_if_near_integer(amount))
773
774
        # Raises a ZiplineError if invalid parameters are detected.
775
        self.validate_order_params(sid,
776
                                   amount,
777
                                   limit_price,
778
                                   stop_price,
779
                                   style)
780
781
        # Convert deprecated limit_price and stop_price parameters to use
782
        # ExecutionStyle objects.
783
        style = self.__convert_order_params_for_blotter(limit_price,
784
                                                        stop_price,
785
                                                        style)
786
        return self.blotter.order(sid, amount, style)
787
788
    def validate_order_params(self,
789
                              asset,
790
                              amount,
791
                              limit_price,
792
                              stop_price,
793
                              style):
794
        """
795
        Helper method for validating parameters to the order API function.
796
797
        Raises an UnsupportedOrderParameters if invalid arguments are found.
798
        """
799
800
        if not self.initialized:
801
            raise OrderDuringInitialize(
802
                msg="order() can only be called from within handle_data()"
803
            )
804
805
        if style:
806
            if limit_price:
807
                raise UnsupportedOrderParameters(
808
                    msg="Passing both limit_price and style is not supported."
809
                )
810
811
            if stop_price:
812
                raise UnsupportedOrderParameters(
813
                    msg="Passing both stop_price and style is not supported."
814
                )
815
816
        if not isinstance(asset, Asset):
817
            raise UnsupportedOrderParameters(
818
                msg="Passing non-Asset argument to 'order()' is not supported."
819
                    " Use 'sid()' or 'symbol()' methods to look up an Asset."
820
            )
821
822
        for control in self.trading_controls:
823
            control.validate(asset,
824
                             amount,
825
                             self.portfolio,
826
                             self.get_datetime(),
827
                             self.trading_client.current_data)
828
829
    @staticmethod
830
    def __convert_order_params_for_blotter(limit_price, stop_price, style):
831
        """
832
        Helper method for converting deprecated limit_price and stop_price
833
        arguments into ExecutionStyle instances.
834
835
        This function assumes that either style == None or (limit_price,
836
        stop_price) == (None, None).
837
        """
838
        # TODO_SS: DeprecationWarning for usage of limit_price and stop_price.
839
        if style:
840
            assert (limit_price, stop_price) == (None, None)
841
            return style
842
        if limit_price and stop_price:
843
            return StopLimitOrder(limit_price, stop_price)
844
        if limit_price:
845
            return LimitOrder(limit_price)
846
        if stop_price:
847
            return StopOrder(stop_price)
848
        else:
849
            return MarketOrder()
850
851
    @api_method
852
    def order_value(self, sid, value,
853
                    limit_price=None, stop_price=None, style=None):
854
        """
855
        Place an order by desired value rather than desired number of shares.
856
        If the requested sid exists, the requested value is
857
        divided by its price to imply the number of shares to transact.
858
        If the Asset being ordered is a Future, the 'value' calculated
859
        is actually the exposure, as Futures have no 'value'.
860
861
        value > 0 :: Buy/Cover
862
        value < 0 :: Sell/Short
863
        Market order:    order(sid, value)
864
        Limit order:     order(sid, value, limit_price)
865
        Stop order:      order(sid, value, None, stop_price)
866
        StopLimit order: order(sid, value, limit_price, stop_price)
867
        """
868
        amount = self._calculate_order_value_amount(sid, value)
869
        return self.order(sid, amount,
870
                          limit_price=limit_price,
871
                          stop_price=stop_price,
872
                          style=style)
873
874
    @property
875
    def recorded_vars(self):
876
        return copy(self._recorded_vars)
877
878
    @property
879
    def portfolio(self):
880
        return self.updated_portfolio()
881
882
    def updated_portfolio(self):
883
        if self._portfolio is None and self.perf_tracker is not None:
884
            self._portfolio = \
885
                self.perf_tracker.get_portfolio(self.datetime)
886
        return self._portfolio
887
888
    @property
889
    def account(self):
890
        return self.updated_account()
891
892
    def updated_account(self):
893
        if self._account is None and self.perf_tracker is not None:
894
            self._account = \
895
                self.perf_tracker.get_account(self.datetime)
896
        return self._account
897
898
    def set_logger(self, logger):
899
        self.logger = logger
900
901
    def on_dt_changed(self, dt):
902
        """
903
        Callback triggered by the simulation loop whenever the current dt
904
        changes.
905
906
        Any logic that should happen exactly once at the start of each datetime
907
        group should happen here.
908
        """
909
        assert isinstance(dt, datetime), \
910
            "Attempt to set algorithm's current time with non-datetime"
911
        assert dt.tzinfo == pytz.utc, \
912
            "Algorithm expects a utc datetime"
913
914
        self.datetime = dt
915
        self.perf_tracker.set_date(dt)
916
        self.blotter.set_date(dt)
917
918
        self._portfolio = None
919
        self._account = None
920
921
    @api_method
922
    def get_datetime(self, tz=None):
923
        """
924
        Returns the simulation datetime.
925
        """
926
        dt = self.datetime
927
        assert dt.tzinfo == pytz.utc, "Algorithm should have a utc datetime"
928
929
        if tz is not None:
930
            # Convert to the given timezone passed as a string or tzinfo.
931
            if isinstance(tz, string_types):
932
                tz = pytz.timezone(tz)
933
            dt = dt.astimezone(tz)
934
935
        return dt  # datetime.datetime objects are immutable.
936
937
    def update_dividends(self, dividend_frame):
938
        """
939
        Set DataFrame used to process dividends.  DataFrame columns should
940
        contain at least the entries in zp.DIVIDEND_FIELDS.
941
        """
942
        self.perf_tracker.update_dividends(dividend_frame)
943
944
    @api_method
945
    def set_slippage(self, slippage):
946
        if not isinstance(slippage, SlippageModel):
947
            raise UnsupportedSlippageModel()
948
        if self.initialized:
949
            raise OverrideSlippagePostInit()
950
        self.blotter.slippage_func = slippage
951
952
    @api_method
953
    def set_commission(self, commission):
954
        if not isinstance(commission, (PerShare, PerTrade, PerDollar)):
955
            raise UnsupportedCommissionModel()
956
957
        if self.initialized:
958
            raise OverrideCommissionPostInit()
959
        self.blotter.commission = commission
960
961
    @api_method
962
    def set_symbol_lookup_date(self, dt):
963
        """
964
        Set the date for which symbols will be resolved to their sids
965
        (symbols may map to different firms or underlying assets at
966
        different times)
967
        """
968
        try:
969
            self._symbol_lookup_date = pd.Timestamp(dt, tz='UTC')
970
        except ValueError:
971
            raise UnsupportedDatetimeFormat(input=dt,
972
                                            method='set_symbol_lookup_date')
973
974
    # Remain backwards compatibility
975
    @property
976
    def data_frequency(self):
977
        return self.sim_params.data_frequency
978
979
    @data_frequency.setter
980
    def data_frequency(self, value):
981
        assert value in ('daily', 'minute')
982
        self.sim_params.data_frequency = value
983
984
    @api_method
985
    def order_percent(self, sid, percent,
986
                      limit_price=None, stop_price=None, style=None):
987
        """
988
        Place an order in the specified asset corresponding to the given
989
        percent of the current portfolio value.
990
991
        Note that percent must expressed as a decimal (0.50 means 50\%).
992
        """
993
        value = self.portfolio.portfolio_value * percent
994
        return self.order_value(sid, value,
995
                                limit_price=limit_price,
996
                                stop_price=stop_price,
997
                                style=style)
998
999
    @api_method
1000
    def order_target(self, sid, target,
1001
                     limit_price=None, stop_price=None, style=None):
1002
        """
1003
        Place an order to adjust a position to a target number of shares. If
1004
        the position doesn't already exist, this is equivalent to placing a new
1005
        order. If the position does exist, this is equivalent to placing an
1006
        order for the difference between the target number of shares and the
1007
        current number of shares.
1008
        """
1009
        if sid in self.portfolio.positions:
1010
            current_position = self.portfolio.positions[sid].amount
1011
            req_shares = target - current_position
1012
            return self.order(sid, req_shares,
1013
                              limit_price=limit_price,
1014
                              stop_price=stop_price,
1015
                              style=style)
1016
        else:
1017
            return self.order(sid, target,
1018
                              limit_price=limit_price,
1019
                              stop_price=stop_price,
1020
                              style=style)
1021
1022
    @api_method
1023
    def order_target_value(self, sid, target,
1024
                           limit_price=None, stop_price=None, style=None):
1025
        """
1026
        Place an order to adjust a position to a target value. If
1027
        the position doesn't already exist, this is equivalent to placing a new
1028
        order. If the position does exist, this is equivalent to placing an
1029
        order for the difference between the target value and the
1030
        current value.
1031
        If the Asset being ordered is a Future, the 'target value' calculated
1032
        is actually the target exposure, as Futures have no 'value'.
1033
        """
1034
        target_amount = self._calculate_order_value_amount(sid, target)
1035
        return self.order_target(sid, target_amount,
1036
                                 limit_price=limit_price,
1037
                                 stop_price=stop_price,
1038
                                 style=style)
1039
1040
    @api_method
1041
    def order_target_percent(self, sid, target,
1042
                             limit_price=None, stop_price=None, style=None):
1043
        """
1044
        Place an order to adjust a position to a target percent of the
1045
        current portfolio value. If the position doesn't already exist, this is
1046
        equivalent to placing a new order. If the position does exist, this is
1047
        equivalent to placing an order for the difference between the target
1048
        percent and the current percent.
1049
1050
        Note that target must expressed as a decimal (0.50 means 50\%).
1051
        """
1052
        target_value = self.portfolio.portfolio_value * target
1053
        return self.order_target_value(sid, target_value,
1054
                                       limit_price=limit_price,
1055
                                       stop_price=stop_price,
1056
                                       style=style)
1057
1058
    @api_method
1059
    def get_open_orders(self, sid=None):
1060
        if sid is None:
1061
            return {
1062
                key: [order.to_api_obj() for order in orders]
1063
                for key, orders in iteritems(self.blotter.open_orders)
1064
                if orders
1065
            }
1066
        if sid in self.blotter.open_orders:
1067
            orders = self.blotter.open_orders[sid]
1068
            return [order.to_api_obj() for order in orders]
1069
        return []
1070
1071
    @api_method
1072
    def get_order(self, order_id):
1073
        if order_id in self.blotter.orders:
1074
            return self.blotter.orders[order_id].to_api_obj()
1075
1076
    @api_method
1077
    def cancel_order(self, order_param):
1078
        order_id = order_param
1079
        if isinstance(order_param, zipline.protocol.Order):
1080
            order_id = order_param.id
1081
1082
        self.blotter.cancel(order_id)
1083
1084
    @api_method
1085
    @require_initialized(HistoryInInitialize())
1086
    def history(self, sids, bar_count, frequency, field, ffill=True):
1087
        if self.data_portal is None:
1088
            raise Exception("no data portal!")
1089
1090
        return self.data_portal.get_history_window(
1091
            sids,
1092
            self.datetime,
1093
            bar_count,
1094
            frequency,
1095
            field,
1096
            ffill,
1097
        )
1098
1099
    ####################
1100
    # Account Controls #
1101
    ####################
1102
1103
    def register_account_control(self, control):
1104
        """
1105
        Register a new AccountControl to be checked on each bar.
1106
        """
1107
        if self.initialized:
1108
            raise RegisterAccountControlPostInit()
1109
        self.account_controls.append(control)
1110
1111
    def validate_account_controls(self):
1112
        for control in self.account_controls:
1113
            control.validate(self.portfolio,
1114
                             self.account,
1115
                             self.get_datetime(),
1116
                             self.trading_client.current_data)
1117
1118
    @api_method
1119
    def set_max_leverage(self, max_leverage=None):
1120
        """
1121
        Set a limit on the maximum leverage of the algorithm.
1122
        """
1123
        control = MaxLeverage(max_leverage)
1124
        self.register_account_control(control)
1125
1126
    ####################
1127
    # Trading Controls #
1128
    ####################
1129
1130
    def register_trading_control(self, control):
1131
        """
1132
        Register a new TradingControl to be checked prior to order calls.
1133
        """
1134
        if self.initialized:
1135
            raise RegisterTradingControlPostInit()
1136
        self.trading_controls.append(control)
1137
1138
    @api_method
1139
    def set_max_position_size(self,
1140
                              sid=None,
1141
                              max_shares=None,
1142
                              max_notional=None):
1143
        """
1144
        Set a limit on the number of shares and/or dollar value held for the
1145
        given sid. Limits are treated as absolute values and are enforced at
1146
        the time that the algo attempts to place an order for sid. This means
1147
        that it's possible to end up with more than the max number of shares
1148
        due to splits/dividends, and more than the max notional due to price
1149
        improvement.
1150
1151
        If an algorithm attempts to place an order that would result in
1152
        increasing the absolute value of shares/dollar value exceeding one of
1153
        these limits, raise a TradingControlException.
1154
        """
1155
        control = MaxPositionSize(asset=sid,
1156
                                  max_shares=max_shares,
1157
                                  max_notional=max_notional)
1158
        self.register_trading_control(control)
1159
1160
    @api_method
1161
    def set_max_order_size(self, sid=None, max_shares=None, max_notional=None):
1162
        """
1163
        Set a limit on the number of shares and/or dollar value of any single
1164
        order placed for sid.  Limits are treated as absolute values and are
1165
        enforced at the time that the algo attempts to place an order for sid.
1166
1167
        If an algorithm attempts to place an order that would result in
1168
        exceeding one of these limits, raise a TradingControlException.
1169
        """
1170
        control = MaxOrderSize(asset=sid,
1171
                               max_shares=max_shares,
1172
                               max_notional=max_notional)
1173
        self.register_trading_control(control)
1174
1175
    @api_method
1176
    def set_max_order_count(self, max_count):
1177
        """
1178
        Set a limit on the number of orders that can be placed within the given
1179
        time interval.
1180
        """
1181
        control = MaxOrderCount(max_count)
1182
        self.register_trading_control(control)
1183
1184
    @api_method
1185
    def set_do_not_order_list(self, restricted_list):
1186
        """
1187
        Set a restriction on which sids can be ordered.
1188
        """
1189
        control = RestrictedListOrder(restricted_list)
1190
        self.register_trading_control(control)
1191
1192
    @api_method
1193
    def set_long_only(self):
1194
        """
1195
        Set a rule specifying that this algorithm cannot take short positions.
1196
        """
1197
        self.register_trading_control(LongOnly())
1198
1199
    ##############
1200
    # Pipeline API
1201
    ##############
1202
    @api_method
1203
    @require_not_initialized(AttachPipelineAfterInitialize())
1204
    def attach_pipeline(self, pipeline, name, chunksize=None):
1205
        """
1206
        Register a pipeline to be computed at the start of each day.
1207
        """
1208
        if self._pipelines:
1209
            raise NotImplementedError("Multiple pipelines are not supported.")
1210
        if chunksize is None:
1211
            # Make the first chunk smaller to get more immediate results:
1212
            # (one week, then every half year)
1213
            chunks = iter(chain([5], repeat(126)))
1214
        else:
1215
            chunks = iter(repeat(int(chunksize)))
1216
        self._pipelines[name] = pipeline, chunks
1217
1218
        # Return the pipeline to allow expressions like
1219
        # p = attach_pipeline(Pipeline(), 'name')
1220
        return pipeline
1221
1222
    @api_method
1223
    @require_initialized(PipelineOutputDuringInitialize())
1224
    def pipeline_output(self, name):
1225
        """
1226
        Get the results of pipeline with name `name`.
1227
1228
        Parameters
1229
        ----------
1230
        name : str
1231
            Name of the pipeline for which results are requested.
1232
1233
        Returns
1234
        -------
1235
        results : pd.DataFrame
1236
            DataFrame containing the results of the requested pipeline for
1237
            the current simulation date.
1238
1239
        Raises
1240
        ------
1241
        NoSuchPipeline
1242
            Raised when no pipeline with the name `name` has been registered.
1243
1244
        See Also
1245
        --------
1246
        :meth:`zipline.pipeline.engine.PipelineEngine.run_pipeline`
1247
        """
1248
        # NOTE: We don't currently support multiple pipelines, but we plan to
1249
        # in the future.
1250
        try:
1251
            p, chunks = self._pipelines[name]
1252
        except KeyError:
1253
            raise NoSuchPipeline(
1254
                name=name,
1255
                valid=list(self._pipelines.keys()),
1256
            )
1257
        return self._pipeline_output(p, chunks)
1258
1259
    def _pipeline_output(self, pipeline, chunks):
1260
        """
1261
        Internal implementation of `pipeline_output`.
1262
        """
1263
        today = normalize_date(self.get_datetime())
1264
        try:
1265
            data = self._pipeline_cache.unwrap(today)
1266
        except Expired:
1267
            data, valid_until = self._run_pipeline(
1268
                pipeline, today, next(chunks),
1269
            )
1270
            self._pipeline_cache = CachedObject(data, valid_until)
1271
1272
        # Now that we have a cached result, try to return the data for today.
1273
        try:
1274
            return data.loc[today]
1275
        except KeyError:
1276
            # This happens if no assets passed the pipeline screen on a given
1277
            # day.
1278
            return pd.DataFrame(index=[], columns=data.columns)
1279
1280
    def _run_pipeline(self, pipeline, start_date, chunksize):
1281
        """
1282
        Compute `pipeline`, providing values for at least `start_date`.
1283
1284
        Produces a DataFrame containing data for days between `start_date` and
1285
        `end_date`, where `end_date` is defined by:
1286
1287
            `end_date = min(start_date + chunksize trading days,
1288
                            simulation_end)`
1289
1290
        Returns
1291
        -------
1292
        (data, valid_until) : tuple (pd.DataFrame, pd.Timestamp)
1293
1294
        See Also
1295
        --------
1296
        PipelineEngine.run_pipeline
1297
        """
1298
        days = self.trading_environment.trading_days
1299
1300
        # Load data starting from the previous trading day...
1301
        start_date_loc = days.get_loc(start_date)
1302
1303
        # ...continuing until either the day before the simulation end, or
1304
        # until chunksize days of data have been loaded.
1305
        sim_end = self.sim_params.last_close.normalize()
1306
        end_loc = min(start_date_loc + chunksize, days.get_loc(sim_end))
1307
        end_date = days[end_loc]
1308
1309
        return \
1310
            self.engine.run_pipeline(pipeline, start_date, end_date), end_date
1311
1312
    ##################
1313
    # End Pipeline API
1314
    ##################
1315
1316
    @classmethod
1317
    def all_api_methods(cls):
1318
        """
1319
        Return a list of all the TradingAlgorithm API methods.
1320
        """
1321
        return [
1322
            fn for fn in itervalues(vars(cls))
1323
            if getattr(fn, 'is_api_method', False)
1324
        ]
1325