Completed
Pull Request — master (#846)
by Warren
05:15 queued 03:42
created

tests.TestRemoveData.setUp()   B

Complexity

Conditions 1

Size

Total Lines 26

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 1
dl 0
loc 26
rs 8.8571
1
#
2
# Copyright 2014 Quantopian, Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
import datetime
16
from datetime import timedelta
17
from mock import MagicMock
18
from nose_parameterized import parameterized
19
from six.moves import range, map
20
from textwrap import dedent
21
from unittest import TestCase
22
23
import numpy as np
24
import pandas as pd
25
26
from zipline.assets import Equity, Future
27
from zipline.utils.api_support import ZiplineAPI
28
from zipline.utils.control_flow import nullctx
29
from zipline.utils.test_utils import (
30
    setup_logger,
31
    teardown_logger
32
)
33
import zipline.utils.factory as factory
34
import zipline.utils.simfactory as simfactory
35
36
from zipline.errors import (
37
    OrderDuringInitialize,
38
    RegisterTradingControlPostInit,
39
    TradingControlViolation,
40
    AccountControlViolation,
41
    SymbolNotFound,
42
    RootSymbolNotFound,
43
    UnsupportedDatetimeFormat,
44
)
45
from zipline.test_algorithms import (
46
    access_account_in_init,
47
    access_portfolio_in_init,
48
    AmbitiousStopLimitAlgorithm,
49
    EmptyPositionsAlgorithm,
50
    InvalidOrderAlgorithm,
51
    RecordAlgorithm,
52
    FutureFlipAlgo,
53
    TestAlgorithm,
54
    TestOrderAlgorithm,
55
    TestOrderInstantAlgorithm,
56
    TestOrderPercentAlgorithm,
57
    TestOrderStyleForwardingAlgorithm,
58
    TestOrderValueAlgorithm,
59
    TestRegisterTransformAlgorithm,
60
    TestTargetAlgorithm,
61
    TestTargetPercentAlgorithm,
62
    TestTargetValueAlgorithm,
63
    TestRemoveDataAlgo,
64
    SetLongOnlyAlgorithm,
65
    SetAssetDateBoundsAlgorithm,
66
    SetMaxPositionSizeAlgorithm,
67
    SetMaxOrderCountAlgorithm,
68
    SetMaxOrderSizeAlgorithm,
69
    SetDoNotOrderListAlgorithm,
70
    SetMaxLeverageAlgorithm,
71
    api_algo,
72
    api_get_environment_algo,
73
    api_symbol_algo,
74
    call_all_order_methods,
75
    call_order_in_init,
76
    handle_data_api,
77
    handle_data_noop,
78
    initialize_api,
79
    initialize_noop,
80
    noop_algo,
81
    record_float_magic,
82
    record_variables,
83
)
84
from zipline.utils.context_tricks import CallbackManager
85
import zipline.utils.events
86
from zipline.utils.test_utils import (
87
    assert_single_position,
88
    drain_zipline,
89
    to_utc,
90
)
91
92
from zipline.sources import (SpecificEquityTrades,
93
                             DataFrameSource,
94
                             DataPanelSource,
95
                             RandomWalkSource)
96
97
from zipline.finance.execution import LimitOrder
98
from zipline.finance.trading import SimulationParameters
99
from zipline.utils.api_support import set_algo_instance
100
from zipline.utils.events import DateRuleFactory, TimeRuleFactory, Always
101
from zipline.algorithm import TradingAlgorithm
102
from zipline.protocol import DATASOURCE_TYPE
103
from zipline.finance.trading import TradingEnvironment
104
from zipline.finance.commission import PerShare
105
106
# Because test cases appear to reuse some resources.
107
_multiprocess_can_split_ = False
108
109
110
class TestRecordAlgorithm(TestCase):
111
112
    @classmethod
113
    def setUpClass(cls):
114
        cls.env = TradingEnvironment()
115
        cls.env.write_data(equities_identifiers=[133])
116
117
    @classmethod
118
    def tearDownClass(cls):
119
        del cls.env
120
121
    def setUp(self):
122
        self.sim_params = factory.create_simulation_parameters(num_days=4,
123
                                                               env=self.env)
124
        trade_history = factory.create_trade_history(
125
            133,
126
            [10.0, 10.0, 11.0, 11.0],
127
            [100, 100, 100, 300],
128
            timedelta(days=1),
129
            self.sim_params,
130
            self.env
131
        )
132
133
        self.source = SpecificEquityTrades(event_list=trade_history,
134
                                           env=self.env)
135
        self.df_source, self.df = \
136
            factory.create_test_df_source(self.sim_params, self.env)
137
138
    def test_record_incr(self):
139
        algo = RecordAlgorithm(sim_params=self.sim_params, env=self.env)
140
        output = algo.run(self.source)
141
142
        np.testing.assert_array_equal(output['incr'].values,
143
                                      range(1, len(output) + 1))
144
        np.testing.assert_array_equal(output['name'].values,
145
                                      range(1, len(output) + 1))
146
        np.testing.assert_array_equal(output['name2'].values,
147
                                      [2] * len(output))
148
        np.testing.assert_array_equal(output['name3'].values,
149
                                      range(1, len(output) + 1))
150
151
152
class TestMiscellaneousAPI(TestCase):
153
154
    @classmethod
155
    def setUpClass(cls):
156
        cls.sids = [1, 2]
157
        cls.env = TradingEnvironment()
158
159
        metadata = {3: {'symbol': 'PLAY',
160
                        'start_date': '2002-01-01',
161
                        'end_date': '2004-01-01'},
162
                    4: {'symbol': 'PLAY',
163
                        'start_date': '2005-01-01',
164
                        'end_date': '2006-01-01'}}
165
166
        futures_metadata = {
167
            5: {
168
                'symbol': 'CLG06',
169
                'root_symbol': 'CL',
170
                'start_date': pd.Timestamp('2005-12-01', tz='UTC'),
171
                'notice_date': pd.Timestamp('2005-12-20', tz='UTC'),
172
                'expiration_date': pd.Timestamp('2006-01-20', tz='UTC')},
173
            6: {
174
                'root_symbol': 'CL',
175
                'symbol': 'CLK06',
176
                'start_date': pd.Timestamp('2005-12-01', tz='UTC'),
177
                'notice_date': pd.Timestamp('2006-03-20', tz='UTC'),
178
                'expiration_date': pd.Timestamp('2006-04-20', tz='UTC')},
179
            7: {
180
                'symbol': 'CLQ06',
181
                'root_symbol': 'CL',
182
                'start_date': pd.Timestamp('2005-12-01', tz='UTC'),
183
                'notice_date': pd.Timestamp('2006-06-20', tz='UTC'),
184
                'expiration_date': pd.Timestamp('2006-07-20', tz='UTC')},
185
            8: {
186
                'symbol': 'CLX06',
187
                'root_symbol': 'CL',
188
                'start_date': pd.Timestamp('2006-02-01', tz='UTC'),
189
                'notice_date': pd.Timestamp('2006-09-20', tz='UTC'),
190
                'expiration_date': pd.Timestamp('2006-10-20', tz='UTC')}
191
        }
192
        cls.env.write_data(equities_identifiers=cls.sids,
193
                           equities_data=metadata,
194
                           futures_data=futures_metadata)
195
196
    @classmethod
197
    def tearDownClass(cls):
198
        del cls.env
199
200
    def setUp(self):
201
        setup_logger(self)
202
        self.sim_params = factory.create_simulation_parameters(
203
            num_days=2,
204
            data_frequency='minute',
205
            emission_rate='minute',
206
            env=self.env,
207
        )
208
        self.source = factory.create_minutely_trade_source(
209
            self.sids,
210
            sim_params=self.sim_params,
211
            concurrent=True,
212
            env=self.env,
213
        )
214
215
    def tearDown(self):
216
        teardown_logger(self)
217
218
    def test_zipline_api_resolves_dynamically(self):
219
        # Make a dummy algo.
220
        algo = TradingAlgorithm(
221
            initialize=lambda context: None,
222
            handle_data=lambda context, data: None,
223
            sim_params=self.sim_params,
224
        )
225
226
        # Verify that api methods get resolved dynamically by patching them out
227
        # and then calling them
228
        for method in algo.all_api_methods():
229
            name = method.__name__
230
            sentinel = object()
231
232
            def fake_method(*args, **kwargs):
233
                return sentinel
234
            setattr(algo, name, fake_method)
235
            with ZiplineAPI(algo):
236
                self.assertIs(sentinel, getattr(zipline.api, name)())
237
238
    def test_get_environment(self):
239
        expected_env = {
240
            'arena': 'backtest',
241
            'data_frequency': 'minute',
242
            'start': pd.Timestamp('2006-01-03 14:31:00+0000', tz='UTC'),
243
            'end': pd.Timestamp('2006-01-04 21:00:00+0000', tz='UTC'),
244
            'capital_base': 100000.0,
245
            'platform': 'zipline'
246
        }
247
248
        def initialize(algo):
249
            self.assertEqual('zipline', algo.get_environment())
250
            self.assertEqual(expected_env, algo.get_environment('*'))
251
252
        def handle_data(algo, data):
253
            pass
254
255
        algo = TradingAlgorithm(initialize=initialize,
256
                                handle_data=handle_data,
257
                                sim_params=self.sim_params,
258
                                env=self.env)
259
        algo.run(self.source)
260
261
    def test_get_open_orders(self):
262
263
        def initialize(algo):
264
            algo.minute = 0
265
266
        def handle_data(algo, data):
267
            if algo.minute == 0:
268
269
                # Should be filled by the next minute
270
                algo.order(algo.sid(1), 1)
271
272
                # Won't be filled because the price is too low.
273
                algo.order(algo.sid(2), 1, style=LimitOrder(0.01))
274
                algo.order(algo.sid(2), 1, style=LimitOrder(0.01))
275
                algo.order(algo.sid(2), 1, style=LimitOrder(0.01))
276
277
                all_orders = algo.get_open_orders()
278
                self.assertEqual(list(all_orders.keys()), [1, 2])
279
280
                self.assertEqual(all_orders[1], algo.get_open_orders(1))
281
                self.assertEqual(len(all_orders[1]), 1)
282
283
                self.assertEqual(all_orders[2], algo.get_open_orders(2))
284
                self.assertEqual(len(all_orders[2]), 3)
285
286
            if algo.minute == 1:
287
                # First order should have filled.
288
                # Second order should still be open.
289
                all_orders = algo.get_open_orders()
290
                self.assertEqual(list(all_orders.keys()), [2])
291
292
                self.assertEqual([], algo.get_open_orders(1))
293
294
                orders_2 = algo.get_open_orders(2)
295
                self.assertEqual(all_orders[2], orders_2)
296
                self.assertEqual(len(all_orders[2]), 3)
297
298
                for order in orders_2:
299
                    algo.cancel_order(order)
300
301
                all_orders = algo.get_open_orders()
302
                self.assertEqual(all_orders, {})
303
304
            algo.minute += 1
305
306
        algo = TradingAlgorithm(initialize=initialize,
307
                                handle_data=handle_data,
308
                                sim_params=self.sim_params,
309
                                env=self.env)
310
        algo.run(self.source)
311
312
    def test_schedule_function(self):
313
        date_rules = DateRuleFactory
314
        time_rules = TimeRuleFactory
315
316
        def incrementer(algo, data):
317
            algo.func_called += 1
318
            self.assertEqual(
319
                algo.get_datetime().time(),
320
                datetime.time(hour=14, minute=31),
321
            )
322
323
        def initialize(algo):
324
            algo.func_called = 0
325
            algo.days = 1
326
            algo.date = None
327
            algo.schedule_function(
328
                func=incrementer,
329
                date_rule=date_rules.every_day(),
330
                time_rule=time_rules.market_open(),
331
            )
332
333
        def handle_data(algo, data):
334
            if not algo.date:
335
                algo.date = algo.get_datetime().date()
336
337
            if algo.date < algo.get_datetime().date():
338
                algo.days += 1
339
                algo.date = algo.get_datetime().date()
340
341
        algo = TradingAlgorithm(
342
            initialize=initialize,
343
            handle_data=handle_data,
344
            sim_params=self.sim_params,
345
            env=self.env,
346
        )
347
        algo.run(self.source)
348
349
        self.assertEqual(algo.func_called, algo.days)
350
351
    def test_event_context(self):
352
        expected_data = []
353
        collected_data_pre = []
354
        collected_data_post = []
355
        function_stack = []
356
357
        def pre(data):
358
            function_stack.append(pre)
359
            collected_data_pre.append(data)
360
361
        def post(data):
362
            function_stack.append(post)
363
            collected_data_post.append(data)
364
365
        def initialize(context):
366
            context.add_event(Always(), f)
367
            context.add_event(Always(), g)
368
369
        def handle_data(context, data):
370
            function_stack.append(handle_data)
371
            expected_data.append(data)
372
373
        def f(context, data):
374
            function_stack.append(f)
375
376
        def g(context, data):
377
            function_stack.append(g)
378
379
        algo = TradingAlgorithm(
380
            initialize=initialize,
381
            handle_data=handle_data,
382
            sim_params=self.sim_params,
383
            create_event_context=CallbackManager(pre, post),
384
            env=self.env,
385
        )
386
        algo.run(self.source)
387
388
        self.assertEqual(len(expected_data), 779)
389
        self.assertEqual(collected_data_pre, expected_data)
390
        self.assertEqual(collected_data_post, expected_data)
391
392
        self.assertEqual(
393
            len(function_stack),
394
            779 * 5,
395
            'Incorrect number of functions called: %s != 779' %
396
            len(function_stack),
397
        )
398
        expected_functions = [pre, handle_data, f, g, post] * 779
399
        for n, (f, g) in enumerate(zip(function_stack, expected_functions)):
400
            self.assertEqual(
401
                f,
402
                g,
403
                'function at position %d was incorrect, expected %s but got %s'
404
                % (n, g.__name__, f.__name__),
405
            )
406
407
    @parameterized.expand([
408
        ('daily',),
409
        ('minute'),
410
    ])
411
    def test_schedule_funtion_rule_creation(self, mode):
412
        def nop(*args, **kwargs):
413
            return None
414
415
        self.sim_params.data_frequency = mode
416
        algo = TradingAlgorithm(
417
            initialize=nop,
418
            handle_data=nop,
419
            sim_params=self.sim_params,
420
            env=self.env,
421
        )
422
423
        # Schedule something for NOT Always.
424
        algo.schedule_function(nop, time_rule=zipline.utils.events.Never())
425
426
        event_rule = algo.event_manager._events[1].rule
427
428
        self.assertIsInstance(event_rule, zipline.utils.events.OncePerDay)
429
430
        inner_rule = event_rule.rule
431
        self.assertIsInstance(inner_rule, zipline.utils.events.ComposedRule)
432
433
        first = inner_rule.first
434
        second = inner_rule.second
435
        composer = inner_rule.composer
436
437
        self.assertIsInstance(first, zipline.utils.events.Always)
438
439
        if mode == 'daily':
440
            self.assertIsInstance(second, zipline.utils.events.Always)
441
        else:
442
            self.assertIsInstance(second, zipline.utils.events.Never)
443
444
        self.assertIs(composer, zipline.utils.events.ComposedRule.lazy_and)
445
446
    def test_asset_lookup(self):
447
448
        algo = TradingAlgorithm(env=self.env)
449
450
        # Test before either PLAY existed
451
        algo.sim_params.period_end = pd.Timestamp('2001-12-01', tz='UTC')
452
        with self.assertRaises(SymbolNotFound):
453
            algo.symbol('PLAY')
454
        with self.assertRaises(SymbolNotFound):
455
            algo.symbols('PLAY')
456
457
        # Test when first PLAY exists
458
        algo.sim_params.period_end = pd.Timestamp('2002-12-01', tz='UTC')
459
        list_result = algo.symbols('PLAY')
460
        self.assertEqual(3, list_result[0])
461
462
        # Test after first PLAY ends
463
        algo.sim_params.period_end = pd.Timestamp('2004-12-01', tz='UTC')
464
        self.assertEqual(3, algo.symbol('PLAY'))
465
466
        # Test after second PLAY begins
467
        algo.sim_params.period_end = pd.Timestamp('2005-12-01', tz='UTC')
468
        self.assertEqual(4, algo.symbol('PLAY'))
469
470
        # Test after second PLAY ends
471
        algo.sim_params.period_end = pd.Timestamp('2006-12-01', tz='UTC')
472
        self.assertEqual(4, algo.symbol('PLAY'))
473
        list_result = algo.symbols('PLAY')
474
        self.assertEqual(4, list_result[0])
475
476
        # Test lookup SID
477
        self.assertIsInstance(algo.sid(3), Equity)
478
        self.assertIsInstance(algo.sid(4), Equity)
479
480
        # Supplying a non-string argument to symbol()
481
        # should result in a TypeError.
482
        with self.assertRaises(TypeError):
483
            algo.symbol(1)
484
485
        with self.assertRaises(TypeError):
486
            algo.symbol((1,))
487
488
        with self.assertRaises(TypeError):
489
            algo.symbol({1})
490
491
        with self.assertRaises(TypeError):
492
            algo.symbol([1])
493
494
        with self.assertRaises(TypeError):
495
            algo.symbol({'foo': 'bar'})
496
497
    def test_future_symbol(self):
498
        """ Tests the future_symbol API function.
499
        """
500
        algo = TradingAlgorithm(env=self.env)
501
        algo.datetime = pd.Timestamp('2006-12-01', tz='UTC')
502
503
        # Check that we get the correct fields for the CLG06 symbol
504
        cl = algo.future_symbol('CLG06')
505
        self.assertEqual(cl.sid, 5)
506
        self.assertEqual(cl.symbol, 'CLG06')
507
        self.assertEqual(cl.root_symbol, 'CL')
508
        self.assertEqual(cl.start_date, pd.Timestamp('2005-12-01', tz='UTC'))
509
        self.assertEqual(cl.notice_date, pd.Timestamp('2005-12-20', tz='UTC'))
510
        self.assertEqual(cl.expiration_date,
511
                         pd.Timestamp('2006-01-20', tz='UTC'))
512
513
        with self.assertRaises(SymbolNotFound):
514
            algo.future_symbol('')
515
516
        with self.assertRaises(SymbolNotFound):
517
            algo.future_symbol('PLAY')
518
519
        with self.assertRaises(SymbolNotFound):
520
            algo.future_symbol('FOOBAR')
521
522
        # Supplying a non-string argument to future_symbol()
523
        # should result in a TypeError.
524
        with self.assertRaises(TypeError):
525
            algo.future_symbol(1)
526
527
        with self.assertRaises(TypeError):
528
            algo.future_symbol((1,))
529
530
        with self.assertRaises(TypeError):
531
            algo.future_symbol({1})
532
533
        with self.assertRaises(TypeError):
534
            algo.future_symbol([1])
535
536
        with self.assertRaises(TypeError):
537
            algo.future_symbol({'foo': 'bar'})
538
539
    def test_future_chain(self):
540
        """ Tests the future_chain API function.
541
        """
542
        algo = TradingAlgorithm(env=self.env)
543
        algo.datetime = pd.Timestamp('2006-12-01', tz='UTC')
544
545
        # Check that the fields of the FutureChain object are set correctly
546
        cl = algo.future_chain('CL')
547
        self.assertEqual(cl.root_symbol, 'CL')
548
        self.assertEqual(cl.as_of_date, algo.datetime)
549
550
        # Check the fields are set correctly if an as_of_date is supplied
551
        as_of_date = pd.Timestamp('1952-08-11', tz='UTC')
552
553
        cl = algo.future_chain('CL', as_of_date=as_of_date)
554
        self.assertEqual(cl.root_symbol, 'CL')
555
        self.assertEqual(cl.as_of_date, as_of_date)
556
557
        cl = algo.future_chain('CL', as_of_date='1952-08-11')
558
        self.assertEqual(cl.root_symbol, 'CL')
559
        self.assertEqual(cl.as_of_date, as_of_date)
560
561
        # Check that weird capitalization is corrected
562
        cl = algo.future_chain('cL')
563
        self.assertEqual(cl.root_symbol, 'CL')
564
565
        cl = algo.future_chain('cl')
566
        self.assertEqual(cl.root_symbol, 'CL')
567
568
        # Check that invalid root symbols raise RootSymbolNotFound
569
        with self.assertRaises(RootSymbolNotFound):
570
            algo.future_chain('CLZ')
571
572
        with self.assertRaises(RootSymbolNotFound):
573
            algo.future_chain('')
574
575
        # Check that invalid dates raise UnsupportedDatetimeFormat
576
        with self.assertRaises(UnsupportedDatetimeFormat):
577
            algo.future_chain('CL', 'my_finger_slipped')
578
579
        with self.assertRaises(UnsupportedDatetimeFormat):
580
            algo.future_chain('CL', '2015-09-')
581
582
        # Supplying a non-string argument to future_chain()
583
        # should result in a TypeError.
584
        with self.assertRaises(TypeError):
585
            algo.future_chain(1)
586
587
        with self.assertRaises(TypeError):
588
            algo.future_chain((1,))
589
590
        with self.assertRaises(TypeError):
591
            algo.future_chain({1})
592
593
        with self.assertRaises(TypeError):
594
            algo.future_chain([1])
595
596
        with self.assertRaises(TypeError):
597
            algo.future_chain({'foo': 'bar'})
598
599
    def test_set_symbol_lookup_date(self):
600
        """
601
        Test the set_symbol_lookup_date API method.
602
        """
603
        # Note we start sid enumeration at i+3 so as not to
604
        # collide with sids [1, 2] added in the setUp() method.
605
        dates = pd.date_range('2013-01-01', freq='2D', periods=2, tz='UTC')
606
        # Create two assets with the same symbol but different
607
        # non-overlapping date ranges.
608
        metadata = pd.DataFrame.from_records(
609
            [
610
                {
611
                    'sid': i + 3,
612
                    'symbol': 'DUP',
613
                    'start_date': date.value,
614
                    'end_date': (date + timedelta(days=1)).value,
615
                }
616
                for i, date in enumerate(dates)
617
            ]
618
        )
619
        env = TradingEnvironment()
620
        env.write_data(equities_df=metadata)
621
        algo = TradingAlgorithm(env=env)
622
623
        # Set the period end to a date after the period end
624
        # dates for our assets.
625
        algo.sim_params.period_end = pd.Timestamp('2015-01-01', tz='UTC')
626
627
        # With no symbol lookup date set, we will use the period end date
628
        # for the as_of_date, resulting here in the asset with the earlier
629
        # start date being returned.
630
        result = algo.symbol('DUP')
631
        self.assertEqual(result.symbol, 'DUP')
632
633
        # By first calling set_symbol_lookup_date, the relevant asset
634
        # should be returned by lookup_symbol
635
        for i, date in enumerate(dates):
636
            algo.set_symbol_lookup_date(date)
637
            result = algo.symbol('DUP')
638
            self.assertEqual(result.symbol, 'DUP')
639
            self.assertEqual(result.sid, i + 3)
640
641
        with self.assertRaises(UnsupportedDatetimeFormat):
642
            algo.set_symbol_lookup_date('foobar')
643
644
645
class TestTransformAlgorithm(TestCase):
646
647
    @classmethod
648
    def setUpClass(cls):
649
        cls.env = TradingEnvironment()
650
        cls.env.write_data(equities_identifiers=[0, 1, 133])
651
652
        futures_metadata = {0: {'contract_multiplier': 10}}
653
        cls.futures_env = TradingEnvironment()
654
        cls.futures_env.write_data(futures_data=futures_metadata)
655
656
    @classmethod
657
    def tearDownClass(cls):
658
        del cls.env
659
660
    def setUp(self):
661
        setup_logger(self)
662
        self.sim_params = factory.create_simulation_parameters(num_days=4,
663
                                                               env=self.env)
664
665
        trade_history = factory.create_trade_history(
666
            133,
667
            [10.0, 10.0, 11.0, 11.0],
668
            [100, 100, 100, 300],
669
            timedelta(days=1),
670
            self.sim_params,
671
            self.env
672
        )
673
        self.source = SpecificEquityTrades(
674
            event_list=trade_history,
675
            env=self.env,
676
        )
677
        self.df_source, self.df = \
678
            factory.create_test_df_source(self.sim_params, self.env)
679
680
        self.panel_source, self.panel = \
681
            factory.create_test_panel_source(self.sim_params, self.env)
682
683
    def tearDown(self):
684
        teardown_logger(self)
685
686
    def test_source_as_input(self):
687
        algo = TestRegisterTransformAlgorithm(
688
            sim_params=self.sim_params,
689
            env=self.env,
690
            sids=[133]
691
        )
692
        algo.run(self.source)
693
        self.assertEqual(len(algo.sources), 1)
694
        assert isinstance(algo.sources[0], SpecificEquityTrades)
695
696
    def test_invalid_order_parameters(self):
697
        algo = InvalidOrderAlgorithm(
698
            sids=[133],
699
            sim_params=self.sim_params,
700
            env=self.env,
701
        )
702
        algo.run(self.source)
703
704
    def test_multi_source_as_input(self):
705
        sim_params = SimulationParameters(
706
            self.df.index[0],
707
            self.df.index[-1],
708
            env=self.env,
709
        )
710
        algo = TestRegisterTransformAlgorithm(
711
            sim_params=sim_params,
712
            sids=[0, 1],
713
            env=self.env,
714
        )
715
        algo.run([self.source, self.df_source], overwrite_sim_params=False)
716
        self.assertEqual(len(algo.sources), 2)
717
718
    def test_df_as_input(self):
719
        algo = TestRegisterTransformAlgorithm(
720
            sim_params=self.sim_params,
721
            env=self.env,
722
        )
723
        algo.run(self.df)
724
        assert isinstance(algo.sources[0], DataFrameSource)
725
726
    def test_panel_as_input(self):
727
        algo = TestRegisterTransformAlgorithm(
728
            sim_params=self.sim_params,
729
            env=self.env,
730
            sids=[0, 1])
731
        panel = self.panel.copy()
732
        panel.items = pd.Index(map(Equity, panel.items))
733
        algo.run(panel)
734
        assert isinstance(algo.sources[0], DataPanelSource)
735
736
    def test_df_of_assets_as_input(self):
737
        algo = TestRegisterTransformAlgorithm(
738
            sim_params=self.sim_params,
739
            env=TradingEnvironment(),  # new env without assets
740
        )
741
        df = self.df.copy()
742
        df.columns = pd.Index(map(Equity, df.columns))
743
        algo.run(df)
744
        assert isinstance(algo.sources[0], DataFrameSource)
745
746
    def test_panel_of_assets_as_input(self):
747
        algo = TestRegisterTransformAlgorithm(
748
            sim_params=self.sim_params,
749
            env=TradingEnvironment(),  # new env without assets
750
            sids=[0, 1])
751
        algo.run(self.panel)
752
        assert isinstance(algo.sources[0], DataPanelSource)
753
754
    def test_run_twice(self):
755
        algo1 = TestRegisterTransformAlgorithm(
756
            sim_params=self.sim_params,
757
            sids=[0, 1]
758
        )
759
760
        res1 = algo1.run(self.df)
761
762
        # Create a new trading algorithm, which will
763
        # use the newly instantiated environment.
764
        algo2 = TestRegisterTransformAlgorithm(
765
            sim_params=self.sim_params,
766
            sids=[0, 1]
767
        )
768
769
        res2 = algo2.run(self.df)
770
771
        np.testing.assert_array_equal(res1, res2)
772
773
    def test_data_frequency_setting(self):
774
        self.sim_params.data_frequency = 'daily'
775
        algo = TestRegisterTransformAlgorithm(
776
            sim_params=self.sim_params,
777
            env=self.env,
778
        )
779
        self.assertEqual(algo.sim_params.data_frequency, 'daily')
780
781
        self.sim_params.data_frequency = 'minute'
782
        algo = TestRegisterTransformAlgorithm(
783
            sim_params=self.sim_params,
784
            env=self.env,
785
        )
786
        self.assertEqual(algo.sim_params.data_frequency, 'minute')
787
788
    @parameterized.expand([
789
        (TestOrderAlgorithm,),
790
        (TestOrderValueAlgorithm,),
791
        (TestTargetAlgorithm,),
792
        (TestOrderPercentAlgorithm,),
793
        (TestTargetPercentAlgorithm,),
794
        (TestTargetValueAlgorithm,),
795
    ])
796
    def test_order_methods(self, algo_class):
797
        algo = algo_class(
798
            sim_params=self.sim_params,
799
            env=self.env,
800
        )
801
802
        # Ensure that the environment's asset 0 is an Equity
803
        asset_to_test = algo.sid(0)
804
        self.assertIsInstance(asset_to_test, Equity)
805
806
        algo.run(self.df)
807
808
    @parameterized.expand([
809
        (TestOrderAlgorithm,),
810
        (TestOrderValueAlgorithm,),
811
        (TestTargetAlgorithm,),
812
        (TestOrderPercentAlgorithm,),
813
        (TestTargetValueAlgorithm,),
814
    ])
815
    def test_order_methods_for_future(self, algo_class):
816
        algo = algo_class(
817
            sim_params=self.sim_params,
818
            env=self.futures_env,
819
        )
820
821
        # Ensure that the environment's asset 0 is a Future
822
        asset_to_test = algo.sid(0)
823
        self.assertIsInstance(asset_to_test, Future)
824
825
        algo.run(self.df)
826
827
    def test_order_method_style_forwarding(self):
828
829
        method_names_to_test = ['order',
830
                                'order_value',
831
                                'order_percent',
832
                                'order_target',
833
                                'order_target_percent',
834
                                'order_target_value']
835
836
        for name in method_names_to_test:
837
            # Don't supply an env so the TradingAlgorithm builds a new one for
838
            # each method
839
            algo = TestOrderStyleForwardingAlgorithm(
840
                sim_params=self.sim_params,
841
                instant_fill=False,
842
                method_name=name
843
            )
844
            algo.run(self.df)
845
846
    def test_order_instant(self):
847
        algo = TestOrderInstantAlgorithm(sim_params=self.sim_params,
848
                                         env=self.env,
849
                                         instant_fill=True)
850
        algo.run(self.df)
851
852
    def test_minute_data(self):
853
        source = RandomWalkSource(freq='minute',
854
                                  start=pd.Timestamp('2000-1-3',
855
                                                     tz='UTC'),
856
                                  end=pd.Timestamp('2000-1-4',
857
                                                   tz='UTC'))
858
        self.sim_params.data_frequency = 'minute'
859
        algo = TestOrderInstantAlgorithm(sim_params=self.sim_params,
860
                                         env=self.env,
861
                                         instant_fill=True)
862
        algo.run(source)
863
864
865
class TestPositions(TestCase):
866
867
    def setUp(self):
868
        setup_logger(self)
869
        self.env = TradingEnvironment()
870
        self.sim_params = factory.create_simulation_parameters(num_days=4,
871
                                                               env=self.env)
872
        self.env.write_data(equities_identifiers=[1, 133])
873
874
        trade_history = factory.create_trade_history(
875
            1,
876
            [10.0, 10.0, 11.0, 11.0],
877
            [100, 100, 100, 300],
878
            timedelta(days=1),
879
            self.sim_params,
880
            self.env
881
        )
882
        self.source = SpecificEquityTrades(
883
            event_list=trade_history,
884
            env=self.env,
885
        )
886
887
        self.df_source, self.df = \
888
            factory.create_test_df_source(self.sim_params, self.env)
889
890
    def tearDown(self):
891
        teardown_logger(self)
892
893
    def test_empty_portfolio(self):
894
        algo = EmptyPositionsAlgorithm(sim_params=self.sim_params,
895
                                       env=self.env)
896
        daily_stats = algo.run(self.df)
897
898
        expected_position_count = [
899
            0,  # Before entering the first position
900
            1,  # After entering, exiting on this date
901
            0,  # After exiting
902
            0,
903
        ]
904
905
        for i, expected in enumerate(expected_position_count):
906
            self.assertEqual(daily_stats.ix[i]['num_positions'],
907
                             expected)
908
909
    def test_noop_orders(self):
910
911
        algo = AmbitiousStopLimitAlgorithm(sid=1,
912
                                           sim_params=self.sim_params,
913
                                           env=self.env)
914
        daily_stats = algo.run(self.source)
915
916
        # Verify that possitions are empty for all dates.
917
        empty_positions = daily_stats.positions.map(lambda x: len(x) == 0)
918
        self.assertTrue(empty_positions.all())
919
920
921
class TestAlgoScript(TestCase):
922
923
    @classmethod
924
    def setUpClass(cls):
925
        cls.env = TradingEnvironment()
926
        cls.env.write_data(
927
            equities_identifiers=[0, 1, 133]
928
        )
929
930
    @classmethod
931
    def tearDownClass(cls):
932
        del cls.env
933
934
    def setUp(self):
935
        days = 251
936
        # Note that create_simulation_parameters creates
937
        # a new TradingEnvironment
938
        self.sim_params = factory.create_simulation_parameters(num_days=days,
939
                                                               env=self.env)
940
941
        setup_logger(self)
942
        trade_history = factory.create_trade_history(
943
            133,
944
            [10.0] * days,
945
            [100] * days,
946
            timedelta(days=1),
947
            self.sim_params,
948
            self.env
949
        )
950
951
        self.source = SpecificEquityTrades(
952
            sids=[133],
953
            event_list=trade_history,
954
            env=self.env,
955
        )
956
957
        self.df_source, self.df = \
958
            factory.create_test_df_source(self.sim_params, self.env)
959
960
        self.zipline_test_config = {
961
            'sid': 0,
962
        }
963
964
    def tearDown(self):
965
        teardown_logger(self)
966
967
    def test_noop(self):
968
        algo = TradingAlgorithm(initialize=initialize_noop,
969
                                handle_data=handle_data_noop)
970
        algo.run(self.df)
971
972
    def test_noop_string(self):
973
        algo = TradingAlgorithm(script=noop_algo)
974
        algo.run(self.df)
975
976
    def test_api_calls(self):
977
        algo = TradingAlgorithm(initialize=initialize_api,
978
                                handle_data=handle_data_api)
979
        algo.run(self.df)
980
981
    def test_api_calls_string(self):
982
        algo = TradingAlgorithm(script=api_algo)
983
        algo.run(self.df)
984
985
    def test_api_get_environment(self):
986
        platform = 'zipline'
987
        # Use sid not already in test database.
988
        metadata = {3: {'symbol': 'TEST'}}
989
        algo = TradingAlgorithm(script=api_get_environment_algo,
990
                                equities_metadata=metadata,
991
                                platform=platform)
992
        algo.run(self.df)
993
        self.assertEqual(algo.environment, platform)
994
995
    def test_api_symbol(self):
996
        # Use sid not already in test database.
997
        metadata = {3: {'symbol': 'TEST'}}
998
        algo = TradingAlgorithm(script=api_symbol_algo,
999
                                equities_metadata=metadata)
1000
        algo.run(self.df)
1001
1002
    def test_fixed_slippage(self):
1003
        # verify order -> transaction -> portfolio position.
1004
        # --------------
1005
        test_algo = TradingAlgorithm(
1006
            script="""
1007
from zipline.api import (slippage,
1008
                         commission,
1009
                         set_slippage,
1010
                         set_commission,
1011
                         order,
1012
                         record,
1013
                         sid)
1014
1015
def initialize(context):
1016
    model = slippage.FixedSlippage(spread=0.10)
1017
    set_slippage(model)
1018
    set_commission(commission.PerTrade(100.00))
1019
    context.count = 1
1020
    context.incr = 0
1021
1022
def handle_data(context, data):
1023
    if context.incr < context.count:
1024
        order(sid(0), -1000)
1025
    record(price=data[0].price)
1026
1027
    context.incr += 1""",
1028
            sim_params=self.sim_params,
1029
            env=self.env,
1030
        )
1031
        set_algo_instance(test_algo)
1032
1033
        self.zipline_test_config['algorithm'] = test_algo
1034
        self.zipline_test_config['trade_count'] = 200
1035
1036
        # this matches the value in the algotext initialize
1037
        # method, and will be used inside assert_single_position
1038
        # to confirm we have as many transactions as orders we
1039
        # placed.
1040
        self.zipline_test_config['order_count'] = 1
1041
1042
        zipline = simfactory.create_test_zipline(
1043
            **self.zipline_test_config)
1044
1045
        output, _ = assert_single_position(self, zipline)
1046
1047
        # confirm the slippage and commission on a sample
1048
        # transaction
1049
        recorded_price = output[1]['daily_perf']['recorded_vars']['price']
1050
        transaction = output[1]['daily_perf']['transactions'][0]
1051
        self.assertEqual(100.0, transaction['commission'])
1052
        expected_spread = 0.05
1053
        expected_commish = 0.10
1054
        expected_price = recorded_price - expected_spread - expected_commish
1055
        self.assertEqual(expected_price, transaction['price'])
1056
1057
    def test_volshare_slippage(self):
1058
        # verify order -> transaction -> portfolio position.
1059
        # --------------
1060
        test_algo = TradingAlgorithm(
1061
            script="""
1062
from zipline.api import *
1063
1064
def initialize(context):
1065
    model = slippage.VolumeShareSlippage(
1066
                            volume_limit=.3,
1067
                            price_impact=0.05
1068
                       )
1069
    set_slippage(model)
1070
    set_commission(commission.PerShare(0.02))
1071
    context.count = 2
1072
    context.incr = 0
1073
1074
def handle_data(context, data):
1075
    if context.incr < context.count:
1076
        # order small lots to be sure the
1077
        # order will fill in a single transaction
1078
        order(sid(0), 5000)
1079
    record(price=data[0].price)
1080
    record(volume=data[0].volume)
1081
    record(incr=context.incr)
1082
    context.incr += 1
1083
    """,
1084
            sim_params=self.sim_params,
1085
            env=self.env,
1086
        )
1087
        set_algo_instance(test_algo)
1088
1089
        self.zipline_test_config['algorithm'] = test_algo
1090
        self.zipline_test_config['trade_count'] = 100
1091
1092
        # 67 will be used inside assert_single_position
1093
        # to confirm we have as many transactions as expected.
1094
        # The algo places 2 trades of 5000 shares each. The trade
1095
        # events have volume ranging from 100 to 950. The volume cap
1096
        # of 0.3 limits the trade volume to a range of 30 - 316 shares.
1097
        # The spreadsheet linked below calculates the total position
1098
        # size over each bar, and predicts 67 txns will be required
1099
        # to fill the two orders. The number of bars and transactions
1100
        # differ because some bars result in multiple txns. See
1101
        # spreadsheet for details:
1102
# https://www.dropbox.com/s/ulrk2qt0nrtrigb/Volume%20Share%20Worksheet.xlsx
1103
        self.zipline_test_config['expected_transactions'] = 67
1104
1105
        zipline = simfactory.create_test_zipline(
1106
            **self.zipline_test_config)
1107
        output, _ = assert_single_position(self, zipline)
1108
1109
        # confirm the slippage and commission on a sample
1110
        # transaction
1111
        per_share_commish = 0.02
1112
        perf = output[1]
1113
        transaction = perf['daily_perf']['transactions'][0]
1114
        commish = transaction['amount'] * per_share_commish
1115
        self.assertEqual(commish, transaction['commission'])
1116
        self.assertEqual(2.029, transaction['price'])
1117
1118
    def test_algo_record_vars(self):
1119
        test_algo = TradingAlgorithm(
1120
            script=record_variables,
1121
            sim_params=self.sim_params,
1122
            env=self.env,
1123
        )
1124
        set_algo_instance(test_algo)
1125
1126
        self.zipline_test_config['algorithm'] = test_algo
1127
        self.zipline_test_config['trade_count'] = 200
1128
1129
        zipline = simfactory.create_test_zipline(
1130
            **self.zipline_test_config)
1131
        output, _ = drain_zipline(self, zipline)
1132
        self.assertEqual(len(output), 252)
1133
        incr = []
1134
        for o in output[:200]:
1135
            incr.append(o['daily_perf']['recorded_vars']['incr'])
1136
1137
        np.testing.assert_array_equal(incr, range(1, 201))
1138
1139
    def test_algo_record_allow_mock(self):
1140
        """
1141
        Test that values from "MagicMock"ed methods can be passed to record.
1142
1143
        Relevant for our basic/validation and methods like history, which
1144
        will end up returning a MagicMock instead of a DataFrame.
1145
        """
1146
        test_algo = TradingAlgorithm(
1147
            script=record_variables,
1148
            sim_params=self.sim_params,
1149
        )
1150
        set_algo_instance(test_algo)
1151
1152
        test_algo.record(foo=MagicMock())
1153
1154
    def _algo_record_float_magic_should_pass(self, var_type):
1155
        test_algo = TradingAlgorithm(
1156
            script=record_float_magic % var_type,
1157
            sim_params=self.sim_params,
1158
            env=self.env,
1159
        )
1160
        set_algo_instance(test_algo)
1161
1162
        self.zipline_test_config['algorithm'] = test_algo
1163
        self.zipline_test_config['trade_count'] = 200
1164
1165
        zipline = simfactory.create_test_zipline(
1166
            **self.zipline_test_config)
1167
        output, _ = drain_zipline(self, zipline)
1168
        self.assertEqual(len(output), 252)
1169
        incr = []
1170
        for o in output[:200]:
1171
            incr.append(o['daily_perf']['recorded_vars']['data'])
1172
        np.testing.assert_array_equal(incr, [np.nan] * 200)
1173
1174
    def test_algo_record_nan(self):
1175
        self._algo_record_float_magic_should_pass('nan')
1176
1177
    def test_order_methods(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
1178
        """
1179
        Only test that order methods can be called without error.
1180
        Correct filling of orders is tested in zipline.
1181
        """
1182
        test_algo = TradingAlgorithm(
1183
            script=call_all_order_methods,
1184
            sim_params=self.sim_params,
1185
            env=self.env,
1186
        )
1187
        set_algo_instance(test_algo)
1188
1189
        self.zipline_test_config['algorithm'] = test_algo
1190
        self.zipline_test_config['trade_count'] = 200
1191
1192
        zipline = simfactory.create_test_zipline(
1193
            **self.zipline_test_config)
1194
1195
        output, _ = drain_zipline(self, zipline)
1196
1197
    def test_order_in_init(self):
1198
        """
1199
        Test that calling order in initialize
1200
        will raise an error.
1201
        """
1202
        with self.assertRaises(OrderDuringInitialize):
1203
            test_algo = TradingAlgorithm(
1204
                script=call_order_in_init,
1205
                sim_params=self.sim_params,
1206
                env=self.env,
1207
            )
1208
            set_algo_instance(test_algo)
1209
            test_algo.run(self.source)
1210
1211
    def test_portfolio_in_init(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
1212
        """
1213
        Test that accessing portfolio in init doesn't break.
1214
        """
1215
        test_algo = TradingAlgorithm(
1216
            script=access_portfolio_in_init,
1217
            sim_params=self.sim_params,
1218
            env=self.env,
1219
        )
1220
        set_algo_instance(test_algo)
1221
1222
        self.zipline_test_config['algorithm'] = test_algo
1223
        self.zipline_test_config['trade_count'] = 1
1224
1225
        zipline = simfactory.create_test_zipline(
1226
            **self.zipline_test_config)
1227
1228
        output, _ = drain_zipline(self, zipline)
1229
1230
    def test_account_in_init(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
1231
        """
1232
        Test that accessing account in init doesn't break.
1233
        """
1234
        test_algo = TradingAlgorithm(
1235
            script=access_account_in_init,
1236
            sim_params=self.sim_params,
1237
            env=self.env,
1238
        )
1239
        set_algo_instance(test_algo)
1240
1241
        self.zipline_test_config['algorithm'] = test_algo
1242
        self.zipline_test_config['trade_count'] = 1
1243
1244
        zipline = simfactory.create_test_zipline(
1245
            **self.zipline_test_config)
1246
1247
        output, _ = drain_zipline(self, zipline)
1248
1249
1250
class TestHistory(TestCase):
1251
1252
    def setUp(self):
1253
        setup_logger(self)
1254
1255
    def tearDown(self):
1256
        teardown_logger(self)
1257
1258
    @classmethod
1259
    def setUpClass(cls):
1260
        cls._start = pd.Timestamp('1991-01-01', tz='UTC')
1261
        cls._end = pd.Timestamp('1991-01-15', tz='UTC')
1262
        cls.env = TradingEnvironment()
1263
        cls.sim_params = factory.create_simulation_parameters(
1264
            data_frequency='minute',
1265
            env=cls.env
1266
        )
1267
        cls.env.write_data(equities_identifiers=[0, 1])
1268
1269
    @classmethod
1270
    def tearDownClass(cls):
1271
        del cls.env
1272
1273
    @property
1274
    def source(self):
1275
        return RandomWalkSource(start=self._start, end=self._end)
1276
1277
    def test_history(self):
1278
        history_algo = """
1279
from zipline.api import history, add_history
1280
1281
def initialize(context):
1282
    add_history(10, '1d', 'price')
1283
1284
def handle_data(context, data):
1285
    df = history(10, '1d', 'price')
1286
"""
1287
1288
        algo = TradingAlgorithm(
1289
            script=history_algo,
1290
            sim_params=self.sim_params,
1291
            env=self.env,
1292
        )
1293
        output = algo.run(self.source)
1294
        self.assertIsNot(output, None)
1295
1296
    def test_history_without_add(self):
1297
        def handle_data(algo, data):
1298
            algo.history(1, '1m', 'price')
1299
1300
        algo = TradingAlgorithm(
1301
            initialize=lambda _: None,
1302
            handle_data=handle_data,
1303
            sim_params=self.sim_params,
1304
            env=self.env,
1305
        )
1306
        algo.run(self.source)
1307
1308
        self.assertIsNotNone(algo.history_container)
1309
        self.assertEqual(algo.history_container.buffer_panel.window_length, 1)
1310
1311
    def test_add_history_in_handle_data(self):
1312
        def handle_data(algo, data):
1313
            algo.add_history(1, '1m', 'price')
1314
1315
        algo = TradingAlgorithm(
1316
            initialize=lambda _: None,
1317
            handle_data=handle_data,
1318
            sim_params=self.sim_params,
1319
            env=self.env,
1320
        )
1321
        algo.run(self.source)
1322
1323
        self.assertIsNotNone(algo.history_container)
1324
        self.assertEqual(algo.history_container.buffer_panel.window_length, 1)
1325
1326
1327
class TestGetDatetime(TestCase):
1328
1329
    @classmethod
1330
    def setUpClass(cls):
1331
        cls.env = TradingEnvironment()
1332
        cls.env.write_data(equities_identifiers=[0, 1])
1333
1334
    @classmethod
1335
    def tearDownClass(cls):
1336
        del cls.env
1337
1338
    def setUp(self):
1339
        setup_logger(self)
1340
1341
    def tearDown(self):
1342
        teardown_logger(self)
1343
1344
    @parameterized.expand(
1345
        [
1346
            ('default', None,),
1347
            ('utc', 'UTC',),
1348
            ('us_east', 'US/Eastern',),
1349
        ]
1350
    )
1351
    def test_get_datetime(self, name, tz):
1352
1353
        algo = dedent(
1354
            """
1355
            import pandas as pd
1356
            from zipline.api import get_datetime
1357
1358
            def initialize(context):
1359
                context.tz = {tz} or 'UTC'
1360
                context.first_bar = True
1361
1362
            def handle_data(context, data):
1363
                if context.first_bar:
1364
                    dt = get_datetime({tz})
1365
                    if dt.tz.zone != context.tz:
1366
                        raise ValueError("Mismatched Zone")
1367
                    elif dt.tz_convert("US/Eastern").hour != 9:
1368
                        raise ValueError("Mismatched Hour")
1369
                    elif dt.tz_convert("US/Eastern").minute != 31:
1370
                        raise ValueError("Mismatched Minute")
1371
                context.first_bar = False
1372
            """.format(tz=repr(tz))
1373
        )
1374
1375
        start = to_utc('2014-01-02 9:31')
1376
        end = to_utc('2014-01-03 9:31')
1377
        source = RandomWalkSource(
1378
            start=start,
1379
            end=end,
1380
        )
1381
        sim_params = factory.create_simulation_parameters(
1382
            data_frequency='minute',
1383
            env=self.env,
1384
        )
1385
        algo = TradingAlgorithm(
1386
            script=algo,
1387
            sim_params=sim_params,
1388
            env=self.env,
1389
        )
1390
        algo.run(source)
1391
        self.assertFalse(algo.first_bar)
1392
1393
1394
class TestTradingControls(TestCase):
1395
1396
    @classmethod
1397
    def setUpClass(cls):
1398
        cls.sid = 133
1399
        cls.env = TradingEnvironment()
1400
        cls.env.write_data(equities_identifiers=[cls.sid])
1401
1402
    @classmethod
1403
    def tearDownClass(cls):
1404
        del cls.env
1405
1406
    def setUp(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
1407
        self.sim_params = factory.create_simulation_parameters(num_days=4,
1408
                                                               env=self.env)
1409
        self.trade_history = factory.create_trade_history(
1410
            self.sid,
1411
            [10.0, 10.0, 11.0, 11.0],
1412
            [100, 100, 100, 300],
1413
            timedelta(days=1),
1414
            self.sim_params,
1415
            self.env
1416
        )
1417
1418
        self.source = SpecificEquityTrades(
1419
            event_list=self.trade_history,
1420
            env=self.env,
1421
        )
1422
1423
    def _check_algo(self,
1424
                    algo,
1425
                    handle_data,
1426
                    expected_order_count,
1427
                    expected_exc):
1428
1429
        algo._handle_data = handle_data
1430
        with self.assertRaises(expected_exc) if expected_exc else nullctx():
1431
            algo.run(self.source)
1432
        self.assertEqual(algo.order_count, expected_order_count)
1433
        self.source.rewind()
1434
1435
    def check_algo_succeeds(self, algo, handle_data, order_count=4):
1436
        # Default for order_count assumes one order per handle_data call.
1437
        self._check_algo(algo, handle_data, order_count, None)
1438
1439
    def check_algo_fails(self, algo, handle_data, order_count):
1440
        self._check_algo(algo,
1441
                         handle_data,
1442
                         order_count,
1443
                         TradingControlViolation)
1444
1445
    def test_set_max_position_size(self):
1446
1447
        # Buy one share four times.  Should be fine.
1448
        def handle_data(algo, data):
1449
            algo.order(algo.sid(self.sid), 1)
1450
            algo.order_count += 1
1451
        algo = SetMaxPositionSizeAlgorithm(sid=self.sid,
1452
                                           max_shares=10,
1453
                                           max_notional=500.0,
1454
                                           sim_params=self.sim_params,
1455
                                           env=self.env)
1456
        self.check_algo_succeeds(algo, handle_data)
1457
1458
        # Buy three shares four times.  Should bail on the fourth before it's
1459
        # placed.
1460
        def handle_data(algo, data):
1461
            algo.order(algo.sid(self.sid), 3)
1462
            algo.order_count += 1
1463
1464
        algo = SetMaxPositionSizeAlgorithm(sid=self.sid,
1465
                                           max_shares=10,
1466
                                           max_notional=500.0,
1467
                                           sim_params=self.sim_params,
1468
                                           env=self.env)
1469
        self.check_algo_fails(algo, handle_data, 3)
1470
1471
        # Buy two shares four times. Should bail due to max_notional on the
1472
        # third attempt.
1473
        def handle_data(algo, data):
1474
            algo.order(algo.sid(self.sid), 3)
1475
            algo.order_count += 1
1476
1477
        algo = SetMaxPositionSizeAlgorithm(sid=self.sid,
1478
                                           max_shares=10,
1479
                                           max_notional=61.0,
1480
                                           sim_params=self.sim_params,
1481
                                           env=self.env)
1482
        self.check_algo_fails(algo, handle_data, 2)
1483
1484
        # Set the trading control to a different sid, then BUY ALL THE THINGS!.
1485
        # Should continue normally.
1486
        def handle_data(algo, data):
1487
            algo.order(algo.sid(self.sid), 10000)
1488
            algo.order_count += 1
1489
        algo = SetMaxPositionSizeAlgorithm(sid=self.sid + 1,
1490
                                           max_shares=10,
1491
                                           max_notional=61.0,
1492
                                           sim_params=self.sim_params,
1493
                                           env=self.env)
1494
        self.check_algo_succeeds(algo, handle_data)
1495
1496
        # Set the trading control sid to None, then BUY ALL THE THINGS!. Should
1497
        # fail because setting sid to None makes the control apply to all sids.
1498
        def handle_data(algo, data):
1499
            algo.order(algo.sid(self.sid), 10000)
1500
            algo.order_count += 1
1501
        algo = SetMaxPositionSizeAlgorithm(max_shares=10, max_notional=61.0,
1502
                                           sim_params=self.sim_params,
1503
                                           env=self.env)
1504
        self.check_algo_fails(algo, handle_data, 0)
1505
1506
    def test_set_do_not_order_list(self):
1507
        # set the restricted list to be the sid, and fail.
1508
        algo = SetDoNotOrderListAlgorithm(
1509
            sid=self.sid,
1510
            restricted_list=[self.sid],
1511
            sim_params=self.sim_params,
1512
            env=self.env,
1513
        )
1514
1515
        def handle_data(algo, data):
1516
            algo.order(algo.sid(self.sid), 100)
1517
            algo.order_count += 1
1518
1519
        self.check_algo_fails(algo, handle_data, 0)
1520
1521
        # set the restricted list to exclude the sid, and succeed
1522
        algo = SetDoNotOrderListAlgorithm(
1523
            sid=self.sid,
1524
            restricted_list=[134, 135, 136],
1525
            sim_params=self.sim_params,
1526
            env=self.env,
1527
        )
1528
1529
        def handle_data(algo, data):
1530
            algo.order(algo.sid(self.sid), 100)
1531
            algo.order_count += 1
1532
1533
        self.check_algo_succeeds(algo, handle_data)
1534
1535
    def test_set_max_order_size(self):
1536
1537
        # Buy one share.
1538
        def handle_data(algo, data):
1539
            algo.order(algo.sid(self.sid), 1)
1540
            algo.order_count += 1
1541
        algo = SetMaxOrderSizeAlgorithm(sid=self.sid,
1542
                                        max_shares=10,
1543
                                        max_notional=500.0,
1544
                                        sim_params=self.sim_params,
1545
                                        env=self.env)
1546
        self.check_algo_succeeds(algo, handle_data)
1547
1548
        # Buy 1, then 2, then 3, then 4 shares.  Bail on the last attempt
1549
        # because we exceed shares.
1550
        def handle_data(algo, data):
1551
            algo.order(algo.sid(self.sid), algo.order_count + 1)
1552
            algo.order_count += 1
1553
1554
        algo = SetMaxOrderSizeAlgorithm(sid=self.sid,
1555
                                        max_shares=3,
1556
                                        max_notional=500.0,
1557
                                        sim_params=self.sim_params,
1558
                                        env=self.env)
1559
        self.check_algo_fails(algo, handle_data, 3)
1560
1561
        # Buy 1, then 2, then 3, then 4 shares.  Bail on the last attempt
1562
        # because we exceed notional.
1563
        def handle_data(algo, data):
1564
            algo.order(algo.sid(self.sid), algo.order_count + 1)
1565
            algo.order_count += 1
1566
1567
        algo = SetMaxOrderSizeAlgorithm(sid=self.sid,
1568
                                        max_shares=10,
1569
                                        max_notional=40.0,
1570
                                        sim_params=self.sim_params,
1571
                                        env=self.env)
1572
        self.check_algo_fails(algo, handle_data, 3)
1573
1574
        # Set the trading control to a different sid, then BUY ALL THE THINGS!.
1575
        # Should continue normally.
1576
        def handle_data(algo, data):
1577
            algo.order(algo.sid(self.sid), 10000)
1578
            algo.order_count += 1
1579
        algo = SetMaxOrderSizeAlgorithm(sid=self.sid + 1,
1580
                                        max_shares=1,
1581
                                        max_notional=1.0,
1582
                                        sim_params=self.sim_params,
1583
                                        env=self.env)
1584
        self.check_algo_succeeds(algo, handle_data)
1585
1586
        # Set the trading control sid to None, then BUY ALL THE THINGS!.
1587
        # Should fail because not specifying a sid makes the trading control
1588
        # apply to all sids.
1589
        def handle_data(algo, data):
1590
            algo.order(algo.sid(self.sid), 10000)
1591
            algo.order_count += 1
1592
        algo = SetMaxOrderSizeAlgorithm(max_shares=1,
1593
                                        max_notional=1.0,
1594
                                        sim_params=self.sim_params,
1595
                                        env=self.env)
1596
        self.check_algo_fails(algo, handle_data, 0)
1597
1598
    def test_set_max_order_count(self):
1599
1600
        # Override the default setUp to use six-hour intervals instead of full
1601
        # days so we can exercise trading-session rollover logic.
1602
        trade_history = factory.create_trade_history(
1603
            self.sid,
1604
            [10.0, 10.0, 11.0, 11.0],
1605
            [100, 100, 100, 300],
1606
            timedelta(hours=6),
1607
            self.sim_params,
1608
            self.env
1609
        )
1610
        self.source = SpecificEquityTrades(event_list=trade_history,
1611
                                           env=self.env)
1612
1613
        def handle_data(algo, data):
1614
            for i in range(5):
1615
                algo.order(algo.sid(self.sid), 1)
1616
                algo.order_count += 1
1617
1618
        algo = SetMaxOrderCountAlgorithm(3, sim_params=self.sim_params,
1619
                                         env=self.env)
1620
        self.check_algo_fails(algo, handle_data, 3)
1621
1622
        # Second call to handle_data is the same day as the first, so the last
1623
        # order of the second call should fail.
1624
        algo = SetMaxOrderCountAlgorithm(9, sim_params=self.sim_params,
1625
                                         env=self.env)
1626
        self.check_algo_fails(algo, handle_data, 9)
1627
1628
        # Only ten orders are placed per day, so this should pass even though
1629
        # in total more than 20 orders are placed.
1630
        algo = SetMaxOrderCountAlgorithm(10, sim_params=self.sim_params,
1631
                                         env=self.env)
1632
        self.check_algo_succeeds(algo, handle_data, order_count=20)
1633
1634
    def test_long_only(self):
1635
        # Sell immediately -> fail immediately.
1636
        def handle_data(algo, data):
1637
            algo.order(algo.sid(self.sid), -1)
1638
            algo.order_count += 1
1639
        algo = SetLongOnlyAlgorithm(sim_params=self.sim_params, env=self.env)
1640
        self.check_algo_fails(algo, handle_data, 0)
1641
1642
        # Buy on even days, sell on odd days.  Never takes a short position, so
1643
        # should succeed.
1644
        def handle_data(algo, data):
1645
            if (algo.order_count % 2) == 0:
1646
                algo.order(algo.sid(self.sid), 1)
1647
            else:
1648
                algo.order(algo.sid(self.sid), -1)
1649
            algo.order_count += 1
1650
        algo = SetLongOnlyAlgorithm(sim_params=self.sim_params, env=self.env)
1651
        self.check_algo_succeeds(algo, handle_data)
1652
1653
        # Buy on first three days, then sell off holdings.  Should succeed.
1654
        def handle_data(algo, data):
1655
            amounts = [1, 1, 1, -3]
1656
            algo.order(algo.sid(self.sid), amounts[algo.order_count])
1657
            algo.order_count += 1
1658
        algo = SetLongOnlyAlgorithm(sim_params=self.sim_params, env=self.env)
1659
        self.check_algo_succeeds(algo, handle_data)
1660
1661
        # Buy on first three days, then sell off holdings plus an extra share.
1662
        # Should fail on the last sale.
1663
        def handle_data(algo, data):
1664
            amounts = [1, 1, 1, -4]
1665
            algo.order(algo.sid(self.sid), amounts[algo.order_count])
1666
            algo.order_count += 1
1667
        algo = SetLongOnlyAlgorithm(sim_params=self.sim_params, env=self.env)
1668
        self.check_algo_fails(algo, handle_data, 3)
1669
1670
    def test_register_post_init(self):
1671
1672
        def initialize(algo):
1673
            algo.initialized = True
1674
1675
        def handle_data(algo, data):
1676
1677
            with self.assertRaises(RegisterTradingControlPostInit):
1678
                algo.set_max_position_size(self.sid, 1, 1)
1679
            with self.assertRaises(RegisterTradingControlPostInit):
1680
                algo.set_max_order_size(self.sid, 1, 1)
1681
            with self.assertRaises(RegisterTradingControlPostInit):
1682
                algo.set_max_order_count(1)
1683
            with self.assertRaises(RegisterTradingControlPostInit):
1684
                algo.set_long_only()
1685
1686
        algo = TradingAlgorithm(initialize=initialize,
1687
                                handle_data=handle_data,
1688
                                sim_params=self.sim_params,
1689
                                env=self.env)
1690
        algo.run(self.source)
1691
        self.source.rewind()
1692
1693
    def test_asset_date_bounds(self):
1694
1695
        # Run the algorithm with a sid that ends far in the future
1696
        temp_env = TradingEnvironment()
1697
        df_source, _ = factory.create_test_df_source(self.sim_params, temp_env)
1698
        metadata = {0: {'start_date': '1990-01-01',
1699
                        'end_date': '2020-01-01'}}
1700
        algo = SetAssetDateBoundsAlgorithm(
1701
            equities_metadata=metadata,
1702
            sim_params=self.sim_params,
1703
            env=temp_env,
1704
        )
1705
        algo.run(df_source)
1706
1707
        # Run the algorithm with a sid that has already ended
1708
        temp_env = TradingEnvironment()
1709
        df_source, _ = factory.create_test_df_source(self.sim_params, temp_env)
1710
        metadata = {0: {'start_date': '1989-01-01',
1711
                        'end_date': '1990-01-01'}}
1712
        algo = SetAssetDateBoundsAlgorithm(
1713
            equities_metadata=metadata,
1714
            sim_params=self.sim_params,
1715
            env=temp_env,
1716
        )
1717
        with self.assertRaises(TradingControlViolation):
1718
            algo.run(df_source)
1719
1720
        # Run the algorithm with a sid that has not started
1721
        temp_env = TradingEnvironment()
1722
        df_source, _ = factory.create_test_df_source(self.sim_params, temp_env)
1723
        metadata = {0: {'start_date': '2020-01-01',
1724
                        'end_date': '2021-01-01'}}
1725
        algo = SetAssetDateBoundsAlgorithm(
1726
            equities_metadata=metadata,
1727
            sim_params=self.sim_params,
1728
            env=temp_env,
1729
        )
1730
        with self.assertRaises(TradingControlViolation):
1731
            algo.run(df_source)
1732
1733
        # Run the algorithm with a sid that starts on the first day and
1734
        # ends on the last day of the algorithm's parameters (*not* an error).
1735
        temp_env = TradingEnvironment()
1736
        df_source, _ = factory.create_test_df_source(self.sim_params, temp_env)
1737
        metadata = {0: {'start_date': '2006-01-03',
1738
                        'end_date': '2006-01-06'}}
1739
        algo = SetAssetDateBoundsAlgorithm(
1740
            equities_metadata=metadata,
1741
            sim_params=self.sim_params,
1742
            env=temp_env,
1743
        )
1744
        algo.run(df_source)
1745
1746
1747
class TestAccountControls(TestCase):
1748
1749
    @classmethod
1750
    def setUpClass(cls):
1751
        cls.sidint = 133
1752
        cls.env = TradingEnvironment()
1753
        cls.env.write_data(
1754
            equities_identifiers=[cls.sidint]
1755
        )
1756
1757
    @classmethod
1758
    def tearDownClass(cls):
1759
        del cls.env
1760
1761
    def setUp(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
1762
        self.sim_params = factory.create_simulation_parameters(
1763
            num_days=4, env=self.env
1764
        )
1765
        self.trade_history = factory.create_trade_history(
1766
            self.sidint,
1767
            [10.0, 10.0, 11.0, 11.0],
1768
            [100, 100, 100, 300],
1769
            timedelta(days=1),
1770
            self.sim_params,
1771
            self.env,
1772
        )
1773
1774
        self.source = SpecificEquityTrades(
1775
            event_list=self.trade_history,
1776
            env=self.env,
1777
        )
1778
1779
    def _check_algo(self,
1780
                    algo,
1781
                    handle_data,
1782
                    expected_exc):
1783
1784
        algo._handle_data = handle_data
1785
        with self.assertRaises(expected_exc) if expected_exc else nullctx():
1786
            algo.run(self.source)
1787
        self.source.rewind()
1788
1789
    def check_algo_succeeds(self, algo, handle_data):
1790
        # Default for order_count assumes one order per handle_data call.
1791
        self._check_algo(algo, handle_data, None)
1792
1793
    def check_algo_fails(self, algo, handle_data):
1794
        self._check_algo(algo,
1795
                         handle_data,
1796
                         AccountControlViolation)
1797
1798
    def test_set_max_leverage(self):
1799
1800
        # Set max leverage to 0 so buying one share fails.
1801
        def handle_data(algo, data):
1802
            algo.order(algo.sid(self.sidint), 1)
1803
1804
        algo = SetMaxLeverageAlgorithm(0, sim_params=self.sim_params,
1805
                                       env=self.env)
1806
        self.check_algo_fails(algo, handle_data)
1807
1808
        # Set max leverage to 1 so buying one share passes
1809
        def handle_data(algo, data):
1810
            algo.order(algo.sid(self.sidint), 1)
1811
1812
        algo = SetMaxLeverageAlgorithm(1,  sim_params=self.sim_params,
1813
                                       env=self.env)
1814
        self.check_algo_succeeds(algo, handle_data)
1815
1816
1817
class TestClosePosAlgo(TestCase):
1818
1819
    def setUp(self):
1820
        self.env = TradingEnvironment()
1821
        self.days = self.env.trading_days[:4]
1822
        self.panel = pd.Panel({1: pd.DataFrame({
1823
            'price': [1, 1, 2, 4], 'volume': [1e9, 1e9, 1e9, 0],
1824
            'type': [DATASOURCE_TYPE.TRADE,
1825
                     DATASOURCE_TYPE.TRADE,
1826
                     DATASOURCE_TYPE.TRADE,
1827
                     DATASOURCE_TYPE.CLOSE_POSITION]},
1828
            index=self.days)
1829
        })
1830
        self.no_close_panel = pd.Panel({1: pd.DataFrame({
1831
            'price': [1, 1, 2, 4], 'volume': [1e9, 1e9, 1e9, 1e9],
1832
            'type': [DATASOURCE_TYPE.TRADE,
1833
                     DATASOURCE_TYPE.TRADE,
1834
                     DATASOURCE_TYPE.TRADE,
1835
                     DATASOURCE_TYPE.TRADE]},
1836
            index=self.days)
1837
        })
1838
1839
    def test_close_position_equity(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
1840
        metadata = {1: {'symbol': 'TEST',
1841
                        'end_date': self.days[3]}}
1842
        self.env.write_data(equities_data=metadata)
1843
        algo = TestAlgorithm(sid=1, amount=1, order_count=1,
1844
                             commission=PerShare(0),
1845
                             env=self.env)
1846
        data = DataPanelSource(self.panel)
1847
1848
        # Check results
1849
        expected_positions = [0, 1, 1, 0]
1850
        expected_pnl = [0, 0, 1, 2]
1851
        results = algo.run(data)
1852
        self.check_algo_positions(results, expected_positions)
1853
        self.check_algo_pnl(results, expected_pnl)
1854
1855
    def test_close_position_future(self):
1856
        metadata = {1: {'symbol': 'TEST'}}
1857
        self.env.write_data(futures_data=metadata)
1858
        algo = TestAlgorithm(sid=1, amount=1, order_count=1,
1859
                             commission=PerShare(0),
1860
                             env=self.env)
1861
        data = DataPanelSource(self.panel)
1862
1863
        # Check results
1864
        expected_positions = [0, 1, 1, 0]
1865
        expected_pnl = [0, 0, 1, 2]
1866
        results = algo.run(data)
1867
        self.check_algo_pnl(results, expected_pnl)
1868
        self.check_algo_positions(results, expected_positions)
1869
1870
    def test_auto_close_future(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
1871
        metadata = {1: {'symbol': 'TEST',
1872
                        'auto_close_date': self.env.trading_days[4]}}
1873
        self.env.write_data(futures_data=metadata)
1874
        algo = TestAlgorithm(sid=1, amount=1, order_count=1,
1875
                             commission=PerShare(0),
1876
                             env=self.env)
1877
        data = DataPanelSource(self.no_close_panel)
1878
1879
        # Check results
1880
        results = algo.run(data)
1881
1882
        expected_positions = [0, 1, 1, 0]
1883
        self.check_algo_positions(results, expected_positions)
1884
1885
        expected_pnl = [0, 0, 1, 2]
1886
        self.check_algo_pnl(results, expected_pnl)
1887
1888
    def check_algo_pnl(self, results, expected_pnl):
1889
        np.testing.assert_array_almost_equal(results.pnl, expected_pnl)
1890
1891
    def check_algo_positions(self, results, expected_positions):
1892
        for i, amount in enumerate(results.positions):
1893
            if amount:
1894
                actual_position = amount[0]['amount']
1895
            else:
1896
                actual_position = 0
1897
1898
            self.assertEqual(
1899
                actual_position, expected_positions[i],
1900
                "position for day={0} not equal, actual={1}, expected={2}".
1901
                format(i, actual_position, expected_positions[i]))
1902
1903
1904
class TestFutureFlip(TestCase):
1905
    def setUp(self):
1906
        self.env = TradingEnvironment()
1907
        self.days = self.env.trading_days[:4]
1908
        self.trades_panel = pd.Panel({1: pd.DataFrame({
1909
            'price': [1, 2, 4], 'volume': [1e9, 1e9, 1e9],
1910
            'type': [DATASOURCE_TYPE.TRADE,
1911
                     DATASOURCE_TYPE.TRADE,
1912
                     DATASOURCE_TYPE.TRADE]},
1913
            index=self.days[:3])
1914
        })
1915
1916
    def test_flip_algo(self):
1917
        metadata = {1: {'symbol': 'TEST',
1918
                        'end_date': self.days[3],
1919
                        'contract_multiplier': 5}}
1920
        self.env.write_data(futures_data=metadata)
1921
1922
        algo = FutureFlipAlgo(sid=1, amount=1, env=self.env,
1923
                              commission=PerShare(0),
1924
                              order_count=0,  # not applicable but required
1925
                              instant_fill=True)
1926
        data = DataPanelSource(self.trades_panel)
1927
1928
        results = algo.run(data)
1929
1930
        expected_positions = [1, -1, 0]
1931
        self.check_algo_positions(results, expected_positions)
1932
1933
        expected_pnl = [0, 5, -10]
1934
        self.check_algo_pnl(results, expected_pnl)
1935
1936
    def check_algo_pnl(self, results, expected_pnl):
1937
        np.testing.assert_array_almost_equal(results.pnl, expected_pnl)
1938
1939
    def check_algo_positions(self, results, expected_positions):
1940
        for i, amount in enumerate(results.positions):
1941
            if amount:
1942
                actual_position = amount[0]['amount']
1943
            else:
1944
                actual_position = 0
1945
1946
            self.assertEqual(
1947
                actual_position, expected_positions[i],
1948
                "position for day={0} not equal, actual={1}, expected={2}".
1949
                format(i, actual_position, expected_positions[i]))
1950
1951
1952
class TestTradingAlgorithm(TestCase):
1953
    def setUp(self):
1954
        self.env = TradingEnvironment()
1955
        self.days = self.env.trading_days[:4]
1956
        self.panel = pd.Panel({1: pd.DataFrame({
1957
            'price': [1, 1, 2, 4], 'volume': [1e9, 1e9, 1e9, 0],
1958
            'type': [DATASOURCE_TYPE.TRADE,
1959
                     DATASOURCE_TYPE.TRADE,
1960
                     DATASOURCE_TYPE.TRADE,
1961
                     DATASOURCE_TYPE.CLOSE_POSITION]},
1962
            index=self.days)
1963
        })
1964
1965
    def test_analyze_called(self):
1966
        self.perf_ref = None
1967
1968
        def initialize(context):
1969
            pass
1970
1971
        def handle_data(context, data):
1972
            pass
1973
1974
        def analyze(context, perf):
1975
            self.perf_ref = perf
1976
1977
        algo = TradingAlgorithm(initialize=initialize, handle_data=handle_data,
1978
                                analyze=analyze)
1979
        results = algo.run(self.panel)
1980
        self.assertIs(results, self.perf_ref)
1981
1982
1983
class TestRemoveData(TestCase):
1984
    """
1985
    tests if futures data is removed after expiry
1986
    """
1987
    def setUp(self):
1988
        dt = pd.Timestamp('2015-01-02', tz='UTC')
1989
        env = TradingEnvironment()
1990
        ix = env.trading_days.get_loc(dt)
1991
1992
        metadata = {0: {'symbol': 'X',
1993
                        'expiration_date': env.trading_days[ix + 5],
1994
                        'end_date': env.trading_days[ix + 6]},
1995
                    1: {'symbol': 'Y',
1996
                        'expiration_date': env.trading_days[ix + 7],
1997
                        'end_date': env.trading_days[ix + 8]}}
1998
1999
        env.write_data(futures_data=metadata)
2000
2001
        index_x = env.trading_days[ix:ix + 5]
2002
        data_x = pd.DataFrame([[1, 100], [2, 100], [3, 100], [4, 100],
2003
                               [5, 100]],
2004
                              index=index_x, columns=['price', 'volume'])
2005
        index_y = env.trading_days[ix:ix + 5].shift(2)
2006
        data_y = pd.DataFrame([[6, 100], [7, 100], [8, 100], [9, 100],
2007
                               [10, 100]],
2008
                              index=index_y, columns=['price', 'volume'])
2009
2010
        pan = pd.Panel({0: data_x, 1: data_y})
2011
        self.source = DataPanelSource(pan)
2012
        self.algo = TestRemoveDataAlgo(env=env)
2013
2014
    def test_remove_data(self):
2015
        self.algo.run(self.source)
2016
2017
        expected_lengths = [1, 1, 2, 2, 2, 2, 1]
2018
        # initially only data for X should be sent and on the last day only
2019
        # data for Y should be sent since X is expired
2020
        np.testing.assert_array_equal(self.algo.data, expected_lengths)
2021