Completed
Pull Request — master (#858)
by Eddie
01:58
created

tests.pipeline.ClosesOnly   B

Complexity

Total Complexity 37

Size/Duplication

Total Lines 265
Duplicated Lines 0 %
Metric Value
dl 0
loc 265
rs 8.6
wmc 37

14 Methods

Rating   Name   Duplication   Size   Complexity  
B test_get_output_nonexistent_pipeline() 0 27 6
A expected_close() 0 6 2
D test_attach_pipeline_after_initialize() 0 40 8
B test_pipeline_output_after_initialize() 0 28 6
D test_assets_appear_on_correct_days() 0 49 8
A initialize() 0 4 1
A exists() 0 2 1
A barf() 0 2 1
B setUp() 0 96 4
A tearDownClass() 0 3 1
A late_attach() 0 3 1
A setUpClass() 0 3 1
A handle_data() 0 2 1
A before_trading_start() 0 2 1
1
"""
2
Tests for Algorithms using the Pipeline API.
3
"""
4
import os
5
from unittest import TestCase
6
from os.path import (
7
    dirname,
8
    join,
9
    realpath,
10
)
11
12
from nose_parameterized import parameterized
13
from numpy import (
14
    array,
15
    arange,
16
    full_like,
17
    float64,
18
    nan,
19
    uint32,
20
)
21
from numpy.testing import assert_almost_equal
22
from pandas import (
23
    concat,
24
    DataFrame,
25
    date_range,
26
    DatetimeIndex,
27
    read_csv,
28
    Series,
29
    Timestamp,
30
)
31
from six import iteritems, itervalues
32
from testfixtures import TempDirectory
33
34
from zipline.algorithm import TradingAlgorithm
35
from zipline.api import (
36
    attach_pipeline,
37
    pipeline_output,
38
    get_datetime,
39
)
40
from zipline.data.data_portal import DataPortal
41
from zipline.errors import (
42
    AttachPipelineAfterInitialize,
43
    PipelineOutputDuringInitialize,
44
    NoSuchPipeline,
45
)
46
from zipline.data.us_equity_pricing import (
47
    BcolzDailyBarReader,
48
    DailyBarWriterFromCSVs,
49
    SQLiteAdjustmentWriter,
50
    SQLiteAdjustmentReader,
51
)
52
from zipline.finance import trading
53
from zipline.finance.trading import SimulationParameters
54
from zipline.lib.adjustment import MULTIPLY
55
from zipline.pipeline import Pipeline
56
from zipline.pipeline.factors import VWAP
57
from zipline.pipeline.data import USEquityPricing
58
from zipline.pipeline.loaders.frame import DataFrameLoader
59
from zipline.pipeline.loaders.equity_pricing_loader import (
60
    USEquityPricingLoader,
61
)
62
from zipline.utils.test_utils import (
63
    make_simple_equity_info,
64
    str_to_seconds,
65
    DailyBarWriterFromDataFrames, FakeDataPortal)
66
from zipline.utils.tradingcalendar import (
67
    trading_day,
68
    trading_days,
69
)
70
71
72
TEST_RESOURCE_PATH = join(
73
    dirname(dirname(realpath(__file__))),  # zipline_repo/tests
74
    'resources',
75
    'pipeline_inputs',
76
)
77
78
79
def rolling_vwap(df, length):
80
    "Simple rolling vwap implementation for testing"
81
    closes = df['close'].values
82
    volumes = df['volume'].values
83
    product = closes * volumes
84
    out = full_like(closes, nan)
85
    for upper_bound in range(length, len(closes) + 1):
86
        bounds = slice(upper_bound - length, upper_bound)
87
        out[upper_bound - 1] = product[bounds].sum() / volumes[bounds].sum()
88
89
    return Series(out, index=df.index)
90
91
92
class ClosesOnly(TestCase):
93
94
    @classmethod
95
    def setUpClass(cls):
96
        cls.tempdir = TempDirectory()
97
98
    @classmethod
99
    def tearDownClass(cls):
100
        cls.tempdir.cleanup()
101
102
    def setUp(self):
103
        self.env = env = trading.TradingEnvironment()
104
        self.dates = date_range(
105
            '2014-01-01', '2014-02-01', freq=trading_day, tz='UTC'
106
        )
107
        asset_info = DataFrame.from_records([
108
            {
109
                'sid': 1,
110
                'symbol': 'A',
111
                'start_date': self.dates[10],
112
                'end_date': self.dates[13],
113
                'exchange': 'TEST',
114
            },
115
            {
116
                'sid': 2,
117
                'symbol': 'B',
118
                'start_date': self.dates[11],
119
                'end_date': self.dates[14],
120
                'exchange': 'TEST',
121
            },
122
            {
123
                'sid': 3,
124
                'symbol': 'C',
125
                'start_date': self.dates[12],
126
                'end_date': self.dates[15],
127
                'exchange': 'TEST',
128
            },
129
        ])
130
        self.first_asset_start = min(asset_info.start_date)
131
        self.last_asset_end = max(asset_info.end_date)
132
        env.write_data(equities_df=asset_info)
133
        self.asset_finder = finder = env.asset_finder
134
135
        sids = (1, 2, 3)
136
        self.assets = finder.retrieve_all(sids)
137
138
        # View of the baseline data.
139
        self.closes = DataFrame(
140
            {sid: arange(1, len(self.dates) + 1) * sid for sid in sids},
141
            index=self.dates,
142
            dtype=float,
143
        )
144
145
        # Create a data portal holding the data in self.closes
146
        data = {}
147
        for sid in sids:
148
            data[sid] = DataFrame({
149
                "open": self.closes[sid].values,
150
                "high": self.closes[sid].values,
151
                "low": self.closes[sid].values,
152
                "close": self.closes[sid].values,
153
                "volume": self.closes[sid].values,
154
                "day": [day.value for day in self.dates]
155
            })
156
157
        path = os.path.join(self.tempdir.path, "testdaily.bcolz")
158
159
        DailyBarWriterFromDataFrames(data).write(
160
            path,
161
            self.dates,
162
            data
163
        )
164
165
        self.data_portal = DataPortal(
166
            self.env,
167
            daily_equities_path=path,
168
            sim_params=SimulationParameters(
169
                period_start=self.dates[0],
170
                period_end=self.dates[-1],
171
                env=self.env
172
            )
173
        )
174
175
        # Add a split for 'A' on its second date.
176
        self.split_asset = self.assets[0]
177
        self.split_date = self.split_asset.start_date + trading_day
178
        self.split_ratio = 0.5
179
        self.adjustments = DataFrame.from_records([
180
            {
181
                'sid': self.split_asset.sid,
182
                'value': self.split_ratio,
183
                'kind': MULTIPLY,
184
                'start_date': Timestamp('NaT'),
185
                'end_date': self.split_date,
186
                'apply_date': self.split_date,
187
            }
188
        ])
189
190
        # View of the data on/after the split.
191
        self.adj_closes = adj_closes = self.closes.copy()
192
        adj_closes.ix[:self.split_date, self.split_asset] *= self.split_ratio
193
194
        self.pipeline_loader = DataFrameLoader(
195
            column=USEquityPricing.close,
196
            baseline=self.closes,
197
            adjustments=self.adjustments,
198
        )
199
200
    def expected_close(self, date, asset):
201
        if date < self.split_date:
202
            lookup = self.closes
203
        else:
204
            lookup = self.adj_closes
205
        return lookup.loc[date, asset]
206
207
    def exists(self, date, asset):
208
        return asset.start_date <= date <= asset.end_date
209
210
    def test_attach_pipeline_after_initialize(self):
211
        """
212
        Assert that calling attach_pipeline after initialize raises correctly.
213
        """
214
        def initialize(context):
215
            pass
216
217
        def late_attach(context, data):
218
            attach_pipeline(Pipeline(), 'test')
219
            raise AssertionError("Shouldn't make it past attach_pipeline!")
220
221
        algo = TradingAlgorithm(
222
            initialize=initialize,
223
            handle_data=late_attach,
224
            data_frequency='daily',
225
            get_pipeline_loader=lambda column: self.pipeline_loader,
226
            start=self.first_asset_start - trading_day,
227
            end=self.last_asset_end + trading_day,
228
            env=self.env,
229
        )
230
231
        with self.assertRaises(AttachPipelineAfterInitialize):
232
            algo.run(data_portal=self.data_portal)
233
234
        def barf(context, data):
235
            raise AssertionError("Shouldn't make it past before_trading_start")
236
237
        algo = TradingAlgorithm(
238
            initialize=initialize,
239
            before_trading_start=late_attach,
240
            handle_data=barf,
241
            data_frequency='daily',
242
            get_pipeline_loader=lambda column: self.pipeline_loader,
243
            start=self.first_asset_start - trading_day,
244
            end=self.last_asset_end + trading_day,
245
            env=self.env,
246
        )
247
248
        with self.assertRaises(AttachPipelineAfterInitialize):
249
            algo.run(data_portal=self.data_portal)
250
251
    def test_pipeline_output_after_initialize(self):
252
        """
253
        Assert that calling pipeline_output after initialize raises correctly.
254
        """
255
        def initialize(context):
256
            attach_pipeline(Pipeline(), 'test')
257
            pipeline_output('test')
258
            raise AssertionError("Shouldn't make it past pipeline_output()")
259
260
        def handle_data(context, data):
261
            raise AssertionError("Shouldn't make it past initialize!")
262
263
        def before_trading_start(context, data):
264
            raise AssertionError("Shouldn't make it past initialize!")
265
266
        algo = TradingAlgorithm(
267
            initialize=initialize,
268
            handle_data=handle_data,
269
            before_trading_start=before_trading_start,
270
            data_frequency='daily',
271
            get_pipeline_loader=lambda column: self.pipeline_loader,
272
            start=self.first_asset_start - trading_day,
273
            end=self.last_asset_end + trading_day,
274
            env=self.env,
275
        )
276
277
        with self.assertRaises(PipelineOutputDuringInitialize):
278
            algo.run(data_portal=self.data_portal)
279
280
    def test_get_output_nonexistent_pipeline(self):
281
        """
282
        Assert that calling add_pipeline after initialize raises appropriately.
283
        """
284
        def initialize(context):
285
            attach_pipeline(Pipeline(), 'test')
286
287
        def handle_data(context, data):
288
            raise AssertionError("Shouldn't make it past before_trading_start")
289
290
        def before_trading_start(context, data):
291
            pipeline_output('not_test')
292
            raise AssertionError("Shouldn't make it past pipeline_output!")
293
294
        algo = TradingAlgorithm(
295
            initialize=initialize,
296
            handle_data=handle_data,
297
            before_trading_start=before_trading_start,
298
            data_frequency='daily',
299
            get_pipeline_loader=lambda column: self.pipeline_loader,
300
            start=self.first_asset_start - trading_day,
301
            end=self.last_asset_end + trading_day,
302
            env=self.env,
303
        )
304
305
        with self.assertRaises(NoSuchPipeline):
306
            algo.run(data_portal=self.data_portal)
307
308
    @parameterized.expand([('default', None),
309
                           ('day', 1),
310
                           ('week', 5),
311
                           ('year', 252),
312
                           ('all_but_one_day', 'all_but_one_day')])
313
    def test_assets_appear_on_correct_days(self, test_name, chunksize):
314
        """
315
        Assert that assets appear at correct times during a backtest, with
316
        correctly-adjusted close price values.
317
        """
318
319
        if chunksize == 'all_but_one_day':
320
            chunksize = (
321
                self.dates.get_loc(self.last_asset_end) -
322
                self.dates.get_loc(self.first_asset_start)
323
            ) - 1
324
325
        def initialize(context):
326
            p = attach_pipeline(Pipeline(), 'test', chunksize=chunksize)
327
            p.add(USEquityPricing.close.latest, 'close')
328
329
        def handle_data(context, data):
330
            results = pipeline_output('test')
331
            date = get_datetime().normalize()
332
            for asset in self.assets:
333
                # Assets should appear iff they exist today and yesterday.
334
                exists_today = self.exists(date, asset)
335
                existed_yesterday = self.exists(date - trading_day, asset)
336
                if exists_today and existed_yesterday:
337
                    latest = results.loc[asset, 'close']
338
                    self.assertEqual(latest, self.expected_close(date, asset))
339
                else:
340
                    self.assertNotIn(asset, results.index)
341
342
        before_trading_start = handle_data
343
344
        algo = TradingAlgorithm(
345
            initialize=initialize,
346
            handle_data=handle_data,
347
            before_trading_start=before_trading_start,
348
            data_frequency='daily',
349
            get_pipeline_loader=lambda column: self.pipeline_loader,
350
            start=self.first_asset_start,
351
            end=self.last_asset_end,
352
            env=self.env,
353
        )
354
355
        # Run for a week in the middle of our data.
356
        algo.run(data_portal=self.data_portal)
357
358
359
class MockDailyBarSpotReader(object):
360
    """
361
    A BcolzDailyBarReader which returns a constant value for spot price.
362
    """
363
    def spot_price(self, sid, day, column):
364
        return 100.0
365
366
367
class PipelineAlgorithmTestCase(TestCase):
368
369
    @classmethod
370
    def setUpClass(cls):
371
        cls.AAPL = 1
372
        cls.MSFT = 2
373
        cls.BRK_A = 3
374
        cls.assets = [cls.AAPL, cls.MSFT, cls.BRK_A]
375
        asset_info = make_simple_equity_info(
376
            cls.assets,
377
            Timestamp('2014'),
378
            Timestamp('2015'),
379
            ['AAPL', 'MSFT', 'BRK_A'],
380
        )
381
        cls.env = trading.TradingEnvironment()
382
        cls.env.write_data(equities_df=asset_info)
383
        cls.tempdir = tempdir = TempDirectory()
384
        tempdir.create()
385
        try:
386
            cls.raw_data, cls.bar_reader = cls.create_bar_reader(tempdir)
387
            cls.adj_reader = cls.create_adjustment_reader(tempdir)
388
            cls.pipeline_loader = USEquityPricingLoader(
389
                cls.bar_reader, cls.adj_reader
390
            )
391
        except:
392
            cls.tempdir.cleanup()
393
            raise
394
395
        cls.dates = cls.raw_data[cls.AAPL].index.tz_localize('UTC')
396
        cls.AAPL_split_date = Timestamp("2014-06-09", tz='UTC')
397
398
    @classmethod
399
    def tearDownClass(cls):
400
        del cls.env
401
        cls.tempdir.cleanup()
402
403
    @classmethod
404
    def create_bar_reader(cls, tempdir):
405
        resources = {
406
            cls.AAPL: join(TEST_RESOURCE_PATH, 'AAPL.csv'),
407
            cls.MSFT: join(TEST_RESOURCE_PATH, 'MSFT.csv'),
408
            cls.BRK_A: join(TEST_RESOURCE_PATH, 'BRK-A.csv'),
409
        }
410
        raw_data = {
411
            asset: read_csv(path, parse_dates=['day']).set_index('day')
412
            for asset, path in iteritems(resources)
413
        }
414
        # Add 'price' column as an alias because all kinds of stuff in zipline
415
        # depends on it being present. :/
416
        for frame in raw_data.values():
417
            frame['price'] = frame['close']
418
419
        writer = DailyBarWriterFromCSVs(resources)
420
        data_path = tempdir.getpath('testdata.bcolz')
421
        table = writer.write(data_path, trading_days, cls.assets)
422
        return raw_data, BcolzDailyBarReader(table)
423
424
    @classmethod
425
    def create_adjustment_reader(cls, tempdir):
426
        dbpath = tempdir.getpath('adjustments.sqlite')
427
        writer = SQLiteAdjustmentWriter(dbpath, cls.env.trading_days,
428
                                        MockDailyBarSpotReader())
429
        splits = DataFrame.from_records([
430
            {
431
                'effective_date': str_to_seconds('2014-06-09'),
432
                'ratio': (1 / 7.0),
433
                'sid': cls.AAPL,
434
            }
435
        ])
436
        mergers = DataFrame(
437
            {
438
                # Hackery to make the dtypes correct on an empty frame.
439
                'effective_date': array([], dtype=int),
440
                'ratio': array([], dtype=float),
441
                'sid': array([], dtype=int),
442
            },
443
            index=DatetimeIndex([], tz='UTC'),
444
            columns=['effective_date', 'ratio', 'sid'],
445
        )
446
        dividends = DataFrame({
447
            'sid': array([], dtype=uint32),
448
            'amount': array([], dtype=float64),
449
            'record_date': array([], dtype='datetime64[ns]'),
450
            'ex_date': array([], dtype='datetime64[ns]'),
451
            'declared_date': array([], dtype='datetime64[ns]'),
452
            'pay_date': array([], dtype='datetime64[ns]'),
453
        })
454
        writer.write(splits, mergers, dividends)
455
        return SQLiteAdjustmentReader(dbpath)
456
457
    def compute_expected_vwaps(self, window_lengths):
458
        AAPL, MSFT, BRK_A = self.AAPL, self.MSFT, self.BRK_A
459
460
        # Our view of the data before AAPL's split on June 9, 2014.
461
        raw = {k: v.copy() for k, v in iteritems(self.raw_data)}
462
463
        split_date = self.AAPL_split_date
464
        split_loc = self.dates.get_loc(split_date)
465
        split_ratio = 7.0
466
467
        # Our view of the data after AAPL's split.  All prices from before June
468
        # 9 get divided by the split ratio, and volumes get multiplied by the
469
        # split ratio.
470
        adj = {k: v.copy() for k, v in iteritems(self.raw_data)}
471
        for column in 'open', 'high', 'low', 'close':
472
            adj[AAPL].ix[:split_loc, column] /= split_ratio
473
        adj[AAPL].ix[:split_loc, 'volume'] *= split_ratio
474
475
        # length -> asset -> expected vwap
476
        vwaps = {length: {} for length in window_lengths}
477
        for length in window_lengths:
478
            for asset in AAPL, MSFT, BRK_A:
479
                raw_vwap = rolling_vwap(raw[asset], length)
480
                adj_vwap = rolling_vwap(adj[asset], length)
481
                # Shift computed results one day forward so that they're
482
                # labelled by the date on which they'll be seen in the
483
                # algorithm. (We can't show the close price for day N until day
484
                # N + 1.)
485
                vwaps[length][asset] = concat(
486
                    [
487
                        raw_vwap[:split_loc - 1],
488
                        adj_vwap[split_loc - 1:]
489
                    ]
490
                ).shift(1, trading_day)
491
492
        # Make sure all the expected vwaps have the same dates.
493
        vwap_dates = vwaps[1][self.AAPL].index
494
        for dict_ in itervalues(vwaps):
495
            # Each value is a dict mapping sid -> expected series.
496
            for series in itervalues(dict_):
497
                self.assertTrue((vwap_dates == series.index).all())
498
499
        # Spot check expectations near the AAPL split.
500
        # length 1 vwap for the morning before the split should be the close
501
        # price of the previous day.
502
        before_split = vwaps[1][AAPL].loc[split_date - trading_day]
503
        assert_almost_equal(before_split, 647.3499, decimal=2)
504
        assert_almost_equal(
505
            before_split,
506
            raw[AAPL].loc[split_date - (2 * trading_day), 'close'],
507
            decimal=2,
508
        )
509
510
        # length 1 vwap for the morning of the split should be the close price
511
        # of the previous day, **ADJUSTED FOR THE SPLIT**.
512
        on_split = vwaps[1][AAPL].loc[split_date]
513
        assert_almost_equal(on_split, 645.5700 / split_ratio, decimal=2)
514
        assert_almost_equal(
515
            on_split,
516
            raw[AAPL].loc[split_date - trading_day, 'close'] / split_ratio,
517
            decimal=2,
518
        )
519
520
        # length 1 vwap on the day after the split should be the as-traded
521
        # close on the split day.
522
        after_split = vwaps[1][AAPL].loc[split_date + trading_day]
523
        assert_almost_equal(after_split, 93.69999, decimal=2)
524
        assert_almost_equal(
525
            after_split,
526
            raw[AAPL].loc[split_date, 'close'],
527
            decimal=2,
528
        )
529
530
        return vwaps
531
532
    @parameterized.expand([
533
        (True,),
534
        (False,),
535
    ])
536
    def test_handle_adjustment(self, set_screen):
537
        AAPL, MSFT, BRK_A = assets = self.AAPL, self.MSFT, self.BRK_A
538
539
        window_lengths = [1, 2, 5, 10]
540
        vwaps = self.compute_expected_vwaps(window_lengths)
541
542
        def vwap_key(length):
543
            return "vwap_%d" % length
544
545
        def initialize(context):
546
            pipeline = Pipeline()
547
            context.vwaps = []
548
            for length in vwaps:
549
                name = vwap_key(length)
550
                factor = VWAP(window_length=length)
551
                context.vwaps.append(factor)
552
                pipeline.add(factor, name=name)
553
554
            filter_ = (USEquityPricing.close.latest > 300)
555
            pipeline.add(filter_, 'filter')
556
            if set_screen:
557
                pipeline.set_screen(filter_)
558
559
            attach_pipeline(pipeline, 'test')
560
561
        def handle_data(context, data):
562
            today = get_datetime()
563
            results = pipeline_output('test')
564
            expect_over_300 = {
565
                AAPL: today < self.AAPL_split_date,
566
                MSFT: False,
567
                BRK_A: True,
568
            }
569
            for asset in assets:
570
                should_pass_filter = expect_over_300[asset]
571
                if set_screen and not should_pass_filter:
572
                    self.assertNotIn(asset, results.index)
573
                    continue
574
575
                asset_results = results.loc[asset]
576
                self.assertEqual(asset_results['filter'], should_pass_filter)
577
                for length in vwaps:
578
                    computed = results.loc[asset, vwap_key(length)]
579
                    expected = vwaps[length][asset].loc[today]
580
                    # Only having two places of precision here is a bit
581
                    # unfortunate.
582
                    assert_almost_equal(computed, expected, decimal=2)
583
584
        # Do the same checks in before_trading_start
585
        before_trading_start = handle_data
586
587
        algo = TradingAlgorithm(
588
            initialize=initialize,
589
            handle_data=handle_data,
590
            before_trading_start=before_trading_start,
591
            data_frequency='daily',
592
            get_pipeline_loader=lambda column: self.pipeline_loader,
593
            start=self.dates[max(window_lengths)],
594
            end=self.dates[-1],
595
            env=self.env,
596
        )
597
598
        algo.run(data_portal=FakeDataPortal())
599