1
|
|
|
from contextlib import contextmanager |
2
|
|
|
from functools import wraps |
3
|
|
|
from itertools import ( |
4
|
|
|
count, |
5
|
|
|
product, |
6
|
|
|
) |
7
|
|
|
import operator |
8
|
|
|
import os |
9
|
|
|
import shutil |
10
|
|
|
from string import ascii_uppercase |
11
|
|
|
import tempfile |
12
|
|
|
|
13
|
|
|
from logbook import FileHandler |
14
|
|
|
from mock import patch |
15
|
|
|
from numpy.testing import assert_allclose, assert_array_equal |
16
|
|
|
import pandas as pd |
17
|
|
|
from pandas.tseries.offsets import MonthBegin |
18
|
|
|
from six import iteritems, itervalues |
19
|
|
|
from six.moves import filter |
20
|
|
|
from sqlalchemy import create_engine |
21
|
|
|
|
22
|
|
|
from zipline.assets import AssetFinder |
23
|
|
|
from zipline.assets.asset_writer import AssetDBWriterFromDataFrame |
24
|
|
|
from zipline.assets.futures import CME_CODE_TO_MONTH |
25
|
|
|
from zipline.finance.blotter import ORDER_STATUS |
26
|
|
|
from zipline.utils import security_list |
27
|
|
|
|
28
|
|
|
|
29
|
|
|
EPOCH = pd.Timestamp(0, tz='UTC') |
30
|
|
|
|
31
|
|
|
|
32
|
|
|
def seconds_to_timestamp(seconds): |
33
|
|
|
return pd.Timestamp(seconds, unit='s', tz='UTC') |
34
|
|
|
|
35
|
|
|
|
36
|
|
|
def to_utc(time_str): |
37
|
|
|
"""Convert a string in US/Eastern time to UTC""" |
38
|
|
|
return pd.Timestamp(time_str, tz='US/Eastern').tz_convert('UTC') |
39
|
|
|
|
40
|
|
|
|
41
|
|
|
def str_to_seconds(s): |
42
|
|
|
""" |
43
|
|
|
Convert a pandas-intelligible string to (integer) seconds since UTC. |
44
|
|
|
|
45
|
|
|
>>> from pandas import Timestamp |
46
|
|
|
>>> (Timestamp('2014-01-01') - Timestamp(0)).total_seconds() |
47
|
|
|
1388534400.0 |
48
|
|
|
>>> str_to_seconds('2014-01-01') |
49
|
|
|
1388534400 |
50
|
|
|
""" |
51
|
|
|
return int((pd.Timestamp(s, tz='UTC') - EPOCH).total_seconds()) |
52
|
|
|
|
53
|
|
|
|
54
|
|
|
def setup_logger(test, path='test.log'): |
55
|
|
|
test.log_handler = FileHandler(path) |
56
|
|
|
test.log_handler.push_application() |
57
|
|
|
|
58
|
|
|
|
59
|
|
|
def teardown_logger(test): |
60
|
|
|
test.log_handler.pop_application() |
61
|
|
|
test.log_handler.close() |
62
|
|
|
|
63
|
|
|
|
64
|
|
|
def drain_zipline(test, zipline): |
65
|
|
|
output = [] |
66
|
|
|
transaction_count = 0 |
67
|
|
|
msg_counter = 0 |
68
|
|
|
# start the simulation |
69
|
|
|
for update in zipline: |
70
|
|
|
msg_counter += 1 |
71
|
|
|
output.append(update) |
72
|
|
|
if 'daily_perf' in update: |
73
|
|
|
transaction_count += \ |
74
|
|
|
len(update['daily_perf']['transactions']) |
75
|
|
|
|
76
|
|
|
return output, transaction_count |
77
|
|
|
|
78
|
|
|
|
79
|
|
|
def assert_single_position(test, zipline): |
80
|
|
|
|
81
|
|
|
output, transaction_count = drain_zipline(test, zipline) |
82
|
|
|
|
83
|
|
|
if 'expected_transactions' in test.zipline_test_config: |
84
|
|
|
test.assertEqual( |
85
|
|
|
test.zipline_test_config['expected_transactions'], |
86
|
|
|
transaction_count |
87
|
|
|
) |
88
|
|
|
else: |
89
|
|
|
test.assertEqual( |
90
|
|
|
test.zipline_test_config['order_count'], |
91
|
|
|
transaction_count |
92
|
|
|
) |
93
|
|
|
|
94
|
|
|
# the final message is the risk report, the second to |
95
|
|
|
# last is the final day's results. Positions is a list of |
96
|
|
|
# dicts. |
97
|
|
|
closing_positions = output[-2]['daily_perf']['positions'] |
98
|
|
|
|
99
|
|
|
# confirm that all orders were filled. |
100
|
|
|
# iterate over the output updates, overwriting |
101
|
|
|
# orders when they are updated. Then check the status on all. |
102
|
|
|
orders_by_id = {} |
103
|
|
|
for update in output: |
104
|
|
|
if 'daily_perf' in update: |
105
|
|
|
if 'orders' in update['daily_perf']: |
106
|
|
|
for order in update['daily_perf']['orders']: |
107
|
|
|
orders_by_id[order['id']] = order |
108
|
|
|
|
109
|
|
|
for order in itervalues(orders_by_id): |
110
|
|
|
test.assertEqual( |
111
|
|
|
order['status'], |
112
|
|
|
ORDER_STATUS.FILLED, |
113
|
|
|
"") |
114
|
|
|
|
115
|
|
|
test.assertEqual( |
116
|
|
|
len(closing_positions), |
117
|
|
|
1, |
118
|
|
|
"Portfolio should have one position." |
119
|
|
|
) |
120
|
|
|
|
121
|
|
|
sid = test.zipline_test_config['sid'] |
122
|
|
|
test.assertEqual( |
123
|
|
|
closing_positions[0]['sid'], |
124
|
|
|
sid, |
125
|
|
|
"Portfolio should have one position in " + str(sid) |
126
|
|
|
) |
127
|
|
|
|
128
|
|
|
return output, transaction_count |
129
|
|
|
|
130
|
|
|
|
131
|
|
|
class ExceptionSource(object): |
132
|
|
|
|
133
|
|
|
def __init__(self): |
134
|
|
|
pass |
135
|
|
|
|
136
|
|
|
def get_hash(self): |
137
|
|
|
return "ExceptionSource" |
138
|
|
|
|
139
|
|
|
def __iter__(self): |
140
|
|
|
return self |
141
|
|
|
|
142
|
|
|
def next(self): |
143
|
|
|
5 / 0 |
144
|
|
|
|
145
|
|
|
def __next__(self): |
146
|
|
|
5 / 0 |
147
|
|
|
|
148
|
|
|
|
149
|
|
|
@contextmanager |
150
|
|
|
def security_list_copy(): |
151
|
|
|
old_dir = security_list.SECURITY_LISTS_DIR |
152
|
|
|
new_dir = tempfile.mkdtemp() |
153
|
|
|
try: |
154
|
|
|
for subdir in os.listdir(old_dir): |
155
|
|
|
shutil.copytree(os.path.join(old_dir, subdir), |
156
|
|
|
os.path.join(new_dir, subdir)) |
157
|
|
|
with patch.object(security_list, 'SECURITY_LISTS_DIR', new_dir), \ |
158
|
|
|
patch.object(security_list, 'using_copy', True, |
159
|
|
|
create=True): |
160
|
|
|
yield |
161
|
|
|
finally: |
162
|
|
|
shutil.rmtree(new_dir, True) |
163
|
|
|
|
164
|
|
|
|
165
|
|
|
def add_security_data(adds, deletes): |
166
|
|
|
if not hasattr(security_list, 'using_copy'): |
167
|
|
|
raise Exception('add_security_data must be used within ' |
168
|
|
|
'security_list_copy context') |
169
|
|
|
directory = os.path.join( |
170
|
|
|
security_list.SECURITY_LISTS_DIR, |
171
|
|
|
"leveraged_etf_list/20150127/20150125" |
172
|
|
|
) |
173
|
|
|
if not os.path.exists(directory): |
174
|
|
|
os.makedirs(directory) |
175
|
|
|
del_path = os.path.join(directory, "delete") |
176
|
|
|
with open(del_path, 'w') as f: |
177
|
|
|
for sym in deletes: |
178
|
|
|
f.write(sym) |
179
|
|
|
f.write('\n') |
180
|
|
|
add_path = os.path.join(directory, "add") |
181
|
|
|
with open(add_path, 'w') as f: |
182
|
|
|
for sym in adds: |
183
|
|
|
f.write(sym) |
184
|
|
|
f.write('\n') |
185
|
|
|
|
186
|
|
|
|
187
|
|
|
def all_pairs_matching_predicate(values, pred): |
188
|
|
|
""" |
189
|
|
|
Return an iterator of all pairs, (v0, v1) from values such that |
190
|
|
|
|
191
|
|
|
`pred(v0, v1) == True` |
192
|
|
|
|
193
|
|
|
Parameters |
194
|
|
|
---------- |
195
|
|
|
values : iterable |
196
|
|
|
pred : function |
197
|
|
|
|
198
|
|
|
Returns |
199
|
|
|
------- |
200
|
|
|
pairs_iterator : generator |
201
|
|
|
Generator yielding pairs matching `pred`. |
202
|
|
|
|
203
|
|
|
Examples |
204
|
|
|
-------- |
205
|
|
|
>>> from zipline.utils.test_utils import all_pairs_matching_predicate |
206
|
|
|
>>> from operator import eq, lt |
207
|
|
|
>>> list(all_pairs_matching_predicate(range(5), eq)) |
208
|
|
|
[(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)] |
209
|
|
|
>>> list(all_pairs_matching_predicate("abcd", lt)) |
210
|
|
|
[('a', 'b'), ('a', 'c'), ('a', 'd'), ('b', 'c'), ('b', 'd'), ('c', 'd')] |
211
|
|
|
""" |
212
|
|
|
return filter(lambda pair: pred(*pair), product(values, repeat=2)) |
213
|
|
|
|
214
|
|
|
|
215
|
|
|
def product_upper_triangle(values, include_diagonal=False): |
216
|
|
|
""" |
217
|
|
|
Return an iterator over pairs, (v0, v1), drawn from values. |
218
|
|
|
|
219
|
|
|
If `include_diagonal` is True, returns all pairs such that v0 <= v1. |
220
|
|
|
If `include_diagonal` is False, returns all pairs such that v0 < v1. |
221
|
|
|
""" |
222
|
|
|
return all_pairs_matching_predicate( |
223
|
|
|
values, |
224
|
|
|
operator.le if include_diagonal else operator.lt, |
225
|
|
|
) |
226
|
|
|
|
227
|
|
|
|
228
|
|
|
def all_subindices(index): |
229
|
|
|
""" |
230
|
|
|
Return all valid sub-indices of a pandas Index. |
231
|
|
|
""" |
232
|
|
|
return ( |
233
|
|
|
index[start:stop] |
234
|
|
|
for start, stop in product_upper_triangle(range(len(index) + 1)) |
235
|
|
|
) |
236
|
|
|
|
237
|
|
|
|
238
|
|
|
def make_rotating_equity_info(num_assets, |
239
|
|
|
first_start, |
240
|
|
|
frequency, |
241
|
|
|
periods_between_starts, |
242
|
|
|
asset_lifetime): |
243
|
|
|
""" |
244
|
|
|
Create a DataFrame representing lifetimes of assets that are constantly |
245
|
|
|
rotating in and out of existence. |
246
|
|
|
|
247
|
|
|
Parameters |
248
|
|
|
---------- |
249
|
|
|
num_assets : int |
250
|
|
|
How many assets to create. |
251
|
|
|
first_start : pd.Timestamp |
252
|
|
|
The start date for the first asset. |
253
|
|
|
frequency : str or pd.tseries.offsets.Offset (e.g. trading_day) |
254
|
|
|
Frequency used to interpret next two arguments. |
255
|
|
|
periods_between_starts : int |
256
|
|
|
Create a new asset every `frequency` * `periods_between_new` |
257
|
|
|
asset_lifetime : int |
258
|
|
|
Each asset exists for `frequency` * `asset_lifetime` days. |
259
|
|
|
|
260
|
|
|
Returns |
261
|
|
|
------- |
262
|
|
|
info : pd.DataFrame |
263
|
|
|
DataFrame representing newly-created assets. |
264
|
|
|
""" |
265
|
|
|
return pd.DataFrame( |
266
|
|
|
{ |
267
|
|
|
'symbol': [chr(ord('A') + i) for i in range(num_assets)], |
268
|
|
|
# Start a new asset every `periods_between_starts` days. |
269
|
|
|
'start_date': pd.date_range( |
270
|
|
|
first_start, |
271
|
|
|
freq=(periods_between_starts * frequency), |
272
|
|
|
periods=num_assets, |
273
|
|
|
), |
274
|
|
|
# Each asset lasts for `asset_lifetime` days. |
275
|
|
|
'end_date': pd.date_range( |
276
|
|
|
first_start + (asset_lifetime * frequency), |
277
|
|
|
freq=(periods_between_starts * frequency), |
278
|
|
|
periods=num_assets, |
279
|
|
|
), |
280
|
|
|
'exchange': 'TEST', |
281
|
|
|
}, |
282
|
|
|
index=range(num_assets), |
283
|
|
|
) |
284
|
|
|
|
285
|
|
|
|
286
|
|
|
def make_simple_equity_info(sids, start_date, end_date, symbols=None): |
287
|
|
|
""" |
288
|
|
|
Create a DataFrame representing assets that exist for the full duration |
289
|
|
|
between `start_date` and `end_date`. |
290
|
|
|
|
291
|
|
|
Parameters |
292
|
|
|
---------- |
293
|
|
|
sids : array-like of int |
294
|
|
|
start_date : pd.Timestamp |
295
|
|
|
end_date : pd.Timestamp |
296
|
|
|
symbols : list, optional |
297
|
|
|
Symbols to use for the assets. |
298
|
|
|
If not provided, symbols are generated from the sequence 'A', 'B', ... |
299
|
|
|
|
300
|
|
|
Returns |
301
|
|
|
------- |
302
|
|
|
info : pd.DataFrame |
303
|
|
|
DataFrame representing newly-created assets. |
304
|
|
|
""" |
305
|
|
|
num_assets = len(sids) |
306
|
|
|
if symbols is None: |
307
|
|
|
symbols = list(ascii_uppercase[:num_assets]) |
308
|
|
|
return pd.DataFrame( |
309
|
|
|
{ |
310
|
|
|
'symbol': symbols, |
311
|
|
|
'start_date': [start_date] * num_assets, |
312
|
|
|
'end_date': [end_date] * num_assets, |
313
|
|
|
'exchange': 'TEST', |
314
|
|
|
}, |
315
|
|
|
index=sids, |
316
|
|
|
) |
317
|
|
|
|
318
|
|
|
|
319
|
|
|
def make_future_info(first_sid, |
320
|
|
|
root_symbols, |
321
|
|
|
years, |
322
|
|
|
notice_date_func, |
323
|
|
|
expiration_date_func, |
324
|
|
|
start_date_func, |
325
|
|
|
month_codes=None): |
326
|
|
|
""" |
327
|
|
|
Create a DataFrame representing futures for `root_symbols` during `year`. |
328
|
|
|
|
329
|
|
|
Generates a contract per triple of (symbol, year, month) supplied to |
330
|
|
|
`root_symbols`, `years`, and `month_codes`. |
331
|
|
|
|
332
|
|
|
Parameters |
333
|
|
|
---------- |
334
|
|
|
first_sid : int |
335
|
|
|
The first sid to use for assigning sids to the created contracts. |
336
|
|
|
root_symbols : list[str] |
337
|
|
|
A list of root symbols for which to create futures. |
338
|
|
|
years : list[int or str] |
339
|
|
|
Years (e.g. 2014), for which to produce individual contracts. |
340
|
|
|
notice_date_func : (Timestamp) -> Timestamp |
341
|
|
|
Function to generate notice dates from first of the month associated |
342
|
|
|
with asset month code. Return NaT to simulate futures with no notice |
343
|
|
|
date. |
344
|
|
|
expiration_date_func : (Timestamp) -> Timestamp |
345
|
|
|
Function to generate expiration dates from first of the month |
346
|
|
|
associated with asset month code. |
347
|
|
|
start_date_func : (Timestamp) -> Timestamp, optional |
348
|
|
|
Function to generate start dates from first of the month associated |
349
|
|
|
with each asset month code. Defaults to a start_date one year prior |
350
|
|
|
to the month_code date. |
351
|
|
|
month_codes : dict[str -> [1..12]], optional |
352
|
|
|
Dictionary of month codes for which to create contracts. Entries |
353
|
|
|
should be strings mapped to values from 1 (January) to 12 (December). |
354
|
|
|
Default is zipline.futures.CME_CODE_TO_MONTH |
355
|
|
|
|
356
|
|
|
Returns |
357
|
|
|
------- |
358
|
|
|
futures_info : pd.DataFrame |
359
|
|
|
DataFrame of futures data suitable for passing to an |
360
|
|
|
AssetDBWriterFromDataFrame. |
361
|
|
|
""" |
362
|
|
|
if month_codes is None: |
363
|
|
|
month_codes = CME_CODE_TO_MONTH |
364
|
|
|
|
365
|
|
|
year_strs = list(map(str, years)) |
366
|
|
|
years = [pd.Timestamp(s, tz='UTC') for s in year_strs] |
367
|
|
|
|
368
|
|
|
# Pairs of string/date like ('K06', 2006-05-01) |
369
|
|
|
contract_suffix_to_beginning_of_month = tuple( |
370
|
|
|
(month_code + year_str[-2:], year + MonthBegin(month_num)) |
371
|
|
|
for ((year, year_str), (month_code, month_num)) |
372
|
|
|
in product( |
373
|
|
|
zip(years, year_strs), |
374
|
|
|
iteritems(month_codes), |
375
|
|
|
) |
376
|
|
|
) |
377
|
|
|
|
378
|
|
|
contracts = [] |
379
|
|
|
parts = product(root_symbols, contract_suffix_to_beginning_of_month) |
380
|
|
|
for sid, (root_sym, (suffix, month_begin)) in enumerate(parts, first_sid): |
381
|
|
|
contracts.append({ |
382
|
|
|
'sid': sid, |
383
|
|
|
'root_symbol': root_sym, |
384
|
|
|
'symbol': root_sym + suffix, |
385
|
|
|
'start_date': start_date_func(month_begin), |
386
|
|
|
'notice_date': notice_date_func(month_begin), |
387
|
|
|
'expiration_date': notice_date_func(month_begin), |
388
|
|
|
'contract_multiplier': 500, |
389
|
|
|
}) |
390
|
|
|
return pd.DataFrame.from_records(contracts, index='sid').convert_objects() |
391
|
|
|
|
392
|
|
|
|
393
|
|
|
def make_commodity_future_info(first_sid, |
394
|
|
|
root_symbols, |
395
|
|
|
years, |
396
|
|
|
month_codes=None): |
397
|
|
|
""" |
398
|
|
|
Make futures testing data that simulates the notice/expiration date |
399
|
|
|
behavior of physical commodities like oil. |
400
|
|
|
|
401
|
|
|
Parameters |
402
|
|
|
---------- |
403
|
|
|
first_sid : int |
404
|
|
|
root_symbols : list[str] |
405
|
|
|
years : list[int] |
406
|
|
|
month_codes : dict[str -> int] |
407
|
|
|
|
408
|
|
|
Expiration dates are on the 20th of the month prior to the month code. |
409
|
|
|
Notice dates are are on the 20th two months prior to the month code. |
410
|
|
|
Start dates are one year before the contract month. |
411
|
|
|
|
412
|
|
|
See Also |
413
|
|
|
-------- |
414
|
|
|
make_future_info |
415
|
|
|
""" |
416
|
|
|
nineteen_days = pd.Timedelta(days=19) |
417
|
|
|
one_year = pd.Timedelta(days=365) |
418
|
|
|
return make_future_info( |
419
|
|
|
first_sid=first_sid, |
420
|
|
|
root_symbols=root_symbols, |
421
|
|
|
years=years, |
422
|
|
|
notice_date_func=lambda dt: dt - MonthBegin(2) + nineteen_days, |
423
|
|
|
expiration_date_func=lambda dt: dt - MonthBegin(1) + nineteen_days, |
424
|
|
|
start_date_func=lambda dt: dt - one_year, |
425
|
|
|
month_codes=month_codes, |
426
|
|
|
) |
427
|
|
|
|
428
|
|
|
|
429
|
|
|
def check_allclose(actual, |
430
|
|
|
desired, |
431
|
|
|
rtol=1e-07, |
432
|
|
|
atol=0, |
433
|
|
|
err_msg='', |
434
|
|
|
verbose=True): |
435
|
|
|
""" |
436
|
|
|
Wrapper around np.testing.assert_allclose that also verifies that inputs |
437
|
|
|
are ndarrays. |
438
|
|
|
|
439
|
|
|
See Also |
440
|
|
|
-------- |
441
|
|
|
np.assert_allclose |
442
|
|
|
""" |
443
|
|
|
if type(actual) != type(desired): |
444
|
|
|
raise AssertionError("%s != %s" % (type(actual), type(desired))) |
445
|
|
|
return assert_allclose(actual, desired, err_msg=err_msg, verbose=True) |
446
|
|
|
|
447
|
|
|
|
448
|
|
|
def check_arrays(x, y, err_msg='', verbose=True): |
449
|
|
|
""" |
450
|
|
|
Wrapper around np.testing.assert_array_equal that also verifies that inputs |
451
|
|
|
are ndarrays. |
452
|
|
|
|
453
|
|
|
See Also |
454
|
|
|
-------- |
455
|
|
|
np.assert_array_equal |
456
|
|
|
""" |
457
|
|
|
if type(x) != type(y): |
458
|
|
|
raise AssertionError("%s != %s" % (type(x), type(y))) |
459
|
|
|
return assert_array_equal(x, y, err_msg=err_msg, verbose=True) |
460
|
|
|
|
461
|
|
|
|
462
|
|
|
class UnexpectedAttributeAccess(Exception): |
463
|
|
|
pass |
464
|
|
|
|
465
|
|
|
|
466
|
|
|
class ExplodingObject(object): |
467
|
|
|
""" |
468
|
|
|
Object that will raise an exception on any attribute access. |
469
|
|
|
|
470
|
|
|
Useful for verifying that an object is never touched during a |
471
|
|
|
function/method call. |
472
|
|
|
""" |
473
|
|
|
def __getattribute__(self, name): |
474
|
|
|
raise UnexpectedAttributeAccess(name) |
475
|
|
|
|
476
|
|
|
|
477
|
|
|
class tmp_assets_db(object): |
478
|
|
|
"""Create a temporary assets sqlite database. |
479
|
|
|
This is meant to be used as a context manager. |
480
|
|
|
|
481
|
|
|
Parameters |
482
|
|
|
---------- |
483
|
|
|
data : pd.DataFrame, optional |
484
|
|
|
The data to feed to the writer. By default this maps: |
485
|
|
|
('A', 'B', 'C') -> map(ord, 'ABC') |
486
|
|
|
""" |
487
|
|
|
def __init__(self, **frames): |
488
|
|
|
self._eng = None |
489
|
|
|
if not frames: |
490
|
|
|
frames = { |
491
|
|
|
'equities': make_simple_equity_info( |
492
|
|
|
list(map(ord, 'ABC')), |
493
|
|
|
pd.Timestamp(0), |
494
|
|
|
pd.Timestamp('2015'), |
495
|
|
|
) |
496
|
|
|
} |
497
|
|
|
self._data = AssetDBWriterFromDataFrame(**frames) |
498
|
|
|
|
499
|
|
|
def __enter__(self): |
500
|
|
|
self._eng = eng = create_engine('sqlite://') |
501
|
|
|
self._data.write_all(eng) |
502
|
|
|
return eng |
503
|
|
|
|
504
|
|
|
def __exit__(self, *excinfo): |
505
|
|
|
assert self._eng is not None, '_eng was not set in __enter__' |
506
|
|
|
self._eng.dispose() |
507
|
|
|
|
508
|
|
|
|
509
|
|
|
class tmp_asset_finder(tmp_assets_db): |
510
|
|
|
"""Create a temporary asset finder using an in memory sqlite db. |
511
|
|
|
|
512
|
|
|
Parameters |
513
|
|
|
---------- |
514
|
|
|
data : dict, optional |
515
|
|
|
The data to feed to the writer |
516
|
|
|
""" |
517
|
|
|
def __init__(self, finder_cls=AssetFinder, **frames): |
518
|
|
|
self._finder_cls = finder_cls |
519
|
|
|
super(tmp_asset_finder, self).__init__(**frames) |
520
|
|
|
|
521
|
|
|
def __enter__(self): |
522
|
|
|
return self._finder_cls(super(tmp_asset_finder, self).__enter__()) |
523
|
|
|
|
524
|
|
|
|
525
|
|
|
class SubTestFailures(AssertionError): |
526
|
|
|
def __init__(self, *failures): |
527
|
|
|
self.failures = failures |
528
|
|
|
|
529
|
|
|
def __str__(self): |
530
|
|
|
return 'failures:\n %s' % '\n '.join( |
531
|
|
|
'\n '.join(( |
532
|
|
|
', '.join('%s=%r' % item for item in scope.items()), |
533
|
|
|
'%s: %s' % (type(exc).__name__, exc), |
534
|
|
|
)) for scope, exc in self.failures, |
535
|
|
|
) |
536
|
|
|
|
537
|
|
|
|
538
|
|
|
def subtest(iterator, *_names): |
539
|
|
|
"""Construct a subtest in a unittest. |
540
|
|
|
|
541
|
|
|
This works by decorating a function as a subtest. The test will be run |
542
|
|
|
by iterating over the ``iterator`` and *unpacking the values into the |
543
|
|
|
function. If any of the runs fail, the result will be put into a set and |
544
|
|
|
the rest of the tests will be run. Finally, if any failed, all of the |
545
|
|
|
results will be dumped as one failure. |
546
|
|
|
|
547
|
|
|
Parameters |
548
|
|
|
---------- |
549
|
|
|
iterator : iterable[iterable] |
550
|
|
|
The iterator of arguments to pass to the function. |
551
|
|
|
*name : iterator[str] |
552
|
|
|
The names to use for each element of ``iterator``. These will be used |
553
|
|
|
to print the scope when a test fails. If not provided, it will use the |
554
|
|
|
integer index of the value as the name. |
555
|
|
|
|
556
|
|
|
Examples |
557
|
|
|
-------- |
558
|
|
|
|
559
|
|
|
:: |
560
|
|
|
|
561
|
|
|
class MyTest(TestCase): |
562
|
|
|
def test_thing(self): |
563
|
|
|
# Example usage inside another test. |
564
|
|
|
@subtest(([n] for n in range(100000)), 'n') |
565
|
|
|
def subtest(n): |
566
|
|
|
self.assertEqual(n % 2, 0, 'n was not even') |
567
|
|
|
subtest() |
568
|
|
|
|
569
|
|
|
@subtest(([n] for n in range(100000)), 'n') |
570
|
|
|
def test_decorated_function(self, n): |
571
|
|
|
# Example usage to parameterize an entire function. |
572
|
|
|
self.assertEqual(n % 2, 1, 'n was not odd') |
573
|
|
|
|
574
|
|
|
Notes |
575
|
|
|
----- |
576
|
|
|
We use this when we: |
577
|
|
|
|
578
|
|
|
* Will never want to run each parameter individually. |
579
|
|
|
* Have a large parameter space we are testing |
580
|
|
|
(see tests/utils/test_events.py). |
581
|
|
|
|
582
|
|
|
``nose_parameterized.expand`` will create a test for each parameter |
583
|
|
|
combination which bloats the test output and makes the travis pages slow. |
584
|
|
|
|
585
|
|
|
We cannot use ``unittest2.TestCase.subTest`` because nose, pytest, and |
586
|
|
|
nose2 do not support ``addSubTest``. |
587
|
|
|
""" |
588
|
|
|
def dec(f): |
589
|
|
|
@wraps(f) |
590
|
|
|
def wrapped(*args, **kwargs): |
591
|
|
|
names = _names |
592
|
|
|
failures = [] |
593
|
|
|
for scope in iterator: |
594
|
|
|
scope = tuple(scope) |
595
|
|
|
try: |
596
|
|
|
f(*args + scope, **kwargs) |
597
|
|
|
except Exception as e: |
598
|
|
|
if not names: |
599
|
|
|
names = count() |
600
|
|
|
failures.append((dict(zip(names, scope)), e)) |
601
|
|
|
if failures: |
602
|
|
|
raise SubTestFailures(*failures) |
603
|
|
|
|
604
|
|
|
return wrapped |
605
|
|
|
return dec |
606
|
|
|
|