Completed
Push — master ( ce3727...c8ca48 )
by
unknown
01:15
created

zipline.utils.wrapped()   B

Complexity

Conditions 5

Size

Total Lines 14

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 5
dl 0
loc 14
rs 8.5455
1
from contextlib import contextmanager
2
from functools import wraps
3
from itertools import (
4
    count,
5
    product,
6
)
7
import operator
8
import os
9
import shutil
10
from string import ascii_uppercase
11
import tempfile
12
13
from logbook import FileHandler
14
from mock import patch
15
from numpy.testing import assert_allclose, assert_array_equal
16
import pandas as pd
17
from pandas.tseries.offsets import MonthBegin
18
from six import iteritems, itervalues
19
from six.moves import filter
20
from sqlalchemy import create_engine
21
22
from zipline.assets import AssetFinder
23
from zipline.assets.asset_writer import AssetDBWriterFromDataFrame
24
from zipline.assets.futures import CME_CODE_TO_MONTH
25
from zipline.finance.blotter import ORDER_STATUS
26
from zipline.utils import security_list
27
28
29
EPOCH = pd.Timestamp(0, tz='UTC')
30
31
32
def seconds_to_timestamp(seconds):
33
    return pd.Timestamp(seconds, unit='s', tz='UTC')
34
35
36
def to_utc(time_str):
37
    """Convert a string in US/Eastern time to UTC"""
38
    return pd.Timestamp(time_str, tz='US/Eastern').tz_convert('UTC')
39
40
41
def str_to_seconds(s):
42
    """
43
    Convert a pandas-intelligible string to (integer) seconds since UTC.
44
45
    >>> from pandas import Timestamp
46
    >>> (Timestamp('2014-01-01') - Timestamp(0)).total_seconds()
47
    1388534400.0
48
    >>> str_to_seconds('2014-01-01')
49
    1388534400
50
    """
51
    return int((pd.Timestamp(s, tz='UTC') - EPOCH).total_seconds())
52
53
54
def setup_logger(test, path='test.log'):
55
    test.log_handler = FileHandler(path)
56
    test.log_handler.push_application()
57
58
59
def teardown_logger(test):
60
    test.log_handler.pop_application()
61
    test.log_handler.close()
62
63
64
def drain_zipline(test, zipline):
65
    output = []
66
    transaction_count = 0
67
    msg_counter = 0
68
    # start the simulation
69
    for update in zipline:
70
        msg_counter += 1
71
        output.append(update)
72
        if 'daily_perf' in update:
73
            transaction_count += \
74
                len(update['daily_perf']['transactions'])
75
76
    return output, transaction_count
77
78
79
def assert_single_position(test, zipline):
80
81
    output, transaction_count = drain_zipline(test, zipline)
82
83
    if 'expected_transactions' in test.zipline_test_config:
84
        test.assertEqual(
85
            test.zipline_test_config['expected_transactions'],
86
            transaction_count
87
        )
88
    else:
89
        test.assertEqual(
90
            test.zipline_test_config['order_count'],
91
            transaction_count
92
        )
93
94
    # the final message is the risk report, the second to
95
    # last is the final day's results. Positions is a list of
96
    # dicts.
97
    closing_positions = output[-2]['daily_perf']['positions']
98
99
    # confirm that all orders were filled.
100
    # iterate over the output updates, overwriting
101
    # orders when they are updated. Then check the status on all.
102
    orders_by_id = {}
103
    for update in output:
104
        if 'daily_perf' in update:
105
            if 'orders' in update['daily_perf']:
106
                for order in update['daily_perf']['orders']:
107
                    orders_by_id[order['id']] = order
108
109
    for order in itervalues(orders_by_id):
110
        test.assertEqual(
111
            order['status'],
112
            ORDER_STATUS.FILLED,
113
            "")
114
115
    test.assertEqual(
116
        len(closing_positions),
117
        1,
118
        "Portfolio should have one position."
119
    )
120
121
    sid = test.zipline_test_config['sid']
