Completed
Pull Request — master (#858)
by Eddie
02:02
created

test_fetch_csv_with_multi_symbols()   A

Complexity

Conditions 1

Size

Total Lines 18

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 1
dl 0
loc 18
rs 9.4286
1
#
2
# Copyright 2015 Quantopian, Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
from unittest import TestCase
17
from nose_parameterized import parameterized
18
19
import pandas as pd
20
import numpy as np
21
import responses
22
from mock import patch
23
from zipline import TradingAlgorithm
24
from zipline.errors import UnsupportedOrderParameters
25
from zipline.finance.trading import TradingEnvironment
26
from zipline.sources.requests_csv import mask_requests_args
27
28
from zipline.utils import factory
29
from zipline.utils.test_utils import FetcherDataPortal
30
31
from .resources.fetcher_inputs.fetcher_test_data import (
32
    MULTI_SIGNAL_CSV_DATA,
33
    AAPL_CSV_DATA,
34
    AAPL_MINUTE_CSV_DATA,
35
    IBM_CSV_DATA,
36
    ANNUAL_AAPL_CSV_DATA,
37
    AAPL_IBM_CSV_DATA,
38
    NOMATCH_CSV_DATA,
39
    CPIAUCSL_DATA,
40
    PALLADIUM_DATA,
41
    FETCHER_UNIVERSE_DATA,
42
    NON_ASSET_FETCHER_UNIVERSE_DATA,
43
    FETCHER_UNIVERSE_DATA_TICKER_COLUMN, FETCHER_ALTERNATE_COLUMN_HEADER)
44
45
46
class FetcherTestCase(TestCase):
47
    @classmethod
48
    def setUpClass(cls):
49
        responses.start()
50
        responses.add(responses.GET,
51
                      'https://fake.urls.com/aapl_minute_csv_data.csv',
52
                      body=AAPL_MINUTE_CSV_DATA, content_type='text/csv')
53
        responses.add(responses.GET,
54
                      'https://fake.urls.com/aapl_csv_data.csv',
55
                      body=AAPL_CSV_DATA, content_type='text/csv')
56
        responses.add(responses.GET,
57
                      'https://fake.urls.com/multi_signal_csv_data.csv',
58
                      body=MULTI_SIGNAL_CSV_DATA, content_type='text/csv')
59
        responses.add(responses.GET,
60
                      'https://fake.urls.com/nomatch_csv_data.csv',
61
                      body=NOMATCH_CSV_DATA, content_type='text/csv')
62
        responses.add(responses.GET,
63
                      'https://fake.urls.com/cpiaucsl_data.csv',
64
                      body=CPIAUCSL_DATA, content_type='text/csv')
65
        responses.add(responses.GET,
66
                      'https://fake.urls.com/ibm_csv_data.csv',
67
                      body=IBM_CSV_DATA, content_type='text/csv')
68
        responses.add(responses.GET,
69
                      'https://fake.urls.com/aapl_ibm_csv_data.csv',
70
                      body=AAPL_IBM_CSV_DATA, content_type='text/csv')
71
        responses.add(responses.GET,
72
                      'https://fake.urls.com/palladium_data.csv',
73
                      body=PALLADIUM_DATA, content_type='text/csv')
74
        responses.add(responses.GET,
75
                      'https://fake.urls.com/fetcher_universe_data.csv',
76
                      body=FETCHER_UNIVERSE_DATA, content_type='text/csv')
77
        responses.add(responses.GET,
78
                      'https://fake.urls.com/bad_fetcher_universe_data.csv',
79
                      body=NON_ASSET_FETCHER_UNIVERSE_DATA,
80
                      content_type='text/csv')
81
        responses.add(responses.GET,
82
                      'https://fake.urls.com/annual_aapl_csv_data.csv',
83
                      body=ANNUAL_AAPL_CSV_DATA, content_type='text/csv')
84
85
        cls.sim_params = factory.create_simulation_parameters()
86
        cls.env = TradingEnvironment()
87
        cls.env.write_data(
88
            equities_data={
89
                24: {
90
                    "start_date": pd.Timestamp("2006-01-01", tz='UTC'),
91
                    "end_date": pd.Timestamp("2007-01-01", tz='UTC'),
92
                    'symbol': "AAPL",
93
                    "asset_type": "equity",
94
                    "exchange": "nasdaq"
95
                },
96
                3766: {
97
                    "start_date": pd.Timestamp("2006-01-01", tz='UTC'),
98
                    "end_date": pd.Timestamp("2007-01-01", tz='UTC'),
99
                    'symbol': "IBM",
100
                    "asset_type": "equity",
101
                    "exchange": "nasdaq"
102
                },
103
                5061: {
104
                    "start_date": pd.Timestamp("2006-01-01", tz='UTC'),
105
                    "end_date": pd.Timestamp("2007-01-01", tz='UTC'),
106
                    'symbol': "MSFT",
107
                    "asset_type": "equity",
108
                    "exchange": "nasdaq"
109
                },
110
                14848: {
111
                    "start_date": pd.Timestamp("2006-01-01", tz='UTC'),
112
                    "end_date": pd.Timestamp("2007-01-01", tz='UTC'),
113
                    'symbol': "YHOO",
114
                    "asset_type": "equity",
115
                    "exchange": "nasdaq"
116
                },
117
                25317: {
118
                    "start_date": pd.Timestamp("2006-01-01", tz='UTC'),
119
                    "end_date": pd.Timestamp("2007-01-01", tz='UTC'),
120
                    'symbol': "DELL",
121
                    "asset_type": "equity",
122
                    "exchange": "nasdaq"
123
                }
124
            }
125
126
        )
127
128
    def run_algo(self, code, sim_params=None, data_frequency="daily"):
129
        if sim_params is None:
130
            sim_params = self.sim_params
131
132
        test_algo = TradingAlgorithm(
133
            script=code,
134
            sim_params=sim_params,
135
            env=self.env,
136
            data_frequency=data_frequency
137
        )
138
139
        results = test_algo.run(
140
            data_portal=FetcherDataPortal(self.env, self.sim_params)
141
        )
142
143
        return results
144
145
    def test_minutely_fetcher(self):
146
        sim_params = factory.create_simulation_parameters(
147
            start=pd.Timestamp("2006-01-03", tz='UTC'),
148
            end=pd.Timestamp("2006-01-31", tz='UTC'),
149
            emission_rate="minute",
150
            data_frequency="minute"
151
        )
152
153
        test_algo = TradingAlgorithm(
154
            script="""
155
from zipline.api import fetch_csv, record, sid
156
157
def initialize(context):
158
    fetch_csv('https://fake.urls.com/aapl_minute_csv_data.csv')
159
160
def handle_data(context, data):
161
    record(aapl_signal=data[sid(24)].signal)
162
""", sim_params=sim_params, data_frequency="minute", env=self.env)
163
164
        # manually setting data portal and getting generator because we need
165
        # the minutely emission packets here.  TradingAlgorithm.run() only
166
        # returns daily packets.
167
        test_algo.data_portal = FetcherDataPortal(self.env, sim_params)
168
        gen = test_algo.get_generator()
169
        perf_packets = list(gen)
170
171
        signal = [result["minute_perf"]["recorded_vars"]["aapl_signal"] for
172
                  result in perf_packets if "minute_perf" in result]
173
174
        self.assertEqual(20 * 390, len(signal))
175
176
        # csv data is:
177
        # symbol,date,signal
178
        # aapl,1/4/06 4:01PM,-1
179
        # aapl,1/5/06 4:00PM,5
180
        # aapl,1/6/06 9:30AM,6
181
        # aapl,1/9/06 12:01PM,9
182
183
        # dates are interpreted as UTC time
184
        # market hours are 14:31-21:00 UTC each of those days
185
186
        # day1 starts at 2006-01-04 14:31
187
        # day1 ends at 2006-01-04 21:00
188
        # day2 starts at 2006-01-04 14:31
189
        # -1 starts at 2006-01-04 16:01
190
        # day2 ends at 2006-01-04 14:31
191
        # day3 starts at 2006-01-05 14:31
192
        # 5 starts at 2006-01-05 16:00
193
        # day3 ends at 2006-01-05 21:00
194
        # 6 starts at 2006-01-06 9:30
195
        # day4 starts at 2006-01-06 14:31
196
        # day4 ends at 2006-01-06 21:00
197
        # 9 starts at 2006-01-09 12:01
198
        # day5 starts at 2006-01-09 14:31
199
        # day5 ends at 2006-01-09 21:00
200
        # ...
201
        # day20 ends at 2006-01-31 21:00
202
203
        # 480 NaNs
204
        # 389 -1s
205
        # 301 5s
206
        # 390 6s
207
        # 6240 9s
208
209
        values = [result["minute_perf"]["recorded_vars"]["aapl_signal"]
210
                  for result in perf_packets if "minute_perf" in result]
211
212
        np.testing.assert_array_equal([np.NaN] * 480, values[0:480])
213
        np.testing.assert_array_equal([-1.0] * 389, values[480:869])
214
        np.testing.assert_array_equal([5.0] * 301, values[869:1170])
215
        np.testing.assert_array_equal([6.0] * 390, values[1170:1560])
216
        np.testing.assert_array_equal([9.0] * 6240, values[1560:])
217
218
    def test_fetch_csv_with_multi_symbols(self):
219
        results = self.run_algo(
220
            """
221
from zipline.api import fetch_csv, record, sid
222
223
def initialize(context):
224
    fetch_csv('https://fake.urls.com/multi_signal_csv_data.csv')
225
    context.stocks = [sid(3766), sid(25317)]
226
227
def handle_data(context, data):
228
    record(ibm_signal=data[sid(3766)]["signal"])
229
    record(dell_signal=data[sid(25317)]["signal"])
230
231
    assert "signal" not in data[sid(24)]
232
    """)
233
234
        self.assertEqual(5, results["ibm_signal"].iloc[-1])
235
        self.assertEqual(5, results["dell_signal"].iloc[-1])
236
237
    def test_fetch_csv_with_nomatch_symbol(self):
238
        """
239
        The algorithm is loading data with a symbol column
240
        that contains a symbol that doesn't match anything in
241
        our database. Letting these types events through
242
        creates complications for the order method. So we drop these
243
        values, instead of just letting the securities roll through.
244
245
        This test also ensures that if a given symbol has *multiple*
246
        matches, but *none* of the matches are inside the test range,
247
        that none of them are found (nor does the algo crash)
248
        """
249
        results = self.run_algo(
250
            """
251
from zipline.api import fetch_csv, sid, record
252
253
def initialize(context):
254
    fetch_csv('https://fake.urls.com/nomatch_csv_data.csv',
255
              mask=True)
256
    context.stocks = [sid(3766), sid(25317)]
257
258
def handle_data(context, data):
259
    if "signal" in data[sid(3766)]:
260
        record(ibm_signal=data[sid(3766)]["signal"])
261
262
    if "signal" in data[sid(25317)]:
263
        record(dell_signal=data[sid(25317)]["signal"])
264
            """)
265
266
        self.assertNotIn("dell_signal", results.columns)
267
        self.assertNotIn("ibm_signal", results.columns)
268
269
    def test_fetch_csv_with_pure_signal_file(self):
270
        results = self.run_algo(
271
            """
272
from zipline.api import fetch_csv, sid, record
273
274
def clean(df):
275
    return df.rename(columns={'Value':'cpi', 'Date':'date'})
276
277
def initialize(context):
278
    fetch_csv(
279
        'https://fake.urls.com/cpiaucsl_data.csv',
280
        symbol='urban',
281
        pre_func=clean,
282
        date_format='%Y-%m-%d'
283
        )
284
    context.stocks = [sid(3766), sid(25317)]
285
286
def handle_data(context, data):
287
    cur_cpi = data['urban']['cpi']
288
    record(cpi=cur_cpi)
289
            """)
290
291
        self.assertEqual(results["cpi"][-1], 203.1)
292
293
    def test_algo_fetch_csv(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
294
        results = self.run_algo(
295
            """
296
from zipline.api import fetch_csv, record, sid
297
298
def normalize(df):
299
    df['scaled'] = df['signal'] * 10
300
    return df
301
302
def initialize(context):
303
    fetch_csv('https://fake.urls.com/aapl_csv_data.csv',
304
            post_func=normalize)
305
    context.checked_name = False
306
307
def handle_data(context, data):
308
    record(
309
        signal=data[sid(24)]['signal'],
310
        scaled=data[sid(24)]['scaled'],
311
        price=data[sid(24)].price)
312
        """)
313
314
        self.assertEqual(5, results["signal"][-1])
315
        self.assertEqual(50, results["scaled"][-1])
316
        self.assertEqual(24, results["price"][-1])  # fake value
317
318
    def test_algo_fetch_csv_with_extra_symbols(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
319
        results = self.run_algo(
320
            """
321
from zipline.api import fetch_csv, record, sid
322
323
def normalize(df):
324
    df['scaled'] = df['signal'] * 10
325
    return df
326
327
def initialize(context):
328
    fetch_csv('https://fake.urls.com/aapl_ibm_csv_data.csv',
329
            post_func=normalize,
330
            mask=True)
331
332
def handle_data(context, data):
333
    if 'signal' in data[sid(24)]:
334
        record(
335
            signal=data[sid(24)]['signal'],
336
            scaled=data[sid(24)]['scaled'],
337
            price=data[sid(24)].price)
338
            """
339
        )
340
341
        self.assertEqual(5, results["signal"][-1])
342
        self.assertEqual(50, results["scaled"][-1])
343
        self.assertEqual(24, results["price"][-1])  # fake value
344
345
    @parameterized.expand([("unspecified", ""),
346
                           ("none", "usecols=None"),
347
                           ("empty", "usecols=[]"),
348
                           ("without date", "usecols=['Value']"),
349
                           ("with date", "usecols=('Value', 'Date')")])
350
    def test_usecols(self, testname, usecols):
351
        code = """
352
from zipline.api import fetch_csv, sid, record
353
354
def clean(df):
355
    return df.rename(columns={{'Value':'cpi'}})
356
357
def initialize(context):
358
    fetch_csv(
359
        'https://fake.urls.com/cpiaucsl_data.csv',
360
        symbol='urban',
361
        pre_func=clean,
362
        date_column='Date',
363
        date_format='%Y-%m-%d',{usecols}
364
        )
365
    context.stocks = [sid(3766), sid(25317)]
366
367
def handle_data(context, data):
368
    if {should_have_data}:
369
        assert 'cpi' in data['urban']
370
    else:
371
        assert 'cpi' not in data['urban']
372
        """
373
374
        results = self.run_algo(
375
            code.format(
376
                usecols=usecols,
377
                should_have_data=testname in [
378
                    'none',
379
                    'unspecified',
380
                    'without date',
381
                    'with date',
382
                ],
383
            )
384
        )
385
386
        # 251 trading days in 2006
387
        self.assertEqual(len(results), 251)
388
389
    def test_sources_merge_custom_ticker(self):
390
        requests_kwargs = {}
391
392
        def capture_kwargs(zelf, url, **kwargs):
393
            requests_kwargs.update(
394
                mask_requests_args(url, kwargs).requests_kwargs
395
            )
396
            return PALLADIUM_DATA
397
398
        # Patching fetch_url instead of using responses in this test so that we
399
        # can intercept the requests keyword arguments and confirm that they're
400
        # correct.
401
        with patch('zipline.sources.requests_csv.PandasRequestsCSV.fetch_url',
402
                   new=capture_kwargs):
403
            results = self.run_algo(
404
                """
405
from zipline.api import fetch_csv, record, sid
406
407
def rename_col(df):
408
    df = df.rename(columns={'New York 15:00': 'price'})
409
    df = df.fillna(method='ffill')
410
    return df[['price', 'sid']]
411
412
def initialize(context):
413
    fetch_csv('https://dl.dropbox.com/u/16705795/PALL.csv',
414
        date_column='Date',
415
        symbol='palladium',
416
        post_func=rename_col,
417
        date_format='%Y-%m-%d'
418
        )
419
    context.stock = sid(24)
420
421
def handle_data(context, data):
422
    palladium = data['palladium']
423
    aapl = data[context.stock]
424
    if 'price' in palladium:
425
        record(palladium=palladium.price)
426
    if 'price' in aapl:
427
        record(aapl=aapl.price)
428
        """)
429
430
            np.testing.assert_array_equal([24] * 251, results["aapl"])
431
            self.assertEqual(337, results["palladium"].iloc[-1])
432
433
            expected = {
434
                'allow_redirects': False,
435
                'stream': True,
436
                'timeout': 30.0,
437
            }
438
439
            self.assertEqual(expected, requests_kwargs)
440
441
    @parameterized.expand([("symbol", FETCHER_UNIVERSE_DATA, None),
442
                           ("arglebargle", FETCHER_UNIVERSE_DATA_TICKER_COLUMN,
443
                            FETCHER_ALTERNATE_COLUMN_HEADER)])
444
    def test_fetcher_universe(self, name, data, column_name):
445
        # Patching fetch_url here rather than using responses because (a) it's
446
        # easier given the paramaterization, and (b) there are enough tests
447
        # using responses that the fetch_url code is getting a good workout so
448
        # we don't have to use it in every test.
449
        with patch('zipline.sources.requests_csv.PandasRequestsCSV.fetch_url',
450
                   new=lambda *a, **k: data):
451
            sim_params = factory.create_simulation_parameters(
452
                start=pd.Timestamp("2006-01-09", tz='UTC'),
453
                end=pd.Timestamp("2006-01-11", tz='UTC')
454
            )
455
456
            algocode = """
457
from pandas import Timestamp
458
from zipline.api import fetch_csv, record, sid, get_datetime
459
460
def initialize(context):
461
    fetch_csv(
462
        'https://dl.dropbox.com/u/16705795/dtoc_history.csv',
463
        date_format='%m/%d/%Y'{token}
464
    )
465
    context.expected_sids = {{
466
        Timestamp('2006-01-09 00:00:00+0000', tz='UTC'):[24, 3766, 5061],
467
        Timestamp('2006-01-10 00:00:00+0000', tz='UTC'):[24, 3766, 5061],
468
        Timestamp('2006-01-11 00:00:00+0000', tz='UTC'):[24, 3766, 5061, 14848]
469
    }}
470
    context.bar_count = 0
471
472
def handle_data(context, data):
473
    expected = context.expected_sids[get_datetime()]
474
    actual = data.fetcher_assets
475
    for stk in expected:
476
        if stk not in actual:
477
            raise Exception(
478
                "{{stk}} is missing on dt={{dt}}".format(
479
                    stk=stk, dt=get_datetime()))
480
481
    record(sid_count=len(actual))
482
    record(bar_count=context.bar_count)
483
    context.bar_count += 1
484
            """
485
            replacement = ""
486
            if column_name:
487
                replacement = ",symbol_column='%s'\n" % column_name
488
            real_algocode = algocode.format(token=replacement)
489
490
            results = self.run_algo(real_algocode, sim_params=sim_params)
491
492
            self.assertEqual(len(results), 3)
493
            self.assertEqual(3, results["sid_count"].iloc[0])
494
            self.assertEqual(3, results["sid_count"].iloc[1])
495
            self.assertEqual(4, results["sid_count"].iloc[2])
496
497
    def test_fetcher_universe_non_security_return(self):
498
        sim_params = factory.create_simulation_parameters(
499
            start=pd.Timestamp("2006-01-09", tz='UTC'),
500
            end=pd.Timestamp("2006-01-10", tz='UTC')
501
        )
502
503
        self.run_algo(
504
            """
505
from zipline.api import fetch_csv
506
507
def initialize(context):
508
    fetch_csv(
509
        'https://fake.urls.com/bad_fetcher_universe_data.csv',
510
        date_format='%m/%d/%Y'
511
    )
512
513
def handle_data(context, data):
514
    if len(data.fetcher_assets) > 0:
515
        raise Exception("Shouldn't be any assets in fetcher_assets!")
516
            """,
517
            sim_params=sim_params,
518
        )
519
520
    def test_order_against_data(self):
521
        with self.assertRaises(UnsupportedOrderParameters):
522
            self.run_algo("""
523
from zipline.api import fetch_csv, order, sid
524
525
def rename_col(df):
526
    return df.rename(columns={'New York 15:00': 'price'})
527
528
def initialize(context):
529
    fetch_csv('https://fake.urls.com/palladium_data.csv',
530
        date_column='Date',
531
        symbol='palladium',
532
        post_func=rename_col,
533
        date_format='%Y-%m-%d'
534
        )
535
    context.stock = sid(24)
536
537
def handle_data(context, data):
538
    order('palladium', 100)
539
            """)
540
541
    def test_fetcher_universe_minute(self):
542
        sim_params = factory.create_simulation_parameters(
543
            start=pd.Timestamp("2006-01-09", tz='UTC'),
544
            end=pd.Timestamp("2006-01-11", tz='UTC'),
545
            data_frequency="minute"
546
        )
547
548
        results = self.run_algo(
549
            """
550
from pandas import Timestamp
551
from zipline.api import fetch_csv, record, get_datetime
552
553
def initialize(context):
554
    fetch_csv(
555
        'https://fake.urls.com/fetcher_universe_data.csv',
556
        date_format='%m/%d/%Y'
557
    )
558
    context.expected_sids = {
559
        Timestamp('2006-01-09 00:00:00+0000', tz='UTC'):[24, 3766, 5061],
560
        Timestamp('2006-01-10 00:00:00+0000', tz='UTC'):[24, 3766, 5061],
561
        Timestamp('2006-01-11 00:00:00+0000', tz='UTC'):[24, 3766, 5061, 14848]
562
    }
563
    context.bar_count = 0
564
565
def handle_data(context, data):
566
    expected = context.expected_sids[get_datetime().replace(hour=0, minute=0)]
567
    actual = data.fetcher_assets
568
    for stk in expected:
569
        if stk not in actual:
570
            raise Exception("{stk} is missing".format(stk=stk))
571
572
    record(sid_count=len(actual))
573
    record(bar_count=context.bar_count)
574
    context.bar_count += 1
575
        """, sim_params=sim_params, data_frequency="minute"
576
        )
577
578
        self.assertEqual(3, len(results))
579
        self.assertEqual(3, results["sid_count"].iloc[0])
580
        self.assertEqual(3, results["sid_count"].iloc[1])
581
        self.assertEqual(4, results["sid_count"].iloc[2])
582