Completed
Pull Request — master (#858)
by Eddie
10:07 queued 01:13
created

zipline.sources.PandasCSV.__iter__()   F

Complexity

Conditions 13

Size

Total Lines 54

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 13
dl 0
loc 54
rs 3.5512

How to fix   Long Method    Complexity   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Complexity

Complex classes like zipline.sources.PandasCSV.__iter__() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
from six import StringIO, iteritems
2
from abc import ABCMeta, abstractmethod
3
from collections import namedtuple
4
import hashlib
5
from textwrap import dedent
6
import pandas as pd
7
from pandas import read_csv
8
import numpy
9
from logbook import Logger
10
import pytz
11
import warnings
12
import requests
13
14
from zipline.errors import (
15
    MultipleSymbolsFound,
16
    SymbolNotFound,
17
    ZiplineError
18
)
19
from zipline.protocol import (
20
    DATASOURCE_TYPE,
21
    Event
22
)
23
from zipline.assets import Equity
24
25
logger = Logger('Requests Source Logger')
26
27
28
def roll_dts_to_midnight(dts, env):
29
    return pd.DatetimeIndex(
30
        (dts.tz_convert('US/Eastern') - pd.Timedelta(hours=16)).date,
31
        tz='UTC',
32
    ) + env.trading_day
33
34
35
class FetcherEvent(Event):
36
    pass
37
38
39
class FetcherCSVRedirectError(ZiplineError):
40
    msg = dedent(
41
        """\
42
        Attempt to fetch_csv from a redirected url. {url}
43
        must be changed to {new_url}
44
        """
45
    )
46
47
    def __init__(self, *args, **kwargs):
48
        self.url = kwargs["url"]
49
        self.new_url = kwargs["new_url"]
50
        self.extra = kwargs["extra"]
51
52
        super(FetcherCSVRedirectError, self).__init__(*args, **kwargs)
53
54
# The following optional arguments are supported for
55
# requests backed data sources.
56
# see http://docs.python-requests.org/en/latest/api/#main-interface
57
# for a full list.
58
ALLOWED_REQUESTS_KWARGS = {
59
    'params',
60
    'headers',
61
    'auth',
62
    'cert'}
63
64
65
# The following optional arguments are supported for pandas' read_csv
66
# function, and may be passed as kwargs to the datasource below.
67
# see http://pandas.pydata.org/
68
# pandas-docs/stable/generated/pandas.io.parsers.read_csv.html
69
ALLOWED_READ_CSV_KWARGS = {
70
    'sep',
71
    'dialect',
72
    'doublequote',
73
    'escapechar',
74
    'quotechar',
75
    'quoting',
76
    'skipinitialspace',
77
    'lineterminator',
78
    'header',
79
    'index_col',
80
    'names',
81
    'prefix',
82
    'skiprows',
83
    'skipfooter',
84
    'skip_footer',
85
    'na_values',
86
    'true_values',
87
    'false_values',
88
    'delimiter',
89
    'converters',
90
    'dtype',
91
    'delim_whitespace',
92
    'as_recarray',
93
    'na_filter',
94
    'compact_ints',
95
    'use_unsigned',
96
    'buffer_lines',
97
    'warn_bad_lines',
98
    'error_bad_lines',
99
    'keep_default_na',
100
    'thousands',
101
    'comment',
102
    'decimal',
103
    'keep_date_col',
104
    'nrows',
105
    'chunksize',
106
    'encoding',
107
    'usecols'
108
}
109
110
SHARED_REQUESTS_KWARGS = {
111
    'stream': True,
112
    'allow_redirects': False,
113
}
114
115
116
def mask_requests_args(url, validating=False, params_checker=None, **kwargs):
117
    requests_kwargs = {key: val for (key, val) in iteritems(kwargs)
118
                       if key in ALLOWED_REQUESTS_KWARGS}
119
    if params_checker is not None:
120
        url, s_params = params_checker(url)
121
        if s_params:
122
            if 'params' in requests_kwargs:
123
                requests_kwargs['params'].update(s_params)
124
            else:
125
                requests_kwargs['params'] = s_params
126
127
    # Giving the connection 30 seconds. This timeout does not
128
    # apply to the download of the response body.
129
    # (Note that Quandl links can take >10 seconds to return their
130
    # first byte on occasion)
131
    requests_kwargs['timeout'] = 1.0 if validating else 30.0
132
    requests_kwargs.update(SHARED_REQUESTS_KWARGS)
133
134
    request_pair = namedtuple("RequestPair", ("requests_kwargs", "url"))
135
    return request_pair(requests_kwargs, url)
136
137
138
class PandasCSV(object):
139
    __metaclass__ = ABCMeta
140
141
    def __init__(self,
142
                 pre_func,
143
                 post_func,
144
                 env,
145
                 start_date,
146
                 end_date,
147
                 date_column,
148
                 date_format,
149
                 timezone,
150
                 symbol,
151
                 mask,
152
                 symbol_column,
153
                 data_frequency,
154
                 **kwargs):
155
156
        self.start_date = start_date
157
        self.end_date = end_date
158
        self.date_column = date_column
159
        self.date_format = date_format
160
        self.timezone = timezone
161
        self.mask = mask
162
        self.symbol_column = symbol_column or "symbol"
163
        self.data_frequency = data_frequency
164
165
        invalid_kwargs = set(kwargs) - ALLOWED_READ_CSV_KWARGS
166
        if invalid_kwargs:
167
            raise TypeError(
168
                "Unexpected keyword arguments: %s" % invalid_kwargs,
169
            )
170
171
        self.pandas_kwargs = self.mask_pandas_args(kwargs)
172
173
        self.symbol = symbol
174
175
        self.env = env
176
        self.finder = env.asset_finder
177
178
        self.pre_func = pre_func
179
        self.post_func = post_func
180
181
    @property
182
    def fields(self):
183
        return self.df.columns.tolist()
184
185
    def get_hash(self):
186
        return self.namestring
187
188
    @abstractmethod
189
    def fetch_data(self):
190
        return
191
192
    @staticmethod
193
    def parse_date_str_series(format_str, tz, date_str_series, data_frequency,
194
                              env):
195
        """
196
        Efficient parsing for a 1d Pandas/numpy object containing string
197
        representations of dates.
198
199
        Note: pd.to_datetime is significantly faster when no format string is
200
        passed, and in pandas 0.12.0 the %p strptime directive is not correctly
201
        handled if a format string is explicitly passed, but AM/PM is handled
202
        properly if format=None.
203
204
        Moreover, we were previously ignoring this parameter unintentionally
205
        because we were incorrectly passing it as a positional.  For all these
206
        reasons, we ignore the format_str parameter when parsing datetimes.
207
        """
208
209
        # Explicitly ignoring this parameter.  See note above.
210
        if format_str is not None:
211
            logger.warn(
212
                "The 'format_str' parameter to fetch_csv is deprecated. "
213
                "Ignoring and defaulting to pandas default date parsing."
214
            )
215
            format_str = None
216
217
        tz_str = str(tz)
218
        if tz_str == pytz.utc.zone:
219
            parsed = pd.to_datetime(
220
                date_str_series.values,
221
                format=format_str,
222
                utc=True,
223
                coerce=True,
224
            )
225
        else:
226
            parsed = pd.to_datetime(
227
                date_str_series.values,
228
                format=format_str,
229
                coerce=True,
230
            ).tz_localize(tz_str).tz_convert('UTC')
231
232
        if data_frequency == 'daily':
233
            parsed = roll_dts_to_midnight(parsed, env)
234
        return parsed
235
236
    def mask_pandas_args(self, kwargs):
237
        pandas_kwargs = {key: val for (key, val) in iteritems(kwargs)
238
                         if key in ALLOWED_READ_CSV_KWARGS}
239
        if 'usecols' in pandas_kwargs:
240
            usecols = pandas_kwargs['usecols']
241
            if usecols and self.date_column not in usecols:
242
                # make a new list so we don't modify user's,
243
                # and to ensure it is mutable
244
                with_date = list(usecols)
245
                with_date.append(self.date_column)
246
                pandas_kwargs['usecols'] = with_date
247
248
        # No strings in the 'symbol' column should be interpreted as NaNs
249
        pandas_kwargs.setdefault('keep_default_na', False)
250
        pandas_kwargs.setdefault('na_values', {'symbol': []})
251
252
        return pandas_kwargs
253
254
    def _lookup_unconflicted_symbol(self, symbol):
255
        """
256
        Attempt to find a unique asset whose symbol is the given string.
257
258
        If multiple assets have held the given symbol, return a 0.
259
260
        If no asset has held the given symbol, return a  NaN.
261
        """
262
        try:
263
            uppered = symbol.upper()
264
        except AttributeError:
265
            # The mapping fails because symbol was a non-string
266
            return numpy.nan
267
268
        try:
269
            return self.finder.lookup_symbol(uppered, as_of_date=None)
270
        except MultipleSymbolsFound:
271
            # Fill conflicted entries with zeros to mark that they need to be
272
            # resolved by date.
273
            return 0
274
        except SymbolNotFound:
275
            # Fill not found entries with nans.
276
            return numpy.nan
277
278
    def load_df(self):
279
        df = self.fetch_data()
280
281
        if self.pre_func:
282
            df = self.pre_func(df)
283
284
        # Batch-convert the user-specifed date column into timestamps.
285
        df['dt'] = self.parse_date_str_series(
286
            self.date_format,
287
            self.timezone,
288
            df[self.date_column],
289
            self.data_frequency,
290
            self.env
291
        ).values
292
293
        # ignore rows whose dates we couldn't parse
294
        df = df[df['dt'].notnull()]
295
296
        if self.symbol is not None:
297
            df['sid'] = self.symbol
298
        elif self.finder:
299
300
            df.sort(self.symbol_column)
301
302
            # Pop the 'sid' column off of the DataFrame, just in case the user
303
            # has assigned it, and throw a warning
304
            try:
305
                df.pop('sid')
306
                warnings.warn(
307
                    "Assignment of the 'sid' column of a DataFrame is "
308
                    "not supported by Fetcher. The 'sid' column has been "
309
                    "overwritten.",
310
                    category=UserWarning,
311
                    stacklevel=2,
312
                )
313
            except KeyError:
314
                # There was no 'sid' column, so no warning is necessary
315
                pass
316
317
            # Fill entries for any symbols that don't require a date to
318
            # uniquely identify.  Entries for which multiple securities exist
319
            # are replaced with zeroes, while entries for which no asset
320
            # exists are replaced with NaNs.
321
            unique_symbols = df[self.symbol_column].unique()
322
            sid_series = pd.Series(
323
                data=map(self._lookup_unconflicted_symbol, unique_symbols),
324
                index=unique_symbols,
325
                name='sid',
326
            )
327
            df = df.join(sid_series, on=self.symbol_column)
328
329
            # Fill any zero entries left in our sid column by doing a lookup
330
            # using both symbol and the row date.
331
            conflict_rows = df[df['sid'] == 0]
332
            for row_idx, row in conflict_rows.iterrows():
333
                try:
334
                    asset = self.finder.lookup_symbol(
335
                        row[self.symbol_column],
336
                        # Replacing tzinfo here is necessary because of the
337
                        # timezone metadata bug described below.
338
                        row['dt'].replace(tzinfo=pytz.utc),
339
340
                        # It's possible that no asset comes back here if our
341
                        # lookup date is from before any asset held the
342
                        # requested symbol.  Mark such cases as NaN so that
343
                        # they get dropped in the next step.
344
                    ) or numpy.nan
345
                except SymbolNotFound:
346
                    asset = numpy.nan
347
348
                # Assign the resolved asset to the cell
349
                df.ix[row_idx, 'sid'] = asset
350
351
            # Filter out rows containing symbols that we failed to find.
352
            length_before_drop = len(df)
353
            df = df[df['sid'].notnull()]
354
            no_sid_count = length_before_drop - len(df)
355
            if no_sid_count:
356
                logger.warn(
357
                    "Dropped {} rows from fetched csv.".format(no_sid_count),
358
                    no_sid_count,
359
                    extra={'syslog': True},
360
                )
361
        else:
362
            df['sid'] = df['symbol']
363
364
        # Dates are localized to UTC when they come out of
365
        # parse_date_str_series, but we need to re-localize them here because
366
        # of a bug that wasn't fixed until
367
        # https://github.com/pydata/pandas/pull/7092.
368
        # We should be able to remove the call to tz_localize once we're on
369
        # pandas 0.14.0
370
371
        # We don't set 'dt' as the index until here because the Symbol parsing
372
        # operations above depend on having a unique index for the dataframe,
373
        # and the 'dt' column can contain multiple dates for the same entry.
374
        df.drop_duplicates(["sid", "dt"])
375
        df.set_index(['dt'], inplace=True)
376
        df = df.tz_localize('UTC')
377
        df.sort_index(inplace=True)
378
379
        cols_to_drop = [self.date_column]
380
        if self.symbol is None:
381
            cols_to_drop.append(self.symbol_column)
382
        df = df[df.columns.drop(cols_to_drop)]
383
384
        if self.post_func:
385
            df = self.post_func(df)
386
387
        return df
388
389
    def __iter__(self):
390
        asset_cache = {}
391
        for dt, series in self.df.iterrows():
392
            if dt < self.start_date:
393
                continue
394
395
            if dt > self.end_date:
396
                return
397
398
            event = FetcherEvent()
399
            # when dt column is converted to be the dataframe's index
400
            # the dt column is dropped. So, we need to manually copy
401
            # dt into the event.
402
            event.dt = dt
403
            for k, v in series.iteritems():
404
                # convert numpy integer types to
405
                # int. This assumes we are on a 64bit
406
                # platform that will not lose information
407
                # by casting.
408
                # TODO: this is only necessary on the
409
                # amazon qexec instances. would be good
410
                # to figure out how to use the numpy dtypes
411
                # without this check and casting.
412
                if isinstance(v, numpy.integer):
413
                    v = int(v)
414
415
                setattr(event, k, v)
416
417
            # If it has start_date, then it's already an Asset
418
            # object from asset_for_symbol, and we don't have to
419
            # transform it any further. Checking for start_date is
420
            # faster than isinstance.
421
            if event.sid in asset_cache:
422
                event.sid = asset_cache[event.sid]
423
            elif hasattr(event.sid, 'start_date'):
424
                # Clone for user algo code, if we haven't already.
425
                asset_cache[event.sid] = event.sid
426
            elif self.finder and isinstance(event.sid, int):
427
                asset = self.finder.retrieve_asset(event.sid,
428
                                                   default_none=True)
429
                if asset:
430
                    # Clone for user algo code.
431
                    event.sid = asset_cache[asset] = asset
432
                elif self.mask:
433
                    # When masking drop all non-mappable values.
434
                    continue
435
                elif self.symbol is None:
436
                    # If the event's sid property is an int we coerce
437
                    # it into an Equity.
438
                    event.sid = asset_cache[event.sid] = Equity(event.sid)
439
440
            event.type = DATASOURCE_TYPE.CUSTOM
441
            event.source_id = self.namestring
442
            yield event
443
444
445
class PandasRequestsCSV(PandasCSV):
446
    # maximum 100 megs to prevent DDoS
447
    MAX_DOCUMENT_SIZE = (1024 * 1024) * 100
448
449
    # maximum number of bytes to read in at a time
450
    CONTENT_CHUNK_SIZE = 4096
451
452
    def __init__(self,
453
                 url,
454
                 pre_func,
455
                 post_func,
456
                 env,
457
                 start_date,
458
                 end_date,
459
                 date_column,
460
                 date_format,
461
                 timezone,
462
                 symbol,
463
                 mask,
464
                 symbol_column,
465
                 data_frequency,
466
                 special_params_checker=None,
467
                 **kwargs):
468
469
        # Peel off extra requests kwargs, forwarding the remaining kwargs to
470
        # the superclass.
471
        # Also returns possible https updated url if sent to http quandl ds
472
        # If url hasn't changed, will just return the original.
473
        self._requests_kwargs, self.url =\
474
            mask_requests_args(url,
475
                               params_checker=special_params_checker,
476
                               **kwargs)
477
478
        remaining_kwargs = {
479
            k: v for k, v in iteritems(kwargs)
480
            if k not in self.requests_kwargs
481
        }
482
483
        self.namestring = type(self).__name__
484
485
        super(PandasRequestsCSV, self).__init__(
486
            pre_func,
487
            post_func,
488
            env,
489
            start_date,
490
            end_date,
491
            date_column,
492
            date_format,
493
            timezone,
494
            symbol,
495
            mask,
496
            symbol_column,
497
            data_frequency,
498
            **remaining_kwargs
499
        )
500
501
        self.fetch_size = None
502
        self.fetch_hash = None
503
504
        self.df = self.load_df()
505
506
        self.special_params_checker = special_params_checker
507
508
    @property
509
    def requests_kwargs(self):
510
        return self._requests_kwargs
511
512
    def fetch_url(self, url):
513
        info = "checking {url} with {params}"
514
        logger.info(info.format(url=url, params=self.requests_kwargs))
515
        # setting decode_unicode=True sometimes results in a
516
        # UnicodeEncodeError exception, so instead we'll use
517
        # pandas logic for decoding content
518
        try:
519
            response = requests.get(url, **self.requests_kwargs)
520
        except requests.exceptions.ConnectionError:
521
            raise Exception('Could not connect to %s' % url)
522
523
        if not response.ok:
524
            raise Exception('Problem reaching %s' % url)
525
        elif response.is_redirect:
526
            # On the offchance we don't catch a redirect URL
527
            # in validation, this will catch it.
528
            new_url = response.headers['location']
529
            raise FetcherCSVRedirectError(
530
                url=url,
531
                new_url=new_url,
532
                extra={
533
                    'old_url': url,
534
                    'new_url': new_url
535
                }
536
            )
537
538
        content_length = 0
539
        logger.info('{} connection established in {:.1f} seconds'.format(
540
            url, response.elapsed.total_seconds()))
541
542
        # use the decode_unicode flag to ensure that the output of this is
543
        # a string, and not bytes.
544
        for chunk in response.iter_content(self.CONTENT_CHUNK_SIZE,
545
                                           decode_unicode=True):
546
            if content_length > self.MAX_DOCUMENT_SIZE:
547
                raise Exception('Document size too big.')
548
            if chunk:
549
                content_length += len(chunk)
550
                yield chunk
551
552
        return
553
554
    def fetch_data(self):
555
        # create a data frame directly from the full text of
556
        # the response from the returned file-descriptor.
557
        data = self.fetch_url(self.url)
558
        fd = StringIO()
559
560
        if isinstance(data, str):
561
            fd.write(data)
562
        else:
563
            for chunk in data:
564
                fd.write(chunk)
565
566
        self.fetch_size = fd.tell()
567
568
        fd.seek(0)
569
570
        try:
571
            # see if pandas can parse csv data
572
            frames = read_csv(fd, **self.pandas_kwargs)
573
574
            frames_hash = hashlib.md5(str(fd.getvalue()).encode('utf-8'))
575
            self.fetch_hash = frames_hash.hexdigest()
576
        except pd.parser.CParserError:
577
            # could not parse the data, raise exception
578
            raise Exception('Error parsing remote CSV data.')
579
        finally:
580
            fd.close()
581
582
        return frames
583