122
    test.assertEqual(
123
        closing_positions[0]['sid'],
124
        sid,
125
        "Portfolio should have one position in " + str(sid)
126
    )
127
128
    return output, transaction_count
129
130
131
class ExceptionSource(object):
132
133
    def __init__(self):
134
        pass
135
136
    def get_hash(self):
137
        return "ExceptionSource"
138
139
    def __iter__(self):
140
        return self
141
142
    def next(self):
143
        5 / 0
144
145
    def __next__(self):
146
        5 / 0
147
148
149
@contextmanager
150
def security_list_copy():
151
    old_dir = security_list.SECURITY_LISTS_DIR
152
    new_dir = tempfile.mkdtemp()
153
    try:
154
        for subdir in os.listdir(old_dir):
155
            shutil.copytree(os.path.join(old_dir, subdir),
156
                            os.path.join(new_dir, subdir))
157
            with patch.object(security_list, 'SECURITY_LISTS_DIR', new_dir), \
158
                    patch.object(security_list, 'using_copy', True,
159
                                 create=True):
160
                yield
161
    finally:
162
        shutil.rmtree(new_dir, True)
163
164
165
def add_security_data(adds, deletes):
166
    if not hasattr(security_list, 'using_copy'):
167
        raise Exception('add_security_data must be used within '
168
                        'security_list_copy context')
169
    directory = os.path.join(
170
        security_list.SECURITY_LISTS_DIR,
171
        "leveraged_etf_list/20150127/20150125"
172
    )
173
    if not os.path.exists(directory):
174
        os.makedirs(directory)
175
    del_path = os.path.join(directory, "delete")
176
    with open(del_path, 'w') as f:
177
        for sym in deletes:
178
            f.write(sym)
179
            f.write('\n')
180
    add_path = os.path.join(directory, "add")
181
    with open(add_path, 'w') as f:
182
        for sym in adds:
183
            f.write(sym)
184
            f.write('\n')
185
186
187
def all_pairs_matching_predicate(values, pred):
188
    """
189
    Return an iterator of all pairs, (v0, v1) from values such that
190
191
    `pred(v0, v1) == True`
192
193
    Parameters
194
    ----------
195
    values : iterable
196
    pred : function
197
198
    Returns
199
    -------
200
    pairs_iterator : generator
201
       Generator yielding pairs matching `pred`.
202
203
    Examples
204
    --------
205
    >>> from zipline.utils.test_utils import all_pairs_matching_predicate
206
    >>> from operator import eq, lt
207
    >>> list(all_pairs_matching_predicate(range(5), eq))
208
    [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)]
209
    >>> list(all_pairs_matching_predicate("abcd", lt))
210
    [('a', 'b'), ('a', 'c'), ('a', 'd'), ('b', 'c'), ('b', 'd'), ('c', 'd')]
211
    """
212
    return filter(lambda pair: pred(*pair), product(values, repeat=2))
213
214
215
def product_upper_triangle(values, include_diagonal=False):
216
    """
217
    Return an iterator over pairs, (v0, v1), drawn from values.
218
219
    If `include_diagonal` is True, returns all pairs such that v0 <= v1.
220
    If `include_diagonal` is False, returns all pairs such that v0 < v1.
221
    """
222
    return all_pairs_matching_predicate(
223
        values,
224
        operator.le if include_diagonal else operator.lt,
225
    )
226
227
228
def all_subindices(index):
229
    """
230
    Return all valid sub-indices of a pandas Index.
231
    """
232
    return (
233
        index[start:stop]
234
        for start, stop in product_upper_triangle(range(len(index) + 1))
235
    )
236
237
238
def make_rotating_equity_info(num_assets,
239
                              first_start,
240
                              frequency,
241
                              periods_between_starts,
242
                              asset_lifetime):
