Completed
Pull Request — master (#846)
by Warren
05:15 queued 03:42
created

zipline.TestRemoveDataAlgo.initialize()   A

Complexity

Conditions 1

Size

Total Lines 3

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 1
dl 0
loc 3
rs 10
1
#
2
# Copyright 2014 Quantopian, Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
17
"""
18
Algorithm Protocol
19
===================
20
21
For a class to be passed as a trading algorithm to the
22
:py:class:`zipline.lines.SimulatedTrading` zipline it must follow an
23
implementation protocol. Examples of this algorithm protocol are provided
24
below.
25
26
The algorithm must expose methods:
27
28
  - initialize: method that takes no args, no returns. Simply called to
29
    enable the algorithm to set any internal state needed.
30
31
  - get_sid_filter: method that takes no args, and returns a list of valid
32
    sids. List must have a length between 1 and 10. If None is returned the
33
    filter will block all events.
34
35
  - handle_data: method that accepts a :py:class:`zipline.protocol.BarData`
36
    of the current state of the simulation universe. An example data object:
37
38
        ..  This outputs the table as an HTML table but for some reason there
39
            is no bounding box. Make the previous paraagraph ending colon a
40
            double-colon to turn this back into blockquoted table in ASCII art.
41
42
        +-----------------+--------------+----------------+-------------------+
43
        |                 | sid(133)     |  sid(134)      | sid(135)          |
44
        +=================+==============+================+===================+
45
        | price           | $10.10       | $22.50         | $13.37            |
46
        +-----------------+--------------+----------------+-------------------+
47
        | volume          | 10,000       | 5,000          | 50,000            |
48
        +-----------------+--------------+----------------+-------------------+
49
        | mvg_avg_30      | $9.97        | $22.61         | $13.37            |
50
        +-----------------+--------------+----------------+-------------------+
51
        | dt              | 6/30/2012    | 6/30/2011      | 6/29/2012         |
52
        +-----------------+--------------+----------------+-------------------+
53
54
  - set_order: method that accepts a callable. Will be set as the value of the
55
    order method of trading_client. An algorithm can then place orders with a
56
    valid sid and a number of shares::
57
58
        self.order(sid(133), share_count)
59
60
  - set_performance: property which can be set equal to the
61
    cumulative_trading_performance property of the trading_client. An
62
    algorithm can then check position information with the
63
    Portfolio object::
64
65
        self.Portfolio[sid(133)]['cost_basis']
66
67
  - set_transact_setter: method that accepts a callable. Will
68
    be set as the value of the set_transact_setter method of
69
    the trading_client. This allows an algorithm to change the
70
    slippage model used to predict transactions based on orders
71
    and trade events.
72
73
"""
74
from copy import deepcopy
75
import numpy as np
76
77
from nose.tools import assert_raises
78
79
from six.moves import range
80
from six import itervalues
81
82
from zipline.algorithm import TradingAlgorithm
83
from zipline.api import (
84
    FixedSlippage,
85
    order,
86
    set_slippage,
87
    record,
88
    sid,
89
)
90
from zipline.errors import UnsupportedOrderParameters
91
from zipline.assets import Future, Equity
92
from zipline.finance.execution import (
93
    LimitOrder,
94
    MarketOrder,
95
    StopLimitOrder,
96
    StopOrder,
97
)
98
from zipline.finance.controls import AssetDateBounds
99
from zipline.transforms import BatchTransform, batch_transform
100
101
102
class TestAlgorithm(TradingAlgorithm):
103
    """
104
    This algorithm will send a specified number of orders, to allow unit tests
105
    to verify the orders sent/received, transactions created, and positions
106
    at the close of a simulation.
107
    """
108
109
    def initialize(self,
110
                   sid,
111
                   amount,
112
                   order_count,
113
                   sid_filter=None,
114
                   slippage=None,
115
                   commission=None):
116
        self.count = order_count
117
        self.asset = self.sid(sid)
118
        self.amount = amount
119
        self.incr = 0
120
121
        if sid_filter:
122
            self.sid_filter = sid_filter
123
        else:
124
            self.sid_filter = [self.asset.sid]
125
126
        if slippage is not None:
127
            self.set_slippage(slippage)
128
129
        if commission is not None:
130
            self.set_commission(commission)
131
132
    def handle_data(self, data):
133
        # place an order for amount shares of sid
134
        if self.incr < self.count:
135
            self.order(self.asset, self.amount)
136
            self.incr += 1
137
138
139
class HeavyBuyAlgorithm(TradingAlgorithm):
140
    """
141
    This algorithm will send a specified number of orders, to allow unit tests
142
    to verify the orders sent/received, transactions created, and positions
143
    at the close of a simulation.
144
    """
145
146
    def initialize(self, sid, amount):
147
        self.asset = self.sid(sid)
148
        self.amount = amount
149
        self.incr = 0
150
151
    def handle_data(self, data):
152
        # place an order for 100 shares of sid
153
        self.order(self.asset, self.amount)
154
        self.incr += 1
155
156
157
class NoopAlgorithm(TradingAlgorithm):
158
    """
159
    Dolce fa niente.
160
    """
161
    def get_sid_filter(self):
162
        return []
163
164
    def initialize(self):
165
        pass
166
167
    def set_transact_setter(self, txn_sim_callable):
168
        pass
169
170
    def handle_data(self, data):
171
        pass
172
173
174
class ExceptionAlgorithm(TradingAlgorithm):
175
    """
176
    Throw an exception from the method name specified in the
177
    constructor.
178
    """
179
180
    def initialize(self, throw_from, sid):
181
182
        self.throw_from = throw_from
183
        self.asset = self.sid(sid)
184
185
        if self.throw_from == "initialize":
186
            raise Exception("Algo exception in initialize")
187
        else:
188
            pass
189
190
    def set_portfolio(self, portfolio):
191
        if self.throw_from == "set_portfolio":
192
            raise Exception("Algo exception in set_portfolio")
193
        else:
194
            pass
195
196
    def handle_data(self, data):
197
        if self.throw_from == "handle_data":
198
            raise Exception("Algo exception in handle_data")
199
        else:
200
            pass
201
202
    def get_sid_filter(self):
203
        if self.throw_from == "get_sid_filter":
204
            raise Exception("Algo exception in get_sid_filter")
205
        else:
206
            return [self.asset]
207
208
    def set_transact_setter(self, txn_sim_callable):
209
        pass
210
211
212
class DivByZeroAlgorithm(TradingAlgorithm):
213
214
    def initialize(self, sid):
215
        self.asset = self.sid(sid)
216
        self.incr = 0
217
218
    def handle_data(self, data):
219
        self.incr += 1
220
        if self.incr > 4:
221
            5 / 0
222
        pass
223
224
225
class TooMuchProcessingAlgorithm(TradingAlgorithm):
226
227
    def initialize(self, sid):
228
        self.asset = self.sid(sid)
229
230
    def handle_data(self, data):
231
        # Unless we're running on some sort of
232
        # supercomputer this will hit timeout.
233
        for i in range(1000000000):
234
            self.foo = i
235
236
237
class TimeoutAlgorithm(TradingAlgorithm):
238
239
    def initialize(self, sid):
240
        self.asset = self.sid(sid)
241
        self.incr = 0
242
243
    def handle_data(self, data):
244
        if self.incr > 4:
245
            import time
246
            time.sleep(100)
247
        pass
248
249
250
class RecordAlgorithm(TradingAlgorithm):
251
    def initialize(self):
252
        self.incr = 0
253
254
    def handle_data(self, data):
255
        self.incr += 1
256
        self.record(incr=self.incr)
257
        name = 'name'
258
        self.record(name, self.incr)
259
        record(name, self.incr, 'name2', 2, name3=self.incr)
260
261
262
class TestOrderAlgorithm(TradingAlgorithm):
263
    def initialize(self):
264
        self.incr = 0
265
266
    def handle_data(self, data):
267
        if self.incr == 0:
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
268
            assert 0 not in self.portfolio.positions
269
        else:
270
            assert self.portfolio.positions[0]['amount'] == \
271
                self.incr, "Orders not filled immediately."
272
            assert self.portfolio.positions[0]['last_sale_price'] == \
273
                data[0].price, "Orders not filled at current price."
274
        self.incr += 1
275
        self.order(self.sid(0), 1)
276
277
278
class TestOrderInstantAlgorithm(TradingAlgorithm):
279
    def initialize(self):
280
        self.incr = 0
281
        self.last_price = None
282
283
    def handle_data(self, data):
284
        if self.incr == 0:
285
            assert 0 not in self.portfolio.positions
286
        else:
287
            assert self.portfolio.positions[0]['amount'] == \
288
                self.incr, "Orders not filled immediately."
289
            assert self.portfolio.positions[0]['last_sale_price'] == \
290
                self.last_price, "Orders was not filled at last price."
291
        self.incr += 2
292
        self.order_value(self.sid(0), data[0].price * 2.)
293
        self.last_price = data[0].price
294
295
296
class TestOrderStyleForwardingAlgorithm(TradingAlgorithm):
297
    """
298
    Test Algorithm for verifying that ExecutionStyles are properly forwarded by
299
    order API helper methods.  Pass the name of the method to be tested as a
300
    string parameter to this algorithm's constructor.
301
    """
302
303
    def __init__(self, *args, **kwargs):
304
        self.method_name = kwargs.pop('method_name')
305
        super(TestOrderStyleForwardingAlgorithm, self)\
306
            .__init__(*args, **kwargs)
307
308
    def initialize(self):
309
        self.incr = 0
310
        self.last_price = None
311
312
    def handle_data(self, data):
313
        if self.incr == 0:
314
            assert len(self.portfolio.positions.keys()) == 0
315
316
            method_to_check = getattr(self, self.method_name)
317
            method_to_check(self.sid(0),
318
                            data[0].price,
319
                            style=StopLimitOrder(10, 10))
320
321
            assert len(self.blotter.open_orders[0]) == 1
322
            result = self.blotter.open_orders[0][0]
323
            assert result.limit == 10
324
            assert result.stop == 10
325
326
            self.incr += 1
327
328
329
class TestOrderValueAlgorithm(TradingAlgorithm):
330
    def initialize(self):
331
        self.incr = 0
332
        self.sale_price = None
333
334
    def handle_data(self, data):
335
        if self.incr == 0:
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
336
            assert 0 not in self.portfolio.positions
337
        else:
338
            assert self.portfolio.positions[0]['amount'] == \
339
                self.incr, "Orders not filled immediately."
340
            assert self.portfolio.positions[0]['last_sale_price'] == \
341
                data[0].price, "Orders not filled at current price."
342
        self.incr += 2
343
344
        multiplier = 2.
345
        if isinstance(self.sid(0), Future):
346
            multiplier *= self.sid(0).contract_multiplier
347
348
        self.order_value(self.sid(0), data[0].price * multiplier)
349
350
351
class TestTargetAlgorithm(TradingAlgorithm):
352
    def initialize(self):
353
        self.target_shares = 0
354
        self.sale_price = None
355
356
    def handle_data(self, data):
357
        if self.target_shares == 0:
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
358
            assert 0 not in self.portfolio.positions
359
        else:
360
            assert self.portfolio.positions[0]['amount'] == \
361
                self.target_shares, "Orders not filled immediately."
362
            assert self.portfolio.positions[0]['last_sale_price'] == \
363
                data[0].price, "Orders not filled at current price."
364
        self.target_shares = np.random.randint(1, 30)
365
        self.order_target(self.sid(0), self.target_shares)
366
367
368
class TestOrderPercentAlgorithm(TradingAlgorithm):
369
    def initialize(self):
370
        self.target_shares = 0
371
        self.sale_price = None
372
373
    def handle_data(self, data):
374
        if self.target_shares == 0:
375
            assert 0 not in self.portfolio.positions
376
            self.order(self.sid(0), 10)
377
            self.target_shares = 10
378
            return
379
        else:
380
            assert self.portfolio.positions[0]['amount'] == \
381
                self.target_shares, "Orders not filled immediately."
382
            assert self.portfolio.positions[0]['last_sale_price'] == \
383
                data[0].price, "Orders not filled at current price."
384
385
        self.order_percent(self.sid(0), .001)
386
387
        if isinstance(self.sid(0), Equity):
388
            self.target_shares += np.floor(
389
                (.001 * self.portfolio.portfolio_value) / data[0].price
390
            )
391
        if isinstance(self.sid(0), Future):
392
            self.target_shares += np.floor(
393
                (.001 * self.portfolio.portfolio_value) /
394
                (data[0].price * self.sid(0).contract_multiplier)
395
            )
396
397
398
class TestTargetPercentAlgorithm(TradingAlgorithm):
399
    def initialize(self):
400
        self.target_shares = 0
401
        self.sale_price = None
402
403
    def handle_data(self, data):
404
        if self.target_shares == 0:
405
            assert 0 not in self.portfolio.positions
406
            self.target_shares = 1
407
        else:
408
            assert np.round(self.portfolio.portfolio_value * 0.002) == \
409
                self.portfolio.positions[0]['amount'] * self.sale_price, \
410
                "Orders not filled correctly."
411
            assert self.portfolio.positions[0]['last_sale_price'] == \
412
                data[0].price, "Orders not filled at current price."
413
        self.sale_price = data[0].price
414
        self.order_target_percent(self.sid(0), .002)
415
416
417
class TestTargetValueAlgorithm(TradingAlgorithm):
418
    def initialize(self):
419
        self.target_shares = 0
420
        self.sale_price = None
421
422
    def handle_data(self, data):
423
        if self.target_shares == 0:
424
            assert 0 not in self.portfolio.positions
425
            self.order(self.sid(0), 10)
426
            self.target_shares = 10
427
            return
428
        else:
429
            print(self.portfolio)
430
            assert self.portfolio.positions[0]['amount'] == \
431
                self.target_shares, "Orders not filled immediately."
432
            assert self.portfolio.positions[0]['last_sale_price'] == \
433
                data[0].price, "Orders not filled at current price."
434
435
        self.order_target_value(self.sid(0), 20)
436
        self.target_shares = np.round(20 / data[0].price)
437
438
        if isinstance(self.sid(0), Equity):
439
            self.target_shares = np.round(20 / data[0].price)
440
        if isinstance(self.sid(0), Future):
441
            self.target_shares = np.round(
442
                20 / (data[0].price * self.sid(0).contract_multiplier))
443
444
445
class FutureFlipAlgo(TestAlgorithm):
446
    def handle_data(self, data):
447
        if len(self.portfolio.positions) > 0:
448
            if self.portfolio.positions[self.asset.sid]["amount"] > 0:
449
                self.order_target(self.asset, -self.amount)
450
            else:
451
                self.order_target(self.asset, 0)
452
        else:
453
            self.order_target(self.asset, self.amount)
454
455
############################
456
# AccountControl Test Algos#
457
############################
458
459
460
class SetMaxLeverageAlgorithm(TradingAlgorithm):
461
    def initialize(self, max_leverage=None):
462
        self.set_max_leverage(max_leverage=max_leverage)
463
464
465
############################
466
# TradingControl Test Algos#
467
############################
468
469
470
class SetMaxPositionSizeAlgorithm(TradingAlgorithm):
471
    def initialize(self, sid=None, max_shares=None, max_notional=None):
472
        self.order_count = 0
473
        self.set_max_position_size(sid=sid,
474
                                   max_shares=max_shares,
475
                                   max_notional=max_notional)
476
477
478
class SetMaxOrderSizeAlgorithm(TradingAlgorithm):
479
    def initialize(self, sid=None, max_shares=None, max_notional=None):
480
        self.order_count = 0
481
        self.set_max_order_size(sid=sid,
482
                                max_shares=max_shares,
483
                                max_notional=max_notional)
484
485
486
class SetDoNotOrderListAlgorithm(TradingAlgorithm):
487
    def initialize(self, sid=None, restricted_list=None):
488
        self.order_count = 0
489
        self.set_do_not_order_list(restricted_list)
490
491
492
class SetMaxOrderCountAlgorithm(TradingAlgorithm):
493
    def initialize(self, count):
494
        self.order_count = 0
495
        self.set_max_order_count(count)
496
497
498
class SetLongOnlyAlgorithm(TradingAlgorithm):
499
    def initialize(self):
500
        self.order_count = 0
501
        self.set_long_only()
502
503
504
class SetAssetDateBoundsAlgorithm(TradingAlgorithm):
505
    """
506
    Algorithm that tries to order 1 share of sid 0 on every bar and has an
507
    AssetDateBounds() trading control in place.
508
    """
509
    def initialize(self):
510
        self.register_trading_control(AssetDateBounds())
511
512
    def handle_data(algo, data):
513
        algo.order(algo.sid(0), 1)
514
515
516
class TestRegisterTransformAlgorithm(TradingAlgorithm):
517
    def initialize(self, *args, **kwargs):
518
        self.set_slippage(FixedSlippage())
519
520
    def handle_data(self, data):
521
        pass
522
523
524
class AmbitiousStopLimitAlgorithm(TradingAlgorithm):
525
    """
526
    Algorithm that tries to buy with extremely low stops/limits and tries to
527
    sell with extremely high versions of same. Should not end up with any
528
    positions for reasonable data.
529
    """
530
531
    def initialize(self, *args, **kwargs):
532
        self.asset = self.sid(kwargs.pop('sid'))
533
534
    def handle_data(self, data):
535
536
        ########
537
        # Buys #
538
        ########
539
540
        # Buy with low limit, shouldn't trigger.
541
        self.order(self.asset, 100, limit_price=1)
542
543
        # But with high stop, shouldn't trigger
544
        self.order(self.asset, 100, stop_price=10000000)
545
546
        # Buy with high limit (should trigger) but also high stop (should
547
        # prevent trigger).
548
        self.order(self.asset, 100, limit_price=10000000, stop_price=10000000)
549
550
        # Buy with low stop (should trigger), but also low limit (should
551
        # prevent trigger).
552
        self.order(self.asset, 100, limit_price=1, stop_price=1)
553
554
        #########
555
        # Sells #
556
        #########
557
558
        # Sell with high limit, shouldn't trigger.
559
        self.order(self.asset, -100, limit_price=1000000)
560
561
        # Sell with low stop, shouldn't trigger.
562
        self.order(self.asset, -100, stop_price=1)
563
564
        # Sell with low limit (should trigger), but also high stop (should
565
        # prevent trigger).
566
        self.order(self.asset, -100, limit_price=1000000, stop_price=1000000)
567
568
        # Sell with low limit (should trigger), but also low stop (should
569
        # prevent trigger).
570
        self.order(self.asset, -100, limit_price=1, stop_price=1)
571
572
        ###################
573
        # Rounding Checks #
574
        ###################
575
        self.order(self.asset, 100, limit_price=.00000001)
576
        self.order(self.asset, -100, stop_price=.00000001)
577
578
579
##########################################
580
# Algorithm using simple batch transforms
581
582
class ReturnPriceBatchTransform(BatchTransform):
583
    def get_value(self, data):
584
        assert data.shape[1] == self.window_length, \
585
            "data shape={0} does not equal window_length={1} for data={2}".\
586
            format(data.shape[1], self.window_length, data)
587
        return data.price
588
589
590
@batch_transform
591
def return_price_batch_decorator(data):
592
    return data.price
593
594
595
@batch_transform
596
def return_args_batch_decorator(data, *args, **kwargs):
597
    return args, kwargs
598
599
600
@batch_transform
601
def return_data(data, *args, **kwargs):
602
    return data
603
604
605
@batch_transform
606
def uses_ufunc(data, *args, **kwargs):
607
    # ufuncs like np.log should not crash
608
    return np.log(data)
609
610
611
@batch_transform
612
def price_multiple(data, multiplier, extra_arg=1):
613
    return data.price * multiplier * extra_arg
614
615
616
class BatchTransformAlgorithm(TradingAlgorithm):
617
    def initialize(self, *args, **kwargs):
618
        self.refresh_period = kwargs.pop('refresh_period', 1)
619
        self.window_length = kwargs.pop('window_length', 3)
620
621
        self.args = args
622
        self.kwargs = kwargs
623
624
        self.history_return_price_class = []
625
        self.history_return_price_decorator = []
626
        self.history_return_args = []
627
        self.history_return_arbitrary_fields = []
628
        self.history_return_nan = []
629
        self.history_return_sid_filter = []
630
        self.history_return_field_filter = []
631
        self.history_return_field_no_filter = []
632
        self.history_return_ticks = []
633
        self.history_return_not_full = []
634
635
        self.return_price_class = ReturnPriceBatchTransform(
636
            refresh_period=self.refresh_period,
637
            window_length=self.window_length,
638
            clean_nans=False
639
        )
640
641
        self.return_price_decorator = return_price_batch_decorator(
642
            refresh_period=self.refresh_period,
643
            window_length=self.window_length,
644
            clean_nans=False
645
        )
646
647
        self.return_args_batch = return_args_batch_decorator(
648
            refresh_period=self.refresh_period,
649
            window_length=self.window_length,
650
            clean_nans=False
651
        )
652
653
        self.return_arbitrary_fields = return_data(
654
            refresh_period=self.refresh_period,
655
            window_length=self.window_length,
656
            clean_nans=False
657
        )
658
659
        self.return_nan = return_price_batch_decorator(
660
            refresh_period=self.refresh_period,
661
            window_length=self.window_length,
662
            clean_nans=True
663
        )
664
665
        self.return_sid_filter = return_price_batch_decorator(
666
            refresh_period=self.refresh_period,
667
            window_length=self.window_length,
668
            clean_nans=True,
669
            sids=[0]
670
        )
671
672
        self.return_field_filter = return_data(
673
            refresh_period=self.refresh_period,
674
            window_length=self.window_length,
675
            clean_nans=True,
676
            fields=['price']
677
        )
678
679
        self.return_field_no_filter = return_data(
680
            refresh_period=self.refresh_period,
681
            window_length=self.window_length,
682
            clean_nans=True
683
        )
684
685
        self.return_not_full = return_data(
686
            refresh_period=1,
687
            window_length=self.window_length,
688
            compute_only_full=False
689
        )
690
691
        self.uses_ufunc = uses_ufunc(
692
            refresh_period=self.refresh_period,
693
            window_length=self.window_length,
694
            clean_nans=False
695
        )
696
697
        self.price_multiple = price_multiple(
698
            refresh_period=self.refresh_period,
699
            window_length=self.window_length,
700
            clean_nans=False
701
        )
702
703
        self.iter = 0
704
705
        self.set_slippage(FixedSlippage())
706
707
    def handle_data(self, data):
708
        self.history_return_price_class.append(
709
            self.return_price_class.handle_data(data))
710
        self.history_return_price_decorator.append(
711
            self.return_price_decorator.handle_data(data))
712
        self.history_return_args.append(
713
            self.return_args_batch.handle_data(
714
                data, *self.args, **self.kwargs))
715
        self.history_return_not_full.append(
716
            self.return_not_full.handle_data(data))
717
        self.uses_ufunc.handle_data(data)
718
719
        # check that calling transforms with the same arguments
720
        # is idempotent
721
        self.price_multiple.handle_data(data, 1, extra_arg=1)
722
723
        if self.price_multiple.full:
724
            pre = self.price_multiple.rolling_panel.get_current().shape[0]
725
            result1 = self.price_multiple.handle_data(data, 1, extra_arg=1)
726
            post = self.price_multiple.rolling_panel.get_current().shape[0]
727
            assert pre == post, "batch transform is appending redundant events"
728
            result2 = self.price_multiple.handle_data(data, 1, extra_arg=1)
729
            assert result1 is result2, "batch transform is not idempotent"
730
731
            # check that calling transform with the same data, but
732
            # different supplemental arguments results in new
733
            # results.
734
            result3 = self.price_multiple.handle_data(data, 2, extra_arg=1)
735
            assert result1 is not result3, \
736
                "batch transform is not updating for new args"
737
738
            result4 = self.price_multiple.handle_data(data, 1, extra_arg=2)
739
            assert result1 is not result4,\
740
                "batch transform is not updating for new kwargs"
741
742
        new_data = deepcopy(data)
743
        for sidint in new_data:
744
            new_data[sidint]['arbitrary'] = 123
745
746
        self.history_return_arbitrary_fields.append(
747
            self.return_arbitrary_fields.handle_data(new_data))
748
749
        # nan every second event price
750
        if self.iter % 2 == 0:
751
            self.history_return_nan.append(
752
                self.return_nan.handle_data(data))
753
        else:
754
            nan_data = deepcopy(data)
755
            nan_data.price = np.nan
756
            self.history_return_nan.append(
757
                self.return_nan.handle_data(nan_data))
758
759
        self.iter += 1
760
761
        # Add a new sid to check that it does not get included
762
        extra_sid_data = deepcopy(data)
763
        extra_sid_data[1] = extra_sid_data[0]
764
        self.history_return_sid_filter.append(
765
            self.return_sid_filter.handle_data(extra_sid_data)
766
        )
767
768
        # Add a field to check that it does not get included
769
        extra_field_data = deepcopy(data)
770
        extra_field_data[0]['ignore'] = extra_sid_data[0]['price']
771
        self.history_return_field_filter.append(
772
            self.return_field_filter.handle_data(extra_field_data)
773
        )
774
        self.history_return_field_no_filter.append(
775
            self.return_field_no_filter.handle_data(extra_field_data)
776
        )
777
778
779
class BatchTransformAlgorithmMinute(TradingAlgorithm):
780
    def initialize(self, *args, **kwargs):
781
        self.refresh_period = kwargs.pop('refresh_period', 1)
782
        self.window_length = kwargs.pop('window_length', 3)
783
784
        self.args = args
785
        self.kwargs = kwargs
786
787
        self.history = []
788
789
        self.batch_transform = return_price_batch_decorator(
790
            refresh_period=self.refresh_period,
791
            window_length=self.window_length,
792
            clean_nans=False,
793
            bars='minute'
794
        )
795
796
    def handle_data(self, data):
797
        self.history.append(self.batch_transform.handle_data(data))
798
799
800
class SetPortfolioAlgorithm(TradingAlgorithm):
801
    """
802
    An algorithm that tries to set the portfolio directly.
803
804
    The portfolio should be treated as a read-only object
805
    within the algorithm.
806
    """
807
808
    def initialize(self, *args, **kwargs):
809
        pass
810
811
    def handle_data(self, data):
812
        self.portfolio = 3
813
814
815
class TALIBAlgorithm(TradingAlgorithm):
816
    """
817
    An algorithm that applies a TA-Lib transform. The transform object can be
818
    passed at initialization with the 'talib' keyword argument. The results are
819
    stored in the talib_results array.
820
    """
821
    def initialize(self, *args, **kwargs):
822
823
        if 'talib' not in kwargs:
824
            raise KeyError('No TA-LIB transform specified '
825
                           '(use keyword \'talib\').')
826
        elif not isinstance(kwargs['talib'], (list, tuple)):
827
            self.talib_transforms = (kwargs['talib'],)
828
        else:
829
            self.talib_transforms = kwargs['talib']
830
831
        self.talib_results = dict((t, []) for t in self.talib_transforms)
832
833
    def handle_data(self, data):
834
        for t in self.talib_transforms:
835
            result = t.handle_data(data)
836
            if result is None:
837
                if len(t.talib_fn.output_names) == 1:
838
                    result = np.nan
839
                else:
840
                    result = (np.nan,) * len(t.talib_fn.output_names)
841
            self.talib_results[t].append(result)
842
843
844
class EmptyPositionsAlgorithm(TradingAlgorithm):
845
    """
846
    An algorithm that ensures that 'phantom' positions do not appear
847
    portfolio.positions in the case that a position has been entered
848
    and fully exited.
849
    """
850
    def initialize(self, *args, **kwargs):
851
        self.ordered = False
852
        self.exited = False
853
854
    def handle_data(self, data):
855
        if not self.ordered:
856
            for s in data:
857
                self.order(self.sid(s), 100)
858
            self.ordered = True
859
860
        if not self.exited:
861
            amounts = [pos.amount for pos
862
                       in itervalues(self.portfolio.positions)]
863
            if (
864
                all([(amount == 100) for amount in amounts]) and
865
                (len(amounts) == len(data.keys()))
866
            ):
867
                for stock in self.portfolio.positions:
868
                    self.order(self.sid(stock), -100)
869
                self.exited = True
870
871
        # Should be 0 when all positions are exited.
872
        self.record(num_positions=len(self.portfolio.positions))
873
874
875
class InvalidOrderAlgorithm(TradingAlgorithm):
876
    """
877
    An algorithm that tries to make various invalid order calls, verifying that
878
    appropriate exceptions are raised.
879
    """
880
    def initialize(self, *args, **kwargs):
881
        self.asset = self.sid(kwargs.pop('sids')[0])
882
883
    def handle_data(self, data):
884
        from zipline.api import (
885
            order_percent,
886
            order_target,
887
            order_target_percent,
888
            order_target_value,
889
            order_value,
890
        )
891
892
        for style in [MarketOrder(), LimitOrder(10),
893
                      StopOrder(10), StopLimitOrder(10, 10)]:
894
895
            with assert_raises(UnsupportedOrderParameters):
896
                order(self.asset, 10, limit_price=10, style=style)
897
898
            with assert_raises(UnsupportedOrderParameters):
899
                order(self.asset, 10, stop_price=10, style=style)
900
901
            with assert_raises(UnsupportedOrderParameters):
902
                order_value(self.asset, 300, limit_price=10, style=style)
903
904
            with assert_raises(UnsupportedOrderParameters):
905
                order_value(self.asset, 300, stop_price=10, style=style)
906
907
            with assert_raises(UnsupportedOrderParameters):
908
                order_percent(self.asset, .1, limit_price=10, style=style)
909
910
            with assert_raises(UnsupportedOrderParameters):
911
                order_percent(self.asset, .1, stop_price=10, style=style)
912
913
            with assert_raises(UnsupportedOrderParameters):
914
                order_target(self.asset, 100, limit_price=10, style=style)
915
916
            with assert_raises(UnsupportedOrderParameters):
917
                order_target(self.asset, 100, stop_price=10, style=style)
918
919
            with assert_raises(UnsupportedOrderParameters):
920
                order_target_value(self.asset, 100,
921
                                   limit_price=10,
922
                                   style=style)
923
924
            with assert_raises(UnsupportedOrderParameters):
925
                order_target_value(self.asset, 100,
926
                                   stop_price=10,
927
                                   style=style)
928
929
            with assert_raises(UnsupportedOrderParameters):
930
                order_target_percent(self.asset, .2,
931
                                     limit_price=10,
932
                                     style=style)
933
934
            with assert_raises(UnsupportedOrderParameters):
935
                order_target_percent(self.asset, .2,
936
                                     stop_price=10,
937
                                     style=style)
938
939
940
class TestRemoveDataAlgo(TradingAlgorithm):
941
    def initialize(self, *args, **kwargs):
942
        self.data = np.zeros(7)
943
        self.i = 0
944
945
    def handle_data(self, data):
946
        self.data[self.i] = len(data)
947
        self.i += 1
948
949
950
##############################
951
# Quantopian style algorithms
952
953
# Noop algo
954
def initialize_noop(context):
955
    pass
956
957
958
def handle_data_noop(context, data):
959
    pass
960
961
962
# API functions
963
def initialize_api(context):
964
    context.incr = 0
965
    context.sale_price = None
966
    set_slippage(FixedSlippage())
967
968
969
def handle_data_api(context, data):
970
    if context.incr == 0:
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
971
        assert 0 not in context.portfolio.positions
972
    else:
973
        assert context.portfolio.positions[0]['amount'] == \
974
            context.incr, "Orders not filled immediately."
975
        assert context.portfolio.positions[0]['last_sale_price'] == \
976
            data[0].price, "Orders not filled at current price."
977
    context.incr += 1
978
    order(sid(0), 1)
979
980
    record(incr=context.incr)
981
982
###########################
983
# AlgoScripts as strings
984
noop_algo = """
985
# Noop algo
986
def initialize(context):
987
    pass
988
989
def handle_data(context, data):
990
    pass
991
"""
992
993
api_algo = """
994
from zipline.api import (order,
995
                         set_slippage,
996
                         FixedSlippage,
997
                         record,
998
                         sid)
999
1000
def initialize(context):
1001
    context.incr = 0
1002
    context.sale_price = None
1003
    set_slippage(FixedSlippage())
1004
1005
def handle_data(context, data):
1006
    if context.incr == 0:
1007
        assert 0 not in context.portfolio.positions
1008
    else:
1009
        assert context.portfolio.positions[0]['amount'] == \
1010
                context.incr, "Orders not filled immediately."
1011
        assert context.portfolio.positions[0]['last_sale_price'] == \
1012
                data[0].price, "Orders not filled at current price."
1013
    context.incr += 1
1014
    order(sid(0), 1)
1015
1016
    record(incr=context.incr)
1017
"""
1018
1019
api_get_environment_algo = """
1020
from zipline.api import get_environment, order, symbol
1021
1022
1023
def initialize(context):
1024
    context.environment = get_environment()
1025
1026
handle_data = lambda context, data: order(symbol('TEST'), 1)
1027
"""
1028
1029
api_symbol_algo = """
1030
from zipline.api import (order,
1031
                         symbol)
1032
1033
def initialize(context):
1034
    pass
1035
1036
def handle_data(context, data):
1037
    order(symbol('TEST'), 1)
1038
"""
1039
1040
call_order_in_init = """
1041
from zipline.api import (order)
1042
1043
def initialize(context):
1044
    order(0, 10)
1045
    pass
1046
1047
def handle_data(context, data):
1048
    pass
1049
"""
1050
1051
access_portfolio_in_init = """
1052
def initialize(context):
1053
    var = context.portfolio.cash
1054
    pass
1055
1056
def handle_data(context, data):
1057
    pass
1058
"""
1059
1060
access_account_in_init = """
1061
def initialize(context):
1062
    var = context.account.settled_cash
1063
    pass
1064
1065
def handle_data(context, data):
1066
    pass
1067
"""
1068
1069
call_all_order_methods = """
1070
from zipline.api import (order,
1071
                         order_value,
1072
                         order_percent,
1073
                         order_target,
1074
                         order_target_value,
1075
                         order_target_percent,
1076
                         sid)
1077
1078
def initialize(context):
1079
    pass
1080
1081
def handle_data(context, data):
1082
    order(sid(0), 10)
1083
    order_value(sid(0), 300)
1084
    order_percent(sid(0), .1)
1085
    order_target(sid(0), 100)
1086
    order_target_value(sid(0), 100)
1087
    order_target_percent(sid(0), .2)
1088
"""
1089
1090
record_variables = """
1091
from zipline.api import record
1092
1093
def initialize(context):
1094
    context.stocks = [0, 1]
1095
    context.incr = 0
1096
1097
def handle_data(context, data):
1098
    context.incr += 1
1099
    record(incr=context.incr)
1100
"""
1101
1102
record_float_magic = """
1103
from zipline.api import record
1104
1105
def initialize(context):
1106
    context.stocks = [0, 1]
1107
    context.incr = 0
1108
1109
def handle_data(context, data):
1110
    context.incr += 1
1111
    record(data=float('%s'))
1112
"""
1113