Completed
Pull Request — master (#858)
by Eddie
05:34 queued 02:25
created

test_get_output_nonexistent_pipeline()   B

Complexity

Conditions 6

Size

Total Lines 27

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 6
dl 0
loc 27
rs 7.5385

2 Methods

Rating   Name   Duplication   Size   Complexity  
A tests.pipeline.ClosesOnly.before_trading_start() 0 3 1
A tests.pipeline.ClosesOnly.handle_data() 0 2 4
1
"""
2
Tests for Algorithms using the Pipeline API.
3
"""
4
import os
5
from unittest import TestCase
6
from os.path import (
7
    dirname,
8
    join,
9
    realpath,
10
)
11
12
from nose_parameterized import parameterized
13
from numpy import (
14
    array,
15
    arange,
16
    full_like,
17
    float64,
18
    nan,
19
    uint32,
20
)
21
from numpy.testing import assert_almost_equal
22
from pandas import (
23
    concat,
24
    DataFrame,
25
    date_range,
26
    DatetimeIndex,
27
    read_csv,
28
    Series,
29
    Timestamp,
30
)
31
from six import iteritems, itervalues
32
from testfixtures import TempDirectory
33
34
from zipline.algorithm import TradingAlgorithm
35
from zipline.api import (
36
    attach_pipeline,
37
    pipeline_output,
38
    get_datetime,
39
)
40
from zipline.data.data_portal import DataPortal
41
from zipline.errors import (
42
    AttachPipelineAfterInitialize,
43
    PipelineOutputDuringInitialize,
44
    NoSuchPipeline,
45
)
46
from zipline.data.us_equity_pricing import (
47
    BcolzDailyBarReader,
48
    DailyBarWriterFromCSVs,
49
    SQLiteAdjustmentWriter,
50
    SQLiteAdjustmentReader,
51
)
52
from zipline.finance import trading
53
from zipline.finance.trading import SimulationParameters
54
from zipline.pipeline import Pipeline
55
from zipline.pipeline.factors import VWAP
56
from zipline.pipeline.data import USEquityPricing
57
from zipline.pipeline.loaders.frame import DataFrameLoader, MULTIPLY
58
from zipline.pipeline.loaders.equity_pricing_loader import (
59
    USEquityPricingLoader,
60
)
61
from zipline.utils.test_utils import (
62
    make_simple_equity_info,
63
    str_to_seconds,
64
    DailyBarWriterFromDataFrames, FakeDataPortal)
65
from zipline.utils.tradingcalendar import (
66
    trading_day,
67
    trading_days,
68
)
69
70
71
TEST_RESOURCE_PATH = join(
72
    dirname(dirname(realpath(__file__))),  # zipline_repo/tests
73
    'resources',
74
    'pipeline_inputs',
75
)
76
77
78
def rolling_vwap(df, length):
79
    "Simple rolling vwap implementation for testing"
80
    closes = df['close'].values
81
    volumes = df['volume'].values
82
    product = closes * volumes
83
    out = full_like(closes, nan)
84
    for upper_bound in range(length, len(closes) + 1):
85
        bounds = slice(upper_bound - length, upper_bound)
86
        out[upper_bound - 1] = product[bounds].sum() / volumes[bounds].sum()
87
88
    return Series(out, index=df.index)
89
90
91
class ClosesOnly(TestCase):
92
93
    @classmethod
94
    def setUpClass(cls):
95
        cls.tempdir = TempDirectory()
96
97
    @classmethod
98
    def tearDownClass(cls):
99
        cls.tempdir.cleanup()
100
101
    def setUp(self):
102
        self.env = env = trading.TradingEnvironment()
103
        self.dates = date_range(
104
            '2014-01-01', '2014-02-01', freq=trading_day, tz='UTC'
105
        )
106
        asset_info = DataFrame.from_records([
107
            {
108
                'sid': 1,
109
                'symbol': 'A',
110
                'start_date': self.dates[10],
111
                'end_date': self.dates[13],
112
                'exchange': 'TEST',
113
            },
114
            {
115
                'sid': 2,
116
                'symbol': 'B',
117
                'start_date': self.dates[11],
118
                'end_date': self.dates[14],
119
                'exchange': 'TEST',
120
            },
121
            {
122
                'sid': 3,
123
                'symbol': 'C',
124
                'start_date': self.dates[12],
125
                'end_date': self.dates[15],
126
                'exchange': 'TEST',
127
            },
128
        ])
129
        self.first_asset_start = min(asset_info.start_date)
130
        self.last_asset_end = max(asset_info.end_date)
131
        env.write_data(equities_df=asset_info)
132
        self.asset_finder = finder = env.asset_finder
133
134
        sids = (1, 2, 3)
135
        self.assets = finder.retrieve_all(sids)
136
137
        # View of the baseline data.
138
        self.closes = DataFrame(
139
            {sid: arange(1, len(self.dates) + 1) * sid for sid in sids},
140
            index=self.dates,
141
            dtype=float,
142
        )
143
144
        # Create a data portal holding the data in self.closes
145
        data = {}
146
        for sid in sids:
147
            data[sid] = DataFrame({
148
                "open": self.closes[sid].values,
149
                "high": self.closes[sid].values,
150
                "low": self.closes[sid].values,
151
                "close": self.closes[sid].values,
152
                "volume": self.closes[sid].values,
153
                "day": [day.value for day in self.dates]
154
            })
155
156
        path = os.path.join(self.tempdir.path, "testdaily.bcolz")
157
158
        DailyBarWriterFromDataFrames(data).write(
159
            path,
160
            self.dates,
161
            data
162
        )
163
164
        self.data_portal = DataPortal(
165
            self.env,
166
            daily_equities_path=path,
167
            sim_params=SimulationParameters(
168
                period_start=self.dates[0],
169
                period_end=self.dates[-1],
170
                env=self.env
171
            )
172
        )
173
174
        # Add a split for 'A' on its second date.
175
        self.split_asset = self.assets[0]
176
        self.split_date = self.split_asset.start_date + trading_day
177
        self.split_ratio = 0.5
178
        self.adjustments = DataFrame.from_records([
179
            {
180
                'sid': self.split_asset.sid,
181
                'value': self.split_ratio,
182
                'kind': MULTIPLY,
183
                'start_date': Timestamp('NaT'),
184
                'end_date': self.split_date,
185
                'apply_date': self.split_date,
186
            }
187
        ])
188
189
        # View of the data on/after the split.
190
        self.adj_closes = adj_closes = self.closes.copy()
191
        adj_closes.ix[:self.split_date, self.split_asset] *= self.split_ratio
192
193
        self.pipeline_loader = DataFrameLoader(
194
            column=USEquityPricing.close,
195
            baseline=self.closes,
196
            adjustments=self.adjustments,
197
        )
198
199
    def expected_close(self, date, asset):
200
        if date < self.split_date:
201
            lookup = self.closes
202
        else:
203
            lookup = self.adj_closes
204
        return lookup.loc[date, asset]
205
206
    def exists(self, date, asset):
207
        return asset.start_date <= date <= asset.end_date
208
209
    def test_attach_pipeline_after_initialize(self):
210
        """
211
        Assert that calling attach_pipeline after initialize raises correctly.
212
        """
213
        def initialize(context):
214
            pass
215
216
        def late_attach(context, data):
217
            attach_pipeline(Pipeline(), 'test')
218
            raise AssertionError("Shouldn't make it past attach_pipeline!")
219
220
        algo = TradingAlgorithm(
221
            initialize=initialize,
222
            handle_data=late_attach,
223
            data_frequency='daily',
224
            get_pipeline_loader=lambda column: self.pipeline_loader,
225
            start=self.first_asset_start - trading_day,
226
            end=self.last_asset_end + trading_day,
227
            env=self.env,
228
        )
229
230
        with self.assertRaises(AttachPipelineAfterInitialize):
231
            algo.run(data_portal=self.data_portal)
232
233
        def barf(context, data):
234
            raise AssertionError("Shouldn't make it past before_trading_start")
235
236
        algo = TradingAlgorithm(
237
            initialize=initialize,
238
            before_trading_start=late_attach,
239
            handle_data=barf,
240
            data_frequency='daily',
241
            get_pipeline_loader=lambda column: self.pipeline_loader,
242
            start=self.first_asset_start - trading_day,
243
            end=self.last_asset_end + trading_day,
244
            env=self.env,
245
        )
246
247
        with self.assertRaises(AttachPipelineAfterInitialize):
248
            algo.run(data_portal=self.data_portal)
249
250
    def test_pipeline_output_after_initialize(self):
251
        """
252
        Assert that calling pipeline_output after initialize raises correctly.
253
        """
254
        def initialize(context):
255
            attach_pipeline(Pipeline(), 'test')
256
            pipeline_output('test')
257
            raise AssertionError("Shouldn't make it past pipeline_output()")
258
259
        def handle_data(context, data):
260
            raise AssertionError("Shouldn't make it past initialize!")
261
262
        def before_trading_start(context, data):
263
            raise AssertionError("Shouldn't make it past initialize!")
264
265
        algo = TradingAlgorithm(
266
            initialize=initialize,
267
            handle_data=handle_data,
268
            before_trading_start=before_trading_start,
269
            data_frequency='daily',
270
            get_pipeline_loader=lambda column: self.pipeline_loader,
271
            start=self.first_asset_start - trading_day,
272
            end=self.last_asset_end + trading_day,
273
            env=self.env,
274
        )
275
276
        with self.assertRaises(PipelineOutputDuringInitialize):
277
            algo.run(data_portal=self.data_portal)
278
279
    def test_get_output_nonexistent_pipeline(self):
280
        """
281
        Assert that calling add_pipeline after initialize raises appropriately.
282
        """
283
        def initialize(context):
284
            attach_pipeline(Pipeline(), 'test')
285
286
        def handle_data(context, data):
287
            raise AssertionError("Shouldn't make it past before_trading_start")
288
289
        def before_trading_start(context, data):
290
            pipeline_output('not_test')
291
            raise AssertionError("Shouldn't make it past pipeline_output!")
292
293
        algo = TradingAlgorithm(
294
            initialize=initialize,
295
            handle_data=handle_data,
296
            before_trading_start=before_trading_start,
297
            data_frequency='daily',
298
            get_pipeline_loader=lambda column: self.pipeline_loader,
299
            start=self.first_asset_start - trading_day,
300
            end=self.last_asset_end + trading_day,
301
            env=self.env,
302
        )
303
304
        with self.assertRaises(NoSuchPipeline):
305
            algo.run(data_portal=self.data_portal)
306
307
    @parameterized.expand([('default', None),
308
                           ('day', 1),
309
                           ('week', 5),
310
                           ('year', 252),
311
                           ('all_but_one_day', 'all_but_one_day')])
312
    def test_assets_appear_on_correct_days(self, test_name, chunksize):
313
        """
314
        Assert that assets appear at correct times during a backtest, with
315
        correctly-adjusted close price values.
316
        """
317
318
        if chunksize == 'all_but_one_day':
319
            chunksize = (
320
                self.dates.get_loc(self.last_asset_end) -
321
                self.dates.get_loc(self.first_asset_start)
322
            ) - 1
323
324
        def initialize(context):
325
            p = attach_pipeline(Pipeline(), 'test', chunksize=chunksize)
326
            p.add(USEquityPricing.close.latest, 'close')
327
328
        def handle_data(context, data):
329
            results = pipeline_output('test')
330
            date = get_datetime().normalize()
331
            for asset in self.assets:
332
                # Assets should appear iff they exist today and yesterday.
333
                exists_today = self.exists(date, asset)
334
                existed_yesterday = self.exists(date - trading_day, asset)
335
                if exists_today and existed_yesterday:
336
                    latest = results.loc[asset, 'close']
337
                    self.assertEqual(latest, self.expected_close(date, asset))
338
                else:
339
                    self.assertNotIn(asset, results.index)
340
341
        before_trading_start = handle_data
342
343
        algo = TradingAlgorithm(
344
            initialize=initialize,
345
            handle_data=handle_data,
346
            before_trading_start=before_trading_start,
347
            data_frequency='daily',
348
            get_pipeline_loader=lambda column: self.pipeline_loader,
349
            start=self.first_asset_start,
350
            end=self.last_asset_end,
351
            env=self.env,
352
        )
353
354
        # Run for a week in the middle of our data.
355
        algo.run(data_portal=self.data_portal)
356
357
358
class MockDailyBarSpotReader(object):
359
    """
360
    A BcolzDailyBarReader which returns a constant value for spot price.
361
    """
362
    def spot_price(self, sid, day, column):
363
        return 100.0
364
365
366
class PipelineAlgorithmTestCase(TestCase):
367
368
    @classmethod
369
    def setUpClass(cls):
370
        cls.AAPL = 1
371
        cls.MSFT = 2
372
        cls.BRK_A = 3
373
        cls.assets = [cls.AAPL, cls.MSFT, cls.BRK_A]
374
        asset_info = make_simple_equity_info(
375
            cls.assets,
376
            Timestamp('2014'),
377
            Timestamp('2015'),
378
            ['AAPL', 'MSFT', 'BRK_A'],
379
        )
380
        cls.env = trading.TradingEnvironment()
381
        cls.env.write_data(equities_df=asset_info)
382
        cls.tempdir = tempdir = TempDirectory()
383
        tempdir.create()
384
        try:
385
            cls.raw_data, cls.bar_reader = cls.create_bar_reader(tempdir)
386
            cls.adj_reader = cls.create_adjustment_reader(tempdir)
387
            cls.pipeline_loader = USEquityPricingLoader(
388
                cls.bar_reader, cls.adj_reader
389
            )
390
        except:
391
            cls.tempdir.cleanup()
392
            raise
393
394
        cls.dates = cls.raw_data[cls.AAPL].index.tz_localize('UTC')
395
        cls.AAPL_split_date = Timestamp("2014-06-09", tz='UTC')
396
397
    @classmethod
398
    def tearDownClass(cls):
399
        del cls.env
400
        cls.tempdir.cleanup()
401
402
    @classmethod
403
    def create_bar_reader(cls, tempdir):
404
        resources = {
405
            cls.AAPL: join(TEST_RESOURCE_PATH, 'AAPL.csv'),
406
            cls.MSFT: join(TEST_RESOURCE_PATH, 'MSFT.csv'),
407
            cls.BRK_A: join(TEST_RESOURCE_PATH, 'BRK-A.csv'),
408
        }
409
        raw_data = {
410
            asset: read_csv(path, parse_dates=['day']).set_index('day')
411
            for asset, path in iteritems(resources)
412
        }
413
        # Add 'price' column as an alias because all kinds of stuff in zipline
414
        # depends on it being present. :/
415
        for frame in raw_data.values():
416
            frame['price'] = frame['close']
417
418
        writer = DailyBarWriterFromCSVs(resources)
419
        data_path = tempdir.getpath('testdata.bcolz')
420
        table = writer.write(data_path, trading_days, cls.assets)
421
        return raw_data, BcolzDailyBarReader(table)
422
423
    @classmethod
424
    def create_adjustment_reader(cls, tempdir):
425
        dbpath = tempdir.getpath('adjustments.sqlite')
426
        writer = SQLiteAdjustmentWriter(dbpath, cls.env.trading_days,
427
                                        MockDailyBarSpotReader())
428
        splits = DataFrame.from_records([
429
            {
430
                'effective_date': str_to_seconds('2014-06-09'),
431
                'ratio': (1 / 7.0),
432
                'sid': cls.AAPL,
433
            }
434
        ])
435
        mergers = DataFrame(
436
            {
437
                # Hackery to make the dtypes correct on an empty frame.
438
                'effective_date': array([], dtype=int),
439
                'ratio': array([], dtype=float),
440
                'sid': array([], dtype=int),
441
            },
442
            index=DatetimeIndex([], tz='UTC'),
443
            columns=['effective_date', 'ratio', 'sid'],
444
        )
445
        dividends = DataFrame({
446
            'sid': array([], dtype=uint32),
447
            'amount': array([], dtype=float64),
448
            'record_date': array([], dtype='datetime64[ns]'),
449
            'ex_date': array([], dtype='datetime64[ns]'),
450
            'declared_date': array([], dtype='datetime64[ns]'),
451
            'pay_date': array([], dtype='datetime64[ns]'),
452
        })
453
        writer.write(splits, mergers, dividends)
454
        return SQLiteAdjustmentReader(dbpath)
455
456
    def compute_expected_vwaps(self, window_lengths):
457
        AAPL, MSFT, BRK_A = self.AAPL, self.MSFT, self.BRK_A
458
459
        # Our view of the data before AAPL's split on June 9, 2014.
460
        raw = {k: v.copy() for k, v in iteritems(self.raw_data)}
461
462
        split_date = self.AAPL_split_date
463
        split_loc = self.dates.get_loc(split_date)
464
        split_ratio = 7.0
465
466
        # Our view of the data after AAPL's split.  All prices from before June
467
        # 9 get divided by the split ratio, and volumes get multiplied by the
468
        # split ratio.
469
        adj = {k: v.copy() for k, v in iteritems(self.raw_data)}
470
        for column in 'open', 'high', 'low', 'close':
471
            adj[AAPL].ix[:split_loc, column] /= split_ratio
472
        adj[AAPL].ix[:split_loc, 'volume'] *= split_ratio
473
474
        # length -> asset -> expected vwap
475
        vwaps = {length: {} for length in window_lengths}
476
        for length in window_lengths:
477
            for asset in AAPL, MSFT, BRK_A:
478
                raw_vwap = rolling_vwap(raw[asset], length)
479
                adj_vwap = rolling_vwap(adj[asset], length)
480
                # Shift computed results one day forward so that they're
481
                # labelled by the date on which they'll be seen in the
482
                # algorithm. (We can't show the close price for day N until day
483
                # N + 1.)
484
                vwaps[length][asset] = concat(
485
                    [
486
                        raw_vwap[:split_loc - 1],
487
                        adj_vwap[split_loc - 1:]
488
                    ]
489
                ).shift(1, trading_day)
490
491
        # Make sure all the expected vwaps have the same dates.
492
        vwap_dates = vwaps[1][self.AAPL].index
493
        for dict_ in itervalues(vwaps):
494
            # Each value is a dict mapping sid -> expected series.
495
            for series in itervalues(dict_):
496
                self.assertTrue((vwap_dates == series.index).all())
497
498
        # Spot check expectations near the AAPL split.
499
        # length 1 vwap for the morning before the split should be the close
500
        # price of the previous day.
501
        before_split = vwaps[1][AAPL].loc[split_date - trading_day]
502
        assert_almost_equal(before_split, 647.3499, decimal=2)
503
        assert_almost_equal(
504
            before_split,
505
            raw[AAPL].loc[split_date - (2 * trading_day), 'close'],
506
            decimal=2,
507
        )
508
509
        # length 1 vwap for the morning of the split should be the close price
510
        # of the previous day, **ADJUSTED FOR THE SPLIT**.
511
        on_split = vwaps[1][AAPL].loc[split_date]
512
        assert_almost_equal(on_split, 645.5700 / split_ratio, decimal=2)
513
        assert_almost_equal(
514
            on_split,
515
            raw[AAPL].loc[split_date - trading_day, 'close'] / split_ratio,
516
            decimal=2,
517
        )
518
519
        # length 1 vwap on the day after the split should be the as-traded
520
        # close on the split day.
521
        after_split = vwaps[1][AAPL].loc[split_date + trading_day]
522
        assert_almost_equal(after_split, 93.69999, decimal=2)
523
        assert_almost_equal(
524
            after_split,
525
            raw[AAPL].loc[split_date, 'close'],
526
            decimal=2,
527
        )
528
529
        return vwaps
530
531
    @parameterized.expand([
532
        (True,),
533
        (False,),
534
    ])
535
    def test_handle_adjustment(self, set_screen):
536
        AAPL, MSFT, BRK_A = assets = self.AAPL, self.MSFT, self.BRK_A
537
538
        window_lengths = [1, 2, 5, 10]
539
        vwaps = self.compute_expected_vwaps(window_lengths)
540
541
        def vwap_key(length):
542
            return "vwap_%d" % length
543
544
        def initialize(context):
545
            pipeline = Pipeline()
546
            context.vwaps = []
547
            for length in vwaps:
548
                name = vwap_key(length)
549
                factor = VWAP(window_length=length)
550
                context.vwaps.append(factor)
551
                pipeline.add(factor, name=name)
552
553
            filter_ = (USEquityPricing.close.latest > 300)
554
            pipeline.add(filter_, 'filter')
555
            if set_screen:
556
                pipeline.set_screen(filter_)
557
558
            attach_pipeline(pipeline, 'test')
559
560
        def handle_data(context, data):
561
            today = get_datetime()
562
            results = pipeline_output('test')
563
            expect_over_300 = {
564
                AAPL: today < self.AAPL_split_date,
565
                MSFT: False,
566
                BRK_A: True,
567
            }
568
            for asset in assets:
569
                should_pass_filter = expect_over_300[asset]
570
                if set_screen and not should_pass_filter:
571
                    self.assertNotIn(asset, results.index)
572
                    continue
573
574
                asset_results = results.loc[asset]
575
                self.assertEqual(asset_results['filter'], should_pass_filter)
576
                for length in vwaps:
577
                    computed = results.loc[asset, vwap_key(length)]
578
                    expected = vwaps[length][asset].loc[today]
579
                    # Only having two places of precision here is a bit
580
                    # unfortunate.
581
                    assert_almost_equal(computed, expected, decimal=2)
582
583
        # Do the same checks in before_trading_start
584
        before_trading_start = handle_data
585
586
        algo = TradingAlgorithm(
587
            initialize=initialize,
588
            handle_data=handle_data,
589
            before_trading_start=before_trading_start,
590
            data_frequency='daily',
591
            get_pipeline_loader=lambda column: self.pipeline_loader,
592
            start=self.dates[max(window_lengths)],
593
            end=self.dates[-1],
594
            env=self.env,
595
        )
596
597
        algo.run(data_portal=FakeDataPortal())
598