Completed
Pull Request — master (#858)
by Eddie
01:35
created

tests.TestTradingControls.check_algo_succeeds()   A

Complexity

Conditions 1

Size

Total Lines 3

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 1
dl 0
loc 3
rs 10
1
#
2
# Copyright 2014 Quantopian, Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
import datetime
16
from datetime import timedelta
17
from mock import MagicMock
18
from nose_parameterized import parameterized
19
from six.moves import range
20
from testfixtures import TempDirectory
21
from textwrap import dedent
22
from unittest import TestCase
23
24
import numpy as np
25
import pandas as pd
26
27
from zipline.assets import Equity, Future
28
from zipline.utils.api_support import ZiplineAPI
29
from zipline.utils.control_flow import nullctx
30
from zipline.utils.test_utils import (
31
    setup_logger,
32
    teardown_logger,
33
    FakeDataPortal)
34
import zipline.utils.factory as factory
35
36
from zipline.errors import (
37
    OrderDuringInitialize,
38
    RegisterTradingControlPostInit,
39
    TradingControlViolation,
40
    AccountControlViolation,
41
    SymbolNotFound,
42
    RootSymbolNotFound,
43
    UnsupportedDatetimeFormat,
44
)
45
from zipline.test_algorithms import (
46
    access_account_in_init,
47
    access_portfolio_in_init,
48
    AmbitiousStopLimitAlgorithm,
49
    EmptyPositionsAlgorithm,
50
    InvalidOrderAlgorithm,
51
    RecordAlgorithm,
52
    FutureFlipAlgo,
53
    TestAlgorithm,
54
    TestOrderAlgorithm,
55
    TestOrderPercentAlgorithm,
56
    TestOrderStyleForwardingAlgorithm,
57
    TestOrderValueAlgorithm,
58
    TestRegisterTransformAlgorithm,
59
    TestTargetAlgorithm,
60
    TestTargetPercentAlgorithm,
61
    TestTargetValueAlgorithm,
62
    SetLongOnlyAlgorithm,
63
    SetAssetDateBoundsAlgorithm,
64
    SetMaxPositionSizeAlgorithm,
65
    SetMaxOrderCountAlgorithm,
66
    SetMaxOrderSizeAlgorithm,
67
    SetDoNotOrderListAlgorithm,
68
    SetMaxLeverageAlgorithm,
69
    api_algo,
70
    api_get_environment_algo,
71
    api_symbol_algo,
72
    call_all_order_methods,
73
    call_order_in_init,
74
    handle_data_api,
75
    handle_data_noop,
76
    initialize_api,
77
    initialize_noop,
78
    noop_algo,
79
    record_float_magic,
80
    record_variables,
81
)
82
from zipline.utils.context_tricks import CallbackManager
83
import zipline.utils.events
84
from zipline.assets import Equity
85
from zipline.utils.test_utils import (
86
    assert_single_position,
87
    drain_zipline,
88
    to_utc,
89
)
90
91
92
from zipline.finance.execution import LimitOrder
93
from zipline.finance.trading import SimulationParameters
94
from zipline.utils.api_support import set_algo_instance
95
from zipline.utils.events import DateRuleFactory, TimeRuleFactory, Always
96
from zipline.algorithm import TradingAlgorithm
97
from zipline.finance.trading import TradingEnvironment
98
from zipline.finance.commission import PerShare
99
100
from zipline.utils.test_utils import (
101
    create_data_portal,
102
    create_data_portal_from_trade_history
103
)
104
105
# Because test cases appear to reuse some resources.
106
107
108
_multiprocess_can_split_ = False
109
110
111
class TestRecordAlgorithm(TestCase):
112
113
    @classmethod
114
    def setUpClass(cls):
115
        cls.env = TradingEnvironment()
116
        cls.sids = [133]
117
        cls.env.write_data(equities_identifiers=cls.sids)
118
119
        cls.sim_params = factory.create_simulation_parameters(
120
            num_days=4,
121
            env=cls.env
122
        )
123
124
        cls.tempdir = TempDirectory()
125
126
        cls.data_portal = create_data_portal(
127
            cls.env,
128
            cls.tempdir,
129
            cls.sim_params,
130
            cls.sids
131
        )
132
133
    @classmethod
134
    def tearDownClass(cls):
135
        del cls.env
136
        cls.tempdir.cleanup()
137
138
    def test_record_incr(self):
139
        algo = RecordAlgorithm(sim_params=self.sim_params, env=self.env)
140
        output = algo.run(self.data_portal)
141
142
        np.testing.assert_array_equal(output['incr'].values,
143
                                      range(1, len(output) + 1))
144
        np.testing.assert_array_equal(output['name'].values,
145
                                      range(1, len(output) + 1))
146
        np.testing.assert_array_equal(output['name2'].values,
147
                                      [2] * len(output))
148
        np.testing.assert_array_equal(output['name3'].values,
149
                                      range(1, len(output) + 1))
150
151
152
class TestMiscellaneousAPI(TestCase):
153
154
    @classmethod
155
    def setUpClass(cls):
156
        cls.sids = [1, 2]
157
        cls.env = TradingEnvironment()
158
159
        metadata = {3: {'symbol': 'PLAY',
160
                        'start_date': '2002-01-01',
161
                        'end_date': '2004-01-01'},
162
                    4: {'symbol': 'PLAY',
163
                        'start_date': '2005-01-01',
164
                        'end_date': '2006-01-01'}}
165
166
        futures_metadata = {
167
            5: {
168
                'symbol': 'CLG06',
169
                'root_symbol': 'CL',
170
                'start_date': pd.Timestamp('2005-12-01', tz='UTC'),
171
                'notice_date': pd.Timestamp('2005-12-20', tz='UTC'),
172
                'expiration_date': pd.Timestamp('2006-01-20', tz='UTC')},
173
            6: {
174
                'root_symbol': 'CL',
175
                'symbol': 'CLK06',
176
                'start_date': pd.Timestamp('2005-12-01', tz='UTC'),
177
                'notice_date': pd.Timestamp('2006-03-20', tz='UTC'),
178
                'expiration_date': pd.Timestamp('2006-04-20', tz='UTC')},
179
            7: {
180
                'symbol': 'CLQ06',
181
                'root_symbol': 'CL',
182
                'start_date': pd.Timestamp('2005-12-01', tz='UTC'),
183
                'notice_date': pd.Timestamp('2006-06-20', tz='UTC'),
184
                'expiration_date': pd.Timestamp('2006-07-20', tz='UTC')},
185
            8: {
186
                'symbol': 'CLX06',
187
                'root_symbol': 'CL',
188
                'start_date': pd.Timestamp('2006-02-01', tz='UTC'),
189
                'notice_date': pd.Timestamp('2006-09-20', tz='UTC'),
190
                'expiration_date': pd.Timestamp('2006-10-20', tz='UTC')}
191
        }
192
        cls.env.write_data(equities_identifiers=cls.sids,
193
                           equities_data=metadata,
194
                           futures_data=futures_metadata)
195
196
        setup_logger(cls)
197
198
        cls.sim_params = factory.create_simulation_parameters(
199
            num_days=2,
200
            data_frequency='minute',
201
            emission_rate='daily',
202
            env=cls.env,
203
        )
204
205
        cls.temp_dir = TempDirectory()
206
207
        cls.data_portal = create_data_portal(
208
            cls.env,
209
            cls.temp_dir,
210
            cls.sim_params,
211
            cls.sids
212
        )
213
214
    @classmethod
215
    def tearDownClass(cls):
216
        del cls.env
217
        teardown_logger(cls)
218
        cls.temp_dir.cleanup()
219
220
    def test_zipline_api_resolves_dynamically(self):
221
        # Make a dummy algo.
222
        algo = TradingAlgorithm(
223
            initialize=lambda context: None,
224
            handle_data=lambda context, data: None,
225
            sim_params=self.sim_params,
226
        )
227
228
        # Verify that api methods get resolved dynamically by patching them out
229
        # and then calling them
230
        for method in algo.all_api_methods():
231
            name = method.__name__
232
            sentinel = object()
233
234
            def fake_method(*args, **kwargs):
235
                return sentinel
236
            setattr(algo, name, fake_method)
237
            with ZiplineAPI(algo):
238
                self.assertIs(sentinel, getattr(zipline.api, name)())
239
240
    def test_get_environment(self):
241
        expected_env = {
242
            'arena': 'backtest',
243
            'data_frequency': 'minute',
244
            'start': pd.Timestamp('2006-01-03 14:31:00+0000', tz='UTC'),
245
            'end': pd.Timestamp('2006-01-04 21:00:00+0000', tz='UTC'),
246
            'capital_base': 100000.0,
247
            'platform': 'zipline'
248
        }
249
250
        def initialize(algo):
251
            self.assertEqual('zipline', algo.get_environment())
252
            self.assertEqual(expected_env, algo.get_environment('*'))
253
254
        def handle_data(algo, data):
255
            pass
256
257
        algo = TradingAlgorithm(initialize=initialize,
258
                                handle_data=handle_data,
259
                                sim_params=self.sim_params,
260
                                env=self.env)
261
        algo.run(self.data_portal)
262
263
    def test_get_open_orders(self):
264
        def initialize(algo):
265
            algo.minute = 0
266
267
        def handle_data(algo, data):
268
            if algo.minute == 0:
269
270
                # Should be filled by the next minute
271
                algo.order(algo.sid(1), 1)
272
273
                # Won't be filled because the price is too low.
274
                algo.order(algo.sid(2), 1, style=LimitOrder(0.01))
275
                algo.order(algo.sid(2), 1, style=LimitOrder(0.01))
276
                algo.order(algo.sid(2), 1, style=LimitOrder(0.01))
277
278
                all_orders = algo.get_open_orders()
279
                self.assertEqual(list(all_orders.keys()), [1, 2])
280
281
                self.assertEqual(all_orders[1], algo.get_open_orders(1))
282
                self.assertEqual(len(all_orders[1]), 1)
283
284
                self.assertEqual(all_orders[2], algo.get_open_orders(2))
285
                self.assertEqual(len(all_orders[2]), 3)
286
287
            if algo.minute == 1:
288
                # First order should have filled.
289
                # Second order should still be open.
290
                all_orders = algo.get_open_orders()
291
                self.assertEqual(list(all_orders.keys()), [2])
292
293
                self.assertEqual([], algo.get_open_orders(1))
294
295
                orders_2 = algo.get_open_orders(2)
296
                self.assertEqual(all_orders[2], orders_2)
297
                self.assertEqual(len(all_orders[2]), 3)
298
299
                for order in orders_2:
300
                    algo.cancel_order(order)
301
302
                all_orders = algo.get_open_orders()
303
                self.assertEqual(all_orders, {})
304
305
            algo.minute += 1
306
307
        algo = TradingAlgorithm(initialize=initialize,
308
                                handle_data=handle_data,
309
                                sim_params=self.sim_params,
310
                                env=self.env)
311
        algo.run(self.data_portal)
312
313
    def test_schedule_function(self):
314
        date_rules = DateRuleFactory
315
        time_rules = TimeRuleFactory
316
317
        def incrementer(algo, data):
318
            algo.func_called += 1
319
            self.assertEqual(
320
                algo.get_datetime().time(),
321
                datetime.time(hour=14, minute=31),
322
            )
323
324
        def initialize(algo):
325
            algo.func_called = 0
326
            algo.days = 1
327
            algo.date = None
328
            algo.schedule_function(
329
                func=incrementer,
330
                date_rule=date_rules.every_day(),
331
                time_rule=time_rules.market_open(),
332
            )
333
334
        def handle_data(algo, data):
335
            if not algo.date:
336
                algo.date = algo.get_datetime().date()
337
338
            if algo.date < algo.get_datetime().date():
339
                algo.days += 1
340
                algo.date = algo.get_datetime().date()
341
342
        algo = TradingAlgorithm(
343
            initialize=initialize,
344
            handle_data=handle_data,
345
            sim_params=self.sim_params,
346
            env=self.env,
347
        )
348
        algo.run(self.data_portal)
349
350
        self.assertEqual(algo.func_called, algo.days)
351
352
    def test_event_context(self):
353
        expected_data = []
354
        collected_data_pre = []
355
        collected_data_post = []
356
        function_stack = []
357
358
        def pre(data):
359
            function_stack.append(pre)
360
            collected_data_pre.append(data)
361
362
        def post(data):
363
            function_stack.append(post)
364
            collected_data_post.append(data)
365
366
        def initialize(context):
367
            context.add_event(Always(), f)
368
            context.add_event(Always(), g)
369
370
        def handle_data(context, data):
371
            function_stack.append(handle_data)
372
            expected_data.append(data)
373
374
        def f(context, data):
375
            function_stack.append(f)
376
377
        def g(context, data):
378
            function_stack.append(g)
379
380
        algo = TradingAlgorithm(
381
            initialize=initialize,
382
            handle_data=handle_data,
383
            sim_params=self.sim_params,
384
            create_event_context=CallbackManager(pre, post),
385
            env=self.env,
386
        )
387
        algo.run(self.data_portal)
388
389
        self.assertEqual(len(expected_data), 780)
390
        self.assertEqual(collected_data_pre, expected_data)
391
        self.assertEqual(collected_data_post, expected_data)
392
393
        self.assertEqual(
394
            len(function_stack),
395
            780 * 5,
396
            'Incorrect number of functions called: %s != 780' %
397
            len(function_stack),
398
        )
399
        expected_functions = [pre, handle_data, f, g, post] * 780
400
        for n, (f, g) in enumerate(zip(function_stack, expected_functions)):
401
            self.assertEqual(
402
                f,
403
                g,
404
                'function at position %d was incorrect, expected %s but got %s'
405
                % (n, g.__name__, f.__name__),
406
            )
407
408
    @parameterized.expand([
409
        ('daily',),
410
        ('minute'),
411
    ])
412
    def test_schedule_function_rule_creation(self, mode):
413
        def nop(*args, **kwargs):
414
            return None
415
416
        self.sim_params.data_frequency = mode
417
        algo = TradingAlgorithm(
418
            initialize=nop,
419
            handle_data=nop,
420
            sim_params=self.sim_params,
421
            env=self.env,
422
        )
423
424
        # Schedule something for NOT Always.
425
        algo.schedule_function(nop, time_rule=zipline.utils.events.Never())
426
427
        event_rule = algo.event_manager._events[1].rule
428
429
        self.assertIsInstance(event_rule, zipline.utils.events.OncePerDay)
430
431
        inner_rule = event_rule.rule
432
        self.assertIsInstance(inner_rule, zipline.utils.events.ComposedRule)
433
434
        first = inner_rule.first
435
        second = inner_rule.second
436
        composer = inner_rule.composer
437
438
        self.assertIsInstance(first, zipline.utils.events.Always)
439
440
        if mode == 'daily':
441
            self.assertIsInstance(second, zipline.utils.events.Always)
442
        else:
443
            self.assertIsInstance(second, zipline.utils.events.Never)
444
445
        self.assertIs(composer, zipline.utils.events.ComposedRule.lazy_and)
446
447
    def test_asset_lookup(self):
448
449
        algo = TradingAlgorithm(env=self.env)
450
451
        # Test before either PLAY existed
452
        algo.sim_params.period_end = pd.Timestamp('2001-12-01', tz='UTC')
453
        with self.assertRaises(SymbolNotFound):
454
            algo.symbol('PLAY')
455
        with self.assertRaises(SymbolNotFound):
456
            algo.symbols('PLAY')
457
458
        # Test when first PLAY exists
459
        algo.sim_params.period_end = pd.Timestamp('2002-12-01', tz='UTC')
460
        list_result = algo.symbols('PLAY')
461
        self.assertEqual(3, list_result[0])
462
463
        # Test after first PLAY ends
464
        algo.sim_params.period_end = pd.Timestamp('2004-12-01', tz='UTC')
465
        self.assertEqual(3, algo.symbol('PLAY'))
466
467
        # Test after second PLAY begins
468
        algo.sim_params.period_end = pd.Timestamp('2005-12-01', tz='UTC')
469
        self.assertEqual(4, algo.symbol('PLAY'))
470
471
        # Test after second PLAY ends
472
        algo.sim_params.period_end = pd.Timestamp('2006-12-01', tz='UTC')
473
        self.assertEqual(4, algo.symbol('PLAY'))
474
        list_result = algo.symbols('PLAY')
475
        self.assertEqual(4, list_result[0])
476
477
        # Test lookup SID
478
        self.assertIsInstance(algo.sid(3), Equity)
479
        self.assertIsInstance(algo.sid(4), Equity)
480
481
        # Supplying a non-string argument to symbol()
482
        # should result in a TypeError.
483
        with self.assertRaises(TypeError):
484
            algo.symbol(1)
485
486
        with self.assertRaises(TypeError):
487
            algo.symbol((1,))
488
489
        with self.assertRaises(TypeError):
490
            algo.symbol({1})
491
492
        with self.assertRaises(TypeError):
493
            algo.symbol([1])
494
495
        with self.assertRaises(TypeError):
496
            algo.symbol({'foo': 'bar'})
497
498
    def test_future_symbol(self):
499
        """ Tests the future_symbol API function.
500
        """
501
        algo = TradingAlgorithm(env=self.env)
502
        algo.datetime = pd.Timestamp('2006-12-01', tz='UTC')
503
504
        # Check that we get the correct fields for the CLG06 symbol
505
        cl = algo.future_symbol('CLG06')
506
        self.assertEqual(cl.sid, 5)
507
        self.assertEqual(cl.symbol, 'CLG06')
508
        self.assertEqual(cl.root_symbol, 'CL')
509
        self.assertEqual(cl.start_date, pd.Timestamp('2005-12-01', tz='UTC'))
510
        self.assertEqual(cl.notice_date, pd.Timestamp('2005-12-20', tz='UTC'))
511
        self.assertEqual(cl.expiration_date,
512
                         pd.Timestamp('2006-01-20', tz='UTC'))
513
514
        with self.assertRaises(SymbolNotFound):
515
            algo.future_symbol('')
516
517
        with self.assertRaises(SymbolNotFound):
518
            algo.future_symbol('PLAY')
519
520
        with self.assertRaises(SymbolNotFound):
521
            algo.future_symbol('FOOBAR')
522
523
        # Supplying a non-string argument to future_symbol()
524
        # should result in a TypeError.
525
        with self.assertRaises(TypeError):
526
            algo.future_symbol(1)
527
528
        with self.assertRaises(TypeError):
529
            algo.future_symbol((1,))
530
531
        with self.assertRaises(TypeError):
532
            algo.future_symbol({1})
533
534
        with self.assertRaises(TypeError):
535
            algo.future_symbol([1])
536
537
        with self.assertRaises(TypeError):
538
            algo.future_symbol({'foo': 'bar'})
539
540
    def test_future_chain(self):
541
        """ Tests the future_chain API function.
542
        """
543
        algo = TradingAlgorithm(env=self.env)
544
        algo.datetime = pd.Timestamp('2006-12-01', tz='UTC')
545
546
        # Check that the fields of the FutureChain object are set correctly
547
        cl = algo.future_chain('CL')
548
        self.assertEqual(cl.root_symbol, 'CL')
549
        self.assertEqual(cl.as_of_date, algo.datetime)
550
551
        # Check the fields are set correctly if an as_of_date is supplied
552
        as_of_date = pd.Timestamp('1952-08-11', tz='UTC')
553
554
        cl = algo.future_chain('CL', as_of_date=as_of_date)
555
        self.assertEqual(cl.root_symbol, 'CL')
556
        self.assertEqual(cl.as_of_date, as_of_date)
557
558
        cl = algo.future_chain('CL', as_of_date='1952-08-11')
559
        self.assertEqual(cl.root_symbol, 'CL')
560
        self.assertEqual(cl.as_of_date, as_of_date)
561
562
        # Check that weird capitalization is corrected
563
        cl = algo.future_chain('cL')
564
        self.assertEqual(cl.root_symbol, 'CL')
565
566
        cl = algo.future_chain('cl')
567
        self.assertEqual(cl.root_symbol, 'CL')
568
569
        # Check that invalid root symbols raise RootSymbolNotFound
570
        with self.assertRaises(RootSymbolNotFound):
571
            algo.future_chain('CLZ')
572
573
        with self.assertRaises(RootSymbolNotFound):
574
            algo.future_chain('')
575
576
        # Check that invalid dates raise UnsupportedDatetimeFormat
577
        with self.assertRaises(UnsupportedDatetimeFormat):
578
            algo.future_chain('CL', 'my_finger_slipped')
579
580
        with self.assertRaises(UnsupportedDatetimeFormat):
581
            algo.future_chain('CL', '2015-09-')
582
583
        # Supplying a non-string argument to future_chain()
584
        # should result in a TypeError.
585
        with self.assertRaises(TypeError):
586
            algo.future_chain(1)
587
588
        with self.assertRaises(TypeError):
589
            algo.future_chain((1,))
590
591
        with self.assertRaises(TypeError):
592
            algo.future_chain({1})
593
594
        with self.assertRaises(TypeError):
595
            algo.future_chain([1])
596
597
        with self.assertRaises(TypeError):
598
            algo.future_chain({'foo': 'bar'})
599
600
    def test_set_symbol_lookup_date(self):
601
        """
602
        Test the set_symbol_lookup_date API method.
603
        """
604
        # Note we start sid enumeration at i+3 so as not to
605
        # collide with sids [1, 2] added in the setUp() method.
606
        dates = pd.date_range('2013-01-01', freq='2D', periods=2, tz='UTC')
607
        # Create two assets with the same symbol but different
608
        # non-overlapping date ranges.
609
        metadata = pd.DataFrame.from_records(
610
            [
611
                {
612
                    'sid': i + 3,
613
                    'symbol': 'DUP',
614
                    'start_date': date.value,
615
                    'end_date': (date + timedelta(days=1)).value,
616
                }
617
                for i, date in enumerate(dates)
618
            ]
619
        )
620
        env = TradingEnvironment()
621
        env.write_data(equities_df=metadata)
622
        algo = TradingAlgorithm(env=env)
623
624
        # Set the period end to a date after the period end
625
        # dates for our assets.
626
        algo.sim_params.period_end = pd.Timestamp('2015-01-01', tz='UTC')
627
628
        # With no symbol lookup date set, we will use the period end date
629
        # for the as_of_date, resulting here in the asset with the earlier
630
        # start date being returned.
631
        result = algo.symbol('DUP')
632
        self.assertEqual(result.symbol, 'DUP')
633
634
        # By first calling set_symbol_lookup_date, the relevant asset
635
        # should be returned by lookup_symbol
636
        for i, date in enumerate(dates):
637
            algo.set_symbol_lookup_date(date)
638
            result = algo.symbol('DUP')
639
            self.assertEqual(result.symbol, 'DUP')
640
            self.assertEqual(result.sid, i + 3)
641
642
        with self.assertRaises(UnsupportedDatetimeFormat):
643
            algo.set_symbol_lookup_date('foobar')
644
645
646
class TestTransformAlgorithm(TestCase):
647
648
    @classmethod
649
    def setUpClass(cls):
650
        setup_logger(cls)
651
        cls.env = TradingEnvironment()
652
        cls.sim_params = factory.create_simulation_parameters(num_days=4,
653
                                                              env=cls.env)
654
        cls.sids = [0, 1, 133]
655
        cls.tempdir = TempDirectory()
656
657
        futures_metadata = {3: {'contract_multiplier': 10}}
658
        equities_metadata = {}
659
660
        for sid in cls.sids:
661
            equities_metadata[sid] = {
662
                'start_date': cls.sim_params.period_start,
663
                'end_date': cls.sim_params.period_end
664
            }
665
666
        cls.env.write_data(equities_data=equities_metadata,
667
                           futures_data=futures_metadata)
668
669
        trades_by_sid = {}
670
        for sid in cls.sids:
671
            trades_by_sid[sid] = factory.create_trade_history(
672
                sid,
673
                [10.0, 10.0, 11.0, 11.0],
674
                [100, 100, 100, 300],
675
                timedelta(days=1),
676
                cls.sim_params,
677
                cls.env
678
            )
679
680
        cls.data_portal = create_data_portal_from_trade_history(cls.env,
681
                                                                cls.tempdir,
682
                                                                cls.sim_params,
683
                                                                trades_by_sid)
684
685
        cls.data_portal.current_day = cls.sim_params.trading_days[0]
686
687
    @classmethod
688
    def tearDownClass(cls):
689
        teardown_logger(cls)
690
        del cls.env
691
        cls.tempdir.cleanup()
692
693
    def test_invalid_order_parameters(self):
694
        algo = InvalidOrderAlgorithm(
695
            sids=[133],
696
            sim_params=self.sim_params,
697
            env=self.env,
698
        )
699
        algo.run(self.data_portal)
700
701
    def test_run_twice(self):
702
        algo1 = TestRegisterTransformAlgorithm(
703
            sim_params=self.sim_params,
704
            sids=[0, 1]
705
        )
706
707
        res1 = algo1.run(self.data_portal)
708
709
        # Create a new trading algorithm, which will
710
        # use the newly instantiated environment.
711
        algo2 = TestRegisterTransformAlgorithm(
712
            sim_params=self.sim_params,
713
            sids=[0, 1]
714
        )
715
716
        res2 = algo2.run(self.data_portal)
717
718
        # FIXME I think we are getting Nans due to fixed benchmark,
719
        # so dropping them for now.
720
        res1 = res1.fillna(method='ffill')
721
        res2 = res2.fillna(method='ffill')
722
723
        np.testing.assert_array_equal(res1, res2)
724
725
    def test_data_frequency_setting(self):
726
        self.sim_params.data_frequency = 'daily'
727
728
        sim_params = factory.create_simulation_parameters(
729
            num_days=4, env=self.env, data_frequency='daily')
730
731
        algo = TestRegisterTransformAlgorithm(
732
            sim_params=sim_params,
733
            env=self.env,
734
        )
735
        self.assertEqual(algo.sim_params.data_frequency, 'daily')
736
737
        sim_params = factory.create_simulation_parameters(
738
            num_days=4, env=self.env, data_frequency='minute')
739
740
        algo = TestRegisterTransformAlgorithm(
741
            sim_params=sim_params,
742
            env=self.env,
743
        )
744
        self.assertEqual(algo.sim_params.data_frequency, 'minute')
745
746
    @parameterized.expand([
747
        (TestOrderAlgorithm,),
748
        (TestOrderValueAlgorithm,),
749
        (TestTargetAlgorithm,),
750
        (TestOrderPercentAlgorithm,),
751
        (TestTargetPercentAlgorithm,),
752
        (TestTargetValueAlgorithm,),
753
    ])
754
    def test_order_methods(self, algo_class):
755
        algo = algo_class(
756
            sim_params=self.sim_params,
757
            env=self.env,
758
        )
759
        # Ensure that the environment's asset 0 is an Equity
760
        asset_to_test = algo.sid(0)
761
        self.assertIsInstance(asset_to_test, Equity)
762
763
        algo.run(self.data_portal)
764
765
    @parameterized.expand([
766
        (TestOrderAlgorithm,),
767
        (TestOrderValueAlgorithm,),
768
        (TestTargetAlgorithm,),
769
        (TestOrderPercentAlgorithm,),
770
        (TestTargetValueAlgorithm,),
771
    ])
772
    def test_order_methods_for_future(self, algo_class):
773
        algo = algo_class(
774
            sim_params=self.sim_params,
775
            env=self.env,
776
        )
777
        # Ensure that the environment's asset 0 is a Future
778
        asset_to_test = algo.sid(3)
779
        self.assertIsInstance(asset_to_test, Future)
780
781
        algo.run(self.data_portal)
782
783
    @parameterized.expand([
784
        ("order",),
785
        ("order_value",),
786
        ("order_percent",),
787
        ("order_target",),
788
        ("order_target_percent",),
789
        ("order_target_value",),
790
    ])
791
    def test_order_method_style_forwarding(self, order_style):
792
        algo = TestOrderStyleForwardingAlgorithm(
793
            sim_params=self.sim_params,
794
            method_name=order_style,
795
            env=self.env
796
        )
797
        algo.run(self.data_portal)
798
799
    @parameterized.expand([
800
        (TestOrderAlgorithm,),
801
        (TestOrderValueAlgorithm,),
802
        (TestTargetAlgorithm,),
803
        (TestOrderPercentAlgorithm,)
804
    ])
805
    def test_minute_data(self, algo_class):
806
        tempdir = TempDirectory()
807
808
        try:
809
            env = TradingEnvironment()
810
811
            sim_params = SimulationParameters(
812
                period_start=pd.Timestamp('2002-1-2', tz='UTC'),
813
                period_end=pd.Timestamp('2002-1-4', tz='UTC'),
814
                capital_base=float("1.0e5"),
815
                data_frequency='minute',
816
                env=env
817
            )
818
819
            equities_metadata = {}
820
821
            for sid in [0, 1]:
822
                equities_metadata[sid] = {
823
                    'start_date': sim_params.period_start,
824
                    'end_date': sim_params.period_end + timedelta(days=1)
825
                }
826
827
            env.write_data(equities_data=equities_metadata)
828
829
            data_portal = create_data_portal(
830
                env,
831
                tempdir,
832
                sim_params,
833
                [0, 1]
834
            )
835
836
            algo = algo_class(sim_params=sim_params, env=env)
837
            algo.run(data_portal)
838
        finally:
839
            tempdir.cleanup()
840
841
842
class TestPositions(TestCase):
843
    @classmethod
844
    def setUpClass(cls):
845
        setup_logger(cls)
846
        cls.env = TradingEnvironment()
847
        cls.sim_params = factory.create_simulation_parameters(num_days=4,
848
                                                              env=cls.env)
849
850
        cls.sids = [1, 133]
851
        cls.tempdir = TempDirectory()
852
853
        equities_metadata = {}
854
855
        for sid in cls.sids:
856
            equities_metadata[sid] = {
857
                'start_date': cls.sim_params.period_start,
858
                'end_date': cls.sim_params.period_end
859
            }
860
861
        cls.env.write_data(equities_data=equities_metadata)
862
863
        cls.data_portal = create_data_portal(
864
            cls.env,
865
            cls.tempdir,
866
            cls.sim_params,
867
            cls.sids
868
        )
869
870
    @classmethod
871
    def tearDownClass(cls):
872
        teardown_logger(cls)
873
        cls.tempdir.cleanup()
874
875
    def test_empty_portfolio(self):
876
        algo = EmptyPositionsAlgorithm(self.sids,
877
                                       sim_params=self.sim_params,
878
                                       env=self.env)
879
        daily_stats = algo.run(self.data_portal)
880
881
        expected_position_count = [
882
            0,  # Before entering the first position
883
            2,  # After entering, exiting on this date
884
            0,  # After exiting
885
            0,
886
        ]
887
888
        for i, expected in enumerate(expected_position_count):
889
            self.assertEqual(daily_stats.ix[i]['num_positions'],
890
                             expected)
891
892
    def test_noop_orders(self):
893
        algo = AmbitiousStopLimitAlgorithm(sid=1,
894
                                           sim_params=self.sim_params,
895
                                           env=self.env)
896
        daily_stats = algo.run(self.data_portal)
897
898
        # Verify that positions are empty for all dates.
899
        empty_positions = daily_stats.positions.map(lambda x: len(x) == 0)
900
        self.assertTrue(empty_positions.all())
901
902
903
class TestAlgoScript(TestCase):
904
905
    @classmethod
906
    def setUpClass(cls):
907
        setup_logger(cls)
908
        cls.env = TradingEnvironment()
909
        cls.sim_params = factory.create_simulation_parameters(num_days=251,
910
                                                              env=cls.env)
911
912
        cls.sids = [0, 1, 3, 133]
913
        cls.tempdir = TempDirectory()
914
915
        equities_metadata = {}
916
917
        for sid in cls.sids:
918
            equities_metadata[sid] = {
919
                'start_date': cls.sim_params.period_start,
920
                'end_date': cls.sim_params.period_end
921
            }
922
923
            if sid == 3:
924
                equities_metadata[sid]["symbol"] = "TEST"
925
                equities_metadata[sid]["asset_type"] = "equity"
926
927
        cls.env.write_data(equities_data=equities_metadata)
928
929
        days = 251
930
931
        trades_by_sid = {
932
            0: factory.create_trade_history(
933
                0,
934
                [10.0] * days,
935
                [100] * days,
936
                timedelta(days=1),
937
                cls.sim_params,
938
                cls.env),
939
            3: factory.create_trade_history(
940
                3,
941
                [10.0] * days,
942
                [100] * days,
943
                timedelta(days=1),
944
                cls.sim_params,
945
                cls.env)
946
        }
947
948
        cls.data_portal = create_data_portal_from_trade_history(cls.env,
949
                                                                cls.tempdir,
950
                                                                cls.sim_params,
951
                                                                trades_by_sid)
952
953
        cls.zipline_test_config = {
954
            'sid': 0,
955
        }
956
957
    @classmethod
958
    def tearDownClass(cls):
959
        del cls.env
960
        cls.tempdir.cleanup()
961
        teardown_logger(cls)
962
963
    def test_noop(self):
964
        algo = TradingAlgorithm(initialize=initialize_noop,
965
                                handle_data=handle_data_noop)
966
        algo.run(self.data_portal)
967
968
    def test_noop_string(self):
969
        algo = TradingAlgorithm(script=noop_algo)
970
        algo.run(self.data_portal)
971
972
    def test_api_calls(self):
973
        algo = TradingAlgorithm(initialize=initialize_api,
974
                                handle_data=handle_data_api,
975
                                env=self.env)
976
        algo.run(self.data_portal)
977
978
    def test_api_calls_string(self):
979
        algo = TradingAlgorithm(script=api_algo, env=self.env)
980
        algo.run(self.data_portal)
981
982
    def test_api_get_environment(self):
983
        platform = 'zipline'
984
        algo = TradingAlgorithm(script=api_get_environment_algo,
985
                                platform=platform)
986
        algo.run(self.data_portal)
987
        self.assertEqual(algo.environment, platform)
988
989
    def test_api_symbol(self):
990
        algo = TradingAlgorithm(script=api_symbol_algo,
991
                                env=self.env,
992
                                sim_params=self.sim_params)
993
        algo.run(self.data_portal)
994
995
    def test_fixed_slippage(self):
996
        # verify order -> transaction -> portfolio position.
997
        # --------------
998
        test_algo = TradingAlgorithm(
999
            script="""
1000
from zipline.api import (slippage,
1001
                         commission,
1002
                         set_slippage,
1003
                         set_commission,
1004
                         order,
1005
                         record,
1006
                         sid)
1007
1008
def initialize(context):
1009
    model = slippage.FixedSlippage(spread=0.10)
1010
    set_slippage(model)
1011
    set_commission(commission.PerTrade(100.00))
1012
    context.count = 1
1013
    context.incr = 0
1014
1015
def handle_data(context, data):
1016
    if context.incr < context.count:
1017
        order(sid(0), -1000)
1018
    record(price=data[0].price)
1019
1020
    context.incr += 1""",
1021
            sim_params=self.sim_params,
1022
            env=self.env,
1023
        )
1024
        results = test_algo.run(self.data_portal)
1025
1026
        # flatten the list of txns
1027
        all_txns = [val for sublist in results["transactions"].tolist()
1028
                    for val in sublist]
1029
1030
        self.assertEqual(len(all_txns), 1)
1031
        txn = all_txns[0]
1032
1033
        self.assertEqual(100.0, txn["commission"])
1034
        expected_spread = 0.05
1035
        expected_commish = 0.10
1036
        expected_price = test_algo.recorded_vars["price"] - expected_spread \
1037
            - expected_commish
1038
1039
        self.assertEqual(expected_price, txn['price'])
1040
1041
    def test_volshare_slippage(self):
1042
        tempdir = TempDirectory()
1043
        try:
1044
            # verify order -> transaction -> portfolio position.
1045
            # --------------
1046
            test_algo = TradingAlgorithm(
1047
                script="""
1048
from zipline.api import *
1049
1050
def initialize(context):
1051
    model = slippage.VolumeShareSlippage(
1052
                            volume_limit=.3,
1053
                            price_impact=0.05
1054
                       )
1055
    set_slippage(model)
1056
    set_commission(commission.PerShare(0.02))
1057
    context.count = 2
1058
    context.incr = 0
1059
1060
def handle_data(context, data):
1061
    if context.incr < context.count:
1062
        # order small lots to be sure the
1063
        # order will fill in a single transaction
1064
        order(sid(0), 5000)
1065
    record(price=data[0].price)
1066
    record(volume=data[0].volume)
1067
    record(incr=context.incr)
1068
    context.incr += 1
1069
    """,
1070
                sim_params=self.sim_params,
1071
                env=self.env,
1072
            )
1073
            set_algo_instance(test_algo)
1074
            trades = factory.create_daily_trade_source(
1075
                [0], self.sim_params, self.env)
1076
            data_portal = create_data_portal_from_trade_history(
1077
                self.env, tempdir, self.sim_params, {0: trades})
1078
            results = test_algo.run(data_portal=data_portal)
1079
1080
            all_txns = [
1081
                val for sublist in results["transactions"].tolist()
1082
                for val in sublist]
1083
1084
            self.assertEqual(len(all_txns), 67)
1085
1086
            per_share_commish = 0.02
1087
            first_txn = all_txns[0]
1088
            commish = first_txn["amount"] * per_share_commish
1089
            self.assertEqual(commish, first_txn["commission"])
1090
            self.assertEqual(2.029, first_txn["price"])
1091
        finally:
1092
            tempdir.cleanup()
1093
1094
    def test_algo_record_vars(self):
1095
        test_algo = TradingAlgorithm(
1096
            script=record_variables,
1097
            sim_params=self.sim_params,
1098
            env=self.env,
1099
        )
1100
        results = test_algo.run(self.data_portal)
1101
1102
        for i in range(1, 252):
1103
            self.assertEqual(results.iloc[i-1]["incr"], i)
1104
1105
    def test_algo_record_allow_mock(self):
1106
        """
1107
        Test that values from "MagicMock"ed methods can be passed to record.
1108
1109
        Relevant for our basic/validation and methods like history, which
1110
        will end up returning a MagicMock instead of a DataFrame.
1111
        """
1112
        test_algo = TradingAlgorithm(
1113
            script=record_variables,
1114
            sim_params=self.sim_params,
1115
        )
1116
        set_algo_instance(test_algo)
1117
1118
        test_algo.record(foo=MagicMock())
1119
1120
    def test_algo_record_nan(self):
1121
        test_algo = TradingAlgorithm(
1122
            script=record_float_magic % 'nan',
1123
            sim_params=self.sim_params,
1124
            env=self.env,
1125
        )
1126
        results = test_algo.run(self.data_portal)
1127
1128
        for i in range(1, 252):
1129
            self.assertTrue(np.isnan(results.iloc[i-1]["data"]))
1130
1131
    def test_order_methods(self):
1132
        """
1133
        Only test that order methods can be called without error.
1134
        Correct filling of orders is tested in zipline.
1135
        """
1136
        test_algo = TradingAlgorithm(
1137
            script=call_all_order_methods,
1138
            sim_params=self.sim_params,
1139
            env=self.env,
1140
        )
1141
        test_algo.run(self.data_portal)
1142
1143
    def test_order_in_init(self):
1144
        """
1145
        Test that calling order in initialize
1146
        will raise an error.
1147
        """
1148
        with self.assertRaises(OrderDuringInitialize):
1149
            test_algo = TradingAlgorithm(
1150
                script=call_order_in_init,
1151
                sim_params=self.sim_params,
1152
                env=self.env,
1153
            )
1154
            test_algo.run(self.data_portal)
1155
1156
    def test_portfolio_in_init(self):
1157
        """
1158
        Test that accessing portfolio in init doesn't break.
1159
        """
1160
        test_algo = TradingAlgorithm(
1161
            script=access_portfolio_in_init,
1162
            sim_params=self.sim_params,
1163
            env=self.env,
1164
        )
1165
        test_algo.run(self.data_portal)
1166
1167
    def test_account_in_init(self):
1168
        """
1169
        Test that accessing account in init doesn't break.
1170
        """
1171
        test_algo = TradingAlgorithm(
1172
            script=access_account_in_init,
1173
            sim_params=self.sim_params,
1174
            env=self.env,
1175
        )
1176
        test_algo.run(self.data_portal)
1177
1178
1179
class TestGetDatetime(TestCase):
1180
1181
    @classmethod
1182
    def setUpClass(cls):
1183
        cls.env = TradingEnvironment()
1184
        cls.env.write_data(equities_identifiers=[0, 1])
1185
1186
        setup_logger(cls)
1187
1188
        cls.sim_params = factory.create_simulation_parameters(
1189
            data_frequency='minute',
1190
            env=cls.env,
1191
            start=to_utc('2014-01-02 9:31'),
1192
            end=to_utc('2014-01-03 9:31')
1193
        )
1194
1195
        cls.tempdir = TempDirectory()
1196
1197
        cls.data_portal = create_data_portal(
1198
            cls.env,
1199
            cls.tempdir,
1200
            cls.sim_params,
1201
            [1]
1202
        )
1203
1204
    @classmethod
1205
    def tearDownClass(cls):
1206
        del cls.env
1207
        teardown_logger(cls)
1208
        cls.tempdir.cleanup()
1209
1210
    @parameterized.expand(
1211
        [
1212
            ('default', None,),
1213
            ('utc', 'UTC',),
1214
            ('us_east', 'US/Eastern',),
1215
        ]
1216
    )
1217
    def test_get_datetime(self, name, tz):
1218
        algo = dedent(
1219
            """
1220
            import pandas as pd
1221
            from zipline.api import get_datetime
1222
1223
            def initialize(context):
1224
                context.tz = {tz} or 'UTC'
1225
                context.first_bar = True
1226
1227
            def handle_data(context, data):
1228
                if context.first_bar:
1229
                    dt = get_datetime({tz})
1230
                    if dt.tz.zone != context.tz:
1231
                        raise ValueError("Mismatched Zone")
1232
                    elif dt.tz_convert("US/Eastern").hour != 9:
1233
                        raise ValueError("Mismatched Hour")
1234
                    elif dt.tz_convert("US/Eastern").minute != 31:
1235
                        raise ValueError("Mismatched Minute")
1236
                context.first_bar = False
1237
            """.format(tz=repr(tz))
1238
        )
1239
1240
        algo = TradingAlgorithm(
1241
            script=algo,
1242
            sim_params=self.sim_params,
1243
            env=self.env,
1244
        )
1245
        algo.run(self.data_portal)
1246
        self.assertFalse(algo.first_bar)
1247
1248
1249
class TestTradingControls(TestCase):
1250
1251
    @classmethod
1252
    def setUpClass(cls):
1253
        cls.sid = 133
1254
        cls.env = TradingEnvironment()
1255
        cls.sim_params = factory.create_simulation_parameters(num_days=4,
1256
                                                              env=cls.env)
1257
1258
        cls.env.write_data(equities_data={
1259
            133: {
1260
                'start_date': cls.sim_params.period_start,
1261
                'end_date': cls.sim_params.period_end
1262
            }
1263
        })
1264
1265
        cls.tempdir = TempDirectory()
1266
1267
        cls.data_portal = create_data_portal(
1268
            cls.env,
1269
            cls.tempdir,
1270
            cls.sim_params,
1271
            [cls.sid]
1272
        )
1273
1274
    @classmethod
1275
    def tearDownClass(cls):
1276
        del cls.env
1277
        cls.tempdir.cleanup()
1278
1279
    def _check_algo(self,
1280
                    algo,
1281
                    handle_data,
1282
                    expected_order_count,
1283
                    expected_exc):
1284
1285
        algo._handle_data = handle_data
1286
        with self.assertRaises(expected_exc) if expected_exc else nullctx():
1287
            algo.run(self.data_portal)
1288
        self.assertEqual(algo.order_count, expected_order_count)
1289
1290
    def check_algo_succeeds(self, algo, handle_data, order_count=4):
1291
        # Default for order_count assumes one order per handle_data call.
1292
        self._check_algo(algo, handle_data, order_count, None)
1293
1294
    def check_algo_fails(self, algo, handle_data, order_count):
1295
        self._check_algo(algo,
1296
                         handle_data,
1297
                         order_count,
1298
                         TradingControlViolation)
1299
1300
    def test_set_max_position_size(self):
1301
1302
        # Buy one share four times.  Should be fine.
1303
        def handle_data(algo, data):
1304
            algo.order(algo.sid(self.sid), 1)
1305
            algo.order_count += 1
1306
        algo = SetMaxPositionSizeAlgorithm(sid=self.sid,
1307
                                           max_shares=10,
1308
                                           max_notional=500.0,
1309
                                           sim_params=self.sim_params,
1310
                                           env=self.env)
1311
        self.check_algo_succeeds(algo, handle_data)
1312
1313
        # Buy three shares four times.  Should bail on the fourth before it's
1314
        # placed.
1315
        def handle_data(algo, data):
1316
            algo.order(algo.sid(self.sid), 3)
1317
            algo.order_count += 1
1318
1319
        algo = SetMaxPositionSizeAlgorithm(sid=self.sid,
1320
                                           max_shares=10,
1321
                                           max_notional=500.0,
1322
                                           sim_params=self.sim_params,
1323
                                           env=self.env)
1324
        self.check_algo_fails(algo, handle_data, 3)
1325
1326
        # Buy three shares four times. Should bail due to max_notional on the
1327
        # third attempt.
1328
        def handle_data(algo, data):
1329
            algo.order(algo.sid(self.sid), 3)
1330
            algo.order_count += 1
1331
1332
        algo = SetMaxPositionSizeAlgorithm(sid=self.sid,
1333
                                           max_shares=10,
1334
                                           max_notional=67.0,
1335
                                           sim_params=self.sim_params,
1336
                                           env=self.env)
1337
        self.check_algo_fails(algo, handle_data, 2)
1338
1339
        # Set the trading control to a different sid, then BUY ALL THE THINGS!.
1340
        # Should continue normally.
1341
        def handle_data(algo, data):
1342
            algo.order(algo.sid(self.sid), 10000)
1343
            algo.order_count += 1
1344
        algo = SetMaxPositionSizeAlgorithm(sid=self.sid + 1,
1345
                                           max_shares=10,
1346
                                           max_notional=67.0,
1347
                                           sim_params=self.sim_params,
1348
                                           env=self.env)
1349
        self.check_algo_succeeds(algo, handle_data)
1350
1351
        # Set the trading control sid to None, then BUY ALL THE THINGS!. Should
1352
        # fail because setting sid to None makes the control apply to all sids.
1353
        def handle_data(algo, data):
1354
            algo.order(algo.sid(self.sid), 10000)
1355
            algo.order_count += 1
1356
        algo = SetMaxPositionSizeAlgorithm(max_shares=10, max_notional=61.0,
1357
                                           sim_params=self.sim_params,
1358
                                           env=self.env)
1359
        self.check_algo_fails(algo, handle_data, 0)
1360
1361
    def test_set_do_not_order_list(self):
1362
        # set the restricted list to be the sid, and fail.
1363
        algo = SetDoNotOrderListAlgorithm(
1364
            sid=self.sid,
1365
            restricted_list=[self.sid],
1366
            sim_params=self.sim_params,
1367
            env=self.env,
1368
        )
1369
1370
        def handle_data(algo, data):
1371
            algo.order(algo.sid(self.sid), 100)
1372
            algo.order_count += 1
1373
1374
        self.check_algo_fails(algo, handle_data, 0)
1375
1376
        # set the restricted list to exclude the sid, and succeed
1377
        algo = SetDoNotOrderListAlgorithm(
1378
            sid=self.sid,
1379
            restricted_list=[134, 135, 136],
1380
            sim_params=self.sim_params,
1381
            env=self.env,
1382
        )
1383
1384
        def handle_data(algo, data):
1385
            algo.order(algo.sid(self.sid), 100)
1386
            algo.order_count += 1
1387
1388
        self.check_algo_succeeds(algo, handle_data)
1389
1390
    def test_set_max_order_size(self):
1391
1392
        # Buy one share.
1393
        def handle_data(algo, data):
1394
            algo.order(algo.sid(self.sid), 1)
1395
            algo.order_count += 1
1396
        algo = SetMaxOrderSizeAlgorithm(sid=self.sid,
1397
                                        max_shares=10,
1398
                                        max_notional=500.0,
1399
                                        sim_params=self.sim_params,
1400
                                        env=self.env)
1401
        self.check_algo_succeeds(algo, handle_data)
1402
1403
        # Buy 1, then 2, then 3, then 4 shares.  Bail on the last attempt
1404
        # because we exceed shares.
1405
        def handle_data(algo, data):
1406
            algo.order(algo.sid(self.sid), algo.order_count + 1)
1407
            algo.order_count += 1
1408
1409
        algo = SetMaxOrderSizeAlgorithm(sid=self.sid,
1410
                                        max_shares=3,
1411
                                        max_notional=500.0,
1412
                                        sim_params=self.sim_params,
1413
                                        env=self.env)
1414
        self.check_algo_fails(algo, handle_data, 3)
1415
1416
        # Buy 1, then 2, then 3, then 4 shares.  Bail on the last attempt
1417
        # because we exceed notional.
1418
        def handle_data(algo, data):
1419
            algo.order(algo.sid(self.sid), algo.order_count + 1)
1420
            algo.order_count += 1
1421
1422
        algo = SetMaxOrderSizeAlgorithm(sid=self.sid,
1423
                                        max_shares=10,
1424
                                        max_notional=40.0,
1425
                                        sim_params=self.sim_params,
1426
                                        env=self.env)
1427
        self.check_algo_fails(algo, handle_data, 3)
1428
1429
        # Set the trading control to a different sid, then BUY ALL THE THINGS!.
1430
        # Should continue normally.
1431
        def handle_data(algo, data):
1432
            algo.order(algo.sid(self.sid), 10000)
1433
            algo.order_count += 1
1434
        algo = SetMaxOrderSizeAlgorithm(sid=self.sid + 1,
1435
                                        max_shares=1,
1436
                                        max_notional=1.0,
1437
                                        sim_params=self.sim_params,
1438
                                        env=self.env)
1439
        self.check_algo_succeeds(algo, handle_data)
1440
1441
        # Set the trading control sid to None, then BUY ALL THE THINGS!.
1442
        # Should fail because not specifying a sid makes the trading control
1443
        # apply to all sids.
1444
        def handle_data(algo, data):
1445
            algo.order(algo.sid(self.sid), 10000)
1446
            algo.order_count += 1
1447
        algo = SetMaxOrderSizeAlgorithm(max_shares=1,
1448
                                        max_notional=1.0,
1449
                                        sim_params=self.sim_params,
1450
                                        env=self.env)
1451
        self.check_algo_fails(algo, handle_data, 0)
1452
1453
    def test_set_max_order_count(self):
1454
        tempdir = TempDirectory()
1455
        try:
1456
            env = TradingEnvironment()
1457
            sim_params = factory.create_simulation_parameters(
1458
                num_days=4, env=env, data_frequency="minute")
1459
1460
            env.write_data(equities_data={
1461
                1: {
1462
                    'start_date': sim_params.period_start,
1463
                    'end_date': sim_params.period_end + timedelta(days=1)
1464
                }
1465
            })
1466
1467
            data_portal = create_data_portal(
1468
                env,
1469
                tempdir,
1470
                sim_params,
1471
                [1]
1472
            )
1473
1474
            def handle_data(algo, data):
1475
                for i in range(5):
1476
                    algo.order(algo.sid(1), 1)
1477
                    algo.order_count += 1
1478
1479
            algo = SetMaxOrderCountAlgorithm(3, sim_params=sim_params,
1480
                                             env=env)
1481
            with self.assertRaises(TradingControlViolation):
1482
                algo._handle_data = handle_data
1483
                algo.run(data_portal)
1484
1485
            self.assertEqual(algo.order_count, 3)
1486
1487
            # This time, order 5 times twice in a single day. The last order
1488
            # of the second batch should fail.
1489
            def handle_data2(algo, data):
1490
                if algo.minute_count == 0 or algo.minute_count == 100:
1491
                    for i in range(5):
1492
                        algo.order(algo.sid(1), 1)
1493
                        algo.order_count += 1
1494
1495
                algo.minute_count += 1
1496
1497
            algo = SetMaxOrderCountAlgorithm(9, sim_params=sim_params,
1498
                                             env=env)
1499
            with self.assertRaises(TradingControlViolation):
1500
                algo._handle_data = handle_data2
1501
                algo.run(data_portal)
1502
1503
            self.assertEqual(algo.order_count, 9)
1504
1505
            def handle_data3(algo, data):
1506
                if (algo.minute_count % 390) == 0:
1507
                    for i in range(5):
1508
                        algo.order(algo.sid(1), 1)
1509
                        algo.order_count += 1
1510
1511
                algo.minute_count += 1
1512
1513
            # Only 5 orders are placed per day, so this should pass even
1514
            # though in total more than 20 orders are placed.
1515
            algo = SetMaxOrderCountAlgorithm(5, sim_params=sim_params,
1516
                                             env=env)
1517
            algo._handle_data = handle_data3
1518
            algo.run(data_portal)
1519
        finally:
1520
            tempdir.cleanup()
1521
1522
    def test_long_only(self):
1523
        # Sell immediately -> fail immediately.
1524
        def handle_data(algo, data):
1525
            algo.order(algo.sid(self.sid), -1)
1526
            algo.order_count += 1
1527
        algo = SetLongOnlyAlgorithm(sim_params=self.sim_params, env=self.env)
1528
        self.check_algo_fails(algo, handle_data, 0)
1529
1530
        # Buy on even days, sell on odd days.  Never takes a short position, so
1531
        # should succeed.
1532
        def handle_data(algo, data):
1533
            if (algo.order_count % 2) == 0:
1534
                algo.order(algo.sid(self.sid), 1)
1535
            else:
1536
                algo.order(algo.sid(self.sid), -1)
1537
            algo.order_count += 1
1538
        algo = SetLongOnlyAlgorithm(sim_params=self.sim_params, env=self.env)
1539
        self.check_algo_succeeds(algo, handle_data)
1540
1541
        # Buy on first three days, then sell off holdings.  Should succeed.
1542
        def handle_data(algo, data):
1543
            amounts = [1, 1, 1, -3]
1544
            algo.order(algo.sid(self.sid), amounts[algo.order_count])
1545
            algo.order_count += 1
1546
        algo = SetLongOnlyAlgorithm(sim_params=self.sim_params, env=self.env)
1547
        self.check_algo_succeeds(algo, handle_data)
1548
1549
        # Buy on first three days, then sell off holdings plus an extra share.
1550
        # Should fail on the last sale.
1551
        def handle_data(algo, data):
1552
            amounts = [1, 1, 1, -4]
1553
            algo.order(algo.sid(self.sid), amounts[algo.order_count])
1554
            algo.order_count += 1
1555
        algo = SetLongOnlyAlgorithm(sim_params=self.sim_params, env=self.env)
1556
        self.check_algo_fails(algo, handle_data, 3)
1557
1558
    def test_register_post_init(self):
1559
1560
        def initialize(algo):
1561
            algo.initialized = True
1562
1563
        def handle_data(algo, data):
1564
            with self.assertRaises(RegisterTradingControlPostInit):
1565
                algo.set_max_position_size(self.sid, 1, 1)
1566
            with self.assertRaises(RegisterTradingControlPostInit):
1567
                algo.set_max_order_size(self.sid, 1, 1)
1568
            with self.assertRaises(RegisterTradingControlPostInit):
1569
                algo.set_max_order_count(1)
1570
            with self.assertRaises(RegisterTradingControlPostInit):
1571
                algo.set_long_only()
1572
1573
        algo = TradingAlgorithm(initialize=initialize,
1574
                                handle_data=handle_data,
1575
                                sim_params=self.sim_params,
1576
                                env=self.env)
1577
        algo.run(self.data_portal)
1578
1579
    def test_asset_date_bounds(self):
1580
        tempdir = TempDirectory()
1581
        try:
1582
            # Run the algorithm with a sid that ends far in the future
1583
            temp_env = TradingEnvironment()
1584
1585
            data_portal = create_data_portal(
1586
                temp_env,
1587
                tempdir,
1588
                self.sim_params,
1589
                [0]
1590
            )
1591
1592
            metadata = {0: {'start_date': self.sim_params.period_start,
1593
                            'end_date': '2020-01-01'}}
1594
1595
            algo = SetAssetDateBoundsAlgorithm(
1596
                equities_metadata=metadata,
1597
                sim_params=self.sim_params,
1598
                env=temp_env,
1599
            )
1600
            algo.run(data_portal)
1601
        finally:
1602
            tempdir.cleanup()
1603
1604
        # Run the algorithm with a sid that has already ended
1605
        tempdir = TempDirectory()
1606
        try:
1607
            temp_env = TradingEnvironment()
1608
1609
            data_portal = create_data_portal(
1610
                temp_env,
1611
                tempdir,
1612
                self.sim_params,
1613
                [0]
1614
            )
1615
            metadata = {0: {'start_date': '1989-01-01',
1616
                            'end_date': '1990-01-01'}}
1617
1618
            algo = SetAssetDateBoundsAlgorithm(
1619
                equities_metadata=metadata,
1620
                sim_params=self.sim_params,
1621
                env=temp_env,
1622
            )
1623
            with self.assertRaises(TradingControlViolation):
1624
                algo.run(data_portal)
1625
        finally:
1626
            tempdir.cleanup()
1627
1628
        # Run the algorithm with a sid that has not started
1629
        tempdir = TempDirectory()
1630
        try:
1631
            temp_env = TradingEnvironment()
1632
            data_portal = create_data_portal(
1633
                temp_env,
1634
                tempdir,
1635
                self.sim_params,
1636
                [0]
1637
            )
1638
1639
            metadata = {0: {'start_date': '2020-01-01',
1640
                            'end_date': '2021-01-01'}}
1641
1642
            algo = SetAssetDateBoundsAlgorithm(
1643
                equities_metadata=metadata,
1644
                sim_params=self.sim_params,
1645
                env=temp_env,
1646
            )
1647
1648
            with self.assertRaises(TradingControlViolation):
1649
                algo.run(data_portal)
1650
1651
        finally:
1652
            tempdir.cleanup()
1653
1654
1655
class TestAccountControls(TestCase):
1656
1657
    @classmethod
1658
    def setUpClass(cls):
1659
        cls.sidint = 133
1660
        cls.env = TradingEnvironment()
1661
        cls.sim_params = factory.create_simulation_parameters(
1662
            num_days=4, env=cls.env
1663
        )
1664
1665
        cls.env.write_data(equities_data={
1666
            133: {
1667
                'start_date': cls.sim_params.period_start,
1668
                'end_date': cls.sim_params.period_end + timedelta(days=1)
1669
            }
1670
        })
1671
1672
        cls.tempdir = TempDirectory()
1673
1674
        trades_by_sid = {
1675
            cls.sidint: factory.create_trade_history(
1676
                cls.sidint,
1677
                [10.0, 10.0, 11.0, 11.0],
1678
                [100, 100, 100, 300],
1679
                timedelta(days=1),
1680
                cls.sim_params,
1681
                cls.env,
1682
            )
1683
        }
1684
1685
        cls.data_portal = create_data_portal_from_trade_history(cls.env,
1686
                                                                cls.tempdir,
1687
                                                                cls.sim_params,
1688
                                                                trades_by_sid)
1689
1690
    @classmethod
1691
    def tearDownClass(cls):
1692
        del cls.env
1693
        cls.tempdir.cleanup()
1694
1695
    def _check_algo(self,
1696
                    algo,
1697
                    handle_data,
1698
                    expected_exc):
1699
1700
        algo._handle_data = handle_data
1701
        with self.assertRaises(expected_exc) if expected_exc else nullctx():
1702
            algo.run(self.data_portal)
1703
1704
    def check_algo_succeeds(self, algo, handle_data):
1705
        # Default for order_count assumes one order per handle_data call.
1706
        self._check_algo(algo, handle_data, None)
1707
1708
    def check_algo_fails(self, algo, handle_data):
1709
        self._check_algo(algo,
1710
                         handle_data,
1711
                         AccountControlViolation)
1712
1713
    def test_set_max_leverage(self):
1714
1715
        # Set max leverage to 0 so buying one share fails.
1716
        def handle_data(algo, data):
1717
            algo.order(algo.sid(self.sidint), 1)
1718
1719
        algo = SetMaxLeverageAlgorithm(0, sim_params=self.sim_params,
1720
                                       env=self.env)
1721
        self.check_algo_fails(algo, handle_data)
1722
1723
        # Set max leverage to 1 so buying one share passes
1724
        def handle_data(algo, data):
1725
            algo.order(algo.sid(self.sidint), 1)
1726
1727
        algo = SetMaxLeverageAlgorithm(1,  sim_params=self.sim_params,
1728
                                       env=self.env)
1729
        self.check_algo_succeeds(algo, handle_data)
1730
1731
1732
class TestClosePosAlgo(TestCase):
1733
1734
    @classmethod
1735
    def setUpClass(cls):
1736
        cls.tempdir = TempDirectory()
1737
1738
        cls.env = TradingEnvironment()
1739
        cls.days = pd.date_range(start=pd.Timestamp("2006-01-09", tz='UTC'),
1740
                                 end=pd.Timestamp("2006-01-12", tz='UTC'))
1741
1742
        cls.sid = 1
1743
1744
        cls.sim_params = factory.create_simulation_parameters(
1745
            start=cls.days[0],
1746
            end=cls.days[-1]
1747
        )
1748
1749
        trades_by_sid = {}
1750
        trades_by_sid[cls.sid] = factory.create_trade_history(
1751
            cls.sid,
1752
            [1, 1, 2, 4],
1753
            [1e9, 1e9, 1e9, 1e9],
1754
            timedelta(days=1),
1755
            cls.sim_params,
1756
            cls.env
1757
        )
1758
1759
        cls.data_portal = create_data_portal_from_trade_history(
1760
            cls.env,
1761
            cls.tempdir,
1762
            cls.sim_params,
1763
            trades_by_sid
1764
        )
1765
1766
    @classmethod
1767
    def tearDownClass(cls):
1768
        cls.tempdir.cleanup()
1769
1770
    def test_auto_close_future(self):
1771
        self.env.write_data(futures_data={
1772
            1: {
1773
                "start_date": self.sim_params.trading_days[0],
1774
                "end_date": self.env.next_trading_day(
1775
                    self.sim_params.trading_days[-1]),
1776
                'symbol': 'TEST',
1777
                'asset_type': 'future',
1778
                'auto_close_date': self.env.next_trading_day(
1779
                    self.sim_params.trading_days[-1])
1780
            }
1781
        })
1782
1783
        algo = TestAlgorithm(sid=1, amount=1, order_count=1,
1784
                             commission=PerShare(0),
1785
                             env=self.env,
1786
                             sim_params=self.sim_params)
1787
1788
        # Check results
1789
        results = algo.run(self.data_portal)
1790
1791
        expected_pnl = [0, 0, 1, 2]
1792
        self.check_algo_pnl(results, expected_pnl)
1793
1794
        expected_positions = [0, 1, 1, 0]
1795
        self.check_algo_positions(results, expected_positions)
1796
1797
        expected_pnl = [0, 0, 1, 2]
1798
        self.check_algo_pnl(results, expected_pnl)
1799
1800
    def check_algo_pnl(self, results, expected_pnl):
1801
        np.testing.assert_array_almost_equal(results.pnl, expected_pnl)
1802
1803
    def check_algo_positions(self, results, expected_positions):
1804
        for i, amount in enumerate(results.positions):
1805
            if amount:
1806
                actual_position = amount[0]['amount']
1807
            else:
1808
                actual_position = 0
1809
1810
            self.assertEqual(
1811
                actual_position, expected_positions[i],
1812
                "position for day={0} not equal, actual={1}, expected={2}".
1813
                format(i, actual_position, expected_positions[i]))
1814
1815
1816
class TestFutureFlip(TestCase):
1817
    @classmethod
1818
    def setUpClass(cls):
1819
        cls.tempdir = TempDirectory()
1820
1821
        cls.env = TradingEnvironment()
1822
        cls.days = pd.date_range(start=pd.Timestamp("2006-01-09", tz='UTC'),
1823
                                 end=pd.Timestamp("2006-01-12", tz='UTC'))
1824
1825
        cls.sid = 1
1826
1827
        cls.sim_params = factory.create_simulation_parameters(
1828
            start=cls.days[0],
1829
            end=cls.days[-2]
1830
        )
1831
1832
        trades = factory.create_trade_history(
1833
            cls.sid,
1834
            [1, 2, 4],
1835
            [1e9, 1e9, 1e9],
1836
            timedelta(days=1),
1837
            cls.sim_params,
1838
            cls.env
1839
        )
1840
1841
        trades_by_sid = {
1842
            cls.sid: trades
1843
        }
1844
1845
        cls.data_portal = create_data_portal_from_trade_history(
1846
            cls.env,
1847
            cls.tempdir,
1848
            cls.sim_params,
1849
            trades_by_sid
1850
        )
1851
1852
    @classmethod
1853
    def tearDownClass(cls):
1854
        cls.tempdir.cleanup()
1855
1856
    def test_flip_algo(self):
1857
        metadata = {1: {'symbol': 'TEST',
1858
                        'start_date': self.sim_params.trading_days[0],
1859
                        'end_date': self.env.next_trading_day(
1860
                            self.sim_params.trading_days[-1]),
1861
                        'contract_multiplier': 5}}
1862
        self.env.write_data(futures_data=metadata)
1863
1864
        algo = FutureFlipAlgo(sid=1, amount=1, env=self.env,
1865
                              commission=PerShare(0),
1866
                              order_count=0,  # not applicable but required
1867
                              sim_params=self.sim_params)
1868
1869
        results = algo.run(self.data_portal)
1870
1871
        expected_positions = [0, 1, -1]
1872
        self.check_algo_positions(results, expected_positions)
1873
1874
        expected_pnl = [0, 5, -10]
1875
        self.check_algo_pnl(results, expected_pnl)
1876
1877
    def check_algo_pnl(self, results, expected_pnl):
1878
        np.testing.assert_array_almost_equal(results.pnl, expected_pnl)
1879
1880
    def check_algo_positions(self, results, expected_positions):
1881
        for i, amount in enumerate(results.positions):
1882
            if amount:
1883
                actual_position = amount[0]['amount']
1884
            else:
1885
                actual_position = 0
1886
1887
            self.assertEqual(
1888
                actual_position, expected_positions[i],
1889
                "position for day={0} not equal, actual={1}, expected={2}".
1890
                format(i, actual_position, expected_positions[i]))
1891
1892
1893
class TestTradingAlgorithm(TestCase):
1894
    def setUp(self):
1895
        self.env = TradingEnvironment()
1896
        self.days = self.env.trading_days[:4]
1897
1898
    def test_analyze_called(self):
1899
        self.perf_ref = None
1900
1901
        def initialize(context):
1902
            pass
1903
1904
        def handle_data(context, data):
1905
            pass
1906
1907
        def analyze(context, perf):
1908
            self.perf_ref = perf
1909
1910
        algo = TradingAlgorithm(initialize=initialize, handle_data=handle_data,
1911
                                analyze=analyze)
1912
1913
        data_portal = FakeDataPortal()
1914
1915
        results = algo.run(data_portal)
1916
        self.assertIs(results, self.perf_ref)
1917