Completed
Pull Request — master (#858)
by Eddie
01:38
created

tests.pipeline.ClosesOnly   B

Complexity

Total Complexity 37

Size/Duplication

Total Lines 262
Duplicated Lines 0 %
Metric Value
dl 0
loc 262
rs 8.6
wmc 37

14 Methods

Rating   Name   Duplication   Size   Complexity  
A tearDownClass() 0 3 1
A setUpClass() 0 3 1
B test_get_output_nonexistent_pipeline() 0 27 6
A expected_close() 0 6 2
D test_attach_pipeline_after_initialize() 0 40 8
B test_pipeline_output_after_initialize() 0 28 6
D test_assets_appear_on_correct_days() 0 49 8
A initialize() 0 4 1
A handle_data() 0 2 4
A exists() 0 2 1
A barf() 0 2 1
A before_trading_start() 0 2 1
B setUp() 0 93 4
A late_attach() 0 3 1
1
"""
2
Tests for Algorithms using the Pipeline API.
3
"""
4
import os
5
from unittest import TestCase
6
from os.path import (
7
    dirname,
8
    join,
9
    realpath,
10
)
11
12
from nose_parameterized import parameterized
13
from numpy import (
14
    array,
15
    arange,
16
    full_like,
17
    float64,
18
    nan,
19
    uint32,
20
)
21
from numpy.testing import assert_almost_equal
22
from pandas import (
23
    concat,
24
    DataFrame,
25
    date_range,
26
    DatetimeIndex,
27
    read_csv,
28
    Series,
29
    Timestamp,
30
)
31
from six import iteritems, itervalues
32
from testfixtures import TempDirectory
33
34
from zipline.algorithm import TradingAlgorithm
35
from zipline.api import (
36
    attach_pipeline,
37
    pipeline_output,
38
    get_datetime,
39
)
40
from zipline.data.data_portal import DataPortal
41
from zipline.errors import (
42
    AttachPipelineAfterInitialize,
43
    PipelineOutputDuringInitialize,
44
    NoSuchPipeline,
45
)
46
from zipline.data.us_equity_pricing import (
47
    BcolzDailyBarReader,
48
    DailyBarWriterFromCSVs,
49
    SQLiteAdjustmentWriter,
50
    SQLiteAdjustmentReader,
51
)
52
from zipline.finance import trading
53
from zipline.finance.trading import SimulationParameters
54
from zipline.lib.adjustment import MULTIPLY
55
from zipline.pipeline import Pipeline
56
from zipline.pipeline.factors import VWAP
57
from zipline.pipeline.data import USEquityPricing
58
from zipline.pipeline.loaders.frame import DataFrameLoader
59
from zipline.pipeline.loaders.equity_pricing_loader import (
60
    USEquityPricingLoader,
61
)
62
from zipline.utils.test_utils import (
63
    make_simple_equity_info,
64
    str_to_seconds,
65
    DailyBarWriterFromDataFrames, FakeDataPortal)
66
from zipline.utils.tradingcalendar import (
67
    trading_day,
68
    trading_days,
69
)
70
71
72
TEST_RESOURCE_PATH = join(
73
    dirname(dirname(realpath(__file__))),  # zipline_repo/tests
74
    'resources',
75
    'pipeline_inputs',
76
)
77
78
79
def rolling_vwap(df, length):
80
    "Simple rolling vwap implementation for testing"
81
    closes = df['close'].values
82
    volumes = df['volume'].values
83
    product = closes * volumes
84
    out = full_like(closes, nan)
85
    for upper_bound in range(length, len(closes) + 1):
86
        bounds = slice(upper_bound - length, upper_bound)
87
        out[upper_bound - 1] = product[bounds].sum() / volumes[bounds].sum()
88
89
    return Series(out, index=df.index)
90
91
92
class ClosesOnly(TestCase):
93
94
    @classmethod
95
    def setUpClass(cls):
96
        cls.tempdir = TempDirectory()
97
98
    @classmethod
99
    def tearDownClass(cls):
100
        cls.tempdir.cleanup()
101
102
    def setUp(self):
103
        self.env = env = trading.TradingEnvironment()
104
        self.dates = date_range(
105
            '2014-01-01', '2014-02-01', freq=trading_day, tz='UTC'
106
        )
107
        asset_info = DataFrame.from_records([
108
            {
109
                'sid': 1,
110
                'symbol': 'A',
111
                'start_date': self.dates[10],
112
                'end_date': self.dates[13],
113
                'exchange': 'TEST',
114
            },
115
            {
116
                'sid': 2,
117
                'symbol': 'B',
118
                'start_date': self.dates[11],
119
                'end_date': self.dates[14],
120
                'exchange': 'TEST',
121
            },
122
            {
123
                'sid': 3,
124
                'symbol': 'C',
125
                'start_date': self.dates[12],
126
                'end_date': self.dates[15],
127
                'exchange': 'TEST',
128
            },
129
        ])
130
        self.first_asset_start = min(asset_info.start_date)
131
        self.last_asset_end = max(asset_info.end_date)
132
        env.write_data(equities_df=asset_info)
133
        self.asset_finder = finder = env.asset_finder
134
135
        sids = (1, 2, 3)
136
        self.assets = finder.retrieve_all(sids)
137
138
        # View of the baseline data.
139
        self.closes = DataFrame(
140
            {sid: arange(1, len(self.dates) + 1) * sid for sid in sids},
141
            index=self.dates,
142
            dtype=float,
143
        )
144
145
        # Create a data portal holding the data in self.closes
146
        data = {}
147
        for sid in sids:
148
            data[sid] = DataFrame({
149
                "open": self.closes[sid].values,
150
                "high": self.closes[sid].values,
151
                "low": self.closes[sid].values,
152
                "close": self.closes[sid].values,
153
                "volume": self.closes[sid].values,
154
                "day": [day.value for day in self.dates]
155
            })
156
157
        path = os.path.join(self.tempdir.path, "testdaily.bcolz")
158
159
        DailyBarWriterFromDataFrames(data).write(
160
            path,
161
            self.dates,
162
            data
163
        )
164
165
        daily_bar_reader = BcolzDailyBarReader(path)
166
167
        self.data_portal = DataPortal(
168
            self.env,
169
            equity_daily_reader=daily_bar_reader,
170
        )
171
172
        # Add a split for 'A' on its second date.
173
        self.split_asset = self.assets[0]
174
        self.split_date = self.split_asset.start_date + trading_day
175
        self.split_ratio = 0.5
176
        self.adjustments = DataFrame.from_records([
177
            {
178
                'sid': self.split_asset.sid,
179
                'value': self.split_ratio,
180
                'kind': MULTIPLY,
181
                'start_date': Timestamp('NaT'),
182
                'end_date': self.split_date,
183
                'apply_date': self.split_date,
184
            }
185
        ])
186
187
        # View of the data on/after the split.
188
        self.adj_closes = adj_closes = self.closes.copy()
189
        adj_closes.ix[:self.split_date, self.split_asset] *= self.split_ratio
190
191
        self.pipeline_loader = DataFrameLoader(
192
            column=USEquityPricing.close,
193
            baseline=self.closes,
194
            adjustments=self.adjustments,
195
        )
196
197
    def expected_close(self, date, asset):
198
        if date < self.split_date:
199
            lookup = self.closes
200
        else:
201
            lookup = self.adj_closes
202
        return lookup.loc[date, asset]
203
204
    def exists(self, date, asset):
205
        return asset.start_date <= date <= asset.end_date
206
207
    def test_attach_pipeline_after_initialize(self):
208
        """
209
        Assert that calling attach_pipeline after initialize raises correctly.
210
        """
211
        def initialize(context):
212
            pass
213
214
        def late_attach(context, data):
215
            attach_pipeline(Pipeline(), 'test')
216
            raise AssertionError("Shouldn't make it past attach_pipeline!")
217
218
        algo = TradingAlgorithm(
219
            initialize=initialize,
220
            handle_data=late_attach,
221
            data_frequency='daily',
222
            get_pipeline_loader=lambda column: self.pipeline_loader,
223
            start=self.first_asset_start - trading_day,
224
            end=self.last_asset_end + trading_day,
225
            env=self.env,
226
        )
227
228
        with self.assertRaises(AttachPipelineAfterInitialize):
229
            algo.run(data_portal=self.data_portal)
230
231
        def barf(context, data):
232
            raise AssertionError("Shouldn't make it past before_trading_start")
233
234
        algo = TradingAlgorithm(
235
            initialize=initialize,
236
            before_trading_start=late_attach,
237
            handle_data=barf,
238
            data_frequency='daily',
239
            get_pipeline_loader=lambda column: self.pipeline_loader,
240
            start=self.first_asset_start - trading_day,
241
            end=self.last_asset_end + trading_day,
242
            env=self.env,
243
        )
244
245
        with self.assertRaises(AttachPipelineAfterInitialize):
246
            algo.run(data_portal=self.data_portal)
247
248
    def test_pipeline_output_after_initialize(self):
249
        """
250
        Assert that calling pipeline_output after initialize raises correctly.
251
        """
252
        def initialize(context):
253
            attach_pipeline(Pipeline(), 'test')
254
            pipeline_output('test')
255
            raise AssertionError("Shouldn't make it past pipeline_output()")
256
257
        def handle_data(context, data):
258
            raise AssertionError("Shouldn't make it past initialize!")
259
260
        def before_trading_start(context, data):
261
            raise AssertionError("Shouldn't make it past initialize!")
262
263
        algo = TradingAlgorithm(
264
            initialize=initialize,
265
            handle_data=handle_data,
266
            before_trading_start=before_trading_start,
267
            data_frequency='daily',
268
            get_pipeline_loader=lambda column: self.pipeline_loader,
269
            start=self.first_asset_start - trading_day,
270
            end=self.last_asset_end + trading_day,
271
            env=self.env,
272
        )
273
274
        with self.assertRaises(PipelineOutputDuringInitialize):
275
            algo.run(data_portal=self.data_portal)
276
277
    def test_get_output_nonexistent_pipeline(self):
278
        """
279
        Assert that calling add_pipeline after initialize raises appropriately.
280
        """
281
        def initialize(context):
282
            attach_pipeline(Pipeline(), 'test')
283
284
        def handle_data(context, data):
285
            raise AssertionError("Shouldn't make it past before_trading_start")
286
287
        def before_trading_start(context, data):
288
            pipeline_output('not_test')
289
            raise AssertionError("Shouldn't make it past pipeline_output!")
290
291
        algo = TradingAlgorithm(
292
            initialize=initialize,
293
            handle_data=handle_data,
294
            before_trading_start=before_trading_start,
295
            data_frequency='daily',
296
            get_pipeline_loader=lambda column: self.pipeline_loader,
297
            start=self.first_asset_start - trading_day,
298
            end=self.last_asset_end + trading_day,
299
            env=self.env,
300
        )
301
302
        with self.assertRaises(NoSuchPipeline):
303
            algo.run(data_portal=self.data_portal)
304
305
    @parameterized.expand([('default', None),
306
                           ('day', 1),
307
                           ('week', 5),
308
                           ('year', 252),
309
                           ('all_but_one_day', 'all_but_one_day')])
310
    def test_assets_appear_on_correct_days(self, test_name, chunksize):
311
        """
312
        Assert that assets appear at correct times during a backtest, with
313
        correctly-adjusted close price values.
314
        """
315
316
        if chunksize == 'all_but_one_day':
317
            chunksize = (
318
                self.dates.get_loc(self.last_asset_end) -
319
                self.dates.get_loc(self.first_asset_start)
320
            ) - 1
321
322
        def initialize(context):
323
            p = attach_pipeline(Pipeline(), 'test', chunksize=chunksize)
324
            p.add(USEquityPricing.close.latest, 'close')
325
326
        def handle_data(context, data):
327
            results = pipeline_output('test')
328
            date = get_datetime().normalize()
329
            for asset in self.assets:
330
                # Assets should appear iff they exist today and yesterday.
331
                exists_today = self.exists(date, asset)
332
                existed_yesterday = self.exists(date - trading_day, asset)
333
                if exists_today and existed_yesterday:
334
                    latest = results.loc[asset, 'close']
335
                    self.assertEqual(latest, self.expected_close(date, asset))
336
                else:
337
                    self.assertNotIn(asset, results.index)
338
339
        before_trading_start = handle_data
340
341
        algo = TradingAlgorithm(
342
            initialize=initialize,
343
            handle_data=handle_data,
344
            before_trading_start=before_trading_start,
345
            data_frequency='daily',
346
            get_pipeline_loader=lambda column: self.pipeline_loader,
347
            start=self.first_asset_start,
348
            end=self.last_asset_end,
349
            env=self.env,
350
        )
351
352
        # Run for a week in the middle of our data.
353
        algo.run(data_portal=self.data_portal)
354
355
356
class MockDailyBarSpotReader(object):
357
    """
358
    A BcolzDailyBarReader which returns a constant value for spot price.
359
    """
360
    def spot_price(self, sid, day, column):
361
        return 100.0
362
363
364
class PipelineAlgorithmTestCase(TestCase):
365
366
    @classmethod
367
    def setUpClass(cls):
368
        cls.AAPL = 1
369
        cls.MSFT = 2
370
        cls.BRK_A = 3
371
        cls.assets = [cls.AAPL, cls.MSFT, cls.BRK_A]
372
        asset_info = make_simple_equity_info(
373
            cls.assets,
374
            Timestamp('2014'),
375
            Timestamp('2015'),
376
            ['AAPL', 'MSFT', 'BRK_A'],
377
        )
378
        cls.env = trading.TradingEnvironment()
379
        cls.env.write_data(equities_df=asset_info)
380
        cls.tempdir = tempdir = TempDirectory()
381
        tempdir.create()
382
        try:
383
            cls.raw_data, cls.bar_reader = cls.create_bar_reader(tempdir)
384
            cls.adj_reader = cls.create_adjustment_reader(tempdir)
385
            cls.pipeline_loader = USEquityPricingLoader(
386
                cls.bar_reader, cls.adj_reader
387
            )
388
        except:
389
            cls.tempdir.cleanup()
390
            raise
391
392
        cls.dates = cls.raw_data[cls.AAPL].index.tz_localize('UTC')
393
        cls.AAPL_split_date = Timestamp("2014-06-09", tz='UTC')
394
395
    @classmethod
396
    def tearDownClass(cls):
397
        del cls.env
398
        cls.tempdir.cleanup()
399
400
    @classmethod
401
    def create_bar_reader(cls, tempdir):
402
        resources = {
403
            cls.AAPL: join(TEST_RESOURCE_PATH, 'AAPL.csv'),
404
            cls.MSFT: join(TEST_RESOURCE_PATH, 'MSFT.csv'),
405
            cls.BRK_A: join(TEST_RESOURCE_PATH, 'BRK-A.csv'),
406
        }
407
        raw_data = {
408
            asset: read_csv(path, parse_dates=['day']).set_index('day')
409
            for asset, path in iteritems(resources)
410
        }
411
        # Add 'price' column as an alias because all kinds of stuff in zipline
412
        # depends on it being present. :/
413
        for frame in raw_data.values():
414
            frame['price'] = frame['close']
415
416
        writer = DailyBarWriterFromCSVs(resources)
417
        data_path = tempdir.getpath('testdata.bcolz')
418
        table = writer.write(data_path, trading_days, cls.assets)
419
        return raw_data, BcolzDailyBarReader(table)
420
421
    @classmethod
422
    def create_adjustment_reader(cls, tempdir):
423
        dbpath = tempdir.getpath('adjustments.sqlite')
424
        writer = SQLiteAdjustmentWriter(dbpath, cls.env.trading_days,
425
                                        MockDailyBarSpotReader())
426
        splits = DataFrame.from_records([
427
            {
428
                'effective_date': str_to_seconds('2014-06-09'),
429
                'ratio': (1 / 7.0),
430
                'sid': cls.AAPL,
431
            }
432
        ])
433
        mergers = DataFrame(
434
            {
435
                # Hackery to make the dtypes correct on an empty frame.
436
                'effective_date': array([], dtype=int),
437
                'ratio': array([], dtype=float),
438
                'sid': array([], dtype=int),
439
            },
440
            index=DatetimeIndex([], tz='UTC'),
441
            columns=['effective_date', 'ratio', 'sid'],
442
        )
443
        dividends = DataFrame({
444
            'sid': array([], dtype=uint32),
445
            'amount': array([], dtype=float64),
446
            'record_date': array([], dtype='datetime64[ns]'),
447
            'ex_date': array([], dtype='datetime64[ns]'),
448
            'declared_date': array([], dtype='datetime64[ns]'),
449
            'pay_date': array([], dtype='datetime64[ns]'),
450
        })
451
        writer.write(splits, mergers, dividends)
452
        return SQLiteAdjustmentReader(dbpath)
453
454
    def compute_expected_vwaps(self, window_lengths):
455
        AAPL, MSFT, BRK_A = self.AAPL, self.MSFT, self.BRK_A
456
457
        # Our view of the data before AAPL's split on June 9, 2014.
458
        raw = {k: v.copy() for k, v in iteritems(self.raw_data)}
459
460
        split_date = self.AAPL_split_date
461
        split_loc = self.dates.get_loc(split_date)
462
        split_ratio = 7.0
463
464
        # Our view of the data after AAPL's split.  All prices from before June
465
        # 9 get divided by the split ratio, and volumes get multiplied by the
466
        # split ratio.
467
        adj = {k: v.copy() for k, v in iteritems(self.raw_data)}
468
        for column in 'open', 'high', 'low', 'close':
469
            adj[AAPL].ix[:split_loc, column] /= split_ratio
470
        adj[AAPL].ix[:split_loc, 'volume'] *= split_ratio
471
472
        # length -> asset -> expected vwap
473
        vwaps = {length: {} for length in window_lengths}
474
        for length in window_lengths:
475
            for asset in AAPL, MSFT, BRK_A:
476
                raw_vwap = rolling_vwap(raw[asset], length)
477
                adj_vwap = rolling_vwap(adj[asset], length)
478
                # Shift computed results one day forward so that they're
479
                # labelled by the date on which they'll be seen in the
480
                # algorithm. (We can't show the close price for day N until day
481
                # N + 1.)
482
                vwaps[length][asset] = concat(
483
                    [
484
                        raw_vwap[:split_loc - 1],
485
                        adj_vwap[split_loc - 1:]
486
                    ]
487
                ).shift(1, trading_day)
488
489
        # Make sure all the expected vwaps have the same dates.
490
        vwap_dates = vwaps[1][self.AAPL].index
491
        for dict_ in itervalues(vwaps):
492
            # Each value is a dict mapping sid -> expected series.
493
            for series in itervalues(dict_):
494
                self.assertTrue((vwap_dates == series.index).all())
495
496
        # Spot check expectations near the AAPL split.
497
        # length 1 vwap for the morning before the split should be the close
498
        # price of the previous day.
499
        before_split = vwaps[1][AAPL].loc[split_date - trading_day]
500
        assert_almost_equal(before_split, 647.3499, decimal=2)
501
        assert_almost_equal(
502
            before_split,
503
            raw[AAPL].loc[split_date - (2 * trading_day), 'close'],
504
            decimal=2,
505
        )
506
507
        # length 1 vwap for the morning of the split should be the close price
508
        # of the previous day, **ADJUSTED FOR THE SPLIT**.
509
        on_split = vwaps[1][AAPL].loc[split_date]
510
        assert_almost_equal(on_split, 645.5700 / split_ratio, decimal=2)
511
        assert_almost_equal(
512
            on_split,
513
            raw[AAPL].loc[split_date - trading_day, 'close'] / split_ratio,
514
            decimal=2,
515
        )
516
517
        # length 1 vwap on the day after the split should be the as-traded
518
        # close on the split day.
519
        after_split = vwaps[1][AAPL].loc[split_date + trading_day]
520
        assert_almost_equal(after_split, 93.69999, decimal=2)
521
        assert_almost_equal(
522
            after_split,
523
            raw[AAPL].loc[split_date, 'close'],
524
            decimal=2,
525
        )
526
527
        return vwaps
528
529
    @parameterized.expand([
530
        (True,),
531
        (False,),
532
    ])
533
    def test_handle_adjustment(self, set_screen):
534
        AAPL, MSFT, BRK_A = assets = self.AAPL, self.MSFT, self.BRK_A
535
536
        window_lengths = [1, 2, 5, 10]
537
        vwaps = self.compute_expected_vwaps(window_lengths)
538
539
        def vwap_key(length):
540
            return "vwap_%d" % length
541
542
        def initialize(context):
543
            pipeline = Pipeline()
544
            context.vwaps = []
545
            for length in vwaps:
546
                name = vwap_key(length)
547
                factor = VWAP(window_length=length)
548
                context.vwaps.append(factor)
549
                pipeline.add(factor, name=name)
550
551
            filter_ = (USEquityPricing.close.latest > 300)
552
            pipeline.add(filter_, 'filter')
553
            if set_screen:
554
                pipeline.set_screen(filter_)
555
556
            attach_pipeline(pipeline, 'test')
557
558
        def handle_data(context, data):
559
            today = get_datetime()
560
            results = pipeline_output('test')
561
            expect_over_300 = {
562
                AAPL: today < self.AAPL_split_date,
563
                MSFT: False,
564
                BRK_A: True,
565
            }
566
            for asset in assets:
567
                should_pass_filter = expect_over_300[asset]
568
                if set_screen and not should_pass_filter:
569
                    self.assertNotIn(asset, results.index)
570
                    continue
571
572
                asset_results = results.loc[asset]
573
                self.assertEqual(asset_results['filter'], should_pass_filter)
574
                for length in vwaps:
575
                    computed = results.loc[asset, vwap_key(length)]
576
                    expected = vwaps[length][asset].loc[today]
577
                    # Only having two places of precision here is a bit
578
                    # unfortunate.
579
                    assert_almost_equal(computed, expected, decimal=2)
580
581
        # Do the same checks in before_trading_start
582
        before_trading_start = handle_data
583
584
        algo = TradingAlgorithm(
585
            initialize=initialize,
586
            handle_data=handle_data,
587
            before_trading_start=before_trading_start,
588
            data_frequency='daily',
589
            get_pipeline_loader=lambda column: self.pipeline_loader,
590
            start=self.dates[max(window_lengths)],
591
            end=self.dates[-1],
592
            env=self.env,
593
        )
594
595
        algo.run(data_portal=FakeDataPortal())
596