Completed
Pull Request — master (#858)
by Eddie
01:55
created

zipline.TradingAlgorithm.attach_pipeline()   A

Complexity

Conditions 3

Size

Total Lines 19

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 3
dl 0
loc 19
rs 9.4286
1
#
2
# Copyright 2015 Quantopian, Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
from copy import copy
16
17
import pytz
18
import pandas as pd
19
from pandas.tseries.tools import normalize_date
20
import numpy as np
21
22
from datetime import datetime
23
from itertools import chain, repeat
24
from numbers import Integral
25
26
from six import (
27
    exec_,
28
    iteritems,
29
    itervalues,
30
    string_types,
31
)
32
33
34
from zipline.errors import (
35
    AttachPipelineAfterInitialize,
36
    HistoryInInitialize,
37
    NoSuchPipeline,
38
    OrderDuringInitialize,
39
    OverrideCommissionPostInit,
40
    OverrideSlippagePostInit,
41
    PipelineOutputDuringInitialize,
42
    RegisterAccountControlPostInit,
43
    RegisterTradingControlPostInit,
44
    SetBenchmarkOutsideInitialize,
45
    UnsupportedCommissionModel,
46
    UnsupportedDatetimeFormat,
47
    UnsupportedOrderParameters,
48
    UnsupportedSlippageModel,
49
)
50
from zipline.finance.trading import TradingEnvironment
51
from zipline.finance.blotter import Blotter
52
from zipline.finance.commission import PerShare, PerTrade, PerDollar
53
from zipline.finance.controls import (
54
    LongOnly,
55
    MaxOrderCount,
56
    MaxOrderSize,
57
    MaxPositionSize,
58
    MaxLeverage,
59
    RestrictedListOrder
60
)
61
from zipline.finance.execution import (
62
    LimitOrder,
63
    MarketOrder,
64
    StopLimitOrder,
65
    StopOrder,
66
)
67
from zipline.finance.performance import PerformanceTracker
68
from zipline.finance.slippage import (
69
    VolumeShareSlippage,
70
    SlippageModel
71
)
72
from zipline.assets import Asset, Future
73
from zipline.assets.futures import FutureChain
74
from zipline.gens.tradesimulation import AlgorithmSimulator
75
from zipline.pipeline.engine import (
76
    NoOpPipelineEngine,
77
    SimplePipelineEngine,
78
)
79
from zipline.utils.api_support import (
80
    api_method,
81
    require_initialized,
82
    require_not_initialized,
83
    ZiplineAPI,
84
)
85
from zipline.utils.input_validation import ensure_upper_case
86
from zipline.utils.cache import CachedObject, Expired
87
import zipline.utils.events
88
from zipline.utils.events import (
89
    EventManager,
90
    make_eventrule,
91
    DateRuleFactory,
92
    TimeRuleFactory,
93
)
94
from zipline.utils.factory import create_simulation_parameters
95
from zipline.utils.math_utils import (
96
    tolerant_equals,
97
    round_if_near_integer
98
)
99
from zipline.utils.preprocess import preprocess
100
101
import zipline.protocol
102
from zipline.sources.requests_csv import PandasRequestsCSV
103
104
from zipline.gens.sim_engine import (
105
    MinuteSimulationClock,
106
    DailySimulationClock,
107
)
108
from zipline.sources.benchmark_source import BenchmarkSource
109
110
DEFAULT_CAPITAL_BASE = float("1.0e5")
111
112
113
class TradingAlgorithm(object):
114
    """
115
    Base class for trading algorithms. Inherit and overload
116
    initialize() and handle_data(data).
117
118
    A new algorithm could look like this:
119
    ```
120
    from zipline.api import order, symbol
121
122
    def initialize(context):
123
        context.sid = symbol('AAPL')
124
        context.amount = 100
125
126
    def handle_data(context, data):
127
        sid = context.sid
128
        amount = context.amount
129
        order(sid, amount)
130
    ```
131
    To then to run this algorithm pass these functions to
132
    TradingAlgorithm:
133
134
    my_algo = TradingAlgorithm(initialize, handle_data)
135
    stats = my_algo.run(data)
136
137
    """
138
139
    def __init__(self, *args, **kwargs):
140
        """Initialize sids and other state variables.
141
142
        :Arguments:
143
        :Optional:
144
            initialize : function
145
                Function that is called with a single
146
                argument at the begninning of the simulation.
147
            handle_data : function
148
                Function that is called with 2 arguments
149
                (context and data) on every bar.
150
            script : str
151
                Algoscript that contains initialize and
152
                handle_data function definition.
153
            data_frequency : {'daily', 'minute'}
154
               The duration of the bars.
155
            capital_base : float <default: 1.0e5>
156
               How much capital to start with.
157
            asset_finder : An AssetFinder object
158
                A new AssetFinder object to be used in this TradingEnvironment
159
            equities_metadata : can be either:
160
                            - dict
161
                            - pandas.DataFrame
162
                            - object with 'read' property
163
                If dict is provided, it must have the following structure:
164
                * keys are the identifiers
165
                * values are dicts containing the metadata, with the metadata
166
                  field name as the key
167
                If pandas.DataFrame is provided, it must have the
168
                following structure:
169
                * column names must be the metadata fields
170
                * index must be the different asset identifiers
171
                * array contents should be the metadata value
172
                If an object with a 'read' property is provided, 'read' must
173
                return rows containing at least one of 'sid' or 'symbol' along
174
                with the other metadata fields.
175
            identifiers : List
176
                Any asset identifiers that are not provided in the
177
                equities_metadata, but will be traded by this TradingAlgorithm
178
        """
179
        self.sources = []
180
181
        # List of trading controls to be used to validate orders.
182
        self.trading_controls = []
183
184
        # List of account controls to be checked on each bar.
185
        self.account_controls = []
186
187
        self._recorded_vars = {}
188
        self.namespace = kwargs.pop('namespace', {})
189
190
        self._platform = kwargs.pop('platform', 'zipline')
191
192
        self.logger = None
193
194
        self.data_portal = None
195
196
        # If an env has been provided, pop it
197
        self.trading_environment = kwargs.pop('env', None)
198
199
        if self.trading_environment is None:
200
            self.trading_environment = TradingEnvironment()
201
202
        # Update the TradingEnvironment with the provided asset metadata
203
        self.trading_environment.write_data(
204
            equities_data=kwargs.pop('equities_metadata', {}),
205
            equities_identifiers=kwargs.pop('identifiers', []),
206
            futures_data=kwargs.pop('futures_metadata', {}),
207
        )
208
209
        # set the capital base
210
        self.capital_base = kwargs.pop('capital_base', DEFAULT_CAPITAL_BASE)
211
        self.sim_params = kwargs.pop('sim_params', None)
212
        if self.sim_params is None:
213
            self.sim_params = create_simulation_parameters(
214
                capital_base=self.capital_base,
215
                start=kwargs.pop('start', None),
216
                end=kwargs.pop('end', None),
217
                env=self.trading_environment,
218
            )
219
        else:
220
            self.sim_params.update_internal_from_env(self.trading_environment)
221
222
        self.perf_tracker = None
223
        # Pull in the environment's new AssetFinder for quick reference
224
        self.asset_finder = self.trading_environment.asset_finder
225
226
        # Initialize Pipeline API data.
227
        self.init_engine(kwargs.pop('get_pipeline_loader', None))
228
        self._pipelines = {}
229
        # Create an always-expired cache so that we compute the first time data
230
        # is requested.
231
        self._pipeline_cache = CachedObject(None, pd.Timestamp(0, tz='UTC'))
232
233
        self.blotter = kwargs.pop('blotter', None)
234
        if not self.blotter:
235
            self.blotter = Blotter(
236
                slippage_func=VolumeShareSlippage(),
237
                commission=PerShare(),
238
                data_frequency=self.data_frequency,
239
            )
240
241
        # The symbol lookup date specifies the date to use when resolving
242
        # symbols to sids, and can be set using set_symbol_lookup_date()
243
        self._symbol_lookup_date = None
244
245
        self.portfolio_needs_update = True
246
        self.account_needs_update = True
247
        self.performance_needs_update = True
248
        self._portfolio = None
249
        self._account = None
250
251
        # If string is passed in, execute and get reference to
252
        # functions.
253
        self.algoscript = kwargs.pop('script', None)
254
255
        self._initialize = None
256
        self._before_trading_start = None
257
        self._analyze = None
258
259
        self.event_manager = EventManager(
260
            create_context=kwargs.pop('create_event_context', None),
261
        )
262
263
        if self.algoscript is not None:
264
            filename = kwargs.pop('algo_filename', None)
265
            if filename is None:
266
                filename = '<string>'
267
            code = compile(self.algoscript, filename, 'exec')
268
            exec_(code, self.namespace)
269
            self._initialize = self.namespace.get('initialize')
270
            if 'handle_data' not in self.namespace:
271
                raise ValueError('You must define a handle_data function.')
272
            else:
273
                self._handle_data = self.namespace['handle_data']
274
275
            self._before_trading_start = \
276
                self.namespace.get('before_trading_start')
277
            # Optional analyze function, gets called after run
278
            self._analyze = self.namespace.get('analyze')
279
280
        elif kwargs.get('initialize') and kwargs.get('handle_data'):
281
            if self.algoscript is not None:
282
                raise ValueError('You can not set script and \
283
                initialize/handle_data.')
284
            self._initialize = kwargs.pop('initialize')
285
            self._handle_data = kwargs.pop('handle_data')
286
            self._before_trading_start = kwargs.pop('before_trading_start',
287
                                                    None)
288
            self._analyze = kwargs.pop('analyze', None)
289
290
        self.event_manager.add_event(
291
            zipline.utils.events.Event(
292
                zipline.utils.events.Always(),
293
                # We pass handle_data.__func__ to get the unbound method.
294
                # We will explicitly pass the algorithm to bind it again.
295
                self.handle_data.__func__,
296
            ),
297
            prepend=True,
298
        )
299
300
        # If method not defined, NOOP
301
        if self._initialize is None:
302
            self._initialize = lambda x: None
303
304
        # Alternative way of setting data_frequency for backwards
305
        # compatibility.
306
        if 'data_frequency' in kwargs:
307
            self.data_frequency = kwargs.pop('data_frequency')
308
309
        # Prepare the algo for initialization
310
        self.initialized = False
311
        self.initialize_args = args
312
        self.initialize_kwargs = kwargs
313
314
        self.benchmark_sid = kwargs.pop('benchmark_sid', None)
315
316
    def init_engine(self, get_loader):
317
        """
318
        Construct and store a PipelineEngine from loader.
319
320
        If get_loader is None, constructs a NoOpPipelineEngine.
321
        """
322
        if get_loader is not None:
323
            self.engine = SimplePipelineEngine(
324
                get_loader,
325
                self.trading_environment.trading_days,
326
                self.asset_finder,
327
            )
328
        else:
329
            self.engine = NoOpPipelineEngine()
330
331
    def initialize(self, *args, **kwargs):
332
        """
333
        Call self._initialize with `self` made available to Zipline API
334
        functions.
335
        """
336
        with ZiplineAPI(self):
337
            self._initialize(self, *args, **kwargs)
338
339
    def before_trading_start(self, data):
340
        if self._before_trading_start is None:
341
            return
342
343
        self._before_trading_start(self, data)
344
345
    def handle_data(self, data):
346
        self._handle_data(self, data)
347
348
        # Unlike trading controls which remain constant unless placing an
349
        # order, account controls can change each bar. Thus, must check
350
        # every bar no matter if the algorithm places an order or not.
351
        self.validate_account_controls()
352
353
    def analyze(self, perf):
354
        if self._analyze is None:
355
            return
356
357
        with ZiplineAPI(self):
358
            self._analyze(self, perf)
359
360
    def __repr__(self):
361
        """
362
        N.B. this does not yet represent a string that can be used
363
        to instantiate an exact copy of an algorithm.
364
365
        However, it is getting close, and provides some value as something
366
        that can be inspected interactively.
367
        """
368
        return """
369
{class_name}(
370
    capital_base={capital_base}
371
    sim_params={sim_params},
372
    initialized={initialized},
373
    slippage={slippage},
374
    commission={commission},
375
    blotter={blotter},
376
    recorded_vars={recorded_vars})
377
""".strip().format(class_name=self.__class__.__name__,
378
                   capital_base=self.capital_base,
379
                   sim_params=repr(self.sim_params),
380
                   initialized=self.initialized,
381
                   slippage=repr(self.blotter.slippage_func),
382
                   commission=repr(self.blotter.commission),
383
                   blotter=repr(self.blotter),
384
                   recorded_vars=repr(self.recorded_vars))
385
386
    def _create_clock(self):
387
        """
388
        If the clock property is not set, then create one based on frequency.
389
        """
390
        if self.sim_params.data_frequency == 'minute':
391
            env = self.trading_environment
392
            trading_o_and_c = env.open_and_closes.ix[
393
                self.sim_params.trading_days]
394
            market_opens = trading_o_and_c['market_open'].values.astype(
395
                'datetime64[ns]').astype(np.int64)
396
            market_closes = trading_o_and_c['market_close'].values.astype(
397
                'datetime64[ns]').astype(np.int64)
398
399
            minutely_emission = self.sim_params.emission_rate == "minute"
400
401
            clock = MinuteSimulationClock(
402
                self.sim_params.trading_days,
403
                market_opens,
404
                market_closes,
405
                env.trading_days,
406
                minutely_emission
407
            )
408
            self.data_portal.setup_offset_cache(
409
                clock.minutes_by_day,
410
                clock.minutes_to_day,
411
                self.sim_params.trading_days)
412
            return clock
413
        else:
414
            return DailySimulationClock(self.sim_params.trading_days)
415
416
    def _create_benchmark_source(self):
417
        return BenchmarkSource(
418
            self.benchmark_sid,
419
            self.trading_environment,
420
            self.sim_params.trading_days,
421
            self.data_portal,
422
            emission_rate=self.sim_params.emission_rate,
423
        )
424
425
    def _create_generator(self, sim_params):
426
        if sim_params is not None:
427
            self.sim_params = sim_params
428
429
        if self.perf_tracker is None:
430
            # HACK: When running with the `run` method, we set perf_tracker to
431
            # None so that it will be overwritten here.
432
            self.perf_tracker = PerformanceTracker(
433
                sim_params=self.sim_params,
434
                env=self.trading_environment,
435
                data_portal=self.data_portal
436
            )
437
438
            # Set the dt initially to the period start by forcing it to change.
439
            self.on_dt_changed(self.sim_params.period_start)
440
441
        if not self.initialized:
442
            self.initialize(*self.initialize_args, **self.initialize_kwargs)
443
            self.initialized = True
444
445
        self.trading_client = self._create_trading_client(sim_params)
446
447
        return self.trading_client.transform()
448
449
    def _create_trading_client(self, sim_params):
450
        return AlgorithmSimulator(
451
            self,
452
            sim_params,
453
            self.data_portal,
454
            self._create_clock(),
455
            self._create_benchmark_source()
456
        )
457
458
    def get_generator(self):
459
        """
460
        Override this method to add new logic to the construction
461
        of the generator. Overrides can use the _create_generator
462
        method to get a standard construction generator.
463
        """
464
        return self._create_generator(self.sim_params)
465
466
    def run(self, data_portal=None):
467
        """Run the algorithm.
468
469
        :Arguments:
470
            source : DataPortal
471
472
        :Returns:
473
            daily_stats : pandas.DataFrame
474
              Daily performance metrics such as returns, alpha etc.
475
476
        """
477
        self.data_portal = data_portal
478
479
        # Force a reset of the performance tracker, in case
480
        # this is a repeat run of the algorithm.
481
        self.perf_tracker = None
482
483
        # Create zipline and loop through simulated_trading.
484
        # Each iteration returns a perf dictionary
485
        perfs = []
486
        for perf in self.get_generator():
487
            perfs.append(perf)
488
489
        # convert perf dict to pandas dataframe
490
        daily_stats = self._create_daily_stats(perfs)
491
492
        self.analyze(daily_stats)
493
494
        return daily_stats
495
496
    def _write_and_map_id_index_to_sids(self, identifiers, as_of_date):
497
        # Build new Assets for identifiers that can't be resolved as
498
        # sids/Assets
499
        identifiers_to_build = []
500
        for identifier in identifiers:
501
            asset = None
502
503
            if isinstance(identifier, Asset):
504
                asset = self.asset_finder.retrieve_asset(sid=identifier.sid,
505
                                                         default_none=True)
506
            elif isinstance(identifier, Integral):
507
                asset = self.asset_finder.retrieve_asset(sid=identifier,
508
                                                         default_none=True)
509
            if asset is None:
510
                identifiers_to_build.append(identifier)
511
512
        self.trading_environment.write_data(
513
            equities_identifiers=identifiers_to_build)
514
515
        # We need to clear out any cache misses that were stored while trying
516
        # to do lookups.  The real fix for this problem is to not construct an
517
        # AssetFinder until we `run()` when we actually have all the data we
518
        # need to so.
519
        self.asset_finder._reset_caches()
520
521
        return self.asset_finder.map_identifier_index_to_sids(
522
            identifiers, as_of_date,
523
        )
524
525
    def _create_daily_stats(self, perfs):
526
        # create daily and cumulative stats dataframe
527
        daily_perfs = []
528
        # TODO: the loop here could overwrite expected properties
529
        # of daily_perf. Could potentially raise or log a
530
        # warning.
531
        for perf in perfs:
532
            if 'daily_perf' in perf:
533
534
                perf['daily_perf'].update(
535
                    perf['daily_perf'].pop('recorded_vars')
536
                )
537
                perf['daily_perf'].update(perf['cumulative_risk_metrics'])
538
                daily_perfs.append(perf['daily_perf'])
539
            else:
540
                self.risk_report = perf
541
542
        daily_dts = [np.datetime64(perf['period_close'], utc=True)
543
                     for perf in daily_perfs]
544
        daily_stats = pd.DataFrame(daily_perfs, index=daily_dts)
545
546
        return daily_stats
547
548
    @api_method
549
    def get_environment(self, field='platform'):
550
        env = {
551
            'arena': self.sim_params.arena,
552
            'data_frequency': self.sim_params.data_frequency,
553
            'start': self.sim_params.first_open,
554
            'end': self.sim_params.last_close,
555
            'capital_base': self.sim_params.capital_base,
556
            'platform': self._platform
557
        }
558
        if field == '*':
559
            return env
560
        else:
561
            return env[field]
562
563
    @api_method
564
    def fetch_csv(self, url,
565
                  pre_func=None,
566
                  post_func=None,
567
                  date_column='date',
568
                  date_format=None,
569
                  timezone=pytz.utc.zone,
570
                  symbol=None,
571
                  mask=True,
572
                  symbol_column=None,
573
                  special_params_checker=None,
574
                  **kwargs):
575
576
        # Show all the logs every time fetcher is used.
577
        csv_data_source = PandasRequestsCSV(
578
            url,
579
            pre_func,
580
            post_func,
581
            self.trading_environment,
582
            self.sim_params.period_start,
583
            self.sim_params.period_end,
584
            date_column,
585
            date_format,
586
            timezone,
587
            symbol,
588
            mask,
589
            symbol_column,
590
            data_frequency=self.data_frequency,
591
            special_params_checker=special_params_checker,
592
            **kwargs
593
        )
594
595
        # ingest this into dataportal
596
        self.data_portal.handle_extra_source(csv_data_source.df,
597
                                             self.sim_params)
598
599
        return csv_data_source
600
601
    def add_event(self, rule=None, callback=None):
602
        """
603
        Adds an event to the algorithm's EventManager.
604
        """
605
        self.event_manager.add_event(
606
            zipline.utils.events.Event(rule, callback),
607
        )
608
609
    @api_method
610
    def schedule_function(self,
611
                          func,
612
                          date_rule=None,
613
                          time_rule=None,
614
                          half_days=True):
615
        """
616
        Schedules a function to be called with some timed rules.
617
        """
618
        date_rule = date_rule or DateRuleFactory.every_day()
619
        time_rule = ((time_rule or TimeRuleFactory.market_open())
620
                     if self.sim_params.data_frequency == 'minute' else
621
                     # If we are in daily mode the time_rule is ignored.
622
                     zipline.utils.events.Always())
623
624
        self.add_event(
625
            make_eventrule(date_rule, time_rule, half_days),
626
            func,
627
        )
628
629
    @api_method
630
    def record(self, *args, **kwargs):
631
        """
632
        Track and record local variable (i.e. attributes) each day.
633
        """
634
        # Make 2 objects both referencing the same iterator
635
        args = [iter(args)] * 2
636
637
        # Zip generates list entries by calling `next` on each iterator it
638
        # receives.  In this case the two iterators are the same object, so the
639
        # call to next on args[0] will also advance args[1], resulting in zip
640
        # returning (a,b) (c,d) (e,f) rather than (a,a) (b,b) (c,c) etc.
641
        positionals = zip(*args)
642
        for name, value in chain(positionals, iteritems(kwargs)):
643
            self._recorded_vars[name] = value
644
645
    @api_method
646
    def set_benchmark(self, benchmark_sid):
647
        if self.initialized:
648
            raise SetBenchmarkOutsideInitialize()
649
650
        self.benchmark_sid = benchmark_sid
651
652
    @api_method
653
    @preprocess(symbol_str=ensure_upper_case)
654
    def symbol(self, symbol_str):
655
        """
656
        Default symbol lookup for any source that directly maps the
657
        symbol to the Asset (e.g. yahoo finance).
658
        """
659
        # If the user has not set the symbol lookup date,
660
        # use the period_end as the date for sybmol->sid resolution.
661
        _lookup_date = self._symbol_lookup_date if self._symbol_lookup_date is not None \
662
            else self.sim_params.period_end
663
664
        return self.asset_finder.lookup_symbol(
665
            symbol_str,
666
            as_of_date=_lookup_date,
667
        )
668
669
    @api_method
670
    def symbols(self, *args):
671
        """
672
        Default symbols lookup for any source that directly maps the
673
        symbol to the Asset (e.g. yahoo finance).
674
        """
675
        return [self.symbol(identifier) for identifier in args]
676
677
    @api_method
678
    def sid(self, a_sid):
679
        """
680
        Default sid lookup for any source that directly maps the integer sid
681
        to the Asset.
682
        """
683
        return self.asset_finder.retrieve_asset(a_sid)
684
685
    @api_method
686
    @preprocess(symbol=ensure_upper_case)
687
    def future_symbol(self, symbol):
688
        """ Lookup a futures contract with a given symbol.
689
690
        Parameters
691
        ----------
692
        symbol : str
693
            The symbol of the desired contract.
694
695
        Returns
696
        -------
697
        Future
698
            A Future object.
699
700
        Raises
701
        ------
702
        SymbolNotFound
703
            Raised when no contract named 'symbol' is found.
704
705
        """
706
        return self.asset_finder.lookup_future_symbol(symbol)
707
708
    @api_method
709
    @preprocess(root_symbol=ensure_upper_case)
710
    def future_chain(self, root_symbol, as_of_date=None):
711
        """ Look up a future chain with the specified parameters.
712
713
        Parameters
714
        ----------
715
        root_symbol : str
716
            The root symbol of a future chain.
717
        as_of_date : datetime.datetime or pandas.Timestamp or str, optional
718
            Date at which the chain determination is rooted. I.e. the
719
            existing contract whose notice date is first after this date is
720
            the primary contract, etc.
721
722
        Returns
723
        -------
724
        FutureChain
725
            The future chain matching the specified parameters.
726
727
        Raises
728
        ------
729
        RootSymbolNotFound
730
            If a future chain could not be found for the given root symbol.
731
        """
732
        if as_of_date:
733
            try:
734
                as_of_date = pd.Timestamp(as_of_date, tz='UTC')
735
            except ValueError:
736
                raise UnsupportedDatetimeFormat(input=as_of_date,
737
                                                method='future_chain')
738
        return FutureChain(
739
            asset_finder=self.asset_finder,
740
            get_datetime=self.get_datetime,
741
            root_symbol=root_symbol,
742
            as_of_date=as_of_date
743
        )
744
745
    def _calculate_order_value_amount(self, asset, value):
746
        """
747
        Calculates how many shares/contracts to order based on the type of
748
        asset being ordered.
749
        """
750
        last_price = self.trading_client.current_data[asset].price
751
752
        if tolerant_equals(last_price, 0):
753
            zero_message = "Price of 0 for {psid}; can't infer value".format(
754
                psid=asset
755
            )
756
            if self.logger:
757
                self.logger.debug(zero_message)
758
            # Don't place any order
759
            return 0
760
761
        if isinstance(asset, Future):
762
            value_multiplier = asset.contract_multiplier
763
        else:
764
            value_multiplier = 1
765
766
        return value / (last_price * value_multiplier)
767
768
    @api_method
769
    def order(self, sid, amount,
770
              limit_price=None,
771
              stop_price=None,
772
              style=None):
773
        """
774
        Place an order using the specified parameters.
775
        """
776
        # Truncate to the integer share count that's either within .0001 of
777
        # amount or closer to zero.
778
        # E.g. 3.9999 -> 4.0; 5.5 -> 5.0; -5.5 -> -5.0
779
        amount = int(round_if_near_integer(amount))
780
781
        # Raises a ZiplineError if invalid parameters are detected.
782
        self.validate_order_params(sid,
783
                                   amount,
784
                                   limit_price,
785
                                   stop_price,
786
                                   style)
787
788
        # Convert deprecated limit_price and stop_price parameters to use
789
        # ExecutionStyle objects.
790
        style = self.__convert_order_params_for_blotter(limit_price,
791
                                                        stop_price,
792
                                                        style)
793
        return self.blotter.order(sid, amount, style)
794
795
    def validate_order_params(self,
796
                              asset,
797
                              amount,
798
                              limit_price,
799
                              stop_price,
800
                              style):
801
        """
802
        Helper method for validating parameters to the order API function.
803
804
        Raises an UnsupportedOrderParameters if invalid arguments are found.
805
        """
806
807
        if not self.initialized:
808
            raise OrderDuringInitialize(
809
                msg="order() can only be called from within handle_data()"
810
            )
811
812
        if style:
813
            if limit_price:
814
                raise UnsupportedOrderParameters(
815
                    msg="Passing both limit_price and style is not supported."
816
                )
817
818
            if stop_price:
819
                raise UnsupportedOrderParameters(
820
                    msg="Passing both stop_price and style is not supported."
821
                )
822
823
        if not isinstance(asset, Asset):
824
            raise UnsupportedOrderParameters(
825
                msg="Passing non-Asset argument to 'order()' is not supported."
826
                    " Use 'sid()' or 'symbol()' methods to look up an Asset."
827
            )
828
829
        for control in self.trading_controls:
830
            control.validate(asset,
831
                             amount,
832
                             self.portfolio,
833
                             self.get_datetime(),
834
                             self.trading_client.current_data)
835
836
    @staticmethod
837
    def __convert_order_params_for_blotter(limit_price, stop_price, style):
838
        """
839
        Helper method for converting deprecated limit_price and stop_price
840
        arguments into ExecutionStyle instances.
841
842
        This function assumes that either style == None or (limit_price,
843
        stop_price) == (None, None).
844
        """
845
        # TODO_SS: DeprecationWarning for usage of limit_price and stop_price.
846
        if style:
847
            assert (limit_price, stop_price) == (None, None)
848
            return style
849
        if limit_price and stop_price:
850
            return StopLimitOrder(limit_price, stop_price)
851
        if limit_price:
852
            return LimitOrder(limit_price)
853
        if stop_price:
854
            return StopOrder(stop_price)
855
        else:
856
            return MarketOrder()
857
858
    @api_method
859
    def order_value(self, sid, value,
860
                    limit_price=None, stop_price=None, style=None):
861
        """
862
        Place an order by desired value rather than desired number of shares.
863
        If the requested sid exists, the requested value is
864
        divided by its price to imply the number of shares to transact.
865
        If the Asset being ordered is a Future, the 'value' calculated
866
        is actually the exposure, as Futures have no 'value'.
867
868
        value > 0 :: Buy/Cover
869
        value < 0 :: Sell/Short
870
        Market order:    order(sid, value)
871
        Limit order:     order(sid, value, limit_price)
872
        Stop order:      order(sid, value, None, stop_price)
873
        StopLimit order: order(sid, value, limit_price, stop_price)
874
        """
875
        amount = self._calculate_order_value_amount(sid, value)
876
        return self.order(sid, amount,
877
                          limit_price=limit_price,
878
                          stop_price=stop_price,
879
                          style=style)
880
881
    @property
882
    def recorded_vars(self):
883
        return copy(self._recorded_vars)
884
885
    @property
886
    def portfolio(self):
887
        return self.updated_portfolio()
888
889
    def updated_portfolio(self):
890
        if self.portfolio_needs_update:
891
            self._portfolio = \
892
                self.perf_tracker.get_portfolio(self.performance_needs_update,
893
                                                self.datetime)
894
            self.portfolio_needs_update = False
895
            self.performance_needs_update = False
896
        return self._portfolio
897
898
    @property
899
    def account(self):
900
        return self.updated_account()
901
902
    def updated_account(self):
903
        if self.account_needs_update:
904
            self._account = \
905
                self.perf_tracker.get_account(self.performance_needs_update,
906
                                              self.datetime)
907
            self.account_needs_update = False
908
            self.performance_needs_update = False
909
        return self._account
910
911
    def set_logger(self, logger):
912
        self.logger = logger
913
914
    def on_dt_changed(self, dt):
915
        """
916
        Callback triggered by the simulation loop whenever the current dt
917
        changes.
918
919
        Any logic that should happen exactly once at the start of each datetime
920
        group should happen here.
921
        """
922
        assert isinstance(dt, datetime), \
923
            "Attempt to set algorithm's current time with non-datetime"
924
        assert dt.tzinfo == pytz.utc, \
925
            "Algorithm expects a utc datetime"
926
927
        self.datetime = dt
928
        self.perf_tracker.set_date(dt)
929
        self.blotter.set_date(dt)
930
931
        self._portfolio = None
932
        self._account = None
933
934
        self.portfolio_needs_update = True
935
        self.account_needs_update = True
936
        self.performance_needs_update = True
937
938
    @api_method
939
    def get_datetime(self, tz=None):
940
        """
941
        Returns the simulation datetime.
942
        """
943
        dt = self.datetime
944
        assert dt.tzinfo == pytz.utc, "Algorithm should have a utc datetime"
945
946
        if tz is not None:
947
            # Convert to the given timezone passed as a string or tzinfo.
948
            if isinstance(tz, string_types):
949
                tz = pytz.timezone(tz)
950
            dt = dt.astimezone(tz)
951
952
        return dt  # datetime.datetime objects are immutable.
953
954
    def update_dividends(self, dividend_frame):
955
        """
956
        Set DataFrame used to process dividends.  DataFrame columns should
957
        contain at least the entries in zp.DIVIDEND_FIELDS.
958
        """
959
        self.perf_tracker.update_dividends(dividend_frame)
960
961
    @api_method
962
    def set_slippage(self, slippage):
963
        if not isinstance(slippage, SlippageModel):
964
            raise UnsupportedSlippageModel()
965
        if self.initialized:
966
            raise OverrideSlippagePostInit()
967
        self.blotter.slippage_func = slippage
968
969
    @api_method
970
    def set_commission(self, commission):
971
        if not isinstance(commission, (PerShare, PerTrade, PerDollar)):
972
            raise UnsupportedCommissionModel()
973
974
        if self.initialized:
975
            raise OverrideCommissionPostInit()
976
        self.blotter.commission = commission
977
978
    @api_method
979
    def set_symbol_lookup_date(self, dt):
980
        """
981
        Set the date for which symbols will be resolved to their sids
982
        (symbols may map to different firms or underlying assets at
983
        different times)
984
        """
985
        try:
986
            self._symbol_lookup_date = pd.Timestamp(dt, tz='UTC')
987
        except ValueError:
988
            raise UnsupportedDatetimeFormat(input=dt,
989
                                            method='set_symbol_lookup_date')
990
991
    # Remain backwards compatibility
992
    @property
993
    def data_frequency(self):
994
        return self.sim_params.data_frequency
995
996
    @data_frequency.setter
997
    def data_frequency(self, value):
998
        assert value in ('daily', 'minute')
999
        self.sim_params.data_frequency = value
1000
1001
    @api_method
1002
    def order_percent(self, sid, percent,
1003
                      limit_price=None, stop_price=None, style=None):
1004
        """
1005
        Place an order in the specified asset corresponding to the given
1006
        percent of the current portfolio value.
1007
1008
        Note that percent must expressed as a decimal (0.50 means 50\%).
1009
        """
1010
        value = self.portfolio.portfolio_value * percent
1011
        return self.order_value(sid, value,
1012
                                limit_price=limit_price,
1013
                                stop_price=stop_price,
1014
                                style=style)
1015
1016
    @api_method
1017
    def order_target(self, sid, target,
1018
                     limit_price=None, stop_price=None, style=None):
1019
        """
1020
        Place an order to adjust a position to a target number of shares. If
1021
        the position doesn't already exist, this is equivalent to placing a new
1022
        order. If the position does exist, this is equivalent to placing an
1023
        order for the difference between the target number of shares and the
1024
        current number of shares.
1025
        """
1026
        if sid in self.portfolio.positions:
1027
            current_position = self.portfolio.positions[sid].amount
1028
            req_shares = target - current_position
1029
            return self.order(sid, req_shares,
1030
                              limit_price=limit_price,
1031
                              stop_price=stop_price,
1032
                              style=style)
1033
        else:
1034
            return self.order(sid, target,
1035
                              limit_price=limit_price,
1036
                              stop_price=stop_price,
1037
                              style=style)
1038
1039
    @api_method
1040
    def order_target_value(self, sid, target,
1041
                           limit_price=None, stop_price=None, style=None):
1042
        """
1043
        Place an order to adjust a position to a target value. If
1044
        the position doesn't already exist, this is equivalent to placing a new
1045
        order. If the position does exist, this is equivalent to placing an
1046
        order for the difference between the target value and the
1047
        current value.
1048
        If the Asset being ordered is a Future, the 'target value' calculated
1049
        is actually the target exposure, as Futures have no 'value'.
1050
        """
1051
        target_amount = self._calculate_order_value_amount(sid, target)
1052
        return self.order_target(sid, target_amount,
1053
                                 limit_price=limit_price,
1054
                                 stop_price=stop_price,
1055
                                 style=style)
1056
1057
    @api_method
1058
    def order_target_percent(self, sid, target,
1059
                             limit_price=None, stop_price=None, style=None):
1060
        """
1061
        Place an order to adjust a position to a target percent of the
1062
        current portfolio value. If the position doesn't already exist, this is
1063
        equivalent to placing a new order. If the position does exist, this is
1064
        equivalent to placing an order for the difference between the target
1065
        percent and the current percent.
1066
1067
        Note that target must expressed as a decimal (0.50 means 50\%).
1068
        """
1069
        target_value = self.portfolio.portfolio_value * target
1070
        return self.order_target_value(sid, target_value,
1071
                                       limit_price=limit_price,
1072
                                       stop_price=stop_price,
1073
                                       style=style)
1074
1075
    @api_method
1076
    def get_open_orders(self, sid=None):
1077
        if sid is None:
1078
            return {
1079
                key: [order.to_api_obj() for order in orders]
1080
                for key, orders in iteritems(self.blotter.open_orders)
1081
                if orders
1082
            }
1083
        if sid in self.blotter.open_orders:
1084
            orders = self.blotter.open_orders[sid]
1085
            return [order.to_api_obj() for order in orders]
1086
        return []
1087
1088
    @api_method
1089
    def get_order(self, order_id):
1090
        if order_id in self.blotter.orders:
1091
            return self.blotter.orders[order_id].to_api_obj()
1092
1093
    @api_method
1094
    def cancel_order(self, order_param):
1095
        order_id = order_param
1096
        if isinstance(order_param, zipline.protocol.Order):
1097
            order_id = order_param.id
1098
1099
        self.blotter.cancel(order_id)
1100
1101
    @api_method
1102
    @require_initialized(HistoryInInitialize())
1103
    def history(self, sids, bar_count, frequency, field, ffill=True):
1104
        if self.data_portal is None:
1105
            raise Exception("no data portal!")
1106
1107
        return self.data_portal.get_history_window(
1108
            sids,
1109
            self.datetime,
1110
            bar_count,
1111
            frequency,
1112
            field,
1113
            ffill,
1114
        )
1115
1116
    ####################
1117
    # Account Controls #
1118
    ####################
1119
1120
    def register_account_control(self, control):
1121
        """
1122
        Register a new AccountControl to be checked on each bar.
1123
        """
1124
        if self.initialized:
1125
            raise RegisterAccountControlPostInit()
1126
        self.account_controls.append(control)
1127
1128
    def validate_account_controls(self):
1129
        for control in self.account_controls:
1130
            control.validate(self.portfolio,
1131
                             self.account,
1132
                             self.get_datetime(),
1133
                             self.trading_client.current_data)
1134
1135
    @api_method
1136
    def set_max_leverage(self, max_leverage=None):
1137
        """
1138
        Set a limit on the maximum leverage of the algorithm.
1139
        """
1140
        control = MaxLeverage(max_leverage)
1141
        self.register_account_control(control)
1142
1143
    ####################
1144
    # Trading Controls #
1145
    ####################
1146
1147
    def register_trading_control(self, control):
1148
        """
1149
        Register a new TradingControl to be checked prior to order calls.
1150
        """
1151
        if self.initialized:
1152
            raise RegisterTradingControlPostInit()
1153
        self.trading_controls.append(control)
1154
1155
    @api_method
1156
    def set_max_position_size(self,
1157
                              sid=None,
1158
                              max_shares=None,
1159
                              max_notional=None):
1160
        """
1161
        Set a limit on the number of shares and/or dollar value held for the
1162
        given sid. Limits are treated as absolute values and are enforced at
1163
        the time that the algo attempts to place an order for sid. This means
1164
        that it's possible to end up with more than the max number of shares
1165
        due to splits/dividends, and more than the max notional due to price
1166
        improvement.
1167
1168
        If an algorithm attempts to place an order that would result in
1169
        increasing the absolute value of shares/dollar value exceeding one of
1170
        these limits, raise a TradingControlException.
1171
        """
1172
        control = MaxPositionSize(asset=sid,
1173
                                  max_shares=max_shares,
1174
                                  max_notional=max_notional)
1175
        self.register_trading_control(control)
1176
1177
    @api_method
1178
    def set_max_order_size(self, sid=None, max_shares=None, max_notional=None):
1179
        """
1180
        Set a limit on the number of shares and/or dollar value of any single
1181
        order placed for sid.  Limits are treated as absolute values and are
1182
        enforced at the time that the algo attempts to place an order for sid.
1183
1184
        If an algorithm attempts to place an order that would result in
1185
        exceeding one of these limits, raise a TradingControlException.
1186
        """
1187
        control = MaxOrderSize(asset=sid,
1188
                               max_shares=max_shares,
1189
                               max_notional=max_notional)
1190
        self.register_trading_control(control)
1191
1192
    @api_method
1193
    def set_max_order_count(self, max_count):
1194
        """
1195
        Set a limit on the number of orders that can be placed within the given
1196
        time interval.
1197
        """
1198
        control = MaxOrderCount(max_count)
1199
        self.register_trading_control(control)
1200
1201
    @api_method
1202
    def set_do_not_order_list(self, restricted_list):
1203
        """
1204
        Set a restriction on which sids can be ordered.
1205
        """
1206
        control = RestrictedListOrder(restricted_list)
1207
        self.register_trading_control(control)
1208
1209
    @api_method
1210
    def set_long_only(self):
1211
        """
1212
        Set a rule specifying that this algorithm cannot take short positions.
1213
        """
1214
        self.register_trading_control(LongOnly())
1215
1216
    ##############
1217
    # Pipeline API
1218
    ##############
1219
    @api_method
1220
    @require_not_initialized(AttachPipelineAfterInitialize())
1221
    def attach_pipeline(self, pipeline, name, chunksize=None):
1222
        """
1223
        Register a pipeline to be computed at the start of each day.
1224
        """
1225
        if self._pipelines:
1226
            raise NotImplementedError("Multiple pipelines are not supported.")
1227
        if chunksize is None:
1228
            # Make the first chunk smaller to get more immediate results:
1229
            # (one week, then every half year)
1230
            chunks = iter(chain([5], repeat(126)))
1231
        else:
1232
            chunks = iter(repeat(int(chunksize)))
1233
        self._pipelines[name] = pipeline, chunks
1234
1235
        # Return the pipeline to allow expressions like
1236
        # p = attach_pipeline(Pipeline(), 'name')
1237
        return pipeline
1238
1239
    @api_method
1240
    @require_initialized(PipelineOutputDuringInitialize())
1241
    def pipeline_output(self, name):
1242
        """
1243
        Get the results of pipeline with name `name`.
1244
1245
        Parameters
1246
        ----------
1247
        name : str
1248
            Name of the pipeline for which results are requested.
1249
1250
        Returns
1251
        -------
1252
        results : pd.DataFrame
1253
            DataFrame containing the results of the requested pipeline for
1254
            the current simulation date.
1255
1256
        Raises
1257
        ------
1258
        NoSuchPipeline
1259
            Raised when no pipeline with the name `name` has been registered.
1260
1261
        See Also
1262
        --------
1263
        :meth:`zipline.pipeline.engine.PipelineEngine.run_pipeline`
1264
        """
1265
        # NOTE: We don't currently support multiple pipelines, but we plan to
1266
        # in the future.
1267
        try:
1268
            p, chunks = self._pipelines[name]
1269
        except KeyError:
1270
            raise NoSuchPipeline(
1271
                name=name,
1272
                valid=list(self._pipelines.keys()),
1273
            )
1274
        return self._pipeline_output(p, chunks)
1275
1276
    def _pipeline_output(self, pipeline, chunks):
1277
        """
1278
        Internal implementation of `pipeline_output`.
1279
        """
1280
        today = normalize_date(self.get_datetime())
1281
        try:
1282
            data = self._pipeline_cache.unwrap(today)
1283
        except Expired:
1284
            data, valid_until = self._run_pipeline(
1285
                pipeline, today, next(chunks),
1286
            )
1287
            self._pipeline_cache = CachedObject(data, valid_until)
1288
1289
        # Now that we have a cached result, try to return the data for today.
1290
        try:
1291
            return data.loc[today]
1292
        except KeyError:
1293
            # This happens if no assets passed the pipeline screen on a given
1294
            # day.
1295
            return pd.DataFrame(index=[], columns=data.columns)
1296
1297
    def _run_pipeline(self, pipeline, start_date, chunksize):
1298
        """
1299
        Compute `pipeline`, providing values for at least `start_date`.
1300
1301
        Produces a DataFrame containing data for days between `start_date` and
1302
        `end_date`, where `end_date` is defined by:
1303
1304
            `end_date = min(start_date + chunksize trading days,
1305
                            simulation_end)`
1306
1307
        Returns
1308
        -------
1309
        (data, valid_until) : tuple (pd.DataFrame, pd.Timestamp)
1310
1311
        See Also
1312
        --------
1313
        PipelineEngine.run_pipeline
1314
        """
1315
        days = self.trading_environment.trading_days
1316
1317
        # Load data starting from the previous trading day...
1318
        start_date_loc = days.get_loc(start_date)
1319
1320
        # ...continuing until either the day before the simulation end, or
1321
        # until chunksize days of data have been loaded.
1322
        sim_end = self.sim_params.last_close.normalize()
1323
        end_loc = min(start_date_loc + chunksize, days.get_loc(sim_end))
1324
        end_date = days[end_loc]
1325
1326
        return \
1327
            self.engine.run_pipeline(pipeline, start_date, end_date), end_date
1328
1329
    ##################
1330
    # End Pipeline API
1331
    ##################
1332
1333
    @classmethod
1334
    def all_api_methods(cls):
1335
        """
1336
        Return a list of all the TradingAlgorithm API methods.
1337
        """
1338
        return [
1339
            fn for fn in itervalues(vars(cls))
1340
            if getattr(fn, 'is_api_method', False)
1341
        ]
1342