| Total Complexity | 42 |
| Total Lines | 305 |
| Duplicated Lines | 0 % |
Complex classes like zipline.sources.PandasCSV often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
| 1 | from six import StringIO, iteritems |
||
| 138 | class PandasCSV(object): |
||
| 139 | __metaclass__ = ABCMeta |
||
| 140 | |||
| 141 | def __init__(self, |
||
| 142 | pre_func, |
||
| 143 | post_func, |
||
| 144 | env, |
||
| 145 | start_date, |
||
| 146 | end_date, |
||
| 147 | date_column, |
||
| 148 | date_format, |
||
| 149 | timezone, |
||
| 150 | symbol, |
||
| 151 | mask, |
||
| 152 | symbol_column, |
||
| 153 | data_frequency, |
||
| 154 | **kwargs): |
||
| 155 | |||
| 156 | self.start_date = start_date |
||
| 157 | self.end_date = end_date |
||
| 158 | self.date_column = date_column |
||
| 159 | self.date_format = date_format |
||
| 160 | self.timezone = timezone |
||
| 161 | self.mask = mask |
||
| 162 | self.symbol_column = symbol_column or "symbol" |
||
| 163 | self.data_frequency = data_frequency |
||
| 164 | |||
| 165 | invalid_kwargs = set(kwargs) - ALLOWED_READ_CSV_KWARGS |
||
| 166 | if invalid_kwargs: |
||
| 167 | raise TypeError( |
||
| 168 | "Unexpected keyword arguments: %s" % invalid_kwargs, |
||
| 169 | ) |
||
| 170 | |||
| 171 | self.pandas_kwargs = self.mask_pandas_args(kwargs) |
||
| 172 | |||
| 173 | self.symbol = symbol |
||
| 174 | |||
| 175 | self.env = env |
||
| 176 | self.finder = env.asset_finder |
||
| 177 | |||
| 178 | self.pre_func = pre_func |
||
| 179 | self.post_func = post_func |
||
| 180 | |||
| 181 | @property |
||
| 182 | def fields(self): |
||
| 183 | return self.df.columns.tolist() |
||
| 184 | |||
| 185 | def get_hash(self): |
||
| 186 | return self.namestring |
||
| 187 | |||
| 188 | @abstractmethod |
||
| 189 | def fetch_data(self): |
||
| 190 | return |
||
| 191 | |||
| 192 | @staticmethod |
||
| 193 | def parse_date_str_series(format_str, tz, date_str_series, data_frequency, |
||
| 194 | env): |
||
| 195 | """ |
||
| 196 | Efficient parsing for a 1d Pandas/numpy object containing string |
||
| 197 | representations of dates. |
||
| 198 | |||
| 199 | Note: pd.to_datetime is significantly faster when no format string is |
||
| 200 | passed, and in pandas 0.12.0 the %p strptime directive is not correctly |
||
| 201 | handled if a format string is explicitly passed, but AM/PM is handled |
||
| 202 | properly if format=None. |
||
| 203 | |||
| 204 | Moreover, we were previously ignoring this parameter unintentionally |
||
| 205 | because we were incorrectly passing it as a positional. For all these |
||
| 206 | reasons, we ignore the format_str parameter when parsing datetimes. |
||
| 207 | """ |
||
| 208 | |||
| 209 | # Explicitly ignoring this parameter. See note above. |
||
| 210 | if format_str is not None: |
||
| 211 | logger.warn( |
||
| 212 | "The 'format_str' parameter to fetch_csv is deprecated. " |
||
| 213 | "Ignoring and defaulting to pandas default date parsing." |
||
| 214 | ) |
||
| 215 | format_str = None |
||
| 216 | |||
| 217 | tz_str = str(tz) |
||
| 218 | if tz_str == pytz.utc.zone: |
||
| 219 | parsed = pd.to_datetime( |
||
| 220 | date_str_series.values, |
||
| 221 | format=format_str, |
||
| 222 | utc=True, |
||
| 223 | coerce=True, |
||
| 224 | ) |
||
| 225 | else: |
||
| 226 | parsed = pd.to_datetime( |
||
| 227 | date_str_series.values, |
||
| 228 | format=format_str, |
||
| 229 | coerce=True, |
||
| 230 | ).tz_localize(tz_str).tz_convert('UTC') |
||
| 231 | |||
| 232 | if data_frequency == 'daily': |
||
| 233 | parsed = roll_dts_to_midnight(parsed, env) |
||
| 234 | return parsed |
||
| 235 | |||
| 236 | def mask_pandas_args(self, kwargs): |
||
| 237 | pandas_kwargs = {key: val for (key, val) in iteritems(kwargs) |
||
| 238 | if key in ALLOWED_READ_CSV_KWARGS} |
||
| 239 | if 'usecols' in pandas_kwargs: |
||
| 240 | usecols = pandas_kwargs['usecols'] |
||
| 241 | if usecols and self.date_column not in usecols: |
||
| 242 | # make a new list so we don't modify user's, |
||
| 243 | # and to ensure it is mutable |
||
| 244 | with_date = list(usecols) |
||
| 245 | with_date.append(self.date_column) |
||
| 246 | pandas_kwargs['usecols'] = with_date |
||
| 247 | |||
| 248 | # No strings in the 'symbol' column should be interpreted as NaNs |
||
| 249 | pandas_kwargs.setdefault('keep_default_na', False) |
||
| 250 | pandas_kwargs.setdefault('na_values', {'symbol': []}) |
||
| 251 | |||
| 252 | return pandas_kwargs |
||
| 253 | |||
| 254 | def _lookup_unconflicted_symbol(self, symbol): |
||
| 255 | """ |
||
| 256 | Attempt to find a unique asset whose symbol is the given string. |
||
| 257 | |||
| 258 | If multiple assets have held the given symbol, return a 0. |
||
| 259 | |||
| 260 | If no asset has held the given symbol, return a NaN. |
||
| 261 | """ |
||
| 262 | try: |
||
| 263 | uppered = symbol.upper() |
||
| 264 | except AttributeError: |
||
| 265 | # The mapping fails because symbol was a non-string |
||
| 266 | return numpy.nan |
||
| 267 | |||
| 268 | try: |
||
| 269 | return self.finder.lookup_symbol(uppered, as_of_date=None) |
||
| 270 | except MultipleSymbolsFound: |
||
| 271 | # Fill conflicted entries with zeros to mark that they need to be |
||
| 272 | # resolved by date. |
||
| 273 | return 0 |
||
| 274 | except SymbolNotFound: |
||
| 275 | # Fill not found entries with nans. |
||
| 276 | return numpy.nan |
||
| 277 | |||
| 278 | def load_df(self): |
||
| 279 | df = self.fetch_data() |
||
| 280 | |||
| 281 | if self.pre_func: |
||
| 282 | df = self.pre_func(df) |
||
| 283 | |||
| 284 | # Batch-convert the user-specifed date column into timestamps. |
||
| 285 | df['dt'] = self.parse_date_str_series( |
||
| 286 | self.date_format, |
||
| 287 | self.timezone, |
||
| 288 | df[self.date_column], |
||
| 289 | self.data_frequency, |
||
| 290 | self.env |
||
| 291 | ).values |
||
| 292 | |||
| 293 | # ignore rows whose dates we couldn't parse |
||
| 294 | df = df[df['dt'].notnull()] |
||
| 295 | |||
| 296 | if self.symbol is not None: |
||
| 297 | df['sid'] = self.symbol |
||
| 298 | elif self.finder: |
||
| 299 | |||
| 300 | df.sort(self.symbol_column) |
||
| 301 | |||
| 302 | # Pop the 'sid' column off of the DataFrame, just in case the user |
||
| 303 | # has assigned it, and throw a warning |
||
| 304 | try: |
||
| 305 | df.pop('sid') |
||
| 306 | warnings.warn( |
||
| 307 | "Assignment of the 'sid' column of a DataFrame is " |
||
| 308 | "not supported by Fetcher. The 'sid' column has been " |
||
| 309 | "overwritten.", |
||
| 310 | category=UserWarning, |
||
| 311 | stacklevel=2, |
||
| 312 | ) |
||
| 313 | except KeyError: |
||
| 314 | # There was no 'sid' column, so no warning is necessary |
||
| 315 | pass |
||
| 316 | |||
| 317 | # Fill entries for any symbols that don't require a date to |
||
| 318 | # uniquely identify. Entries for which multiple securities exist |
||
| 319 | # are replaced with zeroes, while entries for which no asset |
||
| 320 | # exists are replaced with NaNs. |
||
| 321 | unique_symbols = df[self.symbol_column].unique() |
||
| 322 | sid_series = pd.Series( |
||
| 323 | data=map(self._lookup_unconflicted_symbol, unique_symbols), |
||
| 324 | index=unique_symbols, |
||
| 325 | name='sid', |
||
| 326 | ) |
||
| 327 | df = df.join(sid_series, on=self.symbol_column) |
||
| 328 | |||
| 329 | # Fill any zero entries left in our sid column by doing a lookup |
||
| 330 | # using both symbol and the row date. |
||
| 331 | conflict_rows = df[df['sid'] == 0] |
||
| 332 | for row_idx, row in conflict_rows.iterrows(): |
||
| 333 | try: |
||
| 334 | asset = self.finder.lookup_symbol( |
||
| 335 | row[self.symbol_column], |
||
| 336 | # Replacing tzinfo here is necessary because of the |
||
| 337 | # timezone metadata bug described below. |
||
| 338 | row['dt'].replace(tzinfo=pytz.utc), |
||
| 339 | |||
| 340 | # It's possible that no asset comes back here if our |
||
| 341 | # lookup date is from before any asset held the |
||
| 342 | # requested symbol. Mark such cases as NaN so that |
||
| 343 | # they get dropped in the next step. |
||
| 344 | ) or numpy.nan |
||
| 345 | except SymbolNotFound: |
||
| 346 | asset = numpy.nan |
||
| 347 | |||
| 348 | # Assign the resolved asset to the cell |
||
| 349 | df.ix[row_idx, 'sid'] = asset |
||
| 350 | |||
| 351 | # Filter out rows containing symbols that we failed to find. |
||
| 352 | length_before_drop = len(df) |
||
| 353 | df = df[df['sid'].notnull()] |
||
| 354 | no_sid_count = length_before_drop - len(df) |
||
| 355 | if no_sid_count: |
||
| 356 | logger.warn( |
||
| 357 | "Dropped {} rows from fetched csv.".format(no_sid_count), |
||
| 358 | no_sid_count, |
||
| 359 | extra={'syslog': True}, |
||
| 360 | ) |
||
| 361 | else: |
||
| 362 | df['sid'] = df['symbol'] |
||
| 363 | |||
| 364 | # Dates are localized to UTC when they come out of |
||
| 365 | # parse_date_str_series, but we need to re-localize them here because |
||
| 366 | # of a bug that wasn't fixed until |
||
| 367 | # https://github.com/pydata/pandas/pull/7092. |
||
| 368 | # We should be able to remove the call to tz_localize once we're on |
||
| 369 | # pandas 0.14.0 |
||
| 370 | |||
| 371 | # We don't set 'dt' as the index until here because the Symbol parsing |
||
| 372 | # operations above depend on having a unique index for the dataframe, |
||
| 373 | # and the 'dt' column can contain multiple dates for the same entry. |
||
| 374 | df.drop_duplicates(["sid", "dt"]) |
||
| 375 | df.set_index(['dt'], inplace=True) |
||
| 376 | df = df.tz_localize('UTC') |
||
| 377 | df.sort_index(inplace=True) |
||
| 378 | |||
| 379 | cols_to_drop = [self.date_column] |
||
| 380 | if self.symbol is None: |
||
| 381 | cols_to_drop.append(self.symbol_column) |
||
| 382 | df = df[df.columns.drop(cols_to_drop)] |
||
| 383 | |||
| 384 | if self.post_func: |
||
| 385 | df = self.post_func(df) |
||
| 386 | |||
| 387 | return df |
||
| 388 | |||
| 389 | def __iter__(self): |
||
| 390 | asset_cache = {} |
||
| 391 | for dt, series in self.df.iterrows(): |
||
| 392 | if dt < self.start_date: |
||
| 393 | continue |
||
| 394 | |||
| 395 | if dt > self.end_date: |
||
| 396 | return |
||
| 397 | |||
| 398 | event = FetcherEvent() |
||
| 399 | # when dt column is converted to be the dataframe's index |
||
| 400 | # the dt column is dropped. So, we need to manually copy |
||
| 401 | # dt into the event. |
||
| 402 | event.dt = dt |
||
| 403 | for k, v in series.iteritems(): |
||
| 404 | # convert numpy integer types to |
||
| 405 | # int. This assumes we are on a 64bit |
||
| 406 | # platform that will not lose information |
||
| 407 | # by casting. |
||
| 408 | # TODO: this is only necessary on the |
||
| 409 | # amazon qexec instances. would be good |
||
| 410 | # to figure out how to use the numpy dtypes |
||
| 411 | # without this check and casting. |
||
| 412 | if isinstance(v, numpy.integer): |
||
| 413 | v = int(v) |
||
| 414 | |||
| 415 | setattr(event, k, v) |
||
| 416 | |||
| 417 | # If it has start_date, then it's already an Asset |
||
| 418 | # object from asset_for_symbol, and we don't have to |
||
| 419 | # transform it any further. Checking for start_date is |
||
| 420 | # faster than isinstance. |
||
| 421 | if event.sid in asset_cache: |
||
| 422 | event.sid = asset_cache[event.sid] |
||
| 423 | elif hasattr(event.sid, 'start_date'): |
||
| 424 | # Clone for user algo code, if we haven't already. |
||
| 425 | asset_cache[event.sid] = event.sid |
||
| 426 | elif self.finder and isinstance(event.sid, int): |
||
| 427 | asset = self.finder.retrieve_asset(event.sid, |
||
| 428 | default_none=True) |
||
| 429 | if asset: |
||
| 430 | # Clone for user algo code. |
||
| 431 | event.sid = asset_cache[asset] = asset |
||
| 432 | elif self.mask: |
||
| 433 | # When masking drop all non-mappable values. |
||
| 434 | continue |
||
| 435 | elif self.symbol is None: |
||
| 436 | # If the event's sid property is an int we coerce |
||
| 437 | # it into an Equity. |
||
| 438 | event.sid = asset_cache[event.sid] = Equity(event.sid) |
||
| 439 | |||
| 440 | event.type = DATASOURCE_TYPE.CUSTOM |
||
| 441 | event.source_id = self.namestring |
||
| 442 | yield event |
||
| 443 | |||
| 583 |