Completed
Pull Request — master (#846)
by Warren
03:06
created

tests.TestRemoveData   A

Complexity

Total Complexity 3

Size/Duplication

Total Lines 39
Duplicated Lines 0 %
Metric Value
dl 0
loc 39
rs 10
wmc 3

2 Methods

Rating   Name   Duplication   Size   Complexity  
A test_remove_data() 0 8 2
B setUp() 0 26 1
1
#
2
# Copyright 2014 Quantopian, Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
import datetime
16
from datetime import timedelta
17
from mock import MagicMock
18
from nose_parameterized import parameterized
19
from six.moves import range, map
20
from textwrap import dedent
21
from unittest import TestCase
22
23
import numpy as np
24
import pandas as pd
25
26
from zipline.utils.api_support import ZiplineAPI
27
from zipline.utils.control_flow import nullctx
28
from zipline.utils.test_utils import (
29
    setup_logger,
30
    teardown_logger
31
)
32
import zipline.utils.factory as factory
33
import zipline.utils.simfactory as simfactory
34
35
from zipline.errors import (
36
    OrderDuringInitialize,
37
    RegisterTradingControlPostInit,
38
    TradingControlViolation,
39
    AccountControlViolation,
40
    SymbolNotFound,
41
    RootSymbolNotFound,
42
    UnsupportedDatetimeFormat,
43
)
44
from zipline.test_algorithms import (
45
    access_account_in_init,
46
    access_portfolio_in_init,
47
    AmbitiousStopLimitAlgorithm,
48
    EmptyPositionsAlgorithm,
49
    InvalidOrderAlgorithm,
50
    RecordAlgorithm,
51
    FutureFlipAlgo,
52
    TestAlgorithm,
53
    TestOrderAlgorithm,
54
    TestOrderInstantAlgorithm,
55
    TestOrderPercentAlgorithm,
56
    TestOrderStyleForwardingAlgorithm,
57
    TestOrderValueAlgorithm,
58
    TestRegisterTransformAlgorithm,
59
    TestTargetAlgorithm,
60
    TestTargetPercentAlgorithm,
61
    TestTargetValueAlgorithm,
62
    TestRemoveDataAlgo,
63
    SetLongOnlyAlgorithm,
64
    SetAssetDateBoundsAlgorithm,
65
    SetMaxPositionSizeAlgorithm,
66
    SetMaxOrderCountAlgorithm,
67
    SetMaxOrderSizeAlgorithm,
68
    SetDoNotOrderListAlgorithm,
69
    SetMaxLeverageAlgorithm,
70
    api_algo,
71
    api_get_environment_algo,
72
    api_symbol_algo,
73
    call_all_order_methods,
74
    call_order_in_init,
75
    handle_data_api,
76
    handle_data_noop,
77
    initialize_api,
78
    initialize_noop,
79
    noop_algo,
80
    record_float_magic,
81
    record_variables,
82
)
83
from zipline.utils.context_tricks import CallbackManager
84
import zipline.utils.events
85
from zipline.utils.test_utils import (
86
    assert_single_position,
87
    drain_zipline,
88
    to_utc,
89
)
90
91
from zipline.sources import (SpecificEquityTrades,
92
                             DataFrameSource,
93
                             DataPanelSource,
94
                             RandomWalkSource)
95
from zipline.assets import Equity
96
97
from zipline.finance.execution import LimitOrder
98
from zipline.finance.trading import SimulationParameters
99
from zipline.utils.api_support import set_algo_instance
100
from zipline.utils.events import DateRuleFactory, TimeRuleFactory, Always
101
from zipline.algorithm import TradingAlgorithm
102
from zipline.protocol import DATASOURCE_TYPE
103
from zipline.finance.trading import TradingEnvironment
104
from zipline.finance.commission import PerShare
105
106
# Because test cases appear to reuse some resources.
107
_multiprocess_can_split_ = False
108
109
110
class TestRecordAlgorithm(TestCase):
111
112
    @classmethod
113
    def setUpClass(cls):
114
        cls.env = TradingEnvironment()
115
        cls.env.write_data(equities_identifiers=[133])
116
117
    @classmethod
118
    def tearDownClass(cls):
119
        del cls.env
120
121
    def setUp(self):
122
        self.sim_params = factory.create_simulation_parameters(num_days=4,
123
                                                               env=self.env)
124
        trade_history = factory.create_trade_history(
125
            133,
126
            [10.0, 10.0, 11.0, 11.0],
127
            [100, 100, 100, 300],
128
            timedelta(days=1),
129
            self.sim_params,
130
            self.env
131
        )
132
133
        self.source = SpecificEquityTrades(event_list=trade_history,
134
                                           env=self.env)
135
        self.df_source, self.df = \
136
            factory.create_test_df_source(self.sim_params, self.env)
137
138
    def test_record_incr(self):
139
        algo = RecordAlgorithm(sim_params=self.sim_params, env=self.env)
140
        output = algo.run(self.source)
141
142
        np.testing.assert_array_equal(output['incr'].values,
143
                                      range(1, len(output) + 1))
144
        np.testing.assert_array_equal(output['name'].values,
145
                                      range(1, len(output) + 1))
146
        np.testing.assert_array_equal(output['name2'].values,
147
                                      [2] * len(output))
148
        np.testing.assert_array_equal(output['name3'].values,
149
                                      range(1, len(output) + 1))
150
151
152
class TestMiscellaneousAPI(TestCase):
153
154
    @classmethod
155
    def setUpClass(cls):
156
        cls.sids = [1, 2]
157
        cls.env = TradingEnvironment()
158
159
        metadata = {3: {'symbol': 'PLAY',
160
                        'start_date': '2002-01-01',
161
                        'end_date': '2004-01-01'},
162
                    4: {'symbol': 'PLAY',
163
                        'start_date': '2005-01-01',
164
                        'end_date': '2006-01-01'}}
165
166
        futures_metadata = {
167
            5: {
168
                'symbol': 'CLG06',
169
                'root_symbol': 'CL',
170
                'start_date': pd.Timestamp('2005-12-01', tz='UTC'),
171
                'notice_date': pd.Timestamp('2005-12-20', tz='UTC'),
172
                'expiration_date': pd.Timestamp('2006-01-20', tz='UTC')},
173
            6: {
174
                'root_symbol': 'CL',
175
                'symbol': 'CLK06',
176
                'start_date': pd.Timestamp('2005-12-01', tz='UTC'),
177
                'notice_date': pd.Timestamp('2006-03-20', tz='UTC'),
178
                'expiration_date': pd.Timestamp('2006-04-20', tz='UTC')},
179
            7: {
180
                'symbol': 'CLQ06',
181
                'root_symbol': 'CL',
182
                'start_date': pd.Timestamp('2005-12-01', tz='UTC'),
183
                'notice_date': pd.Timestamp('2006-06-20', tz='UTC'),
184
                'expiration_date': pd.Timestamp('2006-07-20', tz='UTC')},
185
            8: {
186
                'symbol': 'CLX06',
187
                'root_symbol': 'CL',
188
                'start_date': pd.Timestamp('2006-02-01', tz='UTC'),
189
                'notice_date': pd.Timestamp('2006-09-20', tz='UTC'),
190
                'expiration_date': pd.Timestamp('2006-10-20', tz='UTC')}
191
        }
192
        cls.env.write_data(equities_identifiers=cls.sids,
193
                           equities_data=metadata,
194
                           futures_data=futures_metadata)
195
196
    @classmethod
197
    def tearDownClass(cls):
198
        del cls.env
199
200
    def setUp(self):
201
        setup_logger(self)
202
        self.sim_params = factory.create_simulation_parameters(
203
            num_days=2,
204
            data_frequency='minute',
205
            emission_rate='minute',
206
            env=self.env,
207
        )
208
        self.source = factory.create_minutely_trade_source(
209
            self.sids,
210
            sim_params=self.sim_params,
211
            concurrent=True,
212
            env=self.env,
213
        )
214
215
    def tearDown(self):
216
        teardown_logger(self)
217
218
    def test_zipline_api_resolves_dynamically(self):
219
        # Make a dummy algo.
220
        algo = TradingAlgorithm(
221
            initialize=lambda context: None,
222
            handle_data=lambda context, data: None,
223
            sim_params=self.sim_params,
224
        )
225
226
        # Verify that api methods get resolved dynamically by patching them out
227
        # and then calling them
228
        for method in algo.all_api_methods():
229
            name = method.__name__
230
            sentinel = object()
231
232
            def fake_method(*args, **kwargs):
233
                return sentinel
234
            setattr(algo, name, fake_method)
235
            with ZiplineAPI(algo):
236
                self.assertIs(sentinel, getattr(zipline.api, name)())
237
238
    def test_get_environment(self):
239
        expected_env = {
240
            'arena': 'backtest',
241
            'data_frequency': 'minute',
242
            'start': pd.Timestamp('2006-01-03 14:31:00+0000', tz='UTC'),
243
            'end': pd.Timestamp('2006-01-04 21:00:00+0000', tz='UTC'),
244
            'capital_base': 100000.0,
245
            'platform': 'zipline'
246
        }
247
248
        def initialize(algo):
249
            self.assertEqual('zipline', algo.get_environment())
250
            self.assertEqual(expected_env, algo.get_environment('*'))
251
252
        def handle_data(algo, data):
253
            pass
254
255
        algo = TradingAlgorithm(initialize=initialize,
256
                                handle_data=handle_data,
257
                                sim_params=self.sim_params,
258
                                env=self.env)
259
        algo.run(self.source)
260
261
    def test_get_open_orders(self):
262
263
        def initialize(algo):
264
            algo.minute = 0
265
266
        def handle_data(algo, data):
267
            if algo.minute == 0:
268
269
                # Should be filled by the next minute
270
                algo.order(algo.sid(1), 1)
271
272
                # Won't be filled because the price is too low.
273
                algo.order(algo.sid(2), 1, style=LimitOrder(0.01))
274
                algo.order(algo.sid(2), 1, style=LimitOrder(0.01))
275
                algo.order(algo.sid(2), 1, style=LimitOrder(0.01))
276
277
                all_orders = algo.get_open_orders()
278
                self.assertEqual(list(all_orders.keys()), [1, 2])
279
280
                self.assertEqual(all_orders[1], algo.get_open_orders(1))
281
                self.assertEqual(len(all_orders[1]), 1)
282
283
                self.assertEqual(all_orders[2], algo.get_open_orders(2))
284
                self.assertEqual(len(all_orders[2]), 3)
285
286
            if algo.minute == 1:
287
                # First order should have filled.
288
                # Second order should still be open.
289
                all_orders = algo.get_open_orders()
290
                self.assertEqual(list(all_orders.keys()), [2])
291
292
                self.assertEqual([], algo.get_open_orders(1))
293
294
                orders_2 = algo.get_open_orders(2)
295
                self.assertEqual(all_orders[2], orders_2)
296
                self.assertEqual(len(all_orders[2]), 3)
297
298
                for order in orders_2:
299
                    algo.cancel_order(order)
300
301
                all_orders = algo.get_open_orders()
302
                self.assertEqual(all_orders, {})
303
304
            algo.minute += 1
305
306
        algo = TradingAlgorithm(initialize=initialize,
307
                                handle_data=handle_data,
308
                                sim_params=self.sim_params,
309
                                env=self.env)
310
        algo.run(self.source)
311
312
    def test_schedule_function(self):
313
        date_rules = DateRuleFactory
314
        time_rules = TimeRuleFactory
315
316
        def incrementer(algo, data):
317
            algo.func_called += 1
318
            self.assertEqual(
319
                algo.get_datetime().time(),
320
                datetime.time(hour=14, minute=31),
321
            )
322
323
        def initialize(algo):
324
            algo.func_called = 0
325
            algo.days = 1
326
            algo.date = None
327
            algo.schedule_function(
328
                func=incrementer,
329
                date_rule=date_rules.every_day(),
330
                time_rule=time_rules.market_open(),
331
            )
332
333
        def handle_data(algo, data):
334
            if not algo.date:
335
                algo.date = algo.get_datetime().date()
336
337
            if algo.date < algo.get_datetime().date():
338
                algo.days += 1
339
                algo.date = algo.get_datetime().date()
340
341
        algo = TradingAlgorithm(
342
            initialize=initialize,
343
            handle_data=handle_data,
344
            sim_params=self.sim_params,
345
            env=self.env,
346
        )
347
        algo.run(self.source)
348
349
        self.assertEqual(algo.func_called, algo.days)
350
351
    def test_event_context(self):
352
        expected_data = []
353
        collected_data_pre = []
354
        collected_data_post = []
355
        function_stack = []
356
357
        def pre(data):
358
            function_stack.append(pre)
359
            collected_data_pre.append(data)
360
361
        def post(data):
362
            function_stack.append(post)
363
            collected_data_post.append(data)
364
365
        def initialize(context):
366
            context.add_event(Always(), f)
367
            context.add_event(Always(), g)
368
369
        def handle_data(context, data):
370
            function_stack.append(handle_data)
371
            expected_data.append(data)
372
373
        def f(context, data):
374
            function_stack.append(f)
375
376
        def g(context, data):
377
            function_stack.append(g)
378
379
        algo = TradingAlgorithm(
380
            initialize=initialize,
381
            handle_data=handle_data,
382
            sim_params=self.sim_params,
383
            create_event_context=CallbackManager(pre, post),
384
            env=self.env,
385
        )
386
        algo.run(self.source)
387
388
        self.assertEqual(len(expected_data), 779)
389
        self.assertEqual(collected_data_pre, expected_data)
390
        self.assertEqual(collected_data_post, expected_data)
391
392
        self.assertEqual(
393
            len(function_stack),
394
            779 * 5,
395
            'Incorrect number of functions called: %s != 779' %
396
            len(function_stack),
397
        )
398
        expected_functions = [pre, handle_data, f, g, post] * 779
399
        for n, (f, g) in enumerate(zip(function_stack, expected_functions)):
400
            self.assertEqual(
401
                f,
402
                g,
403
                'function at position %d was incorrect, expected %s but got %s'
404
                % (n, g.__name__, f.__name__),
405
            )
406
407
    @parameterized.expand([
408
        ('daily',),
409
        ('minute'),
410
    ])
411
    def test_schedule_funtion_rule_creation(self, mode):
412
        def nop(*args, **kwargs):
413
            return None
414
415
        self.sim_params.data_frequency = mode
416
        algo = TradingAlgorithm(
417
            initialize=nop,
418
            handle_data=nop,
419
            sim_params=self.sim_params,
420
            env=self.env,
421
        )
422
423
        # Schedule something for NOT Always.
424
        algo.schedule_function(nop, time_rule=zipline.utils.events.Never())
425
426
        event_rule = algo.event_manager._events[1].rule
427
428
        self.assertIsInstance(event_rule, zipline.utils.events.OncePerDay)
429
430
        inner_rule = event_rule.rule
431
        self.assertIsInstance(inner_rule, zipline.utils.events.ComposedRule)
432
433
        first = inner_rule.first
434
        second = inner_rule.second
435
        composer = inner_rule.composer
436
437
        self.assertIsInstance(first, zipline.utils.events.Always)
438
439
        if mode == 'daily':
440
            self.assertIsInstance(second, zipline.utils.events.Always)
441
        else:
442
            self.assertIsInstance(second, zipline.utils.events.Never)
443
444
        self.assertIs(composer, zipline.utils.events.ComposedRule.lazy_and)
445
446
    def test_asset_lookup(self):
447
448
        algo = TradingAlgorithm(env=self.env)
449
450
        # Test before either PLAY existed
451
        algo.sim_params.period_end = pd.Timestamp('2001-12-01', tz='UTC')
452
        with self.assertRaises(SymbolNotFound):
453
            algo.symbol('PLAY')
454
        with self.assertRaises(SymbolNotFound):
455
            algo.symbols('PLAY')
456
457
        # Test when first PLAY exists
458
        algo.sim_params.period_end = pd.Timestamp('2002-12-01', tz='UTC')
459
        list_result = algo.symbols('PLAY')
460
        self.assertEqual(3, list_result[0])
461
462
        # Test after first PLAY ends
463
        algo.sim_params.period_end = pd.Timestamp('2004-12-01', tz='UTC')
464
        self.assertEqual(3, algo.symbol('PLAY'))
465
466
        # Test after second PLAY begins
467
        algo.sim_params.period_end = pd.Timestamp('2005-12-01', tz='UTC')
468
        self.assertEqual(4, algo.symbol('PLAY'))
469
470
        # Test after second PLAY ends
471
        algo.sim_params.period_end = pd.Timestamp('2006-12-01', tz='UTC')
472
        self.assertEqual(4, algo.symbol('PLAY'))
473
        list_result = algo.symbols('PLAY')
474
        self.assertEqual(4, list_result[0])
475
476
        # Test lookup SID
477
        self.assertIsInstance(algo.sid(3), Equity)
478
        self.assertIsInstance(algo.sid(4), Equity)
479
480
        # Supplying a non-string argument to symbol()
481
        # should result in a TypeError.
482
        with self.assertRaises(TypeError):
483
            algo.symbol(1)
484
485
        with self.assertRaises(TypeError):
486
            algo.symbol((1,))
487
488
        with self.assertRaises(TypeError):
489
            algo.symbol({1})
490
491
        with self.assertRaises(TypeError):
492
            algo.symbol([1])
493
494
        with self.assertRaises(TypeError):
495
            algo.symbol({'foo': 'bar'})
496
497
    def test_future_symbol(self):
498
        """ Tests the future_symbol API function.
499
        """
500
        algo = TradingAlgorithm(env=self.env)
501
        algo.datetime = pd.Timestamp('2006-12-01', tz='UTC')
502
503
        # Check that we get the correct fields for the CLG06 symbol
504
        cl = algo.future_symbol('CLG06')
505
        self.assertEqual(cl.sid, 5)
506
        self.assertEqual(cl.symbol, 'CLG06')
507
        self.assertEqual(cl.root_symbol, 'CL')
508
        self.assertEqual(cl.start_date, pd.Timestamp('2005-12-01', tz='UTC'))
509
        self.assertEqual(cl.notice_date, pd.Timestamp('2005-12-20', tz='UTC'))
510
        self.assertEqual(cl.expiration_date,
511
                         pd.Timestamp('2006-01-20', tz='UTC'))
512
513
        with self.assertRaises(SymbolNotFound):
514
            algo.future_symbol('')
515
516
        with self.assertRaises(SymbolNotFound):
517
            algo.future_symbol('PLAY')
518
519
        with self.assertRaises(SymbolNotFound):
520
            algo.future_symbol('FOOBAR')
521
522
        # Supplying a non-string argument to future_symbol()
523
        # should result in a TypeError.
524
        with self.assertRaises(TypeError):
525
            algo.future_symbol(1)
526
527
        with self.assertRaises(TypeError):
528
            algo.future_symbol((1,))
529
530
        with self.assertRaises(TypeError):
531
            algo.future_symbol({1})
532
533
        with self.assertRaises(TypeError):
534
            algo.future_symbol([1])
535
536
        with self.assertRaises(TypeError):
537
            algo.future_symbol({'foo': 'bar'})
538
539
    def test_future_chain(self):
540
        """ Tests the future_chain API function.
541
        """
542
        algo = TradingAlgorithm(env=self.env)
543
        algo.datetime = pd.Timestamp('2006-12-01', tz='UTC')
544
545
        # Check that the fields of the FutureChain object are set correctly
546
        cl = algo.future_chain('CL')
547
        self.assertEqual(cl.root_symbol, 'CL')
548
        self.assertEqual(cl.as_of_date, algo.datetime)
549
550
        # Check the fields are set correctly if an as_of_date is supplied
551
        as_of_date = pd.Timestamp('1952-08-11', tz='UTC')
552
553
        cl = algo.future_chain('CL', as_of_date=as_of_date)
554
        self.assertEqual(cl.root_symbol, 'CL')
555
        self.assertEqual(cl.as_of_date, as_of_date)
556
557
        cl = algo.future_chain('CL', as_of_date='1952-08-11')
558
        self.assertEqual(cl.root_symbol, 'CL')
559
        self.assertEqual(cl.as_of_date, as_of_date)
560
561
        # Check that weird capitalization is corrected
562
        cl = algo.future_chain('cL')
563
        self.assertEqual(cl.root_symbol, 'CL')
564
565
        cl = algo.future_chain('cl')
566
        self.assertEqual(cl.root_symbol, 'CL')
567
568
        # Check that invalid root symbols raise RootSymbolNotFound
569
        with self.assertRaises(RootSymbolNotFound):
570
            algo.future_chain('CLZ')
571
572
        with self.assertRaises(RootSymbolNotFound):
573
            algo.future_chain('')
574
575
        # Check that invalid dates raise UnsupportedDatetimeFormat
576
        with self.assertRaises(UnsupportedDatetimeFormat):
577
            algo.future_chain('CL', 'my_finger_slipped')
578
579
        with self.assertRaises(UnsupportedDatetimeFormat):
580
            algo.future_chain('CL', '2015-09-')
581
582
        # Supplying a non-string argument to future_chain()
583
        # should result in a TypeError.
584
        with self.assertRaises(TypeError):
585
            algo.future_chain(1)
586
587
        with self.assertRaises(TypeError):
588
            algo.future_chain((1,))
589
590
        with self.assertRaises(TypeError):
591
            algo.future_chain({1})
592
593
        with self.assertRaises(TypeError):
594
            algo.future_chain([1])
595
596
        with self.assertRaises(TypeError):
597
            algo.future_chain({'foo': 'bar'})
598
599
    def test_set_symbol_lookup_date(self):
600
        """
601
        Test the set_symbol_lookup_date API method.
602
        """
603
        # Note we start sid enumeration at i+3 so as not to
604
        # collide with sids [1, 2] added in the setUp() method.
605
        dates = pd.date_range('2013-01-01', freq='2D', periods=2, tz='UTC')
606
        # Create two assets with the same symbol but different
607
        # non-overlapping date ranges.
608
        metadata = pd.DataFrame.from_records(
609
            [
610
                {
611
                    'sid': i + 3,
612
                    'symbol': 'DUP',
613
                    'start_date': date.value,
614
                    'end_date': (date + timedelta(days=1)).value,
615
                }
616
                for i, date in enumerate(dates)
617
            ]
618
        )
619
        env = TradingEnvironment()
620
        env.write_data(equities_df=metadata)
621
        algo = TradingAlgorithm(env=env)
622
623
        # Set the period end to a date after the period end
624
        # dates for our assets.
625
        algo.sim_params.period_end = pd.Timestamp('2015-01-01', tz='UTC')
626
627
        # With no symbol lookup date set, we will use the period end date
628
        # for the as_of_date, resulting here in the asset with the earlier
629
        # start date being returned.
630
        result = algo.symbol('DUP')
631
        self.assertEqual(result.symbol, 'DUP')
632
633
        # By first calling set_symbol_lookup_date, the relevant asset
634
        # should be returned by lookup_symbol
635
        for i, date in enumerate(dates):
636
            algo.set_symbol_lookup_date(date)
637
            result = algo.symbol('DUP')
638
            self.assertEqual(result.symbol, 'DUP')
639
            self.assertEqual(result.sid, i + 3)
640
641
        with self.assertRaises(UnsupportedDatetimeFormat):
642
            algo.set_symbol_lookup_date('foobar')
643
644
645
class TestTransformAlgorithm(TestCase):
646
647
    @classmethod
648
    def setUpClass(cls):
649
        futures_metadata = {3: {'contract_multiplier': 10}}
650
        cls.env = TradingEnvironment()
651
        cls.env.write_data(equities_identifiers=[0, 1, 133],
652
                           futures_data=futures_metadata)
653
654
    @classmethod
655
    def tearDownClass(cls):
656
        del cls.env
657
658
    def setUp(self):
659
        setup_logger(self)
660
        self.sim_params = factory.create_simulation_parameters(num_days=4,
661
                                                               env=self.env)
662
663
        trade_history = factory.create_trade_history(
664
            133,
665
            [10.0, 10.0, 11.0, 11.0],
666
            [100, 100, 100, 300],
667
            timedelta(days=1),
668
            self.sim_params,
669
            self.env
670
        )
671
        self.source = SpecificEquityTrades(
672
            event_list=trade_history,
673
            env=self.env,
674
        )
675
        self.df_source, self.df = \
676
            factory.create_test_df_source(self.sim_params, self.env)
677
678
        self.panel_source, self.panel = \
679
            factory.create_test_panel_source(self.sim_params, self.env)
680
681
    def tearDown(self):
682
        teardown_logger(self)
683
684
    def test_source_as_input(self):
685
        algo = TestRegisterTransformAlgorithm(
686
            sim_params=self.sim_params,
687
            env=self.env,
688
            sids=[133]
689
        )
690
        algo.run(self.source)
691
        self.assertEqual(len(algo.sources), 1)
692
        assert isinstance(algo.sources[0], SpecificEquityTrades)
693
694
    def test_invalid_order_parameters(self):
695
        algo = InvalidOrderAlgorithm(
696
            sids=[133],
697
            sim_params=self.sim_params,
698
            env=self.env,
699
        )
700
        algo.run(self.source)
701
702
    def test_multi_source_as_input(self):
703
        sim_params = SimulationParameters(
704
            self.df.index[0],
705
            self.df.index[-1],
706
            env=self.env,
707
        )
708
        algo = TestRegisterTransformAlgorithm(
709
            sim_params=sim_params,
710
            sids=[0, 1],
711
            env=self.env,
712
        )
713
        algo.run([self.source, self.df_source], overwrite_sim_params=False)
714
        self.assertEqual(len(algo.sources), 2)
715
716
    def test_df_as_input(self):
717
        algo = TestRegisterTransformAlgorithm(
718
            sim_params=self.sim_params,
719
            env=self.env,
720
        )
721
        algo.run(self.df)
722
        assert isinstance(algo.sources[0], DataFrameSource)
723
724
    def test_panel_as_input(self):
725
        algo = TestRegisterTransformAlgorithm(
726
            sim_params=self.sim_params,
727
            env=self.env,
728
            sids=[0, 1])
729
        panel = self.panel.copy()
730
        panel.items = pd.Index(map(Equity, panel.items))
731
        algo.run(panel)
732
        assert isinstance(algo.sources[0], DataPanelSource)
733
734
    def test_df_of_assets_as_input(self):
735
        algo = TestRegisterTransformAlgorithm(
736
            sim_params=self.sim_params,
737
            env=TradingEnvironment(),  # new env without assets
738
        )
739
        df = self.df.copy()
740
        df.columns = pd.Index(map(Equity, df.columns))
741
        algo.run(df)
742
        assert isinstance(algo.sources[0], DataFrameSource)
743
744
    def test_panel_of_assets_as_input(self):
745
        algo = TestRegisterTransformAlgorithm(
746
            sim_params=self.sim_params,
747
            env=TradingEnvironment(),  # new env without assets
748
            sids=[0, 1])
749
        algo.run(self.panel)
750
        assert isinstance(algo.sources[0], DataPanelSource)
751
752
    def test_run_twice(self):
753
        algo1 = TestRegisterTransformAlgorithm(
754
            sim_params=self.sim_params,
755
            sids=[0, 1]
756
        )
757
758
        res1 = algo1.run(self.df)
759
760
        # Create a new trading algorithm, which will
761
        # use the newly instantiated environment.
762
        algo2 = TestRegisterTransformAlgorithm(
763
            sim_params=self.sim_params,
764
            sids=[0, 1]
765
        )
766
767
        res2 = algo2.run(self.df)
768
769
        np.testing.assert_array_equal(res1, res2)
770
771
    def test_data_frequency_setting(self):
772
        self.sim_params.data_frequency = 'daily'
773
        algo = TestRegisterTransformAlgorithm(
774
            sim_params=self.sim_params,
775
            env=self.env,
776
        )
777
        self.assertEqual(algo.sim_params.data_frequency, 'daily')
778
779
        self.sim_params.data_frequency = 'minute'
780
        algo = TestRegisterTransformAlgorithm(
781
            sim_params=self.sim_params,
782
            env=self.env,
783
        )
784
        self.assertEqual(algo.sim_params.data_frequency, 'minute')
785
786
    @parameterized.expand([
787
        (TestOrderAlgorithm,),
788
        (TestOrderValueAlgorithm,),
789
        (TestTargetAlgorithm,),
790
        (TestOrderPercentAlgorithm,),
791
        (TestTargetPercentAlgorithm,),
792
        (TestTargetValueAlgorithm,),
793
    ])
794
    def test_order_methods(self, algo_class):
795
        algo = algo_class(
796
            sim_params=self.sim_params,
797
            env=self.env,
798
        )
799
        algo.run(self.df)
800
801
    @parameterized.expand([
802
        (TestOrderAlgorithm,),
803
        (TestOrderValueAlgorithm,),
804
        (TestTargetAlgorithm,),
805
        (TestOrderPercentAlgorithm,),
806
        (TestTargetValueAlgorithm,),
807
    ])
808
    def test_order_methods_for_future(self, algo_class):
809
        algo = algo_class(
810
            sim_params=self.sim_params,
811
            env=self.env,
812
        )
813
        algo.run(self.df)
814
815
    def test_order_method_style_forwarding(self):
816
817
        method_names_to_test = ['order',
818
                                'order_value',
819
                                'order_percent',
820
                                'order_target',
821
                                'order_target_percent',
822
                                'order_target_value']
823
824
        for name in method_names_to_test:
825
            # Don't supply an env so the TradingAlgorithm builds a new one for
826
            # each method
827
            algo = TestOrderStyleForwardingAlgorithm(
828
                sim_params=self.sim_params,
829
                instant_fill=False,
830
                method_name=name
831
            )
832
            algo.run(self.df)
833
834
    def test_order_instant(self):
835
        algo = TestOrderInstantAlgorithm(sim_params=self.sim_params,
836
                                         env=self.env,
837
                                         instant_fill=True)
838
        algo.run(self.df)
839
840
    def test_minute_data(self):
841
        source = RandomWalkSource(freq='minute',
842
                                  start=pd.Timestamp('2000-1-3',
843
                                                     tz='UTC'),
844
                                  end=pd.Timestamp('2000-1-4',
845
                                                   tz='UTC'))
846
        self.sim_params.data_frequency = 'minute'
847
        algo = TestOrderInstantAlgorithm(sim_params=self.sim_params,
848
                                         env=self.env,
849
                                         instant_fill=True)
850
        algo.run(source)
851
852
853
class TestPositions(TestCase):
854
855
    def setUp(self):
856
        setup_logger(self)
857
        self.env = TradingEnvironment()
858
        self.sim_params = factory.create_simulation_parameters(num_days=4,
859
                                                               env=self.env)
860
        self.env.write_data(equities_identifiers=[1, 133])
861
862
        trade_history = factory.create_trade_history(
863
            1,
864
            [10.0, 10.0, 11.0, 11.0],
865
            [100, 100, 100, 300],
866
            timedelta(days=1),
867
            self.sim_params,
868
            self.env
869
        )
870
        self.source = SpecificEquityTrades(
871
            event_list=trade_history,
872
            env=self.env,
873
        )
874
875
        self.df_source, self.df = \
876
            factory.create_test_df_source(self.sim_params, self.env)
877
878
    def tearDown(self):
879
        teardown_logger(self)
880
881
    def test_empty_portfolio(self):
882
        algo = EmptyPositionsAlgorithm(sim_params=self.sim_params,
883
                                       env=self.env)
884
        daily_stats = algo.run(self.df)
885
886
        expected_position_count = [
887
            0,  # Before entering the first position
888
            1,  # After entering, exiting on this date
889
            0,  # After exiting
890
            0,
891
        ]
892
893
        for i, expected in enumerate(expected_position_count):
894
            self.assertEqual(daily_stats.ix[i]['num_positions'],
895
                             expected)
896
897
    def test_noop_orders(self):
898
899
        algo = AmbitiousStopLimitAlgorithm(sid=1,
900
                                           sim_params=self.sim_params,
901
                                           env=self.env)
902
        daily_stats = algo.run(self.source)
903
904
        # Verify that possitions are empty for all dates.
905
        empty_positions = daily_stats.positions.map(lambda x: len(x) == 0)
906
        self.assertTrue(empty_positions.all())
907
908
909
class TestAlgoScript(TestCase):
910
911
    @classmethod
912
    def setUpClass(cls):
913
        cls.env = TradingEnvironment()
914
        cls.env.write_data(
915
            equities_identifiers=[0, 1, 133]
916
        )
917
918
    @classmethod
919
    def tearDownClass(cls):
920
        del cls.env
921
922
    def setUp(self):
923
        days = 251
924
        # Note that create_simulation_parameters creates
925
        # a new TradingEnvironment
926
        self.sim_params = factory.create_simulation_parameters(num_days=days,
927
                                                               env=self.env)
928
929
        setup_logger(self)
930
        trade_history = factory.create_trade_history(
931
            133,
932
            [10.0] * days,
933
            [100] * days,
934
            timedelta(days=1),
935
            self.sim_params,
936
            self.env
937
        )
938
939
        self.source = SpecificEquityTrades(
940
            sids=[133],
941
            event_list=trade_history,
942
            env=self.env,
943
        )
944
945
        self.df_source, self.df = \
946
            factory.create_test_df_source(self.sim_params, self.env)
947
948
        self.zipline_test_config = {
949
            'sid': 0,
950
        }
951
952
    def tearDown(self):
953
        teardown_logger(self)
954
955
    def test_noop(self):
956
        algo = TradingAlgorithm(initialize=initialize_noop,
957
                                handle_data=handle_data_noop)
958
        algo.run(self.df)
959
960
    def test_noop_string(self):
961
        algo = TradingAlgorithm(script=noop_algo)
962
        algo.run(self.df)
963
964
    def test_api_calls(self):
965
        algo = TradingAlgorithm(initialize=initialize_api,
966
                                handle_data=handle_data_api)
967
        algo.run(self.df)
968
969
    def test_api_calls_string(self):
970
        algo = TradingAlgorithm(script=api_algo)
971
        algo.run(self.df)
972
973
    def test_api_get_environment(self):
974
        platform = 'zipline'
975
        # Use sid not already in test database.
976
        metadata = {3: {'symbol': 'TEST'}}
977
        algo = TradingAlgorithm(script=api_get_environment_algo,
978
                                equities_metadata=metadata,
979
                                platform=platform)
980
        algo.run(self.df)
981
        self.assertEqual(algo.environment, platform)
982
983
    def test_api_symbol(self):
984
        # Use sid not already in test database.
985
        metadata = {3: {'symbol': 'TEST'}}
986
        algo = TradingAlgorithm(script=api_symbol_algo,
987
                                equities_metadata=metadata)
988
        algo.run(self.df)
989
990
    def test_fixed_slippage(self):
991
        # verify order -> transaction -> portfolio position.
992
        # --------------
993
        test_algo = TradingAlgorithm(
994
            script="""
995
from zipline.api import (slippage,
996
                         commission,
997
                         set_slippage,
998
                         set_commission,
999
                         order,
1000
                         record,
1001
                         sid)
1002
1003
def initialize(context):
1004
    model = slippage.FixedSlippage(spread=0.10)
1005
    set_slippage(model)
1006
    set_commission(commission.PerTrade(100.00))
1007
    context.count = 1
1008
    context.incr = 0
1009
1010
def handle_data(context, data):
1011
    if context.incr < context.count:
1012
        order(sid(0), -1000)
1013
    record(price=data[0].price)
1014
1015
    context.incr += 1""",
1016
            sim_params=self.sim_params,
1017
            env=self.env,
1018
        )
1019
        set_algo_instance(test_algo)
1020
1021
        self.zipline_test_config['algorithm'] = test_algo
1022
        self.zipline_test_config['trade_count'] = 200
1023
1024
        # this matches the value in the algotext initialize
1025
        # method, and will be used inside assert_single_position
1026
        # to confirm we have as many transactions as orders we
1027
        # placed.
1028
        self.zipline_test_config['order_count'] = 1
1029
1030
        zipline = simfactory.create_test_zipline(
1031
            **self.zipline_test_config)
1032
1033
        output, _ = assert_single_position(self, zipline)
1034
1035
        # confirm the slippage and commission on a sample
1036
        # transaction
1037
        recorded_price = output[1]['daily_perf']['recorded_vars']['price']
1038
        transaction = output[1]['daily_perf']['transactions'][0]
1039
        self.assertEqual(100.0, transaction['commission'])
1040
        expected_spread = 0.05
1041
        expected_commish = 0.10
1042
        expected_price = recorded_price - expected_spread - expected_commish
1043
        self.assertEqual(expected_price, transaction['price'])
1044
1045
    def test_volshare_slippage(self):
1046
        # verify order -> transaction -> portfolio position.
1047
        # --------------
1048
        test_algo = TradingAlgorithm(
1049
            script="""
1050
from zipline.api import *
1051
1052
def initialize(context):
1053
    model = slippage.VolumeShareSlippage(
1054
                            volume_limit=.3,
1055
                            price_impact=0.05
1056
                       )
1057
    set_slippage(model)
1058
    set_commission(commission.PerShare(0.02))
1059
    context.count = 2
1060
    context.incr = 0
1061
1062
def handle_data(context, data):
1063
    if context.incr < context.count:
1064
        # order small lots to be sure the
1065
        # order will fill in a single transaction
1066
        order(sid(0), 5000)
1067
    record(price=data[0].price)
1068
    record(volume=data[0].volume)
1069
    record(incr=context.incr)
1070
    context.incr += 1
1071
    """,
1072
            sim_params=self.sim_params,
1073
            env=self.env,
1074
        )
1075
        set_algo_instance(test_algo)
1076
1077
        self.zipline_test_config['algorithm'] = test_algo
1078
        self.zipline_test_config['trade_count'] = 100
1079
1080
        # 67 will be used inside assert_single_position
1081
        # to confirm we have as many transactions as expected.
1082
        # The algo places 2 trades of 5000 shares each. The trade
1083
        # events have volume ranging from 100 to 950. The volume cap
1084
        # of 0.3 limits the trade volume to a range of 30 - 316 shares.
1085
        # The spreadsheet linked below calculates the total position
1086
        # size over each bar, and predicts 67 txns will be required
1087
        # to fill the two orders. The number of bars and transactions
1088
        # differ because some bars result in multiple txns. See
1089
        # spreadsheet for details:
1090
# https://www.dropbox.com/s/ulrk2qt0nrtrigb/Volume%20Share%20Worksheet.xlsx
1091
        self.zipline_test_config['expected_transactions'] = 67
1092
1093
        zipline = simfactory.create_test_zipline(
1094
            **self.zipline_test_config)
1095
        output, _ = assert_single_position(self, zipline)
1096
1097
        # confirm the slippage and commission on a sample
1098
        # transaction
1099
        per_share_commish = 0.02
1100
        perf = output[1]
1101
        transaction = perf['daily_perf']['transactions'][0]
1102
        commish = transaction['amount'] * per_share_commish
1103
        self.assertEqual(commish, transaction['commission'])
1104
        self.assertEqual(2.029, transaction['price'])
1105
1106
    def test_algo_record_vars(self):
1107
        test_algo = TradingAlgorithm(
1108
            script=record_variables,
1109
            sim_params=self.sim_params,
1110
            env=self.env,
1111
        )
1112
        set_algo_instance(test_algo)
1113
1114
        self.zipline_test_config['algorithm'] = test_algo
1115
        self.zipline_test_config['trade_count'] = 200
1116
1117
        zipline = simfactory.create_test_zipline(
1118
            **self.zipline_test_config)
1119
        output, _ = drain_zipline(self, zipline)
1120
        self.assertEqual(len(output), 252)
1121
        incr = []
1122
        for o in output[:200]:
1123
            incr.append(o['daily_perf']['recorded_vars']['incr'])
1124
1125
        np.testing.assert_array_equal(incr, range(1, 201))
1126
1127
    def test_algo_record_allow_mock(self):
1128
        """
1129
        Test that values from "MagicMock"ed methods can be passed to record.
1130
1131
        Relevant for our basic/validation and methods like history, which
1132
        will end up returning a MagicMock instead of a DataFrame.
1133
        """
1134
        test_algo = TradingAlgorithm(
1135
            script=record_variables,
1136
            sim_params=self.sim_params,
1137
        )
1138
        set_algo_instance(test_algo)
1139
1140
        test_algo.record(foo=MagicMock())
1141
1142
    def _algo_record_float_magic_should_pass(self, var_type):
1143
        test_algo = TradingAlgorithm(
1144
            script=record_float_magic % var_type,
1145
            sim_params=self.sim_params,
1146
            env=self.env,
1147
        )
1148
        set_algo_instance(test_algo)
1149
1150
        self.zipline_test_config['algorithm'] = test_algo
1151
        self.zipline_test_config['trade_count'] = 200
1152
1153
        zipline = simfactory.create_test_zipline(
1154
            **self.zipline_test_config)
1155
        output, _ = drain_zipline(self, zipline)
1156
        self.assertEqual(len(output), 252)
1157
        incr = []
1158
        for o in output[:200]:
1159
            incr.append(o['daily_perf']['recorded_vars']['data'])
1160
        np.testing.assert_array_equal(incr, [np.nan] * 200)
1161
1162
    def test_algo_record_nan(self):
1163
        self._algo_record_float_magic_should_pass('nan')
1164
1165
    def test_order_methods(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
1166
        """
1167
        Only test that order methods can be called without error.
1168
        Correct filling of orders is tested in zipline.
1169
        """
1170
        test_algo = TradingAlgorithm(
1171
            script=call_all_order_methods,
1172
            sim_params=self.sim_params,
1173
            env=self.env,
1174
        )
1175
        set_algo_instance(test_algo)
1176
1177
        self.zipline_test_config['algorithm'] = test_algo
1178
        self.zipline_test_config['trade_count'] = 200
1179
1180
        zipline = simfactory.create_test_zipline(
1181
            **self.zipline_test_config)
1182
1183
        output, _ = drain_zipline(self, zipline)
1184
1185
    def test_order_in_init(self):
1186
        """
1187
        Test that calling order in initialize
1188
        will raise an error.
1189
        """
1190
        with self.assertRaises(OrderDuringInitialize):
1191
            test_algo = TradingAlgorithm(
1192
                script=call_order_in_init,
1193
                sim_params=self.sim_params,
1194
                env=self.env,
1195
            )
1196
            set_algo_instance(test_algo)
1197
            test_algo.run(self.source)
1198
1199
    def test_portfolio_in_init(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
1200
        """
1201
        Test that accessing portfolio in init doesn't break.
1202
        """
1203
        test_algo = TradingAlgorithm(
1204
            script=access_portfolio_in_init,
1205
            sim_params=self.sim_params,
1206
            env=self.env,
1207
        )
1208
        set_algo_instance(test_algo)
1209
1210
        self.zipline_test_config['algorithm'] = test_algo
1211
        self.zipline_test_config['trade_count'] = 1
1212
1213
        zipline = simfactory.create_test_zipline(
1214
            **self.zipline_test_config)
1215
1216
        output, _ = drain_zipline(self, zipline)
1217
1218
    def test_account_in_init(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
1219
        """
1220
        Test that accessing account in init doesn't break.
1221
        """
1222
        test_algo = TradingAlgorithm(
1223
            script=access_account_in_init,
1224
            sim_params=self.sim_params,
1225
            env=self.env,
1226
        )
1227
        set_algo_instance(test_algo)
1228
1229
        self.zipline_test_config['algorithm'] = test_algo
1230
        self.zipline_test_config['trade_count'] = 1
1231
1232
        zipline = simfactory.create_test_zipline(
1233
            **self.zipline_test_config)
1234
1235
        output, _ = drain_zipline(self, zipline)
1236
1237
1238
class TestHistory(TestCase):
1239
1240
    def setUp(self):
1241
        setup_logger(self)
1242
1243
    def tearDown(self):
1244
        teardown_logger(self)
1245
1246
    @classmethod
1247
    def setUpClass(cls):
1248
        cls._start = pd.Timestamp('1991-01-01', tz='UTC')
1249
        cls._end = pd.Timestamp('1991-01-15', tz='UTC')
1250
        cls.env = TradingEnvironment()
1251
        cls.sim_params = factory.create_simulation_parameters(
1252
            data_frequency='minute',
1253
            env=cls.env
1254
        )
1255
        cls.env.write_data(equities_identifiers=[0, 1])
1256
1257
    @classmethod
1258
    def tearDownClass(cls):
1259
        del cls.env
1260
1261
    @property
1262
    def source(self):
1263
        return RandomWalkSource(start=self._start, end=self._end)
1264
1265
    def test_history(self):
1266
        history_algo = """
1267
from zipline.api import history, add_history
1268
1269
def initialize(context):
1270
    add_history(10, '1d', 'price')
1271
1272
def handle_data(context, data):
1273
    df = history(10, '1d', 'price')
1274
"""
1275
1276
        algo = TradingAlgorithm(
1277
            script=history_algo,
1278
            sim_params=self.sim_params,
1279
            env=self.env,
1280
        )
1281
        output = algo.run(self.source)
1282
        self.assertIsNot(output, None)
1283
1284
    def test_history_without_add(self):
1285
        def handle_data(algo, data):
1286
            algo.history(1, '1m', 'price')
1287
1288
        algo = TradingAlgorithm(
1289
            initialize=lambda _: None,
1290
            handle_data=handle_data,
1291
            sim_params=self.sim_params,
1292
            env=self.env,
1293
        )
1294
        algo.run(self.source)
1295
1296
        self.assertIsNotNone(algo.history_container)
1297
        self.assertEqual(algo.history_container.buffer_panel.window_length, 1)
1298
1299
    def test_add_history_in_handle_data(self):
1300
        def handle_data(algo, data):
1301
            algo.add_history(1, '1m', 'price')
1302
1303
        algo = TradingAlgorithm(
1304
            initialize=lambda _: None,
1305
            handle_data=handle_data,
1306
            sim_params=self.sim_params,
1307
            env=self.env,
1308
        )
1309
        algo.run(self.source)
1310
1311
        self.assertIsNotNone(algo.history_container)
1312
        self.assertEqual(algo.history_container.buffer_panel.window_length, 1)
1313
1314
1315
class TestGetDatetime(TestCase):
1316
1317
    @classmethod
1318
    def setUpClass(cls):
1319
        cls.env = TradingEnvironment()
1320
        cls.env.write_data(equities_identifiers=[0, 1])
1321
1322
    @classmethod
1323
    def tearDownClass(cls):
1324
        del cls.env
1325
1326
    def setUp(self):
1327
        setup_logger(self)
1328
1329
    def tearDown(self):
1330
        teardown_logger(self)
1331
1332
    @parameterized.expand(
1333
        [
1334
            ('default', None,),
1335
            ('utc', 'UTC',),
1336
            ('us_east', 'US/Eastern',),
1337
        ]
1338
    )
1339
    def test_get_datetime(self, name, tz):
1340
1341
        algo = dedent(
1342
            """
1343
            import pandas as pd
1344
            from zipline.api import get_datetime
1345
1346
            def initialize(context):
1347
                context.tz = {tz} or 'UTC'
1348
                context.first_bar = True
1349
1350
            def handle_data(context, data):
1351
                if context.first_bar:
1352
                    dt = get_datetime({tz})
1353
                    if dt.tz.zone != context.tz:
1354
                        raise ValueError("Mismatched Zone")
1355
                    elif dt.tz_convert("US/Eastern").hour != 9:
1356
                        raise ValueError("Mismatched Hour")
1357
                    elif dt.tz_convert("US/Eastern").minute != 31:
1358
                        raise ValueError("Mismatched Minute")
1359
                context.first_bar = False
1360
            """.format(tz=repr(tz))
1361
        )
1362
1363
        start = to_utc('2014-01-02 9:31')
1364
        end = to_utc('2014-01-03 9:31')
1365
        source = RandomWalkSource(
1366
            start=start,
1367
            end=end,
1368
        )
1369
        sim_params = factory.create_simulation_parameters(
1370
            data_frequency='minute',
1371
            env=self.env,
1372
        )
1373
        algo = TradingAlgorithm(
1374
            script=algo,
1375
            sim_params=sim_params,
1376
            env=self.env,
1377
        )
1378
        algo.run(source)
1379
        self.assertFalse(algo.first_bar)
1380
1381
1382
class TestTradingControls(TestCase):
1383
1384
    @classmethod
1385
    def setUpClass(cls):
1386
        cls.sid = 133
1387
        cls.env = TradingEnvironment()
1388
        cls.env.write_data(equities_identifiers=[cls.sid])
1389
1390
    @classmethod
1391
    def tearDownClass(cls):
1392
        del cls.env
1393
1394
    def setUp(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
1395
        self.sim_params = factory.create_simulation_parameters(num_days=4,
1396
                                                               env=self.env)
1397
        self.trade_history = factory.create_trade_history(
1398
            self.sid,
1399
            [10.0, 10.0, 11.0, 11.0],
1400
            [100, 100, 100, 300],
1401
            timedelta(days=1),
1402
            self.sim_params,
1403
            self.env
1404
        )
1405
1406
        self.source = SpecificEquityTrades(
1407
            event_list=self.trade_history,
1408
            env=self.env,
1409
        )
1410
1411
    def _check_algo(self,
1412
                    algo,
1413
                    handle_data,
1414
                    expected_order_count,
1415
                    expected_exc):
1416
1417
        algo._handle_data = handle_data
1418
        with self.assertRaises(expected_exc) if expected_exc else nullctx():
1419
            algo.run(self.source)
1420
        self.assertEqual(algo.order_count, expected_order_count)
1421
        self.source.rewind()
1422
1423
    def check_algo_succeeds(self, algo, handle_data, order_count=4):
1424
        # Default for order_count assumes one order per handle_data call.
1425
        self._check_algo(algo, handle_data, order_count, None)
1426
1427
    def check_algo_fails(self, algo, handle_data, order_count):
1428
        self._check_algo(algo,
1429
                         handle_data,
1430
                         order_count,
1431
                         TradingControlViolation)
1432
1433
    def test_set_max_position_size(self):
1434
1435
        # Buy one share four times.  Should be fine.
1436
        def handle_data(algo, data):
1437
            algo.order(algo.sid(self.sid), 1)
1438
            algo.order_count += 1
1439
        algo = SetMaxPositionSizeAlgorithm(sid=self.sid,
1440
                                           max_shares=10,
1441
                                           max_notional=500.0,
1442
                                           sim_params=self.sim_params,
1443
                                           env=self.env)
1444
        self.check_algo_succeeds(algo, handle_data)
1445
1446
        # Buy three shares four times.  Should bail on the fourth before it's
1447
        # placed.
1448
        def handle_data(algo, data):
1449
            algo.order(algo.sid(self.sid), 3)
1450
            algo.order_count += 1
1451
1452
        algo = SetMaxPositionSizeAlgorithm(sid=self.sid,
1453
                                           max_shares=10,
1454
                                           max_notional=500.0,
1455
                                           sim_params=self.sim_params,
1456
                                           env=self.env)
1457
        self.check_algo_fails(algo, handle_data, 3)
1458
1459
        # Buy two shares four times. Should bail due to max_notional on the
1460
        # third attempt.
1461
        def handle_data(algo, data):
1462
            algo.order(algo.sid(self.sid), 3)
1463
            algo.order_count += 1
1464
1465
        algo = SetMaxPositionSizeAlgorithm(sid=self.sid,
1466
                                           max_shares=10,
1467
                                           max_notional=61.0,
1468
                                           sim_params=self.sim_params,
1469
                                           env=self.env)
1470
        self.check_algo_fails(algo, handle_data, 2)
1471
1472
        # Set the trading control to a different sid, then BUY ALL THE THINGS!.
1473
        # Should continue normally.
1474
        def handle_data(algo, data):
1475
            algo.order(algo.sid(self.sid), 10000)
1476
            algo.order_count += 1
1477
        algo = SetMaxPositionSizeAlgorithm(sid=self.sid + 1,
1478
                                           max_shares=10,
1479
                                           max_notional=61.0,
1480
                                           sim_params=self.sim_params,
1481
                                           env=self.env)
1482
        self.check_algo_succeeds(algo, handle_data)
1483
1484
        # Set the trading control sid to None, then BUY ALL THE THINGS!. Should
1485
        # fail because setting sid to None makes the control apply to all sids.
1486
        def handle_data(algo, data):
1487
            algo.order(algo.sid(self.sid), 10000)
1488
            algo.order_count += 1
1489
        algo = SetMaxPositionSizeAlgorithm(max_shares=10, max_notional=61.0,
1490
                                           sim_params=self.sim_params,
1491
                                           env=self.env)
1492
        self.check_algo_fails(algo, handle_data, 0)
1493
1494
    def test_set_do_not_order_list(self):
1495
        # set the restricted list to be the sid, and fail.
1496
        algo = SetDoNotOrderListAlgorithm(
1497
            sid=self.sid,
1498
            restricted_list=[self.sid],
1499
            sim_params=self.sim_params,
1500
            env=self.env,
1501
        )
1502
1503
        def handle_data(algo, data):
1504
            algo.order(algo.sid(self.sid), 100)
1505
            algo.order_count += 1
1506
1507
        self.check_algo_fails(algo, handle_data, 0)
1508
1509
        # set the restricted list to exclude the sid, and succeed
1510
        algo = SetDoNotOrderListAlgorithm(
1511
            sid=self.sid,
1512
            restricted_list=[134, 135, 136],
1513
            sim_params=self.sim_params,
1514
            env=self.env,
1515
        )
1516
1517
        def handle_data(algo, data):
1518
            algo.order(algo.sid(self.sid), 100)
1519
            algo.order_count += 1
1520
1521
        self.check_algo_succeeds(algo, handle_data)
1522
1523
    def test_set_max_order_size(self):
1524
1525
        # Buy one share.
1526
        def handle_data(algo, data):
1527
            algo.order(algo.sid(self.sid), 1)
1528
            algo.order_count += 1
1529
        algo = SetMaxOrderSizeAlgorithm(sid=self.sid,
1530
                                        max_shares=10,
1531
                                        max_notional=500.0,
1532
                                        sim_params=self.sim_params,
1533
                                        env=self.env)
1534
        self.check_algo_succeeds(algo, handle_data)
1535
1536
        # Buy 1, then 2, then 3, then 4 shares.  Bail on the last attempt
1537
        # because we exceed shares.
1538
        def handle_data(algo, data):
1539
            algo.order(algo.sid(self.sid), algo.order_count + 1)
1540
            algo.order_count += 1
1541
1542
        algo = SetMaxOrderSizeAlgorithm(sid=self.sid,
1543
                                        max_shares=3,
1544
                                        max_notional=500.0,
1545
                                        sim_params=self.sim_params,
1546
                                        env=self.env)
1547
        self.check_algo_fails(algo, handle_data, 3)
1548
1549
        # Buy 1, then 2, then 3, then 4 shares.  Bail on the last attempt
1550
        # because we exceed notional.
1551
        def handle_data(algo, data):
1552
            algo.order(algo.sid(self.sid), algo.order_count + 1)
1553
            algo.order_count += 1
1554
1555
        algo = SetMaxOrderSizeAlgorithm(sid=self.sid,
1556
                                        max_shares=10,
1557
                                        max_notional=40.0,
1558
                                        sim_params=self.sim_params,
1559
                                        env=self.env)
1560
        self.check_algo_fails(algo, handle_data, 3)
1561
1562
        # Set the trading control to a different sid, then BUY ALL THE THINGS!.
1563
        # Should continue normally.
1564
        def handle_data(algo, data):
1565
            algo.order(algo.sid(self.sid), 10000)
1566
            algo.order_count += 1
1567
        algo = SetMaxOrderSizeAlgorithm(sid=self.sid + 1,
1568
                                        max_shares=1,
1569
                                        max_notional=1.0,
1570
                                        sim_params=self.sim_params,
1571
                                        env=self.env)
1572
        self.check_algo_succeeds(algo, handle_data)
1573
1574
        # Set the trading control sid to None, then BUY ALL THE THINGS!.
1575
        # Should fail because not specifying a sid makes the trading control
1576
        # apply to all sids.
1577
        def handle_data(algo, data):
1578
            algo.order(algo.sid(self.sid), 10000)
1579
            algo.order_count += 1
1580
        algo = SetMaxOrderSizeAlgorithm(max_shares=1,
1581
                                        max_notional=1.0,
1582
                                        sim_params=self.sim_params,
1583
                                        env=self.env)
1584
        self.check_algo_fails(algo, handle_data, 0)
1585
1586
    def test_set_max_order_count(self):
1587
1588
        # Override the default setUp to use six-hour intervals instead of full
1589
        # days so we can exercise trading-session rollover logic.
1590
        trade_history = factory.create_trade_history(
1591
            self.sid,
1592
            [10.0, 10.0, 11.0, 11.0],
1593
            [100, 100, 100, 300],
1594
            timedelta(hours=6),
1595
            self.sim_params,
1596
            self.env
1597
        )
1598
        self.source = SpecificEquityTrades(event_list=trade_history,
1599
                                           env=self.env)
1600
1601
        def handle_data(algo, data):
1602
            for i in range(5):
1603
                algo.order(algo.sid(self.sid), 1)
1604
                algo.order_count += 1
1605
1606
        algo = SetMaxOrderCountAlgorithm(3, sim_params=self.sim_params,
1607
                                         env=self.env)
1608
        self.check_algo_fails(algo, handle_data, 3)
1609
1610
        # Second call to handle_data is the same day as the first, so the last
1611
        # order of the second call should fail.
1612
        algo = SetMaxOrderCountAlgorithm(9, sim_params=self.sim_params,
1613
                                         env=self.env)
1614
        self.check_algo_fails(algo, handle_data, 9)
1615
1616
        # Only ten orders are placed per day, so this should pass even though
1617
        # in total more than 20 orders are placed.
1618
        algo = SetMaxOrderCountAlgorithm(10, sim_params=self.sim_params,
1619
                                         env=self.env)
1620
        self.check_algo_succeeds(algo, handle_data, order_count=20)
1621
1622
    def test_long_only(self):
1623
        # Sell immediately -> fail immediately.
1624
        def handle_data(algo, data):
1625
            algo.order(algo.sid(self.sid), -1)
1626
            algo.order_count += 1
1627
        algo = SetLongOnlyAlgorithm(sim_params=self.sim_params, env=self.env)
1628
        self.check_algo_fails(algo, handle_data, 0)
1629
1630
        # Buy on even days, sell on odd days.  Never takes a short position, so
1631
        # should succeed.
1632
        def handle_data(algo, data):
1633
            if (algo.order_count % 2) == 0:
1634
                algo.order(algo.sid(self.sid), 1)
1635
            else:
1636
                algo.order(algo.sid(self.sid), -1)
1637
            algo.order_count += 1
1638
        algo = SetLongOnlyAlgorithm(sim_params=self.sim_params, env=self.env)
1639
        self.check_algo_succeeds(algo, handle_data)
1640
1641
        # Buy on first three days, then sell off holdings.  Should succeed.
1642
        def handle_data(algo, data):
1643
            amounts = [1, 1, 1, -3]
1644
            algo.order(algo.sid(self.sid), amounts[algo.order_count])
1645
            algo.order_count += 1
1646
        algo = SetLongOnlyAlgorithm(sim_params=self.sim_params, env=self.env)
1647
        self.check_algo_succeeds(algo, handle_data)
1648
1649
        # Buy on first three days, then sell off holdings plus an extra share.
1650
        # Should fail on the last sale.
1651
        def handle_data(algo, data):
1652
            amounts = [1, 1, 1, -4]
1653
            algo.order(algo.sid(self.sid), amounts[algo.order_count])
1654
            algo.order_count += 1
1655
        algo = SetLongOnlyAlgorithm(sim_params=self.sim_params, env=self.env)
1656
        self.check_algo_fails(algo, handle_data, 3)
1657
1658
    def test_register_post_init(self):
1659
1660
        def initialize(algo):
1661
            algo.initialized = True
1662
1663
        def handle_data(algo, data):
1664
1665
            with self.assertRaises(RegisterTradingControlPostInit):
1666
                algo.set_max_position_size(self.sid, 1, 1)
1667
            with self.assertRaises(RegisterTradingControlPostInit):
1668
                algo.set_max_order_size(self.sid, 1, 1)
1669
            with self.assertRaises(RegisterTradingControlPostInit):
1670
                algo.set_max_order_count(1)
1671
            with self.assertRaises(RegisterTradingControlPostInit):
1672
                algo.set_long_only()
1673
1674
        algo = TradingAlgorithm(initialize=initialize,
1675
                                handle_data=handle_data,
1676
                                sim_params=self.sim_params,
1677
                                env=self.env)
1678
        algo.run(self.source)
1679
        self.source.rewind()
1680
1681
    def test_asset_date_bounds(self):
1682
1683
        # Run the algorithm with a sid that ends far in the future
1684
        temp_env = TradingEnvironment()
1685
        df_source, _ = factory.create_test_df_source(self.sim_params, temp_env)
1686
        metadata = {0: {'start_date': '1990-01-01',
1687
                        'end_date': '2020-01-01'}}
1688
        algo = SetAssetDateBoundsAlgorithm(
1689
            equities_metadata=metadata,
1690
            sim_params=self.sim_params,
1691
            env=temp_env,
1692
        )
1693
        algo.run(df_source)
1694
1695
        # Run the algorithm with a sid that has already ended
1696
        temp_env = TradingEnvironment()
1697
        df_source, _ = factory.create_test_df_source(self.sim_params, temp_env)
1698
        metadata = {0: {'start_date': '1989-01-01',
1699
                        'end_date': '1990-01-01'}}
1700
        algo = SetAssetDateBoundsAlgorithm(
1701
            equities_metadata=metadata,
1702
            sim_params=self.sim_params,
1703
            env=temp_env,
1704
        )
1705
        with self.assertRaises(TradingControlViolation):
1706
            algo.run(df_source)
1707
1708
        # Run the algorithm with a sid that has not started
1709
        temp_env = TradingEnvironment()
1710
        df_source, _ = factory.create_test_df_source(self.sim_params, temp_env)
1711
        metadata = {0: {'start_date': '2020-01-01',
1712
                        'end_date': '2021-01-01'}}
1713
        algo = SetAssetDateBoundsAlgorithm(
1714
            equities_metadata=metadata,
1715
            sim_params=self.sim_params,
1716
            env=temp_env,
1717
        )
1718
        with self.assertRaises(TradingControlViolation):
1719
            algo.run(df_source)
1720
1721
        # Run the algorithm with a sid that starts on the first day and
1722
        # ends on the last day of the algorithm's parameters (*not* an error).
1723
        temp_env = TradingEnvironment()
1724
        df_source, _ = factory.create_test_df_source(self.sim_params, temp_env)
1725
        metadata = {0: {'start_date': '2006-01-03',
1726
                        'end_date': '2006-01-06'}}
1727
        algo = SetAssetDateBoundsAlgorithm(
1728
            equities_metadata=metadata,
1729
            sim_params=self.sim_params,
1730
            env=temp_env,
1731
        )
1732
        algo.run(df_source)
1733
1734
1735
class TestAccountControls(TestCase):
1736
1737
    @classmethod
1738
    def setUpClass(cls):
1739
        cls.sidint = 133
1740
        cls.env = TradingEnvironment()
1741
        cls.env.write_data(
1742
            equities_identifiers=[cls.sidint]
1743
        )
1744
1745
    @classmethod
1746
    def tearDownClass(cls):
1747
        del cls.env
1748
1749
    def setUp(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
1750
        self.sim_params = factory.create_simulation_parameters(
1751
            num_days=4, env=self.env
1752
        )
1753
        self.trade_history = factory.create_trade_history(
1754
            self.sidint,
1755
            [10.0, 10.0, 11.0, 11.0],
1756
            [100, 100, 100, 300],
1757
            timedelta(days=1),
1758
            self.sim_params,
1759
            self.env,
1760
        )
1761
1762
        self.source = SpecificEquityTrades(
1763
            event_list=self.trade_history,
1764
            env=self.env,
1765
        )
1766
1767
    def _check_algo(self,
1768
                    algo,
1769
                    handle_data,
1770
                    expected_exc):
1771
1772
        algo._handle_data = handle_data
1773
        with self.assertRaises(expected_exc) if expected_exc else nullctx():
1774
            algo.run(self.source)
1775
        self.source.rewind()
1776
1777
    def check_algo_succeeds(self, algo, handle_data):
1778
        # Default for order_count assumes one order per handle_data call.
1779
        self._check_algo(algo, handle_data, None)
1780
1781
    def check_algo_fails(self, algo, handle_data):
1782
        self._check_algo(algo,
1783
                         handle_data,
1784
                         AccountControlViolation)
1785
1786
    def test_set_max_leverage(self):
1787
1788
        # Set max leverage to 0 so buying one share fails.
1789
        def handle_data(algo, data):
1790
            algo.order(algo.sid(self.sidint), 1)
1791
1792
        algo = SetMaxLeverageAlgorithm(0, sim_params=self.sim_params,
1793
                                       env=self.env)
1794
        self.check_algo_fails(algo, handle_data)
1795
1796
        # Set max leverage to 1 so buying one share passes
1797
        def handle_data(algo, data):
1798
            algo.order(algo.sid(self.sidint), 1)
1799
1800
        algo = SetMaxLeverageAlgorithm(1,  sim_params=self.sim_params,
1801
                                       env=self.env)
1802
        self.check_algo_succeeds(algo, handle_data)
1803
1804
1805
class TestClosePosAlgo(TestCase):
1806
1807
    def setUp(self):
1808
        self.env = TradingEnvironment()
1809
        self.days = self.env.trading_days[:4]
1810
        self.panel = pd.Panel({1: pd.DataFrame({
1811
            'price': [1, 1, 2, 4], 'volume': [1e9, 1e9, 1e9, 0],
1812
            'type': [DATASOURCE_TYPE.TRADE,
1813
                     DATASOURCE_TYPE.TRADE,
1814
                     DATASOURCE_TYPE.TRADE,
1815
                     DATASOURCE_TYPE.CLOSE_POSITION]},
1816
            index=self.days)
1817
        })
1818
        self.no_close_panel = pd.Panel({1: pd.DataFrame({
1819
            'price': [1, 1, 2, 4], 'volume': [1e9, 1e9, 1e9, 1e9],
1820
            'type': [DATASOURCE_TYPE.TRADE,
1821
                     DATASOURCE_TYPE.TRADE,
1822
                     DATASOURCE_TYPE.TRADE,
1823
                     DATASOURCE_TYPE.TRADE]},
1824
            index=self.days)
1825
        })
1826
1827
    def test_close_position_equity(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
1828
        metadata = {1: {'symbol': 'TEST',
1829
                        'end_date': self.days[3]}}
1830
        self.env.write_data(equities_data=metadata)
1831
        algo = TestAlgorithm(sid=1, amount=1, order_count=1,
1832
                             commission=PerShare(0),
1833
                             env=self.env)
1834
        data = DataPanelSource(self.panel)
1835
1836
        # Check results
1837
        expected_positions = [0, 1, 1, 0]
1838
        expected_pnl = [0, 0, 1, 2]
1839
        results = algo.run(data)
1840
        self.check_algo_positions(results, expected_positions)
1841
        self.check_algo_pnl(results, expected_pnl)
1842
1843
    def test_close_position_future(self):
1844
        metadata = {1: {'symbol': 'TEST'}}
1845
        self.env.write_data(futures_data=metadata)
1846
        algo = TestAlgorithm(sid=1, amount=1, order_count=1,
1847
                             commission=PerShare(0),
1848
                             env=self.env)
1849
        data = DataPanelSource(self.panel)
1850
1851
        # Check results
1852
        expected_positions = [0, 1, 1, 0]
1853
        expected_pnl = [0, 0, 1, 2]
1854
        results = algo.run(data)
1855
        self.check_algo_pnl(results, expected_pnl)
1856
        self.check_algo_positions(results, expected_positions)
1857
1858
    def test_auto_close_future(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
1859
        metadata = {1: {'symbol': 'TEST',
1860
                        'auto_close_date': self.env.trading_days[4]}}
1861
        self.env.write_data(futures_data=metadata)
1862
        algo = TestAlgorithm(sid=1, amount=1, order_count=1,
1863
                             commission=PerShare(0),
1864
                             env=self.env)
1865
        data = DataPanelSource(self.no_close_panel)
1866
1867
        # Check results
1868
        results = algo.run(data)
1869
1870
        expected_positions = [0, 1, 1, 0]
1871
        self.check_algo_positions(results, expected_positions)
1872
1873
        expected_pnl = [0, 0, 1, 2]
1874
        self.check_algo_pnl(results, expected_pnl)
1875
1876
    def check_algo_pnl(self, results, expected_pnl):
1877
        np.testing.assert_array_almost_equal(results.pnl, expected_pnl)
1878
1879
    def check_algo_positions(self, results, expected_positions):
1880
        for i, amount in enumerate(results.positions):
1881
            if amount:
1882
                actual_position = amount[0]['amount']
1883
            else:
1884
                actual_position = 0
1885
1886
            self.assertEqual(
1887
                actual_position, expected_positions[i],
1888
                "position for day={0} not equal, actual={1}, expected={2}".
1889
                format(i, actual_position, expected_positions[i]))
1890
1891
1892
class TestFutureFlip(TestCase):
1893
    def setUp(self):
1894
        self.env = TradingEnvironment()
1895
        self.days = self.env.trading_days[:4]
1896
        self.trades_panel = pd.Panel({1: pd.DataFrame({
1897
            'price': [1, 2, 4], 'volume': [1e9, 1e9, 1e9],
1898
            'type': [DATASOURCE_TYPE.TRADE,
1899
                     DATASOURCE_TYPE.TRADE,
1900
                     DATASOURCE_TYPE.TRADE]},
1901
            index=self.days[:3])
1902
        })
1903
1904
    def test_flip_algo(self):
1905
        metadata = {1: {'symbol': 'TEST',
1906
                        'end_date': self.days[3],
1907
                        'contract_multiplier': 5}}
1908
        self.env.write_data(futures_data=metadata)
1909
1910
        algo = FutureFlipAlgo(sid=1, amount=1, env=self.env,
1911
                              commission=PerShare(0),
1912
                              order_count=0,  # not applicable but required
1913
                              instant_fill=True)
1914
        data = DataPanelSource(self.trades_panel)
1915
1916
        results = algo.run(data)
1917
1918
        expected_positions = [1, -1, 0]
1919
        self.check_algo_positions(results, expected_positions)
1920
1921
        expected_pnl = [0, 5, -10]
1922
        self.check_algo_pnl(results, expected_pnl)
1923
1924
    def check_algo_pnl(self, results, expected_pnl):
1925
        np.testing.assert_array_almost_equal(results.pnl, expected_pnl)
1926
1927
    def check_algo_positions(self, results, expected_positions):
1928
        for i, amount in enumerate(results.positions):
1929
            if amount:
1930
                actual_position = amount[0]['amount']
1931
            else:
1932
                actual_position = 0
1933
1934
            self.assertEqual(
1935
                actual_position, expected_positions[i],
1936
                "position for day={0} not equal, actual={1}, expected={2}".
1937
                format(i, actual_position, expected_positions[i]))
1938
1939
1940
class TestTradingAlgorithm(TestCase):
1941
    def setUp(self):
1942
        self.env = TradingEnvironment()
1943
        self.days = self.env.trading_days[:4]
1944
        self.panel = pd.Panel({1: pd.DataFrame({
1945
            'price': [1, 1, 2, 4], 'volume': [1e9, 1e9, 1e9, 0],
1946
            'type': [DATASOURCE_TYPE.TRADE,
1947
                     DATASOURCE_TYPE.TRADE,
1948
                     DATASOURCE_TYPE.TRADE,
1949
                     DATASOURCE_TYPE.CLOSE_POSITION]},
1950
            index=self.days)
1951
        })
1952
1953
    def test_analyze_called(self):
1954
        self.perf_ref = None
1955
1956
        def initialize(context):
1957
            pass
1958
1959
        def handle_data(context, data):
1960
            pass
1961
1962
        def analyze(context, perf):
1963
            self.perf_ref = perf
1964
1965
        algo = TradingAlgorithm(initialize=initialize, handle_data=handle_data,
1966
                                analyze=analyze)
1967
        results = algo.run(self.panel)
1968
        self.assertIs(results, self.perf_ref)
1969
1970
1971
class TestRemoveData(TestCase):
1972
    """
1973
    tests if futures data is removed after expiry
1974
    """
1975
    def setUp(self):
1976
        dt = pd.Timestamp('2015-01-02', tz='UTC')
1977
        env = TradingEnvironment()
1978
        ix = env.trading_days.get_loc(dt)
1979
1980
        metadata = {0: {'symbol': 'X',
1981
                        'expiration_date': env.trading_days[ix+5],
1982
                        'end_date': env.trading_days[ix+6]},
1983
                    1: {'symbol': 'Y',
1984
                        'expiration_date': env.trading_days[ix+7],
1985
                        'end_date': env.trading_days[ix+8]}}
1986
1987
        env.write_data(futures_data=metadata)
1988
1989
        index_x = env.trading_days[ix:ix+5]
1990
        data_x = pd.DataFrame([[1, 100], [2, 100], [3, 100], [4, 100],
1991
                               [5, 100]],
1992
                              index=index_x, columns=['price', 'volume'])
1993
        index_y = env.trading_days[ix:ix+5].shift(2)
1994
        data_y = pd.DataFrame([[6, 100], [7, 100], [8, 100], [9, 100],
1995
                               [10, 100]],
1996
                              index=index_y, columns=['price', 'volume'])
1997
1998
        pan = pd.Panel({0: data_x, 1: data_y})
1999
        self.source = DataPanelSource(pan)
2000
        self.algo = TestRemoveDataAlgo(env=env)
2001
2002
    def test_remove_data(self):
2003
        self.algo.run(self.source)
2004
2005
        expected_length = [1, 1, 2, 2, 2, 1, 1]
2006
        # initially only data for X should be sent and on the last day only
2007
        # data for Y should be sent since X is expired
2008
        for i, length in enumerate(self.algo.data):
2009
            self.assertEqual(expected_length[i], length, i)
2010