Completed
Pull Request — master (#940)
by Joe
01:26
created

tests.wrapper()   A

Complexity

Conditions 1

Size

Total Lines 11

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 1
dl 0
loc 11
rs 9.4286
1
#
2
# Copyright 2014 Quantopian, Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
from datetime import timedelta
16
from functools import wraps
17
from itertools import product
18
from nose_parameterized import parameterized
19
import operator
20
import random
21
from six import itervalues
22
from six.moves import map
23
from unittest import TestCase
24
25
import numpy as np
26
from numpy.testing import assert_allclose
27
28
from zipline.finance.trading import TradingEnvironment
29
from zipline.algorithm import TradingAlgorithm
30
import zipline.utils.factory as factory
31
from zipline.api import add_transform, get_datetime
32
33
34
def handle_data_wrapper(f):
35
    @wraps(f)
36
    def wrapper(context, data):
37
        dt = get_datetime()
38
        if dt.date() != context.current_date:
39
            context.warmup -= 1
40
            context.mins_for_days.append(1)
41
            context.current_date = dt.date()
42
        else:
43
            context.mins_for_days[-1] += 1
44
45
        hist = context.history(2, '1d', 'close_price')
46
        for n in (1, 2, 3):
47
            if n in data:
48
                if data[n].dt == dt:
49
                    context.vol_bars[n].append(data[n].volume)
50
                else:
51
                    context.vol_bars[n].append(0)
52
53
                context.price_bars[n].append(data[n].price)
54
            else:
55
                context.price_bars[n].append(np.nan)
56
                context.vol_bars[n].append(0)
57
58
            context.last_close_prices[n] = hist[n][0]
59
60
        if context.warmup < 0:
61
            return f(context, data)
62
63
    return wrapper
64
65
66
def initialize_with(test_case, tfm_name, days):
67
    def initalize(context):
68
        context.test_case = test_case
69
        context.days = days
70
        context.mins_for_days = []
71
        context.price_bars = (None, [np.nan], [np.nan], [np.nan])
72
        context.vol_bars = (None, [np.nan], [np.nan], [np.nan])
73
        if context.days:
74
            context.warmup = days + 1
75
        else:
76
            context.warmup = 2
77
78
        context.current_date = None
79
80
        context.last_close_prices = [np.nan, np.nan, np.nan, np.nan]
81
        add_transform(tfm_name, days)
82
83
    return initalize
84
85
86
def windows_with_frequencies(*args):
87
    args = args or (None,)
88
    return product(('daily', 'minute'), args)
89
90
91
def with_algo(f):
92
    name = f.__name__
93
    if not name.startswith('test_'):
94
        raise ValueError('This must decorate a test case')
95
96
    tfm_name = name[len('test_'):]
97
98
    @wraps(f)
99
    def wrapper(self, data_frequency, days=None):
100
        sim_params, source = self.sim_and_source[data_frequency]
101
102
        algo = TradingAlgorithm(
103
            initialize=initialize_with(self, tfm_name, days),
104
            handle_data=handle_data_wrapper(f),
105
            sim_params=sim_params,
106
            env=self.env,
107
        )
108
        algo.run(source)
109
110
    return wrapper
111
112
113
class TransformTestCase(TestCase):
114
    """
115
    Tests the simple transforms by running them through a zipline.
116
    """
117
    @classmethod
118
    def setUpClass(cls):
119
        random.seed(0)
120
        cls.sids = (1, 2, 3)
121
        minute_sim_ps = factory.create_simulation_parameters(
122
            num_days=3,
123
            data_frequency='minute',
124
            emission_rate='minute',
125
        )
126
        daily_sim_ps = factory.create_simulation_parameters(
127
            num_days=30,
128
            data_frequency='daily',
129
            emission_rate='daily',
130
        )
131
        cls.env = TradingEnvironment()
132
        cls.env.write_data(equities_identifiers=[1, 2, 3])
133
        cls.sim_and_source = {
134
            'minute': (minute_sim_ps, factory.create_minutely_trade_source(
135
                cls.sids,
136
                sim_params=minute_sim_ps,
137
                env=cls.env,
138
            )),
139
            'daily': (daily_sim_ps, factory.create_trade_source(
140
                cls.sids,
141
                trade_time_increment=timedelta(days=1),
142
                sim_params=daily_sim_ps,
143
                env=cls.env,
144
            )),
145
        }
146
147
    @classmethod
148
    def tearDownClass(cls):
149
        del cls.env
150
151
    def tearDown(self):
152
        """
153
        Each test consumes a source, we need to rewind it.
154
        """
155
        for _, source in itervalues(self.sim_and_source):
156
            source.rewind()
157
158
    @parameterized.expand(windows_with_frequencies(1, 2, 3, 4))
159
    @with_algo
160
    def test_mavg(context, data):
161
        """
162
        Tests the mavg transform by manually keeping track of the prices
163
        in a naiive way and asserting that our mean is the same.
164
        """
165
        mins = sum(context.mins_for_days[-context.days:])
166
167
        for sid in data:
168
            assert_allclose(
169
                data[sid].mavg(context.days),
170
                np.mean(context.price_bars[sid][-mins:]),
171
            )
172
173
    @parameterized.expand(windows_with_frequencies(2, 3, 4))
174
    @with_algo
175
    def test_stddev(context, data):
176
        """
177
        Tests the stddev transform by manually keeping track of the prices
178
        in a naiive way and asserting that our stddev is the same.
179
        This accounts for the corrected ddof.
180
        """
181
        mins = sum(context.mins_for_days[-context.days:])
182
183
        for sid in data:
184
            assert_allclose(
185
                data[sid].stddev(context.days),
186
                np.std(context.price_bars[sid][-mins:], ddof=1),
187
            )
188
189
    @parameterized.expand(windows_with_frequencies(2, 3, 4))
190
    @with_algo
191
    def test_vwap(context, data):
192
        """
193
        Tests the vwap transform by manually keeping track of the prices
194
        and volumes in a naiive way and asserting that our hand-rolled vwap is
195
        the same
196
        """
197
        mins = sum(context.mins_for_days[-context.days:])
198
        for sid in data:
199
            prices = context.price_bars[sid][-mins:]
200
            vols = context.vol_bars[sid][-mins:]
201
            manual_vwap = sum(
202
                map(operator.mul, np.nan_to_num(np.array(prices)), vols),
203
            ) / sum(vols)
204
205
            assert_allclose(
206
                data[sid].vwap(context.days),
207
                manual_vwap,
208
            )
209
210
    @parameterized.expand(windows_with_frequencies())
211
    @with_algo
212
    def test_returns(context, data):
213
        for sid in data:
214
            last_close = context.last_close_prices[sid]
215
            returns = (data[sid].price - last_close) / last_close
216
217
            assert_allclose(
218
                data[sid].returns(),
219
                returns,
220
            )
221