Completed
Pull Request — master (#858)
by Eddie
10:07 queued 01:13
created

tests.TestBenchmark   A

Complexity

Total Complexity 11

Size/Duplication

Total Lines 180
Duplicated Lines 0 %
Metric Value
dl 0
loc 180
rs 10
wmc 11

6 Methods

Rating   Name   Duplication   Size   Complexity  
A tearDownClass() 0 4 1
B test_asset_IPOed_same_day() 0 36 2
A test_normal() 0 19 2
A test_no_stock_dividends_allowed() 0 13 2
B test_asset_not_trading() 0 27 3
A setUpClass() 0 71 1
1
#
2
# Copyright 2015 Quantopian, Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
import os
16
from unittest import TestCase
17
from datetime import timedelta
18
import numpy as np
19
import pandas as pd
20
from testfixtures import TempDirectory
21
from zipline.data.us_equity_pricing import SQLiteAdjustmentWriter, \
22
    SQLiteAdjustmentReader
23
from zipline.errors import (
24
    BenchmarkAssetNotAvailableTooEarly,
25
    BenchmarkAssetNotAvailableTooLate,
26
    InvalidBenchmarkAsset)
27
28
from zipline.finance.trading import TradingEnvironment
29
from zipline.sources.benchmark_source import BenchmarkSource
30
from zipline.utils import factory
31
from zipline.utils.test_utils import create_data_portal, write_minute_data
32
from .test_perf_tracking import MockDailyBarSpotReader
33
34
35
class TestBenchmark(TestCase):
36
    @classmethod
37
    def setUpClass(cls):
38
        cls.env = TradingEnvironment()
39
        cls.tempdir = TempDirectory()
40
41
        cls.sim_params = factory.create_simulation_parameters()
42
43
        cls.env.write_data(equities_data={
44
            1: {
45
                "start_date": cls.sim_params.trading_days[0],
46
                "end_date": cls.sim_params.trading_days[-1] + timedelta(days=1)
47
            },
48
            2: {
49
                "start_date": cls.sim_params.trading_days[0],
50
                "end_date": cls.sim_params.trading_days[-1] + timedelta(days=1)
51
            },
52
            3: {
53
                "start_date": cls.sim_params.trading_days[100],
54
                "end_date": cls.sim_params.trading_days[-100]
55
            },
56
            4: {
57
                "start_date": cls.sim_params.trading_days[0],
58
                "end_date": cls.sim_params.trading_days[-1] + timedelta(days=1)
59
            }
60
61
        })
62
63
        dbpath = os.path.join(cls.tempdir.path, "adjustments.db")
64
65
        writer = SQLiteAdjustmentWriter(dbpath, cls.env.trading_days,
66
                                        MockDailyBarSpotReader())
67
        splits = mergers = pd.DataFrame(
68
            {
69
                # Hackery to make the dtypes correct on an empty frame.
70
                'effective_date': np.array([], dtype=int),
71
                'ratio': np.array([], dtype=float),
72
                'sid': np.array([], dtype=int),
73
            },
74
            index=pd.DatetimeIndex([], tz='UTC'),
75
            columns=['effective_date', 'ratio', 'sid'],
76
        )
77
        dividends = pd.DataFrame({
78
            'sid': np.array([], dtype=np.uint32),
79
            'amount': np.array([], dtype=np.float64),
80
            'declared_date': np.array([], dtype='datetime64[ns]'),
81
            'ex_date': np.array([], dtype='datetime64[ns]'),
82
            'pay_date': np.array([], dtype='datetime64[ns]'),
83
            'record_date': np.array([], dtype='datetime64[ns]'),
84
        })
85
        declared_date = cls.sim_params.trading_days[45]
86
        ex_date = cls.sim_params.trading_days[50]
87
        record_date = pay_date = cls.sim_params.trading_days[55]
88
89
        stock_dividends = pd.DataFrame({
90
            'sid': np.array([4], dtype=np.uint32),
91
            'payment_sid': np.array([5], dtype=np.uint32),
92
            'ratio': np.array([2], dtype=np.float64),
93
            'declared_date': np.array([declared_date], dtype='datetime64[ns]'),
94
            'ex_date': np.array([ex_date], dtype='datetime64[ns]'),
95
            'record_date': np.array([record_date], dtype='datetime64[ns]'),
96
            'pay_date': np.array([pay_date], dtype='datetime64[ns]'),
97
        })
98
        writer.write(splits, mergers, dividends,
99
                     stock_dividends=stock_dividends)
100
101
        cls.data_portal = create_data_portal(
102
            cls.env,
103
            cls.tempdir,
104
            cls.sim_params,
105
            [1, 2, 3, 4],
106
            adjustment_reader=SQLiteAdjustmentReader(dbpath)
107
        )
108
109
    @classmethod
110
    def tearDownClass(cls):
111
        del cls.env
112
        cls.tempdir.cleanup()
113
114
    def test_normal(self):
115
        days_to_use = self.sim_params.trading_days[1:]
116
117
        source = BenchmarkSource(
118
            1, self.env, days_to_use, self.data_portal
119
        )
120
121
        # should be the equivalent of getting the price history, then doing
122
        # a pct_change on it
123
        manually_calculated = self.data_portal.get_history_window(
124
            [1], days_to_use[-1], len(days_to_use), "1d", "close_price"
125
        )[1].pct_change()
126
127
        # compare all the fields except the first one, for which we don't have
128
        # data in manually_calculated
129
        for idx, day in enumerate(days_to_use[1:]):
130
            self.assertEqual(
131
                source.get_value(day),
132
                manually_calculated[idx + 1]
133
            )
134
135
    def test_asset_not_trading(self):
136
        with self.assertRaises(BenchmarkAssetNotAvailableTooEarly) as exc:
137
            BenchmarkSource(
138
                3,
139
                self.env,
140
                self.sim_params.trading_days[1:],
141
                self.data_portal
142
            )
143
144
        self.assertEqual(
145
            '3 does not exist on 2006-01-04 00:00:00+00:00. '
146
            'It started trading on 2006-05-26 00:00:00+00:00.',
147
            exc.exception.message
148
        )
149
150
        with self.assertRaises(BenchmarkAssetNotAvailableTooLate) as exc2:
151
            BenchmarkSource(
152
                3,
153
                self.env,
154
                self.sim_params.trading_days[120:],
155
                self.data_portal
156
            )
157
158
        self.assertEqual(
159
            '3 does not exist on 2006-06-26 00:00:00+00:00. '
160
            'It stopped trading on 2006-08-09 00:00:00+00:00.',
161
            exc2.exception.message
162
        )
163
164
    def test_asset_IPOed_same_day(self):
165
        # gotta get some minute data up in here.
166
        # add sid 4 for a couple of days
167
        minutes = self.env.minutes_for_days_in_range(
168
            self.sim_params.trading_days[0],
169
            self.sim_params.trading_days[5]
170
        )
171
172
        path = write_minute_data(
173
            self.tempdir,
174
            minutes,
175
            [2]
176
        )
177
178
        self.data_portal._minutes_equities_path = path
179
180
        source = BenchmarkSource(
181
            2,
182
            self.env,
183
            self.sim_params.trading_days,
184
            self.data_portal
185
        )
186
187
        days_to_use = self.sim_params.trading_days
188
189
        # first value should be 0.0, coming from daily data
190
        self.assertAlmostEquals(0.0, source.get_value(days_to_use[0]))
191
192
        manually_calculated = self.data_portal.get_history_window(
193
            [2], days_to_use[-1], len(days_to_use), "1d", "close_price"
194
        )[2].pct_change()
195
196
        for idx, day in enumerate(days_to_use[1:]):
197
            self.assertEqual(
198
                source.get_value(day),
199
                manually_calculated[idx + 1]
200
            )
201
202
    def test_no_stock_dividends_allowed(self):
203
        # try to use sid(4) as benchmark, should blow up due to the presence
204
        # of a stock dividend
205
206
        with self.assertRaises(InvalidBenchmarkAsset) as exc:
207
            BenchmarkSource(
208
                4, self.env, self.sim_params.trading_days, self.data_portal
209
            )
210
211
        self.assertEqual("4 cannot be used as the benchmark because it has a "
212
                         "stock dividend on 2006-03-16 00:00:00.  Choose "
213
                         "another asset to use as the benchmark.",
214
                         exc.exception.message)
215