Completed
Push — master ( 5c3ca1...d3d362 )
by Joe
01:27
created

tests.TestTradingControls.check_algo_succeeds()   A

Complexity

Conditions 1

Size

Total Lines 3

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 1
dl 0
loc 3
rs 10
1
#
2
# Copyright 2014 Quantopian, Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
import datetime
16
from datetime import timedelta
17
from mock import MagicMock
18
from nose_parameterized import parameterized
19
from six.moves import range, map
20
from textwrap import dedent
21
from unittest import TestCase
22
23
import numpy as np
24
import pandas as pd
25
26
from zipline.assets import Equity, Future
27
from zipline.utils.api_support import ZiplineAPI
28
from zipline.utils.control_flow import nullctx
29
from zipline.utils.test_utils import (
30
    setup_logger,
31
    teardown_logger
32
)
33
import zipline.utils.factory as factory
34
import zipline.utils.simfactory as simfactory
35
36
from zipline.errors import (
37
    OrderDuringInitialize,
38
    RegisterTradingControlPostInit,
39
    TradingControlViolation,
40
    AccountControlViolation,
41
    SymbolNotFound,
42
    RootSymbolNotFound,
43
    UnsupportedDatetimeFormat,
44
)
45
from zipline.test_algorithms import (
46
    access_account_in_init,
47
    access_portfolio_in_init,
48
    AmbitiousStopLimitAlgorithm,
49
    EmptyPositionsAlgorithm,
50
    InvalidOrderAlgorithm,
51
    RecordAlgorithm,
52
    FutureFlipAlgo,
53
    TestAlgorithm,
54
    TestOrderAlgorithm,
55
    TestOrderInstantAlgorithm,
56
    TestOrderPercentAlgorithm,
57
    TestOrderStyleForwardingAlgorithm,
58
    TestOrderValueAlgorithm,
59
    TestRegisterTransformAlgorithm,
60
    TestTargetAlgorithm,
61
    TestTargetPercentAlgorithm,
62
    TestTargetValueAlgorithm,
63
    SetLongOnlyAlgorithm,
64
    SetAssetDateBoundsAlgorithm,
65
    SetMaxPositionSizeAlgorithm,
66
    SetMaxOrderCountAlgorithm,
67
    SetMaxOrderSizeAlgorithm,
68
    SetDoNotOrderListAlgorithm,
69
    SetMaxLeverageAlgorithm,
70
    api_algo,
71
    api_get_environment_algo,
72
    api_symbol_algo,
73
    call_all_order_methods,
74
    call_order_in_init,
75
    handle_data_api,
76
    handle_data_noop,
77
    initialize_api,
78
    initialize_noop,
79
    noop_algo,
80
    record_float_magic,
81
    record_variables,
82
)
83
from zipline.utils.context_tricks import CallbackManager
84
import zipline.utils.events
85
from zipline.utils.test_utils import (
86
    assert_single_position,
87
    drain_zipline,
88
    to_utc,
89
)
90
91
from zipline.sources import (SpecificEquityTrades,
92
                             DataFrameSource,
93
                             DataPanelSource,
94
                             RandomWalkSource)
95
96
from zipline.finance.execution import LimitOrder
97
from zipline.finance.trading import SimulationParameters
98
from zipline.utils.api_support import set_algo_instance
99
from zipline.utils.events import DateRuleFactory, TimeRuleFactory, Always
100
from zipline.algorithm import TradingAlgorithm
101
from zipline.protocol import DATASOURCE_TYPE
102
from zipline.finance.trading import TradingEnvironment
103
from zipline.finance.commission import PerShare
104
105
# Because test cases appear to reuse some resources.
106
_multiprocess_can_split_ = False
107
108
109
class TestRecordAlgorithm(TestCase):
110
111
    @classmethod
112
    def setUpClass(cls):
113
        cls.env = TradingEnvironment()
114
        cls.env.write_data(equities_identifiers=[133])
115
116
    @classmethod
117
    def tearDownClass(cls):
118
        del cls.env
119
120
    def setUp(self):
121
        self.sim_params = factory.create_simulation_parameters(num_days=4,
122
                                                               env=self.env)
123
        trade_history = factory.create_trade_history(
124
            133,
125
            [10.0, 10.0, 11.0, 11.0],
126
            [100, 100, 100, 300],
127
            timedelta(days=1),
128
            self.sim_params,
129
            self.env
130
        )
131
132
        self.source = SpecificEquityTrades(event_list=trade_history,
133
                                           env=self.env)
134
        self.df_source, self.df = \
135
            factory.create_test_df_source(self.sim_params, self.env)
136
137
    def test_record_incr(self):
138
        algo = RecordAlgorithm(sim_params=self.sim_params, env=self.env)
139
        output = algo.run(self.source)
140
141
        np.testing.assert_array_equal(output['incr'].values,
142
                                      range(1, len(output) + 1))
143
        np.testing.assert_array_equal(output['name'].values,
144
                                      range(1, len(output) + 1))
145
        np.testing.assert_array_equal(output['name2'].values,
146
                                      [2] * len(output))
147
        np.testing.assert_array_equal(output['name3'].values,
148
                                      range(1, len(output) + 1))
149
150
151
class TestMiscellaneousAPI(TestCase):
152
153
    @classmethod
154
    def setUpClass(cls):
155
        cls.sids = [1, 2]
156
        cls.env = TradingEnvironment()
157
158
        metadata = {3: {'symbol': 'PLAY',
159
                        'start_date': '2002-01-01',
160
                        'end_date': '2004-01-01'},
161
                    4: {'symbol': 'PLAY',
162
                        'start_date': '2005-01-01',
163
                        'end_date': '2006-01-01'}}
164
165
        futures_metadata = {
166
            5: {
167
                'symbol': 'CLG06',
168
                'root_symbol': 'CL',
169
                'start_date': pd.Timestamp('2005-12-01', tz='UTC'),
170
                'notice_date': pd.Timestamp('2005-12-20', tz='UTC'),
171
                'expiration_date': pd.Timestamp('2006-01-20', tz='UTC')},
172
            6: {
173
                'root_symbol': 'CL',
174
                'symbol': 'CLK06',
175
                'start_date': pd.Timestamp('2005-12-01', tz='UTC'),
176
                'notice_date': pd.Timestamp('2006-03-20', tz='UTC'),
177
                'expiration_date': pd.Timestamp('2006-04-20', tz='UTC')},
178
            7: {
179
                'symbol': 'CLQ06',
180
                'root_symbol': 'CL',
181
                'start_date': pd.Timestamp('2005-12-01', tz='UTC'),
182
                'notice_date': pd.Timestamp('2006-06-20', tz='UTC'),
183
                'expiration_date': pd.Timestamp('2006-07-20', tz='UTC')},
184
            8: {
185
                'symbol': 'CLX06',
186
                'root_symbol': 'CL',
187
                'start_date': pd.Timestamp('2006-02-01', tz='UTC'),
188
                'notice_date': pd.Timestamp('2006-09-20', tz='UTC'),
189
                'expiration_date': pd.Timestamp('2006-10-20', tz='UTC')}
190
        }
191
        cls.env.write_data(equities_identifiers=cls.sids,
192
                           equities_data=metadata,
193
                           futures_data=futures_metadata)
194
195
    @classmethod
196
    def tearDownClass(cls):
197
        del cls.env
198
199
    def setUp(self):
200
        setup_logger(self)
201
        self.sim_params = factory.create_simulation_parameters(
202
            num_days=2,
203
            data_frequency='minute',
204
            emission_rate='minute',
205
            env=self.env,
206
        )
207
        self.source = factory.create_minutely_trade_source(
208
            self.sids,
209
            sim_params=self.sim_params,
210
            concurrent=True,
211
            env=self.env,
212
        )
213
214
    def tearDown(self):
215
        teardown_logger(self)
216
217
    def test_zipline_api_resolves_dynamically(self):
218
        # Make a dummy algo.
219
        algo = TradingAlgorithm(
220
            initialize=lambda context: None,
221
            handle_data=lambda context, data: None,
222
            sim_params=self.sim_params,
223
        )
224
225
        # Verify that api methods get resolved dynamically by patching them out
226
        # and then calling them
227
        for method in algo.all_api_methods():
228
            name = method.__name__
229
            sentinel = object()
230
231
            def fake_method(*args, **kwargs):
232
                return sentinel
233
            setattr(algo, name, fake_method)
234
            with ZiplineAPI(algo):
235
                self.assertIs(sentinel, getattr(zipline.api, name)())
236
237
    def test_get_environment(self):
238
        expected_env = {
239
            'arena': 'backtest',
240
            'data_frequency': 'minute',
241
            'start': pd.Timestamp('2006-01-03 14:31:00+0000', tz='UTC'),
242
            'end': pd.Timestamp('2006-01-04 21:00:00+0000', tz='UTC'),
243
            'capital_base': 100000.0,
244
            'platform': 'zipline'
245
        }
246
247
        def initialize(algo):
248
            self.assertEqual('zipline', algo.get_environment())
249
            self.assertEqual(expected_env, algo.get_environment('*'))
250
251
        def handle_data(algo, data):
252
            pass
253
254
        algo = TradingAlgorithm(initialize=initialize,
255
                                handle_data=handle_data,
256
                                sim_params=self.sim_params,
257
                                env=self.env)
