Completed
Pull Request — master (#858)
by Eddie
01:32
created

zipline.TradingAlgorithm.future_symbol()   A

Complexity

Conditions 1

Size

Total Lines 22

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 1
dl 0
loc 22
rs 9.2
1
#
2
# Copyright 2015 Quantopian, Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
from copy import copy
16
17
import pytz
18
import pandas as pd
19
from pandas.tseries.tools import normalize_date
20
import numpy as np
21
22
from datetime import datetime
23
from itertools import chain, repeat
24
from numbers import Integral
25
26
from six import (
27
    exec_,
28
    iteritems,
29
    itervalues,
30
    string_types,
31
)
32
33
34
from zipline.errors import (
35
    AttachPipelineAfterInitialize,
36
    HistoryInInitialize,
37
    NoSuchPipeline,
38
    OrderDuringInitialize,
39
    OverrideCommissionPostInit,
40
    OverrideSlippagePostInit,
41
    PipelineOutputDuringInitialize,
42
    RegisterAccountControlPostInit,
43
    RegisterTradingControlPostInit,
44
    SetBenchmarkOutsideInitialize,
45
    UnsupportedCommissionModel,
46
    UnsupportedDatetimeFormat,
47
    UnsupportedOrderParameters,
48
    UnsupportedSlippageModel,
49
)
50
from zipline.finance.trading import TradingEnvironment
51
from zipline.finance.blotter import Blotter
52
from zipline.finance.commission import PerShare, PerTrade, PerDollar
53
from zipline.finance.controls import (
54
    LongOnly,
55
    MaxOrderCount,
56
    MaxOrderSize,
57
    MaxPositionSize,
58
    MaxLeverage,
59
    RestrictedListOrder
60
)
61
from zipline.finance.execution import (
62
    LimitOrder,
63
    MarketOrder,
64
    StopLimitOrder,
65
    StopOrder,
66
)
67
from zipline.finance.performance import PerformanceTracker
68
from zipline.finance.slippage import (
69
    VolumeShareSlippage,
70
    SlippageModel
71
)
72
from zipline.assets import Asset, Future
73
from zipline.assets.futures import FutureChain
74
from zipline.gens.tradesimulation import AlgorithmSimulator
75
from zipline.pipeline.engine import (
76
    NoOpPipelineEngine,
77
    SimplePipelineEngine,
78
)
79
from zipline.utils.api_support import (
80
    api_method,
81
    require_initialized,
82
    require_not_initialized,
83
    ZiplineAPI,
84
)
85
from zipline.utils.input_validation import ensure_upper_case
86
from zipline.utils.cache import CachedObject, Expired
87
import zipline.utils.events
88
from zipline.utils.events import (
89
    EventManager,
90
    make_eventrule,
91
    DateRuleFactory,
92
    TimeRuleFactory,
93
)
94
from zipline.utils.factory import create_simulation_parameters
95
from zipline.utils.math_utils import (
96
    tolerant_equals,
97
    round_if_near_integer
98
)
99
from zipline.utils.preprocess import preprocess
100
101
import zipline.protocol
102
from zipline.sources.requests_csv import PandasRequestsCSV
103
104
from zipline.gens.sim_engine import (
105
    MinuteSimulationClock,
106
    DailySimulationClock,
107
)
108
from zipline.sources.benchmark_source import BenchmarkSource
109
110
DEFAULT_CAPITAL_BASE = float("1.0e5")
111
112
113
class TradingAlgorithm(object):
114
    """
115
    Base class for trading algorithms. Inherit and overload
116
    initialize() and handle_data(data).
117
118
    A new algorithm could look like this:
119
    ```
120
    from zipline.api import order, symbol
121
122
    def initialize(context):
123
        context.sid = symbol('AAPL')
124
        context.amount = 100
125
126
    def handle_data(context, data):
127
        sid = context.sid
128
        amount = context.amount
129
        order(sid, amount)
130
    ```
131
    To then to run this algorithm pass these functions to
132
    TradingAlgorithm:
133
134
    my_algo = TradingAlgorithm(initialize, handle_data)
135
    stats = my_algo.run(data)
136
137
    """
138
139
    def __init__(self, *args, **kwargs):
140
        """Initialize sids and other state variables.
141
142
        :Arguments:
143
        :Optional:
144
            initialize : function
145
                Function that is called with a single
146
                argument at the begninning of the simulation.
147
            handle_data : function
148
                Function that is called with 2 arguments
149
                (context and data) on every bar.
150
            script : str
151
                Algoscript that contains initialize and
152
                handle_data function definition.
153
            data_frequency : {'daily', 'minute'}
154
               The duration of the bars.
155
            capital_base : float <default: 1.0e5>
156
               How much capital to start with.
157
            asset_finder : An AssetFinder object
158
                A new AssetFinder object to be used in this TradingEnvironment
159
            equities_metadata : can be either:
160
                            - dict
161
                            - pandas.DataFrame
162
                            - object with 'read' property
163
                If dict is provided, it must have the following structure:
164
                * keys are the identifiers
165
                * values are dicts containing the metadata, with the metadata
166
                  field name as the key
167
                If pandas.DataFrame is provided, it must have the
168
                following structure:
169
                * column names must be the metadata fields
170
                * index must be the different asset identifiers
171
                * array contents should be the metadata value
172
                If an object with a 'read' property is provided, 'read' must
173
                return rows containing at least one of 'sid' or 'symbol' along
174
                with the other metadata fields.
175
            identifiers : List
176
                Any asset identifiers that are not provided in the
177
                equities_metadata, but will be traded by this TradingAlgorithm
178
        """
179
        self.sources = []
180
181
        # List of trading controls to be used to validate orders.
182
        self.trading_controls = []
183
184
        # List of account controls to be checked on each bar.
185
        self.account_controls = []
186
187
        self._recorded_vars = {}
188
        self.namespace = kwargs.pop('namespace', {})
189
190
        self._platform = kwargs.pop('platform', 'zipline')
191
192
        self.logger = None
193
194
        self.data_portal = None
195
196
        # If an env has been provided, pop it
197
        self.trading_environment = kwargs.pop('env', None)
198
199
        if self.trading_environment is None:
200
            self.trading_environment = TradingEnvironment()
201
202
        # Update the TradingEnvironment with the provided asset metadata
203
        self.trading_environment.write_data(
204
            equities_data=kwargs.pop('equities_metadata', {}),
205
            equities_identifiers=kwargs.pop('identifiers', []),
206
            futures_data=kwargs.pop('futures_metadata', {}),
207
        )
208
209
        # set the capital base
210
        self.capital_base = kwargs.pop('capital_base', DEFAULT_CAPITAL_BASE)
211
        self.sim_params = kwargs.pop('sim_params', None)
212
        if self.sim_params is None:
213
            self.sim_params = create_simulation_parameters(
214
                capital_base=self.capital_base,
215
                start=kwargs.pop('start', None),
216
                end=kwargs.pop('end', None),
217
                env=self.trading_environment,
218
            )
219
        else:
220
            self.sim_params.update_internal_from_env(self.trading_environment)
221
222
        self.perf_tracker = None
223
        # Pull in the environment's new AssetFinder for quick reference
224
        self.asset_finder = self.trading_environment.asset_finder
225
226
        # Initialize Pipeline API data.
227
        self.init_engine(kwargs.pop('get_pipeline_loader', None))
228
        self._pipelines = {}
229
        # Create an always-expired cache so that we compute the first time data
230
        # is requested.
231
        self._pipeline_cache = CachedObject(None, pd.Timestamp(0, tz='UTC'))
232
233
        self.blotter = kwargs.pop('blotter', None)
234
        if not self.blotter:
235
            self.blotter = Blotter(
236
                slippage_func=VolumeShareSlippage(),
237
                commission=PerShare(),
238
                data_frequency=self.data_frequency,
239
            )
240
241
        # The symbol lookup date specifies the date to use when resolving
242
        # symbols to sids, and can be set using set_symbol_lookup_date()
243
        self._symbol_lookup_date = None
244
245
        self._portfolio = None
246
        self._account = None
247
248
        # If string is passed in, execute and get reference to
249
        # functions.
250
        self.algoscript = kwargs.pop('script', None)
251
252
        self._initialize = None
253
        self._before_trading_start = None
254
        self._analyze = None
255
256
        self.event_manager = EventManager(
257
            create_context=kwargs.pop('create_event_context', None),
258
        )
259
260
        if self.algoscript is not None:
261
            filename = kwargs.pop('algo_filename', None)
262
            if filename is None:
263
                filename = '<string>'
264
            code = compile(self.algoscript, filename, 'exec')
265
            exec_(code, self.namespace)
266
            self._initialize = self.namespace.get('initialize')
267
            if 'handle_data' not in self.namespace:
268
                raise ValueError('You must define a handle_data function.')
269
            else:
270
                self._handle_data = self.namespace['handle_data']
271
272
            self._before_trading_start = \
273
                self.namespace.get('before_trading_start')
274
            # Optional analyze function, gets called after run
275
            self._analyze = self.namespace.get('analyze')
276
277
        elif kwargs.get('initialize') and kwargs.get('handle_data'):
278
            if self.algoscript is not None:
279
                raise ValueError('You can not set script and \
280
                initialize/handle_data.')
281
            self._initialize = kwargs.pop('initialize')
282
            self._handle_data = kwargs.pop('handle_data')
283
            self._before_trading_start = kwargs.pop('before_trading_start',
284
                                                    None)
285
            self._analyze = kwargs.pop('analyze', None)
286
287
        self.event_manager.add_event(
288
            zipline.utils.events.Event(
289
                zipline.utils.events.Always(),
290
                # We pass handle_data.__func__ to get the unbound method.
291
                # We will explicitly pass the algorithm to bind it again.
292
                self.handle_data.__func__,
293
            ),
294
            prepend=True,
295
        )
296
297
        # If method not defined, NOOP
298
        if self._initialize is None:
299
            self._initialize = lambda x: None
300
301
        # Alternative way of setting data_frequency for backwards
302
        # compatibility.
303
        if 'data_frequency' in kwargs:
304
            self.data_frequency = kwargs.pop('data_frequency')
305
306
        # Prepare the algo for initialization
307
        self.initialized = False
308
        self.initialize_args = args
309
        self.initialize_kwargs = kwargs
310
311
        self.benchmark_sid = kwargs.pop('benchmark_sid', None)
312
313
    def init_engine(self, get_loader):
314
        """
315
        Construct and store a PipelineEngine from loader.
316
317
        If get_loader is None, constructs a NoOpPipelineEngine.
318
        """
319
        if get_loader is not None:
320
            self.engine = SimplePipelineEngine(
321
                get_loader,
322
                self.trading_environment.trading_days,
323
                self.asset_finder,
324
            )
325
        else:
326
            self.engine = NoOpPipelineEngine()
327
328
    def initialize(self, *args, **kwargs):
329
        """
330
        Call self._initialize with `self` made available to Zipline API
331
        functions.
332
        """
333
        with ZiplineAPI(self):
334
            self._initialize(self, *args, **kwargs)
335
336
    def before_trading_start(self, data):
337
        if self._before_trading_start is None:
338
            return
339
340
        self._before_trading_start(self, data)
341
342
    def handle_data(self, data):
343
        self._handle_data(self, data)
344
345
        # Unlike trading controls which remain constant unless placing an
346
        # order, account controls can change each bar. Thus, must check
347
        # every bar no matter if the algorithm places an order or not.
348
        self.validate_account_controls()
349
350
    def analyze(self, perf):
351
        if self._analyze is None:
352
            return
353
354
        with ZiplineAPI(self):
355
            self._analyze(self, perf)
356
357
    def __repr__(self):
358
        """
359
        N.B. this does not yet represent a string that can be used
360
        to instantiate an exact copy of an algorithm.
361
362
        However, it is getting close, and provides some value as something
363
        that can be inspected interactively.
364
        """
365
        return """
366
{class_name}(
367
    capital_base={capital_base}
368
    sim_params={sim_params},
369
    initialized={initialized},
370
    slippage={slippage},
371
    commission={commission},
372
    blotter={blotter},
373
    recorded_vars={recorded_vars})
374
""".strip().format(class_name=self.__class__.__name__,
375
                   capital_base=self.capital_base,
376
                   sim_params=repr(self.sim_params),
377
                   initialized=self.initialized,
378
                   slippage=repr(self.blotter.slippage_func),
379
                   commission=repr(self.blotter.commission),
380
                   blotter=repr(self.blotter),
381
                   recorded_vars=repr(self.recorded_vars))
382
383
    def _create_clock(self):
384
        """
385
        If the clock property is not set, then create one based on frequency.
386
        """
387
        if self.sim_params.data_frequency == 'minute':
388
            env = self.trading_environment
389
            trading_o_and_c = env.open_and_closes.ix[
390
                self.sim_params.trading_days]
391
            market_opens = trading_o_and_c['market_open'].values.astype(
392
                'datetime64[ns]').astype(np.int64)
393
            market_closes = trading_o_and_c['market_close'].values.astype(
394
                'datetime64[ns]').astype(np.int64)
395
396
            minutely_emission = self.sim_params.emission_rate == "minute"
397
398
            clock = MinuteSimulationClock(
399
                self.sim_params.trading_days,
400
                market_opens,
401
                market_closes,
402
                env.trading_days,
403
                minutely_emission
404
            )
405
            self.data_portal.setup_offset_cache(
406
                clock.minutes_by_day,
407
                clock.minutes_to_day,
408
                self.sim_params.trading_days)
409
            return clock
410
        else:
411
            return DailySimulationClock(self.sim_params.trading_days)
412
413
    def _create_benchmark_source(self):
414
        return BenchmarkSource(
415
            self.benchmark_sid,
416
            self.trading_environment,
417
            self.sim_params.trading_days,
418
            self.data_portal,
419
            emission_rate=self.sim_params.emission_rate,
420
        )
421
422
    def _create_generator(self, sim_params):
423
        if sim_params is not None:
424
            self.sim_params = sim_params
425
426
        if self.perf_tracker is None:
427
            # HACK: When running with the `run` method, we set perf_tracker to
428
            # None so that it will be overwritten here.
429
            self.perf_tracker = PerformanceTracker(
430
                sim_params=self.sim_params,
431
                env=self.trading_environment,
432
                data_portal=self.data_portal
433
            )
434
435
            # Set the dt initially to the period start by forcing it to change.
436
            self.on_dt_changed(self.sim_params.period_start)
437
438
        if not self.initialized:
439
            self.initialize(*self.initialize_args, **self.initialize_kwargs)
440
            self.initialized = True
441
442
        self.trading_client = self._create_trading_client(sim_params)
443
444
        return self.trading_client.transform()
445
446
    def _create_trading_client(self, sim_params):
447
        return AlgorithmSimulator(
448
            self,
449
            sim_params,
450
            self.data_portal,
451
            self._create_clock(),
452
            self._create_benchmark_source()
453
        )
454
455
    def get_generator(self):
456
        """
457
        Override this method to add new logic to the construction
458
        of the generator. Overrides can use the _create_generator
459
        method to get a standard construction generator.
460
        """
461
        return self._create_generator(self.sim_params)
462
463
    def run(self, data_portal=None):
464
        """Run the algorithm.
465
466
        :Arguments:
467
            source : DataPortal
468
469
        :Returns:
470
            daily_stats : pandas.DataFrame
471
              Daily performance metrics such as returns, alpha etc.
472
473
        """
474
        self.data_portal = data_portal
475
476
        # Force a reset of the performance tracker, in case
477
        # this is a repeat run of the algorithm.
478
        self.perf_tracker = None
479
480
        # Create zipline and loop through simulated_trading.
481
        # Each iteration returns a perf dictionary
482
        perfs = []
483
        for perf in self.get_generator():
484
            perfs.append(perf)
485
486
        # convert perf dict to pandas dataframe
487
        daily_stats = self._create_daily_stats(perfs)
488
489
        self.analyze(daily_stats)
490
491
        return daily_stats
492
493
    def _write_and_map_id_index_to_sids(self, identifiers, as_of_date):
494
        # Build new Assets for identifiers that can't be resolved as
495
        # sids/Assets
496
        identifiers_to_build = []
497
        for identifier in identifiers:
498
            asset = None
499
500
            if isinstance(identifier, Asset):
501
                asset = self.asset_finder.retrieve_asset(sid=identifier.sid,
502
                                                         default_none=True)
503
            elif isinstance(identifier, Integral):
504
                asset = self.asset_finder.retrieve_asset(sid=identifier,
505
                                                         default_none=True)
506
            if asset is None:
507
                identifiers_to_build.append(identifier)
508
509
        self.trading_environment.write_data(
510
            equities_identifiers=identifiers_to_build)
511
512
        # We need to clear out any cache misses that were stored while trying
513
        # to do lookups.  The real fix for this problem is to not construct an
514
        # AssetFinder until we `run()` when we actually have all the data we
515
        # need to so.
516
        self.asset_finder._reset_caches()
517
518
        return self.asset_finder.map_identifier_index_to_sids(
519
            identifiers, as_of_date,
520
        )
521
522
    def _create_daily_stats(self, perfs):
523
        # create daily and cumulative stats dataframe
524
        daily_perfs = []
525
        # TODO: the loop here could overwrite expected properties
526
        # of daily_perf. Could potentially raise or log a
527
        # warning.
528
        for perf in perfs:
529
            if 'daily_perf' in perf:
530
531
                perf['daily_perf'].update(
532
                    perf['daily_perf'].pop('recorded_vars')
533
                )
534
                perf['daily_perf'].update(perf['cumulative_risk_metrics'])
535
                daily_perfs.append(perf['daily_perf'])
536
            else:
537
                self.risk_report = perf
538
539
        daily_dts = [np.datetime64(perf['period_close'], utc=True)
540
                     for perf in daily_perfs]
541
        daily_stats = pd.DataFrame(daily_perfs, index=daily_dts)
542
543
        return daily_stats
544
545
    @api_method
546
    def get_environment(self, field='platform'):
547
        env = {
548
            'arena': self.sim_params.arena,
549
            'data_frequency': self.sim_params.data_frequency,
550
            'start': self.sim_params.first_open,
551
            'end': self.sim_params.last_close,
552
            'capital_base': self.sim_params.capital_base,
553
            'platform': self._platform
554
        }
555
        if field == '*':
556
            return env
557
        else:
558
            return env[field]
559
560
    @api_method
561
    def fetch_csv(self, url,
562
                  pre_func=None,
563
                  post_func=None,
564
                  date_column='date',
565
                  date_format=None,
566
                  timezone=pytz.utc.zone,
567
                  symbol=None,
568
                  mask=True,
569
                  symbol_column=None,
570
                  special_params_checker=None,
571
                  **kwargs):
572
573
        # Show all the logs every time fetcher is used.
574
        csv_data_source = PandasRequestsCSV(
575
            url,
576
            pre_func,
577
            post_func,
578
            self.trading_environment,
579
            self.sim_params.period_start,
580
            self.sim_params.period_end,
581
            date_column,
582
            date_format,
583
            timezone,
584
            symbol,
585
            mask,
586
            symbol_column,
587
            data_frequency=self.data_frequency,
588
            special_params_checker=special_params_checker,
589
            **kwargs
590
        )
591
592
        # ingest this into dataportal
593
        self.data_portal.handle_extra_source(csv_data_source.df,
594
                                             self.sim_params)
595
596
        return csv_data_source
597
598
    def add_event(self, rule=None, callback=None):
599
        """
600
        Adds an event to the algorithm's EventManager.
601
        """
602
        self.event_manager.add_event(
603
            zipline.utils.events.Event(rule, callback),
604
        )
605
606
    @api_method
607
    def schedule_function(self,
608
                          func,
609
                          date_rule=None,
610
                          time_rule=None,
611
                          half_days=True):
612
        """
613
        Schedules a function to be called with some timed rules.
614
        """
615
        date_rule = date_rule or DateRuleFactory.every_day()
616
        time_rule = ((time_rule or TimeRuleFactory.market_open())
617
                     if self.sim_params.data_frequency == 'minute' else
618
                     # If we are in daily mode the time_rule is ignored.
619
                     zipline.utils.events.Always())
620
621
        self.add_event(
622
            make_eventrule(date_rule, time_rule, half_days),
623
            func,
624
        )
625
626
    @api_method
627
    def record(self, *args, **kwargs):
628
        """
629
        Track and record local variable (i.e. attributes) each day.
630
        """
631
        # Make 2 objects both referencing the same iterator
632
        args = [iter(args)] * 2
633
634
        # Zip generates list entries by calling `next` on each iterator it
635
        # receives.  In this case the two iterators are the same object, so the
636
        # call to next on args[0] will also advance args[1], resulting in zip
637
        # returning (a,b) (c,d) (e,f) rather than (a,a) (b,b) (c,c) etc.
638
        positionals = zip(*args)
639
        for name, value in chain(positionals, iteritems(kwargs)):
640
            self._recorded_vars[name] = value
641
642
    @api_method
643
    def set_benchmark(self, benchmark_sid):
644
        if self.initialized:
645
            raise SetBenchmarkOutsideInitialize()
646
647
        self.benchmark_sid = benchmark_sid
648
649
    @api_method
650
    @preprocess(symbol_str=ensure_upper_case)
651
    def symbol(self, symbol_str):
652
        """
653
        Default symbol lookup for any source that directly maps the
654
        symbol to the Asset (e.g. yahoo finance).
655
        """
656
        # If the user has not set the symbol lookup date,
657
        # use the period_end as the date for sybmol->sid resolution.
658
        _lookup_date = self._symbol_lookup_date if self._symbol_lookup_date is not None \
659
            else self.sim_params.period_end
660
661
        return self.asset_finder.lookup_symbol(
662
            symbol_str,
663
            as_of_date=_lookup_date,
664
        )
665
666
    @api_method
667
    def symbols(self, *args):
668
        """
669
        Default symbols lookup for any source that directly maps the
670
        symbol to the Asset (e.g. yahoo finance).
671
        """
672
        return [self.symbol(identifier) for identifier in args]
673
674
    @api_method
675
    def sid(self, a_sid):
676
        """
677
        Default sid lookup for any source that directly maps the integer sid
678
        to the Asset.
679
        """
680
        return self.asset_finder.retrieve_asset(a_sid)
681
682
    @api_method
683
    @preprocess(symbol=ensure_upper_case)
684
    def future_symbol(self, symbol):
685
        """ Lookup a futures contract with a given symbol.
686
687
        Parameters
688
        ----------
689
        symbol : str
690
            The symbol of the desired contract.
691
692
        Returns
693
        -------
694
        Future
695
            A Future object.
696
697
        Raises
698
        ------
699
        SymbolNotFound
700
            Raised when no contract named 'symbol' is found.
701
702
        """
703
        return self.asset_finder.lookup_future_symbol(symbol)
704
705
    @api_method
706
    @preprocess(root_symbol=ensure_upper_case)
707
    def future_chain(self, root_symbol, as_of_date=None):
708
        """ Look up a future chain with the specified parameters.
709
710
        Parameters
711
        ----------
712
        root_symbol : str
713
            The root symbol of a future chain.
714
        as_of_date : datetime.datetime or pandas.Timestamp or str, optional
715
            Date at which the chain determination is rooted. I.e. the
716
            existing contract whose notice date is first after this date is
717
            the primary contract, etc.
718
719
        Returns
720
        -------
721
        FutureChain
722
            The future chain matching the specified parameters.
723
724
        Raises
725
        ------
726
        RootSymbolNotFound
727
            If a future chain could not be found for the given root symbol.
728
        """
729
        if as_of_date:
730
            try:
731
                as_of_date = pd.Timestamp(as_of_date, tz='UTC')
732
            except ValueError:
733
                raise UnsupportedDatetimeFormat(input=as_of_date,
734
                                                method='future_chain')
735
        return FutureChain(
736
            asset_finder=self.asset_finder,
737
            get_datetime=self.get_datetime,
738
            root_symbol=root_symbol,
739
            as_of_date=as_of_date
740
        )
741
742
    def _calculate_order_value_amount(self, asset, value):
743
        """
744
        Calculates how many shares/contracts to order based on the type of
745
        asset being ordered.
746
        """
747
        last_price = self.trading_client.current_data[asset].price
748
749
        if tolerant_equals(last_price, 0):
750
            zero_message = "Price of 0 for {psid}; can't infer value".format(
751
                psid=asset
752
            )
753
            if self.logger:
754
                self.logger.debug(zero_message)
755
            # Don't place any order
756
            return 0
757
758
        if isinstance(asset, Future):
759
            value_multiplier = asset.contract_multiplier
760
        else:
761
            value_multiplier = 1
762
763
        return value / (last_price * value_multiplier)
764
765
    @api_method
766
    def order(self, sid, amount,
767
              limit_price=None,
768
              stop_price=None,
769
              style=None):
770
        """
771
        Place an order using the specified parameters.
772
        """
773
        # Truncate to the integer share count that's either within .0001 of
774
        # amount or closer to zero.
775
        # E.g. 3.9999 -> 4.0; 5.5 -> 5.0; -5.5 -> -5.0
776
        amount = int(round_if_near_integer(amount))
777
778
        # Raises a ZiplineError if invalid parameters are detected.
779
        self.validate_order_params(sid,
780
                                   amount,
781
                                   limit_price,
782
                                   stop_price,
783
                                   style)
784
785
        # Convert deprecated limit_price and stop_price parameters to use
786
        # ExecutionStyle objects.
787
        style = self.__convert_order_params_for_blotter(limit_price,
788
                                                        stop_price,
789
                                                        style)
790
        return self.blotter.order(sid, amount, style)
791
792
    def validate_order_params(self,
793
                              asset,
794
                              amount,
795
                              limit_price,
796
                              stop_price,
797
                              style):
798
        """
799
        Helper method for validating parameters to the order API function.
800
801
        Raises an UnsupportedOrderParameters if invalid arguments are found.
802
        """
803
804
        if not self.initialized:
805
            raise OrderDuringInitialize(
806
                msg="order() can only be called from within handle_data()"
807
            )
808
809
        if style:
810
            if limit_price:
811
                raise UnsupportedOrderParameters(
812
                    msg="Passing both limit_price and style is not supported."
813
                )
814
815
            if stop_price:
816
                raise UnsupportedOrderParameters(
817
                    msg="Passing both stop_price and style is not supported."
818
                )
819
820
        if not isinstance(asset, Asset):
821
            raise UnsupportedOrderParameters(
822
                msg="Passing non-Asset argument to 'order()' is not supported."
823
                    " Use 'sid()' or 'symbol()' methods to look up an Asset."
824
            )
825
826
        for control in self.trading_controls:
827
            control.validate(asset,
828
                             amount,
829
                             self.portfolio,
830
                             self.get_datetime(),
831
                             self.trading_client.current_data)
832
833
    @staticmethod
834
    def __convert_order_params_for_blotter(limit_price, stop_price, style):
835
        """
836
        Helper method for converting deprecated limit_price and stop_price
837
        arguments into ExecutionStyle instances.
838
839
        This function assumes that either style == None or (limit_price,
840
        stop_price) == (None, None).
841
        """
842
        # TODO_SS: DeprecationWarning for usage of limit_price and stop_price.
843
        if style:
844
            assert (limit_price, stop_price) == (None, None)
845
            return style
846
        if limit_price and stop_price:
847
            return StopLimitOrder(limit_price, stop_price)
848
        if limit_price:
849
            return LimitOrder(limit_price)
850
        if stop_price:
851
            return StopOrder(stop_price)
852
        else:
853
            return MarketOrder()
854
855
    @api_method
856
    def order_value(self, sid, value,
857
                    limit_price=None, stop_price=None, style=None):
858
        """
859
        Place an order by desired value rather than desired number of shares.
860
        If the requested sid exists, the requested value is
861
        divided by its price to imply the number of shares to transact.
862
        If the Asset being ordered is a Future, the 'value' calculated
863
        is actually the exposure, as Futures have no 'value'.
864
865
        value > 0 :: Buy/Cover
866
        value < 0 :: Sell/Short
867
        Market order:    order(sid, value)
868
        Limit order:     order(sid, value, limit_price)
869
        Stop order:      order(sid, value, None, stop_price)
870
        StopLimit order: order(sid, value, limit_price, stop_price)
871
        """
872
        amount = self._calculate_order_value_amount(sid, value)
873
        return self.order(sid, amount,
874
                          limit_price=limit_price,
875
                          stop_price=stop_price,
876
                          style=style)
877
878
    @property
879
    def recorded_vars(self):
880
        return copy(self._recorded_vars)
881
882
    @property
883
    def portfolio(self):
884
        return self.updated_portfolio()
885
886
    def updated_portfolio(self):
887
        if self._portfolio is None and self.perf_tracker is not None:
888
            self._portfolio = \
889
                self.perf_tracker.get_portfolio(self.datetime)
890
        return self._portfolio
891
892
    @property
893
    def account(self):
894
        return self.updated_account()
895
896
    def updated_account(self):
897
        if self._account is None and self.perf_tracker is not None:
898
            self._account = \
899
                self.perf_tracker.get_account(self.datetime)
900
        return self._account
901
902
    def set_logger(self, logger):
903
        self.logger = logger
904
905
    def on_dt_changed(self, dt):
906
        """
907
        Callback triggered by the simulation loop whenever the current dt
908
        changes.
909
910
        Any logic that should happen exactly once at the start of each datetime
911
        group should happen here.
912
        """
913
        assert isinstance(dt, datetime), \
914
            "Attempt to set algorithm's current time with non-datetime"
915
        assert dt.tzinfo == pytz.utc, \
916
            "Algorithm expects a utc datetime"
917
918
        self.datetime = dt
919
        self.perf_tracker.set_date(dt)
920
        self.blotter.set_date(dt)
921
922
        self._portfolio = None
923
        self._account = None
924
925
    @api_method
926
    def get_datetime(self, tz=None):
927
        """
928
        Returns the simulation datetime.
929
        """
930
        dt = self.datetime
931
        assert dt.tzinfo == pytz.utc, "Algorithm should have a utc datetime"
932
933
        if tz is not None:
934
            # Convert to the given timezone passed as a string or tzinfo.
935
            if isinstance(tz, string_types):
936
                tz = pytz.timezone(tz)
937
            dt = dt.astimezone(tz)
938
939
        return dt  # datetime.datetime objects are immutable.
940
941
    def update_dividends(self, dividend_frame):
942
        """
943
        Set DataFrame used to process dividends.  DataFrame columns should
944
        contain at least the entries in zp.DIVIDEND_FIELDS.
945
        """
946
        self.perf_tracker.update_dividends(dividend_frame)
947
948
    @api_method
949
    def set_slippage(self, slippage):
950
        if not isinstance(slippage, SlippageModel):
951
            raise UnsupportedSlippageModel()
952
        if self.initialized:
953
            raise OverrideSlippagePostInit()
954
        self.blotter.slippage_func = slippage
955
956
    @api_method
957
    def set_commission(self, commission):
958
        if not isinstance(commission, (PerShare, PerTrade, PerDollar)):
959
            raise UnsupportedCommissionModel()
960
961
        if self.initialized:
962
            raise OverrideCommissionPostInit()
963
        self.blotter.commission = commission
964
965
    @api_method
966
    def set_symbol_lookup_date(self, dt):
967
        """
968
        Set the date for which symbols will be resolved to their sids
969
        (symbols may map to different firms or underlying assets at
970
        different times)
971
        """
972
        try:
973
            self._symbol_lookup_date = pd.Timestamp(dt, tz='UTC')
974
        except ValueError:
975
            raise UnsupportedDatetimeFormat(input=dt,
976
                                            method='set_symbol_lookup_date')
977
978
    # Remain backwards compatibility
979
    @property
980
    def data_frequency(self):
981
        return self.sim_params.data_frequency
982
983
    @data_frequency.setter
984
    def data_frequency(self, value):
985
        assert value in ('daily', 'minute')
986
        self.sim_params.data_frequency = value
987
988
    @api_method
989
    def order_percent(self, sid, percent,
990
                      limit_price=None, stop_price=None, style=None):
991
        """
992
        Place an order in the specified asset corresponding to the given
993
        percent of the current portfolio value.
994
995
        Note that percent must expressed as a decimal (0.50 means 50\%).
996
        """
997
        value = self.portfolio.portfolio_value * percent
998
        return self.order_value(sid, value,
999
                                limit_price=limit_price,
1000
                                stop_price=stop_price,
1001
                                style=style)
1002
1003
    @api_method
1004
    def order_target(self, sid, target,
1005
                     limit_price=None, stop_price=None, style=None):
1006
        """
1007
        Place an order to adjust a position to a target number of shares. If
1008
        the position doesn't already exist, this is equivalent to placing a new
1009
        order. If the position does exist, this is equivalent to placing an
1010
        order for the difference between the target number of shares and the
1011
        current number of shares.
1012
        """
1013
        if sid in self.portfolio.positions:
1014
            current_position = self.portfolio.positions[sid].amount
1015
            req_shares = target - current_position
1016
            return self.order(sid, req_shares,
1017
                              limit_price=limit_price,
1018
                              stop_price=stop_price,
1019
                              style=style)
1020
        else:
1021
            return self.order(sid, target,
1022
                              limit_price=limit_price,
1023
                              stop_price=stop_price,
1024
                              style=style)
1025
1026
    @api_method
1027
    def order_target_value(self, sid, target,
1028
                           limit_price=None, stop_price=None, style=None):
1029
        """
1030
        Place an order to adjust a position to a target value. If
1031
        the position doesn't already exist, this is equivalent to placing a new
1032
        order. If the position does exist, this is equivalent to placing an
1033
        order for the difference between the target value and the
1034
        current value.
1035
        If the Asset being ordered is a Future, the 'target value' calculated
1036
        is actually the target exposure, as Futures have no 'value'.
1037
        """
1038
        target_amount = self._calculate_order_value_amount(sid, target)
1039
        return self.order_target(sid, target_amount,
1040
                                 limit_price=limit_price,
1041
                                 stop_price=stop_price,
1042
                                 style=style)
1043
1044
    @api_method
1045
    def order_target_percent(self, sid, target,
1046
                             limit_price=None, stop_price=None, style=None):
1047
        """
1048
        Place an order to adjust a position to a target percent of the
1049
        current portfolio value. If the position doesn't already exist, this is
1050
        equivalent to placing a new order. If the position does exist, this is
1051
        equivalent to placing an order for the difference between the target
1052
        percent and the current percent.
1053
1054
        Note that target must expressed as a decimal (0.50 means 50\%).
1055
        """
1056
        target_value = self.portfolio.portfolio_value * target
1057
        return self.order_target_value(sid, target_value,
1058
                                       limit_price=limit_price,
1059
                                       stop_price=stop_price,
1060
                                       style=style)
1061
1062
    @api_method
1063
    def get_open_orders(self, sid=None):
1064
        if sid is None:
1065
            return {
1066
                key: [order.to_api_obj() for order in orders]
1067
                for key, orders in iteritems(self.blotter.open_orders)
1068
                if orders
1069
            }
1070
        if sid in self.blotter.open_orders:
1071
            orders = self.blotter.open_orders[sid]
1072
            return [order.to_api_obj() for order in orders]
1073
        return []
1074
1075
    @api_method
1076
    def get_order(self, order_id):
1077
        if order_id in self.blotter.orders:
1078
            return self.blotter.orders[order_id].to_api_obj()
1079
1080
    @api_method
1081
    def cancel_order(self, order_param):
1082
        order_id = order_param
1083
        if isinstance(order_param, zipline.protocol.Order):
1084
            order_id = order_param.id
1085
1086
        self.blotter.cancel(order_id)
1087
1088
    @api_method
1089
    @require_initialized(HistoryInInitialize())
1090
    def history(self, sids, bar_count, frequency, field, ffill=True):
1091
        if self.data_portal is None:
1092
            raise Exception("no data portal!")
1093
1094
        return self.data_portal.get_history_window(
1095
            sids,
1096
            self.datetime,
1097
            bar_count,
1098
            frequency,
1099
            field,
1100
            ffill,
1101
        )
1102
1103
    ####################
1104
    # Account Controls #
1105
    ####################
1106
1107
    def register_account_control(self, control):
1108
        """
1109
        Register a new AccountControl to be checked on each bar.
1110
        """
1111
        if self.initialized:
1112
            raise RegisterAccountControlPostInit()
1113
        self.account_controls.append(control)
1114
1115
    def validate_account_controls(self):
1116
        for control in self.account_controls:
1117
            control.validate(self.portfolio,
1118
                             self.account,
1119
                             self.get_datetime(),
1120
                             self.trading_client.current_data)
1121
1122
    @api_method
1123
    def set_max_leverage(self, max_leverage=None):
1124
        """
1125
        Set a limit on the maximum leverage of the algorithm.
1126
        """
1127
        control = MaxLeverage(max_leverage)
1128
        self.register_account_control(control)
1129
1130
    ####################
1131
    # Trading Controls #
1132
    ####################
1133
1134
    def register_trading_control(self, control):
1135
        """
1136
        Register a new TradingControl to be checked prior to order calls.
1137
        """
1138
        if self.initialized:
1139
            raise RegisterTradingControlPostInit()
1140
        self.trading_controls.append(control)
1141
1142
    @api_method
1143
    def set_max_position_size(self,
1144
                              sid=None,
1145
                              max_shares=None,
1146
                              max_notional=None):
1147
        """
1148
        Set a limit on the number of shares and/or dollar value held for the
1149
        given sid. Limits are treated as absolute values and are enforced at
1150
        the time that the algo attempts to place an order for sid. This means
1151
        that it's possible to end up with more than the max number of shares
1152
        due to splits/dividends, and more than the max notional due to price
1153
        improvement.
1154
1155
        If an algorithm attempts to place an order that would result in
1156
        increasing the absolute value of shares/dollar value exceeding one of
1157
        these limits, raise a TradingControlException.
1158
        """
1159
        control = MaxPositionSize(asset=sid,
1160
                                  max_shares=max_shares,
1161
                                  max_notional=max_notional)
1162
        self.register_trading_control(control)
1163
1164
    @api_method
1165
    def set_max_order_size(self, sid=None, max_shares=None, max_notional=None):
1166
        """
1167
        Set a limit on the number of shares and/or dollar value of any single
1168
        order placed for sid.  Limits are treated as absolute values and are
1169
        enforced at the time that the algo attempts to place an order for sid.
1170
1171
        If an algorithm attempts to place an order that would result in
1172
        exceeding one of these limits, raise a TradingControlException.
1173
        """
1174
        control = MaxOrderSize(asset=sid,
1175
                               max_shares=max_shares,
1176
                               max_notional=max_notional)
1177
        self.register_trading_control(control)
1178
1179
    @api_method
1180
    def set_max_order_count(self, max_count):
1181
        """
1182
        Set a limit on the number of orders that can be placed within the given
1183
        time interval.
1184
        """
1185
        control = MaxOrderCount(max_count)
1186
        self.register_trading_control(control)
1187
1188
    @api_method
1189
    def set_do_not_order_list(self, restricted_list):
1190
        """
1191
        Set a restriction on which sids can be ordered.
1192
        """
1193
        control = RestrictedListOrder(restricted_list)
1194
        self.register_trading_control(control)
1195
1196
    @api_method
1197
    def set_long_only(self):
1198
        """
1199
        Set a rule specifying that this algorithm cannot take short positions.
1200
        """
1201
        self.register_trading_control(LongOnly())
1202
1203
    ##############
1204
    # Pipeline API
1205
    ##############
1206
    @api_method
1207
    @require_not_initialized(AttachPipelineAfterInitialize())
1208
    def attach_pipeline(self, pipeline, name, chunksize=None):
1209
        """
1210
        Register a pipeline to be computed at the start of each day.
1211
        """
1212
        if self._pipelines:
1213
            raise NotImplementedError("Multiple pipelines are not supported.")
1214
        if chunksize is None:
1215
            # Make the first chunk smaller to get more immediate results:
1216
            # (one week, then every half year)
1217
            chunks = iter(chain([5], repeat(126)))
1218
        else:
1219
            chunks = iter(repeat(int(chunksize)))
1220
        self._pipelines[name] = pipeline, chunks
1221
1222
        # Return the pipeline to allow expressions like
1223
        # p = attach_pipeline(Pipeline(), 'name')
1224
        return pipeline
1225
1226
    @api_method
1227
    @require_initialized(PipelineOutputDuringInitialize())
1228
    def pipeline_output(self, name):
1229
        """
1230
        Get the results of pipeline with name `name`.
1231
1232
        Parameters
1233
        ----------
1234
        name : str
1235
            Name of the pipeline for which results are requested.
1236
1237
        Returns
1238
        -------
1239
        results : pd.DataFrame
1240
            DataFrame containing the results of the requested pipeline for
1241
            the current simulation date.
1242
1243
        Raises
1244
        ------
1245
        NoSuchPipeline
1246
            Raised when no pipeline with the name `name` has been registered.
1247
1248
        See Also
1249
        --------
1250
        :meth:`zipline.pipeline.engine.PipelineEngine.run_pipeline`
1251
        """
1252
        # NOTE: We don't currently support multiple pipelines, but we plan to
1253
        # in the future.
1254
        try:
1255
            p, chunks = self._pipelines[name]
1256
        except KeyError:
1257
            raise NoSuchPipeline(
1258
                name=name,
1259
                valid=list(self._pipelines.keys()),
1260
            )
1261
        return self._pipeline_output(p, chunks)
1262
1263
    def _pipeline_output(self, pipeline, chunks):
1264
        """
1265
        Internal implementation of `pipeline_output`.
1266
        """
1267
        today = normalize_date(self.get_datetime())
1268
        try:
1269
            data = self._pipeline_cache.unwrap(today)
1270
        except Expired:
1271
            data, valid_until = self._run_pipeline(
1272
                pipeline, today, next(chunks),
1273
            )
1274
            self._pipeline_cache = CachedObject(data, valid_until)
1275
1276
        # Now that we have a cached result, try to return the data for today.
1277
        try:
1278
            return data.loc[today]
1279
        except KeyError:
1280
            # This happens if no assets passed the pipeline screen on a given
1281
            # day.
1282
            return pd.DataFrame(index=[], columns=data.columns)
1283
1284
    def _run_pipeline(self, pipeline, start_date, chunksize):
1285
        """
1286
        Compute `pipeline`, providing values for at least `start_date`.
1287
1288
        Produces a DataFrame containing data for days between `start_date` and
1289
        `end_date`, where `end_date` is defined by:
1290
1291
            `end_date = min(start_date + chunksize trading days,
1292
                            simulation_end)`
1293
1294
        Returns
1295
        -------
1296
        (data, valid_until) : tuple (pd.DataFrame, pd.Timestamp)
1297
1298
        See Also
1299
        --------
1300
        PipelineEngine.run_pipeline
1301
        """
1302
        days = self.trading_environment.trading_days
1303
1304
        # Load data starting from the previous trading day...
1305
        start_date_loc = days.get_loc(start_date)
1306
1307
        # ...continuing until either the day before the simulation end, or
1308
        # until chunksize days of data have been loaded.
1309
        sim_end = self.sim_params.last_close.normalize()
1310
        end_loc = min(start_date_loc + chunksize, days.get_loc(sim_end))
1311
        end_date = days[end_loc]
1312
1313
        return \
1314
            self.engine.run_pipeline(pipeline, start_date, end_date), end_date
1315
1316
    ##################
1317
    # End Pipeline API
1318
    ##################
1319
1320
    @classmethod
1321
    def all_api_methods(cls):
1322
        """
1323
        Return a list of all the TradingAlgorithm API methods.
1324
        """
1325
        return [
1326
            fn for fn in itervalues(vars(cls))
1327
            if getattr(fn, 'is_api_method', False)
1328
        ]
1329