Completed
Pull Request — master (#858)
by Eddie
01:43
created

tests.TestTradingControls._check_algo()   A

Complexity

Conditions 3

Size

Total Lines 10

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 3
dl 0
loc 10
rs 9.4286
1
#
2
# Copyright 2014 Quantopian, Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
import datetime
16
from datetime import timedelta
17
from mock import MagicMock
18
from nose_parameterized import parameterized
19
from six.moves import range
20
from testfixtures import TempDirectory
21
from textwrap import dedent
22
from unittest import TestCase
23
24
import numpy as np
25
import pandas as pd
26
27
from zipline.assets import Equity, Future
28
from zipline.utils.api_support import ZiplineAPI
29
from zipline.utils.control_flow import nullctx
30
from zipline.utils.test_utils import (
31
    setup_logger,
32
    teardown_logger,
33
    FakeDataPortal)
34
import zipline.utils.factory as factory
35
36
from zipline.errors import (
37
    OrderDuringInitialize,
38
    RegisterTradingControlPostInit,
39
    TradingControlViolation,
40
    AccountControlViolation,
41
    SymbolNotFound,
42
    RootSymbolNotFound,
43
    UnsupportedDatetimeFormat,
44
)
45
from zipline.test_algorithms import (
46
    access_account_in_init,
47
    access_portfolio_in_init,
48
    AmbitiousStopLimitAlgorithm,
49
    EmptyPositionsAlgorithm,
50
    InvalidOrderAlgorithm,
51
    RecordAlgorithm,
52
    FutureFlipAlgo,
53
    TestAlgorithm,
54
    TestOrderAlgorithm,
55
    TestOrderPercentAlgorithm,
56
    TestOrderStyleForwardingAlgorithm,
57
    TestOrderValueAlgorithm,
58
    TestRegisterTransformAlgorithm,
59
    TestTargetAlgorithm,
60
    TestTargetPercentAlgorithm,
61
    TestTargetValueAlgorithm,
62
    SetLongOnlyAlgorithm,
63
    SetAssetDateBoundsAlgorithm,
64
    SetMaxPositionSizeAlgorithm,
65
    SetMaxOrderCountAlgorithm,
66
    SetMaxOrderSizeAlgorithm,
67
    SetDoNotOrderListAlgorithm,
68
    SetMaxLeverageAlgorithm,
69
    api_algo,
70
    api_get_environment_algo,
71
    api_symbol_algo,
72
    call_all_order_methods,
73
    call_order_in_init,
74
    handle_data_api,
75
    handle_data_noop,
76
    initialize_api,
77
    initialize_noop,
78
    noop_algo,
79
    record_float_magic,
80
    record_variables,
81
)
82
from zipline.utils.context_tricks import CallbackManager
83
import zipline.utils.events
84
from zipline.utils.test_utils import to_utc
85
86
from zipline.finance.execution import LimitOrder
87
from zipline.finance.trading import SimulationParameters
88
from zipline.utils.api_support import set_algo_instance
89
from zipline.utils.events import DateRuleFactory, TimeRuleFactory, Always
90
from zipline.algorithm import TradingAlgorithm
91
from zipline.finance.trading import TradingEnvironment
92
from zipline.finance.commission import PerShare
93
94
from zipline.utils.test_utils import (
95
    create_data_portal,
96
    create_data_portal_from_trade_history
97
)
98
99
# Because test cases appear to reuse some resources.
100
101
102
_multiprocess_can_split_ = False
103
104
105
class TestRecordAlgorithm(TestCase):
106
107
    @classmethod
108
    def setUpClass(cls):
109
        cls.env = TradingEnvironment()
110
        cls.sids = [133]
111
        cls.env.write_data(equities_identifiers=cls.sids)
112
113
        cls.sim_params = factory.create_simulation_parameters(
114
            num_days=4,
115
            env=cls.env
116
        )
117
118
        cls.tempdir = TempDirectory()
119
120
        cls.data_portal = create_data_portal(
121
            cls.env,
122
            cls.tempdir,
123
            cls.sim_params,
124
            cls.sids
125
        )
126
127
    @classmethod
128
    def tearDownClass(cls):
129
        del cls.env
130
        cls.tempdir.cleanup()
131
132
    def test_record_incr(self):
133
        algo = RecordAlgorithm(sim_params=self.sim_params, env=self.env)
134
        output = algo.run(self.data_portal)
135
136
        np.testing.assert_array_equal(output['incr'].values,
137
                                      range(1, len(output) + 1))
138
        np.testing.assert_array_equal(output['name'].values,
139
                                      range(1, len(output) + 1))
140
        np.testing.assert_array_equal(output['name2'].values,
141
                                      [2] * len(output))
142
        np.testing.assert_array_equal(output['name3'].values,
143
                                      range(1, len(output) + 1))
144
145
146
class TestMiscellaneousAPI(TestCase):
147
148
    @classmethod
149
    def setUpClass(cls):
150
        cls.sids = [1, 2]
151
        cls.env = TradingEnvironment()
152
153
        metadata = {3: {'symbol': 'PLAY',
154
                        'start_date': '2002-01-01',
155
                        'end_date': '2004-01-01'},
156
                    4: {'symbol': 'PLAY',
157
                        'start_date': '2005-01-01',
158
                        'end_date': '2006-01-01'}}
159
160
        futures_metadata = {
161
            5: {
162
                'symbol': 'CLG06',
163
                'root_symbol': 'CL',
164
                'start_date': pd.Timestamp('2005-12-01', tz='UTC'),
165
                'notice_date': pd.Timestamp('2005-12-20', tz='UTC'),
166
                'expiration_date': pd.Timestamp('2006-01-20', tz='UTC')},
167
            6: {
168
                'root_symbol': 'CL',
169
                'symbol': 'CLK06',
170
                'start_date': pd.Timestamp('2005-12-01', tz='UTC'),
171
                'notice_date': pd.Timestamp('2006-03-20', tz='UTC'),
172
                'expiration_date': pd.Timestamp('2006-04-20', tz='UTC')},
173
            7: {
174
                'symbol': 'CLQ06',
175
                'root_symbol': 'CL',
176
                'start_date': pd.Timestamp('2005-12-01', tz='UTC'),
177
                'notice_date': pd.Timestamp('2006-06-20', tz='UTC'),
178
                'expiration_date': pd.Timestamp('2006-07-20', tz='UTC')},
179
            8: {
180
                'symbol': 'CLX06',
181
                'root_symbol': 'CL',
182
                'start_date': pd.Timestamp('2006-02-01', tz='UTC'),
183
                'notice_date': pd.Timestamp('2006-09-20', tz='UTC'),
184
                'expiration_date': pd.Timestamp('2006-10-20', tz='UTC')}
185
        }
186
        cls.env.write_data(equities_identifiers=cls.sids,
187
                           equities_data=metadata,
188
                           futures_data=futures_metadata)
189
190
        setup_logger(cls)
191
192
        cls.sim_params = factory.create_simulation_parameters(
193
            num_days=2,
194
            data_frequency='minute',
195
            emission_rate='daily',
196
            env=cls.env,
197
        )
198
199
        cls.temp_dir = TempDirectory()
200
201
        cls.data_portal = create_data_portal(
202
            cls.env,
203
            cls.temp_dir,
204
            cls.sim_params,
205
            cls.sids
206
        )
207
208
    @classmethod
209
    def tearDownClass(cls):
210
        del cls.env
211
        teardown_logger(cls)
212
        cls.temp_dir.cleanup()
213
214
    def test_zipline_api_resolves_dynamically(self):
215
        # Make a dummy algo.
216
        algo = TradingAlgorithm(
217
            initialize=lambda context: None,
218
            handle_data=lambda context, data: None,
219
            sim_params=self.sim_params,
220
        )
221
222
        # Verify that api methods get resolved dynamically by patching them out
223
        # and then calling them
224
        for method in algo.all_api_methods():
225
            name = method.__name__
226
            sentinel = object()
227
228
            def fake_method(*args, **kwargs):
229
                return sentinel
230
            setattr(algo, name, fake_method)
231
            with ZiplineAPI(algo):
232
                self.assertIs(sentinel, getattr(zipline.api, name)())
233
234
    def test_get_environment(self):
235
        expected_env = {
236
            'arena': 'backtest',
237
            'data_frequency': 'minute',
238
            'start': pd.Timestamp('2006-01-03 14:31:00+0000', tz='UTC'),
239
            'end': pd.Timestamp('2006-01-04 21:00:00+0000', tz='UTC'),
240
            'capital_base': 100000.0,
241
            'platform': 'zipline'
242
        }
243
244
        def initialize(algo):
245
            self.assertEqual('zipline', algo.get_environment())
246
            self.assertEqual(expected_env, algo.get_environment('*'))
247
248
        def handle_data(algo, data):
249
            pass
250
251
        algo = TradingAlgorithm(initialize=initialize,
252
                                handle_data=handle_data,
253
                                sim_params=self.sim_params,
254
                                env=self.env)
255
        algo.run(self.data_portal)
256
257
    def test_get_open_orders(self):
258
        def initialize(algo):
259
            algo.minute = 0
260
261
        def handle_data(algo, data):
262
            if algo.minute == 0:
263
264
                # Should be filled by the next minute
265
                algo.order(algo.sid(1), 1)
266
267
                # Won't be filled because the price is too low.
268
                algo.order(algo.sid(2), 1, style=LimitOrder(0.01))
269
                algo.order(algo.sid(2), 1, style=LimitOrder(0.01))
270
                algo.order(algo.sid(2), 1, style=LimitOrder(0.01))
271
272
                all_orders = algo.get_open_orders()
273
                self.assertEqual(list(all_orders.keys()), [1, 2])
274
275
                self.assertEqual(all_orders[1], algo.get_open_orders(1))
276
                self.assertEqual(len(all_orders[1]), 1)
277
278
                self.assertEqual(all_orders[2], algo.get_open_orders(2))
279
                self.assertEqual(len(all_orders[2]), 3)
280
281
            if algo.minute == 1:
282
                # First order should have filled.
283
                # Second order should still be open.
284
                all_orders = algo.get_open_orders()
285
                self.assertEqual(list(all_orders.keys()), [2])
286
287
                self.assertEqual([], algo.get_open_orders(1))
288
289
                orders_2 = algo.get_open_orders(2)
290
                self.assertEqual(all_orders[2], orders_2)
291
                self.assertEqual(len(all_orders[2]), 3)
292
293
                for order in orders_2:
294
                    algo.cancel_order(order)
295
296
                all_orders = algo.get_open_orders()
297
                self.assertEqual(all_orders, {})
298
299
            algo.minute += 1
300
301
        algo = TradingAlgorithm(initialize=initialize,
302
                                handle_data=handle_data,
303
                                sim_params=self.sim_params,
304
                                env=self.env)
305
        algo.run(self.data_portal)
306
307
    def test_schedule_function(self):
308
        date_rules = DateRuleFactory
309
        time_rules = TimeRuleFactory
310
311
        def incrementer(algo, data):
312
            algo.func_called += 1
313
            self.assertEqual(
314
                algo.get_datetime().time(),
315
                datetime.time(hour=14, minute=31),
316
            )
317
318
        def initialize(algo):
319
            algo.func_called = 0
320
            algo.days = 1
321
            algo.date = None
322
            algo.schedule_function(
323
                func=incrementer,
324
                date_rule=date_rules.every_day(),
325
                time_rule=time_rules.market_open(),
326
            )
327
328
        def handle_data(algo, data):
329
            if not algo.date:
330
                algo.date = algo.get_datetime().date()
331
332
            if algo.date < algo.get_datetime().date():
333
                algo.days += 1
334
                algo.date = algo.get_datetime().date()
335
336
        algo = TradingAlgorithm(
337
            initialize=initialize,
338
            handle_data=handle_data,
339
            sim_params=self.sim_params,
340
            env=self.env,
341
        )
342
        algo.run(self.data_portal)
343
344
        self.assertEqual(algo.func_called, algo.days)
345
346
    def test_event_context(self):
347
        expected_data = []
348
        collected_data_pre = []
349
        collected_data_post = []
350
        function_stack = []
351
352
        def pre(data):
353
            function_stack.append(pre)
354
            collected_data_pre.append(data)
355
356
        def post(data):
357
            function_stack.append(post)
358
            collected_data_post.append(data)
359
360
        def initialize(context):
361
            context.add_event(Always(), f)
362
            context.add_event(Always(), g)
363
364
        def handle_data(context, data):
365
            function_stack.append(handle_data)
366
            expected_data.append(data)
367
368
        def f(context, data):
369
            function_stack.append(f)
370
371
        def g(context, data):
372
            function_stack.append(g)
373
374
        algo = TradingAlgorithm(
375
            initialize=initialize,
376
            handle_data=handle_data,
377
            sim_params=self.sim_params,
378
            create_event_context=CallbackManager(pre, post),
379
            env=self.env,
380
        )
381
        algo.run(self.data_portal)
382
383
        self.assertEqual(len(expected_data), 780)
384
        self.assertEqual(collected_data_pre, expected_data)
385
        self.assertEqual(collected_data_post, expected_data)
386
387
        self.assertEqual(
388
            len(function_stack),
389
            780 * 5,
390
            'Incorrect number of functions called: %s != 780' %
391
            len(function_stack),
392
        )
393
        expected_functions = [pre, handle_data, f, g, post] * 780
394
        for n, (f, g) in enumerate(zip(function_stack, expected_functions)):
395
            self.assertEqual(
396
                f,
397
                g,
398
                'function at position %d was incorrect, expected %s but got %s'
399
                % (n, g.__name__, f.__name__),
400
            )
401
402
    @parameterized.expand([
403
        ('daily',),
404
        ('minute'),
405
    ])
406
    def test_schedule_function_rule_creation(self, mode):
407
        def nop(*args, **kwargs):
408
            return None
409
410
        self.sim_params.data_frequency = mode
411
        algo = TradingAlgorithm(
412
            initialize=nop,
413
            handle_data=nop,
414
            sim_params=self.sim_params,
415
            env=self.env,
416
        )
417
418
        # Schedule something for NOT Always.
419
        algo.schedule_function(nop, time_rule=zipline.utils.events.Never())
420
421
        event_rule = algo.event_manager._events[1].rule
422
423
        self.assertIsInstance(event_rule, zipline.utils.events.OncePerDay)
424
425
        inner_rule = event_rule.rule
426
        self.assertIsInstance(inner_rule, zipline.utils.events.ComposedRule)
427
428
        first = inner_rule.first
429
        second = inner_rule.second
430
        composer = inner_rule.composer
431
432
        self.assertIsInstance(first, zipline.utils.events.Always)
433
434
        if mode == 'daily':
435
            self.assertIsInstance(second, zipline.utils.events.Always)
436
        else:
437
            self.assertIsInstance(second, zipline.utils.events.Never)
438
439
        self.assertIs(composer, zipline.utils.events.ComposedRule.lazy_and)
440
441
    def test_asset_lookup(self):
442
443
        algo = TradingAlgorithm(env=self.env)
444
445
        # Test before either PLAY existed
446
        algo.sim_params.period_end = pd.Timestamp('2001-12-01', tz='UTC')
447
        with self.assertRaises(SymbolNotFound):
448
            algo.symbol('PLAY')
449
        with self.assertRaises(SymbolNotFound):
450
            algo.symbols('PLAY')
451
452
        # Test when first PLAY exists
453
        algo.sim_params.period_end = pd.Timestamp('2002-12-01', tz='UTC')
454
        list_result = algo.symbols('PLAY')
455
        self.assertEqual(3, list_result[0])
456
457
        # Test after first PLAY ends
458
        algo.sim_params.period_end = pd.Timestamp('2004-12-01', tz='UTC')
459
        self.assertEqual(3, algo.symbol('PLAY'))
460
461
        # Test after second PLAY begins
462
        algo.sim_params.period_end = pd.Timestamp('2005-12-01', tz='UTC')
463
        self.assertEqual(4, algo.symbol('PLAY'))
464
465
        # Test after second PLAY ends
466
        algo.sim_params.period_end = pd.Timestamp('2006-12-01', tz='UTC')
467
        self.assertEqual(4, algo.symbol('PLAY'))
468
        list_result = algo.symbols('PLAY')
469
        self.assertEqual(4, list_result[0])
470
471
        # Test lookup SID
472
        self.assertIsInstance(algo.sid(3), Equity)
473
        self.assertIsInstance(algo.sid(4), Equity)
474
475
        # Supplying a non-string argument to symbol()
476
        # should result in a TypeError.
477
        with self.assertRaises(TypeError):
478
            algo.symbol(1)
479
480
        with self.assertRaises(TypeError):
481
            algo.symbol((1,))
482
483
        with self.assertRaises(TypeError):
484
            algo.symbol({1})
485
486
        with self.assertRaises(TypeError):
487
            algo.symbol([1])
488
489
        with self.assertRaises(TypeError):
490
            algo.symbol({'foo': 'bar'})
491
492
    def test_future_symbol(self):
493
        """ Tests the future_symbol API function.
494
        """
495
        algo = TradingAlgorithm(env=self.env)
496
        algo.datetime = pd.Timestamp('2006-12-01', tz='UTC')
497
498
        # Check that we get the correct fields for the CLG06 symbol
499
        cl = algo.future_symbol('CLG06')
500
        self.assertEqual(cl.sid, 5)
501
        self.assertEqual(cl.symbol, 'CLG06')
502
        self.assertEqual(cl.root_symbol, 'CL')
503
        self.assertEqual(cl.start_date, pd.Timestamp('2005-12-01', tz='UTC'))
504
        self.assertEqual(cl.notice_date, pd.Timestamp('2005-12-20', tz='UTC'))
505
        self.assertEqual(cl.expiration_date,
506
                         pd.Timestamp('2006-01-20', tz='UTC'))
507
508
        with self.assertRaises(SymbolNotFound):
509
            algo.future_symbol('')
510
511
        with self.assertRaises(SymbolNotFound):
512
            algo.future_symbol('PLAY')
513
514
        with self.assertRaises(SymbolNotFound):
515
            algo.future_symbol('FOOBAR')
516
517
        # Supplying a non-string argument to future_symbol()
518
        # should result in a TypeError.
519
        with self.assertRaises(TypeError):
520
            algo.future_symbol(1)
521
522
        with self.assertRaises(TypeError):
523
            algo.future_symbol((1,))
524
525
        with self.assertRaises(TypeError):
526
            algo.future_symbol({1})
527
528
        with self.assertRaises(TypeError):
529
            algo.future_symbol([1])
530
531
        with self.assertRaises(TypeError):
532
            algo.future_symbol({'foo': 'bar'})
533
534
    def test_future_chain(self):
535
        """ Tests the future_chain API function.
536
        """
537
        algo = TradingAlgorithm(env=self.env)
538
        algo.datetime = pd.Timestamp('2006-12-01', tz='UTC')
539
540
        # Check that the fields of the FutureChain object are set correctly
541
        cl = algo.future_chain('CL')
542
        self.assertEqual(cl.root_symbol, 'CL')
543
        self.assertEqual(cl.as_of_date, algo.datetime)
544
545
        # Check the fields are set correctly if an as_of_date is supplied
546
        as_of_date = pd.Timestamp('1952-08-11', tz='UTC')
547
548
        cl = algo.future_chain('CL', as_of_date=as_of_date)
549
        self.assertEqual(cl.root_symbol, 'CL')
550
        self.assertEqual(cl.as_of_date, as_of_date)
551
552
        cl = algo.future_chain('CL', as_of_date='1952-08-11')
553
        self.assertEqual(cl.root_symbol, 'CL')
554
        self.assertEqual(cl.as_of_date, as_of_date)
555
556
        # Check that weird capitalization is corrected
557
        cl = algo.future_chain('cL')
558
        self.assertEqual(cl.root_symbol, 'CL')
559
560
        cl = algo.future_chain('cl')
561
        self.assertEqual(cl.root_symbol, 'CL')
562
563
        # Check that invalid root symbols raise RootSymbolNotFound
564
        with self.assertRaises(RootSymbolNotFound):
565
            algo.future_chain('CLZ')
566
567
        with self.assertRaises(RootSymbolNotFound):
568
            algo.future_chain('')
569
570
        # Check that invalid dates raise UnsupportedDatetimeFormat
571
        with self.assertRaises(UnsupportedDatetimeFormat):
572
            algo.future_chain('CL', 'my_finger_slipped')
573
574
        with self.assertRaises(UnsupportedDatetimeFormat):
575
            algo.future_chain('CL', '2015-09-')
576
577
        # Supplying a non-string argument to future_chain()
578
        # should result in a TypeError.
579
        with self.assertRaises(TypeError):
580
            algo.future_chain(1)
581
582
        with self.assertRaises(TypeError):
583
            algo.future_chain((1,))
584
585
        with self.assertRaises(TypeError):
586
            algo.future_chain({1})
587
588
        with self.assertRaises(TypeError):
589
            algo.future_chain([1])
590
591
        with self.assertRaises(TypeError):
592
            algo.future_chain({'foo': 'bar'})
593
594
    def test_set_symbol_lookup_date(self):
595
        """
596
        Test the set_symbol_lookup_date API method.
597
        """
598
        # Note we start sid enumeration at i+3 so as not to
599
        # collide with sids [1, 2] added in the setUp() method.
600
        dates = pd.date_range('2013-01-01', freq='2D', periods=2, tz='UTC')
601
        # Create two assets with the same symbol but different
602
        # non-overlapping date ranges.
603
        metadata = pd.DataFrame.from_records(
604
            [
605
                {
606
                    'sid': i + 3,
607
                    'symbol': 'DUP',
608
                    'start_date': date.value,
609
                    'end_date': (date + timedelta(days=1)).value,
610
                }
611
                for i, date in enumerate(dates)
612
            ]
613
        )
614
        env = TradingEnvironment()
615
        env.write_data(equities_df=metadata)
616
        algo = TradingAlgorithm(env=env)
617
618
        # Set the period end to a date after the period end
619
        # dates for our assets.
620
        algo.sim_params.period_end = pd.Timestamp('2015-01-01', tz='UTC')
621
622
        # With no symbol lookup date set, we will use the period end date
623
        # for the as_of_date, resulting here in the asset with the earlier
624
        # start date being returned.
625
        result = algo.symbol('DUP')
626
        self.assertEqual(result.symbol, 'DUP')
627
628
        # By first calling set_symbol_lookup_date, the relevant asset
629
        # should be returned by lookup_symbol
630
        for i, date in enumerate(dates):
631
            algo.set_symbol_lookup_date(date)
632
            result = algo.symbol('DUP')
633
            self.assertEqual(result.symbol, 'DUP')
634
            self.assertEqual(result.sid, i + 3)
635
636
        with self.assertRaises(UnsupportedDatetimeFormat):
637
            algo.set_symbol_lookup_date('foobar')
638
639
640
class TestTransformAlgorithm(TestCase):
641
642
    @classmethod
643
    def setUpClass(cls):
644
        setup_logger(cls)
645
        cls.env = TradingEnvironment()
646
        cls.sim_params = factory.create_simulation_parameters(num_days=4,
647
                                                              env=cls.env)
648
        cls.sids = [0, 1, 133]
649
        cls.tempdir = TempDirectory()
650
651
        futures_metadata = {3: {'contract_multiplier': 10}}
652
        equities_metadata = {}
653
654
        for sid in cls.sids:
655
            equities_metadata[sid] = {
656
                'start_date': cls.sim_params.period_start,
657
                'end_date': cls.sim_params.period_end
658
            }
659
660
        cls.env.write_data(equities_data=equities_metadata,
661
                           futures_data=futures_metadata)
662
663
        trades_by_sid = {}
664
        for sid in cls.sids:
665
            trades_by_sid[sid] = factory.create_trade_history(
666
                sid,
667
                [10.0, 10.0, 11.0, 11.0],
668
                [100, 100, 100, 300],
669
                timedelta(days=1),
670
                cls.sim_params,
671
                cls.env
672
            )
673
674
        cls.data_portal = create_data_portal_from_trade_history(cls.env,
675
                                                                cls.tempdir,
676
                                                                cls.sim_params,
677
                                                                trades_by_sid)
678
679
        cls.data_portal.current_day = cls.sim_params.trading_days[0]
680
681
    @classmethod
682
    def tearDownClass(cls):
683
        teardown_logger(cls)
684
        del cls.env
685
        cls.tempdir.cleanup()
686
687
    def test_invalid_order_parameters(self):
688
        algo = InvalidOrderAlgorithm(
689
            sids=[133],
690
            sim_params=self.sim_params,
691
            env=self.env,
692
        )
693
        algo.run(self.data_portal)
694
695
    def test_run_twice(self):
696
        algo1 = TestRegisterTransformAlgorithm(
697
            sim_params=self.sim_params,
698
            sids=[0, 1]
699
        )
700
701
        res1 = algo1.run(self.data_portal)
702
703
        # Create a new trading algorithm, which will
704
        # use the newly instantiated environment.
705
        algo2 = TestRegisterTransformAlgorithm(
706
            sim_params=self.sim_params,
707
            sids=[0, 1]
708
        )
709
710
        res2 = algo2.run(self.data_portal)
711
712
        # FIXME I think we are getting Nans due to fixed benchmark,
713
        # so dropping them for now.
714
        res1 = res1.fillna(method='ffill')
715
        res2 = res2.fillna(method='ffill')
716
717
        np.testing.assert_array_equal(res1, res2)
718
719
    def test_data_frequency_setting(self):
720
        self.sim_params.data_frequency = 'daily'
721
722
        sim_params = factory.create_simulation_parameters(
723
            num_days=4, env=self.env, data_frequency='daily')
724
725
        algo = TestRegisterTransformAlgorithm(
726
            sim_params=sim_params,
727
            env=self.env,
728
        )
729
        self.assertEqual(algo.sim_params.data_frequency, 'daily')
730
731
        sim_params = factory.create_simulation_parameters(
732
            num_days=4, env=self.env, data_frequency='minute')
733
734
        algo = TestRegisterTransformAlgorithm(
735
            sim_params=sim_params,
736
            env=self.env,
737
        )
738
        self.assertEqual(algo.sim_params.data_frequency, 'minute')
739
740
    @parameterized.expand([
741
        (TestOrderAlgorithm,),
742
        (TestOrderValueAlgorithm,),
743
        (TestTargetAlgorithm,),
744
        (TestOrderPercentAlgorithm,),
745
        (TestTargetPercentAlgorithm,),
746
        (TestTargetValueAlgorithm,),
747
    ])
748
    def test_order_methods(self, algo_class):
749
        algo = algo_class(
750
            sim_params=self.sim_params,
751
            env=self.env,
752
        )
753
        # Ensure that the environment's asset 0 is an Equity
754
        asset_to_test = algo.sid(0)
755
        self.assertIsInstance(asset_to_test, Equity)
756
757
        algo.run(self.data_portal)
758
759
    @parameterized.expand([
760
        (TestOrderAlgorithm,),
761
        (TestOrderValueAlgorithm,),
762
        (TestTargetAlgorithm,),
763
        (TestOrderPercentAlgorithm,),
764
        (TestTargetValueAlgorithm,),
765
    ])
766
    def test_order_methods_for_future(self, algo_class):
767
        algo = algo_class(
768
            sim_params=self.sim_params,
769
            env=self.env,
770
        )
771
        # Ensure that the environment's asset 0 is a Future
772
        asset_to_test = algo.sid(3)
773
        self.assertIsInstance(asset_to_test, Future)
774
775
        algo.run(self.data_portal)
776
777
    @parameterized.expand([
778
        ("order",),
779
        ("order_value",),
780
        ("order_percent",),
781
        ("order_target",),
782
        ("order_target_percent",),
783
        ("order_target_value",),
784
    ])
785
    def test_order_method_style_forwarding(self, order_style):
786
        algo = TestOrderStyleForwardingAlgorithm(
787
            sim_params=self.sim_params,
788
            method_name=order_style,
789
            env=self.env
790
        )
791
        algo.run(self.data_portal)
792
793
    @parameterized.expand([
794
        (TestOrderAlgorithm,),
795
        (TestOrderValueAlgorithm,),
796
        (TestTargetAlgorithm,),
797
        (TestOrderPercentAlgorithm,)
798
    ])
799
    def test_minute_data(self, algo_class):
800
        tempdir = TempDirectory()
801
802
        try:
803
            env = TradingEnvironment()
804
805
            sim_params = SimulationParameters(
806
                period_start=pd.Timestamp('2002-1-2', tz='UTC'),
807
                period_end=pd.Timestamp('2002-1-4', tz='UTC'),
808
                capital_base=float("1.0e5"),
809
                data_frequency='minute',
810
                env=env
811
            )
812
813
            equities_metadata = {}
814
815
            for sid in [0, 1]:
816
                equities_metadata[sid] = {
817
                    'start_date': sim_params.period_start,
818
                    'end_date': sim_params.period_end + timedelta(days=1)
819
                }
820
821
            env.write_data(equities_data=equities_metadata)
822
823
            data_portal = create_data_portal(
824
                env,
825
                tempdir,
826
                sim_params,
827
                [0, 1]
828
            )
829
830
            algo = algo_class(sim_params=sim_params, env=env)
831
            algo.run(data_portal)
832
        finally:
833
            tempdir.cleanup()
834
835
836
class TestPositions(TestCase):
837
    @classmethod
838
    def setUpClass(cls):
839
        setup_logger(cls)
840
        cls.env = TradingEnvironment()
841
        cls.sim_params = factory.create_simulation_parameters(num_days=4,
842
                                                              env=cls.env)
843
844
        cls.sids = [1, 133]
845
        cls.tempdir = TempDirectory()
846
847
        equities_metadata = {}
848
849
        for sid in cls.sids:
850
            equities_metadata[sid] = {
851
                'start_date': cls.sim_params.period_start,
852
                'end_date': cls.sim_params.period_end
853
            }
854
855
        cls.env.write_data(equities_data=equities_metadata)
856
857
        cls.data_portal = create_data_portal(
858
            cls.env,
859
            cls.tempdir,
860
            cls.sim_params,
861
            cls.sids
862
        )
863
864
    @classmethod
865
    def tearDownClass(cls):
866
        teardown_logger(cls)
867
        cls.tempdir.cleanup()
868
869
    def test_empty_portfolio(self):
870
        algo = EmptyPositionsAlgorithm(self.sids,
871
                                       sim_params=self.sim_params,
872
                                       env=self.env)
873
        daily_stats = algo.run(self.data_portal)
874
875
        expected_position_count = [
876
            0,  # Before entering the first position
877
            2,  # After entering, exiting on this date
878
            0,  # After exiting
879
            0,
880
        ]
881
882
        for i, expected in enumerate(expected_position_count):
883
            self.assertEqual(daily_stats.ix[i]['num_positions'],
884
                             expected)
885
886
    def test_noop_orders(self):
887
        algo = AmbitiousStopLimitAlgorithm(sid=1,
888
                                           sim_params=self.sim_params,
889
                                           env=self.env)
890
        daily_stats = algo.run(self.data_portal)
891
892
        # Verify that positions are empty for all dates.
893
        empty_positions = daily_stats.positions.map(lambda x: len(x) == 0)
894
        self.assertTrue(empty_positions.all())
895
896
897
class TestAlgoScript(TestCase):
898
899
    @classmethod
900
    def setUpClass(cls):
901
        setup_logger(cls)
902
        cls.env = TradingEnvironment()
903
        cls.sim_params = factory.create_simulation_parameters(num_days=251,
904
                                                              env=cls.env)
905
906
        cls.sids = [0, 1, 3, 133]
907
        cls.tempdir = TempDirectory()
908
909
        equities_metadata = {}
910
911
        for sid in cls.sids:
912
            equities_metadata[sid] = {
913
                'start_date': cls.sim_params.period_start,
914
                'end_date': cls.sim_params.period_end
915
            }
916
917
            if sid == 3:
918
                equities_metadata[sid]["symbol"] = "TEST"
919
                equities_metadata[sid]["asset_type"] = "equity"
920
921
        cls.env.write_data(equities_data=equities_metadata)
922
923
        days = 251
924
925
        trades_by_sid = {
926
            0: factory.create_trade_history(
927
                0,
928
                [10.0] * days,
929
                [100] * days,
930
                timedelta(days=1),
931
                cls.sim_params,
932
                cls.env),
933
            3: factory.create_trade_history(
934
                3,
935
                [10.0] * days,
936
                [100] * days,
937
                timedelta(days=1),
938
                cls.sim_params,
939
                cls.env)
940
        }
941
942
        cls.data_portal = create_data_portal_from_trade_history(cls.env,
943
                                                                cls.tempdir,
944
                                                                cls.sim_params,
945
                                                                trades_by_sid)
946
947
        cls.zipline_test_config = {
948
            'sid': 0,
949
        }
950
951
    @classmethod
952
    def tearDownClass(cls):
953
        del cls.env
954
        cls.tempdir.cleanup()
955
        teardown_logger(cls)
956
957
    def test_noop(self):
958
        algo = TradingAlgorithm(initialize=initialize_noop,
959
                                handle_data=handle_data_noop)
960
        algo.run(self.data_portal)
961
962
    def test_noop_string(self):
963
        algo = TradingAlgorithm(script=noop_algo)
964
        algo.run(self.data_portal)
965
966
    def test_api_calls(self):
967
        algo = TradingAlgorithm(initialize=initialize_api,
968
                                handle_data=handle_data_api,
969
                                env=self.env)
970
        algo.run(self.data_portal)
971
972
    def test_api_calls_string(self):
973
        algo = TradingAlgorithm(script=api_algo, env=self.env)
974
        algo.run(self.data_portal)
975
976
    def test_api_get_environment(self):
977
        platform = 'zipline'
978
        algo = TradingAlgorithm(script=api_get_environment_algo,
979
                                platform=platform)
980
        algo.run(self.data_portal)
981
        self.assertEqual(algo.environment, platform)
982
983
    def test_api_symbol(self):
984
        algo = TradingAlgorithm(script=api_symbol_algo,
985
                                env=self.env,
986
                                sim_params=self.sim_params)
987
        algo.run(self.data_portal)
988
989
    def test_fixed_slippage(self):
990
        # verify order -> transaction -> portfolio position.
991
        # --------------
992
        test_algo = TradingAlgorithm(
993
            script="""
994
from zipline.api import (slippage,
995
                         commission,
996
                         set_slippage,
997
                         set_commission,
998
                         order,
999
                         record,
1000
                         sid)
1001
1002
def initialize(context):
1003
    model = slippage.FixedSlippage(spread=0.10)
1004
    set_slippage(model)
1005
    set_commission(commission.PerTrade(100.00))
1006
    context.count = 1
1007
    context.incr = 0
1008
1009
def handle_data(context, data):
1010
    if context.incr < context.count:
1011
        order(sid(0), -1000)
1012
    record(price=data[0].price)
1013
1014
    context.incr += 1""",
1015
            sim_params=self.sim_params,
1016
            env=self.env,
1017
        )
1018
        results = test_algo.run(self.data_portal)
1019
1020
        # flatten the list of txns
1021
        all_txns = [val for sublist in results["transactions"].tolist()
1022
                    for val in sublist]
1023
1024
        self.assertEqual(len(all_txns), 1)
1025
        txn = all_txns[0]
1026
1027
        self.assertEqual(100.0, txn["commission"])
1028
        expected_spread = 0.05
1029
        expected_commish = 0.10
1030
        expected_price = test_algo.recorded_vars["price"] - expected_spread \
1031
            - expected_commish
1032
1033
        self.assertEqual(expected_price, txn['price'])
1034
1035
    @parameterized.expand(
1036
        [
1037
            ('no_minimum_commission', 0,),
1038
            ('default_minimum_commission', 1,),
1039
            ('alternate_minimum_commission', 2,),
1040
        ]
1041
    )
1042
    def test_volshare_slippage(self, name, minimum_commission):
1043
        tempdir = TempDirectory()
1044
        try:
1045
            if name == "default_minimum_commission":
1046
                commission_line = "set_commission(commission.PerShare(0.02))"
1047
            else:
1048
                commission_line = \
1049
                    "set_commission(commission.PerShare(0.02, " \
1050
                    "min_trade_cost={0}))".format(minimum_commission)
1051
1052
            # verify order -> transaction -> portfolio position.
1053
            # --------------
1054
            test_algo = TradingAlgorithm(
1055
                script="""
1056
from zipline.api import *
1057
1058
def initialize(context):
1059
    model = slippage.VolumeShareSlippage(
1060
                            volume_limit=.3,
1061
                            price_impact=0.05
1062
                       )
1063
    set_slippage(model)
1064
    {0}
1065
1066
    context.count = 2
1067
    context.incr = 0
1068
1069
def handle_data(context, data):
1070
    if context.incr < context.count:
1071
        # order small lots to be sure the
1072
        # order will fill in a single transaction
1073
        order(sid(0), 5000)
1074
    record(price=data[0].price)
1075
    record(volume=data[0].volume)
1076
    record(incr=context.incr)
1077
    context.incr += 1
1078
    """.format(commission_line),
1079
                sim_params=self.sim_params,
1080
                env=self.env,
1081
            )
1082
            set_algo_instance(test_algo)
1083
            trades = factory.create_daily_trade_source(
1084
                [0], self.sim_params, self.env)
1085
            data_portal = create_data_portal_from_trade_history(
1086
                self.env, tempdir, self.sim_params, {0: trades})
1087
            results = test_algo.run(data_portal=data_portal)
1088
1089
            all_txns = [
1090
                val for sublist in results["transactions"].tolist()
1091
                for val in sublist]
1092
1093
            self.assertEqual(len(all_txns), 67)
1094
            first_txn = all_txns[0]
1095
1096
            if minimum_commission == 0:
1097
                commish = first_txn["amount"] * 0.02
1098
                self.assertEqual(commish, first_txn["commission"])
1099
            else:
1100
                self.assertEqual(minimum_commission, first_txn["commission"])
1101
1102
            # import pdb; pdb.set_trace()
1103
            # commish = first_txn["amount"] * per_share_commish
1104
            # self.assertEqual(commish, first_txn["commission"])
1105
            # self.assertEqual(2.029, first_txn["price"])
1106
        finally:
1107
            tempdir.cleanup()
1108
1109
    def test_algo_record_vars(self):
1110
        test_algo = TradingAlgorithm(
1111
            script=record_variables,
1112
            sim_params=self.sim_params,
1113
            env=self.env,
1114
        )
1115
        results = test_algo.run(self.data_portal)
1116
1117
        for i in range(1, 252):
1118
            self.assertEqual(results.iloc[i-1]["incr"], i)
1119
1120
    def test_algo_record_allow_mock(self):
1121
        """
1122
        Test that values from "MagicMock"ed methods can be passed to record.
1123
1124
        Relevant for our basic/validation and methods like history, which
1125
        will end up returning a MagicMock instead of a DataFrame.
1126
        """
1127
        test_algo = TradingAlgorithm(
1128
            script=record_variables,
1129
            sim_params=self.sim_params,
1130
        )
1131
        set_algo_instance(test_algo)
1132
1133
        test_algo.record(foo=MagicMock())
1134
1135
    def test_algo_record_nan(self):
1136
        test_algo = TradingAlgorithm(
1137
            script=record_float_magic % 'nan',
1138
            sim_params=self.sim_params,
1139
            env=self.env,
1140
        )
1141
        results = test_algo.run(self.data_portal)
1142
1143
        for i in range(1, 252):
1144
            self.assertTrue(np.isnan(results.iloc[i-1]["data"]))
1145
1146
    def test_order_methods(self):
1147
        """
1148
        Only test that order methods can be called without error.
1149
        Correct filling of orders is tested in zipline.
1150
        """
1151
        test_algo = TradingAlgorithm(
1152
            script=call_all_order_methods,
1153
            sim_params=self.sim_params,
1154
            env=self.env,
1155
        )
1156
        test_algo.run(self.data_portal)
1157
1158
    def test_order_in_init(self):
1159
        """
1160
        Test that calling order in initialize
1161
        will raise an error.
1162
        """
1163
        with self.assertRaises(OrderDuringInitialize):
1164
            test_algo = TradingAlgorithm(
1165
                script=call_order_in_init,
1166
                sim_params=self.sim_params,
1167
                env=self.env,
1168
            )
1169
            test_algo.run(self.data_portal)
1170
1171
    def test_portfolio_in_init(self):
1172
        """
1173
        Test that accessing portfolio in init doesn't break.
1174
        """
1175
        test_algo = TradingAlgorithm(
1176
            script=access_portfolio_in_init,
1177
            sim_params=self.sim_params,
1178
            env=self.env,
1179
        )
1180
        test_algo.run(self.data_portal)
1181
1182
    def test_account_in_init(self):
1183
        """
1184
        Test that accessing account in init doesn't break.
1185
        """
1186
        test_algo = TradingAlgorithm(
1187
            script=access_account_in_init,
1188
            sim_params=self.sim_params,
1189
            env=self.env,
1190
        )
1191
        test_algo.run(self.data_portal)
1192
1193
1194
class TestGetDatetime(TestCase):
1195
1196
    @classmethod
1197
    def setUpClass(cls):
1198
        cls.env = TradingEnvironment()
1199
        cls.env.write_data(equities_identifiers=[0, 1])
1200
1201
        setup_logger(cls)
1202
1203
        cls.sim_params = factory.create_simulation_parameters(
1204
            data_frequency='minute',
1205
            env=cls.env,
1206
            start=to_utc('2014-01-02 9:31'),
1207
            end=to_utc('2014-01-03 9:31')
1208
        )
1209
1210
        cls.tempdir = TempDirectory()
1211
1212
        cls.data_portal = create_data_portal(
1213
            cls.env,
1214
            cls.tempdir,
1215
            cls.sim_params,
1216
            [1]
1217
        )
1218
1219
    @classmethod
1220
    def tearDownClass(cls):
1221
        del cls.env
1222
        teardown_logger(cls)
1223
        cls.tempdir.cleanup()
1224
1225
    @parameterized.expand(
1226
        [
1227
            ('default', None,),
1228
            ('utc', 'UTC',),
1229
            ('us_east', 'US/Eastern',),
1230
        ]
1231
    )
1232
    def test_get_datetime(self, name, tz):
1233
        algo = dedent(
1234
            """
1235
            import pandas as pd
1236
            from zipline.api import get_datetime
1237
1238
            def initialize(context):
1239
                context.tz = {tz} or 'UTC'
1240
                context.first_bar = True
1241
1242
            def handle_data(context, data):
1243
                if context.first_bar:
1244
                    dt = get_datetime({tz})
1245
                    if dt.tz.zone != context.tz:
1246
                        raise ValueError("Mismatched Zone")
1247
                    elif dt.tz_convert("US/Eastern").hour != 9:
1248
                        raise ValueError("Mismatched Hour")
1249
                    elif dt.tz_convert("US/Eastern").minute != 31:
1250
                        raise ValueError("Mismatched Minute")
1251
                context.first_bar = False
1252
            """.format(tz=repr(tz))
1253
        )
1254
1255
        algo = TradingAlgorithm(
1256
            script=algo,
1257
            sim_params=self.sim_params,
1258
            env=self.env,
1259
        )
1260
        algo.run(self.data_portal)
1261
        self.assertFalse(algo.first_bar)
1262
1263
1264
class TestTradingControls(TestCase):
1265
1266
    @classmethod
1267
    def setUpClass(cls):
1268
        cls.sid = 133
1269
        cls.env = TradingEnvironment()
1270
        cls.sim_params = factory.create_simulation_parameters(num_days=4,
1271
                                                              env=cls.env)
1272
1273
        cls.env.write_data(equities_data={
1274
            133: {
1275
                'start_date': cls.sim_params.period_start,
1276
                'end_date': cls.sim_params.period_end
1277
            }
1278
        })
1279
1280
        cls.tempdir = TempDirectory()
1281
1282
        cls.data_portal = create_data_portal(
1283
            cls.env,
1284
            cls.tempdir,
1285
            cls.sim_params,
1286
            [cls.sid]
1287
        )
1288
1289
    @classmethod
1290
    def tearDownClass(cls):
1291
        del cls.env
1292
        cls.tempdir.cleanup()
1293
1294
    def _check_algo(self,
1295
                    algo,
1296
                    handle_data,
1297
                    expected_order_count,
1298
                    expected_exc):
1299
1300
        algo._handle_data = handle_data
1301
        with self.assertRaises(expected_exc) if expected_exc else nullctx():
1302
            algo.run(self.data_portal)
1303
        self.assertEqual(algo.order_count, expected_order_count)
1304
1305
    def check_algo_succeeds(self, algo, handle_data, order_count=4):
1306
        # Default for order_count assumes one order per handle_data call.
1307
        self._check_algo(algo, handle_data, order_count, None)
1308
1309
    def check_algo_fails(self, algo, handle_data, order_count):
1310
        self._check_algo(algo,
1311
                         handle_data,
1312
                         order_count,
1313
                         TradingControlViolation)
1314
1315
    def test_set_max_position_size(self):
1316
1317
        # Buy one share four times.  Should be fine.
1318
        def handle_data(algo, data):
1319
            algo.order(algo.sid(self.sid), 1)
1320
            algo.order_count += 1
1321
        algo = SetMaxPositionSizeAlgorithm(sid=self.sid,
1322
                                           max_shares=10,
1323
                                           max_notional=500.0,
1324
                                           sim_params=self.sim_params,
1325
                                           env=self.env)
1326
        self.check_algo_succeeds(algo, handle_data)
1327
1328
        # Buy three shares four times.  Should bail on the fourth before it's
1329
        # placed.
1330
        def handle_data(algo, data):
1331
            algo.order(algo.sid(self.sid), 3)
1332
            algo.order_count += 1
1333
1334
        algo = SetMaxPositionSizeAlgorithm(sid=self.sid,
1335
                                           max_shares=10,
1336
                                           max_notional=500.0,
1337
                                           sim_params=self.sim_params,
1338
                                           env=self.env)
1339
        self.check_algo_fails(algo, handle_data, 3)
1340
1341
        # Buy three shares four times. Should bail due to max_notional on the
1342
        # third attempt.
1343
        def handle_data(algo, data):
1344
            algo.order(algo.sid(self.sid), 3)
1345
            algo.order_count += 1
1346
1347
        algo = SetMaxPositionSizeAlgorithm(sid=self.sid,
1348
                                           max_shares=10,
1349
                                           max_notional=67.0,
1350
                                           sim_params=self.sim_params,
1351
                                           env=self.env)
1352
        self.check_algo_fails(algo, handle_data, 2)
1353
1354
        # Set the trading control to a different sid, then BUY ALL THE THINGS!.
1355
        # Should continue normally.
1356
        def handle_data(algo, data):
1357
            algo.order(algo.sid(self.sid), 10000)
1358
            algo.order_count += 1
1359
        algo = SetMaxPositionSizeAlgorithm(sid=self.sid + 1,
1360
                                           max_shares=10,
1361
                                           max_notional=67.0,
1362
                                           sim_params=self.sim_params,
1363
                                           env=self.env)
1364
        self.check_algo_succeeds(algo, handle_data)
1365
1366
        # Set the trading control sid to None, then BUY ALL THE THINGS!. Should
1367
        # fail because setting sid to None makes the control apply to all sids.
1368
        def handle_data(algo, data):
1369
            algo.order(algo.sid(self.sid), 10000)
1370
            algo.order_count += 1
1371
        algo = SetMaxPositionSizeAlgorithm(max_shares=10, max_notional=61.0,
1372
                                           sim_params=self.sim_params,
1373
                                           env=self.env)
1374
        self.check_algo_fails(algo, handle_data, 0)
1375
1376
    def test_set_do_not_order_list(self):
1377
        # set the restricted list to be the sid, and fail.
1378
        algo = SetDoNotOrderListAlgorithm(
1379
            sid=self.sid,
1380
            restricted_list=[self.sid],
1381
            sim_params=self.sim_params,
1382
            env=self.env,
1383
        )
1384
1385
        def handle_data(algo, data):
1386
            algo.order(algo.sid(self.sid), 100)
1387
            algo.order_count += 1
1388
1389
        self.check_algo_fails(algo, handle_data, 0)
1390
1391
        # set the restricted list to exclude the sid, and succeed
1392
        algo = SetDoNotOrderListAlgorithm(
1393
            sid=self.sid,
1394
            restricted_list=[134, 135, 136],
1395
            sim_params=self.sim_params,
1396
            env=self.env,
1397
        )
1398
1399
        def handle_data(algo, data):
1400
            algo.order(algo.sid(self.sid), 100)
1401
            algo.order_count += 1
1402
1403
        self.check_algo_succeeds(algo, handle_data)
1404
1405
    def test_set_max_order_size(self):
1406
1407
        # Buy one share.
1408
        def handle_data(algo, data):
1409
            algo.order(algo.sid(self.sid), 1)
1410
            algo.order_count += 1
1411
        algo = SetMaxOrderSizeAlgorithm(sid=self.sid,
1412
                                        max_shares=10,
1413
                                        max_notional=500.0,
1414
                                        sim_params=self.sim_params,
1415
                                        env=self.env)
1416
        self.check_algo_succeeds(algo, handle_data)
1417
1418
        # Buy 1, then 2, then 3, then 4 shares.  Bail on the last attempt
1419
        # because we exceed shares.
1420
        def handle_data(algo, data):
1421
            algo.order(algo.sid(self.sid), algo.order_count + 1)
1422
            algo.order_count += 1
1423
1424
        algo = SetMaxOrderSizeAlgorithm(sid=self.sid,
1425
                                        max_shares=3,
1426
                                        max_notional=500.0,
1427
                                        sim_params=self.sim_params,
1428
                                        env=self.env)
1429
        self.check_algo_fails(algo, handle_data, 3)
1430
1431
        # Buy 1, then 2, then 3, then 4 shares.  Bail on the last attempt
1432
        # because we exceed notional.
1433
        def handle_data(algo, data):
1434
            algo.order(algo.sid(self.sid), algo.order_count + 1)
1435
            algo.order_count += 1
1436
1437
        algo = SetMaxOrderSizeAlgorithm(sid=self.sid,
1438
                                        max_shares=10,
1439
                                        max_notional=40.0,
1440
                                        sim_params=self.sim_params,
1441
                                        env=self.env)
1442
        self.check_algo_fails(algo, handle_data, 3)
1443
1444
        # Set the trading control to a different sid, then BUY ALL THE THINGS!.
1445
        # Should continue normally.
1446
        def handle_data(algo, data):
1447
            algo.order(algo.sid(self.sid), 10000)
1448
            algo.order_count += 1
1449
        algo = SetMaxOrderSizeAlgorithm(sid=self.sid + 1,
1450
                                        max_shares=1,
1451
                                        max_notional=1.0,
1452
                                        sim_params=self.sim_params,
1453
                                        env=self.env)
1454
        self.check_algo_succeeds(algo, handle_data)
1455
1456
        # Set the trading control sid to None, then BUY ALL THE THINGS!.
1457
        # Should fail because not specifying a sid makes the trading control
1458
        # apply to all sids.
1459
        def handle_data(algo, data):
1460
            algo.order(algo.sid(self.sid), 10000)
1461
            algo.order_count += 1
1462
        algo = SetMaxOrderSizeAlgorithm(max_shares=1,
1463
                                        max_notional=1.0,
1464
                                        sim_params=self.sim_params,
1465
                                        env=self.env)
1466
        self.check_algo_fails(algo, handle_data, 0)
1467
1468
    def test_set_max_order_count(self):
1469
        tempdir = TempDirectory()
1470
        try:
1471
            env = TradingEnvironment()
1472
            sim_params = factory.create_simulation_parameters(
1473
                num_days=4, env=env, data_frequency="minute")
1474
1475
            env.write_data(equities_data={
1476
                1: {
1477
                    'start_date': sim_params.period_start,
1478
                    'end_date': sim_params.period_end + timedelta(days=1)
1479
                }
1480
            })
1481
1482
            data_portal = create_data_portal(
1483
                env,
1484
                tempdir,
1485
                sim_params,
1486
                [1]
1487
            )
1488
1489
            def handle_data(algo, data):
1490
                for i in range(5):
1491
                    algo.order(algo.sid(1), 1)
1492
                    algo.order_count += 1
1493
1494
            algo = SetMaxOrderCountAlgorithm(3, sim_params=sim_params,
1495
                                             env=env)
1496
            with self.assertRaises(TradingControlViolation):
1497
                algo._handle_data = handle_data
1498
                algo.run(data_portal)
1499
1500
            self.assertEqual(algo.order_count, 3)
1501
1502
            # This time, order 5 times twice in a single day. The last order
1503
            # of the second batch should fail.
1504
            def handle_data2(algo, data):
1505
                if algo.minute_count == 0 or algo.minute_count == 100:
1506
                    for i in range(5):
1507
                        algo.order(algo.sid(1), 1)
1508
                        algo.order_count += 1
1509
1510
                algo.minute_count += 1
1511
1512
            algo = SetMaxOrderCountAlgorithm(9, sim_params=sim_params,
1513
                                             env=env)
1514
            with self.assertRaises(TradingControlViolation):
1515
                algo._handle_data = handle_data2
1516
                algo.run(data_portal)
1517
1518
            self.assertEqual(algo.order_count, 9)
1519
1520
            def handle_data3(algo, data):
1521
                if (algo.minute_count % 390) == 0:
1522
                    for i in range(5):
1523
                        algo.order(algo.sid(1), 1)
1524
                        algo.order_count += 1
1525
1526
                algo.minute_count += 1
1527
1528
            # Only 5 orders are placed per day, so this should pass even
1529
            # though in total more than 20 orders are placed.
1530
            algo = SetMaxOrderCountAlgorithm(5, sim_params=sim_params,
1531
                                             env=env)
1532
            algo._handle_data = handle_data3
1533
            algo.run(data_portal)
1534
        finally:
1535
            tempdir.cleanup()
1536
1537
    def test_long_only(self):
1538
        # Sell immediately -> fail immediately.
1539
        def handle_data(algo, data):
1540
            algo.order(algo.sid(self.sid), -1)
1541
            algo.order_count += 1
1542
        algo = SetLongOnlyAlgorithm(sim_params=self.sim_params, env=self.env)
1543
        self.check_algo_fails(algo, handle_data, 0)
1544
1545
        # Buy on even days, sell on odd days.  Never takes a short position, so
1546
        # should succeed.
1547
        def handle_data(algo, data):
1548
            if (algo.order_count % 2) == 0:
1549
                algo.order(algo.sid(self.sid), 1)
1550
            else:
1551
                algo.order(algo.sid(self.sid), -1)
1552
            algo.order_count += 1
1553
        algo = SetLongOnlyAlgorithm(sim_params=self.sim_params, env=self.env)
1554
        self.check_algo_succeeds(algo, handle_data)
1555
1556
        # Buy on first three days, then sell off holdings.  Should succeed.
1557
        def handle_data(algo, data):
1558
            amounts = [1, 1, 1, -3]
1559
            algo.order(algo.sid(self.sid), amounts[algo.order_count])
1560
            algo.order_count += 1
1561
        algo = SetLongOnlyAlgorithm(sim_params=self.sim_params, env=self.env)
1562
        self.check_algo_succeeds(algo, handle_data)
1563
1564
        # Buy on first three days, then sell off holdings plus an extra share.
1565
        # Should fail on the last sale.
1566
        def handle_data(algo, data):
1567
            amounts = [1, 1, 1, -4]
1568
            algo.order(algo.sid(self.sid), amounts[algo.order_count])
1569
            algo.order_count += 1
1570
        algo = SetLongOnlyAlgorithm(sim_params=self.sim_params, env=self.env)
1571
        self.check_algo_fails(algo, handle_data, 3)
1572
1573
    def test_register_post_init(self):
1574
1575
        def initialize(algo):
1576
            algo.initialized = True
1577
1578
        def handle_data(algo, data):
1579
            with self.assertRaises(RegisterTradingControlPostInit):
1580
                algo.set_max_position_size(self.sid, 1, 1)
1581
            with self.assertRaises(RegisterTradingControlPostInit):
1582
                algo.set_max_order_size(self.sid, 1, 1)
1583
            with self.assertRaises(RegisterTradingControlPostInit):
1584
                algo.set_max_order_count(1)
1585
            with self.assertRaises(RegisterTradingControlPostInit):
1586
                algo.set_long_only()
1587
1588
        algo = TradingAlgorithm(initialize=initialize,
1589
                                handle_data=handle_data,
1590
                                sim_params=self.sim_params,
1591
                                env=self.env)
1592
        algo.run(self.data_portal)
1593
1594
    def test_asset_date_bounds(self):
1595
        tempdir = TempDirectory()
1596
        try:
1597
            # Run the algorithm with a sid that ends far in the future
1598
            temp_env = TradingEnvironment()
1599
1600
            data_portal = create_data_portal(
1601
                temp_env,
1602
                tempdir,
1603
                self.sim_params,
1604
                [0]
1605
            )
1606
1607
            metadata = {0: {'start_date': self.sim_params.period_start,
1608
                            'end_date': '2020-01-01'}}
1609
1610
            algo = SetAssetDateBoundsAlgorithm(
1611
                equities_metadata=metadata,
1612
                sim_params=self.sim_params,
1613
                env=temp_env,
1614
            )
1615
            algo.run(data_portal)
1616
        finally:
1617
            tempdir.cleanup()
1618
1619
        # Run the algorithm with a sid that has already ended
1620
        tempdir = TempDirectory()
1621
        try:
1622
            temp_env = TradingEnvironment()
1623
1624
            data_portal = create_data_portal(
1625
                temp_env,
1626
                tempdir,
1627
                self.sim_params,
1628
                [0]
1629
            )
1630
            metadata = {0: {'start_date': '1989-01-01',
1631
                            'end_date': '1990-01-01'}}
1632
1633
            algo = SetAssetDateBoundsAlgorithm(
1634
                equities_metadata=metadata,
1635
                sim_params=self.sim_params,
1636
                env=temp_env,
1637
            )
1638
            with self.assertRaises(TradingControlViolation):
1639
                algo.run(data_portal)
1640
        finally:
1641
            tempdir.cleanup()
1642
1643
        # Run the algorithm with a sid that has not started
1644
        tempdir = TempDirectory()
1645
        try:
1646
            temp_env = TradingEnvironment()
1647
            data_portal = create_data_portal(
1648
                temp_env,
1649
                tempdir,
1650
                self.sim_params,
1651
                [0]
1652
            )
1653
1654
            metadata = {0: {'start_date': '2020-01-01',
1655
                            'end_date': '2021-01-01'}}
1656
1657
            algo = SetAssetDateBoundsAlgorithm(
1658
                equities_metadata=metadata,
1659
                sim_params=self.sim_params,
1660
                env=temp_env,
1661
            )
1662
1663
            with self.assertRaises(TradingControlViolation):
1664
                algo.run(data_portal)
1665
1666
        finally:
1667
            tempdir.cleanup()
1668
1669
1670
class TestAccountControls(TestCase):
1671
1672
    @classmethod
1673
    def setUpClass(cls):
1674
        cls.sidint = 133
1675
        cls.env = TradingEnvironment()
1676
        cls.sim_params = factory.create_simulation_parameters(
1677
            num_days=4, env=cls.env
1678
        )
1679
1680
        cls.env.write_data(equities_data={
1681
            133: {
1682
                'start_date': cls.sim_params.period_start,
1683
                'end_date': cls.sim_params.period_end + timedelta(days=1)
1684
            }
1685
        })
1686
1687
        cls.tempdir = TempDirectory()
1688
1689
        trades_by_sid = {
1690
            cls.sidint: factory.create_trade_history(
1691
                cls.sidint,
1692
                [10.0, 10.0, 11.0, 11.0],
1693
                [100, 100, 100, 300],
1694
                timedelta(days=1),
1695
                cls.sim_params,
1696
                cls.env,
1697
            )
1698
        }
1699
1700
        cls.data_portal = create_data_portal_from_trade_history(cls.env,
1701
                                                                cls.tempdir,
1702
                                                                cls.sim_params,
1703
                                                                trades_by_sid)
1704
1705
    @classmethod
1706
    def tearDownClass(cls):
1707
        del cls.env
1708
        cls.tempdir.cleanup()
1709
1710
    def _check_algo(self,
1711
                    algo,
1712
                    handle_data,
1713
                    expected_exc):
1714
1715
        algo._handle_data = handle_data
1716
        with self.assertRaises(expected_exc) if expected_exc else nullctx():
1717
            algo.run(self.data_portal)
1718
1719
    def check_algo_succeeds(self, algo, handle_data):
1720
        # Default for order_count assumes one order per handle_data call.
1721
        self._check_algo(algo, handle_data, None)
1722
1723
    def check_algo_fails(self, algo, handle_data):
1724
        self._check_algo(algo,
1725
                         handle_data,
1726
                         AccountControlViolation)
1727
1728
    def test_set_max_leverage(self):
1729
1730
        # Set max leverage to 0 so buying one share fails.
1731
        def handle_data(algo, data):
1732
            algo.order(algo.sid(self.sidint), 1)
1733
1734
        algo = SetMaxLeverageAlgorithm(0, sim_params=self.sim_params,
1735
                                       env=self.env)
1736
        self.check_algo_fails(algo, handle_data)
1737
1738
        # Set max leverage to 1 so buying one share passes
1739
        def handle_data(algo, data):
1740
            algo.order(algo.sid(self.sidint), 1)
1741
1742
        algo = SetMaxLeverageAlgorithm(1,  sim_params=self.sim_params,
1743
                                       env=self.env)
1744
        self.check_algo_succeeds(algo, handle_data)
1745
1746
1747
class TestClosePosAlgo(TestCase):
1748
1749
    @classmethod
1750
    def setUpClass(cls):
1751
        cls.tempdir = TempDirectory()
1752
1753
        cls.env = TradingEnvironment()
1754
        cls.days = pd.date_range(start=pd.Timestamp("2006-01-09", tz='UTC'),
1755
                                 end=pd.Timestamp("2006-01-12", tz='UTC'))
1756
1757
        cls.sid = 1
1758
1759
        cls.sim_params = factory.create_simulation_parameters(
1760
            start=cls.days[0],
1761
            end=cls.days[-1]
1762
        )
1763
1764
        trades_by_sid = {}
1765
        trades_by_sid[cls.sid] = factory.create_trade_history(
1766
            cls.sid,
1767
            [1, 1, 2, 4],
1768
            [1e9, 1e9, 1e9, 1e9],
1769
            timedelta(days=1),
1770
            cls.sim_params,
1771
            cls.env
1772
        )
1773
1774
        cls.data_portal = create_data_portal_from_trade_history(
1775
            cls.env,
1776
            cls.tempdir,
1777
            cls.sim_params,
1778
            trades_by_sid
1779
        )
1780
1781
    @classmethod
1782
    def tearDownClass(cls):
1783
        cls.tempdir.cleanup()
1784
1785
    def test_auto_close_future(self):
1786
        self.env.write_data(futures_data={
1787
            1: {
1788
                "start_date": self.sim_params.trading_days[0],
1789
                "end_date": self.env.next_trading_day(
1790
                    self.sim_params.trading_days[-1]),
1791
                'symbol': 'TEST',
1792
                'asset_type': 'future',
1793
                'auto_close_date': self.env.next_trading_day(
1794
                    self.sim_params.trading_days[-1])
1795
            }
1796
        })
1797
1798
        algo = TestAlgorithm(sid=1, amount=1, order_count=1,
1799
                             commission=PerShare(0),
1800
                             env=self.env,
1801
                             sim_params=self.sim_params)
1802
1803
        # Check results
1804
        results = algo.run(self.data_portal)
1805
1806
        expected_pnl = [0, 0, 1, 2]
1807
        self.check_algo_pnl(results, expected_pnl)
1808
1809
        expected_positions = [0, 1, 1, 0]
1810
        self.check_algo_positions(results, expected_positions)
1811
1812
        expected_pnl = [0, 0, 1, 2]
1813
        self.check_algo_pnl(results, expected_pnl)
1814
1815
    def check_algo_pnl(self, results, expected_pnl):
1816
        np.testing.assert_array_almost_equal(results.pnl, expected_pnl)
1817
1818
    def check_algo_positions(self, results, expected_positions):
1819
        for i, amount in enumerate(results.positions):
1820
            if amount:
1821
                actual_position = amount[0]['amount']
1822
            else:
1823
                actual_position = 0
1824
1825
            self.assertEqual(
1826
                actual_position, expected_positions[i],
1827
                "position for day={0} not equal, actual={1}, expected={2}".
1828
                format(i, actual_position, expected_positions[i]))
1829
1830
1831
class TestFutureFlip(TestCase):
1832
    @classmethod
1833
    def setUpClass(cls):
1834
        cls.tempdir = TempDirectory()
1835
1836
        cls.env = TradingEnvironment()
1837
        cls.days = pd.date_range(start=pd.Timestamp("2006-01-09", tz='UTC'),
1838
                                 end=pd.Timestamp("2006-01-12", tz='UTC'))
1839
1840
        cls.sid = 1
1841
1842
        cls.sim_params = factory.create_simulation_parameters(
1843
            start=cls.days[0],
1844
            end=cls.days[-2]
1845
        )
1846
1847
        trades = factory.create_trade_history(
1848
            cls.sid,
1849
            [1, 2, 4],
1850
            [1e9, 1e9, 1e9],
1851
            timedelta(days=1),
1852
            cls.sim_params,
1853
            cls.env
1854
        )
1855
1856
        trades_by_sid = {
1857
            cls.sid: trades
1858
        }
1859
1860
        cls.data_portal = create_data_portal_from_trade_history(
1861
            cls.env,
1862
            cls.tempdir,
1863
            cls.sim_params,
1864
            trades_by_sid
1865
        )
1866
1867
    @classmethod
1868
    def tearDownClass(cls):
1869
        cls.tempdir.cleanup()
1870
1871
    def test_flip_algo(self):
1872
        metadata = {1: {'symbol': 'TEST',
1873
                        'start_date': self.sim_params.trading_days[0],
1874
                        'end_date': self.env.next_trading_day(
1875
                            self.sim_params.trading_days[-1]),
1876
                        'contract_multiplier': 5}}
1877
        self.env.write_data(futures_data=metadata)
1878
1879
        algo = FutureFlipAlgo(sid=1, amount=1, env=self.env,
1880
                              commission=PerShare(0),
1881
                              order_count=0,  # not applicable but required
1882
                              sim_params=self.sim_params)
1883
1884
        results = algo.run(self.data_portal)
1885
1886
        expected_positions = [0, 1, -1]
1887
        self.check_algo_positions(results, expected_positions)
1888
1889
        expected_pnl = [0, 5, -10]
1890
        self.check_algo_pnl(results, expected_pnl)
1891
1892
    def check_algo_pnl(self, results, expected_pnl):
1893
        np.testing.assert_array_almost_equal(results.pnl, expected_pnl)
1894
1895
    def check_algo_positions(self, results, expected_positions):
1896
        for i, amount in enumerate(results.positions):
1897
            if amount:
1898
                actual_position = amount[0]['amount']
1899
            else:
1900
                actual_position = 0
1901
1902
            self.assertEqual(
1903
                actual_position, expected_positions[i],
1904
                "position for day={0} not equal, actual={1}, expected={2}".
1905
                format(i, actual_position, expected_positions[i]))
1906
1907
1908
class TestTradingAlgorithm(TestCase):
1909
    def setUp(self):
1910
        self.env = TradingEnvironment()
1911
        self.days = self.env.trading_days[:4]
1912
1913
    def test_analyze_called(self):
1914
        self.perf_ref = None
1915
1916
        def initialize(context):
1917
            pass
1918
1919
        def handle_data(context, data):
1920
            pass
1921
1922
        def analyze(context, perf):
1923
            self.perf_ref = perf
1924
1925
        algo = TradingAlgorithm(initialize=initialize, handle_data=handle_data,
1926
                                analyze=analyze)
1927
1928
        data_portal = FakeDataPortal()
1929
1930
        results = algo.run(data_portal)
1931
        self.assertIs(results, self.perf_ref)
1932