258
        algo.run(self.source)
259
260
    def test_get_open_orders(self):
261
262
        def initialize(algo):
263
            algo.minute = 0
264
265
        def handle_data(algo, data):
266
            if algo.minute == 0:
267
268
                # Should be filled by the next minute
269
                algo.order(algo.sid(1), 1)
270
271
                # Won't be filled because the price is too low.
272
                algo.order(algo.sid(2), 1, style=LimitOrder(0.01))
273
                algo.order(algo.sid(2), 1, style=LimitOrder(0.01))
274
                algo.order(algo.sid(2), 1, style=LimitOrder(0.01))
275
276
                all_orders = algo.get_open_orders()
277
                self.assertEqual(list(all_orders.keys()), [1, 2])
278
279
                self.assertEqual(all_orders[1], algo.get_open_orders(1))
280
                self.assertEqual(len(all_orders[1]), 1)
281
282
                self.assertEqual(all_orders[2], algo.get_open_orders(2))
283
                self.assertEqual(len(all_orders[2]), 3)
284
285
            if algo.minute == 1:
286
                # First order should have filled.
287
                # Second order should still be open.
288
                all_orders = algo.get_open_orders()
289
                self.assertEqual(list(all_orders.keys()), [2])
290
291
                self.assertEqual([], algo.get_open_orders(1))
292
293
                orders_2 = algo.get_open_orders(2)
294
                self.assertEqual(all_orders[2], orders_2)
295
                self.assertEqual(len(all_orders[2]), 3)
296
297
                for order in orders_2:
298
                    algo.cancel_order(order)
299
300
                all_orders = algo.get_open_orders()
301
                self.assertEqual(all_orders, {})
302
303
            algo.minute += 1
304
305
        algo = TradingAlgorithm(initialize=initialize,
306
                                handle_data=handle_data,
307
                                sim_params=self.sim_params,
308
                                env=self.env)
309
        algo.run(self.source)
310
311
    def test_schedule_function(self):
312
        date_rules = DateRuleFactory
313
        time_rules = TimeRuleFactory
314
315
        def incrementer(algo, data):
316
            algo.func_called += 1
317
            self.assertEqual(
318
                algo.get_datetime().time(),
319
                datetime.time(hour=14, minute=31),
320
            )
321
322
        def initialize(algo):
323
            algo.func_called = 0
324
            algo.days = 1
325
            algo.date = None
326
            algo.schedule_function(
327
                func=incrementer,
328
                date_rule=date_rules.every_day(),
329
                time_rule=time_rules.market_open(),
330
            )
331
332
        def handle_data(algo, data):
333
            if not algo.date:
334
                algo.date = algo.get_datetime().date()
335
336
            if algo.date < algo.get_datetime().date():
337
                algo.days += 1
338
                algo.date = algo.get_datetime().date()
339
340
        algo = TradingAlgorithm(
341
            initialize=initialize,
342
            handle_data=handle_data,
343
            sim_params=self.sim_params,
344
            env=self.env,
345
        )
346
        algo.run(self.source)
347
348
        self.assertEqual(algo.func_called, algo.days)
349
350
    def test_event_context(self):
351
        expected_data = []
352
        collected_data_pre = []
353
        collected_data_post = []
354
        function_stack = []
355
356
        def pre(data):
357
            function_stack.append(pre)
358
            collected_data_pre.append(data)
359
360
        def post(data):
361
            function_stack.append(post)
362
            collected_data_post.append(data)
363
364
        def initialize(context):
365
            context.add_event(Always(), f)
366
            context.add_event(Always(), g)
367
368
        def handle_data(context, data):
369
            function_stack.append(handle_data)
370
            expected_data.append(data)
371
372
        def f(context, data):
373
            function_stack.append(f)
374
375
        def g(context, data):
376
            function_stack.append(g)
377
378
        algo = TradingAlgorithm(
379
            initialize=initialize,
380
            handle_data=handle_data,
381
            sim_params=self.sim_params,
382
            create_event_context=CallbackManager(pre, post),
383
            env=self.env,
384
        )
385
        algo.run(self.source)
386
387
        self.assertEqual(len(expected_data), 779)
388
        self.assertEqual(collected_data_pre, expected_data)
389
        self.assertEqual(collected_data_post, expected_data)
390
391
        self.assertEqual(
392
            len(function_stack),
393
            779 * 5,
394
            'Incorrect number of functions called: %s != 779' %
395
            len(function_stack),
396
        )
397
        expected_functions = [pre, handle_data, f, g, post] * 779
398
        for n, (f, g) in enumerate(zip(function_stack, expected_functions)):
399
            self.assertEqual(
400
                f,
401
                g,
402
                'function at position %d was incorrect, expected %s but got %s'
403
                % (n, g.__name__, f.__name__),
404
            )
405
406
    @parameterized.expand([
407
        ('daily',),
408
        ('minute'),
409
    ])
410
    def test_schedule_funtion_rule_creation(self, mode):
411
        def nop(*args, **kwargs):
412
            return None
413
414
        self.sim_params.data_frequency = mode
415
        algo = TradingAlgorithm(
416
            initialize=nop,
417
            handle_data=nop,
418
            sim_params=self.sim_params,
419
            env=self.env,
420
        )
421
422
        # Schedule something for NOT Always.
423
        algo.schedule_function(nop, time_rule=zipline.utils.events.Never())
424
425
        event_rule = algo.event_manager._events[1].rule
426
427
        self.assertIsInstance(event_rule, zipline.utils.events.OncePerDay)
428
429
        inner_rule = event_rule.rule
430
        self.assertIsInstance(inner_rule, zipline.utils.events.ComposedRule)
431
432
        first = inner_rule.first
433
        second = inner_rule.second
434
        composer = inner_rule.composer
435
436
        self.assertIsInstance(first, zipline.utils.events.Always)
437
438
        if mode == 'daily':
439
            self.assertIsInstance(second, zipline.utils.events.Always)
440
        else:
441
            self.assertIsInstance(second, zipline.utils.events.Never)
442
443
        self.assertIs(composer, zipline.utils.events.ComposedRule.lazy_and)
444
445
    def test_asset_lookup(self):
446
447
        algo = TradingAlgorithm(env=self.env)
448
449
        # Test before either PLAY existed
450
        algo.sim_params.period_end = pd.Timestamp('2001-12-01', tz='UTC')
451
        with self.assertRaises(SymbolNotFound):
452
            algo.symbol('PLAY')
453
        with self.assertRaises(SymbolNotFound):
454
            algo.symbols('PLAY')
455
456
        # Test when first PLAY exists
457
        algo.sim_params.period_end = pd.Timestamp('2002-12-01', tz='UTC')
458
        list_result = algo.symbols('PLAY')
459
        self.assertEqual(3, list_result[0])
460
461
        # Test after first PLAY ends
462
        algo.sim_params.period_end = pd.Timestamp('2004-12-01', tz='UTC')
463
        self.assertEqual(3, algo.symbol('PLAY'))
464
465
        # Test after second PLAY begins
466
        algo.sim_params.period_end = pd.Timestamp('2005-12-01', tz='UTC')
467
        self.assertEqual(4, algo.symbol('PLAY'))
468
469
        # Test after second PLAY ends
470
        algo.sim_params.period_end = pd.Timestamp('2006-12-01', tz='UTC')
471
        self.assertEqual(4, algo.symbol('PLAY'))
472
        list_result = algo.symbols('PLAY')
473
        self.assertEqual(4, list_result[0])
474
475
        # Test lookup SID
476
        self.assertIsInstance(algo.sid(3), Equity)
477
        self.assertIsInstance(algo.sid(4), Equity)
478
479
        # Supplying a non-string argument to symbol()
480
        # should result in a TypeError.
481
        with self.assertRaises(TypeError):
482
            algo.symbol(1)
483
484
        with self.assertRaises(TypeError):
485
            algo.symbol((1,))
486
487
        with self.assertRaises(TypeError):
488
            algo.symbol({1})
489
490
        with self.assertRaises(TypeError):
491
            algo.symbol([1])
492
493
        with self.assertRaises(TypeError):
494
            algo.symbol({'foo': 'bar'})
495
496
    def test_future_symbol(self):
497
        """ Tests the future_symbol API function.
498
        """
499
        algo = TradingAlgorithm(env=self.env)
500
        algo.datetime = pd.Timestamp('2006-12-01', tz='UTC')
501
502
        # Check that we get the correct fields for the CLG06 symbol
503
        cl = algo.future_symbol('CLG06')
504
        self.assertEqual(cl.sid, 5)
505
        self.assertEqual(cl.symbol, 'CLG06')
506
        self.assertEqual(cl.root_symbol, 'CL')
507
        self.assertEqual(cl.start_date, pd.Timestamp('2005-12-01', tz='UTC'))
508
        self.assertEqual(cl.notice_date, pd.Timestamp('2005-12-20', tz='UTC'))
509
        self.assertEqual(cl.expiration_date,
510
                         pd.Timestamp('2006-01-20', tz='UTC'))
511
512
        with self.assertRaises(SymbolNotFound):
513
            algo.future_symbol('')
514
515
        with self.assertRaises(SymbolNotFound):
516
            algo.future_symbol('PLAY')
