Completed
Pull Request — master (#858)
by Eddie
10:07 queued 01:13
created

tests.TestGetDatetime.tearDownClass()   A

Complexity

Conditions 1

Size

Total Lines 5

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 1
dl 0
loc 5
rs 9.4286
1
#
2
# Copyright 2014 Quantopian, Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
import datetime
16
from datetime import timedelta
17
from mock import MagicMock
18
from nose_parameterized import parameterized
19
from six.moves import range
20
from testfixtures import TempDirectory
21
from textwrap import dedent
22
from unittest import TestCase, skip
23
24
import numpy as np
25
import pandas as pd
26
27
from zipline.assets import Equity, Future
28
from zipline.utils.api_support import ZiplineAPI
29
from zipline.utils.control_flow import nullctx
30
from zipline.utils.test_utils import (
31
    setup_logger,
32
    teardown_logger,
33
    FakeDataPortal)
34
import zipline.utils.factory as factory
35
36
from zipline.errors import (
37
    OrderDuringInitialize,
38
    RegisterTradingControlPostInit,
39
    TradingControlViolation,
40
    AccountControlViolation,
41
    SymbolNotFound,
42
    RootSymbolNotFound,
43
    UnsupportedDatetimeFormat,
44
)
45
from zipline.test_algorithms import (
46
    access_account_in_init,
47
    access_portfolio_in_init,
48
    AmbitiousStopLimitAlgorithm,
49
    EmptyPositionsAlgorithm,
50
    InvalidOrderAlgorithm,
51
    RecordAlgorithm,
52
    FutureFlipAlgo,
53
    TestAlgorithm,
54
    TestOrderAlgorithm,
55
    TestOrderPercentAlgorithm,
56
    TestOrderStyleForwardingAlgorithm,
57
    TestOrderValueAlgorithm,
58
    TestRegisterTransformAlgorithm,
59
    TestTargetAlgorithm,
60
    TestTargetPercentAlgorithm,
61
    TestTargetValueAlgorithm,
62
    SetLongOnlyAlgorithm,
63
    SetAssetDateBoundsAlgorithm,
64
    SetMaxPositionSizeAlgorithm,
65
    SetMaxOrderCountAlgorithm,
66
    SetMaxOrderSizeAlgorithm,
67
    SetDoNotOrderListAlgorithm,
68
    SetMaxLeverageAlgorithm,
69
    api_algo,
70
    api_get_environment_algo,
71
    api_symbol_algo,
72
    call_all_order_methods,
73
    call_order_in_init,
74
    handle_data_api,
75
    handle_data_noop,
76
    initialize_api,
77
    initialize_noop,
78
    noop_algo,
79
    record_float_magic,
80
    record_variables,
81
)
82
from zipline.utils.context_tricks import CallbackManager
83
import zipline.utils.events
84
from zipline.utils.test_utils import to_utc
85
86
from zipline.finance.execution import LimitOrder
87
from zipline.finance.trading import SimulationParameters
88
from zipline.utils.api_support import set_algo_instance
89
from zipline.utils.events import DateRuleFactory, TimeRuleFactory, Always
90
from zipline.algorithm import TradingAlgorithm
91
from zipline.finance.trading import TradingEnvironment
92
from zipline.finance.commission import PerShare
93
94
from zipline.utils.test_utils import (
95
    create_data_portal,
96
    create_data_portal_from_trade_history
97
)
98
99
# Because test cases appear to reuse some resources.
100
101
102
_multiprocess_can_split_ = False
103
104
105
class TestRecordAlgorithm(TestCase):
106
107
    @classmethod
108
    def setUpClass(cls):
109
        cls.env = TradingEnvironment()
110
        cls.sids = [133]
111
        cls.env.write_data(equities_identifiers=cls.sids)
112
113
        cls.sim_params = factory.create_simulation_parameters(
114
            num_days=4,
115
            env=cls.env
116
        )
117
118
        cls.tempdir = TempDirectory()
119
120
        cls.data_portal = create_data_portal(
121
            cls.env,
122
            cls.tempdir,
123
            cls.sim_params,
124
            cls.sids
125
        )
126
127
    @classmethod
128
    def tearDownClass(cls):
129
        del cls.env
130
        cls.tempdir.cleanup()
131
132
    def test_record_incr(self):
133
        algo = RecordAlgorithm(sim_params=self.sim_params, env=self.env)
134
        output = algo.run(self.data_portal)
135
136
        np.testing.assert_array_equal(output['incr'].values,
137
                                      range(1, len(output) + 1))
138
        np.testing.assert_array_equal(output['name'].values,
139
                                      range(1, len(output) + 1))
140
        np.testing.assert_array_equal(output['name2'].values,
141
                                      [2] * len(output))
142
        np.testing.assert_array_equal(output['name3'].values,
143
                                      range(1, len(output) + 1))
144
145
146
class TestMiscellaneousAPI(TestCase):
147
148
    @classmethod
149
    def setUpClass(cls):
150
        cls.sids = [1, 2]
151
        cls.env = TradingEnvironment()
152
153
        metadata = {3: {'symbol': 'PLAY',
154
                        'start_date': '2002-01-01',
155
                        'end_date': '2004-01-01'},
156
                    4: {'symbol': 'PLAY',
157
                        'start_date': '2005-01-01',
158
                        'end_date': '2006-01-01'}}
159
160
        futures_metadata = {
161
            5: {
162
                'symbol': 'CLG06',
163
                'root_symbol': 'CL',
164
                'start_date': pd.Timestamp('2005-12-01', tz='UTC'),
165
                'notice_date': pd.Timestamp('2005-12-20', tz='UTC'),
166
                'expiration_date': pd.Timestamp('2006-01-20', tz='UTC')},
167
            6: {
168
                'root_symbol': 'CL',
169
                'symbol': 'CLK06',
170
                'start_date': pd.Timestamp('2005-12-01', tz='UTC'),
171
                'notice_date': pd.Timestamp('2006-03-20', tz='UTC'),
172
                'expiration_date': pd.Timestamp('2006-04-20', tz='UTC')},
173
            7: {
174
                'symbol': 'CLQ06',
175
                'root_symbol': 'CL',
176
                'start_date': pd.Timestamp('2005-12-01', tz='UTC'),
177
                'notice_date': pd.Timestamp('2006-06-20', tz='UTC'),
178
                'expiration_date': pd.Timestamp('2006-07-20', tz='UTC')},
179
            8: {
180
                'symbol': 'CLX06',
181
                'root_symbol': 'CL',
182
                'start_date': pd.Timestamp('2006-02-01', tz='UTC'),
183
                'notice_date': pd.Timestamp('2006-09-20', tz='UTC'),
184
                'expiration_date': pd.Timestamp('2006-10-20', tz='UTC')}
185
        }
186
        cls.env.write_data(equities_identifiers=cls.sids,
187
                           equities_data=metadata,
188
                           futures_data=futures_metadata)
189
190
        setup_logger(cls)
191
192
        cls.sim_params = factory.create_simulation_parameters(
193
            num_days=2,
194
            data_frequency='minute',
195
            emission_rate='daily',
196
            env=cls.env,
197
        )
198
199
        cls.temp_dir = TempDirectory()
200
201
        cls.data_portal = create_data_portal(
202
            cls.env,
203
            cls.temp_dir,
204
            cls.sim_params,
205
            cls.sids
206
        )
207
208
    @classmethod
209
    def tearDownClass(cls):
210
        del cls.env
211
        teardown_logger(cls)
212
        cls.temp_dir.cleanup()
213
214
    def test_zipline_api_resolves_dynamically(self):
215
        # Make a dummy algo.
216
        algo = TradingAlgorithm(
217
            initialize=lambda context: None,
218
            handle_data=lambda context, data: None,
219
            sim_params=self.sim_params,
220
        )
221
222
        # Verify that api methods get resolved dynamically by patching them out
223
        # and then calling them
224
        for method in algo.all_api_methods():
225
            name = method.__name__
226
            sentinel = object()
227
228
            def fake_method(*args, **kwargs):
229
                return sentinel
230
            setattr(algo, name, fake_method)
231
            with ZiplineAPI(algo):
232
                self.assertIs(sentinel, getattr(zipline.api, name)())
233
234
    def test_cannot_iterate_over_data(self):
235
        def initialize(algo):
236
            pass
237
238
        def handle_data(algo, data):
239
            for asset in data:
240
                pass
241
242
        algo = TradingAlgorithm(initialize=initialize,
243
                                handle_data=handle_data,
244
                                sim_params=self.sim_params,
245
                                env=self.env)
246
247
        with self.assertRaises(ValueError):
248
            algo.run(self.data_portal)
249
250
    def test_get_sid(self):
251
        algo_text = """
252
from zipline.api import sid
253
254
def initialize(context):
255
    pass
256
257
def handle_data(context, data):
258
    aapl = data[sid(3)]
259
    assert aapl.sid == 3
260
"""
261
262
        algo = TradingAlgorithm(script=algo_text,
263
                                sim_params=self.sim_params,
264
                                env=self.env)
265
266
        algo.run(self.data_portal)
267
268
    def test_sid_datetime(self):
269
        algo_text = """
270
from zipline.api import sid, get_datetime
271
272
def initialize(context):
273
    pass
274
275
def handle_data(context, data):
276
    aapl = data[sid(1)]
277
    assert aapl.datetime == get_datetime()
278
"""
279
280
        algo = TradingAlgorithm(script=algo_text,
281
                                sim_params=self.sim_params,
282
                                env=self.env)
283
284
        algo.run(self.data_portal)
285
286
    def test_get_environment(self):
287
        expected_env = {
288
            'arena': 'backtest',
289
            'data_frequency': 'minute',
290
            'start': pd.Timestamp('2006-01-03 14:31:00+0000', tz='UTC'),
291
            'end': pd.Timestamp('2006-01-04 21:00:00+0000', tz='UTC'),
292
            'capital_base': 100000.0,
293
            'platform': 'zipline'
294
        }
295
296
        def initialize(algo):
297
            self.assertEqual('zipline', algo.get_environment())
298
            self.assertEqual(expected_env, algo.get_environment('*'))
299
300
        def handle_data(algo, data):
301
            pass
302
303
        algo = TradingAlgorithm(initialize=initialize,
304
                                handle_data=handle_data,
305
                                sim_params=self.sim_params,
306
                                env=self.env)
307
        algo.run(self.data_portal)
308
309
    def test_get_open_orders(self):
310
        def initialize(algo):
311
            algo.minute = 0
312
313
        def handle_data(algo, data):
314
            if algo.minute == 0:
315
316
                # Should be filled by the next minute
317
                algo.order(algo.sid(1), 1)
318
319
                # Won't be filled because the price is too low.
320
                algo.order(algo.sid(2), 1, style=LimitOrder(0.01))
321
                algo.order(algo.sid(2), 1, style=LimitOrder(0.01))
322
                algo.order(algo.sid(2), 1, style=LimitOrder(0.01))
323
324
                all_orders = algo.get_open_orders()
325
                self.assertEqual(list(all_orders.keys()), [1, 2])
326
327
                self.assertEqual(all_orders[1], algo.get_open_orders(1))
328
                self.assertEqual(len(all_orders[1]), 1)
329
330
                self.assertEqual(all_orders[2], algo.get_open_orders(2))
331
                self.assertEqual(len(all_orders[2]), 3)
332
333
            if algo.minute == 1:
334
                # First order should have filled.
335
                # Second order should still be open.
336
                all_orders = algo.get_open_orders()
337
                self.assertEqual(list(all_orders.keys()), [2])
338
339
                self.assertEqual([], algo.get_open_orders(1))
340
341
                orders_2 = algo.get_open_orders(2)
342
                self.assertEqual(all_orders[2], orders_2)
343
                self.assertEqual(len(all_orders[2]), 3)
344
345
                for order in orders_2:
346
                    algo.cancel_order(order)
347
348
                all_orders = algo.get_open_orders()
349
                self.assertEqual(all_orders, {})
350
351
            algo.minute += 1
352
353
        algo = TradingAlgorithm(initialize=initialize,
354
                                handle_data=handle_data,
355
                                sim_params=self.sim_params,
356
                                env=self.env)
357
        algo.run(self.data_portal)
358
359
    def test_schedule_function(self):
360
        date_rules = DateRuleFactory
361
        time_rules = TimeRuleFactory
362
363
        def incrementer(algo, data):
364
            algo.func_called += 1
365
            self.assertEqual(
366
                algo.get_datetime().time(),
367
                datetime.time(hour=14, minute=31),
368
            )
369
370
        def initialize(algo):
371
            algo.func_called = 0
372
            algo.days = 1
373
            algo.date = None
374
            algo.schedule_function(
375
                func=incrementer,
376
                date_rule=date_rules.every_day(),
377
                time_rule=time_rules.market_open(),
378
            )
379
380
        def handle_data(algo, data):
381
            if not algo.date:
382
                algo.date = algo.get_datetime().date()
383
384
            if algo.date < algo.get_datetime().date():
385
                algo.days += 1
386
                algo.date = algo.get_datetime().date()
387
388
        algo = TradingAlgorithm(
389
            initialize=initialize,
390
            handle_data=handle_data,
391
            sim_params=self.sim_params,
392
            env=self.env,
393
        )
394
        algo.run(self.data_portal)
395
396
        self.assertEqual(algo.func_called, algo.days)
397
398
    def test_event_context(self):
399
        expected_data = []
400
        collected_data_pre = []
401
        collected_data_post = []
402
        function_stack = []
403
404
        def pre(data):
405
            function_stack.append(pre)
406
            collected_data_pre.append(data)
407
408
        def post(data):
409
            function_stack.append(post)
410
            collected_data_post.append(data)
411
412
        def initialize(context):
413
            context.add_event(Always(), f)
414
            context.add_event(Always(), g)
415
416
        def handle_data(context, data):
417
            function_stack.append(handle_data)
418
            expected_data.append(data)
419
420
        def f(context, data):
421
            function_stack.append(f)
422
423
        def g(context, data):
424
            function_stack.append(g)
425
426
        algo = TradingAlgorithm(
427
            initialize=initialize,
428
            handle_data=handle_data,
429
            sim_params=self.sim_params,
430
            create_event_context=CallbackManager(pre, post),
431
            env=self.env,
432
        )
433
        algo.run(self.data_portal)
434
435
        self.assertEqual(len(expected_data), 780)
436
        self.assertEqual(collected_data_pre, expected_data)
437
        self.assertEqual(collected_data_post, expected_data)
438
439
        self.assertEqual(
440
            len(function_stack),
441
            780 * 5,
442
            'Incorrect number of functions called: %s != 780' %
443
            len(function_stack),
444
        )
445
        expected_functions = [pre, handle_data, f, g, post] * 780
446
        for n, (f, g) in enumerate(zip(function_stack, expected_functions)):
447
            self.assertEqual(
448
                f,
449
                g,
450
                'function at position %d was incorrect, expected %s but got %s'
451
                % (n, g.__name__, f.__name__),
452
            )
453
454
    @parameterized.expand([
455
        ('daily',),
456
        ('minute'),
457
    ])
458
    def test_schedule_function_rule_creation(self, mode):
459
        def nop(*args, **kwargs):
460
            return None
461
462
        self.sim_params.data_frequency = mode
463
        algo = TradingAlgorithm(
464
            initialize=nop,
465
            handle_data=nop,
466
            sim_params=self.sim_params,
467
            env=self.env,
468
        )
469
470
        # Schedule something for NOT Always.
471
        algo.schedule_function(nop, time_rule=zipline.utils.events.Never())
472
473
        event_rule = algo.event_manager._events[1].rule
474
475
        self.assertIsInstance(event_rule, zipline.utils.events.OncePerDay)
476
477
        inner_rule = event_rule.rule
478
        self.assertIsInstance(inner_rule, zipline.utils.events.ComposedRule)
479
480
        first = inner_rule.first
481
        second = inner_rule.second
482
        composer = inner_rule.composer
483
484
        self.assertIsInstance(first, zipline.utils.events.Always)
485
486
        if mode == 'daily':
487
            self.assertIsInstance(second, zipline.utils.events.Always)
488
        else:
489
            self.assertIsInstance(second, zipline.utils.events.Never)
490
491
        self.assertIs(composer, zipline.utils.events.ComposedRule.lazy_and)
492
493
    def test_asset_lookup(self):
494
495
        algo = TradingAlgorithm(env=self.env)
496
497
        # Test before either PLAY existed
498
        algo.sim_params.period_end = pd.Timestamp('2001-12-01', tz='UTC')
499
        with self.assertRaises(SymbolNotFound):
500
            algo.symbol('PLAY')
501
        with self.assertRaises(SymbolNotFound):
502
            algo.symbols('PLAY')
503
504
        # Test when first PLAY exists
505
        algo.sim_params.period_end = pd.Timestamp('2002-12-01', tz='UTC')
506
        list_result = algo.symbols('PLAY')
507
        self.assertEqual(3, list_result[0])
508
509
        # Test after first PLAY ends
510
        algo.sim_params.period_end = pd.Timestamp('2004-12-01', tz='UTC')
511
        self.assertEqual(3, algo.symbol('PLAY'))
512
513
        # Test after second PLAY begins
514
        algo.sim_params.period_end = pd.Timestamp('2005-12-01', tz='UTC')
515
        self.assertEqual(4, algo.symbol('PLAY'))
516
517
        # Test after second PLAY ends
518
        algo.sim_params.period_end = pd.Timestamp('2006-12-01', tz='UTC')
519
        self.assertEqual(4, algo.symbol('PLAY'))
520
        list_result = algo.symbols('PLAY')
521
        self.assertEqual(4, list_result[0])
522
523
        # Test lookup SID
524
        self.assertIsInstance(algo.sid(3), Equity)
525
        self.assertIsInstance(algo.sid(4), Equity)
526
527
        # Supplying a non-string argument to symbol()
528
        # should result in a TypeError.
529
        with self.assertRaises(TypeError):
530
            algo.symbol(1)
531
532
        with self.assertRaises(TypeError):
533
            algo.symbol((1,))
534
535
        with self.assertRaises(TypeError):
536
            algo.symbol({1})
537
538
        with self.assertRaises(TypeError):
539
            algo.symbol([1])
540
541
        with self.assertRaises(TypeError):
542
            algo.symbol({'foo': 'bar'})
543
544
    def test_future_symbol(self):
545
        """ Tests the future_symbol API function.
546
        """
547
        algo = TradingAlgorithm(env=self.env)
548
        algo.datetime = pd.Timestamp('2006-12-01', tz='UTC')
549
550
        # Check that we get the correct fields for the CLG06 symbol
551
        cl = algo.future_symbol('CLG06')
552
        self.assertEqual(cl.sid, 5)
553
        self.assertEqual(cl.symbol, 'CLG06')
554
        self.assertEqual(cl.root_symbol, 'CL')
555
        self.assertEqual(cl.start_date, pd.Timestamp('2005-12-01', tz='UTC'))
556
        self.assertEqual(cl.notice_date, pd.Timestamp('2005-12-20', tz='UTC'))
557
        self.assertEqual(cl.expiration_date,
558
                         pd.Timestamp('2006-01-20', tz='UTC'))
559
560
        with self.assertRaises(SymbolNotFound):
561
            algo.future_symbol('')
562
563
        with self.assertRaises(SymbolNotFound):
564
            algo.future_symbol('PLAY')
565
566
        with self.assertRaises(SymbolNotFound):
567
            algo.future_symbol('FOOBAR')
568
569
        # Supplying a non-string argument to future_symbol()
570
        # should result in a TypeError.
571
        with self.assertRaises(TypeError):
572
            algo.future_symbol(1)
573
574
        with self.assertRaises(TypeError):
575
            algo.future_symbol((1,))
576
577
        with self.assertRaises(TypeError):
578
            algo.future_symbol({1})
579
580
        with self.assertRaises(TypeError):
581
            algo.future_symbol([1])
582
583
        with self.assertRaises(TypeError):
584
            algo.future_symbol({'foo': 'bar'})
585
586
    def test_future_chain(self):
587
        """ Tests the future_chain API function.
588
        """
589
        algo = TradingAlgorithm(env=self.env)
590
        algo.datetime = pd.Timestamp('2006-12-01', tz='UTC')
591
592
        # Check that the fields of the FutureChain object are set correctly
593
        cl = algo.future_chain('CL')
594
        self.assertEqual(cl.root_symbol, 'CL')
595
        self.assertEqual(cl.as_of_date, algo.datetime)
596
597
        # Check the fields are set correctly if an as_of_date is supplied
598
        as_of_date = pd.Timestamp('1952-08-11', tz='UTC')
599
600
        cl = algo.future_chain('CL', as_of_date=as_of_date)
601
        self.assertEqual(cl.root_symbol, 'CL')
602
        self.assertEqual(cl.as_of_date, as_of_date)
603
604
        cl = algo.future_chain('CL', as_of_date='1952-08-11')
605
        self.assertEqual(cl.root_symbol, 'CL')
606
        self.assertEqual(cl.as_of_date, as_of_date)
607
608
        # Check that weird capitalization is corrected
609
        cl = algo.future_chain('cL')
610
        self.assertEqual(cl.root_symbol, 'CL')
611
612
        cl = algo.future_chain('cl')
613
        self.assertEqual(cl.root_symbol, 'CL')
614
615
        # Check that invalid root symbols raise RootSymbolNotFound
616
        with self.assertRaises(RootSymbolNotFound):
617
            algo.future_chain('CLZ')
618
619
        with self.assertRaises(RootSymbolNotFound):
620
            algo.future_chain('')
621
622
        # Check that invalid dates raise UnsupportedDatetimeFormat
623
        with self.assertRaises(UnsupportedDatetimeFormat):
624
            algo.future_chain('CL', 'my_finger_slipped')
625
626
        with self.assertRaises(UnsupportedDatetimeFormat):
627
            algo.future_chain('CL', '2015-09-')
628
629
        # Supplying a non-string argument to future_chain()
630
        # should result in a TypeError.
631
        with self.assertRaises(TypeError):
632
            algo.future_chain(1)
633
634
        with self.assertRaises(TypeError):
635
            algo.future_chain((1,))
636
637
        with self.assertRaises(TypeError):
638
            algo.future_chain({1})
639
640
        with self.assertRaises(TypeError):
641
            algo.future_chain([1])
642
643
        with self.assertRaises(TypeError):
644
            algo.future_chain({'foo': 'bar'})
645
646
    def test_set_symbol_lookup_date(self):
647
        """
648
        Test the set_symbol_lookup_date API method.
649
        """
650
        # Note we start sid enumeration at i+3 so as not to
651
        # collide with sids [1, 2] added in the setUp() method.
652
        dates = pd.date_range('2013-01-01', freq='2D', periods=2, tz='UTC')
653
        # Create two assets with the same symbol but different
654
        # non-overlapping date ranges.
655
        metadata = pd.DataFrame.from_records(
656
            [
657
                {
658
                    'sid': i + 3,
659
                    'symbol': 'DUP',
660
                    'start_date': date.value,
661
                    'end_date': (date + timedelta(days=1)).value,
662
                }
663
                for i, date in enumerate(dates)
664
            ]
665
        )
666
        env = TradingEnvironment()
667
        env.write_data(equities_df=metadata)
668
        algo = TradingAlgorithm(env=env)
669
670
        # Set the period end to a date after the period end
671
        # dates for our assets.
672
        algo.sim_params.period_end = pd.Timestamp('2015-01-01', tz='UTC')
673
674
        # With no symbol lookup date set, we will use the period end date
675
        # for the as_of_date, resulting here in the asset with the earlier
676
        # start date being returned.
677
        result = algo.symbol('DUP')
678
        self.assertEqual(result.symbol, 'DUP')
679
680
        # By first calling set_symbol_lookup_date, the relevant asset
681
        # should be returned by lookup_symbol
682
        for i, date in enumerate(dates):
683
            algo.set_symbol_lookup_date(date)
684
            result = algo.symbol('DUP')
685
            self.assertEqual(result.symbol, 'DUP')
686
            self.assertEqual(result.sid, i + 3)
687
688
        with self.assertRaises(UnsupportedDatetimeFormat):
689
            algo.set_symbol_lookup_date('foobar')
690
691
692
class TestTransformAlgorithm(TestCase):
693
694
    @classmethod
695
    def setUpClass(cls):
696
        setup_logger(cls)
697
        cls.env = TradingEnvironment()
698
        cls.sim_params = factory.create_simulation_parameters(num_days=4,
699
                                                              env=cls.env)
700
        cls.sids = [0, 1, 133]
701
        cls.tempdir = TempDirectory()
702
703
        futures_metadata = {3: {'contract_multiplier': 10}}
704
        equities_metadata = {}
705
706
        for sid in cls.sids:
707
            equities_metadata[sid] = {
708
                'start_date': cls.sim_params.period_start,
709
                'end_date': cls.sim_params.period_end
710
            }
711
712
        cls.env.write_data(equities_data=equities_metadata,
713
                           futures_data=futures_metadata)
714
715
        trades_by_sid = {}
716
        for sid in cls.sids:
717
            trades_by_sid[sid] = factory.create_trade_history(
718
                sid,
719
                [10.0, 10.0, 11.0, 11.0],
720
                [100, 100, 100, 300],
721
                timedelta(days=1),
722
                cls.sim_params,
723
                cls.env
724
            )
725
726
        cls.data_portal = create_data_portal_from_trade_history(cls.env,
727
                                                                cls.tempdir,
728
                                                                cls.sim_params,
729
                                                                trades_by_sid)
730
731
    @classmethod
732
    def tearDownClass(cls):
733
        teardown_logger(cls)
734
        del cls.env
735
        cls.tempdir.cleanup()
736
737
    def test_invalid_order_parameters(self):
738
        algo = InvalidOrderAlgorithm(
739
            sids=[133],
740
            sim_params=self.sim_params,
741
            env=self.env,
742
        )
743
        algo.run(self.data_portal)
744
745
    def test_run_twice(self):
746
        algo1 = TestRegisterTransformAlgorithm(
747
            sim_params=self.sim_params,
748
            sids=[0, 1]
749
        )
750
751
        res1 = algo1.run(self.data_portal)
752
753
        # Create a new trading algorithm, which will
754
        # use the newly instantiated environment.
755
        algo2 = TestRegisterTransformAlgorithm(
756
            sim_params=self.sim_params,
757
            sids=[0, 1]
758
        )
759
760
        res2 = algo2.run(self.data_portal)
761
762
        # FIXME I think we are getting Nans due to fixed benchmark,
763
        # so dropping them for now.
764
        res1 = res1.fillna(method='ffill')
765
        res2 = res2.fillna(method='ffill')
766
767
        np.testing.assert_array_equal(res1, res2)
768
769
    def test_data_frequency_setting(self):
770
        self.sim_params.data_frequency = 'daily'
771
772
        sim_params = factory.create_simulation_parameters(
773
            num_days=4, env=self.env, data_frequency='daily')
774
775
        algo = TestRegisterTransformAlgorithm(
776
            sim_params=sim_params,
777
            env=self.env,
778
        )
779
        self.assertEqual(algo.sim_params.data_frequency, 'daily')
780
781
        sim_params = factory.create_simulation_parameters(
782
            num_days=4, env=self.env, data_frequency='minute')
783
784
        algo = TestRegisterTransformAlgorithm(
785
            sim_params=sim_params,
786
            env=self.env,
787
        )
788
        self.assertEqual(algo.sim_params.data_frequency, 'minute')
789
790
    @parameterized.expand([
791
        (TestOrderAlgorithm,),
792
        (TestOrderValueAlgorithm,),
793
        (TestTargetAlgorithm,),
794
        (TestOrderPercentAlgorithm,),
795
        (TestTargetPercentAlgorithm,),
796
        (TestTargetValueAlgorithm,),
797
    ])
798
    def test_order_methods(self, algo_class):
799
        algo = algo_class(
800
            sim_params=self.sim_params,
801
            env=self.env,
802
        )
803
        # Ensure that the environment's asset 0 is an Equity
804
        asset_to_test = algo.sid(0)
805
        self.assertIsInstance(asset_to_test, Equity)
806
807
        algo.run(self.data_portal)
808
809
    @parameterized.expand([
810
        (TestOrderAlgorithm,),
811
        (TestOrderValueAlgorithm,),
812
        (TestTargetAlgorithm,),
813
        (TestOrderPercentAlgorithm,),
814
        (TestTargetValueAlgorithm,),
815
    ])
816
    def test_order_methods_for_future(self, algo_class):
817
        algo = algo_class(
818
            sim_params=self.sim_params,
819
            env=self.env,
820
        )
821
        # Ensure that the environment's asset 0 is a Future
822
        asset_to_test = algo.sid(3)
823
        self.assertIsInstance(asset_to_test, Future)
824
825
        algo.run(self.data_portal)
826
827
    @parameterized.expand([
828
        ("order",),
829
        ("order_value",),
830
        ("order_percent",),
831
        ("order_target",),
832
        ("order_target_percent",),
833
        ("order_target_value",),
834
    ])
835
    def test_order_method_style_forwarding(self, order_style):
836
        algo = TestOrderStyleForwardingAlgorithm(
837
            sim_params=self.sim_params,
838
            method_name=order_style,
839
            env=self.env
840
        )
841
        algo.run(self.data_portal)
842
843
    @parameterized.expand([
844
        (TestOrderAlgorithm,),
845
        (TestOrderValueAlgorithm,),
846
        (TestTargetAlgorithm,),
847
        (TestOrderPercentAlgorithm,)
848
    ])
849
    def test_minute_data(self, algo_class):
850
        tempdir = TempDirectory()
851
852
        try:
853
            env = TradingEnvironment()
854
855
            sim_params = SimulationParameters(
856
                period_start=pd.Timestamp('2002-1-2', tz='UTC'),
857
                period_end=pd.Timestamp('2002-1-4', tz='UTC'),
858
                capital_base=float("1.0e5"),
859
                data_frequency='minute',
860
                env=env
861
            )
862
863
            equities_metadata = {}
864
865
            for sid in [0, 1]:
866
                equities_metadata[sid] = {
867
                    'start_date': sim_params.period_start,
868
                    'end_date': sim_params.period_end + timedelta(days=1)
869
                }
870
871
            env.write_data(equities_data=equities_metadata)
872
873
            data_portal = create_data_portal(
874
                env,
875
                tempdir,
876
                sim_params,
877
                [0, 1]
878
            )
879
880
            algo = algo_class(sim_params=sim_params, env=env)
881
            algo.run(data_portal)
882
        finally:
883
            tempdir.cleanup()
884
885
886
class TestPositions(TestCase):
887
    @classmethod
888
    def setUpClass(cls):
889
        setup_logger(cls)
890
        cls.env = TradingEnvironment()
891
        cls.sim_params = factory.create_simulation_parameters(num_days=4,
892
                                                              env=cls.env)
893
894
        cls.sids = [1, 133]
895
        cls.tempdir = TempDirectory()
896
897
        equities_metadata = {}
898
899
        for sid in cls.sids:
900
            equities_metadata[sid] = {
901
                'start_date': cls.sim_params.period_start,
902
                'end_date': cls.sim_params.period_end
903
            }
904
905
        cls.env.write_data(equities_data=equities_metadata)
906
907
        cls.data_portal = create_data_portal(
908
            cls.env,
909
            cls.tempdir,
910
            cls.sim_params,
911
            cls.sids
912
        )
913
914
    @classmethod
915
    def tearDownClass(cls):
916
        teardown_logger(cls)
917
        cls.tempdir.cleanup()
918
919
    def test_empty_portfolio(self):
920
        algo = EmptyPositionsAlgorithm(self.sids,
921
                                       sim_params=self.sim_params,
922
                                       env=self.env)
923
        daily_stats = algo.run(self.data_portal)
924
925
        expected_position_count = [
926
            0,  # Before entering the first position
927
            2,  # After entering, exiting on this date
928
            0,  # After exiting
929
            0,
930
        ]
931
932
        for i, expected in enumerate(expected_position_count):
933
            self.assertEqual(daily_stats.ix[i]['num_positions'],
934
                             expected)
935
936
    def test_noop_orders(self):
937
        algo = AmbitiousStopLimitAlgorithm(sid=1,
938
                                           sim_params=self.sim_params,
939
                                           env=self.env)
940
        daily_stats = algo.run(self.data_portal)
941
942
        # Verify that positions are empty for all dates.
943
        empty_positions = daily_stats.positions.map(lambda x: len(x) == 0)
944
        self.assertTrue(empty_positions.all())
945
946
947
class TestAlgoScript(TestCase):
948
949
    @classmethod
950
    def setUpClass(cls):
951
        setup_logger(cls)
952
        cls.env = TradingEnvironment()
953
        cls.sim_params = factory.create_simulation_parameters(num_days=251,
954
                                                              env=cls.env)
955
956
        cls.sids = [0, 1, 3, 133]
957
        cls.tempdir = TempDirectory()
958
959
        equities_metadata = {}
960
961
        for sid in cls.sids:
962
            equities_metadata[sid] = {
963
                'start_date': cls.sim_params.period_start,
964
                'end_date': cls.sim_params.period_end
965
            }
966
967
            if sid == 3:
968
                equities_metadata[sid]["symbol"] = "TEST"
969
                equities_metadata[sid]["asset_type"] = "equity"
970
971
        cls.env.write_data(equities_data=equities_metadata)
972
973
        days = 251
974
975
        trades_by_sid = {
976
            0: factory.create_trade_history(
977
                0,
978
                [10.0] * days,
979
                [100] * days,
980
                timedelta(days=1),
981
                cls.sim_params,
982
                cls.env),
983
            3: factory.create_trade_history(
984
                3,
985
                [10.0] * days,
986
                [100] * days,
987
                timedelta(days=1),
988
                cls.sim_params,
989
                cls.env)
990
        }
991
992
        cls.data_portal = create_data_portal_from_trade_history(cls.env,
993
                                                                cls.tempdir,
994
                                                                cls.sim_params,
995
                                                                trades_by_sid)
996
997
        cls.zipline_test_config = {
998
            'sid': 0,
999
        }
1000
1001
    @classmethod
1002
    def tearDownClass(cls):
1003
        del cls.env
1004
        cls.tempdir.cleanup()
1005
        teardown_logger(cls)
1006
1007
    def test_noop(self):
1008
        algo = TradingAlgorithm(initialize=initialize_noop,
1009
                                handle_data=handle_data_noop)
1010
        algo.run(self.data_portal)
1011
1012
    def test_noop_string(self):
1013
        algo = TradingAlgorithm(script=noop_algo)
1014
        algo.run(self.data_portal)
1015
1016
    def test_api_calls(self):
1017
        algo = TradingAlgorithm(initialize=initialize_api,
1018
                                handle_data=handle_data_api,
1019
                                env=self.env)
1020
        algo.run(self.data_portal)
1021
1022
    def test_api_calls_string(self):
1023
        algo = TradingAlgorithm(script=api_algo, env=self.env)
1024
        algo.run(self.data_portal)
1025
1026
    def test_api_get_environment(self):
1027
        platform = 'zipline'
1028
        algo = TradingAlgorithm(script=api_get_environment_algo,
1029
                                platform=platform)
1030
        algo.run(self.data_portal)
1031
        self.assertEqual(algo.environment, platform)
1032
1033
    def test_api_symbol(self):
1034
        algo = TradingAlgorithm(script=api_symbol_algo,
1035
                                env=self.env,
1036
                                sim_params=self.sim_params)
1037
        algo.run(self.data_portal)
1038
1039
    def test_fixed_slippage(self):
1040
        # verify order -> transaction -> portfolio position.
1041
        # --------------
1042
        test_algo = TradingAlgorithm(
1043
            script="""
1044
from zipline.api import (slippage,
1045
                         commission,
1046
                         set_slippage,
1047
                         set_commission,
1048
                         order,
1049
                         record,
1050
                         sid)
1051
1052
def initialize(context):
1053
    model = slippage.FixedSlippage(spread=0.10)
1054
    set_slippage(model)
1055
    set_commission(commission.PerTrade(100.00))
1056
    context.count = 1
1057
    context.incr = 0
1058
1059
def handle_data(context, data):
1060
    if context.incr < context.count:
1061
        order(sid(0), -1000)
1062
    record(price=data[0].price)
1063
1064
    context.incr += 1""",
1065
            sim_params=self.sim_params,
1066
            env=self.env,
1067
        )
1068
        results = test_algo.run(self.data_portal)
1069
1070
        # flatten the list of txns
1071
        all_txns = [val for sublist in results["transactions"].tolist()
1072
                    for val in sublist]
1073
1074
        self.assertEqual(len(all_txns), 1)
1075
        txn = all_txns[0]
1076
1077
        self.assertEqual(100.0, txn["commission"])
1078
        expected_spread = 0.05
1079
        expected_commish = 0.10
1080
        expected_price = test_algo.recorded_vars["price"] - expected_spread \
1081
            - expected_commish
1082
1083
        self.assertEqual(expected_price, txn['price'])
1084
1085
    @parameterized.expand(
1086
        [
1087
            ('no_minimum_commission', 0,),
1088
            ('default_minimum_commission', 1,),
1089
            ('alternate_minimum_commission', 2,),
1090
        ]
1091
    )
1092
    def test_volshare_slippage(self, name, minimum_commission):
1093
        tempdir = TempDirectory()
1094
        try:
1095
            if name == "default_minimum_commission":
1096
                commission_line = "set_commission(commission.PerShare(0.02))"
1097
            else:
1098
                commission_line = \
1099
                    "set_commission(commission.PerShare(0.02, " \
1100
                    "min_trade_cost={0}))".format(minimum_commission)
1101
1102
            # verify order -> transaction -> portfolio position.
1103
            # --------------
1104
            test_algo = TradingAlgorithm(
1105
                script="""
1106
from zipline.api import *
1107
1108
def initialize(context):
1109
    model = slippage.VolumeShareSlippage(
1110
                            volume_limit=.3,
1111
                            price_impact=0.05
1112
                       )
1113
    set_slippage(model)
1114
    {0}
1115
1116
    context.count = 2
1117
    context.incr = 0
1118
1119
def handle_data(context, data):
1120
    if context.incr < context.count:
1121
        # order small lots to be sure the
1122
        # order will fill in a single transaction
1123
        order(sid(0), 5000)
1124
    record(price=data[0].price)
1125
    record(volume=data[0].volume)
1126
    record(incr=context.incr)
1127
    context.incr += 1
1128
    """.format(commission_line),
1129
                sim_params=self.sim_params,
1130
                env=self.env,
1131
            )
1132
            set_algo_instance(test_algo)
1133
            trades = factory.create_daily_trade_source(
1134
                [0], self.sim_params, self.env)
1135
            data_portal = create_data_portal_from_trade_history(
1136
                self.env, tempdir, self.sim_params, {0: trades})
1137
            results = test_algo.run(data_portal=data_portal)
1138
1139
            all_txns = [
1140
                val for sublist in results["transactions"].tolist()
1141
                for val in sublist]
1142
1143
            self.assertEqual(len(all_txns), 67)
1144
            first_txn = all_txns[0]
1145
1146
            if minimum_commission == 0:
1147
                commish = first_txn["amount"] * 0.02
1148
                self.assertEqual(commish, first_txn["commission"])
1149
            else:
1150
                self.assertEqual(minimum_commission, first_txn["commission"])
1151
1152
            # import pdb; pdb.set_trace()
1153
            # commish = first_txn["amount"] * per_share_commish
1154
            # self.assertEqual(commish, first_txn["commission"])
1155
            # self.assertEqual(2.029, first_txn["price"])
1156
        finally:
1157
            tempdir.cleanup()
1158
1159
    def test_algo_record_vars(self):
1160
        test_algo = TradingAlgorithm(
1161
            script=record_variables,
1162
            sim_params=self.sim_params,
1163
            env=self.env,
1164
        )
1165
        results = test_algo.run(self.data_portal)
1166
1167
        for i in range(1, 252):
1168
            self.assertEqual(results.iloc[i-1]["incr"], i)
1169
1170
    def test_algo_record_allow_mock(self):
1171
        """
1172
        Test that values from "MagicMock"ed methods can be passed to record.
1173
1174
        Relevant for our basic/validation and methods like history, which
1175
        will end up returning a MagicMock instead of a DataFrame.
1176
        """
1177
        test_algo = TradingAlgorithm(
1178
            script=record_variables,
1179
            sim_params=self.sim_params,
1180
        )
1181
        set_algo_instance(test_algo)
1182
1183
        test_algo.record(foo=MagicMock())
1184
1185
    def test_algo_record_nan(self):
1186
        test_algo = TradingAlgorithm(
1187
            script=record_float_magic % 'nan',
1188
            sim_params=self.sim_params,
1189
            env=self.env,
1190
        )
1191
        results = test_algo.run(self.data_portal)
1192
1193
        for i in range(1, 252):
1194
            self.assertTrue(np.isnan(results.iloc[i-1]["data"]))
1195
1196
    def test_order_methods(self):
1197
        """
1198
        Only test that order methods can be called without error.
1199
        Correct filling of orders is tested in zipline.
1200
        """
1201
        test_algo = TradingAlgorithm(
1202
            script=call_all_order_methods,
1203
            sim_params=self.sim_params,
1204
            env=self.env,
1205
        )
1206
        test_algo.run(self.data_portal)
1207
1208
    def test_order_in_init(self):
1209
        """
1210
        Test that calling order in initialize
1211
        will raise an error.
1212
        """
1213
        with self.assertRaises(OrderDuringInitialize):
1214
            test_algo = TradingAlgorithm(
1215
                script=call_order_in_init,
1216
                sim_params=self.sim_params,
1217
                env=self.env,
1218
            )
1219
            test_algo.run(self.data_portal)
1220
1221
    def test_portfolio_in_init(self):
1222
        """
1223
        Test that accessing portfolio in init doesn't break.
1224
        """
1225
        test_algo = TradingAlgorithm(
1226
            script=access_portfolio_in_init,
1227
            sim_params=self.sim_params,
1228
            env=self.env,
1229
        )
1230
        test_algo.run(self.data_portal)
1231
1232
    def test_account_in_init(self):
1233
        """
1234
        Test that accessing account in init doesn't break.
1235
        """
1236
        test_algo = TradingAlgorithm(
1237
            script=access_account_in_init,
1238
            sim_params=self.sim_params,
1239
            env=self.env,
1240
        )
1241
        test_algo.run(self.data_portal)
1242
1243
1244
class TestGetDatetime(TestCase):
1245
1246
    @classmethod
1247
    def setUpClass(cls):
1248
        cls.env = TradingEnvironment()
1249
        cls.env.write_data(equities_identifiers=[0, 1])
1250
1251
        setup_logger(cls)
1252
1253
        cls.sim_params = factory.create_simulation_parameters(
1254
            data_frequency='minute',
1255
            env=cls.env,
1256
            start=to_utc('2014-01-02 9:31'),
1257
            end=to_utc('2014-01-03 9:31')
1258
        )
1259
1260
        cls.tempdir = TempDirectory()
1261
1262
        cls.data_portal = create_data_portal(
1263
            cls.env,
1264
            cls.tempdir,
1265
            cls.sim_params,
1266
            [1]
1267
        )
1268
1269
    @classmethod
1270
    def tearDownClass(cls):
1271
        del cls.env
1272
        teardown_logger(cls)
1273
        cls.tempdir.cleanup()
1274
1275
    @parameterized.expand(
1276
        [
1277
            ('default', None,),
1278
            ('utc', 'UTC',),
1279
            ('us_east', 'US/Eastern',),
1280
        ]
1281
    )
1282
    def test_get_datetime(self, name, tz):
1283
        algo = dedent(
1284
            """
1285
            import pandas as pd
1286
            from zipline.api import get_datetime
1287
1288
            def initialize(context):
1289
                context.tz = {tz} or 'UTC'
1290
                context.first_bar = True
1291
1292
            def handle_data(context, data):
1293
                if context.first_bar:
1294
                    dt = get_datetime({tz})
1295
                    if dt.tz.zone != context.tz:
1296
                        raise ValueError("Mismatched Zone")
1297
                    elif dt.tz_convert("US/Eastern").hour != 9:
1298
                        raise ValueError("Mismatched Hour")
1299
                    elif dt.tz_convert("US/Eastern").minute != 31:
1300
                        raise ValueError("Mismatched Minute")
1301
                context.first_bar = False
1302
            """.format(tz=repr(tz))
1303
        )
1304
1305
        algo = TradingAlgorithm(
1306
            script=algo,
1307
            sim_params=self.sim_params,
1308
            env=self.env,
1309
        )
1310
        algo.run(self.data_portal)
1311
        self.assertFalse(algo.first_bar)
1312
1313
1314
class TestTradingControls(TestCase):
1315
1316
    @classmethod
1317
    def setUpClass(cls):
1318
        cls.sid = 133
1319
        cls.env = TradingEnvironment()
1320
        cls.sim_params = factory.create_simulation_parameters(num_days=4,
1321
                                                              env=cls.env)
1322
1323
        cls.env.write_data(equities_data={
1324
            133: {
1325
                'start_date': cls.sim_params.period_start,
1326
                'end_date': cls.sim_params.period_end
1327
            }
1328
        })
1329
1330
        cls.tempdir = TempDirectory()
1331
1332
        cls.data_portal = create_data_portal(
1333
            cls.env,
1334
            cls.tempdir,
1335
            cls.sim_params,
1336
            [cls.sid]
1337
        )
1338
1339
    @classmethod
1340
    def tearDownClass(cls):
1341
        del cls.env
1342
        cls.tempdir.cleanup()
1343
1344
    def _check_algo(self,
1345
                    algo,
1346
                    handle_data,
1347
                    expected_order_count,
1348
                    expected_exc):
1349
1350
        algo._handle_data = handle_data
1351
        with self.assertRaises(expected_exc) if expected_exc else nullctx():
1352
            algo.run(self.data_portal)
1353
        self.assertEqual(algo.order_count, expected_order_count)
1354
1355
    def check_algo_succeeds(self, algo, handle_data, order_count=4):
1356
        # Default for order_count assumes one order per handle_data call.
1357
        self._check_algo(algo, handle_data, order_count, None)
1358
1359
    def check_algo_fails(self, algo, handle_data, order_count):
1360
        self._check_algo(algo,
1361
                         handle_data,
1362
                         order_count,
1363
                         TradingControlViolation)
1364
1365
    def test_set_max_position_size(self):
1366
1367
        # Buy one share four times.  Should be fine.
1368
        def handle_data(algo, data):
1369
            algo.order(algo.sid(self.sid), 1)
1370
            algo.order_count += 1
1371
        algo = SetMaxPositionSizeAlgorithm(sid=self.sid,
1372
                                           max_shares=10,
1373
                                           max_notional=500.0,
1374
                                           sim_params=self.sim_params,
1375
                                           env=self.env)
1376
        self.check_algo_succeeds(algo, handle_data)
1377
1378
        # Buy three shares four times.  Should bail on the fourth before it's
1379
        # placed.
1380
        def handle_data(algo, data):
1381
            algo.order(algo.sid(self.sid), 3)
1382
            algo.order_count += 1
1383
1384
        algo = SetMaxPositionSizeAlgorithm(sid=self.sid,
1385
                                           max_shares=10,
1386
                                           max_notional=500.0,
1387
                                           sim_params=self.sim_params,
1388
                                           env=self.env)
1389
        self.check_algo_fails(algo, handle_data, 3)
1390
1391
        # Buy three shares four times. Should bail due to max_notional on the
1392
        # third attempt.
1393
        def handle_data(algo, data):
1394
            algo.order(algo.sid(self.sid), 3)
1395
            algo.order_count += 1
1396
1397
        algo = SetMaxPositionSizeAlgorithm(sid=self.sid,
1398
                                           max_shares=10,
1399
                                           max_notional=67.0,
1400
                                           sim_params=self.sim_params,
1401
                                           env=self.env)
1402
        self.check_algo_fails(algo, handle_data, 2)
1403
1404
        # Set the trading control to a different sid, then BUY ALL THE THINGS!.
1405
        # Should continue normally.
1406
        def handle_data(algo, data):
1407
            algo.order(algo.sid(self.sid), 10000)
1408
            algo.order_count += 1
1409
        algo = SetMaxPositionSizeAlgorithm(sid=self.sid + 1,
1410
                                           max_shares=10,
1411
                                           max_notional=67.0,
1412
                                           sim_params=self.sim_params,
1413
                                           env=self.env)
1414
        self.check_algo_succeeds(algo, handle_data)
1415
1416
        # Set the trading control sid to None, then BUY ALL THE THINGS!. Should
1417
        # fail because setting sid to None makes the control apply to all sids.
1418
        def handle_data(algo, data):
1419
            algo.order(algo.sid(self.sid), 10000)
1420
            algo.order_count += 1
1421
        algo = SetMaxPositionSizeAlgorithm(max_shares=10, max_notional=61.0,
1422
                                           sim_params=self.sim_params,
1423
                                           env=self.env)
1424
        self.check_algo_fails(algo, handle_data, 0)
1425
1426
    def test_set_do_not_order_list(self):
1427
        # set the restricted list to be the sid, and fail.
1428
        algo = SetDoNotOrderListAlgorithm(
1429
            sid=self.sid,
1430
            restricted_list=[self.sid],
1431
            sim_params=self.sim_params,
1432
            env=self.env,
1433
        )
1434
1435
        def handle_data(algo, data):
1436
            algo.order(algo.sid(self.sid), 100)
1437
            algo.order_count += 1
1438
1439
        self.check_algo_fails(algo, handle_data, 0)
1440
1441
        # set the restricted list to exclude the sid, and succeed
1442
        algo = SetDoNotOrderListAlgorithm(
1443
            sid=self.sid,
1444
            restricted_list=[134, 135, 136],
1445
            sim_params=self.sim_params,
1446
            env=self.env,
1447
        )
1448
1449
        def handle_data(algo, data):
1450
            algo.order(algo.sid(self.sid), 100)
1451
            algo.order_count += 1
1452
1453
        self.check_algo_succeeds(algo, handle_data)
1454
1455
    def test_set_max_order_size(self):
1456
1457
        # Buy one share.
1458
        def handle_data(algo, data):
1459
            algo.order(algo.sid(self.sid), 1)
1460
            algo.order_count += 1
1461
        algo = SetMaxOrderSizeAlgorithm(sid=self.sid,
1462
                                        max_shares=10,
1463
                                        max_notional=500.0,
1464
                                        sim_params=self.sim_params,
1465
                                        env=self.env)
1466
        self.check_algo_succeeds(algo, handle_data)
1467
1468
        # Buy 1, then 2, then 3, then 4 shares.  Bail on the last attempt
1469
        # because we exceed shares.
1470
        def handle_data(algo, data):
1471
            algo.order(algo.sid(self.sid), algo.order_count + 1)
1472
            algo.order_count += 1
1473
1474
        algo = SetMaxOrderSizeAlgorithm(sid=self.sid,
1475
                                        max_shares=3,
1476
                                        max_notional=500.0,
1477
                                        sim_params=self.sim_params,
1478
                                        env=self.env)
1479
        self.check_algo_fails(algo, handle_data, 3)
1480
1481
        # Buy 1, then 2, then 3, then 4 shares.  Bail on the last attempt
1482
        # because we exceed notional.
1483
        def handle_data(algo, data):
1484
            algo.order(algo.sid(self.sid), algo.order_count + 1)
1485
            algo.order_count += 1
1486
1487
        algo = SetMaxOrderSizeAlgorithm(sid=self.sid,
1488
                                        max_shares=10,
1489
                                        max_notional=40.0,
1490
                                        sim_params=self.sim_params,
1491
                                        env=self.env)
1492
        self.check_algo_fails(algo, handle_data, 3)
1493
1494
        # Set the trading control to a different sid, then BUY ALL THE THINGS!.
1495
        # Should continue normally.
1496
        def handle_data(algo, data):
1497
            algo.order(algo.sid(self.sid), 10000)
1498
            algo.order_count += 1
1499
        algo = SetMaxOrderSizeAlgorithm(sid=self.sid + 1,
1500
                                        max_shares=1,
1501
                                        max_notional=1.0,
1502
                                        sim_params=self.sim_params,
1503
                                        env=self.env)
1504
        self.check_algo_succeeds(algo, handle_data)
1505
1506
        # Set the trading control sid to None, then BUY ALL THE THINGS!.
1507
        # Should fail because not specifying a sid makes the trading control
1508
        # apply to all sids.
1509
        def handle_data(algo, data):
1510
            algo.order(algo.sid(self.sid), 10000)
1511
            algo.order_count += 1
1512
        algo = SetMaxOrderSizeAlgorithm(max_shares=1,
1513
                                        max_notional=1.0,
1514
                                        sim_params=self.sim_params,
1515
                                        env=self.env)
1516
        self.check_algo_fails(algo, handle_data, 0)
1517
1518
    def test_set_max_order_count(self):
1519
        tempdir = TempDirectory()
1520
        try:
1521
            env = TradingEnvironment()
1522
            sim_params = factory.create_simulation_parameters(
1523
                num_days=4, env=env, data_frequency="minute")
1524
1525
            env.write_data(equities_data={
1526
                1: {
1527
                    'start_date': sim_params.period_start,
1528
                    'end_date': sim_params.period_end + timedelta(days=1)
1529
                }
1530
            })
1531
1532
            data_portal = create_data_portal(
1533
                env,
1534
                tempdir,
1535
                sim_params,
1536
                [1]
1537
            )
1538
1539
            def handle_data(algo, data):
1540
                for i in range(5):
1541
                    algo.order(algo.sid(1), 1)
1542
                    algo.order_count += 1
1543
1544
            algo = SetMaxOrderCountAlgorithm(3, sim_params=sim_params,
1545
                                             env=env)
1546
            with self.assertRaises(TradingControlViolation):
1547
                algo._handle_data = handle_data
1548
                algo.run(data_portal)
1549
1550
            self.assertEqual(algo.order_count, 3)
1551
1552
            # This time, order 5 times twice in a single day. The last order
1553
            # of the second batch should fail.
1554
            def handle_data2(algo, data):
1555
                if algo.minute_count == 0 or algo.minute_count == 100:
1556
                    for i in range(5):
1557
                        algo.order(algo.sid(1), 1)
1558
                        algo.order_count += 1
1559
1560
                algo.minute_count += 1
1561
1562
            algo = SetMaxOrderCountAlgorithm(9, sim_params=sim_params,
1563
                                             env=env)
1564
            with self.assertRaises(TradingControlViolation):
1565
                algo._handle_data = handle_data2
1566
                algo.run(data_portal)
1567
1568
            self.assertEqual(algo.order_count, 9)
1569
1570
            def handle_data3(algo, data):
1571
                if (algo.minute_count % 390) == 0:
1572
                    for i in range(5):
1573
                        algo.order(algo.sid(1), 1)
1574
                        algo.order_count += 1
1575
1576
                algo.minute_count += 1
1577
1578
            # Only 5 orders are placed per day, so this should pass even
1579
            # though in total more than 20 orders are placed.
1580
            algo = SetMaxOrderCountAlgorithm(5, sim_params=sim_params,
1581
                                             env=env)
1582
            algo._handle_data = handle_data3
1583
            algo.run(data_portal)
1584
        finally:
1585
            tempdir.cleanup()
1586
1587
    def test_long_only(self):
1588
        # Sell immediately -> fail immediately.
1589
        def handle_data(algo, data):
1590
            algo.order(algo.sid(self.sid), -1)
1591
            algo.order_count += 1
1592
        algo = SetLongOnlyAlgorithm(sim_params=self.sim_params, env=self.env)
1593
        self.check_algo_fails(algo, handle_data, 0)
1594
1595
        # Buy on even days, sell on odd days.  Never takes a short position, so
1596
        # should succeed.
1597
        def handle_data(algo, data):
1598
            if (algo.order_count % 2) == 0:
1599
                algo.order(algo.sid(self.sid), 1)
1600
            else:
1601
                algo.order(algo.sid(self.sid), -1)
1602
            algo.order_count += 1
1603
        algo = SetLongOnlyAlgorithm(sim_params=self.sim_params, env=self.env)
1604
        self.check_algo_succeeds(algo, handle_data)
1605
1606
        # Buy on first three days, then sell off holdings.  Should succeed.
1607
        def handle_data(algo, data):
1608
            amounts = [1, 1, 1, -3]
1609
            algo.order(algo.sid(self.sid), amounts[algo.order_count])
1610
            algo.order_count += 1
1611
        algo = SetLongOnlyAlgorithm(sim_params=self.sim_params, env=self.env)
1612
        self.check_algo_succeeds(algo, handle_data)
1613
1614
        # Buy on first three days, then sell off holdings plus an extra share.
1615
        # Should fail on the last sale.
1616
        def handle_data(algo, data):
1617
            amounts = [1, 1, 1, -4]
1618
            algo.order(algo.sid(self.sid), amounts[algo.order_count])
1619
            algo.order_count += 1
1620
        algo = SetLongOnlyAlgorithm(sim_params=self.sim_params, env=self.env)
1621
        self.check_algo_fails(algo, handle_data, 3)
1622
1623
    def test_register_post_init(self):
1624
1625
        def initialize(algo):
1626
            algo.initialized = True
1627
1628
        def handle_data(algo, data):
1629
            with self.assertRaises(RegisterTradingControlPostInit):
1630
                algo.set_max_position_size(self.sid, 1, 1)
1631
            with self.assertRaises(RegisterTradingControlPostInit):
1632
                algo.set_max_order_size(self.sid, 1, 1)
1633
            with self.assertRaises(RegisterTradingControlPostInit):
1634
                algo.set_max_order_count(1)
1635
            with self.assertRaises(RegisterTradingControlPostInit):
1636
                algo.set_long_only()
1637
1638
        algo = TradingAlgorithm(initialize=initialize,
1639
                                handle_data=handle_data,
1640
                                sim_params=self.sim_params,
1641
                                env=self.env)
1642
        algo.run(self.data_portal)
1643
1644
    def test_asset_date_bounds(self):
1645
        tempdir = TempDirectory()
1646
        try:
1647
            # Run the algorithm with a sid that ends far in the future
1648
            temp_env = TradingEnvironment()
1649
1650
            data_portal = create_data_portal(
1651
                temp_env,
1652
                tempdir,
1653
                self.sim_params,
1654
                [0]
1655
            )
1656
1657
            metadata = {0: {'start_date': self.sim_params.period_start,
1658
                            'end_date': '2020-01-01'}}
1659
1660
            algo = SetAssetDateBoundsAlgorithm(
1661
                equities_metadata=metadata,
1662
                sim_params=self.sim_params,
1663
                env=temp_env,
1664
            )
1665
            algo.run(data_portal)
1666
        finally:
1667
            tempdir.cleanup()
1668
1669
        # Run the algorithm with a sid that has already ended
1670
        tempdir = TempDirectory()
1671
        try:
1672
            temp_env = TradingEnvironment()
1673
1674
            data_portal = create_data_portal(
1675
                temp_env,
1676
                tempdir,
1677
                self.sim_params,
1678
                [0]
1679
            )
1680
            metadata = {0: {'start_date': '1989-01-01',
1681
                            'end_date': '1990-01-01'}}
1682
1683
            algo = SetAssetDateBoundsAlgorithm(
1684
                equities_metadata=metadata,
1685
                sim_params=self.sim_params,
1686
                env=temp_env,
1687
            )
1688
            with self.assertRaises(TradingControlViolation):
1689
                algo.run(data_portal)
1690
        finally:
1691
            tempdir.cleanup()
1692
1693
        # Run the algorithm with a sid that has not started
1694
        tempdir = TempDirectory()
1695
        try:
1696
            temp_env = TradingEnvironment()
1697
            data_portal = create_data_portal(
1698
                temp_env,
1699
                tempdir,
1700
                self.sim_params,
1701
                [0]
1702
            )
1703
1704
            metadata = {0: {'start_date': '2020-01-01',
1705
                            'end_date': '2021-01-01'}}
1706
1707
            algo = SetAssetDateBoundsAlgorithm(
1708
                equities_metadata=metadata,
1709
                sim_params=self.sim_params,
1710
                env=temp_env,
1711
            )
1712
1713
            with self.assertRaises(TradingControlViolation):
1714
                algo.run(data_portal)
1715
1716
        finally:
1717
            tempdir.cleanup()
1718
1719
1720
class TestAccountControls(TestCase):
1721
1722
    @classmethod
1723
    def setUpClass(cls):
1724
        cls.sidint = 133
1725
        cls.env = TradingEnvironment()
1726
        cls.sim_params = factory.create_simulation_parameters(
1727
            num_days=4, env=cls.env
1728
        )
1729
1730
        cls.env.write_data(equities_data={
1731
            133: {
1732
                'start_date': cls.sim_params.period_start,
1733
                'end_date': cls.sim_params.period_end + timedelta(days=1)
1734
            }
1735
        })
1736
1737
        cls.tempdir = TempDirectory()
1738
1739
        trades_by_sid = {
1740
            cls.sidint: factory.create_trade_history(
1741
                cls.sidint,
1742
                [10.0, 10.0, 11.0, 11.0],
1743
                [100, 100, 100, 300],
1744
                timedelta(days=1),
1745
                cls.sim_params,
1746
                cls.env,
1747
            )
1748
        }
1749
1750
        cls.data_portal = create_data_portal_from_trade_history(cls.env,
1751
                                                                cls.tempdir,
1752
                                                                cls.sim_params,
1753
                                                                trades_by_sid)
1754
1755
    @classmethod
1756
    def tearDownClass(cls):
1757
        del cls.env
1758
        cls.tempdir.cleanup()
1759
1760
    def _check_algo(self,
1761
                    algo,
1762
                    handle_data,
1763
                    expected_exc):
1764
1765
        algo._handle_data = handle_data
1766
        with self.assertRaises(expected_exc) if expected_exc else nullctx():
1767
            algo.run(self.data_portal)
1768
1769
    def check_algo_succeeds(self, algo, handle_data):
1770
        # Default for order_count assumes one order per handle_data call.
1771
        self._check_algo(algo, handle_data, None)
1772
1773
    def check_algo_fails(self, algo, handle_data):
1774
        self._check_algo(algo,
1775
                         handle_data,
1776
                         AccountControlViolation)
1777
1778
    def test_set_max_leverage(self):
1779
1780
        # Set max leverage to 0 so buying one share fails.
1781
        def handle_data(algo, data):
1782
            algo.order(algo.sid(self.sidint), 1)
1783
1784
        algo = SetMaxLeverageAlgorithm(0, sim_params=self.sim_params,
1785
                                       env=self.env)
1786
        self.check_algo_fails(algo, handle_data)
1787
1788
        # Set max leverage to 1 so buying one share passes
1789
        def handle_data(algo, data):
1790
            algo.order(algo.sid(self.sidint), 1)
1791
1792
        algo = SetMaxLeverageAlgorithm(1,  sim_params=self.sim_params,
1793
                                       env=self.env)
1794
        self.check_algo_succeeds(algo, handle_data)
1795
1796
1797
class TestClosePosAlgo(TestCase):
1798
1799
    @classmethod
1800
    def setUpClass(cls):
1801
        cls.tempdir = TempDirectory()
1802
1803
        cls.env = TradingEnvironment()
1804
        cls.days = pd.date_range(start=pd.Timestamp("2006-01-09", tz='UTC'),
1805
                                 end=pd.Timestamp("2006-01-12", tz='UTC'))
1806
1807
        cls.sid = 1
1808
1809
        cls.sim_params = factory.create_simulation_parameters(
1810
            start=cls.days[0],
1811
            end=cls.days[-1]
1812
        )
1813
1814
        trades_by_sid = {}
1815
        trades_by_sid[cls.sid] = factory.create_trade_history(
1816
            cls.sid,
1817
            [1, 1, 2, 4],
1818
            [1e9, 1e9, 1e9, 1e9],
1819
            timedelta(days=1),
1820
            cls.sim_params,
1821
            cls.env
1822
        )
1823
1824
        cls.data_portal = create_data_portal_from_trade_history(
1825
            cls.env,
1826
            cls.tempdir,
1827
            cls.sim_params,
1828
            trades_by_sid
1829
        )
1830
1831
    @classmethod
1832
    def tearDownClass(cls):
1833
        cls.tempdir.cleanup()
1834
1835
    @skip
1836
    def test_auto_close_future(self):
1837
        self.env.write_data(futures_data={
1838
            1: {
1839
                "start_date": self.sim_params.trading_days[0],
1840
                "end_date": self.env.next_trading_day(
1841
                    self.sim_params.trading_days[-1]),
1842
                'symbol': 'TEST',
1843
                'asset_type': 'future',
1844
                'auto_close_date': self.env.next_trading_day(
1845
                    self.sim_params.trading_days[-1])
1846
            }
1847
        })
1848
1849
        algo = TestAlgorithm(sid=1, amount=1, order_count=1,
1850
                             commission=PerShare(0),
1851
                             env=self.env,
1852
                             sim_params=self.sim_params)
1853
1854
        # Check results
1855
        results = algo.run(self.data_portal)
1856
1857
        expected_positions = [0, 1, 1, 0]
1858
        self.check_algo_positions(results, expected_positions)
1859
1860
        expected_pnl = [0, 0, 1, 2]
1861
        self.check_algo_pnl(results, expected_pnl)
1862
1863
    def check_algo_pnl(self, results, expected_pnl):
1864
        np.testing.assert_array_almost_equal(results.pnl, expected_pnl)
1865
1866
    def check_algo_positions(self, results, expected_positions):
1867
        for i, amount in enumerate(results.positions):
1868
            if amount:
1869
                actual_position = amount[0]['amount']
1870
            else:
1871
                actual_position = 0
1872
1873
            self.assertEqual(
1874
                actual_position, expected_positions[i],
1875
                "position for day={0} not equal, actual={1}, expected={2}".
1876
                format(i, actual_position, expected_positions[i]))
1877
1878
1879
class TestFutureFlip(TestCase):
1880
    @classmethod
1881
    def setUpClass(cls):
1882
        cls.tempdir = TempDirectory()
1883
1884
        cls.env = TradingEnvironment()
1885
        cls.days = pd.date_range(start=pd.Timestamp("2006-01-09", tz='UTC'),
1886
                                 end=pd.Timestamp("2006-01-12", tz='UTC'))
1887
1888
        cls.sid = 1
1889
1890
        cls.sim_params = factory.create_simulation_parameters(
1891
            start=cls.days[0],
1892
            end=cls.days[-2]
1893
        )
1894
1895
        trades = factory.create_trade_history(
1896
            cls.sid,
1897
            [1, 2, 4],
1898
            [1e9, 1e9, 1e9],
1899
            timedelta(days=1),
1900
            cls.sim_params,
1901
            cls.env
1902
        )
1903
1904
        trades_by_sid = {
1905
            cls.sid: trades
1906
        }
1907
1908
        cls.data_portal = create_data_portal_from_trade_history(
1909
            cls.env,
1910
            cls.tempdir,
1911
            cls.sim_params,
1912
            trades_by_sid
1913
        )
1914
1915
    @classmethod
1916
    def tearDownClass(cls):
1917
        cls.tempdir.cleanup()
1918
1919
    def test_flip_algo(self):
1920
        metadata = {1: {'symbol': 'TEST',
1921
                        'start_date': self.sim_params.trading_days[0],
1922
                        'end_date': self.env.next_trading_day(
1923
                            self.sim_params.trading_days[-1]),
1924
                        'contract_multiplier': 5}}
1925
        self.env.write_data(futures_data=metadata)
1926
1927
        algo = FutureFlipAlgo(sid=1, amount=1, env=self.env,
1928
                              commission=PerShare(0),
1929
                              order_count=0,  # not applicable but required
1930
                              sim_params=self.sim_params)
1931
1932
        results = algo.run(self.data_portal)
1933
1934
        expected_positions = [0, 1, -1]
1935
        self.check_algo_positions(results, expected_positions)
1936
1937
        expected_pnl = [0, 5, -10]
1938
        self.check_algo_pnl(results, expected_pnl)
1939
1940
    def check_algo_pnl(self, results, expected_pnl):
1941
        np.testing.assert_array_almost_equal(results.pnl, expected_pnl)
1942
1943
    def check_algo_positions(self, results, expected_positions):
1944
        for i, amount in enumerate(results.positions):
1945
            if amount:
1946
                actual_position = amount[0]['amount']
1947
            else:
1948
                actual_position = 0
1949
1950
            self.assertEqual(
1951
                actual_position, expected_positions[i],
1952
                "position for day={0} not equal, actual={1}, expected={2}".
1953
                format(i, actual_position, expected_positions[i]))
1954
1955
1956
class TestTradingAlgorithm(TestCase):
1957
    def setUp(self):
1958
        self.env = TradingEnvironment()
1959
        self.days = self.env.trading_days[:4]
1960
1961
    def test_analyze_called(self):
1962
        self.perf_ref = None
1963
1964
        def initialize(context):
1965
            pass
1966
1967
        def handle_data(context, data):
1968
            pass
1969
1970
        def analyze(context, perf):
1971
            self.perf_ref = perf
1972
1973
        algo = TradingAlgorithm(initialize=initialize, handle_data=handle_data,
1974
                                analyze=analyze)
1975
1976
        data_portal = FakeDataPortal()
1977
1978
        results = algo.run(data_portal)
1979
        self.assertIs(results, self.perf_ref)
1980