Completed
Pull Request — master (#858)
by Eddie
01:58
created

zipline.utils.FakeDataPortal.__init__()   A

Complexity

Conditions 1

Size

Total Lines 2

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 1
dl 0
loc 2
rs 10
1
from contextlib import contextmanager
2
from functools import wraps
3
from itertools import (
4
    count,
5
    product,
6
)
7
import operator
8
import os
9
import shutil
10
from string import ascii_uppercase
11
import tempfile
12
from bcolz import ctable
13
14
from logbook import FileHandler
15
from mock import patch
16
from numpy.testing import assert_allclose, assert_array_equal
17
import pandas as pd
18
from pandas.tseries.offsets import MonthBegin
19
from six import iteritems, itervalues
20
from six.moves import filter
21
from sqlalchemy import create_engine
22
23
from zipline.assets import AssetFinder
24
from zipline.assets.asset_writer import AssetDBWriterFromDataFrame
25
from zipline.assets.futures import CME_CODE_TO_MONTH
26
from zipline.data.data_portal import DataPortal
27
from zipline.data.us_equity_minutes import (
28
    MinuteBarWriterFromDataFrames,
29
    BcolzMinuteBarReader
30
)
31
from zipline.data.us_equity_pricing import SQLiteAdjustmentWriter, OHLC, \
32
    UINT32_MAX, BcolzDailyBarWriter
33
from zipline.finance.order import ORDER_STATUS
34
from zipline.utils import security_list
35
36
import numpy as np
37
from numpy import (
38
    float64,
39
    uint32
40
)
41
42
43
EPOCH = pd.Timestamp(0, tz='UTC')
44
45
46
def seconds_to_timestamp(seconds):
47
    return pd.Timestamp(seconds, unit='s', tz='UTC')
48
49
50
def to_utc(time_str):
51
    """Convert a string in US/Eastern time to UTC"""
52
    return pd.Timestamp(time_str, tz='US/Eastern').tz_convert('UTC')
53
54
55
def str_to_seconds(s):
56
    """
57
    Convert a pandas-intelligible string to (integer) seconds since UTC.
58
59
    >>> from pandas import Timestamp
60
    >>> (Timestamp('2014-01-01') - Timestamp(0)).total_seconds()
61
    1388534400.0
62
    >>> str_to_seconds('2014-01-01')
63
    1388534400
64
    """
65
    return int((pd.Timestamp(s, tz='UTC') - EPOCH).total_seconds())
66
67
68
def setup_logger(test, path='test.log'):
69
    test.log_handler = FileHandler(path)
70
    test.log_handler.push_application()
71
72
73
def teardown_logger(test):
74
    test.log_handler.pop_application()
75
    test.log_handler.close()
76
77
78
def drain_zipline(test, zipline):
79
    output = []
80
    transaction_count = 0
81
    msg_counter = 0
82
    # start the simulation
83
    for update in zipline:
84
        msg_counter += 1
85
        output.append(update)
86
        if 'daily_perf' in update:
87
            transaction_count += \
88
                len(update['daily_perf']['transactions'])
89
90
    return output, transaction_count
91
92
93
def check_algo_results(test,
94
                       results,
95
                       expected_transactions_count=None,
96
                       expected_order_count=None,
97
                       expected_positions_count=None,
98
                       sid=None):
99
100
    if expected_transactions_count is not None:
101
        txns = flatten_list(results["transactions"])
102
        test.assertEqual(expected_transactions_count, len(txns))
103
104
    if expected_positions_count is not None:
105
        raise NotImplementedError
106
107
    if expected_order_count is not None:
108
        # de-dup orders on id, because orders are put back into perf packets
109
        # whenever they a txn is filled
110
        orders = set([order['id'] for order in
111
                      flatten_list(results["orders"])])
112
113
        test.assertEqual(expected_order_count, len(orders))
114
115
116
def flatten_list(list):
117
    return [item for sublist in list for item in sublist]
118
119
120
def assert_single_position(test, zipline):
121
122
    output, transaction_count = drain_zipline(test, zipline)
123
124
    if 'expected_transactions' in test.zipline_test_config:
125
        test.assertEqual(
126
            test.zipline_test_config['expected_transactions'],
127
            transaction_count
128
        )
129
    else:
130
        test.assertEqual(
131
            test.zipline_test_config['order_count'],
132
            transaction_count
133
        )
134
135
    # the final message is the risk report, the second to
136
    # last is the final day's results. Positions is a list of
137
    # dicts.
138
    closing_positions = output[-2]['daily_perf']['positions']
139
140
    # confirm that all orders were filled.
141
    # iterate over the output updates, overwriting
142
    # orders when they are updated. Then check the status on all.
143
    orders_by_id = {}
144
    for update in output:
145
        if 'daily_perf' in update:
146
            if 'orders' in update['daily_perf']:
147
                for order in update['daily_perf']['orders']:
148
                    orders_by_id[order['id']] = order
149
150
    for order in itervalues(orders_by_id):
151
        test.assertEqual(
152
            order['status'],
153
            ORDER_STATUS.FILLED,
154
            "")
155
156
    test.assertEqual(
157
        len(closing_positions),
158
        1,
159
        "Portfolio should have one position."
160
    )
161
162
    sid = test.zipline_test_config['sid']
163
    test.assertEqual(
164
        closing_positions[0]['sid'],
165
        sid,
166
        "Portfolio should have one position in " + str(sid)
167
    )
168
169
    return output, transaction_count
170
171
172
class ExceptionSource(object):
173
174
    def __init__(self):
175
        pass
176
177
    def get_hash(self):
178
        return "ExceptionSource"
179
180
    def __iter__(self):
181
        return self
182
183
    def next(self):
184
        5 / 0
185
186
    def __next__(self):
187
        5 / 0
188
189
190
@contextmanager
191
def security_list_copy():
192
    old_dir = security_list.SECURITY_LISTS_DIR
193
    new_dir = tempfile.mkdtemp()
194
    try:
195
        for subdir in os.listdir(old_dir):
196
            shutil.copytree(os.path.join(old_dir, subdir),
197
                            os.path.join(new_dir, subdir))
198
            with patch.object(security_list, 'SECURITY_LISTS_DIR', new_dir), \
199
                    patch.object(security_list, 'using_copy', True,
200
                                 create=True):
201
                yield
202
    finally:
203
        shutil.rmtree(new_dir, True)
204
205
206
def add_security_data(adds, deletes):
207
    if not hasattr(security_list, 'using_copy'):
208
        raise Exception('add_security_data must be used within '
209
                        'security_list_copy context')
210
    directory = os.path.join(
211
        security_list.SECURITY_LISTS_DIR,
212
        "leveraged_etf_list/20150127/20150125"
213
    )
214
    if not os.path.exists(directory):
215
        os.makedirs(directory)
216
    del_path = os.path.join(directory, "delete")
217
    with open(del_path, 'w') as f:
218
        for sym in deletes:
219
            f.write(sym)
220
            f.write('\n')
221
    add_path = os.path.join(directory, "add")
222
    with open(add_path, 'w') as f:
223
        for sym in adds:
224
            f.write(sym)
225
            f.write('\n')
226
227
228
def all_pairs_matching_predicate(values, pred):
229
    """
230
    Return an iterator of all pairs, (v0, v1) from values such that
231
232
    `pred(v0, v1) == True`
233
234
    Parameters
235
    ----------
236
    values : iterable
237
    pred : function
238
239
    Returns
240
    -------
241
    pairs_iterator : generator
242
       Generator yielding pairs matching `pred`.
243
244
    Examples
245
    --------
246
    >>> from zipline.utils.test_utils import all_pairs_matching_predicate
247
    >>> from operator import eq, lt
248
    >>> list(all_pairs_matching_predicate(range(5), eq))
249
    [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)]
250
    >>> list(all_pairs_matching_predicate("abcd", lt))
251
    [('a', 'b'), ('a', 'c'), ('a', 'd'), ('b', 'c'), ('b', 'd'), ('c', 'd')]
252
    """
253
    return filter(lambda pair: pred(*pair), product(values, repeat=2))
254
255
256
def product_upper_triangle(values, include_diagonal=False):
257
    """
258
    Return an iterator over pairs, (v0, v1), drawn from values.
259
260
    If `include_diagonal` is True, returns all pairs such that v0 <= v1.
261
    If `include_diagonal` is False, returns all pairs such that v0 < v1.
262
    """
263
    return all_pairs_matching_predicate(
264
        values,
265
        operator.le if include_diagonal else operator.lt,
266
    )
267
268
269
def all_subindices(index):
270
    """
271
    Return all valid sub-indices of a pandas Index.
272
    """
273
    return (
274
        index[start:stop]
275
        for start, stop in product_upper_triangle(range(len(index) + 1))
276
    )
277
278
279
def make_rotating_equity_info(num_assets,
280
                              first_start,
281
                              frequency,
282
                              periods_between_starts,
283
                              asset_lifetime):
284
    """
285
    Create a DataFrame representing lifetimes of assets that are constantly
286
    rotating in and out of existence.
287
288
    Parameters
289
    ----------
290
    num_assets : int
291
        How many assets to create.
292
    first_start : pd.Timestamp
293
        The start date for the first asset.
294
    frequency : str or pd.tseries.offsets.Offset (e.g. trading_day)
295
        Frequency used to interpret next two arguments.
296
    periods_between_starts : int
297
        Create a new asset every `frequency` * `periods_between_new`
298
    asset_lifetime : int
299
        Each asset exists for `frequency` * `asset_lifetime` days.
300
301
    Returns
302
    -------
303
    info : pd.DataFrame
304
        DataFrame representing newly-created assets.
305
    """
306
    return pd.DataFrame(
307
        {
308
            'symbol': [chr(ord('A') + i) for i in range(num_assets)],
309
            # Start a new asset every `periods_between_starts` days.
310
            'start_date': pd.date_range(
311
                first_start,
312
                freq=(periods_between_starts * frequency),
313
                periods=num_assets,
314
            ),
315
            # Each asset lasts for `asset_lifetime` days.
316
            'end_date': pd.date_range(
317
                first_start + (asset_lifetime * frequency),
318
                freq=(periods_between_starts * frequency),
319
                periods=num_assets,
320
            ),
321
            'exchange': 'TEST',
322
        },
323
        index=range(num_assets),
324
    )
325
326
327
def make_simple_equity_info(sids, start_date, end_date, symbols=None):
328
    """
329
    Create a DataFrame representing assets that exist for the full duration
330
    between `start_date` and `end_date`.
331
332
    Parameters
333
    ----------
334
    sids : array-like of int
335
    start_date : pd.Timestamp
336
    end_date : pd.Timestamp
337
    symbols : list, optional
338
        Symbols to use for the assets.
339
        If not provided, symbols are generated from the sequence 'A', 'B', ...
340
341
    Returns
342
    -------
343
    info : pd.DataFrame
344
        DataFrame representing newly-created assets.
345
    """
346
    num_assets = len(sids)
347
    if symbols is None:
348
        symbols = list(ascii_uppercase[:num_assets])
349
    return pd.DataFrame(
350
        {
351
            'symbol': symbols,
352
            'start_date': [start_date] * num_assets,
353
            'end_date': [end_date] * num_assets,
354
            'exchange': 'TEST',
355
        },
356
        index=sids,
357
    )
358
359
360
def make_future_info(first_sid,
361
                     root_symbols,
362
                     years,
363
                     notice_date_func,
364
                     expiration_date_func,
365
                     start_date_func,
366
                     month_codes=None):
367
    """
368
    Create a DataFrame representing futures for `root_symbols` during `year`.
369
370
    Generates a contract per triple of (symbol, year, month) supplied to
371
    `root_symbols`, `years`, and `month_codes`.
372
373
    Parameters
374
    ----------
375
    first_sid : int
376
        The first sid to use for assigning sids to the created contracts.
377
    root_symbols : list[str]
378
        A list of root symbols for which to create futures.
379
    years : list[int or str]
380
        Years (e.g. 2014), for which to produce individual contracts.
381
    notice_date_func : (Timestamp) -> Timestamp
382
        Function to generate notice dates from first of the month associated
383
        with asset month code.  Return NaT to simulate futures with no notice
384
        date.
385
    expiration_date_func : (Timestamp) -> Timestamp
386
        Function to generate expiration dates from first of the month
387
        associated with asset month code.
388
    start_date_func : (Timestamp) -> Timestamp, optional
389
        Function to generate start dates from first of the month associated
390
        with each asset month code.  Defaults to a start_date one year prior
391
        to the month_code date.
392
    month_codes : dict[str -> [1..12]], optional
393
        Dictionary of month codes for which to create contracts.  Entries
394
        should be strings mapped to values from 1 (January) to 12 (December).
395
        Default is zipline.futures.CME_CODE_TO_MONTH
396
397
    Returns
398
    -------
399
    futures_info : pd.DataFrame
400
        DataFrame of futures data suitable for passing to an
401
        AssetDBWriterFromDataFrame.
402
    """
403
    if month_codes is None:
404
        month_codes = CME_CODE_TO_MONTH
405
406
    year_strs = list(map(str, years))
407
    years = [pd.Timestamp(s, tz='UTC') for s in year_strs]
408
409
    # Pairs of string/date like ('K06', 2006-05-01)
410
    contract_suffix_to_beginning_of_month = tuple(
411
        (month_code + year_str[-2:], year + MonthBegin(month_num))
412
        for ((year, year_str), (month_code, month_num))
413
        in product(
414
            zip(years, year_strs),
415
            iteritems(month_codes),
416
        )
417
    )
418
419
    contracts = []
420
    parts = product(root_symbols, contract_suffix_to_beginning_of_month)
421
    for sid, (root_sym, (suffix, month_begin)) in enumerate(parts, first_sid):
422
        contracts.append({
423
            'sid': sid,
424
            'root_symbol': root_sym,
425
            'symbol': root_sym + suffix,
426
            'start_date': start_date_func(month_begin),
427
            'notice_date': notice_date_func(month_begin),
428
            'expiration_date': notice_date_func(month_begin),
429
            'contract_multiplier': 500,
430
        })
431
    return pd.DataFrame.from_records(contracts, index='sid').convert_objects()
432
433
434
def make_commodity_future_info(first_sid,
435
                               root_symbols,
436
                               years,
437
                               month_codes=None):
438
    """
439
    Make futures testing data that simulates the notice/expiration date
440
    behavior of physical commodities like oil.
441
442
    Parameters
443
    ----------
444
    first_sid : int
445
    root_symbols : list[str]
446
    years : list[int]
447
    month_codes : dict[str -> int]
448
449
    Expiration dates are on the 20th of the month prior to the month code.
450
    Notice dates are are on the 20th two months prior to the month code.
451
    Start dates are one year before the contract month.
452
453
    See Also
454
    --------
455
    make_future_info
456
    """
457
    nineteen_days = pd.Timedelta(days=19)
458
    one_year = pd.Timedelta(days=365)
459
    return make_future_info(
460
        first_sid=first_sid,
461
        root_symbols=root_symbols,
462
        years=years,
463
        notice_date_func=lambda dt: dt - MonthBegin(2) + nineteen_days,
464
        expiration_date_func=lambda dt: dt - MonthBegin(1) + nineteen_days,
465
        start_date_func=lambda dt: dt - one_year,
466
        month_codes=month_codes,
467
    )
468
469
470
def make_simple_asset_info(assets, start_date, end_date, symbols=None):
471
    """
472
    Create a DataFrame representing assets that exist for the full duration
473
    between `start_date` and `end_date`.
474
    Parameters
475
    ----------
476
    assets : array-like
477
    start_date : pd.Timestamp
478
    end_date : pd.Timestamp
479
    symbols : list, optional
480
        Symbols to use for the assets.
481
        If not provided, symbols are generated from the sequence 'A', 'B', ...
482
    Returns
483
    -------
484
    info : pd.DataFrame
485
        DataFrame representing newly-created assets.
486
    """
487
    num_assets = len(assets)
488
    if symbols is None:
489
        symbols = list(ascii_uppercase[:num_assets])
490
    return pd.DataFrame(
491
        {
492
            'sid': assets,
493
            'symbol': symbols,
494
            'asset_type': ['equity'] * num_assets,
495
            'start_date': [start_date] * num_assets,
496
            'end_date': [end_date] * num_assets,
497
            'exchange': 'TEST',
498
        }
499
    )
500
501
502
def check_allclose(actual,
503
                   desired,
504
                   rtol=1e-07,
505
                   atol=0,
506
                   err_msg='',
507
                   verbose=True):
508
    """
509
    Wrapper around np.testing.assert_allclose that also verifies that inputs
510
    are ndarrays.
511
512
    See Also
513
    --------
514
    np.assert_allclose
515
    """
516
    if type(actual) != type(desired):
517
        raise AssertionError("%s != %s" % (type(actual), type(desired)))
518
    return assert_allclose(actual, desired, err_msg=err_msg, verbose=True)
519
520
521
def check_arrays(x, y, err_msg='', verbose=True):
522
    """
523
    Wrapper around np.testing.assert_array_equal that also verifies that inputs
524
    are ndarrays.
525
526
    See Also
527
    --------
528
    np.assert_array_equal
529
    """
530
    if type(x) != type(y):
531
        raise AssertionError("%s != %s" % (type(x), type(y)))
532
    return assert_array_equal(x, y, err_msg=err_msg, verbose=True)
533
534
535
class UnexpectedAttributeAccess(Exception):
536
    pass
537
538
539
class ExplodingObject(object):
540
    """
541
    Object that will raise an exception on any attribute access.
542
543
    Useful for verifying that an object is never touched during a
544
    function/method call.
545
    """
546
    def __getattribute__(self, name):
547
        raise UnexpectedAttributeAccess(name)
548
549
550
class DailyBarWriterFromDataFrames(BcolzDailyBarWriter):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
551
    _csv_dtypes = {
552
        'open': float64,
553
        'high': float64,
554
        'low': float64,
555
        'close': float64,
556
        'volume': float64,
557
    }
558
559
    def __init__(self, asset_map):
560
        self._asset_map = asset_map
561
562
    def gen_tables(self, assets):
563
        for asset in assets:
564
            yield asset, ctable.fromdataframe(assets[asset])
565
566
    def to_uint32(self, array, colname):
567
        arrmax = array.max()
568
        if colname in OHLC:
569
            self.check_uint_safe(arrmax * 1000, colname)
570
            return (array * 1000).astype(uint32)
571
        elif colname == 'volume':
572
            self.check_uint_safe(arrmax, colname)
573
            return array.astype(uint32)
574
        elif colname == 'day':
575
            nanos_per_second = (1000 * 1000 * 1000)
576
            self.check_uint_safe(arrmax.view(int) / nanos_per_second, colname)
577
            return (array.view(int) / nanos_per_second).astype(uint32)
578
579
    @staticmethod
580
    def check_uint_safe(value, colname):
581
        if value >= UINT32_MAX:
582
            raise ValueError(
583
                "Value %s from column '%s' is too large" % (value, colname)
584
            )
585
586
587
def write_minute_data(tempdir, minutes, sids, sid_path_func=None):
588
    assets = {}
589
590
    length = len(minutes)
591
592
    for sid_idx, sid in enumerate(sids):
593
        assets[sid] = pd.DataFrame({
594
            "open": (np.array(range(10, 10 + length)) + sid_idx) * 1000,
595
            "high": (np.array(range(15, 15 + length)) + sid_idx) * 1000,
596
            "low": (np.array(range(8, 8 + length)) + sid_idx) * 1000,
597
            "close": (np.array(range(10, 10 + length)) + sid_idx) * 1000,
598
            "volume": np.array(range(100, 100 + length)) + sid_idx,
599
            "minute": minutes
600
        }, index=minutes)
601
602
    MinuteBarWriterFromDataFrames(pd.Timestamp('2002-01-02', tz='UTC')).write(
603
        tempdir.path, assets, sid_path_func=sid_path_func)
604
605
    return tempdir.path
606
607
608
def write_daily_data(tempdir, sim_params, sids):
609
    path = os.path.join(tempdir.path, "testdaily.bcolz")
610
    assets = {}
611
    length = sim_params.days_in_period
612
    for sid_idx, sid in enumerate(sids):
613
        assets[sid] = pd.DataFrame({
614
            "open": (np.array(range(10, 10 + length)) + sid_idx),
615
            "high": (np.array(range(15, 15 + length)) + sid_idx),
616
            "low": (np.array(range(8, 8 + length)) + sid_idx),
617
            "close": (np.array(range(10, 10 + length)) + sid_idx),
618
            "volume": np.array(range(100, 100 + length)) + sid_idx,
619
            "day": [day.value for day in sim_params.trading_days]
620
        }, index=sim_params.trading_days)
621
622
    DailyBarWriterFromDataFrames(assets).write(
623
        path,
624
        sim_params.trading_days,
625
        assets
626
    )
627
628
    return path
629
630
631
def create_data_portal(env, tempdir, sim_params, sids, sid_path_func=None,
632
                       adjustment_reader=None):
633
    if sim_params.data_frequency == "daily":
634
        daily_path = write_daily_data(tempdir, sim_params, sids)
635
636
        return DataPortal(
637
            env,
638
            daily_equities_path=daily_path,
639
            sim_params=sim_params,
640
            adjustment_reader=adjustment_reader
641
        )
642
    else:
643
        minutes = env.minutes_for_days_in_range(
644
            sim_params.first_open,
645
            sim_params.last_close
646
        )
647
648
        minute_path = write_minute_data(tempdir, minutes, sids,
649
                                        sid_path_func)
650
651
        equity_minute_reader = BcolzMinuteBarReader(minute_path)
652
653
        return DataPortal(
654
            env,
655
            equity_minute_reader=equity_minute_reader,
656
            sim_params=sim_params,
657
            adjustment_reader=adjustment_reader
658
        )
659
660
661
def create_data_portal_from_trade_history(env, tempdir, sim_params,
662
                                          trades_by_sid):
663
    if sim_params.data_frequency == "daily":
664
        path = os.path.join(tempdir.path, "testdaily.bcolz")
665
        assets = {}
666
        for sidint, trades in iteritems(trades_by_sid):
667
            opens = []
668
            highs = []
669
            lows = []
670
            closes = []
671
            volumes = []
672
            for trade in trades:
673
                opens.append(trade["open_price"])
674
                highs.append(trade["high"])
675
                lows.append(trade["low"])
676
                closes.append(trade["close_price"])
677
                volumes.append(trade["volume"])
678
679
            assets[sidint] = pd.DataFrame({
680
                "open": np.array(opens),
681
                "high": np.array(highs),
682
                "low": np.array(lows),
683
                "close": np.array(closes),
684
                "volume": np.array(volumes),
685
                "day": [day.value for day in sim_params.trading_days]
686
            }, index=sim_params.trading_days)
687
688
        DailyBarWriterFromDataFrames(assets).write(
689
            path,
690
            sim_params.trading_days,
691
            assets
692
        )
693
694
        return DataPortal(
695
            env,
696
            daily_equities_path=path,
697
            sim_params=sim_params,
698
        )
699
    else:
700
        minutes = env.minutes_for_days_in_range(
701
            sim_params.first_open,
702
            sim_params.last_close
703
        )
704
705
        length = len(minutes)
706
        assets = {}
707
708
        for sidint, trades in trades_by_sid.iteritems():
709
            opens = np.zeros(length)
710
            highs = np.zeros(length)
711
            lows = np.zeros(length)
712
            closes = np.zeros(length)
713
            volumes = np.zeros(length)
714
715
            for trade in trades:
716
                # put them in the right place
717
                idx = minutes.searchsorted(trade.dt)
718
719
                opens[idx] = trade.open_price * 1000
720
                highs[idx] = trade.high * 1000
721
                lows[idx] = trade.low * 1000
722
                closes[idx] = trade.close_price * 1000
723
                volumes[idx] = trade.volume
724
725
            assets[sidint] = pd.DataFrame({
726
                "open": opens,
727
                "high": highs,
728
                "low": lows,
729
                "close": closes,
730
                "volume": volumes,
731
                "minute": minutes
732
            }, index=minutes)
733
734
        MinuteBarWriterFromDataFrames(pd.Timestamp('2002-01-02', tz='UTC')).\
735
            write(tempdir.path, assets)
736
737
        equity_minute_reader = BcolzMinuteBarReader(tempdir.path)
738
739
        return DataPortal(
740
            env,
741
            equity_minute_reader=equity_minute_reader,
742
            sim_params=sim_params
743
        )
744
745
746
class FakeDataPortal(object):
747
748
    def __init__(self):
749
        self._adjustment_reader = None
750
751
    def setup_offset_cache(self, minutes_by_day, minutes_to_day):
752
        pass
753
754
755
class FetcherDataPortal(DataPortal):
756
    """
757
    Mock dataportal that returns fake data for history and non-fetcher
758
    spot value.
759
    """
760
    def __init__(self, env, sim_params):
761
        super(FetcherDataPortal, self).__init__(
762
            env,
763
            sim_params
764
        )
765
766
    def get_spot_value(self, asset, field, dt=None):
767
        # if this is a fetcher field, exercise the regular code path
768
        if self._check_extra_sources(asset, field, (dt or self.current_dt)):
769
            return super(FetcherDataPortal, self).get_spot_value(
770
                asset, field, dt)
771
772
        # otherwise just return a fixed value
773
        return int(asset)
774
775
    def setup_offset_cache(self, minutes_by_day, minutes_to_day):
776
        pass
777
778
    def _get_daily_window_for_sid(self, asset, field, days_in_window,
779
                                  extra_slot=True):
780
        return np.arange(days_in_window, dtype=np.float64)
781
782
    def _get_minute_window_for_asset(self, asset, field, minutes_for_window):
783
        return np.arange(minutes_for_window, dtype=np.float64)
784
785
786
class tmp_assets_db(object):
787
    """Create a temporary assets sqlite database.
788
    This is meant to be used as a context manager.
789
790
    Parameters
791
    ----------
792
    data : pd.DataFrame, optional
793
        The data to feed to the writer. By default this maps:
794
        ('A', 'B', 'C') -> map(ord, 'ABC')
795
    """
796
    def __init__(self, **frames):
797
        self._eng = None
798
        if not frames:
799
            frames = {
800
                'equities': make_simple_equity_info(
801
                    list(map(ord, 'ABC')),
802
                    pd.Timestamp(0),
803
                    pd.Timestamp('2015'),
804
                )
805
            }
806
        self._data = AssetDBWriterFromDataFrame(**frames)
807
808
    def __enter__(self):
809
        self._eng = eng = create_engine('sqlite://')
810
        self._data.write_all(eng)
811
        return eng
812
813
    def __exit__(self, *excinfo):
814
        assert self._eng is not None, '_eng was not set in __enter__'
815
        self._eng.dispose()
816
817
818
class tmp_asset_finder(tmp_assets_db):
819
    """Create a temporary asset finder using an in memory sqlite db.
820
821
    Parameters
822
    ----------
823
    data : dict, optional
824
        The data to feed to the writer
825
    """
826
    def __init__(self, finder_cls=AssetFinder, **frames):
827
        self._finder_cls = finder_cls
828
        super(tmp_asset_finder, self).__init__(**frames)
829
830
    def __enter__(self):
831
        return self._finder_cls(super(tmp_asset_finder, self).__enter__())
832
833
834
class SubTestFailures(AssertionError):
835
    def __init__(self, *failures):
836
        self.failures = failures
837
838
    def __str__(self):
839
        return 'failures:\n  %s' % '\n  '.join(
840
            '\n    '.join((
841
                ', '.join('%s=%r' % item for item in scope.items()),
842
                '%s: %s' % (type(exc).__name__, exc),
843
            )) for scope, exc in self.failures,
844
        )
845
846
847
def subtest(iterator, *_names):
848
    """Construct a subtest in a unittest.
849
850
    This works by decorating a function as a subtest. The test will be run
851
    by iterating over the ``iterator`` and *unpacking the values into the
852
    function. If any of the runs fail, the result will be put into a set and
853
    the rest of the tests will be run. Finally, if any failed, all of the
854
    results will be dumped as one failure.
855
856
    Parameters
857
    ----------
858
    iterator : iterable[iterable]
859
        The iterator of arguments to pass to the function.
860
    *name : iterator[str]
861
        The names to use for each element of ``iterator``. These will be used
862
        to print the scope when a test fails. If not provided, it will use the
863
        integer index of the value as the name.
864
865
    Examples
866
    --------
867
868
    ::
869
870
       class MyTest(TestCase):
871
           def test_thing(self):
872
               # Example usage inside another test.
873
               @subtest(([n] for n in range(100000)), 'n')
874
               def subtest(n):
875
                   self.assertEqual(n % 2, 0, 'n was not even')
876
               subtest()
877
878
           @subtest(([n] for n in range(100000)), 'n')
879
           def test_decorated_function(self, n):
880
               # Example usage to parameterize an entire function.
881
               self.assertEqual(n % 2, 1, 'n was not odd')
882
883
    Notes
884
    -----
885
    We use this when we:
886
887
    * Will never want to run each parameter individually.
888
    * Have a large parameter space we are testing
889
      (see tests/utils/test_events.py).
890
891
    ``nose_parameterized.expand`` will create a test for each parameter
892
    combination which bloats the test output and makes the travis pages slow.
893
894
    We cannot use ``unittest2.TestCase.subTest`` because nose, pytest, and
895
    nose2 do not support ``addSubTest``.
896
    """
897
    def dec(f):
898
        @wraps(f)
899
        def wrapped(*args, **kwargs):
900
            names = _names
901
            failures = []
902
            for scope in iterator:
903
                scope = tuple(scope)
904
                try:
905
                    f(*args + scope, **kwargs)
906
                except Exception as e:
907
                    if not names:
908
                        names = count()
909
                    failures.append((dict(zip(names, scope)), e))
910
            if failures:
911
                raise SubTestFailures(*failures)
912
913
        return wrapped
914
    return dec
915
916
917
class MockDailyBarReader(object):
918
    def spot_price(self, col, sid, dt):
919
        return 100
920
921
922
def create_mock_adjustments(tempdir, days, splits=None, dividends=None,
923
                            mergers=None):
924
    path = tempdir.getpath("test_adjustments.db")
925
926
    # create a split for the last day
927
    writer = SQLiteAdjustmentWriter(path, days, MockDailyBarReader())
928
    if splits is None:
929
        splits = pd.DataFrame({
930
            # Hackery to make the dtypes correct on an empty frame.
931
            'effective_date': np.array([], dtype=int),
932
            'ratio': np.array([], dtype=float),
933
            'sid': np.array([], dtype=int),
934
        }, index=pd.DatetimeIndex([], tz='UTC'))
935
    else:
936
        splits = pd.DataFrame(splits)
937
938
    if mergers is None:
939
        mergers = pd.DataFrame({
940
            # Hackery to make the dtypes correct on an empty frame.
941
            'effective_date': np.array([], dtype=int),
942
            'ratio': np.array([], dtype=float),
943
            'sid': np.array([], dtype=int),
944
        }, index=pd.DatetimeIndex([], tz='UTC'))
945
    else:
946
        mergers = pd.DataFrame(mergers)
947
948
    if dividends is None:
949
        data = {
950
            # Hackery to make the dtypes correct on an empty frame.
951
            'ex_date': np.array([], dtype='datetime64[ns]'),
952
            'pay_date': np.array([], dtype='datetime64[ns]'),
953
            'record_date': np.array([], dtype='datetime64[ns]'),
954
            'declared_date': np.array([], dtype='datetime64[ns]'),
955
            'amount': np.array([], dtype=float),
956
            'sid': np.array([], dtype=int),
957
        }
958
        dividends = pd.DataFrame(
959
            data,
960
            index=pd.DatetimeIndex([], tz='UTC'),
961
            columns=['ex_date',
962
                     'pay_date',
963
                     'record_date',
964
                     'declared_date',
965
                     'amount',
966
                     'sid']
967
        )
968
    else:
969
        if not isinstance(dividends, pd.DataFrame):
970
            dividends = pd.DataFrame(dividends)
971
972
    writer.write(splits, mergers, dividends)
973
974
    return path
975