517
518
        with self.assertRaises(SymbolNotFound):
519
            algo.future_symbol('FOOBAR')
520
521
        # Supplying a non-string argument to future_symbol()
522
        # should result in a TypeError.
523
        with self.assertRaises(TypeError):
524
            algo.future_symbol(1)
525
526
        with self.assertRaises(TypeError):
527
            algo.future_symbol((1,))
528
529
        with self.assertRaises(TypeError):
530
            algo.future_symbol({1})
531
532
        with self.assertRaises(TypeError):
533
            algo.future_symbol([1])
534
535
        with self.assertRaises(TypeError):
536
            algo.future_symbol({'foo': 'bar'})
537
538
    def test_future_chain(self):
539
        """ Tests the future_chain API function.
540
        """
541
        algo = TradingAlgorithm(env=self.env)
542
        algo.datetime = pd.Timestamp('2006-12-01', tz='UTC')
543
544
        # Check that the fields of the FutureChain object are set correctly
545
        cl = algo.future_chain('CL')
546
        self.assertEqual(cl.root_symbol, 'CL')
547
        self.assertEqual(cl.as_of_date, algo.datetime)
548
549
        # Check the fields are set correctly if an as_of_date is supplied
550
        as_of_date = pd.Timestamp('1952-08-11', tz='UTC')
551
552
        cl = algo.future_chain('CL', as_of_date=as_of_date)
553
        self.assertEqual(cl.root_symbol, 'CL')
554
        self.assertEqual(cl.as_of_date, as_of_date)
555
556
        cl = algo.future_chain('CL', as_of_date='1952-08-11')
557
        self.assertEqual(cl.root_symbol, 'CL')
558
        self.assertEqual(cl.as_of_date, as_of_date)
559
560
        # Check that weird capitalization is corrected
561
        cl = algo.future_chain('cL')
562
        self.assertEqual(cl.root_symbol, 'CL')
563
564
        cl = algo.future_chain('cl')
565
        self.assertEqual(cl.root_symbol, 'CL')
566
567
        # Check that invalid root symbols raise RootSymbolNotFound
568
        with self.assertRaises(RootSymbolNotFound):
569
            algo.future_chain('CLZ')
570
571
        with self.assertRaises(RootSymbolNotFound):
572
            algo.future_chain('')
573
574
        # Check that invalid dates raise UnsupportedDatetimeFormat
575
        with self.assertRaises(UnsupportedDatetimeFormat):
576
            algo.future_chain('CL', 'my_finger_slipped')
577
578
        with self.assertRaises(UnsupportedDatetimeFormat):
579
            algo.future_chain('CL', '2015-09-')
580
581
        # Supplying a non-string argument to future_chain()
582
        # should result in a TypeError.
583
        with self.assertRaises(TypeError):
584
            algo.future_chain(1)
585
586
        with self.assertRaises(TypeError):
587
            algo.future_chain((1,))
588
589
        with self.assertRaises(TypeError):
590
            algo.future_chain({1})
591
592
        with self.assertRaises(TypeError):
593
            algo.future_chain([1])
594
595
        with self.assertRaises(TypeError):
596
            algo.future_chain({'foo': 'bar'})
597
598
    def test_set_symbol_lookup_date(self):
599
        """
600
        Test the set_symbol_lookup_date API method.
601
        """
602
        # Note we start sid enumeration at i+3 so as not to
603
        # collide with sids [1, 2] added in the setUp() method.
604
        dates = pd.date_range('2013-01-01', freq='2D', periods=2, tz='UTC')
605
        # Create two assets with the same symbol but different
606
        # non-overlapping date ranges.
607
        metadata = pd.DataFrame.from_records(
608
            [
609
                {
610
                    'sid': i + 3,
611
                    'symbol': 'DUP',
612
                    'start_date': date.value,
613
                    'end_date': (date + timedelta(days=1)).value,
614
                }
615
                for i, date in enumerate(dates)
616
            ]
617
        )
618
        env = TradingEnvironment()
619
        env.write_data(equities_df=metadata)
620
        algo = TradingAlgorithm(env=env)
621
622
        # Set the period end to a date after the period end
623
        # dates for our assets.
624
        algo.sim_params.period_end = pd.Timestamp('2015-01-01', tz='UTC')
625
626
        # With no symbol lookup date set, we will use the period end date
627
        # for the as_of_date, resulting here in the asset with the earlier
628
        # start date being returned.
629
        result = algo.symbol('DUP')
630
        self.assertEqual(result.symbol, 'DUP')
631
632
        # By first calling set_symbol_lookup_date, the relevant asset
633
        # should be returned by lookup_symbol
634
        for i, date in enumerate(dates):
635
            algo.set_symbol_lookup_date(date)
636
            result = algo.symbol('DUP')
637
            self.assertEqual(result.symbol, 'DUP')
638
            self.assertEqual(result.sid, i + 3)
639
640
        with self.assertRaises(UnsupportedDatetimeFormat):
641
            algo.set_symbol_lookup_date('foobar')
642
643
644
class TestTransformAlgorithm(TestCase):
645
646
    @classmethod
647
    def setUpClass(cls):
648
        cls.env = TradingEnvironment()
649
        cls.env.write_data(equities_identifiers=[0, 1, 133])
650
651
        futures_metadata = {0: {'contract_multiplier': 10}}
652
        cls.futures_env = TradingEnvironment()
653
        cls.futures_env.write_data(futures_data=futures_metadata)
654
655
    @classmethod
656
    def tearDownClass(cls):
657
        del cls.env
658
659
    def setUp(self):
660
        setup_logger(self)
661
        self.sim_params = factory.create_simulation_parameters(num_days=4,
662
                                                               env=self.env)
663
664
        trade_history = factory.create_trade_history(
665
            133,
666
            [10.0, 10.0, 11.0, 11.0],
667
            [100, 100, 100, 300],
668
            timedelta(days=1),
669
            self.sim_params,
670
            self.env
671
        )
672
        self.source = SpecificEquityTrades(
673
            event_list=trade_history,
674
            env=self.env,
675
        )
676
        self.df_source, self.df = \
677
            factory.create_test_df_source(self.sim_params, self.env)
678
679
        self.panel_source, self.panel = \
680
            factory.create_test_panel_source(self.sim_params, self.env)
681
682
    def tearDown(self):
683
        teardown_logger(self)
684
685
    def test_source_as_input(self):
686
        algo = TestRegisterTransformAlgorithm(
687
            sim_params=self.sim_params,
688
            env=self.env,
689
            sids=[133]
690
        )
691
        algo.run(self.source)
692
        self.assertEqual(len(algo.sources), 1)
693
        assert isinstance(algo.sources[0], SpecificEquityTrades)
694
695
    def test_invalid_order_parameters(self):
696
        algo = InvalidOrderAlgorithm(
697
            sids=[133],
698
            sim_params=self.sim_params,
699
            env=self.env,
700
        )
701
        algo.run(self.source)
702
703
    def test_multi_source_as_input(self):
704
        sim_params = SimulationParameters(
705
            self.df.index[0],
706
            self.df.index[-1],
707
            env=self.env,
708
        )
709
        algo = TestRegisterTransformAlgorithm(
710
            sim_params=sim_params,
711
            sids=[0, 1],
712
            env=self.env,
713
        )
714
        algo.run([self.source, self.df_source], overwrite_sim_params=False)
715
        self.assertEqual(len(algo.sources), 2)
716
717
    def test_df_as_input(self):
718
        algo = TestRegisterTransformAlgorithm(
719
            sim_params=self.sim_params,
720
            env=self.env,
721
        )
722
        algo.run(self.df)
723
        assert isinstance(algo.sources[0], DataFrameSource)
724
725
    def test_panel_as_input(self):
726
        algo = TestRegisterTransformAlgorithm(
727
            sim_params=self.sim_params,
728
            env=self.env,
729
            sids=[0, 1])
730
        panel = self.panel.copy()
731
        panel.items = pd.Index(map(Equity, panel.items))
732
        algo.run(panel)
733
        assert isinstance(algo.sources[0], DataPanelSource)
734
735
    def test_df_of_assets_as_input(self):
736
        algo = TestRegisterTransformAlgorithm(
737
            sim_params=self.sim_params,
738
            env=TradingEnvironment(),  # new env without assets
739
        )
740
        df = self.df.copy()
741
        df.columns = pd.Index(map(Equity, df.columns))
742
        algo.run(df)
743
        assert isinstance(algo.sources[0], DataFrameSource)
744
745
    def test_panel_of_assets_as_input(self):
746
        algo = TestRegisterTransformAlgorithm(
747
            sim_params=self.sim_params,
748
            env=TradingEnvironment(),  # new env without assets
749
            sids=[0, 1])
750
        algo.run(self.panel)
751
        assert isinstance(algo.sources[0], DataPanelSource)
752
753
    def test_run_twice(self):
754
        algo1 = TestRegisterTransformAlgorithm(
755
            sim_params=self.sim_params,
756
            sids=[0, 1]
757
        )
758
759
        res1 = algo1.run(self.df)
760
761
        # Create a new trading algorithm, which will
762
        # use the newly instantiated environment.
763
        algo2 = TestRegisterTransformAlgorithm(
764
            sim_params=self.sim_params,
765
            sids=[0, 1]
766
        )
