Completed
Pull Request — master (#858)
by Eddie
02:21
created

zipline.TradingAlgorithm.future_chain()   B

Complexity

Conditions 3

Size

Total Lines 35

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 3
dl 0
loc 35
rs 8.8571
1
#
2
# Copyright 2015 Quantopian, Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
from copy import copy
16
17
import pytz
18
import pandas as pd
19
from pandas.tseries.tools import normalize_date
20
import numpy as np
21
22
from datetime import datetime
23
from itertools import chain, repeat
24
from numbers import Integral
25
26
from six import (
27
    exec_,
28
    iteritems,
29
    itervalues,
30
    string_types,
31
)
32
33
34
from zipline.errors import (
35
    AttachPipelineAfterInitialize,
36
    HistoryInInitialize,
37
    NoSuchPipeline,
38
    OrderDuringInitialize,
39
    OverrideCommissionPostInit,
40
    OverrideSlippagePostInit,
41
    PipelineOutputDuringInitialize,
42
    RegisterAccountControlPostInit,
43
    RegisterTradingControlPostInit,
44
    SetBenchmarkOutsideInitialize,
45
    UnsupportedCommissionModel,
46
    UnsupportedDatetimeFormat,
47
    UnsupportedOrderParameters,
48
    UnsupportedSlippageModel,
49
)
50
from zipline.finance.trading import TradingEnvironment
51
from zipline.finance.blotter import Blotter
52
from zipline.finance.commission import PerShare, PerTrade, PerDollar
53
from zipline.finance.controls import (
54
    LongOnly,
55
    MaxOrderCount,
56
    MaxOrderSize,
57
    MaxPositionSize,
58
    MaxLeverage,
59
    RestrictedListOrder
60
)
61
from zipline.finance.execution import (
62
    LimitOrder,
63
    MarketOrder,
64
    StopLimitOrder,
65
    StopOrder,
66
)
67
from zipline.finance.performance import PerformanceTracker
68
from zipline.finance.slippage import (
69
    VolumeShareSlippage,
70
    SlippageModel
71
)
72
from zipline.assets import Asset, Future
73
from zipline.assets.futures import FutureChain
74
from zipline.gens.tradesimulation import AlgorithmSimulator
75
from zipline.pipeline.engine import (
76
    NoOpPipelineEngine,
77
    SimplePipelineEngine,
78
)
79
from zipline.utils.api_support import (
80
    api_method,
81
    require_initialized,
82
    require_not_initialized,
83
    ZiplineAPI,
84
)
85
from zipline.utils.input_validation import ensure_upper_case
86
from zipline.utils.cache import (
87
    CachedObject,
88
    Expired
89
)
90
import zipline.utils.events
91
from zipline.utils.events import (
92
    EventManager,
93
    make_eventrule,
94
    DateRuleFactory,
95
    TimeRuleFactory,
96
)
97
from zipline.utils.factory import create_simulation_parameters
98
from zipline.utils.math_utils import (
99
    tolerant_equals,
100
    round_if_near_integer
101
)
102
from zipline.utils.preprocess import preprocess
103
104
import zipline.protocol
105
from zipline.sources.requests_csv import PandasRequestsCSV
106
107
from zipline.gens.sim_engine import (
108
    MinuteSimulationClock,
109
    DailySimulationClock,
110
)
111
from zipline.sources.benchmark_source import BenchmarkSource
112
113
DEFAULT_CAPITAL_BASE = float("1.0e5")
114
115
116
class TradingAlgorithm(object):
117
    """
118
    Base class for trading algorithms. Inherit and overload
119
    initialize() and handle_data(data).
120
121
    A new algorithm could look like this:
122
    ```
123
    from zipline.api import order, symbol
124
125
    def initialize(context):
126
        context.sid = symbol('AAPL')
127
        context.amount = 100
128
129
    def handle_data(context, data):
130
        sid = context.sid
131
        amount = context.amount
132
        order(sid, amount)
133
    ```
134
    To then to run this algorithm pass these functions to
135
    TradingAlgorithm:
136
137
    my_algo = TradingAlgorithm(initialize, handle_data)
138
    stats = my_algo.run(data)
139
140
    """
141
142
    def __init__(self, *args, **kwargs):
143
        """Initialize sids and other state variables.
144
145
        :Arguments:
146
        :Optional:
147
            initialize : function
148
                Function that is called with a single
149
                argument at the begninning of the simulation.
150
            handle_data : function
151
                Function that is called with 2 arguments
152
                (context and data) on every bar.
153
            script : str
154
                Algoscript that contains initialize and
155
                handle_data function definition.
156
            data_frequency : {'daily', 'minute'}
157
               The duration of the bars.
158
            capital_base : float <default: 1.0e5>
159
               How much capital to start with.
160
            asset_finder : An AssetFinder object
161
                A new AssetFinder object to be used in this TradingEnvironment
162
            equities_metadata : can be either:
163
                            - dict
164
                            - pandas.DataFrame
165
                            - object with 'read' property
166
                If dict is provided, it must have the following structure:
167
                * keys are the identifiers
168
                * values are dicts containing the metadata, with the metadata
169
                  field name as the key
170
                If pandas.DataFrame is provided, it must have the
171
                following structure:
172
                * column names must be the metadata fields
173
                * index must be the different asset identifiers
174
                * array contents should be the metadata value
175
                If an object with a 'read' property is provided, 'read' must
176
                return rows containing at least one of 'sid' or 'symbol' along
177
                with the other metadata fields.
178
            identifiers : List
179
                Any asset identifiers that are not provided in the
180
                equities_metadata, but will be traded by this TradingAlgorithm
181
        """
182
        self.sources = []
183
184
        # List of trading controls to be used to validate orders.
185
        self.trading_controls = []
186
187
        # List of account controls to be checked on each bar.
188
        self.account_controls = []
189
190
        self._recorded_vars = {}
191
        self.namespace = kwargs.pop('namespace', {})
192
193
        self._platform = kwargs.pop('platform', 'zipline')
194
195
        self.logger = None
196
197
        self.data_portal = None
198
199
        # If an env has been provided, pop it
200
        self.trading_environment = kwargs.pop('env', None)
201
202
        if self.trading_environment is None:
203
            self.trading_environment = TradingEnvironment()
204
205
        # Update the TradingEnvironment with the provided asset metadata
206
        self.trading_environment.write_data(
207
            equities_data=kwargs.pop('equities_metadata', {}),
208
            equities_identifiers=kwargs.pop('identifiers', []),
209
            futures_data=kwargs.pop('futures_metadata', {}),
210
        )
211
212
        # set the capital base
213
        self.capital_base = kwargs.pop('capital_base', DEFAULT_CAPITAL_BASE)
214
        self.sim_params = kwargs.pop('sim_params', None)
215
        if self.sim_params is None:
216
            self.sim_params = create_simulation_parameters(
217
                capital_base=self.capital_base,
218
                start=kwargs.pop('start', None),
219
                end=kwargs.pop('end', None),
220
                env=self.trading_environment,
221
            )
222
        else:
223
            self.sim_params.update_internal_from_env(self.trading_environment)
224
225
        self.perf_tracker = None
226
        # Pull in the environment's new AssetFinder for quick reference
227
        self.asset_finder = self.trading_environment.asset_finder
228
229
        # Initialize Pipeline API data.
230
        self.init_engine(kwargs.pop('get_pipeline_loader', None))
231
        self._pipelines = {}
232
        # Create an always-expired cache so that we compute the first time data
233
        # is requested.
234
        self._pipeline_cache = CachedObject(None, pd.Timestamp(0, tz='UTC'))
235
236
        self.blotter = kwargs.pop('blotter', None)
237
        if not self.blotter:
238
            self.blotter = Blotter(
239
                slippage_func=VolumeShareSlippage(),
240
                commission=PerShare(),
241
                data_frequency=self.data_frequency,
242
            )
243
244
        # The symbol lookup date specifies the date to use when resolving
245
        # symbols to sids, and can be set using set_symbol_lookup_date()
246
        self._symbol_lookup_date = None
247
248
        self._portfolio = None
249
        self._account = None
250
251
        # If string is passed in, execute and get reference to
252
        # functions.
253
        self.algoscript = kwargs.pop('script', None)
254
255
        self._initialize = None
256
        self._before_trading_start = None
257
        self._analyze = None
258
259
        self.event_manager = EventManager(
260
            create_context=kwargs.pop('create_event_context', None),
261
        )
262
263
        if self.algoscript is not None:
264
            filename = kwargs.pop('algo_filename', None)
265
            if filename is None:
266
                filename = '<string>'
267
            code = compile(self.algoscript, filename, 'exec')
268
            exec_(code, self.namespace)
269
            self._initialize = self.namespace.get('initialize')
270
            if 'handle_data' not in self.namespace:
271
                raise ValueError('You must define a handle_data function.')
272
            else:
273
                self._handle_data = self.namespace['handle_data']
274
275
            self._before_trading_start = \
276
                self.namespace.get('before_trading_start')
277
            # Optional analyze function, gets called after run
278
            self._analyze = self.namespace.get('analyze')
279
280
        elif kwargs.get('initialize') and kwargs.get('handle_data'):
281
            if self.algoscript is not None:
282
                raise ValueError('You can not set script and \
283
                initialize/handle_data.')
284
            self._initialize = kwargs.pop('initialize')
285
            self._handle_data = kwargs.pop('handle_data')
286
            self._before_trading_start = kwargs.pop('before_trading_start',
287
                                                    None)
288
            self._analyze = kwargs.pop('analyze', None)
289
290
        self.event_manager.add_event(
291
            zipline.utils.events.Event(
292
                zipline.utils.events.Always(),
293
                # We pass handle_data.__func__ to get the unbound method.
294
                # We will explicitly pass the algorithm to bind it again.
295
                self.handle_data.__func__,
296
            ),
297
            prepend=True,
298
        )
299
300
        # If method not defined, NOOP
301
        if self._initialize is None:
302
            self._initialize = lambda x: None
303
304
        # Alternative way of setting data_frequency for backwards
305
        # compatibility.
306
        if 'data_frequency' in kwargs:
307
            self.data_frequency = kwargs.pop('data_frequency')
308
309
        # Prepare the algo for initialization
310
        self.initialized = False
311
        self.initialize_args = args
312
        self.initialize_kwargs = kwargs
313
314
        self.benchmark_sid = kwargs.pop('benchmark_sid', None)
315
316
    def init_engine(self, get_loader):
317
        """
318
        Construct and store a PipelineEngine from loader.
319
320
        If get_loader is None, constructs a NoOpPipelineEngine.
321
        """
322
        if get_loader is not None:
323
            self.engine = SimplePipelineEngine(
324
                get_loader,
325
                self.trading_environment.trading_days,
326
                self.asset_finder,
327
            )
328
        else:
329
            self.engine = NoOpPipelineEngine()
330
331
    def initialize(self, *args, **kwargs):
332
        """
333
        Call self._initialize with `self` made available to Zipline API
334
        functions.
335
        """
336
        with ZiplineAPI(self):
337
            self._initialize(self, *args, **kwargs)
338
339
    def before_trading_start(self, data):
340
        if self._before_trading_start is None:
341
            return
342
343
        self._before_trading_start(self, data)
344
345
    def handle_data(self, data):
346
        self._handle_data(self, data)
347
348
        # Unlike trading controls which remain constant unless placing an
349
        # order, account controls can change each bar. Thus, must check
350
        # every bar no matter if the algorithm places an order or not.
351
        self.validate_account_controls()
352
353
    def analyze(self, perf):
354
        if self._analyze is None:
355
            return
356
357
        with ZiplineAPI(self):
358
            self._analyze(self, perf)
359
360
    def __repr__(self):
361
        """
362
        N.B. this does not yet represent a string that can be used
363
        to instantiate an exact copy of an algorithm.
364
365
        However, it is getting close, and provides some value as something
366
        that can be inspected interactively.
367
        """
368
        return """
369
{class_name}(
370
    capital_base={capital_base}
371
    sim_params={sim_params},
372
    initialized={initialized},
373
    slippage={slippage},
374
    commission={commission},
375
    blotter={blotter},
376
    recorded_vars={recorded_vars})
377
""".strip().format(class_name=self.__class__.__name__,
378
                   capital_base=self.capital_base,
379
                   sim_params=repr(self.sim_params),
380
                   initialized=self.initialized,
381
                   slippage=repr(self.blotter.slippage_func),
382
                   commission=repr(self.blotter.commission),
383
                   blotter=repr(self.blotter),
384
                   recorded_vars=repr(self.recorded_vars))
385
386
    def _create_clock(self):
387
        """
388
        If the clock property is not set, then create one based on frequency.
389
        """
390
        if self.sim_params.data_frequency == 'minute':
391
            env = self.trading_environment
392
            trading_o_and_c = env.open_and_closes.ix[
393
                self.sim_params.trading_days]
394
            market_opens = trading_o_and_c['market_open'].values.astype(
395
                'datetime64[ns]').astype(np.int64)
396
            market_closes = trading_o_and_c['market_close'].values.astype(
397
                'datetime64[ns]').astype(np.int64)
398
399
            minutely_emission = self.sim_params.emission_rate == "minute"
400
401
            clock = MinuteSimulationClock(
402
                self.sim_params.trading_days,
403
                market_opens,
404
                market_closes,
405
                env.trading_days,
406
                minutely_emission
407
            )
408
            self.data_portal.setup_offset_cache(
409
                clock.minutes_by_day,
410
                clock.minutes_to_day,
411
                self.sim_params.trading_days)
412
            return clock
413
        else:
414
            return DailySimulationClock(self.sim_params.trading_days)
415
416
    def _create_benchmark_source(self):
417
        return BenchmarkSource(
418
            self.benchmark_sid,
419
            self.trading_environment,
420
            self.sim_params.trading_days,
421
            self.data_portal,
422
            emission_rate=self.sim_params.emission_rate,
423
        )
424
425
    def _create_generator(self, sim_params):
426
        if sim_params is not None:
427
            self.sim_params = sim_params
428
429
        if self.perf_tracker is None:
430
            # HACK: When running with the `run` method, we set perf_tracker to
431
            # None so that it will be overwritten here.
432
            self.perf_tracker = PerformanceTracker(
433
                sim_params=self.sim_params,
434
                env=self.trading_environment,
435
                data_portal=self.data_portal
436
            )
437
438
            # Set the dt initially to the period start by forcing it to change.
439
            self.on_dt_changed(self.sim_params.period_start)
440
441
        if not self.initialized:
442
            self.initialize(*self.initialize_args, **self.initialize_kwargs)
443
            self.initialized = True
444
445
        self.trading_client = self._create_trading_client(sim_params)
446
447
        return self.trading_client.transform()
448
449
    def _create_trading_client(self, sim_params):
450
        return AlgorithmSimulator(
451
            self,
452
            sim_params,
453
            self.data_portal,
454
            self._create_clock(),
455
            self._create_benchmark_source()
456
        )
457
458
    def get_generator(self):
459
        """
460
        Override this method to add new logic to the construction
461
        of the generator. Overrides can use the _create_generator
462
        method to get a standard construction generator.
463
        """
464
        return self._create_generator(self.sim_params)
465
466
    def run(self, data_portal=None):
467
        """Run the algorithm.
468
469
        :Arguments:
470
            source : DataPortal
471
472
        :Returns:
473
            daily_stats : pandas.DataFrame
474
              Daily performance metrics such as returns, alpha etc.
475
476
        """
477
        self.data_portal = data_portal
478
479
        # Force a reset of the performance tracker, in case
480
        # this is a repeat run of the algorithm.
481
        self.perf_tracker = None
482
483
        # Create zipline and loop through simulated_trading.
484
        # Each iteration returns a perf dictionary
485
        perfs = []
486
        for perf in self.get_generator():
487
            perfs.append(perf)
488
489
        # convert perf dict to pandas dataframe
490
        daily_stats = self._create_daily_stats(perfs)
491
492
        self.analyze(daily_stats)
493
494
        return daily_stats
495
496
    def _write_and_map_id_index_to_sids(self, identifiers, as_of_date):
497
        # Build new Assets for identifiers that can't be resolved as
498
        # sids/Assets
499
        identifiers_to_build = []
500
        for identifier in identifiers:
501
            asset = None
502
503
            if isinstance(identifier, Asset):
504
                asset = self.asset_finder.retrieve_asset(sid=identifier.sid,
505
                                                         default_none=True)
506
            elif isinstance(identifier, Integral):
507
                asset = self.asset_finder.retrieve_asset(sid=identifier,
508
                                                         default_none=True)
509
            if asset is None:
510
                identifiers_to_build.append(identifier)
511
512
        self.trading_environment.write_data(
513
            equities_identifiers=identifiers_to_build)
514
515
        # We need to clear out any cache misses that were stored while trying
516
        # to do lookups.  The real fix for this problem is to not construct an
517
        # AssetFinder until we `run()` when we actually have all the data we
518
        # need to so.
519
        self.asset_finder._reset_caches()
520
521
        return self.asset_finder.map_identifier_index_to_sids(
522
            identifiers, as_of_date,
523
        )
524
525
    def _create_daily_stats(self, perfs):
526
        # create daily and cumulative stats dataframe
527
        daily_perfs = []
528
        # TODO: the loop here could overwrite expected properties
529
        # of daily_perf. Could potentially raise or log a
530
        # warning.
531
        for perf in perfs:
532
            if 'daily_perf' in perf:
533
534
                perf['daily_perf'].update(
535
                    perf['daily_perf'].pop('recorded_vars')
536
                )
537
                perf['daily_perf'].update(perf['cumulative_risk_metrics'])
538
                daily_perfs.append(perf['daily_perf'])
539
            else:
540
                self.risk_report = perf
541
542
        daily_dts = [np.datetime64(perf['period_close'], utc=True)
543
                     for perf in daily_perfs]
544
        daily_stats = pd.DataFrame(daily_perfs, index=daily_dts)
545
546
        return daily_stats
547
548
    @api_method
549
    def get_environment(self, field='platform'):
550
        env = {
551
            'arena': self.sim_params.arena,
552
            'data_frequency': self.sim_params.data_frequency,
553
            'start': self.sim_params.first_open,
554
            'end': self.sim_params.last_close,
555
            'capital_base': self.sim_params.capital_base,
556
            'platform': self._platform
557
        }
558
        if field == '*':
559
            return env
560
        else:
561
            return env[field]
562
563
    @api_method
564
    def fetch_csv(self, url,
565
                  pre_func=None,
566
                  post_func=None,
567
                  date_column='date',
568
                  date_format=None,
569
                  timezone=pytz.utc.zone,
570
                  symbol=None,
571
                  mask=True,
572
                  symbol_column=None,
573
                  special_params_checker=None,
574
                  **kwargs):
575
576
        # Show all the logs every time fetcher is used.
577
        csv_data_source = PandasRequestsCSV(
578
            url,
579
            pre_func,
580
            post_func,
581
            self.trading_environment,
582
            self.sim_params.period_start,
583
            self.sim_params.period_end,
584
            date_column,
585
            date_format,
586
            timezone,
587
            symbol,
588
            mask,
589
            symbol_column,
590
            data_frequency=self.data_frequency,
591
            special_params_checker=special_params_checker,
592
            **kwargs
593
        )
594
595
        # ingest this into dataportal
596
        self.data_portal.handle_extra_source(csv_data_source.df,
597
                                             self.sim_params)
598
599
        return csv_data_source
600
601
    def add_event(self, rule=None, callback=None):
602
        """
603
        Adds an event to the algorithm's EventManager.
604
        """
605
        self.event_manager.add_event(
606
            zipline.utils.events.Event(rule, callback),
607
        )
608
609
    @api_method
610
    def schedule_function(self,
611
                          func,
612
                          date_rule=None,
613
                          time_rule=None,
614
                          half_days=True):
615
        """
616
        Schedules a function to be called with some timed rules.
617
        """
618
        date_rule = date_rule or DateRuleFactory.every_day()
619
        time_rule = ((time_rule or TimeRuleFactory.market_open())
620
                     if self.sim_params.data_frequency == 'minute' else
621
                     # If we are in daily mode the time_rule is ignored.
622
                     zipline.utils.events.Always())
623
624
        self.add_event(
625
            make_eventrule(date_rule, time_rule, half_days),
626
            func,
627
        )
628
629
    @api_method
630
    def record(self, *args, **kwargs):
631
        """
632
        Track and record local variable (i.e. attributes) each day.
633
        """
634
        # Make 2 objects both referencing the same iterator
635
        args = [iter(args)] * 2
636
637
        # Zip generates list entries by calling `next` on each iterator it
638
        # receives.  In this case the two iterators are the same object, so the
639
        # call to next on args[0] will also advance args[1], resulting in zip
640
        # returning (a,b) (c,d) (e,f) rather than (a,a) (b,b) (c,c) etc.
641
        positionals = zip(*args)
642
        for name, value in chain(positionals, iteritems(kwargs)):
643
            self._recorded_vars[name] = value
644
645
    @api_method
646
    def set_benchmark(self, benchmark_sid):
647
        if self.initialized:
648
            raise SetBenchmarkOutsideInitialize()
649
650
        self.benchmark_sid = benchmark_sid
651
652
    @api_method
653
    @preprocess(symbol_str=ensure_upper_case)
654
    def symbol(self, symbol_str):
655
        """
656
        Default symbol lookup for any source that directly maps the
657
        symbol to the Asset (e.g. yahoo finance).
658
        """
659
        # If the user has not set the symbol lookup date,
660
        # use the period_end as the date for sybmol->sid resolution.
661
        _lookup_date = self._symbol_lookup_date if self._symbol_lookup_date is not None \
662
            else self.sim_params.period_end
663
664
        return self.asset_finder.lookup_symbol(
665
            symbol_str,
666
            as_of_date=_lookup_date,
667
        )
668
669
    @api_method
670
    def symbols(self, *args):
671
        """
672
        Default symbols lookup for any source that directly maps the
673
        symbol to the Asset (e.g. yahoo finance).
674
        """
675
        return [self.symbol(identifier) for identifier in args]
676
677
    @api_method
678
    def sid(self, a_sid):
679
        """
680
        Default sid lookup for any source that directly maps the integer sid
681
        to the Asset.
682
        """
683
        return self.asset_finder.retrieve_asset(a_sid)
684
685
    @api_method
686
    @preprocess(symbol=ensure_upper_case)
687
    def future_symbol(self, symbol):
688
        """ Lookup a futures contract with a given symbol.
689
690
        Parameters
691
        ----------
692
        symbol : str
693
            The symbol of the desired contract.
694
695
        Returns
696
        -------
697
        Future
698
            A Future object.
699
700
        Raises
701
        ------
702
        SymbolNotFound
703
            Raised when no contract named 'symbol' is found.
704
705
        """
706
        return self.asset_finder.lookup_future_symbol(symbol)
707
708
    @api_method
709
    @preprocess(root_symbol=ensure_upper_case)
710
    def future_chain(self, root_symbol, as_of_date=None):
711
        """ Look up a future chain with the specified parameters.
712
713
        Parameters
714
        ----------
715
        root_symbol : str
716
            The root symbol of a future chain.
717
        as_of_date : datetime.datetime or pandas.Timestamp or str, optional
718
            Date at which the chain determination is rooted. I.e. the
719
            existing contract whose notice date is first after this date is
720
            the primary contract, etc.
721
722
        Returns
723
        -------
724
        FutureChain
725
            The future chain matching the specified parameters.
726
727
        Raises
728
        ------
729
        RootSymbolNotFound
730
            If a future chain could not be found for the given root symbol.
731
        """
732
        if as_of_date:
733
            try:
734
                as_of_date = pd.Timestamp(as_of_date, tz='UTC')
735
            except ValueError:
736
                raise UnsupportedDatetimeFormat(input=as_of_date,
737
                                                method='future_chain')
738
        return FutureChain(
739
            asset_finder=self.asset_finder,
740
            get_datetime=self.get_datetime,
741
            root_symbol=root_symbol,
742
            as_of_date=as_of_date
743
        )
744
745
    def _calculate_order_value_amount(self, asset, value):
746
        """
747
        Calculates how many shares/contracts to order based on the type of
748
        asset being ordered.
749
        """
750
        last_price = self.trading_client.current_data[asset].price
751
752
        if tolerant_equals(last_price, 0):
753
            zero_message = "Price of 0 for {psid}; can't infer value".format(
754
                psid=asset
755
            )
756
            if self.logger:
757
                self.logger.debug(zero_message)
758
            # Don't place any order
759
            return 0
760
761
        if isinstance(asset, Future):
762
            value_multiplier = asset.contract_multiplier
763
        else:
764
            value_multiplier = 1
765
766
        return value / (last_price * value_multiplier)
767
768
    @api_method
769
    def order(self, sid, amount,
770
              limit_price=None,
771
              stop_price=None,
772
              style=None):
773
        """
774
        Place an order using the specified parameters.
775
        """
776
        # Truncate to the integer share count that's either within .0001 of
777
        # amount or closer to zero.
778
        # E.g. 3.9999 -> 4.0; 5.5 -> 5.0; -5.5 -> -5.0
779
        amount = int(round_if_near_integer(amount))
780
781
        # Raises a ZiplineError if invalid parameters are detected.
782
        self.validate_order_params(sid,
783
                                   amount,
784
                                   limit_price,
785
                                   stop_price,
786
                                   style)
787
788
        # Convert deprecated limit_price and stop_price parameters to use
789
        # ExecutionStyle objects.
790
        style = self.__convert_order_params_for_blotter(limit_price,
791
                                                        stop_price,
792
                                                        style)
793
        return self.blotter.order(sid, amount, style)
794
795
    def validate_order_params(self,
796
                              asset,
797
                              amount,
798
                              limit_price,
799
                              stop_price,
800
                              style):
801
        """
802
        Helper method for validating parameters to the order API function.
803
804
        Raises an UnsupportedOrderParameters if invalid arguments are found.
805
        """
806
807
        if not self.initialized:
808
            raise OrderDuringInitialize(
809
                msg="order() can only be called from within handle_data()"
810
            )
811
812
        if style:
813
            if limit_price:
814
                raise UnsupportedOrderParameters(
815
                    msg="Passing both limit_price and style is not supported."
816
                )
817
818
            if stop_price:
819
                raise UnsupportedOrderParameters(
820
                    msg="Passing both stop_price and style is not supported."
821
                )
822
823
        if not isinstance(asset, Asset):
824
            raise UnsupportedOrderParameters(
825
                msg="Passing non-Asset argument to 'order()' is not supported."
826
                    " Use 'sid()' or 'symbol()' methods to look up an Asset."
827
            )
828
829
        for control in self.trading_controls:
830
            control.validate(asset,
831
                             amount,
832
                             self.portfolio,
833
                             self.get_datetime(),
834
                             self.trading_client.current_data)
835
836
    @staticmethod
837
    def __convert_order_params_for_blotter(limit_price, stop_price, style):
838
        """
839
        Helper method for converting deprecated limit_price and stop_price
840
        arguments into ExecutionStyle instances.
841
842
        This function assumes that either style == None or (limit_price,
843
        stop_price) == (None, None).
844
        """
845
        # TODO_SS: DeprecationWarning for usage of limit_price and stop_price.
846
        if style:
847
            assert (limit_price, stop_price) == (None, None)
848
            return style
849
        if limit_price and stop_price:
850
            return StopLimitOrder(limit_price, stop_price)
851
        if limit_price:
852
            return LimitOrder(limit_price)
853
        if stop_price:
854
            return StopOrder(stop_price)
855
        else:
856
            return MarketOrder()
857
858
    @api_method
859
    def order_value(self, sid, value,
860
                    limit_price=None, stop_price=None, style=None):
861
        """
862
        Place an order by desired value rather than desired number of shares.
863
        If the requested sid exists, the requested value is
864
        divided by its price to imply the number of shares to transact.
865
        If the Asset being ordered is a Future, the 'value' calculated
866
        is actually the exposure, as Futures have no 'value'.
867
868
        value > 0 :: Buy/Cover
869
        value < 0 :: Sell/Short
870
        Market order:    order(sid, value)
871
        Limit order:     order(sid, value, limit_price)
872
        Stop order:      order(sid, value, None, stop_price)
873
        StopLimit order: order(sid, value, limit_price, stop_price)
874
        """
875
        amount = self._calculate_order_value_amount(sid, value)
876
        return self.order(sid, amount,
877
                          limit_price=limit_price,
878
                          stop_price=stop_price,
879
                          style=style)
880
881
    @property
882
    def recorded_vars(self):
883
        return copy(self._recorded_vars)
884
885
    @property
886
    def portfolio(self):
887
        return self.updated_portfolio()
888
889
    def updated_portfolio(self):
890
        if self._portfolio is None and self.perf_tracker is not None:
891
            self._portfolio = \
892
                self.perf_tracker.get_portfolio(self.datetime)
893
        return self._portfolio
894
895
    @property
896
    def account(self):
897
        return self.updated_account()
898
899
    def updated_account(self):
900
        if self._account is None and self.perf_tracker is not None:
901
            self._account = \
902
                self.perf_tracker.get_account(self.datetime)
903
        return self._account
904
905
    def set_logger(self, logger):
906
        self.logger = logger
907
908
    def on_dt_changed(self, dt):
909
        """
910
        Callback triggered by the simulation loop whenever the current dt
911
        changes.
912
913
        Any logic that should happen exactly once at the start of each datetime
914
        group should happen here.
915
        """
916
        assert isinstance(dt, datetime), \
917
            "Attempt to set algorithm's current time with non-datetime"
918
        assert dt.tzinfo == pytz.utc, \
919
            "Algorithm expects a utc datetime"
920
921
        self.datetime = dt
922
        self.perf_tracker.set_date(dt)
923
        self.blotter.set_date(dt)
924
925
        self._portfolio = None
926
        self._account = None
927
928
    @api_method
929
    def get_datetime(self, tz=None):
930
        """
931
        Returns the simulation datetime.
932
        """
933
        dt = self.datetime
934
        assert dt.tzinfo == pytz.utc, "Algorithm should have a utc datetime"
935
936
        if tz is not None:
937
            # Convert to the given timezone passed as a string or tzinfo.
938
            if isinstance(tz, string_types):
939
                tz = pytz.timezone(tz)
940
            dt = dt.astimezone(tz)
941
942
        return dt  # datetime.datetime objects are immutable.
943
944
    def update_dividends(self, dividend_frame):
945
        """
946
        Set DataFrame used to process dividends.  DataFrame columns should
947
        contain at least the entries in zp.DIVIDEND_FIELDS.
948
        """
949
        self.perf_tracker.update_dividends(dividend_frame)
950
951
    @api_method
952
    def set_slippage(self, slippage):
953
        if not isinstance(slippage, SlippageModel):
954
            raise UnsupportedSlippageModel()
955
        if self.initialized:
956
            raise OverrideSlippagePostInit()
957
        self.blotter.slippage_func = slippage
958
959
    @api_method
960
    def set_commission(self, commission):
961
        if not isinstance(commission, (PerShare, PerTrade, PerDollar)):
962
            raise UnsupportedCommissionModel()
963
964
        if self.initialized:
965
            raise OverrideCommissionPostInit()
966
        self.blotter.commission = commission
967
968
    @api_method
969
    def set_symbol_lookup_date(self, dt):
970
        """
971
        Set the date for which symbols will be resolved to their sids
972
        (symbols may map to different firms or underlying assets at
973
        different times)
974
        """
975
        try:
976
            self._symbol_lookup_date = pd.Timestamp(dt, tz='UTC')
977
        except ValueError:
978
            raise UnsupportedDatetimeFormat(input=dt,
979
                                            method='set_symbol_lookup_date')
980
981
    # Remain backwards compatibility
982
    @property
983
    def data_frequency(self):
984
        return self.sim_params.data_frequency
985
986
    @data_frequency.setter
987
    def data_frequency(self, value):
988
        assert value in ('daily', 'minute')
989
        self.sim_params.data_frequency = value
990
991
    @api_method
992
    def order_percent(self, sid, percent,
993
                      limit_price=None, stop_price=None, style=None):
994
        """
995
        Place an order in the specified asset corresponding to the given
996
        percent of the current portfolio value.
997
998
        Note that percent must expressed as a decimal (0.50 means 50\%).
999
        """
1000
        value = self.portfolio.portfolio_value * percent
1001
        return self.order_value(sid, value,
1002
                                limit_price=limit_price,
1003
                                stop_price=stop_price,
1004
                                style=style)
1005
1006
    @api_method
1007
    def order_target(self, sid, target,
1008
                     limit_price=None, stop_price=None, style=None):
1009
        """
1010
        Place an order to adjust a position to a target number of shares. If
1011
        the position doesn't already exist, this is equivalent to placing a new
1012
        order. If the position does exist, this is equivalent to placing an
1013
        order for the difference between the target number of shares and the
1014
        current number of shares.
1015
        """
1016
        if sid in self.portfolio.positions:
1017
            current_position = self.portfolio.positions[sid].amount
1018
            req_shares = target - current_position
1019
            return self.order(sid, req_shares,
1020
                              limit_price=limit_price,
1021
                              stop_price=stop_price,
1022
                              style=style)
1023
        else:
1024
            return self.order(sid, target,
1025
                              limit_price=limit_price,
1026
                              stop_price=stop_price,
1027
                              style=style)
1028
1029
    @api_method
1030
    def order_target_value(self, sid, target,
1031
                           limit_price=None, stop_price=None, style=None):
1032
        """
1033
        Place an order to adjust a position to a target value. If
1034
        the position doesn't already exist, this is equivalent to placing a new
1035
        order. If the position does exist, this is equivalent to placing an
1036
        order for the difference between the target value and the
1037
        current value.
1038
        If the Asset being ordered is a Future, the 'target value' calculated
1039
        is actually the target exposure, as Futures have no 'value'.
1040
        """
1041
        target_amount = self._calculate_order_value_amount(sid, target)
1042
        return self.order_target(sid, target_amount,
1043
                                 limit_price=limit_price,
1044
                                 stop_price=stop_price,
1045
                                 style=style)
1046
1047
    @api_method
1048
    def order_target_percent(self, sid, target,
1049
                             limit_price=None, stop_price=None, style=None):
1050
        """
1051
        Place an order to adjust a position to a target percent of the
1052
        current portfolio value. If the position doesn't already exist, this is
1053
        equivalent to placing a new order. If the position does exist, this is
1054
        equivalent to placing an order for the difference between the target
1055
        percent and the current percent.
1056
1057
        Note that target must expressed as a decimal (0.50 means 50\%).
1058
        """
1059
        target_value = self.portfolio.portfolio_value * target
1060
        return self.order_target_value(sid, target_value,
1061
                                       limit_price=limit_price,
1062
                                       stop_price=stop_price,
1063
                                       style=style)
1064
1065
    @api_method
1066
    def get_open_orders(self, sid=None):
1067
        if sid is None:
1068
            return {
1069
                key: [order.to_api_obj() for order in orders]
1070
                for key, orders in iteritems(self.blotter.open_orders)
1071
                if orders
1072
            }
1073
        if sid in self.blotter.open_orders:
1074
            orders = self.blotter.open_orders[sid]
1075
            return [order.to_api_obj() for order in orders]
1076
        return []
1077
1078
    @api_method
1079
    def get_order(self, order_id):
1080
        if order_id in self.blotter.orders:
1081
            return self.blotter.orders[order_id].to_api_obj()
1082
1083
    @api_method
1084
    def cancel_order(self, order_param):
1085
        order_id = order_param
1086
        if isinstance(order_param, zipline.protocol.Order):
1087
            order_id = order_param.id
1088
1089
        self.blotter.cancel(order_id)
1090
1091
    @api_method
1092
    @require_initialized(HistoryInInitialize())
1093
    def history(self, sids, bar_count, frequency, field, ffill=True):
1094
        if self.data_portal is None:
1095
            raise Exception("no data portal!")
1096
1097
        return self.data_portal.get_history_window(
1098
            sids,
1099
            self.datetime,
1100
            bar_count,
1101
            frequency,
1102
            field,
1103
            ffill,
1104
        )
1105
1106
    ####################
1107
    # Account Controls #
1108
    ####################
1109
1110
    def register_account_control(self, control):
1111
        """
1112
        Register a new AccountControl to be checked on each bar.
1113
        """
1114
        if self.initialized:
1115
            raise RegisterAccountControlPostInit()
1116
        self.account_controls.append(control)
1117
1118
    def validate_account_controls(self):
1119
        for control in self.account_controls:
1120
            control.validate(self.portfolio,
1121
                             self.account,
1122
                             self.get_datetime(),
1123
                             self.trading_client.current_data)
1124
1125
    @api_method
1126
    def set_max_leverage(self, max_leverage=None):
1127
        """
1128
        Set a limit on the maximum leverage of the algorithm.
1129
        """
1130
        control = MaxLeverage(max_leverage)
1131
        self.register_account_control(control)
1132
1133
    ####################
1134
    # Trading Controls #
1135
    ####################
1136
1137
    def register_trading_control(self, control):
1138
        """
1139
        Register a new TradingControl to be checked prior to order calls.
1140
        """
1141
        if self.initialized:
1142
            raise RegisterTradingControlPostInit()
1143
        self.trading_controls.append(control)
1144
1145
    @api_method
1146
    def set_max_position_size(self,
1147
                              sid=None,
1148
                              max_shares=None,
1149
                              max_notional=None):
1150
        """
1151
        Set a limit on the number of shares and/or dollar value held for the
1152
        given sid. Limits are treated as absolute values and are enforced at
1153
        the time that the algo attempts to place an order for sid. This means
1154
        that it's possible to end up with more than the max number of shares
1155
        due to splits/dividends, and more than the max notional due to price
1156
        improvement.
1157
1158
        If an algorithm attempts to place an order that would result in
1159
        increasing the absolute value of shares/dollar value exceeding one of
1160
        these limits, raise a TradingControlException.
1161
        """
1162
        control = MaxPositionSize(asset=sid,
1163
                                  max_shares=max_shares,
1164
                                  max_notional=max_notional)
1165
        self.register_trading_control(control)
1166
1167
    @api_method
1168
    def set_max_order_size(self, sid=None, max_shares=None, max_notional=None):
1169
        """
1170
        Set a limit on the number of shares and/or dollar value of any single
1171
        order placed for sid.  Limits are treated as absolute values and are
1172
        enforced at the time that the algo attempts to place an order for sid.
1173
1174
        If an algorithm attempts to place an order that would result in
1175
        exceeding one of these limits, raise a TradingControlException.
1176
        """
1177
        control = MaxOrderSize(asset=sid,
1178
                               max_shares=max_shares,
1179
                               max_notional=max_notional)
1180
        self.register_trading_control(control)
1181
1182
    @api_method
1183
    def set_max_order_count(self, max_count):
1184
        """
1185
        Set a limit on the number of orders that can be placed within the given
1186
        time interval.
1187
        """
1188
        control = MaxOrderCount(max_count)
1189
        self.register_trading_control(control)
1190
1191
    @api_method
1192
    def set_do_not_order_list(self, restricted_list):
1193
        """
1194
        Set a restriction on which sids can be ordered.
1195
        """
1196
        control = RestrictedListOrder(restricted_list)
1197
        self.register_trading_control(control)
1198
1199
    @api_method
1200
    def set_long_only(self):
1201
        """
1202
        Set a rule specifying that this algorithm cannot take short positions.
1203
        """
1204
        self.register_trading_control(LongOnly())
1205
1206
    ##############
1207
    # Pipeline API
1208
    ##############
1209
    @api_method
1210
    @require_not_initialized(AttachPipelineAfterInitialize())
1211
    def attach_pipeline(self, pipeline, name, chunksize=None):
1212
        """
1213
        Register a pipeline to be computed at the start of each day.
1214
        """
1215
        if self._pipelines:
1216
            raise NotImplementedError("Multiple pipelines are not supported.")
1217
        if chunksize is None:
1218
            # Make the first chunk smaller to get more immediate results:
1219
            # (one week, then every half year)
1220
            chunks = iter(chain([5], repeat(126)))
1221
        else:
1222
            chunks = iter(repeat(int(chunksize)))
1223
        self._pipelines[name] = pipeline, chunks
1224
1225
        # Return the pipeline to allow expressions like
1226
        # p = attach_pipeline(Pipeline(), 'name')
1227
        return pipeline
1228
1229
    @api_method
1230
    @require_initialized(PipelineOutputDuringInitialize())
1231
    def pipeline_output(self, name):
1232
        """
1233
        Get the results of pipeline with name `name`.
1234
1235
        Parameters
1236
        ----------
1237
        name : str
1238
            Name of the pipeline for which results are requested.
1239
1240
        Returns
1241
        -------
1242
        results : pd.DataFrame
1243
            DataFrame containing the results of the requested pipeline for
1244
            the current simulation date.
1245
1246
        Raises
1247
        ------
1248
        NoSuchPipeline
1249
            Raised when no pipeline with the name `name` has been registered.
1250
1251
        See Also
1252
        --------
1253
        :meth:`zipline.pipeline.engine.PipelineEngine.run_pipeline`
1254
        """
1255
        # NOTE: We don't currently support multiple pipelines, but we plan to
1256
        # in the future.
1257
        try:
1258
            p, chunks = self._pipelines[name]
1259
        except KeyError:
1260
            raise NoSuchPipeline(
1261
                name=name,
1262
                valid=list(self._pipelines.keys()),
1263
            )
1264
        return self._pipeline_output(p, chunks)
1265
1266
    def _pipeline_output(self, pipeline, chunks):
1267
        """
1268
        Internal implementation of `pipeline_output`.
1269
        """
1270
        today = normalize_date(self.get_datetime())
1271
        try:
1272
            data = self._pipeline_cache.unwrap(today)
1273
        except Expired:
1274
            data, valid_until = self._run_pipeline(
1275
                pipeline, today, next(chunks),
1276
            )
1277
            self._pipeline_cache = CachedObject(data, valid_until)
1278
1279
        # Now that we have a cached result, try to return the data for today.
1280
        try:
1281
            return data.loc[today]
1282
        except KeyError:
1283
            # This happens if no assets passed the pipeline screen on a given
1284
            # day.
1285
            return pd.DataFrame(index=[], columns=data.columns)
1286
1287
    def _run_pipeline(self, pipeline, start_date, chunksize):
1288
        """
1289
        Compute `pipeline`, providing values for at least `start_date`.
1290
1291
        Produces a DataFrame containing data for days between `start_date` and
1292
        `end_date`, where `end_date` is defined by:
1293
1294
            `end_date = min(start_date + chunksize trading days,
1295
                            simulation_end)`
1296
1297
        Returns
1298
        -------
1299
        (data, valid_until) : tuple (pd.DataFrame, pd.Timestamp)
1300
1301
        See Also
1302
        --------
1303
        PipelineEngine.run_pipeline
1304
        """
1305
        days = self.trading_environment.trading_days
1306
1307
        # Load data starting from the previous trading day...
1308
        start_date_loc = days.get_loc(start_date)
1309
1310
        # ...continuing until either the day before the simulation end, or
1311
        # until chunksize days of data have been loaded.
1312
        sim_end = self.sim_params.last_close.normalize()
1313
        end_loc = min(start_date_loc + chunksize, days.get_loc(sim_end))
1314
        end_date = days[end_loc]
1315
1316
        return \
1317
            self.engine.run_pipeline(pipeline, start_date, end_date), end_date
1318
1319
    ##################
1320
    # End Pipeline API
1321
    ##################
1322
1323
    @classmethod
1324
    def all_api_methods(cls):
1325
        """
1326
        Return a list of all the TradingAlgorithm API methods.
1327
        """
1328
        return [
1329
            fn for fn in itervalues(vars(cls))
1330
            if getattr(fn, 'is_api_method', False)
1331
        ]
1332