Completed
Push — master ( 5c3ca1...d3d362 )
by Joe
01:27
created

zipline.utils.gen_calendars()   A

Complexity

Conditions 2

Size

Total Lines 11

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 2
dl 0
loc 11
rs 9.4286
1
from contextlib import contextmanager
2
from functools import wraps
3
from itertools import (
4
    combinations,
5
    count,
6
    product,
7
)
8
import operator
9
import os
10
import shutil
11
from string import ascii_uppercase
12
import tempfile
13
14
from logbook import FileHandler
15
from mock import patch
16
from numpy.testing import assert_allclose, assert_array_equal
17
import pandas as pd
18
from pandas.tseries.offsets import MonthBegin
19
from six import iteritems, itervalues
20
from six.moves import filter
21
from sqlalchemy import create_engine
22
from toolz import concat
23
24
from zipline.assets import AssetFinder
25
from zipline.assets.asset_writer import AssetDBWriterFromDataFrame
26
from zipline.assets.futures import CME_CODE_TO_MONTH
27
from zipline.finance.order import ORDER_STATUS
28
from zipline.utils import security_list
29
from zipline.utils.tradingcalendar import trading_days
30
31
32
EPOCH = pd.Timestamp(0, tz='UTC')
33
34
35
def seconds_to_timestamp(seconds):
36
    return pd.Timestamp(seconds, unit='s', tz='UTC')
37
38
39
def to_utc(time_str):
40
    """Convert a string in US/Eastern time to UTC"""
41
    return pd.Timestamp(time_str, tz='US/Eastern').tz_convert('UTC')
42
43
44
def str_to_seconds(s):
45
    """
46
    Convert a pandas-intelligible string to (integer) seconds since UTC.
47
48
    >>> from pandas import Timestamp
49
    >>> (Timestamp('2014-01-01') - Timestamp(0)).total_seconds()
50
    1388534400.0
51
    >>> str_to_seconds('2014-01-01')
52
    1388534400
53
    """
54
    return int((pd.Timestamp(s, tz='UTC') - EPOCH).total_seconds())
55
56
57
def setup_logger(test, path='test.log'):
58
    test.log_handler = FileHandler(path)
59
    test.log_handler.push_application()
60
61
62
def teardown_logger(test):
63
    test.log_handler.pop_application()
64
    test.log_handler.close()
65
66
67
def drain_zipline(test, zipline):
68
    output = []
69
    transaction_count = 0
70
    msg_counter = 0
71
    # start the simulation
72
    for update in zipline:
73
        msg_counter += 1
74
        output.append(update)
75
        if 'daily_perf' in update:
76
            transaction_count += \
77
                len(update['daily_perf']['transactions'])
78
79
    return output, transaction_count
80
81
82
def assert_single_position(test, zipline):
83
84
    output, transaction_count = drain_zipline(test, zipline)
85
86
    if 'expected_transactions' in test.zipline_test_config:
87
        test.assertEqual(
88
            test.zipline_test_config['expected_transactions'],
89
            transaction_count
90
        )
91
    else:
92
        test.assertEqual(
93
            test.zipline_test_config['order_count'],
94
            transaction_count
95
        )
96
97
    # the final message is the risk report, the second to
98
    # last is the final day's results. Positions is a list of
99
    # dicts.
100
    closing_positions = output[-2]['daily_perf']['positions']
101
102
    # confirm that all orders were filled.
103
    # iterate over the output updates, overwriting
104
    # orders when they are updated. Then check the status on all.
105
    orders_by_id = {}
106
    for update in output:
107
        if 'daily_perf' in update:
108
            if 'orders' in update['daily_perf']:
109
                for order in update['daily_perf']['orders']:
110
                    orders_by_id[order['id']] = order
111
112
    for order in itervalues(orders_by_id):
113
        test.assertEqual(
114
            order['status'],
115
            ORDER_STATUS.FILLED,
116
            "")
117
118
    test.assertEqual(
119
        len(closing_positions),
120
        1,
121
        "Portfolio should have one position."
122
    )
123
124
    sid = test.zipline_test_config['sid']
125
    test.assertEqual(
126
        closing_positions[0]['sid'],
127
        sid,
128
        "Portfolio should have one position in " + str(sid)
129
    )
130
131
    return output, transaction_count
132
133
134
class ExceptionSource(object):
135
136
    def __init__(self):
137
        pass
138
139
    def get_hash(self):
140
        return "ExceptionSource"