243
    """
244
    Create a DataFrame representing lifetimes of assets that are constantly
245
    rotating in and out of existence.
246
247
    Parameters
248
    ----------
249
    num_assets : int
250
        How many assets to create.
251
    first_start : pd.Timestamp
252
        The start date for the first asset.
253
    frequency : str or pd.tseries.offsets.Offset (e.g. trading_day)
254
        Frequency used to interpret next two arguments.
255
    periods_between_starts : int
256
        Create a new asset every `frequency` * `periods_between_new`
257
    asset_lifetime : int
258
        Each asset exists for `frequency` * `asset_lifetime` days.
259
260
    Returns
261
    -------
262
    info : pd.DataFrame
263
        DataFrame representing newly-created assets.
264
    """
265
    return pd.DataFrame(
266
        {
267
            'symbol': [chr(ord('A') + i) for i in range(num_assets)],
268
            # Start a new asset every `periods_between_starts` days.
269
            'start_date': pd.date_range(
270
                first_start,
271
                freq=(periods_between_starts * frequency),
272
                periods=num_assets,
273
            ),
274
            # Each asset lasts for `asset_lifetime` days.
275
            'end_date': pd.date_range(
276
                first_start + (asset_lifetime * frequency),
277
                freq=(periods_between_starts * frequency),
278
                periods=num_assets,
279
            ),
280
            'exchange': 'TEST',
281
        },
282
        index=range(num_assets),
283
    )
284
285
286
def make_simple_equity_info(sids, start_date, end_date, symbols=None):
287
    """
288
    Create a DataFrame representing assets that exist for the full duration
289
    between `start_date` and `end_date`.
290
291
    Parameters
292
    ----------
293
    sids : array-like of int
294
    start_date : pd.Timestamp
295
    end_date : pd.Timestamp
296
    symbols : list, optional
297
        Symbols to use for the assets.
298
        If not provided, symbols are generated from the sequence 'A', 'B', ...
299
300
    Returns
301
    -------
302
    info : pd.DataFrame
303
        DataFrame representing newly-created assets.
304
    """
305
    num_assets = len(sids)
306
    if symbols is None:
307
        symbols = list(ascii_uppercase[:num_assets])
308
    return pd.DataFrame(
309
        {
310
            'symbol': symbols,
311
            'start_date': [start_date] * num_assets,
312
            'end_date': [end_date] * num_assets,
313
            'exchange': 'TEST',
314
        },
315
        index=sids,
316
    )
317
318
319
def make_future_info(first_sid,
320
                     root_symbols,
321
                     years,
322
                     notice_date_func,
323
                     expiration_date_func,
324
                     start_date_func,
325
                     month_codes=None):
326
    """
327
    Create a DataFrame representing futures for `root_symbols` during `year`.
328
329
    Generates a contract per triple of (symbol, year, month) supplied to
330
    `root_symbols`, `years`, and `month_codes`.
331
332
    Parameters
333
    ----------
334
    first_sid : int
335
        The first sid to use for assigning sids to the created contracts.
336
    root_symbols : list[str]
337
        A list of root symbols for which to create futures.
338
    years : list[int or str]
339
        Years (e.g. 2014), for which to produce individual contracts.
340
    notice_date_func : (Timestamp) -> Timestamp
341
        Function to generate notice dates from first of the month associated
342
        with asset month code.  Return NaT to simulate futures with no notice
343
        date.
344
    expiration_date_func : (Timestamp) -> Timestamp
345
        Function to generate expiration dates from first of the month
346
        associated with asset month code.
347
    start_date_func : (Timestamp) -> Timestamp, optional
348
        Function to generate start dates from first of the month associated
349
        with each asset month code.  Defaults to a start_date one year prior
350
        to the month_code date.
351
    month_codes : dict[str -> [1..12]], optional
352
        Dictionary of month codes for which to create contracts.  Entries
353
        should be strings mapped to values from 1 (January) to 12 (December).
354
        Default is zipline.futures.CME_CODE_TO_MONTH
355
356
    Returns
357
    -------
358
    futures_info : pd.DataFrame
359
        DataFrame of futures data suitable for passing to an
360
        AssetDBWriterFromDataFrame.
361
    """
362
    if month_codes is None:
363
        month_codes = CME_CODE_TO_MONTH
364
365
    year_strs = list(map(str, years))
