Total Complexity | 87 |
Total Lines | 797 |
Duplicated Lines | 0 % |
Complex classes like zipline.assets.AssetFinder often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
1 | # Copyright 2015 Quantopian, Inc. |
||
81 | class AssetFinder(object): |
||
82 | """ |
||
83 | An AssetFinder is an interface to a database of Asset metadata written by |
||
84 | an ``AssetDBWriter``. |
||
85 | |||
86 | This class provides methods for looking up assets by unique integer id or |
||
87 | by symbol. For historical reasons, we refer to these unique ids as 'sids'. |
||
88 | |||
89 | Parameters |
||
90 | ---------- |
||
91 | engine : str or SQLAlchemy.engine |
||
92 | An engine with a connection to the asset database to use, or a string |
||
93 | that can be parsed by SQLAlchemy as a URI. |
||
94 | |||
95 | See Also |
||
96 | -------- |
||
97 | :class:`zipline.assets.asset_writer.AssetDBWriter` |
||
98 | """ |
||
99 | # Token used as a substitute for pickling objects that contain a |
||
100 | # reference to an AssetFinder. |
||
101 | PERSISTENT_TOKEN = "<AssetFinder>" |
||
102 | |||
103 | def __init__(self, engine): |
||
104 | |||
105 | self.engine = engine |
||
106 | metadata = sa.MetaData(bind=engine) |
||
107 | metadata.reflect(only=asset_db_table_names) |
||
108 | for table_name in asset_db_table_names: |
||
109 | setattr(self, table_name, metadata.tables[table_name]) |
||
110 | |||
111 | # Check the version info of the db for compatibility |
||
112 | check_version_info(self.version_info, ASSET_DB_VERSION) |
||
113 | |||
114 | # Cache for lookup of assets by sid, the objects in the asset lookup |
||
115 | # may be shared with the results from equity and future lookup caches. |
||
116 | # |
||
117 | # The top level cache exists to minimize lookups on the asset type |
||
118 | # routing. |
||
119 | # |
||
120 | # The caches are read through, i.e. accessing an asset through |
||
121 | # retrieve_asset will populate the cache on first retrieval. |
||
122 | self._caches = (self._asset_cache, self._asset_type_cache) = {}, {} |
||
123 | |||
124 | # Populated on first call to `lifetimes`. |
||
125 | self._asset_lifetimes = None |
||
126 | |||
127 | def _reset_caches(self): |
||
128 | """ |
||
129 | Reset our asset caches. |
||
130 | |||
131 | You probably shouldn't call this method. |
||
132 | """ |
||
133 | # This method exists as a workaround for the in-place mutating behavior |
||
134 | # of `TradingAlgorithm._write_and_map_id_index_to_sids`. No one else |
||
135 | # should be calling this. |
||
136 | for cache in self._caches: |
||
137 | cache.clear() |
||
138 | |||
139 | def lookup_asset_types(self, sids): |
||
140 | """ |
||
141 | Retrieve asset types for a list of sids. |
||
142 | |||
143 | Parameters |
||
144 | ---------- |
||
145 | sids : list[int] |
||
146 | |||
147 | Returns |
||
148 | ------- |
||
149 | types : dict[sid -> str or None] |
||
150 | Asset types for the provided sids. |
||
151 | """ |
||
152 | found = {} |
||
153 | missing = set() |
||
154 | |||
155 | for sid in sids: |
||
156 | try: |
||
157 | found[sid] = self._asset_type_cache[sid] |
||
158 | except KeyError: |
||
159 | missing.add(sid) |
||
160 | |||
161 | if not missing: |
||
162 | return found |
||
163 | |||
164 | router_cols = self.asset_router.c |
||
165 | |||
166 | for assets in self._group_into_chunks(missing): |
||
167 | query = sa.select((router_cols.sid, router_cols.asset_type)).where( |
||
168 | self.asset_router.c.sid.in_(map(int, assets)) |
||
169 | ) |
||
170 | for sid, type_ in query.execute().fetchall(): |
||
171 | missing.remove(sid) |
||
172 | found[sid] = self._asset_type_cache[sid] = type_ |
||
173 | |||
174 | for sid in missing: |
||
175 | found[sid] = self._asset_type_cache[sid] = None |
||
176 | |||
177 | return found |
||
178 | |||
179 | @staticmethod |
||
180 | def _group_into_chunks(items, chunk_size=SQLITE_MAX_VARIABLE_NUMBER): |
||
181 | items = list(items) |
||
182 | return [items[x:x+chunk_size] |
||
183 | for x in range(0, len(items), chunk_size)] |
||
184 | |||
185 | def group_by_type(self, sids): |
||
186 | """ |
||
187 | Group a list of sids by asset type. |
||
188 | |||
189 | Parameters |
||
190 | ---------- |
||
191 | sids : list[int] |
||
192 | |||
193 | Returns |
||
194 | ------- |
||
195 | types : dict[str or None -> list[int]] |
||
196 | A dict mapping unique asset types to lists of sids drawn from sids. |
||
197 | If we fail to look up an asset, we assign it a key of None. |
||
198 | """ |
||
199 | return invert(self.lookup_asset_types(sids)) |
||
200 | |||
201 | def retrieve_asset(self, sid, default_none=False): |
||
202 | """ |
||
203 | Retrieve the Asset for a given sid. |
||
204 | """ |
||
205 | return self.retrieve_all((sid,), default_none=default_none)[0] |
||
206 | |||
207 | def retrieve_all(self, sids, default_none=False): |
||
208 | """ |
||
209 | Retrieve all assets in `sids`. |
||
210 | |||
211 | Parameters |
||
212 | ---------- |
||
213 | sids : interable of int |
||
214 | Assets to retrieve. |
||
215 | default_none : bool |
||
216 | If True, return None for failed lookups. |
||
217 | If False, raise `SidsNotFound`. |
||
218 | |||
219 | Returns |
||
220 | ------- |
||
221 | assets : list[int or None] |
||
222 | A list of the same length as `sids` containing Assets (or Nones) |
||
223 | corresponding to the requested sids. |
||
224 | |||
225 | Raises |
||
226 | ------ |
||
227 | SidsNotFound |
||
228 | When a requested sid is not found and default_none=False. |
||
229 | """ |
||
230 | hits, missing, failures = {}, set(), [] |
||
231 | for sid in sids: |
||
232 | try: |
||
233 | asset = self._asset_cache[sid] |
||
234 | if not default_none and asset is None: |
||
235 | # Bail early if we've already cached that we don't know |
||
236 | # about an asset. |
||
237 | raise SidsNotFound(sids=[sid]) |
||
238 | hits[sid] = asset |
||
239 | except KeyError: |
||
240 | missing.add(sid) |
||
241 | |||
242 | # All requests were cache hits. Return requested sids in order. |
||
243 | if not missing: |
||
244 | return [hits[sid] for sid in sids] |
||
245 | |||
246 | update_hits = hits.update |
||
247 | |||
248 | # Look up cache misses by type. |
||
249 | type_to_assets = self.group_by_type(missing) |
||
250 | |||
251 | # Handle failures |
||
252 | failures = {failure: None for failure in type_to_assets.pop(None, ())} |
||
253 | update_hits(failures) |
||
254 | self._asset_cache.update(failures) |
||
255 | |||
256 | if failures and not default_none: |
||
257 | raise SidsNotFound(sids=list(failures)) |
||
258 | |||
259 | # We don't update the asset cache here because it should already be |
||
260 | # updated by `self.retrieve_equities`. |
||
261 | update_hits(self.retrieve_equities(type_to_assets.pop('equity', ()))) |
||
262 | update_hits( |
||
263 | self.retrieve_futures_contracts(type_to_assets.pop('future', ())) |
||
264 | ) |
||
265 | |||
266 | # We shouldn't know about any other asset types. |
||
267 | if type_to_assets: |
||
268 | raise AssertionError( |
||
269 | "Found asset types: %s" % list(type_to_assets.keys()) |
||
270 | ) |
||
271 | |||
272 | return [hits[sid] for sid in sids] |
||
273 | |||
274 | def retrieve_equities(self, sids): |
||
275 | """ |
||
276 | Retrieve Equity objects for a list of sids. |
||
277 | |||
278 | Users generally shouldn't need to this method (instead, they should |
||
279 | prefer the more general/friendly `retrieve_assets`), but it has a |
||
280 | documented interface and tests because it's used upstream. |
||
281 | |||
282 | Parameters |
||
283 | ---------- |
||
284 | sids : iterable[int] |
||
285 | |||
286 | Returns |
||
287 | ------- |
||
288 | equities : dict[int -> Equity] |
||
289 | |||
290 | Raises |
||
291 | ------ |
||
292 | EquitiesNotFound |
||
293 | When any requested asset isn't found. |
||
294 | """ |
||
295 | return self._retrieve_assets(sids, self.equities, Equity) |
||
296 | |||
297 | def _retrieve_equity(self, sid): |
||
298 | return self.retrieve_equities((sid,))[sid] |
||
299 | |||
300 | def retrieve_futures_contracts(self, sids): |
||
301 | """ |
||
302 | Retrieve Future objects for an iterable of sids. |
||
303 | |||
304 | Users generally shouldn't need to this method (instead, they should |
||
305 | prefer the more general/friendly `retrieve_assets`), but it has a |
||
306 | documented interface and tests because it's used upstream. |
||
307 | |||
308 | Parameters |
||
309 | ---------- |
||
310 | sids : iterable[int] |
||
311 | |||
312 | Returns |
||
313 | ------- |
||
314 | equities : dict[int -> Equity] |
||
315 | |||
316 | Raises |
||
317 | ------ |
||
318 | EquitiesNotFound |
||
319 | When any requested asset isn't found. |
||
320 | """ |
||
321 | return self._retrieve_assets(sids, self.futures_contracts, Future) |
||
322 | |||
323 | @staticmethod |
||
324 | def _select_assets_by_sid(asset_tbl, sids): |
||
325 | return sa.select([asset_tbl]).where( |
||
326 | asset_tbl.c.sid.in_(map(int, sids)) |
||
327 | ) |
||
328 | |||
329 | @staticmethod |
||
330 | def _select_asset_by_symbol(asset_tbl, symbol): |
||
331 | return sa.select([asset_tbl]).where(asset_tbl.c.symbol == symbol) |
||
332 | |||
333 | def _retrieve_assets(self, sids, asset_tbl, asset_type): |
||
334 | """ |
||
335 | Internal function for loading assets from a table. |
||
336 | |||
337 | This should be the only method of `AssetFinder` that writes Assets into |
||
338 | self._asset_cache. |
||
339 | |||
340 | Parameters |
||
341 | --------- |
||
342 | sids : iterable of int |
||
343 | Asset ids to look up. |
||
344 | asset_tbl : sqlalchemy.Table |
||
345 | Table from which to query assets. |
||
346 | asset_type : type |
||
347 | Type of asset to be constructed. |
||
348 | |||
349 | Returns |
||
350 | ------- |
||
351 | assets : dict[int -> Asset] |
||
352 | Dict mapping requested sids to the retrieved assets. |
||
353 | """ |
||
354 | # Fastpath for empty request. |
||
355 | if not sids: |
||
356 | return {} |
||
357 | |||
358 | cache = self._asset_cache |
||
359 | hits = {} |
||
360 | |||
361 | for assets in self._group_into_chunks(sids): |
||
362 | # Load misses from the db. |
||
363 | query = self._select_assets_by_sid(asset_tbl, assets) |
||
364 | |||
365 | for row in imap(dict, query.execute().fetchall()): |
||
366 | asset = asset_type(**_convert_asset_timestamp_fields(row)) |
||
367 | sid = asset.sid |
||
368 | hits[sid] = cache[sid] = asset |
||
369 | |||
370 | # If we get here, it means something in our code thought that a |
||
371 | # particular sid was an equity/future and called this function with a |
||
372 | # concrete type, but we couldn't actually resolve the asset. This is |
||
373 | # an error in our code, not a user-input error. |
||
374 | misses = tuple(set(sids) - viewkeys(hits)) |
||
375 | if misses: |
||
376 | if asset_type == Equity: |
||
377 | raise EquitiesNotFound(sids=misses) |
||
378 | else: |
||
379 | raise FutureContractsNotFound(sids=misses) |
||
380 | return hits |
||
381 | |||
382 | def _get_fuzzy_candidates(self, fuzzy_symbol): |
||
383 | candidates = sa.select( |
||
384 | (self.equities.c.sid,) |
||
385 | ).where(self.equities.c.fuzzy_symbol == fuzzy_symbol).order_by( |
||
386 | self.equities.c.start_date.desc(), |
||
387 | self.equities.c.end_date.desc() |
||
388 | ).execute().fetchall() |
||
389 | return candidates |
||
390 | |||
391 | def _get_fuzzy_candidates_in_range(self, fuzzy_symbol, ad_value): |
||
392 | candidates = sa.select( |
||
393 | (self.equities.c.sid,) |
||
394 | ).where( |
||
395 | sa.and_( |
||
396 | self.equities.c.fuzzy_symbol == fuzzy_symbol, |
||
397 | self.equities.c.start_date <= ad_value, |
||
398 | self.equities.c.end_date >= ad_value |
||
399 | ) |
||
400 | ).order_by( |
||
401 | self.equities.c.start_date.desc(), |
||
402 | self.equities.c.end_date.desc(), |
||
403 | ).execute().fetchall() |
||
404 | return candidates |
||
405 | |||
406 | def _get_split_candidates_in_range(self, |
||
407 | company_symbol, |
||
408 | share_class_symbol, |
||
409 | ad_value): |
||
410 | candidates = sa.select( |
||
411 | (self.equities.c.sid,) |
||
412 | ).where( |
||
413 | sa.and_( |
||
414 | self.equities.c.company_symbol == company_symbol, |
||
415 | self.equities.c.share_class_symbol == share_class_symbol, |
||
416 | self.equities.c.start_date <= ad_value, |
||
417 | self.equities.c.end_date >= ad_value |
||
418 | ) |
||
419 | ).order_by( |
||
420 | self.equities.c.start_date.desc(), |
||
421 | self.equities.c.end_date.desc(), |
||
422 | ).execute().fetchall() |
||
423 | return candidates |
||
424 | |||
425 | def _get_split_candidates(self, company_symbol, share_class_symbol): |
||
426 | candidates = sa.select( |
||
427 | (self.equities.c.sid,) |
||
428 | ).where( |
||
429 | sa.and_( |
||
430 | self.equities.c.company_symbol == company_symbol, |
||
431 | self.equities.c.share_class_symbol == share_class_symbol |
||
432 | ) |
||
433 | ).order_by( |
||
434 | self.equities.c.start_date.desc(), |
||
435 | self.equities.c.end_date.desc(), |
||
436 | ).execute().fetchall() |
||
437 | return candidates |
||
438 | |||
439 | def _resolve_no_matching_candidates(self, |
||
440 | company_symbol, |
||
441 | share_class_symbol, |
||
442 | ad_value): |
||
443 | candidates = sa.select((self.equities.c.sid,)).where( |
||
444 | sa.and_( |
||
445 | self.equities.c.company_symbol == company_symbol, |
||
446 | self.equities.c.share_class_symbol == |
||
447 | share_class_symbol, |
||
448 | self.equities.c.start_date <= ad_value), |
||
449 | ).order_by( |
||
450 | self.equities.c.end_date.desc(), |
||
451 | ).execute().fetchall() |
||
452 | return candidates |
||
453 | |||
454 | def _get_best_candidate(self, candidates): |
||
455 | return self._retrieve_equity(candidates[0]['sid']) |
||
456 | |||
457 | def _get_equities_from_candidates(self, candidates): |
||
458 | sids = map(itemgetter('sid'), candidates) |
||
459 | results = self.retrieve_equities(sids) |
||
460 | return [results[sid] for sid in sids] |
||
461 | |||
462 | def lookup_symbol(self, symbol, as_of_date, fuzzy=False): |
||
463 | """ |
||
464 | Return matching Equity of name symbol in database. |
||
465 | |||
466 | If multiple Equities are found and as_of_date is not set, |
||
467 | raises MultipleSymbolsFound. |
||
468 | |||
469 | If no Equity was active at as_of_date raises SymbolNotFound. |
||
470 | """ |
||
471 | company_symbol, share_class_symbol, fuzzy_symbol = \ |
||
472 | split_delimited_symbol(symbol) |
||
473 | if as_of_date: |
||
474 | # Format inputs |
||
475 | as_of_date = pd.Timestamp(as_of_date).normalize() |
||
476 | ad_value = as_of_date.value |
||
477 | |||
478 | if fuzzy: |
||
479 | # Search for a single exact match on the fuzzy column |
||
480 | candidates = self._get_fuzzy_candidates_in_range(fuzzy_symbol, |
||
481 | ad_value) |
||
482 | |||
483 | # If exactly one SID exists for fuzzy_symbol, return that sid |
||
484 | if len(candidates) == 1: |
||
485 | return self._get_best_candidate(candidates) |
||
486 | |||
487 | # Search for exact matches of the split-up company_symbol and |
||
488 | # share_class_symbol |
||
489 | candidates = self._get_split_candidates_in_range( |
||
490 | company_symbol, |
||
491 | share_class_symbol, |
||
492 | ad_value |
||
493 | ) |
||
494 | |||
495 | # If exactly one SID exists for symbol, return that symbol |
||
496 | # If multiple SIDs exist for symbol, return latest start_date with |
||
497 | # end_date as a tie-breaker |
||
498 | if candidates: |
||
499 | return self._get_best_candidate(candidates) |
||
500 | |||
501 | # If no SID exists for symbol, return SID with the |
||
502 | # highest-but-not-over end_date |
||
503 | elif not candidates: |
||
504 | candidates = self._resolve_no_matching_candidates( |
||
505 | company_symbol, |
||
506 | share_class_symbol, |
||
507 | ad_value |
||
508 | ) |
||
509 | if candidates: |
||
510 | return self._get_best_candidate(candidates) |
||
511 | |||
512 | raise SymbolNotFound(symbol=symbol) |
||
513 | |||
514 | else: |
||
515 | # If this is a fuzzy look-up, check if there is exactly one match |
||
516 | # for the fuzzy symbol |
||
517 | if fuzzy: |
||
518 | candidates = self._get_fuzzy_candidates(fuzzy_symbol) |
||
519 | if len(candidates) == 1: |
||
520 | return self._get_best_candidate(candidates) |
||
521 | |||
522 | candidates = self._get_split_candidates(company_symbol, |
||
523 | share_class_symbol) |
||
524 | if len(candidates) == 1: |
||
525 | return self._get_best_candidate(candidates) |
||
526 | elif not candidates: |
||
527 | raise SymbolNotFound(symbol=symbol) |
||
528 | else: |
||
529 | raise MultipleSymbolsFound( |
||
530 | symbol=symbol, |
||
531 | options=self._get_equities_from_candidates(candidates) |
||
532 | ) |
||
533 | |||
534 | def lookup_future_symbol(self, symbol): |
||
535 | """ Return the Future object for a given symbol. |
||
536 | |||
537 | Parameters |
||
538 | ---------- |
||
539 | symbol : str |
||
540 | The symbol of the desired contract. |
||
541 | |||
542 | Returns |
||
543 | ------- |
||
544 | Future |
||
545 | A Future object. |
||
546 | |||
547 | Raises |
||
548 | ------ |
||
549 | SymbolNotFound |
||
550 | Raised when no contract named 'symbol' is found. |
||
551 | |||
552 | """ |
||
553 | |||
554 | data = self._select_asset_by_symbol(self.futures_contracts, symbol)\ |
||
555 | .execute().fetchone() |
||
556 | |||
557 | # If no data found, raise an exception |
||
558 | if not data: |
||
559 | raise SymbolNotFound(symbol=symbol) |
||
560 | return self.retrieve_asset(data['sid']) |
||
561 | |||
562 | def lookup_future_chain(self, root_symbol, as_of_date): |
||
563 | """ Return the futures chain for a given root symbol. |
||
564 | |||
565 | Parameters |
||
566 | ---------- |
||
567 | root_symbol : str |
||
568 | Root symbol of the desired future. |
||
569 | |||
570 | as_of_date : pd.Timestamp or pd.NaT |
||
571 | Date at which the chain determination is rooted. I.e. the |
||
572 | existing contract whose notice date/expiration date is first |
||
573 | after this date is the primary contract, etc. If NaT is |
||
574 | given, the chain is unbounded, and all contracts for this |
||
575 | root symbol are returned. |
||
576 | |||
577 | Returns |
||
578 | ------- |
||
579 | list |
||
580 | A list of Future objects, the chain for the given |
||
581 | parameters. |
||
582 | |||
583 | Raises |
||
584 | ------ |
||
585 | RootSymbolNotFound |
||
586 | Raised when a future chain could not be found for the given |
||
587 | root symbol. |
||
588 | """ |
||
589 | |||
590 | fc_cols = self.futures_contracts.c |
||
591 | |||
592 | if as_of_date is pd.NaT: |
||
593 | # If the as_of_date is NaT, get all contracts for this |
||
594 | # root symbol. |
||
595 | sids = list(map( |
||
596 | itemgetter('sid'), |
||
597 | sa.select((fc_cols.sid,)).where( |
||
598 | (fc_cols.root_symbol == root_symbol), |
||
599 | ).order_by( |
||
600 | fc_cols.notice_date.asc(), |
||
601 | ).execute().fetchall())) |
||
602 | else: |
||
603 | as_of_date = as_of_date.value |
||
604 | |||
605 | sids = list(map( |
||
606 | itemgetter('sid'), |
||
607 | sa.select((fc_cols.sid,)).where( |
||
608 | (fc_cols.root_symbol == root_symbol) & |
||
609 | |||
610 | # Filter to contracts that are still valid. If both |
||
611 | # exist, use the one that comes first in time (i.e. |
||
612 | # the lower value). If either notice_date or |
||
613 | # expiration_date is NaT, use the other. If both are |
||
614 | # NaT, the contract cannot be included in any chain. |
||
615 | sa.case( |
||
616 | [ |
||
617 | ( |
||
618 | fc_cols.notice_date == pd.NaT.value, |
||
619 | fc_cols.expiration_date >= as_of_date |
||
620 | ), |
||
621 | ( |
||
622 | fc_cols.expiration_date == pd.NaT.value, |
||
623 | fc_cols.notice_date >= as_of_date |
||
624 | ) |
||
625 | ], |
||
626 | else_=( |
||
627 | sa.func.min( |
||
628 | fc_cols.notice_date, |
||
629 | fc_cols.expiration_date |
||
630 | ) >= as_of_date |
||
631 | ) |
||
632 | ) |
||
633 | ).order_by( |
||
634 | # Sort using expiration_date if valid. If it's NaT, |
||
635 | # use notice_date instead. |
||
636 | sa.case( |
||
637 | [ |
||
638 | ( |
||
639 | fc_cols.expiration_date == pd.NaT.value, |
||
640 | fc_cols.notice_date |
||
641 | ) |
||
642 | ], |
||
643 | else_=fc_cols.expiration_date |
||
644 | ).asc() |
||
645 | ).execute().fetchall() |
||
646 | )) |
||
647 | |||
648 | if not sids: |
||
649 | # Check if root symbol exists. |
||
650 | count = sa.select((sa.func.count(fc_cols.sid),)).where( |
||
651 | fc_cols.root_symbol == root_symbol, |
||
652 | ).scalar() |
||
653 | if count == 0: |
||
654 | raise RootSymbolNotFound(root_symbol=root_symbol) |
||
655 | |||
656 | contracts = self.retrieve_futures_contracts(sids) |
||
657 | return [contracts[sid] for sid in sids] |
||
658 | |||
659 | @property |
||
660 | def sids(self): |
||
661 | return tuple(map( |
||
662 | itemgetter('sid'), |
||
663 | sa.select((self.asset_router.c.sid,)).execute().fetchall(), |
||
664 | )) |
||
665 | |||
666 | def _lookup_generic_scalar(self, |
||
667 | asset_convertible, |
||
668 | as_of_date, |
||
669 | matches, |
||
670 | missing): |
||
671 | """ |
||
672 | Convert asset_convertible to an asset. |
||
673 | |||
674 | On success, append to matches. |
||
675 | On failure, append to missing. |
||
676 | """ |
||
677 | if isinstance(asset_convertible, Asset): |
||
678 | matches.append(asset_convertible) |
||
679 | |||
680 | elif isinstance(asset_convertible, Integral): |
||
681 | try: |
||
682 | result = self.retrieve_asset(int(asset_convertible)) |
||
683 | except SidsNotFound: |
||
684 | missing.append(asset_convertible) |
||
685 | return None |
||
686 | matches.append(result) |
||
687 | |||
688 | elif isinstance(asset_convertible, string_types): |
||
689 | try: |
||
690 | matches.append( |
||
691 | self.lookup_symbol(asset_convertible, as_of_date) |
||
692 | ) |
||
693 | except SymbolNotFound: |
||
694 | missing.append(asset_convertible) |
||
695 | return None |
||
696 | else: |
||
697 | raise NotAssetConvertible( |
||
698 | "Input was %s, not AssetConvertible." |
||
699 | % asset_convertible |
||
700 | ) |
||
701 | |||
702 | def lookup_generic(self, |
||
703 | asset_convertible_or_iterable, |
||
704 | as_of_date): |
||
705 | """ |
||
706 | Convert a AssetConvertible or iterable of AssetConvertibles into |
||
707 | a list of Asset objects. |
||
708 | |||
709 | This method exists primarily as a convenience for implementing |
||
710 | user-facing APIs that can handle multiple kinds of input. It should |
||
711 | not be used for internal code where we already know the expected types |
||
712 | of our inputs. |
||
713 | |||
714 | Returns a pair of objects, the first of which is the result of the |
||
715 | conversion, and the second of which is a list containing any values |
||
716 | that couldn't be resolved. |
||
717 | """ |
||
718 | matches = [] |
||
719 | missing = [] |
||
720 | |||
721 | # Interpret input as scalar. |
||
722 | if isinstance(asset_convertible_or_iterable, AssetConvertible): |
||
723 | self._lookup_generic_scalar( |
||
724 | asset_convertible=asset_convertible_or_iterable, |
||
725 | as_of_date=as_of_date, |
||
726 | matches=matches, |
||
727 | missing=missing, |
||
728 | ) |
||
729 | try: |
||
730 | return matches[0], missing |
||
731 | except IndexError: |
||
732 | if hasattr(asset_convertible_or_iterable, '__int__'): |
||
733 | raise SidsNotFound(sids=[asset_convertible_or_iterable]) |
||
734 | else: |
||
735 | raise SymbolNotFound(symbol=asset_convertible_or_iterable) |
||
736 | |||
737 | # Interpret input as iterable. |
||
738 | try: |
||
739 | iterator = iter(asset_convertible_or_iterable) |
||
740 | except TypeError: |
||
741 | raise NotAssetConvertible( |
||
742 | "Input was not a AssetConvertible " |
||
743 | "or iterable of AssetConvertible." |
||
744 | ) |
||
745 | |||
746 | for obj in iterator: |
||
747 | self._lookup_generic_scalar(obj, as_of_date, matches, missing) |
||
748 | return matches, missing |
||
749 | |||
750 | def map_identifier_index_to_sids(self, index, as_of_date): |
||
751 | """ |
||
752 | This method is for use in sanitizing a user's DataFrame or Panel |
||
753 | inputs. |
||
754 | |||
755 | Takes the given index of identifiers, checks their types, builds assets |
||
756 | if necessary, and returns a list of the sids that correspond to the |
||
757 | input index. |
||
758 | |||
759 | Parameters |
||
760 | ---------- |
||
761 | index : Iterable |
||
762 | An iterable containing ints, strings, or Assets |
||
763 | as_of_date : pandas.Timestamp |
||
764 | A date to be used to resolve any dual-mapped symbols |
||
765 | |||
766 | Returns |
||
767 | ------- |
||
768 | List |
||
769 | A list of integer sids corresponding to the input index |
||
770 | """ |
||
771 | # This method assumes that the type of the objects in the index is |
||
772 | # consistent and can, therefore, be taken from the first identifier |
||
773 | first_identifier = index[0] |
||
774 | |||
775 | # Ensure that input is AssetConvertible (integer, string, or Asset) |
||
776 | if not isinstance(first_identifier, AssetConvertible): |
||
777 | raise MapAssetIdentifierIndexError(obj=first_identifier) |
||
778 | |||
779 | # If sids are provided, no mapping is necessary |
||
780 | if isinstance(first_identifier, Integral): |
||
781 | return index |
||
782 | |||
783 | # Look up all Assets for mapping |
||
784 | matches = [] |
||
785 | missing = [] |
||
786 | for identifier in index: |
||
787 | self._lookup_generic_scalar(identifier, as_of_date, |
||
788 | matches, missing) |
||
789 | |||
790 | if missing: |
||
791 | raise ValueError("Missing assets for identifiers: %s" % missing) |
||
792 | |||
793 | # Return a list of the sids of the found assets |
||
794 | return [asset.sid for asset in matches] |
||
795 | |||
796 | def _compute_asset_lifetimes(self): |
||
797 | """ |
||
798 | Compute and cache a recarry of asset lifetimes. |
||
799 | """ |
||
800 | equities_cols = self.equities.c |
||
801 | buf = np.array( |
||
802 | tuple( |
||
803 | sa.select(( |
||
804 | equities_cols.sid, |
||
805 | equities_cols.start_date, |
||
806 | equities_cols.end_date, |
||
807 | )).execute(), |
||
808 | ), dtype='<f8', # use doubles so we get NaNs |
||
809 | ) |
||
810 | lifetimes = np.recarray( |
||
811 | buf=buf, |
||
812 | shape=(len(buf),), |
||
813 | dtype=[ |
||
814 | ('sid', '<f8'), |
||
815 | ('start', '<f8'), |
||
816 | ('end', '<f8') |
||
817 | ], |
||
818 | ) |
||
819 | start = lifetimes.start |
||
820 | end = lifetimes.end |
||
821 | start[np.isnan(start)] = 0 # convert missing starts to 0 |
||
822 | end[np.isnan(end)] = np.iinfo(int).max # convert missing end to INTMAX |
||
823 | # Cast the results back down to int. |
||
824 | return lifetimes.astype([ |
||
825 | ('sid', '<i8'), |
||
826 | ('start', '<i8'), |
||
827 | ('end', '<i8'), |
||
828 | ]) |
||
829 | |||
830 | def lifetimes(self, dates, include_start_date): |
||
831 | """ |
||
832 | Compute a DataFrame representing asset lifetimes for the specified date |
||
833 | range. |
||
834 | |||
835 | Parameters |
||
836 | ---------- |
||
837 | dates : pd.DatetimeIndex |
||
838 | The dates for which to compute lifetimes. |
||
839 | include_start_date : bool |
||
840 | Whether or not to count the asset as alive on its start_date. |
||
841 | |||
842 | This is useful in a backtesting context where `lifetimes` is being |
||
843 | used to signify "do I have data for this asset as of the morning of |
||
844 | this date?" For many financial metrics, (e.g. daily close), data |
||
845 | isn't available for an asset until the end of the asset's first |
||
846 | day. |
||
847 | |||
848 | Returns |
||
849 | ------- |
||
850 | lifetimes : pd.DataFrame |
||
851 | A frame of dtype bool with `dates` as index and an Int64Index of |
||
852 | assets as columns. The value at `lifetimes.loc[date, asset]` will |
||
853 | be True iff `asset` existed on `date`. If `include_start_date` is |
||
854 | False, then lifetimes.loc[date, asset] will be false when date == |
||
855 | asset.start_date. |
||
856 | |||
857 | See Also |
||
858 | -------- |
||
859 | numpy.putmask |
||
860 | zipline.pipeline.engine.SimplePipelineEngine._compute_root_mask |
||
861 | """ |
||
862 | # This is a less than ideal place to do this, because if someone adds |
||
863 | # assets to the finder after we've touched lifetimes we won't have |
||
864 | # those new assets available. Mutability is not my favorite |
||
865 | # programming feature. |
||
866 | if self._asset_lifetimes is None: |
||
867 | self._asset_lifetimes = self._compute_asset_lifetimes() |
||
868 | lifetimes = self._asset_lifetimes |
||
869 | |||
870 | raw_dates = dates.asi8[:, None] |
||
871 | if include_start_date: |
||
872 | mask = lifetimes.start <= raw_dates |
||
873 | else: |
||
874 | mask = lifetimes.start < raw_dates |
||
875 | mask &= (raw_dates <= lifetimes.end) |
||
876 | |||
877 | return pd.DataFrame(mask, index=dates, columns=lifetimes.sid) |
||
878 | |||
1039 |