Completed
Pull Request — master (#858)
by Eddie
10:07 queued 01:13
created

zipline.TradingAlgorithm.get_environment()   A

Complexity

Conditions 2

Size

Total Lines 14

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 2
dl 0
loc 14
rs 9.4286
1
#
2
# Copyright 2015 Quantopian, Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
from copy import copy
16
17
import pytz
18
import pandas as pd
19
from pandas.tseries.tools import normalize_date
20
import numpy as np
21
22
from datetime import datetime
23
from itertools import chain, repeat
24
from numbers import Integral
25
26
from six import (
27
    exec_,
28
    iteritems,
29
    itervalues,
30
    string_types,
31
)
32
33
34
from zipline.errors import (
35
    AttachPipelineAfterInitialize,
36
    HistoryInInitialize,
37
    NoSuchPipeline,
38
    OrderDuringInitialize,
39
    PipelineOutputDuringInitialize,
40
    RegisterAccountControlPostInit,
41
    RegisterTradingControlPostInit,
42
    SetBenchmarkOutsideInitialize,
43
    SetCommissionPostInit,
44
    SetSlippagePostInit,
45
    UnsupportedCommissionModel,
46
    UnsupportedDatetimeFormat,
47
    UnsupportedOrderParameters,
48
    UnsupportedSlippageModel,
49
)
50
from zipline.finance.trading import TradingEnvironment
51
from zipline.finance.blotter import Blotter
52
from zipline.finance.commission import PerShare, PerTrade, PerDollar
53
from zipline.finance.controls import (
54
    LongOnly,
55
    MaxOrderCount,
56
    MaxOrderSize,
57
    MaxPositionSize,
58
    MaxLeverage,
59
    RestrictedListOrder
60
)
61
from zipline.finance.execution import (
62
    LimitOrder,
63
    MarketOrder,
64
    StopLimitOrder,
65
    StopOrder,
66
)
67
from zipline.finance.performance import PerformanceTracker
68
from zipline.finance.slippage import (
69
    VolumeShareSlippage,
70
    SlippageModel
71
)
72
from zipline.assets import Asset, Future
73
from zipline.assets.futures import FutureChain
74
from zipline.gens.tradesimulation import AlgorithmSimulator
75
from zipline.pipeline.engine import (
76
    NoOpPipelineEngine,
77
    SimplePipelineEngine,
78
)
79
from zipline.utils.api_support import (
80
    api_method,
81
    require_initialized,
82
    require_not_initialized,
83
    ZiplineAPI,
84
)
85
from zipline.utils.input_validation import ensure_upper_case
86
from zipline.utils.cache import CachedObject, Expired
87
import zipline.utils.events
88
from zipline.utils.events import (
89
    EventManager,
90
    make_eventrule,
91
    DateRuleFactory,
92
    TimeRuleFactory,
93
)
94
from zipline.utils.factory import create_simulation_parameters
95
from zipline.utils.math_utils import (
96
    tolerant_equals,
97
    round_if_near_integer
98
)
99
from zipline.utils.preprocess import preprocess
100
101
import zipline.protocol
102
from zipline.sources.requests_csv import PandasRequestsCSV
103
104
from zipline.gens.sim_engine import (
105
    MinuteSimulationClock,
106
    DailySimulationClock,
107
)
108
from zipline.sources.benchmark_source import BenchmarkSource
109
110
DEFAULT_CAPITAL_BASE = float("1.0e5")
111
112
113
class TradingAlgorithm(object):
114
    """
115
    Base class for trading algorithms. Inherit and overload
116
    initialize() and handle_data(data).
117
118
    A new algorithm could look like this:
119
    ```
120
    from zipline.api import order, symbol
121
122
    def initialize(context):
123
        context.sid = symbol('AAPL')
124
        context.amount = 100
125
126
    def handle_data(context, data):
127
        sid = context.sid
128
        amount = context.amount
129
        order(sid, amount)
130
    ```
131
    To then to run this algorithm pass these functions to
132
    TradingAlgorithm:
133
134
    my_algo = TradingAlgorithm(initialize, handle_data)
135
    stats = my_algo.run(data)
136
137
    """
138
139
    def __init__(self, *args, **kwargs):
140
        """Initialize sids and other state variables.
141
142
        :Arguments:
143
        :Optional:
144
            initialize : function
145
                Function that is called with a single
146
                argument at the begninning of the simulation.
147
            handle_data : function
148
                Function that is called with 2 arguments
149
                (context and data) on every bar.
150
            script : str
151
                Algoscript that contains initialize and
152
                handle_data function definition.
153
            data_frequency : {'daily', 'minute'}
154
               The duration of the bars.
155
            capital_base : float <default: 1.0e5>
156
               How much capital to start with.
157
            asset_finder : An AssetFinder object
158
                A new AssetFinder object to be used in this TradingEnvironment
159
            equities_metadata : can be either:
160
                            - dict
161
                            - pandas.DataFrame
162
                            - object with 'read' property
163
                If dict is provided, it must have the following structure:
164
                * keys are the identifiers
165
                * values are dicts containing the metadata, with the metadata
166
                  field name as the key
167
                If pandas.DataFrame is provided, it must have the
168
                following structure:
169
                * column names must be the metadata fields
170
                * index must be the different asset identifiers
171
                * array contents should be the metadata value
172
                If an object with a 'read' property is provided, 'read' must
173
                return rows containing at least one of 'sid' or 'symbol' along
174
                with the other metadata fields.
175
            identifiers : List
176
                Any asset identifiers that are not provided in the
177
                equities_metadata, but will be traded by this TradingAlgorithm
178
        """
179
        self.sources = []
180
181
        # List of trading controls to be used to validate orders.
182
        self.trading_controls = []
183
184
        # List of account controls to be checked on each bar.
185
        self.account_controls = []
186
187
        self._recorded_vars = {}
188
        self.namespace = kwargs.pop('namespace', {})
189
190
        self._platform = kwargs.pop('platform', 'zipline')
191
192
        self.logger = None
193
194
        self.data_portal = None
195
196
        # If an env has been provided, pop it
197
        self.trading_environment = kwargs.pop('env', None)
198
199
        if self.trading_environment is None:
200
            self.trading_environment = TradingEnvironment()
201
202
        # Update the TradingEnvironment with the provided asset metadata
203
        self.trading_environment.write_data(
204
            equities_data=kwargs.pop('equities_metadata', {}),
205
            equities_identifiers=kwargs.pop('identifiers', []),
206
            futures_data=kwargs.pop('futures_metadata', {}),
207
        )
208
209
        # set the capital base
210
        self.capital_base = kwargs.pop('capital_base', DEFAULT_CAPITAL_BASE)
211
        self.sim_params = kwargs.pop('sim_params', None)
212
        if self.sim_params is None:
213
            self.sim_params = create_simulation_parameters(
214
                capital_base=self.capital_base,
215
                start=kwargs.pop('start', None),
216
                end=kwargs.pop('end', None),
217
                env=self.trading_environment,
218
            )
219
        else:
220
            self.sim_params.update_internal_from_env(self.trading_environment)
221
222
        self.perf_tracker = None
223
        # Pull in the environment's new AssetFinder for quick reference
224
        self.asset_finder = self.trading_environment.asset_finder
225
226
        # Initialize Pipeline API data.
227
        self.init_engine(kwargs.pop('get_pipeline_loader', None))
228
        self._pipelines = {}
229
        # Create an always-expired cache so that we compute the first time data
230
        # is requested.
231
        self._pipeline_cache = CachedObject(None, pd.Timestamp(0, tz='UTC'))
232
233
        self.blotter = kwargs.pop('blotter', None)
234
        if not self.blotter:
235
            self.blotter = Blotter(
236
                slippage_func=VolumeShareSlippage(),
237
                commission=PerShare(),
238
                data_frequency=self.data_frequency,
239
            )
240
241
        # The symbol lookup date specifies the date to use when resolving
242
        # symbols to sids, and can be set using set_symbol_lookup_date()
243
        self._symbol_lookup_date = None
244
245
        self.portfolio_needs_update = True
246
        self.account_needs_update = True
247
        self.performance_needs_update = True
248
        self._portfolio = None
249
        self._account = None
250
251
        # If string is passed in, execute and get reference to
252
        # functions.
253
        self.algoscript = kwargs.pop('script', None)
254
255
        self._initialize = None
256
        self._before_trading_start = None
257
        self._analyze = None
258
259
        self.event_manager = EventManager(
260
            create_context=kwargs.pop('create_event_context', None),
261
        )
262
263
        if self.algoscript is not None:
264
            filename = kwargs.pop('algo_filename', None)
265
            if filename is None:
266
                filename = '<string>'
267
            code = compile(self.algoscript, filename, 'exec')
268
            exec_(code, self.namespace)
269
            self._initialize = self.namespace.get('initialize')
270
            if 'handle_data' not in self.namespace:
271
                raise ValueError('You must define a handle_data function.')
272
            else:
273
                self._handle_data = self.namespace['handle_data']
274
275
            self._before_trading_start = \
276
                self.namespace.get('before_trading_start')
277
            # Optional analyze function, gets called after run
278
            self._analyze = self.namespace.get('analyze')
279
280
        elif kwargs.get('initialize') and kwargs.get('handle_data'):
281
            if self.algoscript is not None:
282
                raise ValueError('You can not set script and \
283
                initialize/handle_data.')
284
            self._initialize = kwargs.pop('initialize')
285
            self._handle_data = kwargs.pop('handle_data')
286
            self._before_trading_start = kwargs.pop('before_trading_start',
287
                                                    None)
288
            self._analyze = kwargs.pop('analyze', None)
289
290
        self.event_manager.add_event(
291
            zipline.utils.events.Event(
292
                zipline.utils.events.Always(),
293
                # We pass handle_data.__func__ to get the unbound method.
294
                # We will explicitly pass the algorithm to bind it again.
295
                self.handle_data.__func__,
296
            ),
297
            prepend=True,
298
        )
299
300
        # If method not defined, NOOP
301
        if self._initialize is None:
302
            self._initialize = lambda x: None
303
304
        # Alternative way of setting data_frequency for backwards
305
        # compatibility.
306
        if 'data_frequency' in kwargs:
307
            self.data_frequency = kwargs.pop('data_frequency')
308
309
        # Prepare the algo for initialization
310
        self.initialized = False
311
        self.initialize_args = args
312
        self.initialize_kwargs = kwargs
313
314
        self.benchmark_sid = kwargs.pop('benchmark_sid', None)
315
316
    def init_engine(self, get_loader):
317
        """
318
        Construct and store a PipelineEngine from loader.
319
320
        If get_loader is None, constructs a NoOpPipelineEngine.
321
        """
322
        if get_loader is not None:
323
            self.engine = SimplePipelineEngine(
324
                get_loader,
325
                self.trading_environment.trading_days,
326
                self.asset_finder,
327
            )
328
        else:
329
            self.engine = NoOpPipelineEngine()
330
331
    def initialize(self, *args, **kwargs):
332
        """
333
        Call self._initialize with `self` made available to Zipline API
334
        functions.
335
        """
336
        with ZiplineAPI(self):
337
            self._initialize(self, *args, **kwargs)
338
339
    def before_trading_start(self, data):
340
        if self._before_trading_start is None:
341
            return
342
343
        self._before_trading_start(self, data)
344
345
    def handle_data(self, data):
346
        self._handle_data(self, data)
347
348
        # Unlike trading controls which remain constant unless placing an
349
        # order, account controls can change each bar. Thus, must check
350
        # every bar no matter if the algorithm places an order or not.
351
        self.validate_account_controls()
352
353
    def analyze(self, perf):
354
        if self._analyze is None:
355
            return
356
357
        with ZiplineAPI(self):
358
            self._analyze(self, perf)
359
360
    def __repr__(self):
361
        """
362
        N.B. this does not yet represent a string that can be used
363
        to instantiate an exact copy of an algorithm.
364
365
        However, it is getting close, and provides some value as something
366
        that can be inspected interactively.
367
        """
368
        return """
369
{class_name}(
370
    capital_base={capital_base}
371
    sim_params={sim_params},
372
    initialized={initialized},
373
    slippage={slippage},
374
    commission={commission},
375
    blotter={blotter},
376
    recorded_vars={recorded_vars})
377
""".strip().format(class_name=self.__class__.__name__,
378
                   capital_base=self.capital_base,
379
                   sim_params=repr(self.sim_params),
380
                   initialized=self.initialized,
381
                   slippage=repr(self.blotter.slippage_func),
382
                   commission=repr(self.blotter.commission),
383
                   blotter=repr(self.blotter),
384
                   recorded_vars=repr(self.recorded_vars))
385
386
    def _create_clock(self):
387
        """
388
        If the clock property is not set, then create one based on frequency.
389
        """
390
        if self.sim_params.data_frequency == 'minute':
391
            env = self.trading_environment
392
            trading_o_and_c = env.open_and_closes.ix[
393
                self.sim_params.trading_days]
394
            market_opens = trading_o_and_c['market_open'].values.astype(
395
                'datetime64[ns]').astype(np.int64)
396
            market_closes = trading_o_and_c['market_close'].values.astype(
397
                'datetime64[ns]').astype(np.int64)
398
399
            minutely_emission = self.sim_params.emission_rate == "minute"
400
401
            clock = MinuteSimulationClock(
402
                self.sim_params.trading_days,
403
                market_opens,
404
                market_closes,
405
                env.trading_days,
406
                minutely_emission
407
            )
408
            return clock
409
        else:
410
            return DailySimulationClock(self.sim_params.trading_days)
411
412
    def _create_benchmark_source(self):
413
        return BenchmarkSource(
414
            self.benchmark_sid,
415
            self.trading_environment,
416
            self.sim_params.trading_days,
417
            self.data_portal,
418
            emission_rate=self.sim_params.emission_rate,
419
        )
420
421
    def _create_generator(self, sim_params):
422
        if sim_params is not None:
423
            self.sim_params = sim_params
424
425
        if self.perf_tracker is None:
426
            # HACK: When running with the `run` method, we set perf_tracker to
427
            # None so that it will be overwritten here.
428
            self.perf_tracker = PerformanceTracker(
429
                sim_params=self.sim_params,
430
                env=self.trading_environment,
431
                data_portal=self.data_portal
432
            )
433
434
            # Set the dt initially to the period start by forcing it to change.
435
            self.on_dt_changed(self.sim_params.period_start)
436
437
        if not self.initialized:
438
            self.initialize(*self.initialize_args, **self.initialize_kwargs)
439
            self.initialized = True
440
441
        self.trading_client = self._create_trading_client(sim_params)
442
443
        return self.trading_client.transform()
444
445
    def _create_trading_client(self, sim_params):
446
        return AlgorithmSimulator(
447
            self,
448
            sim_params,
449
            self.data_portal,
450
            self._create_clock(),
451
            self._create_benchmark_source()
452
        )
453
454
    def get_generator(self):
455
        """
456
        Override this method to add new logic to the construction
457
        of the generator. Overrides can use the _create_generator
458
        method to get a standard construction generator.
459
        """
460
        return self._create_generator(self.sim_params)
461
462
    def run(self, data_portal=None):
463
        """Run the algorithm.
464
465
        :Arguments:
466
            source : DataPortal
467
468
        :Returns:
469
            daily_stats : pandas.DataFrame
470
              Daily performance metrics such as returns, alpha etc.
471
472
        """
473
        self.data_portal = data_portal
474
475
        # Force a reset of the performance tracker, in case
476
        # this is a repeat run of the algorithm.
477
        self.perf_tracker = None
478
479
        # Create zipline and loop through simulated_trading.
480
        # Each iteration returns a perf dictionary
481
        perfs = []
482
        for perf in self.get_generator():
483
            perfs.append(perf)
484
485
        # convert perf dict to pandas dataframe
486
        daily_stats = self._create_daily_stats(perfs)
487
488
        self.analyze(daily_stats)
489
490
        return daily_stats
491
492
    def _write_and_map_id_index_to_sids(self, identifiers, as_of_date):
493
        # Build new Assets for identifiers that can't be resolved as
494
        # sids/Assets
495
        identifiers_to_build = []
496
        for identifier in identifiers:
497
            asset = None
498
499
            if isinstance(identifier, Asset):
500
                asset = self.asset_finder.retrieve_asset(sid=identifier.sid,
501
                                                         default_none=True)
502
            elif isinstance(identifier, Integral):
503
                asset = self.asset_finder.retrieve_asset(sid=identifier,
504
                                                         default_none=True)
505
            if asset is None:
506
                identifiers_to_build.append(identifier)
507
508
        self.trading_environment.write_data(
509
            equities_identifiers=identifiers_to_build)
510
511
        # We need to clear out any cache misses that were stored while trying
512
        # to do lookups.  The real fix for this problem is to not construct an
513
        # AssetFinder until we `run()` when we actually have all the data we
514
        # need to so.
515
        self.asset_finder._reset_caches()
516
517
        return self.asset_finder.map_identifier_index_to_sids(
518
            identifiers, as_of_date,
519
        )
520
521
    def _create_daily_stats(self, perfs):
522
        # create daily and cumulative stats dataframe
523
        daily_perfs = []
524
        # TODO: the loop here could overwrite expected properties
525
        # of daily_perf. Could potentially raise or log a
526
        # warning.
527
        for perf in perfs:
528
            if 'daily_perf' in perf:
529
530
                perf['daily_perf'].update(
531
                    perf['daily_perf'].pop('recorded_vars')
532
                )
533
                perf['daily_perf'].update(perf['cumulative_risk_metrics'])
534
                daily_perfs.append(perf['daily_perf'])
535
            else:
536
                self.risk_report = perf
537
538
        daily_dts = [np.datetime64(perf['period_close'], utc=True)
539
                     for perf in daily_perfs]
540
        daily_stats = pd.DataFrame(daily_perfs, index=daily_dts)
541
542
        return daily_stats
543
544
    @api_method
545
    def get_environment(self, field='platform'):
546
        env = {
547
            'arena': self.sim_params.arena,
548
            'data_frequency': self.sim_params.data_frequency,
549
            'start': self.sim_params.first_open,
550
            'end': self.sim_params.last_close,
551
            'capital_base': self.sim_params.capital_base,
552
            'platform': self._platform
553
        }
554
        if field == '*':
555
            return env
556
        else:
557
            return env[field]
558
559
    @api_method
560
    def fetch_csv(self, url,
561
                  pre_func=None,
562
                  post_func=None,
563
                  date_column='date',
564
                  date_format=None,
565
                  timezone=pytz.utc.zone,
566
                  symbol=None,
567
                  mask=True,
568
                  symbol_column=None,
569
                  special_params_checker=None,
570
                  **kwargs):
571
572
        # Show all the logs every time fetcher is used.
573
        csv_data_source = PandasRequestsCSV(
574
            url,
575
            pre_func,
576
            post_func,
577
            self.trading_environment,
578
            self.sim_params.period_start,
579
            self.sim_params.period_end,
580
            date_column,
581
            date_format,
582
            timezone,
583
            symbol,
584
            mask,
585
            symbol_column,
586
            data_frequency=self.data_frequency,
587
            special_params_checker=special_params_checker,
588
            **kwargs
589
        )
590
591
        # ingest this into dataportal
592
        self.data_portal.handle_extra_source(csv_data_source.df,
593
                                             self.sim_params)
594
595
        return csv_data_source
596
597
    def add_event(self, rule=None, callback=None):
598
        """
599
        Adds an event to the algorithm's EventManager.
600
        """
601
        self.event_manager.add_event(
602
            zipline.utils.events.Event(rule, callback),
603
        )
604
605
    @api_method
606
    def schedule_function(self,
607
                          func,
608
                          date_rule=None,
609
                          time_rule=None,
610
                          half_days=True):
611
        """
612
        Schedules a function to be called with some timed rules.
613
        """
614
        date_rule = date_rule or DateRuleFactory.every_day()
615
        time_rule = ((time_rule or TimeRuleFactory.market_open())
616
                     if self.sim_params.data_frequency == 'minute' else
617
                     # If we are in daily mode the time_rule is ignored.
618
                     zipline.utils.events.Always())
619
620
        self.add_event(
621
            make_eventrule(date_rule, time_rule, half_days),
622
            func,
623
        )
624
625
    @api_method
626
    def record(self, *args, **kwargs):
627
        """
628
        Track and record local variable (i.e. attributes) each day.
629
        """
630
        # Make 2 objects both referencing the same iterator
631
        args = [iter(args)] * 2
632
633
        # Zip generates list entries by calling `next` on each iterator it
634
        # receives.  In this case the two iterators are the same object, so the
635
        # call to next on args[0] will also advance args[1], resulting in zip
636
        # returning (a,b) (c,d) (e,f) rather than (a,a) (b,b) (c,c) etc.
637
        positionals = zip(*args)
638
        for name, value in chain(positionals, iteritems(kwargs)):
639
            self._recorded_vars[name] = value
640
641
    @api_method
642
    def set_benchmark(self, benchmark_sid):
643
        if self.initialized:
644
            raise SetBenchmarkOutsideInitialize()
645
646
        self.benchmark_sid = benchmark_sid
647
648
    @api_method
649
    @preprocess(symbol_str=ensure_upper_case)
650
    def symbol(self, symbol_str):
651
        """
652
        Default symbol lookup for any source that directly maps the
653
        symbol to the Asset (e.g. yahoo finance).
654
        """
655
        # If the user has not set the symbol lookup date,
656
        # use the period_end as the date for sybmol->sid resolution.
657
        _lookup_date = self._symbol_lookup_date if self._symbol_lookup_date is not None \
658
            else self.sim_params.period_end
659
660
        return self.asset_finder.lookup_symbol(
661
            symbol_str,
662
            as_of_date=_lookup_date,
663
        )
664
665
    @api_method
666
    def symbols(self, *args):
667
        """
668
        Default symbols lookup for any source that directly maps the
669
        symbol to the Asset (e.g. yahoo finance).
670
        """
671
        return [self.symbol(identifier) for identifier in args]
672
673
    @api_method
674
    def sid(self, a_sid):
675
        """
676
        Default sid lookup for any source that directly maps the integer sid
677
        to the Asset.
678
        """
679
        return self.asset_finder.retrieve_asset(a_sid)
680
681
    @api_method
682
    @preprocess(symbol=ensure_upper_case)
683
    def future_symbol(self, symbol):
684
        """ Lookup a futures contract with a given symbol.
685
686
        Parameters
687
        ----------
688
        symbol : str
689
            The symbol of the desired contract.
690
691
        Returns
692
        -------
693
        Future
694
            A Future object.
695
696
        Raises
697
        ------
698
        SymbolNotFound
699
            Raised when no contract named 'symbol' is found.
700
701
        """
702
        return self.asset_finder.lookup_future_symbol(symbol)
703
704
    @api_method
705
    @preprocess(root_symbol=ensure_upper_case)
706
    def future_chain(self, root_symbol, as_of_date=None):
707
        """ Look up a future chain with the specified parameters.
708
709
        Parameters
710
        ----------
711
        root_symbol : str
712
            The root symbol of a future chain.
713
        as_of_date : datetime.datetime or pandas.Timestamp or str, optional
714
            Date at which the chain determination is rooted. I.e. the
715
            existing contract whose notice date is first after this date is
716
            the primary contract, etc.
717
718
        Returns
719
        -------
720
        FutureChain
721
            The future chain matching the specified parameters.
722
723
        Raises
724
        ------
725
        RootSymbolNotFound
726
            If a future chain could not be found for the given root symbol.
727
        """
728
        if as_of_date:
729
            try:
730
                as_of_date = pd.Timestamp(as_of_date, tz='UTC')
731
            except ValueError:
732
                raise UnsupportedDatetimeFormat(input=as_of_date,
733
                                                method='future_chain')
734
        return FutureChain(
735
            asset_finder=self.asset_finder,
736
            get_datetime=self.get_datetime,
737
            root_symbol=root_symbol,
738
            as_of_date=as_of_date
739
        )
740
741
    def _calculate_order_value_amount(self, asset, value):
742
        """
743
        Calculates how many shares/contracts to order based on the type of
744
        asset being ordered.
745
        """
746
        last_price = self.trading_client.current_data[asset].price
747
748
        if tolerant_equals(last_price, 0):
749
            zero_message = "Price of 0 for {psid}; can't infer value".format(
750
                psid=asset
751
            )
752
            if self.logger:
753
                self.logger.debug(zero_message)
754
            # Don't place any order
755
            return 0
756
757
        if isinstance(asset, Future):
758
            value_multiplier = asset.contract_multiplier
759
        else:
760
            value_multiplier = 1
761
762
        return value / (last_price * value_multiplier)
763
764
    @api_method
765
    def order(self, sid, amount,
766
              limit_price=None,
767
              stop_price=None,
768
              style=None):
769
        """
770
        Place an order using the specified parameters.
771
        """
772
        # Truncate to the integer share count that's either within .0001 of
773
        # amount or closer to zero.
774
        # E.g. 3.9999 -> 4.0; 5.5 -> 5.0; -5.5 -> -5.0
775
        amount = int(round_if_near_integer(amount))
776
777
        # Raises a ZiplineError if invalid parameters are detected.
778
        self.validate_order_params(sid,
779
                                   amount,
780
                                   limit_price,
781
                                   stop_price,
782
                                   style)
783
784
        # Convert deprecated limit_price and stop_price parameters to use
785
        # ExecutionStyle objects.
786
        style = self.__convert_order_params_for_blotter(limit_price,
787
                                                        stop_price,
788
                                                        style)
789
        return self.blotter.order(sid, amount, style)
790
791
    def validate_order_params(self,
792
                              asset,
793
                              amount,
794
                              limit_price,
795
                              stop_price,
796
                              style):
797
        """
798
        Helper method for validating parameters to the order API function.
799
800
        Raises an UnsupportedOrderParameters if invalid arguments are found.
801
        """
802
803
        if not self.initialized:
804
            raise OrderDuringInitialize(
805
                msg="order() can only be called from within handle_data()"
806
            )
807
808
        if style:
809
            if limit_price:
810
                raise UnsupportedOrderParameters(
811
                    msg="Passing both limit_price and style is not supported."
812
                )
813
814
            if stop_price:
815
                raise UnsupportedOrderParameters(
816
                    msg="Passing both stop_price and style is not supported."
817
                )
818
819
        if not isinstance(asset, Asset):
820
            raise UnsupportedOrderParameters(
821
                msg="Passing non-Asset argument to 'order()' is not supported."
822
                    " Use 'sid()' or 'symbol()' methods to look up an Asset."
823
            )
824
825
        for control in self.trading_controls:
826
            control.validate(asset,
827
                             amount,
828
                             self.updated_portfolio(),
829
                             self.get_datetime(),
830
                             self.trading_client.current_data)
831
832
    @staticmethod
833
    def __convert_order_params_for_blotter(limit_price, stop_price, style):
834
        """
835
        Helper method for converting deprecated limit_price and stop_price
836
        arguments into ExecutionStyle instances.
837
838
        This function assumes that either style == None or (limit_price,
839
        stop_price) == (None, None).
840
        """
841
        # TODO_SS: DeprecationWarning for usage of limit_price and stop_price.
842
        if style:
843
            assert (limit_price, stop_price) == (None, None)
844
            return style
845
        if limit_price and stop_price:
846
            return StopLimitOrder(limit_price, stop_price)
847
        if limit_price:
848
            return LimitOrder(limit_price)
849
        if stop_price:
850
            return StopOrder(stop_price)
851
        else:
852
            return MarketOrder()
853
854
    @api_method
855
    def order_value(self, sid, value,
856
                    limit_price=None, stop_price=None, style=None):
857
        """
858
        Place an order by desired value rather than desired number of shares.
859
        If the requested sid exists, the requested value is
860
        divided by its price to imply the number of shares to transact.
861
        If the Asset being ordered is a Future, the 'value' calculated
862
        is actually the exposure, as Futures have no 'value'.
863
864
        value > 0 :: Buy/Cover
865
        value < 0 :: Sell/Short
866
        Market order:    order(sid, value)
867
        Limit order:     order(sid, value, limit_price)
868
        Stop order:      order(sid, value, None, stop_price)
869
        StopLimit order: order(sid, value, limit_price, stop_price)
870
        """
871
        amount = self._calculate_order_value_amount(sid, value)
872
        return self.order(sid, amount,
873
                          limit_price=limit_price,
874
                          stop_price=stop_price,
875
                          style=style)
876
877
    @property
878
    def recorded_vars(self):
879
        return copy(self._recorded_vars)
880
881
    @property
882
    def portfolio(self):
883
        return self.updated_portfolio()
884
885
    def updated_portfolio(self):
886
        if self.portfolio_needs_update:
887
            self._portfolio = \
888
                self.perf_tracker.get_portfolio(self.performance_needs_update,
889
                                                self.datetime)
890
            self.portfolio_needs_update = False
891
            self.performance_needs_update = False
892
        return self._portfolio
893
894
    @property
895
    def account(self):
896
        return self.updated_account()
897
898
    def updated_account(self):
899
        if self.account_needs_update:
900
            self._account = \
901
                self.perf_tracker.get_account(self.performance_needs_update,
902
                                              self.datetime)
903
            self.account_needs_update = False
904
            self.performance_needs_update = False
905
        return self._account
906
907
    def set_logger(self, logger):
908
        self.logger = logger
909
910
    def on_dt_changed(self, dt):
911
        """
912
        Callback triggered by the simulation loop whenever the current dt
913
        changes.
914
915
        Any logic that should happen exactly once at the start of each datetime
916
        group should happen here.
917
        """
918
        assert isinstance(dt, datetime), \
919
            "Attempt to set algorithm's current time with non-datetime"
920
        assert dt.tzinfo == pytz.utc, \
921
            "Algorithm expects a utc datetime"
922
923
        self.datetime = dt
924
        self.perf_tracker.set_date(dt)
925
        self.blotter.set_date(dt)
926
927
        self.portfolio_needs_update = True
928
        self.account_needs_update = True
929
        self.performance_needs_update = True
930
931
    @api_method
932
    def get_datetime(self, tz=None):
933
        """
934
        Returns the simulation datetime.
935
        """
936
        dt = self.datetime
937
        assert dt.tzinfo == pytz.utc, "Algorithm should have a utc datetime"
938
939
        if tz is not None:
940
            # Convert to the given timezone passed as a string or tzinfo.
941
            if isinstance(tz, string_types):
942
                tz = pytz.timezone(tz)
943
            dt = dt.astimezone(tz)
944
945
        return dt  # datetime.datetime objects are immutable.
946
947
    def update_dividends(self, dividend_frame):
948
        """
949
        Set DataFrame used to process dividends.  DataFrame columns should
950
        contain at least the entries in zp.DIVIDEND_FIELDS.
951
        """
952
        self.perf_tracker.update_dividends(dividend_frame)
953
954
    @api_method
955
    def set_slippage(self, slippage):
956
        if not isinstance(slippage, SlippageModel):
957
            raise UnsupportedSlippageModel()
958
        if self.initialized:
959
            raise SetSlippagePostInit()
960
        self.blotter.slippage_func = slippage
961
962
    @api_method
963
    def set_commission(self, commission):
964
        if not isinstance(commission, (PerShare, PerTrade, PerDollar)):
965
            raise UnsupportedCommissionModel()
966
967
        if self.initialized:
968
            raise SetCommissionPostInit()
969
        self.blotter.commission = commission
970
971
    @api_method
972
    def set_symbol_lookup_date(self, dt):
973
        """
974
        Set the date for which symbols will be resolved to their sids
975
        (symbols may map to different firms or underlying assets at
976
        different times)
977
        """
978
        try:
979
            self._symbol_lookup_date = pd.Timestamp(dt, tz='UTC')
980
        except ValueError:
981
            raise UnsupportedDatetimeFormat(input=dt,
982
                                            method='set_symbol_lookup_date')
983
984
    # Remain backwards compatibility
985
    @property
986
    def data_frequency(self):
987
        return self.sim_params.data_frequency
988
989
    @data_frequency.setter
990
    def data_frequency(self, value):
991
        assert value in ('daily', 'minute')
992
        self.sim_params.data_frequency = value
993
994
    @api_method
995
    def order_percent(self, sid, percent,
996
                      limit_price=None, stop_price=None, style=None):
997
        """
998
        Place an order in the specified asset corresponding to the given
999
        percent of the current portfolio value.
1000
1001
        Note that percent must expressed as a decimal (0.50 means 50\%).
1002
        """
1003
        value = self.portfolio.portfolio_value * percent
1004
        return self.order_value(sid, value,
1005
                                limit_price=limit_price,
1006
                                stop_price=stop_price,
1007
                                style=style)
1008
1009
    @api_method
1010
    def order_target(self, sid, target,
1011
                     limit_price=None, stop_price=None, style=None):
1012
        """
1013
        Place an order to adjust a position to a target number of shares. If
1014
        the position doesn't already exist, this is equivalent to placing a new
1015
        order. If the position does exist, this is equivalent to placing an
1016
        order for the difference between the target number of shares and the
1017
        current number of shares.
1018
        """
1019
        if sid in self.portfolio.positions:
1020
            current_position = self.portfolio.positions[sid].amount
1021
            req_shares = target - current_position
1022
            return self.order(sid, req_shares,
1023
                              limit_price=limit_price,
1024
                              stop_price=stop_price,
1025
                              style=style)
1026
        else:
1027
            return self.order(sid, target,
1028
                              limit_price=limit_price,
1029
                              stop_price=stop_price,
1030
                              style=style)
1031
1032
    @api_method
1033
    def order_target_value(self, sid, target,
1034
                           limit_price=None, stop_price=None, style=None):
1035
        """
1036
        Place an order to adjust a position to a target value. If
1037
        the position doesn't already exist, this is equivalent to placing a new
1038
        order. If the position does exist, this is equivalent to placing an
1039
        order for the difference between the target value and the
1040
        current value.
1041
        If the Asset being ordered is a Future, the 'target value' calculated
1042
        is actually the target exposure, as Futures have no 'value'.
1043
        """
1044
        target_amount = self._calculate_order_value_amount(sid, target)
1045
        return self.order_target(sid, target_amount,
1046
                                 limit_price=limit_price,
1047
                                 stop_price=stop_price,
1048
                                 style=style)
1049
1050
    @api_method
1051
    def order_target_percent(self, sid, target,
1052
                             limit_price=None, stop_price=None, style=None):
1053
        """
1054
        Place an order to adjust a position to a target percent of the
1055
        current portfolio value. If the position doesn't already exist, this is
1056
        equivalent to placing a new order. If the position does exist, this is
1057
        equivalent to placing an order for the difference between the target
1058
        percent and the current percent.
1059
1060
        Note that target must expressed as a decimal (0.50 means 50\%).
1061
        """
1062
        target_value = self.portfolio.portfolio_value * target
1063
        return self.order_target_value(sid, target_value,
1064
                                       limit_price=limit_price,
1065
                                       stop_price=stop_price,
1066
                                       style=style)
1067
1068
    @api_method
1069
    def get_open_orders(self, sid=None):
1070
        if sid is None:
1071
            return {
1072
                key: [order.to_api_obj() for order in orders]
1073
                for key, orders in iteritems(self.blotter.open_orders)
1074
                if orders
1075
            }
1076
        if sid in self.blotter.open_orders:
1077
            orders = self.blotter.open_orders[sid]
1078
            return [order.to_api_obj() for order in orders]
1079
        return []
1080
1081
    @api_method
1082
    def get_order(self, order_id):
1083
        if order_id in self.blotter.orders:
1084
            return self.blotter.orders[order_id].to_api_obj()
1085
1086
    @api_method
1087
    def cancel_order(self, order_param):
1088
        order_id = order_param
1089
        if isinstance(order_param, zipline.protocol.Order):
1090
            order_id = order_param.id
1091
1092
        self.blotter.cancel(order_id)
1093
1094
    @api_method
1095
    @require_initialized(HistoryInInitialize())
1096
    def history(self, sids, bar_count, frequency, field, ffill=True):
1097
        if self.data_portal is None:
1098
            raise Exception("no data portal!")
1099
1100
        return self.data_portal.get_history_window(
1101
            sids,
1102
            self.datetime,
1103
            bar_count,
1104
            frequency,
1105
            field,
1106
            ffill,
1107
        )
1108
1109
    ####################
1110
    # Account Controls #
1111
    ####################
1112
1113
    def register_account_control(self, control):
1114
        """
1115
        Register a new AccountControl to be checked on each bar.
1116
        """
1117
        if self.initialized:
1118
            raise RegisterAccountControlPostInit()
1119
        self.account_controls.append(control)
1120
1121
    def validate_account_controls(self):
1122
        for control in self.account_controls:
1123
            control.validate(self.updated_portfolio(),
1124
                             self.updated_account(),
1125
                             self.get_datetime(),
1126
                             self.trading_client.current_data)
1127
1128
    @api_method
1129
    def set_max_leverage(self, max_leverage=None):
1130
        """
1131
        Set a limit on the maximum leverage of the algorithm.
1132
        """
1133
        control = MaxLeverage(max_leverage)
1134
        self.register_account_control(control)
1135
1136
    ####################
1137
    # Trading Controls #
1138
    ####################
1139
1140
    def register_trading_control(self, control):
1141
        """
1142
        Register a new TradingControl to be checked prior to order calls.
1143
        """
1144
        if self.initialized:
1145
            raise RegisterTradingControlPostInit()
1146
        self.trading_controls.append(control)
1147
1148
    @api_method
1149
    def set_max_position_size(self,
1150
                              sid=None,
1151
                              max_shares=None,
1152
                              max_notional=None):
1153
        """
1154
        Set a limit on the number of shares and/or dollar value held for the
1155
        given sid. Limits are treated as absolute values and are enforced at
1156
        the time that the algo attempts to place an order for sid. This means
1157
        that it's possible to end up with more than the max number of shares
1158
        due to splits/dividends, and more than the max notional due to price
1159
        improvement.
1160
1161
        If an algorithm attempts to place an order that would result in
1162
        increasing the absolute value of shares/dollar value exceeding one of
1163
        these limits, raise a TradingControlException.
1164
        """
1165
        control = MaxPositionSize(asset=sid,
1166
                                  max_shares=max_shares,
1167
                                  max_notional=max_notional)
1168
        self.register_trading_control(control)
1169
1170
    @api_method
1171
    def set_max_order_size(self, sid=None, max_shares=None, max_notional=None):
1172
        """
1173
        Set a limit on the number of shares and/or dollar value of any single
1174
        order placed for sid.  Limits are treated as absolute values and are
1175
        enforced at the time that the algo attempts to place an order for sid.
1176
1177
        If an algorithm attempts to place an order that would result in
1178
        exceeding one of these limits, raise a TradingControlException.
1179
        """
1180
        control = MaxOrderSize(asset=sid,
1181
                               max_shares=max_shares,
1182
                               max_notional=max_notional)
1183
        self.register_trading_control(control)
1184
1185
    @api_method
1186
    def set_max_order_count(self, max_count):
1187
        """
1188
        Set a limit on the number of orders that can be placed within the given
1189
        time interval.
1190
        """
1191
        control = MaxOrderCount(max_count)
1192
        self.register_trading_control(control)
1193
1194
    @api_method
1195
    def set_do_not_order_list(self, restricted_list):
1196
        """
1197
        Set a restriction on which sids can be ordered.
1198
        """
1199
        control = RestrictedListOrder(restricted_list)
1200
        self.register_trading_control(control)
1201
1202
    @api_method
1203
    def set_long_only(self):
1204
        """
1205
        Set a rule specifying that this algorithm cannot take short positions.
1206
        """
1207
        self.register_trading_control(LongOnly())
1208
1209
    ##############
1210
    # Pipeline API
1211
    ##############
1212
    @api_method
1213
    @require_not_initialized(AttachPipelineAfterInitialize())
1214
    def attach_pipeline(self, pipeline, name, chunksize=None):
1215
        """
1216
        Register a pipeline to be computed at the start of each day.
1217
        """
1218
        if self._pipelines:
1219
            raise NotImplementedError("Multiple pipelines are not supported.")
1220
        if chunksize is None:
1221
            # Make the first chunk smaller to get more immediate results:
1222
            # (one week, then every half year)
1223
            chunks = iter(chain([5], repeat(126)))
1224
        else:
1225
            chunks = iter(repeat(int(chunksize)))
1226
        self._pipelines[name] = pipeline, chunks
1227
1228
        # Return the pipeline to allow expressions like
1229
        # p = attach_pipeline(Pipeline(), 'name')
1230
        return pipeline
1231
1232
    @api_method
1233
    @require_initialized(PipelineOutputDuringInitialize())
1234
    def pipeline_output(self, name):
1235
        """
1236
        Get the results of pipeline with name `name`.
1237
1238
        Parameters
1239
        ----------
1240
        name : str
1241
            Name of the pipeline for which results are requested.
1242
1243
        Returns
1244
        -------
1245
        results : pd.DataFrame
1246
            DataFrame containing the results of the requested pipeline for
1247
            the current simulation date.
1248
1249
        Raises
1250
        ------
1251
        NoSuchPipeline
1252
            Raised when no pipeline with the name `name` has been registered.
1253
1254
        See Also
1255
        --------
1256
        :meth:`zipline.pipeline.engine.PipelineEngine.run_pipeline`
1257
        """
1258
        # NOTE: We don't currently support multiple pipelines, but we plan to
1259
        # in the future.
1260
        try:
1261
            p, chunks = self._pipelines[name]
1262
        except KeyError:
1263
            raise NoSuchPipeline(
1264
                name=name,
1265
                valid=list(self._pipelines.keys()),
1266
            )
1267
        return self._pipeline_output(p, chunks)
1268
1269
    def _pipeline_output(self, pipeline, chunks):
1270
        """
1271
        Internal implementation of `pipeline_output`.
1272
        """
1273
        today = normalize_date(self.get_datetime())
1274
        try:
1275
            data = self._pipeline_cache.unwrap(today)
1276
        except Expired:
1277
            data, valid_until = self._run_pipeline(
1278
                pipeline, today, next(chunks),
1279
            )
1280
            self._pipeline_cache = CachedObject(data, valid_until)
1281
1282
        # Now that we have a cached result, try to return the data for today.
1283
        try:
1284
            return data.loc[today]
1285
        except KeyError:
1286
            # This happens if no assets passed the pipeline screen on a given
1287
            # day.
1288
            return pd.DataFrame(index=[], columns=data.columns)
1289
1290
    def _run_pipeline(self, pipeline, start_date, chunksize):
1291
        """
1292
        Compute `pipeline`, providing values for at least `start_date`.
1293
1294
        Produces a DataFrame containing data for days between `start_date` and
1295
        `end_date`, where `end_date` is defined by:
1296
1297
            `end_date = min(start_date + chunksize trading days,
1298
                            simulation_end)`
1299
1300
        Returns
1301
        -------
1302
        (data, valid_until) : tuple (pd.DataFrame, pd.Timestamp)
1303
1304
        See Also
1305
        --------
1306
        PipelineEngine.run_pipeline
1307
        """
1308
        days = self.trading_environment.trading_days
1309
1310
        # Load data starting from the previous trading day...
1311
        start_date_loc = days.get_loc(start_date)
1312
1313
        # ...continuing until either the day before the simulation end, or
1314
        # until chunksize days of data have been loaded.
1315
        sim_end = self.sim_params.last_close.normalize()
1316
        end_loc = min(start_date_loc + chunksize, days.get_loc(sim_end))
1317
        end_date = days[end_loc]
1318
1319
        return \
1320
            self.engine.run_pipeline(pipeline, start_date, end_date), end_date
1321
1322
    ##################
1323
    # End Pipeline API
1324
    ##################
1325
1326
    @classmethod
1327
    def all_api_methods(cls):
1328
        """
1329
        Return a list of all the TradingAlgorithm API methods.
1330
        """
1331
        return [
1332
            fn for fn in itervalues(vars(cls))
1333
            if getattr(fn, 'is_api_method', False)
1334
        ]
1335