Completed
Pull Request — master (#858)
by Eddie
01:32
created

zipline.TradingAlgorithm.get_open_orders()   B

Complexity

Conditions 7

Size

Total Lines 12

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 7
dl 0
loc 12
rs 7.3333
1
#
2
# Copyright 2015 Quantopian, Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
from copy import copy
16
17
import pytz
18
import pandas as pd
19
from pandas.tseries.tools import normalize_date
20
import numpy as np
21
22
from datetime import datetime
23
from itertools import chain, repeat
24
from numbers import Integral
25
26
from six import (
27
    exec_,
28
    iteritems,
29
    itervalues,
30
    string_types,
31
)
32
33
34
from zipline.errors import (
35
    AttachPipelineAfterInitialize,
36
    HistoryInInitialize,
37
    NoSuchPipeline,
38
    OrderDuringInitialize,
39
    OverrideCommissionPostInit,
40
    OverrideSlippagePostInit,
41
    PipelineOutputDuringInitialize,
42
    RegisterAccountControlPostInit,
43
    RegisterTradingControlPostInit,
44
    SetBenchmarkOutsideInitialize,
45
    UnsupportedCommissionModel,
46
    UnsupportedDatetimeFormat,
47
    UnsupportedOrderParameters,
48
    UnsupportedSlippageModel,
49
)
50
from zipline.finance.trading import TradingEnvironment
51
from zipline.finance.blotter import Blotter
52
from zipline.finance.commission import PerShare, PerTrade, PerDollar
53
from zipline.finance.controls import (
54
    LongOnly,
55
    MaxOrderCount,
56
    MaxOrderSize,
57
    MaxPositionSize,
58
    MaxLeverage,
59
    RestrictedListOrder
60
)
61
from zipline.finance.execution import (
62
    LimitOrder,
63
    MarketOrder,
64
    StopLimitOrder,
65
    StopOrder,
66
)
67
from zipline.finance.performance import PerformanceTracker
68
from zipline.finance.slippage import (
69
    VolumeShareSlippage,
70
    SlippageModel
71
)
72
from zipline.assets import Asset, Future
73
from zipline.assets.futures import FutureChain
74
from zipline.gens.tradesimulation import AlgorithmSimulator
75
from zipline.pipeline.engine import (
76
    NoOpPipelineEngine,
77
    SimplePipelineEngine,
78
)
79
from zipline.utils.api_support import (
80
    api_method,
81
    require_initialized,
82
    require_not_initialized,
83
    ZiplineAPI,
84
)
85
from zipline.utils.input_validation import ensure_upper_case
86
from zipline.utils.cache import CachedObject, Expired
87
import zipline.utils.events
88
from zipline.utils.events import (
89
    EventManager,
90
    make_eventrule,
91
    DateRuleFactory,
92
    TimeRuleFactory,
93
)
94
from zipline.utils.factory import create_simulation_parameters
95
from zipline.utils.math_utils import (
96
    tolerant_equals,
97
    round_if_near_integer
98
)
99
from zipline.utils.preprocess import preprocess
100
101
import zipline.protocol
102
from zipline.sources.requests_csv import PandasRequestsCSV
103
104
from zipline.gens.sim_engine import (
105
    MinuteSimulationClock,
106
    DailySimulationClock,
107
)
108
from zipline.sources.benchmark_source import BenchmarkSource
109
110
DEFAULT_CAPITAL_BASE = float("1.0e5")
111
112
113
class TradingAlgorithm(object):
114
    """
115
    Base class for trading algorithms. Inherit and overload
116
    initialize() and handle_data(data).
117
118
    A new algorithm could look like this:
119
    ```
120
    from zipline.api import order, symbol
121
122
    def initialize(context):
123
        context.sid = symbol('AAPL')
124
        context.amount = 100
125
126
    def handle_data(context, data):
127
        sid = context.sid
128
        amount = context.amount
129
        order(sid, amount)
130
    ```
131
    To then to run this algorithm pass these functions to
132
    TradingAlgorithm:
133
134
    my_algo = TradingAlgorithm(initialize, handle_data)
135
    stats = my_algo.run(data)
136
137
    """
138
139
    def __init__(self, *args, **kwargs):
140
        """Initialize sids and other state variables.
141
142
        :Arguments:
143
        :Optional:
144
            initialize : function
145
                Function that is called with a single
146
                argument at the begninning of the simulation.
147
            handle_data : function
148
                Function that is called with 2 arguments
149
                (context and data) on every bar.
150
            script : str
151
                Algoscript that contains initialize and
152
                handle_data function definition.
153
            data_frequency : {'daily', 'minute'}
154
               The duration of the bars.
155
            capital_base : float <default: 1.0e5>
156
               How much capital to start with.
157
            asset_finder : An AssetFinder object
158
                A new AssetFinder object to be used in this TradingEnvironment
159
            equities_metadata : can be either:
160
                            - dict
161
                            - pandas.DataFrame
162
                            - object with 'read' property
163
                If dict is provided, it must have the following structure:
164
                * keys are the identifiers
165
                * values are dicts containing the metadata, with the metadata
166
                  field name as the key
167
                If pandas.DataFrame is provided, it must have the
168
                following structure:
169
                * column names must be the metadata fields
170
                * index must be the different asset identifiers
171
                * array contents should be the metadata value
172
                If an object with a 'read' property is provided, 'read' must
173
                return rows containing at least one of 'sid' or 'symbol' along
174
                with the other metadata fields.
175
            identifiers : List
176
                Any asset identifiers that are not provided in the
177
                equities_metadata, but will be traded by this TradingAlgorithm
178
        """
179
        self.sources = []
180
181
        # List of trading controls to be used to validate orders.
182
        self.trading_controls = []
183
184
        # List of account controls to be checked on each bar.
185
        self.account_controls = []
186
187
        self._recorded_vars = {}
188
        self.namespace = kwargs.pop('namespace', {})
189
190
        self._platform = kwargs.pop('platform', 'zipline')
191
192
        self.logger = None
193
194
        self.data_portal = None
195
196
        # If an env has been provided, pop it
197
        self.trading_environment = kwargs.pop('env', None)
198
199
        if self.trading_environment is None:
200
            self.trading_environment = TradingEnvironment()
201
202
        # Update the TradingEnvironment with the provided asset metadata
203
        self.trading_environment.write_data(
204
            equities_data=kwargs.pop('equities_metadata', {}),
205
            equities_identifiers=kwargs.pop('identifiers', []),
206
            futures_data=kwargs.pop('futures_metadata', {}),
207
        )
208
209
        # set the capital base
210
        self.capital_base = kwargs.pop('capital_base', DEFAULT_CAPITAL_BASE)
211
        self.sim_params = kwargs.pop('sim_params', None)
212
        if self.sim_params is None:
213
            self.sim_params = create_simulation_parameters(
214
                capital_base=self.capital_base,
215
                start=kwargs.pop('start', None),
216
                end=kwargs.pop('end', None),
217
                env=self.trading_environment,
218
            )
219
        else:
220
            self.sim_params.update_internal_from_env(self.trading_environment)
221
222
        self.perf_tracker = None
223
        # Pull in the environment's new AssetFinder for quick reference
224
        self.asset_finder = self.trading_environment.asset_finder
225
226
        # Initialize Pipeline API data.
227
        self.init_engine(kwargs.pop('get_pipeline_loader', None))
228
        self._pipelines = {}
229
        # Create an always-expired cache so that we compute the first time data
230
        # is requested.
231
        self._pipeline_cache = CachedObject(None, pd.Timestamp(0, tz='UTC'))
232
233
        self.blotter = kwargs.pop('blotter', None)
234
        if not self.blotter:
235
            self.blotter = Blotter(
236
                slippage_func=VolumeShareSlippage(),
237
                commission=PerShare(),
238
                data_frequency=self.data_frequency,
239
            )
240
241
        # The symbol lookup date specifies the date to use when resolving
242
        # symbols to sids, and can be set using set_symbol_lookup_date()
243
        self._symbol_lookup_date = None
244
245
        self.portfolio_needs_update = True
246
        self.account_needs_update = True
247
        self.performance_needs_update = True
248
        self._portfolio = None
249
        self._account = None
250
251
        # If string is passed in, execute and get reference to
252
        # functions.
253
        self.algoscript = kwargs.pop('script', None)
254
255
        self._initialize = None
256
        self._before_trading_start = None
257
        self._analyze = None
258
259
        self.event_manager = EventManager(
260
            create_context=kwargs.pop('create_event_context', None),
261
        )
262
263
        if self.algoscript is not None:
264
            filename = kwargs.pop('algo_filename', None)
265
            if filename is None:
266
                filename = '<string>'
267
            code = compile(self.algoscript, filename, 'exec')
268
            exec_(code, self.namespace)
269
            self._initialize = self.namespace.get('initialize')
270
            if 'handle_data' not in self.namespace:
271
                raise ValueError('You must define a handle_data function.')
272
            else:
273
                self._handle_data = self.namespace['handle_data']
274
275
            self._before_trading_start = \
276
                self.namespace.get('before_trading_start')
277
            # Optional analyze function, gets called after run
278
            self._analyze = self.namespace.get('analyze')
279
280
        elif kwargs.get('initialize') and kwargs.get('handle_data'):
281
            if self.algoscript is not None:
282
                raise ValueError('You can not set script and \
283
                initialize/handle_data.')
284
            self._initialize = kwargs.pop('initialize')
285
            self._handle_data = kwargs.pop('handle_data')
286
            self._before_trading_start = kwargs.pop('before_trading_start',
287
                                                    None)
288
            self._analyze = kwargs.pop('analyze', None)
289
290
        self.event_manager.add_event(
291
            zipline.utils.events.Event(
292
                zipline.utils.events.Always(),
293
                # We pass handle_data.__func__ to get the unbound method.
294
                # We will explicitly pass the algorithm to bind it again.
295
                self.handle_data.__func__,
296
            ),
297
            prepend=True,
298
        )
299
300
        # If method not defined, NOOP
301
        if self._initialize is None:
302
            self._initialize = lambda x: None
303
304
        # Alternative way of setting data_frequency for backwards
305
        # compatibility.
306
        if 'data_frequency' in kwargs:
307
            self.data_frequency = kwargs.pop('data_frequency')
308
309
        # Prepare the algo for initialization
310
        self.initialized = False
311
        self.initialize_args = args
312
        self.initialize_kwargs = kwargs
313
314
        self.benchmark_sid = kwargs.pop('benchmark_sid', None)
315
316
    def init_engine(self, get_loader):
317
        """
318
        Construct and store a PipelineEngine from loader.
319
320
        If get_loader is None, constructs a NoOpPipelineEngine.
321
        """
322
        if get_loader is not None:
323
            self.engine = SimplePipelineEngine(
324
                get_loader,
325
                self.trading_environment.trading_days,
326
                self.asset_finder,
327
            )
328
        else:
329
            self.engine = NoOpPipelineEngine()
330
331
    def initialize(self, *args, **kwargs):
332
        """
333
        Call self._initialize with `self` made available to Zipline API
334
        functions.
335
        """
336
        with ZiplineAPI(self):
337
            self._initialize(self, *args, **kwargs)
338
339
    def before_trading_start(self, data):
340
        if self._before_trading_start is None:
341
            return
342
343
        self._before_trading_start(self, data)
344
345
    def handle_data(self, data):
346
        self._handle_data(self, data)
347
348
        # Unlike trading controls which remain constant unless placing an
349
        # order, account controls can change each bar. Thus, must check
350
        # every bar no matter if the algorithm places an order or not.
351
        self.validate_account_controls()
352
353
    def analyze(self, perf):
354
        if self._analyze is None:
355
            return
356
357
        with ZiplineAPI(self):
358
            self._analyze(self, perf)
359
360
    def __repr__(self):
361
        """
362
        N.B. this does not yet represent a string that can be used
363
        to instantiate an exact copy of an algorithm.
364
365
        However, it is getting close, and provides some value as something
366
        that can be inspected interactively.
367
        """
368
        return """
369
{class_name}(
370
    capital_base={capital_base}
371
    sim_params={sim_params},
372
    initialized={initialized},
373
    slippage={slippage},
374
    commission={commission},
375
    blotter={blotter},
376
    recorded_vars={recorded_vars})
377
""".strip().format(class_name=self.__class__.__name__,
378
                   capital_base=self.capital_base,
379
                   sim_params=repr(self.sim_params),
380
                   initialized=self.initialized,
381
                   slippage=repr(self.blotter.slippage_func),
382
                   commission=repr(self.blotter.commission),
383
                   blotter=repr(self.blotter),
384
                   recorded_vars=repr(self.recorded_vars))
385
386
    def _create_clock(self):
387
        """
388
        If the clock property is not set, then create one based on frequency.
389
        """
390
        if self.sim_params.data_frequency == 'minute':
391
            env = self.trading_environment
392
            trading_o_and_c = env.open_and_closes.ix[
393
                self.sim_params.trading_days]
394
            market_opens = trading_o_and_c['market_open'].values.astype(
395
                'datetime64[ns]').astype(np.int64)
396
            market_closes = trading_o_and_c['market_close'].values.astype(
397
                'datetime64[ns]').astype(np.int64)
398
399
            minutely_emission = self.sim_params.emission_rate == "minute"
400
401
            clock = MinuteSimulationClock(
402
                self.sim_params.trading_days,
403
                market_opens,
404
                market_closes,
405
                env.trading_days,
406
                minutely_emission
407
            )
408
            self.data_portal.setup_offset_cache(
409
                clock.minutes_by_day,
410
                clock.minutes_to_day,
411
                self.sim_params.trading_days)
412
            return clock
413
        else:
414
            return DailySimulationClock(self.sim_params.trading_days)
415
416
    def _create_benchmark_source(self):
417
        return BenchmarkSource(
418
            self.benchmark_sid,
419
            self.trading_environment,
420
            self.sim_params.trading_days,
421
            self.data_portal,
422
            emission_rate=self.sim_params.emission_rate,
423
        )
424
425
    def _create_generator(self, sim_params):
426
        if sim_params is not None:
427
            self.sim_params = sim_params
428
429
        if self.perf_tracker is None:
430
            # HACK: When running with the `run` method, we set perf_tracker to
431
            # None so that it will be overwritten here.
432
            self.perf_tracker = PerformanceTracker(
433
                sim_params=self.sim_params,
434
                env=self.trading_environment,
435
                data_portal=self.data_portal
436
            )
437
438
            # Set the dt initially to the period start by forcing it to change.
439
            self.on_dt_changed(self.sim_params.period_start)
440
441
        if not self.initialized:
442
            self.initialize(*self.initialize_args, **self.initialize_kwargs)
443
            self.initialized = True
444
445
        self.trading_client = self._create_trading_client(sim_params)
446
447
        return self.trading_client.transform()
448
449
    def _create_trading_client(self, sim_params):
450
        return AlgorithmSimulator(
451
            self,
452
            sim_params,
453
            self.data_portal,
454
            self._create_clock(),
455
            self._create_benchmark_source()
456
        )
457
458
    def get_generator(self):
459
        """
460
        Override this method to add new logic to the construction
461
        of the generator. Overrides can use the _create_generator
462
        method to get a standard construction generator.
463
        """
464
        return self._create_generator(self.sim_params)
465
466
    def run(self, data_portal=None):
467
        """Run the algorithm.
468
469
        :Arguments:
470
            source : DataPortal
471
472
        :Returns:
473
            daily_stats : pandas.DataFrame
474
              Daily performance metrics such as returns, alpha etc.
475
476
        """
477
        self.data_portal = data_portal
478
479
        # Force a reset of the performance tracker, in case
480
        # this is a repeat run of the algorithm.
481
        self.perf_tracker = None
482
483
        # Create zipline and loop through simulated_trading.
484
        # Each iteration returns a perf dictionary
485
        perfs = []
486
        for perf in self.get_generator():
487
            perfs.append(perf)
488
489
        # convert perf dict to pandas dataframe
490
        daily_stats = self._create_daily_stats(perfs)
491
492
        self.analyze(daily_stats)
493
494
        return daily_stats
495
496
    def _write_and_map_id_index_to_sids(self, identifiers, as_of_date):
497
        # Build new Assets for identifiers that can't be resolved as
498
        # sids/Assets
499
        identifiers_to_build = []
500
        for identifier in identifiers:
501
            asset = None
502
503
            if isinstance(identifier, Asset):
504
                asset = self.asset_finder.retrieve_asset(sid=identifier.sid,
505
                                                         default_none=True)
506
            elif isinstance(identifier, Integral):
507
                asset = self.asset_finder.retrieve_asset(sid=identifier,
508
                                                         default_none=True)
509
            if asset is None:
510
                identifiers_to_build.append(identifier)
511
512
        self.trading_environment.write_data(
513
            equities_identifiers=identifiers_to_build)
514
515
        # We need to clear out any cache misses that were stored while trying
516
        # to do lookups.  The real fix for this problem is to not construct an
517
        # AssetFinder until we `run()` when we actually have all the data we
518
        # need to so.
519
        self.asset_finder._reset_caches()
520
521
        return self.asset_finder.map_identifier_index_to_sids(
522
            identifiers, as_of_date,
523
        )
524
525
    def _create_daily_stats(self, perfs):
526
        # create daily and cumulative stats dataframe
527
        daily_perfs = []
528
        # TODO: the loop here could overwrite expected properties
529
        # of daily_perf. Could potentially raise or log a
530
        # warning.
531
        for perf in perfs:
532
            if 'daily_perf' in perf:
533
534
                perf['daily_perf'].update(
535
                    perf['daily_perf'].pop('recorded_vars')
536
                )
537
                perf['daily_perf'].update(perf['cumulative_risk_metrics'])
538
                daily_perfs.append(perf['daily_perf'])
539
            else:
540
                self.risk_report = perf
541
542
        daily_dts = [np.datetime64(perf['period_close'], utc=True)
543
                     for perf in daily_perfs]
544
        daily_stats = pd.DataFrame(daily_perfs, index=daily_dts)
545
546
        return daily_stats
547
548
    @api_method
549
    def get_environment(self, field='platform'):
550
        env = {
551
            'arena': self.sim_params.arena,
552
            'data_frequency': self.sim_params.data_frequency,
553
            'start': self.sim_params.first_open,
554
            'end': self.sim_params.last_close,
555
            'capital_base': self.sim_params.capital_base,
556
            'platform': self._platform
557
        }
558
        if field == '*':
559
            return env
560
        else:
561
            return env[field]
562
563
    @api_method
564
    def fetch_csv(self, url,
565
                  pre_func=None,
566
                  post_func=None,
567
                  date_column='date',
568
                  date_format=None,
569
                  timezone=pytz.utc.zone,
570
                  symbol=None,
571
                  mask=True,
572
                  symbol_column=None,
573
                  special_params_checker=None,
574
                  **kwargs):
575
576
        # Show all the logs every time fetcher is used.
577
        csv_data_source = PandasRequestsCSV(
578
            url,
579
            pre_func,
580
            post_func,
581
            self.trading_environment,
582
            self.sim_params.period_start,
583
            self.sim_params.period_end,
584
            date_column,
585
            date_format,
586
            timezone,
587
            symbol,
588
            mask,
589
            symbol_column,
590
            data_frequency=self.data_frequency,
591
            special_params_checker=special_params_checker,
592
            **kwargs
593
        )
594
595
        # ingest this into dataportal
596
        self.data_portal.handle_extra_source(csv_data_source.df,
597
                                             self.sim_params)
598
599
        return csv_data_source
600
601
    def add_event(self, rule=None, callback=None):
602
        """
603
        Adds an event to the algorithm's EventManager.
604
        """
605
        self.event_manager.add_event(
606
            zipline.utils.events.Event(rule, callback),
607
        )
608
609
    @api_method
610
    def schedule_function(self,
611
                          func,
612
                          date_rule=None,
613
                          time_rule=None,
614
                          half_days=True):
615
        """
616
        Schedules a function to be called with some timed rules.
617
        """
618
        date_rule = date_rule or DateRuleFactory.every_day()
619
        time_rule = ((time_rule or TimeRuleFactory.market_open())
620
                     if self.sim_params.data_frequency == 'minute' else
621
                     # If we are in daily mode the time_rule is ignored.
622
                     zipline.utils.events.Always())
623
624
        self.add_event(
625
            make_eventrule(date_rule, time_rule, half_days),
626
            func,
627
        )
628
629
    @api_method
630
    def record(self, *args, **kwargs):
631
        """
632
        Track and record local variable (i.e. attributes) each day.
633
        """
634
        # Make 2 objects both referencing the same iterator
635
        args = [iter(args)] * 2
636
637
        # Zip generates list entries by calling `next` on each iterator it
638
        # receives.  In this case the two iterators are the same object, so the
639
        # call to next on args[0] will also advance args[1], resulting in zip
640
        # returning (a,b) (c,d) (e,f) rather than (a,a) (b,b) (c,c) etc.
641
        positionals = zip(*args)
642
        for name, value in chain(positionals, iteritems(kwargs)):
643
            self._recorded_vars[name] = value
644
645
    @api_method
646
    def set_benchmark(self, benchmark_sid):
647
        if self.initialized:
648
            raise SetBenchmarkOutsideInitialize()
649
650
        self.benchmark_sid = benchmark_sid
651
652
    @api_method
653
    @preprocess(symbol_str=ensure_upper_case)
654
    def symbol(self, symbol_str):
655
        """
656
        Default symbol lookup for any source that directly maps the
657
        symbol to the Asset (e.g. yahoo finance).
658
        """
659
        # If the user has not set the symbol lookup date,
660
        # use the period_end as the date for sybmol->sid resolution.
661
        _lookup_date = self._symbol_lookup_date if self._symbol_lookup_date is not None \
662
            else self.sim_params.period_end
663
664
        return self.asset_finder.lookup_symbol(
665
            symbol_str,
666
            as_of_date=_lookup_date,
667
        )
668
669
    @api_method
670
    def symbols(self, *args):
671
        """
672
        Default symbols lookup for any source that directly maps the
673
        symbol to the Asset (e.g. yahoo finance).
674
        """
675
        return [self.symbol(identifier) for identifier in args]
676
677
    @api_method
678
    def sid(self, a_sid):
679
        """
680
        Default sid lookup for any source that directly maps the integer sid
681
        to the Asset.
682
        """
683
        return self.asset_finder.retrieve_asset(a_sid)
684
685
    @api_method
686
    @preprocess(symbol=ensure_upper_case)
687
    def future_symbol(self, symbol):
688
        """ Lookup a futures contract with a given symbol.
689
690
        Parameters
691
        ----------
692
        symbol : str
693
            The symbol of the desired contract.
694
695
        Returns
696
        -------
697
        Future
698
            A Future object.
699
700
        Raises
701
        ------
702
        SymbolNotFound
703
            Raised when no contract named 'symbol' is found.
704
705
        """
706
        return self.asset_finder.lookup_future_symbol(symbol)
707
708
    @api_method
709
    @preprocess(root_symbol=ensure_upper_case)
710
    def future_chain(self, root_symbol, as_of_date=None):
711
        """ Look up a future chain with the specified parameters.
712
713
        Parameters
714
        ----------
715
        root_symbol : str
716
            The root symbol of a future chain.
717
        as_of_date : datetime.datetime or pandas.Timestamp or str, optional
718
            Date at which the chain determination is rooted. I.e. the
719
            existing contract whose notice date is first after this date is
720
            the primary contract, etc.
721
722
        Returns
723
        -------
724
        FutureChain
725
            The future chain matching the specified parameters.
726
727
        Raises
728
        ------
729
        RootSymbolNotFound
730
            If a future chain could not be found for the given root symbol.
731
        """
732
        if as_of_date:
733
            try:
734
                as_of_date = pd.Timestamp(as_of_date, tz='UTC')
735
            except ValueError:
736
                raise UnsupportedDatetimeFormat(input=as_of_date,
737
                                                method='future_chain')
738
        return FutureChain(
739
            asset_finder=self.asset_finder,
740
            get_datetime=self.get_datetime,
741
            root_symbol=root_symbol,
742
            as_of_date=as_of_date
743
        )
744
745
    def _calculate_order_value_amount(self, asset, value):
746
        """
747
        Calculates how many shares/contracts to order based on the type of
748
        asset being ordered.
749
        """
750
        last_price = self.trading_client.current_data[asset].price
751
752
        if tolerant_equals(last_price, 0):
753
            zero_message = "Price of 0 for {psid}; can't infer value".format(
754
                psid=asset
755
            )
756
            if self.logger:
757
                self.logger.debug(zero_message)
758
            # Don't place any order
759
            return 0
760
761
        if isinstance(asset, Future):
762
            value_multiplier = asset.contract_multiplier
763
        else:
764
            value_multiplier = 1
765
766
        return value / (last_price * value_multiplier)
767
768
    @api_method
769
    def order(self, sid, amount,
770
              limit_price=None,
771
              stop_price=None,
772
              style=None):
773
        """
774
        Place an order using the specified parameters.
775
        """
776
        # Truncate to the integer share count that's either within .0001 of
777
        # amount or closer to zero.
778
        # E.g. 3.9999 -> 4.0; 5.5 -> 5.0; -5.5 -> -5.0
779
        amount = int(round_if_near_integer(amount))
780
781
        # Raises a ZiplineError if invalid parameters are detected.
782
        self.validate_order_params(sid,
783
                                   amount,
784
                                   limit_price,
785
                                   stop_price,
786
                                   style)
787
788
        # Convert deprecated limit_price and stop_price parameters to use
789
        # ExecutionStyle objects.
790
        style = self.__convert_order_params_for_blotter(limit_price,
791
                                                        stop_price,
792
                                                        style)
793
        return self.blotter.order(sid, amount, style)
794
795
    def validate_order_params(self,
796
                              asset,
797
                              amount,
798
                              limit_price,
799
                              stop_price,
800
                              style):
801
        """
802
        Helper method for validating parameters to the order API function.
803
804
        Raises an UnsupportedOrderParameters if invalid arguments are found.
805
        """
806
807
        if not self.initialized:
808
            raise OrderDuringInitialize(
809
                msg="order() can only be called from within handle_data()"
810
            )
811
812
        if style:
813
            if limit_price:
814
                raise UnsupportedOrderParameters(
815
                    msg="Passing both limit_price and style is not supported."
816
                )
817
818
            if stop_price:
819
                raise UnsupportedOrderParameters(
820
                    msg="Passing both stop_price and style is not supported."
821
                )
822
823
        if not isinstance(asset, Asset):
824
            raise UnsupportedOrderParameters(
825
                msg="Passing non-Asset argument to 'order()' is not supported."
826
                    " Use 'sid()' or 'symbol()' methods to look up an Asset."
827
            )
828
829
        for control in self.trading_controls:
830
            control.validate(asset,
831
                             amount,
832
                             self.updated_portfolio(),
833
                             self.get_datetime(),
834
                             self.trading_client.current_data)
835
836
    @staticmethod
837
    def __convert_order_params_for_blotter(limit_price, stop_price, style):
838
        """
839
        Helper method for converting deprecated limit_price and stop_price
840
        arguments into ExecutionStyle instances.
841
842
        This function assumes that either style == None or (limit_price,
843
        stop_price) == (None, None).
844
        """
845
        # TODO_SS: DeprecationWarning for usage of limit_price and stop_price.
846
        if style:
847
            assert (limit_price, stop_price) == (None, None)
848
            return style
849
        if limit_price and stop_price:
850
            return StopLimitOrder(limit_price, stop_price)
851
        if limit_price:
852
            return LimitOrder(limit_price)
853
        if stop_price:
854
            return StopOrder(stop_price)
855
        else:
856
            return MarketOrder()
857
858
    @api_method
859
    def order_value(self, sid, value,
860
                    limit_price=None, stop_price=None, style=None):
861
        """
862
        Place an order by desired value rather than desired number of shares.
863
        If the requested sid exists, the requested value is
864
        divided by its price to imply the number of shares to transact.
865
        If the Asset being ordered is a Future, the 'value' calculated
866
        is actually the exposure, as Futures have no 'value'.
867
868
        value > 0 :: Buy/Cover
869
        value < 0 :: Sell/Short
870
        Market order:    order(sid, value)
871
        Limit order:     order(sid, value, limit_price)
872
        Stop order:      order(sid, value, None, stop_price)
873
        StopLimit order: order(sid, value, limit_price, stop_price)
874
        """
875
        amount = self._calculate_order_value_amount(sid, value)
876
        return self.order(sid, amount,
877
                          limit_price=limit_price,
878
                          stop_price=stop_price,
879
                          style=style)
880
881
    @property
882
    def recorded_vars(self):
883
        return copy(self._recorded_vars)
884
885
    @property
886
    def portfolio(self):
887
        return self.updated_portfolio()
888
889
    def updated_portfolio(self):
890
        if self.portfolio_needs_update:
891
            self._portfolio = \
892
                self.perf_tracker.get_portfolio(self.performance_needs_update,
893
                                                self.datetime)
894
            self.portfolio_needs_update = False
895
            self.performance_needs_update = False
896
        return self._portfolio
897
898
    @property
899
    def account(self):
900
        return self.updated_account()
901
902
    def updated_account(self):
903
        if self.account_needs_update:
904
            self._account = \
905
                self.perf_tracker.get_account(self.performance_needs_update,
906
                                              self.datetime)
907
            self.account_needs_update = False
908
            self.performance_needs_update = False
909
        return self._account
910
911
    def set_logger(self, logger):
912
        self.logger = logger
913
914
    def on_dt_changed(self, dt):
915
        """
916
        Callback triggered by the simulation loop whenever the current dt
917
        changes.
918
919
        Any logic that should happen exactly once at the start of each datetime
920
        group should happen here.
921
        """
922
        assert isinstance(dt, datetime), \
923
            "Attempt to set algorithm's current time with non-datetime"
924
        assert dt.tzinfo == pytz.utc, \
925
            "Algorithm expects a utc datetime"
926
927
        self.datetime = dt
928
        self.perf_tracker.set_date(dt)
929
        self.blotter.set_date(dt)
930
931
        self.portfolio_needs_update = True
932
        self.account_needs_update = True
933
        self.performance_needs_update = True
934
935
    @api_method
936
    def get_datetime(self, tz=None):
937
        """
938
        Returns the simulation datetime.
939
        """
940
        dt = self.datetime
941
        assert dt.tzinfo == pytz.utc, "Algorithm should have a utc datetime"
942
943
        if tz is not None:
944
            # Convert to the given timezone passed as a string or tzinfo.
945
            if isinstance(tz, string_types):
946
                tz = pytz.timezone(tz)
947
            dt = dt.astimezone(tz)
948
949
        return dt  # datetime.datetime objects are immutable.
950
951
    def update_dividends(self, dividend_frame):
952
        """
953
        Set DataFrame used to process dividends.  DataFrame columns should
954
        contain at least the entries in zp.DIVIDEND_FIELDS.
955
        """
956
        self.perf_tracker.update_dividends(dividend_frame)
957
958
    @api_method
959
    def set_slippage(self, slippage):
960
        if not isinstance(slippage, SlippageModel):
961
            raise UnsupportedSlippageModel()
962
        if self.initialized:
963
            raise OverrideSlippagePostInit()
964
        self.blotter.slippage_func = slippage
965
966
    @api_method
967
    def set_commission(self, commission):
968
        if not isinstance(commission, (PerShare, PerTrade, PerDollar)):
969
            raise UnsupportedCommissionModel()
970
971
        if self.initialized:
972
            raise OverrideCommissionPostInit()
973
        self.blotter.commission = commission
974
975
    @api_method
976
    def set_symbol_lookup_date(self, dt):
977
        """
978
        Set the date for which symbols will be resolved to their sids
979
        (symbols may map to different firms or underlying assets at
980
        different times)
981
        """
982
        try:
983
            self._symbol_lookup_date = pd.Timestamp(dt, tz='UTC')
984
        except ValueError:
985
            raise UnsupportedDatetimeFormat(input=dt,
986
                                            method='set_symbol_lookup_date')
987
988
    # Remain backwards compatibility
989
    @property
990
    def data_frequency(self):
991
        return self.sim_params.data_frequency
992
993
    @data_frequency.setter
994
    def data_frequency(self, value):
995
        assert value in ('daily', 'minute')
996
        self.sim_params.data_frequency = value
997
998
    @api_method
999
    def order_percent(self, sid, percent,
1000
                      limit_price=None, stop_price=None, style=None):
1001
        """
1002
        Place an order in the specified asset corresponding to the given
1003
        percent of the current portfolio value.
1004
1005
        Note that percent must expressed as a decimal (0.50 means 50\%).
1006
        """
1007
        value = self.portfolio.portfolio_value * percent
1008
        return self.order_value(sid, value,
1009
                                limit_price=limit_price,
1010
                                stop_price=stop_price,
1011
                                style=style)
1012
1013
    @api_method
1014
    def order_target(self, sid, target,
1015
                     limit_price=None, stop_price=None, style=None):
1016
        """
1017
        Place an order to adjust a position to a target number of shares. If
1018
        the position doesn't already exist, this is equivalent to placing a new
1019
        order. If the position does exist, this is equivalent to placing an
1020
        order for the difference between the target number of shares and the
1021
        current number of shares.
1022
        """
1023
        if sid in self.portfolio.positions:
1024
            current_position = self.portfolio.positions[sid].amount
1025
            req_shares = target - current_position
1026
            return self.order(sid, req_shares,
1027
                              limit_price=limit_price,
1028
                              stop_price=stop_price,
1029
                              style=style)
1030
        else:
1031
            return self.order(sid, target,
1032
                              limit_price=limit_price,
1033
                              stop_price=stop_price,
1034
                              style=style)
1035
1036
    @api_method
1037
    def order_target_value(self, sid, target,
1038
                           limit_price=None, stop_price=None, style=None):
1039
        """
1040
        Place an order to adjust a position to a target value. If
1041
        the position doesn't already exist, this is equivalent to placing a new
1042
        order. If the position does exist, this is equivalent to placing an
1043
        order for the difference between the target value and the
1044
        current value.
1045
        If the Asset being ordered is a Future, the 'target value' calculated
1046
        is actually the target exposure, as Futures have no 'value'.
1047
        """
1048
        target_amount = self._calculate_order_value_amount(sid, target)
1049
        return self.order_target(sid, target_amount,
1050
                                 limit_price=limit_price,
1051
                                 stop_price=stop_price,
1052
                                 style=style)
1053
1054
    @api_method
1055
    def order_target_percent(self, sid, target,
1056
                             limit_price=None, stop_price=None, style=None):
1057
        """
1058
        Place an order to adjust a position to a target percent of the
1059
        current portfolio value. If the position doesn't already exist, this is
1060
        equivalent to placing a new order. If the position does exist, this is
1061
        equivalent to placing an order for the difference between the target
1062
        percent and the current percent.
1063
1064
        Note that target must expressed as a decimal (0.50 means 50\%).
1065
        """
1066
        target_value = self.portfolio.portfolio_value * target
1067
        return self.order_target_value(sid, target_value,
1068
                                       limit_price=limit_price,
1069
                                       stop_price=stop_price,
1070
                                       style=style)
1071
1072
    @api_method
1073
    def get_open_orders(self, sid=None):
1074
        if sid is None:
1075
            return {
1076
                key: [order.to_api_obj() for order in orders]
1077
                for key, orders in iteritems(self.blotter.open_orders)
1078
                if orders
1079
            }
1080
        if sid in self.blotter.open_orders:
1081
            orders = self.blotter.open_orders[sid]
1082
            return [order.to_api_obj() for order in orders]
1083
        return []
1084
1085
    @api_method
1086
    def get_order(self, order_id):
1087
        if order_id in self.blotter.orders:
1088
            return self.blotter.orders[order_id].to_api_obj()
1089
1090
    @api_method
1091
    def cancel_order(self, order_param):
1092
        order_id = order_param
1093
        if isinstance(order_param, zipline.protocol.Order):
1094
            order_id = order_param.id
1095
1096
        self.blotter.cancel(order_id)
1097
1098
    @api_method
1099
    @require_initialized(HistoryInInitialize())
1100
    def history(self, sids, bar_count, frequency, field, ffill=True):
1101
        if self.data_portal is None:
1102
            raise Exception("no data portal!")
1103
1104
        return self.data_portal.get_history_window(
1105
            sids,
1106
            self.datetime,
1107
            bar_count,
1108
            frequency,
1109
            field,
1110
            ffill,
1111
        )
1112
1113
    ####################
1114
    # Account Controls #
1115
    ####################
1116
1117
    def register_account_control(self, control):
1118
        """
1119
        Register a new AccountControl to be checked on each bar.
1120
        """
1121
        if self.initialized:
1122
            raise RegisterAccountControlPostInit()
1123
        self.account_controls.append(control)
1124
1125
    def validate_account_controls(self):
1126
        for control in self.account_controls:
1127
            control.validate(self.updated_portfolio(),
1128
                             self.updated_account(),
1129
                             self.get_datetime(),
1130
                             self.trading_client.current_data)
1131
1132
    @api_method
1133
    def set_max_leverage(self, max_leverage=None):
1134
        """
1135
        Set a limit on the maximum leverage of the algorithm.
1136
        """
1137
        control = MaxLeverage(max_leverage)
1138
        self.register_account_control(control)
1139
1140
    ####################
1141
    # Trading Controls #
1142
    ####################
1143
1144
    def register_trading_control(self, control):
1145
        """
1146
        Register a new TradingControl to be checked prior to order calls.
1147
        """
1148
        if self.initialized:
1149
            raise RegisterTradingControlPostInit()
1150
        self.trading_controls.append(control)
1151
1152
    @api_method
1153
    def set_max_position_size(self,
1154
                              sid=None,
1155
                              max_shares=None,
1156
                              max_notional=None):
1157
        """
1158
        Set a limit on the number of shares and/or dollar value held for the
1159
        given sid. Limits are treated as absolute values and are enforced at
1160
        the time that the algo attempts to place an order for sid. This means
1161
        that it's possible to end up with more than the max number of shares
1162
        due to splits/dividends, and more than the max notional due to price
1163
        improvement.
1164
1165
        If an algorithm attempts to place an order that would result in
1166
        increasing the absolute value of shares/dollar value exceeding one of
1167
        these limits, raise a TradingControlException.
1168
        """
1169
        control = MaxPositionSize(asset=sid,
1170
                                  max_shares=max_shares,
1171
                                  max_notional=max_notional)
1172
        self.register_trading_control(control)
1173
1174
    @api_method
1175
    def set_max_order_size(self, sid=None, max_shares=None, max_notional=None):
1176
        """
1177
        Set a limit on the number of shares and/or dollar value of any single
1178
        order placed for sid.  Limits are treated as absolute values and are
1179
        enforced at the time that the algo attempts to place an order for sid.
1180
1181
        If an algorithm attempts to place an order that would result in
1182
        exceeding one of these limits, raise a TradingControlException.
1183
        """
1184
        control = MaxOrderSize(asset=sid,
1185
                               max_shares=max_shares,
1186
                               max_notional=max_notional)
1187
        self.register_trading_control(control)
1188
1189
    @api_method
1190
    def set_max_order_count(self, max_count):
1191
        """
1192
        Set a limit on the number of orders that can be placed within the given
1193
        time interval.
1194
        """
1195
        control = MaxOrderCount(max_count)
1196
        self.register_trading_control(control)
1197
1198
    @api_method
1199
    def set_do_not_order_list(self, restricted_list):
1200
        """
1201
        Set a restriction on which sids can be ordered.
1202
        """
1203
        control = RestrictedListOrder(restricted_list)
1204
        self.register_trading_control(control)
1205
1206
    @api_method
1207
    def set_long_only(self):
1208
        """
1209
        Set a rule specifying that this algorithm cannot take short positions.
1210
        """
1211
        self.register_trading_control(LongOnly())
1212
1213
    ##############
1214
    # Pipeline API
1215
    ##############
1216
    @api_method
1217
    @require_not_initialized(AttachPipelineAfterInitialize())
1218
    def attach_pipeline(self, pipeline, name, chunksize=None):
1219
        """
1220
        Register a pipeline to be computed at the start of each day.
1221
        """
1222
        if self._pipelines:
1223
            raise NotImplementedError("Multiple pipelines are not supported.")
1224
        if chunksize is None:
1225
            # Make the first chunk smaller to get more immediate results:
1226
            # (one week, then every half year)
1227
            chunks = iter(chain([5], repeat(126)))
1228
        else:
1229
            chunks = iter(repeat(int(chunksize)))
1230
        self._pipelines[name] = pipeline, chunks
1231
1232
        # Return the pipeline to allow expressions like
1233
        # p = attach_pipeline(Pipeline(), 'name')
1234
        return pipeline
1235
1236
    @api_method
1237
    @require_initialized(PipelineOutputDuringInitialize())
1238
    def pipeline_output(self, name):
1239
        """
1240
        Get the results of pipeline with name `name`.
1241
1242
        Parameters
1243
        ----------
1244
        name : str
1245
            Name of the pipeline for which results are requested.
1246
1247
        Returns
1248
        -------
1249
        results : pd.DataFrame
1250
            DataFrame containing the results of the requested pipeline for
1251
            the current simulation date.
1252
1253
        Raises
1254
        ------
1255
        NoSuchPipeline
1256
            Raised when no pipeline with the name `name` has been registered.
1257
1258
        See Also
1259
        --------
1260
        :meth:`zipline.pipeline.engine.PipelineEngine.run_pipeline`
1261
        """
1262
        # NOTE: We don't currently support multiple pipelines, but we plan to
1263
        # in the future.
1264
        try:
1265
            p, chunks = self._pipelines[name]
1266
        except KeyError:
1267
            raise NoSuchPipeline(
1268
                name=name,
1269
                valid=list(self._pipelines.keys()),
1270
            )
1271
        return self._pipeline_output(p, chunks)
1272
1273
    def _pipeline_output(self, pipeline, chunks):
1274
        """
1275
        Internal implementation of `pipeline_output`.
1276
        """
1277
        today = normalize_date(self.get_datetime())
1278
        try:
1279
            data = self._pipeline_cache.unwrap(today)
1280
        except Expired:
1281
            data, valid_until = self._run_pipeline(
1282
                pipeline, today, next(chunks),
1283
            )
1284
            self._pipeline_cache = CachedObject(data, valid_until)
1285
1286
        # Now that we have a cached result, try to return the data for today.
1287
        try:
1288
            return data.loc[today]
1289
        except KeyError:
1290
            # This happens if no assets passed the pipeline screen on a given
1291
            # day.
1292
            return pd.DataFrame(index=[], columns=data.columns)
1293
1294
    def _run_pipeline(self, pipeline, start_date, chunksize):
1295
        """
1296
        Compute `pipeline`, providing values for at least `start_date`.
1297
1298
        Produces a DataFrame containing data for days between `start_date` and
1299
        `end_date`, where `end_date` is defined by:
1300
1301
            `end_date = min(start_date + chunksize trading days,
1302
                            simulation_end)`
1303
1304
        Returns
1305
        -------
1306
        (data, valid_until) : tuple (pd.DataFrame, pd.Timestamp)
1307
1308
        See Also
1309
        --------
1310
        PipelineEngine.run_pipeline
1311
        """
1312
        days = self.trading_environment.trading_days
1313
1314
        # Load data starting from the previous trading day...
1315
        start_date_loc = days.get_loc(start_date)
1316
1317
        # ...continuing until either the day before the simulation end, or
1318
        # until chunksize days of data have been loaded.
1319
        sim_end = self.sim_params.last_close.normalize()
1320
        end_loc = min(start_date_loc + chunksize, days.get_loc(sim_end))
1321
        end_date = days[end_loc]
1322
1323
        return \
1324
            self.engine.run_pipeline(pipeline, start_date, end_date), end_date
1325
1326
    ##################
1327
    # End Pipeline API
1328
    ##################
1329
1330
    @classmethod
1331
    def all_api_methods(cls):
1332
        """
1333
        Return a list of all the TradingAlgorithm API methods.
1334
        """
1335
        return [
1336
            fn for fn in itervalues(vars(cls))
1337
            if getattr(fn, 'is_api_method', False)
1338
        ]
1339