Completed
Pull Request — master (#858)
by Eddie
02:03
created

tests.TestTradingControls.handle_data2()   A

Complexity

Conditions 4

Size

Total Lines 7

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 4
dl 0
loc 7
rs 9.2
1
#
2
# Copyright 2014 Quantopian, Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
import datetime
16
from datetime import timedelta
17
from mock import MagicMock
18
from nose_parameterized import parameterized
19
from six.moves import range
20
from testfixtures import TempDirectory
21
from textwrap import dedent
22
from unittest import TestCase
23
24
import numpy as np
25
import pandas as pd
26
27
from zipline.assets import Equity, Future
28
from zipline.utils.api_support import ZiplineAPI
29
from zipline.utils.control_flow import nullctx
30
from zipline.utils.test_utils import (
31
    setup_logger,
32
    teardown_logger,
33
    FakeDataPortal)
34
import zipline.utils.factory as factory
35
36
from zipline.errors import (
37
    OrderDuringInitialize,
38
    RegisterTradingControlPostInit,
39
    TradingControlViolation,
40
    AccountControlViolation,
41
    SymbolNotFound,
42
    RootSymbolNotFound,
43
    UnsupportedDatetimeFormat,
44
)
45
from zipline.test_algorithms import (
46
    access_account_in_init,
47
    access_portfolio_in_init,
48
    AmbitiousStopLimitAlgorithm,
49
    EmptyPositionsAlgorithm,
50
    InvalidOrderAlgorithm,
51
    RecordAlgorithm,
52
    FutureFlipAlgo,
53
    TestAlgorithm,
54
    TestOrderAlgorithm,
55
    TestOrderPercentAlgorithm,
56
    TestOrderStyleForwardingAlgorithm,
57
    TestOrderValueAlgorithm,
58
    TestRegisterTransformAlgorithm,
59
    TestTargetAlgorithm,
60
    TestTargetPercentAlgorithm,
61
    TestTargetValueAlgorithm,
62
    SetLongOnlyAlgorithm,
63
    SetAssetDateBoundsAlgorithm,
64
    SetMaxPositionSizeAlgorithm,
65
    SetMaxOrderCountAlgorithm,
66
    SetMaxOrderSizeAlgorithm,
67
    SetDoNotOrderListAlgorithm,
68
    SetMaxLeverageAlgorithm,
69
    api_algo,
70
    api_get_environment_algo,
71
    api_symbol_algo,
72
    call_all_order_methods,
73
    call_order_in_init,
74
    handle_data_api,
75
    handle_data_noop,
76
    initialize_api,
77
    initialize_noop,
78
    noop_algo,
79
    record_float_magic,
80
    record_variables,
81
)
82
from zipline.utils.context_tricks import CallbackManager
83
import zipline.utils.events
84
from zipline.utils.test_utils import to_utc
85
86
from zipline.finance.execution import LimitOrder
87
from zipline.finance.trading import SimulationParameters
88
from zipline.utils.api_support import set_algo_instance
89
from zipline.utils.events import DateRuleFactory, TimeRuleFactory, Always
90
from zipline.algorithm import TradingAlgorithm
91
from zipline.finance.trading import TradingEnvironment
92
from zipline.finance.commission import PerShare
93
94
from zipline.utils.test_utils import (
95
    create_data_portal,
96
    create_data_portal_from_trade_history
97
)
98
99
# Because test cases appear to reuse some resources.
100
101
102
_multiprocess_can_split_ = False
103
104
105
class TestRecordAlgorithm(TestCase):
106
107
    @classmethod
108
    def setUpClass(cls):
109
        cls.env = TradingEnvironment()
110
        cls.sids = [133]
111
        cls.env.write_data(equities_identifiers=cls.sids)
112
113
        cls.sim_params = factory.create_simulation_parameters(
114
            num_days=4,
115
            env=cls.env
116
        )
117
118
        cls.tempdir = TempDirectory()
119
120
        cls.data_portal = create_data_portal(
121
            cls.env,
122
            cls.tempdir,
123
            cls.sim_params,
124
            cls.sids
125
        )
126
127
    @classmethod
128
    def tearDownClass(cls):
129
        del cls.env
130
        cls.tempdir.cleanup()
131
132
    def test_record_incr(self):
133
        algo = RecordAlgorithm(sim_params=self.sim_params, env=self.env)
134
        output = algo.run(self.data_portal)
135
136
        np.testing.assert_array_equal(output['incr'].values,
137
                                      range(1, len(output) + 1))
138
        np.testing.assert_array_equal(output['name'].values,
139
                                      range(1, len(output) + 1))
140
        np.testing.assert_array_equal(output['name2'].values,
141
                                      [2] * len(output))
142
        np.testing.assert_array_equal(output['name3'].values,
143
                                      range(1, len(output) + 1))
144
145
146
class TestMiscellaneousAPI(TestCase):
147
148
    @classmethod
149
    def setUpClass(cls):
150
        cls.sids = [1, 2]
151
        cls.env = TradingEnvironment()
152
153
        metadata = {3: {'symbol': 'PLAY',
154
                        'start_date': '2002-01-01',
155
                        'end_date': '2004-01-01'},
156
                    4: {'symbol': 'PLAY',
157
                        'start_date': '2005-01-01',
158
                        'end_date': '2006-01-01'}}
159
160
        futures_metadata = {
161
            5: {
162
                'symbol': 'CLG06',
163
                'root_symbol': 'CL',
164
                'start_date': pd.Timestamp('2005-12-01', tz='UTC'),
165
                'notice_date': pd.Timestamp('2005-12-20', tz='UTC'),
166
                'expiration_date': pd.Timestamp('2006-01-20', tz='UTC')},
167
            6: {
168
                'root_symbol': 'CL',
169
                'symbol': 'CLK06',
170
                'start_date': pd.Timestamp('2005-12-01', tz='UTC'),
171
                'notice_date': pd.Timestamp('2006-03-20', tz='UTC'),
172
                'expiration_date': pd.Timestamp('2006-04-20', tz='UTC')},
173
            7: {
174
                'symbol': 'CLQ06',
175
                'root_symbol': 'CL',
176
                'start_date': pd.Timestamp('2005-12-01', tz='UTC'),
177
                'notice_date': pd.Timestamp('2006-06-20', tz='UTC'),
178
                'expiration_date': pd.Timestamp('2006-07-20', tz='UTC')},
179
            8: {
180
                'symbol': 'CLX06',
181
                'root_symbol': 'CL',
182
                'start_date': pd.Timestamp('2006-02-01', tz='UTC'),
183
                'notice_date': pd.Timestamp('2006-09-20', tz='UTC'),
184
                'expiration_date': pd.Timestamp('2006-10-20', tz='UTC')}
185
        }
186
        cls.env.write_data(equities_identifiers=cls.sids,
187
                           equities_data=metadata,
188
                           futures_data=futures_metadata)
189
190
        setup_logger(cls)
191
192
        cls.sim_params = factory.create_simulation_parameters(
193
            num_days=2,
194
            data_frequency='minute',
195
            emission_rate='daily',
196
            env=cls.env,
197
        )
198
199
        cls.temp_dir = TempDirectory()
200
201
        cls.data_portal = create_data_portal(
202
            cls.env,
203
            cls.temp_dir,
204
            cls.sim_params,
205
            cls.sids
206
        )
207
208
    @classmethod
209
    def tearDownClass(cls):
210
        del cls.env
211
        teardown_logger(cls)
212
        cls.temp_dir.cleanup()
213
214
    def test_zipline_api_resolves_dynamically(self):
215
        # Make a dummy algo.
216
        algo = TradingAlgorithm(
217
            initialize=lambda context: None,
218
            handle_data=lambda context, data: None,
219
            sim_params=self.sim_params,
220
        )
221
222
        # Verify that api methods get resolved dynamically by patching them out
223
        # and then calling them
224
        for method in algo.all_api_methods():
225
            name = method.__name__
226
            sentinel = object()
227
228
            def fake_method(*args, **kwargs):
229
                return sentinel
230
            setattr(algo, name, fake_method)
231
            with ZiplineAPI(algo):
232
                self.assertIs(sentinel, getattr(zipline.api, name)())
233
234
    def test_cannot_iterate_over_data(self):
235
        def initialize(algo):
236
            pass
237
238
        def handle_data(algo, data):
239
            for asset in data:
240
                pass
241
242
        algo = TradingAlgorithm(initialize=initialize,
243
                                handle_data=handle_data,
244
                                sim_params=self.sim_params,
245
                                env=self.env)
246
247
        with self.assertRaises(ValueError):
248
            algo.run(self.data_portal)
249
250
    def test_get_sid(self):
251
        algo_text = """
252
from zipline.api import sid
253
254
def initialize(context):
255
    pass
256
257
def handle_data(context, data):
258
    aapl = data[sid(3)]
259
    assert aapl.sid == 3
260
"""
261
262
        algo = TradingAlgorithm(script=algo_text,
263
                                sim_params=self.sim_params,
264
                                env=self.env)
265
266
        algo.run(self.data_portal)
267
268
    def test_sid_datetime(self):
269
        algo_text = """
270
from zipline.api import sid, get_datetime
271
272
def initialize(context):
273
    pass
274
275
def handle_data(context, data):
276
    aapl = data[sid(1)]
277
    assert aapl.datetime == get_datetime()
278
"""
279
280
        algo = TradingAlgorithm(script=algo_text,
281
                                sim_params=self.sim_params,
282
                                env=self.env)
283
284
        algo.run(self.data_portal)
285
286
    def test_get_environment(self):
287
        expected_env = {
288
            'arena': 'backtest',
289
            'data_frequency': 'minute',
290
            'start': pd.Timestamp('2006-01-03 14:31:00+0000', tz='UTC'),
291
            'end': pd.Timestamp('2006-01-04 21:00:00+0000', tz='UTC'),
292
            'capital_base': 100000.0,
293
            'platform': 'zipline'
294
        }
295
296
        def initialize(algo):
297
            self.assertEqual('zipline', algo.get_environment())
298
            self.assertEqual(expected_env, algo.get_environment('*'))
299
300
        def handle_data(algo, data):
301
            pass
302
303
        algo = TradingAlgorithm(initialize=initialize,
304
                                handle_data=handle_data,
305
                                sim_params=self.sim_params,
306
                                env=self.env)
307
        algo.run(self.data_portal)
308
309
    def test_get_open_orders(self):
310
        def initialize(algo):
311
            algo.minute = 0
312
313
        def handle_data(algo, data):
314
            if algo.minute == 0:
315
316
                # Should be filled by the next minute
317
                algo.order(algo.sid(1), 1)
318
319
                # Won't be filled because the price is too low.
320
                algo.order(algo.sid(2), 1, style=LimitOrder(0.01))
321
                algo.order(algo.sid(2), 1, style=LimitOrder(0.01))
322
                algo.order(algo.sid(2), 1, style=LimitOrder(0.01))
323
324
                all_orders = algo.get_open_orders()
325
                self.assertEqual(list(all_orders.keys()), [1, 2])
326
327
                self.assertEqual(all_orders[1], algo.get_open_orders(1))
328
                self.assertEqual(len(all_orders[1]), 1)
329
330
                self.assertEqual(all_orders[2], algo.get_open_orders(2))
331
                self.assertEqual(len(all_orders[2]), 3)
332
333
            if algo.minute == 1:
334
                # First order should have filled.
335
                # Second order should still be open.
336
                all_orders = algo.get_open_orders()
337
                self.assertEqual(list(all_orders.keys()), [2])
338
339
                self.assertEqual([], algo.get_open_orders(1))
340
341
                orders_2 = algo.get_open_orders(2)
342
                self.assertEqual(all_orders[2], orders_2)
343
                self.assertEqual(len(all_orders[2]), 3)
344
345
                for order in orders_2:
346
                    algo.cancel_order(order)
347
348
                all_orders = algo.get_open_orders()
349
                self.assertEqual(all_orders, {})
350
351
            algo.minute += 1
352
353
        algo = TradingAlgorithm(initialize=initialize,
354
                                handle_data=handle_data,
355
                                sim_params=self.sim_params,
356
                                env=self.env)
357
        algo.run(self.data_portal)
358
359
    def test_schedule_function(self):
360
        date_rules = DateRuleFactory
361
        time_rules = TimeRuleFactory
362
363
        def incrementer(algo, data):
364
            algo.func_called += 1
365
            self.assertEqual(
366
                algo.get_datetime().time(),
367
                datetime.time(hour=14, minute=31),
368
            )
369
370
        def initialize(algo):
371
            algo.func_called = 0
372
            algo.days = 1
373
            algo.date = None
374
            algo.schedule_function(
375
                func=incrementer,
376
                date_rule=date_rules.every_day(),
377
                time_rule=time_rules.market_open(),
378
            )
379
380
        def handle_data(algo, data):
381
            if not algo.date:
382
                algo.date = algo.get_datetime().date()
383
384
            if algo.date < algo.get_datetime().date():
385
                algo.days += 1
386
                algo.date = algo.get_datetime().date()
387
388
        algo = TradingAlgorithm(
389
            initialize=initialize,
390
            handle_data=handle_data,
391
            sim_params=self.sim_params,
392
            env=self.env,
393
        )
394
        algo.run(self.data_portal)
395
396
        self.assertEqual(algo.func_called, algo.days)
397
398
    def test_event_context(self):
399
        expected_data = []
400
        collected_data_pre = []
401
        collected_data_post = []
402
        function_stack = []
403
404
        def pre(data):
405
            function_stack.append(pre)
406
            collected_data_pre.append(data)
407
408
        def post(data):
409
            function_stack.append(post)
410
            collected_data_post.append(data)
411
412
        def initialize(context):
413
            context.add_event(Always(), f)
414
            context.add_event(Always(), g)
415
416
        def handle_data(context, data):
417
            function_stack.append(handle_data)
418
            expected_data.append(data)
419
420
        def f(context, data):
421
            function_stack.append(f)
422
423
        def g(context, data):
424
            function_stack.append(g)
425
426
        algo = TradingAlgorithm(
427
            initialize=initialize,
428
            handle_data=handle_data,
429
            sim_params=self.sim_params,
430
            create_event_context=CallbackManager(pre, post),
431
            env=self.env,
432
        )
433
        algo.run(self.data_portal)
434
435
        self.assertEqual(len(expected_data), 780)
436
        self.assertEqual(collected_data_pre, expected_data)
437
        self.assertEqual(collected_data_post, expected_data)
438
439
        self.assertEqual(
440
            len(function_stack),
441
            780 * 5,
442
            'Incorrect number of functions called: %s != 780' %
443
            len(function_stack),
444
        )
445
        expected_functions = [pre, handle_data, f, g, post] * 780
446
        for n, (f, g) in enumerate(zip(function_stack, expected_functions)):
447
            self.assertEqual(
448
                f,
449
                g,
450
                'function at position %d was incorrect, expected %s but got %s'
451
                % (n, g.__name__, f.__name__),
452
            )
453
454
    @parameterized.expand([
455
        ('daily',),
456
        ('minute'),
457
    ])
458
    def test_schedule_function_rule_creation(self, mode):
459
        def nop(*args, **kwargs):
460
            return None
461
462
        self.sim_params.data_frequency = mode
463
        algo = TradingAlgorithm(
464
            initialize=nop,
465
            handle_data=nop,
466
            sim_params=self.sim_params,
467
            env=self.env,
468
        )
469
470
        # Schedule something for NOT Always.
471
        algo.schedule_function(nop, time_rule=zipline.utils.events.Never())
472
473
        event_rule = algo.event_manager._events[1].rule
474
475
        self.assertIsInstance(event_rule, zipline.utils.events.OncePerDay)
476
477
        inner_rule = event_rule.rule
478
        self.assertIsInstance(inner_rule, zipline.utils.events.ComposedRule)
479
480
        first = inner_rule.first
481
        second = inner_rule.second
482
        composer = inner_rule.composer
483
484
        self.assertIsInstance(first, zipline.utils.events.Always)
485
486
        if mode == 'daily':
487
            self.assertIsInstance(second, zipline.utils.events.Always)
488
        else:
489
            self.assertIsInstance(second, zipline.utils.events.Never)
490
491
        self.assertIs(composer, zipline.utils.events.ComposedRule.lazy_and)
492
493
    def test_asset_lookup(self):
494
495
        algo = TradingAlgorithm(env=self.env)
496
497
        # Test before either PLAY existed
498
        algo.sim_params.period_end = pd.Timestamp('2001-12-01', tz='UTC')
499
        with self.assertRaises(SymbolNotFound):
500
            algo.symbol('PLAY')
501
        with self.assertRaises(SymbolNotFound):
502
            algo.symbols('PLAY')
503
504
        # Test when first PLAY exists
505
        algo.sim_params.period_end = pd.Timestamp('2002-12-01', tz='UTC')
506
        list_result = algo.symbols('PLAY')
507
        self.assertEqual(3, list_result[0])
508
509
        # Test after first PLAY ends
510
        algo.sim_params.period_end = pd.Timestamp('2004-12-01', tz='UTC')
511
        self.assertEqual(3, algo.symbol('PLAY'))
512
513
        # Test after second PLAY begins
514
        algo.sim_params.period_end = pd.Timestamp('2005-12-01', tz='UTC')
515
        self.assertEqual(4, algo.symbol('PLAY'))
516
517
        # Test after second PLAY ends
518
        algo.sim_params.period_end = pd.Timestamp('2006-12-01', tz='UTC')
519
        self.assertEqual(4, algo.symbol('PLAY'))
520
        list_result = algo.symbols('PLAY')
521
        self.assertEqual(4, list_result[0])
522
523
        # Test lookup SID
524
        self.assertIsInstance(algo.sid(3), Equity)
525
        self.assertIsInstance(algo.sid(4), Equity)
526
527
        # Supplying a non-string argument to symbol()
528
        # should result in a TypeError.
529
        with self.assertRaises(TypeError):
530
            algo.symbol(1)
531
532
        with self.assertRaises(TypeError):
533
            algo.symbol((1,))
534
535
        with self.assertRaises(TypeError):
536
            algo.symbol({1})
537
538
        with self.assertRaises(TypeError):
539
            algo.symbol([1])
540
541
        with self.assertRaises(TypeError):
542
            algo.symbol({'foo': 'bar'})
543
544
    def test_future_symbol(self):
545
        """ Tests the future_symbol API function.
546
        """
547
        algo = TradingAlgorithm(env=self.env)
548
        algo.datetime = pd.Timestamp('2006-12-01', tz='UTC')
549
550
        # Check that we get the correct fields for the CLG06 symbol
551
        cl = algo.future_symbol('CLG06')
552
        self.assertEqual(cl.sid, 5)
553
        self.assertEqual(cl.symbol, 'CLG06')
554
        self.assertEqual(cl.root_symbol, 'CL')
555
        self.assertEqual(cl.start_date, pd.Timestamp('2005-12-01', tz='UTC'))
556
        self.assertEqual(cl.notice_date, pd.Timestamp('2005-12-20', tz='UTC'))
557
        self.assertEqual(cl.expiration_date,
558
                         pd.Timestamp('2006-01-20', tz='UTC'))
559
560
        with self.assertRaises(SymbolNotFound):
561
            algo.future_symbol('')
562
563
        with self.assertRaises(SymbolNotFound):
564
            algo.future_symbol('PLAY')
565
566
        with self.assertRaises(SymbolNotFound):
567
            algo.future_symbol('FOOBAR')
568
569
        # Supplying a non-string argument to future_symbol()
570
        # should result in a TypeError.
571
        with self.assertRaises(TypeError):
572
            algo.future_symbol(1)
573
574
        with self.assertRaises(TypeError):
575
            algo.future_symbol((1,))
576
577
        with self.assertRaises(TypeError):
578
            algo.future_symbol({1})
579
580
        with self.assertRaises(TypeError):
581
            algo.future_symbol([1])
582
583
        with self.assertRaises(TypeError):
584
            algo.future_symbol({'foo': 'bar'})
585
586
    def test_future_chain(self):
587
        """ Tests the future_chain API function.
588
        """
589
        algo = TradingAlgorithm(env=self.env)
590
        algo.datetime = pd.Timestamp('2006-12-01', tz='UTC')
591
592
        # Check that the fields of the FutureChain object are set correctly
593
        cl = algo.future_chain('CL')
594
        self.assertEqual(cl.root_symbol, 'CL')
595
        self.assertEqual(cl.as_of_date, algo.datetime)
596
597
        # Check the fields are set correctly if an as_of_date is supplied
598
        as_of_date = pd.Timestamp('1952-08-11', tz='UTC')
599
600
        cl = algo.future_chain('CL', as_of_date=as_of_date)
601
        self.assertEqual(cl.root_symbol, 'CL')
602
        self.assertEqual(cl.as_of_date, as_of_date)
603
604
        cl = algo.future_chain('CL', as_of_date='1952-08-11')
605
        self.assertEqual(cl.root_symbol, 'CL')
606
        self.assertEqual(cl.as_of_date, as_of_date)
607
608
        # Check that weird capitalization is corrected
609
        cl = algo.future_chain('cL')
610
        self.assertEqual(cl.root_symbol, 'CL')
611
612
        cl = algo.future_chain('cl')
613
        self.assertEqual(cl.root_symbol, 'CL')
614
615
        # Check that invalid root symbols raise RootSymbolNotFound
616
        with self.assertRaises(RootSymbolNotFound):
617
            algo.future_chain('CLZ')
618
619
        with self.assertRaises(RootSymbolNotFound):
620
            algo.future_chain('')
621
622
        # Check that invalid dates raise UnsupportedDatetimeFormat
623
        with self.assertRaises(UnsupportedDatetimeFormat):
624
            algo.future_chain('CL', 'my_finger_slipped')
625
626
        with self.assertRaises(UnsupportedDatetimeFormat):
627
            algo.future_chain('CL', '2015-09-')
628
629
        # Supplying a non-string argument to future_chain()
630
        # should result in a TypeError.
631
        with self.assertRaises(TypeError):
632
            algo.future_chain(1)
633
634
        with self.assertRaises(TypeError):
635
            algo.future_chain((1,))
636
637
        with self.assertRaises(TypeError):
638
            algo.future_chain({1})
639
640
        with self.assertRaises(TypeError):
641
            algo.future_chain([1])
642
643
        with self.assertRaises(TypeError):
644
            algo.future_chain({'foo': 'bar'})
645
646
    def test_set_symbol_lookup_date(self):
647
        """
648
        Test the set_symbol_lookup_date API method.
649
        """
650
        # Note we start sid enumeration at i+3 so as not to
651
        # collide with sids [1, 2] added in the setUp() method.
652
        dates = pd.date_range('2013-01-01', freq='2D', periods=2, tz='UTC')
653
        # Create two assets with the same symbol but different
654
        # non-overlapping date ranges.
655
        metadata = pd.DataFrame.from_records(
656
            [
657
                {
658
                    'sid': i + 3,
659
                    'symbol': 'DUP',
660
                    'start_date': date.value,
661
                    'end_date': (date + timedelta(days=1)).value,
662
                }
663
                for i, date in enumerate(dates)
664
            ]
665
        )
666
        env = TradingEnvironment()
667
        env.write_data(equities_df=metadata)
668
        algo = TradingAlgorithm(env=env)
669
670
        # Set the period end to a date after the period end
671
        # dates for our assets.
672
        algo.sim_params.period_end = pd.Timestamp('2015-01-01', tz='UTC')
673
674
        # With no symbol lookup date set, we will use the period end date
675
        # for the as_of_date, resulting here in the asset with the earlier
676
        # start date being returned.
677
        result = algo.symbol('DUP')
678
        self.assertEqual(result.symbol, 'DUP')
679
680
        # By first calling set_symbol_lookup_date, the relevant asset
681
        # should be returned by lookup_symbol
682
        for i, date in enumerate(dates):
683
            algo.set_symbol_lookup_date(date)
684
            result = algo.symbol('DUP')
685
            self.assertEqual(result.symbol, 'DUP')
686
            self.assertEqual(result.sid, i + 3)
687
688
        with self.assertRaises(UnsupportedDatetimeFormat):
689
            algo.set_symbol_lookup_date('foobar')
690
691
692
class TestTransformAlgorithm(TestCase):
693
694
    @classmethod
695
    def setUpClass(cls):
696
        setup_logger(cls)
697
        cls.env = TradingEnvironment()
698
        cls.sim_params = factory.create_simulation_parameters(num_days=4,
699
                                                              env=cls.env)
700
        cls.sids = [0, 1, 133]
701
        cls.tempdir = TempDirectory()
702
703
        futures_metadata = {3: {'contract_multiplier': 10}}
704
        equities_metadata = {}
705
706
        for sid in cls.sids:
707
            equities_metadata[sid] = {
708
                'start_date': cls.sim_params.period_start,
709
                'end_date': cls.sim_params.period_end
710
            }
711
712
        cls.env.write_data(equities_data=equities_metadata,
713
                           futures_data=futures_metadata)
714
715
        trades_by_sid = {}
716
        for sid in cls.sids:
717
            trades_by_sid[sid] = factory.create_trade_history(
718
                sid,
719
                [10.0, 10.0, 11.0, 11.0],
720
                [100, 100, 100, 300],
721
                timedelta(days=1),
722
                cls.sim_params,
723
                cls.env
724
            )
725
726
        cls.data_portal = create_data_portal_from_trade_history(cls.env,
727
                                                                cls.tempdir,
728
                                                                cls.sim_params,
729
                                                                trades_by_sid)
730
731
        cls.data_portal.current_day = cls.sim_params.trading_days[0]
732
733
    @classmethod
734
    def tearDownClass(cls):
735
        teardown_logger(cls)
736
        del cls.env
737
        cls.tempdir.cleanup()
738
739
    def test_invalid_order_parameters(self):
740
        algo = InvalidOrderAlgorithm(
741
            sids=[133],
742
            sim_params=self.sim_params,
743
            env=self.env,
744
        )
745
        algo.run(self.data_portal)
746
747
    def test_run_twice(self):
748
        algo1 = TestRegisterTransformAlgorithm(
749
            sim_params=self.sim_params,
750
            sids=[0, 1]
751
        )
752
753
        res1 = algo1.run(self.data_portal)
754
755
        # Create a new trading algorithm, which will
756
        # use the newly instantiated environment.
757
        algo2 = TestRegisterTransformAlgorithm(
758
            sim_params=self.sim_params,
759
            sids=[0, 1]
760
        )
761
762
        res2 = algo2.run(self.data_portal)
763
764
        # FIXME I think we are getting Nans due to fixed benchmark,
765
        # so dropping them for now.
766
        res1 = res1.fillna(method='ffill')
767
        res2 = res2.fillna(method='ffill')
768
769
        np.testing.assert_array_equal(res1, res2)
770
771
    def test_data_frequency_setting(self):
772
        self.sim_params.data_frequency = 'daily'
773
774
        sim_params = factory.create_simulation_parameters(
775
            num_days=4, env=self.env, data_frequency='daily')
776
777
        algo = TestRegisterTransformAlgorithm(
778
            sim_params=sim_params,
779
            env=self.env,
780
        )
781
        self.assertEqual(algo.sim_params.data_frequency, 'daily')
782
783
        sim_params = factory.create_simulation_parameters(
784
            num_days=4, env=self.env, data_frequency='minute')
785
786
        algo = TestRegisterTransformAlgorithm(
787
            sim_params=sim_params,
788
            env=self.env,
789
        )
790
        self.assertEqual(algo.sim_params.data_frequency, 'minute')
791
792
    @parameterized.expand([
793
        (TestOrderAlgorithm,),
794
        (TestOrderValueAlgorithm,),
795
        (TestTargetAlgorithm,),
796
        (TestOrderPercentAlgorithm,),
797
        (TestTargetPercentAlgorithm,),
798
        (TestTargetValueAlgorithm,),
799
    ])
800
    def test_order_methods(self, algo_class):
801
        algo = algo_class(
802
            sim_params=self.sim_params,
803
            env=self.env,
804
        )
805
        # Ensure that the environment's asset 0 is an Equity
806
        asset_to_test = algo.sid(0)
807
        self.assertIsInstance(asset_to_test, Equity)
808
809
        algo.run(self.data_portal)
810
811
    @parameterized.expand([
812
        (TestOrderAlgorithm,),
813
        (TestOrderValueAlgorithm,),
814
        (TestTargetAlgorithm,),
815
        (TestOrderPercentAlgorithm,),
816
        (TestTargetValueAlgorithm,),
817
    ])
818
    def test_order_methods_for_future(self, algo_class):
819
        algo = algo_class(
820
            sim_params=self.sim_params,
821
            env=self.env,
822
        )
823
        # Ensure that the environment's asset 0 is a Future
824
        asset_to_test = algo.sid(3)
825
        self.assertIsInstance(asset_to_test, Future)
826
827
        algo.run(self.data_portal)
828
829
    @parameterized.expand([
830
        ("order",),
831
        ("order_value",),
832
        ("order_percent",),
833
        ("order_target",),
834
        ("order_target_percent",),
835
        ("order_target_value",),
836
    ])
837
    def test_order_method_style_forwarding(self, order_style):
838
        algo = TestOrderStyleForwardingAlgorithm(
839
            sim_params=self.sim_params,
840
            method_name=order_style,
841
            env=self.env
842
        )
843
        algo.run(self.data_portal)
844
845
    @parameterized.expand([
846
        (TestOrderAlgorithm,),
847
        (TestOrderValueAlgorithm,),
848
        (TestTargetAlgorithm,),
849
        (TestOrderPercentAlgorithm,)
850
    ])
851
    def test_minute_data(self, algo_class):
852
        tempdir = TempDirectory()
853
854
        try:
855
            env = TradingEnvironment()
856
857
            sim_params = SimulationParameters(
858
                period_start=pd.Timestamp('2002-1-2', tz='UTC'),
859
                period_end=pd.Timestamp('2002-1-4', tz='UTC'),
860
                capital_base=float("1.0e5"),
861
                data_frequency='minute',
862
                env=env
863
            )
864
865
            equities_metadata = {}
866
867
            for sid in [0, 1]:
868
                equities_metadata[sid] = {
869
                    'start_date': sim_params.period_start,
870
                    'end_date': sim_params.period_end + timedelta(days=1)
871
                }
872
873
            env.write_data(equities_data=equities_metadata)
874
875
            data_portal = create_data_portal(
876
                env,
877
                tempdir,
878
                sim_params,
879
                [0, 1]
880
            )
881
882
            algo = algo_class(sim_params=sim_params, env=env)
883
            algo.run(data_portal)
884
        finally:
885
            tempdir.cleanup()
886
887
888
class TestPositions(TestCase):
889
    @classmethod
890
    def setUpClass(cls):
891
        setup_logger(cls)
892
        cls.env = TradingEnvironment()
893
        cls.sim_params = factory.create_simulation_parameters(num_days=4,
894
                                                              env=cls.env)
895
896
        cls.sids = [1, 133]
897
        cls.tempdir = TempDirectory()
898
899
        equities_metadata = {}
900
901
        for sid in cls.sids:
902
            equities_metadata[sid] = {
903
                'start_date': cls.sim_params.period_start,
904
                'end_date': cls.sim_params.period_end
905
            }
906
907
        cls.env.write_data(equities_data=equities_metadata)
908
909
        cls.data_portal = create_data_portal(
910
            cls.env,
911
            cls.tempdir,
912
            cls.sim_params,
913
            cls.sids
914
        )
915
916
    @classmethod
917
    def tearDownClass(cls):
918
        teardown_logger(cls)
919
        cls.tempdir.cleanup()
920
921
    def test_empty_portfolio(self):
922
        algo = EmptyPositionsAlgorithm(self.sids,
923
                                       sim_params=self.sim_params,
924
                                       env=self.env)
925
        daily_stats = algo.run(self.data_portal)
926
927
        expected_position_count = [
928
            0,  # Before entering the first position
929
            2,  # After entering, exiting on this date
930
            0,  # After exiting
931
            0,
932
        ]
933
934
        for i, expected in enumerate(expected_position_count):
935
            self.assertEqual(daily_stats.ix[i]['num_positions'],
936
                             expected)
937
938
    def test_noop_orders(self):
939
        algo = AmbitiousStopLimitAlgorithm(sid=1,
940
                                           sim_params=self.sim_params,
941
                                           env=self.env)
942
        daily_stats = algo.run(self.data_portal)
943
944
        # Verify that positions are empty for all dates.
945
        empty_positions = daily_stats.positions.map(lambda x: len(x) == 0)
946
        self.assertTrue(empty_positions.all())
947
948
949
class TestAlgoScript(TestCase):
950
951
    @classmethod
952
    def setUpClass(cls):
953
        setup_logger(cls)
954
        cls.env = TradingEnvironment()
955
        cls.sim_params = factory.create_simulation_parameters(num_days=251,
956
                                                              env=cls.env)
957
958
        cls.sids = [0, 1, 3, 133]
959
        cls.tempdir = TempDirectory()
960
961
        equities_metadata = {}
962
963
        for sid in cls.sids:
964
            equities_metadata[sid] = {
965
                'start_date': cls.sim_params.period_start,
966
                'end_date': cls.sim_params.period_end
967
            }
968
969
            if sid == 3:
970
                equities_metadata[sid]["symbol"] = "TEST"
971
                equities_metadata[sid]["asset_type"] = "equity"
972
973
        cls.env.write_data(equities_data=equities_metadata)
974
975
        days = 251
976
977
        trades_by_sid = {
978
            0: factory.create_trade_history(
979
                0,
980
                [10.0] * days,
981
                [100] * days,
982
                timedelta(days=1),
983
                cls.sim_params,
984
                cls.env),
985
            3: factory.create_trade_history(
986
                3,
987
                [10.0] * days,
988
                [100] * days,
989
                timedelta(days=1),
990
                cls.sim_params,
991
                cls.env)
992
        }
993
994
        cls.data_portal = create_data_portal_from_trade_history(cls.env,
995
                                                                cls.tempdir,
996
                                                                cls.sim_params,
997
                                                                trades_by_sid)
998
999
        cls.zipline_test_config = {
1000
            'sid': 0,
1001
        }
1002
1003
    @classmethod
1004
    def tearDownClass(cls):
1005
        del cls.env
1006
        cls.tempdir.cleanup()
1007
        teardown_logger(cls)
1008
1009
    def test_noop(self):
1010
        algo = TradingAlgorithm(initialize=initialize_noop,
1011
                                handle_data=handle_data_noop)
1012
        algo.run(self.data_portal)
1013
1014
    def test_noop_string(self):
1015
        algo = TradingAlgorithm(script=noop_algo)
1016
        algo.run(self.data_portal)
1017
1018
    def test_api_calls(self):
1019
        algo = TradingAlgorithm(initialize=initialize_api,
1020
                                handle_data=handle_data_api,
1021
                                env=self.env)
1022
        algo.run(self.data_portal)
1023
1024
    def test_api_calls_string(self):
1025
        algo = TradingAlgorithm(script=api_algo, env=self.env)
1026
        algo.run(self.data_portal)
1027
1028
    def test_api_get_environment(self):
1029
        platform = 'zipline'
1030
        algo = TradingAlgorithm(script=api_get_environment_algo,
1031
                                platform=platform)
1032
        algo.run(self.data_portal)
1033
        self.assertEqual(algo.environment, platform)
1034
1035
    def test_api_symbol(self):
1036
        algo = TradingAlgorithm(script=api_symbol_algo,
1037
                                env=self.env,
1038
                                sim_params=self.sim_params)
1039
        algo.run(self.data_portal)
1040
1041
    def test_fixed_slippage(self):
1042
        # verify order -> transaction -> portfolio position.
1043
        # --------------
1044
        test_algo = TradingAlgorithm(
1045
            script="""
1046
from zipline.api import (slippage,
1047
                         commission,
1048
                         set_slippage,
1049
                         set_commission,
1050
                         order,
1051
                         record,
1052
                         sid)
1053
1054
def initialize(context):
1055
    model = slippage.FixedSlippage(spread=0.10)
1056
    set_slippage(model)
1057
    set_commission(commission.PerTrade(100.00))
1058
    context.count = 1
1059
    context.incr = 0
1060
1061
def handle_data(context, data):
1062
    if context.incr < context.count:
1063
        order(sid(0), -1000)
1064
    record(price=data[0].price)
1065
1066
    context.incr += 1""",
1067
            sim_params=self.sim_params,
1068
            env=self.env,
1069
        )
1070
        results = test_algo.run(self.data_portal)
1071
1072
        # flatten the list of txns
1073
        all_txns = [val for sublist in results["transactions"].tolist()
1074
                    for val in sublist]
1075
1076
        self.assertEqual(len(all_txns), 1)
1077
        txn = all_txns[0]
1078
1079
        self.assertEqual(100.0, txn["commission"])
1080
        expected_spread = 0.05
1081
        expected_commish = 0.10
1082
        expected_price = test_algo.recorded_vars["price"] - expected_spread \
1083
            - expected_commish
1084
1085
        self.assertEqual(expected_price, txn['price'])
1086
1087
    @parameterized.expand(
1088
        [
1089
            ('no_minimum_commission', 0,),
1090
            ('default_minimum_commission', 1,),
1091
            ('alternate_minimum_commission', 2,),
1092
        ]
1093
    )
1094
    def test_volshare_slippage(self, name, minimum_commission):
1095
        tempdir = TempDirectory()
1096
        try:
1097
            if name == "default_minimum_commission":
1098
                commission_line = "set_commission(commission.PerShare(0.02))"
1099
            else:
1100
                commission_line = \
1101
                    "set_commission(commission.PerShare(0.02, " \
1102
                    "min_trade_cost={0}))".format(minimum_commission)
1103
1104
            # verify order -> transaction -> portfolio position.
1105
            # --------------
1106
            test_algo = TradingAlgorithm(
1107
                script="""
1108
from zipline.api import *
1109
1110
def initialize(context):
1111
    model = slippage.VolumeShareSlippage(
1112
                            volume_limit=.3,
1113
                            price_impact=0.05
1114
                       )
1115
    set_slippage(model)
1116
    {0}
1117
1118
    context.count = 2
1119
    context.incr = 0
1120
1121
def handle_data(context, data):
1122
    if context.incr < context.count:
1123
        # order small lots to be sure the
1124
        # order will fill in a single transaction
1125
        order(sid(0), 5000)
1126
    record(price=data[0].price)
1127
    record(volume=data[0].volume)
1128
    record(incr=context.incr)
1129
    context.incr += 1
1130
    """.format(commission_line),
1131
                sim_params=self.sim_params,
1132
                env=self.env,
1133
            )
1134
            set_algo_instance(test_algo)
1135
            trades = factory.create_daily_trade_source(
1136
                [0], self.sim_params, self.env)
1137
            data_portal = create_data_portal_from_trade_history(
1138
                self.env, tempdir, self.sim_params, {0: trades})
1139
            results = test_algo.run(data_portal=data_portal)
1140
1141
            all_txns = [
1142
                val for sublist in results["transactions"].tolist()
1143
                for val in sublist]
1144
1145
            self.assertEqual(len(all_txns), 67)
1146
            first_txn = all_txns[0]
1147
1148
            if minimum_commission == 0:
1149
                commish = first_txn["amount"] * 0.02
1150
                self.assertEqual(commish, first_txn["commission"])
1151
            else:
1152
                self.assertEqual(minimum_commission, first_txn["commission"])
1153
1154
            # import pdb; pdb.set_trace()
1155
            # commish = first_txn["amount"] * per_share_commish
1156
            # self.assertEqual(commish, first_txn["commission"])
1157
            # self.assertEqual(2.029, first_txn["price"])
1158
        finally:
1159
            tempdir.cleanup()
1160
1161
    def test_algo_record_vars(self):
1162
        test_algo = TradingAlgorithm(
1163
            script=record_variables,
1164
            sim_params=self.sim_params,
1165
            env=self.env,
1166
        )
1167
        results = test_algo.run(self.data_portal)
1168
1169
        for i in range(1, 252):
1170
            self.assertEqual(results.iloc[i-1]["incr"], i)
1171
1172
    def test_algo_record_allow_mock(self):
1173
        """
1174
        Test that values from "MagicMock"ed methods can be passed to record.
1175
1176
        Relevant for our basic/validation and methods like history, which
1177
        will end up returning a MagicMock instead of a DataFrame.
1178
        """
1179
        test_algo = TradingAlgorithm(
1180
            script=record_variables,
1181
            sim_params=self.sim_params,
1182
        )
1183
        set_algo_instance(test_algo)
1184
1185
        test_algo.record(foo=MagicMock())
1186
1187
    def test_algo_record_nan(self):
1188
        test_algo = TradingAlgorithm(
1189
            script=record_float_magic % 'nan',
1190
            sim_params=self.sim_params,
1191
            env=self.env,
1192
        )
1193
        results = test_algo.run(self.data_portal)
1194
1195
        for i in range(1, 252):
1196
            self.assertTrue(np.isnan(results.iloc[i-1]["data"]))
1197
1198
    def test_order_methods(self):
1199
        """
1200
        Only test that order methods can be called without error.
1201
        Correct filling of orders is tested in zipline.
1202
        """
1203
        test_algo = TradingAlgorithm(
1204
            script=call_all_order_methods,
1205
            sim_params=self.sim_params,
1206
            env=self.env,
1207
        )
1208
        test_algo.run(self.data_portal)
1209
1210
    def test_order_in_init(self):
1211
        """
1212
        Test that calling order in initialize
1213
        will raise an error.
1214
        """
1215
        with self.assertRaises(OrderDuringInitialize):
1216
            test_algo = TradingAlgorithm(
1217
                script=call_order_in_init,
1218
                sim_params=self.sim_params,
1219
                env=self.env,
1220
            )
1221
            test_algo.run(self.data_portal)
1222
1223
    def test_portfolio_in_init(self):
1224
        """
1225
        Test that accessing portfolio in init doesn't break.
1226
        """
1227
        test_algo = TradingAlgorithm(
1228
            script=access_portfolio_in_init,
1229
            sim_params=self.sim_params,
1230
            env=self.env,
1231
        )
1232
        test_algo.run(self.data_portal)
1233
1234
    def test_account_in_init(self):
1235
        """
1236
        Test that accessing account in init doesn't break.
1237
        """
1238
        test_algo = TradingAlgorithm(
1239
            script=access_account_in_init,
1240
            sim_params=self.sim_params,
1241
            env=self.env,
1242
        )
1243
        test_algo.run(self.data_portal)
1244
1245
1246
class TestGetDatetime(TestCase):
1247
1248
    @classmethod
1249
    def setUpClass(cls):
1250
        cls.env = TradingEnvironment()
1251
        cls.env.write_data(equities_identifiers=[0, 1])
1252
1253
        setup_logger(cls)
1254
1255
        cls.sim_params = factory.create_simulation_parameters(
1256
            data_frequency='minute',
1257
            env=cls.env,
1258
            start=to_utc('2014-01-02 9:31'),
1259
            end=to_utc('2014-01-03 9:31')
1260
        )
1261
1262
        cls.tempdir = TempDirectory()
1263
1264
        cls.data_portal = create_data_portal(
1265
            cls.env,
1266
            cls.tempdir,
1267
            cls.sim_params,
1268
            [1]
1269
        )
1270
1271
    @classmethod
1272
    def tearDownClass(cls):
1273
        del cls.env
1274
        teardown_logger(cls)
1275
        cls.tempdir.cleanup()
1276
1277
    @parameterized.expand(
1278
        [
1279
            ('default', None,),
1280
            ('utc', 'UTC',),
1281
            ('us_east', 'US/Eastern',),
1282
        ]
1283
    )
1284
    def test_get_datetime(self, name, tz):
1285
        algo = dedent(
1286
            """
1287
            import pandas as pd
1288
            from zipline.api import get_datetime
1289
1290
            def initialize(context):
1291
                context.tz = {tz} or 'UTC'
1292
                context.first_bar = True
1293
1294
            def handle_data(context, data):
1295
                if context.first_bar:
1296
                    dt = get_datetime({tz})
1297
                    if dt.tz.zone != context.tz:
1298
                        raise ValueError("Mismatched Zone")
1299
                    elif dt.tz_convert("US/Eastern").hour != 9:
1300
                        raise ValueError("Mismatched Hour")
1301
                    elif dt.tz_convert("US/Eastern").minute != 31:
1302
                        raise ValueError("Mismatched Minute")
1303
                context.first_bar = False
1304
            """.format(tz=repr(tz))
1305
        )
1306
1307
        algo = TradingAlgorithm(
1308
            script=algo,
1309
            sim_params=self.sim_params,
1310
            env=self.env,
1311
        )
1312
        algo.run(self.data_portal)
1313
        self.assertFalse(algo.first_bar)
1314
1315
1316
class TestTradingControls(TestCase):
1317
1318
    @classmethod
1319
    def setUpClass(cls):
1320
        cls.sid = 133
1321
        cls.env = TradingEnvironment()
1322
        cls.sim_params = factory.create_simulation_parameters(num_days=4,
1323
                                                              env=cls.env)
1324
1325
        cls.env.write_data(equities_data={
1326
            133: {
1327
                'start_date': cls.sim_params.period_start,
1328
                'end_date': cls.sim_params.period_end
1329
            }
1330
        })
1331
1332
        cls.tempdir = TempDirectory()
1333
1334
        cls.data_portal = create_data_portal(
1335
            cls.env,
1336
            cls.tempdir,
1337
            cls.sim_params,
1338
            [cls.sid]
1339
        )
1340
1341
    @classmethod
1342
    def tearDownClass(cls):
1343
        del cls.env
1344
        cls.tempdir.cleanup()
1345
1346
    def _check_algo(self,
1347
                    algo,
1348
                    handle_data,
1349
                    expected_order_count,
1350
                    expected_exc):
1351
1352
        algo._handle_data = handle_data
1353
        with self.assertRaises(expected_exc) if expected_exc else nullctx():
1354
            algo.run(self.data_portal)
1355
        self.assertEqual(algo.order_count, expected_order_count)
1356
1357
    def check_algo_succeeds(self, algo, handle_data, order_count=4):
1358
        # Default for order_count assumes one order per handle_data call.
1359
        self._check_algo(algo, handle_data, order_count, None)
1360
1361
    def check_algo_fails(self, algo, handle_data, order_count):
1362
        self._check_algo(algo,
1363
                         handle_data,
1364
                         order_count,
1365
                         TradingControlViolation)
1366
1367
    def test_set_max_position_size(self):
1368
1369
        # Buy one share four times.  Should be fine.
1370
        def handle_data(algo, data):
1371
            algo.order(algo.sid(self.sid), 1)
1372
            algo.order_count += 1
1373
        algo = SetMaxPositionSizeAlgorithm(sid=self.sid,
1374
                                           max_shares=10,
1375
                                           max_notional=500.0,
1376
                                           sim_params=self.sim_params,
1377
                                           env=self.env)
1378
        self.check_algo_succeeds(algo, handle_data)
1379
1380
        # Buy three shares four times.  Should bail on the fourth before it's
1381
        # placed.
1382
        def handle_data(algo, data):
1383
            algo.order(algo.sid(self.sid), 3)
1384
            algo.order_count += 1
1385
1386
        algo = SetMaxPositionSizeAlgorithm(sid=self.sid,
1387
                                           max_shares=10,
1388
                                           max_notional=500.0,
1389
                                           sim_params=self.sim_params,
1390
                                           env=self.env)
1391
        self.check_algo_fails(algo, handle_data, 3)
1392
1393
        # Buy three shares four times. Should bail due to max_notional on the
1394
        # third attempt.
1395
        def handle_data(algo, data):
1396
            algo.order(algo.sid(self.sid), 3)
1397
            algo.order_count += 1
1398
1399
        algo = SetMaxPositionSizeAlgorithm(sid=self.sid,
1400
                                           max_shares=10,
1401
                                           max_notional=67.0,
1402
                                           sim_params=self.sim_params,
1403
                                           env=self.env)
1404
        self.check_algo_fails(algo, handle_data, 2)
1405
1406
        # Set the trading control to a different sid, then BUY ALL THE THINGS!.
1407
        # Should continue normally.
1408
        def handle_data(algo, data):
1409
            algo.order(algo.sid(self.sid), 10000)
1410
            algo.order_count += 1
1411
        algo = SetMaxPositionSizeAlgorithm(sid=self.sid + 1,
1412
                                           max_shares=10,
1413
                                           max_notional=67.0,
1414
                                           sim_params=self.sim_params,
1415
                                           env=self.env)
1416
        self.check_algo_succeeds(algo, handle_data)
1417
1418
        # Set the trading control sid to None, then BUY ALL THE THINGS!. Should
1419
        # fail because setting sid to None makes the control apply to all sids.
1420
        def handle_data(algo, data):
1421
            algo.order(algo.sid(self.sid), 10000)
1422
            algo.order_count += 1
1423
        algo = SetMaxPositionSizeAlgorithm(max_shares=10, max_notional=61.0,
1424
                                           sim_params=self.sim_params,
1425
                                           env=self.env)
1426
        self.check_algo_fails(algo, handle_data, 0)
1427
1428
    def test_set_do_not_order_list(self):
1429
        # set the restricted list to be the sid, and fail.
1430
        algo = SetDoNotOrderListAlgorithm(
1431
            sid=self.sid,
1432
            restricted_list=[self.sid],
1433
            sim_params=self.sim_params,
1434
            env=self.env,
1435
        )
1436
1437
        def handle_data(algo, data):
1438
            algo.order(algo.sid(self.sid), 100)
1439
            algo.order_count += 1
1440
1441
        self.check_algo_fails(algo, handle_data, 0)
1442
1443
        # set the restricted list to exclude the sid, and succeed
1444
        algo = SetDoNotOrderListAlgorithm(
1445
            sid=self.sid,
1446
            restricted_list=[134, 135, 136],
1447
            sim_params=self.sim_params,
1448
            env=self.env,
1449
        )
1450
1451
        def handle_data(algo, data):
1452
            algo.order(algo.sid(self.sid), 100)
1453
            algo.order_count += 1
1454
1455
        self.check_algo_succeeds(algo, handle_data)
1456
1457
    def test_set_max_order_size(self):
1458
1459
        # Buy one share.
1460
        def handle_data(algo, data):
1461
            algo.order(algo.sid(self.sid), 1)
1462
            algo.order_count += 1
1463
        algo = SetMaxOrderSizeAlgorithm(sid=self.sid,
1464
                                        max_shares=10,
1465
                                        max_notional=500.0,
1466
                                        sim_params=self.sim_params,
1467
                                        env=self.env)
1468
        self.check_algo_succeeds(algo, handle_data)
1469
1470
        # Buy 1, then 2, then 3, then 4 shares.  Bail on the last attempt
1471
        # because we exceed shares.
1472
        def handle_data(algo, data):
1473
            algo.order(algo.sid(self.sid), algo.order_count + 1)
1474
            algo.order_count += 1
1475
1476
        algo = SetMaxOrderSizeAlgorithm(sid=self.sid,
1477
                                        max_shares=3,
1478
                                        max_notional=500.0,
1479
                                        sim_params=self.sim_params,
1480
                                        env=self.env)
1481
        self.check_algo_fails(algo, handle_data, 3)
1482
1483
        # Buy 1, then 2, then 3, then 4 shares.  Bail on the last attempt
1484
        # because we exceed notional.
1485
        def handle_data(algo, data):
1486
            algo.order(algo.sid(self.sid), algo.order_count + 1)
1487
            algo.order_count += 1
1488
1489
        algo = SetMaxOrderSizeAlgorithm(sid=self.sid,
1490
                                        max_shares=10,
1491
                                        max_notional=40.0,
1492
                                        sim_params=self.sim_params,
1493
                                        env=self.env)
1494
        self.check_algo_fails(algo, handle_data, 3)
1495
1496
        # Set the trading control to a different sid, then BUY ALL THE THINGS!.
1497
        # Should continue normally.
1498
        def handle_data(algo, data):
1499
            algo.order(algo.sid(self.sid), 10000)
1500
            algo.order_count += 1
1501
        algo = SetMaxOrderSizeAlgorithm(sid=self.sid + 1,
1502
                                        max_shares=1,
1503
                                        max_notional=1.0,
1504
                                        sim_params=self.sim_params,
1505
                                        env=self.env)
1506
        self.check_algo_succeeds(algo, handle_data)
1507
1508
        # Set the trading control sid to None, then BUY ALL THE THINGS!.
1509
        # Should fail because not specifying a sid makes the trading control
1510
        # apply to all sids.
1511
        def handle_data(algo, data):
1512
            algo.order(algo.sid(self.sid), 10000)
1513
            algo.order_count += 1
1514
        algo = SetMaxOrderSizeAlgorithm(max_shares=1,
1515
                                        max_notional=1.0,
1516
                                        sim_params=self.sim_params,
1517
                                        env=self.env)
1518
        self.check_algo_fails(algo, handle_data, 0)
1519
1520
    def test_set_max_order_count(self):
1521
        tempdir = TempDirectory()
1522
        try:
1523
            env = TradingEnvironment()
1524
            sim_params = factory.create_simulation_parameters(
1525
                num_days=4, env=env, data_frequency="minute")
1526
1527
            env.write_data(equities_data={
1528
                1: {
1529
                    'start_date': sim_params.period_start,
1530
                    'end_date': sim_params.period_end + timedelta(days=1)
1531
                }
1532
            })
1533
1534
            data_portal = create_data_portal(
1535
                env,
1536
                tempdir,
1537
                sim_params,
1538
                [1]
1539
            )
1540
1541
            def handle_data(algo, data):
1542
                for i in range(5):
1543
                    algo.order(algo.sid(1), 1)
1544
                    algo.order_count += 1
1545
1546
            algo = SetMaxOrderCountAlgorithm(3, sim_params=sim_params,
1547
                                             env=env)
1548
            with self.assertRaises(TradingControlViolation):
1549
                algo._handle_data = handle_data
1550
                algo.run(data_portal)
1551
1552
            self.assertEqual(algo.order_count, 3)
1553
1554
            # This time, order 5 times twice in a single day. The last order
1555
            # of the second batch should fail.
1556
            def handle_data2(algo, data):
1557
                if algo.minute_count == 0 or algo.minute_count == 100:
1558
                    for i in range(5):
1559
                        algo.order(algo.sid(1), 1)
1560
                        algo.order_count += 1
1561
1562
                algo.minute_count += 1
1563
1564
            algo = SetMaxOrderCountAlgorithm(9, sim_params=sim_params,
1565
                                             env=env)
1566
            with self.assertRaises(TradingControlViolation):
1567
                algo._handle_data = handle_data2
1568
                algo.run(data_portal)
1569
1570
            self.assertEqual(algo.order_count, 9)
1571
1572
            def handle_data3(algo, data):
1573
                if (algo.minute_count % 390) == 0:
1574
                    for i in range(5):
1575
                        algo.order(algo.sid(1), 1)
1576
                        algo.order_count += 1
1577
1578
                algo.minute_count += 1
1579
1580
            # Only 5 orders are placed per day, so this should pass even
1581
            # though in total more than 20 orders are placed.
1582
            algo = SetMaxOrderCountAlgorithm(5, sim_params=sim_params,
1583
                                             env=env)
1584
            algo._handle_data = handle_data3
1585
            algo.run(data_portal)
1586
        finally:
1587
            tempdir.cleanup()
1588
1589
    def test_long_only(self):
1590
        # Sell immediately -> fail immediately.
1591
        def handle_data(algo, data):
1592
            algo.order(algo.sid(self.sid), -1)
1593
            algo.order_count += 1
1594
        algo = SetLongOnlyAlgorithm(sim_params=self.sim_params, env=self.env)
1595
        self.check_algo_fails(algo, handle_data, 0)
1596
1597
        # Buy on even days, sell on odd days.  Never takes a short position, so
1598
        # should succeed.
1599
        def handle_data(algo, data):
1600
            if (algo.order_count % 2) == 0:
1601
                algo.order(algo.sid(self.sid), 1)
1602
            else:
1603
                algo.order(algo.sid(self.sid), -1)
1604
            algo.order_count += 1
1605
        algo = SetLongOnlyAlgorithm(sim_params=self.sim_params, env=self.env)
1606
        self.check_algo_succeeds(algo, handle_data)
1607
1608
        # Buy on first three days, then sell off holdings.  Should succeed.
1609
        def handle_data(algo, data):
1610
            amounts = [1, 1, 1, -3]
1611
            algo.order(algo.sid(self.sid), amounts[algo.order_count])
1612
            algo.order_count += 1
1613
        algo = SetLongOnlyAlgorithm(sim_params=self.sim_params, env=self.env)
1614
        self.check_algo_succeeds(algo, handle_data)
1615
1616
        # Buy on first three days, then sell off holdings plus an extra share.
1617
        # Should fail on the last sale.
1618
        def handle_data(algo, data):
1619
            amounts = [1, 1, 1, -4]
1620
            algo.order(algo.sid(self.sid), amounts[algo.order_count])
1621
            algo.order_count += 1
1622
        algo = SetLongOnlyAlgorithm(sim_params=self.sim_params, env=self.env)
1623
        self.check_algo_fails(algo, handle_data, 3)
1624
1625
    def test_register_post_init(self):
1626
1627
        def initialize(algo):
1628
            algo.initialized = True
1629
1630
        def handle_data(algo, data):
1631
            with self.assertRaises(RegisterTradingControlPostInit):
1632
                algo.set_max_position_size(self.sid, 1, 1)
1633
            with self.assertRaises(RegisterTradingControlPostInit):
1634
                algo.set_max_order_size(self.sid, 1, 1)
1635
            with self.assertRaises(RegisterTradingControlPostInit):
1636
                algo.set_max_order_count(1)
1637
            with self.assertRaises(RegisterTradingControlPostInit):
1638
                algo.set_long_only()
1639
1640
        algo = TradingAlgorithm(initialize=initialize,
1641
                                handle_data=handle_data,
1642
                                sim_params=self.sim_params,
1643
                                env=self.env)
1644
        algo.run(self.data_portal)
1645
1646
    def test_asset_date_bounds(self):
1647
        tempdir = TempDirectory()
1648
        try:
1649
            # Run the algorithm with a sid that ends far in the future
1650
            temp_env = TradingEnvironment()
1651
1652
            data_portal = create_data_portal(
1653
                temp_env,
1654
                tempdir,
1655
                self.sim_params,
1656
                [0]
1657
            )
1658
1659
            metadata = {0: {'start_date': self.sim_params.period_start,
1660
                            'end_date': '2020-01-01'}}
1661
1662
            algo = SetAssetDateBoundsAlgorithm(
1663
                equities_metadata=metadata,
1664
                sim_params=self.sim_params,
1665
                env=temp_env,
1666
            )
1667
            algo.run(data_portal)
1668
        finally:
1669
            tempdir.cleanup()
1670
1671
        # Run the algorithm with a sid that has already ended
1672
        tempdir = TempDirectory()
1673
        try:
1674
            temp_env = TradingEnvironment()
1675
1676
            data_portal = create_data_portal(
1677
                temp_env,
1678
                tempdir,
1679
                self.sim_params,
1680
                [0]
1681
            )
1682
            metadata = {0: {'start_date': '1989-01-01',
1683
                            'end_date': '1990-01-01'}}
1684
1685
            algo = SetAssetDateBoundsAlgorithm(
1686
                equities_metadata=metadata,
1687
                sim_params=self.sim_params,
1688
                env=temp_env,
1689
            )
1690
            with self.assertRaises(TradingControlViolation):
1691
                algo.run(data_portal)
1692
        finally:
1693
            tempdir.cleanup()
1694
1695
        # Run the algorithm with a sid that has not started
1696
        tempdir = TempDirectory()
1697
        try:
1698
            temp_env = TradingEnvironment()
1699
            data_portal = create_data_portal(
1700
                temp_env,
1701
                tempdir,
1702
                self.sim_params,
1703
                [0]
1704
            )
1705
1706
            metadata = {0: {'start_date': '2020-01-01',
1707
                            'end_date': '2021-01-01'}}
1708
1709
            algo = SetAssetDateBoundsAlgorithm(
1710
                equities_metadata=metadata,
1711
                sim_params=self.sim_params,
1712
                env=temp_env,
1713
            )
1714
1715
            with self.assertRaises(TradingControlViolation):
1716
                algo.run(data_portal)
1717
1718
        finally:
1719
            tempdir.cleanup()
1720
1721
1722
class TestAccountControls(TestCase):
1723
1724
    @classmethod
1725
    def setUpClass(cls):
1726
        cls.sidint = 133
1727
        cls.env = TradingEnvironment()
1728
        cls.sim_params = factory.create_simulation_parameters(
1729
            num_days=4, env=cls.env
1730
        )
1731
1732
        cls.env.write_data(equities_data={
1733
            133: {
1734
                'start_date': cls.sim_params.period_start,
1735
                'end_date': cls.sim_params.period_end + timedelta(days=1)
1736
            }
1737
        })
1738
1739
        cls.tempdir = TempDirectory()
1740
1741
        trades_by_sid = {
1742
            cls.sidint: factory.create_trade_history(
1743
                cls.sidint,
1744
                [10.0, 10.0, 11.0, 11.0],
1745
                [100, 100, 100, 300],
1746
                timedelta(days=1),
1747
                cls.sim_params,
1748
                cls.env,
1749
            )
1750
        }
1751
1752
        cls.data_portal = create_data_portal_from_trade_history(cls.env,
1753
                                                                cls.tempdir,
1754
                                                                cls.sim_params,
1755
                                                                trades_by_sid)
1756
1757
    @classmethod
1758
    def tearDownClass(cls):
1759
        del cls.env
1760
        cls.tempdir.cleanup()
1761
1762
    def _check_algo(self,
1763
                    algo,
1764
                    handle_data,
1765
                    expected_exc):
1766
1767
        algo._handle_data = handle_data
1768
        with self.assertRaises(expected_exc) if expected_exc else nullctx():
1769
            algo.run(self.data_portal)
1770
1771
    def check_algo_succeeds(self, algo, handle_data):
1772
        # Default for order_count assumes one order per handle_data call.
1773
        self._check_algo(algo, handle_data, None)
1774
1775
    def check_algo_fails(self, algo, handle_data):
1776
        self._check_algo(algo,
1777
                         handle_data,
1778
                         AccountControlViolation)
1779
1780
    def test_set_max_leverage(self):
1781
1782
        # Set max leverage to 0 so buying one share fails.
1783
        def handle_data(algo, data):
1784
            algo.order(algo.sid(self.sidint), 1)
1785
1786
        algo = SetMaxLeverageAlgorithm(0, sim_params=self.sim_params,
1787
                                       env=self.env)
1788
        self.check_algo_fails(algo, handle_data)
1789
1790
        # Set max leverage to 1 so buying one share passes
1791
        def handle_data(algo, data):
1792
            algo.order(algo.sid(self.sidint), 1)
1793
1794
        algo = SetMaxLeverageAlgorithm(1,  sim_params=self.sim_params,
1795
                                       env=self.env)
1796
        self.check_algo_succeeds(algo, handle_data)
1797
1798
1799
class TestClosePosAlgo(TestCase):
1800
1801
    @classmethod
1802
    def setUpClass(cls):
1803
        cls.tempdir = TempDirectory()
1804
1805
        cls.env = TradingEnvironment()
1806
        cls.days = pd.date_range(start=pd.Timestamp("2006-01-09", tz='UTC'),
1807
                                 end=pd.Timestamp("2006-01-12", tz='UTC'))
1808
1809
        cls.sid = 1
1810
1811
        cls.sim_params = factory.create_simulation_parameters(
1812
            start=cls.days[0],
1813
            end=cls.days[-1]
1814
        )
1815
1816
        trades_by_sid = {}
1817
        trades_by_sid[cls.sid] = factory.create_trade_history(
1818
            cls.sid,
1819
            [1, 1, 2, 4],
1820
            [1e9, 1e9, 1e9, 1e9],
1821
            timedelta(days=1),
1822
            cls.sim_params,
1823
            cls.env
1824
        )
1825
1826
        cls.data_portal = create_data_portal_from_trade_history(
1827
            cls.env,
1828
            cls.tempdir,
1829
            cls.sim_params,
1830
            trades_by_sid
1831
        )
1832
1833
    @classmethod
1834
    def tearDownClass(cls):
1835
        cls.tempdir.cleanup()
1836
1837
    def test_auto_close_future(self):
1838
        self.env.write_data(futures_data={
1839
            1: {
1840
                "start_date": self.sim_params.trading_days[0],
1841
                "end_date": self.env.next_trading_day(
1842
                    self.sim_params.trading_days[-1]),
1843
                'symbol': 'TEST',
1844
                'asset_type': 'future',
1845
                'auto_close_date': self.env.next_trading_day(
1846
                    self.sim_params.trading_days[-1])
1847
            }
1848
        })
1849
1850
        algo = TestAlgorithm(sid=1, amount=1, order_count=1,
1851
                             commission=PerShare(0),
1852
                             env=self.env,
1853
                             sim_params=self.sim_params)
1854
1855
        # Check results
1856
        results = algo.run(self.data_portal)
1857
1858
        expected_pnl = [0, 0, 1, 2]
1859
        self.check_algo_pnl(results, expected_pnl)
1860
1861
        expected_positions = [0, 1, 1, 0]
1862
        self.check_algo_positions(results, expected_positions)
1863
1864
        expected_pnl = [0, 0, 1, 2]
1865
        self.check_algo_pnl(results, expected_pnl)
1866
1867
    def check_algo_pnl(self, results, expected_pnl):
1868
        np.testing.assert_array_almost_equal(results.pnl, expected_pnl)
1869
1870
    def check_algo_positions(self, results, expected_positions):
1871
        for i, amount in enumerate(results.positions):
1872
            if amount:
1873
                actual_position = amount[0]['amount']
1874
            else:
1875
                actual_position = 0
1876
1877
            self.assertEqual(
1878
                actual_position, expected_positions[i],
1879
                "position for day={0} not equal, actual={1}, expected={2}".
1880
                format(i, actual_position, expected_positions[i]))
1881
1882
1883
class TestFutureFlip(TestCase):
1884
    @classmethod
1885
    def setUpClass(cls):
1886
        cls.tempdir = TempDirectory()
1887
1888
        cls.env = TradingEnvironment()
1889
        cls.days = pd.date_range(start=pd.Timestamp("2006-01-09", tz='UTC'),
1890
                                 end=pd.Timestamp("2006-01-12", tz='UTC'))
1891
1892
        cls.sid = 1
1893
1894
        cls.sim_params = factory.create_simulation_parameters(
1895
            start=cls.days[0],
1896
            end=cls.days[-2]
1897
        )
1898
1899
        trades = factory.create_trade_history(
1900
            cls.sid,
1901
            [1, 2, 4],
1902
            [1e9, 1e9, 1e9],
1903
            timedelta(days=1),
1904
            cls.sim_params,
1905
            cls.env
1906
        )
1907
1908
        trades_by_sid = {
1909
            cls.sid: trades
1910
        }
1911
1912
        cls.data_portal = create_data_portal_from_trade_history(
1913
            cls.env,
1914
            cls.tempdir,
1915
            cls.sim_params,
1916
            trades_by_sid
1917
        )
1918
1919
    @classmethod
1920
    def tearDownClass(cls):
1921
        cls.tempdir.cleanup()
1922
1923
    def test_flip_algo(self):
1924
        metadata = {1: {'symbol': 'TEST',
1925
                        'start_date': self.sim_params.trading_days[0],
1926
                        'end_date': self.env.next_trading_day(
1927
                            self.sim_params.trading_days[-1]),
1928
                        'contract_multiplier': 5}}
1929
        self.env.write_data(futures_data=metadata)
1930
1931
        algo = FutureFlipAlgo(sid=1, amount=1, env=self.env,
1932
                              commission=PerShare(0),
1933
                              order_count=0,  # not applicable but required
1934
                              sim_params=self.sim_params)
1935
1936
        results = algo.run(self.data_portal)
1937
1938
        expected_positions = [0, 1, -1]
1939
        self.check_algo_positions(results, expected_positions)
1940
1941
        expected_pnl = [0, 5, -10]
1942
        self.check_algo_pnl(results, expected_pnl)
1943
1944
    def check_algo_pnl(self, results, expected_pnl):
1945
        np.testing.assert_array_almost_equal(results.pnl, expected_pnl)
1946
1947
    def check_algo_positions(self, results, expected_positions):
1948
        for i, amount in enumerate(results.positions):
1949
            if amount:
1950
                actual_position = amount[0]['amount']
1951
            else:
1952
                actual_position = 0
1953
1954
            self.assertEqual(
1955
                actual_position, expected_positions[i],
1956
                "position for day={0} not equal, actual={1}, expected={2}".
1957
                format(i, actual_position, expected_positions[i]))
1958
1959
1960
class TestTradingAlgorithm(TestCase):
1961
    def setUp(self):
1962
        self.env = TradingEnvironment()
1963
        self.days = self.env.trading_days[:4]
1964
1965
    def test_analyze_called(self):
1966
        self.perf_ref = None
1967
1968
        def initialize(context):
1969
            pass
1970
1971
        def handle_data(context, data):
1972
            pass
1973
1974
        def analyze(context, perf):
1975
            self.perf_ref = perf
1976
1977
        algo = TradingAlgorithm(initialize=initialize, handle_data=handle_data,
1978
                                analyze=analyze)
1979
1980
        data_portal = FakeDataPortal()
1981
1982
        results = algo.run(data_portal)
1983
        self.assertIs(results, self.perf_ref)
1984