366
    years = [pd.Timestamp(s, tz='UTC') for s in year_strs]
367
368
    # Pairs of string/date like ('K06', 2006-05-01)
369
    contract_suffix_to_beginning_of_month = tuple(
370
        (month_code + year_str[-2:], year + MonthBegin(month_num))
371
        for ((year, year_str), (month_code, month_num))
372
        in product(
373
            zip(years, year_strs),
374
            iteritems(month_codes),
375
        )
376
    )
377
378
    contracts = []
379
    parts = product(root_symbols, contract_suffix_to_beginning_of_month)
380
    for sid, (root_sym, (suffix, month_begin)) in enumerate(parts, first_sid):
381
        contracts.append({
382
            'sid': sid,
383
            'root_symbol': root_sym,
384
            'symbol': root_sym + suffix,
385
            'start_date': start_date_func(month_begin),
386
            'notice_date': notice_date_func(month_begin),
387
            'expiration_date': notice_date_func(month_begin),
388
            'contract_multiplier': 500,
389
        })
390
    return pd.DataFrame.from_records(contracts, index='sid').convert_objects()
391
392
393
def make_commodity_future_info(first_sid,
394
                               root_symbols,
395
                               years,
396
                               month_codes=None):
397
    """
398
    Make futures testing data that simulates the notice/expiration date
399
    behavior of physical commodities like oil.
400
401
    Parameters
402
    ----------
403
    first_sid : int
404
    root_symbols : list[str]
405
    years : list[int]
406
    month_codes : dict[str -> int]
407
408
    Expiration dates are on the 20th of the month prior to the month code.
409
    Notice dates are are on the 20th two months prior to the month code.
410
    Start dates are one year before the contract month.
411
412
    See Also
413
    --------
414
    make_future_info
415
    """
416
    nineteen_days = pd.Timedelta(days=19)
417
    one_year = pd.Timedelta(days=365)
418
    return make_future_info(
419
        first_sid=first_sid,
420
        root_symbols=root_symbols,
421
        years=years,
422
        notice_date_func=lambda dt: dt - MonthBegin(2) + nineteen_days,
423
        expiration_date_func=lambda dt: dt - MonthBegin(1) + nineteen_days,
424
        start_date_func=lambda dt: dt - one_year,
425
        month_codes=month_codes,
426
    )
427
428
429
def check_allclose(actual,
430
                   desired,
431
                   rtol=1e-07,
432
                   atol=0,
433
                   err_msg='',
434
                   verbose=True):
435
    """
436
    Wrapper around np.testing.assert_allclose that also verifies that inputs
437
    are ndarrays.
438
439
    See Also
440
    --------
441
    np.assert_allclose
442
    """
443
    if type(actual) != type(desired):
444
        raise AssertionError("%s != %s" % (type(actual), type(desired)))
445
    return assert_allclose(actual, desired, err_msg=err_msg, verbose=True)
446
447
448
def check_arrays(x, y, err_msg='', verbose=True):
449
    """
450
    Wrapper around np.testing.assert_array_equal that also verifies that inputs
451
    are ndarrays.
452
453
    See Also
454
    --------
455
    np.assert_array_equal
456
    """
457
    if type(x) != type(y):
458
        raise AssertionError("%s != %s" % (type(x), type(y)))
459
    return assert_array_equal(x, y, err_msg=err_msg, verbose=True)
460
461
462
class UnexpectedAttributeAccess(Exception):
463
    pass
464
465
466
class ExplodingObject(object):
467
    """
468
    Object that will raise an exception on any attribute access.
469
470
    Useful for verifying that an object is never touched during a
471
    function/method call.
472
    """
473
    def __getattribute__(self, name):
474
        raise UnexpectedAttributeAccess(name)
475
476
477
class tmp_assets_db(object):
478
    """Create a temporary assets sqlite database.
479
    This is meant to be used as a context manager.
480
481
    Parameters
482
    ----------
483
    data : pd.DataFrame, optional
484
        The data to feed to the writer. By default this maps:
485
        ('A', 'B', 'C') -> map(ord, 'ABC')
486
    """
