Completed
Pull Request — master (#858)
by Eddie
01:55
created

tests.TestTradingControls.check_algo_succeeds()   A

Complexity

Conditions 1

Size

Total Lines 3

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 1
dl 0
loc 3
rs 10
1
#
2
# Copyright 2014 Quantopian, Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
import datetime
16
from datetime import timedelta
17
from mock import MagicMock
18
from nose_parameterized import parameterized
19
from six.moves import range
20
from testfixtures import TempDirectory
21
from textwrap import dedent
22
from unittest import TestCase
23
24
import numpy as np
25
import pandas as pd
26
27
from zipline.assets import Equity, Future
28
from zipline.utils.api_support import ZiplineAPI
29
from zipline.utils.control_flow import nullctx
30
from zipline.utils.test_utils import (
31
    setup_logger,
32
    teardown_logger,
33
    FakeDataPortal)
34
import zipline.utils.factory as factory
35
36
from zipline.errors import (
37
    OrderDuringInitialize,
38
    RegisterTradingControlPostInit,
39
    TradingControlViolation,
40
    AccountControlViolation,
41
    SymbolNotFound,
42
    RootSymbolNotFound,
43
    UnsupportedDatetimeFormat,
44
)
45
from zipline.test_algorithms import (
46
    access_account_in_init,
47
    access_portfolio_in_init,
48
    AmbitiousStopLimitAlgorithm,
49
    EmptyPositionsAlgorithm,
50
    InvalidOrderAlgorithm,
51
    RecordAlgorithm,
52
    FutureFlipAlgo,
53
    TestAlgorithm,
54
    TestOrderAlgorithm,
55
    TestOrderPercentAlgorithm,
56
    TestOrderStyleForwardingAlgorithm,
57
    TestOrderValueAlgorithm,
58
    TestRegisterTransformAlgorithm,
59
    TestTargetAlgorithm,
60
    TestTargetPercentAlgorithm,
61
    TestTargetValueAlgorithm,
62
    SetLongOnlyAlgorithm,
63
    SetAssetDateBoundsAlgorithm,
64
    SetMaxPositionSizeAlgorithm,
65
    SetMaxOrderCountAlgorithm,
66
    SetMaxOrderSizeAlgorithm,
67
    SetDoNotOrderListAlgorithm,
68
    SetMaxLeverageAlgorithm,
69
    api_algo,
70
    api_get_environment_algo,
71
    api_symbol_algo,
72
    call_all_order_methods,
73
    call_order_in_init,
74
    handle_data_api,
75
    handle_data_noop,
76
    initialize_api,
77
    initialize_noop,
78
    noop_algo,
79
    record_float_magic,
80
    record_variables,
81
)
82
from zipline.utils.context_tricks import CallbackManager
83
import zipline.utils.events
84
from zipline.utils.test_utils import to_utc
85
86
from zipline.finance.execution import LimitOrder
87
from zipline.finance.trading import SimulationParameters
88
from zipline.utils.api_support import set_algo_instance
89
from zipline.utils.events import DateRuleFactory, TimeRuleFactory, Always
90
from zipline.algorithm import TradingAlgorithm
91
from zipline.finance.trading import TradingEnvironment
92
from zipline.finance.commission import PerShare
93
94
from zipline.utils.test_utils import (
95
    create_data_portal,
96
    create_data_portal_from_trade_history
97
)
98
99
# Because test cases appear to reuse some resources.
100
101
102
_multiprocess_can_split_ = False
103
104
105
class TestRecordAlgorithm(TestCase):
106
107
    @classmethod
108
    def setUpClass(cls):
109
        cls.env = TradingEnvironment()
110
        cls.sids = [133]
111
        cls.env.write_data(equities_identifiers=cls.sids)
112
113
        cls.sim_params = factory.create_simulation_parameters(
114
            num_days=4,
115
            env=cls.env
116
        )
117
118
        cls.tempdir = TempDirectory()
119
120
        cls.data_portal = create_data_portal(
121
            cls.env,
122
            cls.tempdir,
123
            cls.sim_params,
124
            cls.sids
125
        )
126
127
    @classmethod
128
    def tearDownClass(cls):
129
        del cls.env
130
        cls.tempdir.cleanup()
131
132
    def test_record_incr(self):
133
        algo = RecordAlgorithm(sim_params=self.sim_params, env=self.env)
134
        output = algo.run(self.data_portal)
135
136
        np.testing.assert_array_equal(output['incr'].values,
137
                                      range(1, len(output) + 1))
138
        np.testing.assert_array_equal(output['name'].values,
139
                                      range(1, len(output) + 1))
140
        np.testing.assert_array_equal(output['name2'].values,
141
                                      [2] * len(output))
142
        np.testing.assert_array_equal(output['name3'].values,
143
                                      range(1, len(output) + 1))
144
145
146
class TestMiscellaneousAPI(TestCase):
147
148
    @classmethod
149
    def setUpClass(cls):
150
        cls.sids = [1, 2]
151
        cls.env = TradingEnvironment()
152
153
        metadata = {3: {'symbol': 'PLAY',
154
                        'start_date': '2002-01-01',
155
                        'end_date': '2004-01-01'},
156
                    4: {'symbol': 'PLAY',
157
                        'start_date': '2005-01-01',
158
                        'end_date': '2006-01-01'}}
159
160
        futures_metadata = {
161
            5: {
162
                'symbol': 'CLG06',
163
                'root_symbol': 'CL',
164
                'start_date': pd.Timestamp('2005-12-01', tz='UTC'),
165
                'notice_date': pd.Timestamp('2005-12-20', tz='UTC'),
166
                'expiration_date': pd.Timestamp('2006-01-20', tz='UTC')},
167
            6: {
168
                'root_symbol': 'CL',
169
                'symbol': 'CLK06',
170
                'start_date': pd.Timestamp('2005-12-01', tz='UTC'),
171
                'notice_date': pd.Timestamp('2006-03-20', tz='UTC'),
172
                'expiration_date': pd.Timestamp('2006-04-20', tz='UTC')},
173
            7: {
174
                'symbol': 'CLQ06',
175
                'root_symbol': 'CL',
176
                'start_date': pd.Timestamp('2005-12-01', tz='UTC'),
177
                'notice_date': pd.Timestamp('2006-06-20', tz='UTC'),
178
                'expiration_date': pd.Timestamp('2006-07-20', tz='UTC')},
179
            8: {
180
                'symbol': 'CLX06',
181
                'root_symbol': 'CL',
182
                'start_date': pd.Timestamp('2006-02-01', tz='UTC'),
183
                'notice_date': pd.Timestamp('2006-09-20', tz='UTC'),
184
                'expiration_date': pd.Timestamp('2006-10-20', tz='UTC')}
185
        }
186
        cls.env.write_data(equities_identifiers=cls.sids,
187
                           equities_data=metadata,
188
                           futures_data=futures_metadata)
189
190
        setup_logger(cls)
191
192
        cls.sim_params = factory.create_simulation_parameters(
193
            num_days=2,
194
            data_frequency='minute',
195
            emission_rate='daily',
196
            env=cls.env,
197
        )
198
199
        cls.temp_dir = TempDirectory()
200
201
        cls.data_portal = create_data_portal(
202
            cls.env,
203
            cls.temp_dir,
204
            cls.sim_params,
205
            cls.sids
206
        )
207
208
    @classmethod
209
    def tearDownClass(cls):
210
        del cls.env
211
        teardown_logger(cls)
212
        cls.temp_dir.cleanup()
213
214
    def test_zipline_api_resolves_dynamically(self):
215
        # Make a dummy algo.
216
        algo = TradingAlgorithm(
217
            initialize=lambda context: None,
218
            handle_data=lambda context, data: None,
219
            sim_params=self.sim_params,
220
        )
221
222
        # Verify that api methods get resolved dynamically by patching them out
223
        # and then calling them
224
        for method in algo.all_api_methods():
225
            name = method.__name__
226
            sentinel = object()
227
228
            def fake_method(*args, **kwargs):
229
                return sentinel
230
            setattr(algo, name, fake_method)
231
            with ZiplineAPI(algo):
232
                self.assertIs(sentinel, getattr(zipline.api, name)())
233
234
    def test_cannot_iterate_over_data(self):
235
        def initialize(algo):
236
            pass
237
238
        def handle_data(algo, data):
239
            for asset in data:
240
                pass
241
242
        algo = TradingAlgorithm(initialize=initialize,
243
                                handle_data=handle_data,
244
                                sim_params=self.sim_params,
245
                                env=self.env)
246
247
        with self.assertRaises(ValueError):
248
            algo.run(self.data_portal)
249
250
    def test_get_environment(self):
251
        expected_env = {
252
            'arena': 'backtest',
253
            'data_frequency': 'minute',
254
            'start': pd.Timestamp('2006-01-03 14:31:00+0000', tz='UTC'),
255
            'end': pd.Timestamp('2006-01-04 21:00:00+0000', tz='UTC'),
256
            'capital_base': 100000.0,
257
            'platform': 'zipline'
258
        }
259
260
        def initialize(algo):
261
            self.assertEqual('zipline', algo.get_environment())
262
            self.assertEqual(expected_env, algo.get_environment('*'))
263
264
        def handle_data(algo, data):
265
            pass
266
267
        algo = TradingAlgorithm(initialize=initialize,
268
                                handle_data=handle_data,
269
                                sim_params=self.sim_params,
270
                                env=self.env)
271
        algo.run(self.data_portal)
272
273
    def test_get_open_orders(self):
274
        def initialize(algo):
275
            algo.minute = 0
276
277
        def handle_data(algo, data):
278
            if algo.minute == 0:
279
280
                # Should be filled by the next minute
281
                algo.order(algo.sid(1), 1)
282
283
                # Won't be filled because the price is too low.
284
                algo.order(algo.sid(2), 1, style=LimitOrder(0.01))
285
                algo.order(algo.sid(2), 1, style=LimitOrder(0.01))
286
                algo.order(algo.sid(2), 1, style=LimitOrder(0.01))
287
288
                all_orders = algo.get_open_orders()
289
                self.assertEqual(list(all_orders.keys()), [1, 2])
290
291
                self.assertEqual(all_orders[1], algo.get_open_orders(1))
292
                self.assertEqual(len(all_orders[1]), 1)
293
294
                self.assertEqual(all_orders[2], algo.get_open_orders(2))
295
                self.assertEqual(len(all_orders[2]), 3)
296
297
            if algo.minute == 1:
298
                # First order should have filled.
299
                # Second order should still be open.
300
                all_orders = algo.get_open_orders()
301
                self.assertEqual(list(all_orders.keys()), [2])
302
303
                self.assertEqual([], algo.get_open_orders(1))
304
305
                orders_2 = algo.get_open_orders(2)
306
                self.assertEqual(all_orders[2], orders_2)
307
                self.assertEqual(len(all_orders[2]), 3)
308
309
                for order in orders_2:
310
                    algo.cancel_order(order)
311
312
                all_orders = algo.get_open_orders()
313
                self.assertEqual(all_orders, {})
314
315
            algo.minute += 1
316
317
        algo = TradingAlgorithm(initialize=initialize,
318
                                handle_data=handle_data,
319
                                sim_params=self.sim_params,
320
                                env=self.env)
321
        algo.run(self.data_portal)
322
323
    def test_schedule_function(self):
324
        date_rules = DateRuleFactory
325
        time_rules = TimeRuleFactory
326
327
        def incrementer(algo, data):
328
            algo.func_called += 1
329
            self.assertEqual(
330
                algo.get_datetime().time(),
331
                datetime.time(hour=14, minute=31),
332
            )
333
334
        def initialize(algo):
335
            algo.func_called = 0
336
            algo.days = 1
337
            algo.date = None
338
            algo.schedule_function(
339
                func=incrementer,
340
                date_rule=date_rules.every_day(),
341
                time_rule=time_rules.market_open(),
342
            )
343
344
        def handle_data(algo, data):
345
            if not algo.date:
346
                algo.date = algo.get_datetime().date()
347
348
            if algo.date < algo.get_datetime().date():
349
                algo.days += 1
350
                algo.date = algo.get_datetime().date()
351
352
        algo = TradingAlgorithm(
353
            initialize=initialize,
354
            handle_data=handle_data,
355
            sim_params=self.sim_params,
356
            env=self.env,
357
        )
358
        algo.run(self.data_portal)
359
360
        self.assertEqual(algo.func_called, algo.days)
361
362
    def test_event_context(self):
363
        expected_data = []
364
        collected_data_pre = []
365
        collected_data_post = []
366
        function_stack = []
367
368
        def pre(data):
369
            function_stack.append(pre)
370
            collected_data_pre.append(data)
371
372
        def post(data):
373
            function_stack.append(post)
374
            collected_data_post.append(data)
375
376
        def initialize(context):
377
            context.add_event(Always(), f)
378
            context.add_event(Always(), g)
379
380
        def handle_data(context, data):
381
            function_stack.append(handle_data)
382
            expected_data.append(data)
383
384
        def f(context, data):
385
            function_stack.append(f)
386
387
        def g(context, data):
388
            function_stack.append(g)
389
390
        algo = TradingAlgorithm(
391
            initialize=initialize,
392
            handle_data=handle_data,
393
            sim_params=self.sim_params,
394
            create_event_context=CallbackManager(pre, post),
395
            env=self.env,
396
        )
397
        algo.run(self.data_portal)
398
399
        self.assertEqual(len(expected_data), 780)
400
        self.assertEqual(collected_data_pre, expected_data)
401
        self.assertEqual(collected_data_post, expected_data)
402
403
        self.assertEqual(
404
            len(function_stack),
405
            780 * 5,
406
            'Incorrect number of functions called: %s != 780' %
407
            len(function_stack),
408
        )
409
        expected_functions = [pre, handle_data, f, g, post] * 780
410
        for n, (f, g) in enumerate(zip(function_stack, expected_functions)):
411
            self.assertEqual(
412
                f,
413
                g,
414
                'function at position %d was incorrect, expected %s but got %s'
415
                % (n, g.__name__, f.__name__),
416
            )
417
418
    @parameterized.expand([
419
        ('daily',),
420
        ('minute'),
421
    ])
422
    def test_schedule_function_rule_creation(self, mode):
423
        def nop(*args, **kwargs):
424
            return None
425
426
        self.sim_params.data_frequency = mode
427
        algo = TradingAlgorithm(
428
            initialize=nop,
429
            handle_data=nop,
430
            sim_params=self.sim_params,
431
            env=self.env,
432
        )
433
434
        # Schedule something for NOT Always.
435
        algo.schedule_function(nop, time_rule=zipline.utils.events.Never())
436
437
        event_rule = algo.event_manager._events[1].rule
438
439
        self.assertIsInstance(event_rule, zipline.utils.events.OncePerDay)
440
441
        inner_rule = event_rule.rule
442
        self.assertIsInstance(inner_rule, zipline.utils.events.ComposedRule)
443
444
        first = inner_rule.first
445
        second = inner_rule.second
446
        composer = inner_rule.composer
447
448
        self.assertIsInstance(first, zipline.utils.events.Always)
449
450
        if mode == 'daily':
451
            self.assertIsInstance(second, zipline.utils.events.Always)
452
        else:
453
            self.assertIsInstance(second, zipline.utils.events.Never)
454
455
        self.assertIs(composer, zipline.utils.events.ComposedRule.lazy_and)
456
457
    def test_asset_lookup(self):
458
459
        algo = TradingAlgorithm(env=self.env)
460
461
        # Test before either PLAY existed
462
        algo.sim_params.period_end = pd.Timestamp('2001-12-01', tz='UTC')
463
        with self.assertRaises(SymbolNotFound):
464
            algo.symbol('PLAY')
465
        with self.assertRaises(SymbolNotFound):
466
            algo.symbols('PLAY')
467
468
        # Test when first PLAY exists
469
        algo.sim_params.period_end = pd.Timestamp('2002-12-01', tz='UTC')
470
        list_result = algo.symbols('PLAY')
471
        self.assertEqual(3, list_result[0])
472
473
        # Test after first PLAY ends
474
        algo.sim_params.period_end = pd.Timestamp('2004-12-01', tz='UTC')
475
        self.assertEqual(3, algo.symbol('PLAY'))
476
477
        # Test after second PLAY begins
478
        algo.sim_params.period_end = pd.Timestamp('2005-12-01', tz='UTC')
479
        self.assertEqual(4, algo.symbol('PLAY'))
480
481
        # Test after second PLAY ends
482
        algo.sim_params.period_end = pd.Timestamp('2006-12-01', tz='UTC')
483
        self.assertEqual(4, algo.symbol('PLAY'))
484
        list_result = algo.symbols('PLAY')
485
        self.assertEqual(4, list_result[0])
486
487
        # Test lookup SID
488
        self.assertIsInstance(algo.sid(3), Equity)
489
        self.assertIsInstance(algo.sid(4), Equity)
490
491
        # Supplying a non-string argument to symbol()
492
        # should result in a TypeError.
493
        with self.assertRaises(TypeError):
494
            algo.symbol(1)
495
496
        with self.assertRaises(TypeError):
497
            algo.symbol((1,))
498
499
        with self.assertRaises(TypeError):
500
            algo.symbol({1})
501
502
        with self.assertRaises(TypeError):
503
            algo.symbol([1])
504
505
        with self.assertRaises(TypeError):
506
            algo.symbol({'foo': 'bar'})
507
508
    def test_future_symbol(self):
509
        """ Tests the future_symbol API function.
510
        """
511
        algo = TradingAlgorithm(env=self.env)
512
        algo.datetime = pd.Timestamp('2006-12-01', tz='UTC')
513
514
        # Check that we get the correct fields for the CLG06 symbol
515
        cl = algo.future_symbol('CLG06')
516
        self.assertEqual(cl.sid, 5)
517
        self.assertEqual(cl.symbol, 'CLG06')
518
        self.assertEqual(cl.root_symbol, 'CL')
519
        self.assertEqual(cl.start_date, pd.Timestamp('2005-12-01', tz='UTC'))
520
        self.assertEqual(cl.notice_date, pd.Timestamp('2005-12-20', tz='UTC'))
521
        self.assertEqual(cl.expiration_date,
522
                         pd.Timestamp('2006-01-20', tz='UTC'))
523
524
        with self.assertRaises(SymbolNotFound):
525
            algo.future_symbol('')
526
527
        with self.assertRaises(SymbolNotFound):
528
            algo.future_symbol('PLAY')
529
530
        with self.assertRaises(SymbolNotFound):
531
            algo.future_symbol('FOOBAR')
532
533
        # Supplying a non-string argument to future_symbol()
534
        # should result in a TypeError.
535
        with self.assertRaises(TypeError):
536
            algo.future_symbol(1)
537
538
        with self.assertRaises(TypeError):
539
            algo.future_symbol((1,))
540
541
        with self.assertRaises(TypeError):
542
            algo.future_symbol({1})
543
544
        with self.assertRaises(TypeError):
545
            algo.future_symbol([1])
546
547
        with self.assertRaises(TypeError):
548
            algo.future_symbol({'foo': 'bar'})
549
550
    def test_future_chain(self):
551
        """ Tests the future_chain API function.
552
        """
553
        algo = TradingAlgorithm(env=self.env)
554
        algo.datetime = pd.Timestamp('2006-12-01', tz='UTC')
555
556
        # Check that the fields of the FutureChain object are set correctly
557
        cl = algo.future_chain('CL')
558
        self.assertEqual(cl.root_symbol, 'CL')
559
        self.assertEqual(cl.as_of_date, algo.datetime)
560
561
        # Check the fields are set correctly if an as_of_date is supplied
562
        as_of_date = pd.Timestamp('1952-08-11', tz='UTC')
563
564
        cl = algo.future_chain('CL', as_of_date=as_of_date)
565
        self.assertEqual(cl.root_symbol, 'CL')
566
        self.assertEqual(cl.as_of_date, as_of_date)
567
568
        cl = algo.future_chain('CL', as_of_date='1952-08-11')
569
        self.assertEqual(cl.root_symbol, 'CL')
570
        self.assertEqual(cl.as_of_date, as_of_date)
571
572
        # Check that weird capitalization is corrected
573
        cl = algo.future_chain('cL')
574
        self.assertEqual(cl.root_symbol, 'CL')
575
576
        cl = algo.future_chain('cl')
577
        self.assertEqual(cl.root_symbol, 'CL')
578
579
        # Check that invalid root symbols raise RootSymbolNotFound
580
        with self.assertRaises(RootSymbolNotFound):
581
            algo.future_chain('CLZ')
582
583
        with self.assertRaises(RootSymbolNotFound):
584
            algo.future_chain('')
585
586
        # Check that invalid dates raise UnsupportedDatetimeFormat
587
        with self.assertRaises(UnsupportedDatetimeFormat):
588
            algo.future_chain('CL', 'my_finger_slipped')
589
590
        with self.assertRaises(UnsupportedDatetimeFormat):
591
            algo.future_chain('CL', '2015-09-')
592
593
        # Supplying a non-string argument to future_chain()
594
        # should result in a TypeError.
595
        with self.assertRaises(TypeError):
596
            algo.future_chain(1)
597
598
        with self.assertRaises(TypeError):
599
            algo.future_chain((1,))
600
601
        with self.assertRaises(TypeError):
602
            algo.future_chain({1})
603
604
        with self.assertRaises(TypeError):
605
            algo.future_chain([1])
606
607
        with self.assertRaises(TypeError):
608
            algo.future_chain({'foo': 'bar'})
609
610
    def test_set_symbol_lookup_date(self):
611
        """
612
        Test the set_symbol_lookup_date API method.
613
        """
614
        # Note we start sid enumeration at i+3 so as not to
615
        # collide with sids [1, 2] added in the setUp() method.
616
        dates = pd.date_range('2013-01-01', freq='2D', periods=2, tz='UTC')
617
        # Create two assets with the same symbol but different
618
        # non-overlapping date ranges.
619
        metadata = pd.DataFrame.from_records(
620
            [
621
                {
622
                    'sid': i + 3,
623
                    'symbol': 'DUP',
624
                    'start_date': date.value,
625
                    'end_date': (date + timedelta(days=1)).value,
626
                }
627
                for i, date in enumerate(dates)
628
            ]
629
        )
630
        env = TradingEnvironment()
631
        env.write_data(equities_df=metadata)
632
        algo = TradingAlgorithm(env=env)
633
634
        # Set the period end to a date after the period end
635
        # dates for our assets.
636
        algo.sim_params.period_end = pd.Timestamp('2015-01-01', tz='UTC')
637
638
        # With no symbol lookup date set, we will use the period end date
639
        # for the as_of_date, resulting here in the asset with the earlier
640
        # start date being returned.
641
        result = algo.symbol('DUP')
642
        self.assertEqual(result.symbol, 'DUP')
643
644
        # By first calling set_symbol_lookup_date, the relevant asset
645
        # should be returned by lookup_symbol
646
        for i, date in enumerate(dates):
647
            algo.set_symbol_lookup_date(date)
648
            result = algo.symbol('DUP')
649
            self.assertEqual(result.symbol, 'DUP')
650
            self.assertEqual(result.sid, i + 3)
651
652
        with self.assertRaises(UnsupportedDatetimeFormat):
653
            algo.set_symbol_lookup_date('foobar')
654
655
656
class TestTransformAlgorithm(TestCase):
657
658
    @classmethod
659
    def setUpClass(cls):
660
        setup_logger(cls)
661
        cls.env = TradingEnvironment()
662
        cls.sim_params = factory.create_simulation_parameters(num_days=4,
663
                                                              env=cls.env)
664
        cls.sids = [0, 1, 133]
665
        cls.tempdir = TempDirectory()
666
667
        futures_metadata = {3: {'contract_multiplier': 10}}
668
        equities_metadata = {}
669
670
        for sid in cls.sids:
671
            equities_metadata[sid] = {
672
                'start_date': cls.sim_params.period_start,
673
                'end_date': cls.sim_params.period_end
674
            }
675
676
        cls.env.write_data(equities_data=equities_metadata,
677
                           futures_data=futures_metadata)
678
679
        trades_by_sid = {}
680
        for sid in cls.sids:
681
            trades_by_sid[sid] = factory.create_trade_history(
682
                sid,
683
                [10.0, 10.0, 11.0, 11.0],
684
                [100, 100, 100, 300],
685
                timedelta(days=1),
686
                cls.sim_params,
687
                cls.env
688
            )
689
690
        cls.data_portal = create_data_portal_from_trade_history(cls.env,
691
                                                                cls.tempdir,
692
                                                                cls.sim_params,
693
                                                                trades_by_sid)
694
695
        cls.data_portal.current_day = cls.sim_params.trading_days[0]
696
697
    @classmethod
698
    def tearDownClass(cls):
699
        teardown_logger(cls)
700
        del cls.env
701
        cls.tempdir.cleanup()
702
703
    def test_invalid_order_parameters(self):
704
        algo = InvalidOrderAlgorithm(
705
            sids=[133],
706
            sim_params=self.sim_params,
707
            env=self.env,
708
        )
709
        algo.run(self.data_portal)
710
711
    def test_run_twice(self):
712
        algo1 = TestRegisterTransformAlgorithm(
713
            sim_params=self.sim_params,
714
            sids=[0, 1]
715
        )
716
717
        res1 = algo1.run(self.data_portal)
718
719
        # Create a new trading algorithm, which will
720
        # use the newly instantiated environment.
721
        algo2 = TestRegisterTransformAlgorithm(
722
            sim_params=self.sim_params,
723
            sids=[0, 1]
724
        )
725
726
        res2 = algo2.run(self.data_portal)
727
728
        # FIXME I think we are getting Nans due to fixed benchmark,
729
        # so dropping them for now.
730
        res1 = res1.fillna(method='ffill')
731
        res2 = res2.fillna(method='ffill')
732
733
        np.testing.assert_array_equal(res1, res2)
734
735
    def test_data_frequency_setting(self):
736
        self.sim_params.data_frequency = 'daily'
737
738
        sim_params = factory.create_simulation_parameters(
739
            num_days=4, env=self.env, data_frequency='daily')
740
741
        algo = TestRegisterTransformAlgorithm(
742
            sim_params=sim_params,
743
            env=self.env,
744
        )
745
        self.assertEqual(algo.sim_params.data_frequency, 'daily')
746
747
        sim_params = factory.create_simulation_parameters(
748
            num_days=4, env=self.env, data_frequency='minute')
749
750
        algo = TestRegisterTransformAlgorithm(
751
            sim_params=sim_params,
752
            env=self.env,
753
        )
754
        self.assertEqual(algo.sim_params.data_frequency, 'minute')
755
756
    @parameterized.expand([
757
        (TestOrderAlgorithm,),
758
        (TestOrderValueAlgorithm,),
759
        (TestTargetAlgorithm,),
760
        (TestOrderPercentAlgorithm,),
761
        (TestTargetPercentAlgorithm,),
762
        (TestTargetValueAlgorithm,),
763
    ])
764
    def test_order_methods(self, algo_class):
765
        algo = algo_class(
766
            sim_params=self.sim_params,
767
            env=self.env,
768
        )
769
        # Ensure that the environment's asset 0 is an Equity
770
        asset_to_test = algo.sid(0)
771
        self.assertIsInstance(asset_to_test, Equity)
772
773
        algo.run(self.data_portal)
774
775
    @parameterized.expand([
776
        (TestOrderAlgorithm,),
777
        (TestOrderValueAlgorithm,),
778
        (TestTargetAlgorithm,),
779
        (TestOrderPercentAlgorithm,),
780
        (TestTargetValueAlgorithm,),
781
    ])
782
    def test_order_methods_for_future(self, algo_class):
783
        algo = algo_class(
784
            sim_params=self.sim_params,
785
            env=self.env,
786
        )
787
        # Ensure that the environment's asset 0 is a Future
788
        asset_to_test = algo.sid(3)
789
        self.assertIsInstance(asset_to_test, Future)
790
791
        algo.run(self.data_portal)
792
793
    @parameterized.expand([
794
        ("order",),
795
        ("order_value",),
796
        ("order_percent",),
797
        ("order_target",),
798
        ("order_target_percent",),
799
        ("order_target_value",),
800
    ])
801
    def test_order_method_style_forwarding(self, order_style):
802
        algo = TestOrderStyleForwardingAlgorithm(
803
            sim_params=self.sim_params,
804
            method_name=order_style,
805
            env=self.env
806
        )
807
        algo.run(self.data_portal)
808
809
    @parameterized.expand([
810
        (TestOrderAlgorithm,),
811
        (TestOrderValueAlgorithm,),
812
        (TestTargetAlgorithm,),
813
        (TestOrderPercentAlgorithm,)
814
    ])
815
    def test_minute_data(self, algo_class):
816
        tempdir = TempDirectory()
817
818
        try:
819
            env = TradingEnvironment()
820
821
            sim_params = SimulationParameters(
822
                period_start=pd.Timestamp('2002-1-2', tz='UTC'),
823
                period_end=pd.Timestamp('2002-1-4', tz='UTC'),
824
                capital_base=float("1.0e5"),
825
                data_frequency='minute',
826
                env=env
827
            )
828
829
            equities_metadata = {}
830
831
            for sid in [0, 1]:
832
                equities_metadata[sid] = {
833
                    'start_date': sim_params.period_start,
834
                    'end_date': sim_params.period_end + timedelta(days=1)
835
                }
836
837
            env.write_data(equities_data=equities_metadata)
838
839
            data_portal = create_data_portal(
840
                env,
841
                tempdir,
842
                sim_params,
843
                [0, 1]
844
            )
845
846
            algo = algo_class(sim_params=sim_params, env=env)
847
            algo.run(data_portal)
848
        finally:
849
            tempdir.cleanup()
850
851
852
class TestPositions(TestCase):
853
    @classmethod
854
    def setUpClass(cls):
855
        setup_logger(cls)
856
        cls.env = TradingEnvironment()
857
        cls.sim_params = factory.create_simulation_parameters(num_days=4,
858
                                                              env=cls.env)
859
860
        cls.sids = [1, 133]
861
        cls.tempdir = TempDirectory()
862
863
        equities_metadata = {}
864
865
        for sid in cls.sids:
866
            equities_metadata[sid] = {
867
                'start_date': cls.sim_params.period_start,
868
                'end_date': cls.sim_params.period_end
869
            }
870
871
        cls.env.write_data(equities_data=equities_metadata)
872
873
        cls.data_portal = create_data_portal(
874
            cls.env,
875
            cls.tempdir,
876
            cls.sim_params,
877
            cls.sids
878
        )
879
880
    @classmethod
881
    def tearDownClass(cls):
882
        teardown_logger(cls)
883
        cls.tempdir.cleanup()
884
885
    def test_empty_portfolio(self):
886
        algo = EmptyPositionsAlgorithm(self.sids,
887
                                       sim_params=self.sim_params,
888
                                       env=self.env)
889
        daily_stats = algo.run(self.data_portal)
890
891
        expected_position_count = [
892
            0,  # Before entering the first position
893
            2,  # After entering, exiting on this date
894
            0,  # After exiting
895
            0,
896
        ]
897
898
        for i, expected in enumerate(expected_position_count):
899
            self.assertEqual(daily_stats.ix[i]['num_positions'],
900
                             expected)
901
902
    def test_noop_orders(self):
903
        algo = AmbitiousStopLimitAlgorithm(sid=1,
904
                                           sim_params=self.sim_params,
905
                                           env=self.env)
906
        daily_stats = algo.run(self.data_portal)
907
908
        # Verify that positions are empty for all dates.
909
        empty_positions = daily_stats.positions.map(lambda x: len(x) == 0)
910
        self.assertTrue(empty_positions.all())
911
912
913
class TestAlgoScript(TestCase):
914
915
    @classmethod
916
    def setUpClass(cls):
917
        setup_logger(cls)
918
        cls.env = TradingEnvironment()
919
        cls.sim_params = factory.create_simulation_parameters(num_days=251,
920
                                                              env=cls.env)
921
922
        cls.sids = [0, 1, 3, 133]
923
        cls.tempdir = TempDirectory()
924
925
        equities_metadata = {}
926
927
        for sid in cls.sids:
928
            equities_metadata[sid] = {
929
                'start_date': cls.sim_params.period_start,
930
                'end_date': cls.sim_params.period_end
931
            }
932
933
            if sid == 3:
934
                equities_metadata[sid]["symbol"] = "TEST"
935
                equities_metadata[sid]["asset_type"] = "equity"
936
937
        cls.env.write_data(equities_data=equities_metadata)
938
939
        days = 251
940
941
        trades_by_sid = {
942
            0: factory.create_trade_history(
943
                0,
944
                [10.0] * days,
945
                [100] * days,
946
                timedelta(days=1),
947
                cls.sim_params,
948
                cls.env),
949
            3: factory.create_trade_history(
950
                3,
951
                [10.0] * days,
952
                [100] * days,
953
                timedelta(days=1),
954
                cls.sim_params,
955
                cls.env)
956
        }
957
958
        cls.data_portal = create_data_portal_from_trade_history(cls.env,
959
                                                                cls.tempdir,
960
                                                                cls.sim_params,
961
                                                                trades_by_sid)
962
963
        cls.zipline_test_config = {
964
            'sid': 0,
965
        }
966
967
    @classmethod
968
    def tearDownClass(cls):
969
        del cls.env
970
        cls.tempdir.cleanup()
971
        teardown_logger(cls)
972
973
    def test_noop(self):
974
        algo = TradingAlgorithm(initialize=initialize_noop,
975
                                handle_data=handle_data_noop)
976
        algo.run(self.data_portal)
977
978
    def test_noop_string(self):
979
        algo = TradingAlgorithm(script=noop_algo)
980
        algo.run(self.data_portal)
981
982
    def test_api_calls(self):
983
        algo = TradingAlgorithm(initialize=initialize_api,
984
                                handle_data=handle_data_api,
985
                                env=self.env)
986
        algo.run(self.data_portal)
987
988
    def test_api_calls_string(self):
989
        algo = TradingAlgorithm(script=api_algo, env=self.env)
990
        algo.run(self.data_portal)
991
992
    def test_api_get_environment(self):
993
        platform = 'zipline'
994
        algo = TradingAlgorithm(script=api_get_environment_algo,
995
                                platform=platform)
996
        algo.run(self.data_portal)
997
        self.assertEqual(algo.environment, platform)
998
999
    def test_api_symbol(self):
1000
        algo = TradingAlgorithm(script=api_symbol_algo,
1001
                                env=self.env,
1002
                                sim_params=self.sim_params)
1003
        algo.run(self.data_portal)
1004
1005
    def test_fixed_slippage(self):
1006
        # verify order -> transaction -> portfolio position.
1007
        # --------------
1008
        test_algo = TradingAlgorithm(
1009
            script="""
1010
from zipline.api import (slippage,
1011
                         commission,
1012
                         set_slippage,
1013
                         set_commission,
1014
                         order,
1015
                         record,
1016
                         sid)
1017
1018
def initialize(context):
1019
    model = slippage.FixedSlippage(spread=0.10)
1020
    set_slippage(model)
1021
    set_commission(commission.PerTrade(100.00))
1022
    context.count = 1
1023
    context.incr = 0
1024
1025
def handle_data(context, data):
1026
    if context.incr < context.count:
1027
        order(sid(0), -1000)
1028
    record(price=data[0].price)
1029
1030
    context.incr += 1""",
1031
            sim_params=self.sim_params,
1032
            env=self.env,
1033
        )
1034
        results = test_algo.run(self.data_portal)
1035
1036
        # flatten the list of txns
1037
        all_txns = [val for sublist in results["transactions"].tolist()
1038
                    for val in sublist]
1039
1040
        self.assertEqual(len(all_txns), 1)
1041
        txn = all_txns[0]
1042
1043
        self.assertEqual(100.0, txn["commission"])
1044
        expected_spread = 0.05
1045
        expected_commish = 0.10
1046
        expected_price = test_algo.recorded_vars["price"] - expected_spread \
1047
            - expected_commish
1048
1049
        self.assertEqual(expected_price, txn['price'])
1050
1051
    @parameterized.expand(
1052
        [
1053
            ('no_minimum_commission', 0,),
1054
            ('default_minimum_commission', 1,),
1055
            ('alternate_minimum_commission', 2,),
1056
        ]
1057
    )
1058
    def test_volshare_slippage(self, name, minimum_commission):
1059
        tempdir = TempDirectory()
1060
        try:
1061
            if name == "default_minimum_commission":
1062
                commission_line = "set_commission(commission.PerShare(0.02))"
1063
            else:
1064
                commission_line = \
1065
                    "set_commission(commission.PerShare(0.02, " \
1066
                    "min_trade_cost={0}))".format(minimum_commission)
1067
1068
            # verify order -> transaction -> portfolio position.
1069
            # --------------
1070
            test_algo = TradingAlgorithm(
1071
                script="""
1072
from zipline.api import *
1073
1074
def initialize(context):
1075
    model = slippage.VolumeShareSlippage(
1076
                            volume_limit=.3,
1077
                            price_impact=0.05
1078
                       )
1079
    set_slippage(model)
1080
    {0}
1081
1082
    context.count = 2
1083
    context.incr = 0
1084
1085
def handle_data(context, data):
1086
    if context.incr < context.count:
1087
        # order small lots to be sure the
1088
        # order will fill in a single transaction
1089
        order(sid(0), 5000)
1090
    record(price=data[0].price)
1091
    record(volume=data[0].volume)
1092
    record(incr=context.incr)
1093
    context.incr += 1
1094
    """.format(commission_line),
1095
                sim_params=self.sim_params,
1096
                env=self.env,
1097
            )
1098
            set_algo_instance(test_algo)
1099
            trades = factory.create_daily_trade_source(
1100
                [0], self.sim_params, self.env)
1101
            data_portal = create_data_portal_from_trade_history(
1102
                self.env, tempdir, self.sim_params, {0: trades})
1103
            results = test_algo.run(data_portal=data_portal)
1104
1105
            all_txns = [
1106
                val for sublist in results["transactions"].tolist()
1107
                for val in sublist]
1108
1109
            self.assertEqual(len(all_txns), 67)
1110
            first_txn = all_txns[0]
1111
1112
            if minimum_commission == 0:
1113
                commish = first_txn["amount"] * 0.02
1114
                self.assertEqual(commish, first_txn["commission"])
1115
            else:
1116
                self.assertEqual(minimum_commission, first_txn["commission"])
1117
1118
            # import pdb; pdb.set_trace()
1119
            # commish = first_txn["amount"] * per_share_commish
1120
            # self.assertEqual(commish, first_txn["commission"])
1121
            # self.assertEqual(2.029, first_txn["price"])
1122
        finally:
1123
            tempdir.cleanup()
1124
1125
    def test_algo_record_vars(self):
1126
        test_algo = TradingAlgorithm(
1127
            script=record_variables,
1128
            sim_params=self.sim_params,
1129
            env=self.env,
1130
        )
1131
        results = test_algo.run(self.data_portal)
1132
1133
        for i in range(1, 252):
1134
            self.assertEqual(results.iloc[i-1]["incr"], i)
1135
1136
    def test_algo_record_allow_mock(self):
1137
        """
1138
        Test that values from "MagicMock"ed methods can be passed to record.
1139
1140
        Relevant for our basic/validation and methods like history, which
1141
        will end up returning a MagicMock instead of a DataFrame.
1142
        """
1143
        test_algo = TradingAlgorithm(
1144
            script=record_variables,
1145
            sim_params=self.sim_params,
1146
        )
1147
        set_algo_instance(test_algo)
1148
1149
        test_algo.record(foo=MagicMock())
1150
1151
    def test_algo_record_nan(self):
1152
        test_algo = TradingAlgorithm(
1153
            script=record_float_magic % 'nan',
1154
            sim_params=self.sim_params,
1155
            env=self.env,
1156
        )
1157
        results = test_algo.run(self.data_portal)
1158
1159
        for i in range(1, 252):
1160
            self.assertTrue(np.isnan(results.iloc[i-1]["data"]))
1161
1162
    def test_order_methods(self):
1163
        """
1164
        Only test that order methods can be called without error.
1165
        Correct filling of orders is tested in zipline.
1166
        """
1167
        test_algo = TradingAlgorithm(
1168
            script=call_all_order_methods,
1169
            sim_params=self.sim_params,
1170
            env=self.env,
1171
        )
1172
        test_algo.run(self.data_portal)
1173
1174
    def test_order_in_init(self):
1175
        """
1176
        Test that calling order in initialize
1177
        will raise an error.
1178
        """
1179
        with self.assertRaises(OrderDuringInitialize):
1180
            test_algo = TradingAlgorithm(
1181
                script=call_order_in_init,
1182
                sim_params=self.sim_params,
1183
                env=self.env,
1184
            )
1185
            test_algo.run(self.data_portal)
1186
1187
    def test_portfolio_in_init(self):
1188
        """
1189
        Test that accessing portfolio in init doesn't break.
1190
        """
1191
        test_algo = TradingAlgorithm(
1192
            script=access_portfolio_in_init,
1193
            sim_params=self.sim_params,
1194
            env=self.env,
1195
        )
1196
        test_algo.run(self.data_portal)
1197
1198
    def test_account_in_init(self):
1199
        """
1200
        Test that accessing account in init doesn't break.
1201
        """
1202
        test_algo = TradingAlgorithm(
1203
            script=access_account_in_init,
1204
            sim_params=self.sim_params,
1205
            env=self.env,
1206
        )
1207
        test_algo.run(self.data_portal)
1208
1209
1210
class TestGetDatetime(TestCase):
1211
1212
    @classmethod
1213
    def setUpClass(cls):
1214
        cls.env = TradingEnvironment()
1215
        cls.env.write_data(equities_identifiers=[0, 1])
1216
1217
        setup_logger(cls)
1218
1219
        cls.sim_params = factory.create_simulation_parameters(
1220
            data_frequency='minute',
1221
            env=cls.env,
1222
            start=to_utc('2014-01-02 9:31'),
1223
            end=to_utc('2014-01-03 9:31')
1224
        )
1225
1226
        cls.tempdir = TempDirectory()
1227
1228
        cls.data_portal = create_data_portal(
1229
            cls.env,
1230
            cls.tempdir,
1231
            cls.sim_params,
1232
            [1]
1233
        )
1234
1235
    @classmethod
1236
    def tearDownClass(cls):
1237
        del cls.env
1238
        teardown_logger(cls)
1239
        cls.tempdir.cleanup()
1240
1241
    @parameterized.expand(
1242
        [
1243
            ('default', None,),
1244
            ('utc', 'UTC',),
1245
            ('us_east', 'US/Eastern',),
1246
        ]
1247
    )
1248
    def test_get_datetime(self, name, tz):
1249
        algo = dedent(
1250
            """
1251
            import pandas as pd
1252
            from zipline.api import get_datetime
1253
1254
            def initialize(context):
1255
                context.tz = {tz} or 'UTC'
1256
                context.first_bar = True
1257
1258
            def handle_data(context, data):
1259
                if context.first_bar:
1260
                    dt = get_datetime({tz})
1261
                    if dt.tz.zone != context.tz:
1262
                        raise ValueError("Mismatched Zone")
1263
                    elif dt.tz_convert("US/Eastern").hour != 9:
1264
                        raise ValueError("Mismatched Hour")
1265
                    elif dt.tz_convert("US/Eastern").minute != 31:
1266
                        raise ValueError("Mismatched Minute")
1267
                context.first_bar = False
1268
            """.format(tz=repr(tz))
1269
        )
1270
1271
        algo = TradingAlgorithm(
1272
            script=algo,
1273
            sim_params=self.sim_params,
1274
            env=self.env,
1275
        )
1276
        algo.run(self.data_portal)
1277
        self.assertFalse(algo.first_bar)
1278
1279
1280
class TestTradingControls(TestCase):
1281
1282
    @classmethod
1283
    def setUpClass(cls):
1284
        cls.sid = 133
1285
        cls.env = TradingEnvironment()
1286
        cls.sim_params = factory.create_simulation_parameters(num_days=4,
1287
                                                              env=cls.env)
1288
1289
        cls.env.write_data(equities_data={
1290
            133: {
1291
                'start_date': cls.sim_params.period_start,
1292
                'end_date': cls.sim_params.period_end
1293
            }
1294
        })
1295
1296
        cls.tempdir = TempDirectory()
1297
1298
        cls.data_portal = create_data_portal(
1299
            cls.env,
1300
            cls.tempdir,
1301
            cls.sim_params,
1302
            [cls.sid]
1303
        )
1304
1305
    @classmethod
1306
    def tearDownClass(cls):
1307
        del cls.env
1308
        cls.tempdir.cleanup()
1309
1310
    def _check_algo(self,
1311
                    algo,
1312
                    handle_data,
1313
                    expected_order_count,
1314
                    expected_exc):
1315
1316
        algo._handle_data = handle_data
1317
        with self.assertRaises(expected_exc) if expected_exc else nullctx():
1318
            algo.run(self.data_portal)
1319
        self.assertEqual(algo.order_count, expected_order_count)
1320
1321
    def check_algo_succeeds(self, algo, handle_data, order_count=4):
1322
        # Default for order_count assumes one order per handle_data call.
1323
        self._check_algo(algo, handle_data, order_count, None)
1324
1325
    def check_algo_fails(self, algo, handle_data, order_count):
1326
        self._check_algo(algo,
1327
                         handle_data,
1328
                         order_count,
1329
                         TradingControlViolation)
1330
1331
    def test_set_max_position_size(self):
1332
1333
        # Buy one share four times.  Should be fine.
1334
        def handle_data(algo, data):
1335
            algo.order(algo.sid(self.sid), 1)
1336
            algo.order_count += 1
1337
        algo = SetMaxPositionSizeAlgorithm(sid=self.sid,
1338
                                           max_shares=10,
1339
                                           max_notional=500.0,
1340
                                           sim_params=self.sim_params,
1341
                                           env=self.env)
1342
        self.check_algo_succeeds(algo, handle_data)
1343
1344
        # Buy three shares four times.  Should bail on the fourth before it's
1345
        # placed.
1346
        def handle_data(algo, data):
1347
            algo.order(algo.sid(self.sid), 3)
1348
            algo.order_count += 1
1349
1350
        algo = SetMaxPositionSizeAlgorithm(sid=self.sid,
1351
                                           max_shares=10,
1352
                                           max_notional=500.0,
1353
                                           sim_params=self.sim_params,
1354
                                           env=self.env)
1355
        self.check_algo_fails(algo, handle_data, 3)
1356
1357
        # Buy three shares four times. Should bail due to max_notional on the
1358
        # third attempt.
1359
        def handle_data(algo, data):
1360
            algo.order(algo.sid(self.sid), 3)
1361
            algo.order_count += 1
1362
1363
        algo = SetMaxPositionSizeAlgorithm(sid=self.sid,
1364
                                           max_shares=10,
1365
                                           max_notional=67.0,
1366
                                           sim_params=self.sim_params,
1367
                                           env=self.env)
1368
        self.check_algo_fails(algo, handle_data, 2)
1369
1370
        # Set the trading control to a different sid, then BUY ALL THE THINGS!.
1371
        # Should continue normally.
1372
        def handle_data(algo, data):
1373
            algo.order(algo.sid(self.sid), 10000)
1374
            algo.order_count += 1
1375
        algo = SetMaxPositionSizeAlgorithm(sid=self.sid + 1,
1376
                                           max_shares=10,
1377
                                           max_notional=67.0,
1378
                                           sim_params=self.sim_params,
1379
                                           env=self.env)
1380
        self.check_algo_succeeds(algo, handle_data)
1381
1382
        # Set the trading control sid to None, then BUY ALL THE THINGS!. Should
1383
        # fail because setting sid to None makes the control apply to all sids.
1384
        def handle_data(algo, data):
1385
            algo.order(algo.sid(self.sid), 10000)
1386
            algo.order_count += 1
1387
        algo = SetMaxPositionSizeAlgorithm(max_shares=10, max_notional=61.0,
1388
                                           sim_params=self.sim_params,
1389
                                           env=self.env)
1390
        self.check_algo_fails(algo, handle_data, 0)
1391
1392
    def test_set_do_not_order_list(self):
1393
        # set the restricted list to be the sid, and fail.
1394
        algo = SetDoNotOrderListAlgorithm(
1395
            sid=self.sid,
1396
            restricted_list=[self.sid],
1397
            sim_params=self.sim_params,
1398
            env=self.env,
1399
        )
1400
1401
        def handle_data(algo, data):
1402
            algo.order(algo.sid(self.sid), 100)
1403
            algo.order_count += 1
1404
1405
        self.check_algo_fails(algo, handle_data, 0)
1406
1407
        # set the restricted list to exclude the sid, and succeed
1408
        algo = SetDoNotOrderListAlgorithm(
1409
            sid=self.sid,
1410
            restricted_list=[134, 135, 136],
1411
            sim_params=self.sim_params,
1412
            env=self.env,
1413
        )
1414
1415
        def handle_data(algo, data):
1416
            algo.order(algo.sid(self.sid), 100)
1417
            algo.order_count += 1
1418
1419
        self.check_algo_succeeds(algo, handle_data)
1420
1421
    def test_set_max_order_size(self):
1422
1423
        # Buy one share.
1424
        def handle_data(algo, data):
1425
            algo.order(algo.sid(self.sid), 1)
1426
            algo.order_count += 1
1427
        algo = SetMaxOrderSizeAlgorithm(sid=self.sid,
1428
                                        max_shares=10,
1429
                                        max_notional=500.0,
1430
                                        sim_params=self.sim_params,
1431
                                        env=self.env)
1432
        self.check_algo_succeeds(algo, handle_data)
1433
1434
        # Buy 1, then 2, then 3, then 4 shares.  Bail on the last attempt
1435
        # because we exceed shares.
1436
        def handle_data(algo, data):
1437
            algo.order(algo.sid(self.sid), algo.order_count + 1)
1438
            algo.order_count += 1
1439
1440
        algo = SetMaxOrderSizeAlgorithm(sid=self.sid,
1441
                                        max_shares=3,
1442
                                        max_notional=500.0,
1443
                                        sim_params=self.sim_params,
1444
                                        env=self.env)
1445
        self.check_algo_fails(algo, handle_data, 3)
1446
1447
        # Buy 1, then 2, then 3, then 4 shares.  Bail on the last attempt
1448
        # because we exceed notional.
1449
        def handle_data(algo, data):
1450
            algo.order(algo.sid(self.sid), algo.order_count + 1)
1451
            algo.order_count += 1
1452
1453
        algo = SetMaxOrderSizeAlgorithm(sid=self.sid,
1454
                                        max_shares=10,
1455
                                        max_notional=40.0,
1456
                                        sim_params=self.sim_params,
1457
                                        env=self.env)
1458
        self.check_algo_fails(algo, handle_data, 3)
1459
1460
        # Set the trading control to a different sid, then BUY ALL THE THINGS!.
1461
        # Should continue normally.
1462
        def handle_data(algo, data):
1463
            algo.order(algo.sid(self.sid), 10000)
1464
            algo.order_count += 1
1465
        algo = SetMaxOrderSizeAlgorithm(sid=self.sid + 1,
1466
                                        max_shares=1,
1467
                                        max_notional=1.0,
1468
                                        sim_params=self.sim_params,
1469
                                        env=self.env)
1470
        self.check_algo_succeeds(algo, handle_data)
1471
1472
        # Set the trading control sid to None, then BUY ALL THE THINGS!.
1473
        # Should fail because not specifying a sid makes the trading control
1474
        # apply to all sids.
1475
        def handle_data(algo, data):
1476
            algo.order(algo.sid(self.sid), 10000)
1477
            algo.order_count += 1
1478
        algo = SetMaxOrderSizeAlgorithm(max_shares=1,
1479
                                        max_notional=1.0,
1480
                                        sim_params=self.sim_params,
1481
                                        env=self.env)
1482
        self.check_algo_fails(algo, handle_data, 0)
1483
1484
    def test_set_max_order_count(self):
1485
        tempdir = TempDirectory()
1486
        try:
1487
            env = TradingEnvironment()
1488
            sim_params = factory.create_simulation_parameters(
1489
                num_days=4, env=env, data_frequency="minute")
1490
1491
            env.write_data(equities_data={
1492
                1: {
1493
                    'start_date': sim_params.period_start,
1494
                    'end_date': sim_params.period_end + timedelta(days=1)
1495
                }
1496
            })
1497
1498
            data_portal = create_data_portal(
1499
                env,
1500
                tempdir,
1501
                sim_params,
1502
                [1]
1503
            )
1504
1505
            def handle_data(algo, data):
1506
                for i in range(5):
1507
                    algo.order(algo.sid(1), 1)
1508
                    algo.order_count += 1
1509
1510
            algo = SetMaxOrderCountAlgorithm(3, sim_params=sim_params,
1511
                                             env=env)
1512
            with self.assertRaises(TradingControlViolation):
1513
                algo._handle_data = handle_data
1514
                algo.run(data_portal)
1515
1516
            self.assertEqual(algo.order_count, 3)
1517
1518
            # This time, order 5 times twice in a single day. The last order
1519
            # of the second batch should fail.
1520
            def handle_data2(algo, data):
1521
                if algo.minute_count == 0 or algo.minute_count == 100:
1522
                    for i in range(5):
1523
                        algo.order(algo.sid(1), 1)
1524
                        algo.order_count += 1
1525
1526
                algo.minute_count += 1
1527
1528
            algo = SetMaxOrderCountAlgorithm(9, sim_params=sim_params,
1529
                                             env=env)
1530
            with self.assertRaises(TradingControlViolation):
1531
                algo._handle_data = handle_data2
1532
                algo.run(data_portal)
1533
1534
            self.assertEqual(algo.order_count, 9)
1535
1536
            def handle_data3(algo, data):
1537
                if (algo.minute_count % 390) == 0:
1538
                    for i in range(5):
1539
                        algo.order(algo.sid(1), 1)
1540
                        algo.order_count += 1
1541
1542
                algo.minute_count += 1
1543
1544
            # Only 5 orders are placed per day, so this should pass even
1545
            # though in total more than 20 orders are placed.
1546
            algo = SetMaxOrderCountAlgorithm(5, sim_params=sim_params,
1547
                                             env=env)
1548
            algo._handle_data = handle_data3
1549
            algo.run(data_portal)
1550
        finally:
1551
            tempdir.cleanup()
1552
1553
    def test_long_only(self):
1554
        # Sell immediately -> fail immediately.
1555
        def handle_data(algo, data):
1556
            algo.order(algo.sid(self.sid), -1)
1557
            algo.order_count += 1
1558
        algo = SetLongOnlyAlgorithm(sim_params=self.sim_params, env=self.env)
1559
        self.check_algo_fails(algo, handle_data, 0)
1560
1561
        # Buy on even days, sell on odd days.  Never takes a short position, so
1562
        # should succeed.
1563
        def handle_data(algo, data):
1564
            if (algo.order_count % 2) == 0:
1565
                algo.order(algo.sid(self.sid), 1)
1566
            else:
1567
                algo.order(algo.sid(self.sid), -1)
1568
            algo.order_count += 1
1569
        algo = SetLongOnlyAlgorithm(sim_params=self.sim_params, env=self.env)
1570
        self.check_algo_succeeds(algo, handle_data)
1571
1572
        # Buy on first three days, then sell off holdings.  Should succeed.
1573
        def handle_data(algo, data):
1574
            amounts = [1, 1, 1, -3]
1575
            algo.order(algo.sid(self.sid), amounts[algo.order_count])
1576
            algo.order_count += 1
1577
        algo = SetLongOnlyAlgorithm(sim_params=self.sim_params, env=self.env)
1578
        self.check_algo_succeeds(algo, handle_data)
1579
1580
        # Buy on first three days, then sell off holdings plus an extra share.
1581
        # Should fail on the last sale.
1582
        def handle_data(algo, data):
1583
            amounts = [1, 1, 1, -4]
1584
            algo.order(algo.sid(self.sid), amounts[algo.order_count])
1585
            algo.order_count += 1
1586
        algo = SetLongOnlyAlgorithm(sim_params=self.sim_params, env=self.env)
1587
        self.check_algo_fails(algo, handle_data, 3)
1588
1589
    def test_register_post_init(self):
1590
1591
        def initialize(algo):
1592
            algo.initialized = True
1593
1594
        def handle_data(algo, data):
1595
            with self.assertRaises(RegisterTradingControlPostInit):
1596
                algo.set_max_position_size(self.sid, 1, 1)
1597
            with self.assertRaises(RegisterTradingControlPostInit):
1598
                algo.set_max_order_size(self.sid, 1, 1)
1599
            with self.assertRaises(RegisterTradingControlPostInit):
1600
                algo.set_max_order_count(1)
1601
            with self.assertRaises(RegisterTradingControlPostInit):
1602
                algo.set_long_only()
1603
1604
        algo = TradingAlgorithm(initialize=initialize,
1605
                                handle_data=handle_data,
1606
                                sim_params=self.sim_params,
1607
                                env=self.env)
1608
        algo.run(self.data_portal)
1609
1610
    def test_asset_date_bounds(self):
1611
        tempdir = TempDirectory()
1612
        try:
1613
            # Run the algorithm with a sid that ends far in the future
1614
            temp_env = TradingEnvironment()
1615
1616
            data_portal = create_data_portal(
1617
                temp_env,
1618
                tempdir,
1619
                self.sim_params,
1620
                [0]
1621
            )
1622
1623
            metadata = {0: {'start_date': self.sim_params.period_start,
1624
                            'end_date': '2020-01-01'}}
1625
1626
            algo = SetAssetDateBoundsAlgorithm(
1627
                equities_metadata=metadata,
1628
                sim_params=self.sim_params,
1629
                env=temp_env,
1630
            )
1631
            algo.run(data_portal)
1632
        finally:
1633
            tempdir.cleanup()
1634
1635
        # Run the algorithm with a sid that has already ended
1636
        tempdir = TempDirectory()
1637
        try:
1638
            temp_env = TradingEnvironment()
1639
1640
            data_portal = create_data_portal(
1641
                temp_env,
1642
                tempdir,
1643
                self.sim_params,
1644
                [0]
1645
            )
1646
            metadata = {0: {'start_date': '1989-01-01',
1647
                            'end_date': '1990-01-01'}}
1648
1649
            algo = SetAssetDateBoundsAlgorithm(
1650
                equities_metadata=metadata,
1651
                sim_params=self.sim_params,
1652
                env=temp_env,
1653
            )
1654
            with self.assertRaises(TradingControlViolation):
1655
                algo.run(data_portal)
1656
        finally:
1657
            tempdir.cleanup()
1658
1659
        # Run the algorithm with a sid that has not started
1660
        tempdir = TempDirectory()
1661
        try:
1662
            temp_env = TradingEnvironment()
1663
            data_portal = create_data_portal(
1664
                temp_env,
1665
                tempdir,
1666
                self.sim_params,
1667
                [0]
1668
            )
1669
1670
            metadata = {0: {'start_date': '2020-01-01',
1671
                            'end_date': '2021-01-01'}}
1672
1673
            algo = SetAssetDateBoundsAlgorithm(
1674
                equities_metadata=metadata,
1675
                sim_params=self.sim_params,
1676
                env=temp_env,
1677
            )
1678
1679
            with self.assertRaises(TradingControlViolation):
1680
                algo.run(data_portal)
1681
1682
        finally:
1683
            tempdir.cleanup()
1684
1685
1686
class TestAccountControls(TestCase):
1687
1688
    @classmethod
1689
    def setUpClass(cls):
1690
        cls.sidint = 133
1691
        cls.env = TradingEnvironment()
1692
        cls.sim_params = factory.create_simulation_parameters(
1693
            num_days=4, env=cls.env
1694
        )
1695
1696
        cls.env.write_data(equities_data={
1697
            133: {
1698
                'start_date': cls.sim_params.period_start,
1699
                'end_date': cls.sim_params.period_end + timedelta(days=1)
1700
            }
1701
        })
1702
1703
        cls.tempdir = TempDirectory()
1704
1705
        trades_by_sid = {
1706
            cls.sidint: factory.create_trade_history(
1707
                cls.sidint,
1708
                [10.0, 10.0, 11.0, 11.0],
1709
                [100, 100, 100, 300],
1710
                timedelta(days=1),
1711
                cls.sim_params,
1712
                cls.env,
1713
            )
1714
        }
1715
1716
        cls.data_portal = create_data_portal_from_trade_history(cls.env,
1717
                                                                cls.tempdir,
1718
                                                                cls.sim_params,
1719
                                                                trades_by_sid)
1720
1721
    @classmethod
1722
    def tearDownClass(cls):
1723
        del cls.env
1724
        cls.tempdir.cleanup()
1725
1726
    def _check_algo(self,
1727
                    algo,
1728
                    handle_data,
1729
                    expected_exc):
1730
1731
        algo._handle_data = handle_data
1732
        with self.assertRaises(expected_exc) if expected_exc else nullctx():
1733
            algo.run(self.data_portal)
1734
1735
    def check_algo_succeeds(self, algo, handle_data):
1736
        # Default for order_count assumes one order per handle_data call.
1737
        self._check_algo(algo, handle_data, None)
1738
1739
    def check_algo_fails(self, algo, handle_data):
1740
        self._check_algo(algo,
1741
                         handle_data,
1742
                         AccountControlViolation)
1743
1744
    def test_set_max_leverage(self):
1745
1746
        # Set max leverage to 0 so buying one share fails.
1747
        def handle_data(algo, data):
1748
            algo.order(algo.sid(self.sidint), 1)
1749
1750
        algo = SetMaxLeverageAlgorithm(0, sim_params=self.sim_params,
1751
                                       env=self.env)
1752
        self.check_algo_fails(algo, handle_data)
1753
1754
        # Set max leverage to 1 so buying one share passes
1755
        def handle_data(algo, data):
1756
            algo.order(algo.sid(self.sidint), 1)
1757
1758
        algo = SetMaxLeverageAlgorithm(1,  sim_params=self.sim_params,
1759
                                       env=self.env)
1760
        self.check_algo_succeeds(algo, handle_data)
1761
1762
1763
class TestClosePosAlgo(TestCase):
1764
1765
    @classmethod
1766
    def setUpClass(cls):
1767
        cls.tempdir = TempDirectory()
1768
1769
        cls.env = TradingEnvironment()
1770
        cls.days = pd.date_range(start=pd.Timestamp("2006-01-09", tz='UTC'),
1771
                                 end=pd.Timestamp("2006-01-12", tz='UTC'))
1772
1773
        cls.sid = 1
1774
1775
        cls.sim_params = factory.create_simulation_parameters(
1776
            start=cls.days[0],
1777
            end=cls.days[-1]
1778
        )
1779
1780
        trades_by_sid = {}
1781
        trades_by_sid[cls.sid] = factory.create_trade_history(
1782
            cls.sid,
1783
            [1, 1, 2, 4],
1784
            [1e9, 1e9, 1e9, 1e9],
1785
            timedelta(days=1),
1786
            cls.sim_params,
1787
            cls.env
1788
        )
1789
1790
        cls.data_portal = create_data_portal_from_trade_history(
1791
            cls.env,
1792
            cls.tempdir,
1793
            cls.sim_params,
1794
            trades_by_sid
1795
        )
1796
1797
    @classmethod
1798
    def tearDownClass(cls):
1799
        cls.tempdir.cleanup()
1800
1801
    def test_auto_close_future(self):
1802
        self.env.write_data(futures_data={
1803
            1: {
1804
                "start_date": self.sim_params.trading_days[0],
1805
                "end_date": self.env.next_trading_day(
1806
                    self.sim_params.trading_days[-1]),
1807
                'symbol': 'TEST',
1808
                'asset_type': 'future',
1809
                'auto_close_date': self.env.next_trading_day(
1810
                    self.sim_params.trading_days[-1])
1811
            }
1812
        })
1813
1814
        algo = TestAlgorithm(sid=1, amount=1, order_count=1,
1815
                             commission=PerShare(0),
1816
                             env=self.env,
1817
                             sim_params=self.sim_params)
1818
1819
        # Check results
1820
        results = algo.run(self.data_portal)
1821
1822
        expected_pnl = [0, 0, 1, 2]
1823
        self.check_algo_pnl(results, expected_pnl)
1824
1825
        expected_positions = [0, 1, 1, 0]
1826
        self.check_algo_positions(results, expected_positions)
1827
1828
        expected_pnl = [0, 0, 1, 2]
1829
        self.check_algo_pnl(results, expected_pnl)
1830
1831
    def check_algo_pnl(self, results, expected_pnl):
1832
        np.testing.assert_array_almost_equal(results.pnl, expected_pnl)
1833
1834
    def check_algo_positions(self, results, expected_positions):
1835
        for i, amount in enumerate(results.positions):
1836
            if amount:
1837
                actual_position = amount[0]['amount']
1838
            else:
1839
                actual_position = 0
1840
1841
            self.assertEqual(
1842
                actual_position, expected_positions[i],
1843
                "position for day={0} not equal, actual={1}, expected={2}".
1844
                format(i, actual_position, expected_positions[i]))
1845
1846
1847
class TestFutureFlip(TestCase):
1848
    @classmethod
1849
    def setUpClass(cls):
1850
        cls.tempdir = TempDirectory()
1851
1852
        cls.env = TradingEnvironment()
1853
        cls.days = pd.date_range(start=pd.Timestamp("2006-01-09", tz='UTC'),
1854
                                 end=pd.Timestamp("2006-01-12", tz='UTC'))
1855
1856
        cls.sid = 1
1857
1858
        cls.sim_params = factory.create_simulation_parameters(
1859
            start=cls.days[0],
1860
            end=cls.days[-2]
1861
        )
1862
1863
        trades = factory.create_trade_history(
1864
            cls.sid,
1865
            [1, 2, 4],
1866
            [1e9, 1e9, 1e9],
1867
            timedelta(days=1),
1868
            cls.sim_params,
1869
            cls.env
1870
        )
1871
1872
        trades_by_sid = {
1873
            cls.sid: trades
1874
        }
1875
1876
        cls.data_portal = create_data_portal_from_trade_history(
1877
            cls.env,
1878
            cls.tempdir,
1879
            cls.sim_params,
1880
            trades_by_sid
1881
        )
1882
1883
    @classmethod
1884
    def tearDownClass(cls):
1885
        cls.tempdir.cleanup()
1886
1887
    def test_flip_algo(self):
1888
        metadata = {1: {'symbol': 'TEST',
1889
                        'start_date': self.sim_params.trading_days[0],
1890
                        'end_date': self.env.next_trading_day(
1891
                            self.sim_params.trading_days[-1]),
1892
                        'contract_multiplier': 5}}
1893
        self.env.write_data(futures_data=metadata)
1894
1895
        algo = FutureFlipAlgo(sid=1, amount=1, env=self.env,
1896
                              commission=PerShare(0),
1897
                              order_count=0,  # not applicable but required
1898
                              sim_params=self.sim_params)
1899
1900
        results = algo.run(self.data_portal)
1901
1902
        expected_positions = [0, 1, -1]
1903
        self.check_algo_positions(results, expected_positions)
1904
1905
        expected_pnl = [0, 5, -10]
1906
        self.check_algo_pnl(results, expected_pnl)
1907
1908
    def check_algo_pnl(self, results, expected_pnl):
1909
        np.testing.assert_array_almost_equal(results.pnl, expected_pnl)
1910
1911
    def check_algo_positions(self, results, expected_positions):
1912
        for i, amount in enumerate(results.positions):
1913
            if amount:
1914
                actual_position = amount[0]['amount']
1915
            else:
1916
                actual_position = 0
1917
1918
            self.assertEqual(
1919
                actual_position, expected_positions[i],
1920
                "position for day={0} not equal, actual={1}, expected={2}".
1921
                format(i, actual_position, expected_positions[i]))
1922
1923
1924
class TestTradingAlgorithm(TestCase):
1925
    def setUp(self):
1926
        self.env = TradingEnvironment()
1927
        self.days = self.env.trading_days[:4]
1928
1929
    def test_analyze_called(self):
1930
        self.perf_ref = None
1931
1932
        def initialize(context):
1933
            pass
1934
1935
        def handle_data(context, data):
1936
            pass
1937
1938
        def analyze(context, perf):
1939
            self.perf_ref = perf
1940
1941
        algo = TradingAlgorithm(initialize=initialize, handle_data=handle_data,
1942
                                analyze=analyze)
1943
1944
        data_portal = FakeDataPortal()
1945
1946
        results = algo.run(data_portal)
1947
        self.assertIs(results, self.perf_ref)
1948