|
1
|
|
|
from contextlib import contextmanager |
|
2
|
|
|
from functools import wraps |
|
3
|
|
|
from itertools import ( |
|
4
|
|
|
combinations, |
|
5
|
|
|
count, |
|
6
|
|
|
product, |
|
7
|
|
|
) |
|
8
|
|
|
import operator |
|
9
|
|
|
import os |
|
10
|
|
|
import shutil |
|
11
|
|
|
from string import ascii_uppercase |
|
12
|
|
|
import tempfile |
|
13
|
|
|
|
|
14
|
|
|
from logbook import FileHandler |
|
15
|
|
|
from mock import patch |
|
16
|
|
|
from numpy.testing import assert_allclose, assert_array_equal |
|
17
|
|
|
import pandas as pd |
|
18
|
|
|
from pandas.tseries.offsets import MonthBegin |
|
19
|
|
|
from six import iteritems, itervalues |
|
20
|
|
|
from six.moves import filter |
|
21
|
|
|
from sqlalchemy import create_engine |
|
22
|
|
|
from toolz import concat |
|
23
|
|
|
|
|
24
|
|
|
from zipline.assets import AssetFinder |
|
25
|
|
|
from zipline.assets.asset_writer import AssetDBWriterFromDataFrame |
|
26
|
|
|
from zipline.assets.futures import CME_CODE_TO_MONTH |
|
27
|
|
|
from zipline.finance.order import ORDER_STATUS |
|
28
|
|
|
from zipline.utils import security_list |
|
29
|
|
|
from zipline.utils.tradingcalendar import trading_days |
|
30
|
|
|
|
|
31
|
|
|
|
|
32
|
|
|
EPOCH = pd.Timestamp(0, tz='UTC') |
|
33
|
|
|
|
|
34
|
|
|
|
|
35
|
|
|
def seconds_to_timestamp(seconds): |
|
36
|
|
|
return pd.Timestamp(seconds, unit='s', tz='UTC') |
|
37
|
|
|
|
|
38
|
|
|
|
|
39
|
|
|
def to_utc(time_str): |
|
40
|
|
|
"""Convert a string in US/Eastern time to UTC""" |
|
41
|
|
|
return pd.Timestamp(time_str, tz='US/Eastern').tz_convert('UTC') |
|
42
|
|
|
|
|
43
|
|
|
|
|
44
|
|
|
def str_to_seconds(s): |
|
45
|
|
|
""" |
|
46
|
|
|
Convert a pandas-intelligible string to (integer) seconds since UTC. |
|
47
|
|
|
|
|
48
|
|
|
>>> from pandas import Timestamp |
|
49
|
|
|
>>> (Timestamp('2014-01-01') - Timestamp(0)).total_seconds() |
|
50
|
|
|
1388534400.0 |
|
51
|
|
|
>>> str_to_seconds('2014-01-01') |
|
52
|
|
|
1388534400 |
|
53
|
|
|
""" |
|
54
|
|
|
return int((pd.Timestamp(s, tz='UTC') - EPOCH).total_seconds()) |
|
55
|
|
|
|
|
56
|
|
|
|
|
57
|
|
|
def setup_logger(test, path='test.log'): |
|
58
|
|
|
test.log_handler = FileHandler(path) |
|
59
|
|
|
test.log_handler.push_application() |
|
60
|
|
|
|
|
61
|
|
|
|
|
62
|
|
|
def teardown_logger(test): |
|
63
|
|
|
test.log_handler.pop_application() |
|
64
|
|
|
test.log_handler.close() |
|
65
|
|
|
|
|
66
|
|
|
|
|
67
|
|
|
def drain_zipline(test, zipline): |
|
68
|
|
|
output = [] |
|
69
|
|
|
transaction_count = 0 |
|
70
|
|
|
msg_counter = 0 |
|
71
|
|
|
# start the simulation |
|
72
|
|
|
for update in zipline: |
|
73
|
|
|
msg_counter += 1 |
|
74
|
|
|
output.append(update) |
|
75
|
|
|
if 'daily_perf' in update: |
|
76
|
|
|
transaction_count += \ |
|
77
|
|
|
len(update['daily_perf']['transactions']) |
|
78
|
|
|
|
|
79
|
|
|
return output, transaction_count |
|
80
|
|
|
|
|
81
|
|
|
|
|
82
|
|
|
def assert_single_position(test, zipline): |
|
83
|
|
|
|
|
84
|
|
|
output, transaction_count = drain_zipline(test, zipline) |
|
85
|
|
|
|
|
86
|
|
|
if 'expected_transactions' in test.zipline_test_config: |
|
87
|
|
|
test.assertEqual( |
|
88
|
|
|
test.zipline_test_config['expected_transactions'], |
|
89
|
|
|
transaction_count |
|
90
|
|
|
) |
|
91
|
|
|
else: |
|
92
|
|
|
test.assertEqual( |
|
93
|
|
|
test.zipline_test_config['order_count'], |
|
94
|
|
|
transaction_count |
|
95
|
|
|
) |
|
96
|
|
|
|
|
97
|
|
|
# the final message is the risk report, the second to |
|
98
|
|
|
# last is the final day's results. Positions is a list of |
|
99
|
|
|
# dicts. |
|
100
|
|
|
closing_positions = output[-2]['daily_perf']['positions'] |
|
101
|
|
|
|
|
102
|
|
|
# confirm that all orders were filled. |
|
103
|
|
|
# iterate over the output updates, overwriting |
|
104
|
|
|
# orders when they are updated. Then check the status on all. |
|
105
|
|
|
orders_by_id = {} |
|
106
|
|
|
for update in output: |
|
107
|
|
|
if 'daily_perf' in update: |
|
108
|
|
|
if 'orders' in update['daily_perf']: |
|
109
|
|
|
for order in update['daily_perf']['orders']: |
|
110
|
|
|
orders_by_id[order['id']] = order |
|
111
|
|
|
|
|
112
|
|
|
for order in itervalues(orders_by_id): |
|
113
|
|
|
test.assertEqual( |
|
114
|
|
|
order['status'], |
|
115
|
|
|
ORDER_STATUS.FILLED, |
|
116
|
|
|
"") |
|
117
|
|
|
|
|
118
|
|
|
test.assertEqual( |
|
119
|
|
|
len(closing_positions), |
|
120
|
|
|
1, |
|
121
|
|
|
"Portfolio should have one position." |
|
122
|
|
|
) |
|
123
|
|
|
|
|
124
|
|
|
sid = test.zipline_test_config['sid'] |
|
125
|
|
|
test.assertEqual( |
|
126
|
|
|
closing_positions[0]['sid'], |
|
127
|
|
|
sid, |
|
128
|
|
|
"Portfolio should have one position in " + str(sid) |
|
129
|
|
|
) |
|
130
|
|
|
|
|
131
|
|
|
return output, transaction_count |
|
132
|
|
|
|
|
133
|
|
|
|
|
134
|
|
|
class ExceptionSource(object): |
|
135
|
|
|
|
|
136
|
|
|
def __init__(self): |
|
137
|
|
|
pass |
|
138
|
|
|
|
|
139
|
|
|
def get_hash(self): |
|
140
|
|
|
return "ExceptionSource" |
|
141
|
|
|
|
|
142
|
|
|
def __iter__(self): |
|
143
|
|
|
return self |
|
144
|
|
|
|
|
145
|
|
|
def next(self): |
|
146
|
|
|
5 / 0 |
|
147
|
|
|
|
|
148
|
|
|
def __next__(self): |
|
149
|
|
|
5 / 0 |
|
150
|
|
|
|
|
151
|
|
|
|
|
152
|
|
|
@contextmanager |
|
153
|
|
|
def security_list_copy(): |
|
154
|
|
|
old_dir = security_list.SECURITY_LISTS_DIR |
|
155
|
|
|
new_dir = tempfile.mkdtemp() |
|
156
|
|
|
try: |
|
157
|
|
|
for subdir in os.listdir(old_dir): |
|
158
|
|
|
shutil.copytree(os.path.join(old_dir, subdir), |
|
159
|
|
|
os.path.join(new_dir, subdir)) |
|
160
|
|
|
with patch.object(security_list, 'SECURITY_LISTS_DIR', new_dir), \ |
|
161
|
|
|
patch.object(security_list, 'using_copy', True, |
|
162
|
|
|
create=True): |
|
163
|
|
|
yield |
|
164
|
|
|
finally: |
|
165
|
|
|
shutil.rmtree(new_dir, True) |
|
166
|
|
|
|
|
167
|
|
|
|
|
168
|
|
|
def add_security_data(adds, deletes): |
|
169
|
|
|
if not hasattr(security_list, 'using_copy'): |
|
170
|
|
|
raise Exception('add_security_data must be used within ' |
|
171
|
|
|
'security_list_copy context') |
|
172
|
|
|
directory = os.path.join( |
|
173
|
|
|
security_list.SECURITY_LISTS_DIR, |
|
174
|
|
|
"leveraged_etf_list/20150127/20150125" |
|
175
|
|
|
) |
|
176
|
|
|
if not os.path.exists(directory): |
|
177
|
|
|
os.makedirs(directory) |
|
178
|
|
|
del_path = os.path.join(directory, "delete") |
|
179
|
|
|
with open(del_path, 'w') as f: |
|
180
|
|
|
for sym in deletes: |
|
181
|
|
|
f.write(sym) |
|
182
|
|
|
f.write('\n') |
|
183
|
|
|
add_path = os.path.join(directory, "add") |
|
184
|
|
|
with open(add_path, 'w') as f: |
|
185
|
|
|
for sym in adds: |
|
186
|
|
|
f.write(sym) |
|
187
|
|
|
f.write('\n') |
|
188
|
|
|
|
|
189
|
|
|
|
|
190
|
|
|
def all_pairs_matching_predicate(values, pred): |
|
191
|
|
|
""" |
|
192
|
|
|
Return an iterator of all pairs, (v0, v1) from values such that |
|
193
|
|
|
|
|
194
|
|
|
`pred(v0, v1) == True` |
|
195
|
|
|
|
|
196
|
|
|
Parameters |
|
197
|
|
|
---------- |
|
198
|
|
|
values : iterable |
|
199
|
|
|
pred : function |
|
200
|
|
|
|
|
201
|
|
|
Returns |
|
202
|
|
|
------- |
|
203
|
|
|
pairs_iterator : generator |
|
204
|
|
|
Generator yielding pairs matching `pred`. |
|
205
|
|
|
|
|
206
|
|
|
Examples |
|
207
|
|
|
-------- |
|
208
|
|
|
>>> from zipline.utils.test_utils import all_pairs_matching_predicate |
|
209
|
|
|
>>> from operator import eq, lt |
|
210
|
|
|
>>> list(all_pairs_matching_predicate(range(5), eq)) |
|
211
|
|
|
[(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)] |
|
212
|
|
|
>>> list(all_pairs_matching_predicate("abcd", lt)) |
|
213
|
|
|
[('a', 'b'), ('a', 'c'), ('a', 'd'), ('b', 'c'), ('b', 'd'), ('c', 'd')] |
|
214
|
|
|
""" |
|
215
|
|
|
return filter(lambda pair: pred(*pair), product(values, repeat=2)) |
|
216
|
|
|
|
|
217
|
|
|
|
|
218
|
|
|
def product_upper_triangle(values, include_diagonal=False): |
|
219
|
|
|
""" |
|
220
|
|
|
Return an iterator over pairs, (v0, v1), drawn from values. |
|
221
|
|
|
|
|
222
|
|
|
If `include_diagonal` is True, returns all pairs such that v0 <= v1. |
|
223
|
|
|
If `include_diagonal` is False, returns all pairs such that v0 < v1. |
|
224
|
|
|
""" |
|
225
|
|
|
return all_pairs_matching_predicate( |
|
226
|
|
|
values, |
|
227
|
|
|
operator.le if include_diagonal else operator.lt, |
|
228
|
|
|
) |
|
229
|
|
|
|
|
230
|
|
|
|
|
231
|
|
|
def all_subindices(index): |
|
232
|
|
|
""" |
|
233
|
|
|
Return all valid sub-indices of a pandas Index. |
|
234
|
|
|
""" |
|
235
|
|
|
return ( |
|
236
|
|
|
index[start:stop] |
|
237
|
|
|
for start, stop in product_upper_triangle(range(len(index) + 1)) |
|
238
|
|
|
) |
|
239
|
|
|
|
|
240
|
|
|
|
|
241
|
|
|
def make_rotating_equity_info(num_assets, |
|
242
|
|
|
first_start, |
|
243
|
|
|
frequency, |
|
244
|
|
|
periods_between_starts, |
|
245
|
|
|
asset_lifetime): |
|
246
|
|
|
""" |
|
247
|
|
|
Create a DataFrame representing lifetimes of assets that are constantly |
|
248
|
|
|
rotating in and out of existence. |
|
249
|
|
|
|
|
250
|
|
|
Parameters |
|
251
|
|
|
---------- |
|
252
|
|
|
num_assets : int |
|
253
|
|
|
How many assets to create. |
|
254
|
|
|
first_start : pd.Timestamp |
|
255
|
|
|
The start date for the first asset. |
|
256
|
|
|
frequency : str or pd.tseries.offsets.Offset (e.g. trading_day) |
|
257
|
|
|
Frequency used to interpret next two arguments. |
|
258
|
|
|
periods_between_starts : int |
|
259
|
|
|
Create a new asset every `frequency` * `periods_between_new` |
|
260
|
|
|
asset_lifetime : int |
|
261
|
|
|
Each asset exists for `frequency` * `asset_lifetime` days. |
|
262
|
|
|
|
|
263
|
|
|
Returns |
|
264
|
|
|
------- |
|
265
|
|
|
info : pd.DataFrame |
|
266
|
|
|
DataFrame representing newly-created assets. |
|
267
|
|
|
""" |
|
268
|
|
|
return pd.DataFrame( |
|
269
|
|
|
{ |
|
270
|
|
|
'symbol': [chr(ord('A') + i) for i in range(num_assets)], |
|
271
|
|
|
# Start a new asset every `periods_between_starts` days. |
|
272
|
|
|
'start_date': pd.date_range( |
|
273
|
|
|
first_start, |
|
274
|
|
|
freq=(periods_between_starts * frequency), |
|
275
|
|
|
periods=num_assets, |
|
276
|
|
|
), |
|
277
|
|
|
# Each asset lasts for `asset_lifetime` days. |
|
278
|
|
|
'end_date': pd.date_range( |
|
279
|
|
|
first_start + (asset_lifetime * frequency), |
|
280
|
|
|
freq=(periods_between_starts * frequency), |
|
281
|
|
|
periods=num_assets, |
|
282
|
|
|
), |
|
283
|
|
|
'exchange': 'TEST', |
|
284
|
|
|
}, |
|
285
|
|
|
index=range(num_assets), |
|
286
|
|
|
) |
|
287
|
|
|
|
|
288
|
|
|
|
|
289
|
|
|
def make_simple_equity_info(sids, start_date, end_date, symbols=None): |
|
290
|
|
|
""" |
|
291
|
|
|
Create a DataFrame representing assets that exist for the full duration |
|
292
|
|
|
between `start_date` and `end_date`. |
|
293
|
|
|
|
|
294
|
|
|
Parameters |
|
295
|
|
|
---------- |
|
296
|
|
|
sids : array-like of int |
|
297
|
|
|
start_date : pd.Timestamp |
|
298
|
|
|
end_date : pd.Timestamp |
|
299
|
|
|
symbols : list, optional |
|
300
|
|
|
Symbols to use for the assets. |
|
301
|
|
|
If not provided, symbols are generated from the sequence 'A', 'B', ... |
|
302
|
|
|
|
|
303
|
|
|
Returns |
|
304
|
|
|
------- |
|
305
|
|
|
info : pd.DataFrame |
|
306
|
|
|
DataFrame representing newly-created assets. |
|
307
|
|
|
""" |
|
308
|
|
|
num_assets = len(sids) |
|
309
|
|
|
if symbols is None: |
|
310
|
|
|
symbols = list(ascii_uppercase[:num_assets]) |
|
311
|
|
|
return pd.DataFrame( |
|
312
|
|
|
{ |
|
313
|
|
|
'symbol': symbols, |
|
314
|
|
|
'start_date': [start_date] * num_assets, |
|
315
|
|
|
'end_date': [end_date] * num_assets, |
|
316
|
|
|
'exchange': 'TEST', |
|
317
|
|
|
}, |
|
318
|
|
|
index=sids, |
|
319
|
|
|
) |
|
320
|
|
|
|
|
321
|
|
|
|
|
322
|
|
|
def make_future_info(first_sid, |
|
323
|
|
|
root_symbols, |
|
324
|
|
|
years, |
|
325
|
|
|
notice_date_func, |
|
326
|
|
|
expiration_date_func, |
|
327
|
|
|
start_date_func, |
|
328
|
|
|
month_codes=None): |
|
329
|
|
|
""" |
|
330
|
|
|
Create a DataFrame representing futures for `root_symbols` during `year`. |
|
331
|
|
|
|
|
332
|
|
|
Generates a contract per triple of (symbol, year, month) supplied to |
|
333
|
|
|
`root_symbols`, `years`, and `month_codes`. |
|
334
|
|
|
|
|
335
|
|
|
Parameters |
|
336
|
|
|
---------- |
|
337
|
|
|
first_sid : int |
|
338
|
|
|
The first sid to use for assigning sids to the created contracts. |
|
339
|
|
|
root_symbols : list[str] |
|
340
|
|
|
A list of root symbols for which to create futures. |
|
341
|
|
|
years : list[int or str] |
|
342
|
|
|
Years (e.g. 2014), for which to produce individual contracts. |
|
343
|
|
|
notice_date_func : (Timestamp) -> Timestamp |
|
344
|
|
|
Function to generate notice dates from first of the month associated |
|
345
|
|
|
with asset month code. Return NaT to simulate futures with no notice |
|
346
|
|
|
date. |
|
347
|
|
|
expiration_date_func : (Timestamp) -> Timestamp |
|
348
|
|
|
Function to generate expiration dates from first of the month |
|
349
|
|
|
associated with asset month code. |
|
350
|
|
|
start_date_func : (Timestamp) -> Timestamp, optional |
|
351
|
|
|
Function to generate start dates from first of the month associated |
|
352
|
|
|
with each asset month code. Defaults to a start_date one year prior |
|
353
|
|
|
to the month_code date. |
|
354
|
|
|
month_codes : dict[str -> [1..12]], optional |
|
355
|
|
|
Dictionary of month codes for which to create contracts. Entries |
|
356
|
|
|
should be strings mapped to values from 1 (January) to 12 (December). |
|
357
|
|
|
Default is zipline.futures.CME_CODE_TO_MONTH |
|
358
|
|
|
|
|
359
|
|
|
Returns |
|
360
|
|
|
------- |
|
361
|
|
|
futures_info : pd.DataFrame |
|
362
|
|
|
DataFrame of futures data suitable for passing to an |
|
363
|
|
|
AssetDBWriterFromDataFrame. |
|
364
|
|
|
""" |
|
365
|
|
|
if month_codes is None: |
|
366
|
|
|
month_codes = CME_CODE_TO_MONTH |
|
367
|
|
|
|
|
368
|
|
|
year_strs = list(map(str, years)) |
|
369
|
|
|
years = [pd.Timestamp(s, tz='UTC') for s in year_strs] |
|
370
|
|
|
|
|
371
|
|
|
# Pairs of string/date like ('K06', 2006-05-01) |
|
372
|
|
|
contract_suffix_to_beginning_of_month = tuple( |
|
373
|
|
|
(month_code + year_str[-2:], year + MonthBegin(month_num)) |
|
374
|
|
|
for ((year, year_str), (month_code, month_num)) |
|
375
|
|
|
in product( |
|
376
|
|
|
zip(years, year_strs), |
|
377
|
|
|
iteritems(month_codes), |
|
378
|
|
|
) |
|
379
|
|
|
) |
|
380
|
|
|
|
|
381
|
|
|
contracts = [] |
|
382
|
|
|
parts = product(root_symbols, contract_suffix_to_beginning_of_month) |
|
383
|
|
|
for sid, (root_sym, (suffix, month_begin)) in enumerate(parts, first_sid): |
|
384
|
|
|
contracts.append({ |
|
385
|
|
|
'sid': sid, |
|
386
|
|
|
'root_symbol': root_sym, |
|
387
|
|
|
'symbol': root_sym + suffix, |
|
388
|
|
|
'start_date': start_date_func(month_begin), |
|
389
|
|
|
'notice_date': notice_date_func(month_begin), |
|
390
|
|
|
'expiration_date': notice_date_func(month_begin), |
|
391
|
|
|
'contract_multiplier': 500, |
|
392
|
|
|
}) |
|
393
|
|
|
return pd.DataFrame.from_records(contracts, index='sid').convert_objects() |
|
394
|
|
|
|
|
395
|
|
|
|
|
396
|
|
|
def make_commodity_future_info(first_sid, |
|
397
|
|
|
root_symbols, |
|
398
|
|
|
years, |
|
399
|
|
|
month_codes=None): |
|
400
|
|
|
""" |
|
401
|
|
|
Make futures testing data that simulates the notice/expiration date |
|
402
|
|
|
behavior of physical commodities like oil. |
|
403
|
|
|
|
|
404
|
|
|
Parameters |
|
405
|
|
|
---------- |
|
406
|
|
|
first_sid : int |
|
407
|
|
|
root_symbols : list[str] |
|
408
|
|
|
years : list[int] |
|
409
|
|
|
month_codes : dict[str -> int] |
|
410
|
|
|
|
|
411
|
|
|
Expiration dates are on the 20th of the month prior to the month code. |
|
412
|
|
|
Notice dates are are on the 20th two months prior to the month code. |
|
413
|
|
|
Start dates are one year before the contract month. |
|
414
|
|
|
|
|
415
|
|
|
See Also |
|
416
|
|
|
-------- |
|
417
|
|
|
make_future_info |
|
418
|
|
|
""" |
|
419
|
|
|
nineteen_days = pd.Timedelta(days=19) |
|
420
|
|
|
one_year = pd.Timedelta(days=365) |
|
421
|
|
|
return make_future_info( |
|
422
|
|
|
first_sid=first_sid, |
|
423
|
|
|
root_symbols=root_symbols, |
|
424
|
|
|
years=years, |
|
425
|
|
|
notice_date_func=lambda dt: dt - MonthBegin(2) + nineteen_days, |
|
426
|
|
|
expiration_date_func=lambda dt: dt - MonthBegin(1) + nineteen_days, |
|
427
|
|
|
start_date_func=lambda dt: dt - one_year, |
|
428
|
|
|
month_codes=month_codes, |
|
429
|
|
|
) |
|
430
|
|
|
|
|
431
|
|
|
|
|
432
|
|
|
def check_allclose(actual, |
|
433
|
|
|
desired, |
|
434
|
|
|
rtol=1e-07, |
|
435
|
|
|
atol=0, |
|
436
|
|
|
err_msg='', |
|
437
|
|
|
verbose=True): |
|
438
|
|
|
""" |
|
439
|
|
|
Wrapper around np.testing.assert_allclose that also verifies that inputs |
|
440
|
|
|
are ndarrays. |
|
441
|
|
|
|
|
442
|
|
|
See Also |
|
443
|
|
|
-------- |
|
444
|
|
|
np.assert_allclose |
|
445
|
|
|
""" |
|
446
|
|
|
if type(actual) != type(desired): |
|
447
|
|
|
raise AssertionError("%s != %s" % (type(actual), type(desired))) |
|
448
|
|
|
return assert_allclose(actual, desired, err_msg=err_msg, verbose=True) |
|
449
|
|
|
|
|
450
|
|
|
|
|
451
|
|
|
def check_arrays(x, y, err_msg='', verbose=True): |
|
452
|
|
|
""" |
|
453
|
|
|
Wrapper around np.testing.assert_array_equal that also verifies that inputs |
|
454
|
|
|
are ndarrays. |
|
455
|
|
|
|
|
456
|
|
|
See Also |
|
457
|
|
|
-------- |
|
458
|
|
|
np.assert_array_equal |
|
459
|
|
|
""" |
|
460
|
|
|
if type(x) != type(y): |
|
461
|
|
|
raise AssertionError("%s != %s" % (type(x), type(y))) |
|
462
|
|
|
return assert_array_equal(x, y, err_msg=err_msg, verbose=True) |
|
463
|
|
|
|
|
464
|
|
|
|
|
465
|
|
|
class UnexpectedAttributeAccess(Exception): |
|
466
|
|
|
pass |
|
467
|
|
|
|
|
468
|
|
|
|
|
469
|
|
|
class ExplodingObject(object): |
|
470
|
|
|
""" |
|
471
|
|
|
Object that will raise an exception on any attribute access. |
|
472
|
|
|
|
|
473
|
|
|
Useful for verifying that an object is never touched during a |
|
474
|
|
|
function/method call. |
|
475
|
|
|
""" |
|
476
|
|
|
def __getattribute__(self, name): |
|
477
|
|
|
raise UnexpectedAttributeAccess(name) |
|
478
|
|
|
|
|
479
|
|
|
|
|
480
|
|
|
class tmp_assets_db(object): |
|
481
|
|
|
"""Create a temporary assets sqlite database. |
|
482
|
|
|
This is meant to be used as a context manager. |
|
483
|
|
|
|
|
484
|
|
|
Parameters |
|
485
|
|
|
---------- |
|
486
|
|
|
data : pd.DataFrame, optional |
|
487
|
|
|
The data to feed to the writer. By default this maps: |
|
488
|
|
|
('A', 'B', 'C') -> map(ord, 'ABC') |
|
489
|
|
|
""" |
|
490
|
|
|
def __init__(self, **frames): |
|
491
|
|
|
self._eng = None |
|
492
|
|
|
if not frames: |
|
493
|
|
|
frames = { |
|
494
|
|
|
'equities': make_simple_equity_info( |
|
495
|
|
|
list(map(ord, 'ABC')), |
|
496
|
|
|
pd.Timestamp(0), |
|
497
|
|
|
pd.Timestamp('2015'), |
|
498
|
|
|
) |
|
499
|
|
|
} |
|
500
|
|
|
self._data = AssetDBWriterFromDataFrame(**frames) |
|
501
|
|
|
|
|
502
|
|
|
def __enter__(self): |
|
503
|
|
|
self._eng = eng = create_engine('sqlite://') |
|
504
|
|
|
self._data.write_all(eng) |
|
505
|
|
|
return eng |
|
506
|
|
|
|
|
507
|
|
|
def __exit__(self, *excinfo): |
|
508
|
|
|
assert self._eng is not None, '_eng was not set in __enter__' |
|
509
|
|
|
self._eng.dispose() |
|
510
|
|
|
|
|
511
|
|
|
|
|
512
|
|
|
class tmp_asset_finder(tmp_assets_db): |
|
513
|
|
|
"""Create a temporary asset finder using an in memory sqlite db. |
|
514
|
|
|
|
|
515
|
|
|
Parameters |
|
516
|
|
|
---------- |
|
517
|
|
|
data : dict, optional |
|
518
|
|
|
The data to feed to the writer |
|
519
|
|
|
""" |
|
520
|
|
|
def __init__(self, finder_cls=AssetFinder, **frames): |
|
521
|
|
|
self._finder_cls = finder_cls |
|
522
|
|
|
super(tmp_asset_finder, self).__init__(**frames) |
|
523
|
|
|
|
|
524
|
|
|
def __enter__(self): |
|
525
|
|
|
return self._finder_cls(super(tmp_asset_finder, self).__enter__()) |
|
526
|
|
|
|
|
527
|
|
|
|
|
528
|
|
|
class SubTestFailures(AssertionError): |
|
529
|
|
|
def __init__(self, *failures): |
|
530
|
|
|
self.failures = failures |
|
531
|
|
|
|
|
532
|
|
|
def __str__(self): |
|
533
|
|
|
return 'failures:\n %s' % '\n '.join( |
|
534
|
|
|
'\n '.join(( |
|
535
|
|
|
', '.join('%s=%r' % item for item in scope.items()), |
|
536
|
|
|
'%s: %s' % (type(exc).__name__, exc), |
|
537
|
|
|
)) for scope, exc in self.failures, |
|
538
|
|
|
) |
|
539
|
|
|
|
|
540
|
|
|
|
|
541
|
|
|
def subtest(iterator, *_names): |
|
542
|
|
|
"""Construct a subtest in a unittest. |
|
543
|
|
|
|
|
544
|
|
|
This works by decorating a function as a subtest. The test will be run |
|
545
|
|
|
by iterating over the ``iterator`` and *unpacking the values into the |
|
546
|
|
|
function. If any of the runs fail, the result will be put into a set and |
|
547
|
|
|
the rest of the tests will be run. Finally, if any failed, all of the |
|
548
|
|
|
results will be dumped as one failure. |
|
549
|
|
|
|
|
550
|
|
|
Parameters |
|
551
|
|
|
---------- |
|
552
|
|
|
iterator : iterable[iterable] |
|
553
|
|
|
The iterator of arguments to pass to the function. |
|
554
|
|
|
*name : iterator[str] |
|
555
|
|
|
The names to use for each element of ``iterator``. These will be used |
|
556
|
|
|
to print the scope when a test fails. If not provided, it will use the |
|
557
|
|
|
integer index of the value as the name. |
|
558
|
|
|
|
|
559
|
|
|
Examples |
|
560
|
|
|
-------- |
|
561
|
|
|
|
|
562
|
|
|
:: |
|
563
|
|
|
|
|
564
|
|
|
class MyTest(TestCase): |
|
565
|
|
|
def test_thing(self): |
|
566
|
|
|
# Example usage inside another test. |
|
567
|
|
|
@subtest(([n] for n in range(100000)), 'n') |
|
568
|
|
|
def subtest(n): |
|
569
|
|
|
self.assertEqual(n % 2, 0, 'n was not even') |
|
570
|
|
|
subtest() |
|
571
|
|
|
|
|
572
|
|
|
@subtest(([n] for n in range(100000)), 'n') |
|
573
|
|
|
def test_decorated_function(self, n): |
|
574
|
|
|
# Example usage to parameterize an entire function. |
|
575
|
|
|
self.assertEqual(n % 2, 1, 'n was not odd') |
|
576
|
|
|
|
|
577
|
|
|
Notes |
|
578
|
|
|
----- |
|
579
|
|
|
We use this when we: |
|
580
|
|
|
|
|
581
|
|
|
* Will never want to run each parameter individually. |
|
582
|
|
|
* Have a large parameter space we are testing |
|
583
|
|
|
(see tests/utils/test_events.py). |
|
584
|
|
|
|
|
585
|
|
|
``nose_parameterized.expand`` will create a test for each parameter |
|
586
|
|
|
combination which bloats the test output and makes the travis pages slow. |
|
587
|
|
|
|
|
588
|
|
|
We cannot use ``unittest2.TestCase.subTest`` because nose, pytest, and |
|
589
|
|
|
nose2 do not support ``addSubTest``. |
|
590
|
|
|
""" |
|
591
|
|
|
def dec(f): |
|
592
|
|
|
@wraps(f) |
|
593
|
|
|
def wrapped(*args, **kwargs): |
|
594
|
|
|
names = _names |
|
595
|
|
|
failures = [] |
|
596
|
|
|
for scope in iterator: |
|
597
|
|
|
scope = tuple(scope) |
|
598
|
|
|
try: |
|
599
|
|
|
f(*args + scope, **kwargs) |
|
600
|
|
|
except Exception as e: |
|
601
|
|
|
if not names: |
|
602
|
|
|
names = count() |
|
603
|
|
|
failures.append((dict(zip(names, scope)), e)) |
|
604
|
|
|
if failures: |
|
605
|
|
|
raise SubTestFailures(*failures) |
|
606
|
|
|
|
|
607
|
|
|
return wrapped |
|
608
|
|
|
return dec |
|
609
|
|
|
|
|
610
|
|
|
|
|
611
|
|
|
def assert_timestamp_equal(left, right, compare_nat_equal=True, msg=""): |
|
612
|
|
|
""" |
|
613
|
|
|
Assert that two pandas Timestamp objects are the same. |
|
614
|
|
|
|
|
615
|
|
|
Parameters |
|
616
|
|
|
---------- |
|
617
|
|
|
left, right : pd.Timestamp |
|
618
|
|
|
The values to compare. |
|
619
|
|
|
compare_nat_equal : bool, optional |
|
620
|
|
|
Whether to consider `NaT` values equal. Defaults to True. |
|
621
|
|
|
msg : str, optional |
|
622
|
|
|
A message to forward to `pd.util.testing.assert_equal`. |
|
623
|
|
|
""" |
|
624
|
|
|
if compare_nat_equal and left is pd.NaT and right is pd.NaT: |
|
625
|
|
|
return |
|
626
|
|
|
return pd.util.testing.assert_equal(left, right, msg=msg) |
|
627
|
|
|
|
|
628
|
|
|
|
|
629
|
|
|
def powerset(values): |
|
630
|
|
|
""" |
|
631
|
|
|
Return the power set (i.e., the set of all subsets) of entries in `values`. |
|
632
|
|
|
""" |
|
633
|
|
|
return concat(combinations(values, i) for i in range(len(values) + 1)) |
|
634
|
|
|
|
|
635
|
|
|
|
|
636
|
|
|
def to_series(knowledge_dates, earning_dates): |
|
637
|
|
|
""" |
|
638
|
|
|
Helper for converting a dict of strings to a Series of datetimes. |
|
639
|
|
|
|
|
640
|
|
|
This is just for making the test cases more readable. |
|
641
|
|
|
""" |
|
642
|
|
|
return pd.Series( |
|
643
|
|
|
index=pd.to_datetime(knowledge_dates), |
|
644
|
|
|
data=pd.to_datetime(earning_dates), |
|
645
|
|
|
) |
|
646
|
|
|
|
|
647
|
|
|
|
|
648
|
|
|
def num_days_in_range(dates, start, end): |
|
649
|
|
|
""" |
|
650
|
|
|
Return the number of days in `dates` between start and end, inclusive. |
|
651
|
|
|
""" |
|
652
|
|
|
start_idx, stop_idx = dates.slice_locs(start, end) |
|
653
|
|
|
return stop_idx - start_idx |
|
654
|
|
|
|
|
655
|
|
|
|
|
656
|
|
|
def gen_calendars(start, stop, critical_dates): |
|
657
|
|
|
""" |
|
658
|
|
|
Generate calendars to use as inputs. |
|
659
|
|
|
""" |
|
660
|
|
|
all_dates = pd.date_range(start, stop, tz='utc') |
|
661
|
|
|
for to_drop in map(list, powerset(critical_dates)): |
|
662
|
|
|
# Have to yield tuples. |
|
663
|
|
|
yield (all_dates.drop(to_drop),) |
|
664
|
|
|
|
|
665
|
|
|
# Also test with the trading calendar. |
|
666
|
|
|
yield (trading_days[trading_days.slice_indexer(start, stop)],) |
|
667
|
|
|
|