141
142
    def __iter__(self):
143
        return self
144
145
    def next(self):
146
        5 / 0
147
148
    def __next__(self):
149
        5 / 0
150
151
152
@contextmanager
153
def security_list_copy():
154
    old_dir = security_list.SECURITY_LISTS_DIR
155
    new_dir = tempfile.mkdtemp()
156
    try:
157
        for subdir in os.listdir(old_dir):
158
            shutil.copytree(os.path.join(old_dir, subdir),
159
                            os.path.join(new_dir, subdir))
160
            with patch.object(security_list, 'SECURITY_LISTS_DIR', new_dir), \
161
                    patch.object(security_list, 'using_copy', True,
162
                                 create=True):
163
                yield
164
    finally:
165
        shutil.rmtree(new_dir, True)
166
167
168
def add_security_data(adds, deletes):
169
    if not hasattr(security_list, 'using_copy'):
170
        raise Exception('add_security_data must be used within '
171
                        'security_list_copy context')
172
    directory = os.path.join(
173
        security_list.SECURITY_LISTS_DIR,
174
        "leveraged_etf_list/20150127/20150125"
175
    )
176
    if not os.path.exists(directory):
177
        os.makedirs(directory)
178
    del_path = os.path.join(directory, "delete")
179
    with open(del_path, 'w') as f:
180
        for sym in deletes:
181
            f.write(sym)
182
            f.write('\n')
183
    add_path = os.path.join(directory, "add")
184
    with open(add_path, 'w') as f:
185
        for sym in adds:
186
            f.write(sym)
187
            f.write('\n')
188
189
190
def all_pairs_matching_predicate(values, pred):
191
    """
192
    Return an iterator of all pairs, (v0, v1) from values such that
193
194
    `pred(v0, v1) == True`
195
196
    Parameters
197
    ----------
198
    values : iterable
199
    pred : function
200
201
    Returns
202
    -------
203
    pairs_iterator : generator
204
       Generator yielding pairs matching `pred`.
205
206
    Examples
207
    --------
208
    >>> from zipline.utils.test_utils import all_pairs_matching_predicate
209
    >>> from operator import eq, lt
210
    >>> list(all_pairs_matching_predicate(range(5), eq))
211
    [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)]
212
    >>> list(all_pairs_matching_predicate("abcd", lt))
213
    [('a', 'b'), ('a', 'c'), ('a', 'd'), ('b', 'c'), ('b', 'd'), ('c', 'd')]
214
    """
215
    return filter(lambda pair: pred(*pair), product(values, repeat=2))
216
217
218
def product_upper_triangle(values, include_diagonal=False):
219
    """
220
    Return an iterator over pairs, (v0, v1), drawn from values.
221
222
    If `include_diagonal` is True, returns all pairs such that v0 <= v1.
223
    If `include_diagonal` is False, returns all pairs such that v0 < v1.
224
    """
225
    return all_pairs_matching_predicate(
226
        values,
227
        operator.le if include_diagonal else operator.lt,
228
    )
229
230
231
def all_subindices(index):
232
    """
233
    Return all valid sub-indices of a pandas Index.
234
    """
235
    return (
236
        index[start:stop]
237
        for start, stop in product_upper_triangle(range(len(index) + 1))
238
    )
239
240
241
def make_rotating_equity_info(num_assets,
242
                              first_start,
243
                              frequency,
244
                              periods_between_starts,
245
                              asset_lifetime):
246
    """
247
    Create a DataFrame representing lifetimes of assets that are constantly
248
    rotating in and out of existence.
249
250
    Parameters
251
    ----------
252
    num_assets : int
253
        How many assets to create.
254
    first_start : pd.Timestamp
255
        The start date for the first asset.
256
    frequency : str or pd.tseries.offsets.Offset (e.g. trading_day)
257
        Frequency used to interpret next two arguments.
258
    periods_between_starts : int
259
        Create a new asset every `frequency` * `periods_between_new`
260
    asset_lifetime : int
261
        Each asset exists for `frequency` * `asset_lifetime` days.
262
263
    Returns
264
    -------
265
    info : pd.DataFrame
266
        DataFrame representing newly-created assets.
267
    """
