Completed
Pull Request — master (#858)
by Eddie
01:43
created

zipline.TradingAlgorithm._create_clock()   B

Complexity

Conditions 2

Size

Total Lines 29

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 2
dl 0
loc 29
rs 8.8571
1
#
2
# Copyright 2015 Quantopian, Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
from copy import copy
16
17
import pytz
18
import pandas as pd
19
from pandas.tseries.tools import normalize_date
20
import numpy as np
21
22
from datetime import datetime
23
from itertools import chain, repeat
24
from numbers import Integral
25
26
from six import (
27
    exec_,
28
    iteritems,
29
    itervalues,
30
    string_types,
31
)
32
33
from zipline.errors import (
34
    AttachPipelineAfterInitialize,
35
    HistoryInInitialize,
36
    NoSuchPipeline,
37
    OrderDuringInitialize,
38
    OverrideCommissionPostInit,
39
    OverrideSlippagePostInit,
40
    PipelineOutputDuringInitialize,
41
    RegisterAccountControlPostInit,
42
    RegisterTradingControlPostInit,
43
    SetBenchmarkOutsideInitialize,
44
    UnsupportedCommissionModel,
45
    UnsupportedDatetimeFormat,
46
    UnsupportedOrderParameters,
47
    UnsupportedSlippageModel,
48
)
49
from zipline.finance.trading import TradingEnvironment
50
from zipline.finance.blotter import Blotter
51
from zipline.finance.commission import PerShare, PerTrade, PerDollar
52
from zipline.finance.controls import (
53
    LongOnly,
54
    MaxOrderCount,
55
    MaxOrderSize,
56
    MaxPositionSize,
57
    MaxLeverage,
58
    RestrictedListOrder
59
)
60
from zipline.finance.execution import (
61
    LimitOrder,
62
    MarketOrder,
63
    StopLimitOrder,
64
    StopOrder,
65
)
66
from zipline.finance.performance import PerformanceTracker
67
from zipline.finance.slippage import (
68
    VolumeShareSlippage,
69
    SlippageModel
70
)
71
from zipline.assets import Asset, Future
72
from zipline.assets.futures import FutureChain
73
from zipline.gens.tradesimulation import AlgorithmSimulator
74
from zipline.pipeline.engine import (
75
    NoOpPipelineEngine,
76
    SimplePipelineEngine,
77
)
78
from zipline.utils.api_support import (
79
    api_method,
80
    require_initialized,
81
    require_not_initialized,
82
    ZiplineAPI,
83
)
84
from zipline.utils.input_validation import ensure_upper_case
85
from zipline.utils.cache import (
86
    CachedObject,
87
    Expired
88
)
89
import zipline.utils.events
90
from zipline.utils.events import (
91
    EventManager,
92
    make_eventrule,
93
    DateRuleFactory,
94
    TimeRuleFactory,
95
)
96
from zipline.utils.factory import create_simulation_parameters
97
from zipline.utils.math_utils import (
98
    tolerant_equals,
99
    round_if_near_integer
100
)
101
from zipline.utils.preprocess import preprocess
102
103
import zipline.protocol
104
from zipline.sources.requests_csv import PandasRequestsCSV
105
106
from zipline.gens.sim_engine import (
107
    MinuteSimulationClock,
108
    DailySimulationClock,
109
)
110
from zipline.sources.benchmark_source import BenchmarkSource
111
112
DEFAULT_CAPITAL_BASE = float("1.0e5")
113
114
115
class TradingAlgorithm(object):
116
    """
117
    Base class for trading algorithms. Inherit and overload
118
    initialize() and handle_data(data).
119
120
    A new algorithm could look like this:
121
    ```
122
    from zipline.api import order, symbol
123
124
    def initialize(context):
125
        context.sid = symbol('AAPL')
126
        context.amount = 100
127
128
    def handle_data(context, data):
129
        sid = context.sid
130
        amount = context.amount
131
        order(sid, amount)
132
    ```
133
    To then to run this algorithm pass these functions to
134
    TradingAlgorithm:
135
136
    my_algo = TradingAlgorithm(initialize, handle_data)
137
    stats = my_algo.run(data)
138
139
    """
140
141
    def __init__(self, *args, **kwargs):
142
        """Initialize sids and other state variables.
143
144
        :Arguments:
145
        :Optional:
146
            initialize : function
147
                Function that is called with a single
148
                argument at the begninning of the simulation.
149
            handle_data : function
150
                Function that is called with 2 arguments
151
                (context and data) on every bar.
152
            script : str
153
                Algoscript that contains initialize and
154
                handle_data function definition.
155
            data_frequency : {'daily', 'minute'}
156
               The duration of the bars.
157
            capital_base : float <default: 1.0e5>
158
               How much capital to start with.
159
            asset_finder : An AssetFinder object
160
                A new AssetFinder object to be used in this TradingEnvironment
161
            equities_metadata : can be either:
162
                            - dict
163
                            - pandas.DataFrame
164
                            - object with 'read' property
165
                If dict is provided, it must have the following structure:
166
                * keys are the identifiers
167
                * values are dicts containing the metadata, with the metadata
168
                  field name as the key
169
                If pandas.DataFrame is provided, it must have the
170
                following structure:
171
                * column names must be the metadata fields
172
                * index must be the different asset identifiers
173
                * array contents should be the metadata value
174
                If an object with a 'read' property is provided, 'read' must
175
                return rows containing at least one of 'sid' or 'symbol' along
176
                with the other metadata fields.
177
            identifiers : List
178
                Any asset identifiers that are not provided in the
179
                equities_metadata, but will be traded by this TradingAlgorithm
180
        """
181
        self.sources = []
182
183
        # List of trading controls to be used to validate orders.
184
        self.trading_controls = []
185
186
        # List of account controls to be checked on each bar.
187
        self.account_controls = []
188
189
        self._recorded_vars = {}
190
        self.namespace = kwargs.pop('namespace', {})
191
192
        self._platform = kwargs.pop('platform', 'zipline')
193
194
        self.logger = None
195
196
        self.data_portal = None
197
198
        # If an env has been provided, pop it
199
        self.trading_environment = kwargs.pop('env', None)
200
201
        if self.trading_environment is None:
202
            self.trading_environment = TradingEnvironment()
203
204
        # Update the TradingEnvironment with the provided asset metadata
205
        self.trading_environment.write_data(
206
            equities_data=kwargs.pop('equities_metadata', {}),
207
            equities_identifiers=kwargs.pop('identifiers', []),
208
            futures_data=kwargs.pop('futures_metadata', {}),
209
        )
210
211
        # set the capital base
212
        self.capital_base = kwargs.pop('capital_base', DEFAULT_CAPITAL_BASE)
213
        self.sim_params = kwargs.pop('sim_params', None)
214
        if self.sim_params is None:
215
            self.sim_params = create_simulation_parameters(
216
                capital_base=self.capital_base,
217
                start=kwargs.pop('start', None),
218
                end=kwargs.pop('end', None),
219
                env=self.trading_environment,
220
            )
221
        else:
222
            self.sim_params.update_internal_from_env(self.trading_environment)
223
224
        self.perf_tracker = None
225
        # Pull in the environment's new AssetFinder for quick reference
226
        self.asset_finder = self.trading_environment.asset_finder
227
228
        # Initialize Pipeline API data.
229
        self.init_engine(kwargs.pop('get_pipeline_loader', None))
230
        self._pipelines = {}
231
        # Create an always-expired cache so that we compute the first time data
232
        # is requested.
233
        self._pipeline_cache = CachedObject(None, pd.Timestamp(0, tz='UTC'))
234
235
        self.blotter = kwargs.pop('blotter', None)
236
        if not self.blotter:
237
            self.blotter = Blotter(
238
                slippage_func=VolumeShareSlippage(),
239
                commission=PerShare(),
240
                data_frequency=self.data_frequency,
241
            )
242
243
        # The symbol lookup date specifies the date to use when resolving
244
        # symbols to sids, and can be set using set_symbol_lookup_date()
245
        self._symbol_lookup_date = None
246
247
        self._portfolio = None
248
        self._account = None
249
250
        # If string is passed in, execute and get reference to
251
        # functions.
252
        self.algoscript = kwargs.pop('script', None)
253
254
        self._initialize = None
255
        self._before_trading_start = None
256
        self._analyze = None
257
258
        self.event_manager = EventManager(
259
            create_context=kwargs.pop('create_event_context', None),
260
        )
261
262
        if self.algoscript is not None:
263
            filename = kwargs.pop('algo_filename', None)
264
            if filename is None:
265
                filename = '<string>'
266
            code = compile(self.algoscript, filename, 'exec')
267
            exec_(code, self.namespace)
268
            self._initialize = self.namespace.get('initialize')
269
            if 'handle_data' not in self.namespace:
270
                raise ValueError('You must define a handle_data function.')
271
            else:
272
                self._handle_data = self.namespace['handle_data']
273
274
            self._before_trading_start = \
275
                self.namespace.get('before_trading_start')
276
            # Optional analyze function, gets called after run
277
            self._analyze = self.namespace.get('analyze')
278
279
        elif kwargs.get('initialize') and kwargs.get('handle_data'):
280
            if self.algoscript is not None:
281
                raise ValueError('You can not set script and \
282
                initialize/handle_data.')
283
            self._initialize = kwargs.pop('initialize')
284
            self._handle_data = kwargs.pop('handle_data')
285
            self._before_trading_start = kwargs.pop('before_trading_start',
286
                                                    None)
287
            self._analyze = kwargs.pop('analyze', None)
288
289
        self.event_manager.add_event(
290
            zipline.utils.events.Event(
291
                zipline.utils.events.Always(),
292
                # We pass handle_data.__func__ to get the unbound method.
293
                # We will explicitly pass the algorithm to bind it again.
294
                self.handle_data.__func__,
295
            ),
296
            prepend=True,
297
        )
298
299
        # If method not defined, NOOP
300
        if self._initialize is None:
301
            self._initialize = lambda x: None
302
303
        # Alternative way of setting data_frequency for backwards
304
        # compatibility.
305
        if 'data_frequency' in kwargs:
306
            self.data_frequency = kwargs.pop('data_frequency')
307
308
        # Prepare the algo for initialization
309
        self.initialized = False
310
        self.initialize_args = args
311
        self.initialize_kwargs = kwargs
312
313
        self.benchmark_sid = kwargs.pop('benchmark_sid', None)
314
315
    def init_engine(self, get_loader):
316
        """
317
        Construct and store a PipelineEngine from loader.
318
319
        If get_loader is None, constructs a NoOpPipelineEngine.
320
        """
321
        if get_loader is not None:
322
            self.engine = SimplePipelineEngine(
323
                get_loader,
324
                self.trading_environment.trading_days,
325
                self.asset_finder,
326
            )
327
        else:
328
            self.engine = NoOpPipelineEngine()
329
330
    def initialize(self, *args, **kwargs):
331
        """
332
        Call self._initialize with `self` made available to Zipline API
333
        functions.
334
        """
335
        with ZiplineAPI(self):
336
            self._initialize(self, *args, **kwargs)
337
338
    def before_trading_start(self, data):
339
        if self._before_trading_start is None:
340
            return
341
342
        self._before_trading_start(self, data)
343
344
    def handle_data(self, data):
345
        self._handle_data(self, data)
346
347
        # Unlike trading controls which remain constant unless placing an
348
        # order, account controls can change each bar. Thus, must check
349
        # every bar no matter if the algorithm places an order or not.
350
        self.validate_account_controls()
351
352
    def analyze(self, perf):
353
        if self._analyze is None:
354
            return
355
356
        with ZiplineAPI(self):
357
            self._analyze(self, perf)
358
359
    def __repr__(self):
360
        """
361
        N.B. this does not yet represent a string that can be used
362
        to instantiate an exact copy of an algorithm.
363
364
        However, it is getting close, and provides some value as something
365
        that can be inspected interactively.
366
        """
367
        return """
368
{class_name}(
369
    capital_base={capital_base}
370
    sim_params={sim_params},
371
    initialized={initialized},
372
    slippage={slippage},
373
    commission={commission},
374
    blotter={blotter},
375
    recorded_vars={recorded_vars})
376
""".strip().format(class_name=self.__class__.__name__,
377
                   capital_base=self.capital_base,
378
                   sim_params=repr(self.sim_params),
379
                   initialized=self.initialized,
380
                   slippage=repr(self.blotter.slippage_func),
381
                   commission=repr(self.blotter.commission),
382
                   blotter=repr(self.blotter),
383
                   recorded_vars=repr(self.recorded_vars))
384
385
    def _create_clock(self):
386
        """
387
        If the clock property is not set, then create one based on frequency.
388
        """
389
        if self.sim_params.data_frequency == 'minute':
390
            env = self.trading_environment
391
            trading_o_and_c = env.open_and_closes.ix[
392
                self.sim_params.trading_days]
393
            market_opens = trading_o_and_c['market_open'].values.astype(
394
                'datetime64[ns]').astype(np.int64)
395
            market_closes = trading_o_and_c['market_close'].values.astype(
396
                'datetime64[ns]').astype(np.int64)
397
398
            minutely_emission = self.sim_params.emission_rate == "minute"
399
400
            clock = MinuteSimulationClock(
401
                self.sim_params.trading_days,
402
                market_opens,
403
                market_closes,
404
                env.trading_days,
405
                minutely_emission
406
            )
407
            self.data_portal.setup_offset_cache(
408
                clock.minutes_by_day,
409
                clock.minutes_to_day,
410
                self.sim_params.trading_days)
411
            return clock
412
        else:
413
            return DailySimulationClock(self.sim_params.trading_days)
414
415
    def _create_benchmark_source(self):
416
        return BenchmarkSource(
417
            self.benchmark_sid,
418
            self.trading_environment,
419
            self.sim_params.trading_days,
420
            self.data_portal,
421
            emission_rate=self.sim_params.emission_rate,
422
        )
423
424
    def _create_generator(self, sim_params):
425
        if sim_params is not None:
426
            self.sim_params = sim_params
427
428
        if self.perf_tracker is None:
429
            # HACK: When running with the `run` method, we set perf_tracker to
430
            # None so that it will be overwritten here.
431
            self.perf_tracker = PerformanceTracker(
432
                sim_params=self.sim_params,
433
                env=self.trading_environment,
434
                data_portal=self.data_portal
435
            )
436
437
            # Set the dt initially to the period start by forcing it to change.
438
            self.on_dt_changed(self.sim_params.period_start)
439
440
        if not self.initialized:
441
            self.initialize(*self.initialize_args, **self.initialize_kwargs)
442
            self.initialized = True
443
444
        self.trading_client = AlgorithmSimulator(
445
            self,
446
            sim_params,
447
            self.data_portal,
448
            self._create_clock(),
449
            self._create_benchmark_source()
450
        )
451
452
        return self.trading_client.transform()
453
454
    def get_generator(self):
455
        """
456
        Override this method to add new logic to the construction
457
        of the generator. Overrides can use the _create_generator
458
        method to get a standard construction generator.
459
        """
460
        return self._create_generator(self.sim_params)
461
462
    def run(self, data_portal=None):
463
        """Run the algorithm.
464
465
        :Arguments:
466
            source : DataPortal
467
468
        :Returns:
469
            daily_stats : pandas.DataFrame
470
              Daily performance metrics such as returns, alpha etc.
471
472
        """
473
        self.data_portal = data_portal
474
475
        # Force a reset of the performance tracker, in case
476
        # this is a repeat run of the algorithm.
477
        self.perf_tracker = None
478
479
        # Create zipline and loop through simulated_trading.
480
        # Each iteration returns a perf dictionary
481
        perfs = []
482
        for perf in self.get_generator():
483
            perfs.append(perf)
484
485
        # convert perf dict to pandas dataframe
486
        daily_stats = self._create_daily_stats(perfs)
487
488
        self.analyze(daily_stats)
489
490
        return daily_stats
491
492
    def _write_and_map_id_index_to_sids(self, identifiers, as_of_date):
493
        # Build new Assets for identifiers that can't be resolved as
494
        # sids/Assets
495
        identifiers_to_build = []
496
        for identifier in identifiers:
497
            asset = None
498
499
            if isinstance(identifier, Asset):
500
                asset = self.asset_finder.retrieve_asset(sid=identifier.sid,
501
                                                         default_none=True)
502
            elif isinstance(identifier, Integral):
503
                asset = self.asset_finder.retrieve_asset(sid=identifier,
504
                                                         default_none=True)
505
            if asset is None:
506
                identifiers_to_build.append(identifier)
507
508
        self.trading_environment.write_data(
509
            equities_identifiers=identifiers_to_build)
510
511
        # We need to clear out any cache misses that were stored while trying
512
        # to do lookups.  The real fix for this problem is to not construct an
513
        # AssetFinder until we `run()` when we actually have all the data we
514
        # need to so.
515
        self.asset_finder._reset_caches()
516
517
        return self.asset_finder.map_identifier_index_to_sids(
518
            identifiers, as_of_date,
519
        )
520
521
    def _create_daily_stats(self, perfs):
522
        # create daily and cumulative stats dataframe
523
        daily_perfs = []
524
        # TODO: the loop here could overwrite expected properties
525
        # of daily_perf. Could potentially raise or log a
526
        # warning.
527
        for perf in perfs:
528
            if 'daily_perf' in perf:
529
530
                perf['daily_perf'].update(
531
                    perf['daily_perf'].pop('recorded_vars')
532
                )
533
                perf['daily_perf'].update(perf['cumulative_risk_metrics'])
534
                daily_perfs.append(perf['daily_perf'])
535
            else:
536
                self.risk_report = perf
537
538
        daily_dts = [np.datetime64(perf['period_close'], utc=True)
539
                     for perf in daily_perfs]
540
        daily_stats = pd.DataFrame(daily_perfs, index=daily_dts)
541
542
        return daily_stats
543
544
    @api_method
545
    def get_environment(self, field='platform'):
546
        env = {
547
            'arena': self.sim_params.arena,
548
            'data_frequency': self.sim_params.data_frequency,
549
            'start': self.sim_params.first_open,
550
            'end': self.sim_params.last_close,
551
            'capital_base': self.sim_params.capital_base,
552
            'platform': self._platform
553
        }
554
        if field == '*':
555
            return env
556
        else:
557
            return env[field]
558
559
    @api_method
560
    def fetch_csv(self, url,
561
                  pre_func=None,
562
                  post_func=None,
563
                  date_column='date',
564
                  date_format=None,
565
                  timezone=pytz.utc.zone,
566
                  symbol=None,
567
                  mask=True,
568
                  symbol_column=None,
569
                  special_params_checker=None,
570
                  **kwargs):
571
572
        # Show all the logs every time fetcher is used.
573
        csv_data_source = PandasRequestsCSV(
574
            url,
575
            pre_func,
576
            post_func,
577
            self.trading_environment,
578
            self.sim_params.period_start,
579
            self.sim_params.period_end,
580
            date_column,
581
            date_format,
582
            timezone,
583
            symbol,
584
            mask,
585
            symbol_column,
586
            data_frequency=self.data_frequency,
587
            special_params_checker=special_params_checker,
588
            **kwargs
589
        )
590
591
        # ingest this into dataportal
592
        self.data_portal.handle_extra_source(csv_data_source.df,
593
                                             self.sim_params)
594
595
        return csv_data_source
596
597
    def add_event(self, rule=None, callback=None):
598
        """
599
        Adds an event to the algorithm's EventManager.
600
        """
601
        self.event_manager.add_event(
602
            zipline.utils.events.Event(rule, callback),
603
        )
604
605
    @api_method
606
    def schedule_function(self,
607
                          func,
608
                          date_rule=None,
609
                          time_rule=None,
610
                          half_days=True):
611
        """
612
        Schedules a function to be called with some timed rules.
613
        """
614
        date_rule = date_rule or DateRuleFactory.every_day()
615
        time_rule = ((time_rule or TimeRuleFactory.market_open())
616
                     if self.sim_params.data_frequency == 'minute' else
617
                     # If we are in daily mode the time_rule is ignored.
618
                     zipline.utils.events.Always())
619
620
        self.add_event(
621
            make_eventrule(date_rule, time_rule, half_days),
622
            func,
623
        )
624
625
    @api_method
626
    def record(self, *args, **kwargs):
627
        """
628
        Track and record local variable (i.e. attributes) each day.
629
        """
630
        # Make 2 objects both referencing the same iterator
631
        args = [iter(args)] * 2
632
633
        # Zip generates list entries by calling `next` on each iterator it
634
        # receives.  In this case the two iterators are the same object, so the
635
        # call to next on args[0] will also advance args[1], resulting in zip
636
        # returning (a,b) (c,d) (e,f) rather than (a,a) (b,b) (c,c) etc.
637
        positionals = zip(*args)
638
        for name, value in chain(positionals, iteritems(kwargs)):
639
            self._recorded_vars[name] = value
640
641
    @api_method
642
    def set_benchmark(self, benchmark_sid):
643
        if self.initialized:
644
            raise SetBenchmarkOutsideInitialize()
645
646
        self.benchmark_sid = benchmark_sid
647
648
    @api_method
649
    @preprocess(symbol_str=ensure_upper_case)
650
    def symbol(self, symbol_str):
651
        """
652
        Default symbol lookup for any source that directly maps the
653
        symbol to the Asset (e.g. yahoo finance).
654
        """
655
        # If the user has not set the symbol lookup date,
656
        # use the period_end as the date for sybmol->sid resolution.
657
        _lookup_date = self._symbol_lookup_date if self._symbol_lookup_date is not None \
658
            else self.sim_params.period_end
659
660
        return self.asset_finder.lookup_symbol(
661
            symbol_str,
662
            as_of_date=_lookup_date,
663
        )
664
665
    @api_method
666
    def symbols(self, *args):
667
        """
668
        Default symbols lookup for any source that directly maps the
669
        symbol to the Asset (e.g. yahoo finance).
670
        """
671
        return [self.symbol(identifier) for identifier in args]
672
673
    @api_method
674
    def sid(self, a_sid):
675
        """
676
        Default sid lookup for any source that directly maps the integer sid
677
        to the Asset.
678
        """
679
        return self.asset_finder.retrieve_asset(a_sid)
680
681
    @api_method
682
    @preprocess(symbol=ensure_upper_case)
683
    def future_symbol(self, symbol):
684
        """ Lookup a futures contract with a given symbol.
685
686
        Parameters
687
        ----------
688
        symbol : str
689
            The symbol of the desired contract.
690
691
        Returns
692
        -------
693
        Future
694
            A Future object.
695
696
        Raises
697
        ------
698
        SymbolNotFound
699
            Raised when no contract named 'symbol' is found.
700
701
        """
702
        return self.asset_finder.lookup_future_symbol(symbol)
703
704
    @api_method
705
    @preprocess(root_symbol=ensure_upper_case)
706
    def future_chain(self, root_symbol, as_of_date=None):
707
        """ Look up a future chain with the specified parameters.
708
709
        Parameters
710
        ----------
711
        root_symbol : str
712
            The root symbol of a future chain.
713
        as_of_date : datetime.datetime or pandas.Timestamp or str, optional
714
            Date at which the chain determination is rooted. I.e. the
715
            existing contract whose notice date is first after this date is
716
            the primary contract, etc.
717
718
        Returns
719
        -------
720
        FutureChain
721
            The future chain matching the specified parameters.
722
723
        Raises
724
        ------
725
        RootSymbolNotFound
726
            If a future chain could not be found for the given root symbol.
727
        """
728
        if as_of_date:
729
            try:
730
                as_of_date = pd.Timestamp(as_of_date, tz='UTC')
731
            except ValueError:
732
                raise UnsupportedDatetimeFormat(input=as_of_date,
733
                                                method='future_chain')
734
        return FutureChain(
735
            asset_finder=self.asset_finder,
736
            get_datetime=self.get_datetime,
737
            root_symbol=root_symbol,
738
            as_of_date=as_of_date
739
        )
740
741
    def _calculate_order_value_amount(self, asset, value):
742
        """
743
        Calculates how many shares/contracts to order based on the type of
744
        asset being ordered.
745
        """
746
        last_price = self.trading_client.current_data[asset].price
747
748
        if tolerant_equals(last_price, 0):
749
            zero_message = "Price of 0 for {psid}; can't infer value".format(
750
                psid=asset
751
            )
752
            if self.logger:
753
                self.logger.debug(zero_message)
754
            # Don't place any order
755
            return 0
756
757
        if isinstance(asset, Future):
758
            value_multiplier = asset.contract_multiplier
759
        else:
760
            value_multiplier = 1
761
762
        return value / (last_price * value_multiplier)
763
764
    @api_method
765
    def order(self, sid, amount,
766
              limit_price=None,
767
              stop_price=None,
768
              style=None):
769
        """
770
        Place an order using the specified parameters.
771
        """
772
        # Truncate to the integer share count that's either within .0001 of
773
        # amount or closer to zero.
774
        # E.g. 3.9999 -> 4.0; 5.5 -> 5.0; -5.5 -> -5.0
775
        amount = int(round_if_near_integer(amount))
776
777
        # Raises a ZiplineError if invalid parameters are detected.
778
        self.validate_order_params(sid,
779
                                   amount,
780
                                   limit_price,
781
                                   stop_price,
782
                                   style)
783
784
        # Convert deprecated limit_price and stop_price parameters to use
785
        # ExecutionStyle objects.
786
        style = self.__convert_order_params_for_blotter(limit_price,
787
                                                        stop_price,
788
                                                        style)
789
        return self.blotter.order(sid, amount, style)
790
791
    def validate_order_params(self,
792
                              asset,
793
                              amount,
794
                              limit_price,
795
                              stop_price,
796
                              style):
797
        """
798
        Helper method for validating parameters to the order API function.
799
800
        Raises an UnsupportedOrderParameters if invalid arguments are found.
801
        """
802
803
        if not self.initialized:
804
            raise OrderDuringInitialize(
805
                msg="order() can only be called from within handle_data()"
806
            )
807
808
        if style:
809
            if limit_price:
810
                raise UnsupportedOrderParameters(
811
                    msg="Passing both limit_price and style is not supported."
812
                )
813
814
            if stop_price:
815
                raise UnsupportedOrderParameters(
816
                    msg="Passing both stop_price and style is not supported."
817
                )
818
819
        if not isinstance(asset, Asset):
820
            raise UnsupportedOrderParameters(
821
                msg="Passing non-Asset argument to 'order()' is not supported."
822
                    " Use 'sid()' or 'symbol()' methods to look up an Asset."
823
            )
824
825
        for control in self.trading_controls:
826
            control.validate(asset,
827
                             amount,
828
                             self.portfolio,
829
                             self.get_datetime(),
830
                             self.trading_client.current_data)
831
832
    @staticmethod
833
    def __convert_order_params_for_blotter(limit_price, stop_price, style):
834
        """
835
        Helper method for converting deprecated limit_price and stop_price
836
        arguments into ExecutionStyle instances.
837
838
        This function assumes that either style == None or (limit_price,
839
        stop_price) == (None, None).
840
        """
841
        # TODO_SS: DeprecationWarning for usage of limit_price and stop_price.
842
        if style:
843
            assert (limit_price, stop_price) == (None, None)
844
            return style
845
        if limit_price and stop_price:
846
            return StopLimitOrder(limit_price, stop_price)
847
        if limit_price:
848
            return LimitOrder(limit_price)
849
        if stop_price:
850
            return StopOrder(stop_price)
851
        else:
852
            return MarketOrder()
853
854
    @api_method
855
    def order_value(self, sid, value,
856
                    limit_price=None, stop_price=None, style=None):
857
        """
858
        Place an order by desired value rather than desired number of shares.
859
        If the requested sid exists, the requested value is
860
        divided by its price to imply the number of shares to transact.
861
        If the Asset being ordered is a Future, the 'value' calculated
862
        is actually the exposure, as Futures have no 'value'.
863
864
        value > 0 :: Buy/Cover
865
        value < 0 :: Sell/Short
866
        Market order:    order(sid, value)
867
        Limit order:     order(sid, value, limit_price)
868
        Stop order:      order(sid, value, None, stop_price)
869
        StopLimit order: order(sid, value, limit_price, stop_price)
870
        """
871
        amount = self._calculate_order_value_amount(sid, value)
872
        return self.order(sid, amount,
873
                          limit_price=limit_price,
874
                          stop_price=stop_price,
875
                          style=style)
876
877
    @property
878
    def recorded_vars(self):
879
        return copy(self._recorded_vars)
880
881
    @property
882
    def portfolio(self):
883
        return self.updated_portfolio()
884
885
    def updated_portfolio(self):
886
        if self._portfolio is None and self.perf_tracker is not None:
887
            self._portfolio = \
888
                self.perf_tracker.get_portfolio(self.datetime)
889
        return self._portfolio
890
891
    @property
892
    def account(self):
893
        return self.updated_account()
894
895
    def updated_account(self):
896
        if self._account is None and self.perf_tracker is not None:
897
            self._account = \
898
                self.perf_tracker.get_account(self.datetime)
899
        return self._account
900
901
    def set_logger(self, logger):
902
        self.logger = logger
903
904
    def on_dt_changed(self, dt):
905
        """
906
        Callback triggered by the simulation loop whenever the current dt
907
        changes.
908
909
        Any logic that should happen exactly once at the start of each datetime
910
        group should happen here.
911
        """
912
        assert isinstance(dt, datetime), \
913
            "Attempt to set algorithm's current time with non-datetime"
914
        assert dt.tzinfo == pytz.utc, \
915
            "Algorithm expects a utc datetime"
916
917
        self.datetime = dt
918
        self.perf_tracker.set_date(dt)
919
        self.blotter.set_date(dt)
920
921
        self._portfolio = None
922
        self._account = None
923
924
    @api_method
925
    def get_datetime(self, tz=None):
926
        """
927
        Returns the simulation datetime.
928
        """
929
        dt = self.datetime
930
        assert dt.tzinfo == pytz.utc, "Algorithm should have a utc datetime"
931
932
        if tz is not None:
933
            # Convert to the given timezone passed as a string or tzinfo.
934
            if isinstance(tz, string_types):
935
                tz = pytz.timezone(tz)
936
            dt = dt.astimezone(tz)
937
938
        return dt  # datetime.datetime objects are immutable.
939
940
    def update_dividends(self, dividend_frame):
941
        """
942
        Set DataFrame used to process dividends.  DataFrame columns should
943
        contain at least the entries in zp.DIVIDEND_FIELDS.
944
        """
945
        self.perf_tracker.update_dividends(dividend_frame)
946
947
    @api_method
948
    def set_slippage(self, slippage):
949
        if not isinstance(slippage, SlippageModel):
950
            raise UnsupportedSlippageModel()
951
        if self.initialized:
952
            raise OverrideSlippagePostInit()
953
        self.blotter.slippage_func = slippage
954
955
    @api_method
956
    def set_commission(self, commission):
957
        if not isinstance(commission, (PerShare, PerTrade, PerDollar)):
958
            raise UnsupportedCommissionModel()
959
960
        if self.initialized:
961
            raise OverrideCommissionPostInit()
962
        self.blotter.commission = commission
963
964
    @api_method
965
    def set_symbol_lookup_date(self, dt):
966
        """
967
        Set the date for which symbols will be resolved to their sids
968
        (symbols may map to different firms or underlying assets at
969
        different times)
970
        """
971
        try:
972
            self._symbol_lookup_date = pd.Timestamp(dt, tz='UTC')
973
        except ValueError:
974
            raise UnsupportedDatetimeFormat(input=dt,
975
                                            method='set_symbol_lookup_date')
976
977
    # Remain backwards compatibility
978
    @property
979
    def data_frequency(self):
980
        return self.sim_params.data_frequency
981
982
    @data_frequency.setter
983
    def data_frequency(self, value):
984
        assert value in ('daily', 'minute')
985
        self.sim_params.data_frequency = value
986
987
    @api_method
988
    def order_percent(self, sid, percent,
989
                      limit_price=None, stop_price=None, style=None):
990
        """
991
        Place an order in the specified asset corresponding to the given
992
        percent of the current portfolio value.
993
994
        Note that percent must expressed as a decimal (0.50 means 50\%).
995
        """
996
        value = self.portfolio.portfolio_value * percent
997
        return self.order_value(sid, value,
998
                                limit_price=limit_price,
999
                                stop_price=stop_price,
1000
                                style=style)
1001
1002
    @api_method
1003
    def order_target(self, sid, target,
1004
                     limit_price=None, stop_price=None, style=None):
1005
        """
1006
        Place an order to adjust a position to a target number of shares. If
1007
        the position doesn't already exist, this is equivalent to placing a new
1008
        order. If the position does exist, this is equivalent to placing an
1009
        order for the difference between the target number of shares and the
1010
        current number of shares.
1011
        """
1012
        if sid in self.portfolio.positions:
1013
            current_position = self.portfolio.positions[sid].amount
1014
            req_shares = target - current_position
1015
            return self.order(sid, req_shares,
1016
                              limit_price=limit_price,
1017
                              stop_price=stop_price,
1018
                              style=style)
1019
        else:
1020
            return self.order(sid, target,
1021
                              limit_price=limit_price,
1022
                              stop_price=stop_price,
1023
                              style=style)
1024
1025
    @api_method
1026
    def order_target_value(self, sid, target,
1027
                           limit_price=None, stop_price=None, style=None):
1028
        """
1029
        Place an order to adjust a position to a target value. If
1030
        the position doesn't already exist, this is equivalent to placing a new
1031
        order. If the position does exist, this is equivalent to placing an
1032
        order for the difference between the target value and the
1033
        current value.
1034
        If the Asset being ordered is a Future, the 'target value' calculated
1035
        is actually the target exposure, as Futures have no 'value'.
1036
        """
1037
        target_amount = self._calculate_order_value_amount(sid, target)
1038
        return self.order_target(sid, target_amount,
1039
                                 limit_price=limit_price,
1040
                                 stop_price=stop_price,
1041
                                 style=style)
1042
1043
    @api_method
1044
    def order_target_percent(self, sid, target,
1045
                             limit_price=None, stop_price=None, style=None):
1046
        """
1047
        Place an order to adjust a position to a target percent of the
1048
        current portfolio value. If the position doesn't already exist, this is
1049
        equivalent to placing a new order. If the position does exist, this is
1050
        equivalent to placing an order for the difference between the target
1051
        percent and the current percent.
1052
1053
        Note that target must expressed as a decimal (0.50 means 50\%).
1054
        """
1055
        target_value = self.portfolio.portfolio_value * target
1056
        return self.order_target_value(sid, target_value,
1057
                                       limit_price=limit_price,
1058
                                       stop_price=stop_price,
1059
                                       style=style)
1060
1061
    @api_method
1062
    def get_open_orders(self, sid=None):
1063
        if sid is None:
1064
            return {
1065
                key: [order.to_api_obj() for order in orders]
1066
                for key, orders in iteritems(self.blotter.open_orders)
1067
                if orders
1068
            }
1069
        if sid in self.blotter.open_orders:
1070
            orders = self.blotter.open_orders[sid]
1071
            return [order.to_api_obj() for order in orders]
1072
        return []
1073
1074
    @api_method
1075
    def get_order(self, order_id):
1076
        if order_id in self.blotter.orders:
1077
            return self.blotter.orders[order_id].to_api_obj()
1078
1079
    @api_method
1080
    def cancel_order(self, order_param):
1081
        order_id = order_param
1082
        if isinstance(order_param, zipline.protocol.Order):
1083
            order_id = order_param.id
1084
1085
        self.blotter.cancel(order_id)
1086
1087
    @api_method
1088
    @require_initialized(HistoryInInitialize())
1089
    def history(self, sids, bar_count, frequency, field, ffill=True):
1090
        if self.data_portal is None:
1091
            raise Exception("no data portal!")
1092
1093
        return self.data_portal.get_history_window(
1094
            sids,
1095
            self.datetime,
1096
            bar_count,
1097
            frequency,
1098
            field,
1099
            ffill,
1100
        )
1101
1102
    ####################
1103
    # Account Controls #
1104
    ####################
1105
1106
    def register_account_control(self, control):
1107
        """
1108
        Register a new AccountControl to be checked on each bar.
1109
        """
1110
        if self.initialized:
1111
            raise RegisterAccountControlPostInit()
1112
        self.account_controls.append(control)
1113
1114
    def validate_account_controls(self):
1115
        for control in self.account_controls:
1116
            control.validate(self.portfolio,
1117
                             self.account,
1118
                             self.get_datetime(),
1119
                             self.trading_client.current_data)
1120
1121
    @api_method
1122
    def set_max_leverage(self, max_leverage=None):
1123
        """
1124
        Set a limit on the maximum leverage of the algorithm.
1125
        """
1126
        control = MaxLeverage(max_leverage)
1127
        self.register_account_control(control)
1128
1129
    ####################
1130
    # Trading Controls #
1131
    ####################
1132
1133
    def register_trading_control(self, control):
1134
        """
1135
        Register a new TradingControl to be checked prior to order calls.
1136
        """
1137
        if self.initialized:
1138
            raise RegisterTradingControlPostInit()
1139
        self.trading_controls.append(control)
1140
1141
    @api_method
1142
    def set_max_position_size(self,
1143
                              sid=None,
1144
                              max_shares=None,
1145
                              max_notional=None):
1146
        """
1147
        Set a limit on the number of shares and/or dollar value held for the
1148
        given sid. Limits are treated as absolute values and are enforced at
1149
        the time that the algo attempts to place an order for sid. This means
1150
        that it's possible to end up with more than the max number of shares
1151
        due to splits/dividends, and more than the max notional due to price
1152
        improvement.
1153
1154
        If an algorithm attempts to place an order that would result in
1155
        increasing the absolute value of shares/dollar value exceeding one of
1156
        these limits, raise a TradingControlException.
1157
        """
1158
        control = MaxPositionSize(asset=sid,
1159
                                  max_shares=max_shares,
1160
                                  max_notional=max_notional)
1161
        self.register_trading_control(control)
1162
1163
    @api_method
1164
    def set_max_order_size(self, sid=None, max_shares=None, max_notional=None):
1165
        """
1166
        Set a limit on the number of shares and/or dollar value of any single
1167
        order placed for sid.  Limits are treated as absolute values and are
1168
        enforced at the time that the algo attempts to place an order for sid.
1169
1170
        If an algorithm attempts to place an order that would result in
1171
        exceeding one of these limits, raise a TradingControlException.
1172
        """
1173
        control = MaxOrderSize(asset=sid,
1174
                               max_shares=max_shares,
1175
                               max_notional=max_notional)
1176
        self.register_trading_control(control)
1177
1178
    @api_method
1179
    def set_max_order_count(self, max_count):
1180
        """
1181
        Set a limit on the number of orders that can be placed within the given
1182
        time interval.
1183
        """
1184
        control = MaxOrderCount(max_count)
1185
        self.register_trading_control(control)
1186
1187
    @api_method
1188
    def set_do_not_order_list(self, restricted_list):
1189
        """
1190
        Set a restriction on which sids can be ordered.
1191
        """
1192
        control = RestrictedListOrder(restricted_list)
1193
        self.register_trading_control(control)
1194
1195
    @api_method
1196
    def set_long_only(self):
1197
        """
1198
        Set a rule specifying that this algorithm cannot take short positions.
1199
        """
1200
        self.register_trading_control(LongOnly())
1201
1202
    ##############
1203
    # Pipeline API
1204
    ##############
1205
    @api_method
1206
    @require_not_initialized(AttachPipelineAfterInitialize())
1207
    def attach_pipeline(self, pipeline, name, chunksize=None):
1208
        """
1209
        Register a pipeline to be computed at the start of each day.
1210
        """
1211
        if self._pipelines:
1212
            raise NotImplementedError("Multiple pipelines are not supported.")
1213
        if chunksize is None:
1214
            # Make the first chunk smaller to get more immediate results:
1215
            # (one week, then every half year)
1216
            chunks = iter(chain([5], repeat(126)))
1217
        else:
1218
            chunks = iter(repeat(int(chunksize)))
1219
        self._pipelines[name] = pipeline, chunks
1220
1221
        # Return the pipeline to allow expressions like
1222
        # p = attach_pipeline(Pipeline(), 'name')
1223
        return pipeline
1224
1225
    @api_method
1226
    @require_initialized(PipelineOutputDuringInitialize())
1227
    def pipeline_output(self, name):
1228
        """
1229
        Get the results of pipeline with name `name`.
1230
1231
        Parameters
1232
        ----------
1233
        name : str
1234
            Name of the pipeline for which results are requested.
1235
1236
        Returns
1237
        -------
1238
        results : pd.DataFrame
1239
            DataFrame containing the results of the requested pipeline for
1240
            the current simulation date.
1241
1242
        Raises
1243
        ------
1244
        NoSuchPipeline
1245
            Raised when no pipeline with the name `name` has been registered.
1246
1247
        See Also
1248
        --------
1249
        :meth:`zipline.pipeline.engine.PipelineEngine.run_pipeline`
1250
        """
1251
        # NOTE: We don't currently support multiple pipelines, but we plan to
1252
        # in the future.
1253
        try:
1254
            p, chunks = self._pipelines[name]
1255
        except KeyError:
1256
            raise NoSuchPipeline(
1257
                name=name,
1258
                valid=list(self._pipelines.keys()),
1259
            )
1260
        return self._pipeline_output(p, chunks)
1261
1262
    def _pipeline_output(self, pipeline, chunks):
1263
        """
1264
        Internal implementation of `pipeline_output`.
1265
        """
1266
        today = normalize_date(self.get_datetime())
1267
        try:
1268
            data = self._pipeline_cache.unwrap(today)
1269
        except Expired:
1270
            data, valid_until = self._run_pipeline(
1271
                pipeline, today, next(chunks),
1272
            )
1273
            self._pipeline_cache = CachedObject(data, valid_until)
1274
1275
        # Now that we have a cached result, try to return the data for today.
1276
        try:
1277
            return data.loc[today]
1278
        except KeyError:
1279
            # This happens if no assets passed the pipeline screen on a given
1280
            # day.
1281
            return pd.DataFrame(index=[], columns=data.columns)
1282
1283
    def _run_pipeline(self, pipeline, start_date, chunksize):
1284
        """
1285
        Compute `pipeline`, providing values for at least `start_date`.
1286
1287
        Produces a DataFrame containing data for days between `start_date` and
1288
        `end_date`, where `end_date` is defined by:
1289
1290
            `end_date = min(start_date + chunksize trading days,
1291
                            simulation_end)`
1292
1293
        Returns
1294
        -------
1295
        (data, valid_until) : tuple (pd.DataFrame, pd.Timestamp)
1296
1297
        See Also
1298
        --------
1299
        PipelineEngine.run_pipeline
1300
        """
1301
        days = self.trading_environment.trading_days
1302
1303
        # Load data starting from the previous trading day...
1304
        start_date_loc = days.get_loc(start_date)
1305
1306
        # ...continuing until either the day before the simulation end, or
1307
        # until chunksize days of data have been loaded.
1308
        sim_end = self.sim_params.last_close.normalize()
1309
        end_loc = min(start_date_loc + chunksize, days.get_loc(sim_end))
1310
        end_date = days[end_loc]
1311
1312
        return \
1313
            self.engine.run_pipeline(pipeline, start_date, end_date), end_date
1314
1315
    ##################
1316
    # End Pipeline API
1317
    ##################
1318
1319
    @classmethod
1320
    def all_api_methods(cls):
1321
        """
1322
        Return a list of all the TradingAlgorithm API methods.
1323
        """
1324
        return [
1325
            fn for fn in itervalues(vars(cls))
1326
            if getattr(fn, 'is_api_method', False)
1327
        ]
1328