Completed
Pull Request — master (#867)
by Joe
02:09
created

tests.VersioningTestCase   A

Complexity

Total Complexity 13

Size/Duplication

Total Lines 63
Duplicated Lines 0 %
Metric Value
dl 0
loc 63
rs 10
wmc 13

2 Methods

Rating   Name   Duplication   Size   Complexity  
F test_object_serialization() 0 48 9
A load_state_from_disk() 0 11 4
1
#
2
# Copyright 2014 Quantopian, Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
from unittest import TestCase
17
from itertools import product
18
from textwrap import dedent
19
import warnings
20
21
from nose_parameterized import parameterized
22
import numpy as np
23
import pandas as pd
24
from pandas.util.testing import assert_frame_equal
25
from pandas.tseries.tools import normalize_date
26
27
from .history_cases import (
28
    HISTORY_CONTAINER_TEST_CASES,
29
)
30
from zipline import TradingAlgorithm
31
from zipline.errors import HistoryInInitialize, IncompatibleHistoryFrequency
32
from zipline.finance import trading
33
from zipline.finance.trading import (
34
    SimulationParameters,
35
    TradingEnvironment,
36
)
37
from zipline.history import history
38
from zipline.history.history_container import HistoryContainer
39
from zipline.protocol import BarData
40
from zipline.sources import RandomWalkSource, DataFrameSource
41
import zipline.utils.factory as factory
42
from zipline.utils.test_utils import subtest
43
44
# Cases are over the July 4th holiday, to ensure use of trading calendar.
45
46
#      March 2013
47
# Su Mo Tu We Th Fr Sa
48
#                 1  2
49
#  3  4  5  6  7  8  9
50
# 10 11 12 13 14 15 16
51
# 17 18 19 20 21 22 23
52
# 24 25 26 27 28 29 30
53
# 31
54
#      April 2013
55
# Su Mo Tu We Th Fr Sa
56
#     1  2  3  4  5  6
57
#  7  8  9 10 11 12 13
58
# 14 15 16 17 18 19 20
59
# 21 22 23 24 25 26 27
60
# 28 29 30
61
#
62
#       May 2013
63
# Su Mo Tu We Th Fr Sa
64
#           1  2  3  4
65
#  5  6  7  8  9 10 11
66
# 12 13 14 15 16 17 18
67
# 19 20 21 22 23 24 25
68
# 26 27 28 29 30 31
69
#
70
#      June 2013
71
# Su Mo Tu We Th Fr Sa
72
#                    1
73
#  2  3  4  5  6  7  8
74
#  9 10 11 12 13 14 15
75
# 16 17 18 19 20 21 22
76
# 23 24 25 26 27 28 29
77
# 30
78
#      July 2013
79
# Su Mo Tu We Th Fr Sa
80
#     1  2  3  4  5  6
81
#  7  8  9 10 11 12 13
82
# 14 15 16 17 18 19 20
83
# 21 22 23 24 25 26 27
84
# 28 29 30 31
85
#
86
# Times to be converted via:
87
# pd.Timestamp('2013-07-05 9:31', tz='US/Eastern').tz_convert('UTC')},
88
89
INDEX_TEST_CASES_RAW = {
90
    'week of daily data': {
91
        'input': {'bar_count': 5,
92
                  'frequency': '1d',
93
                  'algo_dt': '2013-07-05 9:31AM'},
94
        'expected': [
95
            '2013-06-28 4:00PM',
96
            '2013-07-01 4:00PM',
97
            '2013-07-02 4:00PM',
98
            '2013-07-03 1:00PM',
99
            '2013-07-05 9:31AM',
100
        ]
101
    },
102
    'five minutes on july 5th open': {
103
        'input': {'bar_count': 5,
104
                  'frequency': '1m',
105
                  'algo_dt': '2013-07-05 9:31AM'},
106
        'expected': [
107
            '2013-07-03 12:57PM',
108
            '2013-07-03 12:58PM',
109
            '2013-07-03 12:59PM',
110
            '2013-07-03 1:00PM',
111
            '2013-07-05 9:31AM',
112
        ]
113
    },
114
}
115
116
117
def to_timestamp(dt_str):
118
    return pd.Timestamp(dt_str, tz='US/Eastern').tz_convert('UTC')
119
120
121
def convert_cases(cases):
122
    """
123
    Convert raw strings to values comparable with system data.
124
    """
125
    cases = cases.copy()
126
    for case in cases.values():
127
        case['input']['algo_dt'] = to_timestamp(case['input']['algo_dt'])
128
        case['expected'] = pd.DatetimeIndex([to_timestamp(dt_str) for dt_str
129
                                             in case['expected']])
130
    return cases
131
132
INDEX_TEST_CASES = convert_cases(INDEX_TEST_CASES_RAW)
133
134
135
def get_index_at_dt(case_input, env):
136
    history_spec = history.HistorySpec(
137
        case_input['bar_count'],
138
        case_input['frequency'],
139
        None,
140
        False,
141
        env=env,
142
        data_frequency='minute',
143
    )
144
    return history.index_at_dt(history_spec, case_input['algo_dt'], env=env)
145
146
147
class TestHistoryIndex(TestCase):
148
149
    @classmethod
150
    def setUpClass(cls):
151
        cls.environment = TradingEnvironment()
152
153
    @classmethod
154
    def tearDownClass(cls):
155
        del cls.environment
156
157
    @parameterized.expand(
158
        [(name, case['input'], case['expected'])
159
         for name, case in INDEX_TEST_CASES.items()]
160
    )
161
    def test_index_at_dt(self, name, case_input, expected):
162
        history_index = get_index_at_dt(case_input, self.environment)
163
164
        history_series = pd.Series(index=history_index)
165
        expected_series = pd.Series(index=expected)
166
167
        pd.util.testing.assert_series_equal(history_series, expected_series)
168
169
170
class TestHistoryContainer(TestCase):
171
172
    @classmethod
173
    def setUpClass(cls):
174
        cls.env = TradingEnvironment()
175
176
    @classmethod
177
    def tearDownClass(cls):
178
        del cls.env
179
180
    def bar_data_dt(self, bar_data, require_unique=True):
181
        """
182
        Get a dt to associate with the given BarData object.
183
184
        If require_unique == True, throw an error if multiple unique dt's are
185
        encountered.  Otherwise, return the earliest dt encountered.
186
        """
187
        dts = {sid_data['dt'] for sid_data in bar_data.values()}
188
        if require_unique and len(dts) > 1:
189
            self.fail("Multiple unique dts ({0}) in {1}".format(dts, bar_data))
190
191
        return sorted(dts)[0]
192
193
    @parameterized.expand(
194
        [(name,
195
          case['specs'],
196
          case['sids'],
197
          case['dt'],
198
          case['updates'],
199
          case['expected'])
200
         for name, case in HISTORY_CONTAINER_TEST_CASES.items()]
201
    )
202
    def test_history_container(self,
203
                               name,
204
                               specs,
205
                               sids,
206
                               dt,
207
                               updates,
208
                               expected):
209
210
        for spec in specs:
211
            # Sanity check on test input.
212
            self.assertEqual(len(expected[spec.key_str]), len(updates))
213
214
        container = HistoryContainer(
215
            {spec.key_str: spec for spec in specs}, sids, dt, 'minute',
216
            env=self.env,
217
        )
218
219
        for update_count, update in enumerate(updates):
220
221
            bar_dt = self.bar_data_dt(update)
222
            container.update(update, bar_dt)
223
224
            for spec in specs:
225
                pd.util.testing.assert_frame_equal(
226
                    container.get_history(spec, bar_dt),
227
                    expected[spec.key_str][update_count],
228
                    check_dtype=False,
229
                    check_column_type=True,
230
                    check_index_type=True,
231
                    check_frame_type=True,
232
                )
233
234
    def test_multiple_specs_on_same_bar(self):
235
        """
236
        Test that a ffill and non ffill spec both get
237
        the correct results when called on the same tick
238
        """
239
        spec = history.HistorySpec(
240
            bar_count=3,
241
            frequency='1m',
242
            field='price',
243
            ffill=True,
244
            data_frequency='minute',
245
            env=self.env,
246
        )
247
        no_fill_spec = history.HistorySpec(
248
            bar_count=3,
249
            frequency='1m',
250
            field='price',
251
            ffill=False,
252
            data_frequency='minute',
253
            env=self.env,
254
        )
255
256
        specs = {spec.key_str: spec, no_fill_spec.key_str: no_fill_spec}
257
        initial_sids = [1, ]
258
        initial_dt = pd.Timestamp(
259
            '2013-06-28 9:31AM', tz='US/Eastern').tz_convert('UTC')
260
261
        container = HistoryContainer(
262
            specs, initial_sids, initial_dt, 'minute', env=self.env,
263
        )
264
265
        bar_data = BarData()
266
        container.update(bar_data, initial_dt)
267
        # Add data on bar two of first day.
268
        second_bar_dt = pd.Timestamp(
269
            '2013-06-28 9:32AM', tz='US/Eastern').tz_convert('UTC')
270
        bar_data[1] = {
271
            'price': 10,
272
            'dt': second_bar_dt
273
        }
274
        container.update(bar_data, second_bar_dt)
275
276
        third_bar_dt = pd.Timestamp(
277
            '2013-06-28 9:33AM', tz='US/Eastern').tz_convert('UTC')
278
279
        del bar_data[1]
280
281
        # add nan for 3rd bar
282
        container.update(bar_data, third_bar_dt)
283
        prices = container.get_history(spec, third_bar_dt)
284
        no_fill_prices = container.get_history(no_fill_spec, third_bar_dt)
285
        self.assertEqual(prices.values[-1], 10)
286
        self.assertTrue(np.isnan(no_fill_prices.values[-1]),
287
                        "Last price should be np.nan")
288
289
    def test_container_nans_and_daily_roll(self):
290
291
        spec = history.HistorySpec(
292
            bar_count=3,
293
            frequency='1d',
294
            field='price',
295
            ffill=True,
296
            data_frequency='minute',
297
            env=self.env,
298
        )
299
        specs = {spec.key_str: spec}
300
        initial_sids = [1, ]
301
        initial_dt = pd.Timestamp(
302
            '2013-06-28 9:31AM', tz='US/Eastern').tz_convert('UTC')
303
304
        container = HistoryContainer(
305
            specs, initial_sids, initial_dt, 'minute', env=self.env,
306
        )
307
308
        bar_data = BarData()
309
        container.update(bar_data, initial_dt)
310
        # Since there was no backfill because of no db.
311
        # And no first bar of data, so all values should be nans.
312
        prices = container.get_history(spec, initial_dt)
313
        nan_values = np.isnan(prices[1])
314
        self.assertTrue(all(nan_values), nan_values)
315
316
        # Add data on bar two of first day.
317
        second_bar_dt = pd.Timestamp(
318
            '2013-06-28 9:32AM', tz='US/Eastern').tz_convert('UTC')
319
320
        bar_data[1] = {
321
            'price': 10,
322
            'dt': second_bar_dt
323
        }
324
        container.update(bar_data, second_bar_dt)
325
326
        prices = container.get_history(spec, second_bar_dt)
327
        # Prices should be
328
        #                             1
329
        # 2013-06-26 20:00:00+00:00 NaN
330
        # 2013-06-27 20:00:00+00:00 NaN
331
        # 2013-06-28 13:32:00+00:00  10
332
333
        self.assertTrue(np.isnan(prices[1].ix[0]))
334
        self.assertTrue(np.isnan(prices[1].ix[1]))
335
        self.assertEqual(prices[1].ix[2], 10)
336
337
        third_bar_dt = pd.Timestamp(
338
            '2013-06-28 9:33AM', tz='US/Eastern').tz_convert('UTC')
339
340
        del bar_data[1]
341
342
        container.update(bar_data, third_bar_dt)
343
344
        prices = container.get_history(spec, third_bar_dt)
345
        # The one should be forward filled
346
347
        # Prices should be
348
        #                             1
349
        # 2013-06-26 20:00:00+00:00 NaN
350
        # 2013-06-27 20:00:00+00:00 NaN
351
        # 2013-06-28 13:33:00+00:00  10
352
353
        self.assertEquals(prices[1][third_bar_dt], 10)
354
355
        # Note that we did not fill in data at the close.
356
        # There was a bug where a nan was being introduced because of the
357
        # last value of 'raw' data was used, instead of a ffilled close price.
358
359
        day_two_first_bar_dt = pd.Timestamp(
360
            '2013-07-01 9:31AM', tz='US/Eastern').tz_convert('UTC')
361
362
        bar_data[1] = {
363
            'price': 20,
364
            'dt': day_two_first_bar_dt
365
        }
366
367
        container.update(bar_data, day_two_first_bar_dt)
368
369
        prices = container.get_history(spec, day_two_first_bar_dt)
370
371
        # Prices Should Be
372
373
        #                              1
374
        # 2013-06-27 20:00:00+00:00  nan
375
        # 2013-06-28 20:00:00+00:00   10
376
        # 2013-07-01 13:31:00+00:00   20
377
378
        self.assertTrue(np.isnan(prices[1].ix[0]))
379
        self.assertEqual(prices[1].ix[1], 10)
380
        self.assertEqual(prices[1].ix[2], 20)
381
382
        # Clear out the bar data
383
384
        del bar_data[1]
385
386
        day_three_first_bar_dt = pd.Timestamp(
387
            '2013-07-02 9:31AM', tz='US/Eastern').tz_convert('UTC')
388
389
        container.update(bar_data, day_three_first_bar_dt)
390
391
        prices = container.get_history(spec, day_three_first_bar_dt)
392
393
        #                             1
394
        # 2013-06-28 20:00:00+00:00  10
395
        # 2013-07-01 20:00:00+00:00  20
396
        # 2013-07-02 13:31:00+00:00  20
397
398
        self.assertTrue(prices[1].ix[0], 10)
399
        self.assertTrue(prices[1].ix[1], 20)
400
        self.assertTrue(prices[1].ix[2], 20)
401
402
        day_four_first_bar_dt = pd.Timestamp(
403
            '2013-07-03 9:31AM', tz='US/Eastern').tz_convert('UTC')
404
405
        container.update(bar_data, day_four_first_bar_dt)
406
407
        prices = container.get_history(spec, day_four_first_bar_dt)
408
409
        #                             1
410
        # 2013-07-01 20:00:00+00:00  20
411
        # 2013-07-02 20:00:00+00:00  20
412
        # 2013-07-03 13:31:00+00:00  20
413
414
        self.assertEqual(prices[1].ix[0], 20)
415
        self.assertEqual(prices[1].ix[1], 20)
416
        self.assertEqual(prices[1].ix[2], 20)
417
418
419
class TestHistoryAlgo(TestCase):
420
421
    @classmethod
422
    def setUpClass(cls):
423
        cls.env = trading.TradingEnvironment()
424
        cls.env.write_data(equities_identifiers=[0, 1])
425
426
    @classmethod
427
    def tearDownClass(cls):
428
        del cls.env
429
430
    def setUp(self):
431
        np.random.seed(123)
432
433
    def test_history_daily(self):
434
        bar_count = 3
435
        algo_text = """
436
from zipline.api import history, add_history
437
438
def initialize(context):
439
    add_history(bar_count={bar_count}, frequency='1d', field='price')
440
    context.history_trace = []
441
442
def handle_data(context, data):
443
    prices = history(bar_count={bar_count}, frequency='1d', field='price')
444
    context.history_trace.append(prices)
445
""".format(bar_count=bar_count).strip()
446
447
        #      March 2006
448
        # Su Mo Tu We Th Fr Sa
449
        #          1  2  3  4
450
        #  5  6  7  8  9 10 11
451
        # 12 13 14 15 16 17 18
452
        # 19 20 21 22 23 24 25
453
        # 26 27 28 29 30 31
454
455
        start = pd.Timestamp('2006-03-20', tz='UTC')
456
        end = pd.Timestamp('2006-03-30', tz='UTC')
457
458
        sim_params = factory.create_simulation_parameters(
459
            start=start, end=end, data_frequency='daily', env=self.env,
460
        )
461
462
        _, df = factory.create_test_df_source(sim_params, self.env)
463
        df = df.astype(np.float64)
464
        source = DataFrameSource(df)
465
466
        test_algo = TradingAlgorithm(
467
            script=algo_text,
468
            data_frequency='daily',
469
            sim_params=sim_params,
470
            env=TestHistoryAlgo.env,
471
        )
472
473
        output = test_algo.run(source)
474
        self.assertIsNotNone(output)
475
476
        history_trace = test_algo.history_trace
477
478
        for i, received in enumerate(history_trace[bar_count - 1:]):
479
            expected = df.iloc[i:i + bar_count]
480
            assert_frame_equal(expected, received)
481
482
    def test_history_daily_data_1m_window(self):
483
        algo_text = """
484
from zipline.api import history, add_history
485
486
def initialize(context):
487
    add_history(bar_count=1, frequency='1m', field='price')
488
489
def handle_data(context, data):
490
    prices = history(bar_count=3, frequency='1d', field='price')
491
""".strip()
492
493
        start = pd.Timestamp('2006-03-20', tz='UTC')
494
        end = pd.Timestamp('2006-03-30', tz='UTC')
495
496
        sim_params = factory.create_simulation_parameters(
497
            start=start, end=end)
498
499
        with self.assertRaises(IncompatibleHistoryFrequency):
500
            algo = TradingAlgorithm(
501
                script=algo_text,
502
                data_frequency='daily',
503
                sim_params=sim_params,
504
                env=TestHistoryAlgo.env,
505
            )
506
            source = RandomWalkSource(start=start, end=end)
507
            algo.run(source)
508
509
    def test_basic_history(self):
510
        algo_text = """
511
from zipline.api import history, add_history
512
513
def initialize(context):
514
    add_history(bar_count=2, frequency='1d', field='price')
515
516
def handle_data(context, data):
517
    prices = history(bar_count=2, frequency='1d', field='price')
518
    prices['prices_times_two'] = prices[1] * 2
519
    context.last_prices = prices
520
""".strip()
521
522
        #      March 2006
523
        # Su Mo Tu We Th Fr Sa
524
        #          1  2  3  4
525
        #  5  6  7  8  9 10 11
526
        # 12 13 14 15 16 17 18
527
        # 19 20 21 22 23 24 25
528
        # 26 27 28 29 30 31
529
        start = pd.Timestamp('2006-03-20', tz='UTC')
530
        end = pd.Timestamp('2006-03-21', tz='UTC')
531
532
        sim_params = factory.create_simulation_parameters(
533
            start=start, end=end)
534
535
        test_algo = TradingAlgorithm(
536
            script=algo_text,
537
            data_frequency='minute',
538
            sim_params=sim_params,
539
            env=TestHistoryAlgo.env,
540
        )
541
542
        source = RandomWalkSource(start=start,
543
                                  end=end)
544
        output = test_algo.run(source)
545
        self.assertIsNotNone(output)
546
547
        last_prices = test_algo.last_prices[0]
548
        oldest_dt = pd.Timestamp(
549
            '2006-03-20 4:00 PM', tz='US/Eastern').tz_convert('UTC')
550
        newest_dt = pd.Timestamp(
551
            '2006-03-21 4:00 PM', tz='US/Eastern').tz_convert('UTC')
552
553
        self.assertEquals(oldest_dt, last_prices.index[0])
554
        self.assertEquals(newest_dt, last_prices.index[-1])
555
556
        # Random, depends on seed
557
        self.assertEquals(139.36946942498648, last_prices[oldest_dt])
558
        self.assertEquals(180.15661995395106, last_prices[newest_dt])
559
560
    @parameterized.expand([
561
        ('daily',),
562
        ('minute',),
563
    ])
564
    def test_history_in_bts_price_days(self, data_freq):
565
        """
566
        Test calling history() in before_trading_start()
567
        with daily price bars.
568
        """
569
        algo_text = """
570
from zipline.api import history
571
572
def initialize(context):
573
    context.first_bts_call = True
574
575
def before_trading_start(context, data):
576
    if not context.first_bts_call:
577
        prices_bts = history(bar_count=3, frequency='1d', field='price')
578
        context.prices_bts = prices_bts
579
    context.first_bts_call = False
580
581
def handle_data(context, data):
582
    prices_hd = history(bar_count=3, frequency='1d', field='price')
583
    context.prices_hd = prices_hd
584
""".strip()
585
586
        #      March 2006
587
        # Su Mo Tu We Th Fr Sa
588
        #          1  2  3  4
589
        #  5  6  7  8  9 10 11
590
        # 12 13 14 15 16 17 18
591
        # 19 20 21 22 23 24 25
592
        # 26 27 28 29 30 31
593
        start = pd.Timestamp('2006-03-20', tz='UTC')
594
        end = pd.Timestamp('2006-03-22', tz='UTC')
595
596
        sim_params = factory.create_simulation_parameters(
597
            start=start, end=end, data_frequency=data_freq)
598
599
        test_algo = TradingAlgorithm(
600
            script=algo_text,
601
            data_frequency=data_freq,
602
            sim_params=sim_params,
603
            env=TestHistoryAlgo.env,
604
        )
605
606
        source = RandomWalkSource(start=start, end=end, freq=data_freq)
607
        output = test_algo.run(source)
608
        self.assertIsNotNone(output)
609
610
        # Get the prices recorded by history() within handle_data()
611
        prices_hd = test_algo.prices_hd[0]
612
        # Get the prices recorded by history() within BTS
613
        prices_bts = test_algo.prices_bts[0]
614
615
        # before_trading_start() is timestamp'd to midnight prior to
616
        # the day's trading. Since no equity trades occur at midnight,
617
        # the price recorded for this time is forward filled from the
618
        # last trade - typically ~4pm the previous day. This results
619
        # in the OHLCV data recorded by history() in BTS lagging
620
        # that recorded by history in handle_data().
621
        # The trace of the pricing data from history() called within
622
        # handle_data() vs. BTS in the above algo is as follows:
623
624
        #  When called within handle_data()
625
        # ---------------------------------
626
        # 2006-03-20 21:00:00    139.369469
627
        # 2006-03-21 21:00:00    180.156620
628
        # 2006-03-22 21:00:00    221.344654
629
630
        #       When called within BTS
631
        # ---------------------------------
632
        # 2006-03-17 21:00:00           NaN
633
        # 2006-03-20 21:00:00    139.369469
634
        # 2006-03-22 00:00:00    180.156620
635
636
        # Get relevant Timestamps for the history() call within handle_data()
637
        oldest_hd_dt = pd.Timestamp(
638
            '2006-03-20 4:00 PM', tz='US/Eastern').tz_convert('UTC')
639
        penultimate_hd_dt = pd.Timestamp(
640
            '2006-03-21 4:00 PM', tz='US/Eastern').tz_convert('UTC')
641
642
        # Get relevant Timestamps for the history() call within BTS
643
        penultimate_bts_dt = pd.Timestamp(
644
            '2006-03-20 4:00 PM', tz='US/Eastern').tz_convert('UTC')
645
        newest_bts_dt = normalize_date(pd.Timestamp(
646
            '2006-03-22 04:00 PM', tz='US/Eastern').tz_convert('UTC'))
647
648
        if data_freq == 'daily':
649
            # If we're dealing with daily data, then we record
650
            # canonicalized timestamps, so make conversion here:
651
            oldest_hd_dt = normalize_date(oldest_hd_dt)
652
            penultimate_hd_dt = normalize_date(penultimate_hd_dt)
653
            penultimate_bts_dt = normalize_date(penultimate_bts_dt)
654
655
        self.assertEquals(prices_hd[oldest_hd_dt],
656
                          prices_bts[penultimate_bts_dt])
657
        self.assertEquals(prices_hd[penultimate_hd_dt],
658
                          prices_bts[newest_bts_dt])
659
660
    def test_history_in_bts_price_minutes(self):
661
        """
662
        Test calling history() in before_trading_start()
663
        with minutely price bars.
664
        """
665
        algo_text = """
666
from zipline.api import history
667
668
def initialize(context):
669
    context.first_bts_call = True
670
671
def before_trading_start(context, data):
672
    if not context.first_bts_call:
673
        price_bts = history(bar_count=1, frequency='1m', field='price')
674
        context.price_bts = price_bts
675
    context.first_bts_call = False
676
677
def handle_data(context, data):
678
    pass
679
680
""".strip()
681
682
        #      March 2006
683
        # Su Mo Tu We Th Fr Sa
684
        #          1  2  3  4
685
        #  5  6  7  8  9 10 11
686
        # 12 13 14 15 16 17 18
687
        # 19 20 21 22 23 24 25
688
        # 26 27 28 29 30 31
689
        start = pd.Timestamp('2006-03-20', tz='UTC')
690
        end = pd.Timestamp('2006-03-22', tz='UTC')
691
692
        sim_params = factory.create_simulation_parameters(
693
            start=start, end=end)
694
695
        test_algo = TradingAlgorithm(
696
            script=algo_text,
697
            data_frequency='minute',
698
            sim_params=sim_params,
699
            env=TestHistoryAlgo.env,
700
        )
701
702
        source = RandomWalkSource(start=start, end=end)
703
        output = test_algo.run(source)
704
        self.assertIsNotNone(output)
705
706
        # Get the prices recorded by history() within BTS
707
        price_bts_0 = test_algo.price_bts[0]
708
        price_bts_1 = test_algo.price_bts[1]
709
710
        # The prices recorded by history() in BTS should
711
        # be the closing price of the previous day, which are:
712
        #
713
        #          sid | close on 2006-03-21
714
        #         ----------------------------
715
        #           0  | 180.15661995395106
716
        #           1  | 578.41665003444723
717
718
        # These are not 'real' price values. They are the product of
719
        # RandonWalkSource, which produces random walk OHLCV timeseries. For a
720
        # given seed these values are deterministc.
721
        self.assertEquals(180.15661995395106, price_bts_0.ix[0])
722
        self.assertEquals(578.41665003444723, price_bts_1.ix[0])
723
724
    @parameterized.expand([
725
        ('daily',),
726
        ('minute',),
727
    ])
728
    def test_history_in_bts_volume_days(self, data_freq):
729
        """
730
        Test calling history() in before_trading_start()
731
        with daily volume bars.
732
        """
733
        algo_text = """
734
from zipline.api import history
735
736
def initialize(context):
737
    context.first_bts_call = True
738
739
def before_trading_start(context, data):
740
    if not context.first_bts_call:
741
        volume_bts = history(bar_count=2, frequency='1d', field='volume')
742
        context.volume_bts = volume_bts
743
    context.first_bts_call = False
744
745
def handle_data(context, data):
746
    volume_hd = history(bar_count=2, frequency='1d', field='volume')
747
    context.volume_hd = volume_hd
748
""".strip()
749
750
        #      March 2006
751
        # Su Mo Tu We Th Fr Sa
752
        #          1  2  3  4
753
        #  5  6  7  8  9 10 11
754
        # 12 13 14 15 16 17 18
755
        # 19 20 21 22 23 24 25
756
        # 26 27 28 29 30 31
757
        start = pd.Timestamp('2006-03-20', tz='UTC')
758
        end = pd.Timestamp('2006-03-22', tz='UTC')
759
760
        sim_params = factory.create_simulation_parameters(
761
            start=start, end=end, data_frequency=data_freq)
762
763
        test_algo = TradingAlgorithm(
764
            script=algo_text,
765
            data_frequency=data_freq,
766
            sim_params=sim_params,
767
            env=TestHistoryAlgo.env,
768
        )
769
770
        source = RandomWalkSource(start=start, end=end, freq=data_freq)
771
        output = test_algo.run(source)
772
        self.assertIsNotNone(output)
773
774
        # Get the volume recorded by history() within handle_data()
775
        volume_hd_0 = test_algo.volume_hd[0]
776
        volume_hd_1 = test_algo.volume_hd[1]
777
        # Get the volume recorded by history() within BTS
778
        volume_bts_0 = test_algo.volume_bts[0]
779
        volume_bts_1 = test_algo.volume_bts[1]
780
781
        penultimate_hd_dt = pd.Timestamp(
782
            '2006-03-21 4:00 PM', tz='US/Eastern').tz_convert('UTC')
783
        # Midnight of the day on which BTS is invoked.
784
        newest_bts_dt = normalize_date(pd.Timestamp(
785
            '2006-03-22 04:00 PM', tz='US/Eastern').tz_convert('UTC'))
786
787
        if data_freq == 'daily':
788
            # If we're dealing with daily data, then we record
789
            # canonicalized timestamps, so make conversion here:
790
            penultimate_hd_dt = normalize_date(penultimate_hd_dt)
791
792
        # When history() is called in BTS, its 'current' volume value
793
        # should equal the sum of the previous day.
794
        self.assertEquals(volume_hd_0[penultimate_hd_dt],
795
                          volume_bts_0[newest_bts_dt])
796
        self.assertEquals(volume_hd_1[penultimate_hd_dt],
797
                          volume_bts_1[newest_bts_dt])
798
799
    def test_history_in_bts_volume_minutes(self):
800
        """
801
        Test calling history() in before_trading_start()
802
        with minutely volume bars.
803
        """
804
        algo_text = """
805
from zipline.api import history
806
807
def initialize(context):
808
    context.first_bts_call = True
809
810
def before_trading_start(context, data):
811
    if not context.first_bts_call:
812
        volume_bts = history(bar_count=2, frequency='1m', field='volume')
813
        context.volume_bts = volume_bts
814
    context.first_bts_call = False
815
816
def handle_data(context, data):
817
    pass
818
""".strip()
819
820
        #      March 2006
821
        # Su Mo Tu We Th Fr Sa
822
        #          1  2  3  4
823
        #  5  6  7  8  9 10 11
824
        # 12 13 14 15 16 17 18
825
        # 19 20 21 22 23 24 25
826
        # 26 27 28 29 30 31
827
        start = pd.Timestamp('2006-03-20', tz='UTC')
828
        end = pd.Timestamp('2006-03-22', tz='UTC')
829
830
        sim_params = factory.create_simulation_parameters(
831
            start=start, end=end)
832
833
        test_algo = TradingAlgorithm(
834
            script=algo_text,
835
            data_frequency='minute',
836
            sim_params=sim_params,
837
            env=TestHistoryAlgo.env,
838
        )
839
840
        source = RandomWalkSource(start=start, end=end)
841
        output = test_algo.run(source)
842
        self.assertIsNotNone(output)
843
844
        # Get the volumes recorded for sid 0 by history() within BTS
845
        volume_bts_0 = test_algo.volume_bts[0]
846
        # Get the volumes recorded for sid 1 by history() within BTS
847
        volume_bts_1 = test_algo.volume_bts[1]
848
849
        # The values recorded on 2006-03-22 by history() in BTS
850
        # should equal the final volume values for the trading
851
        # day 2006-03-21:
852
        #                             0       1
853
        #   2006-03-21 20:59:00  215548  439908
854
        #   2006-03-21 21:00:00  985645  664313
855
        #
856
        # Note: These are not 'real' volume values. They are the product of
857
        # RandonWalkSource, which produces random walk OHLCV timeseries. For a
858
        # given seed these values are deterministc.
859
        self.assertEquals(215548, volume_bts_0.ix[0])
860
        self.assertEquals(985645, volume_bts_0.ix[1])
861
        self.assertEquals(439908, volume_bts_1.ix[0])
862
        self.assertEquals(664313, volume_bts_1.ix[1])
863
864
    def test_basic_history_one_day(self):
865
        algo_text = """
866
from zipline.api import history, add_history
867
868
def initialize(context):
869
    add_history(bar_count=1, frequency='1d', field='price')
870
871
def handle_data(context, data):
872
    prices = history(bar_count=1, frequency='1d', field='price')
873
    context.last_prices = prices
874
""".strip()
875
876
        #      March 2006
877
        # Su Mo Tu We Th Fr Sa
878
        #          1  2  3  4
879
        #  5  6  7  8  9 10 11
880
        # 12 13 14 15 16 17 18
881
        # 19 20 21 22 23 24 25
882
        # 26 27 28 29 30 31
883
884
        start = pd.Timestamp('2006-03-20', tz='UTC')
885
        end = pd.Timestamp('2006-03-21', tz='UTC')
886
887
        sim_params = factory.create_simulation_parameters(
888
            start=start, end=end)
889
890
        test_algo = TradingAlgorithm(
891
            script=algo_text,
892
            data_frequency='minute',
893
            sim_params=sim_params,
894
            env=TestHistoryAlgo.env,
895
        )
896
897
        source = RandomWalkSource(start=start,
898
                                  end=end)
899
        output = test_algo.run(source)
900
901
        self.assertIsNotNone(output)
902
903
        last_prices = test_algo.last_prices[0]
904
        # oldest and newest should be the same if there is only 1 bar
905
        oldest_dt = pd.Timestamp(
906
            '2006-03-21 4:00 PM', tz='US/Eastern').tz_convert('UTC')
907
        newest_dt = pd.Timestamp(
908
            '2006-03-21 4:00 PM', tz='US/Eastern').tz_convert('UTC')
909
910
        self.assertEquals(oldest_dt, last_prices.index[0])
911
        self.assertEquals(newest_dt, last_prices.index[-1])
912
913
        # Random, depends on seed
914
        self.assertEquals(180.15661995395106, last_prices[oldest_dt])
915
        self.assertEquals(180.15661995395106, last_prices[newest_dt])
916
917
    def test_basic_history_positional_args(self):
918
        """
919
        Ensure that positional args work.
920
        """
921
        algo_text = """
922
from zipline.api import history, add_history
923
924
def initialize(context):
925
    add_history(2, '1d', 'price')
926
927
def handle_data(context, data):
928
929
    prices = history(2, '1d', 'price')
930
    context.last_prices = prices
931
""".strip()
932
933
        #      March 2006
934
        # Su Mo Tu We Th Fr Sa
935
        #          1  2  3  4
936
        #  5  6  7  8  9 10 11
937
        # 12 13 14 15 16 17 18
938
        # 19 20 21 22 23 24 25
939
        # 26 27 28 29 30 31
940
941
        start = pd.Timestamp('2006-03-20', tz='UTC')
942
        end = pd.Timestamp('2006-03-21', tz='UTC')
943
944
        sim_params = factory.create_simulation_parameters(
945
            start=start, end=end)
946
947
        test_algo = TradingAlgorithm(
948
            script=algo_text,
949
            data_frequency='minute',
950
            sim_params=sim_params,
951
            env=TestHistoryAlgo.env,
952
        )
953
954
        source = RandomWalkSource(start=start,
955
                                  end=end)
956
        output = test_algo.run(source)
957
        self.assertIsNotNone(output)
958
959
        last_prices = test_algo.last_prices[0]
960
        oldest_dt = pd.Timestamp(
961
            '2006-03-20 4:00 PM', tz='US/Eastern').tz_convert('UTC')
962
        newest_dt = pd.Timestamp(
963
            '2006-03-21 4:00 PM', tz='US/Eastern').tz_convert('UTC')
964
965
        self.assertEquals(oldest_dt, last_prices.index[0])
966
        self.assertEquals(newest_dt, last_prices.index[-1])
967
968
        self.assertEquals(139.36946942498648, last_prices[oldest_dt])
969
        self.assertEquals(180.15661995395106, last_prices[newest_dt])
970
971
    def test_history_with_volume(self):
972
        algo_text = """
973
from zipline.api import history, add_history, record
974
975
def initialize(context):
976
    add_history(3, '1d', 'volume')
977
978
def handle_data(context, data):
979
    volume = history(3, '1d', 'volume')
980
981
    record(current_volume=volume[0].ix[-1])
982
""".strip()
983
984
        #      April 2007
985
        # Su Mo Tu We Th Fr Sa
986
        #  1  2  3  4  5  6  7
987
        #  8  9 10 11 12 13 14
988
        # 15 16 17 18 19 20 21
989
        # 22 23 24 25 26 27 28
990
        # 29 30
991
992
        start = pd.Timestamp('2007-04-10', tz='UTC')
993
        end = pd.Timestamp('2007-04-10', tz='UTC')
994
995
        sim_params = SimulationParameters(
996
            period_start=start,
997
            period_end=end,
998
            capital_base=float("1.0e5"),
999
            data_frequency='minute',
1000
            emission_rate='minute'
1001
        )
1002
1003
        test_algo = TradingAlgorithm(
1004
            script=algo_text,
1005
            data_frequency='minute',
1006
            sim_params=sim_params,
1007
            env=TestHistoryAlgo.env,
1008
        )
1009
1010
        source = RandomWalkSource(start=start,
1011
                                  end=end)
1012
        output = test_algo.run(source)
1013
1014
        np.testing.assert_equal(output.ix[0, 'current_volume'],
1015
                                212218404.0)
1016
1017
    def test_history_with_high(self):
1018
        algo_text = """
1019
from zipline.api import history, add_history, record
1020
1021
def initialize(context):
1022
    add_history(3, '1d', 'high')
1023
1024
def handle_data(context, data):
1025
    highs = history(3, '1d', 'high')
1026
1027
    record(current_high=highs[0].ix[-1])
1028
""".strip()
1029
1030
        #      April 2007
1031
        # Su Mo Tu We Th Fr Sa
1032
        #  1  2  3  4  5  6  7
1033
        #  8  9 10 11 12 13 14
1034
        # 15 16 17 18 19 20 21
1035
        # 22 23 24 25 26 27 28
1036
        # 29 30
1037
1038
        start = pd.Timestamp('2007-04-10', tz='UTC')
1039
        end = pd.Timestamp('2007-04-10', tz='UTC')
1040
1041
        sim_params = SimulationParameters(
1042
            period_start=start,
1043
            period_end=end,
1044
            capital_base=float("1.0e5"),
1045
            data_frequency='minute',
1046
            emission_rate='minute'
1047
        )
1048
1049
        test_algo = TradingAlgorithm(
1050
            script=algo_text,
1051
            data_frequency='minute',
1052
            sim_params=sim_params,
1053
            env=TestHistoryAlgo.env,
1054
        )
1055
1056
        source = RandomWalkSource(start=start,
1057
                                  end=end)
1058
        output = test_algo.run(source)
1059
1060
        np.testing.assert_equal(output.ix[0, 'current_high'],
1061
                                139.5370641791925)
1062
1063
    def test_history_with_low(self):
1064
        algo_text = """
1065
from zipline.api import history, add_history, record
1066
1067
def initialize(context):
1068
    add_history(3, '1d', 'low')
1069
1070
def handle_data(context, data):
1071
    lows = history(3, '1d', 'low')
1072
1073
    record(current_low=lows[0].ix[-1])
1074
""".strip()
1075
1076
        #      April 2007
1077
        # Su Mo Tu We Th Fr Sa
1078
        #  1  2  3  4  5  6  7
1079
        #  8  9 10 11 12 13 14
1080
        # 15 16 17 18 19 20 21
1081
        # 22 23 24 25 26 27 28
1082
        # 29 30
1083
1084
        start = pd.Timestamp('2007-04-10', tz='UTC')
1085
        end = pd.Timestamp('2007-04-10', tz='UTC')
1086
1087
        sim_params = SimulationParameters(
1088
            period_start=start,
1089
            period_end=end,
1090
            capital_base=float("1.0e5"),
1091
            data_frequency='minute',
1092
            emission_rate='minute'
1093
        )
1094
1095
        test_algo = TradingAlgorithm(
1096
            script=algo_text,
1097
            data_frequency='minute',
1098
            sim_params=sim_params,
1099
            env=TestHistoryAlgo.env,
1100
        )
1101
1102
        source = RandomWalkSource(start=start,
1103
                                  end=end)
1104
        output = test_algo.run(source)
1105
1106
        np.testing.assert_equal(output.ix[0, 'current_low'],
1107
                                99.891436939669944)
1108
1109
    def test_history_with_open(self):
1110
        algo_text = """
1111
from zipline.api import history, add_history, record
1112
1113
def initialize(context):
1114
    add_history(3, '1d', 'open_price')
1115
1116
def handle_data(context, data):
1117
    opens = history(3, '1d', 'open_price')
1118
1119
    record(current_open=opens[0].ix[-1])
1120
""".strip()
1121
1122
        #      April 2007
1123
        # Su Mo Tu We Th Fr Sa
1124
        #  1  2  3  4  5  6  7
1125
        #  8  9 10 11 12 13 14
1126
        # 15 16 17 18 19 20 21
1127
        # 22 23 24 25 26 27 28
1128
        # 29 30
1129
1130
        start = pd.Timestamp('2007-04-10', tz='UTC')
1131
        end = pd.Timestamp('2007-04-10', tz='UTC')
1132
1133
        sim_params = SimulationParameters(
1134
            period_start=start,
1135
            period_end=end,
1136
            capital_base=float("1.0e5"),
1137
            data_frequency='minute',
1138
            emission_rate='minute'
1139
        )
1140
1141
        test_algo = TradingAlgorithm(
1142
            script=algo_text,
1143
            data_frequency='minute',
1144
            sim_params=sim_params,
1145
            env=TestHistoryAlgo.env,
1146
        )
1147
1148
        source = RandomWalkSource(start=start,
1149
                                  end=end)
1150
        output = test_algo.run(source)
1151
1152
        np.testing.assert_equal(output.ix[0, 'current_open'],
1153
                                99.991436939669939)
1154
1155
    def test_history_passed_to_func(self):
1156
        """
1157
        Had an issue where MagicMock was causing errors during validation
1158
        with rolling mean.
1159
        """
1160
        algo_text = """
1161
from zipline.api import history, add_history
1162
import pandas as pd
1163
1164
def initialize(context):
1165
    add_history(2, '1d', 'price')
1166
1167
def handle_data(context, data):
1168
    prices = history(2, '1d', 'price')
1169
1170
    pd.rolling_mean(prices, 2)
1171
""".strip()
1172
1173
        #      April 2007
1174
        # Su Mo Tu We Th Fr Sa
1175
        #  1  2  3  4  5  6  7
1176
        #  8  9 10 11 12 13 14
1177
        # 15 16 17 18 19 20 21
1178
        # 22 23 24 25 26 27 28
1179
        # 29 30
1180
1181
        start = pd.Timestamp('2007-04-10', tz='UTC')
1182
        end = pd.Timestamp('2007-04-10', tz='UTC')
1183
1184
        sim_params = SimulationParameters(
1185
            period_start=start,
1186
            period_end=end,
1187
            capital_base=float("1.0e5"),
1188
            data_frequency='minute',
1189
            emission_rate='minute'
1190
        )
1191
1192
        test_algo = TradingAlgorithm(
1193
            script=algo_text,
1194
            data_frequency='minute',
1195
            sim_params=sim_params,
1196
            env=TestHistoryAlgo.env,
1197
        )
1198
1199
        source = RandomWalkSource(start=start,
1200
                                  end=end)
1201
        output = test_algo.run(source)
1202
1203
        # At this point, just ensure that there is no crash.
1204
        self.assertIsNotNone(output)
1205
1206
    def test_history_passed_to_talib(self):
1207
        """
1208
        Had an issue where MagicMock was causing errors during validation
1209
        with talib.
1210
1211
        We don't officially support a talib integration, yet.
1212
        But using talib directly should work.
1213
        """
1214
        algo_text = """
1215
import talib
1216
import numpy as np
1217
1218
from zipline.api import history, add_history, record
1219
1220
def initialize(context):
1221
    add_history(2, '1d', 'price')
1222
1223
def handle_data(context, data):
1224
    prices = history(2, '1d', 'price')
1225
1226
    ma_result = talib.MA(np.asarray(prices[0]), timeperiod=2)
1227
    record(ma=ma_result[-1])
1228
""".strip()
1229
1230
        #      April 2007
1231
        # Su Mo Tu We Th Fr Sa
1232
        #  1  2  3  4  5  6  7
1233
        #  8  9 10 11 12 13 14
1234
        # 15 16 17 18 19 20 21
1235
        # 22 23 24 25 26 27 28
1236
        # 29 30
1237
1238
        # Eddie: this was set to 04-10 but I don't see how that makes
1239
        # sense as it does not generate enough data to get at -2 index
1240
        # below.
1241
        start = pd.Timestamp('2007-04-05', tz='UTC')
1242
        end = pd.Timestamp('2007-04-10', tz='UTC')
1243
1244
        sim_params = SimulationParameters(
1245
            period_start=start,
1246
            period_end=end,
1247
            capital_base=float("1.0e5"),
1248
            data_frequency='minute',
1249
            emission_rate='daily'
1250
        )
1251
1252
        test_algo = TradingAlgorithm(
1253
            script=algo_text,
1254
            data_frequency='minute',
1255
            sim_params=sim_params,
1256
            env=TestHistoryAlgo.env,
1257
        )
1258
1259
        source = RandomWalkSource(start=start,
1260
                                  end=end)
1261
        output = test_algo.run(source)
1262
        # At this point, just ensure that there is no crash.
1263
        self.assertIsNotNone(output)
1264
1265
        recorded_ma = output.ix[-2, 'ma']
1266
1267
        self.assertFalse(pd.isnull(recorded_ma))
1268
        # Depends on seed
1269
        np.testing.assert_almost_equal(recorded_ma,
1270
                                       159.76304468946876)
1271
1272
    @parameterized.expand([
1273
        ('daily',),
1274
        ('minute',),
1275
    ])
1276
    def test_history_container_constructed_at_runtime(self, data_freq):
1277
        algo_text = dedent(
1278
            """\
1279
            from zipline.api import history
1280
            def handle_data(context, data):
1281
                context.prices = history(2, '1d', 'price')
1282
            """
1283
        )
1284
        start = pd.Timestamp('2007-04-05', tz='UTC')
1285
        end = pd.Timestamp('2007-04-10', tz='UTC')
1286
1287
        sim_params = SimulationParameters(
1288
            period_start=start,
1289
            period_end=end,
1290
            capital_base=float("1.0e5"),
1291
            data_frequency=data_freq,
1292
            emission_rate=data_freq
1293
        )
1294
1295
        test_algo = TradingAlgorithm(
1296
            script=algo_text,
1297
            data_frequency=data_freq,
1298
            sim_params=sim_params,
1299
            env=TestHistoryAlgo.env,
1300
        )
1301
1302
        source = RandomWalkSource(start=start, end=end, freq=data_freq)
1303
1304
        self.assertIsNone(test_algo.history_container)
1305
        test_algo.run(source)
1306
        self.assertIsNotNone(
1307
            test_algo.history_container,
1308
            msg='HistoryContainer was not constructed at runtime',
1309
        )
1310
1311
        container = test_algo.history_container
1312
        self.assertEqual(
1313
            len(container.digest_panels),
1314
            1,
1315
            msg='The HistoryContainer created too many digest panels',
1316
        )
1317
1318
        freq, digest = list(container.digest_panels.items())[0]
1319
        self.assertEqual(
1320
            freq.unit_str,
1321
            'd',
1322
        )
1323
1324
        self.assertEqual(
1325
            digest.window_length,
1326
            1,
1327
            msg='The digest panel is not large enough to service the given'
1328
            ' HistorySpec',
1329
        )
1330
1331
    def test_history_in_initialize(self):
1332
        algo_text = dedent(
1333
            """\
1334
            from zipline.api import history
1335
1336
            def initialize(context):
1337
                history(10, '1d', 'price')
1338
1339
            def handle_data(context, data):
1340
                pass
1341
            """
1342
        )
1343
1344
        start = pd.Timestamp('2007-04-05', tz='UTC')
1345
        end = pd.Timestamp('2007-04-10', tz='UTC')
1346
1347
        sim_params = SimulationParameters(
1348
            period_start=start,
1349
            period_end=end,
1350
            capital_base=float("1.0e5"),
1351
            data_frequency='minute',
1352
            emission_rate='daily',
1353
            env=self.env,
1354
        )
1355
1356
        test_algo = TradingAlgorithm(
1357
            script=algo_text,
1358
            data_frequency='minute',
1359
            sim_params=sim_params,
1360
            env=self.env,
1361
        )
1362
1363
        with self.assertRaises(HistoryInInitialize):
1364
            test_algo.initialize()
1365
1366
    @parameterized.expand([
1367
        (1,),
1368
        (2,),
1369
    ])
1370
    def test_history_grow_length_inter_bar(self, incr):
1371
        """
1372
        Tests growing the length of a digest panel with different date_buf
1373
        deltas once per bar.
1374
        """
1375
        algo_text = dedent(
1376
            """\
1377
            from zipline.api import history
1378
1379
1380
            def initialize(context):
1381
                context.bar_count = 1
1382
1383
1384
            def handle_data(context, data):
1385
                prices = history(context.bar_count, '1d', 'price')
1386
                context.test_case.assertEqual(len(prices), context.bar_count)
1387
                context.bar_count += {incr}
1388
            """
1389
        ).format(incr=incr)
1390
        start = pd.Timestamp('2007-04-05', tz='UTC')
1391
        end = pd.Timestamp('2007-04-10', tz='UTC')
1392
1393
        sim_params = SimulationParameters(
1394
            period_start=start,
1395
            period_end=end,
1396
            capital_base=float("1.0e5"),
1397
            data_frequency='minute',
1398
            emission_rate='daily',
1399
            env=self.env,
1400
        )
1401
1402
        test_algo = TradingAlgorithm(
1403
            script=algo_text,
1404
            data_frequency='minute',
1405
            sim_params=sim_params,
1406
            env=self.env,
1407
        )
1408
        test_algo.test_case = self
1409
1410
        source = RandomWalkSource(start=start, end=end)
1411
1412
        self.assertIsNone(test_algo.history_container)
1413
        test_algo.run(source)
1414
1415
    @parameterized.expand([
1416
        (1,),
1417
        (2,),
1418
    ])
1419
    def test_history_grow_length_intra_bar(self, incr):
1420
        """
1421
        Tests growing the length of a digest panel with different date_buf
1422
        deltas in a single bar.
1423
        """
1424
        algo_text = dedent(
1425
            """\
1426
            from zipline.api import history
1427
1428
1429
            def initialize(context):
1430
                context.bar_count = 1
1431
1432
1433
            def handle_data(context, data):
1434
                prices = history(context.bar_count, '1d', 'price')
1435
                context.test_case.assertEqual(len(prices), context.bar_count)
1436
                context.bar_count += {incr}
1437
                prices = history(context.bar_count, '1d', 'price')
1438
                context.test_case.assertEqual(len(prices), context.bar_count)
1439
            """
1440
        ).format(incr=incr)
1441
        start = pd.Timestamp('2007-04-05', tz='UTC')
1442
        end = pd.Timestamp('2007-04-10', tz='UTC')
1443
1444
        sim_params = SimulationParameters(
1445
            period_start=start,
1446
            period_end=end,
1447
            capital_base=float("1.0e5"),
1448
            data_frequency='minute',
1449
            emission_rate='daily',
1450
            env=self.env,
1451
        )
1452
1453
        test_algo = TradingAlgorithm(
1454
            script=algo_text,
1455
            data_frequency='minute',
1456
            sim_params=sim_params,
1457
            env=self.env,
1458
        )
1459
        test_algo.test_case = self
1460
1461
        source = RandomWalkSource(start=start, end=end)
1462
1463
        self.assertIsNone(test_algo.history_container)
1464
        test_algo.run(source)
1465
1466
1467
class TestHistoryContainerResize(TestCase):
1468
1469
    @classmethod
1470
    def setUpClass(cls):
1471
        cls.env = TradingEnvironment()
1472
1473
    @classmethod
1474
    def tearDownClass(cls):
1475
        del cls.env
1476
1477
    @subtest(
1478
        ((freq, field, data_frequency, construct_digest)
1479
         for freq in ('1m', '1d')
1480
         for field in HistoryContainer.VALID_FIELDS
1481
         for data_frequency in ('minute', 'daily')
1482
         for construct_digest in (True, False)
1483
         if not (freq == '1m' and data_frequency == 'daily')),
1484
        'freq',
1485
        'field',
1486
        'data_frequency',
1487
        'construct_digest',
1488
    )
1489
    def test_history_grow_length(self,
1490
                                 freq,
1491
                                 field,
1492
                                 data_frequency,
1493
                                 construct_digest):
1494
        bar_count = 2 if construct_digest else 1
1495
        spec = history.HistorySpec(
1496
            bar_count=bar_count,
1497
            frequency=freq,
1498
            field=field,
1499
            ffill=True,
1500
            data_frequency=data_frequency,
1501
            env=self.env,
1502
        )
1503
        specs = {spec.key_str: spec}
1504
        initial_sids = [1]
1505
        initial_dt = pd.Timestamp(
1506
            '2013-06-28 13:31'
1507
            if data_frequency == 'minute'
1508
            else '2013-06-28 12:00AM',
1509
            tz='UTC',
1510
        )
1511
1512
        container = HistoryContainer(
1513
            specs, initial_sids, initial_dt, data_frequency, env=self.env,
1514
        )
1515
1516
        if construct_digest:
1517
            self.assertEqual(
1518
                container.digest_panels[spec.frequency].window_length, 1,
1519
            )
1520
1521
        bar_data = BarData()
1522
        container.update(bar_data, initial_dt)
1523
1524
        to_add = (
1525
            history.HistorySpec(
1526
                bar_count=bar_count + 1,
1527
                frequency=freq,
1528
                field=field,
1529
                ffill=True,
1530
                data_frequency=data_frequency,
1531
                env=self.env,
1532
            ),
1533
            history.HistorySpec(
1534
                bar_count=bar_count + 2,
1535
                frequency=freq,
1536
                field=field,
1537
                ffill=True,
1538
                data_frequency=data_frequency,
1539
                env=self.env,
1540
            ),
1541
        )
1542
1543
        for spec in to_add:
1544
            container.ensure_spec(spec, initial_dt, bar_data)
1545
1546
            self.assertEqual(
1547
                container.digest_panels[spec.frequency].window_length,
1548
                spec.bar_count - 1,
1549
            )
1550
1551
            self.assert_history(container, spec, initial_dt)
1552
1553
    @subtest(
1554
        ((bar_count, freq, pair, data_frequency)
1555
         for bar_count in (1, 2)
1556
         for freq in ('1m', '1d')
1557
         for pair in product(HistoryContainer.VALID_FIELDS, repeat=2)
1558
         for data_frequency in ('minute', 'daily')
1559
         if not (freq == '1m' and data_frequency == 'daily')),
1560
        'bar_count',
1561
        'freq',
1562
        'pair',
1563
        'data_frequency',
1564
    )
1565
    def test_history_add_field(self, bar_count, freq, pair, data_frequency):
1566
        first, second = pair
1567
        spec = history.HistorySpec(
1568
            bar_count=bar_count,
1569
            frequency=freq,
1570
            field=first,
1571
            ffill=True,
1572
            data_frequency=data_frequency,
1573
            env=self.env,
1574
        )
1575
        specs = {spec.key_str: spec}
1576
        initial_sids = [1]
1577
        initial_dt = pd.Timestamp(
1578
            '2013-06-28 13:31'
1579
            if data_frequency == 'minute'
1580
            else '2013-06-28 12:00AM',
1581
            tz='UTC',
1582
        )
1583
1584
        container = HistoryContainer(
1585
            specs, initial_sids, initial_dt, data_frequency, env=self.env
1586
        )
1587
1588
        if bar_count > 1:
1589
            self.assertEqual(
1590
                container.digest_panels[spec.frequency].window_length, 1,
1591
            )
1592
1593
        bar_data = BarData()
1594
        container.update(bar_data, initial_dt)
1595
1596
        new_spec = history.HistorySpec(
1597
            bar_count,
1598
            frequency=freq,
1599
            field=second,
1600
            ffill=True,
1601
            data_frequency=data_frequency,
1602
            env=self.env,
1603
        )
1604
1605
        container.ensure_spec(new_spec, initial_dt, bar_data)
1606
1607
        if bar_count > 1:
1608
            digest_panel = container.digest_panels[new_spec.frequency]
1609
            self.assertEqual(digest_panel.window_length, bar_count - 1)
1610
            self.assertIn(second, digest_panel.items)
1611
        else:
1612
            self.assertNotIn(new_spec.frequency, container.digest_panels)
1613
1614
        with warnings.catch_warnings():
1615
            warnings.simplefilter('ignore')
1616
1617
            self.assert_history(container, new_spec, initial_dt)
1618
1619
    @subtest(
1620
        ((bar_count, pair, field, data_frequency)
1621
         for bar_count in (1, 2)
1622
         for pair in product(('1m', '1d'), repeat=2)
1623
         for field in HistoryContainer.VALID_FIELDS
1624
         for data_frequency in ('minute', 'daily')
1625
         if not ('1m' in pair and data_frequency == 'daily')),
1626
        'bar_count',
1627
        'pair',
1628
        'field',
1629
        'data_frequency',
1630
    )
1631
    def test_history_add_freq(self, bar_count, pair, field, data_frequency):
1632
        first, second = pair
1633
        spec = history.HistorySpec(
1634
            bar_count=bar_count,
1635
            frequency=first,
1636
            field=field,
1637
            ffill=True,
1638
            data_frequency=data_frequency,
1639
            env=self.env,
1640
        )
1641
        specs = {spec.key_str: spec}
1642
        initial_sids = [1]
1643
        initial_dt = pd.Timestamp(
1644
            '2013-06-28 13:31'
1645
            if data_frequency == 'minute'
1646
            else '2013-06-28 12:00AM',
1647
            tz='UTC',
1648
        )
1649
1650
        container = HistoryContainer(
1651
            specs, initial_sids, initial_dt, data_frequency, env=self.env,
1652
        )
1653
1654
        if bar_count > 1:
1655
            self.assertEqual(
1656
                container.digest_panels[spec.frequency].window_length, 1,
1657
            )
1658
1659
        bar_data = BarData()
1660
        container.update(bar_data, initial_dt)
1661
1662
        new_spec = history.HistorySpec(
1663
            bar_count,
1664
            frequency=second,
1665
            field=field,
1666
            ffill=True,
1667
            data_frequency=data_frequency,
1668
            env=self.env,
1669
        )
1670
1671
        container.ensure_spec(new_spec, initial_dt, bar_data)
1672
1673
        if bar_count > 1:
1674
            digest_panel = container.digest_panels[new_spec.frequency]
1675
            self.assertEqual(digest_panel.window_length, bar_count - 1)
1676
        else:
1677
            self.assertNotIn(new_spec.frequency, container.digest_panels)
1678
1679
        self.assert_history(container, new_spec, initial_dt)
1680
1681
    def assert_history(self, container, spec, dt):
1682
        hst = container.get_history(spec, dt)
1683
1684
        self.assertEqual(len(hst), spec.bar_count)
1685
1686
        back = spec.frequency.prev_bar
1687
        for n in reversed(hst.index):
1688
            self.assertEqual(dt, n)
1689
            dt = back(dt)
1690