767
768
        res2 = algo2.run(self.df)
769
770
        np.testing.assert_array_equal(res1, res2)
771
772
    def test_data_frequency_setting(self):
773
        self.sim_params.data_frequency = 'daily'
774
        algo = TestRegisterTransformAlgorithm(
775
            sim_params=self.sim_params,
776
            env=self.env,
777
        )
778
        self.assertEqual(algo.sim_params.data_frequency, 'daily')
779
780
        self.sim_params.data_frequency = 'minute'
781
        algo = TestRegisterTransformAlgorithm(
782
            sim_params=self.sim_params,
783
            env=self.env,
784
        )
785
        self.assertEqual(algo.sim_params.data_frequency, 'minute')
786
787
    @parameterized.expand([
788
        (TestOrderAlgorithm,),
789
        (TestOrderValueAlgorithm,),
790
        (TestTargetAlgorithm,),
791
        (TestOrderPercentAlgorithm,),
792
        (TestTargetPercentAlgorithm,),
793
        (TestTargetValueAlgorithm,),
794
    ])
795
    def test_order_methods(self, algo_class):
796
        algo = algo_class(
797
            sim_params=self.sim_params,
798
            env=self.env,
799
        )
800
801
        # Ensure that the environment's asset 0 is an Equity
802
        asset_to_test = algo.sid(0)
803
        self.assertIsInstance(asset_to_test, Equity)
804
805
        algo.run(self.df)
806
807
    @parameterized.expand([
808
        (TestOrderAlgorithm,),
809
        (TestOrderValueAlgorithm,),
810
        (TestTargetAlgorithm,),
811
        (TestOrderPercentAlgorithm,),
812
        (TestTargetValueAlgorithm,),
813
    ])
814
    def test_order_methods_for_future(self, algo_class):
815
        algo = algo_class(
816
            sim_params=self.sim_params,
817
            env=self.futures_env,
818
        )
819
820
        # Ensure that the environment's asset 0 is a Future
821
        asset_to_test = algo.sid(0)
822
        self.assertIsInstance(asset_to_test, Future)
823
824
        algo.run(self.df)
825
826
    def test_order_method_style_forwarding(self):
827
828
        method_names_to_test = ['order',
829
                                'order_value',
830
                                'order_percent',
831
                                'order_target',
832
                                'order_target_percent',
833
                                'order_target_value']
834
835
        for name in method_names_to_test:
836
            # Don't supply an env so the TradingAlgorithm builds a new one for
837
            # each method
838
            algo = TestOrderStyleForwardingAlgorithm(
839
                sim_params=self.sim_params,
840
                instant_fill=False,
841
                method_name=name
842
            )
843
            algo.run(self.df)
844
845
    def test_order_instant(self):
846
        algo = TestOrderInstantAlgorithm(sim_params=self.sim_params,
847
                                         env=self.env,
848
                                         instant_fill=True)
849
        algo.run(self.df)
850
851
    def test_minute_data(self):
852
        source = RandomWalkSource(freq='minute',
853
                                  start=pd.Timestamp('2000-1-3',
854
                                                     tz='UTC'),
855
                                  end=pd.Timestamp('2000-1-4',
856
                                                   tz='UTC'))
857
        self.sim_params.data_frequency = 'minute'
858
        algo = TestOrderInstantAlgorithm(sim_params=self.sim_params,
859
                                         env=self.env,
860
                                         instant_fill=True)
861
        algo.run(source)
862
863
864
class TestPositions(TestCase):
865
866
    def setUp(self):
867
        setup_logger(self)
868
        self.env = TradingEnvironment()
869
        self.sim_params = factory.create_simulation_parameters(num_days=4,
870
                                                               env=self.env)
871
        self.env.write_data(equities_identifiers=[1, 133])
872
873
        trade_history = factory.create_trade_history(
874
            1,
875
            [10.0, 10.0, 11.0, 11.0],
876
            [100, 100, 100, 300],
877
            timedelta(days=1),
878
            self.sim_params,
879
            self.env
880
        )
881
        self.source = SpecificEquityTrades(
882
            event_list=trade_history,
883
            env=self.env,
884
        )
885
886
        self.df_source, self.df = \
887
            factory.create_test_df_source(self.sim_params, self.env)
888
889
    def tearDown(self):
890
        teardown_logger(self)
891
892
    def test_empty_portfolio(self):
893
        algo = EmptyPositionsAlgorithm(sim_params=self.sim_params,
894
                                       env=self.env)
895
        daily_stats = algo.run(self.df)
896
897
        expected_position_count = [
898
            0,  # Before entering the first position
899
            1,  # After entering, exiting on this date
900
            0,  # After exiting
901
            0,
902
        ]
903
904
        for i, expected in enumerate(expected_position_count):
905
            self.assertEqual(daily_stats.ix[i]['num_positions'],
906
                             expected)
907
908
    def test_noop_orders(self):
909
910
        algo = AmbitiousStopLimitAlgorithm(sid=1,
911
                                           sim_params=self.sim_params,
912
                                           env=self.env)
913
        daily_stats = algo.run(self.source)
914
915
        # Verify that possitions are empty for all dates.
916
        empty_positions = daily_stats.positions.map(lambda x: len(x) == 0)
917
        self.assertTrue(empty_positions.all())
918
919
920
class TestAlgoScript(TestCase):
921
922
    @classmethod
923
    def setUpClass(cls):
924
        cls.env = TradingEnvironment()
925
        cls.env.write_data(
926
            equities_identifiers=[0, 1, 133]
927
        )
928
929
    @classmethod
930
    def tearDownClass(cls):
931
        del cls.env
932
933
    def setUp(self):
934
        days = 251
935
        # Note that create_simulation_parameters creates
936
        # a new TradingEnvironment
937
        self.sim_params = factory.create_simulation_parameters(num_days=days,
938
                                                               env=self.env)
939
940
        setup_logger(self)
941
        trade_history = factory.create_trade_history(
942
            133,
943
            [10.0] * days,
944
            [100] * days,
945
            timedelta(days=1),
946
            self.sim_params,
947
            self.env
948
        )
949
950
        self.source = SpecificEquityTrades(
951
            sids=[133],
952
            event_list=trade_history,
953
            env=self.env,
954
        )
955
956
        self.df_source, self.df = \
957
            factory.create_test_df_source(self.sim_params, self.env)
958
959
        self.zipline_test_config = {
960
            'sid': 0,
961
        }
962
963
    def tearDown(self):
964
        teardown_logger(self)
965
966
    def test_noop(self):
967
        algo = TradingAlgorithm(initialize=initialize_noop,
968
                                handle_data=handle_data_noop)
969
        algo.run(self.df)
970
971
    def test_noop_string(self):
972
        algo = TradingAlgorithm(script=noop_algo)
973
        algo.run(self.df)
974
975
    def test_api_calls(self):
976
        algo = TradingAlgorithm(initialize=initialize_api,
977
                                handle_data=handle_data_api)
978
        algo.run(self.df)
979
980
    def test_api_calls_string(self):
981
        algo = TradingAlgorithm(script=api_algo)
982
        algo.run(self.df)
983
984
    def test_api_get_environment(self):
985
        platform = 'zipline'
986
        # Use sid not already in test database.
987
        metadata = {3: {'symbol': 'TEST'}}
988
        algo = TradingAlgorithm(script=api_get_environment_algo,
989
                                equities_metadata=metadata,
990
                                platform=platform)
991
        algo.run(self.df)
992
        self.assertEqual(algo.environment, platform)
993
994
    def test_api_symbol(self):
995
        # Use sid not already in test database.
996
        metadata = {3: {'symbol': 'TEST'}}
997
        algo = TradingAlgorithm(script=api_symbol_algo,
998
                                equities_metadata=metadata)
999
        algo.run(self.df)
1000
1001
    def test_fixed_slippage(self):
1002
        # verify order -> transaction -> portfolio position.
1003
        # --------------
1004
        test_algo = TradingAlgorithm(
1005
            script="""
1006
from zipline.api import (slippage,
1007
                         commission,
1008
                         set_slippage,
1009
                         set_commission,
1010
                         order,
1011
                         record,
1012
                         sid)
1013
1014
def initialize(context):
1015
    model = slippage.FixedSlippage(spread=0.10)
1016
    set_slippage(model)
1017
    set_commission(commission.PerTrade(100.00))
1018
    context.count = 1
1019
    context.incr = 0
1020
1021
def handle_data(context, data):
1022
    if context.incr < context.count:
1023
        order(sid(0), -1000)
1024
    record(price=data[0].price)
1025
1026
    context.incr += 1""",
1027
            sim_params=self.sim_params,
1028
            env=self.env,
1029
        )
1030
        set_algo_instance(test_algo)
1031
1032
        self.zipline_test_config['algorithm'] = test_algo
1033
        self.zipline_test_config['trade_count'] = 200
1034
1035
        # this matches the value in the algotext initialize
1036
        # method, and will be used inside assert_single_position
1037
        # to confirm we have as many transactions as orders we
1038
        # placed.
1039
        self.zipline_test_config['order_count'] = 1
1040
1041
        zipline = simfactory.create_test_zipline(
1042
            **self.zipline_test_config)
1043
1044
        output, _ = assert_single_position(self, zipline)