268
    return pd.DataFrame(
269
        {
270
            'symbol': [chr(ord('A') + i) for i in range(num_assets)],
271
            # Start a new asset every `periods_between_starts` days.
272
            'start_date': pd.date_range(
273
                first_start,
274
                freq=(periods_between_starts * frequency),
275
                periods=num_assets,
276
            ),
277
            # Each asset lasts for `asset_lifetime` days.
278
            'end_date': pd.date_range(
279
                first_start + (asset_lifetime * frequency),
280
                freq=(periods_between_starts * frequency),
281
                periods=num_assets,
282
            ),
283
            'exchange': 'TEST',
284
        },
285
        index=range(num_assets),
286
    )
287
288
289
def make_simple_equity_info(sids, start_date, end_date, symbols=None):
290
    """
291
    Create a DataFrame representing assets that exist for the full duration
292
    between `start_date` and `end_date`.
293
294
    Parameters
295
    ----------
296
    sids : array-like of int
297
    start_date : pd.Timestamp
298
    end_date : pd.Timestamp
299
    symbols : list, optional
300
        Symbols to use for the assets.
301
        If not provided, symbols are generated from the sequence 'A', 'B', ...
302
303
    Returns
304
    -------
305
    info : pd.DataFrame
306
        DataFrame representing newly-created assets.
307
    """
308
    num_assets = len(sids)
309
    if symbols is None:
310
        symbols = list(ascii_uppercase[:num_assets])
311
    return pd.DataFrame(
312
        {
313
            'symbol': symbols,
314
            'start_date': [start_date] * num_assets,
315
            'end_date': [end_date] * num_assets,
316
            'exchange': 'TEST',
317
        },
318
        index=sids,
319
    )
320
321
322
def make_future_info(first_sid,
323
                     root_symbols,
324
                     years,
325
                     notice_date_func,
326
                     expiration_date_func,
327
                     start_date_func,
328
                     month_codes=None):
329
    """
330
    Create a DataFrame representing futures for `root_symbols` during `year`.
331
332
    Generates a contract per triple of (symbol, year, month) supplied to
333
    `root_symbols`, `years`, and `month_codes`.
334
335
    Parameters
336
    ----------
337
    first_sid : int
338
        The first sid to use for assigning sids to the created contracts.
339
    root_symbols : list[str]
340
        A list of root symbols for which to create futures.
341
    years : list[int or str]
342
        Years (e.g. 2014), for which to produce individual contracts.
343
    notice_date_func : (Timestamp) -> Timestamp
344
        Function to generate notice dates from first of the month associated
345
        with asset month code.  Return NaT to simulate futures with no notice
346
        date.
347
    expiration_date_func : (Timestamp) -> Timestamp
348
        Function to generate expiration dates from first of the month
349
        associated with asset month code.
350
    start_date_func : (Timestamp) -> Timestamp, optional
351
        Function to generate start dates from first of the month associated
352
        with each asset month code.  Defaults to a start_date one year prior
353
        to the month_code date.
354
    month_codes : dict[str -> [1..12]], optional
355
        Dictionary of month codes for which to create contracts.  Entries
356
        should be strings mapped to values from 1 (January) to 12 (December).
357
        Default is zipline.futures.CME_CODE_TO_MONTH
358
359
    Returns
360
    -------
361
    futures_info : pd.DataFrame
362
        DataFrame of futures data suitable for passing to an
363
        AssetDBWriterFromDataFrame.
364
    """
365
    if month_codes is None:
366
        month_codes = CME_CODE_TO_MONTH
367
368
    year_strs = list(map(str, years))
369
    years = [pd.Timestamp(s, tz='UTC') for s in year_strs]
370
371
    # Pairs of string/date like ('K06', 2006-05-01)
372
    contract_suffix_to_beginning_of_month = tuple(
373
        (month_code + year_str[-2:], year + MonthBegin(month_num))
374
        for ((year, year_str), (month_code, month_num))
375
        in product(
376
            zip(years, year_strs),
377
            iteritems(month_codes),
378
        )
379
    )
380
381
    contracts = []
382
    parts = product(root_symbols, contract_suffix_to_beginning_of_month)
383
    for sid, (root_sym, (suffix, month_begin)) in enumerate(parts, first_sid):
384
        contracts.append({
385
            'sid': sid,
386
            'root_symbol': root_sym,
387
            'symbol': root_sym + suffix,
388
            'start_date': start_date_func(month_begin),
389
            'notice_date': notice_date_func(month_begin),
390
            'expiration_date': notice_date_func(month_begin),
391
            'contract_multiplier': 500,
392
        })
393
    return pd.DataFrame.from_records(contracts, index='sid').convert_objects()
