Completed
Pull Request — master (#858)
by Eddie
05:34 queued 02:25
created

zipline.utils.FakeDataPortal   A

Complexity

Total Complexity 1

Size/Duplication

Total Lines 4
Duplicated Lines 0 %
Metric Value
dl 0
loc 4
rs 10
wmc 1

1 Method

Rating   Name   Duplication   Size   Complexity  
A __init__() 0 2 1
1
from contextlib import contextmanager
2
from functools import wraps
3
from itertools import (
4
    count,
5
    product,
6
)
7
import operator
8
import os
9
import shutil
10
from string import ascii_uppercase
11
import tempfile
12
from bcolz import ctable
13
14
from logbook import FileHandler
15
from mock import patch
16
from numpy.testing import assert_allclose, assert_array_equal
17
import pandas as pd
18
from pandas.tseries.offsets import MonthBegin
19
from six import iteritems, itervalues
20
from six.moves import filter
21
from sqlalchemy import create_engine
22
23
from zipline.assets import AssetFinder
24
from zipline.assets.asset_writer import AssetDBWriterFromDataFrame
25
from zipline.assets.futures import CME_CODE_TO_MONTH
26
from zipline.data.data_portal import DataPortal
27
from zipline.data.minute_writer import MinuteBarWriterFromDataFrames
28
from zipline.data.us_equity_pricing import SQLiteAdjustmentWriter, OHLC, \
29
    UINT32_MAX, BcolzDailyBarWriter
30
from zipline.finance.order import ORDER_STATUS
31
from zipline.utils import security_list
32
33
import numpy as np
34
from numpy import (
35
    float64,
36
    uint32
37
)
38
39
40
EPOCH = pd.Timestamp(0, tz='UTC')
41
42
43
def seconds_to_timestamp(seconds):
44
    return pd.Timestamp(seconds, unit='s', tz='UTC')
45
46
47
def to_utc(time_str):
48
    """Convert a string in US/Eastern time to UTC"""
49
    return pd.Timestamp(time_str, tz='US/Eastern').tz_convert('UTC')
50
51
52
def str_to_seconds(s):
53
    """
54
    Convert a pandas-intelligible string to (integer) seconds since UTC.
55
56
    >>> from pandas import Timestamp
57
    >>> (Timestamp('2014-01-01') - Timestamp(0)).total_seconds()
58
    1388534400.0
59
    >>> str_to_seconds('2014-01-01')
60
    1388534400
61
    """
62
    return int((pd.Timestamp(s, tz='UTC') - EPOCH).total_seconds())
63
64
65
def setup_logger(test, path='test.log'):
66
    test.log_handler = FileHandler(path)
67
    test.log_handler.push_application()
68
69
70
def teardown_logger(test):
71
    test.log_handler.pop_application()
72
    test.log_handler.close()
73
74
75
def drain_zipline(test, zipline):
76
    output = []
77
    transaction_count = 0
78
    msg_counter = 0
79
    # start the simulation
80
    for update in zipline:
81
        msg_counter += 1
82
        output.append(update)
83
        if 'daily_perf' in update:
84
            transaction_count += \
85
                len(update['daily_perf']['transactions'])
86
87
    return output, transaction_count
88
89
90
def check_algo_results(test,
91
                       results,
92
                       expected_transactions_count=None,
93
                       expected_order_count=None,
94
                       expected_positions_count=None,
95
                       sid=None):
96
97
    if expected_transactions_count is not None:
98
        txns = flatten_list(results["transactions"])
99
        test.assertEqual(expected_transactions_count, len(txns))
100
101
    if expected_positions_count is not None:
102
        raise NotImplementedError
103
104
    if expected_order_count is not None:
105
        # de-dup orders on id, because orders are put back into perf packets
106
        # whenever they a txn is filled
107
        orders = set([order['id'] for order in
108
                      flatten_list(results["orders"])])
109
110
        test.assertEqual(expected_order_count, len(orders))
111
112
113
def flatten_list(list):
114
    return [item for sublist in list for item in sublist]
115
116
117
def assert_single_position(test, zipline):
118
119
    output, transaction_count = drain_zipline(test, zipline)
120
121
    if 'expected_transactions' in test.zipline_test_config:
122
        test.assertEqual(
123
            test.zipline_test_config['expected_transactions'],
124
            transaction_count
125
        )
126
    else:
127
        test.assertEqual(
128
            test.zipline_test_config['order_count'],
129
            transaction_count
130
        )
131
132
    # the final message is the risk report, the second to
133
    # last is the final day's results. Positions is a list of
134
    # dicts.
135
    closing_positions = output[-2]['daily_perf']['positions']
136
137
    # confirm that all orders were filled.
138
    # iterate over the output updates, overwriting
139
    # orders when they are updated. Then check the status on all.
140
    orders_by_id = {}
141
    for update in output:
142
        if 'daily_perf' in update:
143
            if 'orders' in update['daily_perf']:
144
                for order in update['daily_perf']['orders']:
145
                    orders_by_id[order['id']] = order
146
147
    for order in itervalues(orders_by_id):
148
        test.assertEqual(
149
            order['status'],
150
            ORDER_STATUS.FILLED,
151
            "")
152
153
    test.assertEqual(
154
        len(closing_positions),
155
        1,
156
        "Portfolio should have one position."
157
    )
158
159
    sid = test.zipline_test_config['sid']
160
    test.assertEqual(
161
        closing_positions[0]['sid'],
162
        sid,
163
        "Portfolio should have one position in " + str(sid)
164
    )
165
166
    return output, transaction_count
167
168
169
class ExceptionSource(object):
170
171
    def __init__(self):
172
        pass
173
174
    def get_hash(self):
175
        return "ExceptionSource"
176
177
    def __iter__(self):
178
        return self
179
180
    def next(self):
181
        5 / 0
182
183
    def __next__(self):
184
        5 / 0
185
186
187
@contextmanager
188
def security_list_copy():
189
    old_dir = security_list.SECURITY_LISTS_DIR
190
    new_dir = tempfile.mkdtemp()
191
    try:
192
        for subdir in os.listdir(old_dir):
193
            shutil.copytree(os.path.join(old_dir, subdir),
194
                            os.path.join(new_dir, subdir))
195
            with patch.object(security_list, 'SECURITY_LISTS_DIR', new_dir), \
196
                    patch.object(security_list, 'using_copy', True,
197
                                 create=True):
198
                yield
199
    finally:
200
        shutil.rmtree(new_dir, True)
201
202
203
def add_security_data(adds, deletes):
204
    if not hasattr(security_list, 'using_copy'):
205
        raise Exception('add_security_data must be used within '
206
                        'security_list_copy context')
207
    directory = os.path.join(
208
        security_list.SECURITY_LISTS_DIR,
209
        "leveraged_etf_list/20150127/20150125"
210
    )
211
    if not os.path.exists(directory):
212
        os.makedirs(directory)
213
    del_path = os.path.join(directory, "delete")
214
    with open(del_path, 'w') as f:
215
        for sym in deletes:
216
            f.write(sym)
217
            f.write('\n')
218
    add_path = os.path.join(directory, "add")
219
    with open(add_path, 'w') as f:
220
        for sym in adds:
221
            f.write(sym)
222
            f.write('\n')
223
224
225
def all_pairs_matching_predicate(values, pred):
226
    """
227
    Return an iterator of all pairs, (v0, v1) from values such that
228
229
    `pred(v0, v1) == True`
230
231
    Parameters
232
    ----------
233
    values : iterable
234
    pred : function
235
236
    Returns
237
    -------
238
    pairs_iterator : generator
239
       Generator yielding pairs matching `pred`.
240
241
    Examples
242
    --------
243
    >>> from zipline.utils.test_utils import all_pairs_matching_predicate
244
    >>> from operator import eq, lt
245
    >>> list(all_pairs_matching_predicate(range(5), eq))
246
    [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)]
247
    >>> list(all_pairs_matching_predicate("abcd", lt))
248
    [('a', 'b'), ('a', 'c'), ('a', 'd'), ('b', 'c'), ('b', 'd'), ('c', 'd')]
249
    """
250
    return filter(lambda pair: pred(*pair), product(values, repeat=2))
251
252
253
def product_upper_triangle(values, include_diagonal=False):
254
    """
255
    Return an iterator over pairs, (v0, v1), drawn from values.
256
257
    If `include_diagonal` is True, returns all pairs such that v0 <= v1.
258
    If `include_diagonal` is False, returns all pairs such that v0 < v1.
259
    """
260
    return all_pairs_matching_predicate(
261
        values,
262
        operator.le if include_diagonal else operator.lt,
263
    )
264
265
266
def all_subindices(index):
267
    """
268
    Return all valid sub-indices of a pandas Index.
269
    """
270
    return (
271
        index[start:stop]
272
        for start, stop in product_upper_triangle(range(len(index) + 1))
273
    )
274
275
276
def make_rotating_equity_info(num_assets,
277
                              first_start,
278
                              frequency,
279
                              periods_between_starts,
280
                              asset_lifetime):
281
    """
282
    Create a DataFrame representing lifetimes of assets that are constantly
283
    rotating in and out of existence.
284
285
    Parameters
286
    ----------
287
    num_assets : int
288
        How many assets to create.
289
    first_start : pd.Timestamp
290
        The start date for the first asset.
291
    frequency : str or pd.tseries.offsets.Offset (e.g. trading_day)
292
        Frequency used to interpret next two arguments.
293
    periods_between_starts : int
294
        Create a new asset every `frequency` * `periods_between_new`
295
    asset_lifetime : int
296
        Each asset exists for `frequency` * `asset_lifetime` days.
297
298
    Returns
299
    -------
300
    info : pd.DataFrame
301
        DataFrame representing newly-created assets.
302
    """
303
    return pd.DataFrame(
304
        {
305
            'symbol': [chr(ord('A') + i) for i in range(num_assets)],
306
            # Start a new asset every `periods_between_starts` days.
307
            'start_date': pd.date_range(
308
                first_start,
309
                freq=(periods_between_starts * frequency),
310
                periods=num_assets,
311
            ),
312
            # Each asset lasts for `asset_lifetime` days.
313
            'end_date': pd.date_range(
314
                first_start + (asset_lifetime * frequency),
315
                freq=(periods_between_starts * frequency),
316
                periods=num_assets,
317
            ),
318
            'exchange': 'TEST',
319
        },
320
        index=range(num_assets),
321
    )
322
323
324
def make_simple_equity_info(sids, start_date, end_date, symbols=None):
325
    """
326
    Create a DataFrame representing assets that exist for the full duration
327
    between `start_date` and `end_date`.
328
329
    Parameters
330
    ----------
331
    sids : array-like of int
332
    start_date : pd.Timestamp
333
    end_date : pd.Timestamp
334
    symbols : list, optional
335
        Symbols to use for the assets.
336
        If not provided, symbols are generated from the sequence 'A', 'B', ...
337
338
    Returns
339
    -------
340
    info : pd.DataFrame
341
        DataFrame representing newly-created assets.
342
    """
343
    num_assets = len(sids)
344
    if symbols is None:
345
        symbols = list(ascii_uppercase[:num_assets])
346
    return pd.DataFrame(
347
        {
348
            'symbol': symbols,
349
            'start_date': [start_date] * num_assets,
350
            'end_date': [end_date] * num_assets,
351
            'exchange': 'TEST',
352
        },
353
        index=sids,
354
    )
355
356
357
def make_future_info(first_sid,
358
                     root_symbols,
359
                     years,
360
                     notice_date_func,
361
                     expiration_date_func,
362
                     start_date_func,
363
                     month_codes=None):
364
    """
365
    Create a DataFrame representing futures for `root_symbols` during `year`.
366
367
    Generates a contract per triple of (symbol, year, month) supplied to
368
    `root_symbols`, `years`, and `month_codes`.
369
370
    Parameters
371
    ----------
372
    first_sid : int
373
        The first sid to use for assigning sids to the created contracts.
374
    root_symbols : list[str]
375
        A list of root symbols for which to create futures.
376
    years : list[int or str]
377
        Years (e.g. 2014), for which to produce individual contracts.
378
    notice_date_func : (Timestamp) -> Timestamp
379
        Function to generate notice dates from first of the month associated
380
        with asset month code.  Return NaT to simulate futures with no notice
381
        date.
382
    expiration_date_func : (Timestamp) -> Timestamp
383
        Function to generate expiration dates from first of the month
384
        associated with asset month code.
385
    start_date_func : (Timestamp) -> Timestamp, optional
386
        Function to generate start dates from first of the month associated
387
        with each asset month code.  Defaults to a start_date one year prior
388
        to the month_code date.
389
    month_codes : dict[str -> [1..12]], optional
390
        Dictionary of month codes for which to create contracts.  Entries
391
        should be strings mapped to values from 1 (January) to 12 (December).
392
        Default is zipline.futures.CME_CODE_TO_MONTH
393
394
    Returns
395
    -------
396
    futures_info : pd.DataFrame
397
        DataFrame of futures data suitable for passing to an
398
        AssetDBWriterFromDataFrame.
399
    """
400
    if month_codes is None:
401
        month_codes = CME_CODE_TO_MONTH
402
403
    year_strs = list(map(str, years))
404
    years = [pd.Timestamp(s, tz='UTC') for s in year_strs]
405
406
    # Pairs of string/date like ('K06', 2006-05-01)
407
    contract_suffix_to_beginning_of_month = tuple(
408
        (month_code + year_str[-2:], year + MonthBegin(month_num))
409
        for ((year, year_str), (month_code, month_num))
410
        in product(
411
            zip(years, year_strs),
412
            iteritems(month_codes),
413
        )
414
    )
415
416
    contracts = []
417
    parts = product(root_symbols, contract_suffix_to_beginning_of_month)
418
    for sid, (root_sym, (suffix, month_begin)) in enumerate(parts, first_sid):
419
        contracts.append({
420
            'sid': sid,
421
            'root_symbol': root_sym,
422
            'symbol': root_sym + suffix,
423
            'start_date': start_date_func(month_begin),
424
            'notice_date': notice_date_func(month_begin),
425
            'expiration_date': notice_date_func(month_begin),
426
            'contract_multiplier': 500,
427
        })
428
    return pd.DataFrame.from_records(contracts, index='sid').convert_objects()
429
430
431
def make_commodity_future_info(first_sid,
432
                               root_symbols,
433
                               years,
434
                               month_codes=None):
435
    """
436
    Make futures testing data that simulates the notice/expiration date
437
    behavior of physical commodities like oil.
438
439
    Parameters
440
    ----------
441
    first_sid : int
442
    root_symbols : list[str]
443
    years : list[int]
444
    month_codes : dict[str -> int]
445
446
    Expiration dates are on the 20th of the month prior to the month code.
447
    Notice dates are are on the 20th two months prior to the month code.
448
    Start dates are one year before the contract month.
449
450
    See Also
451
    --------
452
    make_future_info
453
    """
454
    nineteen_days = pd.Timedelta(days=19)
455
    one_year = pd.Timedelta(days=365)
456
    return make_future_info(
457
        first_sid=first_sid,
458
        root_symbols=root_symbols,
459
        years=years,
460
        notice_date_func=lambda dt: dt - MonthBegin(2) + nineteen_days,
461
        expiration_date_func=lambda dt: dt - MonthBegin(1) + nineteen_days,
462
        start_date_func=lambda dt: dt - one_year,
463
        month_codes=month_codes,
464
    )
465
466
467
def make_simple_asset_info(assets, start_date, end_date, symbols=None):
468
    """
469
    Create a DataFrame representing assets that exist for the full duration
470
    between `start_date` and `end_date`.
471
    Parameters
472
    ----------
473
    assets : array-like
474
    start_date : pd.Timestamp
475
    end_date : pd.Timestamp
476
    symbols : list, optional
477
        Symbols to use for the assets.
478
        If not provided, symbols are generated from the sequence 'A', 'B', ...
479
    Returns
480
    -------
481
    info : pd.DataFrame
482
        DataFrame representing newly-created assets.
483
    """
484
    num_assets = len(assets)
485
    if symbols is None:
486
        symbols = list(ascii_uppercase[:num_assets])
487
    return pd.DataFrame(
488
        {
489
            'sid': assets,
490
            'symbol': symbols,
491
            'asset_type': ['equity'] * num_assets,
492
            'start_date': [start_date] * num_assets,
493
            'end_date': [end_date] * num_assets,
494
            'exchange': 'TEST',
495
        }
496
    )
497
498
499
def check_allclose(actual,
500
                   desired,
501
                   rtol=1e-07,
502
                   atol=0,
503
                   err_msg='',
504
                   verbose=True):
505
    """
506
    Wrapper around np.testing.assert_allclose that also verifies that inputs
507
    are ndarrays.
508
509
    See Also
510
    --------
511
    np.assert_allclose
512
    """
513
    if type(actual) != type(desired):
514
        raise AssertionError("%s != %s" % (type(actual), type(desired)))
515
    return assert_allclose(actual, desired, err_msg=err_msg, verbose=True)
516
517
518
def check_arrays(x, y, err_msg='', verbose=True):
519
    """
520
    Wrapper around np.testing.assert_array_equal that also verifies that inputs
521
    are ndarrays.
522
523
    See Also
524
    --------
525
    np.assert_array_equal
526
    """
527
    if type(x) != type(y):
528
        raise AssertionError("%s != %s" % (type(x), type(y)))
529
    return assert_array_equal(x, y, err_msg=err_msg, verbose=True)
530
531
532
class UnexpectedAttributeAccess(Exception):
533
    pass
534
535
536
class ExplodingObject(object):
537
    """
538
    Object that will raise an exception on any attribute access.
539
540
    Useful for verifying that an object is never touched during a
541
    function/method call.
542
    """
543
    def __getattribute__(self, name):
544
        raise UnexpectedAttributeAccess(name)
545
546
547
class DailyBarWriterFromDataFrames(BcolzDailyBarWriter):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
548
    _csv_dtypes = {
549
        'open': float64,
550
        'high': float64,
551
        'low': float64,
552
        'close': float64,
553
        'volume': float64,
554
    }
555
556
    def __init__(self, asset_map):
557
        self._asset_map = asset_map
558
559
    def gen_tables(self, assets):
560
        for asset in assets:
561
            yield asset, ctable.fromdataframe(assets[asset])
562
563
    def to_uint32(self, array, colname):
564
        arrmax = array.max()
565
        if colname in OHLC:
566
            self.check_uint_safe(arrmax * 1000, colname)
567
            return (array * 1000).astype(uint32)
568
        elif colname == 'volume':
569
            self.check_uint_safe(arrmax, colname)
570
            return array.astype(uint32)
571
        elif colname == 'day':
572
            nanos_per_second = (1000 * 1000 * 1000)
573
            self.check_uint_safe(arrmax.view(int) / nanos_per_second, colname)
574
            return (array.view(int) / nanos_per_second).astype(uint32)
575
576
    @staticmethod
577
    def check_uint_safe(value, colname):
578
        if value >= UINT32_MAX:
579
            raise ValueError(
580
                "Value %s from column '%s' is too large" % (value, colname)
581
            )
582
583
584
def write_minute_data(tempdir, minutes, sids, sid_path_func=None):
585
    assets = {}
586
587
    length = len(minutes)
588
589
    for sid_idx, sid in enumerate(sids):
590
        assets[sid] = pd.DataFrame({
591
            "open": (np.array(range(10, 10 + length)) + sid_idx) * 1000,
592
            "high": (np.array(range(15, 15 + length)) + sid_idx) * 1000,
593
            "low": (np.array(range(8, 8 + length)) + sid_idx) * 1000,
594
            "close": (np.array(range(10, 10 + length)) + sid_idx) * 1000,
595
            "volume": np.array(range(100, 100 + length)) + sid_idx,
596
            "minute": minutes
597
        }, index=minutes)
598
599
    MinuteBarWriterFromDataFrames().write(tempdir.path, assets,
600
                                          sid_path_func=sid_path_func)
601
602
    return tempdir.path
603
604
605
def write_daily_data(tempdir, sim_params, sids):
606
    path = os.path.join(tempdir.path, "testdaily.bcolz")
607
    assets = {}
608
    length = sim_params.days_in_period
609
    for sid_idx, sid in enumerate(sids):
610
        assets[sid] = pd.DataFrame({
611
            "open": (np.array(range(10, 10 + length)) + sid_idx),
612
            "high": (np.array(range(15, 15 + length)) + sid_idx),
613
            "low": (np.array(range(8, 8 + length)) + sid_idx),
614
            "close": (np.array(range(10, 10 + length)) + sid_idx),
615
            "volume": np.array(range(100, 100 + length)) + sid_idx,
616
            "day": [day.value for day in sim_params.trading_days]
617
        }, index=sim_params.trading_days)
618
619
    DailyBarWriterFromDataFrames(assets).write(
620
        path,
621
        sim_params.trading_days,
622
        assets
623
    )
624
625
    return path
626
627
628
def create_data_portal(env, tempdir, sim_params, sids, sid_path_func=None,
629
                       adjustment_reader=None):
630
    if sim_params.data_frequency == "daily":
631
        daily_path = write_daily_data(tempdir, sim_params, sids)
632
633
        return DataPortal(
634
            env,
635
            daily_equities_path=daily_path,
636
            sim_params=sim_params,
637
            adjustment_reader=adjustment_reader
638
        )
639
    else:
640
        minutes = env.minutes_for_days_in_range(
641
            sim_params.first_open,
642
            sim_params.last_close
643
        )
644
645
        minute_path = write_minute_data(tempdir, minutes, sids,
646
                                        sid_path_func)
647
648
        return DataPortal(
649
            env,
650
            minutes_equities_path=minute_path,
651
            sim_params=sim_params,
652
            adjustment_reader=adjustment_reader
653
        )
654
655
656
def create_data_portal_from_trade_history(env, tempdir, sim_params,
657
                                          trades_by_sid):
658
    if sim_params.data_frequency == "daily":
659
        path = os.path.join(tempdir.path, "testdaily.bcolz")
660
        assets = {}
661
        for sidint, trades in iteritems(trades_by_sid):
662
            opens = []
663
            highs = []
664
            lows = []
665
            closes = []
666
            volumes = []
667
            for trade in trades:
668
                opens.append(trade["open_price"])
669
                highs.append(trade["high"])
670
                lows.append(trade["low"])
671
                closes.append(trade["close_price"])
672
                volumes.append(trade["volume"])
673
674
            assets[sidint] = pd.DataFrame({
675
                "open": np.array(opens),
676
                "high": np.array(highs),
677
                "low": np.array(lows),
678
                "close": np.array(closes),
679
                "volume": np.array(volumes),
680
                "day": [day.value for day in sim_params.trading_days]
681
            }, index=sim_params.trading_days)
682
683
        DailyBarWriterFromDataFrames(assets).write(
684
            path,
685
            sim_params.trading_days,
686
            assets
687
        )
688
689
        return DataPortal(
690
            env,
691
            daily_equities_path=path,
692
            sim_params=sim_params,
693
        )
694
    else:
695
        minutes = env.minutes_for_days_in_range(
696
            sim_params.first_open,
697
            sim_params.last_close
698
        )
699
700
        length = len(minutes)
701
        assets = {}
702
703
        for sidint, trades in trades_by_sid.iteritems():
704
            opens = np.zeros(length)
705
            highs = np.zeros(length)
706
            lows = np.zeros(length)
707
            closes = np.zeros(length)
708
            volumes = np.zeros(length)
709
710
            for trade in trades:
711
                # put them in the right place
712
                idx = minutes.searchsorted(trade.dt)
713
714
                opens[idx] = trade.open_price * 1000
715
                highs[idx] = trade.high * 1000
716
                lows[idx] = trade.low * 1000
717
                closes[idx] = trade.close_price * 1000
718
                volumes[idx] = trade.volume
719
720
            assets[sidint] = pd.DataFrame({
721
                "open": opens,
722
                "high": highs,
723
                "low": lows,
724
                "close": closes,
725
                "volume": volumes,
726
                "minute": minutes
727
            }, index=minutes)
728
729
        MinuteBarWriterFromDataFrames().write(tempdir.path, assets)
730
731
        return DataPortal(
732
            env,
733
            minutes_equities_path=tempdir.path,
734
            sim_params=sim_params
735
        )
736
737
738
class FakeDataPortal(object):
739
740
    def __init__(self):
741
        self._adjustment_reader = None
742
743
744
class tmp_assets_db(object):
745
    """Create a temporary assets sqlite database.
746
    This is meant to be used as a context manager.
747
748
    Parameters
749
    ----------
750
    data : pd.DataFrame, optional
751
        The data to feed to the writer. By default this maps:
752
        ('A', 'B', 'C') -> map(ord, 'ABC')
753
    """
754
    def __init__(self, **frames):
755
        self._eng = None
756
        if not frames:
757
            frames = {
758
                'equities': make_simple_equity_info(
759
                    list(map(ord, 'ABC')),
760
                    pd.Timestamp(0),
761
                    pd.Timestamp('2015'),
762
                )
763
            }
764
        self._data = AssetDBWriterFromDataFrame(**frames)
765
766
    def __enter__(self):
767
        self._eng = eng = create_engine('sqlite://')
768
        self._data.write_all(eng)
769
        return eng
770
771
    def __exit__(self, *excinfo):
772
        assert self._eng is not None, '_eng was not set in __enter__'
773
        self._eng.dispose()
774
775
776
class tmp_asset_finder(tmp_assets_db):
777
    """Create a temporary asset finder using an in memory sqlite db.
778
779
    Parameters
780
    ----------
781
    data : dict, optional
782
        The data to feed to the writer
783
    """
784
    def __init__(self, finder_cls=AssetFinder, **frames):
785
        self._finder_cls = finder_cls
786
        super(tmp_asset_finder, self).__init__(**frames)
787
788
    def __enter__(self):
789
        return self._finder_cls(super(tmp_asset_finder, self).__enter__())
790
791
792
class SubTestFailures(AssertionError):
793
    def __init__(self, *failures):
794
        self.failures = failures
795
796
    def __str__(self):
797
        return 'failures:\n  %s' % '\n  '.join(
798
            '\n    '.join((
799
                ', '.join('%s=%r' % item for item in scope.items()),
800
                '%s: %s' % (type(exc).__name__, exc),
801
            )) for scope, exc in self.failures,
802
        )
803
804
805
def subtest(iterator, *_names):
806
    """Construct a subtest in a unittest.
807
808
    This works by decorating a function as a subtest. The test will be run
809
    by iterating over the ``iterator`` and *unpacking the values into the
810
    function. If any of the runs fail, the result will be put into a set and
811
    the rest of the tests will be run. Finally, if any failed, all of the
812
    results will be dumped as one failure.
813
814
    Parameters
815
    ----------
816
    iterator : iterable[iterable]
817
        The iterator of arguments to pass to the function.
818
    *name : iterator[str]
819
        The names to use for each element of ``iterator``. These will be used
820
        to print the scope when a test fails. If not provided, it will use the
821
        integer index of the value as the name.
822
823
    Examples
824
    --------
825
826
    ::
827
828
       class MyTest(TestCase):
829
           def test_thing(self):
830
               # Example usage inside another test.
831
               @subtest(([n] for n in range(100000)), 'n')
832
               def subtest(n):
833
                   self.assertEqual(n % 2, 0, 'n was not even')
834
               subtest()
835
836
           @subtest(([n] for n in range(100000)), 'n')
837
           def test_decorated_function(self, n):
838
               # Example usage to parameterize an entire function.
839
               self.assertEqual(n % 2, 1, 'n was not odd')
840
841
    Notes
842
    -----
843
    We use this when we:
844
845
    * Will never want to run each parameter individually.
846
    * Have a large parameter space we are testing
847
      (see tests/utils/test_events.py).
848
849
    ``nose_parameterized.expand`` will create a test for each parameter
850
    combination which bloats the test output and makes the travis pages slow.
851
852
    We cannot use ``unittest2.TestCase.subTest`` because nose, pytest, and
853
    nose2 do not support ``addSubTest``.
854
    """
855
    def dec(f):
856
        @wraps(f)
857
        def wrapped(*args, **kwargs):
858
            names = _names
859
            failures = []
860
            for scope in iterator:
861
                scope = tuple(scope)
862
                try:
863
                    f(*args + scope, **kwargs)
864
                except Exception as e:
865
                    if not names:
866
                        names = count()
867
                    failures.append((dict(zip(names, scope)), e))
868
            if failures:
869
                raise SubTestFailures(*failures)
870
871
        return wrapped
872
    return dec
873
874
875
class MockDailyBarReader(object):
876
    def spot_price(self, col, sid, dt):
877
        return 100
878
879
880
def create_mock_adjustments(tempdir, days, splits=None, dividends=None,
881
                            mergers=None):
882
    path = tempdir.getpath("test_adjustments.db")
883
884
    # create a split for the last day
885
    writer = SQLiteAdjustmentWriter(path, days, MockDailyBarReader())
886
    if splits is None:
887
        splits = pd.DataFrame({
888
            # Hackery to make the dtypes correct on an empty frame.
889
            'effective_date': np.array([], dtype=int),
890
            'ratio': np.array([], dtype=float),
891
            'sid': np.array([], dtype=int),
892
        }, index=pd.DatetimeIndex([], tz='UTC'))
893
    else:
894
        splits = pd.DataFrame(splits)
895
896
    if mergers is None:
897
        mergers = pd.DataFrame({
898
            # Hackery to make the dtypes correct on an empty frame.
899
            'effective_date': np.array([], dtype=int),
900
            'ratio': np.array([], dtype=float),
901
            'sid': np.array([], dtype=int),
902
        }, index=pd.DatetimeIndex([], tz='UTC'))
903
    else:
904
        mergers = pd.DataFrame(mergers)
905
906
    if dividends is None:
907
        data = {
908
            # Hackery to make the dtypes correct on an empty frame.
909
            'ex_date': np.array([], dtype='datetime64[ns]'),
910
            'pay_date': np.array([], dtype='datetime64[ns]'),
911
            'record_date': np.array([], dtype='datetime64[ns]'),
912
            'declared_date': np.array([], dtype='datetime64[ns]'),
913
            'amount': np.array([], dtype=float),
914
            'sid': np.array([], dtype=int),
915
        }
916
        dividends = pd.DataFrame(
917
            data,
918
            index=pd.DatetimeIndex([], tz='UTC'),
919
            columns=['ex_date',
920
                     'pay_date',
921
                     'record_date',
922
                     'declared_date',
923
                     'amount',
924
                     'sid']
925
        )
926
    else:
927
        if not isinstance(dividends, pd.DataFrame):
928
            dividends = pd.DataFrame(dividends)
929
930
    writer.write(splits, mergers, dividends)
931
932
    return path
933