1045
1046
        # confirm the slippage and commission on a sample
1047
        # transaction
1048
        recorded_price = output[1]['daily_perf']['recorded_vars']['price']
1049
        transaction = output[1]['daily_perf']['transactions'][0]
1050
        self.assertEqual(100.0, transaction['commission'])
1051
        expected_spread = 0.05
1052
        expected_commish = 0.10
1053
        expected_price = recorded_price - expected_spread - expected_commish
1054
        self.assertEqual(expected_price, transaction['price'])
1055
1056
    def test_volshare_slippage(self):
1057
        # verify order -> transaction -> portfolio position.
1058
        # --------------
1059
        test_algo = TradingAlgorithm(
1060
            script="""
1061
from zipline.api import *
1062
1063
def initialize(context):
1064
    model = slippage.VolumeShareSlippage(
1065
                            volume_limit=.3,
1066
                            price_impact=0.05
1067
                       )
1068
    set_slippage(model)
1069
    set_commission(commission.PerShare(0.02))
1070
    context.count = 2
1071
    context.incr = 0
1072
1073
def handle_data(context, data):
1074
    if context.incr < context.count:
1075
        # order small lots to be sure the
1076
        # order will fill in a single transaction
1077
        order(sid(0), 5000)
1078
    record(price=data[0].price)
1079
    record(volume=data[0].volume)
1080
    record(incr=context.incr)
1081
    context.incr += 1
1082
    """,
1083
            sim_params=self.sim_params,
1084
            env=self.env,
1085
        )
1086
        set_algo_instance(test_algo)
1087
1088
        self.zipline_test_config['algorithm'] = test_algo
1089
        self.zipline_test_config['trade_count'] = 100
1090
1091
        # 67 will be used inside assert_single_position
1092
        # to confirm we have as many transactions as expected.
1093
        # The algo places 2 trades of 5000 shares each. The trade
1094
        # events have volume ranging from 100 to 950. The volume cap
1095
        # of 0.3 limits the trade volume to a range of 30 - 316 shares.
1096
        # The spreadsheet linked below calculates the total position
1097
        # size over each bar, and predicts 67 txns will be required
1098
        # to fill the two orders. The number of bars and transactions
1099
        # differ because some bars result in multiple txns. See
1100
        # spreadsheet for details:
1101
# https://www.dropbox.com/s/ulrk2qt0nrtrigb/Volume%20Share%20Worksheet.xlsx
1102
        self.zipline_test_config['expected_transactions'] = 67
1103
1104
        zipline = simfactory.create_test_zipline(
1105
            **self.zipline_test_config)
1106
        output, _ = assert_single_position(self, zipline)
1107
1108
        # confirm the slippage and commission on a sample
1109
        # transaction
1110
        per_share_commish = 0.02
1111
        perf = output[1]
1112
        transaction = perf['daily_perf']['transactions'][0]
1113
        commish = transaction['amount'] * per_share_commish
1114
        self.assertEqual(commish, transaction['commission'])
1115
        self.assertEqual(2.029, transaction['price'])
1116
1117
    def test_algo_record_vars(self):
1118
        test_algo = TradingAlgorithm(
1119
            script=record_variables,
1120
            sim_params=self.sim_params,
1121
            env=self.env,
1122
        )
1123
        set_algo_instance(test_algo)
1124
1125
        self.zipline_test_config['algorithm'] = test_algo
1126
        self.zipline_test_config['trade_count'] = 200
1127
1128
        zipline = simfactory.create_test_zipline(
1129
            **self.zipline_test_config)
1130
        output, _ = drain_zipline(self, zipline)
1131
        self.assertEqual(len(output), 252)
1132
        incr = []
1133
        for o in output[:200]:
1134
            incr.append(o['daily_perf']['recorded_vars']['incr'])
1135
1136
        np.testing.assert_array_equal(incr, range(1, 201))
1137
1138
    def test_algo_record_allow_mock(self):
1139
        """
1140
        Test that values from "MagicMock"ed methods can be passed to record.
1141
1142
        Relevant for our basic/validation and methods like history, which
1143
        will end up returning a MagicMock instead of a DataFrame.
1144
        """
1145
        test_algo = TradingAlgorithm(
1146
            script=record_variables,
1147
            sim_params=self.sim_params,
1148
        )
1149
        set_algo_instance(test_algo)
1150
1151
        test_algo.record(foo=MagicMock())
1152
1153
    def _algo_record_float_magic_should_pass(self, var_type):
1154
        test_algo = TradingAlgorithm(
1155
            script=record_float_magic % var_type,
1156
            sim_params=self.sim_params,
1157
            env=self.env,
1158
        )
1159
        set_algo_instance(test_algo)
1160
1161
        self.zipline_test_config['algorithm'] = test_algo
1162
        self.zipline_test_config['trade_count'] = 200
1163
1164
        zipline = simfactory.create_test_zipline(
1165
            **self.zipline_test_config)
1166
        output, _ = drain_zipline(self, zipline)
1167
        self.assertEqual(len(output), 252)
1168
        incr = []
1169
        for o in output[:200]:
1170
            incr.append(o['daily_perf']['recorded_vars']['data'])
1171
        np.testing.assert_array_equal(incr, [np.nan] * 200)
1172
1173
    def test_algo_record_nan(self):
1174
        self._algo_record_float_magic_should_pass('nan')
