Total Complexity | 42 |
Total Lines | 305 |
Duplicated Lines | 0 % |
Complex classes like zipline.sources.PandasCSV often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
1 | from six import StringIO, iteritems |
||
138 | class PandasCSV(object): |
||
139 | __metaclass__ = ABCMeta |
||
140 | |||
141 | def __init__(self, |
||
142 | pre_func, |
||
143 | post_func, |
||
144 | env, |
||
145 | start_date, |
||
146 | end_date, |
||
147 | date_column, |
||
148 | date_format, |
||
149 | timezone, |
||
150 | symbol, |
||
151 | mask, |
||
152 | symbol_column, |
||
153 | data_frequency, |
||
154 | **kwargs): |
||
155 | |||
156 | self.start_date = start_date |
||
157 | self.end_date = end_date |
||
158 | self.date_column = date_column |
||
159 | self.date_format = date_format |
||
160 | self.timezone = timezone |
||
161 | self.mask = mask |
||
162 | self.symbol_column = symbol_column or "symbol" |
||
163 | self.data_frequency = data_frequency |
||
164 | |||
165 | invalid_kwargs = set(kwargs) - ALLOWED_READ_CSV_KWARGS |
||
166 | if invalid_kwargs: |
||
167 | raise TypeError( |
||
168 | "Unexpected keyword arguments: %s" % invalid_kwargs, |
||
169 | ) |
||
170 | |||
171 | self.pandas_kwargs = self.mask_pandas_args(kwargs) |
||
172 | |||
173 | self.symbol = symbol |
||
174 | |||
175 | self.env = env |
||
176 | self.finder = env.asset_finder |
||
177 | |||
178 | self.pre_func = pre_func |
||
179 | self.post_func = post_func |
||
180 | |||
181 | @property |
||
182 | def fields(self): |
||
183 | return self.df.columns.tolist() |
||
184 | |||
185 | def get_hash(self): |
||
186 | return self.namestring |
||
187 | |||
188 | @abstractmethod |
||
189 | def fetch_data(self): |
||
190 | return |
||
191 | |||
192 | @staticmethod |
||
193 | def parse_date_str_series(format_str, tz, date_str_series, data_frequency, |
||
194 | env): |
||
195 | """ |
||
196 | Efficient parsing for a 1d Pandas/numpy object containing string |
||
197 | representations of dates. |
||
198 | |||
199 | Note: pd.to_datetime is significantly faster when no format string is |
||
200 | passed, and in pandas 0.12.0 the %p strptime directive is not correctly |
||
201 | handled if a format string is explicitly passed, but AM/PM is handled |
||
202 | properly if format=None. |
||
203 | |||
204 | Moreover, we were previously ignoring this parameter unintentionally |
||
205 | because we were incorrectly passing it as a positional. For all these |
||
206 | reasons, we ignore the format_str parameter when parsing datetimes. |
||
207 | """ |
||
208 | |||
209 | # Explicitly ignoring this parameter. See note above. |
||
210 | if format_str is not None: |
||
211 | logger.warn( |
||
212 | "The 'format_str' parameter to fetch_csv is deprecated. " |
||
213 | "Ignoring and defaulting to pandas default date parsing." |
||
214 | ) |
||
215 | format_str = None |
||
216 | |||
217 | tz_str = str(tz) |
||
218 | if tz_str == pytz.utc.zone: |
||
219 | parsed = pd.to_datetime( |
||
220 | date_str_series.values, |
||
221 | format=format_str, |
||
222 | utc=True, |
||
223 | coerce=True, |
||
224 | ) |
||
225 | else: |
||
226 | parsed = pd.to_datetime( |
||
227 | date_str_series.values, |
||
228 | format=format_str, |
||
229 | coerce=True, |
||
230 | ).tz_localize(tz_str).tz_convert('UTC') |
||
231 | |||
232 | if data_frequency == 'daily': |
||
233 | parsed = roll_dts_to_midnight(parsed, env) |
||
234 | return parsed |
||
235 | |||
236 | def mask_pandas_args(self, kwargs): |
||
237 | pandas_kwargs = {key: val for (key, val) in iteritems(kwargs) |
||
238 | if key in ALLOWED_READ_CSV_KWARGS} |
||
239 | if 'usecols' in pandas_kwargs: |
||
240 | usecols = pandas_kwargs['usecols'] |
||
241 | if usecols and self.date_column not in usecols: |
||
242 | # make a new list so we don't modify user's, |
||
243 | # and to ensure it is mutable |
||
244 | with_date = list(usecols) |
||
245 | with_date.append(self.date_column) |
||
246 | pandas_kwargs['usecols'] = with_date |
||
247 | |||
248 | # No strings in the 'symbol' column should be interpreted as NaNs |
||
249 | pandas_kwargs.setdefault('keep_default_na', False) |
||
250 | pandas_kwargs.setdefault('na_values', {'symbol': []}) |
||
251 | |||
252 | return pandas_kwargs |
||
253 | |||
254 | def _lookup_unconflicted_symbol(self, symbol): |
||
255 | """ |
||
256 | Attempt to find a unique asset whose symbol is the given string. |
||
257 | |||
258 | If multiple assets have held the given symbol, return a 0. |
||
259 | |||
260 | If no asset has held the given symbol, return a NaN. |
||
261 | """ |
||
262 | try: |
||
263 | uppered = symbol.upper() |
||
264 | except AttributeError: |
||
265 | # The mapping fails because symbol was a non-string |
||
266 | return numpy.nan |
||
267 | |||
268 | try: |
||
269 | return self.finder.lookup_symbol(uppered, as_of_date=None) |
||
270 | except MultipleSymbolsFound: |
||
271 | # Fill conflicted entries with zeros to mark that they need to be |
||
272 | # resolved by date. |
||
273 | return 0 |
||
274 | except SymbolNotFound: |
||
275 | # Fill not found entries with nans. |
||
276 | return numpy.nan |
||
277 | |||
278 | def load_df(self): |
||
279 | df = self.fetch_data() |
||
280 | |||
281 | if self.pre_func: |
||
282 | df = self.pre_func(df) |
||
283 | |||
284 | # Batch-convert the user-specifed date column into timestamps. |
||
285 | df['dt'] = self.parse_date_str_series( |
||
286 | self.date_format, |
||
287 | self.timezone, |
||
288 | df[self.date_column], |
||
289 | self.data_frequency, |
||
290 | self.env |
||
291 | ).values |
||
292 | |||
293 | # ignore rows whose dates we couldn't parse |
||
294 | df = df[df['dt'].notnull()] |
||
295 | |||
296 | if self.symbol is not None: |
||
297 | df['sid'] = self.symbol |
||
298 | elif self.finder: |
||
299 | |||
300 | df.sort(self.symbol_column) |
||
301 | |||
302 | # Pop the 'sid' column off of the DataFrame, just in case the user |
||
303 | # has assigned it, and throw a warning |
||
304 | try: |
||
305 | df.pop('sid') |
||
306 | warnings.warn( |
||
307 | "Assignment of the 'sid' column of a DataFrame is " |
||
308 | "not supported by Fetcher. The 'sid' column has been " |
||
309 | "overwritten.", |
||
310 | category=UserWarning, |
||
311 | stacklevel=2, |
||
312 | ) |
||
313 | except KeyError: |
||
314 | # There was no 'sid' column, so no warning is necessary |
||
315 | pass |
||
316 | |||
317 | # Fill entries for any symbols that don't require a date to |
||
318 | # uniquely identify. Entries for which multiple securities exist |
||
319 | # are replaced with zeroes, while entries for which no asset |
||
320 | # exists are replaced with NaNs. |
||
321 | unique_symbols = df[self.symbol_column].unique() |
||
322 | sid_series = pd.Series( |
||
323 | data=map(self._lookup_unconflicted_symbol, unique_symbols), |
||
324 | index=unique_symbols, |
||
325 | name='sid', |
||
326 | ) |
||
327 | df = df.join(sid_series, on=self.symbol_column) |
||
328 | |||
329 | # Fill any zero entries left in our sid column by doing a lookup |
||
330 | # using both symbol and the row date. |
||
331 | conflict_rows = df[df['sid'] == 0] |
||
332 | for row_idx, row in conflict_rows.iterrows(): |
||
333 | try: |
||
334 | asset = self.finder.lookup_symbol( |
||
335 | row[self.symbol_column], |
||
336 | # Replacing tzinfo here is necessary because of the |
||
337 | # timezone metadata bug described below. |
||
338 | row['dt'].replace(tzinfo=pytz.utc), |
||
339 | |||
340 | # It's possible that no asset comes back here if our |
||
341 | # lookup date is from before any asset held the |
||
342 | # requested symbol. Mark such cases as NaN so that |
||
343 | # they get dropped in the next step. |
||
344 | ) or numpy.nan |
||
345 | except SymbolNotFound: |
||
346 | asset = numpy.nan |
||
347 | |||
348 | # Assign the resolved asset to the cell |
||
349 | df.ix[row_idx, 'sid'] = asset |
||
350 | |||
351 | # Filter out rows containing symbols that we failed to find. |
||
352 | length_before_drop = len(df) |
||
353 | df = df[df['sid'].notnull()] |
||
354 | no_sid_count = length_before_drop - len(df) |
||
355 | if no_sid_count: |
||
356 | logger.warn( |
||
357 | "Dropped {} rows from fetched csv.".format(no_sid_count), |
||
358 | no_sid_count, |
||
359 | extra={'syslog': True}, |
||
360 | ) |
||
361 | else: |
||
362 | df['sid'] = df['symbol'] |
||
363 | |||
364 | # Dates are localized to UTC when they come out of |
||
365 | # parse_date_str_series, but we need to re-localize them here because |
||
366 | # of a bug that wasn't fixed until |
||
367 | # https://github.com/pydata/pandas/pull/7092. |
||
368 | # We should be able to remove the call to tz_localize once we're on |
||
369 | # pandas 0.14.0 |
||
370 | |||
371 | # We don't set 'dt' as the index until here because the Symbol parsing |
||
372 | # operations above depend on having a unique index for the dataframe, |
||
373 | # and the 'dt' column can contain multiple dates for the same entry. |
||
374 | df.drop_duplicates(["sid", "dt"]) |
||
375 | df.set_index(['dt'], inplace=True) |
||
376 | df = df.tz_localize('UTC') |
||
377 | df.sort_index(inplace=True) |
||
378 | |||
379 | cols_to_drop = [self.date_column] |
||
380 | if self.symbol is None: |
||
381 | cols_to_drop.append(self.symbol_column) |
||
382 | df = df[df.columns.drop(cols_to_drop)] |
||
383 | |||
384 | if self.post_func: |
||
385 | df = self.post_func(df) |
||
386 | |||
387 | return df |
||
388 | |||
389 | def __iter__(self): |
||
390 | asset_cache = {} |
||
391 | for dt, series in self.df.iterrows(): |
||
392 | if dt < self.start_date: |
||
393 | continue |
||
394 | |||
395 | if dt > self.end_date: |
||
396 | return |
||
397 | |||
398 | event = FetcherEvent() |
||
399 | # when dt column is converted to be the dataframe's index |
||
400 | # the dt column is dropped. So, we need to manually copy |
||
401 | # dt into the event. |
||
402 | event.dt = dt |
||
403 | for k, v in series.iteritems(): |
||
404 | # convert numpy integer types to |
||
405 | # int. This assumes we are on a 64bit |
||
406 | # platform that will not lose information |
||
407 | # by casting. |
||
408 | # TODO: this is only necessary on the |
||
409 | # amazon qexec instances. would be good |
||
410 | # to figure out how to use the numpy dtypes |
||
411 | # without this check and casting. |
||
412 | if isinstance(v, numpy.integer): |
||
413 | v = int(v) |
||
414 | |||
415 | setattr(event, k, v) |
||
416 | |||
417 | # If it has start_date, then it's already an Asset |
||
418 | # object from asset_for_symbol, and we don't have to |
||
419 | # transform it any further. Checking for start_date is |
||
420 | # faster than isinstance. |
||
421 | if event.sid in asset_cache: |
||
422 | event.sid = asset_cache[event.sid] |
||
423 | elif hasattr(event.sid, 'start_date'): |
||
424 | # Clone for user algo code, if we haven't already. |
||
425 | asset_cache[event.sid] = event.sid |
||
426 | elif self.finder and isinstance(event.sid, int): |
||
427 | asset = self.finder.retrieve_asset(event.sid, |
||
428 | default_none=True) |
||
429 | if asset: |
||
430 | # Clone for user algo code. |
||
431 | event.sid = asset_cache[asset] = asset |
||
432 | elif self.mask: |
||
433 | # When masking drop all non-mappable values. |
||
434 | continue |
||
435 | elif self.symbol is None: |
||
436 | # If the event's sid property is an int we coerce |
||
437 | # it into an Equity. |
||
438 | event.sid = asset_cache[event.sid] = Equity(event.sid) |
||
439 | |||
440 | event.type = DATASOURCE_TYPE.CUSTOM |
||
441 | event.source_id = self.namestring |
||
442 | yield event |
||
443 | |||
583 |