Completed
Pull Request — master (#875)
by Eddie
02:28
created

tests.TestMiscellaneousAPI.handle_data()   A

Complexity

Conditions 3

Size

Total Lines 2

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 3
dl 0
loc 2
rs 10
1
#
2
# Copyright 2014 Quantopian, Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
import datetime
16
from datetime import timedelta
17
from mock import MagicMock
18
from nose_parameterized import parameterized
19
from six.moves import range, map
20
from textwrap import dedent
21
from unittest import TestCase
22
23
import numpy as np
24
import pandas as pd
25
26
from zipline.utils.api_support import ZiplineAPI
27
from zipline.utils.control_flow import nullctx
28
from zipline.utils.test_utils import (
29
    setup_logger,
30
    teardown_logger
31
)
32
import zipline.utils.factory as factory
33
import zipline.utils.simfactory as simfactory
34
35
from zipline.errors import (
36
    OrderDuringInitialize,
37
    RegisterTradingControlPostInit,
38
    TradingControlViolation,
39
    AccountControlViolation,
40
    SymbolNotFound,
41
    RootSymbolNotFound,
42
    UnsupportedDatetimeFormat,
43
)
44
from zipline.test_algorithms import (
45
    access_account_in_init,
46
    access_portfolio_in_init,
47
    AmbitiousStopLimitAlgorithm,
48
    EmptyPositionsAlgorithm,
49
    InvalidOrderAlgorithm,
50
    RecordAlgorithm,
51
    FutureFlipAlgo,
52
    TestAlgorithm,
53
    TestOrderAlgorithm,
54
    TestOrderInstantAlgorithm,
55
    TestOrderPercentAlgorithm,
56
    TestOrderStyleForwardingAlgorithm,
57
    TestOrderValueAlgorithm,
58
    TestRegisterTransformAlgorithm,
59
    TestTargetAlgorithm,
60
    TestTargetPercentAlgorithm,
61
    TestTargetValueAlgorithm,
62
    SetLongOnlyAlgorithm,
63
    SetAssetDateBoundsAlgorithm,
64
    SetMaxPositionSizeAlgorithm,
65
    SetMaxOrderCountAlgorithm,
66
    SetMaxOrderSizeAlgorithm,
67
    SetDoNotOrderListAlgorithm,
68
    SetMaxLeverageAlgorithm,
69
    api_algo,
70
    api_get_environment_algo,
71
    api_symbol_algo,
72
    call_all_order_methods,
73
    call_order_in_init,
74
    handle_data_api,
75
    handle_data_noop,
76
    initialize_api,
77
    initialize_noop,
78
    noop_algo,
79
    record_float_magic,
80
    record_variables,
81
)
82
from zipline.utils.context_tricks import CallbackManager
83
import zipline.utils.events
84
from zipline.utils.test_utils import (
85
    assert_single_position,
86
    drain_zipline,
87
    to_utc,
88
)
89
90
from zipline.sources import (SpecificEquityTrades,
91
                             DataFrameSource,
92
                             DataPanelSource,
93
                             RandomWalkSource)
94
from zipline.assets import Equity
95
96
from zipline.finance.execution import LimitOrder
97
from zipline.finance.trading import SimulationParameters
98
from zipline.utils.api_support import set_algo_instance
99
from zipline.utils.events import DateRuleFactory, TimeRuleFactory, Always
100
from zipline.algorithm import TradingAlgorithm
101
from zipline.protocol import DATASOURCE_TYPE
102
from zipline.finance.trading import TradingEnvironment
103
from zipline.finance.commission import PerShare
104
105
# Because test cases appear to reuse some resources.
106
_multiprocess_can_split_ = False
107
108
109
class TestRecordAlgorithm(TestCase):
110
111
    @classmethod
112
    def setUpClass(cls):
113
        cls.env = TradingEnvironment()
114
        cls.env.write_data(equities_identifiers=[133])
115
116
    @classmethod
117
    def tearDownClass(cls):
118
        del cls.env
119
120
    def setUp(self):
121
        self.sim_params = factory.create_simulation_parameters(num_days=4,
122
                                                               env=self.env)
123
        trade_history = factory.create_trade_history(
124
            133,
125
            [10.0, 10.0, 11.0, 11.0],
126
            [100, 100, 100, 300],
127
            timedelta(days=1),
128
            self.sim_params,
129
            self.env
130
        )
131
132
        self.source = SpecificEquityTrades(event_list=trade_history,
133
                                           env=self.env)
134
        self.df_source, self.df = \
135
            factory.create_test_df_source(self.sim_params, self.env)
136
137
    def test_record_incr(self):
138
        algo = RecordAlgorithm(sim_params=self.sim_params, env=self.env)
139
        output = algo.run(self.source)
140
141
        np.testing.assert_array_equal(output['incr'].values,
142
                                      range(1, len(output) + 1))
143
        np.testing.assert_array_equal(output['name'].values,
144
                                      range(1, len(output) + 1))
145
        np.testing.assert_array_equal(output['name2'].values,
146
                                      [2] * len(output))
147
        np.testing.assert_array_equal(output['name3'].values,
148
                                      range(1, len(output) + 1))
149
150
151
class TestMiscellaneousAPI(TestCase):
152
153
    @classmethod
154
    def setUpClass(cls):
155
        cls.sids = [1, 2]
156
        cls.env = TradingEnvironment()
157
158
        metadata = {3: {'symbol': 'PLAY',
159
                        'start_date': '2002-01-01',
160
                        'end_date': '2004-01-01'},
161
                    4: {'symbol': 'PLAY',
162
                        'start_date': '2005-01-01',
163
                        'end_date': '2006-01-01'}}
164
165
        futures_metadata = {
166
            5: {
167
                'symbol': 'CLG06',
168
                'root_symbol': 'CL',
169
                'start_date': pd.Timestamp('2005-12-01', tz='UTC'),
170
                'notice_date': pd.Timestamp('2005-12-20', tz='UTC'),
171
                'expiration_date': pd.Timestamp('2006-01-20', tz='UTC')},
172
            6: {
173
                'root_symbol': 'CL',
174
                'symbol': 'CLK06',
175
                'start_date': pd.Timestamp('2005-12-01', tz='UTC'),
176
                'notice_date': pd.Timestamp('2006-03-20', tz='UTC'),
177
                'expiration_date': pd.Timestamp('2006-04-20', tz='UTC')},
178
            7: {
179
                'symbol': 'CLQ06',
180
                'root_symbol': 'CL',
181
                'start_date': pd.Timestamp('2005-12-01', tz='UTC'),
182
                'notice_date': pd.Timestamp('2006-06-20', tz='UTC'),
183
                'expiration_date': pd.Timestamp('2006-07-20', tz='UTC')},
184
            8: {
185
                'symbol': 'CLX06',
186
                'root_symbol': 'CL',
187
                'start_date': pd.Timestamp('2006-02-01', tz='UTC'),
188
                'notice_date': pd.Timestamp('2006-09-20', tz='UTC'),
189
                'expiration_date': pd.Timestamp('2006-10-20', tz='UTC')}
190
        }
191
        cls.env.write_data(equities_identifiers=cls.sids,
192
                           equities_data=metadata,
193
                           futures_data=futures_metadata)
194
195
    @classmethod
196
    def tearDownClass(cls):
197
        del cls.env
198
199
    def setUp(self):
200
        setup_logger(self)
201
        self.sim_params = factory.create_simulation_parameters(
202
            num_days=2,
203
            data_frequency='minute',
204
            emission_rate='minute',
205
            env=self.env,
206
        )
207
        self.source = factory.create_minutely_trade_source(
208
            self.sids,
209
            sim_params=self.sim_params,
210
            concurrent=True,
211
            env=self.env,
212
        )
213
214
    def tearDown(self):
215
        teardown_logger(self)
216
217
    def test_zipline_api_resolves_dynamically(self):
218
        # Make a dummy algo.
219
        algo = TradingAlgorithm(
220
            initialize=lambda context: None,
221
            handle_data=lambda context, data: None,
222
            sim_params=self.sim_params,
223
        )
224
225
        # Verify that api methods get resolved dynamically by patching them out
226
        # and then calling them
227
        for method in algo.all_api_methods():
228
            name = method.__name__
229
            sentinel = object()
230
231
            def fake_method(*args, **kwargs):
232
                return sentinel
233
            setattr(algo, name, fake_method)
234
            with ZiplineAPI(algo):
235
                self.assertIs(sentinel, getattr(zipline.api, name)())
236
237
    def test_get_environment(self):
238
        expected_env = {
239
            'arena': 'backtest',
240
            'data_frequency': 'minute',
241
            'start': pd.Timestamp('2006-01-03 14:31:00+0000', tz='UTC'),
242
            'end': pd.Timestamp('2006-01-04 21:00:00+0000', tz='UTC'),
243
            'capital_base': 100000.0,
244
            'platform': 'zipline'
245
        }
246
247
        def initialize(algo):
248
            self.assertEqual('zipline', algo.get_environment())
249
            self.assertEqual(expected_env, algo.get_environment('*'))
250
251
        def handle_data(algo, data):
252
            pass
253
254
        algo = TradingAlgorithm(initialize=initialize,
255
                                handle_data=handle_data,
256
                                sim_params=self.sim_params,
257
                                env=self.env)
258
        algo.run(self.source)
259
260
    def test_get_open_orders(self):
261
262
        def initialize(algo):
263
            algo.minute = 0
264
265
        def handle_data(algo, data):
266
            if algo.minute == 0:
267
268
                # Should be filled by the next minute
269
                algo.order(algo.sid(1), 1)
270
271
                # Won't be filled because the price is too low.
272
                algo.order(algo.sid(2), 1, style=LimitOrder(0.01))
273
                algo.order(algo.sid(2), 1, style=LimitOrder(0.01))
274
                algo.order(algo.sid(2), 1, style=LimitOrder(0.01))
275
276
                all_orders = algo.get_open_orders()
277
                self.assertEqual(list(all_orders.keys()), [1, 2])
278
279
                self.assertEqual(all_orders[1], algo.get_open_orders(1))
280
                self.assertEqual(len(all_orders[1]), 1)
281
282
                self.assertEqual(all_orders[2], algo.get_open_orders(2))
283
                self.assertEqual(len(all_orders[2]), 3)
284
285
            if algo.minute == 1:
286
                # First order should have filled.
287
                # Second order should still be open.
288
                all_orders = algo.get_open_orders()
289
                self.assertEqual(list(all_orders.keys()), [2])
290
291
                self.assertEqual([], algo.get_open_orders(1))
292
293
                orders_2 = algo.get_open_orders(2)
294
                self.assertEqual(all_orders[2], orders_2)
295
                self.assertEqual(len(all_orders[2]), 3)
296
297
                for order in orders_2:
298
                    algo.cancel_order(order)
299
300
                all_orders = algo.get_open_orders()
301
                self.assertEqual(all_orders, {})
302
303
            algo.minute += 1
304
305
        algo = TradingAlgorithm(initialize=initialize,
306
                                handle_data=handle_data,
307
                                sim_params=self.sim_params,
308
                                env=self.env)
309
        algo.run(self.source)
310
311
    def test_schedule_function(self):
312
        date_rules = DateRuleFactory
313
        time_rules = TimeRuleFactory
314
315
        def incrementer(algo, data):
316
            algo.func_called += 1
317
            self.assertEqual(
318
                algo.get_datetime().time(),
319
                datetime.time(hour=14, minute=31),
320
            )
321
322
        def initialize(algo):
323
            algo.func_called = 0
324
            algo.days = 1
325
            algo.date = None
326
            algo.schedule_function(
327
                func=incrementer,
328
                date_rule=date_rules.every_day(),
329
                time_rule=time_rules.market_open(),
330
            )
331
332
        def handle_data(algo, data):
333
            if not algo.date:
334
                algo.date = algo.get_datetime().date()
335
336
            if algo.date < algo.get_datetime().date():
337
                algo.days += 1
338
                algo.date = algo.get_datetime().date()
339
340
        algo = TradingAlgorithm(
341
            initialize=initialize,
342
            handle_data=handle_data,
343
            sim_params=self.sim_params,
344
            env=self.env,
345
        )
346
        algo.run(self.source)
347
348
        self.assertEqual(algo.func_called, algo.days)
349
350
    def test_event_context(self):
351
        expected_data = []
352
        collected_data_pre = []
353
        collected_data_post = []
354
        function_stack = []
355
356
        def pre(data):
357
            function_stack.append(pre)
358
            collected_data_pre.append(data)
359
360
        def post(data):
361
            function_stack.append(post)
362
            collected_data_post.append(data)
363
364
        def initialize(context):
365
            context.add_event(Always(), f)
366
            context.add_event(Always(), g)
367
368
        def handle_data(context, data):
369
            function_stack.append(handle_data)
370
            expected_data.append(data)
371
372
        def f(context, data):
373
            function_stack.append(f)
374
375
        def g(context, data):
376
            function_stack.append(g)
377
378
        algo = TradingAlgorithm(
379
            initialize=initialize,
380
            handle_data=handle_data,
381
            sim_params=self.sim_params,
382
            create_event_context=CallbackManager(pre, post),
383
            env=self.env,
384
        )
385
        algo.run(self.source)
386
387
        self.assertEqual(len(expected_data), 779)
388
        self.assertEqual(collected_data_pre, expected_data)
389
        self.assertEqual(collected_data_post, expected_data)
390
391
        self.assertEqual(
392
            len(function_stack),
393
            779 * 5,
394
            'Incorrect number of functions called: %s != 779' %
395
            len(function_stack),
396
        )
397
        expected_functions = [pre, handle_data, f, g, post] * 779
398
        for n, (f, g) in enumerate(zip(function_stack, expected_functions)):
399
            self.assertEqual(
400
                f,
401
                g,
402
                'function at position %d was incorrect, expected %s but got %s'
403
                % (n, g.__name__, f.__name__),
404
            )
405
406
    @parameterized.expand([
407
        ('daily',),
408
        ('minute'),
409
    ])
410
    def test_schedule_funtion_rule_creation(self, mode):
411
        def nop(*args, **kwargs):
412
            return None
413
414
        self.sim_params.data_frequency = mode
415
        algo = TradingAlgorithm(
416
            initialize=nop,
417
            handle_data=nop,
418
            sim_params=self.sim_params,
419
            env=self.env,
420
        )
421
422
        # Schedule something for NOT Always.
423
        algo.schedule_function(nop, time_rule=zipline.utils.events.Never())
424
425
        event_rule = algo.event_manager._events[1].rule
426
427
        self.assertIsInstance(event_rule, zipline.utils.events.OncePerDay)
428
429
        inner_rule = event_rule.rule
430
        self.assertIsInstance(inner_rule, zipline.utils.events.ComposedRule)
431
432
        first = inner_rule.first
433
        second = inner_rule.second
434
        composer = inner_rule.composer
435
436
        self.assertIsInstance(first, zipline.utils.events.Always)
437
438
        if mode == 'daily':
439
            self.assertIsInstance(second, zipline.utils.events.Always)
440
        else:
441
            self.assertIsInstance(second, zipline.utils.events.Never)
442
443
        self.assertIs(composer, zipline.utils.events.ComposedRule.lazy_and)
444
445
    def test_asset_lookup(self):
446
447
        algo = TradingAlgorithm(env=self.env)
448
449
        # Test before either PLAY existed
450
        algo.sim_params.period_end = pd.Timestamp('2001-12-01', tz='UTC')
451
        with self.assertRaises(SymbolNotFound):
452
            algo.symbol('PLAY')
453
        with self.assertRaises(SymbolNotFound):
454
            algo.symbols('PLAY')
455
456
        # Test when first PLAY exists
457
        algo.sim_params.period_end = pd.Timestamp('2002-12-01', tz='UTC')
458
        list_result = algo.symbols('PLAY')
459
        self.assertEqual(3, list_result[0])
460
461
        # Test after first PLAY ends
462
        algo.sim_params.period_end = pd.Timestamp('2004-12-01', tz='UTC')
463
        self.assertEqual(3, algo.symbol('PLAY'))
464
465
        # Test after second PLAY begins
466
        algo.sim_params.period_end = pd.Timestamp('2005-12-01', tz='UTC')
467
        self.assertEqual(4, algo.symbol('PLAY'))
468
469
        # Test after second PLAY ends
470
        algo.sim_params.period_end = pd.Timestamp('2006-12-01', tz='UTC')
471
        self.assertEqual(4, algo.symbol('PLAY'))
472
        list_result = algo.symbols('PLAY')
473
        self.assertEqual(4, list_result[0])
474
475
        # Test lookup SID
476
        self.assertIsInstance(algo.sid(3), Equity)
477
        self.assertIsInstance(algo.sid(4), Equity)
478
479
        # Supplying a non-string argument to symbol()
480
        # should result in a TypeError.
481
        with self.assertRaises(TypeError):
482
            algo.symbol(1)
483
484
        with self.assertRaises(TypeError):
485
            algo.symbol((1,))
486
487
        with self.assertRaises(TypeError):
488
            algo.symbol({1})
489
490
        with self.assertRaises(TypeError):
491
            algo.symbol([1])
492
493
        with self.assertRaises(TypeError):
494
            algo.symbol({'foo': 'bar'})
495
496
    def test_future_symbol(self):
497
        """ Tests the future_symbol API function.
498
        """
499
        algo = TradingAlgorithm(env=self.env)
500
        algo.datetime = pd.Timestamp('2006-12-01', tz='UTC')
501
502
        # Check that we get the correct fields for the CLG06 symbol
503
        cl = algo.future_symbol('CLG06')
504
        self.assertEqual(cl.sid, 5)
505
        self.assertEqual(cl.symbol, 'CLG06')
506
        self.assertEqual(cl.root_symbol, 'CL')
507
        self.assertEqual(cl.start_date, pd.Timestamp('2005-12-01', tz='UTC'))
508
        self.assertEqual(cl.notice_date, pd.Timestamp('2005-12-20', tz='UTC'))
509
        self.assertEqual(cl.expiration_date,
510
                         pd.Timestamp('2006-01-20', tz='UTC'))
511
512
        with self.assertRaises(SymbolNotFound):
513
            algo.future_symbol('')
514
515
        with self.assertRaises(SymbolNotFound):
516
            algo.future_symbol('PLAY')
517
518
        with self.assertRaises(SymbolNotFound):
519
            algo.future_symbol('FOOBAR')
520
521
        # Supplying a non-string argument to future_symbol()
522
        # should result in a TypeError.
523
        with self.assertRaises(TypeError):
524
            algo.future_symbol(1)
525
526
        with self.assertRaises(TypeError):
527
            algo.future_symbol((1,))
528
529
        with self.assertRaises(TypeError):
530
            algo.future_symbol({1})
531
532
        with self.assertRaises(TypeError):
533
            algo.future_symbol([1])
534
535
        with self.assertRaises(TypeError):
536
            algo.future_symbol({'foo': 'bar'})
537
538
    def test_future_chain(self):
539
        """ Tests the future_chain API function.
540
        """
541
        algo = TradingAlgorithm(env=self.env)
542
        algo.datetime = pd.Timestamp('2006-12-01', tz='UTC')
543
544
        # Check that the fields of the FutureChain object are set correctly
545
        cl = algo.future_chain('CL')
546
        self.assertEqual(cl.root_symbol, 'CL')
547
        self.assertEqual(cl.as_of_date, algo.datetime)
548
549
        # Check the fields are set correctly if an as_of_date is supplied
550
        as_of_date = pd.Timestamp('1952-08-11', tz='UTC')
551
552
        cl = algo.future_chain('CL', as_of_date=as_of_date)
553
        self.assertEqual(cl.root_symbol, 'CL')
554
        self.assertEqual(cl.as_of_date, as_of_date)
555
556
        cl = algo.future_chain('CL', as_of_date='1952-08-11')
557
        self.assertEqual(cl.root_symbol, 'CL')
558
        self.assertEqual(cl.as_of_date, as_of_date)
559
560
        # Check that weird capitalization is corrected
561
        cl = algo.future_chain('cL')
562
        self.assertEqual(cl.root_symbol, 'CL')
563
564
        cl = algo.future_chain('cl')
565
        self.assertEqual(cl.root_symbol, 'CL')
566
567
        # Check that invalid root symbols raise RootSymbolNotFound
568
        with self.assertRaises(RootSymbolNotFound):
569
            algo.future_chain('CLZ')
570
571
        with self.assertRaises(RootSymbolNotFound):
572
            algo.future_chain('')
573
574
        # Check that invalid dates raise UnsupportedDatetimeFormat
575
        with self.assertRaises(UnsupportedDatetimeFormat):
576
            algo.future_chain('CL', 'my_finger_slipped')
577
578
        with self.assertRaises(UnsupportedDatetimeFormat):
579
            algo.future_chain('CL', '2015-09-')
580
581
        # Supplying a non-string argument to future_chain()
582
        # should result in a TypeError.
583
        with self.assertRaises(TypeError):
584
            algo.future_chain(1)
585
586
        with self.assertRaises(TypeError):
587
            algo.future_chain((1,))
588
589
        with self.assertRaises(TypeError):
590
            algo.future_chain({1})
591
592
        with self.assertRaises(TypeError):
593
            algo.future_chain([1])
594
595
        with self.assertRaises(TypeError):
596
            algo.future_chain({'foo': 'bar'})
597
598
    def test_set_symbol_lookup_date(self):
599
        """
600
        Test the set_symbol_lookup_date API method.
601
        """
602
        # Note we start sid enumeration at i+3 so as not to
603
        # collide with sids [1, 2] added in the setUp() method.
604
        dates = pd.date_range('2013-01-01', freq='2D', periods=2, tz='UTC')
605
        # Create two assets with the same symbol but different
606
        # non-overlapping date ranges.
607
        metadata = pd.DataFrame.from_records(
608
            [
609
                {
610
                    'sid': i + 3,
611
                    'symbol': 'DUP',
612
                    'start_date': date.value,
613
                    'end_date': (date + timedelta(days=1)).value,
614
                }
615
                for i, date in enumerate(dates)
616
            ]
617
        )
618
        env = TradingEnvironment()
619
        env.write_data(equities_df=metadata)
620
        algo = TradingAlgorithm(env=env)
621
622
        # Set the period end to a date after the period end
623
        # dates for our assets.
624
        algo.sim_params.period_end = pd.Timestamp('2015-01-01', tz='UTC')
625
626
        # With no symbol lookup date set, we will use the period end date
627
        # for the as_of_date, resulting here in the asset with the earlier
628
        # start date being returned.
629
        result = algo.symbol('DUP')
630
        self.assertEqual(result.symbol, 'DUP')
631
632
        # By first calling set_symbol_lookup_date, the relevant asset
633
        # should be returned by lookup_symbol
634
        for i, date in enumerate(dates):
635
            algo.set_symbol_lookup_date(date)
636
            result = algo.symbol('DUP')
637
            self.assertEqual(result.symbol, 'DUP')
638
            self.assertEqual(result.sid, i + 3)
639
640
        with self.assertRaises(UnsupportedDatetimeFormat):
641
            algo.set_symbol_lookup_date('foobar')
642
643
644
class TestTransformAlgorithm(TestCase):
645
646
    @classmethod
647
    def setUpClass(cls):
648
        futures_metadata = {3: {'contract_multiplier': 10}}
649
        cls.env = TradingEnvironment()
650
        cls.env.write_data(equities_identifiers=[0, 1, 133],
651
                           futures_data=futures_metadata)
652
653
    @classmethod
654
    def tearDownClass(cls):
655
        del cls.env
656
657
    def setUp(self):
658
        setup_logger(self)
659
        self.sim_params = factory.create_simulation_parameters(num_days=4,
660
                                                               env=self.env)
661
662
        trade_history = factory.create_trade_history(
663
            133,
664
            [10.0, 10.0, 11.0, 11.0],
665
            [100, 100, 100, 300],
666
            timedelta(days=1),
667
            self.sim_params,
668
            self.env
669
        )
670
        self.source = SpecificEquityTrades(
671
            event_list=trade_history,
672
            env=self.env,
673
        )
674
        self.df_source, self.df = \
675
            factory.create_test_df_source(self.sim_params, self.env)
676
677
        self.panel_source, self.panel = \
678
            factory.create_test_panel_source(self.sim_params, self.env)
679
680
    def tearDown(self):
681
        teardown_logger(self)
682
683
    def test_source_as_input(self):
684
        algo = TestRegisterTransformAlgorithm(
685
            sim_params=self.sim_params,
686
            env=self.env,
687
            sids=[133]
688
        )
689
        algo.run(self.source)
690
        self.assertEqual(len(algo.sources), 1)
691
        assert isinstance(algo.sources[0], SpecificEquityTrades)
692
693
    def test_invalid_order_parameters(self):
694
        algo = InvalidOrderAlgorithm(
695
            sids=[133],
696
            sim_params=self.sim_params,
697
            env=self.env,
698
        )
699
        algo.run(self.source)
700
701
    def test_multi_source_as_input(self):
702
        sim_params = SimulationParameters(
703
            self.df.index[0],
704
            self.df.index[-1],
705
            env=self.env,
706
        )
707
        algo = TestRegisterTransformAlgorithm(
708
            sim_params=sim_params,
709
            sids=[0, 1],
710
            env=self.env,
711
        )
712
        algo.run([self.source, self.df_source], overwrite_sim_params=False)
713
        self.assertEqual(len(algo.sources), 2)
714
715
    def test_df_as_input(self):
716
        algo = TestRegisterTransformAlgorithm(
717
            sim_params=self.sim_params,
718
            env=self.env,
719
        )
720
        algo.run(self.df)
721
        assert isinstance(algo.sources[0], DataFrameSource)
722
723
    def test_panel_as_input(self):
724
        algo = TestRegisterTransformAlgorithm(
725
            sim_params=self.sim_params,
726
            env=self.env,
727
            sids=[0, 1])
728
        panel = self.panel.copy()
729
        panel.items = pd.Index(map(Equity, panel.items))
730
        algo.run(panel)
731
        assert isinstance(algo.sources[0], DataPanelSource)
732
733
    def test_df_of_assets_as_input(self):
734
        algo = TestRegisterTransformAlgorithm(
735
            sim_params=self.sim_params,
736
            env=TradingEnvironment(),  # new env without assets
737
        )
738
        df = self.df.copy()
739
        df.columns = pd.Index(map(Equity, df.columns))
740
        algo.run(df)
741
        assert isinstance(algo.sources[0], DataFrameSource)
742
743
    def test_panel_of_assets_as_input(self):
744
        algo = TestRegisterTransformAlgorithm(
745
            sim_params=self.sim_params,
746
            env=TradingEnvironment(),  # new env without assets
747
            sids=[0, 1])
748
        algo.run(self.panel)
749
        assert isinstance(algo.sources[0], DataPanelSource)
750
751
    def test_run_twice(self):
752
        algo1 = TestRegisterTransformAlgorithm(
753
            sim_params=self.sim_params,
754
            sids=[0, 1]
755
        )
756
757
        res1 = algo1.run(self.df)
758
759
        # Create a new trading algorithm, which will
760
        # use the newly instantiated environment.
761
        algo2 = TestRegisterTransformAlgorithm(
762
            sim_params=self.sim_params,
763
            sids=[0, 1]
764
        )
765
766
        res2 = algo2.run(self.df)
767
768
        np.testing.assert_array_equal(res1, res2)
769
770
    def test_data_frequency_setting(self):
771
        self.sim_params.data_frequency = 'daily'
772
        algo = TestRegisterTransformAlgorithm(
773
            sim_params=self.sim_params,
774
            env=self.env,
775
        )
776
        self.assertEqual(algo.sim_params.data_frequency, 'daily')
777
778
        self.sim_params.data_frequency = 'minute'
779
        algo = TestRegisterTransformAlgorithm(
780
            sim_params=self.sim_params,
781
            env=self.env,
782
        )
783
        self.assertEqual(algo.sim_params.data_frequency, 'minute')
784
785
    @parameterized.expand([
786
        (TestOrderAlgorithm,),
787
        (TestOrderValueAlgorithm,),
788
        (TestTargetAlgorithm,),
789
        (TestOrderPercentAlgorithm,),
790
        (TestTargetPercentAlgorithm,),
791
        (TestTargetValueAlgorithm,),
792
    ])
793
    def test_order_methods(self, algo_class):
794
        algo = algo_class(
795
            sim_params=self.sim_params,
796
            env=self.env,
797
        )
798
        algo.run(self.df)
799
800
    @parameterized.expand([
801
        (TestOrderAlgorithm,),
802
        (TestOrderValueAlgorithm,),
803
        (TestTargetAlgorithm,),
804
        (TestOrderPercentAlgorithm,),
805
        (TestTargetValueAlgorithm,),
806
    ])
807
    def test_order_methods_for_future(self, algo_class):
808
        algo = algo_class(
809
            sim_params=self.sim_params,
810
            env=self.env,
811
        )
812
        algo.run(self.df)
813
814
    def test_order_method_style_forwarding(self):
815
816
        method_names_to_test = ['order',
817
                                'order_value',
818
                                'order_percent',
819
                                'order_target',
820
                                'order_target_percent',
821
                                'order_target_value']
822
823
        for name in method_names_to_test:
824
            # Don't supply an env so the TradingAlgorithm builds a new one for
825
            # each method
826
            algo = TestOrderStyleForwardingAlgorithm(
827
                sim_params=self.sim_params,
828
                instant_fill=False,
829
                method_name=name
830
            )
831
            algo.run(self.df)
832
833
    def test_order_instant(self):
834
        algo = TestOrderInstantAlgorithm(sim_params=self.sim_params,
835
                                         env=self.env,
836
                                         instant_fill=True)
837
        algo.run(self.df)
838
839
    def test_minute_data(self):
840
        source = RandomWalkSource(freq='minute',
841
                                  start=pd.Timestamp('2000-1-3',
842
                                                     tz='UTC'),
843
                                  end=pd.Timestamp('2000-1-4',
844
                                                   tz='UTC'))
845
        self.sim_params.data_frequency = 'minute'
846
        algo = TestOrderInstantAlgorithm(sim_params=self.sim_params,
847
                                         env=self.env,
848
                                         instant_fill=True)
849
        algo.run(source)
850
851
852
class TestPositions(TestCase):
853
854
    def setUp(self):
855
        setup_logger(self)
856
        self.env = TradingEnvironment()
857
        self.sim_params = factory.create_simulation_parameters(num_days=4,
858
                                                               env=self.env)
859
        self.env.write_data(equities_identifiers=[1, 133])
860
861
        trade_history = factory.create_trade_history(
862
            1,
863
            [10.0, 10.0, 11.0, 11.0],
864
            [100, 100, 100, 300],
865
            timedelta(days=1),
866
            self.sim_params,
867
            self.env
868
        )
869
        self.source = SpecificEquityTrades(
870
            event_list=trade_history,
871
            env=self.env,
872
        )
873
874
        self.df_source, self.df = \
875
            factory.create_test_df_source(self.sim_params, self.env)
876
877
    def tearDown(self):
878
        teardown_logger(self)
879
880
    def test_empty_portfolio(self):
881
        algo = EmptyPositionsAlgorithm(sim_params=self.sim_params,
882
                                       env=self.env)
883
        daily_stats = algo.run(self.df)
884
885
        expected_position_count = [
886
            0,  # Before entering the first position
887
            1,  # After entering, exiting on this date
888
            0,  # After exiting
889
            0,
890
        ]
891
892
        for i, expected in enumerate(expected_position_count):
893
            self.assertEqual(daily_stats.ix[i]['num_positions'],
894
                             expected)
895
896
    def test_noop_orders(self):
897
898
        algo = AmbitiousStopLimitAlgorithm(sid=1,
899
                                           sim_params=self.sim_params,
900
                                           env=self.env)
901
        daily_stats = algo.run(self.source)
902
903
        # Verify that possitions are empty for all dates.
904
        empty_positions = daily_stats.positions.map(lambda x: len(x) == 0)
905
        self.assertTrue(empty_positions.all())
906
907
908
class TestAlgoScript(TestCase):
909
910
    @classmethod
911
    def setUpClass(cls):
912
        cls.env = TradingEnvironment()
913
        cls.env.write_data(
914
            equities_identifiers=[0, 1, 133]
915
        )
916
917
    @classmethod
918
    def tearDownClass(cls):
919
        del cls.env
920
921
    def setUp(self):
922
        days = 251
923
        # Note that create_simulation_parameters creates
924
        # a new TradingEnvironment
925
        self.sim_params = factory.create_simulation_parameters(num_days=days,
926
                                                               env=self.env)
927
928
        setup_logger(self)
929
        trade_history = factory.create_trade_history(
930
            133,
931
            [10.0] * days,
932
            [100] * days,
933
            timedelta(days=1),
934
            self.sim_params,
935
            self.env
936
        )
937
938
        self.source = SpecificEquityTrades(
939
            sids=[133],
940
            event_list=trade_history,
941
            env=self.env,
942
        )
943
944
        self.df_source, self.df = \
945
            factory.create_test_df_source(self.sim_params, self.env)
946
947
        self.zipline_test_config = {
948
            'sid': 0,
949
        }
950
951
    def tearDown(self):
952
        teardown_logger(self)
953
954
    def test_noop(self):
955
        algo = TradingAlgorithm(initialize=initialize_noop,
956
                                handle_data=handle_data_noop)
957
        algo.run(self.df)
958
959
    def test_noop_string(self):
960
        algo = TradingAlgorithm(script=noop_algo)
961
        algo.run(self.df)
962
963
    def test_api_calls(self):
964
        algo = TradingAlgorithm(initialize=initialize_api,
965
                                handle_data=handle_data_api)
966
        algo.run(self.df)
967
968
    def test_api_calls_string(self):
969
        algo = TradingAlgorithm(script=api_algo)
970
        algo.run(self.df)
971
972
    def test_api_get_environment(self):
973
        platform = 'zipline'
974
        # Use sid not already in test database.
975
        metadata = {3: {'symbol': 'TEST'}}
976
        algo = TradingAlgorithm(script=api_get_environment_algo,
977
                                equities_metadata=metadata,
978
                                platform=platform)
979
        algo.run(self.df)
980
        self.assertEqual(algo.environment, platform)
981
982
    def test_api_symbol(self):
983
        # Use sid not already in test database.
984
        metadata = {3: {'symbol': 'TEST'}}
985
        algo = TradingAlgorithm(script=api_symbol_algo,
986
                                equities_metadata=metadata)
987
        algo.run(self.df)
988
989
    def test_fixed_slippage(self):
990
        # verify order -> transaction -> portfolio position.
991
        # --------------
992
        test_algo = TradingAlgorithm(
993
            script="""
994
from zipline.api import (slippage,
995
                         commission,
996
                         set_slippage,
997
                         set_commission,
998
                         order,
999
                         record,
1000
                         sid)
1001
1002
def initialize(context):
1003
    model = slippage.FixedSlippage(spread=0.10)
1004
    set_slippage(model)
1005
    set_commission(commission.PerTrade(100.00))
1006
    context.count = 1
1007
    context.incr = 0
1008
1009
def handle_data(context, data):
1010
    if context.incr < context.count:
1011
        order(sid(0), -1000)
1012
    record(price=data[0].price)
1013
1014
    context.incr += 1""",
1015
            sim_params=self.sim_params,
1016
            env=self.env,
1017
        )
1018
        set_algo_instance(test_algo)
1019
1020
        self.zipline_test_config['algorithm'] = test_algo
1021
        self.zipline_test_config['trade_count'] = 200
1022
1023
        # this matches the value in the algotext initialize
1024
        # method, and will be used inside assert_single_position
1025
        # to confirm we have as many transactions as orders we
1026
        # placed.
1027
        self.zipline_test_config['order_count'] = 1
1028
1029
        zipline = simfactory.create_test_zipline(
1030
            **self.zipline_test_config)
1031
1032
        output, _ = assert_single_position(self, zipline)
1033
1034
        # confirm the slippage and commission on a sample
1035
        # transaction
1036
        recorded_price = output[1]['daily_perf']['recorded_vars']['price']
1037
        transaction = output[1]['daily_perf']['transactions'][0]
1038
        self.assertEqual(100.0, transaction['commission'])
1039
        expected_spread = 0.05
1040
        expected_commish = 0.10
1041
        expected_price = recorded_price - expected_spread - expected_commish
1042
        self.assertEqual(expected_price, transaction['price'])
1043
1044
    def test_volshare_slippage(self):
1045
        # verify order -> transaction -> portfolio position.
1046
        # --------------
1047
        test_algo = TradingAlgorithm(
1048
            script="""
1049
from zipline.api import *
1050
1051
def initialize(context):
1052
    model = slippage.VolumeShareSlippage(
1053
                            volume_limit=.3,
1054
                            price_impact=0.05
1055
                       )
1056
    set_slippage(model)
1057
    set_commission(commission.PerShare(0.02))
1058
    context.count = 2
1059
    context.incr = 0
1060
1061
def handle_data(context, data):
1062
    if context.incr < context.count:
1063
        # order small lots to be sure the
1064
        # order will fill in a single transaction
1065
        order(sid(0), 5000)
1066
    record(price=data[0].price)
1067
    record(volume=data[0].volume)
1068
    record(incr=context.incr)
1069
    context.incr += 1
1070
    """,
1071
            sim_params=self.sim_params,
1072
            env=self.env,
1073
        )
1074
        set_algo_instance(test_algo)
1075
1076
        self.zipline_test_config['algorithm'] = test_algo
1077
        self.zipline_test_config['trade_count'] = 100
1078
1079
        # 67 will be used inside assert_single_position
1080
        # to confirm we have as many transactions as expected.
1081
        # The algo places 2 trades of 5000 shares each. The trade
1082
        # events have volume ranging from 100 to 950. The volume cap
1083
        # of 0.3 limits the trade volume to a range of 30 - 316 shares.
1084
        # The spreadsheet linked below calculates the total position
1085
        # size over each bar, and predicts 67 txns will be required
1086
        # to fill the two orders. The number of bars and transactions
1087
        # differ because some bars result in multiple txns. See
1088
        # spreadsheet for details:
1089
# https://www.dropbox.com/s/ulrk2qt0nrtrigb/Volume%20Share%20Worksheet.xlsx
1090
        self.zipline_test_config['expected_transactions'] = 67
1091
1092
        zipline = simfactory.create_test_zipline(
1093
            **self.zipline_test_config)
1094
        output, _ = assert_single_position(self, zipline)
1095
1096
        # confirm the slippage and commission on a sample
1097
        # transaction
1098
        per_share_commish = 0.02
1099
        perf = output[1]
1100
        transaction = perf['daily_perf']['transactions'][0]
1101
        commish = transaction['amount'] * per_share_commish
1102
        self.assertEqual(commish, transaction['commission'])
1103
        self.assertEqual(2.029, transaction['price'])
1104
1105
    def test_algo_record_vars(self):
1106
        test_algo = TradingAlgorithm(
1107
            script=record_variables,
1108
            sim_params=self.sim_params,
1109
            env=self.env,
1110
        )
1111
        set_algo_instance(test_algo)
1112
1113
        self.zipline_test_config['algorithm'] = test_algo
1114
        self.zipline_test_config['trade_count'] = 200
1115
1116
        zipline = simfactory.create_test_zipline(
1117
            **self.zipline_test_config)
1118
        output, _ = drain_zipline(self, zipline)
1119
        self.assertEqual(len(output), 252)
1120
        incr = []
1121
        for o in output[:200]:
1122
            incr.append(o['daily_perf']['recorded_vars']['incr'])
1123
1124
        np.testing.assert_array_equal(incr, range(1, 201))
1125
1126
    def test_algo_record_allow_mock(self):
1127
        """
1128
        Test that values from "MagicMock"ed methods can be passed to record.
1129
1130
        Relevant for our basic/validation and methods like history, which
1131
        will end up returning a MagicMock instead of a DataFrame.
1132
        """
1133
        test_algo = TradingAlgorithm(
1134
            script=record_variables,
1135
            sim_params=self.sim_params,
1136
        )
1137
        set_algo_instance(test_algo)
1138
1139
        test_algo.record(foo=MagicMock())
1140
1141
    def _algo_record_float_magic_should_pass(self, var_type):
1142
        test_algo = TradingAlgorithm(
1143
            script=record_float_magic % var_type,
1144
            sim_params=self.sim_params,
1145
            env=self.env,
1146
        )
1147
        set_algo_instance(test_algo)
1148
1149
        self.zipline_test_config['algorithm'] = test_algo
1150
        self.zipline_test_config['trade_count'] = 200
1151
1152
        zipline = simfactory.create_test_zipline(
1153
            **self.zipline_test_config)
1154
        output, _ = drain_zipline(self, zipline)
1155
        self.assertEqual(len(output), 252)
1156
        incr = []
1157
        for o in output[:200]:
1158
            incr.append(o['daily_perf']['recorded_vars']['data'])
1159
        np.testing.assert_array_equal(incr, [np.nan] * 200)
1160
1161
    def test_algo_record_nan(self):
1162
        self._algo_record_float_magic_should_pass('nan')
1163
1164
    def test_order_methods(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
1165
        """
1166
        Only test that order methods can be called without error.
1167
        Correct filling of orders is tested in zipline.
1168
        """
1169
        test_algo = TradingAlgorithm(
1170
            script=call_all_order_methods,
1171
            sim_params=self.sim_params,
1172
            env=self.env,
1173
        )
1174
        set_algo_instance(test_algo)
1175
1176
        self.zipline_test_config['algorithm'] = test_algo
1177
        self.zipline_test_config['trade_count'] = 200
1178
1179
        zipline = simfactory.create_test_zipline(
1180
            **self.zipline_test_config)
1181
1182
        output, _ = drain_zipline(self, zipline)
1183
1184
    def test_order_in_init(self):
1185
        """
1186
        Test that calling order in initialize
1187
        will raise an error.
1188
        """
1189
        with self.assertRaises(OrderDuringInitialize):
1190
            test_algo = TradingAlgorithm(
1191
                script=call_order_in_init,
1192
                sim_params=self.sim_params,
1193
                env=self.env,
1194
            )
1195
            set_algo_instance(test_algo)
1196
            test_algo.run(self.source)
1197
1198
    def test_portfolio_in_init(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
1199
        """
1200
        Test that accessing portfolio in init doesn't break.
1201
        """
1202
        test_algo = TradingAlgorithm(
1203
            script=access_portfolio_in_init,
1204
            sim_params=self.sim_params,
1205
            env=self.env,
1206
        )
1207
        set_algo_instance(test_algo)
1208
1209
        self.zipline_test_config['algorithm'] = test_algo
1210
        self.zipline_test_config['trade_count'] = 1
1211
1212
        zipline = simfactory.create_test_zipline(
1213
            **self.zipline_test_config)
1214
1215
        output, _ = drain_zipline(self, zipline)
1216
1217
    def test_account_in_init(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
1218
        """
1219
        Test that accessing account in init doesn't break.
1220
        """
1221
        test_algo = TradingAlgorithm(
1222
            script=access_account_in_init,
1223
            sim_params=self.sim_params,
1224
            env=self.env,
1225
        )
1226
        set_algo_instance(test_algo)
1227
1228
        self.zipline_test_config['algorithm'] = test_algo
1229
        self.zipline_test_config['trade_count'] = 1
1230
1231
        zipline = simfactory.create_test_zipline(
1232
            **self.zipline_test_config)
1233
1234
        output, _ = drain_zipline(self, zipline)
1235
1236
1237
class TestHistory(TestCase):
1238
1239
    def setUp(self):
1240
        setup_logger(self)
1241
1242
    def tearDown(self):
1243
        teardown_logger(self)
1244
1245
    @classmethod
1246
    def setUpClass(cls):
1247
        cls._start = pd.Timestamp('1991-01-01', tz='UTC')
1248
        cls._end = pd.Timestamp('1991-01-15', tz='UTC')
1249
        cls.env = TradingEnvironment()
1250
        cls.sim_params = factory.create_simulation_parameters(
1251
            data_frequency='minute',
1252
            env=cls.env
1253
        )
1254
        cls.env.write_data(equities_identifiers=[0, 1])
1255
1256
    @classmethod
1257
    def tearDownClass(cls):
1258
        del cls.env
1259
1260
    @property
1261
    def source(self):
1262
        return RandomWalkSource(start=self._start, end=self._end)
1263
1264
    def test_history(self):
1265
        history_algo = """
1266
from zipline.api import history, add_history
1267
1268
def initialize(context):
1269
    add_history(10, '1d', 'price')
1270
1271
def handle_data(context, data):
1272
    df = history(10, '1d', 'price')
1273
"""
1274
1275
        algo = TradingAlgorithm(
1276
            script=history_algo,
1277
            sim_params=self.sim_params,
1278
            env=self.env,
1279
        )
1280
        output = algo.run(self.source)
1281
        self.assertIsNot(output, None)
1282
1283
    def test_history_without_add(self):
1284
        def handle_data(algo, data):
1285
            algo.history(1, '1m', 'price')
1286
1287
        algo = TradingAlgorithm(
1288
            initialize=lambda _: None,
1289
            handle_data=handle_data,
1290
            sim_params=self.sim_params,
1291
            env=self.env,
1292
        )
1293
        algo.run(self.source)
1294
1295
        self.assertIsNotNone(algo.history_container)
1296
        self.assertEqual(algo.history_container.buffer_panel.window_length, 1)
1297
1298
    def test_add_history_in_handle_data(self):
1299
        def handle_data(algo, data):
1300
            algo.add_history(1, '1m', 'price')
1301
1302
        algo = TradingAlgorithm(
1303
            initialize=lambda _: None,
1304
            handle_data=handle_data,
1305
            sim_params=self.sim_params,
1306
            env=self.env,
1307
        )
1308
        algo.run(self.source)
1309
1310
        self.assertIsNotNone(algo.history_container)
1311
        self.assertEqual(algo.history_container.buffer_panel.window_length, 1)
1312
1313
1314
class TestGetDatetime(TestCase):
1315
1316
    @classmethod
1317
    def setUpClass(cls):
1318
        cls.env = TradingEnvironment()
1319
        cls.env.write_data(equities_identifiers=[0, 1])
1320
1321
    @classmethod
1322
    def tearDownClass(cls):
1323
        del cls.env
1324
1325
    def setUp(self):
1326
        setup_logger(self)
1327
1328
    def tearDown(self):
1329
        teardown_logger(self)
1330
1331
    @parameterized.expand(
1332
        [
1333
            ('default', None,),
1334
            ('utc', 'UTC',),
1335
            ('us_east', 'US/Eastern',),
1336
        ]
1337
    )
1338
    def test_get_datetime(self, name, tz):
1339
1340
        algo = dedent(
1341
            """
1342
            import pandas as pd
1343
            from zipline.api import get_datetime
1344
1345
            def initialize(context):
1346
                context.tz = {tz} or 'UTC'
1347
                context.first_bar = True
1348
1349
            def handle_data(context, data):
1350
                if context.first_bar:
1351
                    dt = get_datetime({tz})
1352
                    if dt.tz.zone != context.tz:
1353
                        raise ValueError("Mismatched Zone")
1354
                    elif dt.tz_convert("US/Eastern").hour != 9:
1355
                        raise ValueError("Mismatched Hour")
1356
                    elif dt.tz_convert("US/Eastern").minute != 31:
1357
                        raise ValueError("Mismatched Minute")
1358
                context.first_bar = False
1359
            """.format(tz=repr(tz))
1360
        )
1361
1362
        start = to_utc('2014-01-02 9:31')
1363
        end = to_utc('2014-01-03 9:31')
1364
        source = RandomWalkSource(
1365
            start=start,
1366
            end=end,
1367
        )
1368
        sim_params = factory.create_simulation_parameters(
1369
            data_frequency='minute',
1370
            env=self.env,
1371
        )
1372
        algo = TradingAlgorithm(
1373
            script=algo,
1374
            sim_params=sim_params,
1375
            env=self.env,
1376
        )
1377
        algo.run(source)
1378
        self.assertFalse(algo.first_bar)
1379
1380
1381
class TestTradingControls(TestCase):
1382
1383
    @classmethod
1384
    def setUpClass(cls):
1385
        cls.sid = 133
1386
        cls.env = TradingEnvironment()
1387
        cls.env.write_data(equities_identifiers=[cls.sid])
1388
1389
    @classmethod
1390
    def tearDownClass(cls):
1391
        del cls.env
1392
1393
    def setUp(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
1394
        self.sim_params = factory.create_simulation_parameters(num_days=4,
1395
                                                               env=self.env)
1396
        self.trade_history = factory.create_trade_history(
1397
            self.sid,
1398
            [10.0, 10.0, 11.0, 11.0],
1399
            [100, 100, 100, 300],
1400
            timedelta(days=1),
1401
            self.sim_params,
1402
            self.env
1403
        )
1404
1405
        self.source = SpecificEquityTrades(
1406
            event_list=self.trade_history,
1407
            env=self.env,
1408
        )
1409
1410
    def _check_algo(self,
1411
                    algo,
1412
                    handle_data,
1413
                    expected_order_count,
1414
                    expected_exc):
1415
1416
        algo._handle_data = handle_data
1417
        with self.assertRaises(expected_exc) if expected_exc else nullctx():
1418
            algo.run(self.source)
1419
        self.assertEqual(algo.order_count, expected_order_count)
1420
        self.source.rewind()
1421
1422
    def check_algo_succeeds(self, algo, handle_data, order_count=4):
1423
        # Default for order_count assumes one order per handle_data call.
1424
        self._check_algo(algo, handle_data, order_count, None)
1425
1426
    def check_algo_fails(self, algo, handle_data, order_count):
1427
        self._check_algo(algo,
1428
                         handle_data,
1429
                         order_count,
1430
                         TradingControlViolation)
1431
1432
    def test_set_max_position_size(self):
1433
1434
        # Buy one share four times.  Should be fine.
1435
        def handle_data(algo, data):
1436
            algo.order(algo.sid(self.sid), 1)
1437
            algo.order_count += 1
1438
        algo = SetMaxPositionSizeAlgorithm(sid=self.sid,
1439
                                           max_shares=10,
1440
                                           max_notional=500.0,
1441
                                           sim_params=self.sim_params,
1442
                                           env=self.env)
1443
        self.check_algo_succeeds(algo, handle_data)
1444
1445
        # Buy three shares four times.  Should bail on the fourth before it's
1446
        # placed.
1447
        def handle_data(algo, data):
1448
            algo.order(algo.sid(self.sid), 3)
1449
            algo.order_count += 1
1450
1451
        algo = SetMaxPositionSizeAlgorithm(sid=self.sid,
1452
                                           max_shares=10,
1453
                                           max_notional=500.0,
1454
                                           sim_params=self.sim_params,
1455
                                           env=self.env)
1456
        self.check_algo_fails(algo, handle_data, 3)
1457
1458
        # Buy two shares four times. Should bail due to max_notional on the
1459
        # third attempt.
1460
        def handle_data(algo, data):
1461
            algo.order(algo.sid(self.sid), 3)
1462
            algo.order_count += 1
1463
1464
        algo = SetMaxPositionSizeAlgorithm(sid=self.sid,
1465
                                           max_shares=10,
1466
                                           max_notional=61.0,
1467
                                           sim_params=self.sim_params,
1468
                                           env=self.env)
1469
        self.check_algo_fails(algo, handle_data, 2)
1470
1471
        # Set the trading control to a different sid, then BUY ALL THE THINGS!.
1472
        # Should continue normally.
1473
        def handle_data(algo, data):
1474
            algo.order(algo.sid(self.sid), 10000)
1475
            algo.order_count += 1
1476
        algo = SetMaxPositionSizeAlgorithm(sid=self.sid + 1,
1477
                                           max_shares=10,
1478
                                           max_notional=61.0,
1479
                                           sim_params=self.sim_params,
1480
                                           env=self.env)
1481
        self.check_algo_succeeds(algo, handle_data)
1482
1483
        # Set the trading control sid to None, then BUY ALL THE THINGS!. Should
1484
        # fail because setting sid to None makes the control apply to all sids.
1485
        def handle_data(algo, data):
1486
            algo.order(algo.sid(self.sid), 10000)
1487
            algo.order_count += 1
1488
        algo = SetMaxPositionSizeAlgorithm(max_shares=10, max_notional=61.0,
1489
                                           sim_params=self.sim_params,
1490
                                           env=self.env)
1491
        self.check_algo_fails(algo, handle_data, 0)
1492
1493
    def test_set_do_not_order_list(self):
1494
        # set the restricted list to be the sid, and fail.
1495
        algo = SetDoNotOrderListAlgorithm(
1496
            sid=self.sid,
1497
            restricted_list=[self.sid],
1498
            sim_params=self.sim_params,
1499
            env=self.env,
1500
        )
1501
1502
        def handle_data(algo, data):
1503
            algo.order(algo.sid(self.sid), 100)
1504
            algo.order_count += 1
1505
1506
        self.check_algo_fails(algo, handle_data, 0)
1507
1508
        # set the restricted list to exclude the sid, and succeed
1509
        algo = SetDoNotOrderListAlgorithm(
1510
            sid=self.sid,
1511
            restricted_list=[134, 135, 136],
1512
            sim_params=self.sim_params,
1513
            env=self.env,
1514
        )
1515
1516
        def handle_data(algo, data):
1517
            algo.order(algo.sid(self.sid), 100)
1518
            algo.order_count += 1
1519
1520
        self.check_algo_succeeds(algo, handle_data)
1521
1522
    def test_set_max_order_size(self):
1523
1524
        # Buy one share.
1525
        def handle_data(algo, data):
1526
            algo.order(algo.sid(self.sid), 1)
1527
            algo.order_count += 1
1528
        algo = SetMaxOrderSizeAlgorithm(sid=self.sid,
1529
                                        max_shares=10,
1530
                                        max_notional=500.0,
1531
                                        sim_params=self.sim_params,
1532
                                        env=self.env)
1533
        self.check_algo_succeeds(algo, handle_data)
1534
1535
        # Buy 1, then 2, then 3, then 4 shares.  Bail on the last attempt
1536
        # because we exceed shares.
1537
        def handle_data(algo, data):
1538
            algo.order(algo.sid(self.sid), algo.order_count + 1)
1539
            algo.order_count += 1
1540
1541
        algo = SetMaxOrderSizeAlgorithm(sid=self.sid,
1542
                                        max_shares=3,
1543
                                        max_notional=500.0,
1544
                                        sim_params=self.sim_params,
1545
                                        env=self.env)
1546
        self.check_algo_fails(algo, handle_data, 3)
1547
1548
        # Buy 1, then 2, then 3, then 4 shares.  Bail on the last attempt
1549
        # because we exceed notional.
1550
        def handle_data(algo, data):
1551
            algo.order(algo.sid(self.sid), algo.order_count + 1)
1552
            algo.order_count += 1
1553
1554
        algo = SetMaxOrderSizeAlgorithm(sid=self.sid,
1555
                                        max_shares=10,
1556
                                        max_notional=40.0,
1557
                                        sim_params=self.sim_params,
1558
                                        env=self.env)
1559
        self.check_algo_fails(algo, handle_data, 3)
1560
1561
        # Set the trading control to a different sid, then BUY ALL THE THINGS!.
1562
        # Should continue normally.
1563
        def handle_data(algo, data):
1564
            algo.order(algo.sid(self.sid), 10000)
1565
            algo.order_count += 1
1566
        algo = SetMaxOrderSizeAlgorithm(sid=self.sid + 1,
1567
                                        max_shares=1,
1568
                                        max_notional=1.0,
1569
                                        sim_params=self.sim_params,
1570
                                        env=self.env)
1571
        self.check_algo_succeeds(algo, handle_data)
1572
1573
        # Set the trading control sid to None, then BUY ALL THE THINGS!.
1574
        # Should fail because not specifying a sid makes the trading control
1575
        # apply to all sids.
1576
        def handle_data(algo, data):
1577
            algo.order(algo.sid(self.sid), 10000)
1578
            algo.order_count += 1
1579
        algo = SetMaxOrderSizeAlgorithm(max_shares=1,
1580
                                        max_notional=1.0,
1581
                                        sim_params=self.sim_params,
1582
                                        env=self.env)
1583
        self.check_algo_fails(algo, handle_data, 0)
1584
1585
    def test_set_max_order_count(self):
1586
1587
        # Override the default setUp to use six-hour intervals instead of full
1588
        # days so we can exercise trading-session rollover logic.
1589
        trade_history = factory.create_trade_history(
1590
            self.sid,
1591
            [10.0, 10.0, 11.0, 11.0],
1592
            [100, 100, 100, 300],
1593
            timedelta(hours=6),
1594
            self.sim_params,
1595
            self.env
1596
        )
1597
        self.source = SpecificEquityTrades(event_list=trade_history,
1598
                                           env=self.env)
1599
1600
        def handle_data(algo, data):
1601
            for i in range(5):
1602
                algo.order(algo.sid(self.sid), 1)
1603
                algo.order_count += 1
1604
1605
        algo = SetMaxOrderCountAlgorithm(3, sim_params=self.sim_params,
1606
                                         env=self.env)
1607
        self.check_algo_fails(algo, handle_data, 3)
1608
1609
        # Second call to handle_data is the same day as the first, so the last
1610
        # order of the second call should fail.
1611
        algo = SetMaxOrderCountAlgorithm(9, sim_params=self.sim_params,
1612
                                         env=self.env)
1613
        self.check_algo_fails(algo, handle_data, 9)
1614
1615
        # Only ten orders are placed per day, so this should pass even though
1616
        # in total more than 20 orders are placed.
1617
        algo = SetMaxOrderCountAlgorithm(10, sim_params=self.sim_params,
1618
                                         env=self.env)
1619
        self.check_algo_succeeds(algo, handle_data, order_count=20)
1620
1621
    def test_long_only(self):
1622
        # Sell immediately -> fail immediately.
1623
        def handle_data(algo, data):
1624
            algo.order(algo.sid(self.sid), -1)
1625
            algo.order_count += 1
1626
        algo = SetLongOnlyAlgorithm(sim_params=self.sim_params, env=self.env)
1627
        self.check_algo_fails(algo, handle_data, 0)
1628
1629
        # Buy on even days, sell on odd days.  Never takes a short position, so
1630
        # should succeed.
1631
        def handle_data(algo, data):
1632
            if (algo.order_count % 2) == 0:
1633
                algo.order(algo.sid(self.sid), 1)
1634
            else:
1635
                algo.order(algo.sid(self.sid), -1)
1636
            algo.order_count += 1
1637
        algo = SetLongOnlyAlgorithm(sim_params=self.sim_params, env=self.env)
1638
        self.check_algo_succeeds(algo, handle_data)
1639
1640
        # Buy on first three days, then sell off holdings.  Should succeed.
1641
        def handle_data(algo, data):
1642
            amounts = [1, 1, 1, -3]
1643
            algo.order(algo.sid(self.sid), amounts[algo.order_count])
1644
            algo.order_count += 1
1645
        algo = SetLongOnlyAlgorithm(sim_params=self.sim_params, env=self.env)
1646
        self.check_algo_succeeds(algo, handle_data)
1647
1648
        # Buy on first three days, then sell off holdings plus an extra share.
1649
        # Should fail on the last sale.
1650
        def handle_data(algo, data):
1651
            amounts = [1, 1, 1, -4]
1652
            algo.order(algo.sid(self.sid), amounts[algo.order_count])
1653
            algo.order_count += 1
1654
        algo = SetLongOnlyAlgorithm(sim_params=self.sim_params, env=self.env)
1655
        self.check_algo_fails(algo, handle_data, 3)
1656
1657
    def test_register_post_init(self):
1658
1659
        def initialize(algo):
1660
            algo.initialized = True
1661
1662
        def handle_data(algo, data):
1663
1664
            with self.assertRaises(RegisterTradingControlPostInit):
1665
                algo.set_max_position_size(self.sid, 1, 1)
1666
            with self.assertRaises(RegisterTradingControlPostInit):
1667
                algo.set_max_order_size(self.sid, 1, 1)
1668
            with self.assertRaises(RegisterTradingControlPostInit):
1669
                algo.set_max_order_count(1)
1670
            with self.assertRaises(RegisterTradingControlPostInit):
1671
                algo.set_long_only()
1672
1673
        algo = TradingAlgorithm(initialize=initialize,
1674
                                handle_data=handle_data,
1675
                                sim_params=self.sim_params,
1676
                                env=self.env)
1677
        algo.run(self.source)
1678
        self.source.rewind()
1679
1680
    def test_asset_date_bounds(self):
1681
1682
        # Run the algorithm with a sid that ends far in the future
1683
        temp_env = TradingEnvironment()
1684
        df_source, _ = factory.create_test_df_source(self.sim_params, temp_env)
1685
        metadata = {0: {'start_date': '1990-01-01',
1686
                        'end_date': '2020-01-01'}}
1687
        algo = SetAssetDateBoundsAlgorithm(
1688
            equities_metadata=metadata,
1689
            sim_params=self.sim_params,
1690
            env=temp_env,
1691
        )
1692
        algo.run(df_source)
1693
1694
        # Run the algorithm with a sid that has already ended
1695
        temp_env = TradingEnvironment()
1696
        df_source, _ = factory.create_test_df_source(self.sim_params, temp_env)
1697
        metadata = {0: {'start_date': '1989-01-01',
1698
                        'end_date': '1990-01-01'}}
1699
        algo = SetAssetDateBoundsAlgorithm(
1700
            equities_metadata=metadata,
1701
            sim_params=self.sim_params,
1702
            env=temp_env,
1703
        )
1704
        with self.assertRaises(TradingControlViolation):
1705
            algo.run(df_source)
1706
1707
        # Run the algorithm with a sid that has not started
1708
        temp_env = TradingEnvironment()
1709
        df_source, _ = factory.create_test_df_source(self.sim_params, temp_env)
1710
        metadata = {0: {'start_date': '2020-01-01',
1711
                        'end_date': '2021-01-01'}}
1712
        algo = SetAssetDateBoundsAlgorithm(
1713
            equities_metadata=metadata,
1714
            sim_params=self.sim_params,
1715
            env=temp_env,
1716
        )
1717
        with self.assertRaises(TradingControlViolation):
1718
            algo.run(df_source)
1719
1720
        # Run the algorithm with a sid that starts on the first day and
1721
        # ends on the last day of the algorithm's parameters (*not* an error).
1722
        temp_env = TradingEnvironment()
1723
        df_source, _ = factory.create_test_df_source(self.sim_params, temp_env)
1724
        metadata = {0: {'start_date': '2006-01-03',
1725
                        'end_date': '2006-01-06'}}
1726
        algo = SetAssetDateBoundsAlgorithm(
1727
            equities_metadata=metadata,
1728
            sim_params=self.sim_params,
1729
            env=temp_env,
1730
        )
1731
        algo.run(df_source)
1732
1733
1734
class TestAccountControls(TestCase):
1735
1736
    @classmethod
1737
    def setUpClass(cls):
1738
        cls.sidint = 133
1739
        cls.env = TradingEnvironment()
1740
        cls.env.write_data(
1741
            equities_identifiers=[cls.sidint]
1742
        )
1743
1744
    @classmethod
1745
    def tearDownClass(cls):
1746
        del cls.env
1747
1748
    def setUp(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
1749
        self.sim_params = factory.create_simulation_parameters(
1750
            num_days=4, env=self.env
1751
        )
1752
        self.trade_history = factory.create_trade_history(
1753
            self.sidint,
1754
            [10.0, 10.0, 11.0, 11.0],
1755
            [100, 100, 100, 300],
1756
            timedelta(days=1),
1757
            self.sim_params,
1758
            self.env,
1759
        )
1760
1761
        self.source = SpecificEquityTrades(
1762
            event_list=self.trade_history,
1763
            env=self.env,
1764
        )
1765
1766
    def _check_algo(self,
1767
                    algo,
1768
                    handle_data,
1769
                    expected_exc):
1770
1771
        algo._handle_data = handle_data
1772
        with self.assertRaises(expected_exc) if expected_exc else nullctx():
1773
            algo.run(self.source)
1774
        self.source.rewind()
1775
1776
    def check_algo_succeeds(self, algo, handle_data):
1777
        # Default for order_count assumes one order per handle_data call.
1778
        self._check_algo(algo, handle_data, None)
1779
1780
    def check_algo_fails(self, algo, handle_data):
1781
        self._check_algo(algo,
1782
                         handle_data,
1783
                         AccountControlViolation)
1784
1785
    def test_set_max_leverage(self):
1786
1787
        # Set max leverage to 0 so buying one share fails.
1788
        def handle_data(algo, data):
1789
            algo.order(algo.sid(self.sidint), 1)
1790
1791
        algo = SetMaxLeverageAlgorithm(0, sim_params=self.sim_params,
1792
                                       env=self.env)
1793
        self.check_algo_fails(algo, handle_data)
1794
1795
        # Set max leverage to 1 so buying one share passes
1796
        def handle_data(algo, data):
1797
            algo.order(algo.sid(self.sidint), 1)
1798
1799
        algo = SetMaxLeverageAlgorithm(1,  sim_params=self.sim_params,
1800
                                       env=self.env)
1801
        self.check_algo_succeeds(algo, handle_data)
1802
1803
1804
class TestClosePosAlgo(TestCase):
1805
1806
    def setUp(self):
1807
        self.env = TradingEnvironment()
1808
        self.days = self.env.trading_days[:4]
1809
        self.panel = pd.Panel({1: pd.DataFrame({
1810
            'price': [1, 1, 2, 4], 'volume': [1e9, 1e9, 1e9, 0],
1811
            'type': [DATASOURCE_TYPE.TRADE,
1812
                     DATASOURCE_TYPE.TRADE,
1813
                     DATASOURCE_TYPE.TRADE,
1814
                     DATASOURCE_TYPE.CLOSE_POSITION]},
1815
            index=self.days)
1816
        })
1817
        self.no_close_panel = pd.Panel({1: pd.DataFrame({
1818
            'price': [1, 1, 2, 4], 'volume': [1e9, 1e9, 1e9, 1e9],
1819
            'type': [DATASOURCE_TYPE.TRADE,
1820
                     DATASOURCE_TYPE.TRADE,
1821
                     DATASOURCE_TYPE.TRADE,
1822
                     DATASOURCE_TYPE.TRADE]},
1823
            index=self.days)
1824
        })
1825
1826
    def test_close_position_equity(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
1827
        metadata = {1: {'symbol': 'TEST',
1828
                        'end_date': self.days[3]}}
1829
        self.env.write_data(equities_data=metadata)
1830
        algo = TestAlgorithm(sid=1, amount=1, order_count=1,
1831
                             commission=PerShare(0),
1832
                             env=self.env)
1833
        data = DataPanelSource(self.panel)
1834
1835
        # Check results
1836
        expected_positions = [0, 1, 1, 0]
1837
        expected_pnl = [0, 0, 1, 2]
1838
        results = algo.run(data)
1839
        self.check_algo_positions(results, expected_positions)
1840
        self.check_algo_pnl(results, expected_pnl)
1841
1842
    def test_close_position_future(self):
1843
        metadata = {1: {'symbol': 'TEST'}}
1844
        self.env.write_data(futures_data=metadata)
1845
        algo = TestAlgorithm(sid=1, amount=1, order_count=1,
1846
                             commission=PerShare(0),
1847
                             env=self.env)
1848
        data = DataPanelSource(self.panel)
1849
1850
        # Check results
1851
        expected_positions = [0, 1, 1, 0]
1852
        expected_pnl = [0, 0, 1, 2]
1853
        results = algo.run(data)
1854
        self.check_algo_pnl(results, expected_pnl)
1855
        self.check_algo_positions(results, expected_positions)
1856
1857
    def test_auto_close_future(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
1858
        metadata = {1: {'symbol': 'TEST',
1859
                        'auto_close_date': self.env.trading_days[4]}}
1860
        self.env.write_data(futures_data=metadata)
1861
        algo = TestAlgorithm(sid=1, amount=1, order_count=1,
1862
                             commission=PerShare(0),
1863
                             env=self.env)
1864
        data = DataPanelSource(self.no_close_panel)
1865
1866
        # Check results
1867
        results = algo.run(data)
1868
1869
        expected_positions = [0, 1, 1, 0]
1870
        self.check_algo_positions(results, expected_positions)
1871
1872
        expected_pnl = [0, 0, 1, 2]
1873
        self.check_algo_pnl(results, expected_pnl)
1874
1875
    def check_algo_pnl(self, results, expected_pnl):
1876
        np.testing.assert_array_almost_equal(results.pnl, expected_pnl)
1877
1878
    def check_algo_positions(self, results, expected_positions):
1879
        for i, amount in enumerate(results.positions):
1880
            if amount:
1881
                actual_position = amount[0]['amount']
1882
            else:
1883
                actual_position = 0
1884
1885
            self.assertEqual(
1886
                actual_position, expected_positions[i],
1887
                "position for day={0} not equal, actual={1}, expected={2}".
1888
                format(i, actual_position, expected_positions[i]))
1889
1890
1891
class TestFutureFlip(TestCase):
1892
    def setUp(self):
1893
        self.env = TradingEnvironment()
1894
        self.days = self.env.trading_days[:4]
1895
        self.trades_panel = pd.Panel({1: pd.DataFrame({
1896
            'price': [1, 2, 4], 'volume': [1e9, 1e9, 1e9],
1897
            'type': [DATASOURCE_TYPE.TRADE,
1898
                     DATASOURCE_TYPE.TRADE,
1899
                     DATASOURCE_TYPE.TRADE]},
1900
            index=self.days[:3])
1901
        })
1902
1903
    def test_flip_algo(self):
1904
        metadata = {1: {'symbol': 'TEST',
1905
                        'end_date': self.days[3],
1906
                        'contract_multiplier': 5}}
1907
        self.env.write_data(futures_data=metadata)
1908
1909
        algo = FutureFlipAlgo(sid=1, amount=1, env=self.env,
1910
                              commission=PerShare(0),
1911
                              order_count=0,  # not applicable but required
1912
                              instant_fill=True)
1913
        data = DataPanelSource(self.trades_panel)
1914
1915
        results = algo.run(data)
1916
1917
        expected_positions = [1, -1, 0]
1918
        self.check_algo_positions(results, expected_positions)
1919
1920
        expected_pnl = [0, 5, -10]
1921
        self.check_algo_pnl(results, expected_pnl)
1922
1923
    def check_algo_pnl(self, results, expected_pnl):
1924
        np.testing.assert_array_almost_equal(results.pnl, expected_pnl)
1925
1926
    def check_algo_positions(self, results, expected_positions):
1927
        for i, amount in enumerate(results.positions):
1928
            if amount:
1929
                actual_position = amount[0]['amount']
1930
            else:
1931
                actual_position = 0
1932
1933
            self.assertEqual(
1934
                actual_position, expected_positions[i],
1935
                "position for day={0} not equal, actual={1}, expected={2}".
1936
                format(i, actual_position, expected_positions[i]))
1937
1938
1939
class TestTradingAlgorithm(TestCase):
1940
    def setUp(self):
1941
        self.env = TradingEnvironment()
1942
        self.days = self.env.trading_days[:4]
1943
        self.panel = pd.Panel({1: pd.DataFrame({
1944
            'price': [1, 1, 2, 4], 'volume': [1e9, 1e9, 1e9, 0],
1945
            'type': [DATASOURCE_TYPE.TRADE,
1946
                     DATASOURCE_TYPE.TRADE,
1947
                     DATASOURCE_TYPE.TRADE,
1948
                     DATASOURCE_TYPE.CLOSE_POSITION]},
1949
            index=self.days)
1950
        })
1951
1952
    def test_analyze_called(self):
1953
        self.perf_ref = None
1954
1955
        def initialize(context):
1956
            pass
1957
1958
        def handle_data(context, data):
1959
            pass
1960
1961
        def analyze(context, perf):
1962
            self.perf_ref = perf
1963
1964
        algo = TradingAlgorithm(initialize=initialize, handle_data=handle_data,
1965
                                analyze=analyze)
1966
        results = algo.run(self.panel)
1967
        self.assertIs(results, self.perf_ref)
1968