Completed
Pull Request — master (#858)
by Eddie
05:34 queued 02:25
created

tests.TestClosePosAlgo.test_auto_close_future()   B

Complexity

Conditions 1

Size

Total Lines 29

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 1
dl 0
loc 29
rs 8.8571
1
#
2
# Copyright 2014 Quantopian, Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
import datetime
16
from datetime import timedelta
17
from mock import MagicMock
18
from nose_parameterized import parameterized
19
from six.moves import range
20
from testfixtures import TempDirectory
21
from textwrap import dedent
22
from unittest import TestCase
23
24
import numpy as np
25
import pandas as pd
26
27
from zipline.utils.api_support import ZiplineAPI
28
from zipline.utils.control_flow import nullctx
29
from zipline.utils.test_utils import (
30
    setup_logger,
31
    teardown_logger,
32
    FakeDataPortal)
33
import zipline.utils.factory as factory
34
35
from zipline.errors import (
36
    OrderDuringInitialize,
37
    RegisterTradingControlPostInit,
38
    TradingControlViolation,
39
    AccountControlViolation,
40
    SymbolNotFound,
41
    RootSymbolNotFound,
42
    UnsupportedDatetimeFormat,
43
)
44
from zipline.test_algorithms import (
45
    access_account_in_init,
46
    access_portfolio_in_init,
47
    AmbitiousStopLimitAlgorithm,
48
    EmptyPositionsAlgorithm,
49
    InvalidOrderAlgorithm,
50
    RecordAlgorithm,
51
    FutureFlipAlgo,
52
    TestAlgorithm,
53
    TestOrderAlgorithm,
54
    TestOrderPercentAlgorithm,
55
    TestOrderStyleForwardingAlgorithm,
56
    TestOrderValueAlgorithm,
57
    TestRegisterTransformAlgorithm,
58
    TestTargetAlgorithm,
59
    TestTargetPercentAlgorithm,
60
    TestTargetValueAlgorithm,
61
    SetLongOnlyAlgorithm,
62
    SetAssetDateBoundsAlgorithm,
63
    SetMaxPositionSizeAlgorithm,
64
    SetMaxOrderCountAlgorithm,
65
    SetMaxOrderSizeAlgorithm,
66
    SetDoNotOrderListAlgorithm,
67
    SetMaxLeverageAlgorithm,
68
    api_algo,
69
    api_get_environment_algo,
70
    api_symbol_algo,
71
    call_all_order_methods,
72
    call_order_in_init,
73
    handle_data_api,
74
    handle_data_noop,
75
    initialize_api,
76
    initialize_noop,
77
    noop_algo,
78
    record_float_magic,
79
    record_variables,
80
)
81
from zipline.utils.context_tricks import CallbackManager
82
import zipline.utils.events
83
from zipline.utils.test_utils import to_utc
84
from zipline.assets import Equity
85
86
from zipline.finance.execution import LimitOrder
87
from zipline.finance.trading import SimulationParameters
88
from zipline.utils.api_support import set_algo_instance
89
from zipline.utils.events import DateRuleFactory, TimeRuleFactory, Always
90
from zipline.algorithm import TradingAlgorithm
91
from zipline.finance.trading import TradingEnvironment
92
from zipline.finance.commission import PerShare
93
94
from zipline.utils.test_utils import (
95
    create_data_portal,
96
    create_data_portal_from_trade_history
97
)
98
99
# Because test cases appear to reuse some resources.
100
101
102
_multiprocess_can_split_ = False
103
104
105
class TestRecordAlgorithm(TestCase):
106
107
    @classmethod
108
    def setUpClass(cls):
109
        cls.env = TradingEnvironment()
110
        cls.sids = [133]
111
        cls.env.write_data(equities_identifiers=cls.sids)
112
113
        cls.sim_params = factory.create_simulation_parameters(
114
            num_days=4,
115
            env=cls.env
116
        )
117
118
        cls.tempdir = TempDirectory()
119
120
        cls.data_portal = create_data_portal(
121
            cls.env,
122
            cls.tempdir,
123
            cls.sim_params,
124
            cls.sids
125
        )
126
127
    @classmethod
128
    def tearDownClass(cls):
129
        del cls.env
130
        cls.tempdir.cleanup()
131
132
    def test_record_incr(self):
133
        algo = RecordAlgorithm(sim_params=self.sim_params, env=self.env)
134
        output = algo.run(self.data_portal)
135
136
        np.testing.assert_array_equal(output['incr'].values,
137
                                      range(1, len(output) + 1))
138
        np.testing.assert_array_equal(output['name'].values,
139
                                      range(1, len(output) + 1))
140
        np.testing.assert_array_equal(output['name2'].values,
141
                                      [2] * len(output))
142
        np.testing.assert_array_equal(output['name3'].values,
143
                                      range(1, len(output) + 1))
144
145
146
class TestMiscellaneousAPI(TestCase):
147
148
    @classmethod
149
    def setUpClass(cls):
150
        cls.sids = [1, 2]
151
        cls.env = TradingEnvironment()
152
153
        metadata = {3: {'symbol': 'PLAY',
154
                        'start_date': '2002-01-01',
155
                        'end_date': '2004-01-01'},
156
                    4: {'symbol': 'PLAY',
157
                        'start_date': '2005-01-01',
158
                        'end_date': '2006-01-01'}}
159
160
        futures_metadata = {
161
            5: {
162
                'symbol': 'CLG06',
163
                'root_symbol': 'CL',
164
                'start_date': pd.Timestamp('2005-12-01', tz='UTC'),
165
                'notice_date': pd.Timestamp('2005-12-20', tz='UTC'),
166
                'expiration_date': pd.Timestamp('2006-01-20', tz='UTC')},
167
            6: {
168
                'root_symbol': 'CL',
169
                'symbol': 'CLK06',
170
                'start_date': pd.Timestamp('2005-12-01', tz='UTC'),
171
                'notice_date': pd.Timestamp('2006-03-20', tz='UTC'),
172
                'expiration_date': pd.Timestamp('2006-04-20', tz='UTC')},
173
            7: {
174
                'symbol': 'CLQ06',
175
                'root_symbol': 'CL',
176
                'start_date': pd.Timestamp('2005-12-01', tz='UTC'),
177
                'notice_date': pd.Timestamp('2006-06-20', tz='UTC'),
178
                'expiration_date': pd.Timestamp('2006-07-20', tz='UTC')},
179
            8: {
180
                'symbol': 'CLX06',
181
                'root_symbol': 'CL',
182
                'start_date': pd.Timestamp('2006-02-01', tz='UTC'),
183
                'notice_date': pd.Timestamp('2006-09-20', tz='UTC'),
184
                'expiration_date': pd.Timestamp('2006-10-20', tz='UTC')}
185
        }
186
        cls.env.write_data(equities_identifiers=cls.sids,
187
                           equities_data=metadata,
188
                           futures_data=futures_metadata)
189
190
        setup_logger(cls)
191
192
        cls.sim_params = factory.create_simulation_parameters(
193
            num_days=2,
194
            data_frequency='minute',
195
            emission_rate='daily',
196
            env=cls.env,
197
        )
198
199
        cls.temp_dir = TempDirectory()
200
201
        cls.data_portal = create_data_portal(
202
            cls.env,
203
            cls.temp_dir,
204
            cls.sim_params,
205
            cls.sids
206
        )
207
208
    @classmethod
209
    def tearDownClass(cls):
210
        del cls.env
211
        teardown_logger(cls)
212
        cls.temp_dir.cleanup()
213
214
    def test_zipline_api_resolves_dynamically(self):
215
        # Make a dummy algo.
216
        algo = TradingAlgorithm(
217
            initialize=lambda context: None,
218
            handle_data=lambda context, data: None,
219
            sim_params=self.sim_params,
220
        )
221
222
        # Verify that api methods get resolved dynamically by patching them out
223
        # and then calling them
224
        for method in algo.all_api_methods():
225
            name = method.__name__
226
            sentinel = object()
227
228
            def fake_method(*args, **kwargs):
229
                return sentinel
230
            setattr(algo, name, fake_method)
231
            with ZiplineAPI(algo):
232
                self.assertIs(sentinel, getattr(zipline.api, name)())
233
234
    def test_get_environment(self):
235
        expected_env = {
236
            'arena': 'backtest',
237
            'data_frequency': 'minute',
238
            'start': pd.Timestamp('2006-01-03 14:31:00+0000', tz='UTC'),
239
            'end': pd.Timestamp('2006-01-04 21:00:00+0000', tz='UTC'),
240
            'capital_base': 100000.0,
241
            'platform': 'zipline'
242
        }
243
244
        def initialize(algo):
245
            self.assertEqual('zipline', algo.get_environment())
246
            self.assertEqual(expected_env, algo.get_environment('*'))
247
248
        def handle_data(algo, data):
249
            pass
250
251
        algo = TradingAlgorithm(initialize=initialize,
252
                                handle_data=handle_data,
253
                                sim_params=self.sim_params,
254
                                env=self.env)
255
        algo.run(self.data_portal)
256
257
    def test_get_open_orders(self):
258
        def initialize(algo):
259
            algo.minute = 0
260
261
        def handle_data(algo, data):
262
            if algo.minute == 0:
263
264
                # Should be filled by the next minute
265
                algo.order(algo.sid(1), 1)
266
267
                # Won't be filled because the price is too low.
268
                algo.order(algo.sid(2), 1, style=LimitOrder(0.01))
269
                algo.order(algo.sid(2), 1, style=LimitOrder(0.01))
270
                algo.order(algo.sid(2), 1, style=LimitOrder(0.01))
271
272
                all_orders = algo.get_open_orders()
273
                self.assertEqual(list(all_orders.keys()), [1, 2])
274
275
                self.assertEqual(all_orders[1], algo.get_open_orders(1))
276
                self.assertEqual(len(all_orders[1]), 1)
277
278
                self.assertEqual(all_orders[2], algo.get_open_orders(2))
279
                self.assertEqual(len(all_orders[2]), 3)
280
281
            if algo.minute == 1:
282
                # First order should have filled.
283
                # Second order should still be open.
284
                all_orders = algo.get_open_orders()
285
                self.assertEqual(list(all_orders.keys()), [2])
286
287
                self.assertEqual([], algo.get_open_orders(1))
288
289
                orders_2 = algo.get_open_orders(2)
290
                self.assertEqual(all_orders[2], orders_2)
291
                self.assertEqual(len(all_orders[2]), 3)
292
293
                for order in orders_2:
294
                    algo.cancel_order(order)
295
296
                all_orders = algo.get_open_orders()
297
                self.assertEqual(all_orders, {})
298
299
            algo.minute += 1
300
301
        algo = TradingAlgorithm(initialize=initialize,
302
                                handle_data=handle_data,
303
                                sim_params=self.sim_params,
304
                                env=self.env)
305
        algo.run(self.data_portal)
306
307
    def test_schedule_function(self):
308
        date_rules = DateRuleFactory
309
        time_rules = TimeRuleFactory
310
311
        def incrementer(algo, data):
312
            algo.func_called += 1
313
            self.assertEqual(
314
                algo.get_datetime().time(),
315
                datetime.time(hour=14, minute=31),
316
            )
317
318
        def initialize(algo):
319
            algo.func_called = 0
320
            algo.days = 1
321
            algo.date = None
322
            algo.schedule_function(
323
                func=incrementer,
324
                date_rule=date_rules.every_day(),
325
                time_rule=time_rules.market_open(),
326
            )
327
328
        def handle_data(algo, data):
329
            if not algo.date:
330
                algo.date = algo.get_datetime().date()
331
332
            if algo.date < algo.get_datetime().date():
333
                algo.days += 1
334
                algo.date = algo.get_datetime().date()
335
336
        algo = TradingAlgorithm(
337
            initialize=initialize,
338
            handle_data=handle_data,
339
            sim_params=self.sim_params,
340
            env=self.env,
341
        )
342
        algo.run(self.data_portal)
343
344
        self.assertEqual(algo.func_called, algo.days)
345
346
    def test_event_context(self):
347
        expected_data = []
348
        collected_data_pre = []
349
        collected_data_post = []
350
        function_stack = []
351
352
        def pre(data):
353
            function_stack.append(pre)
354
            collected_data_pre.append(data)
355
356
        def post(data):
357
            function_stack.append(post)
358
            collected_data_post.append(data)
359
360
        def initialize(context):
361
            context.add_event(Always(), f)
362
            context.add_event(Always(), g)
363
364
        def handle_data(context, data):
365
            function_stack.append(handle_data)
366
            expected_data.append(data)
367
368
        def f(context, data):
369
            function_stack.append(f)
370
371
        def g(context, data):
372
            function_stack.append(g)
373
374
        algo = TradingAlgorithm(
375
            initialize=initialize,
376
            handle_data=handle_data,
377
            sim_params=self.sim_params,
378
            create_event_context=CallbackManager(pre, post),
379
            env=self.env,
380
        )
381
        algo.run(self.data_portal)
382
383
        self.assertEqual(len(expected_data), 780)
384
        self.assertEqual(collected_data_pre, expected_data)
385
        self.assertEqual(collected_data_post, expected_data)
386
387
        self.assertEqual(
388
            len(function_stack),
389
            780 * 5,
390
            'Incorrect number of functions called: %s != 780' %
391
            len(function_stack),
392
        )
393
        expected_functions = [pre, handle_data, f, g, post] * 780
394
        for n, (f, g) in enumerate(zip(function_stack, expected_functions)):
395
            self.assertEqual(
396
                f,
397
                g,
398
                'function at position %d was incorrect, expected %s but got %s'
399
                % (n, g.__name__, f.__name__),
400
            )
401
402
    @parameterized.expand([
403
        ('daily',),
404
        ('minute'),
405
    ])
406
    def test_schedule_function_rule_creation(self, mode):
407
        def nop(*args, **kwargs):
408
            return None
409
410
        self.sim_params.data_frequency = mode
411
        algo = TradingAlgorithm(
412
            initialize=nop,
413
            handle_data=nop,
414
            sim_params=self.sim_params,
415
            env=self.env,
416
        )
417
418
        # Schedule something for NOT Always.
419
        algo.schedule_function(nop, time_rule=zipline.utils.events.Never())
420
421
        event_rule = algo.event_manager._events[1].rule
422
423
        self.assertIsInstance(event_rule, zipline.utils.events.OncePerDay)
424
425
        inner_rule = event_rule.rule
426
        self.assertIsInstance(inner_rule, zipline.utils.events.ComposedRule)
427
428
        first = inner_rule.first
429
        second = inner_rule.second
430
        composer = inner_rule.composer
431
432
        self.assertIsInstance(first, zipline.utils.events.Always)
433
434
        if mode == 'daily':
435
            self.assertIsInstance(second, zipline.utils.events.Always)
436
        else:
437
            self.assertIsInstance(second, zipline.utils.events.Never)
438
439
        self.assertIs(composer, zipline.utils.events.ComposedRule.lazy_and)
440
441
    def test_asset_lookup(self):
442
443
        algo = TradingAlgorithm(env=self.env)
444
445
        # Test before either PLAY existed
446
        algo.sim_params.period_end = pd.Timestamp('2001-12-01', tz='UTC')
447
        with self.assertRaises(SymbolNotFound):
448
            algo.symbol('PLAY')
449
        with self.assertRaises(SymbolNotFound):
450
            algo.symbols('PLAY')
451
452
        # Test when first PLAY exists
453
        algo.sim_params.period_end = pd.Timestamp('2002-12-01', tz='UTC')
454
        list_result = algo.symbols('PLAY')
455
        self.assertEqual(3, list_result[0])
456
457
        # Test after first PLAY ends
458
        algo.sim_params.period_end = pd.Timestamp('2004-12-01', tz='UTC')
459
        self.assertEqual(3, algo.symbol('PLAY'))
460
461
        # Test after second PLAY begins
462
        algo.sim_params.period_end = pd.Timestamp('2005-12-01', tz='UTC')
463
        self.assertEqual(4, algo.symbol('PLAY'))
464
465
        # Test after second PLAY ends
466
        algo.sim_params.period_end = pd.Timestamp('2006-12-01', tz='UTC')
467
        self.assertEqual(4, algo.symbol('PLAY'))
468
        list_result = algo.symbols('PLAY')
469
        self.assertEqual(4, list_result[0])
470
471
        # Test lookup SID
472
        self.assertIsInstance(algo.sid(3), Equity)
473
        self.assertIsInstance(algo.sid(4), Equity)
474
475
        # Supplying a non-string argument to symbol()
476
        # should result in a TypeError.
477
        with self.assertRaises(TypeError):
478
            algo.symbol(1)
479
480
        with self.assertRaises(TypeError):
481
            algo.symbol((1,))
482
483
        with self.assertRaises(TypeError):
484
            algo.symbol({1})
485
486
        with self.assertRaises(TypeError):
487
            algo.symbol([1])
488
489
        with self.assertRaises(TypeError):
490
            algo.symbol({'foo': 'bar'})
491
492
    def test_future_symbol(self):
493
        """ Tests the future_symbol API function.
494
        """
495
        algo = TradingAlgorithm(env=self.env)
496
        algo.datetime = pd.Timestamp('2006-12-01', tz='UTC')
497
498
        # Check that we get the correct fields for the CLG06 symbol
499
        cl = algo.future_symbol('CLG06')
500
        self.assertEqual(cl.sid, 5)
501
        self.assertEqual(cl.symbol, 'CLG06')
502
        self.assertEqual(cl.root_symbol, 'CL')
503
        self.assertEqual(cl.start_date, pd.Timestamp('2005-12-01', tz='UTC'))
504
        self.assertEqual(cl.notice_date, pd.Timestamp('2005-12-20', tz='UTC'))
505
        self.assertEqual(cl.expiration_date,
506
                         pd.Timestamp('2006-01-20', tz='UTC'))
507
508
        with self.assertRaises(SymbolNotFound):
509
            algo.future_symbol('')
510
511
        with self.assertRaises(SymbolNotFound):
512
            algo.future_symbol('PLAY')
513
514
        with self.assertRaises(SymbolNotFound):
515
            algo.future_symbol('FOOBAR')
516
517
        # Supplying a non-string argument to future_symbol()
518
        # should result in a TypeError.
519
        with self.assertRaises(TypeError):
520
            algo.future_symbol(1)
521
522
        with self.assertRaises(TypeError):
523
            algo.future_symbol((1,))
524
525
        with self.assertRaises(TypeError):
526
            algo.future_symbol({1})
527
528
        with self.assertRaises(TypeError):
529
            algo.future_symbol([1])
530
531
        with self.assertRaises(TypeError):
532
            algo.future_symbol({'foo': 'bar'})
533
534
    def test_future_chain(self):
535
        """ Tests the future_chain API function.
536
        """
537
        algo = TradingAlgorithm(env=self.env)
538
        algo.datetime = pd.Timestamp('2006-12-01', tz='UTC')
539
540
        # Check that the fields of the FutureChain object are set correctly
541
        cl = algo.future_chain('CL')
542
        self.assertEqual(cl.root_symbol, 'CL')
543
        self.assertEqual(cl.as_of_date, algo.datetime)
544
545
        # Check the fields are set correctly if an as_of_date is supplied
546
        as_of_date = pd.Timestamp('1952-08-11', tz='UTC')
547
548
        cl = algo.future_chain('CL', as_of_date=as_of_date)
549
        self.assertEqual(cl.root_symbol, 'CL')
550
        self.assertEqual(cl.as_of_date, as_of_date)
551
552
        cl = algo.future_chain('CL', as_of_date='1952-08-11')
553
        self.assertEqual(cl.root_symbol, 'CL')
554
        self.assertEqual(cl.as_of_date, as_of_date)
555
556
        # Check that weird capitalization is corrected
557
        cl = algo.future_chain('cL')
558
        self.assertEqual(cl.root_symbol, 'CL')
559
560
        cl = algo.future_chain('cl')
561
        self.assertEqual(cl.root_symbol, 'CL')
562
563
        # Check that invalid root symbols raise RootSymbolNotFound
564
        with self.assertRaises(RootSymbolNotFound):
565
            algo.future_chain('CLZ')
566
567
        with self.assertRaises(RootSymbolNotFound):
568
            algo.future_chain('')
569
570
        # Check that invalid dates raise UnsupportedDatetimeFormat
571
        with self.assertRaises(UnsupportedDatetimeFormat):
572
            algo.future_chain('CL', 'my_finger_slipped')
573
574
        with self.assertRaises(UnsupportedDatetimeFormat):
575
            algo.future_chain('CL', '2015-09-')
576
577
        # Supplying a non-string argument to future_chain()
578
        # should result in a TypeError.
579
        with self.assertRaises(TypeError):
580
            algo.future_chain(1)
581
582
        with self.assertRaises(TypeError):
583
            algo.future_chain((1,))
584
585
        with self.assertRaises(TypeError):
586
            algo.future_chain({1})
587
588
        with self.assertRaises(TypeError):
589
            algo.future_chain([1])
590
591
        with self.assertRaises(TypeError):
592
            algo.future_chain({'foo': 'bar'})
593
594
    def test_set_symbol_lookup_date(self):
595
        """
596
        Test the set_symbol_lookup_date API method.
597
        """
598
        # Note we start sid enumeration at i+3 so as not to
599
        # collide with sids [1, 2] added in the setUp() method.
600
        dates = pd.date_range('2013-01-01', freq='2D', periods=2, tz='UTC')
601
        # Create two assets with the same symbol but different
602
        # non-overlapping date ranges.
603
        metadata = pd.DataFrame.from_records(
604
            [
605
                {
606
                    'sid': i + 3,
607
                    'symbol': 'DUP',
608
                    'start_date': date.value,
609
                    'end_date': (date + timedelta(days=1)).value,
610
                }
611
                for i, date in enumerate(dates)
612
            ]
613
        )
614
        env = TradingEnvironment()
615
        env.write_data(equities_df=metadata)
616
        algo = TradingAlgorithm(env=env)
617
618
        # Set the period end to a date after the period end
619
        # dates for our assets.
620
        algo.sim_params.period_end = pd.Timestamp('2015-01-01', tz='UTC')
621
622
        # With no symbol lookup date set, we will use the period end date
623
        # for the as_of_date, resulting here in the asset with the earlier
624
        # start date being returned.
625
        result = algo.symbol('DUP')
626
        self.assertEqual(result.symbol, 'DUP')
627
628
        # By first calling set_symbol_lookup_date, the relevant asset
629
        # should be returned by lookup_symbol
630
        for i, date in enumerate(dates):
631
            algo.set_symbol_lookup_date(date)
632
            result = algo.symbol('DUP')
633
            self.assertEqual(result.symbol, 'DUP')
634
            self.assertEqual(result.sid, i + 3)
635
636
        with self.assertRaises(UnsupportedDatetimeFormat):
637
            algo.set_symbol_lookup_date('foobar')
638
639
640
class TestTransformAlgorithm(TestCase):
641
642
    @classmethod
643
    def setUpClass(cls):
644
        setup_logger(cls)
645
        futures_metadata = {3: {'contract_multiplier': 10}}
646
        cls.env = TradingEnvironment()
647
        cls.sim_params = factory.create_simulation_parameters(num_days=4,
648
                                                              env=cls.env)
649
        cls.sids = [0, 1, 133]
650
        cls.tempdir = TempDirectory()
651
652
        equities_metadata = {}
653
654
        for sid in cls.sids:
655
            equities_metadata[sid] = {
656
                'start_date': cls.sim_params.period_start,
657
                'end_date': cls.sim_params.period_end
658
            }
659
660
        cls.env.write_data(equities_data=equities_metadata,
661
                           futures_data=futures_metadata)
662
663
        trades_by_sid = {}
664
        for sid in cls.sids:
665
            trades_by_sid[sid] = factory.create_trade_history(
666
                sid,
667
                [10.0, 10.0, 11.0, 11.0],
668
                [100, 100, 100, 300],
669
                timedelta(days=1),
670
                cls.sim_params,
671
                cls.env
672
            )
673
674
        cls.data_portal = create_data_portal_from_trade_history(cls.env,
675
                                                                cls.tempdir,
676
                                                                cls.sim_params,
677
                                                                trades_by_sid)
678
679
        cls.data_portal.current_day = cls.sim_params.trading_days[0]
680
681
    @classmethod
682
    def tearDownClass(cls):
683
        teardown_logger(cls)
684
        del cls.env
685
        cls.tempdir.cleanup()
686
687
    def test_invalid_order_parameters(self):
688
        algo = InvalidOrderAlgorithm(
689
            sids=[133],
690
            sim_params=self.sim_params,
691
            env=self.env,
692
        )
693
        algo.run(self.data_portal)
694
695
    def test_run_twice(self):
696
        algo1 = TestRegisterTransformAlgorithm(
697
            sim_params=self.sim_params,
698
            sids=[0, 1]
699
        )
700
701
        res1 = algo1.run(self.data_portal)
702
703
        # Create a new trading algorithm, which will
704
        # use the newly instantiated environment.
705
        algo2 = TestRegisterTransformAlgorithm(
706
            sim_params=self.sim_params,
707
            sids=[0, 1]
708
        )
709
710
        res2 = algo2.run(self.data_portal)
711
712
        # FIXME I think we are getting Nans due to fixed benchmark,
713
        # so dropping them for now.
714
        res1 = res1.fillna(method='ffill')
715
        res2 = res2.fillna(method='ffill')
716
717
        np.testing.assert_array_equal(res1, res2)
718
719
    def test_data_frequency_setting(self):
720
        self.sim_params.data_frequency = 'daily'
721
722
        sim_params = factory.create_simulation_parameters(
723
            num_days=4, env=self.env, data_frequency='daily')
724
725
        algo = TestRegisterTransformAlgorithm(
726
            sim_params=sim_params,
727
            env=self.env,
728
        )
729
        self.assertEqual(algo.sim_params.data_frequency, 'daily')
730
731
        sim_params = factory.create_simulation_parameters(
732
            num_days=4, env=self.env, data_frequency='minute')
733
734
        algo = TestRegisterTransformAlgorithm(
735
            sim_params=sim_params,
736
            env=self.env,
737
        )
738
        self.assertEqual(algo.sim_params.data_frequency, 'minute')
739
740
    @parameterized.expand([
741
        (TestOrderAlgorithm,),
742
        (TestOrderValueAlgorithm,),
743
        (TestTargetAlgorithm,),
744
        (TestOrderPercentAlgorithm,),
745
        (TestTargetPercentAlgorithm,),
746
        (TestTargetValueAlgorithm,),
747
    ])
748
    def test_order_methods(self, algo_class):
749
        algo = algo_class(
750
            sim_params=self.sim_params,
751
            env=self.env,
752
        )
753
        algo.run(self.data_portal)
754
755
    @parameterized.expand([
756
        (TestOrderAlgorithm,),
757
        (TestOrderValueAlgorithm,),
758
        (TestTargetAlgorithm,),
759
        (TestOrderPercentAlgorithm,),
760
        (TestTargetValueAlgorithm,),
761
    ])
762
    def test_order_methods_for_future(self, algo_class):
763
        algo = algo_class(
764
            sim_params=self.sim_params,
765
            env=self.env,
766
        )
767
        algo.run(self.data_portal)
768
769
    @parameterized.expand([
770
        ("order",),
771
        ("order_value",),
772
        ("order_percent",),
773
        ("order_target",),
774
        ("order_target_percent",),
775
        ("order_target_value",),
776
    ])
777
    def test_order_method_style_forwarding(self, order_style):
778
        algo = TestOrderStyleForwardingAlgorithm(
779
            sim_params=self.sim_params,
780
            instant_fill=False,
781
            method_name=order_style,
782
            env=self.env
783
        )
784
        algo.run(self.data_portal)
785
786
    @parameterized.expand([
787
        (TestOrderAlgorithm,),
788
        (TestOrderValueAlgorithm,),
789
        (TestTargetAlgorithm,),
790
        (TestOrderPercentAlgorithm,)
791
    ])
792
    def test_minute_data(self, algo_class):
793
        tempdir = TempDirectory()
794
795
        try:
796
            env = TradingEnvironment()
797
798
            sim_params = SimulationParameters(
799
                period_start=pd.Timestamp('2002-1-2', tz='UTC'),
800
                period_end=pd.Timestamp('2002-1-4', tz='UTC'),
801
                capital_base=float("1.0e5"),
802
                data_frequency='minute',
803
                env=env
804
            )
805
806
            equities_metadata = {}
807
808
            for sid in [0, 1]:
809
                equities_metadata[sid] = {
810
                    'start_date': sim_params.period_start,
811
                    'end_date': sim_params.period_end + timedelta(days=1)
812
                }
813
814
            env.write_data(equities_data=equities_metadata)
815
816
            data_portal = create_data_portal(
817
                env,
818
                tempdir,
819
                sim_params,
820
                [0, 1]
821
            )
822
823
            algo = algo_class(sim_params=sim_params, env=env)
824
            algo.run(data_portal)
825
        finally:
826
            tempdir.cleanup()
827
828
829
class TestPositions(TestCase):
830
    @classmethod
831
    def setUpClass(cls):
832
        setup_logger(cls)
833
        cls.env = TradingEnvironment()
834
        cls.sim_params = factory.create_simulation_parameters(num_days=4,
835
                                                              env=cls.env)
836
837
        cls.sids = [1, 133]
838
        cls.tempdir = TempDirectory()
839
840
        equities_metadata = {}
841
842
        for sid in cls.sids:
843
            equities_metadata[sid] = {
844
                'start_date': cls.sim_params.period_start,
845
                'end_date': cls.sim_params.period_end
846
            }
847
848
        cls.env.write_data(equities_data=equities_metadata)
849
850
        cls.data_portal = create_data_portal(
851
            cls.env,
852
            cls.tempdir,
853
            cls.sim_params,
854
            cls.sids
855
        )
856
857
    @classmethod
858
    def tearDownClass(cls):
859
        teardown_logger(cls)
860
        cls.tempdir.cleanup()
861
862
    def test_empty_portfolio(self):
863
        algo = EmptyPositionsAlgorithm(self.sids,
864
                                       sim_params=self.sim_params,
865
                                       env=self.env)
866
        daily_stats = algo.run(self.data_portal)
867
868
        expected_position_count = [
869
            0,  # Before entering the first position
870
            2,  # After entering, exiting on this date
871
            0,  # After exiting
872
            0,
873
        ]
874
875
        for i, expected in enumerate(expected_position_count):
876
            self.assertEqual(daily_stats.ix[i]['num_positions'],
877
                             expected)
878
879
    def test_noop_orders(self):
880
        algo = AmbitiousStopLimitAlgorithm(sid=1,
881
                                           sim_params=self.sim_params,
882
                                           env=self.env)
883
        daily_stats = algo.run(self.data_portal)
884
885
        # Verify that positions are empty for all dates.
886
        empty_positions = daily_stats.positions.map(lambda x: len(x) == 0)
887
        self.assertTrue(empty_positions.all())
888
889
890
class TestAlgoScript(TestCase):
891
892
    @classmethod
893
    def setUpClass(cls):
894
        setup_logger(cls)
895
        cls.env = TradingEnvironment()
896
        cls.sim_params = factory.create_simulation_parameters(num_days=251,
897
                                                              env=cls.env)
898
899
        cls.sids = [0, 1, 3, 133]
900
        cls.tempdir = TempDirectory()
901
902
        equities_metadata = {}
903
904
        for sid in cls.sids:
905
            equities_metadata[sid] = {
906
                'start_date': cls.sim_params.period_start,
907
                'end_date': cls.sim_params.period_end
908
            }
909
910
            if sid == 3:
911
                equities_metadata[sid]["symbol"] = "TEST"
912
                equities_metadata[sid]["asset_type"] = "equity"
913
914
        cls.env.write_data(equities_data=equities_metadata)
915
916
        days = 251
917
918
        trades_by_sid = {
919
            0: factory.create_trade_history(
920
                0,
921
                [10.0] * days,
922
                [100] * days,
923
                timedelta(days=1),
924
                cls.sim_params,
925
                cls.env),
926
            3: factory.create_trade_history(
927
                3,
928
                [10.0] * days,
929
                [100] * days,
930
                timedelta(days=1),
931
                cls.sim_params,
932
                cls.env)
933
        }
934
935
        cls.data_portal = create_data_portal_from_trade_history(cls.env,
936
                                                                cls.tempdir,
937
                                                                cls.sim_params,
938
                                                                trades_by_sid)
939
940
        cls.zipline_test_config = {
941
            'sid': 0,
942
        }
943
944
    @classmethod
945
    def tearDownClass(cls):
946
        del cls.env
947
        cls.tempdir.cleanup()
948
        teardown_logger(cls)
949
950
    def test_noop(self):
951
        algo = TradingAlgorithm(initialize=initialize_noop,
952
                                handle_data=handle_data_noop)
953
        algo.run(self.data_portal)
954
955
    def test_noop_string(self):
956
        algo = TradingAlgorithm(script=noop_algo)
957
        algo.run(self.data_portal)
958
959
    def test_api_calls(self):
960
        algo = TradingAlgorithm(initialize=initialize_api,
961
                                handle_data=handle_data_api,
962
                                env=self.env)
963
        algo.run(self.data_portal)
964
965
    def test_api_calls_string(self):
966
        algo = TradingAlgorithm(script=api_algo, env=self.env)
967
        algo.run(self.data_portal)
968
969
    def test_api_get_environment(self):
970
        platform = 'zipline'
971
        algo = TradingAlgorithm(script=api_get_environment_algo,
972
                                platform=platform)
973
        algo.run(self.data_portal)
974
        self.assertEqual(algo.environment, platform)
975
976
    def test_api_symbol(self):
977
        algo = TradingAlgorithm(script=api_symbol_algo,
978
                                env=self.env,
979
                                sim_params=self.sim_params)
980
        algo.run(self.data_portal)
981
982
    def test_fixed_slippage(self):
983
        # verify order -> transaction -> portfolio position.
984
        # --------------
985
        test_algo = TradingAlgorithm(
986
            script="""
987
from zipline.api import (slippage,
988
                         commission,
989
                         set_slippage,
990
                         set_commission,
991
                         order,
992
                         record,
993
                         sid)
994
995
def initialize(context):
996
    model = slippage.FixedSlippage(spread=0.10)
997
    set_slippage(model)
998
    set_commission(commission.PerTrade(100.00))
999
    context.count = 1
1000
    context.incr = 0
1001
1002
def handle_data(context, data):
1003
    if context.incr < context.count:
1004
        order(sid(0), -1000)
1005
    record(price=data[0].price)
1006
1007
    context.incr += 1""",
1008
            sim_params=self.sim_params,
1009
            env=self.env,
1010
        )
1011
        results = test_algo.run(self.data_portal)
1012
1013
        # flatten the list of txns
1014
        all_txns = [val for sublist in results["transactions"].tolist()
1015
                    for val in sublist]
1016
1017
        self.assertEqual(len(all_txns), 1)
1018
        txn = all_txns[0]
1019
1020
        self.assertEqual(100.0, txn["commission"])
1021
        expected_spread = 0.05
1022
        expected_commish = 0.10
1023
        expected_price = test_algo.recorded_vars["price"] - expected_spread \
1024
            - expected_commish
1025
1026
        self.assertEqual(expected_price, txn['price'])
1027
1028
    def test_volshare_slippage(self):
1029
        tempdir = TempDirectory()
1030
        try:
1031
            # verify order -> transaction -> portfolio position.
1032
            # --------------
1033
            test_algo = TradingAlgorithm(
1034
                script="""
1035
from zipline.api import *
1036
1037
def initialize(context):
1038
    model = slippage.VolumeShareSlippage(
1039
                            volume_limit=.3,
1040
                            price_impact=0.05
1041
                       )
1042
    set_slippage(model)
1043
    set_commission(commission.PerShare(0.02))
1044
    context.count = 2
1045
    context.incr = 0
1046
1047
def handle_data(context, data):
1048
    if context.incr < context.count:
1049
        # order small lots to be sure the
1050
        # order will fill in a single transaction
1051
        order(sid(0), 5000)
1052
    record(price=data[0].price)
1053
    record(volume=data[0].volume)
1054
    record(incr=context.incr)
1055
    context.incr += 1
1056
    """,
1057
                sim_params=self.sim_params,
1058
                env=self.env,
1059
            )
1060
            set_algo_instance(test_algo)
1061
            trades = factory.create_daily_trade_source(
1062
                [0], self.sim_params, self.env)
1063
            data_portal = create_data_portal_from_trade_history(
1064
                self.env, tempdir, self.sim_params, {0: trades})
1065
            results = test_algo.run(data_portal=data_portal)
1066
1067
            all_txns = [
1068
                val for sublist in results["transactions"].tolist()
1069
                for val in sublist]
1070
1071
            self.assertEqual(len(all_txns), 67)
1072
1073
            per_share_commish = 0.02
1074
            first_txn = all_txns[0]
1075
            commish = first_txn["amount"] * per_share_commish
1076
            self.assertEqual(commish, first_txn["commission"])
1077
            self.assertEqual(2.029, first_txn["price"])
1078
        finally:
1079
            tempdir.cleanup()
1080
1081
    def test_algo_record_vars(self):
1082
        test_algo = TradingAlgorithm(
1083
            script=record_variables,
1084
            sim_params=self.sim_params,
1085
            env=self.env,
1086
        )
1087
        results = test_algo.run(self.data_portal)
1088
1089
        for i in range(1, 252):
1090
            self.assertEqual(results.iloc[i-1]["incr"], i)
1091
1092
    def test_algo_record_allow_mock(self):
1093
        """
1094
        Test that values from "MagicMock"ed methods can be passed to record.
1095
1096
        Relevant for our basic/validation and methods like history, which
1097
        will end up returning a MagicMock instead of a DataFrame.
1098
        """
1099
        test_algo = TradingAlgorithm(
1100
            script=record_variables,
1101
            sim_params=self.sim_params,
1102
        )
1103
        set_algo_instance(test_algo)
1104
1105
        test_algo.record(foo=MagicMock())
1106
1107
    def test_algo_record_nan(self):
1108
        test_algo = TradingAlgorithm(
1109
            script=record_float_magic % 'nan',
1110
            sim_params=self.sim_params,
1111
            env=self.env,
1112
        )
1113
        results = test_algo.run(self.data_portal)
1114
1115
        for i in range(1, 252):
1116
            self.assertTrue(np.isnan(results.iloc[i-1]["data"]))
1117
1118
    def test_order_methods(self):
1119
        """
1120
        Only test that order methods can be called without error.
1121
        Correct filling of orders is tested in zipline.
1122
        """
1123
        test_algo = TradingAlgorithm(
1124
            script=call_all_order_methods,
1125
            sim_params=self.sim_params,
1126
            env=self.env,
1127
        )
1128
        test_algo.run(self.data_portal)
1129
1130
    def test_order_in_init(self):
1131
        """
1132
        Test that calling order in initialize
1133
        will raise an error.
1134
        """
1135
        with self.assertRaises(OrderDuringInitialize):
1136
            test_algo = TradingAlgorithm(
1137
                script=call_order_in_init,
1138
                sim_params=self.sim_params,
1139
                env=self.env,
1140
            )
1141
            test_algo.run(self.data_portal)
1142
1143
    def test_portfolio_in_init(self):
1144
        """
1145
        Test that accessing portfolio in init doesn't break.
1146
        """
1147
        test_algo = TradingAlgorithm(
1148
            script=access_portfolio_in_init,
1149
            sim_params=self.sim_params,
1150
            env=self.env,
1151
        )
1152
        test_algo.run(self.data_portal)
1153
1154
    def test_account_in_init(self):
1155
        """
1156
        Test that accessing account in init doesn't break.
1157
        """
1158
        test_algo = TradingAlgorithm(
1159
            script=access_account_in_init,
1160
            sim_params=self.sim_params,
1161
            env=self.env,
1162
        )
1163
        test_algo.run(self.data_portal)
1164
1165
1166
class TestGetDatetime(TestCase):
1167
1168
    @classmethod
1169
    def setUpClass(cls):
1170
        cls.env = TradingEnvironment()
1171
        cls.env.write_data(equities_identifiers=[0, 1])
1172
1173
        setup_logger(cls)
1174
1175
        cls.sim_params = factory.create_simulation_parameters(
1176
            data_frequency='minute',
1177
            env=cls.env,
1178
            start=to_utc('2014-01-02 9:31'),
1179
            end=to_utc('2014-01-03 9:31')
1180
        )
1181
1182
        cls.tempdir = TempDirectory()
1183
1184
        cls.data_portal = create_data_portal(
1185
            cls.env,
1186
            cls.tempdir,
1187
            cls.sim_params,
1188
            [1]
1189
        )
1190
1191
    @classmethod
1192
    def tearDownClass(cls):
1193
        del cls.env
1194
        teardown_logger(cls)
1195
        cls.tempdir.cleanup()
1196
1197
    @parameterized.expand(
1198
        [
1199
            ('default', None,),
1200
            ('utc', 'UTC',),
1201
            ('us_east', 'US/Eastern',),
1202
        ]
1203
    )
1204
    def test_get_datetime(self, name, tz):
1205
        algo = dedent(
1206
            """
1207
            import pandas as pd
1208
            from zipline.api import get_datetime
1209
1210
            def initialize(context):
1211
                context.tz = {tz} or 'UTC'
1212
                context.first_bar = True
1213
1214
            def handle_data(context, data):
1215
                if context.first_bar:
1216
                    dt = get_datetime({tz})
1217
                    if dt.tz.zone != context.tz:
1218
                        raise ValueError("Mismatched Zone")
1219
                    elif dt.tz_convert("US/Eastern").hour != 9:
1220
                        raise ValueError("Mismatched Hour")
1221
                    elif dt.tz_convert("US/Eastern").minute != 31:
1222
                        raise ValueError("Mismatched Minute")
1223
                context.first_bar = False
1224
            """.format(tz=repr(tz))
1225
        )
1226
1227
        algo = TradingAlgorithm(
1228
            script=algo,
1229
            sim_params=self.sim_params,
1230
            env=self.env,
1231
        )
1232
        algo.run(self.data_portal)
1233
        self.assertFalse(algo.first_bar)
1234
1235
1236
class TestTradingControls(TestCase):
1237
1238
    @classmethod
1239
    def setUpClass(cls):
1240
        cls.sid = 133
1241
        cls.env = TradingEnvironment()
1242
        cls.sim_params = factory.create_simulation_parameters(num_days=4,
1243
                                                              env=cls.env)
1244
1245
        cls.env.write_data(equities_data={
1246
            133: {
1247
                'start_date': cls.sim_params.period_start,
1248
                'end_date': cls.sim_params.period_end
1249
            }
1250
        })
1251
1252
        cls.tempdir = TempDirectory()
1253
1254
        cls.data_portal = create_data_portal(
1255
            cls.env,
1256
            cls.tempdir,
1257
            cls.sim_params,
1258
            [cls.sid]
1259
        )
1260
1261
    @classmethod
1262
    def tearDownClass(cls):
1263
        del cls.env
1264
        cls.tempdir.cleanup()
1265
1266
    def _check_algo(self,
1267
                    algo,
1268
                    handle_data,
1269
                    expected_order_count,
1270
                    expected_exc):
1271
1272
        algo._handle_data = handle_data
1273
        with self.assertRaises(expected_exc) if expected_exc else nullctx():
1274
            algo.run(self.data_portal)
1275
        self.assertEqual(algo.order_count, expected_order_count)
1276
1277
    def check_algo_succeeds(self, algo, handle_data, order_count=4):
1278
        # Default for order_count assumes one order per handle_data call.
1279
        self._check_algo(algo, handle_data, order_count, None)
1280
1281
    def check_algo_fails(self, algo, handle_data, order_count):
1282
        self._check_algo(algo,
1283
                         handle_data,
1284
                         order_count,
1285
                         TradingControlViolation)
1286
1287
    def test_set_max_position_size(self):
1288
1289
        # Buy one share four times.  Should be fine.
1290
        def handle_data(algo, data):
1291
            algo.order(algo.sid(self.sid), 1)
1292
            algo.order_count += 1
1293
        algo = SetMaxPositionSizeAlgorithm(sid=self.sid,
1294
                                           max_shares=10,
1295
                                           max_notional=500.0,
1296
                                           sim_params=self.sim_params,
1297
                                           env=self.env)
1298
        self.check_algo_succeeds(algo, handle_data)
1299
1300
        # Buy three shares four times.  Should bail on the fourth before it's
1301
        # placed.
1302
        def handle_data(algo, data):
1303
            algo.order(algo.sid(self.sid), 3)
1304
            algo.order_count += 1
1305
1306
        algo = SetMaxPositionSizeAlgorithm(sid=self.sid,
1307
                                           max_shares=10,
1308
                                           max_notional=500.0,
1309
                                           sim_params=self.sim_params,
1310
                                           env=self.env)
1311
        self.check_algo_fails(algo, handle_data, 3)
1312
1313
        # Buy three shares four times. Should bail due to max_notional on the
1314
        # third attempt.
1315
        def handle_data(algo, data):
1316
            algo.order(algo.sid(self.sid), 3)
1317
            algo.order_count += 1
1318
1319
        algo = SetMaxPositionSizeAlgorithm(sid=self.sid,
1320
                                           max_shares=10,
1321
                                           max_notional=67.0,
1322
                                           sim_params=self.sim_params,
1323
                                           env=self.env)
1324
        self.check_algo_fails(algo, handle_data, 2)
1325
1326
        # Set the trading control to a different sid, then BUY ALL THE THINGS!.
1327
        # Should continue normally.
1328
        def handle_data(algo, data):
1329
            algo.order(algo.sid(self.sid), 10000)
1330
            algo.order_count += 1
1331
        algo = SetMaxPositionSizeAlgorithm(sid=self.sid + 1,
1332
                                           max_shares=10,
1333
                                           max_notional=67.0,
1334
                                           sim_params=self.sim_params,
1335
                                           env=self.env)
1336
        self.check_algo_succeeds(algo, handle_data)
1337
1338
        # Set the trading control sid to None, then BUY ALL THE THINGS!. Should
1339
        # fail because setting sid to None makes the control apply to all sids.
1340
        def handle_data(algo, data):
1341
            algo.order(algo.sid(self.sid), 10000)
1342
            algo.order_count += 1
1343
        algo = SetMaxPositionSizeAlgorithm(max_shares=10, max_notional=61.0,
1344
                                           sim_params=self.sim_params,
1345
                                           env=self.env)
1346
        self.check_algo_fails(algo, handle_data, 0)
1347
1348
    def test_set_do_not_order_list(self):
1349
        # set the restricted list to be the sid, and fail.
1350
        algo = SetDoNotOrderListAlgorithm(
1351
            sid=self.sid,
1352
            restricted_list=[self.sid],
1353
            sim_params=self.sim_params,
1354
            env=self.env,
1355
        )
1356
1357
        def handle_data(algo, data):
1358
            algo.order(algo.sid(self.sid), 100)
1359
            algo.order_count += 1
1360
1361
        self.check_algo_fails(algo, handle_data, 0)
1362
1363
        # set the restricted list to exclude the sid, and succeed
1364
        algo = SetDoNotOrderListAlgorithm(
1365
            sid=self.sid,
1366
            restricted_list=[134, 135, 136],
1367
            sim_params=self.sim_params,
1368
            env=self.env,
1369
        )
1370
1371
        def handle_data(algo, data):
1372
            algo.order(algo.sid(self.sid), 100)
1373
            algo.order_count += 1
1374
1375
        self.check_algo_succeeds(algo, handle_data)
1376
1377
    def test_set_max_order_size(self):
1378
1379
        # Buy one share.
1380
        def handle_data(algo, data):
1381
            algo.order(algo.sid(self.sid), 1)
1382
            algo.order_count += 1
1383
        algo = SetMaxOrderSizeAlgorithm(sid=self.sid,
1384
                                        max_shares=10,
1385
                                        max_notional=500.0,
1386
                                        sim_params=self.sim_params,
1387
                                        env=self.env)
1388
        self.check_algo_succeeds(algo, handle_data)
1389
1390
        # Buy 1, then 2, then 3, then 4 shares.  Bail on the last attempt
1391
        # because we exceed shares.
1392
        def handle_data(algo, data):
1393
            algo.order(algo.sid(self.sid), algo.order_count + 1)
1394
            algo.order_count += 1
1395
1396
        algo = SetMaxOrderSizeAlgorithm(sid=self.sid,
1397
                                        max_shares=3,
1398
                                        max_notional=500.0,
1399
                                        sim_params=self.sim_params,
1400
                                        env=self.env)
1401
        self.check_algo_fails(algo, handle_data, 3)
1402
1403
        # Buy 1, then 2, then 3, then 4 shares.  Bail on the last attempt
1404
        # because we exceed notional.
1405
        def handle_data(algo, data):
1406
            algo.order(algo.sid(self.sid), algo.order_count + 1)
1407
            algo.order_count += 1
1408
1409
        algo = SetMaxOrderSizeAlgorithm(sid=self.sid,
1410
                                        max_shares=10,
1411
                                        max_notional=40.0,
1412
                                        sim_params=self.sim_params,
1413
                                        env=self.env)
1414
        self.check_algo_fails(algo, handle_data, 3)
1415
1416
        # Set the trading control to a different sid, then BUY ALL THE THINGS!.
1417
        # Should continue normally.
1418
        def handle_data(algo, data):
1419
            algo.order(algo.sid(self.sid), 10000)
1420
            algo.order_count += 1
1421
        algo = SetMaxOrderSizeAlgorithm(sid=self.sid + 1,
1422
                                        max_shares=1,
1423
                                        max_notional=1.0,
1424
                                        sim_params=self.sim_params,
1425
                                        env=self.env)
1426
        self.check_algo_succeeds(algo, handle_data)
1427
1428
        # Set the trading control sid to None, then BUY ALL THE THINGS!.
1429
        # Should fail because not specifying a sid makes the trading control
1430
        # apply to all sids.
1431
        def handle_data(algo, data):
1432
            algo.order(algo.sid(self.sid), 10000)
1433
            algo.order_count += 1
1434
        algo = SetMaxOrderSizeAlgorithm(max_shares=1,
1435
                                        max_notional=1.0,
1436
                                        sim_params=self.sim_params,
1437
                                        env=self.env)
1438
        self.check_algo_fails(algo, handle_data, 0)
1439
1440
    def test_set_max_order_count(self):
1441
        tempdir = TempDirectory()
1442
        try:
1443
            env = TradingEnvironment()
1444
            sim_params = factory.create_simulation_parameters(
1445
                num_days=4, env=env, data_frequency="minute")
1446
1447
            env.write_data(equities_data={
1448
                1: {
1449
                    'start_date': sim_params.period_start,
1450
                    'end_date': sim_params.period_end + timedelta(days=1)
1451
                }
1452
            })
1453
1454
            data_portal = create_data_portal(
1455
                env,
1456
                tempdir,
1457
                sim_params,
1458
                [1]
1459
            )
1460
1461
            def handle_data(algo, data):
1462
                for i in range(5):
1463
                    algo.order(algo.sid(1), 1)
1464
                    algo.order_count += 1
1465
1466
            algo = SetMaxOrderCountAlgorithm(3, sim_params=sim_params,
1467
                                             env=env)
1468
            with self.assertRaises(TradingControlViolation):
1469
                algo._handle_data = handle_data
1470
                algo.run(data_portal)
1471
1472
            self.assertEqual(algo.order_count, 3)
1473
1474
            # This time, order 5 times twice in a single day. The last order
1475
            # of the second batch should fail.
1476
            def handle_data2(algo, data):
1477
                if algo.minute_count == 0 or algo.minute_count == 100:
1478
                    for i in range(5):
1479
                        algo.order(algo.sid(1), 1)
1480
                        algo.order_count += 1
1481
1482
                algo.minute_count += 1
1483
1484
            algo = SetMaxOrderCountAlgorithm(9, sim_params=sim_params,
1485
                                             env=env)
1486
            with self.assertRaises(TradingControlViolation):
1487
                algo._handle_data = handle_data2
1488
                algo.run(data_portal)
1489
1490
            self.assertEqual(algo.order_count, 9)
1491
1492
            def handle_data3(algo, data):
1493
                if (algo.minute_count % 390) == 0:
1494
                    for i in range(5):
1495
                        algo.order(algo.sid(1), 1)
1496
                        algo.order_count += 1
1497
1498
                algo.minute_count += 1
1499
1500
            # Only 5 orders are placed per day, so this should pass even
1501
            # though in total more than 20 orders are placed.
1502
            algo = SetMaxOrderCountAlgorithm(5, sim_params=sim_params,
1503
                                             env=env)
1504
            algo._handle_data = handle_data3
1505
            algo.run(data_portal)
1506
        finally:
1507
            tempdir.cleanup()
1508
1509
    def test_long_only(self):
1510
        # Sell immediately -> fail immediately.
1511
        def handle_data(algo, data):
1512
            algo.order(algo.sid(self.sid), -1)
1513
            algo.order_count += 1
1514
        algo = SetLongOnlyAlgorithm(sim_params=self.sim_params, env=self.env)
1515
        self.check_algo_fails(algo, handle_data, 0)
1516
1517
        # Buy on even days, sell on odd days.  Never takes a short position, so
1518
        # should succeed.
1519
        def handle_data(algo, data):
1520
            if (algo.order_count % 2) == 0:
1521
                algo.order(algo.sid(self.sid), 1)
1522
            else:
1523
                algo.order(algo.sid(self.sid), -1)
1524
            algo.order_count += 1
1525
        algo = SetLongOnlyAlgorithm(sim_params=self.sim_params, env=self.env)
1526
        self.check_algo_succeeds(algo, handle_data)
1527
1528
        # Buy on first three days, then sell off holdings.  Should succeed.
1529
        def handle_data(algo, data):
1530
            amounts = [1, 1, 1, -3]
1531
            algo.order(algo.sid(self.sid), amounts[algo.order_count])
1532
            algo.order_count += 1
1533
        algo = SetLongOnlyAlgorithm(sim_params=self.sim_params, env=self.env)
1534
        self.check_algo_succeeds(algo, handle_data)
1535
1536
        # Buy on first three days, then sell off holdings plus an extra share.
1537
        # Should fail on the last sale.
1538
        def handle_data(algo, data):
1539
            amounts = [1, 1, 1, -4]
1540
            algo.order(algo.sid(self.sid), amounts[algo.order_count])
1541
            algo.order_count += 1
1542
        algo = SetLongOnlyAlgorithm(sim_params=self.sim_params, env=self.env)
1543
        self.check_algo_fails(algo, handle_data, 3)
1544
1545
    def test_register_post_init(self):
1546
1547
        def initialize(algo):
1548
            algo.initialized = True
1549
1550
        def handle_data(algo, data):
1551
            with self.assertRaises(RegisterTradingControlPostInit):
1552
                algo.set_max_position_size(self.sid, 1, 1)
1553
            with self.assertRaises(RegisterTradingControlPostInit):
1554
                algo.set_max_order_size(self.sid, 1, 1)
1555
            with self.assertRaises(RegisterTradingControlPostInit):
1556
                algo.set_max_order_count(1)
1557
            with self.assertRaises(RegisterTradingControlPostInit):
1558
                algo.set_long_only()
1559
1560
        algo = TradingAlgorithm(initialize=initialize,
1561
                                handle_data=handle_data,
1562
                                sim_params=self.sim_params,
1563
                                env=self.env)
1564
        algo.run(self.data_portal)
1565
1566
    def test_asset_date_bounds(self):
1567
        tempdir = TempDirectory()
1568
        try:
1569
            # Run the algorithm with a sid that ends far in the future
1570
            temp_env = TradingEnvironment()
1571
1572
            data_portal = create_data_portal(
1573
                temp_env,
1574
                tempdir,
1575
                self.sim_params,
1576
                [0]
1577
            )
1578
1579
            metadata = {0: {'start_date': self.sim_params.period_start,
1580
                            'end_date': '2020-01-01'}}
1581
1582
            algo = SetAssetDateBoundsAlgorithm(
1583
                equities_metadata=metadata,
1584
                sim_params=self.sim_params,
1585
                env=temp_env,
1586
            )
1587
            algo.run(data_portal)
1588
        finally:
1589
            tempdir.cleanup()
1590
1591
        # Run the algorithm with a sid that has already ended
1592
        tempdir = TempDirectory()
1593
        try:
1594
            temp_env = TradingEnvironment()
1595
1596
            data_portal = create_data_portal(
1597
                temp_env,
1598
                tempdir,
1599
                self.sim_params,
1600
                [0]
1601
            )
1602
            metadata = {0: {'start_date': '1989-01-01',
1603
                            'end_date': '1990-01-01'}}
1604
1605
            algo = SetAssetDateBoundsAlgorithm(
1606
                equities_metadata=metadata,
1607
                sim_params=self.sim_params,
1608
                env=temp_env,
1609
            )
1610
            with self.assertRaises(TradingControlViolation):
1611
                algo.run(data_portal)
1612
        finally:
1613
            tempdir.cleanup()
1614
1615
        # Run the algorithm with a sid that has not started
1616
        tempdir = TempDirectory()
1617
        try:
1618
            temp_env = TradingEnvironment()
1619
            data_portal = create_data_portal(
1620
                temp_env,
1621
                tempdir,
1622
                self.sim_params,
1623
                [0]
1624
            )
1625
1626
            metadata = {0: {'start_date': '2020-01-01',
1627
                            'end_date': '2021-01-01'}}
1628
1629
            algo = SetAssetDateBoundsAlgorithm(
1630
                equities_metadata=metadata,
1631
                sim_params=self.sim_params,
1632
                env=temp_env,
1633
            )
1634
1635
            with self.assertRaises(TradingControlViolation):
1636
                algo.run(data_portal)
1637
1638
        finally:
1639
            tempdir.cleanup()
1640
1641
1642
class TestAccountControls(TestCase):
1643
1644
    @classmethod
1645
    def setUpClass(cls):
1646
        cls.sidint = 133
1647
        cls.env = TradingEnvironment()
1648
        cls.sim_params = factory.create_simulation_parameters(
1649
            num_days=4, env=cls.env
1650
        )
1651
1652
        cls.env.write_data(equities_data={
1653
            133: {
1654
                'start_date': cls.sim_params.period_start,
1655
                'end_date': cls.sim_params.period_end + timedelta(days=1)
1656
            }
1657
        })
1658
1659
        cls.tempdir = TempDirectory()
1660
1661
        trades_by_sid = {
1662
            cls.sidint: factory.create_trade_history(
1663
                cls.sidint,
1664
                [10.0, 10.0, 11.0, 11.0],
1665
                [100, 100, 100, 300],
1666
                timedelta(days=1),
1667
                cls.sim_params,
1668
                cls.env,
1669
            )
1670
        }
1671
1672
        cls.data_portal = create_data_portal_from_trade_history(cls.env,
1673
                                                                cls.tempdir,
1674
                                                                cls.sim_params,
1675
                                                                trades_by_sid)
1676
1677
    @classmethod
1678
    def tearDownClass(cls):
1679
        del cls.env
1680
        cls.tempdir.cleanup()
1681
1682
    def _check_algo(self,
1683
                    algo,
1684
                    handle_data,
1685
                    expected_exc):
1686
1687
        algo._handle_data = handle_data
1688
        with self.assertRaises(expected_exc) if expected_exc else nullctx():
1689
            algo.run(self.data_portal)
1690
1691
    def check_algo_succeeds(self, algo, handle_data):
1692
        # Default for order_count assumes one order per handle_data call.
1693
        self._check_algo(algo, handle_data, None)
1694
1695
    def check_algo_fails(self, algo, handle_data):
1696
        self._check_algo(algo,
1697
                         handle_data,
1698
                         AccountControlViolation)
1699
1700
    def test_set_max_leverage(self):
1701
1702
        # Set max leverage to 0 so buying one share fails.
1703
        def handle_data(algo, data):
1704
            algo.order(algo.sid(self.sidint), 1)
1705
1706
        algo = SetMaxLeverageAlgorithm(0, sim_params=self.sim_params,
1707
                                       env=self.env)
1708
        self.check_algo_fails(algo, handle_data)
1709
1710
        # Set max leverage to 1 so buying one share passes
1711
        def handle_data(algo, data):
1712
            algo.order(algo.sid(self.sidint), 1)
1713
1714
        algo = SetMaxLeverageAlgorithm(1,  sim_params=self.sim_params,
1715
                                       env=self.env)
1716
        self.check_algo_succeeds(algo, handle_data)
1717
1718
1719
class TestClosePosAlgo(TestCase):
1720
1721
    @classmethod
1722
    def setUpClass(cls):
1723
        cls.tempdir = TempDirectory()
1724
1725
        cls.env = TradingEnvironment()
1726
        cls.days = pd.date_range(start=pd.Timestamp("2006-01-09", tz='UTC'),
1727
                                 end=pd.Timestamp("2006-01-12", tz='UTC'))
1728
1729
        cls.sid = 1
1730
1731
        cls.sim_params = factory.create_simulation_parameters(
1732
            start=cls.days[0],
1733
            end=cls.days[-1]
1734
        )
1735
1736
        trades_by_sid = {}
1737
        trades_by_sid[cls.sid] = factory.create_trade_history(
1738
            cls.sid,
1739
            [1, 1, 2, 4],
1740
            [1e9, 1e9, 1e9, 1e9],
1741
            timedelta(days=1),
1742
            cls.sim_params,
1743
            cls.env
1744
        )
1745
1746
        cls.data_portal = create_data_portal_from_trade_history(
1747
            cls.env,
1748
            cls.tempdir,
1749
            cls.sim_params,
1750
            trades_by_sid
1751
        )
1752
1753
    @classmethod
1754
    def tearDownClass(cls):
1755
        cls.tempdir.cleanup()
1756
1757
    def test_auto_close_future(self):
1758
        self.env.write_data(futures_data={
1759
            1: {
1760
                "start_date": self.sim_params.trading_days[0],
1761
                "end_date": self.env.next_trading_day(
1762
                    self.sim_params.trading_days[-1]),
1763
                'symbol': 'TEST',
1764
                'asset_type': 'future',
1765
                'auto_close_date': self.env.next_trading_day(
1766
                    self.sim_params.trading_days[-1])
1767
            }
1768
        })
1769
1770
        algo = TestAlgorithm(sid=1, amount=1, order_count=1,
1771
                             commission=PerShare(0),
1772
                             env=self.env,
1773
                             sim_params=self.sim_params)
1774
1775
        # Check results
1776
        results = algo.run(self.data_portal)
1777
1778
        expected_pnl = [0, 0, 1, 2]
1779
        self.check_algo_pnl(results, expected_pnl)
1780
1781
        expected_positions = [0, 1, 1, 0]
1782
        self.check_algo_positions(results, expected_positions)
1783
1784
        expected_pnl = [0, 0, 1, 2]
1785
        self.check_algo_pnl(results, expected_pnl)
1786
1787
    def check_algo_pnl(self, results, expected_pnl):
1788
        np.testing.assert_array_almost_equal(results.pnl, expected_pnl)
1789
1790
    def check_algo_positions(self, results, expected_positions):
1791
        for i, amount in enumerate(results.positions):
1792
            if amount:
1793
                actual_position = amount[0]['amount']
1794
            else:
1795
                actual_position = 0
1796
1797
            self.assertEqual(
1798
                actual_position, expected_positions[i],
1799
                "position for day={0} not equal, actual={1}, expected={2}".
1800
                format(i, actual_position, expected_positions[i]))
1801
1802
1803
class TestFutureFlip(TestCase):
1804
    @classmethod
1805
    def setUpClass(cls):
1806
        cls.tempdir = TempDirectory()
1807
1808
        cls.env = TradingEnvironment()
1809
        cls.days = pd.date_range(start=pd.Timestamp("2006-01-09", tz='UTC'),
1810
                                 end=pd.Timestamp("2006-01-12", tz='UTC'))
1811
1812
        cls.sid = 1
1813
1814
        cls.sim_params = factory.create_simulation_parameters(
1815
            start=cls.days[0],
1816
            end=cls.days[-2]
1817
        )
1818
1819
        trades = factory.create_trade_history(
1820
            cls.sid,
1821
            [1, 2, 4],
1822
            [1e9, 1e9, 1e9],
1823
            timedelta(days=1),
1824
            cls.sim_params,
1825
            cls.env
1826
        )
1827
1828
        trades_by_sid = {
1829
            cls.sid: trades
1830
        }
1831
1832
        cls.data_portal = create_data_portal_from_trade_history(
1833
            cls.env,
1834
            cls.tempdir,
1835
            cls.sim_params,
1836
            trades_by_sid
1837
        )
1838
1839
    @classmethod
1840
    def tearDownClass(cls):
1841
        cls.tempdir.cleanup()
1842
1843
    def test_flip_algo(self):
1844
        metadata = {1: {'symbol': 'TEST',
1845
                        'start_date': self.sim_params.trading_days[0],
1846
                        'end_date': self.env.next_trading_day(
1847
                            self.sim_params.trading_days[-1]),
1848
                        'contract_multiplier': 5}}
1849
        self.env.write_data(futures_data=metadata)
1850
1851
        algo = FutureFlipAlgo(sid=1, amount=1, env=self.env,
1852
                              commission=PerShare(0),
1853
                              order_count=0,  # not applicable but required
1854
                              sim_params=self.sim_params)
1855
1856
        results = algo.run(self.data_portal)
1857
1858
        expected_positions = [0, 1, -1]
1859
        self.check_algo_positions(results, expected_positions)
1860
1861
        expected_pnl = [0, 5, -10]
1862
        self.check_algo_pnl(results, expected_pnl)
1863
1864
    def check_algo_pnl(self, results, expected_pnl):
1865
        np.testing.assert_array_almost_equal(results.pnl, expected_pnl)
1866
1867
    def check_algo_positions(self, results, expected_positions):
1868
        for i, amount in enumerate(results.positions):
1869
            if amount:
1870
                actual_position = amount[0]['amount']
1871
            else:
1872
                actual_position = 0
1873
1874
            self.assertEqual(
1875
                actual_position, expected_positions[i],
1876
                "position for day={0} not equal, actual={1}, expected={2}".
1877
                format(i, actual_position, expected_positions[i]))
1878
1879
1880
class TestTradingAlgorithm(TestCase):
1881
    def setUp(self):
1882
        self.env = TradingEnvironment()
1883
        self.days = self.env.trading_days[:4]
1884
1885
    def test_analyze_called(self):
1886
        self.perf_ref = None
1887
1888
        def initialize(context):
1889
            pass
1890
1891
        def handle_data(context, data):
1892
            pass
1893
1894
        def analyze(context, perf):
1895
            self.perf_ref = perf
1896
1897
        algo = TradingAlgorithm(initialize=initialize, handle_data=handle_data,
1898
                                analyze=analyze)
1899
1900
        data_portal = FakeDataPortal()
1901
1902
        results = algo.run(data_portal)
1903
        self.assertIs(results, self.perf_ref)
1904