Completed
Pull Request — master (#858)
by Eddie
01:41 queued 11s
created

tests.TestGetDatetime.test_get_datetime()   B

Complexity

Conditions 1

Size

Total Lines 37

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 1
dl 0
loc 37
rs 8.8571
1
#
2
# Copyright 2014 Quantopian, Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
import datetime
16
from datetime import timedelta
17
from mock import MagicMock
18
from nose_parameterized import parameterized
19
from six.moves import range
20
from testfixtures import TempDirectory
21
from textwrap import dedent
22
from unittest import TestCase
23
24
import numpy as np
25
import pandas as pd
26
27
from zipline.assets import Equity, Future
28
from zipline.utils.api_support import ZiplineAPI
29
from zipline.utils.control_flow import nullctx
30
from zipline.utils.test_utils import (
31
    setup_logger,
32
    teardown_logger,
33
    FakeDataPortal)
34
import zipline.utils.factory as factory
35
36
from zipline.errors import (
37
    OrderDuringInitialize,
38
    RegisterTradingControlPostInit,
39
    TradingControlViolation,
40
    AccountControlViolation,
41
    SymbolNotFound,
42
    RootSymbolNotFound,
43
    UnsupportedDatetimeFormat,
44
)
45
from zipline.test_algorithms import (
46
    access_account_in_init,
47
    access_portfolio_in_init,
48
    AmbitiousStopLimitAlgorithm,
49
    EmptyPositionsAlgorithm,
50
    InvalidOrderAlgorithm,
51
    RecordAlgorithm,
52
    FutureFlipAlgo,
53
    TestAlgorithm,
54
    TestOrderAlgorithm,
55
    TestOrderPercentAlgorithm,
56
    TestOrderStyleForwardingAlgorithm,
57
    TestOrderValueAlgorithm,
58
    TestRegisterTransformAlgorithm,
59
    TestTargetAlgorithm,
60
    TestTargetPercentAlgorithm,
61
    TestTargetValueAlgorithm,
62
    SetLongOnlyAlgorithm,
63
    SetAssetDateBoundsAlgorithm,
64
    SetMaxPositionSizeAlgorithm,
65
    SetMaxOrderCountAlgorithm,
66
    SetMaxOrderSizeAlgorithm,
67
    SetDoNotOrderListAlgorithm,
68
    SetMaxLeverageAlgorithm,
69
    api_algo,
70
    api_get_environment_algo,
71
    api_symbol_algo,
72
    call_all_order_methods,
73
    call_order_in_init,
74
    handle_data_api,
75
    handle_data_noop,
76
    initialize_api,
77
    initialize_noop,
78
    noop_algo,
79
    record_float_magic,
80
    record_variables,
81
)
82
from zipline.utils.context_tricks import CallbackManager
83
import zipline.utils.events
84
from zipline.utils.test_utils import to_utc
85
86
from zipline.finance.execution import LimitOrder
87
from zipline.finance.trading import SimulationParameters
88
from zipline.utils.api_support import set_algo_instance
89
from zipline.utils.events import DateRuleFactory, TimeRuleFactory, Always
90
from zipline.algorithm import TradingAlgorithm
91
from zipline.finance.trading import TradingEnvironment
92
from zipline.finance.commission import PerShare
93
94
from zipline.utils.test_utils import (
95
    create_data_portal,
96
    create_data_portal_from_trade_history
97
)
98
99
# Because test cases appear to reuse some resources.
100
101
102
_multiprocess_can_split_ = False
103
104
105
class TestRecordAlgorithm(TestCase):
106
107
    @classmethod
108
    def setUpClass(cls):
109
        cls.env = TradingEnvironment()
110
        cls.sids = [133]
111
        cls.env.write_data(equities_identifiers=cls.sids)
112
113
        cls.sim_params = factory.create_simulation_parameters(
114
            num_days=4,
115
            env=cls.env
116
        )
117
118
        cls.tempdir = TempDirectory()
119
120
        cls.data_portal = create_data_portal(
121
            cls.env,
122
            cls.tempdir,
123
            cls.sim_params,
124
            cls.sids
125
        )
126
127
    @classmethod
128
    def tearDownClass(cls):
129
        del cls.env
130
        cls.tempdir.cleanup()
131
132
    def test_record_incr(self):
133
        algo = RecordAlgorithm(sim_params=self.sim_params, env=self.env)
134
        output = algo.run(self.data_portal)
135
136
        np.testing.assert_array_equal(output['incr'].values,
137
                                      range(1, len(output) + 1))
138
        np.testing.assert_array_equal(output['name'].values,
139
                                      range(1, len(output) + 1))
140
        np.testing.assert_array_equal(output['name2'].values,
141
                                      [2] * len(output))
142
        np.testing.assert_array_equal(output['name3'].values,
143
                                      range(1, len(output) + 1))
144
145
146
class TestMiscellaneousAPI(TestCase):
147
148
    @classmethod
149
    def setUpClass(cls):
150
        cls.sids = [1, 2]
151
        cls.env = TradingEnvironment()
152
153
        metadata = {3: {'symbol': 'PLAY',
154
                        'start_date': '2002-01-01',
155
                        'end_date': '2004-01-01'},
156
                    4: {'symbol': 'PLAY',
157
                        'start_date': '2005-01-01',
158
                        'end_date': '2006-01-01'}}
159
160
        futures_metadata = {
161
            5: {
162
                'symbol': 'CLG06',
163
                'root_symbol': 'CL',
164
                'start_date': pd.Timestamp('2005-12-01', tz='UTC'),
165
                'notice_date': pd.Timestamp('2005-12-20', tz='UTC'),
166
                'expiration_date': pd.Timestamp('2006-01-20', tz='UTC')},
167
            6: {
168
                'root_symbol': 'CL',
169
                'symbol': 'CLK06',
170
                'start_date': pd.Timestamp('2005-12-01', tz='UTC'),
171
                'notice_date': pd.Timestamp('2006-03-20', tz='UTC'),
172
                'expiration_date': pd.Timestamp('2006-04-20', tz='UTC')},
173
            7: {
174
                'symbol': 'CLQ06',
175
                'root_symbol': 'CL',
176
                'start_date': pd.Timestamp('2005-12-01', tz='UTC'),
177
                'notice_date': pd.Timestamp('2006-06-20', tz='UTC'),
178
                'expiration_date': pd.Timestamp('2006-07-20', tz='UTC')},
179
            8: {
180
                'symbol': 'CLX06',
181
                'root_symbol': 'CL',
182
                'start_date': pd.Timestamp('2006-02-01', tz='UTC'),
183
                'notice_date': pd.Timestamp('2006-09-20', tz='UTC'),
184
                'expiration_date': pd.Timestamp('2006-10-20', tz='UTC')}
185
        }
186
        cls.env.write_data(equities_identifiers=cls.sids,
187
                           equities_data=metadata,
188
                           futures_data=futures_metadata)
189
190
        setup_logger(cls)
191
192
        cls.sim_params = factory.create_simulation_parameters(
193
            num_days=2,
194
            data_frequency='minute',
195
            emission_rate='daily',
196
            env=cls.env,
197
        )
198
199
        cls.temp_dir = TempDirectory()
200
201
        cls.data_portal = create_data_portal(
202
            cls.env,
203
            cls.temp_dir,
204
            cls.sim_params,
205
            cls.sids
206
        )
207
208
    @classmethod
209
    def tearDownClass(cls):
210
        del cls.env
211
        teardown_logger(cls)
212
        cls.temp_dir.cleanup()
213
214
    def test_zipline_api_resolves_dynamically(self):
215
        # Make a dummy algo.
216
        algo = TradingAlgorithm(
217
            initialize=lambda context: None,
218
            handle_data=lambda context, data: None,
219
            sim_params=self.sim_params,
220
        )
221
222
        # Verify that api methods get resolved dynamically by patching them out
223
        # and then calling them
224
        for method in algo.all_api_methods():
225
            name = method.__name__
226
            sentinel = object()
227
228
            def fake_method(*args, **kwargs):
229
                return sentinel
230
            setattr(algo, name, fake_method)
231
            with ZiplineAPI(algo):
232
                self.assertIs(sentinel, getattr(zipline.api, name)())
233
234
    def test_cannot_iterate_over_data(self):
235
        def initialize(algo):
236
            pass
237
238
        def handle_data(algo, data):
239
            for asset in data:
240
                pass
241
242
        algo = TradingAlgorithm(initialize=initialize,
243
                                handle_data=handle_data,
244
                                sim_params=self.sim_params,
245
                                env=self.env)
246
247
        with self.assertRaises(ValueError):
248
            algo.run(self.data_portal)
249
250
    def test_get_sid(self):
251
        algo_text = """
252
from zipline.api import sid
253
254
def initialize(context):
255
    pass
256
257
def handle_data(context, data):
258
    aapl = data[sid(3)]
259
    assert aapl.sid == 3
260
"""
261
262
        algo = TradingAlgorithm(script=algo_text,
263
                                sim_params=self.sim_params,
264
                                env=self.env)
265
266
        algo.run(self.data_portal)
267
268
    def test_get_environment(self):
269
        expected_env = {
270
            'arena': 'backtest',
271
            'data_frequency': 'minute',
272
            'start': pd.Timestamp('2006-01-03 14:31:00+0000', tz='UTC'),
273
            'end': pd.Timestamp('2006-01-04 21:00:00+0000', tz='UTC'),
274
            'capital_base': 100000.0,
275
            'platform': 'zipline'
276
        }
277
278
        def initialize(algo):
279
            self.assertEqual('zipline', algo.get_environment())
280
            self.assertEqual(expected_env, algo.get_environment('*'))
281
282
        def handle_data(algo, data):
283
            pass
284
285
        algo = TradingAlgorithm(initialize=initialize,
286
                                handle_data=handle_data,
287
                                sim_params=self.sim_params,
288
                                env=self.env)
289
        algo.run(self.data_portal)
290
291
    def test_get_open_orders(self):
292
        def initialize(algo):
293
            algo.minute = 0
294
295
        def handle_data(algo, data):
296
            if algo.minute == 0:
297
298
                # Should be filled by the next minute
299
                algo.order(algo.sid(1), 1)
300
301
                # Won't be filled because the price is too low.
302
                algo.order(algo.sid(2), 1, style=LimitOrder(0.01))
303
                algo.order(algo.sid(2), 1, style=LimitOrder(0.01))
304
                algo.order(algo.sid(2), 1, style=LimitOrder(0.01))
305
306
                all_orders = algo.get_open_orders()
307
                self.assertEqual(list(all_orders.keys()), [1, 2])
308
309
                self.assertEqual(all_orders[1], algo.get_open_orders(1))
310
                self.assertEqual(len(all_orders[1]), 1)
311
312
                self.assertEqual(all_orders[2], algo.get_open_orders(2))
313
                self.assertEqual(len(all_orders[2]), 3)
314
315
            if algo.minute == 1:
316
                # First order should have filled.
317
                # Second order should still be open.
318
                all_orders = algo.get_open_orders()
319
                self.assertEqual(list(all_orders.keys()), [2])
320
321
                self.assertEqual([], algo.get_open_orders(1))
322
323
                orders_2 = algo.get_open_orders(2)
324
                self.assertEqual(all_orders[2], orders_2)
325
                self.assertEqual(len(all_orders[2]), 3)
326
327
                for order in orders_2:
328
                    algo.cancel_order(order)
329
330
                all_orders = algo.get_open_orders()
331
                self.assertEqual(all_orders, {})
332
333
            algo.minute += 1
334
335
        algo = TradingAlgorithm(initialize=initialize,
336
                                handle_data=handle_data,
337
                                sim_params=self.sim_params,
338
                                env=self.env)
339
        algo.run(self.data_portal)
340
341
    def test_schedule_function(self):
342
        date_rules = DateRuleFactory
343
        time_rules = TimeRuleFactory
344
345
        def incrementer(algo, data):
346
            algo.func_called += 1
347
            self.assertEqual(
348
                algo.get_datetime().time(),
349
                datetime.time(hour=14, minute=31),
350
            )
351
352
        def initialize(algo):
353
            algo.func_called = 0
354
            algo.days = 1
355
            algo.date = None
356
            algo.schedule_function(
357
                func=incrementer,
358
                date_rule=date_rules.every_day(),
359
                time_rule=time_rules.market_open(),
360
            )
361
362
        def handle_data(algo, data):
363
            if not algo.date:
364
                algo.date = algo.get_datetime().date()
365
366
            if algo.date < algo.get_datetime().date():
367
                algo.days += 1
368
                algo.date = algo.get_datetime().date()
369
370
        algo = TradingAlgorithm(
371
            initialize=initialize,
372
            handle_data=handle_data,
373
            sim_params=self.sim_params,
374
            env=self.env,
375
        )
376
        algo.run(self.data_portal)
377
378
        self.assertEqual(algo.func_called, algo.days)
379
380
    def test_event_context(self):
381
        expected_data = []
382
        collected_data_pre = []
383
        collected_data_post = []
384
        function_stack = []
385
386
        def pre(data):
387
            function_stack.append(pre)
388
            collected_data_pre.append(data)
389
390
        def post(data):
391
            function_stack.append(post)
392
            collected_data_post.append(data)
393
394
        def initialize(context):
395
            context.add_event(Always(), f)
396
            context.add_event(Always(), g)
397
398
        def handle_data(context, data):
399
            function_stack.append(handle_data)
400
            expected_data.append(data)
401
402
        def f(context, data):
403
            function_stack.append(f)
404
405
        def g(context, data):
406
            function_stack.append(g)
407
408
        algo = TradingAlgorithm(
409
            initialize=initialize,
410
            handle_data=handle_data,
411
            sim_params=self.sim_params,
412
            create_event_context=CallbackManager(pre, post),
413
            env=self.env,
414
        )
415
        algo.run(self.data_portal)
416
417
        self.assertEqual(len(expected_data), 780)
418
        self.assertEqual(collected_data_pre, expected_data)
419
        self.assertEqual(collected_data_post, expected_data)
420
421
        self.assertEqual(
422
            len(function_stack),
423
            780 * 5,
424
            'Incorrect number of functions called: %s != 780' %
425
            len(function_stack),
426
        )
427
        expected_functions = [pre, handle_data, f, g, post] * 780
428
        for n, (f, g) in enumerate(zip(function_stack, expected_functions)):
429
            self.assertEqual(
430
                f,
431
                g,
432
                'function at position %d was incorrect, expected %s but got %s'
433
                % (n, g.__name__, f.__name__),
434
            )
435
436
    @parameterized.expand([
437
        ('daily',),
438
        ('minute'),
439
    ])
440
    def test_schedule_function_rule_creation(self, mode):
441
        def nop(*args, **kwargs):
442
            return None
443
444
        self.sim_params.data_frequency = mode
445
        algo = TradingAlgorithm(
446
            initialize=nop,
447
            handle_data=nop,
448
            sim_params=self.sim_params,
449
            env=self.env,
450
        )
451
452
        # Schedule something for NOT Always.
453
        algo.schedule_function(nop, time_rule=zipline.utils.events.Never())
454
455
        event_rule = algo.event_manager._events[1].rule
456
457
        self.assertIsInstance(event_rule, zipline.utils.events.OncePerDay)
458
459
        inner_rule = event_rule.rule
460
        self.assertIsInstance(inner_rule, zipline.utils.events.ComposedRule)
461
462
        first = inner_rule.first
463
        second = inner_rule.second
464
        composer = inner_rule.composer
465
466
        self.assertIsInstance(first, zipline.utils.events.Always)
467
468
        if mode == 'daily':
469
            self.assertIsInstance(second, zipline.utils.events.Always)
470
        else:
471
            self.assertIsInstance(second, zipline.utils.events.Never)
472
473
        self.assertIs(composer, zipline.utils.events.ComposedRule.lazy_and)
474
475
    def test_asset_lookup(self):
476
477
        algo = TradingAlgorithm(env=self.env)
478
479
        # Test before either PLAY existed
480
        algo.sim_params.period_end = pd.Timestamp('2001-12-01', tz='UTC')
481
        with self.assertRaises(SymbolNotFound):
482
            algo.symbol('PLAY')
483
        with self.assertRaises(SymbolNotFound):
484
            algo.symbols('PLAY')
485
486
        # Test when first PLAY exists
487
        algo.sim_params.period_end = pd.Timestamp('2002-12-01', tz='UTC')
488
        list_result = algo.symbols('PLAY')
489
        self.assertEqual(3, list_result[0])
490
491
        # Test after first PLAY ends
492
        algo.sim_params.period_end = pd.Timestamp('2004-12-01', tz='UTC')
493
        self.assertEqual(3, algo.symbol('PLAY'))
494
495
        # Test after second PLAY begins
496
        algo.sim_params.period_end = pd.Timestamp('2005-12-01', tz='UTC')
497
        self.assertEqual(4, algo.symbol('PLAY'))
498
499
        # Test after second PLAY ends
500
        algo.sim_params.period_end = pd.Timestamp('2006-12-01', tz='UTC')
501
        self.assertEqual(4, algo.symbol('PLAY'))
502
        list_result = algo.symbols('PLAY')
503
        self.assertEqual(4, list_result[0])
504
505
        # Test lookup SID
506
        self.assertIsInstance(algo.sid(3), Equity)
507
        self.assertIsInstance(algo.sid(4), Equity)
508
509
        # Supplying a non-string argument to symbol()
510
        # should result in a TypeError.
511
        with self.assertRaises(TypeError):
512
            algo.symbol(1)
513
514
        with self.assertRaises(TypeError):
515
            algo.symbol((1,))
516
517
        with self.assertRaises(TypeError):
518
            algo.symbol({1})
519
520
        with self.assertRaises(TypeError):
521
            algo.symbol([1])
522
523
        with self.assertRaises(TypeError):
524
            algo.symbol({'foo': 'bar'})
525
526
    def test_future_symbol(self):
527
        """ Tests the future_symbol API function.
528
        """
529
        algo = TradingAlgorithm(env=self.env)
530
        algo.datetime = pd.Timestamp('2006-12-01', tz='UTC')
531
532
        # Check that we get the correct fields for the CLG06 symbol
533
        cl = algo.future_symbol('CLG06')
534
        self.assertEqual(cl.sid, 5)
535
        self.assertEqual(cl.symbol, 'CLG06')
536
        self.assertEqual(cl.root_symbol, 'CL')
537
        self.assertEqual(cl.start_date, pd.Timestamp('2005-12-01', tz='UTC'))
538
        self.assertEqual(cl.notice_date, pd.Timestamp('2005-12-20', tz='UTC'))
539
        self.assertEqual(cl.expiration_date,
540
                         pd.Timestamp('2006-01-20', tz='UTC'))
541
542
        with self.assertRaises(SymbolNotFound):
543
            algo.future_symbol('')
544
545
        with self.assertRaises(SymbolNotFound):
546
            algo.future_symbol('PLAY')
547
548
        with self.assertRaises(SymbolNotFound):
549
            algo.future_symbol('FOOBAR')
550
551
        # Supplying a non-string argument to future_symbol()
552
        # should result in a TypeError.
553
        with self.assertRaises(TypeError):
554
            algo.future_symbol(1)
555
556
        with self.assertRaises(TypeError):
557
            algo.future_symbol((1,))
558
559
        with self.assertRaises(TypeError):
560
            algo.future_symbol({1})
561
562
        with self.assertRaises(TypeError):
563
            algo.future_symbol([1])
564
565
        with self.assertRaises(TypeError):
566
            algo.future_symbol({'foo': 'bar'})
567
568
    def test_future_chain(self):
569
        """ Tests the future_chain API function.
570
        """
571
        algo = TradingAlgorithm(env=self.env)
572
        algo.datetime = pd.Timestamp('2006-12-01', tz='UTC')
573
574
        # Check that the fields of the FutureChain object are set correctly
575
        cl = algo.future_chain('CL')
576
        self.assertEqual(cl.root_symbol, 'CL')
577
        self.assertEqual(cl.as_of_date, algo.datetime)
578
579
        # Check the fields are set correctly if an as_of_date is supplied
580
        as_of_date = pd.Timestamp('1952-08-11', tz='UTC')
581
582
        cl = algo.future_chain('CL', as_of_date=as_of_date)
583
        self.assertEqual(cl.root_symbol, 'CL')
584
        self.assertEqual(cl.as_of_date, as_of_date)
585
586
        cl = algo.future_chain('CL', as_of_date='1952-08-11')
587
        self.assertEqual(cl.root_symbol, 'CL')
588
        self.assertEqual(cl.as_of_date, as_of_date)
589
590
        # Check that weird capitalization is corrected
591
        cl = algo.future_chain('cL')
592
        self.assertEqual(cl.root_symbol, 'CL')
593
594
        cl = algo.future_chain('cl')
595
        self.assertEqual(cl.root_symbol, 'CL')
596
597
        # Check that invalid root symbols raise RootSymbolNotFound
598
        with self.assertRaises(RootSymbolNotFound):
599
            algo.future_chain('CLZ')
600
601
        with self.assertRaises(RootSymbolNotFound):
602
            algo.future_chain('')
603
604
        # Check that invalid dates raise UnsupportedDatetimeFormat
605
        with self.assertRaises(UnsupportedDatetimeFormat):
606
            algo.future_chain('CL', 'my_finger_slipped')
607
608
        with self.assertRaises(UnsupportedDatetimeFormat):
609
            algo.future_chain('CL', '2015-09-')
610
611
        # Supplying a non-string argument to future_chain()
612
        # should result in a TypeError.
613
        with self.assertRaises(TypeError):
614
            algo.future_chain(1)
615
616
        with self.assertRaises(TypeError):
617
            algo.future_chain((1,))
618
619
        with self.assertRaises(TypeError):
620
            algo.future_chain({1})
621
622
        with self.assertRaises(TypeError):
623
            algo.future_chain([1])
624
625
        with self.assertRaises(TypeError):
626
            algo.future_chain({'foo': 'bar'})
627
628
    def test_set_symbol_lookup_date(self):
629
        """
630
        Test the set_symbol_lookup_date API method.
631
        """
632
        # Note we start sid enumeration at i+3 so as not to
633
        # collide with sids [1, 2] added in the setUp() method.
634
        dates = pd.date_range('2013-01-01', freq='2D', periods=2, tz='UTC')
635
        # Create two assets with the same symbol but different
636
        # non-overlapping date ranges.
637
        metadata = pd.DataFrame.from_records(
638
            [
639
                {
640
                    'sid': i + 3,
641
                    'symbol': 'DUP',
642
                    'start_date': date.value,
643
                    'end_date': (date + timedelta(days=1)).value,
644
                }
645
                for i, date in enumerate(dates)
646
            ]
647
        )
648
        env = TradingEnvironment()
649
        env.write_data(equities_df=metadata)
650
        algo = TradingAlgorithm(env=env)
651
652
        # Set the period end to a date after the period end
653
        # dates for our assets.
654
        algo.sim_params.period_end = pd.Timestamp('2015-01-01', tz='UTC')
655
656
        # With no symbol lookup date set, we will use the period end date
657
        # for the as_of_date, resulting here in the asset with the earlier
658
        # start date being returned.
659
        result = algo.symbol('DUP')
660
        self.assertEqual(result.symbol, 'DUP')
661
662
        # By first calling set_symbol_lookup_date, the relevant asset
663
        # should be returned by lookup_symbol
664
        for i, date in enumerate(dates):
665
            algo.set_symbol_lookup_date(date)
666
            result = algo.symbol('DUP')
667
            self.assertEqual(result.symbol, 'DUP')
668
            self.assertEqual(result.sid, i + 3)
669
670
        with self.assertRaises(UnsupportedDatetimeFormat):
671
            algo.set_symbol_lookup_date('foobar')
672
673
674
class TestTransformAlgorithm(TestCase):
675
676
    @classmethod
677
    def setUpClass(cls):
678
        setup_logger(cls)
679
        cls.env = TradingEnvironment()
680
        cls.sim_params = factory.create_simulation_parameters(num_days=4,
681
                                                              env=cls.env)
682
        cls.sids = [0, 1, 133]
683
        cls.tempdir = TempDirectory()
684
685
        futures_metadata = {3: {'contract_multiplier': 10}}
686
        equities_metadata = {}
687
688
        for sid in cls.sids:
689
            equities_metadata[sid] = {
690
                'start_date': cls.sim_params.period_start,
691
                'end_date': cls.sim_params.period_end
692
            }
693
694
        cls.env.write_data(equities_data=equities_metadata,
695
                           futures_data=futures_metadata)
696
697
        trades_by_sid = {}
698
        for sid in cls.sids:
699
            trades_by_sid[sid] = factory.create_trade_history(
700
                sid,
701
                [10.0, 10.0, 11.0, 11.0],
702
                [100, 100, 100, 300],
703
                timedelta(days=1),
704
                cls.sim_params,
705
                cls.env
706
            )
707
708
        cls.data_portal = create_data_portal_from_trade_history(cls.env,
709
                                                                cls.tempdir,
710
                                                                cls.sim_params,
711
                                                                trades_by_sid)
712
713
        cls.data_portal.current_day = cls.sim_params.trading_days[0]
714
715
    @classmethod
716
    def tearDownClass(cls):
717
        teardown_logger(cls)
718
        del cls.env
719
        cls.tempdir.cleanup()
720
721
    def test_invalid_order_parameters(self):
722
        algo = InvalidOrderAlgorithm(
723
            sids=[133],
724
            sim_params=self.sim_params,
725
            env=self.env,
726
        )
727
        algo.run(self.data_portal)
728
729
    def test_run_twice(self):
730
        algo1 = TestRegisterTransformAlgorithm(
731
            sim_params=self.sim_params,
732
            sids=[0, 1]
733
        )
734
735
        res1 = algo1.run(self.data_portal)
736
737
        # Create a new trading algorithm, which will
738
        # use the newly instantiated environment.
739
        algo2 = TestRegisterTransformAlgorithm(
740
            sim_params=self.sim_params,
741
            sids=[0, 1]
742
        )
743
744
        res2 = algo2.run(self.data_portal)
745
746
        # FIXME I think we are getting Nans due to fixed benchmark,
747
        # so dropping them for now.
748
        res1 = res1.fillna(method='ffill')
749
        res2 = res2.fillna(method='ffill')
750
751
        np.testing.assert_array_equal(res1, res2)
752
753
    def test_data_frequency_setting(self):
754
        self.sim_params.data_frequency = 'daily'
755
756
        sim_params = factory.create_simulation_parameters(
757
            num_days=4, env=self.env, data_frequency='daily')
758
759
        algo = TestRegisterTransformAlgorithm(
760
            sim_params=sim_params,
761
            env=self.env,
762
        )
763
        self.assertEqual(algo.sim_params.data_frequency, 'daily')
764
765
        sim_params = factory.create_simulation_parameters(
766
            num_days=4, env=self.env, data_frequency='minute')
767
768
        algo = TestRegisterTransformAlgorithm(
769
            sim_params=sim_params,
770
            env=self.env,
771
        )
772
        self.assertEqual(algo.sim_params.data_frequency, 'minute')
773
774
    @parameterized.expand([
775
        (TestOrderAlgorithm,),
776
        (TestOrderValueAlgorithm,),
777
        (TestTargetAlgorithm,),
778
        (TestOrderPercentAlgorithm,),
779
        (TestTargetPercentAlgorithm,),
780
        (TestTargetValueAlgorithm,),
781
    ])
782
    def test_order_methods(self, algo_class):
783
        algo = algo_class(
784
            sim_params=self.sim_params,
785
            env=self.env,
786
        )
787
        # Ensure that the environment's asset 0 is an Equity
788
        asset_to_test = algo.sid(0)
789
        self.assertIsInstance(asset_to_test, Equity)
790
791
        algo.run(self.data_portal)
792
793
    @parameterized.expand([
794
        (TestOrderAlgorithm,),
795
        (TestOrderValueAlgorithm,),
796
        (TestTargetAlgorithm,),
797
        (TestOrderPercentAlgorithm,),
798
        (TestTargetValueAlgorithm,),
799
    ])
800
    def test_order_methods_for_future(self, algo_class):
801
        algo = algo_class(
802
            sim_params=self.sim_params,
803
            env=self.env,
804
        )
805
        # Ensure that the environment's asset 0 is a Future
806
        asset_to_test = algo.sid(3)
807
        self.assertIsInstance(asset_to_test, Future)
808
809
        algo.run(self.data_portal)
810
811
    @parameterized.expand([
812
        ("order",),
813
        ("order_value",),
814
        ("order_percent",),
815
        ("order_target",),
816
        ("order_target_percent",),
817
        ("order_target_value",),
818
    ])
819
    def test_order_method_style_forwarding(self, order_style):
820
        algo = TestOrderStyleForwardingAlgorithm(
821
            sim_params=self.sim_params,
822
            method_name=order_style,
823
            env=self.env
824
        )
825
        algo.run(self.data_portal)
826
827
    @parameterized.expand([
828
        (TestOrderAlgorithm,),
829
        (TestOrderValueAlgorithm,),
830
        (TestTargetAlgorithm,),
831
        (TestOrderPercentAlgorithm,)
832
    ])
833
    def test_minute_data(self, algo_class):
834
        tempdir = TempDirectory()
835
836
        try:
837
            env = TradingEnvironment()
838
839
            sim_params = SimulationParameters(
840
                period_start=pd.Timestamp('2002-1-2', tz='UTC'),
841
                period_end=pd.Timestamp('2002-1-4', tz='UTC'),
842
                capital_base=float("1.0e5"),
843
                data_frequency='minute',
844
                env=env
845
            )
846
847
            equities_metadata = {}
848
849
            for sid in [0, 1]:
850
                equities_metadata[sid] = {
851
                    'start_date': sim_params.period_start,
852
                    'end_date': sim_params.period_end + timedelta(days=1)
853
                }
854
855
            env.write_data(equities_data=equities_metadata)
856
857
            data_portal = create_data_portal(
858
                env,
859
                tempdir,
860
                sim_params,
861
                [0, 1]
862
            )
863
864
            algo = algo_class(sim_params=sim_params, env=env)
865
            algo.run(data_portal)
866
        finally:
867
            tempdir.cleanup()
868
869
870
class TestPositions(TestCase):
871
    @classmethod
872
    def setUpClass(cls):
873
        setup_logger(cls)
874
        cls.env = TradingEnvironment()
875
        cls.sim_params = factory.create_simulation_parameters(num_days=4,
876
                                                              env=cls.env)
877
878
        cls.sids = [1, 133]
879
        cls.tempdir = TempDirectory()
880
881
        equities_metadata = {}
882
883
        for sid in cls.sids:
884
            equities_metadata[sid] = {
885
                'start_date': cls.sim_params.period_start,
886
                'end_date': cls.sim_params.period_end
887
            }
888
889
        cls.env.write_data(equities_data=equities_metadata)
890
891
        cls.data_portal = create_data_portal(
892
            cls.env,
893
            cls.tempdir,
894
            cls.sim_params,
895
            cls.sids
896
        )
897
898
    @classmethod
899
    def tearDownClass(cls):
900
        teardown_logger(cls)
901
        cls.tempdir.cleanup()
902
903
    def test_empty_portfolio(self):
904
        algo = EmptyPositionsAlgorithm(self.sids,
905
                                       sim_params=self.sim_params,
906
                                       env=self.env)
907
        daily_stats = algo.run(self.data_portal)
908
909
        expected_position_count = [
910
            0,  # Before entering the first position
911
            2,  # After entering, exiting on this date
912
            0,  # After exiting
913
            0,
914
        ]
915
916
        for i, expected in enumerate(expected_position_count):
917
            self.assertEqual(daily_stats.ix[i]['num_positions'],
918
                             expected)
919
920
    def test_noop_orders(self):
921
        algo = AmbitiousStopLimitAlgorithm(sid=1,
922
                                           sim_params=self.sim_params,
923
                                           env=self.env)
924
        daily_stats = algo.run(self.data_portal)
925
926
        # Verify that positions are empty for all dates.
927
        empty_positions = daily_stats.positions.map(lambda x: len(x) == 0)
928
        self.assertTrue(empty_positions.all())
929
930
931
class TestAlgoScript(TestCase):
932
933
    @classmethod
934
    def setUpClass(cls):
935
        setup_logger(cls)
936
        cls.env = TradingEnvironment()
937
        cls.sim_params = factory.create_simulation_parameters(num_days=251,
938
                                                              env=cls.env)
939
940
        cls.sids = [0, 1, 3, 133]
941
        cls.tempdir = TempDirectory()
942
943
        equities_metadata = {}
944
945
        for sid in cls.sids:
946
            equities_metadata[sid] = {
947
                'start_date': cls.sim_params.period_start,
948
                'end_date': cls.sim_params.period_end
949
            }
950
951
            if sid == 3:
952
                equities_metadata[sid]["symbol"] = "TEST"
953
                equities_metadata[sid]["asset_type"] = "equity"
954
955
        cls.env.write_data(equities_data=equities_metadata)
956
957
        days = 251
958
959
        trades_by_sid = {
960
            0: factory.create_trade_history(
961
                0,
962
                [10.0] * days,
963
                [100] * days,
964
                timedelta(days=1),
965
                cls.sim_params,
966
                cls.env),
967
            3: factory.create_trade_history(
968
                3,
969
                [10.0] * days,
970
                [100] * days,
971
                timedelta(days=1),
972
                cls.sim_params,
973
                cls.env)
974
        }
975
976
        cls.data_portal = create_data_portal_from_trade_history(cls.env,
977
                                                                cls.tempdir,
978
                                                                cls.sim_params,
979
                                                                trades_by_sid)
980
981
        cls.zipline_test_config = {
982
            'sid': 0,
983
        }
984
985
    @classmethod
986
    def tearDownClass(cls):
987
        del cls.env
988
        cls.tempdir.cleanup()
989
        teardown_logger(cls)
990
991
    def test_noop(self):
992
        algo = TradingAlgorithm(initialize=initialize_noop,
993
                                handle_data=handle_data_noop)
994
        algo.run(self.data_portal)
995
996
    def test_noop_string(self):
997
        algo = TradingAlgorithm(script=noop_algo)
998
        algo.run(self.data_portal)
999
1000
    def test_api_calls(self):
1001
        algo = TradingAlgorithm(initialize=initialize_api,
1002
                                handle_data=handle_data_api,
1003
                                env=self.env)
1004
        algo.run(self.data_portal)
1005
1006
    def test_api_calls_string(self):
1007
        algo = TradingAlgorithm(script=api_algo, env=self.env)
1008
        algo.run(self.data_portal)
1009
1010
    def test_api_get_environment(self):
1011
        platform = 'zipline'
1012
        algo = TradingAlgorithm(script=api_get_environment_algo,
1013
                                platform=platform)
1014
        algo.run(self.data_portal)
1015
        self.assertEqual(algo.environment, platform)
1016
1017
    def test_api_symbol(self):
1018
        algo = TradingAlgorithm(script=api_symbol_algo,
1019
                                env=self.env,
1020
                                sim_params=self.sim_params)
1021
        algo.run(self.data_portal)
1022
1023
    def test_fixed_slippage(self):
1024
        # verify order -> transaction -> portfolio position.
1025
        # --------------
1026
        test_algo = TradingAlgorithm(
1027
            script="""
1028
from zipline.api import (slippage,
1029
                         commission,
1030
                         set_slippage,
1031
                         set_commission,
1032
                         order,
1033
                         record,
1034
                         sid)
1035
1036
def initialize(context):
1037
    model = slippage.FixedSlippage(spread=0.10)
1038
    set_slippage(model)
1039
    set_commission(commission.PerTrade(100.00))
1040
    context.count = 1
1041
    context.incr = 0
1042
1043
def handle_data(context, data):
1044
    if context.incr < context.count:
1045
        order(sid(0), -1000)
1046
    record(price=data[0].price)
1047
1048
    context.incr += 1""",
1049
            sim_params=self.sim_params,
1050
            env=self.env,
1051
        )
1052
        results = test_algo.run(self.data_portal)
1053
1054
        # flatten the list of txns
1055
        all_txns = [val for sublist in results["transactions"].tolist()
1056
                    for val in sublist]
1057
1058
        self.assertEqual(len(all_txns), 1)
1059
        txn = all_txns[0]
1060
1061
        self.assertEqual(100.0, txn["commission"])
1062
        expected_spread = 0.05
1063
        expected_commish = 0.10
1064
        expected_price = test_algo.recorded_vars["price"] - expected_spread \
1065
            - expected_commish
1066
1067
        self.assertEqual(expected_price, txn['price'])
1068
1069
    @parameterized.expand(
1070
        [
1071
            ('no_minimum_commission', 0,),
1072
            ('default_minimum_commission', 1,),
1073
            ('alternate_minimum_commission', 2,),
1074
        ]
1075
    )
1076
    def test_volshare_slippage(self, name, minimum_commission):
1077
        tempdir = TempDirectory()
1078
        try:
1079
            if name == "default_minimum_commission":
1080
                commission_line = "set_commission(commission.PerShare(0.02))"
1081
            else:
1082
                commission_line = \
1083
                    "set_commission(commission.PerShare(0.02, " \
1084
                    "min_trade_cost={0}))".format(minimum_commission)
1085
1086
            # verify order -> transaction -> portfolio position.
1087
            # --------------
1088
            test_algo = TradingAlgorithm(
1089
                script="""
1090
from zipline.api import *
1091
1092
def initialize(context):
1093
    model = slippage.VolumeShareSlippage(
1094
                            volume_limit=.3,
1095
                            price_impact=0.05
1096
                       )
1097
    set_slippage(model)
1098
    {0}
1099
1100
    context.count = 2
1101
    context.incr = 0
1102
1103
def handle_data(context, data):
1104
    if context.incr < context.count:
1105
        # order small lots to be sure the
1106
        # order will fill in a single transaction
1107
        order(sid(0), 5000)
1108
    record(price=data[0].price)
1109
    record(volume=data[0].volume)
1110
    record(incr=context.incr)
1111
    context.incr += 1
1112
    """.format(commission_line),
1113
                sim_params=self.sim_params,
1114
                env=self.env,
1115
            )
1116
            set_algo_instance(test_algo)
1117
            trades = factory.create_daily_trade_source(
1118
                [0], self.sim_params, self.env)
1119
            data_portal = create_data_portal_from_trade_history(
1120
                self.env, tempdir, self.sim_params, {0: trades})
1121
            results = test_algo.run(data_portal=data_portal)
1122
1123
            all_txns = [
1124
                val for sublist in results["transactions"].tolist()
1125
                for val in sublist]
1126
1127
            self.assertEqual(len(all_txns), 67)
1128
            first_txn = all_txns[0]
1129
1130
            if minimum_commission == 0:
1131
                commish = first_txn["amount"] * 0.02
1132
                self.assertEqual(commish, first_txn["commission"])
1133
            else:
1134
                self.assertEqual(minimum_commission, first_txn["commission"])
1135
1136
            # import pdb; pdb.set_trace()
1137
            # commish = first_txn["amount"] * per_share_commish
1138
            # self.assertEqual(commish, first_txn["commission"])
1139
            # self.assertEqual(2.029, first_txn["price"])
1140
        finally:
1141
            tempdir.cleanup()
1142
1143
    def test_algo_record_vars(self):
1144
        test_algo = TradingAlgorithm(
1145
            script=record_variables,
1146
            sim_params=self.sim_params,
1147
            env=self.env,
1148
        )
1149
        results = test_algo.run(self.data_portal)
1150
1151
        for i in range(1, 252):
1152
            self.assertEqual(results.iloc[i-1]["incr"], i)
1153
1154
    def test_algo_record_allow_mock(self):
1155
        """
1156
        Test that values from "MagicMock"ed methods can be passed to record.
1157
1158
        Relevant for our basic/validation and methods like history, which
1159
        will end up returning a MagicMock instead of a DataFrame.
1160
        """
1161
        test_algo = TradingAlgorithm(
1162
            script=record_variables,
1163
            sim_params=self.sim_params,
1164
        )
1165
        set_algo_instance(test_algo)
1166
1167
        test_algo.record(foo=MagicMock())
1168
1169
    def test_algo_record_nan(self):
1170
        test_algo = TradingAlgorithm(
1171
            script=record_float_magic % 'nan',
1172
            sim_params=self.sim_params,
1173
            env=self.env,
1174
        )
1175
        results = test_algo.run(self.data_portal)
1176
1177
        for i in range(1, 252):
1178
            self.assertTrue(np.isnan(results.iloc[i-1]["data"]))
1179
1180
    def test_order_methods(self):
1181
        """
1182
        Only test that order methods can be called without error.
1183
        Correct filling of orders is tested in zipline.
1184
        """
1185
        test_algo = TradingAlgorithm(
1186
            script=call_all_order_methods,
1187
            sim_params=self.sim_params,
1188
            env=self.env,
1189
        )
1190
        test_algo.run(self.data_portal)
1191
1192
    def test_order_in_init(self):
1193
        """
1194
        Test that calling order in initialize
1195
        will raise an error.
1196
        """
1197
        with self.assertRaises(OrderDuringInitialize):
1198
            test_algo = TradingAlgorithm(
1199
                script=call_order_in_init,
1200
                sim_params=self.sim_params,
1201
                env=self.env,
1202
            )
1203
            test_algo.run(self.data_portal)
1204
1205
    def test_portfolio_in_init(self):
1206
        """
1207
        Test that accessing portfolio in init doesn't break.
1208
        """
1209
        test_algo = TradingAlgorithm(
1210
            script=access_portfolio_in_init,
1211
            sim_params=self.sim_params,
1212
            env=self.env,
1213
        )
1214
        test_algo.run(self.data_portal)
1215
1216
    def test_account_in_init(self):
1217
        """
1218
        Test that accessing account in init doesn't break.
1219
        """
1220
        test_algo = TradingAlgorithm(
1221
            script=access_account_in_init,
1222
            sim_params=self.sim_params,
1223
            env=self.env,
1224
        )
1225
        test_algo.run(self.data_portal)
1226
1227
1228
class TestGetDatetime(TestCase):
1229
1230
    @classmethod
1231
    def setUpClass(cls):
1232
        cls.env = TradingEnvironment()
1233
        cls.env.write_data(equities_identifiers=[0, 1])
1234
1235
        setup_logger(cls)
1236
1237
        cls.sim_params = factory.create_simulation_parameters(
1238
            data_frequency='minute',
1239
            env=cls.env,
1240
            start=to_utc('2014-01-02 9:31'),
1241
            end=to_utc('2014-01-03 9:31')
1242
        )
1243
1244
        cls.tempdir = TempDirectory()
1245
1246
        cls.data_portal = create_data_portal(
1247
            cls.env,
1248
            cls.tempdir,
1249
            cls.sim_params,
1250
            [1]
1251
        )
1252
1253
    @classmethod
1254
    def tearDownClass(cls):
1255
        del cls.env
1256
        teardown_logger(cls)
1257
        cls.tempdir.cleanup()
1258
1259
    @parameterized.expand(
1260
        [
1261
            ('default', None,),
1262
            ('utc', 'UTC',),
1263
            ('us_east', 'US/Eastern',),
1264
        ]
1265
    )
1266
    def test_get_datetime(self, name, tz):
1267
        algo = dedent(
1268
            """
1269
            import pandas as pd
1270
            from zipline.api import get_datetime
1271
1272
            def initialize(context):
1273
                context.tz = {tz} or 'UTC'
1274
                context.first_bar = True
1275
1276
            def handle_data(context, data):
1277
                if context.first_bar:
1278
                    dt = get_datetime({tz})
1279
                    if dt.tz.zone != context.tz:
1280
                        raise ValueError("Mismatched Zone")
1281
                    elif dt.tz_convert("US/Eastern").hour != 9:
1282
                        raise ValueError("Mismatched Hour")
1283
                    elif dt.tz_convert("US/Eastern").minute != 31:
1284
                        raise ValueError("Mismatched Minute")
1285
                context.first_bar = False
1286
            """.format(tz=repr(tz))
1287
        )
1288
1289
        algo = TradingAlgorithm(
1290
            script=algo,
1291
            sim_params=self.sim_params,
1292
            env=self.env,
1293
        )
1294
        algo.run(self.data_portal)
1295
        self.assertFalse(algo.first_bar)
1296
1297
1298
class TestTradingControls(TestCase):
1299
1300
    @classmethod
1301
    def setUpClass(cls):
1302
        cls.sid = 133
1303
        cls.env = TradingEnvironment()
1304
        cls.sim_params = factory.create_simulation_parameters(num_days=4,
1305
                                                              env=cls.env)
1306
1307
        cls.env.write_data(equities_data={
1308
            133: {
1309
                'start_date': cls.sim_params.period_start,
1310
                'end_date': cls.sim_params.period_end
1311
            }
1312
        })
1313
1314
        cls.tempdir = TempDirectory()
1315
1316
        cls.data_portal = create_data_portal(
1317
            cls.env,
1318
            cls.tempdir,
1319
            cls.sim_params,
1320
            [cls.sid]
1321
        )
1322
1323
    @classmethod
1324
    def tearDownClass(cls):
1325
        del cls.env
1326
        cls.tempdir.cleanup()
1327
1328
    def _check_algo(self,
1329
                    algo,
1330
                    handle_data,
1331
                    expected_order_count,
1332
                    expected_exc):
1333
1334
        algo._handle_data = handle_data
1335
        with self.assertRaises(expected_exc) if expected_exc else nullctx():
1336
            algo.run(self.data_portal)
1337
        self.assertEqual(algo.order_count, expected_order_count)
1338
1339
    def check_algo_succeeds(self, algo, handle_data, order_count=4):
1340
        # Default for order_count assumes one order per handle_data call.
1341
        self._check_algo(algo, handle_data, order_count, None)
1342
1343
    def check_algo_fails(self, algo, handle_data, order_count):
1344
        self._check_algo(algo,
1345
                         handle_data,
1346
                         order_count,
1347
                         TradingControlViolation)
1348
1349
    def test_set_max_position_size(self):
1350
1351
        # Buy one share four times.  Should be fine.
1352
        def handle_data(algo, data):
1353
            algo.order(algo.sid(self.sid), 1)
1354
            algo.order_count += 1
1355
        algo = SetMaxPositionSizeAlgorithm(sid=self.sid,
1356
                                           max_shares=10,
1357
                                           max_notional=500.0,
1358
                                           sim_params=self.sim_params,
1359
                                           env=self.env)
1360
        self.check_algo_succeeds(algo, handle_data)
1361
1362
        # Buy three shares four times.  Should bail on the fourth before it's
1363
        # placed.
1364
        def handle_data(algo, data):
1365
            algo.order(algo.sid(self.sid), 3)
1366
            algo.order_count += 1
1367
1368
        algo = SetMaxPositionSizeAlgorithm(sid=self.sid,
1369
                                           max_shares=10,
1370
                                           max_notional=500.0,
1371
                                           sim_params=self.sim_params,
1372
                                           env=self.env)
1373
        self.check_algo_fails(algo, handle_data, 3)
1374
1375
        # Buy three shares four times. Should bail due to max_notional on the
1376
        # third attempt.
1377
        def handle_data(algo, data):
1378
            algo.order(algo.sid(self.sid), 3)
1379
            algo.order_count += 1
1380
1381
        algo = SetMaxPositionSizeAlgorithm(sid=self.sid,
1382
                                           max_shares=10,
1383
                                           max_notional=67.0,
1384
                                           sim_params=self.sim_params,
1385
                                           env=self.env)
1386
        self.check_algo_fails(algo, handle_data, 2)
1387
1388
        # Set the trading control to a different sid, then BUY ALL THE THINGS!.
1389
        # Should continue normally.
1390
        def handle_data(algo, data):
1391
            algo.order(algo.sid(self.sid), 10000)
1392
            algo.order_count += 1
1393
        algo = SetMaxPositionSizeAlgorithm(sid=self.sid + 1,
1394
                                           max_shares=10,
1395
                                           max_notional=67.0,
1396
                                           sim_params=self.sim_params,
1397
                                           env=self.env)
1398
        self.check_algo_succeeds(algo, handle_data)
1399
1400
        # Set the trading control sid to None, then BUY ALL THE THINGS!. Should
1401
        # fail because setting sid to None makes the control apply to all sids.
1402
        def handle_data(algo, data):
1403
            algo.order(algo.sid(self.sid), 10000)
1404
            algo.order_count += 1
1405
        algo = SetMaxPositionSizeAlgorithm(max_shares=10, max_notional=61.0,
1406
                                           sim_params=self.sim_params,
1407
                                           env=self.env)
1408
        self.check_algo_fails(algo, handle_data, 0)
1409
1410
    def test_set_do_not_order_list(self):
1411
        # set the restricted list to be the sid, and fail.
1412
        algo = SetDoNotOrderListAlgorithm(
1413
            sid=self.sid,
1414
            restricted_list=[self.sid],
1415
            sim_params=self.sim_params,
1416
            env=self.env,
1417
        )
1418
1419
        def handle_data(algo, data):
1420
            algo.order(algo.sid(self.sid), 100)
1421
            algo.order_count += 1
1422
1423
        self.check_algo_fails(algo, handle_data, 0)
1424
1425
        # set the restricted list to exclude the sid, and succeed
1426
        algo = SetDoNotOrderListAlgorithm(
1427
            sid=self.sid,
1428
            restricted_list=[134, 135, 136],
1429
            sim_params=self.sim_params,
1430
            env=self.env,
1431
        )
1432
1433
        def handle_data(algo, data):
1434
            algo.order(algo.sid(self.sid), 100)
1435
            algo.order_count += 1
1436
1437
        self.check_algo_succeeds(algo, handle_data)
1438
1439
    def test_set_max_order_size(self):
1440
1441
        # Buy one share.
1442
        def handle_data(algo, data):
1443
            algo.order(algo.sid(self.sid), 1)
1444
            algo.order_count += 1
1445
        algo = SetMaxOrderSizeAlgorithm(sid=self.sid,
1446
                                        max_shares=10,
1447
                                        max_notional=500.0,
1448
                                        sim_params=self.sim_params,
1449
                                        env=self.env)
1450
        self.check_algo_succeeds(algo, handle_data)
1451
1452
        # Buy 1, then 2, then 3, then 4 shares.  Bail on the last attempt
1453
        # because we exceed shares.
1454
        def handle_data(algo, data):
1455
            algo.order(algo.sid(self.sid), algo.order_count + 1)
1456
            algo.order_count += 1
1457
1458
        algo = SetMaxOrderSizeAlgorithm(sid=self.sid,
1459
                                        max_shares=3,
1460
                                        max_notional=500.0,
1461
                                        sim_params=self.sim_params,
1462
                                        env=self.env)
1463
        self.check_algo_fails(algo, handle_data, 3)
1464
1465
        # Buy 1, then 2, then 3, then 4 shares.  Bail on the last attempt
1466
        # because we exceed notional.
1467
        def handle_data(algo, data):
1468
            algo.order(algo.sid(self.sid), algo.order_count + 1)
1469
            algo.order_count += 1
1470
1471
        algo = SetMaxOrderSizeAlgorithm(sid=self.sid,
1472
                                        max_shares=10,
1473
                                        max_notional=40.0,
1474
                                        sim_params=self.sim_params,
1475
                                        env=self.env)
1476
        self.check_algo_fails(algo, handle_data, 3)
1477
1478
        # Set the trading control to a different sid, then BUY ALL THE THINGS!.
1479
        # Should continue normally.
1480
        def handle_data(algo, data):
1481
            algo.order(algo.sid(self.sid), 10000)
1482
            algo.order_count += 1
1483
        algo = SetMaxOrderSizeAlgorithm(sid=self.sid + 1,
1484
                                        max_shares=1,
1485
                                        max_notional=1.0,
1486
                                        sim_params=self.sim_params,
1487
                                        env=self.env)
1488
        self.check_algo_succeeds(algo, handle_data)
1489
1490
        # Set the trading control sid to None, then BUY ALL THE THINGS!.
1491
        # Should fail because not specifying a sid makes the trading control
1492
        # apply to all sids.
1493
        def handle_data(algo, data):
1494
            algo.order(algo.sid(self.sid), 10000)
1495
            algo.order_count += 1
1496
        algo = SetMaxOrderSizeAlgorithm(max_shares=1,
1497
                                        max_notional=1.0,
1498
                                        sim_params=self.sim_params,
1499
                                        env=self.env)
1500
        self.check_algo_fails(algo, handle_data, 0)
1501
1502
    def test_set_max_order_count(self):
1503
        tempdir = TempDirectory()
1504
        try:
1505
            env = TradingEnvironment()
1506
            sim_params = factory.create_simulation_parameters(
1507
                num_days=4, env=env, data_frequency="minute")
1508
1509
            env.write_data(equities_data={
1510
                1: {
1511
                    'start_date': sim_params.period_start,
1512
                    'end_date': sim_params.period_end + timedelta(days=1)
1513
                }
1514
            })
1515
1516
            data_portal = create_data_portal(
1517
                env,
1518
                tempdir,
1519
                sim_params,
1520
                [1]
1521
            )
1522
1523
            def handle_data(algo, data):
1524
                for i in range(5):
1525
                    algo.order(algo.sid(1), 1)
1526
                    algo.order_count += 1
1527
1528
            algo = SetMaxOrderCountAlgorithm(3, sim_params=sim_params,
1529
                                             env=env)
1530
            with self.assertRaises(TradingControlViolation):
1531
                algo._handle_data = handle_data
1532
                algo.run(data_portal)
1533
1534
            self.assertEqual(algo.order_count, 3)
1535
1536
            # This time, order 5 times twice in a single day. The last order
1537
            # of the second batch should fail.
1538
            def handle_data2(algo, data):
1539
                if algo.minute_count == 0 or algo.minute_count == 100:
1540
                    for i in range(5):
1541
                        algo.order(algo.sid(1), 1)
1542
                        algo.order_count += 1
1543
1544
                algo.minute_count += 1
1545
1546
            algo = SetMaxOrderCountAlgorithm(9, sim_params=sim_params,
1547
                                             env=env)
1548
            with self.assertRaises(TradingControlViolation):
1549
                algo._handle_data = handle_data2
1550
                algo.run(data_portal)
1551
1552
            self.assertEqual(algo.order_count, 9)
1553
1554
            def handle_data3(algo, data):
1555
                if (algo.minute_count % 390) == 0:
1556
                    for i in range(5):
1557
                        algo.order(algo.sid(1), 1)
1558
                        algo.order_count += 1
1559
1560
                algo.minute_count += 1
1561
1562
            # Only 5 orders are placed per day, so this should pass even
1563
            # though in total more than 20 orders are placed.
1564
            algo = SetMaxOrderCountAlgorithm(5, sim_params=sim_params,
1565
                                             env=env)
1566
            algo._handle_data = handle_data3
1567
            algo.run(data_portal)
1568
        finally:
1569
            tempdir.cleanup()
1570
1571
    def test_long_only(self):
1572
        # Sell immediately -> fail immediately.
1573
        def handle_data(algo, data):
1574
            algo.order(algo.sid(self.sid), -1)
1575
            algo.order_count += 1
1576
        algo = SetLongOnlyAlgorithm(sim_params=self.sim_params, env=self.env)
1577
        self.check_algo_fails(algo, handle_data, 0)
1578
1579
        # Buy on even days, sell on odd days.  Never takes a short position, so
1580
        # should succeed.
1581
        def handle_data(algo, data):
1582
            if (algo.order_count % 2) == 0:
1583
                algo.order(algo.sid(self.sid), 1)
1584
            else:
1585
                algo.order(algo.sid(self.sid), -1)
1586
            algo.order_count += 1
1587
        algo = SetLongOnlyAlgorithm(sim_params=self.sim_params, env=self.env)
1588
        self.check_algo_succeeds(algo, handle_data)
1589
1590
        # Buy on first three days, then sell off holdings.  Should succeed.
1591
        def handle_data(algo, data):
1592
            amounts = [1, 1, 1, -3]
1593
            algo.order(algo.sid(self.sid), amounts[algo.order_count])
1594
            algo.order_count += 1
1595
        algo = SetLongOnlyAlgorithm(sim_params=self.sim_params, env=self.env)
1596
        self.check_algo_succeeds(algo, handle_data)
1597
1598
        # Buy on first three days, then sell off holdings plus an extra share.
1599
        # Should fail on the last sale.
1600
        def handle_data(algo, data):
1601
            amounts = [1, 1, 1, -4]
1602
            algo.order(algo.sid(self.sid), amounts[algo.order_count])
1603
            algo.order_count += 1
1604
        algo = SetLongOnlyAlgorithm(sim_params=self.sim_params, env=self.env)
1605
        self.check_algo_fails(algo, handle_data, 3)
1606
1607
    def test_register_post_init(self):
1608
1609
        def initialize(algo):
1610
            algo.initialized = True
1611
1612
        def handle_data(algo, data):
1613
            with self.assertRaises(RegisterTradingControlPostInit):
1614
                algo.set_max_position_size(self.sid, 1, 1)
1615
            with self.assertRaises(RegisterTradingControlPostInit):
1616
                algo.set_max_order_size(self.sid, 1, 1)
1617
            with self.assertRaises(RegisterTradingControlPostInit):
1618
                algo.set_max_order_count(1)
1619
            with self.assertRaises(RegisterTradingControlPostInit):
1620
                algo.set_long_only()
1621
1622
        algo = TradingAlgorithm(initialize=initialize,
1623
                                handle_data=handle_data,
1624
                                sim_params=self.sim_params,
1625
                                env=self.env)
1626
        algo.run(self.data_portal)
1627
1628
    def test_asset_date_bounds(self):
1629
        tempdir = TempDirectory()
1630
        try:
1631
            # Run the algorithm with a sid that ends far in the future
1632
            temp_env = TradingEnvironment()
1633
1634
            data_portal = create_data_portal(
1635
                temp_env,
1636
                tempdir,
1637
                self.sim_params,
1638
                [0]
1639
            )
1640
1641
            metadata = {0: {'start_date': self.sim_params.period_start,
1642
                            'end_date': '2020-01-01'}}
1643
1644
            algo = SetAssetDateBoundsAlgorithm(
1645
                equities_metadata=metadata,
1646
                sim_params=self.sim_params,
1647
                env=temp_env,
1648
            )
1649
            algo.run(data_portal)
1650
        finally:
1651
            tempdir.cleanup()
1652
1653
        # Run the algorithm with a sid that has already ended
1654
        tempdir = TempDirectory()
1655
        try:
1656
            temp_env = TradingEnvironment()
1657
1658
            data_portal = create_data_portal(
1659
                temp_env,
1660
                tempdir,
1661
                self.sim_params,
1662
                [0]
1663
            )
1664
            metadata = {0: {'start_date': '1989-01-01',
1665
                            'end_date': '1990-01-01'}}
1666
1667
            algo = SetAssetDateBoundsAlgorithm(
1668
                equities_metadata=metadata,
1669
                sim_params=self.sim_params,
1670
                env=temp_env,
1671
            )
1672
            with self.assertRaises(TradingControlViolation):
1673
                algo.run(data_portal)
1674
        finally:
1675
            tempdir.cleanup()
1676
1677
        # Run the algorithm with a sid that has not started
1678
        tempdir = TempDirectory()
1679
        try:
1680
            temp_env = TradingEnvironment()
1681
            data_portal = create_data_portal(
1682
                temp_env,
1683
                tempdir,
1684
                self.sim_params,
1685
                [0]
1686
            )
1687
1688
            metadata = {0: {'start_date': '2020-01-01',
1689
                            'end_date': '2021-01-01'}}
1690
1691
            algo = SetAssetDateBoundsAlgorithm(
1692
                equities_metadata=metadata,
1693
                sim_params=self.sim_params,
1694
                env=temp_env,
1695
            )
1696
1697
            with self.assertRaises(TradingControlViolation):
1698
                algo.run(data_portal)
1699
1700
        finally:
1701
            tempdir.cleanup()
1702
1703
1704
class TestAccountControls(TestCase):
1705
1706
    @classmethod
1707
    def setUpClass(cls):
1708
        cls.sidint = 133
1709
        cls.env = TradingEnvironment()
1710
        cls.sim_params = factory.create_simulation_parameters(
1711
            num_days=4, env=cls.env
1712
        )
1713
1714
        cls.env.write_data(equities_data={
1715
            133: {
1716
                'start_date': cls.sim_params.period_start,
1717
                'end_date': cls.sim_params.period_end + timedelta(days=1)
1718
            }
1719
        })
1720
1721
        cls.tempdir = TempDirectory()
1722
1723
        trades_by_sid = {
1724
            cls.sidint: factory.create_trade_history(
1725
                cls.sidint,
1726
                [10.0, 10.0, 11.0, 11.0],
1727
                [100, 100, 100, 300],
1728
                timedelta(days=1),
1729
                cls.sim_params,
1730
                cls.env,
1731
            )
1732
        }
1733
1734
        cls.data_portal = create_data_portal_from_trade_history(cls.env,
1735
                                                                cls.tempdir,
1736
                                                                cls.sim_params,
1737
                                                                trades_by_sid)
1738
1739
    @classmethod
1740
    def tearDownClass(cls):
1741
        del cls.env
1742
        cls.tempdir.cleanup()
1743
1744
    def _check_algo(self,
1745
                    algo,
1746
                    handle_data,
1747
                    expected_exc):
1748
1749
        algo._handle_data = handle_data
1750
        with self.assertRaises(expected_exc) if expected_exc else nullctx():
1751
            algo.run(self.data_portal)
1752
1753
    def check_algo_succeeds(self, algo, handle_data):
1754
        # Default for order_count assumes one order per handle_data call.
1755
        self._check_algo(algo, handle_data, None)
1756
1757
    def check_algo_fails(self, algo, handle_data):
1758
        self._check_algo(algo,
1759
                         handle_data,
1760
                         AccountControlViolation)
1761
1762
    def test_set_max_leverage(self):
1763
1764
        # Set max leverage to 0 so buying one share fails.
1765
        def handle_data(algo, data):
1766
            algo.order(algo.sid(self.sidint), 1)
1767
1768
        algo = SetMaxLeverageAlgorithm(0, sim_params=self.sim_params,
1769
                                       env=self.env)
1770
        self.check_algo_fails(algo, handle_data)
1771
1772
        # Set max leverage to 1 so buying one share passes
1773
        def handle_data(algo, data):
1774
            algo.order(algo.sid(self.sidint), 1)
1775
1776
        algo = SetMaxLeverageAlgorithm(1,  sim_params=self.sim_params,
1777
                                       env=self.env)
1778
        self.check_algo_succeeds(algo, handle_data)
1779
1780
1781
class TestClosePosAlgo(TestCase):
1782
1783
    @classmethod
1784
    def setUpClass(cls):
1785
        cls.tempdir = TempDirectory()
1786
1787
        cls.env = TradingEnvironment()
1788
        cls.days = pd.date_range(start=pd.Timestamp("2006-01-09", tz='UTC'),
1789
                                 end=pd.Timestamp("2006-01-12", tz='UTC'))
1790
1791
        cls.sid = 1
1792
1793
        cls.sim_params = factory.create_simulation_parameters(
1794
            start=cls.days[0],
1795
            end=cls.days[-1]
1796
        )
1797
1798
        trades_by_sid = {}
1799
        trades_by_sid[cls.sid] = factory.create_trade_history(
1800
            cls.sid,
1801
            [1, 1, 2, 4],
1802
            [1e9, 1e9, 1e9, 1e9],
1803
            timedelta(days=1),
1804
            cls.sim_params,
1805
            cls.env
1806
        )
1807
1808
        cls.data_portal = create_data_portal_from_trade_history(
1809
            cls.env,
1810
            cls.tempdir,
1811
            cls.sim_params,
1812
            trades_by_sid
1813
        )
1814
1815
    @classmethod
1816
    def tearDownClass(cls):
1817
        cls.tempdir.cleanup()
1818
1819
    def test_auto_close_future(self):
1820
        self.env.write_data(futures_data={
1821
            1: {
1822
                "start_date": self.sim_params.trading_days[0],
1823
                "end_date": self.env.next_trading_day(
1824
                    self.sim_params.trading_days[-1]),
1825
                'symbol': 'TEST',
1826
                'asset_type': 'future',
1827
                'auto_close_date': self.env.next_trading_day(
1828
                    self.sim_params.trading_days[-1])
1829
            }
1830
        })
1831
1832
        algo = TestAlgorithm(sid=1, amount=1, order_count=1,
1833
                             commission=PerShare(0),
1834
                             env=self.env,
1835
                             sim_params=self.sim_params)
1836
1837
        # Check results
1838
        results = algo.run(self.data_portal)
1839
1840
        expected_pnl = [0, 0, 1, 2]
1841
        self.check_algo_pnl(results, expected_pnl)
1842
1843
        expected_positions = [0, 1, 1, 0]
1844
        self.check_algo_positions(results, expected_positions)
1845
1846
        expected_pnl = [0, 0, 1, 2]
1847
        self.check_algo_pnl(results, expected_pnl)
1848
1849
    def check_algo_pnl(self, results, expected_pnl):
1850
        np.testing.assert_array_almost_equal(results.pnl, expected_pnl)
1851
1852
    def check_algo_positions(self, results, expected_positions):
1853
        for i, amount in enumerate(results.positions):
1854
            if amount:
1855
                actual_position = amount[0]['amount']
1856
            else:
1857
                actual_position = 0
1858
1859
            self.assertEqual(
1860
                actual_position, expected_positions[i],
1861
                "position for day={0} not equal, actual={1}, expected={2}".
1862
                format(i, actual_position, expected_positions[i]))
1863
1864
1865
class TestFutureFlip(TestCase):
1866
    @classmethod
1867
    def setUpClass(cls):
1868
        cls.tempdir = TempDirectory()
1869
1870
        cls.env = TradingEnvironment()
1871
        cls.days = pd.date_range(start=pd.Timestamp("2006-01-09", tz='UTC'),
1872
                                 end=pd.Timestamp("2006-01-12", tz='UTC'))
1873
1874
        cls.sid = 1
1875
1876
        cls.sim_params = factory.create_simulation_parameters(
1877
            start=cls.days[0],
1878
            end=cls.days[-2]
1879
        )
1880
1881
        trades = factory.create_trade_history(
1882
            cls.sid,
1883
            [1, 2, 4],
1884
            [1e9, 1e9, 1e9],
1885
            timedelta(days=1),
1886
            cls.sim_params,
1887
            cls.env
1888
        )
1889
1890
        trades_by_sid = {
1891
            cls.sid: trades
1892
        }
1893
1894
        cls.data_portal = create_data_portal_from_trade_history(
1895
            cls.env,
1896
            cls.tempdir,
1897
            cls.sim_params,
1898
            trades_by_sid
1899
        )
1900
1901
    @classmethod
1902
    def tearDownClass(cls):
1903
        cls.tempdir.cleanup()
1904
1905
    def test_flip_algo(self):
1906
        metadata = {1: {'symbol': 'TEST',
1907
                        'start_date': self.sim_params.trading_days[0],
1908
                        'end_date': self.env.next_trading_day(
1909
                            self.sim_params.trading_days[-1]),
1910
                        'contract_multiplier': 5}}
1911
        self.env.write_data(futures_data=metadata)
1912
1913
        algo = FutureFlipAlgo(sid=1, amount=1, env=self.env,
1914
                              commission=PerShare(0),
1915
                              order_count=0,  # not applicable but required
1916
                              sim_params=self.sim_params)
1917
1918
        results = algo.run(self.data_portal)
1919
1920
        expected_positions = [0, 1, -1]
1921
        self.check_algo_positions(results, expected_positions)
1922
1923
        expected_pnl = [0, 5, -10]
1924
        self.check_algo_pnl(results, expected_pnl)
1925
1926
    def check_algo_pnl(self, results, expected_pnl):
1927
        np.testing.assert_array_almost_equal(results.pnl, expected_pnl)
1928
1929
    def check_algo_positions(self, results, expected_positions):
1930
        for i, amount in enumerate(results.positions):
1931
            if amount:
1932
                actual_position = amount[0]['amount']
1933
            else:
1934
                actual_position = 0
1935
1936
            self.assertEqual(
1937
                actual_position, expected_positions[i],
1938
                "position for day={0} not equal, actual={1}, expected={2}".
1939
                format(i, actual_position, expected_positions[i]))
1940
1941
1942
class TestTradingAlgorithm(TestCase):
1943
    def setUp(self):
1944
        self.env = TradingEnvironment()
1945
        self.days = self.env.trading_days[:4]
1946
1947
    def test_analyze_called(self):
1948
        self.perf_ref = None
1949
1950
        def initialize(context):
1951
            pass
1952
1953
        def handle_data(context, data):
1954
            pass
1955
1956
        def analyze(context, perf):
1957
            self.perf_ref = perf
1958
1959
        algo = TradingAlgorithm(initialize=initialize, handle_data=handle_data,
1960
                                analyze=analyze)
1961
1962
        data_portal = FakeDataPortal()
1963
1964
        results = algo.run(data_portal)
1965
        self.assertIs(results, self.perf_ref)
1966