1
|
|
|
from contextlib import contextmanager |
2
|
|
|
from functools import wraps |
3
|
|
|
from itertools import ( |
4
|
|
|
combinations, |
5
|
|
|
count, |
6
|
|
|
product, |
7
|
|
|
) |
8
|
|
|
import operator |
9
|
|
|
import os |
10
|
|
|
import shutil |
11
|
|
|
from string import ascii_uppercase |
12
|
|
|
import tempfile |
13
|
|
|
|
14
|
|
|
from logbook import FileHandler |
15
|
|
|
from mock import patch |
16
|
|
|
from numpy.testing import assert_allclose, assert_array_equal |
17
|
|
|
import pandas as pd |
18
|
|
|
from pandas.tseries.offsets import MonthBegin |
19
|
|
|
from six import iteritems, itervalues |
20
|
|
|
from six.moves import filter |
21
|
|
|
from sqlalchemy import create_engine |
22
|
|
|
from toolz import concat |
23
|
|
|
|
24
|
|
|
from zipline.assets import AssetFinder |
25
|
|
|
from zipline.assets.asset_writer import AssetDBWriterFromDataFrame |
26
|
|
|
from zipline.assets.futures import CME_CODE_TO_MONTH |
27
|
|
|
from zipline.finance.order import ORDER_STATUS |
28
|
|
|
from zipline.utils import security_list |
29
|
|
|
from zipline.utils.tradingcalendar import trading_days |
30
|
|
|
|
31
|
|
|
|
32
|
|
|
EPOCH = pd.Timestamp(0, tz='UTC') |
33
|
|
|
|
34
|
|
|
|
35
|
|
|
def seconds_to_timestamp(seconds): |
36
|
|
|
return pd.Timestamp(seconds, unit='s', tz='UTC') |
37
|
|
|
|
38
|
|
|
|
39
|
|
|
def to_utc(time_str): |
40
|
|
|
"""Convert a string in US/Eastern time to UTC""" |
41
|
|
|
return pd.Timestamp(time_str, tz='US/Eastern').tz_convert('UTC') |
42
|
|
|
|
43
|
|
|
|
44
|
|
|
def str_to_seconds(s): |
45
|
|
|
""" |
46
|
|
|
Convert a pandas-intelligible string to (integer) seconds since UTC. |
47
|
|
|
|
48
|
|
|
>>> from pandas import Timestamp |
49
|
|
|
>>> (Timestamp('2014-01-01') - Timestamp(0)).total_seconds() |
50
|
|
|
1388534400.0 |
51
|
|
|
>>> str_to_seconds('2014-01-01') |
52
|
|
|
1388534400 |
53
|
|
|
""" |
54
|
|
|
return int((pd.Timestamp(s, tz='UTC') - EPOCH).total_seconds()) |
55
|
|
|
|
56
|
|
|
|
57
|
|
|
def setup_logger(test, path='test.log'): |
58
|
|
|
test.log_handler = FileHandler(path) |
59
|
|
|
test.log_handler.push_application() |
60
|
|
|
|
61
|
|
|
|
62
|
|
|
def teardown_logger(test): |
63
|
|
|
test.log_handler.pop_application() |
64
|
|
|
test.log_handler.close() |
65
|
|
|
|
66
|
|
|
|
67
|
|
|
def drain_zipline(test, zipline): |
68
|
|
|
output = [] |
69
|
|
|
transaction_count = 0 |
70
|
|
|
msg_counter = 0 |
71
|
|
|
# start the simulation |
72
|
|
|
for update in zipline: |
73
|
|
|
msg_counter += 1 |
74
|
|
|
output.append(update) |
75
|
|
|
if 'daily_perf' in update: |
76
|
|
|
transaction_count += \ |
77
|
|
|
len(update['daily_perf']['transactions']) |
78
|
|
|
|
79
|
|
|
return output, transaction_count |
80
|
|
|
|
81
|
|
|
|
82
|
|
|
def assert_single_position(test, zipline): |
83
|
|
|
|
84
|
|
|
output, transaction_count = drain_zipline(test, zipline) |
85
|
|
|
|
86
|
|
|
if 'expected_transactions' in test.zipline_test_config: |
87
|
|
|
test.assertEqual( |
88
|
|
|
test.zipline_test_config['expected_transactions'], |
89
|
|
|
transaction_count |
90
|
|
|
) |
91
|
|
|
else: |
92
|
|
|
test.assertEqual( |
93
|
|
|
test.zipline_test_config['order_count'], |
94
|
|
|
transaction_count |
95
|
|
|
) |
96
|
|
|
|
97
|
|
|
# the final message is the risk report, the second to |
98
|
|
|
# last is the final day's results. Positions is a list of |
99
|
|
|
# dicts. |
100
|
|
|
closing_positions = output[-2]['daily_perf']['positions'] |
101
|
|
|
|
102
|
|
|
# confirm that all orders were filled. |
103
|
|
|
# iterate over the output updates, overwriting |
104
|
|
|
# orders when they are updated. Then check the status on all. |
105
|
|
|
orders_by_id = {} |
106
|
|
|
for update in output: |
107
|
|
|
if 'daily_perf' in update: |
108
|
|
|
if 'orders' in update['daily_perf']: |
109
|
|
|
for order in update['daily_perf']['orders']: |
110
|
|
|
orders_by_id[order['id']] = order |
111
|
|
|
|
112
|
|
|
for order in itervalues(orders_by_id): |
113
|
|
|
test.assertEqual( |
114
|
|
|
order['status'], |
115
|
|
|
ORDER_STATUS.FILLED, |
116
|
|
|
"") |
117
|
|
|
|
118
|
|
|
test.assertEqual( |
119
|
|
|
len(closing_positions), |
120
|
|
|
1, |
121
|
|
|
"Portfolio should have one position." |
122
|
|
|
) |
123
|
|
|
|
124
|
|
|
sid = test.zipline_test_config['sid'] |
125
|
|
|
test.assertEqual( |
126
|
|
|
closing_positions[0]['sid'], |
127
|
|
|
sid, |
128
|
|
|
"Portfolio should have one position in " + str(sid) |
129
|
|
|
) |
130
|
|
|
|
131
|
|
|
return output, transaction_count |
132
|
|
|
|
133
|
|
|
|
134
|
|
|
class ExceptionSource(object): |
135
|
|
|
|
136
|
|
|
def __init__(self): |
137
|
|
|
pass |
138
|
|
|
|
139
|
|
|
def get_hash(self): |
140
|
|
|
return "ExceptionSource" |
141
|
|
|
|
142
|
|
|
def __iter__(self): |
143
|
|
|
return self |
144
|
|
|
|
145
|
|
|
def next(self): |
146
|
|
|
5 / 0 |
147
|
|
|
|
148
|
|
|
def __next__(self): |
149
|
|
|
5 / 0 |
150
|
|
|
|
151
|
|
|
|
152
|
|
|
@contextmanager |
153
|
|
|
def security_list_copy(): |
154
|
|
|
old_dir = security_list.SECURITY_LISTS_DIR |
155
|
|
|
new_dir = tempfile.mkdtemp() |
156
|
|
|
try: |
157
|
|
|
for subdir in os.listdir(old_dir): |
158
|
|
|
shutil.copytree(os.path.join(old_dir, subdir), |
159
|
|
|
os.path.join(new_dir, subdir)) |
160
|
|
|
with patch.object(security_list, 'SECURITY_LISTS_DIR', new_dir), \ |
161
|
|
|
patch.object(security_list, 'using_copy', True, |
162
|
|
|
create=True): |
163
|
|
|
yield |
164
|
|
|
finally: |
165
|
|
|
shutil.rmtree(new_dir, True) |
166
|
|
|
|
167
|
|
|
|
168
|
|
|
def add_security_data(adds, deletes): |
169
|
|
|
if not hasattr(security_list, 'using_copy'): |
170
|
|
|
raise Exception('add_security_data must be used within ' |
171
|
|
|
'security_list_copy context') |
172
|
|
|
directory = os.path.join( |
173
|
|
|
security_list.SECURITY_LISTS_DIR, |
174
|
|
|
"leveraged_etf_list/20150127/20150125" |
175
|
|
|
) |
176
|
|
|
if not os.path.exists(directory): |
177
|
|
|
os.makedirs(directory) |
178
|
|
|
del_path = os.path.join(directory, "delete") |
179
|
|
|
with open(del_path, 'w') as f: |
180
|
|
|
for sym in deletes: |
181
|
|
|
f.write(sym) |
182
|
|
|
f.write('\n') |
183
|
|
|
add_path = os.path.join(directory, "add") |
184
|
|
|
with open(add_path, 'w') as f: |
185
|
|
|
for sym in adds: |
186
|
|
|
f.write(sym) |
187
|
|
|
f.write('\n') |
188
|
|
|
|
189
|
|
|
|
190
|
|
|
def all_pairs_matching_predicate(values, pred): |
191
|
|
|
""" |
192
|
|
|
Return an iterator of all pairs, (v0, v1) from values such that |
193
|
|
|
|
194
|
|
|
`pred(v0, v1) == True` |
195
|
|
|
|
196
|
|
|
Parameters |
197
|
|
|
---------- |
198
|
|
|
values : iterable |
199
|
|
|
pred : function |
200
|
|
|
|
201
|
|
|
Returns |
202
|
|
|
------- |
203
|
|
|
pairs_iterator : generator |
204
|
|
|
Generator yielding pairs matching `pred`. |
205
|
|
|
|
206
|
|
|
Examples |
207
|
|
|
-------- |
208
|
|
|
>>> from zipline.utils.test_utils import all_pairs_matching_predicate |
209
|
|
|
>>> from operator import eq, lt |
210
|
|
|
>>> list(all_pairs_matching_predicate(range(5), eq)) |
211
|
|
|
[(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)] |
212
|
|
|
>>> list(all_pairs_matching_predicate("abcd", lt)) |
213
|
|
|
[('a', 'b'), ('a', 'c'), ('a', 'd'), ('b', 'c'), ('b', 'd'), ('c', 'd')] |
214
|
|
|
""" |
215
|
|
|
return filter(lambda pair: pred(*pair), product(values, repeat=2)) |
216
|
|
|
|
217
|
|
|
|
218
|
|
|
def product_upper_triangle(values, include_diagonal=False): |
219
|
|
|
""" |
220
|
|
|
Return an iterator over pairs, (v0, v1), drawn from values. |
221
|
|
|
|
222
|
|
|
If `include_diagonal` is True, returns all pairs such that v0 <= v1. |
223
|
|
|
If `include_diagonal` is False, returns all pairs such that v0 < v1. |
224
|
|
|
""" |
225
|
|
|
return all_pairs_matching_predicate( |
226
|
|
|
values, |
227
|
|
|
operator.le if include_diagonal else operator.lt, |
228
|
|
|
) |
229
|
|
|
|
230
|
|
|
|
231
|
|
|
def all_subindices(index): |
232
|
|
|
""" |
233
|
|
|
Return all valid sub-indices of a pandas Index. |
234
|
|
|
""" |
235
|
|
|
return ( |
236
|
|
|
index[start:stop] |
237
|
|
|
for start, stop in product_upper_triangle(range(len(index) + 1)) |
238
|
|
|
) |
239
|
|
|
|
240
|
|
|
|
241
|
|
|
def make_rotating_equity_info(num_assets, |
242
|
|
|
first_start, |
243
|
|
|
frequency, |
244
|
|
|
periods_between_starts, |
245
|
|
|
asset_lifetime): |
246
|
|
|
""" |
247
|
|
|
Create a DataFrame representing lifetimes of assets that are constantly |
248
|
|
|
rotating in and out of existence. |
249
|
|
|
|
250
|
|
|
Parameters |
251
|
|
|
---------- |
252
|
|
|
num_assets : int |
253
|
|
|
How many assets to create. |
254
|
|
|
first_start : pd.Timestamp |
255
|
|
|
The start date for the first asset. |
256
|
|
|
frequency : str or pd.tseries.offsets.Offset (e.g. trading_day) |
257
|
|
|
Frequency used to interpret next two arguments. |
258
|
|
|
periods_between_starts : int |
259
|
|
|
Create a new asset every `frequency` * `periods_between_new` |
260
|
|
|
asset_lifetime : int |
261
|
|
|
Each asset exists for `frequency` * `asset_lifetime` days. |
262
|
|
|
|
263
|
|
|
Returns |
264
|
|
|
------- |
265
|
|
|
info : pd.DataFrame |
266
|
|
|
DataFrame representing newly-created assets. |
267
|
|
|
""" |
268
|
|
|
return pd.DataFrame( |
269
|
|
|
{ |
270
|
|
|
'symbol': [chr(ord('A') + i) for i in range(num_assets)], |
271
|
|
|
# Start a new asset every `periods_between_starts` days. |
272
|
|
|
'start_date': pd.date_range( |
273
|
|
|
first_start, |
274
|
|
|
freq=(periods_between_starts * frequency), |
275
|
|
|
periods=num_assets, |
276
|
|
|
), |
277
|
|
|
# Each asset lasts for `asset_lifetime` days. |
278
|
|
|
'end_date': pd.date_range( |
279
|
|
|
first_start + (asset_lifetime * frequency), |
280
|
|
|
freq=(periods_between_starts * frequency), |
281
|
|
|
periods=num_assets, |
282
|
|
|
), |
283
|
|
|
'exchange': 'TEST', |
284
|
|
|
}, |
285
|
|
|
index=range(num_assets), |
286
|
|
|
) |
287
|
|
|
|
288
|
|
|
|
289
|
|
|
def make_simple_equity_info(sids, start_date, end_date, symbols=None): |
290
|
|
|
""" |
291
|
|
|
Create a DataFrame representing assets that exist for the full duration |
292
|
|
|
between `start_date` and `end_date`. |
293
|
|
|
|
294
|
|
|
Parameters |
295
|
|
|
---------- |
296
|
|
|
sids : array-like of int |
297
|
|
|
start_date : pd.Timestamp |
298
|
|
|
end_date : pd.Timestamp |
299
|
|
|
symbols : list, optional |
300
|
|
|
Symbols to use for the assets. |
301
|
|
|
If not provided, symbols are generated from the sequence 'A', 'B', ... |
302
|
|
|
|
303
|
|
|
Returns |
304
|
|
|
------- |
305
|
|
|
info : pd.DataFrame |
306
|
|
|
DataFrame representing newly-created assets. |
307
|
|
|
""" |
308
|
|
|
num_assets = len(sids) |
309
|
|
|
if symbols is None: |
310
|
|
|
symbols = list(ascii_uppercase[:num_assets]) |
311
|
|
|
return pd.DataFrame( |
312
|
|
|
{ |
313
|
|
|
'symbol': symbols, |
314
|
|
|
'start_date': [start_date] * num_assets, |
315
|
|
|
'end_date': [end_date] * num_assets, |
316
|
|
|
'exchange': 'TEST', |
317
|
|
|
}, |
318
|
|
|
index=sids, |
319
|
|
|
) |
320
|
|
|
|
321
|
|
|
|
322
|
|
|
def make_future_info(first_sid, |
323
|
|
|
root_symbols, |
324
|
|
|
years, |
325
|
|
|
notice_date_func, |
326
|
|
|
expiration_date_func, |
327
|
|
|
start_date_func, |
328
|
|
|
month_codes=None): |
329
|
|
|
""" |
330
|
|
|
Create a DataFrame representing futures for `root_symbols` during `year`. |
331
|
|
|
|
332
|
|
|
Generates a contract per triple of (symbol, year, month) supplied to |
333
|
|
|
`root_symbols`, `years`, and `month_codes`. |
334
|
|
|
|
335
|
|
|
Parameters |
336
|
|
|
---------- |
337
|
|
|
first_sid : int |
338
|
|
|
The first sid to use for assigning sids to the created contracts. |
339
|
|
|
root_symbols : list[str] |
340
|
|
|
A list of root symbols for which to create futures. |
341
|
|
|
years : list[int or str] |
342
|
|
|
Years (e.g. 2014), for which to produce individual contracts. |
343
|
|
|
notice_date_func : (Timestamp) -> Timestamp |
344
|
|
|
Function to generate notice dates from first of the month associated |
345
|
|
|
with asset month code. Return NaT to simulate futures with no notice |
346
|
|
|
date. |
347
|
|
|
expiration_date_func : (Timestamp) -> Timestamp |
348
|
|
|
Function to generate expiration dates from first of the month |
349
|
|
|
associated with asset month code. |
350
|
|
|
start_date_func : (Timestamp) -> Timestamp, optional |
351
|
|
|
Function to generate start dates from first of the month associated |
352
|
|
|
with each asset month code. Defaults to a start_date one year prior |
353
|
|
|
to the month_code date. |
354
|
|
|
month_codes : dict[str -> [1..12]], optional |
355
|
|
|
Dictionary of month codes for which to create contracts. Entries |
356
|
|
|
should be strings mapped to values from 1 (January) to 12 (December). |
357
|
|
|
Default is zipline.futures.CME_CODE_TO_MONTH |
358
|
|
|
|
359
|
|
|
Returns |
360
|
|
|
------- |
361
|
|
|
futures_info : pd.DataFrame |
362
|
|
|
DataFrame of futures data suitable for passing to an |
363
|
|
|
AssetDBWriterFromDataFrame. |
364
|
|
|
""" |
365
|
|
|
if month_codes is None: |
366
|
|
|
month_codes = CME_CODE_TO_MONTH |
367
|
|
|
|
368
|
|
|
year_strs = list(map(str, years)) |
369
|
|
|
years = [pd.Timestamp(s, tz='UTC') for s in year_strs] |
370
|
|
|
|
371
|
|
|
# Pairs of string/date like ('K06', 2006-05-01) |
372
|
|
|
contract_suffix_to_beginning_of_month = tuple( |
373
|
|
|
(month_code + year_str[-2:], year + MonthBegin(month_num)) |
374
|
|
|
for ((year, year_str), (month_code, month_num)) |
375
|
|
|
in product( |
376
|
|
|
zip(years, year_strs), |
377
|
|
|
iteritems(month_codes), |
378
|
|
|
) |
379
|
|
|
) |
380
|
|
|
|
381
|
|
|
contracts = [] |
382
|
|
|
parts = product(root_symbols, contract_suffix_to_beginning_of_month) |
383
|
|
|
for sid, (root_sym, (suffix, month_begin)) in enumerate(parts, first_sid): |
384
|
|
|
contracts.append({ |
385
|
|
|
'sid': sid, |
386
|
|
|
'root_symbol': root_sym, |
387
|
|
|
'symbol': root_sym + suffix, |
388
|
|
|
'start_date': start_date_func(month_begin), |
389
|
|
|
'notice_date': notice_date_func(month_begin), |
390
|
|
|
'expiration_date': notice_date_func(month_begin), |
391
|
|
|
'contract_multiplier': 500, |
392
|
|
|
}) |
393
|
|
|
return pd.DataFrame.from_records(contracts, index='sid').convert_objects() |
394
|
|
|
|
395
|
|
|
|
396
|
|
|
def make_commodity_future_info(first_sid, |
397
|
|
|
root_symbols, |
398
|
|
|
years, |
399
|
|
|
month_codes=None): |
400
|
|
|
""" |
401
|
|
|
Make futures testing data that simulates the notice/expiration date |
402
|
|
|
behavior of physical commodities like oil. |
403
|
|
|
|
404
|
|
|
Parameters |
405
|
|
|
---------- |
406
|
|
|
first_sid : int |
407
|
|
|
root_symbols : list[str] |
408
|
|
|
years : list[int] |
409
|
|
|
month_codes : dict[str -> int] |
410
|
|
|
|
411
|
|
|
Expiration dates are on the 20th of the month prior to the month code. |
412
|
|
|
Notice dates are are on the 20th two months prior to the month code. |
413
|
|
|
Start dates are one year before the contract month. |
414
|
|
|
|
415
|
|
|
See Also |
416
|
|
|
-------- |
417
|
|
|
make_future_info |
418
|
|
|
""" |
419
|
|
|
nineteen_days = pd.Timedelta(days=19) |
420
|
|
|
one_year = pd.Timedelta(days=365) |
421
|
|
|
return make_future_info( |
422
|
|
|
first_sid=first_sid, |
423
|
|
|
root_symbols=root_symbols, |
424
|
|
|
years=years, |
425
|
|
|
notice_date_func=lambda dt: dt - MonthBegin(2) + nineteen_days, |
426
|
|
|
expiration_date_func=lambda dt: dt - MonthBegin(1) + nineteen_days, |
427
|
|
|
start_date_func=lambda dt: dt - one_year, |
428
|
|
|
month_codes=month_codes, |
429
|
|
|
) |
430
|
|
|
|
431
|
|
|
|
432
|
|
|
def check_allclose(actual, |
433
|
|
|
desired, |
434
|
|
|
rtol=1e-07, |
435
|
|
|
atol=0, |
436
|
|
|
err_msg='', |
437
|
|
|
verbose=True): |
438
|
|
|
""" |
439
|
|
|
Wrapper around np.testing.assert_allclose that also verifies that inputs |
440
|
|
|
are ndarrays. |
441
|
|
|
|
442
|
|
|
See Also |
443
|
|
|
-------- |
444
|
|
|
np.assert_allclose |
445
|
|
|
""" |
446
|
|
|
if type(actual) != type(desired): |
447
|
|
|
raise AssertionError("%s != %s" % (type(actual), type(desired))) |
448
|
|
|
return assert_allclose(actual, desired, err_msg=err_msg, verbose=True) |
449
|
|
|
|
450
|
|
|
|
451
|
|
|
def check_arrays(x, y, err_msg='', verbose=True): |
452
|
|
|
""" |
453
|
|
|
Wrapper around np.testing.assert_array_equal that also verifies that inputs |
454
|
|
|
are ndarrays. |
455
|
|
|
|
456
|
|
|
See Also |
457
|
|
|
-------- |
458
|
|
|
np.assert_array_equal |
459
|
|
|
""" |
460
|
|
|
if type(x) != type(y): |
461
|
|
|
raise AssertionError("%s != %s" % (type(x), type(y))) |
462
|
|
|
return assert_array_equal(x, y, err_msg=err_msg, verbose=True) |
463
|
|
|
|
464
|
|
|
|
465
|
|
|
class UnexpectedAttributeAccess(Exception): |
466
|
|
|
pass |
467
|
|
|
|
468
|
|
|
|
469
|
|
|
class ExplodingObject(object): |
470
|
|
|
""" |
471
|
|
|
Object that will raise an exception on any attribute access. |
472
|
|
|
|
473
|
|
|
Useful for verifying that an object is never touched during a |
474
|
|
|
function/method call. |
475
|
|
|
""" |
476
|
|
|
def __getattribute__(self, name): |
477
|
|
|
raise UnexpectedAttributeAccess(name) |
478
|
|
|
|
479
|
|
|
|
480
|
|
|
class tmp_assets_db(object): |
481
|
|
|
"""Create a temporary assets sqlite database. |
482
|
|
|
This is meant to be used as a context manager. |
483
|
|
|
|
484
|
|
|
Parameters |
485
|
|
|
---------- |
486
|
|
|
data : pd.DataFrame, optional |
487
|
|
|
The data to feed to the writer. By default this maps: |
488
|
|
|
('A', 'B', 'C') -> map(ord, 'ABC') |
489
|
|
|
""" |
490
|
|
|
def __init__(self, **frames): |
491
|
|
|
self._eng = None |
492
|
|
|
if not frames: |
493
|
|
|
frames = { |
494
|
|
|
'equities': make_simple_equity_info( |
495
|
|
|
list(map(ord, 'ABC')), |
496
|
|
|
pd.Timestamp(0), |
497
|
|
|
pd.Timestamp('2015'), |
498
|
|
|
) |
499
|
|
|
} |
500
|
|
|
self._data = AssetDBWriterFromDataFrame(**frames) |
501
|
|
|
|
502
|
|
|
def __enter__(self): |
503
|
|
|
self._eng = eng = create_engine('sqlite://') |
504
|
|
|
self._data.write_all(eng) |
505
|
|
|
return eng |
506
|
|
|
|
507
|
|
|
def __exit__(self, *excinfo): |
508
|
|
|
assert self._eng is not None, '_eng was not set in __enter__' |
509
|
|
|
self._eng.dispose() |
510
|
|
|
|
511
|
|
|
|
512
|
|
|
class tmp_asset_finder(tmp_assets_db): |
513
|
|
|
"""Create a temporary asset finder using an in memory sqlite db. |
514
|
|
|
|
515
|
|
|
Parameters |
516
|
|
|
---------- |
517
|
|
|
data : dict, optional |
518
|
|
|
The data to feed to the writer |
519
|
|
|
""" |
520
|
|
|
def __init__(self, finder_cls=AssetFinder, **frames): |
521
|
|
|
self._finder_cls = finder_cls |
522
|
|
|
super(tmp_asset_finder, self).__init__(**frames) |
523
|
|
|
|
524
|
|
|
def __enter__(self): |
525
|
|
|
return self._finder_cls(super(tmp_asset_finder, self).__enter__()) |
526
|
|
|
|
527
|
|
|
|
528
|
|
|
class SubTestFailures(AssertionError): |
529
|
|
|
def __init__(self, *failures): |
530
|
|
|
self.failures = failures |
531
|
|
|
|
532
|
|
|
def __str__(self): |
533
|
|
|
return 'failures:\n %s' % '\n '.join( |
534
|
|
|
'\n '.join(( |
535
|
|
|
', '.join('%s=%r' % item for item in scope.items()), |
536
|
|
|
'%s: %s' % (type(exc).__name__, exc), |
537
|
|
|
)) for scope, exc in self.failures, |
538
|
|
|
) |
539
|
|
|
|
540
|
|
|
|
541
|
|
|
def subtest(iterator, *_names): |
542
|
|
|
"""Construct a subtest in a unittest. |
543
|
|
|
|
544
|
|
|
This works by decorating a function as a subtest. The test will be run |
545
|
|
|
by iterating over the ``iterator`` and *unpacking the values into the |
546
|
|
|
function. If any of the runs fail, the result will be put into a set and |
547
|
|
|
the rest of the tests will be run. Finally, if any failed, all of the |
548
|
|
|
results will be dumped as one failure. |
549
|
|
|
|
550
|
|
|
Parameters |
551
|
|
|
---------- |
552
|
|
|
iterator : iterable[iterable] |
553
|
|
|
The iterator of arguments to pass to the function. |
554
|
|
|
*name : iterator[str] |
555
|
|
|
The names to use for each element of ``iterator``. These will be used |
556
|
|
|
to print the scope when a test fails. If not provided, it will use the |
557
|
|
|
integer index of the value as the name. |
558
|
|
|
|
559
|
|
|
Examples |
560
|
|
|
-------- |
561
|
|
|
|
562
|
|
|
:: |
563
|
|
|
|
564
|
|
|
class MyTest(TestCase): |
565
|
|
|
def test_thing(self): |
566
|
|
|
# Example usage inside another test. |
567
|
|
|
@subtest(([n] for n in range(100000)), 'n') |
568
|
|
|
def subtest(n): |
569
|
|
|
self.assertEqual(n % 2, 0, 'n was not even') |
570
|
|
|
subtest() |
571
|
|
|
|
572
|
|
|
@subtest(([n] for n in range(100000)), 'n') |
573
|
|
|
def test_decorated_function(self, n): |
574
|
|
|
# Example usage to parameterize an entire function. |
575
|
|
|
self.assertEqual(n % 2, 1, 'n was not odd') |
576
|
|
|
|
577
|
|
|
Notes |
578
|
|
|
----- |
579
|
|
|
We use this when we: |
580
|
|
|
|
581
|
|
|
* Will never want to run each parameter individually. |
582
|
|
|
* Have a large parameter space we are testing |
583
|
|
|
(see tests/utils/test_events.py). |
584
|
|
|
|
585
|
|
|
``nose_parameterized.expand`` will create a test for each parameter |
586
|
|
|
combination which bloats the test output and makes the travis pages slow. |
587
|
|
|
|
588
|
|
|
We cannot use ``unittest2.TestCase.subTest`` because nose, pytest, and |
589
|
|
|
nose2 do not support ``addSubTest``. |
590
|
|
|
""" |
591
|
|
|
def dec(f): |
592
|
|
|
@wraps(f) |
593
|
|
|
def wrapped(*args, **kwargs): |
594
|
|
|
names = _names |
595
|
|
|
failures = [] |
596
|
|
|
for scope in iterator: |
597
|
|
|
scope = tuple(scope) |
598
|
|
|
try: |
599
|
|
|
f(*args + scope, **kwargs) |
600
|
|
|
except Exception as e: |
601
|
|
|
if not names: |
602
|
|
|
names = count() |
603
|
|
|
failures.append((dict(zip(names, scope)), e)) |
604
|
|
|
if failures: |
605
|
|
|
raise SubTestFailures(*failures) |
606
|
|
|
|
607
|
|
|
return wrapped |
608
|
|
|
return dec |
609
|
|
|
|
610
|
|
|
|
611
|
|
|
def assert_timestamp_equal(left, right, compare_nat_equal=True, msg=""): |
612
|
|
|
""" |
613
|
|
|
Assert that two pandas Timestamp objects are the same. |
614
|
|
|
|
615
|
|
|
Parameters |
616
|
|
|
---------- |
617
|
|
|
left, right : pd.Timestamp |
618
|
|
|
The values to compare. |
619
|
|
|
compare_nat_equal : bool, optional |
620
|
|
|
Whether to consider `NaT` values equal. Defaults to True. |
621
|
|
|
msg : str, optional |
622
|
|
|
A message to forward to `pd.util.testing.assert_equal`. |
623
|
|
|
""" |
624
|
|
|
if compare_nat_equal and left is pd.NaT and right is pd.NaT: |
625
|
|
|
return |
626
|
|
|
return pd.util.testing.assert_equal(left, right, msg=msg) |
627
|
|
|
|
628
|
|
|
|
629
|
|
|
def powerset(values): |
630
|
|
|
""" |
631
|
|
|
Return the power set (i.e., the set of all subsets) of entries in `values`. |
632
|
|
|
""" |
633
|
|
|
return concat(combinations(values, i) for i in range(len(values) + 1)) |
634
|
|
|
|
635
|
|
|
|
636
|
|
|
def to_series(knowledge_dates, earning_dates): |
637
|
|
|
""" |
638
|
|
|
Helper for converting a dict of strings to a Series of datetimes. |
639
|
|
|
|
640
|
|
|
This is just for making the test cases more readable. |
641
|
|
|
""" |
642
|
|
|
return pd.Series( |
643
|
|
|
index=pd.to_datetime(knowledge_dates), |
644
|
|
|
data=pd.to_datetime(earning_dates), |
645
|
|
|
) |
646
|
|
|
|
647
|
|
|
|
648
|
|
|
def num_days_in_range(dates, start, end): |
649
|
|
|
""" |
650
|
|
|
Return the number of days in `dates` between start and end, inclusive. |
651
|
|
|
""" |
652
|
|
|
start_idx, stop_idx = dates.slice_locs(start, end) |
653
|
|
|
return stop_idx - start_idx |
654
|
|
|
|
655
|
|
|
|
656
|
|
|
def gen_calendars(start, stop, critical_dates): |
657
|
|
|
""" |
658
|
|
|
Generate calendars to use as inputs. |
659
|
|
|
""" |
660
|
|
|
all_dates = pd.date_range(start, stop, tz='utc') |
661
|
|
|
for to_drop in map(list, powerset(critical_dates)): |
662
|
|
|
# Have to yield tuples. |
663
|
|
|
yield (all_dates.drop(to_drop),) |
664
|
|
|
|
665
|
|
|
# Also test with the trading calendar. |
666
|
|
|
yield (trading_days[trading_days.slice_indexer(start, stop)],) |
667
|
|
|
|