Completed
Pull Request — master (#858)
by Eddie
10:07 queued 01:13
created

test_pipeline_output_after_initialize()   B

Complexity

Conditions 6

Size

Total Lines 28

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 6
dl 0
loc 28
rs 7.5385

2 Methods

Rating   Name   Duplication   Size   Complexity  
A tests.pipeline.ClosesOnly.before_trading_start() 0 2 1
A tests.pipeline.ClosesOnly.initialize() 0 4 1
1
"""
2
Tests for Algorithms using the Pipeline API.
3
"""
4
import os
5
from unittest import TestCase
6
from os.path import (
7
    dirname,
8
    join,
9
    realpath,
10
)
11
12
from nose_parameterized import parameterized
13
from numpy import (
14
    array,
15
    arange,
16
    full_like,
17
    float64,
18
    nan,
19
    uint32,
20
)
21
from numpy.testing import assert_almost_equal
22
from pandas import (
23
    concat,
24
    DataFrame,
25
    date_range,
26
    DatetimeIndex,
27
    read_csv,
28
    Series,
29
    Timestamp,
30
)
31
from six import iteritems, itervalues
32
from testfixtures import TempDirectory
33
34
from zipline.algorithm import TradingAlgorithm
35
from zipline.api import (
36
    attach_pipeline,
37
    pipeline_output,
38
    get_datetime,
39
)
40
from zipline.data.data_portal import DataPortal
41
from zipline.errors import (
42
    AttachPipelineAfterInitialize,
43
    PipelineOutputDuringInitialize,
44
    NoSuchPipeline,
45
)
46
from zipline.data.us_equity_pricing import (
47
    BcolzDailyBarReader,
48
    DailyBarWriterFromCSVs,
49
    SQLiteAdjustmentWriter,
50
    SQLiteAdjustmentReader,
51
)
52
from zipline.finance import trading
53
from zipline.lib.adjustment import MULTIPLY
54
from zipline.pipeline import Pipeline
55
from zipline.pipeline.factors import VWAP
56
from zipline.pipeline.data import USEquityPricing
57
from zipline.pipeline.loaders.frame import DataFrameLoader
58
from zipline.pipeline.loaders.equity_pricing_loader import (
59
    USEquityPricingLoader,
60
)
61
from zipline.utils.test_utils import (
62
    make_simple_equity_info,
63
    str_to_seconds,
64
    DailyBarWriterFromDataFrames, FakeDataPortal)
65
from zipline.utils.tradingcalendar import (
66
    trading_day,
67
    trading_days,
68
)
69
70
71
TEST_RESOURCE_PATH = join(
72
    dirname(dirname(realpath(__file__))),  # zipline_repo/tests
73
    'resources',
74
    'pipeline_inputs',
75
)
76
77
78
def rolling_vwap(df, length):
79
    "Simple rolling vwap implementation for testing"
80
    closes = df['close'].values
81
    volumes = df['volume'].values
82
    product = closes * volumes
83
    out = full_like(closes, nan)
84
    for upper_bound in range(length, len(closes) + 1):
85
        bounds = slice(upper_bound - length, upper_bound)
86
        out[upper_bound - 1] = product[bounds].sum() / volumes[bounds].sum()
87
88
    return Series(out, index=df.index)
89
90
91
class ClosesOnly(TestCase):
92
93
    @classmethod
94
    def setUpClass(cls):
95
        cls.tempdir = TempDirectory()
96
97
    @classmethod
98
    def tearDownClass(cls):
99
        cls.tempdir.cleanup()
100
101
    def setUp(self):
102
        self.env = env = trading.TradingEnvironment()
103
        self.dates = date_range(
104
            '2014-01-01', '2014-02-01', freq=trading_day, tz='UTC'
105
        )
106
        asset_info = DataFrame.from_records([
107
            {
108
                'sid': 1,
109
                'symbol': 'A',
110
                'start_date': self.dates[10],
111
                'end_date': self.dates[13],
112
                'exchange': 'TEST',
113
            },
114
            {
115
                'sid': 2,
116
                'symbol': 'B',
117
                'start_date': self.dates[11],
118
                'end_date': self.dates[14],
119
                'exchange': 'TEST',
120
            },
121
            {
122
                'sid': 3,
123
                'symbol': 'C',
124
                'start_date': self.dates[12],
125
                'end_date': self.dates[15],
126
                'exchange': 'TEST',
127
            },
128
        ])
129
        self.first_asset_start = min(asset_info.start_date)
130
        self.last_asset_end = max(asset_info.end_date)
131
        env.write_data(equities_df=asset_info)
132
        self.asset_finder = finder = env.asset_finder
133
134
        sids = (1, 2, 3)
135
        self.assets = finder.retrieve_all(sids)
136
137
        # View of the baseline data.
138
        self.closes = DataFrame(
139
            {sid: arange(1, len(self.dates) + 1) * sid for sid in sids},
140
            index=self.dates,
141
            dtype=float,
142
        )
143
144
        # Create a data portal holding the data in self.closes
145
        data = {}
146
        for sid in sids:
147
            data[sid] = DataFrame({
148
                "open": self.closes[sid].values,
149
                "high": self.closes[sid].values,
150
                "low": self.closes[sid].values,
151
                "close": self.closes[sid].values,
152
                "volume": self.closes[sid].values,
153
                "day": [day.value for day in self.dates]
154
            })
155
156
        path = os.path.join(self.tempdir.path, "testdaily.bcolz")
157
158
        DailyBarWriterFromDataFrames(data).write(
159
            path,
160
            self.dates,
161
            data
162
        )
163
164
        daily_bar_reader = BcolzDailyBarReader(path)
165
166
        self.data_portal = DataPortal(
167
            self.env,
168
            equity_daily_reader=daily_bar_reader,
169
        )
170
171
        # Add a split for 'A' on its second date.
172
        self.split_asset = self.assets[0]
173
        self.split_date = self.split_asset.start_date + trading_day
174
        self.split_ratio = 0.5
175
        self.adjustments = DataFrame.from_records([
176
            {
177
                'sid': self.split_asset.sid,
178
                'value': self.split_ratio,
179
                'kind': MULTIPLY,
180
                'start_date': Timestamp('NaT'),
181
                'end_date': self.split_date,
182
                'apply_date': self.split_date,
183
            }
184
        ])
185
186
        # View of the data on/after the split.
187
        self.adj_closes = adj_closes = self.closes.copy()
188
        adj_closes.ix[:self.split_date, self.split_asset] *= self.split_ratio
189
190
        self.pipeline_loader = DataFrameLoader(
191
            column=USEquityPricing.close,
192
            baseline=self.closes,
193
            adjustments=self.adjustments,
194
        )
195
196
    def expected_close(self, date, asset):
197
        if date < self.split_date:
198
            lookup = self.closes
199
        else:
200
            lookup = self.adj_closes
201
        return lookup.loc[date, asset]
202
203
    def exists(self, date, asset):
204
        return asset.start_date <= date <= asset.end_date
205
206
    def test_attach_pipeline_after_initialize(self):
207
        """
208
        Assert that calling attach_pipeline after initialize raises correctly.
209
        """
210
        def initialize(context):
211
            pass
212
213
        def late_attach(context, data):
214
            attach_pipeline(Pipeline(), 'test')
215
            raise AssertionError("Shouldn't make it past attach_pipeline!")
216
217
        algo = TradingAlgorithm(
218
            initialize=initialize,
219
            handle_data=late_attach,
220
            data_frequency='daily',
221
            get_pipeline_loader=lambda column: self.pipeline_loader,
222
            start=self.first_asset_start - trading_day,
223
            end=self.last_asset_end + trading_day,
224
            env=self.env,
225
        )
226
227
        with self.assertRaises(AttachPipelineAfterInitialize):
228
            algo.run(data_portal=self.data_portal)
229
230
        def barf(context, data):
231
            raise AssertionError("Shouldn't make it past before_trading_start")
232
233
        algo = TradingAlgorithm(
234
            initialize=initialize,
235
            before_trading_start=late_attach,
236
            handle_data=barf,
237
            data_frequency='daily',
238
            get_pipeline_loader=lambda column: self.pipeline_loader,
239
            start=self.first_asset_start - trading_day,
240
            end=self.last_asset_end + trading_day,
241
            env=self.env,
242
        )
243
244
        with self.assertRaises(AttachPipelineAfterInitialize):
245
            algo.run(data_portal=self.data_portal)
246
247
    def test_pipeline_output_after_initialize(self):
248
        """
249
        Assert that calling pipeline_output after initialize raises correctly.
250
        """
251
        def initialize(context):
252
            attach_pipeline(Pipeline(), 'test')
253
            pipeline_output('test')
254
            raise AssertionError("Shouldn't make it past pipeline_output()")
255
256
        def handle_data(context, data):
257
            raise AssertionError("Shouldn't make it past initialize!")
258
259
        def before_trading_start(context, data):
260
            raise AssertionError("Shouldn't make it past initialize!")
261
262
        algo = TradingAlgorithm(
263
            initialize=initialize,
264
            handle_data=handle_data,
265
            before_trading_start=before_trading_start,
266
            data_frequency='daily',
267
            get_pipeline_loader=lambda column: self.pipeline_loader,
268
            start=self.first_asset_start - trading_day,
269
            end=self.last_asset_end + trading_day,
270
            env=self.env,
271
        )
272
273
        with self.assertRaises(PipelineOutputDuringInitialize):
274
            algo.run(data_portal=self.data_portal)
275
276
    def test_get_output_nonexistent_pipeline(self):
277
        """
278
        Assert that calling add_pipeline after initialize raises appropriately.
279
        """
280
        def initialize(context):
281
            attach_pipeline(Pipeline(), 'test')
282
283
        def handle_data(context, data):
284
            raise AssertionError("Shouldn't make it past before_trading_start")
285
286
        def before_trading_start(context, data):
287
            pipeline_output('not_test')
288
            raise AssertionError("Shouldn't make it past pipeline_output!")
289
290
        algo = TradingAlgorithm(
291
            initialize=initialize,
292
            handle_data=handle_data,
293
            before_trading_start=before_trading_start,
294
            data_frequency='daily',
295
            get_pipeline_loader=lambda column: self.pipeline_loader,
296
            start=self.first_asset_start - trading_day,
297
            end=self.last_asset_end + trading_day,
298
            env=self.env,
299
        )
300
301
        with self.assertRaises(NoSuchPipeline):
302
            algo.run(data_portal=self.data_portal)
303
304
    @parameterized.expand([('default', None),
305
                           ('day', 1),
306
                           ('week', 5),
307
                           ('year', 252),
308
                           ('all_but_one_day', 'all_but_one_day')])
309
    def test_assets_appear_on_correct_days(self, test_name, chunksize):
310
        """
311
        Assert that assets appear at correct times during a backtest, with
312
        correctly-adjusted close price values.
313
        """
314
315
        if chunksize == 'all_but_one_day':
316
            chunksize = (
317
                self.dates.get_loc(self.last_asset_end) -
318
                self.dates.get_loc(self.first_asset_start)
319
            ) - 1
320
321
        def initialize(context):
322
            p = attach_pipeline(Pipeline(), 'test', chunksize=chunksize)
323
            p.add(USEquityPricing.close.latest, 'close')
324
325
        def handle_data(context, data):
326
            results = pipeline_output('test')
327
            date = get_datetime().normalize()
328
            for asset in self.assets:
329
                # Assets should appear iff they exist today and yesterday.
330
                exists_today = self.exists(date, asset)
331
                existed_yesterday = self.exists(date - trading_day, asset)
332
                if exists_today and existed_yesterday:
333
                    latest = results.loc[asset, 'close']
334
                    self.assertEqual(latest, self.expected_close(date, asset))
335
                else:
336
                    self.assertNotIn(asset, results.index)
337
338
        before_trading_start = handle_data
339
340
        algo = TradingAlgorithm(
341
            initialize=initialize,
342
            handle_data=handle_data,
343
            before_trading_start=before_trading_start,
344
            data_frequency='daily',
345
            get_pipeline_loader=lambda column: self.pipeline_loader,
346
            start=self.first_asset_start,
347
            end=self.last_asset_end,
348
            env=self.env,
349
        )
350
351
        # Run for a week in the middle of our data.
352
        algo.run(data_portal=self.data_portal)
353
354
355
class MockDailyBarSpotReader(object):
356
    """
357
    A BcolzDailyBarReader which returns a constant value for spot price.
358
    """
359
    def spot_price(self, sid, day, column):
360
        return 100.0
361
362
363
class PipelineAlgorithmTestCase(TestCase):
364
365
    @classmethod
366
    def setUpClass(cls):
367
        cls.AAPL = 1
368
        cls.MSFT = 2
369
        cls.BRK_A = 3
370
        cls.assets = [cls.AAPL, cls.MSFT, cls.BRK_A]
371
        asset_info = make_simple_equity_info(
372
            cls.assets,
373
            Timestamp('2014'),
374
            Timestamp('2015'),
375
            ['AAPL', 'MSFT', 'BRK_A'],
376
        )
377
        cls.env = trading.TradingEnvironment()
378
        cls.env.write_data(equities_df=asset_info)
379
        cls.tempdir = tempdir = TempDirectory()
380
        tempdir.create()
381
        try:
382
            cls.raw_data, cls.bar_reader = cls.create_bar_reader(tempdir)
383
            cls.adj_reader = cls.create_adjustment_reader(tempdir)
384
            cls.pipeline_loader = USEquityPricingLoader(
385
                cls.bar_reader, cls.adj_reader
386
            )
387
        except:
388
            cls.tempdir.cleanup()
389
            raise
390
391
        cls.dates = cls.raw_data[cls.AAPL].index.tz_localize('UTC')
392
        cls.AAPL_split_date = Timestamp("2014-06-09", tz='UTC')
393
394
    @classmethod
395
    def tearDownClass(cls):
396
        del cls.env
397
        cls.tempdir.cleanup()
398
399
    @classmethod
400
    def create_bar_reader(cls, tempdir):
401
        resources = {
402
            cls.AAPL: join(TEST_RESOURCE_PATH, 'AAPL.csv'),
403
            cls.MSFT: join(TEST_RESOURCE_PATH, 'MSFT.csv'),
404
            cls.BRK_A: join(TEST_RESOURCE_PATH, 'BRK-A.csv'),
405
        }
406
        raw_data = {
407
            asset: read_csv(path, parse_dates=['day']).set_index('day')
408
            for asset, path in iteritems(resources)
409
        }
410
        # Add 'price' column as an alias because all kinds of stuff in zipline
411
        # depends on it being present. :/
412
        for frame in raw_data.values():
413
            frame['price'] = frame['close']
414
415
        writer = DailyBarWriterFromCSVs(resources)
416
        data_path = tempdir.getpath('testdata.bcolz')
417
        table = writer.write(data_path, trading_days, cls.assets)
418
        return raw_data, BcolzDailyBarReader(table)
419
420
    @classmethod
421
    def create_adjustment_reader(cls, tempdir):
422
        dbpath = tempdir.getpath('adjustments.sqlite')
423
        writer = SQLiteAdjustmentWriter(dbpath, cls.env.trading_days,
424
                                        MockDailyBarSpotReader())
425
        splits = DataFrame.from_records([
426
            {
427
                'effective_date': str_to_seconds('2014-06-09'),
428
                'ratio': (1 / 7.0),
429
                'sid': cls.AAPL,
430
            }
431
        ])
432
        mergers = DataFrame(
433
            {
434
                # Hackery to make the dtypes correct on an empty frame.
435
                'effective_date': array([], dtype=int),
436
                'ratio': array([], dtype=float),
437
                'sid': array([], dtype=int),
438
            },
439
            index=DatetimeIndex([], tz='UTC'),
440
            columns=['effective_date', 'ratio', 'sid'],
441
        )
442
        dividends = DataFrame({
443
            'sid': array([], dtype=uint32),
444
            'amount': array([], dtype=float64),
445
            'record_date': array([], dtype='datetime64[ns]'),
446
            'ex_date': array([], dtype='datetime64[ns]'),
447
            'declared_date': array([], dtype='datetime64[ns]'),
448
            'pay_date': array([], dtype='datetime64[ns]'),
449
        })
450
        writer.write(splits, mergers, dividends)
451
        return SQLiteAdjustmentReader(dbpath)
452
453
    def compute_expected_vwaps(self, window_lengths):
454
        AAPL, MSFT, BRK_A = self.AAPL, self.MSFT, self.BRK_A
455
456
        # Our view of the data before AAPL's split on June 9, 2014.
457
        raw = {k: v.copy() for k, v in iteritems(self.raw_data)}
458
459
        split_date = self.AAPL_split_date
460
        split_loc = self.dates.get_loc(split_date)
461
        split_ratio = 7.0
462
463
        # Our view of the data after AAPL's split.  All prices from before June
464
        # 9 get divided by the split ratio, and volumes get multiplied by the
465
        # split ratio.
466
        adj = {k: v.copy() for k, v in iteritems(self.raw_data)}
467
        for column in 'open', 'high', 'low', 'close':
468
            adj[AAPL].ix[:split_loc, column] /= split_ratio
469
        adj[AAPL].ix[:split_loc, 'volume'] *= split_ratio
470
471
        # length -> asset -> expected vwap
472
        vwaps = {length: {} for length in window_lengths}
473
        for length in window_lengths:
474
            for asset in AAPL, MSFT, BRK_A:
475
                raw_vwap = rolling_vwap(raw[asset], length)
476
                adj_vwap = rolling_vwap(adj[asset], length)
477
                # Shift computed results one day forward so that they're
478
                # labelled by the date on which they'll be seen in the
479
                # algorithm. (We can't show the close price for day N until day
480
                # N + 1.)
481
                vwaps[length][asset] = concat(
482
                    [
483
                        raw_vwap[:split_loc - 1],
484
                        adj_vwap[split_loc - 1:]
485
                    ]
486
                ).shift(1, trading_day)
487
488
        # Make sure all the expected vwaps have the same dates.
489
        vwap_dates = vwaps[1][self.AAPL].index
490
        for dict_ in itervalues(vwaps):
491
            # Each value is a dict mapping sid -> expected series.
492
            for series in itervalues(dict_):
493
                self.assertTrue((vwap_dates == series.index).all())
494
495
        # Spot check expectations near the AAPL split.
496
        # length 1 vwap for the morning before the split should be the close
497
        # price of the previous day.
498
        before_split = vwaps[1][AAPL].loc[split_date - trading_day]
499
        assert_almost_equal(before_split, 647.3499, decimal=2)
500
        assert_almost_equal(
501
            before_split,
502
            raw[AAPL].loc[split_date - (2 * trading_day), 'close'],
503
            decimal=2,
504
        )
505
506
        # length 1 vwap for the morning of the split should be the close price
507
        # of the previous day, **ADJUSTED FOR THE SPLIT**.
508
        on_split = vwaps[1][AAPL].loc[split_date]
509
        assert_almost_equal(on_split, 645.5700 / split_ratio, decimal=2)
510
        assert_almost_equal(
511
            on_split,
512
            raw[AAPL].loc[split_date - trading_day, 'close'] / split_ratio,
513
            decimal=2,
514
        )
515
516
        # length 1 vwap on the day after the split should be the as-traded
517
        # close on the split day.
518
        after_split = vwaps[1][AAPL].loc[split_date + trading_day]
519
        assert_almost_equal(after_split, 93.69999, decimal=2)
520
        assert_almost_equal(
521
            after_split,
522
            raw[AAPL].loc[split_date, 'close'],
523
            decimal=2,
524
        )
525
526
        return vwaps
527
528
    @parameterized.expand([
529
        (True,),
530
        (False,),
531
    ])
532
    def test_handle_adjustment(self, set_screen):
533
        AAPL, MSFT, BRK_A = assets = self.AAPL, self.MSFT, self.BRK_A
534
535
        window_lengths = [1, 2, 5, 10]
536
        vwaps = self.compute_expected_vwaps(window_lengths)
537
538
        def vwap_key(length):
539
            return "vwap_%d" % length
540
541
        def initialize(context):
542
            pipeline = Pipeline()
543
            context.vwaps = []
544
            for length in vwaps:
545
                name = vwap_key(length)
546
                factor = VWAP(window_length=length)
547
                context.vwaps.append(factor)
548
                pipeline.add(factor, name=name)
549
550
            filter_ = (USEquityPricing.close.latest > 300)
551
            pipeline.add(filter_, 'filter')
552
            if set_screen:
553
                pipeline.set_screen(filter_)
554
555
            attach_pipeline(pipeline, 'test')
556
557
        def handle_data(context, data):
558
            today = get_datetime()
559
            results = pipeline_output('test')
560
            expect_over_300 = {
561
                AAPL: today < self.AAPL_split_date,
562
                MSFT: False,
563
                BRK_A: True,
564
            }
565
            for asset in assets:
566
                should_pass_filter = expect_over_300[asset]
567
                if set_screen and not should_pass_filter:
568
                    self.assertNotIn(asset, results.index)
569
                    continue
570
571
                asset_results = results.loc[asset]
572
                self.assertEqual(asset_results['filter'], should_pass_filter)
573
                for length in vwaps:
574
                    computed = results.loc[asset, vwap_key(length)]
575
                    expected = vwaps[length][asset].loc[today]
576
                    # Only having two places of precision here is a bit
577
                    # unfortunate.
578
                    assert_almost_equal(computed, expected, decimal=2)
579
580
        # Do the same checks in before_trading_start
581
        before_trading_start = handle_data
582
583
        algo = TradingAlgorithm(
584
            initialize=initialize,
585
            handle_data=handle_data,
586
            before_trading_start=before_trading_start,
587
            data_frequency='daily',
588
            get_pipeline_loader=lambda column: self.pipeline_loader,
589
            start=self.dates[max(window_lengths)],
590
            end=self.dates[-1],
591
            env=self.env,
592
        )
593
594
        algo.run(data_portal=FakeDataPortal())
595