394
395
396
def make_commodity_future_info(first_sid,
397
                               root_symbols,
398
                               years,
399
                               month_codes=None):
400
    """
401
    Make futures testing data that simulates the notice/expiration date
402
    behavior of physical commodities like oil.
403
404
    Parameters
405
    ----------
406
    first_sid : int
407
    root_symbols : list[str]
408
    years : list[int]
409
    month_codes : dict[str -> int]
410
411
    Expiration dates are on the 20th of the month prior to the month code.
412
    Notice dates are are on the 20th two months prior to the month code.
413
    Start dates are one year before the contract month.
414
415
    See Also
416
    --------
417
    make_future_info
418
    """
419
    nineteen_days = pd.Timedelta(days=19)
420
    one_year = pd.Timedelta(days=365)
421
    return make_future_info(
422
        first_sid=first_sid,
423
        root_symbols=root_symbols,
424
        years=years,
425
        notice_date_func=lambda dt: dt - MonthBegin(2) + nineteen_days,
426
        expiration_date_func=lambda dt: dt - MonthBegin(1) + nineteen_days,
427
        start_date_func=lambda dt: dt - one_year,
428
        month_codes=month_codes,
429
    )
430
431
432
def check_allclose(actual,
433
                   desired,
434
                   rtol=1e-07,
435
                   atol=0,
436
                   err_msg='',
437
                   verbose=True):
438
    """
439
    Wrapper around np.testing.assert_allclose that also verifies that inputs
440
    are ndarrays.
441
442
    See Also
443
    --------
444
    np.assert_allclose
445
    """
446
    if type(actual) != type(desired):
447
        raise AssertionError("%s != %s" % (type(actual), type(desired)))
448
    return assert_allclose(actual, desired, err_msg=err_msg, verbose=True)
449
450
451
def check_arrays(x, y, err_msg='', verbose=True):
452
    """
453
    Wrapper around np.testing.assert_array_equal that also verifies that inputs
454
    are ndarrays.
455
456
    See Also
457
    --------
458
    np.assert_array_equal
459
    """
460
    if type(x) != type(y):
461
        raise AssertionError("%s != %s" % (type(x), type(y)))
462
    return assert_array_equal(x, y, err_msg=err_msg, verbose=True)
463
464
465
class UnexpectedAttributeAccess(Exception):
466
    pass
467
468
469
class ExplodingObject(object):
470
    """
471
    Object that will raise an exception on any attribute access.
472
473
    Useful for verifying that an object is never touched during a
474
    function/method call.
475
    """
476
    def __getattribute__(self, name):
477
        raise UnexpectedAttributeAccess(name)
478
479
480
class tmp_assets_db(object):
481
    """Create a temporary assets sqlite database.
482
    This is meant to be used as a context manager.
483
484
    Parameters
485
    ----------
486
    data : pd.DataFrame, optional
487
        The data to feed to the writer. By default this maps:
488
        ('A', 'B', 'C') -> map(ord, 'ABC')
489
    """
490
    def __init__(self, **frames):
491
        self._eng = None
492
        if not frames:
493
            frames = {
494
                'equities': make_simple_equity_info(
495
                    list(map(ord, 'ABC')),
496
                    pd.Timestamp(0),
497
                    pd.Timestamp('2015'),
498
                )
499
            }
500
        self._data = AssetDBWriterFromDataFrame(**frames)
501
502
    def __enter__(self):
503
        self._eng = eng = create_engine('sqlite://')
504
        self._data.write_all(eng)
505
        return eng
506
507
    def __exit__(self, *excinfo):
508
        assert self._eng is not None, '_eng was not set in __enter__'
509
        self._eng.dispose()
510
511
512
class tmp_asset_finder(tmp_assets_db):
513
    """Create a temporary asset finder using an in memory sqlite db.
514
515
    Parameters
516
    ----------
517
    data : dict, optional
518
        The data to feed to the writer
519
    """
520
    def __init__(self, finder_cls=AssetFinder, **frames):
521
        self._finder_cls = finder_cls
522
        super(tmp_asset_finder, self).__init__(**frames)
523
524
    def __enter__(self):
525
        return self._finder_cls(super(tmp_asset_finder, self).__enter__())
526
527
528
class SubTestFailures(AssertionError):
529
    def __init__(self, *failures):
530
        self.failures = failures
531
532
    def __str__(self):
