Completed
Pull Request — master (#858)
by Eddie
01:50
created

zipline.utils.FetcherDataPortal.get_spot_value()   A

Complexity

Conditions 3

Size

Total Lines 8

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 3
dl 0
loc 8
rs 9.4286
1
from contextlib import contextmanager
2
from functools import wraps
3
from itertools import (
4
    count,
5
    product,
6
)
7
import operator
8
import os
9
import shutil
10
from string import ascii_uppercase
11
import tempfile
12
from bcolz import ctable
13
14
from logbook import FileHandler
15
from mock import patch
16
from numpy.testing import assert_allclose, assert_array_equal
17
import pandas as pd
18
from pandas.tseries.offsets import MonthBegin
19
from six import iteritems, itervalues
20
from six.moves import filter
21
from sqlalchemy import create_engine
22
23
from zipline.assets import AssetFinder
24
from zipline.assets.asset_writer import AssetDBWriterFromDataFrame
25
from zipline.assets.futures import CME_CODE_TO_MONTH
26
from zipline.data.data_portal import DataPortal
27
from zipline.data.us_equity_minutes import MinuteBarWriterFromDataFrames
28
from zipline.data.us_equity_pricing import SQLiteAdjustmentWriter, OHLC, \
29
    UINT32_MAX, BcolzDailyBarWriter
30
from zipline.finance.order import ORDER_STATUS
31
from zipline.utils import security_list
32
33
import numpy as np
34
from numpy import (
35
    float64,
36
    uint32
37
)
38
39
40
EPOCH = pd.Timestamp(0, tz='UTC')
41
42
43
def seconds_to_timestamp(seconds):
44
    return pd.Timestamp(seconds, unit='s', tz='UTC')
45
46
47
def to_utc(time_str):
48
    """Convert a string in US/Eastern time to UTC"""
49
    return pd.Timestamp(time_str, tz='US/Eastern').tz_convert('UTC')
50
51
52
def str_to_seconds(s):
53
    """
54
    Convert a pandas-intelligible string to (integer) seconds since UTC.
55
56
    >>> from pandas import Timestamp
57
    >>> (Timestamp('2014-01-01') - Timestamp(0)).total_seconds()
58
    1388534400.0
59
    >>> str_to_seconds('2014-01-01')
60
    1388534400
61
    """
62
    return int((pd.Timestamp(s, tz='UTC') - EPOCH).total_seconds())
63
64
65
def setup_logger(test, path='test.log'):
66
    test.log_handler = FileHandler(path)
67
    test.log_handler.push_application()
68
69
70
def teardown_logger(test):
71
    test.log_handler.pop_application()
72
    test.log_handler.close()
73
74
75
def drain_zipline(test, zipline):
76
    output = []
77
    transaction_count = 0
78
    msg_counter = 0
79
    # start the simulation
80
    for update in zipline:
81
        msg_counter += 1
82
        output.append(update)
83
        if 'daily_perf' in update:
84
            transaction_count += \
85
                len(update['daily_perf']['transactions'])
86
87
    return output, transaction_count
88
89
90
def check_algo_results(test,
91
                       results,
92
                       expected_transactions_count=None,
93
                       expected_order_count=None,
94
                       expected_positions_count=None,
95
                       sid=None):
96
97
    if expected_transactions_count is not None:
98
        txns = flatten_list(results["transactions"])
99
        test.assertEqual(expected_transactions_count, len(txns))
100
101
    if expected_positions_count is not None:
102
        raise NotImplementedError
103
104
    if expected_order_count is not None:
105
        # de-dup orders on id, because orders are put back into perf packets
106
        # whenever they a txn is filled
107
        orders = set([order['id'] for order in
108
                      flatten_list(results["orders"])])
109
110
        test.assertEqual(expected_order_count, len(orders))
111
112
113
def flatten_list(list):
114
    return [item for sublist in list for item in sublist]
115
116
117
def assert_single_position(test, zipline):
118
119
    output, transaction_count = drain_zipline(test, zipline)
120
121
    if 'expected_transactions' in test.zipline_test_config:
122
        test.assertEqual(
123
            test.zipline_test_config['expected_transactions'],
124
            transaction_count
125
        )
126
    else:
127
        test.assertEqual(
128
            test.zipline_test_config['order_count'],
129
            transaction_count
130
        )
131
132
    # the final message is the risk report, the second to
133
    # last is the final day's results. Positions is a list of
134
    # dicts.
135
    closing_positions = output[-2]['daily_perf']['positions']
136
137
    # confirm that all orders were filled.
138
    # iterate over the output updates, overwriting
139
    # orders when they are updated. Then check the status on all.
140
    orders_by_id = {}
141
    for update in output:
142
        if 'daily_perf' in update:
143
            if 'orders' in update['daily_perf']:
144
                for order in update['daily_perf']['orders']:
145
                    orders_by_id[order['id']] = order
146
147
    for order in itervalues(orders_by_id):
148
        test.assertEqual(
149
            order['status'],
150
            ORDER_STATUS.FILLED,
151
            "")
152
153
    test.assertEqual(
154
        len(closing_positions),
155
        1,
156
        "Portfolio should have one position."
157
    )
158
159
    sid = test.zipline_test_config['sid']
160
    test.assertEqual(
161
        closing_positions[0]['sid'],
162
        sid,
163
        "Portfolio should have one position in " + str(sid)
164
    )
165
166
    return output, transaction_count
167
168
169
class ExceptionSource(object):
170
171
    def __init__(self):
172
        pass
173
174
    def get_hash(self):
175
        return "ExceptionSource"
176
177
    def __iter__(self):
178
        return self
179
180
    def next(self):
181
        5 / 0
182
183
    def __next__(self):
184
        5 / 0
185
186
187
@contextmanager
188
def security_list_copy():
189
    old_dir = security_list.SECURITY_LISTS_DIR
190
    new_dir = tempfile.mkdtemp()
191
    try:
192
        for subdir in os.listdir(old_dir):
193
            shutil.copytree(os.path.join(old_dir, subdir),
194
                            os.path.join(new_dir, subdir))
195
            with patch.object(security_list, 'SECURITY_LISTS_DIR', new_dir), \
196
                    patch.object(security_list, 'using_copy', True,
197
                                 create=True):
198
                yield
199
    finally:
200
        shutil.rmtree(new_dir, True)
201
202
203
def add_security_data(adds, deletes):
204
    if not hasattr(security_list, 'using_copy'):
205
        raise Exception('add_security_data must be used within '
206
                        'security_list_copy context')
207
    directory = os.path.join(
208
        security_list.SECURITY_LISTS_DIR,
209
        "leveraged_etf_list/20150127/20150125"
210
    )
211
    if not os.path.exists(directory):
212
        os.makedirs(directory)
213
    del_path = os.path.join(directory, "delete")
214
    with open(del_path, 'w') as f:
215
        for sym in deletes:
216
            f.write(sym)
217
            f.write('\n')
218
    add_path = os.path.join(directory, "add")
219
    with open(add_path, 'w') as f:
220
        for sym in adds:
221
            f.write(sym)
222
            f.write('\n')
223
224
225
def all_pairs_matching_predicate(values, pred):
226
    """
227
    Return an iterator of all pairs, (v0, v1) from values such that
228
229
    `pred(v0, v1) == True`
230
231
    Parameters
232
    ----------
233
    values : iterable
234
    pred : function
235
236
    Returns
237
    -------
238
    pairs_iterator : generator
239
       Generator yielding pairs matching `pred`.
240
241
    Examples
242
    --------
243
    >>> from zipline.utils.test_utils import all_pairs_matching_predicate
244
    >>> from operator import eq, lt
245
    >>> list(all_pairs_matching_predicate(range(5), eq))
246
    [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)]
247
    >>> list(all_pairs_matching_predicate("abcd", lt))
248
    [('a', 'b'), ('a', 'c'), ('a', 'd'), ('b', 'c'), ('b', 'd'), ('c', 'd')]
249
    """
250
    return filter(lambda pair: pred(*pair), product(values, repeat=2))
251
252
253
def product_upper_triangle(values, include_diagonal=False):
254
    """
255
    Return an iterator over pairs, (v0, v1), drawn from values.
256
257
    If `include_diagonal` is True, returns all pairs such that v0 <= v1.
258
    If `include_diagonal` is False, returns all pairs such that v0 < v1.
259
    """
260
    return all_pairs_matching_predicate(
261
        values,
262
        operator.le if include_diagonal else operator.lt,
263
    )
264
265
266
def all_subindices(index):
267
    """
268
    Return all valid sub-indices of a pandas Index.
269
    """
270
    return (
271
        index[start:stop]
272
        for start, stop in product_upper_triangle(range(len(index) + 1))
273
    )
274
275
276
def make_rotating_equity_info(num_assets,
277
                              first_start,
278
                              frequency,
279
                              periods_between_starts,
280
                              asset_lifetime):
281
    """
282
    Create a DataFrame representing lifetimes of assets that are constantly
283
    rotating in and out of existence.
284
285
    Parameters
286
    ----------
287
    num_assets : int
288
        How many assets to create.
289
    first_start : pd.Timestamp
290
        The start date for the first asset.
291
    frequency : str or pd.tseries.offsets.Offset (e.g. trading_day)
292
        Frequency used to interpret next two arguments.
293
    periods_between_starts : int
294
        Create a new asset every `frequency` * `periods_between_new`
295
    asset_lifetime : int
296
        Each asset exists for `frequency` * `asset_lifetime` days.
297
298
    Returns
299
    -------
300
    info : pd.DataFrame
301
        DataFrame representing newly-created assets.
302
    """
303
    return pd.DataFrame(
304
        {
305
            'symbol': [chr(ord('A') + i) for i in range(num_assets)],
306
            # Start a new asset every `periods_between_starts` days.
307
            'start_date': pd.date_range(
308
                first_start,
309
                freq=(periods_between_starts * frequency),
310
                periods=num_assets,
311
            ),
312
            # Each asset lasts for `asset_lifetime` days.
313
            'end_date': pd.date_range(
314
                first_start + (asset_lifetime * frequency),
315
                freq=(periods_between_starts * frequency),
316
                periods=num_assets,
317
            ),
318
            'exchange': 'TEST',
319
        },
320
        index=range(num_assets),
321
    )
322
323
324
def make_simple_equity_info(sids, start_date, end_date, symbols=None):
325
    """
326
    Create a DataFrame representing assets that exist for the full duration
327
    between `start_date` and `end_date`.
328
329
    Parameters
330
    ----------
331
    sids : array-like of int
332
    start_date : pd.Timestamp
333
    end_date : pd.Timestamp
334
    symbols : list, optional
335
        Symbols to use for the assets.
336
        If not provided, symbols are generated from the sequence 'A', 'B', ...
337
338
    Returns
339
    -------
340
    info : pd.DataFrame
341
        DataFrame representing newly-created assets.
342
    """
343
    num_assets = len(sids)
344
    if symbols is None:
345
        symbols = list(ascii_uppercase[:num_assets])
346
    return pd.DataFrame(
347
        {
348
            'symbol': symbols,
349
            'start_date': [start_date] * num_assets,
350
            'end_date': [end_date] * num_assets,
351
            'exchange': 'TEST',
352
        },
353
        index=sids,
354
    )
355
356
357
def make_future_info(first_sid,
358
                     root_symbols,
359
                     years,
360
                     notice_date_func,
361
                     expiration_date_func,
362
                     start_date_func,
363
                     month_codes=None):
364
    """
365
    Create a DataFrame representing futures for `root_symbols` during `year`.
366
367
    Generates a contract per triple of (symbol, year, month) supplied to
368
    `root_symbols`, `years`, and `month_codes`.
369
370
    Parameters
371
    ----------
372
    first_sid : int
373
        The first sid to use for assigning sids to the created contracts.
374
    root_symbols : list[str]
375
        A list of root symbols for which to create futures.
376
    years : list[int or str]
377
        Years (e.g. 2014), for which to produce individual contracts.
378
    notice_date_func : (Timestamp) -> Timestamp
379
        Function to generate notice dates from first of the month associated
380
        with asset month code.  Return NaT to simulate futures with no notice
381
        date.
382
    expiration_date_func : (Timestamp) -> Timestamp
383
        Function to generate expiration dates from first of the month
384
        associated with asset month code.
385
    start_date_func : (Timestamp) -> Timestamp, optional
386
        Function to generate start dates from first of the month associated
387
        with each asset month code.  Defaults to a start_date one year prior
388
        to the month_code date.
389
    month_codes : dict[str -> [1..12]], optional
390
        Dictionary of month codes for which to create contracts.  Entries
391
        should be strings mapped to values from 1 (January) to 12 (December).
392
        Default is zipline.futures.CME_CODE_TO_MONTH
393
394
    Returns
395
    -------
396
    futures_info : pd.DataFrame
397
        DataFrame of futures data suitable for passing to an
398
        AssetDBWriterFromDataFrame.
399
    """
400
    if month_codes is None:
401
        month_codes = CME_CODE_TO_MONTH
402
403
    year_strs = list(map(str, years))
404
    years = [pd.Timestamp(s, tz='UTC') for s in year_strs]
405
406
    # Pairs of string/date like ('K06', 2006-05-01)
407
    contract_suffix_to_beginning_of_month = tuple(
408
        (month_code + year_str[-2:], year + MonthBegin(month_num))
409
        for ((year, year_str), (month_code, month_num))
410
        in product(
411
            zip(years, year_strs),
412
            iteritems(month_codes),
413
        )
414
    )
415
416
    contracts = []
417
    parts = product(root_symbols, contract_suffix_to_beginning_of_month)
418
    for sid, (root_sym, (suffix, month_begin)) in enumerate(parts, first_sid):
419
        contracts.append({
420
            'sid': sid,
421
            'root_symbol': root_sym,
422
            'symbol': root_sym + suffix,
423
            'start_date': start_date_func(month_begin),
424
            'notice_date': notice_date_func(month_begin),
425
            'expiration_date': notice_date_func(month_begin),
426
            'contract_multiplier': 500,
427
        })
428
    return pd.DataFrame.from_records(contracts, index='sid').convert_objects()
429
430
431
def make_commodity_future_info(first_sid,
432
                               root_symbols,
433
                               years,
434
                               month_codes=None):
435
    """
436
    Make futures testing data that simulates the notice/expiration date
437
    behavior of physical commodities like oil.
438
439
    Parameters
440
    ----------
441
    first_sid : int
442
    root_symbols : list[str]
443
    years : list[int]
444
    month_codes : dict[str -> int]
445
446
    Expiration dates are on the 20th of the month prior to the month code.
447
    Notice dates are are on the 20th two months prior to the month code.
448
    Start dates are one year before the contract month.
449
450
    See Also
451
    --------
452
    make_future_info
453
    """
454
    nineteen_days = pd.Timedelta(days=19)
455
    one_year = pd.Timedelta(days=365)
456
    return make_future_info(
457
        first_sid=first_sid,
458
        root_symbols=root_symbols,
459
        years=years,
460
        notice_date_func=lambda dt: dt - MonthBegin(2) + nineteen_days,
461
        expiration_date_func=lambda dt: dt - MonthBegin(1) + nineteen_days,
462
        start_date_func=lambda dt: dt - one_year,
463
        month_codes=month_codes,
464
    )
465
466
467
def make_simple_asset_info(assets, start_date, end_date, symbols=None):
468
    """
469
    Create a DataFrame representing assets that exist for the full duration
470
    between `start_date` and `end_date`.
471
    Parameters
472
    ----------
473
    assets : array-like
474
    start_date : pd.Timestamp
475
    end_date : pd.Timestamp
476
    symbols : list, optional
477
        Symbols to use for the assets.
478
        If not provided, symbols are generated from the sequence 'A', 'B', ...
479
    Returns
480
    -------
481
    info : pd.DataFrame
482
        DataFrame representing newly-created assets.
483
    """
484
    num_assets = len(assets)
485
    if symbols is None:
486
        symbols = list(ascii_uppercase[:num_assets])
487
    return pd.DataFrame(
488
        {
489
            'sid': assets,
490
            'symbol': symbols,
491
            'asset_type': ['equity'] * num_assets,
492
            'start_date': [start_date] * num_assets,
493
            'end_date': [end_date] * num_assets,
494
            'exchange': 'TEST',
495
        }
496
    )
497
498
499
def check_allclose(actual,
500
                   desired,
501
                   rtol=1e-07,
502
                   atol=0,
503
                   err_msg='',
504
                   verbose=True):
505
    """
506
    Wrapper around np.testing.assert_allclose that also verifies that inputs
507
    are ndarrays.
508
509
    See Also
510
    --------
511
    np.assert_allclose
512
    """
513
    if type(actual) != type(desired):
514
        raise AssertionError("%s != %s" % (type(actual), type(desired)))
515
    return assert_allclose(actual, desired, err_msg=err_msg, verbose=True)
516
517
518
def check_arrays(x, y, err_msg='', verbose=True):
519
    """
520
    Wrapper around np.testing.assert_array_equal that also verifies that inputs
521
    are ndarrays.
522
523
    See Also
524
    --------
525
    np.assert_array_equal
526
    """
527
    if type(x) != type(y):
528
        raise AssertionError("%s != %s" % (type(x), type(y)))
529
    return assert_array_equal(x, y, err_msg=err_msg, verbose=True)
530
531
532
class UnexpectedAttributeAccess(Exception):
533
    pass
534
535
536
class ExplodingObject(object):
537
    """
538
    Object that will raise an exception on any attribute access.
539
540
    Useful for verifying that an object is never touched during a
541
    function/method call.
542
    """
543
    def __getattribute__(self, name):
544
        raise UnexpectedAttributeAccess(name)
545
546
547
class DailyBarWriterFromDataFrames(BcolzDailyBarWriter):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
548
    _csv_dtypes = {
549
        'open': float64,
550
        'high': float64,
551
        'low': float64,
552
        'close': float64,
553
        'volume': float64,
554
    }
555
556
    def __init__(self, asset_map):
557
        self._asset_map = asset_map
558
559
    def gen_tables(self, assets):
560
        for asset in assets:
561
            yield asset, ctable.fromdataframe(assets[asset])
562
563
    def to_uint32(self, array, colname):
564
        arrmax = array.max()
565
        if colname in OHLC:
566
            self.check_uint_safe(arrmax * 1000, colname)
567
            return (array * 1000).astype(uint32)
568
        elif colname == 'volume':
569
            self.check_uint_safe(arrmax, colname)
570
            return array.astype(uint32)
571
        elif colname == 'day':
572
            nanos_per_second = (1000 * 1000 * 1000)
573
            self.check_uint_safe(arrmax.view(int) / nanos_per_second, colname)
574
            return (array.view(int) / nanos_per_second).astype(uint32)
575
576
    @staticmethod
577
    def check_uint_safe(value, colname):
578
        if value >= UINT32_MAX:
579
            raise ValueError(
580
                "Value %s from column '%s' is too large" % (value, colname)
581
            )
582
583
584
def write_minute_data(tempdir, minutes, sids, sid_path_func=None):
585
    assets = {}
586
587
    length = len(minutes)
588
589
    for sid_idx, sid in enumerate(sids):
590
        assets[sid] = pd.DataFrame({
591
            "open": (np.array(range(10, 10 + length)) + sid_idx) * 1000,
592
            "high": (np.array(range(15, 15 + length)) + sid_idx) * 1000,
593
            "low": (np.array(range(8, 8 + length)) + sid_idx) * 1000,
594
            "close": (np.array(range(10, 10 + length)) + sid_idx) * 1000,
595
            "volume": np.array(range(100, 100 + length)) + sid_idx,
596
            "minute": minutes
597
        }, index=minutes)
598
599
    MinuteBarWriterFromDataFrames().write(tempdir.path, assets,
600
                                          sid_path_func=sid_path_func)
601
602
    return tempdir.path
603
604
605
def write_daily_data(tempdir, sim_params, sids):
606
    path = os.path.join(tempdir.path, "testdaily.bcolz")
607
    assets = {}
608
    length = sim_params.days_in_period
609
    for sid_idx, sid in enumerate(sids):
610
        assets[sid] = pd.DataFrame({
611
            "open": (np.array(range(10, 10 + length)) + sid_idx),
612
            "high": (np.array(range(15, 15 + length)) + sid_idx),
613
            "low": (np.array(range(8, 8 + length)) + sid_idx),
614
            "close": (np.array(range(10, 10 + length)) + sid_idx),
615
            "volume": np.array(range(100, 100 + length)) + sid_idx,
616
            "day": [day.value for day in sim_params.trading_days]
617
        }, index=sim_params.trading_days)
618
619
    DailyBarWriterFromDataFrames(assets).write(
620
        path,
621
        sim_params.trading_days,
622
        assets
623
    )
624
625
    return path
626
627
628
def create_data_portal(env, tempdir, sim_params, sids, sid_path_func=None,
629
                       adjustment_reader=None):
630
    if sim_params.data_frequency == "daily":
631
        daily_path = write_daily_data(tempdir, sim_params, sids)
632
633
        return DataPortal(
634
            env,
635
            daily_equities_path=daily_path,
636
            sim_params=sim_params,
637
            adjustment_reader=adjustment_reader
638
        )
639
    else:
640
        minutes = env.minutes_for_days_in_range(
641
            sim_params.first_open,
642
            sim_params.last_close
643
        )
644
645
        minute_path = write_minute_data(tempdir, minutes, sids,
646
                                        sid_path_func)
647
648
        return DataPortal(
649
            env,
650
            minutes_equities_path=minute_path,
651
            sim_params=sim_params,
652
            adjustment_reader=adjustment_reader
653
        )
654
655
656
def create_data_portal_from_trade_history(env, tempdir, sim_params,
657
                                          trades_by_sid):
658
    if sim_params.data_frequency == "daily":
659
        path = os.path.join(tempdir.path, "testdaily.bcolz")
660
        assets = {}
661
        for sidint, trades in iteritems(trades_by_sid):
662
            opens = []
663
            highs = []
664
            lows = []
665
            closes = []
666
            volumes = []
667
            for trade in trades:
668
                opens.append(trade["open_price"])
669
                highs.append(trade["high"])
670
                lows.append(trade["low"])
671
                closes.append(trade["close_price"])
672
                volumes.append(trade["volume"])
673
674
            assets[sidint] = pd.DataFrame({
675
                "open": np.array(opens),
676
                "high": np.array(highs),
677
                "low": np.array(lows),
678
                "close": np.array(closes),
679
                "volume": np.array(volumes),
680
                "day": [day.value for day in sim_params.trading_days]
681
            }, index=sim_params.trading_days)
682
683
        DailyBarWriterFromDataFrames(assets).write(
684
            path,
685
            sim_params.trading_days,
686
            assets
687
        )
688
689
        return DataPortal(
690
            env,
691
            daily_equities_path=path,
692
            sim_params=sim_params,
693
        )
694
    else:
695
        minutes = env.minutes_for_days_in_range(
696
            sim_params.first_open,
697
            sim_params.last_close
698
        )
699
700
        length = len(minutes)
701
        assets = {}
702
703
        for sidint, trades in trades_by_sid.iteritems():
704
            opens = np.zeros(length)
705
            highs = np.zeros(length)
706
            lows = np.zeros(length)
707
            closes = np.zeros(length)
708
            volumes = np.zeros(length)
709
710
            for trade in trades:
711
                # put them in the right place
712
                idx = minutes.searchsorted(trade.dt)
713
714
                opens[idx] = trade.open_price * 1000
715
                highs[idx] = trade.high * 1000
716
                lows[idx] = trade.low * 1000
717
                closes[idx] = trade.close_price * 1000
718
                volumes[idx] = trade.volume
719
720
            assets[sidint] = pd.DataFrame({
721
                "open": opens,
722
                "high": highs,
723
                "low": lows,
724
                "close": closes,
725
                "volume": volumes,
726
                "minute": minutes
727
            }, index=minutes)
728
729
        MinuteBarWriterFromDataFrames().write(tempdir.path, assets)
730
731
        return DataPortal(
732
            env,
733
            minutes_equities_path=tempdir.path,
734
            sim_params=sim_params
735
        )
736
737
738
class FakeDataPortal(object):
739
740
    def __init__(self):
741
        self._adjustment_reader = None
742
743
    def setup_offset_cache(self, minutes_by_day, minutes_to_day):
744
        pass
745
746
747
class FetcherDataPortal(DataPortal):
748
    """
749
    Mock dataportal that returns fake data for history and non-fetcher
750
    spot value.
751
    """
752
    def __init__(self, env, sim_params):
753
        super(FetcherDataPortal, self).__init__(
754
            env,
755
            sim_params
756
        )
757
758
    def get_spot_value(self, asset, field, dt=None):
759
        # if this is a fetcher field, exercise the regular code path
760
        if self._check_extra_sources(asset, field, (dt or self.current_dt)):
761
            return super(FetcherDataPortal, self).get_spot_value(
762
                asset, field, dt)
763
764
        # otherwise just return a fixed value
765
        return int(asset)
766
767
    def _get_daily_window_for_sid(self, asset, field, days_in_window,
768
                                  extra_slot=True):
769
        return np.arange(days_in_window, dtype=np.float64)
770
771
    def _get_minute_window_for_asset(self, asset, field, minutes_for_window):
772
        return np.arange(minutes_for_window, dtype=np.float64)
773
774
775
class tmp_assets_db(object):
776
    """Create a temporary assets sqlite database.
777
    This is meant to be used as a context manager.
778
779
    Parameters
780
    ----------
781
    data : pd.DataFrame, optional
782
        The data to feed to the writer. By default this maps:
783
        ('A', 'B', 'C') -> map(ord, 'ABC')
784
    """
785
    def __init__(self, **frames):
786
        self._eng = None
787
        if not frames:
788
            frames = {
789
                'equities': make_simple_equity_info(
790
                    list(map(ord, 'ABC')),
791
                    pd.Timestamp(0),
792
                    pd.Timestamp('2015'),
793
                )
794
            }
795
        self._data = AssetDBWriterFromDataFrame(**frames)
796
797
    def __enter__(self):
798
        self._eng = eng = create_engine('sqlite://')
799
        self._data.write_all(eng)
800
        return eng
801
802
    def __exit__(self, *excinfo):
803
        assert self._eng is not None, '_eng was not set in __enter__'
804
        self._eng.dispose()
805
806
807
class tmp_asset_finder(tmp_assets_db):
808
    """Create a temporary asset finder using an in memory sqlite db.
809
810
    Parameters
811
    ----------
812
    data : dict, optional
813
        The data to feed to the writer
814
    """
815
    def __init__(self, finder_cls=AssetFinder, **frames):
816
        self._finder_cls = finder_cls
817
        super(tmp_asset_finder, self).__init__(**frames)
818
819
    def __enter__(self):
820
        return self._finder_cls(super(tmp_asset_finder, self).__enter__())
821
822
823
class SubTestFailures(AssertionError):
824
    def __init__(self, *failures):
825
        self.failures = failures
826
827
    def __str__(self):
828
        return 'failures:\n  %s' % '\n  '.join(
829
            '\n    '.join((
830
                ', '.join('%s=%r' % item for item in scope.items()),
831
                '%s: %s' % (type(exc).__name__, exc),
832
            )) for scope, exc in self.failures,
833
        )
834
835
836
def subtest(iterator, *_names):
837
    """Construct a subtest in a unittest.
838
839
    This works by decorating a function as a subtest. The test will be run
840
    by iterating over the ``iterator`` and *unpacking the values into the
841
    function. If any of the runs fail, the result will be put into a set and
842
    the rest of the tests will be run. Finally, if any failed, all of the
843
    results will be dumped as one failure.
844
845
    Parameters
846
    ----------
847
    iterator : iterable[iterable]
848
        The iterator of arguments to pass to the function.
849
    *name : iterator[str]
850
        The names to use for each element of ``iterator``. These will be used
851
        to print the scope when a test fails. If not provided, it will use the
852
        integer index of the value as the name.
853
854
    Examples
855
    --------
856
857
    ::
858
859
       class MyTest(TestCase):
860
           def test_thing(self):
861
               # Example usage inside another test.
862
               @subtest(([n] for n in range(100000)), 'n')
863
               def subtest(n):
864
                   self.assertEqual(n % 2, 0, 'n was not even')
865
               subtest()
866
867
           @subtest(([n] for n in range(100000)), 'n')
868
           def test_decorated_function(self, n):
869
               # Example usage to parameterize an entire function.
870
               self.assertEqual(n % 2, 1, 'n was not odd')
871
872
    Notes
873
    -----
874
    We use this when we:
875
876
    * Will never want to run each parameter individually.
877
    * Have a large parameter space we are testing
878
      (see tests/utils/test_events.py).
879
880
    ``nose_parameterized.expand`` will create a test for each parameter
881
    combination which bloats the test output and makes the travis pages slow.
882
883
    We cannot use ``unittest2.TestCase.subTest`` because nose, pytest, and
884
    nose2 do not support ``addSubTest``.
885
    """
886
    def dec(f):
887
        @wraps(f)
888
        def wrapped(*args, **kwargs):
889
            names = _names
890
            failures = []
891
            for scope in iterator:
892
                scope = tuple(scope)
893
                try:
894
                    f(*args + scope, **kwargs)
895
                except Exception as e:
896
                    if not names:
897
                        names = count()
898
                    failures.append((dict(zip(names, scope)), e))
899
            if failures:
900
                raise SubTestFailures(*failures)
901
902
        return wrapped
903
    return dec
904
905
906
class MockDailyBarReader(object):
907
    def spot_price(self, col, sid, dt):
908
        return 100
909
910
911
def create_mock_adjustments(tempdir, days, splits=None, dividends=None,
912
                            mergers=None):
913
    path = tempdir.getpath("test_adjustments.db")
914
915
    # create a split for the last day
916
    writer = SQLiteAdjustmentWriter(path, days, MockDailyBarReader())
917
    if splits is None:
918
        splits = pd.DataFrame({
919
            # Hackery to make the dtypes correct on an empty frame.
920
            'effective_date': np.array([], dtype=int),
921
            'ratio': np.array([], dtype=float),
922
            'sid': np.array([], dtype=int),
923
        }, index=pd.DatetimeIndex([], tz='UTC'))
924
    else:
925
        splits = pd.DataFrame(splits)
926
927
    if mergers is None:
928
        mergers = pd.DataFrame({
929
            # Hackery to make the dtypes correct on an empty frame.
930
            'effective_date': np.array([], dtype=int),
931
            'ratio': np.array([], dtype=float),
932
            'sid': np.array([], dtype=int),
933
        }, index=pd.DatetimeIndex([], tz='UTC'))
934
    else:
935
        mergers = pd.DataFrame(mergers)
936
937
    if dividends is None:
938
        data = {
939
            # Hackery to make the dtypes correct on an empty frame.
940
            'ex_date': np.array([], dtype='datetime64[ns]'),
941
            'pay_date': np.array([], dtype='datetime64[ns]'),
942
            'record_date': np.array([], dtype='datetime64[ns]'),
943
            'declared_date': np.array([], dtype='datetime64[ns]'),
944
            'amount': np.array([], dtype=float),
945
            'sid': np.array([], dtype=int),
946
        }
947
        dividends = pd.DataFrame(
948
            data,
949
            index=pd.DatetimeIndex([], tz='UTC'),
950
            columns=['ex_date',
951
                     'pay_date',
952
                     'record_date',
953
                     'declared_date',
954
                     'amount',
955
                     'sid']
956
        )
957
    else:
958
        if not isinstance(dividends, pd.DataFrame):
959
            dividends = pd.DataFrame(dividends)
960
961
    writer.write(splits, mergers, dividends)
962
963
    return path
964