Completed
Pull Request — master (#858)
by Eddie
01:43
created

zipline.utils.MockDailyBarReader   A

Complexity

Total Complexity 1

Size/Duplication

Total Lines 3
Duplicated Lines 0 %
Metric Value
dl 0
loc 3
rs 10
wmc 1

1 Method

Rating   Name   Duplication   Size   Complexity  
A spot_price() 0 2 1
1
from contextlib import contextmanager
2
from functools import wraps
3
from itertools import (
4
    combinations,
5
    count,
6
    product,
7
)
8
import operator
9
import os
10
import shutil
11
from string import ascii_uppercase
12
import tempfile
13
from bcolz import ctable
14
15
from logbook import FileHandler
16
from mock import patch
17
from numpy.testing import assert_allclose, assert_array_equal
18
import pandas as pd
19
from pandas.tseries.offsets import MonthBegin
20
from six import iteritems, itervalues
21
from six.moves import filter
22
from sqlalchemy import create_engine
23
from toolz import concat
24
25
from zipline.assets import AssetFinder
26
from zipline.assets.asset_writer import AssetDBWriterFromDataFrame
27
from zipline.assets.futures import CME_CODE_TO_MONTH
28
from zipline.data.data_portal import DataPortal
29
from zipline.data.us_equity_minutes import (
30
    MinuteBarWriterFromDataFrames,
31
    BcolzMinuteBarReader
32
)
33
from zipline.data.us_equity_pricing import SQLiteAdjustmentWriter, OHLC, \
34
    UINT32_MAX, BcolzDailyBarWriter, BcolzDailyBarReader
35
from zipline.finance.order import ORDER_STATUS
36
from zipline.utils import security_list
37
38
import numpy as np
39
from numpy import (
40
    float64,
41
    uint32
42
)
43
44
45
EPOCH = pd.Timestamp(0, tz='UTC')
46
47
48
def seconds_to_timestamp(seconds):
49
    return pd.Timestamp(seconds, unit='s', tz='UTC')
50
51
52
def to_utc(time_str):
53
    """Convert a string in US/Eastern time to UTC"""
54
    return pd.Timestamp(time_str, tz='US/Eastern').tz_convert('UTC')
55
56
57
def str_to_seconds(s):
58
    """
59
    Convert a pandas-intelligible string to (integer) seconds since UTC.
60
61
    >>> from pandas import Timestamp
62
    >>> (Timestamp('2014-01-01') - Timestamp(0)).total_seconds()
63
    1388534400.0
64
    >>> str_to_seconds('2014-01-01')
65
    1388534400
66
    """
67
    return int((pd.Timestamp(s, tz='UTC') - EPOCH).total_seconds())
68
69
70
def setup_logger(test, path='test.log'):
71
    test.log_handler = FileHandler(path)
72
    test.log_handler.push_application()
73
74
75
def teardown_logger(test):
76
    test.log_handler.pop_application()
77
    test.log_handler.close()
78
79
80
def drain_zipline(test, zipline):
81
    output = []
82
    transaction_count = 0
83
    msg_counter = 0
84
    # start the simulation
85
    for update in zipline:
86
        msg_counter += 1
87
        output.append(update)
88
        if 'daily_perf' in update:
89
            transaction_count += \
90
                len(update['daily_perf']['transactions'])
91
92
    return output, transaction_count
93
94
95
def check_algo_results(test,
96
                       results,
97
                       expected_transactions_count=None,
98
                       expected_order_count=None,
99
                       expected_positions_count=None,
100
                       sid=None):
101
102
    if expected_transactions_count is not None:
103
        txns = flatten_list(results["transactions"])
104
        test.assertEqual(expected_transactions_count, len(txns))
105
106
    if expected_positions_count is not None:
107
        raise NotImplementedError
108
109
    if expected_order_count is not None:
110
        # de-dup orders on id, because orders are put back into perf packets
111
        # whenever they a txn is filled
112
        orders = set([order['id'] for order in
113
                      flatten_list(results["orders"])])
114
115
        test.assertEqual(expected_order_count, len(orders))
116
117
118
def flatten_list(list):
119
    return [item for sublist in list for item in sublist]
120
121
122
def assert_single_position(test, zipline):
123
124
    output, transaction_count = drain_zipline(test, zipline)
125
126
    if 'expected_transactions' in test.zipline_test_config:
127
        test.assertEqual(
128
            test.zipline_test_config['expected_transactions'],
129
            transaction_count
130
        )
131
    else:
132
        test.assertEqual(
133
            test.zipline_test_config['order_count'],
134
            transaction_count
135
        )
136
137
    # the final message is the risk report, the second to
138
    # last is the final day's results. Positions is a list of
139
    # dicts.
140
    closing_positions = output[-2]['daily_perf']['positions']
141
142
    # confirm that all orders were filled.
143
    # iterate over the output updates, overwriting
144
    # orders when they are updated. Then check the status on all.
145
    orders_by_id = {}
146
    for update in output:
147
        if 'daily_perf' in update:
148
            if 'orders' in update['daily_perf']:
149
                for order in update['daily_perf']['orders']:
150
                    orders_by_id[order['id']] = order
151
152
    for order in itervalues(orders_by_id):
153
        test.assertEqual(
154
            order['status'],
155
            ORDER_STATUS.FILLED,
156
            "")
157
158
    test.assertEqual(
159
        len(closing_positions),
160
        1,
161
        "Portfolio should have one position."
162
    )
163
164
    sid = test.zipline_test_config['sid']
165
    test.assertEqual(
166
        closing_positions[0]['sid'],
167
        sid,
168
        "Portfolio should have one position in " + str(sid)
169
    )
170
171
    return output, transaction_count
172
173
174
class ExceptionSource(object):
175
176
    def __init__(self):
177
        pass
178
179
    def get_hash(self):
180
        return "ExceptionSource"
181
182
    def __iter__(self):
183
        return self
184
185
    def next(self):
186
        5 / 0
187
188
    def __next__(self):
189
        5 / 0
190
191
192
@contextmanager
193
def security_list_copy():
194
    old_dir = security_list.SECURITY_LISTS_DIR
195
    new_dir = tempfile.mkdtemp()
196
    try:
197
        for subdir in os.listdir(old_dir):
198
            shutil.copytree(os.path.join(old_dir, subdir),
199
                            os.path.join(new_dir, subdir))
200
            with patch.object(security_list, 'SECURITY_LISTS_DIR', new_dir), \
201
                    patch.object(security_list, 'using_copy', True,
202
                                 create=True):
203
                yield
204
    finally:
205
        shutil.rmtree(new_dir, True)
206
207
208
def add_security_data(adds, deletes):
209
    if not hasattr(security_list, 'using_copy'):
210
        raise Exception('add_security_data must be used within '
211
                        'security_list_copy context')
212
    directory = os.path.join(
213
        security_list.SECURITY_LISTS_DIR,
214
        "leveraged_etf_list/20150127/20150125"
215
    )
216
    if not os.path.exists(directory):
217
        os.makedirs(directory)
218
    del_path = os.path.join(directory, "delete")
219
    with open(del_path, 'w') as f:
220
        for sym in deletes:
221
            f.write(sym)
222
            f.write('\n')
223
    add_path = os.path.join(directory, "add")
224
    with open(add_path, 'w') as f:
225
        for sym in adds:
226
            f.write(sym)
227
            f.write('\n')
228
229
230
def all_pairs_matching_predicate(values, pred):
231
    """
232
    Return an iterator of all pairs, (v0, v1) from values such that
233
234
    `pred(v0, v1) == True`
235
236
    Parameters
237
    ----------
238
    values : iterable
239
    pred : function
240
241
    Returns
242
    -------
243
    pairs_iterator : generator
244
       Generator yielding pairs matching `pred`.
245
246
    Examples
247
    --------
248
    >>> from zipline.utils.test_utils import all_pairs_matching_predicate
249
    >>> from operator import eq, lt
250
    >>> list(all_pairs_matching_predicate(range(5), eq))
251
    [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)]
252
    >>> list(all_pairs_matching_predicate("abcd", lt))
253
    [('a', 'b'), ('a', 'c'), ('a', 'd'), ('b', 'c'), ('b', 'd'), ('c', 'd')]
254
    """
255
    return filter(lambda pair: pred(*pair), product(values, repeat=2))
256
257
258
def product_upper_triangle(values, include_diagonal=False):
259
    """
260
    Return an iterator over pairs, (v0, v1), drawn from values.
261
262
    If `include_diagonal` is True, returns all pairs such that v0 <= v1.
263
    If `include_diagonal` is False, returns all pairs such that v0 < v1.
264
    """
265
    return all_pairs_matching_predicate(
266
        values,
267
        operator.le if include_diagonal else operator.lt,
268
    )
269
270
271
def all_subindices(index):
272
    """
273
    Return all valid sub-indices of a pandas Index.
274
    """
275
    return (
276
        index[start:stop]
277
        for start, stop in product_upper_triangle(range(len(index) + 1))
278
    )
279
280
281
def make_rotating_equity_info(num_assets,
282
                              first_start,
283
                              frequency,
284
                              periods_between_starts,
285
                              asset_lifetime):
286
    """
287
    Create a DataFrame representing lifetimes of assets that are constantly
288
    rotating in and out of existence.
289
290
    Parameters
291
    ----------
292
    num_assets : int
293
        How many assets to create.
294
    first_start : pd.Timestamp
295
        The start date for the first asset.
296
    frequency : str or pd.tseries.offsets.Offset (e.g. trading_day)
297
        Frequency used to interpret next two arguments.
298
    periods_between_starts : int
299
        Create a new asset every `frequency` * `periods_between_new`
300
    asset_lifetime : int
301
        Each asset exists for `frequency` * `asset_lifetime` days.
302
303
    Returns
304
    -------
305
    info : pd.DataFrame
306
        DataFrame representing newly-created assets.
307
    """
308
    return pd.DataFrame(
309
        {
310
            'symbol': [chr(ord('A') + i) for i in range(num_assets)],
311
            # Start a new asset every `periods_between_starts` days.
312
            'start_date': pd.date_range(
313
                first_start,
314
                freq=(periods_between_starts * frequency),
315
                periods=num_assets,
316
            ),
317
            # Each asset lasts for `asset_lifetime` days.
318
            'end_date': pd.date_range(
319
                first_start + (asset_lifetime * frequency),
320
                freq=(periods_between_starts * frequency),
321
                periods=num_assets,
322
            ),
323
            'exchange': 'TEST',
324
        },
325
        index=range(num_assets),
326
    )
327
328
329
def make_simple_equity_info(sids, start_date, end_date, symbols=None):
330
    """
331
    Create a DataFrame representing assets that exist for the full duration
332
    between `start_date` and `end_date`.
333
334
    Parameters
335
    ----------
336
    sids : array-like of int
337
    start_date : pd.Timestamp
338
    end_date : pd.Timestamp
339
    symbols : list, optional
340
        Symbols to use for the assets.
341
        If not provided, symbols are generated from the sequence 'A', 'B', ...
342
343
    Returns
344
    -------
345
    info : pd.DataFrame
346
        DataFrame representing newly-created assets.
347
    """
348
    num_assets = len(sids)
349
    if symbols is None:
350
        symbols = list(ascii_uppercase[:num_assets])
351
    return pd.DataFrame(
352
        {
353
            'symbol': symbols,
354
            'start_date': [start_date] * num_assets,
355
            'end_date': [end_date] * num_assets,
356
            'exchange': 'TEST',
357
        },
358
        index=sids,
359
    )
360
361
362
def make_future_info(first_sid,
363
                     root_symbols,
364
                     years,
365
                     notice_date_func,
366
                     expiration_date_func,
367
                     start_date_func,
368
                     month_codes=None):
369
    """
370
    Create a DataFrame representing futures for `root_symbols` during `year`.
371
372
    Generates a contract per triple of (symbol, year, month) supplied to
373
    `root_symbols`, `years`, and `month_codes`.
374
375
    Parameters
376
    ----------
377
    first_sid : int
378
        The first sid to use for assigning sids to the created contracts.
379
    root_symbols : list[str]
380
        A list of root symbols for which to create futures.
381
    years : list[int or str]
382
        Years (e.g. 2014), for which to produce individual contracts.
383
    notice_date_func : (Timestamp) -> Timestamp
384
        Function to generate notice dates from first of the month associated
385
        with asset month code.  Return NaT to simulate futures with no notice
386
        date.
387
    expiration_date_func : (Timestamp) -> Timestamp
388
        Function to generate expiration dates from first of the month
389
        associated with asset month code.
390
    start_date_func : (Timestamp) -> Timestamp, optional
391
        Function to generate start dates from first of the month associated
392
        with each asset month code.  Defaults to a start_date one year prior
393
        to the month_code date.
394
    month_codes : dict[str -> [1..12]], optional
395
        Dictionary of month codes for which to create contracts.  Entries
396
        should be strings mapped to values from 1 (January) to 12 (December).
397
        Default is zipline.futures.CME_CODE_TO_MONTH
398
399
    Returns
400
    -------
401
    futures_info : pd.DataFrame
402
        DataFrame of futures data suitable for passing to an
403
        AssetDBWriterFromDataFrame.
404
    """
405
    if month_codes is None:
406
        month_codes = CME_CODE_TO_MONTH
407
408
    year_strs = list(map(str, years))
409
    years = [pd.Timestamp(s, tz='UTC') for s in year_strs]
410
411
    # Pairs of string/date like ('K06', 2006-05-01)
412
    contract_suffix_to_beginning_of_month = tuple(
413
        (month_code + year_str[-2:], year + MonthBegin(month_num))
414
        for ((year, year_str), (month_code, month_num))
415
        in product(
416
            zip(years, year_strs),
417
            iteritems(month_codes),
418
        )
419
    )
420
421
    contracts = []
422
    parts = product(root_symbols, contract_suffix_to_beginning_of_month)
423
    for sid, (root_sym, (suffix, month_begin)) in enumerate(parts, first_sid):
424
        contracts.append({
425
            'sid': sid,
426
            'root_symbol': root_sym,
427
            'symbol': root_sym + suffix,
428
            'start_date': start_date_func(month_begin),
429
            'notice_date': notice_date_func(month_begin),
430
            'expiration_date': notice_date_func(month_begin),
431
            'contract_multiplier': 500,
432
        })
433
    return pd.DataFrame.from_records(contracts, index='sid').convert_objects()
434
435
436
def make_commodity_future_info(first_sid,
437
                               root_symbols,
438
                               years,
439
                               month_codes=None):
440
    """
441
    Make futures testing data that simulates the notice/expiration date
442
    behavior of physical commodities like oil.
443
444
    Parameters
445
    ----------
446
    first_sid : int
447
    root_symbols : list[str]
448
    years : list[int]
449
    month_codes : dict[str -> int]
450
451
    Expiration dates are on the 20th of the month prior to the month code.
452
    Notice dates are are on the 20th two months prior to the month code.
453
    Start dates are one year before the contract month.
454
455
    See Also
456
    --------
457
    make_future_info
458
    """
459
    nineteen_days = pd.Timedelta(days=19)
460
    one_year = pd.Timedelta(days=365)
461
    return make_future_info(
462
        first_sid=first_sid,
463
        root_symbols=root_symbols,
464
        years=years,
465
        notice_date_func=lambda dt: dt - MonthBegin(2) + nineteen_days,
466
        expiration_date_func=lambda dt: dt - MonthBegin(1) + nineteen_days,
467
        start_date_func=lambda dt: dt - one_year,
468
        month_codes=month_codes,
469
    )
470
471
472
def make_simple_asset_info(assets, start_date, end_date, symbols=None):
473
    """
474
    Create a DataFrame representing assets that exist for the full duration
475
    between `start_date` and `end_date`.
476
    Parameters
477
    ----------
478
    assets : array-like
479
    start_date : pd.Timestamp
480
    end_date : pd.Timestamp
481
    symbols : list, optional
482
        Symbols to use for the assets.
483
        If not provided, symbols are generated from the sequence 'A', 'B', ...
484
    Returns
485
    -------
486
    info : pd.DataFrame
487
        DataFrame representing newly-created assets.
488
    """
489
    num_assets = len(assets)
490
    if symbols is None:
491
        symbols = list(ascii_uppercase[:num_assets])
492
    return pd.DataFrame(
493
        {
494
            'sid': assets,
495
            'symbol': symbols,
496
            'asset_type': ['equity'] * num_assets,
497
            'start_date': [start_date] * num_assets,
498
            'end_date': [end_date] * num_assets,
499
            'exchange': 'TEST',
500
        }
501
    )
502
503
504
def check_allclose(actual,
505
                   desired,
506
                   rtol=1e-07,
507
                   atol=0,
508
                   err_msg='',
509
                   verbose=True):
510
    """
511
    Wrapper around np.testing.assert_allclose that also verifies that inputs
512
    are ndarrays.
513
514
    See Also
515
    --------
516
    np.assert_allclose
517
    """
518
    if type(actual) != type(desired):
519
        raise AssertionError("%s != %s" % (type(actual), type(desired)))
520
    return assert_allclose(actual, desired, err_msg=err_msg, verbose=True)
521
522
523
def check_arrays(x, y, err_msg='', verbose=True):
524
    """
525
    Wrapper around np.testing.assert_array_equal that also verifies that inputs
526
    are ndarrays.
527
528
    See Also
529
    --------
530
    np.assert_array_equal
531
    """
532
    if type(x) != type(y):
533
        raise AssertionError("%s != %s" % (type(x), type(y)))
534
    return assert_array_equal(x, y, err_msg=err_msg, verbose=True)
535
536
537
class UnexpectedAttributeAccess(Exception):
538
    pass
539
540
541
class ExplodingObject(object):
542
    """
543
    Object that will raise an exception on any attribute access.
544
545
    Useful for verifying that an object is never touched during a
546
    function/method call.
547
    """
548
    def __getattribute__(self, name):
549
        raise UnexpectedAttributeAccess(name)
550
551
552
class DailyBarWriterFromDataFrames(BcolzDailyBarWriter):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
553
    _csv_dtypes = {
554
        'open': float64,
555
        'high': float64,
556
        'low': float64,
557
        'close': float64,
558
        'volume': float64,
559
    }
560
561
    def __init__(self, asset_map):
562
        self._asset_map = asset_map
563
564
    def gen_tables(self, assets):
565
        for asset in assets:
566
            yield asset, ctable.fromdataframe(assets[asset])
567
568
    def to_uint32(self, array, colname):
569
        arrmax = array.max()
570
        if colname in OHLC:
571
            self.check_uint_safe(arrmax * 1000, colname)
572
            return (array * 1000).astype(uint32)
573
        elif colname == 'volume':
574
            self.check_uint_safe(arrmax, colname)
575
            return array.astype(uint32)
576
        elif colname == 'day':
577
            nanos_per_second = (1000 * 1000 * 1000)
578
            self.check_uint_safe(arrmax.view(int) / nanos_per_second, colname)
579
            return (array.view(int) / nanos_per_second).astype(uint32)
580
581
    @staticmethod
582
    def check_uint_safe(value, colname):
583
        if value >= UINT32_MAX:
584
            raise ValueError(
585
                "Value %s from column '%s' is too large" % (value, colname)
586
            )
587
588
589
def write_minute_data(tempdir, minutes, sids, sid_path_func=None):
590
    assets = {}
591
592
    length = len(minutes)
593
594
    for sid_idx, sid in enumerate(sids):
595
        assets[sid] = pd.DataFrame({
596
            "open": (np.array(range(10, 10 + length)) + sid_idx) * 1000,
597
            "high": (np.array(range(15, 15 + length)) + sid_idx) * 1000,
598
            "low": (np.array(range(8, 8 + length)) + sid_idx) * 1000,
599
            "close": (np.array(range(10, 10 + length)) + sid_idx) * 1000,
600
            "volume": np.array(range(100, 100 + length)) + sid_idx,
601
            "minute": minutes
602
        }, index=minutes)
603
604
    MinuteBarWriterFromDataFrames(pd.Timestamp('2002-01-02', tz='UTC')).write(
605
        tempdir.path, assets, sid_path_func=sid_path_func)
606
607
    return tempdir.path
608
609
610
def write_daily_data(tempdir, sim_params, sids):
611
    path = os.path.join(tempdir.path, "testdaily.bcolz")
612
    assets = {}
613
    length = sim_params.days_in_period
614
    for sid_idx, sid in enumerate(sids):
615
        assets[sid] = pd.DataFrame({
616
            "open": (np.array(range(10, 10 + length)) + sid_idx),
617
            "high": (np.array(range(15, 15 + length)) + sid_idx),
618
            "low": (np.array(range(8, 8 + length)) + sid_idx),
619
            "close": (np.array(range(10, 10 + length)) + sid_idx),
620
            "volume": np.array(range(100, 100 + length)) + sid_idx,
621
            "day": [day.value for day in sim_params.trading_days]
622
        }, index=sim_params.trading_days)
623
624
    DailyBarWriterFromDataFrames(assets).write(
625
        path,
626
        sim_params.trading_days,
627
        assets
628
    )
629
630
    return path
631
632
633
def create_data_portal(env, tempdir, sim_params, sids, sid_path_func=None,
634
                       adjustment_reader=None):
635
    if sim_params.data_frequency == "daily":
636
        daily_path = write_daily_data(tempdir, sim_params, sids)
637
638
        equity_daily_reader = BcolzDailyBarReader(daily_path)
639
640
        return DataPortal(
641
            env,
642
            equity_daily_reader=equity_daily_reader,
643
            adjustment_reader=adjustment_reader
644
        )
645
    else:
646
        minutes = env.minutes_for_days_in_range(
647
            sim_params.first_open,
648
            sim_params.last_close
649
        )
650
651
        minute_path = write_minute_data(tempdir, minutes, sids,
652
                                        sid_path_func)
653
654
        equity_minute_reader = BcolzMinuteBarReader(minute_path)
655
656
        return DataPortal(
657
            env,
658
            equity_minute_reader=equity_minute_reader,
659
            adjustment_reader=adjustment_reader
660
        )
661
662
663
def create_data_portal_from_trade_history(env, tempdir, sim_params,
664
                                          trades_by_sid):
665
    if sim_params.data_frequency == "daily":
666
        path = os.path.join(tempdir.path, "testdaily.bcolz")
667
        assets = {}
668
        for sidint, trades in iteritems(trades_by_sid):
669
            opens = []
670
            highs = []
671
            lows = []
672
            closes = []
673
            volumes = []
674
            for trade in trades:
675
                opens.append(trade["open_price"])
676
                highs.append(trade["high"])
677
                lows.append(trade["low"])
678
                closes.append(trade["close_price"])
679
                volumes.append(trade["volume"])
680
681
            assets[sidint] = pd.DataFrame({
682
                "open": np.array(opens),
683
                "high": np.array(highs),
684
                "low": np.array(lows),
685
                "close": np.array(closes),
686
                "volume": np.array(volumes),
687
                "day": [day.value for day in sim_params.trading_days]
688
            }, index=sim_params.trading_days)
689
690
        DailyBarWriterFromDataFrames(assets).write(
691
            path,
692
            sim_params.trading_days,
693
            assets
694
        )
695
696
        equity_daily_reader = BcolzDailyBarReader(path)
697
698
        return DataPortal(
699
            env,
700
            equity_daily_reader=equity_daily_reader,
701
        )
702
    else:
703
        minutes = env.minutes_for_days_in_range(
704
            sim_params.first_open,
705
            sim_params.last_close
706
        )
707
708
        length = len(minutes)
709
        assets = {}
710
711
        for sidint, trades in trades_by_sid.iteritems():
712
            opens = np.zeros(length)
713
            highs = np.zeros(length)
714
            lows = np.zeros(length)
715
            closes = np.zeros(length)
716
            volumes = np.zeros(length)
717
718
            for trade in trades:
719
                # put them in the right place
720
                idx = minutes.searchsorted(trade.dt)
721
722
                opens[idx] = trade.open_price * 1000
723
                highs[idx] = trade.high * 1000
724
                lows[idx] = trade.low * 1000
725
                closes[idx] = trade.close_price * 1000
726
                volumes[idx] = trade.volume
727
728
            assets[sidint] = pd.DataFrame({
729
                "open": opens,
730
                "high": highs,
731
                "low": lows,
732
                "close": closes,
733
                "volume": volumes,
734
                "minute": minutes
735
            }, index=minutes)
736
737
        MinuteBarWriterFromDataFrames(pd.Timestamp('2002-01-02', tz='UTC')).\
738
            write(tempdir.path, assets)
739
740
        equity_minute_reader = BcolzMinuteBarReader(tempdir.path)
741
742
        return DataPortal(
743
            env,
744
            equity_minute_reader=equity_minute_reader,
745
            sim_params=sim_params
746
        )
747
748
749
class FakeDataPortal(object):
750
751
    def __init__(self):
752
        self._adjustment_reader = None
753
754
    def setup_offset_cache(self, minutes_by_day, minutes_to_day, trading_days):
755
        pass
756
757
758
class FetcherDataPortal(DataPortal):
759
    """
760
    Mock dataportal that returns fake data for history and non-fetcher
761
    spot value.
762
    """
763
    def __init__(self, env, sim_params):
764
        super(FetcherDataPortal, self).__init__(
765
            env,
766
            sim_params
767
        )
768
769
    def get_spot_value(self, asset, field, dt, data_frequency):
770
        # if this is a fetcher field, exercise the regular code path
771
        if self._check_extra_sources(asset, field, (dt or self.current_dt)):
772
            return super(FetcherDataPortal, self).get_spot_value(
773
                asset, field, dt, data_frequency)
774
775
        # otherwise just return a fixed value
776
        return int(asset)
777
778
    def setup_offset_cache(self, minutes_by_day, minutes_to_day, trading_days):
779
        pass
780
781
    def _get_daily_window_for_sid(self, asset, field, days_in_window,
782
                                  extra_slot=True):
783
        return np.arange(days_in_window, dtype=np.float64)
784
785
    def _get_minute_window_for_asset(self, asset, field, minutes_for_window):
786
        return np.arange(minutes_for_window, dtype=np.float64)
787
788
789
class tmp_assets_db(object):
790
    """Create a temporary assets sqlite database.
791
    This is meant to be used as a context manager.
792
793
    Parameters
794
    ----------
795
    data : pd.DataFrame, optional
796
        The data to feed to the writer. By default this maps:
797
        ('A', 'B', 'C') -> map(ord, 'ABC')
798
    """
799
    def __init__(self, **frames):
800
        self._eng = None
801
        if not frames:
802
            frames = {
803
                'equities': make_simple_equity_info(
804
                    list(map(ord, 'ABC')),
805
                    pd.Timestamp(0),
806
                    pd.Timestamp('2015'),
807
                )
808
            }
809
        self._data = AssetDBWriterFromDataFrame(**frames)
810
811
    def __enter__(self):
812
        self._eng = eng = create_engine('sqlite://')
813
        self._data.write_all(eng)
814
        return eng
815
816
    def __exit__(self, *excinfo):
817
        assert self._eng is not None, '_eng was not set in __enter__'
818
        self._eng.dispose()
819
820
821
class tmp_asset_finder(tmp_assets_db):
822
    """Create a temporary asset finder using an in memory sqlite db.
823
824
    Parameters
825
    ----------
826
    data : dict, optional
827
        The data to feed to the writer
828
    """
829
    def __init__(self, finder_cls=AssetFinder, **frames):
830
        self._finder_cls = finder_cls
831
        super(tmp_asset_finder, self).__init__(**frames)
832
833
    def __enter__(self):
834
        return self._finder_cls(super(tmp_asset_finder, self).__enter__())
835
836
837
class SubTestFailures(AssertionError):
838
    def __init__(self, *failures):
839
        self.failures = failures
840
841
    def __str__(self):
842
        return 'failures:\n  %s' % '\n  '.join(
843
            '\n    '.join((
844
                ', '.join('%s=%r' % item for item in scope.items()),
845
                '%s: %s' % (type(exc).__name__, exc),
846
            )) for scope, exc in self.failures,
847
        )
848
849
850
def subtest(iterator, *_names):
851
    """Construct a subtest in a unittest.
852
853
    This works by decorating a function as a subtest. The test will be run
854
    by iterating over the ``iterator`` and *unpacking the values into the
855
    function. If any of the runs fail, the result will be put into a set and
856
    the rest of the tests will be run. Finally, if any failed, all of the
857
    results will be dumped as one failure.
858
859
    Parameters
860
    ----------
861
    iterator : iterable[iterable]
862
        The iterator of arguments to pass to the function.
863
    *name : iterator[str]
864
        The names to use for each element of ``iterator``. These will be used
865
        to print the scope when a test fails. If not provided, it will use the
866
        integer index of the value as the name.
867
868
    Examples
869
    --------
870
871
    ::
872
873
       class MyTest(TestCase):
874
           def test_thing(self):
875
               # Example usage inside another test.
876
               @subtest(([n] for n in range(100000)), 'n')
877
               def subtest(n):
878
                   self.assertEqual(n % 2, 0, 'n was not even')
879
               subtest()
880
881
           @subtest(([n] for n in range(100000)), 'n')
882
           def test_decorated_function(self, n):
883
               # Example usage to parameterize an entire function.
884
               self.assertEqual(n % 2, 1, 'n was not odd')
885
886
    Notes
887
    -----
888
    We use this when we:
889
890
    * Will never want to run each parameter individually.
891
    * Have a large parameter space we are testing
892
      (see tests/utils/test_events.py).
893
894
    ``nose_parameterized.expand`` will create a test for each parameter
895
    combination which bloats the test output and makes the travis pages slow.
896
897
    We cannot use ``unittest2.TestCase.subTest`` because nose, pytest, and
898
    nose2 do not support ``addSubTest``.
899
    """
900
    def dec(f):
901
        @wraps(f)
902
        def wrapped(*args, **kwargs):
903
            names = _names
904
            failures = []
905
            for scope in iterator:
906
                scope = tuple(scope)
907
                try:
908
                    f(*args + scope, **kwargs)
909
                except Exception as e:
910
                    if not names:
911
                        names = count()
912
                    failures.append((dict(zip(names, scope)), e))
913
            if failures:
914
                raise SubTestFailures(*failures)
915
916
        return wrapped
917
    return dec
918
919
920
class MockDailyBarReader(object):
921
    def spot_price(self, col, sid, dt):
922
        return 100
923
924
925
def create_mock_adjustments(tempdir, days, splits=None, dividends=None,
926
                            mergers=None):
927
    path = tempdir.getpath("test_adjustments.db")
928
929
    # create a split for the last day
930
    writer = SQLiteAdjustmentWriter(path, days, MockDailyBarReader())
931
    if splits is None:
932
        splits = pd.DataFrame({
933
            # Hackery to make the dtypes correct on an empty frame.
934
            'effective_date': np.array([], dtype=int),
935
            'ratio': np.array([], dtype=float),
936
            'sid': np.array([], dtype=int),
937
        }, index=pd.DatetimeIndex([], tz='UTC'))
938
    else:
939
        splits = pd.DataFrame(splits)
940
941
    if mergers is None:
942
        mergers = pd.DataFrame({
943
            # Hackery to make the dtypes correct on an empty frame.
944
            'effective_date': np.array([], dtype=int),
945
            'ratio': np.array([], dtype=float),
946
            'sid': np.array([], dtype=int),
947
        }, index=pd.DatetimeIndex([], tz='UTC'))
948
    else:
949
        mergers = pd.DataFrame(mergers)
950
951
    if dividends is None:
952
        data = {
953
            # Hackery to make the dtypes correct on an empty frame.
954
            'ex_date': np.array([], dtype='datetime64[ns]'),
955
            'pay_date': np.array([], dtype='datetime64[ns]'),
956
            'record_date': np.array([], dtype='datetime64[ns]'),
957
            'declared_date': np.array([], dtype='datetime64[ns]'),
958
            'amount': np.array([], dtype=float),
959
            'sid': np.array([], dtype=int),
960
        }
961
        dividends = pd.DataFrame(
962
            data,
963
            index=pd.DatetimeIndex([], tz='UTC'),
964
            columns=['ex_date',
965
                     'pay_date',
966
                     'record_date',
967
                     'declared_date',
968
                     'amount',
969
                     'sid']
970
        )
971
    else:
972
        if not isinstance(dividends, pd.DataFrame):
973
            dividends = pd.DataFrame(dividends)
974
975
    writer.write(splits, mergers, dividends)
976
977
    return path
978
979
980
def assert_timestamp_equal(left, right, compare_nat_equal=True, msg=""):
981
    """
982
    Assert that two pandas Timestamp objects are the same.
983
984
    Parameters
985
    ----------
986
    left, right : pd.Timestamp
987
        The values to compare.
988
    compare_nat_equal : bool, optional
989
        Whether to consider `NaT` values equal.  Defaults to True.
990
    msg : str, optional
991
        A message to forward to `pd.util.testing.assert_equal`.
992
    """
993
    if compare_nat_equal and left is pd.NaT and right is pd.NaT:
994
        return
995
    return pd.util.testing.assert_equal(left, right, msg=msg)
996
997
998
def powerset(values):
999
    """
1000
    Return the power set (i.e., the set of all subsets) of entries in `values`.
1001
    """
1002
    return concat(combinations(values, i) for i in range(len(values) + 1))
1003