Completed
Pull Request — master (#858)
by Eddie
02:03
created

test_algo_with_rl_violation()   A

Complexity

Conditions 3

Size

Total Lines 17

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 3
dl 0
loc 17
rs 9.4286
1
import pandas as pd
2
3
from datetime import timedelta
4
from unittest import TestCase
5
from testfixtures import TempDirectory
6
7
from zipline.algorithm import TradingAlgorithm
8
from zipline.errors import TradingControlViolation
9
from zipline.finance.trading import TradingEnvironment
10
from zipline.utils.test_utils import (
11
    setup_logger, teardown_logger, security_list_copy, add_security_data,)
12
from zipline.utils import factory
13
from zipline.utils.security_list import (
14
    SecurityListSet, load_from_directory)
15
from zipline.utils.test_utils import create_data_portal
16
17
LEVERAGED_ETFS = load_from_directory('leveraged_etf_list')
18
19
20
class RestrictedAlgoWithCheck(TradingAlgorithm):
21
    def initialize(self, symbol):
22
        self.rl = SecurityListSet(self.get_datetime, self.asset_finder)
23
        self.set_do_not_order_list(self.rl.leveraged_etf_list)
24
        self.order_count = 0
25
        self.sid = self.symbol(symbol)
26
27
    def handle_data(self, data):
28
        if not self.order_count:
29
            if self.sid not in \
30
                    self.rl.leveraged_etf_list:
31
                self.order(self.sid, 100)
32
                self.order_count += 1
33
34
35
class RestrictedAlgoWithoutCheck(TradingAlgorithm):
36
    def initialize(self, symbol):
37
        self.rl = SecurityListSet(self.get_datetime, self.asset_finder)
38
        self.set_do_not_order_list(self.rl.leveraged_etf_list)
39
        self.order_count = 0
40
        self.sid = self.symbol(symbol)
41
42
    def handle_data(self, data):
43
        self.order(self.sid, 100)
44
        self.order_count += 1
45
46
47
class IterateRLAlgo(TradingAlgorithm):
48
    def initialize(self, symbol):
49
        self.rl = SecurityListSet(self.get_datetime, self.asset_finder)
50
        self.set_do_not_order_list(self.rl.leveraged_etf_list)
51
        self.order_count = 0
52
        self.sid = self.symbol(symbol)
53
        self.found = False
54
55
    def handle_data(self, data):
56
        for stock in self.rl.leveraged_etf_list:
57
            if stock == self.sid:
58
                self.found = True
59
60
61
class SecurityListTestCase(TestCase):
62
63
    @classmethod
64
    def setUpClass(cls):
65
        # this is ugly, but we need to create two different
66
        # TradingEnvironment/DataPortal pairs
67
68
        cls.env = TradingEnvironment()
69
        cls.env2 = TradingEnvironment()
70
71
        cls.extra_knowledge_date = pd.Timestamp("2015-01-27", tz='UTC')
72
        cls.trading_day_before_first_kd = pd.Timestamp("2015-01-23", tz='UTC')
73
74
        symbols = ['AAPL', 'GOOG', 'BZQ', 'URTY', 'JFT']
75
76
        days = cls.env.days_in_range(
77
            list(LEVERAGED_ETFS.keys())[0],
78
            pd.Timestamp("2015-02-17", tz='UTC')
79
        )
80
81
        cls.sim_params = factory.create_simulation_parameters(
82
            start=list(LEVERAGED_ETFS.keys())[0],
83
            num_days=4,
84
            env=cls.env
85
        )
86
87
        cls.sim_params2 = factory.create_simulation_parameters(
88
            start=cls.trading_day_before_first_kd, num_days=4
89
        )
90
91
        equities_metadata = {}
92
93
        for i, symbol in enumerate(symbols):
94
            equities_metadata[i] = {
95
                'start_date': days[0],
96
                'end_date': days[-1],
97
                'symbol': symbol
98
            }
99
100
        equities_metadata2 = {}
101
        for i, symbol in enumerate(symbols):
102
            equities_metadata2[i] = {
103
                'start_date': cls.sim_params2.period_start,
104
                'end_date': cls.sim_params2.period_end,
105
                'symbol': symbol
106
            }
107
108
        cls.env.write_data(equities_data=equities_metadata)
109
        cls.env2.write_data(equities_data=equities_metadata2)
110
111
        cls.tempdir = TempDirectory()
112
        cls.tempdir2 = TempDirectory()
113
114
        cls.data_portal = create_data_portal(
115
            env=cls.env,
116
            tempdir=cls.tempdir,
117
            sim_params=cls.sim_params,
118
            sids=range(0, 5),
119
        )
120
121
        cls.data_portal2 = create_data_portal(
122
            env=cls.env2,
123
            tempdir=cls.tempdir2,
124
            sim_params=cls.sim_params2,
125
            sids=range(0, 5)
126
        )
127
128
        setup_logger(cls)
129
130
    @classmethod
131
    def tearDownClass(cls):
132
        del cls.env
133
        cls.tempdir.cleanup()
134
        cls.tempdir2.cleanup()
135
        teardown_logger(cls)
136
137
    def test_iterate_over_restricted_list(self):
138
        algo = IterateRLAlgo(symbol='BZQ', sim_params=self.sim_params,
139
                             env=self.env)
140
141
        algo.run(self.data_portal)
142
        self.assertTrue(algo.found)
143
144
    def test_security_list(self):
145
        # set the knowledge date to the first day of the
146
        # leveraged etf knowledge date.
147
        def get_datetime():
148
            return list(LEVERAGED_ETFS.keys())[0]
149
150
        rl = SecurityListSet(get_datetime, self.env.asset_finder)
151
        # assert that a sample from the leveraged list are in restricted
152
        should_exist = [
153
            asset.sid for asset in
154
            [self.env.asset_finder.lookup_symbol(
155
                symbol,
156
                as_of_date=self.extra_knowledge_date)
157
             for symbol in ["BZQ", "URTY", "JFT"]]
158
        ]
159
        for sid in should_exist:
160
            self.assertIn(sid, rl.leveraged_etf_list)
161
162
        # assert that a sample of allowed stocks are not in restricted
163
        shouldnt_exist = [
164
            asset.sid for asset in
165
            [self.env.asset_finder.lookup_symbol(
166
                symbol,
167
                as_of_date=self.extra_knowledge_date)
168
             for symbol in ["AAPL", "GOOG"]]
169
        ]
170
        for sid in shouldnt_exist:
171
            self.assertNotIn(sid, rl.leveraged_etf_list)
172
173
    def test_security_add(self):
174
        def get_datetime():
175
            return pd.Timestamp("2015-01-27", tz='UTC')
176
        with security_list_copy():
177
            add_security_data(['AAPL', 'GOOG'], [])
178
            rl = SecurityListSet(get_datetime, self.env.asset_finder)
179
            should_exist = [
180
                asset.sid for asset in
181
                [self.env.asset_finder.lookup_symbol(
182
                    symbol,
183
                    as_of_date=self.extra_knowledge_date
184
                ) for symbol in ["AAPL", "GOOG", "BZQ", "URTY"]]
185
            ]
186
            for sid in should_exist:
187
                self.assertIn(sid, rl.leveraged_etf_list)
188
189
    def test_security_add_delete(self):
190
        with security_list_copy():
191
            def get_datetime():
192
                return pd.Timestamp("2015-01-27", tz='UTC')
193
            rl = SecurityListSet(get_datetime, self.env.asset_finder)
194
            self.assertNotIn("BZQ", rl.leveraged_etf_list)
195
            self.assertNotIn("URTY", rl.leveraged_etf_list)
196
197
    def test_algo_without_rl_violation_via_check(self):
198
        algo = RestrictedAlgoWithCheck(symbol='BZQ',
199
                                       sim_params=self.sim_params,
200
                                       env=self.env)
201
        algo.run(self.data_portal)
202
203
    def test_algo_without_rl_violation(self):
204
        algo = RestrictedAlgoWithoutCheck(symbol='AAPL',
205
                                          sim_params=self.sim_params,
206
                                          env=self.env)
207
        algo.run(self.data_portal)
208
209
    def test_algo_with_rl_violation(self):
210
        algo = RestrictedAlgoWithoutCheck(symbol='BZQ',
211
                                          sim_params=self.sim_params,
212
                                          env=self.env)
213
        with self.assertRaises(TradingControlViolation) as ctx:
214
            algo.run(self.data_portal)
215
216
        self.check_algo_exception(algo, ctx, 0)
217
218
        # repeat with a symbol from a different lookup date
219
        algo = RestrictedAlgoWithoutCheck(symbol='JFT',
220
                                          sim_params=self.sim_params,
221
                                          env=self.env)
222
        with self.assertRaises(TradingControlViolation) as ctx:
223
            algo.run(self.data_portal)
224
225
        self.check_algo_exception(algo, ctx, 0)
226
227
    def test_algo_with_rl_violation_after_knowledge_date(self):
228
        sim_params = factory.create_simulation_parameters(
229
            start=list(
230
                LEVERAGED_ETFS.keys())[0] + timedelta(days=7), num_days=5,
231
            env=self.env)
232
233
        data_portal = create_data_portal(
234
            self.env,
235
            self.tempdir,
236
            sim_params=sim_params,
237
            sids=range(0, 5)
238
        )
239
240
        algo = RestrictedAlgoWithoutCheck(symbol='BZQ',
241
                                          sim_params=sim_params,
242
                                          env=self.env)
243
        with self.assertRaises(TradingControlViolation) as ctx:
244
            algo.run(data_portal=data_portal)
245
246
        self.check_algo_exception(algo, ctx, 0)
247
248
    def test_algo_with_rl_violation_cumulative(self):
249
        """
250
        Add a new restriction, run a test long after both
251
        knowledge dates, make sure stock from original restriction
252
        set is still disallowed.
253
        """
254
        sim_params = factory.create_simulation_parameters(
255
            start=list(
256
                LEVERAGED_ETFS.keys())[0] + timedelta(days=7), num_days=4)
257
258
        with security_list_copy():
259
            add_security_data(['AAPL'], [])
260
            algo = RestrictedAlgoWithoutCheck(
261
                symbol='BZQ', sim_params=sim_params, env=self.env)
262
            with self.assertRaises(TradingControlViolation) as ctx:
263
                algo.run(self.data_portal)
264
265
            self.check_algo_exception(algo, ctx, 0)
266
267
    def test_algo_without_rl_violation_after_delete(self):
268
        new_tempdir = TempDirectory()
269
        try:
270
            with security_list_copy():
271
                # add a delete statement removing bzq
272
                # write a new delete statement file to disk
273
                add_security_data([], ['BZQ'])
274
275
                # now fast-forward to self.extra_knowledge_date.  requires
276
                # a new env, simparams, and dataportal
277
                env = TradingEnvironment()
278
                sim_params = factory.create_simulation_parameters(
279
                    start=self.extra_knowledge_date, num_days=4, env=env)
280
281
                env.write_data(equities_data={
282
                    "0": {
283
                        'symbol': 'BZQ',
284
                        'start_date': sim_params.period_start,
285
                        'end_date': sim_params.period_end,
286
                    }
287
                })
288
289
                data_portal = create_data_portal(
290
                    env,
291
                    new_tempdir,
292
                    sim_params,
293
                    range(0, 5)
294
                )
295
296
                algo = RestrictedAlgoWithoutCheck(
297
                    symbol='BZQ', sim_params=sim_params, env=env
298
                )
299
                algo.run(data_portal)
300
301
        finally:
302
            new_tempdir.cleanup()
303
304
    def test_algo_with_rl_violation_after_add(self):
305
        with security_list_copy():
306
            add_security_data(['AAPL'], [])
307
308
            algo = RestrictedAlgoWithoutCheck(symbol='AAPL',
309
                                              sim_params=self.sim_params2,
310
                                              env=self.env2)
311
            with self.assertRaises(TradingControlViolation) as ctx:
312
                algo.run(self.data_portal2)
313
314
            self.check_algo_exception(algo, ctx, 2)
315
316
    def check_algo_exception(self, algo, ctx, expected_order_count):
317
        self.assertEqual(algo.order_count, expected_order_count)
318
        exc = ctx.exception
319
        self.assertEqual(TradingControlViolation, type(exc))
320
        exc_msg = str(ctx.exception)
321
        self.assertTrue("RestrictedListOrder" in exc_msg)
322