1175
1176
    def test_order_methods(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
1177
        """
1178
        Only test that order methods can be called without error.
1179
        Correct filling of orders is tested in zipline.
1180
        """
1181
        test_algo = TradingAlgorithm(
1182
            script=call_all_order_methods,
1183
            sim_params=self.sim_params,
1184
            env=self.env,
1185
        )
1186
        set_algo_instance(test_algo)
1187
1188
        self.zipline_test_config['algorithm'] = test_algo
1189
        self.zipline_test_config['trade_count'] = 200
1190
1191
        zipline = simfactory.create_test_zipline(
1192
            **self.zipline_test_config)
1193
1194
        output, _ = drain_zipline(self, zipline)
1195
1196
    def test_order_in_init(self):
1197
        """
1198
        Test that calling order in initialize
1199
        will raise an error.
1200
        """
1201
        with self.assertRaises(OrderDuringInitialize):
1202
            test_algo = TradingAlgorithm(
1203
                script=call_order_in_init,
1204
                sim_params=self.sim_params,
1205
                env=self.env,
1206
            )
1207
            set_algo_instance(test_algo)
1208
            test_algo.run(self.source)
1209
1210
    def test_portfolio_in_init(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
1211
        """
1212
        Test that accessing portfolio in init doesn't break.
1213
        """
1214
        test_algo = TradingAlgorithm(
1215
            script=access_portfolio_in_init,
1216
            sim_params=self.sim_params,
1217
            env=self.env,
1218
        )
1219
        set_algo_instance(test_algo)
1220
1221
        self.zipline_test_config['algorithm'] = test_algo
1222
        self.zipline_test_config['trade_count'] = 1
1223
1224
        zipline = simfactory.create_test_zipline(
1225
            **self.zipline_test_config)
1226
1227
        output, _ = drain_zipline(self, zipline)
1228
1229
    def test_account_in_init(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
1230
        """
1231
        Test that accessing account in init doesn't break.
1232
        """
1233
        test_algo = TradingAlgorithm(
1234
            script=access_account_in_init,
1235
            sim_params=self.sim_params,
1236
            env=self.env,
1237
        )
1238
        set_algo_instance(test_algo)
1239
1240
        self.zipline_test_config['algorithm'] = test_algo
1241
        self.zipline_test_config['trade_count'] = 1
1242
1243
        zipline = simfactory.create_test_zipline(
1244
            **self.zipline_test_config)
1245
1246
        output, _ = drain_zipline(self, zipline)
1247
1248
1249
class TestHistory(TestCase):
1250
1251
    def setUp(self):
1252
        setup_logger(self)
1253
1254
    def tearDown(self):
1255
        teardown_logger(self)
1256
1257
    @classmethod
1258
    def setUpClass(cls):
1259
        cls._start = pd.Timestamp('1991-01-01', tz='UTC')
1260
        cls._end = pd.Timestamp('1991-01-15', tz='UTC')
1261
        cls.env = TradingEnvironment()
1262
        cls.sim_params = factory.create_simulation_parameters(
1263
            data_frequency='minute',
1264
            env=cls.env
1265
        )
1266
        cls.env.write_data(equities_identifiers=[0, 1])
1267
1268
    @classmethod
1269
    def tearDownClass(cls):
1270
        del cls.env
1271
1272
    @property
1273
    def source(self):
1274
        return RandomWalkSource(start=self._start, end=self._end)
1275
1276
    def test_history(self):
1277
        history_algo = """
1278
from zipline.api import history, add_history
1279
1280
def initialize(context):
1281
    add_history(10, '1d', 'price')
1282
1283
def handle_data(context, data):
1284
    df = history(10, '1d', 'price')
1285
"""
1286
1287
        algo = TradingAlgorithm(
1288
            script=history_algo,
1289
            sim_params=self.sim_params,
1290
            env=self.env,
1291
        )
1292
        output = algo.run(self.source)
1293
        self.assertIsNot(output, None)
1294
1295
    def test_history_without_add(self):
1296
        def handle_data(algo, data):
1297
            algo.history(1, '1m', 'price')
1298
1299
        algo = TradingAlgorithm(
1300
            initialize=lambda _: None,
1301
            handle_data=handle_data,
1302
            sim_params=self.sim_params,
1303
            env=self.env,
1304
        )
1305
        algo.run(self.source)
1306
1307
        self.assertIsNotNone(algo.history_container)
1308
        self.assertEqual(algo.history_container.buffer_panel.window_length, 1)
1309
1310
    def test_add_history_in_handle_data(self):
1311
        def handle_data(algo, data):
1312
            algo.add_history(1, '1m', 'price')
1313
1314
        algo = TradingAlgorithm(
1315
            initialize=lambda _: None,
1316
            handle_data=handle_data,
1317
            sim_params=self.sim_params,
1318
            env=self.env,
1319
        )
1320
        algo.run(self.source)
1321
1322
        self.assertIsNotNone(algo.history_container)
1323
        self.assertEqual(algo.history_container.buffer_panel.window_length, 1)
1324
1325
1326
class TestGetDatetime(TestCase):
1327
1328
    @classmethod
1329
    def setUpClass(cls):
1330
        cls.env = TradingEnvironment()
1331
        cls.env.write_data(equities_identifiers=[0, 1])
1332
1333
    @classmethod
1334
    def tearDownClass(cls):
1335
        del cls.env
1336
1337
    def setUp(self):
1338
        setup_logger(self)
1339
1340
    def tearDown(self):
1341
        teardown_logger(self)
1342
1343
    @parameterized.expand(
1344
        [
1345
            ('default', None,),
1346
            ('utc', 'UTC',),
1347
            ('us_east', 'US/Eastern',),
1348
        ]
1349
    )
1350
    def test_get_datetime(self, name, tz):
1351
1352
        algo = dedent(
1353
            """
1354
            import pandas as pd
1355
            from zipline.api import get_datetime
1356
1357
            def initialize(context):
1358
                context.tz = {tz} or 'UTC'
1359
                context.first_bar = True
1360
1361
            def handle_data(context, data):
1362
                if context.first_bar:
1363
                    dt = get_datetime({tz})
1364
                    if dt.tz.zone != context.tz:
1365
                        raise ValueError("Mismatched Zone")
1366
                    elif dt.tz_convert("US/Eastern").hour != 9:
1367
                        raise ValueError("Mismatched Hour")
1368
                    elif dt.tz_convert("US/Eastern").minute != 31:
1369
                        raise ValueError("Mismatched Minute")
1370
                context.first_bar = False
1371
            """.format(tz=repr(tz))
1372
        )
1373
1374
        start = to_utc('2014-01-02 9:31')
1375
        end = to_utc('2014-01-03 9:31')
1376
        source = RandomWalkSource(
1377
            start=start,
1378
            end=end,
1379
        )
1380
        sim_params = factory.create_simulation_parameters(
1381
            data_frequency='minute',
1382
            env=self.env,
1383
        )
1384
        algo = TradingAlgorithm(
1385
            script=algo,
1386
            sim_params=sim_params,
1387
            env=self.env,
1388
        )
1389
        algo.run(source)
1390
        self.assertFalse(algo.first_bar)
1391
1392
1393
class TestTradingControls(TestCase):
1394
1395
    @classmethod
1396
    def setUpClass(cls):
1397
        cls.sid = 133
1398
        cls.env = TradingEnvironment()
1399
        cls.env.write_data(equities_identifiers=[cls.sid])
1400
1401
    @classmethod
1402
    def tearDownClass(cls):
1403
        del cls.env
1404
1405
    def setUp(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
1406
        self.sim_params = factory.create_simulation_parameters(num_days=4,
1407
                                                               env=self.env)
1408
        self.trade_history = factory.create_trade_history(
1409
            self.sid,
1410
            [10.0, 10.0, 11.0, 11.0],
1411
            [100, 100, 100, 300],
1412
            timedelta(days=1),
1413
            self.sim_params,
1414
            self.env
1415
        )
1416
1417
        self.source = SpecificEquityTrades(
1418
            event_list=self.trade_history,
1419
            env=self.env,
1420
        )
1421
1422
    def _check_algo(self,
1423
                    algo,
1424
                    handle_data,
1425
                    expected_order_count,
1426
                    expected_exc):
1427
1428
        algo._handle_data = handle_data
1429
        with self.assertRaises(expected_exc) if expected_exc else nullctx():
1430
            algo.run(self.source)
1431
        self.assertEqual(algo.order_count, expected_order_count)
1432
        self.source.rewind()
1433
1434
    def check_algo_succeeds(self, algo, handle_data, order_count=4):
1435
        # Default for order_count assumes one order per handle_data call.
1436
        self._check_algo(algo, handle_data, order_count, None)
1437
1438
    def check_algo_fails(self, algo, handle_data, order_count):
1439
        self._check_algo(algo,
1440
                         handle_data,
1441
                         order_count,
1442
                         TradingControlViolation)
1443
1444
    def test_set_max_position_size(self):
1445
1446
        # Buy one share four times.  Should be fine.
1447
        def handle_data(algo, data):
1448
            algo.order(algo.sid(self.sid), 1)
1449
            algo.order_count += 1
1450
        algo = SetMaxPositionSizeAlgorithm(sid=self.sid,
1451
                                           max_shares=10,
1452
                                           max_notional=500.0,
1453
                                           sim_params=self.sim_params,
1454
                                           env=self.env)
1455
        self.check_algo_succeeds(algo, handle_data)
1456
1457
        # Buy three shares four times.  Should bail on the fourth before it's
1458
        # placed.
1459
        def handle_data(algo, data):
1460
            algo.order(algo.sid(self.sid), 3)
1461
            algo.order_count += 1
1462
1463
        algo = SetMaxPositionSizeAlgorithm(sid=self.sid,
1464
                                           max_shares=10,
1465
                                           max_notional=500.0,
1466
                                           sim_params=self.sim_params,
1467
                                           env=self.env)
1468
        self.check_algo_fails(algo, handle_data, 3)
1469
1470
        # Buy two shares four times. Should bail due to max_notional on the
1471
        # third attempt.
1472
        def handle_data(algo, data):
1473
            algo.order(algo.sid(self.sid), 3)
1474
            algo.order_count += 1
1475
1476
        algo = SetMaxPositionSizeAlgorithm(sid=self.sid,
1477
                                           max_shares=10,
1478
                                           max_notional=61.0,
1479
                                           sim_params=self.sim_params,
1480
                                           env=self.env)
1481
        self.check_algo_fails(algo, handle_data, 2)
1482
1483
        # Set the trading control to a different sid, then BUY ALL THE THINGS!.
1484
        # Should continue normally.
1485
        def handle_data(algo, data):
1486
            algo.order(algo.sid(self.sid), 10000)
1487
            algo.order_count += 1
1488
        algo = SetMaxPositionSizeAlgorithm(sid=self.sid + 1,
1489
                                           max_shares=10,
1490
                                           max_notional=61.0,
1491
                                           sim_params=self.sim_params,
1492
                                           env=self.env)
1493
        self.check_algo_succeeds(algo, handle_data)
1494
1495
        # Set the trading control sid to None, then BUY ALL THE THINGS!. Should
1496
        # fail because setting sid to None makes the control apply to all sids.
1497
        def handle_data(algo, data):
1498
            algo.order(algo.sid(self.sid), 10000)
1499
            algo.order_count += 1
1500
        algo = SetMaxPositionSizeAlgorithm(max_shares=10, max_notional=61.0,
1501
                                           sim_params=self.sim_params,
1502
                                           env=self.env)
1503
        self.check_algo_fails(algo, handle_data, 0)
1504
1505
    def test_set_do_not_order_list(self):
1506
        # set the restricted list to be the sid, and fail.
1507
        algo = SetDoNotOrderListAlgorithm(
1508
            sid=self.sid,
1509
            restricted_list=[self.sid],
1510
            sim_params=self.sim_params,
1511
            env=self.env,
1512
        )
1513
1514
        def handle_data(algo, data):
1515
            algo.order(algo.sid(self.sid), 100)
1516
            algo.order_count += 1
1517
1518
        self.check_algo_fails(algo, handle_data, 0)
1519
1520
        # set the restricted list to exclude the sid, and succeed
1521
        algo = SetDoNotOrderListAlgorithm(
1522
            sid=self.sid,
1523
            restricted_list=[134, 135, 136],
1524
            sim_params=self.sim_params,
1525
            env=self.env,
1526
        )
1527
1528
        def handle_data(algo, data):
1529
            algo.order(algo.sid(self.sid), 100)
1530
            algo.order_count += 1
1531
1532
        self.check_algo_succeeds(algo, handle_data)
1533
1534
    def test_set_max_order_size(self):
1535
1536
        # Buy one share.
1537
        def handle_data(algo, data):
1538
            algo.order(algo.sid(self.sid), 1)
1539
            algo.order_count += 1
1540
        algo = SetMaxOrderSizeAlgorithm(sid=self.sid,
1541
                                        max_shares=10,
1542
                                        max_notional=500.0,
1543
                                        sim_params=self.sim_params,
1544
                                        env=self.env)
1545
        self.check_algo_succeeds(algo, handle_data)
1546
1547
        # Buy 1, then 2, then 3, then 4 shares.  Bail on the last attempt
1548
        # because we exceed shares.
1549
        def handle_data(algo, data):
1550
            algo.order(algo.sid(self.sid), algo.order_count + 1)
1551
            algo.order_count += 1
1552
1553
        algo = SetMaxOrderSizeAlgorithm(sid=self.sid,
1554
                                        max_shares=3,
1555
                                        max_notional=500.0,
1556
                                        sim_params=self.sim_params,
1557
                                        env=self.env)
1558
        self.check_algo_fails(algo, handle_data, 3)
1559
1560
        # Buy 1, then 2, then 3, then 4 shares.  Bail on the last attempt
1561
        # because we exceed notional.
1562
        def handle_data(algo, data):
1563
            algo.order(algo.sid(self.sid), algo.order_count + 1)
1564
            algo.order_count += 1
1565
1566
        algo = SetMaxOrderSizeAlgorithm(sid=self.sid,
1567
                                        max_shares=10,
1568
                                        max_notional=40.0,
1569
                                        sim_params=self.sim_params,
1570
                                        env=self.env)
1571
        self.check_algo_fails(algo, handle_data, 3)
1572
1573
        # Set the trading control to a different sid, then BUY ALL THE THINGS!.
1574
        # Should continue normally.
1575
        def handle_data(algo, data):
1576
            algo.order(algo.sid(self.sid), 10000)
1577
            algo.order_count += 1
1578
        algo = SetMaxOrderSizeAlgorithm(sid=self.sid + 1,
1579
                                        max_shares=1,
1580
                                        max_notional=1.0,
1581
                                        sim_params=self.sim_params,
1582
                                        env=self.env)
1583
        self.check_algo_succeeds(algo, handle_data)
1584
1585
        # Set the trading control sid to None, then BUY ALL THE THINGS!.
1586
        # Should fail because not specifying a sid makes the trading control
1587
        # apply to all sids.
1588
        def handle_data(algo, data):
1589
            algo.order(algo.sid(self.sid), 10000)
1590
            algo.order_count += 1
1591
        algo = SetMaxOrderSizeAlgorithm(max_shares=1,
1592
                                        max_notional=1.0,
1593
                                        sim_params=self.sim_params,
1594
                                        env=self.env)
1595
        self.check_algo_fails(algo, handle_data, 0)
1596
1597
    def test_set_max_order_count(self):
1598
1599
        # Override the default setUp to use six-hour intervals instead of full
1600
        # days so we can exercise trading-session rollover logic.
1601
        trade_history = factory.create_trade_history(
1602
            self.sid,
1603
            [10.0, 10.0, 11.0, 11.0],
1604
            [100, 100, 100, 300],
1605
            timedelta(hours=6),
1606
            self.sim_params,
1607
            self.env
1608
        )
1609
        self.source = SpecificEquityTrades(event_list=trade_history,
1610
                                           env=self.env)
1611
1612
        def handle_data(algo, data):
1613
            for i in range(5):
1614
                algo.order(algo.sid(self.sid), 1)
1615
                algo.order_count += 1
1616
1617
        algo = SetMaxOrderCountAlgorithm(3, sim_params=self.sim_params,
1618
                                         env=self.env)
1619
        self.check_algo_fails(algo, handle_data, 3)
1620
1621
        # Second call to handle_data is the same day as the first, so the last
1622
        # order of the second call should fail.
1623
        algo = SetMaxOrderCountAlgorithm(9, sim_params=self.sim_params,
1624
                                         env=self.env)
1625
        self.check_algo_fails(algo, handle_data, 9)
1626
1627
        # Only ten orders are placed per day, so this should pass even though
1628
        # in total more than 20 orders are placed.
1629
        algo = SetMaxOrderCountAlgorithm(10, sim_params=self.sim_params,
1630
                                         env=self.env)
1631
        self.check_algo_succeeds(algo, handle_data, order_count=20)
1632
1633
    def test_long_only(self):
1634
        # Sell immediately -> fail immediately.
1635
        def handle_data(algo, data):
1636
            algo.order(algo.sid(self.sid), -1)
1637
            algo.order_count += 1
1638
        algo = SetLongOnlyAlgorithm(sim_params=self.sim_params, env=self.env)
1639
        self.check_algo_fails(algo, handle_data, 0)
1640
1641
        # Buy on even days, sell on odd days.  Never takes a short position, so
1642
        # should succeed.
1643
        def handle_data(algo, data):
1644
            if (algo.order_count % 2) == 0:
1645
                algo.order(algo.sid(self.sid), 1)
1646
            else:
1647
                algo.order(algo.sid(self.sid), -1)
1648
            algo.order_count += 1
1649
        algo = SetLongOnlyAlgorithm(sim_params=self.sim_params, env=self.env)
1650
        self.check_algo_succeeds(algo, handle_data)
1651
1652
        # Buy on first three days, then sell off holdings.  Should succeed.
1653
        def handle_data(algo, data):
1654
            amounts = [1, 1, 1, -3]
1655
            algo.order(algo.sid(self.sid), amounts[algo.order_count])
1656
            algo.order_count += 1
1657
        algo = SetLongOnlyAlgorithm(sim_params=self.sim_params, env=self.env)
1658
        self.check_algo_succeeds(algo, handle_data)
1659
1660
        # Buy on first three days, then sell off holdings plus an extra share.
1661
        # Should fail on the last sale.
1662
        def handle_data(algo, data):
1663
            amounts = [1, 1, 1, -4]
1664
            algo.order(algo.sid(self.sid), amounts[algo.order_count])
1665
            algo.order_count += 1
1666
        algo = SetLongOnlyAlgorithm(sim_params=self.sim_params, env=self.env)
1667
        self.check_algo_fails(algo, handle_data, 3)
1668
1669
    def test_register_post_init(self):
1670
1671
        def initialize(algo):
1672
            algo.initialized = True
1673
1674
        def handle_data(algo, data):
1675
1676
            with self.assertRaises(RegisterTradingControlPostInit):
1677
                algo.set_max_position_size(self.sid, 1, 1)
1678
            with self.assertRaises(RegisterTradingControlPostInit):
1679
                algo.set_max_order_size(self.sid, 1, 1)
1680
            with self.assertRaises(RegisterTradingControlPostInit):
1681
                algo.set_max_order_count(1)
1682
            with self.assertRaises(RegisterTradingControlPostInit):
1683
                algo.set_long_only()
1684
1685
        algo = TradingAlgorithm(initialize=initialize,
1686
                                handle_data=handle_data,
1687
                                sim_params=self.sim_params,
1688
                                env=self.env)
1689
        algo.run(self.source)
1690
        self.source.rewind()
1691
1692
    def test_asset_date_bounds(self):
1693
1694
        # Run the algorithm with a sid that ends far in the future
1695
        temp_env = TradingEnvironment()
1696
        df_source, _ = factory.create_test_df_source(self.sim_params, temp_env)
1697
        metadata = {0: {'start_date': '1990-01-01',
1698
                        'end_date': '2020-01-01'}}
1699
        algo = SetAssetDateBoundsAlgorithm(
1700
            equities_metadata=metadata,
1701
            sim_params=self.sim_params,
1702
            env=temp_env,
1703
        )
1704
        algo.run(df_source)
1705
1706
        # Run the algorithm with a sid that has already ended
1707
        temp_env = TradingEnvironment()
1708
        df_source, _ = factory.create_test_df_source(self.sim_params, temp_env)
1709
        metadata = {0: {'start_date': '1989-01-01',
1710
                        'end_date': '1990-01-01'}}
1711
        algo = SetAssetDateBoundsAlgorithm(
1712
            equities_metadata=metadata,
1713
            sim_params=self.sim_params,
1714
            env=temp_env,
1715
        )
1716
        with self.assertRaises(TradingControlViolation):
1717
            algo.run(df_source)
1718
1719
        # Run the algorithm with a sid that has not started
1720
        temp_env = TradingEnvironment()
1721
        df_source, _ = factory.create_test_df_source(self.sim_params, temp_env)
1722
        metadata = {0: {'start_date': '2020-01-01',
1723
                        'end_date': '2021-01-01'}}
1724
        algo = SetAssetDateBoundsAlgorithm(
1725
            equities_metadata=metadata,
1726
            sim_params=self.sim_params,
1727
            env=temp_env,
1728
        )
1729
        with self.assertRaises(TradingControlViolation):
1730
            algo.run(df_source)
1731
1732
        # Run the algorithm with a sid that starts on the first day and
1733
        # ends on the last day of the algorithm's parameters (*not* an error).
1734
        temp_env = TradingEnvironment()
1735
        df_source, _ = factory.create_test_df_source(self.sim_params, temp_env)
1736
        metadata = {0: {'start_date': '2006-01-03',
1737
                        'end_date': '2006-01-06'}}
1738
        algo = SetAssetDateBoundsAlgorithm(
1739
            equities_metadata=metadata,
1740
            sim_params=self.sim_params,
1741
            env=temp_env,
1742
        )
1743
        algo.run(df_source)
1744
1745
1746
class TestAccountControls(TestCase):
1747
1748
    @classmethod
1749
    def setUpClass(cls):
1750
        cls.sidint = 133
1751
        cls.env = TradingEnvironment()
1752
        cls.env.write_data(
1753
            equities_identifiers=[cls.sidint]
1754
        )
1755
1756
    @classmethod
1757
    def tearDownClass(cls):
1758
        del cls.env
1759
1760
    def setUp(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
1761
        self.sim_params = factory.create_simulation_parameters(
1762
            num_days=4, env=self.env
1763
        )
1764
        self.trade_history = factory.create_trade_history(
1765
            self.sidint,
1766
            [10.0, 10.0, 11.0, 11.0],
1767
            [100, 100, 100, 300],
1768
            timedelta(days=1),
1769
            self.sim_params,
1770
            self.env,
1771
        )
1772
1773
        self.source = SpecificEquityTrades(
1774
            event_list=self.trade_history,
1775
            env=self.env,
1776
        )
1777
1778
    def _check_algo(self,
1779
                    algo,
1780
                    handle_data,
1781
                    expected_exc):
1782
1783
        algo._handle_data = handle_data
1784
        with self.assertRaises(expected_exc) if expected_exc else nullctx():
1785
            algo.run(self.source)
1786
        self.source.rewind()
1787
1788
    def check_algo_succeeds(self, algo, handle_data):
1789
        # Default for order_count assumes one order per handle_data call.
1790
        self._check_algo(algo, handle_data, None)
1791
1792
    def check_algo_fails(self, algo, handle_data):
1793
        self._check_algo(algo,
1794
                         handle_data,
1795
                         AccountControlViolation)
1796
1797
    def test_set_max_leverage(self):
1798
1799
        # Set max leverage to 0 so buying one share fails.
1800
        def handle_data(algo, data):
1801
            algo.order(algo.sid(self.sidint), 1)
1802
1803
        algo = SetMaxLeverageAlgorithm(0, sim_params=self.sim_params,
1804
                                       env=self.env)
1805
        self.check_algo_fails(algo, handle_data)
1806
1807
        # Set max leverage to 1 so buying one share passes
1808
        def handle_data(algo, data):
1809
            algo.order(algo.sid(self.sidint), 1)
1810
1811
        algo = SetMaxLeverageAlgorithm(1,  sim_params=self.sim_params,
1812
                                       env=self.env)
1813
        self.check_algo_succeeds(algo, handle_data)
1814
1815
1816
class TestClosePosAlgo(TestCase):
1817
1818
    def setUp(self):
1819
        self.env = TradingEnvironment()
1820
        self.days = self.env.trading_days[:4]
1821
        self.panel = pd.Panel({1: pd.DataFrame({
1822
            'price': [1, 1, 2, 4], 'volume': [1e9, 1e9, 1e9, 0],
1823
            'type': [DATASOURCE_TYPE.TRADE,
1824
                     DATASOURCE_TYPE.TRADE,
1825
                     DATASOURCE_TYPE.TRADE,
1826
                     DATASOURCE_TYPE.CLOSE_POSITION]},
1827
            index=self.days)
1828
        })
1829
        self.no_close_panel = pd.Panel({1: pd.DataFrame({
1830
            'price': [1, 1, 2, 4], 'volume': [1e9, 1e9, 1e9, 1e9],
1831
            'type': [DATASOURCE_TYPE.TRADE,
1832
                     DATASOURCE_TYPE.TRADE,
1833
                     DATASOURCE_TYPE.TRADE,
1834
                     DATASOURCE_TYPE.TRADE]},
1835
            index=self.days)
1836
        })
1837
1838
    def test_close_position_equity(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
1839
        metadata = {1: {'symbol': 'TEST',
1840
                        'end_date': self.days[3]}}
1841
        self.env.write_data(equities_data=metadata)
1842
        algo = TestAlgorithm(sid=1, amount=1, order_count=1,
1843
                             commission=PerShare(0),
1844
                             env=self.env)
1845
        data = DataPanelSource(self.panel)
1846
1847
        # Check results
1848
        expected_positions = [0, 1, 1, 0]
1849
        expected_pnl = [0, 0, 1, 2]
1850
        results = algo.run(data)
1851
        self.check_algo_positions(results, expected_positions)
1852
        self.check_algo_pnl(results, expected_pnl)
1853
1854
    def test_close_position_future(self):
1855
        metadata = {1: {'symbol': 'TEST'}}
1856
        self.env.write_data(futures_data=metadata)
1857
        algo = TestAlgorithm(sid=1, amount=1, order_count=1,
1858
                             commission=PerShare(0),
1859
                             env=self.env)
1860
        data = DataPanelSource(self.panel)
1861
1862
        # Check results
1863
        expected_positions = [0, 1, 1, 0]
1864
        expected_pnl = [0, 0, 1, 2]
1865
        results = algo.run(data)
1866
        self.check_algo_pnl(results, expected_pnl)
1867
        self.check_algo_positions(results, expected_positions)
1868
1869
    def test_auto_close_future(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
1870
        metadata = {1: {'symbol': 'TEST',
1871
                        'auto_close_date': self.env.trading_days[4]}}
1872
        self.env.write_data(futures_data=metadata)
1873
        algo = TestAlgorithm(sid=1, amount=1, order_count=1,
1874
                             commission=PerShare(0),
1875
                             env=self.env)
1876
        data = DataPanelSource(self.no_close_panel)
1877
1878
        # Check results
1879
        results = algo.run(data)
1880
1881
        expected_positions = [0, 1, 1, 0]
1882
        self.check_algo_positions(results, expected_positions)
1883
1884
        expected_pnl = [0, 0, 1, 2]
1885
        self.check_algo_pnl(results, expected_pnl)
1886
1887
    def check_algo_pnl(self, results, expected_pnl):
1888
        np.testing.assert_array_almost_equal(results.pnl, expected_pnl)
1889
1890
    def check_algo_positions(self, results, expected_positions):
1891
        for i, amount in enumerate(results.positions):
1892
            if amount:
1893
                actual_position = amount[0]['amount']
1894
            else:
1895
                actual_position = 0
1896
1897
            self.assertEqual(
1898
                actual_position, expected_positions[i],
1899
                "position for day={0} not equal, actual={1}, expected={2}".
1900
                format(i, actual_position, expected_positions[i]))
1901
1902
1903
class TestFutureFlip(TestCase):
1904
    def setUp(self):
1905
        self.env = TradingEnvironment()
1906
        self.days = self.env.trading_days[:4]
1907
        self.trades_panel = pd.Panel({1: pd.DataFrame({
1908
            'price': [1, 2, 4], 'volume': [1e9, 1e9, 1e9],
1909
            'type': [DATASOURCE_TYPE.TRADE,
1910
                     DATASOURCE_TYPE.TRADE,
1911
                     DATASOURCE_TYPE.TRADE]},
1912
            index=self.days[:3])
1913
        })
1914
1915
    def test_flip_algo(self):
1916
        metadata = {1: {'symbol': 'TEST',
1917
                        'end_date': self.days[3],
1918
                        'contract_multiplier': 5}}
1919
        self.env.write_data(futures_data=metadata)
1920
1921
        algo = FutureFlipAlgo(sid=1, amount=1, env=self.env,
1922
                              commission=PerShare(0),
1923
                              order_count=0,  # not applicable but required
1924
                              instant_fill=True)
1925
        data = DataPanelSource(self.trades_panel)
1926
1927
        results = algo.run(data)
1928
1929
        expected_positions = [1, -1, 0]
1930
        self.check_algo_positions(results, expected_positions)
1931
1932
        expected_pnl = [0, 5, -10]
1933
        self.check_algo_pnl(results, expected_pnl)
1934
1935
    def check_algo_pnl(self, results, expected_pnl):
1936
        np.testing.assert_array_almost_equal(results.pnl, expected_pnl)
1937
1938
    def check_algo_positions(self, results, expected_positions):
1939
        for i, amount in enumerate(results.positions):
1940
            if amount:
1941
                actual_position = amount[0]['amount']
1942
            else:
1943
                actual_position = 0
1944
1945
            self.assertEqual(
1946
                actual_position, expected_positions[i],
1947
                "position for day={0} not equal, actual={1}, expected={2}".
1948
                format(i, actual_position, expected_positions[i]))
1949
1950
1951
class TestTradingAlgorithm(TestCase):
1952
    def setUp(self):
1953
        self.env = TradingEnvironment()
1954
        self.days = self.env.trading_days[:4]
1955
        self.panel = pd.Panel({1: pd.DataFrame({
1956
            'price': [1, 1, 2, 4], 'volume': [1e9, 1e9, 1e9, 0],
1957
            'type': [DATASOURCE_TYPE.TRADE,
1958
                     DATASOURCE_TYPE.TRADE,
1959
                     DATASOURCE_TYPE.TRADE,
1960
                     DATASOURCE_TYPE.CLOSE_POSITION]},
1961
            index=self.days)
1962
        })
1963
1964
    def test_analyze_called(self):
1965
        self.perf_ref = None
1966
1967
        def initialize(context):
1968
            pass
1969
1970
        def handle_data(context, data):
1971
            pass
1972
1973
        def analyze(context, perf):
1974
            self.perf_ref = perf
1975
1976
        algo = TradingAlgorithm(initialize=initialize, handle_data=handle_data,
1977
                                analyze=analyze)
1978
        results = algo.run(self.panel)
1979
        self.assertIs(results, self.perf_ref)
1980