533
        return 'failures:\n  %s' % '\n  '.join(
534
            '\n    '.join((
535
                ', '.join('%s=%r' % item for item in scope.items()),
536
                '%s: %s' % (type(exc).__name__, exc),
537
            )) for scope, exc in self.failures,
538
        )
539
540
541
def subtest(iterator, *_names):
542
    """Construct a subtest in a unittest.
543
544
    This works by decorating a function as a subtest. The test will be run
545
    by iterating over the ``iterator`` and *unpacking the values into the
546
    function. If any of the runs fail, the result will be put into a set and
547
    the rest of the tests will be run. Finally, if any failed, all of the
548
    results will be dumped as one failure.
549
550
    Parameters
551
    ----------
552
    iterator : iterable[iterable]
553
        The iterator of arguments to pass to the function.
554
    *name : iterator[str]
555
        The names to use for each element of ``iterator``. These will be used
556
        to print the scope when a test fails. If not provided, it will use the
557
        integer index of the value as the name.
558
559
    Examples
560
    --------
561
562
    ::
563
564
       class MyTest(TestCase):
565
           def test_thing(self):
566
               # Example usage inside another test.
567
               @subtest(([n] for n in range(100000)), 'n')
568
               def subtest(n):
569
                   self.assertEqual(n % 2, 0, 'n was not even')
570
               subtest()
571
572
           @subtest(([n] for n in range(100000)), 'n')
573
           def test_decorated_function(self, n):
574
               # Example usage to parameterize an entire function.
575
               self.assertEqual(n % 2, 1, 'n was not odd')
576
577
    Notes
578
    -----
579
    We use this when we:
580
581
    * Will never want to run each parameter individually.
582
    * Have a large parameter space we are testing
583
      (see tests/utils/test_events.py).
584
585
    ``nose_parameterized.expand`` will create a test for each parameter
586
    combination which bloats the test output and makes the travis pages slow.
587
588
    We cannot use ``unittest2.TestCase.subTest`` because nose, pytest, and
589
    nose2 do not support ``addSubTest``.
590
    """
591
    def dec(f):
592
        @wraps(f)
593
        def wrapped(*args, **kwargs):
594
            names = _names
595
            failures = []
596
            for scope in iterator:
597
                scope = tuple(scope)
598
                try:
599
                    f(*args + scope, **kwargs)
600
                except Exception as e:
601
                    if not names:
602
                        names = count()
603
                    failures.append((dict(zip(names, scope)), e))
604
            if failures:
605
                raise SubTestFailures(*failures)
606
607
        return wrapped
608
    return dec
609
610
611
def assert_timestamp_equal(left, right, compare_nat_equal=True, msg=""):
612
    """
613
    Assert that two pandas Timestamp objects are the same.
614
615
    Parameters
616
    ----------
617
    left, right : pd.Timestamp
618
        The values to compare.
619
    compare_nat_equal : bool, optional
620
        Whether to consider `NaT` values equal.  Defaults to True.
621
    msg : str, optional
622
        A message to forward to `pd.util.testing.assert_equal`.
623
    """
624
    if compare_nat_equal and left is pd.NaT and right is pd.NaT:
625
        return
626
    return pd.util.testing.assert_equal(left, right, msg=msg)
627
628
629
def powerset(values):
630
    """
631
    Return the power set (i.e., the set of all subsets) of entries in `values`.
632
    """
633
    return concat(combinations(values, i) for i in range(len(values) + 1))
634
635
636
def to_series(knowledge_dates, earning_dates):
637
    """
638
    Helper for converting a dict of strings to a Series of datetimes.
639
640
    This is just for making the test cases more readable.
641
    """
642
    return pd.Series(
643
        index=pd.to_datetime(knowledge_dates),
644
        data=pd.to_datetime(earning_dates),
645
    )
646
647
648
def num_days_in_range(dates, start, end):
649
    """
650
    Return the number of days in `dates` between start and end, inclusive.
651
    """
652
    start_idx, stop_idx = dates.slice_locs(start, end)
653
    return stop_idx - start_idx
654
655
656
def gen_calendars(start, stop, critical_dates):
657
    """
658
    Generate calendars to use as inputs.
659
    """
660
    all_dates = pd.date_range(start, stop, tz='utc')
661
    for to_drop in map(list, powerset(critical_dates)):
662
        # Have to yield tuples.
663
        yield (all_dates.drop(to_drop),)
664
665
    # Also test with the trading calendar.
666
    yield (trading_days[trading_days.slice_indexer(start, stop)],)
667