Completed
Pull Request — master (#858)
by Eddie
01:37
created

zipline.TradingAlgorithm.data_frequency()   A

Complexity

Conditions 2

Size

Total Lines 3

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 2
dl 0
loc 3
rs 10
1
#
2
# Copyright 2015 Quantopian, Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
from copy import copy
16
17
import pytz
18
import pandas as pd
19
from pandas.tseries.tools import normalize_date
20
import numpy as np
21
22
from datetime import datetime
23
from itertools import chain, repeat
24
from numbers import Integral
25
26
from six import (
27
    exec_,
28
    iteritems,
29
    itervalues,
30
    string_types,
31
)
32
33
from zipline.errors import (
34
    AttachPipelineAfterInitialize,
35
    HistoryInInitialize,
36
    NoSuchPipeline,
37
    OrderDuringInitialize,
38
    OverrideCommissionPostInit,
39
    OverrideSlippagePostInit,
40
    PipelineOutputDuringInitialize,
41
    RegisterAccountControlPostInit,
42
    RegisterTradingControlPostInit,
43
    SetBenchmarkOutsideInitialize,
44
    UnsupportedCommissionModel,
45
    UnsupportedDatetimeFormat,
46
    UnsupportedOrderParameters,
47
    UnsupportedSlippageModel,
48
)
49
from zipline.finance.trading import TradingEnvironment
50
from zipline.finance.blotter import Blotter
51
from zipline.finance.commission import PerShare, PerTrade, PerDollar
52
from zipline.finance.controls import (
53
    LongOnly,
54
    MaxOrderCount,
55
    MaxOrderSize,
56
    MaxPositionSize,
57
    MaxLeverage,
58
    RestrictedListOrder
59
)
60
from zipline.finance.execution import (
61
    LimitOrder,
62
    MarketOrder,
63
    StopLimitOrder,
64
    StopOrder,
65
)
66
from zipline.finance.performance import PerformanceTracker
67
from zipline.finance.slippage import (
68
    VolumeShareSlippage,
69
    SlippageModel
70
)
71
from zipline.assets import Asset, Future
72
from zipline.assets.futures import FutureChain
73
from zipline.gens.tradesimulation import AlgorithmSimulator
74
from zipline.pipeline.engine import (
75
    NoOpPipelineEngine,
76
    SimplePipelineEngine,
77
)
78
from zipline.utils.api_support import (
79
    api_method,
80
    require_initialized,
81
    require_not_initialized,
82
    ZiplineAPI,
83
)
84
from zipline.utils.input_validation import ensure_upper_case
85
from zipline.utils.cache import (
86
    CachedObject,
87
    Expired
88
)
89
import zipline.utils.events
90
from zipline.utils.events import (
91
    EventManager,
92
    make_eventrule,
93
    DateRuleFactory,
94
    TimeRuleFactory,
95
)
96
from zipline.utils.factory import create_simulation_parameters
97
from zipline.utils.math_utils import (
98
    tolerant_equals,
99
    round_if_near_integer
100
)
101
from zipline.utils.preprocess import preprocess
102
103
import zipline.protocol
104
from zipline.sources.requests_csv import PandasRequestsCSV
105
106
from zipline.gens.sim_engine import (
107
    MinuteSimulationClock,
108
    DailySimulationClock,
109
)
110
from zipline.sources.benchmark_source import BenchmarkSource
111
112
DEFAULT_CAPITAL_BASE = float("1.0e5")
113
114
115
class TradingAlgorithm(object):
116
    """
117
    Base class for trading algorithms. Inherit and overload
118
    initialize() and handle_data(data).
119
120
    A new algorithm could look like this:
121
    ```
122
    from zipline.api import order, symbol
123
124
    def initialize(context):
125
        context.sid = symbol('AAPL')
126
        context.amount = 100
127
128
    def handle_data(context, data):
129
        sid = context.sid
130
        amount = context.amount
131
        order(sid, amount)
132
    ```
133
    To then to run this algorithm pass these functions to
134
    TradingAlgorithm:
135
136
    my_algo = TradingAlgorithm(initialize, handle_data)
137
    stats = my_algo.run(data)
138
139
    """
140
141
    def __init__(self, *args, **kwargs):
142
        """Initialize sids and other state variables.
143
144
        :Arguments:
145
        :Optional:
146
            initialize : function
147
                Function that is called with a single
148
                argument at the begninning of the simulation.
149
            handle_data : function
150
                Function that is called with 2 arguments
151
                (context and data) on every bar.
152
            script : str
153
                Algoscript that contains initialize and
154
                handle_data function definition.
155
            data_frequency : {'daily', 'minute'}
156
               The duration of the bars.
157
            capital_base : float <default: 1.0e5>
158
               How much capital to start with.
159
            asset_finder : An AssetFinder object
160
                A new AssetFinder object to be used in this TradingEnvironment
161
            equities_metadata : can be either:
162
                            - dict
163
                            - pandas.DataFrame
164
                            - object with 'read' property
165
                If dict is provided, it must have the following structure:
166
                * keys are the identifiers
167
                * values are dicts containing the metadata, with the metadata
168
                  field name as the key
169
                If pandas.DataFrame is provided, it must have the
170
                following structure:
171
                * column names must be the metadata fields
172
                * index must be the different asset identifiers
173
                * array contents should be the metadata value
174
                If an object with a 'read' property is provided, 'read' must
175
                return rows containing at least one of 'sid' or 'symbol' along
176
                with the other metadata fields.
177
            identifiers : List
178
                Any asset identifiers that are not provided in the
179
                equities_metadata, but will be traded by this TradingAlgorithm
180
        """
181
        self.sources = []
182
183
        # List of trading controls to be used to validate orders.
184
        self.trading_controls = []
185
186
        # List of account controls to be checked on each bar.
187
        self.account_controls = []
188
189
        self._recorded_vars = {}
190
        self.namespace = kwargs.pop('namespace', {})
191
192
        self._platform = kwargs.pop('platform', 'zipline')
193
194
        self.logger = None
195
196
        self.data_portal = None
197
198
        # If an env has been provided, pop it
199
        self.trading_environment = kwargs.pop('env', None)
200
201
        if self.trading_environment is None:
202
            self.trading_environment = TradingEnvironment()
203
204
        # Update the TradingEnvironment with the provided asset metadata
205
        self.trading_environment.write_data(
206
            equities_data=kwargs.pop('equities_metadata', {}),
207
            equities_identifiers=kwargs.pop('identifiers', []),
208
            futures_data=kwargs.pop('futures_metadata', {}),
209
        )
210
211
        # set the capital base
212
        self.capital_base = kwargs.pop('capital_base', DEFAULT_CAPITAL_BASE)
213
        self.sim_params = kwargs.pop('sim_params', None)
214
        if self.sim_params is None:
215
            self.sim_params = create_simulation_parameters(
216
                capital_base=self.capital_base,
217
                start=kwargs.pop('start', None),
218
                end=kwargs.pop('end', None),
219
                env=self.trading_environment,
220
            )
221
        else:
222
            self.sim_params.update_internal_from_env(self.trading_environment)
223
224
        self.perf_tracker = None
225
        # Pull in the environment's new AssetFinder for quick reference
226
        self.asset_finder = self.trading_environment.asset_finder
227
228
        # Initialize Pipeline API data.
229
        self.init_engine(kwargs.pop('get_pipeline_loader', None))
230
        self._pipelines = {}
231
        # Create an always-expired cache so that we compute the first time data
232
        # is requested.
233
        self._pipeline_cache = CachedObject(None, pd.Timestamp(0, tz='UTC'))
234
235
        self.blotter = kwargs.pop('blotter', None)
236
        if not self.blotter:
237
            self.blotter = Blotter(
238
                slippage_func=VolumeShareSlippage(),
239
                commission=PerShare(),
240
                data_frequency=self.data_frequency,
241
            )
242
243
        # The symbol lookup date specifies the date to use when resolving
244
        # symbols to sids, and can be set using set_symbol_lookup_date()
245
        self._symbol_lookup_date = None
246
247
        self._portfolio = None
248
        self._account = None
249
250
        # If string is passed in, execute and get reference to
251
        # functions.
252
        self.algoscript = kwargs.pop('script', None)
253
254
        self._initialize = None
255
        self._before_trading_start = None
256
        self._analyze = None
257
258
        self.event_manager = EventManager(
259
            create_context=kwargs.pop('create_event_context', None),
260
        )
261
262
        if self.algoscript is not None:
263
            filename = kwargs.pop('algo_filename', None)
264
            if filename is None:
265
                filename = '<string>'
266
            code = compile(self.algoscript, filename, 'exec')
267
            exec_(code, self.namespace)
268
            self._initialize = self.namespace.get('initialize')
269
            if 'handle_data' not in self.namespace:
270
                raise ValueError('You must define a handle_data function.')
271
            else:
272
                self._handle_data = self.namespace['handle_data']
273
274
            self._before_trading_start = \
275
                self.namespace.get('before_trading_start')
276
            # Optional analyze function, gets called after run
277
            self._analyze = self.namespace.get('analyze')
278
279
        elif kwargs.get('initialize') and kwargs.get('handle_data'):
280
            if self.algoscript is not None:
281
                raise ValueError('You can not set script and \
282
                initialize/handle_data.')
283
            self._initialize = kwargs.pop('initialize')
284
            self._handle_data = kwargs.pop('handle_data')
285
            self._before_trading_start = kwargs.pop('before_trading_start',
286
                                                    None)
287
            self._analyze = kwargs.pop('analyze', None)
288
289
        self.event_manager.add_event(
290
            zipline.utils.events.Event(
291
                zipline.utils.events.Always(),
292
                # We pass handle_data.__func__ to get the unbound method.
293
                # We will explicitly pass the algorithm to bind it again.
294
                self.handle_data.__func__,
295
            ),
296
            prepend=True,
297
        )
298
299
        # If method not defined, NOOP
300
        if self._initialize is None:
301
            self._initialize = lambda x: None
302
303
        # Alternative way of setting data_frequency for backwards
304
        # compatibility.
305
        if 'data_frequency' in kwargs:
306
            self.data_frequency = kwargs.pop('data_frequency')
307
308
        # Prepare the algo for initialization
309
        self.initialized = False
310
        self.initialize_args = args
311
        self.initialize_kwargs = kwargs
312
313
        self.benchmark_sid = kwargs.pop('benchmark_sid', None)
314
315
    def init_engine(self, get_loader):
316
        """
317
        Construct and store a PipelineEngine from loader.
318
319
        If get_loader is None, constructs a NoOpPipelineEngine.
320
        """
321
        if get_loader is not None:
322
            self.engine = SimplePipelineEngine(
323
                get_loader,
324
                self.trading_environment.trading_days,
325
                self.asset_finder,
326
            )
327
        else:
328
            self.engine = NoOpPipelineEngine()
329
330
    def initialize(self, *args, **kwargs):
331
        """
332
        Call self._initialize with `self` made available to Zipline API
333
        functions.
334
        """
335
        with ZiplineAPI(self):
336
            self._initialize(self, *args, **kwargs)
337
338
    def before_trading_start(self, data):
339
        if self._before_trading_start is None:
340
            return
341
342
        self._before_trading_start(self, data)
343
344
    def handle_data(self, data):
345
        self._handle_data(self, data)
346
347
        # Unlike trading controls which remain constant unless placing an
348
        # order, account controls can change each bar. Thus, must check
349
        # every bar no matter if the algorithm places an order or not.
350
        self.validate_account_controls()
351
352
    def analyze(self, perf):
353
        if self._analyze is None:
354
            return
355
356
        with ZiplineAPI(self):
357
            self._analyze(self, perf)
358
359
    def __repr__(self):
360
        """
361
        N.B. this does not yet represent a string that can be used
362
        to instantiate an exact copy of an algorithm.
363
364
        However, it is getting close, and provides some value as something
365
        that can be inspected interactively.
366
        """
367
        return """
368
{class_name}(
369
    capital_base={capital_base}
370
    sim_params={sim_params},
371
    initialized={initialized},
372
    slippage={slippage},
373
    commission={commission},
374
    blotter={blotter},
375
    recorded_vars={recorded_vars})
376
""".strip().format(class_name=self.__class__.__name__,
377
                   capital_base=self.capital_base,
378
                   sim_params=repr(self.sim_params),
379
                   initialized=self.initialized,
380
                   slippage=repr(self.blotter.slippage_func),
381
                   commission=repr(self.blotter.commission),
382
                   blotter=repr(self.blotter),
383
                   recorded_vars=repr(self.recorded_vars))
384
385
    def _create_clock(self):
386
        """
387
        If the clock property is not set, then create one based on frequency.
388
        """
389
        if self.sim_params.data_frequency == 'minute':
390
            env = self.trading_environment
391
            trading_o_and_c = env.open_and_closes.ix[
392
                self.sim_params.trading_days]
393
            market_opens = trading_o_and_c['market_open'].values.astype(
394
                'datetime64[ns]').astype(np.int64)
395
            market_closes = trading_o_and_c['market_close'].values.astype(
396
                'datetime64[ns]').astype(np.int64)
397
398
            minutely_emission = self.sim_params.emission_rate == "minute"
399
400
            clock = MinuteSimulationClock(
401
                self.sim_params.trading_days,
402
                market_opens,
403
                market_closes,
404
                env.trading_days,
405
                minutely_emission
406
            )
407
            self.data_portal.setup_offset_cache(
408
                clock.minutes_by_day,
409
                clock.minutes_to_day,
410
                self.sim_params.trading_days)
411
            return clock
412
        else:
413
            return DailySimulationClock(self.sim_params.trading_days)
414
415
    def _create_benchmark_source(self):
416
        return BenchmarkSource(
417
            self.benchmark_sid,
418
            self.trading_environment,
419
            self.sim_params.trading_days,
420
            self.data_portal,
421
            emission_rate=self.sim_params.emission_rate,
422
        )
423
424
    def _create_generator(self, sim_params):
425
        if sim_params is not None:
426
            self.sim_params = sim_params
427
428
        if self.perf_tracker is None:
429
            # HACK: When running with the `run` method, we set perf_tracker to
430
            # None so that it will be overwritten here.
431
            self.perf_tracker = PerformanceTracker(
432
                sim_params=self.sim_params,
433
                env=self.trading_environment,
434
                data_portal=self.data_portal
435
            )
436
437
            # Set the dt initially to the period start by forcing it to change.
438
            self.on_dt_changed(self.sim_params.period_start)
439
440
        if not self.initialized:
441
            self.initialize(*self.initialize_args, **self.initialize_kwargs)
442
            self.initialized = True
443
444
        self.trading_client = self._create_trading_client(sim_params)
445
446
        return self.trading_client.transform()
447
448
    def _create_trading_client(self, sim_params):
449
        return AlgorithmSimulator(
450
            self,
451
            sim_params,
452
            self.data_portal,
453
            self._create_clock(),
454
            self._create_benchmark_source()
455
        )
456
457
    def get_generator(self):
458
        """
459
        Override this method to add new logic to the construction
460
        of the generator. Overrides can use the _create_generator
461
        method to get a standard construction generator.
462
        """
463
        return self._create_generator(self.sim_params)
464
465
    def run(self, data_portal=None):
466
        """Run the algorithm.
467
468
        :Arguments:
469
            source : DataPortal
470
471
        :Returns:
472
            daily_stats : pandas.DataFrame
473
              Daily performance metrics such as returns, alpha etc.
474
475
        """
476
        self.data_portal = data_portal
477
478
        # Force a reset of the performance tracker, in case
479
        # this is a repeat run of the algorithm.
480
        self.perf_tracker = None
481
482
        # Create zipline and loop through simulated_trading.
483
        # Each iteration returns a perf dictionary
484
        perfs = []
485
        for perf in self.get_generator():
486
            perfs.append(perf)
487
488
        # convert perf dict to pandas dataframe
489
        daily_stats = self._create_daily_stats(perfs)
490
491
        self.analyze(daily_stats)
492
493
        return daily_stats
494
495
    def _write_and_map_id_index_to_sids(self, identifiers, as_of_date):
496
        # Build new Assets for identifiers that can't be resolved as
497
        # sids/Assets
498
        identifiers_to_build = []
499
        for identifier in identifiers:
500
            asset = None
501
502
            if isinstance(identifier, Asset):
503
                asset = self.asset_finder.retrieve_asset(sid=identifier.sid,
504
                                                         default_none=True)
505
            elif isinstance(identifier, Integral):
506
                asset = self.asset_finder.retrieve_asset(sid=identifier,
507
                                                         default_none=True)
508
            if asset is None:
509
                identifiers_to_build.append(identifier)
510
511
        self.trading_environment.write_data(
512
            equities_identifiers=identifiers_to_build)
513
514
        # We need to clear out any cache misses that were stored while trying
515
        # to do lookups.  The real fix for this problem is to not construct an
516
        # AssetFinder until we `run()` when we actually have all the data we
517
        # need to so.
518
        self.asset_finder._reset_caches()
519
520
        return self.asset_finder.map_identifier_index_to_sids(
521
            identifiers, as_of_date,
522
        )
523
524
    def _create_daily_stats(self, perfs):
525
        # create daily and cumulative stats dataframe
526
        daily_perfs = []
527
        # TODO: the loop here could overwrite expected properties
528
        # of daily_perf. Could potentially raise or log a
529
        # warning.
530
        for perf in perfs:
531
            if 'daily_perf' in perf:
532
533
                perf['daily_perf'].update(
534
                    perf['daily_perf'].pop('recorded_vars')
535
                )
536
                perf['daily_perf'].update(perf['cumulative_risk_metrics'])
537
                daily_perfs.append(perf['daily_perf'])
538
            else:
539
                self.risk_report = perf
540
541
        daily_dts = [np.datetime64(perf['period_close'], utc=True)
542
                     for perf in daily_perfs]
543
        daily_stats = pd.DataFrame(daily_perfs, index=daily_dts)
544
545
        return daily_stats
546
547
    @api_method
548
    def get_environment(self, field='platform'):
549
        env = {
550
            'arena': self.sim_params.arena,
551
            'data_frequency': self.sim_params.data_frequency,
552
            'start': self.sim_params.first_open,
553
            'end': self.sim_params.last_close,
554
            'capital_base': self.sim_params.capital_base,
555
            'platform': self._platform
556
        }
557
        if field == '*':
558
            return env
559
        else:
560
            return env[field]
561
562
    @api_method
563
    def fetch_csv(self, url,
564
                  pre_func=None,
565
                  post_func=None,
566
                  date_column='date',
567
                  date_format=None,
568
                  timezone=pytz.utc.zone,
569
                  symbol=None,
570
                  mask=True,
571
                  symbol_column=None,
572
                  special_params_checker=None,
573
                  **kwargs):
574
575
        # Show all the logs every time fetcher is used.
576
        csv_data_source = PandasRequestsCSV(
577
            url,
578
            pre_func,
579
            post_func,
580
            self.trading_environment,
581
            self.sim_params.period_start,
582
            self.sim_params.period_end,
583
            date_column,
584
            date_format,
585
            timezone,
586
            symbol,
587
            mask,
588
            symbol_column,
589
            data_frequency=self.data_frequency,
590
            special_params_checker=special_params_checker,
591
            **kwargs
592
        )
593
594
        # ingest this into dataportal
595
        self.data_portal.handle_extra_source(csv_data_source.df,
596
                                             self.sim_params)
597
598
        return csv_data_source
599
600
    def add_event(self, rule=None, callback=None):
601
        """
602
        Adds an event to the algorithm's EventManager.
603
        """
604
        self.event_manager.add_event(
605
            zipline.utils.events.Event(rule, callback),
606
        )
607
608
    @api_method
609
    def schedule_function(self,
610
                          func,
611
                          date_rule=None,
612
                          time_rule=None,
613
                          half_days=True):
614
        """
615
        Schedules a function to be called with some timed rules.
616
        """
617
        date_rule = date_rule or DateRuleFactory.every_day()
618
        time_rule = ((time_rule or TimeRuleFactory.market_open())
619
                     if self.sim_params.data_frequency == 'minute' else
620
                     # If we are in daily mode the time_rule is ignored.
621
                     zipline.utils.events.Always())
622
623
        self.add_event(
624
            make_eventrule(date_rule, time_rule, half_days),
625
            func,
626
        )
627
628
    @api_method
629
    def record(self, *args, **kwargs):
630
        """
631
        Track and record local variable (i.e. attributes) each day.
632
        """
633
        # Make 2 objects both referencing the same iterator
634
        args = [iter(args)] * 2
635
636
        # Zip generates list entries by calling `next` on each iterator it
637
        # receives.  In this case the two iterators are the same object, so the
638
        # call to next on args[0] will also advance args[1], resulting in zip
639
        # returning (a,b) (c,d) (e,f) rather than (a,a) (b,b) (c,c) etc.
640
        positionals = zip(*args)
641
        for name, value in chain(positionals, iteritems(kwargs)):
642
            self._recorded_vars[name] = value
643
644
    @api_method
645
    def set_benchmark(self, benchmark_sid):
646
        if self.initialized:
647
            raise SetBenchmarkOutsideInitialize()
648
649
        self.benchmark_sid = benchmark_sid
650
651
    @api_method
652
    @preprocess(symbol_str=ensure_upper_case)
653
    def symbol(self, symbol_str):
654
        """
655
        Default symbol lookup for any source that directly maps the
656
        symbol to the Asset (e.g. yahoo finance).
657
        """
658
        # If the user has not set the symbol lookup date,
659
        # use the period_end as the date for sybmol->sid resolution.
660
        _lookup_date = self._symbol_lookup_date if self._symbol_lookup_date is not None \
661
            else self.sim_params.period_end
662
663
        return self.asset_finder.lookup_symbol(
664
            symbol_str,
665
            as_of_date=_lookup_date,
666
        )
667
668
    @api_method
669
    def symbols(self, *args):
670
        """
671
        Default symbols lookup for any source that directly maps the
672
        symbol to the Asset (e.g. yahoo finance).
673
        """
674
        return [self.symbol(identifier) for identifier in args]
675
676
    @api_method
677
    def sid(self, a_sid):
678
        """
679
        Default sid lookup for any source that directly maps the integer sid
680
        to the Asset.
681
        """
682
        return self.asset_finder.retrieve_asset(a_sid)
683
684
    @api_method
685
    @preprocess(symbol=ensure_upper_case)
686
    def future_symbol(self, symbol):
687
        """ Lookup a futures contract with a given symbol.
688
689
        Parameters
690
        ----------
691
        symbol : str
692
            The symbol of the desired contract.
693
694
        Returns
695
        -------
696
        Future
697
            A Future object.
698
699
        Raises
700
        ------
701
        SymbolNotFound
702
            Raised when no contract named 'symbol' is found.
703
704
        """
705
        return self.asset_finder.lookup_future_symbol(symbol)
706
707
    @api_method
708
    @preprocess(root_symbol=ensure_upper_case)
709
    def future_chain(self, root_symbol, as_of_date=None):
710
        """ Look up a future chain with the specified parameters.
711
712
        Parameters
713
        ----------
714
        root_symbol : str
715
            The root symbol of a future chain.
716
        as_of_date : datetime.datetime or pandas.Timestamp or str, optional
717
            Date at which the chain determination is rooted. I.e. the
718
            existing contract whose notice date is first after this date is
719
            the primary contract, etc.
720
721
        Returns
722
        -------
723
        FutureChain
724
            The future chain matching the specified parameters.
725
726
        Raises
727
        ------
728
        RootSymbolNotFound
729
            If a future chain could not be found for the given root symbol.
730
        """
731
        if as_of_date:
732
            try:
733
                as_of_date = pd.Timestamp(as_of_date, tz='UTC')
734
            except ValueError:
735
                raise UnsupportedDatetimeFormat(input=as_of_date,
736
                                                method='future_chain')
737
        return FutureChain(
738
            asset_finder=self.asset_finder,
739
            get_datetime=self.get_datetime,
740
            root_symbol=root_symbol,
741
            as_of_date=as_of_date
742
        )
743
744
    def _calculate_order_value_amount(self, asset, value):
745
        """
746
        Calculates how many shares/contracts to order based on the type of
747
        asset being ordered.
748
        """
749
        last_price = self.trading_client.current_data[asset].price
750
751
        if tolerant_equals(last_price, 0):
752
            zero_message = "Price of 0 for {psid}; can't infer value".format(
753
                psid=asset
754
            )
755
            if self.logger:
756
                self.logger.debug(zero_message)
757
            # Don't place any order
758
            return 0
759
760
        if isinstance(asset, Future):
761
            value_multiplier = asset.contract_multiplier
762
        else:
763
            value_multiplier = 1
764
765
        return value / (last_price * value_multiplier)
766
767
    @api_method
768
    def order(self, sid, amount,
769
              limit_price=None,
770
              stop_price=None,
771
              style=None):
772
        """
773
        Place an order using the specified parameters.
774
        """
775
        # Truncate to the integer share count that's either within .0001 of
776
        # amount or closer to zero.
777
        # E.g. 3.9999 -> 4.0; 5.5 -> 5.0; -5.5 -> -5.0
778
        amount = int(round_if_near_integer(amount))
779
780
        # Raises a ZiplineError if invalid parameters are detected.
781
        self.validate_order_params(sid,
782
                                   amount,
783
                                   limit_price,
784
                                   stop_price,
785
                                   style)
786
787
        # Convert deprecated limit_price and stop_price parameters to use
788
        # ExecutionStyle objects.
789
        style = self.__convert_order_params_for_blotter(limit_price,
790
                                                        stop_price,
791
                                                        style)
792
        return self.blotter.order(sid, amount, style)
793
794
    def validate_order_params(self,
795
                              asset,
796
                              amount,
797
                              limit_price,
798
                              stop_price,
799
                              style):
800
        """
801
        Helper method for validating parameters to the order API function.
802
803
        Raises an UnsupportedOrderParameters if invalid arguments are found.
804
        """
805
806
        if not self.initialized:
807
            raise OrderDuringInitialize(
808
                msg="order() can only be called from within handle_data()"
809
            )
810
811
        if style:
812
            if limit_price:
813
                raise UnsupportedOrderParameters(
814
                    msg="Passing both limit_price and style is not supported."
815
                )
816
817
            if stop_price:
818
                raise UnsupportedOrderParameters(
819
                    msg="Passing both stop_price and style is not supported."
820
                )
821
822
        if not isinstance(asset, Asset):
823
            raise UnsupportedOrderParameters(
824
                msg="Passing non-Asset argument to 'order()' is not supported."
825
                    " Use 'sid()' or 'symbol()' methods to look up an Asset."
826
            )
827
828
        for control in self.trading_controls:
829
            control.validate(asset,
830
                             amount,
831
                             self.portfolio,
832
                             self.get_datetime(),
833
                             self.trading_client.current_data)
834
835
    @staticmethod
836
    def __convert_order_params_for_blotter(limit_price, stop_price, style):
837
        """
838
        Helper method for converting deprecated limit_price and stop_price
839
        arguments into ExecutionStyle instances.
840
841
        This function assumes that either style == None or (limit_price,
842
        stop_price) == (None, None).
843
        """
844
        # TODO_SS: DeprecationWarning for usage of limit_price and stop_price.
845
        if style:
846
            assert (limit_price, stop_price) == (None, None)
847
            return style
848
        if limit_price and stop_price:
849
            return StopLimitOrder(limit_price, stop_price)
850
        if limit_price:
851
            return LimitOrder(limit_price)
852
        if stop_price:
853
            return StopOrder(stop_price)
854
        else:
855
            return MarketOrder()
856
857
    @api_method
858
    def order_value(self, sid, value,
859
                    limit_price=None, stop_price=None, style=None):
860
        """
861
        Place an order by desired value rather than desired number of shares.
862
        If the requested sid exists, the requested value is
863
        divided by its price to imply the number of shares to transact.
864
        If the Asset being ordered is a Future, the 'value' calculated
865
        is actually the exposure, as Futures have no 'value'.
866
867
        value > 0 :: Buy/Cover
868
        value < 0 :: Sell/Short
869
        Market order:    order(sid, value)
870
        Limit order:     order(sid, value, limit_price)
871
        Stop order:      order(sid, value, None, stop_price)
872
        StopLimit order: order(sid, value, limit_price, stop_price)
873
        """
874
        amount = self._calculate_order_value_amount(sid, value)
875
        return self.order(sid, amount,
876
                          limit_price=limit_price,
877
                          stop_price=stop_price,
878
                          style=style)
879
880
    @property
881
    def recorded_vars(self):
882
        return copy(self._recorded_vars)
883
884
    @property
885
    def portfolio(self):
886
        return self.updated_portfolio()
887
888
    def updated_portfolio(self):
889
        if self._portfolio is None and self.perf_tracker is not None:
890
            self._portfolio = \
891
                self.perf_tracker.get_portfolio(self.datetime)
892
        return self._portfolio
893
894
    @property
895
    def account(self):
896
        return self.updated_account()
897
898
    def updated_account(self):
899
        if self._account is None and self.perf_tracker is not None:
900
            self._account = \
901
                self.perf_tracker.get_account(self.datetime)
902
        return self._account
903
904
    def set_logger(self, logger):
905
        self.logger = logger
906
907
    def on_dt_changed(self, dt):
908
        """
909
        Callback triggered by the simulation loop whenever the current dt
910
        changes.
911
912
        Any logic that should happen exactly once at the start of each datetime
913
        group should happen here.
914
        """
915
        assert isinstance(dt, datetime), \
916
            "Attempt to set algorithm's current time with non-datetime"
917
        assert dt.tzinfo == pytz.utc, \
918
            "Algorithm expects a utc datetime"
919
920
        self.datetime = dt
921
        self.perf_tracker.set_date(dt)
922
        self.blotter.set_date(dt)
923
924
        self._portfolio = None
925
        self._account = None
926
927
    @api_method
928
    def get_datetime(self, tz=None):
929
        """
930
        Returns the simulation datetime.
931
        """
932
        dt = self.datetime
933
        assert dt.tzinfo == pytz.utc, "Algorithm should have a utc datetime"
934
935
        if tz is not None:
936
            # Convert to the given timezone passed as a string or tzinfo.
937
            if isinstance(tz, string_types):
938
                tz = pytz.timezone(tz)
939
            dt = dt.astimezone(tz)
940
941
        return dt  # datetime.datetime objects are immutable.
942
943
    def update_dividends(self, dividend_frame):
944
        """
945
        Set DataFrame used to process dividends.  DataFrame columns should
946
        contain at least the entries in zp.DIVIDEND_FIELDS.
947
        """
948
        self.perf_tracker.update_dividends(dividend_frame)
949
950
    @api_method
951
    def set_slippage(self, slippage):
952
        if not isinstance(slippage, SlippageModel):
953
            raise UnsupportedSlippageModel()
954
        if self.initialized:
955
            raise OverrideSlippagePostInit()
956
        self.blotter.slippage_func = slippage
957
958
    @api_method
959
    def set_commission(self, commission):
960
        if not isinstance(commission, (PerShare, PerTrade, PerDollar)):
961
            raise UnsupportedCommissionModel()
962
963
        if self.initialized:
964
            raise OverrideCommissionPostInit()
965
        self.blotter.commission = commission
966
967
    @api_method
968
    def set_symbol_lookup_date(self, dt):
969
        """
970
        Set the date for which symbols will be resolved to their sids
971
        (symbols may map to different firms or underlying assets at
972
        different times)
973
        """
974
        try:
975
            self._symbol_lookup_date = pd.Timestamp(dt, tz='UTC')
976
        except ValueError:
977
            raise UnsupportedDatetimeFormat(input=dt,
978
                                            method='set_symbol_lookup_date')
979
980
    # Remain backwards compatibility
981
    @property
982
    def data_frequency(self):
983
        return self.sim_params.data_frequency
984
985
    @data_frequency.setter
986
    def data_frequency(self, value):
987
        assert value in ('daily', 'minute')
988
        self.sim_params.data_frequency = value
989
990
    @api_method
991
    def order_percent(self, sid, percent,
992
                      limit_price=None, stop_price=None, style=None):
993
        """
994
        Place an order in the specified asset corresponding to the given
995
        percent of the current portfolio value.
996
997
        Note that percent must expressed as a decimal (0.50 means 50\%).
998
        """
999
        value = self.portfolio.portfolio_value * percent
1000
        return self.order_value(sid, value,
1001
                                limit_price=limit_price,
1002
                                stop_price=stop_price,
1003
                                style=style)
1004
1005
    @api_method
1006
    def order_target(self, sid, target,
1007
                     limit_price=None, stop_price=None, style=None):
1008
        """
1009
        Place an order to adjust a position to a target number of shares. If
1010
        the position doesn't already exist, this is equivalent to placing a new
1011
        order. If the position does exist, this is equivalent to placing an
1012
        order for the difference between the target number of shares and the
1013
        current number of shares.
1014
        """
1015
        if sid in self.portfolio.positions:
1016
            current_position = self.portfolio.positions[sid].amount
1017
            req_shares = target - current_position
1018
            return self.order(sid, req_shares,
1019
                              limit_price=limit_price,
1020
                              stop_price=stop_price,
1021
                              style=style)
1022
        else:
1023
            return self.order(sid, target,
1024
                              limit_price=limit_price,
1025
                              stop_price=stop_price,
1026
                              style=style)
1027
1028
    @api_method
1029
    def order_target_value(self, sid, target,
1030
                           limit_price=None, stop_price=None, style=None):
1031
        """
1032
        Place an order to adjust a position to a target value. If
1033
        the position doesn't already exist, this is equivalent to placing a new
1034
        order. If the position does exist, this is equivalent to placing an
1035
        order for the difference between the target value and the
1036
        current value.
1037
        If the Asset being ordered is a Future, the 'target value' calculated
1038
        is actually the target exposure, as Futures have no 'value'.
1039
        """
1040
        target_amount = self._calculate_order_value_amount(sid, target)
1041
        return self.order_target(sid, target_amount,
1042
                                 limit_price=limit_price,
1043
                                 stop_price=stop_price,
1044
                                 style=style)
1045
1046
    @api_method
1047
    def order_target_percent(self, sid, target,
1048
                             limit_price=None, stop_price=None, style=None):
1049
        """
1050
        Place an order to adjust a position to a target percent of the
1051
        current portfolio value. If the position doesn't already exist, this is
1052
        equivalent to placing a new order. If the position does exist, this is
1053
        equivalent to placing an order for the difference between the target
1054
        percent and the current percent.
1055
1056
        Note that target must expressed as a decimal (0.50 means 50\%).
1057
        """
1058
        target_value = self.portfolio.portfolio_value * target
1059
        return self.order_target_value(sid, target_value,
1060
                                       limit_price=limit_price,
1061
                                       stop_price=stop_price,
1062
                                       style=style)
1063
1064
    @api_method
1065
    def get_open_orders(self, sid=None):
1066
        if sid is None:
1067
            return {
1068
                key: [order.to_api_obj() for order in orders]
1069
                for key, orders in iteritems(self.blotter.open_orders)
1070
                if orders
1071
            }
1072
        if sid in self.blotter.open_orders:
1073
            orders = self.blotter.open_orders[sid]
1074
            return [order.to_api_obj() for order in orders]
1075
        return []
1076
1077
    @api_method
1078
    def get_order(self, order_id):
1079
        if order_id in self.blotter.orders:
1080
            return self.blotter.orders[order_id].to_api_obj()
1081
1082
    @api_method
1083
    def cancel_order(self, order_param):
1084
        order_id = order_param
1085
        if isinstance(order_param, zipline.protocol.Order):
1086
            order_id = order_param.id
1087
1088
        self.blotter.cancel(order_id)
1089
1090
    @api_method
1091
    @require_initialized(HistoryInInitialize())
1092
    def history(self, sids, bar_count, frequency, field, ffill=True):
1093
        if self.data_portal is None:
1094
            raise Exception("no data portal!")
1095
1096
        return self.data_portal.get_history_window(
1097
            sids,
1098
            self.datetime,
1099
            bar_count,
1100
            frequency,
1101
            field,
1102
            ffill,
1103
        )
1104
1105
    ####################
1106
    # Account Controls #
1107
    ####################
1108
1109
    def register_account_control(self, control):
1110
        """
1111
        Register a new AccountControl to be checked on each bar.
1112
        """
1113
        if self.initialized:
1114
            raise RegisterAccountControlPostInit()
1115
        self.account_controls.append(control)
1116
1117
    def validate_account_controls(self):
1118
        for control in self.account_controls:
1119
            control.validate(self.portfolio,
1120
                             self.account,
1121
                             self.get_datetime(),
1122
                             self.trading_client.current_data)
1123
1124
    @api_method
1125
    def set_max_leverage(self, max_leverage=None):
1126
        """
1127
        Set a limit on the maximum leverage of the algorithm.
1128
        """
1129
        control = MaxLeverage(max_leverage)
1130
        self.register_account_control(control)
1131
1132
    ####################
1133
    # Trading Controls #
1134
    ####################
1135
1136
    def register_trading_control(self, control):
1137
        """
1138
        Register a new TradingControl to be checked prior to order calls.
1139
        """
1140
        if self.initialized:
1141
            raise RegisterTradingControlPostInit()
1142
        self.trading_controls.append(control)
1143
1144
    @api_method
1145
    def set_max_position_size(self,
1146
                              sid=None,
1147
                              max_shares=None,
1148
                              max_notional=None):
1149
        """
1150
        Set a limit on the number of shares and/or dollar value held for the
1151
        given sid. Limits are treated as absolute values and are enforced at
1152
        the time that the algo attempts to place an order for sid. This means
1153
        that it's possible to end up with more than the max number of shares
1154
        due to splits/dividends, and more than the max notional due to price
1155
        improvement.
1156
1157
        If an algorithm attempts to place an order that would result in
1158
        increasing the absolute value of shares/dollar value exceeding one of
1159
        these limits, raise a TradingControlException.
1160
        """
1161
        control = MaxPositionSize(asset=sid,
1162
                                  max_shares=max_shares,
1163
                                  max_notional=max_notional)
1164
        self.register_trading_control(control)
1165
1166
    @api_method
1167
    def set_max_order_size(self, sid=None, max_shares=None, max_notional=None):
1168
        """
1169
        Set a limit on the number of shares and/or dollar value of any single
1170
        order placed for sid.  Limits are treated as absolute values and are
1171
        enforced at the time that the algo attempts to place an order for sid.
1172
1173
        If an algorithm attempts to place an order that would result in
1174
        exceeding one of these limits, raise a TradingControlException.
1175
        """
1176
        control = MaxOrderSize(asset=sid,
1177
                               max_shares=max_shares,
1178
                               max_notional=max_notional)
1179
        self.register_trading_control(control)
1180
1181
    @api_method
1182
    def set_max_order_count(self, max_count):
1183
        """
1184
        Set a limit on the number of orders that can be placed within the given
1185
        time interval.
1186
        """
1187
        control = MaxOrderCount(max_count)
1188
        self.register_trading_control(control)
1189
1190
    @api_method
1191
    def set_do_not_order_list(self, restricted_list):
1192
        """
1193
        Set a restriction on which sids can be ordered.
1194
        """
1195
        control = RestrictedListOrder(restricted_list)
1196
        self.register_trading_control(control)
1197
1198
    @api_method
1199
    def set_long_only(self):
1200
        """
1201
        Set a rule specifying that this algorithm cannot take short positions.
1202
        """
1203
        self.register_trading_control(LongOnly())
1204
1205
    ##############
1206
    # Pipeline API
1207
    ##############
1208
    @api_method
1209
    @require_not_initialized(AttachPipelineAfterInitialize())
1210
    def attach_pipeline(self, pipeline, name, chunksize=None):
1211
        """
1212
        Register a pipeline to be computed at the start of each day.
1213
        """
1214
        if self._pipelines:
1215
            raise NotImplementedError("Multiple pipelines are not supported.")
1216
        if chunksize is None:
1217
            # Make the first chunk smaller to get more immediate results:
1218
            # (one week, then every half year)
1219
            chunks = iter(chain([5], repeat(126)))
1220
        else:
1221
            chunks = iter(repeat(int(chunksize)))
1222
        self._pipelines[name] = pipeline, chunks
1223
1224
        # Return the pipeline to allow expressions like
1225
        # p = attach_pipeline(Pipeline(), 'name')
1226
        return pipeline
1227
1228
    @api_method
1229
    @require_initialized(PipelineOutputDuringInitialize())
1230
    def pipeline_output(self, name):
1231
        """
1232
        Get the results of pipeline with name `name`.
1233
1234
        Parameters
1235
        ----------
1236
        name : str
1237
            Name of the pipeline for which results are requested.
1238
1239
        Returns
1240
        -------
1241
        results : pd.DataFrame
1242
            DataFrame containing the results of the requested pipeline for
1243
            the current simulation date.
1244
1245
        Raises
1246
        ------
1247
        NoSuchPipeline
1248
            Raised when no pipeline with the name `name` has been registered.
1249
1250
        See Also
1251
        --------
1252
        :meth:`zipline.pipeline.engine.PipelineEngine.run_pipeline`
1253
        """
1254
        # NOTE: We don't currently support multiple pipelines, but we plan to
1255
        # in the future.
1256
        try:
1257
            p, chunks = self._pipelines[name]
1258
        except KeyError:
1259
            raise NoSuchPipeline(
1260
                name=name,
1261
                valid=list(self._pipelines.keys()),
1262
            )
1263
        return self._pipeline_output(p, chunks)
1264
1265
    def _pipeline_output(self, pipeline, chunks):
1266
        """
1267
        Internal implementation of `pipeline_output`.
1268
        """
1269
        today = normalize_date(self.get_datetime())
1270
        try:
1271
            data = self._pipeline_cache.unwrap(today)
1272
        except Expired:
1273
            data, valid_until = self._run_pipeline(
1274
                pipeline, today, next(chunks),
1275
            )
1276
            self._pipeline_cache = CachedObject(data, valid_until)
1277
1278
        # Now that we have a cached result, try to return the data for today.
1279
        try:
1280
            return data.loc[today]
1281
        except KeyError:
1282
            # This happens if no assets passed the pipeline screen on a given
1283
            # day.
1284
            return pd.DataFrame(index=[], columns=data.columns)
1285
1286
    def _run_pipeline(self, pipeline, start_date, chunksize):
1287
        """
1288
        Compute `pipeline`, providing values for at least `start_date`.
1289
1290
        Produces a DataFrame containing data for days between `start_date` and
1291
        `end_date`, where `end_date` is defined by:
1292
1293
            `end_date = min(start_date + chunksize trading days,
1294
                            simulation_end)`
1295
1296
        Returns
1297
        -------
1298
        (data, valid_until) : tuple (pd.DataFrame, pd.Timestamp)
1299
1300
        See Also
1301
        --------
1302
        PipelineEngine.run_pipeline
1303
        """
1304
        days = self.trading_environment.trading_days
1305
1306
        # Load data starting from the previous trading day...
1307
        start_date_loc = days.get_loc(start_date)
1308
1309
        # ...continuing until either the day before the simulation end, or
1310
        # until chunksize days of data have been loaded.
1311
        sim_end = self.sim_params.last_close.normalize()
1312
        end_loc = min(start_date_loc + chunksize, days.get_loc(sim_end))
1313
        end_date = days[end_loc]
1314
1315
        return \
1316
            self.engine.run_pipeline(pipeline, start_date, end_date), end_date
1317
1318
    ##################
1319
    # End Pipeline API
1320
    ##################
1321
1322
    @classmethod
1323
    def all_api_methods(cls):
1324
        """
1325
        Return a list of all the TradingAlgorithm API methods.
1326
        """
1327
        return [
1328
            fn for fn in itervalues(vars(cls))
1329
            if getattr(fn, 'is_api_method', False)
1330
        ]
1331