487
    def __init__(self, **frames):
488
        self._eng = None
489
        if not frames:
490
            frames = {
491
                'equities': make_simple_equity_info(
492
                    list(map(ord, 'ABC')),
493
                    pd.Timestamp(0),
494
                    pd.Timestamp('2015'),
495
                )
496
            }
497
        self._data = AssetDBWriterFromDataFrame(**frames)
498
499
    def __enter__(self):
500
        self._eng = eng = create_engine('sqlite://')
501
        self._data.write_all(eng)
502
        return eng
503
504
    def __exit__(self, *excinfo):
505
        assert self._eng is not None, '_eng was not set in __enter__'
506
        self._eng.dispose()
507
508
509
class tmp_asset_finder(tmp_assets_db):
510
    """Create a temporary asset finder using an in memory sqlite db.
511
512
    Parameters
513
    ----------
514
    data : dict, optional
515
        The data to feed to the writer
516
    """
517
    def __init__(self, finder_cls=AssetFinder, **frames):
518
        self._finder_cls = finder_cls
519
        super(tmp_asset_finder, self).__init__(**frames)
520
521
    def __enter__(self):
522
        return self._finder_cls(super(tmp_asset_finder, self).__enter__())
523
524
525
class SubTestFailures(AssertionError):
526
    def __init__(self, *failures):
527
        self.failures = failures
528
529
    def __str__(self):
530
        return 'failures:\n  %s' % '\n  '.join(
531
            '\n    '.join((
532
                ', '.join('%s=%r' % item for item in scope.items()),
533
                '%s: %s' % (type(exc).__name__, exc),
534
            )) for scope, exc in self.failures,
535
        )
536
537
538
def subtest(iterator, *_names):
539
    """Construct a subtest in a unittest.
540
541
    This works by decorating a function as a subtest. The test will be run
542
    by iterating over the ``iterator`` and *unpacking the values into the
543
    function. If any of the runs fail, the result will be put into a set and
544
    the rest of the tests will be run. Finally, if any failed, all of the
545
    results will be dumped as one failure.
546
547
    Parameters
548
    ----------
549
    iterator : iterable[iterable]
550
        The iterator of arguments to pass to the function.
551
    *name : iterator[str]
552
        The names to use for each element of ``iterator``. These will be used
553
        to print the scope when a test fails. If not provided, it will use the
554
        integer index of the value as the name.
555
556
    Examples
557
    --------
558
559
    ::
560
561
       class MyTest(TestCase):
562
           def test_thing(self):
563
               # Example usage inside another test.
564
               @subtest(([n] for n in range(100000)), 'n')
565
               def subtest(n):
566
                   self.assertEqual(n % 2, 0, 'n was not even')
567
               subtest()
568
569
           @subtest(([n] for n in range(100000)), 'n')
570
           def test_decorated_function(self, n):
571
               # Example usage to parameterize an entire function.
572
               self.assertEqual(n % 2, 1, 'n was not odd')
573
574
    Notes
575
    -----
576
    We use this when we:
577
578
    * Will never want to run each parameter individually.
579
    * Have a large parameter space we are testing
580
      (see tests/utils/test_events.py).
581
582
    ``nose_parameterized.expand`` will create a test for each parameter
583
    combination which bloats the test output and makes the travis pages slow.
584
585
    We cannot use ``unittest2.TestCase.subTest`` because nose, pytest, and
586
    nose2 do not support ``addSubTest``.
587
    """
588
    def dec(f):
589
        @wraps(f)
590
        def wrapped(*args, **kwargs):
591
            names = _names
592
            failures = []
593
            for scope in iterator:
594
                scope = tuple(scope)
595
                try:
596
                    f(*args + scope, **kwargs)
597
                except Exception as e:
598
                    if not names:
599
                        names = count()
600
                    failures.append((dict(zip(names, scope)), e))
601
            if failures:
602
                raise SubTestFailures(*failures)
603
604
        return wrapped
605
    return dec
606