Completed
Pull Request — master (#858)
by Eddie
05:34 queued 02:25
created

zipline.sources.PandasCSV.load_df()   F

Complexity

Conditions 10

Size

Total Lines 110

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 10
dl 0
loc 110
rs 3.1304

How to fix   Long Method    Complexity   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Complexity

Complex classes like zipline.sources.PandasCSV.load_df() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
from six import StringIO
2
from abc import ABCMeta, abstractmethod
3
from collections import namedtuple
4
import hashlib
5
from textwrap import dedent
6
import pandas as pd
7
from pandas import read_csv
8
import numpy
9
from logbook import Logger
10
import pytz
11
import warnings
12
import requests
13
14
from zipline.errors import (
15
    MultipleSymbolsFound,
16
    SymbolNotFound,
17
    ZiplineError)
18
from zipline.protocol import (
19
    DATASOURCE_TYPE,
20
    Event
21
)
22
from zipline.assets import Equity
23
24
logger = Logger('Requests Source Logger')
25
26
27
def roll_dts_to_midnight(dts, env):
28
    return pd.DatetimeIndex(
29
        (dts.tz_convert('US/Eastern') - pd.Timedelta(hours=16)).date,
30
        tz='UTC',
31
    ) + env.trading_day
32
33
34
class FetcherEvent(Event):
35
    pass
36
37
38
class FetcherCSVRedirectError(ZiplineError):
39
    msg = dedent(
40
        """\
41
        Attempt to fetch_csv from a redirected url. {url}
42
        must be changed to {new_url}
43
        """
44
    )
45
46
    def __init__(self, *args, **kwargs):
47
        self.url = kwargs["url"]
48
        self.new_url = kwargs["new_url"]
49
        self.extra = kwargs["extra"]
50
51
        super(FetcherCSVRedirectError, self).__init__(*args, **kwargs)
52
53
# The following optional arguments are supported for
54
# requests backed data sources.
55
# see http://docs.python-requests.org/en/latest/api/#main-interface
56
# for a full list.
57
ALLOWED_REQUESTS_KWARGS = {
58
    'params',
59
    'headers',
60
    'auth',
61
    'cert'}
62
63
64
# The following optional arguments are supported for pandas' read_csv
65
# function, and may be passed as kwargs to the datasource below.
66
# see http://pandas.pydata.org/
67
# pandas-docs/stable/generated/pandas.io.parsers.read_csv.html
68
ALLOWED_READ_CSV_KWARGS = {
69
    'sep',
70
    'dialect',
71
    'doublequote',
72
    'escapechar',
73
    'quotechar',
74
    'quoting',
75
    'skipinitialspace',
76
    'lineterminator',
77
    'header',
78
    'index_col',
79
    'names',
80
    'prefix',
81
    'skiprows',
82
    'skipfooter',
83
    'skip_footer',
84
    'na_values',
85
    'true_values',
86
    'false_values',
87
    'delimiter',
88
    'converters',
89
    'dtype',
90
    'delim_whitespace',
91
    'as_recarray',
92
    'na_filter',
93
    'compact_ints',
94
    'use_unsigned',
95
    'buffer_lines',
96
    'warn_bad_lines',
97
    'error_bad_lines',
98
    'keep_default_na',
99
    'thousands',
100
    'comment',
101
    'decimal',
102
    'keep_date_col',
103
    'nrows',
104
    'chunksize',
105
    'encoding',
106
    'usecols'
107
}
108
109
SHARED_REQUESTS_KWARGS = {
110
    'stream': True,
111
    'allow_redirects': False,
112
}
113
114
115
def mask_requests_args(url, validating=False, params_checker=None, **kwargs):
116
    requests_kwargs = {key: val for (key, val) in kwargs.iteritems()
117
                       if key in ALLOWED_REQUESTS_KWARGS}
118
    if params_checker is not None:
119
        url, s_params = params_checker(url)
120
        if s_params:
121
            if 'params' in requests_kwargs:
122
                requests_kwargs['params'].update(s_params)
123
            else:
124
                requests_kwargs['params'] = s_params
125
126
    # Giving the connection 30 seconds. This timeout does not
127
    # apply to the download of the response body.
128
    # (Note that Quandl links can take >10 seconds to return their
129
    # first byte on occasion)
130
    requests_kwargs['timeout'] = 1.0 if validating else 30.0
131
    requests_kwargs.update(SHARED_REQUESTS_KWARGS)
132
133
    request_pair = namedtuple("RequestPair", ("requests_kwargs", "url"))
134
    return request_pair(requests_kwargs, url)
135
136
137
class PandasCSV(object):
138
    __metaclass__ = ABCMeta
139
140
    def __init__(self,
141
                 pre_func,
142
                 post_func,
143
                 env,
144
                 start_date,
145
                 end_date,
146
                 date_column,
147
                 date_format,
148
                 timezone,
149
                 symbol,
150
                 mask,
151
                 symbol_column,
152
                 data_frequency,
153
                 **kwargs):
154
155
        self.start_date = start_date
156
        self.end_date = end_date
157
        self.date_column = date_column
158
        self.date_format = date_format
159
        self.timezone = timezone
160
        self.mask = mask
161
        self.symbol_column = symbol_column or "symbol"
162
        self.data_frequency = data_frequency
163
164
        invalid_kwargs = set(kwargs) - ALLOWED_READ_CSV_KWARGS
165
        if invalid_kwargs:
166
            raise TypeError(
167
                "Unexpected keyword arguments: %s" % invalid_kwargs,
168
            )
169
170
        self.pandas_kwargs = self.mask_pandas_args(kwargs)
171
172
        self.symbol = symbol
173
174
        self.env = env
175
        self.finder = env.asset_finder
176
177
        self.pre_func = pre_func
178
        self.post_func = post_func
179
180
    @property
181
    def fields(self):
182
        return self.df.columns.tolist()
183
184
    def get_hash(self):
185
        return self.namestring
186
187
    @abstractmethod
188
    def fetch_data(self):
189
        return
190
191
    @staticmethod
192
    def parse_date_str_series(format_str, tz, date_str_series, data_frequency,
193
                              env):
194
        """
195
        Efficient parsing for a 1d Pandas/numpy object containing string
196
        representations of dates.
197
198
        Note: pd.to_datetime is significantly faster when no format string is
199
        passed, and in pandas 0.12.0 the %p strptime directive is not correctly
200
        handled if a format string is explicitly passed, but AM/PM is handled
201
        properly if format=None.
202
203
        Moreover, we were previously ignoring this parameter unintentionally
204
        because we were incorrectly passing it as a positional.  For all these
205
        reasons, we ignore the format_str parameter when parsing datetimes.
206
        """
207
208
        # Explicitly ignoring this parameter.  See note above.
209
        if format_str is not None:
210
            logger.warn(
211
                "The 'format_str' parameter to fetch_csv is deprecated. "
212
                "Ignoring and defaulting to pandas default date parsing."
213
            )
214
            format_str = None
215
216
        tz_str = str(tz)
217
        if tz_str == pytz.utc.zone:
218
            parsed = pd.to_datetime(
219
                date_str_series.values,
220
                format=format_str,
221
                utc=True,
222
                coerce=True,
223
            )
224
        else:
225
            parsed = pd.to_datetime(
226
                date_str_series.values,
227
                format=format_str,
228
                coerce=True,
229
            ).tz_localize(tz_str).tz_convert('UTC')
230
231
        if data_frequency == 'daily':
232
            parsed = roll_dts_to_midnight(parsed, env)
233
        return parsed
234
235
    def mask_pandas_args(self, kwargs):
236
        pandas_kwargs = {key: val for (key, val) in kwargs.iteritems()
237
                         if key in ALLOWED_READ_CSV_KWARGS}
238
        if 'usecols' in pandas_kwargs:
239
            usecols = pandas_kwargs['usecols']
240
            if usecols and self.date_column not in usecols:
241
                # make a new list so we don't modify user's,
242
                # and to ensure it is mutable
243
                with_date = list(usecols)
244
                with_date.append(self.date_column)
245
                pandas_kwargs['usecols'] = with_date
246
247
        # No strings in the 'symbol' column should be interpreted as NaNs
248
        pandas_kwargs.setdefault('keep_default_na', False)
249
        pandas_kwargs.setdefault('na_values', {'symbol': []})
250
251
        return pandas_kwargs
252
253
    def _lookup_unconflicted_symbol(self, symbol):
254
        """
255
        Attempt to find a unique asset whose symbol is the given string.
256
257
        If multiple assets have held the given symbol, return a 0.
258
259
        If no asset has held the given symbol, return a  NaN.
260
        """
261
        try:
262
            uppered = symbol.upper()
263
        except AttributeError:
264
            # The mapping fails because symbol was a non-string
265
            return numpy.nan
266
267
        try:
268
            return self.finder.lookup_symbol(uppered, as_of_date=None)
269
        except MultipleSymbolsFound:
270
            # Fill conflicted entries with zeros to mark that they need to be
271
            # resolved by date.
272
            return 0
273
        except SymbolNotFound:
274
            # Fill not found entries with nans.
275
            return numpy.nan
276
277
    def load_df(self):
278
        df = self.fetch_data()
279
280
        if self.pre_func:
281
            df = self.pre_func(df)
282
283
        # Batch-convert the user-specifed date column into timestamps.
284
        df['dt'] = self.parse_date_str_series(
285
            self.date_format,
286
            self.timezone,
287
            df[self.date_column],
288
            self.data_frequency,
289
            self.env
290
        ).values
291
292
        # ignore rows whose dates we couldn't parse
293
        df = df[df['dt'].notnull()]
294
295
        if self.symbol is not None:
296
            df['sid'] = self.symbol
297
        elif self.finder:
298
299
            df.sort(self.symbol_column)
300
301
            # Pop the 'sid' column off of the DataFrame, just in case the user
302
            # has assigned it, and throw a warning
303
            try:
304
                df.pop('sid')
305
                warnings.warn(
306
                    "Assignment of the 'sid' column of a DataFrame is "
307
                    "not supported by Fetcher. The 'sid' column has been "
308
                    "overwritten.",
309
                    category=UserWarning,
310
                    stacklevel=2,
311
                )
312
            except KeyError:
313
                # There was no 'sid' column, so no warning is necessary
314
                pass
315
316
            # Fill entries for any symbols that don't require a date to
317
            # uniquely identify.  Entries for which multiple securities exist
318
            # are replaced with zeroes, while entries for which no asset
319
            # exists are replaced with NaNs.
320
            unique_symbols = df[self.symbol_column].unique()
321
            sid_series = pd.Series(
322
                data=map(self._lookup_unconflicted_symbol, unique_symbols),
323
                index=unique_symbols,
324
                name='sid',
325
            )
326
            df = df.join(sid_series, on=self.symbol_column)
327
328
            # Fill any zero entries left in our sid column by doing a lookup
329
            # using both symbol and the row date.
330
            conflict_rows = df[df['sid'] == 0]
331
            for row_idx, row in conflict_rows.iterrows():
332
                try:
333
                    asset = self.finder.lookup_symbol(
334
                        row[self.symbol_column],
335
                        # Replacing tzinfo here is necessary because of the
336
                        # timezone metadata bug described below.
337
                        row['dt'].replace(tzinfo=pytz.utc),
338
339
                        # It's possible that no asset comes back here if our
340
                        # lookup date is from before any asset held the
341
                        # requested symbol.  Mark such cases as NaN so that
342
                        # they get dropped in the next step.
343
                    ) or numpy.nan
344
                except SymbolNotFound:
345
                    asset = numpy.nan
346
347
                # Assign the resolved asset to the cell
348
                df.ix[row_idx, 'sid'] = asset
349
350
            # Filter out rows containing symbols that we failed to find.
351
            length_before_drop = len(df)
352
            df = df[df['sid'].notnull()]
353
            no_sid_count = length_before_drop - len(df)
354
            if no_sid_count:
355
                logger.warn(
356
                    "Dropped {} rows from fetched csv.".format(no_sid_count),
357
                    no_sid_count,
358
                    extra={'syslog': True},
359
                )
360
        else:
361
            df['sid'] = df['symbol']
362
363
        # Dates are localized to UTC when they come out of
364
        # parse_date_str_series, but we need to re-localize them here because
365
        # of a bug that wasn't fixed until
366
        # https://github.com/pydata/pandas/pull/7092.
367
        # We should be able to remove the call to tz_localize once we're on
368
        # pandas 0.14.0
369
370
        # We don't set 'dt' as the index until here because the Symbol parsing
371
        # operations above depend on having a unique index for the dataframe,
372
        # and the 'dt' column can contain multiple dates for the same entry.
373
        df.drop_duplicates(["sid", "dt"])
374
        df.set_index(['dt'], inplace=True)
375
        df = df.tz_localize('UTC')
376
        df.sort_index(inplace=True)
377
378
        cols_to_drop = [self.date_column]
379
        if self.symbol is None:
380
            cols_to_drop.append(self.symbol_column)
381
        df = df[df.columns.drop(cols_to_drop)]
382
383
        if self.post_func:
384
            df = self.post_func(df)
385
386
        return df
387
388
    def __iter__(self):
389
        asset_cache = {}
390
        for dt, series in self.df.iterrows():
391
            if dt < self.start_date:
392
                continue
393
394
            if dt > self.end_date:
395
                return
396
397
            event = FetcherEvent()
398
            # when dt column is converted to be the dataframe's index
399
            # the dt column is dropped. So, we need to manually copy
400
            # dt into the event.
401
            event.dt = dt
402
            for k, v in series.iteritems():
403
                # convert numpy integer types to
404
                # int. This assumes we are on a 64bit
405
                # platform that will not lose information
406
                # by casting.
407
                # TODO: this is only necessary on the
408
                # amazon qexec instances. would be good
409
                # to figure out how to use the numpy dtypes
410
                # without this check and casting.
411
                if isinstance(v, numpy.integer):
412
                    v = int(v)
413
414
                setattr(event, k, v)
415
416
            # If it has start_date, then it's already an Asset
417
            # object from asset_for_symbol, and we don't have to
418
            # transform it any further. Checking for start_date is
419
            # faster than isinstance.
420
            if event.sid in asset_cache:
421
                event.sid = asset_cache[event.sid]
422
            elif hasattr(event.sid, 'start_date'):
423
                # Clone for user algo code, if we haven't already.
424
                asset_cache[event.sid] = event.sid
425
            elif self.finder and isinstance(event.sid, int):
426
                asset = self.finder.retrieve_asset(event.sid,
427
                                                   default_none=True)
428
                if asset:
429
                    # Clone for user algo code.
430
                    event.sid = asset_cache[asset] = asset
431
                elif self.mask:
432
                    # When masking drop all non-mappable values.
433
                    continue
434
                elif self.symbol is None:
435
                    # If the event's sid property is an int we coerce
436
                    # it into an Equity.
437
                    event.sid = asset_cache[event.sid] = Equity(event.sid)
438
439
            event.type = DATASOURCE_TYPE.CUSTOM
440
            event.source_id = self.namestring
441
            yield event
442
443
444
class PandasRequestsCSV(PandasCSV):
445
    # maximum 100 megs to prevent DDoS
446
    MAX_DOCUMENT_SIZE = (1024 * 1024) * 100
447
448
    # maximum number of bytes to read in at a time
449
    CONTENT_CHUNK_SIZE = 4096
450
451
    def __init__(self,
452
                 url,
453
                 pre_func,
454
                 post_func,
455
                 env,
456
                 start_date,
457
                 end_date,
458
                 date_column,
459
                 date_format,
460
                 timezone,
461
                 symbol,
462
                 mask,
463
                 symbol_column,
464
                 data_frequency,
465
                 special_params_checker=None,
466
                 **kwargs):
467
468
        # Peel off extra requests kwargs, forwarding the remaining kwargs to
469
        # the superclass.
470
        # Also returns possible https updated url if sent to http quandl ds
471
        # If url hasn't changed, will just return the original.
472
        self._requests_kwargs, self.url =\
473
            mask_requests_args(url,
474
                               params_checker=special_params_checker,
475
                               **kwargs)
476
477
        remaining_kwargs = {
478
            k: v for k, v in kwargs.iteritems()
479
            if k not in self.requests_kwargs
480
        }
481
482
        self.namestring = type(self).__name__
483
484
        super(PandasRequestsCSV, self).__init__(
485
            pre_func,
486
            post_func,
487
            env,
488
            start_date,
489
            end_date,
490
            date_column,
491
            date_format,
492
            timezone,
493
            symbol,
494
            mask,
495
            symbol_column,
496
            data_frequency,
497
            **remaining_kwargs
498
        )
499
500
        self.fetch_size = None
501
        self.fetch_hash = None
502
        self.df = self.load_df()
503
504
        self.special_params_checker = special_params_checker
505
506
    @property
507
    def requests_kwargs(self):
508
        return self._requests_kwargs
509
510
    def fetch_url(self, url):
511
        info = "checking {url} with {params}"
512
        logger.info(info.format(url=url, params=self.requests_kwargs))
513
        # setting decode_unicode=True sometimes results in a
514
        # UnicodeEncodeError exception, so instead we'll use
515
        # pandas logic for decoding content
516
        try:
517
            response = requests.get(url, **self.requests_kwargs)
518
        except requests.exceptions.ConnectionError:
519
            raise Exception('Could not connect to %s' % url)
520
521
        if not response.ok:
522
            raise Exception('Problem reaching %s' % url)
523
        elif response.is_redirect:
524
            # On the offchance we don't catch a redirect URL
525
            # in validation, this will catch it.
526
            new_url = response.headers['location']
527
            raise FetcherCSVRedirectError(
528
                url=url,
529
                new_url=new_url,
530
                extra={
531
                    'old_url': url,
532
                    'new_url': new_url
533
                }
534
            )
535
536
        content_length = 0
537
        logger.info('{} connection established in {:.1f} seconds'.format(
538
            url, response.elapsed.total_seconds()))
539
        for chunk in response.iter_content(self.CONTENT_CHUNK_SIZE):
540
            if content_length > self.MAX_DOCUMENT_SIZE:
541
                raise Exception('Document size too big.')
542
            if chunk:
543
                content_length += len(chunk)
544
                yield chunk
545
546
        return
547
548
    def fetch_data(self):
549
        # create a data frame directly from the full text of
550
        # the response from the returned file-descriptor.
551
        data = self.fetch_url(self.url)
552
        fd = StringIO()
553
554
        for chunk in data:
555
            fd.write(chunk)
556
        self.fetch_size = fd.tell()
557
558
        fd.seek(0)
559
560
        try:
561
            # see if pandas can parse csv data
562
            frames = read_csv(fd, **self.pandas_kwargs)
563
            frames_hash = hashlib.md5(fd.getvalue())
564
            self.fetch_hash = frames_hash.hexdigest()
565
        except pd.parser.CParserError:
566
            # could not parse the data, raise exception
567
            raise Exception('Error parsing remote CSV data.')
568
        finally:
569
            fd.close()
570
571
        return frames
572