Total Complexity | 88 |
Total Lines | 865 |
Duplicated Lines | 0 % |
Complex classes like zipline.assets.AssetFinder often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
1 | # Copyright 2015 Quantopian, Inc. |
||
79 | class AssetFinder(object): |
||
80 | """ |
||
81 | An AssetFinder is an interface to a database of Asset metadata written by |
||
82 | an ``AssetDBWriter``. |
||
83 | |||
84 | This class provides methods for looking up assets by unique integer id or |
||
85 | by symbol. For historical reasons, we refer to these unique ids as 'sids'. |
||
86 | |||
87 | Parameters |
||
88 | ---------- |
||
89 | engine : str or SQLAlchemy.engine |
||
90 | An engine with a connection to the asset database to use, or a string |
||
91 | that can be parsed by SQLAlchemy as a URI. |
||
92 | |||
93 | See Also |
||
94 | -------- |
||
95 | :class:`zipline.assets.asset_writer.AssetDBWriter` |
||
96 | """ |
||
97 | # Token used as a substitute for pickling objects that contain a |
||
98 | # reference to an AssetFinder. |
||
99 | PERSISTENT_TOKEN = "<AssetFinder>" |
||
100 | |||
101 | def __init__(self, engine): |
||
102 | |||
103 | self.engine = engine |
||
104 | metadata = sa.MetaData(bind=engine) |
||
105 | metadata.reflect(only=asset_db_table_names) |
||
106 | for table_name in asset_db_table_names: |
||
107 | setattr(self, table_name, metadata.tables[table_name]) |
||
108 | |||
109 | # Check the version info of the db for compatibility |
||
110 | check_version_info(self.version_info, ASSET_DB_VERSION) |
||
111 | |||
112 | # Cache for lookup of assets by sid, the objects in the asset lookup |
||
113 | # may be shared with the results from equity and future lookup caches. |
||
114 | # |
||
115 | # The top level cache exists to minimize lookups on the asset type |
||
116 | # routing. |
||
117 | # |
||
118 | # The caches are read through, i.e. accessing an asset through |
||
119 | # retrieve_asset will populate the cache on first retrieval. |
||
120 | self._caches = (self._asset_cache, self._asset_type_cache) = {}, {} |
||
121 | |||
122 | # Populated on first call to `lifetimes`. |
||
123 | self._asset_lifetimes = None |
||
124 | |||
125 | def _reset_caches(self): |
||
126 | """ |
||
127 | Reset our asset caches. |
||
128 | |||
129 | You probably shouldn't call this method. |
||
130 | """ |
||
131 | # This method exists as a workaround for the in-place mutating behavior |
||
132 | # of `TradingAlgorithm._write_and_map_id_index_to_sids`. No one else |
||
133 | # should be calling this. |
||
134 | for cache in self._caches: |
||
135 | cache.clear() |
||
136 | |||
137 | def lookup_asset_types(self, sids): |
||
138 | """ |
||
139 | Retrieve asset types for a list of sids. |
||
140 | |||
141 | Parameters |
||
142 | ---------- |
||
143 | sids : list[int] |
||
144 | |||
145 | Returns |
||
146 | ------- |
||
147 | types : dict[sid -> str or None] |
||
148 | Asset types for the provided sids. |
||
149 | """ |
||
150 | found = {} |
||
151 | missing = set() |
||
152 | |||
153 | for sid in sids: |
||
154 | try: |
||
155 | found[sid] = self._asset_type_cache[sid] |
||
156 | except KeyError: |
||
157 | missing.add(sid) |
||
158 | |||
159 | if not missing: |
||
160 | return found |
||
161 | |||
162 | router_cols = self.asset_router.c |
||
163 | |||
164 | for assets in self._group_into_chunks(missing): |
||
165 | query = sa.select((router_cols.sid, router_cols.asset_type)).where( |
||
166 | self.asset_router.c.sid.in_(map(int, assets)) |
||
167 | ) |
||
168 | for sid, type_ in query.execute().fetchall(): |
||
169 | missing.remove(sid) |
||
170 | found[sid] = self._asset_type_cache[sid] = type_ |
||
171 | |||
172 | for sid in missing: |
||
173 | found[sid] = self._asset_type_cache[sid] = None |
||
174 | |||
175 | return found |
||
176 | |||
177 | @staticmethod |
||
178 | def _group_into_chunks(items, chunk_size=SQLITE_MAX_VARIABLE_NUMBER): |
||
179 | items = list(items) |
||
180 | return [items[x:x+chunk_size] |
||
181 | for x in range(0, len(items), chunk_size)] |
||
182 | |||
183 | def group_by_type(self, sids): |
||
184 | """ |
||
185 | Group a list of sids by asset type. |
||
186 | |||
187 | Parameters |
||
188 | ---------- |
||
189 | sids : list[int] |
||
190 | |||
191 | Returns |
||
192 | ------- |
||
193 | types : dict[str or None -> list[int]] |
||
194 | A dict mapping unique asset types to lists of sids drawn from sids. |
||
195 | If we fail to look up an asset, we assign it a key of None. |
||
196 | """ |
||
197 | return invert(self.lookup_asset_types(sids)) |
||
198 | |||
199 | def retrieve_asset(self, sid, default_none=False): |
||
200 | """ |
||
201 | Retrieve the Asset for a given sid. |
||
202 | """ |
||
203 | return self.retrieve_all((sid,), default_none=default_none)[0] |
||
204 | |||
205 | def retrieve_all(self, sids, default_none=False): |
||
206 | """ |
||
207 | Retrieve all assets in `sids`. |
||
208 | |||
209 | Parameters |
||
210 | ---------- |
||
211 | sids : interable of int |
||
212 | Assets to retrieve. |
||
213 | default_none : bool |
||
214 | If True, return None for failed lookups. |
||
215 | If False, raise `SidsNotFound`. |
||
216 | |||
217 | Returns |
||
218 | ------- |
||
219 | assets : list[int or None] |
||
220 | A list of the same length as `sids` containing Assets (or Nones) |
||
221 | corresponding to the requested sids. |
||
222 | |||
223 | Raises |
||
224 | ------ |
||
225 | SidsNotFound |
||
226 | When a requested sid is not found and default_none=False. |
||
227 | """ |
||
228 | hits, missing, failures = {}, set(), [] |
||
229 | for sid in sids: |
||
230 | try: |
||
231 | asset = self._asset_cache[sid] |
||
232 | if not default_none and asset is None: |
||
233 | # Bail early if we've already cached that we don't know |
||
234 | # about an asset. |
||
235 | raise SidsNotFound(sids=[sid]) |
||
236 | hits[sid] = asset |
||
237 | except KeyError: |
||
238 | missing.add(sid) |
||
239 | |||
240 | # All requests were cache hits. Return requested sids in order. |
||
241 | if not missing: |
||
242 | return [hits[sid] for sid in sids] |
||
243 | |||
244 | update_hits = hits.update |
||
245 | |||
246 | # Look up cache misses by type. |
||
247 | type_to_assets = self.group_by_type(missing) |
||
248 | |||
249 | # Handle failures |
||
250 | failures = {failure: None for failure in type_to_assets.pop(None, ())} |
||
251 | update_hits(failures) |
||
252 | self._asset_cache.update(failures) |
||
253 | |||
254 | if failures and not default_none: |
||
255 | raise SidsNotFound(sids=list(failures)) |
||
256 | |||
257 | # We don't update the asset cache here because it should already be |
||
258 | # updated by `self.retrieve_equities`. |
||
259 | update_hits(self.retrieve_equities(type_to_assets.pop('equity', ()))) |
||
260 | update_hits( |
||
261 | self.retrieve_futures_contracts(type_to_assets.pop('future', ())) |
||
262 | ) |
||
263 | |||
264 | # We shouldn't know about any other asset types. |
||
265 | if type_to_assets: |
||
266 | raise AssertionError( |
||
267 | "Found asset types: %s" % list(type_to_assets.keys()) |
||
268 | ) |
||
269 | |||
270 | return [hits[sid] for sid in sids] |
||
271 | |||
272 | def retrieve_equities(self, sids): |
||
273 | """ |
||
274 | Retrieve Equity objects for a list of sids. |
||
275 | |||
276 | Users generally shouldn't need to this method (instead, they should |
||
277 | prefer the more general/friendly `retrieve_assets`), but it has a |
||
278 | documented interface and tests because it's used upstream. |
||
279 | |||
280 | Parameters |
||
281 | ---------- |
||
282 | sids : iterable[int] |
||
283 | |||
284 | Returns |
||
285 | ------- |
||
286 | equities : dict[int -> Equity] |
||
287 | |||
288 | Raises |
||
289 | ------ |
||
290 | EquitiesNotFound |
||
291 | When any requested asset isn't found. |
||
292 | """ |
||
293 | return self._retrieve_assets(sids, self.equities, Equity) |
||
294 | |||
295 | def _retrieve_equity(self, sid): |
||
296 | return self.retrieve_equities((sid,))[sid] |
||
297 | |||
298 | def retrieve_futures_contracts(self, sids): |
||
299 | """ |
||
300 | Retrieve Future objects for an iterable of sids. |
||
301 | |||
302 | Users generally shouldn't need to this method (instead, they should |
||
303 | prefer the more general/friendly `retrieve_assets`), but it has a |
||
304 | documented interface and tests because it's used upstream. |
||
305 | |||
306 | Parameters |
||
307 | ---------- |
||
308 | sids : iterable[int] |
||
309 | |||
310 | Returns |
||
311 | ------- |
||
312 | equities : dict[int -> Equity] |
||
313 | |||
314 | Raises |
||
315 | ------ |
||
316 | EquitiesNotFound |
||
317 | When any requested asset isn't found. |
||
318 | """ |
||
319 | return self._retrieve_assets(sids, self.futures_contracts, Future) |
||
320 | |||
321 | @staticmethod |
||
322 | def _select_assets_by_sid(asset_tbl, sids): |
||
323 | return sa.select([asset_tbl]).where( |
||
324 | asset_tbl.c.sid.in_(map(int, sids)) |
||
325 | ) |
||
326 | |||
327 | @staticmethod |
||
328 | def _select_asset_by_symbol(asset_tbl, symbol): |
||
329 | return sa.select([asset_tbl]).where(asset_tbl.c.symbol == symbol) |
||
330 | |||
331 | def _retrieve_assets(self, sids, asset_tbl, asset_type): |
||
332 | """ |
||
333 | Internal function for loading assets from a table. |
||
334 | |||
335 | This should be the only method of `AssetFinder` that writes Assets into |
||
336 | self._asset_cache. |
||
337 | |||
338 | Parameters |
||
339 | --------- |
||
340 | sids : iterable of int |
||
341 | Asset ids to look up. |
||
342 | asset_tbl : sqlalchemy.Table |
||
343 | Table from which to query assets. |
||
344 | asset_type : type |
||
345 | Type of asset to be constructed. |
||
346 | |||
347 | Returns |
||
348 | ------- |
||
349 | assets : dict[int -> Asset] |
||
350 | Dict mapping requested sids to the retrieved assets. |
||
351 | """ |
||
352 | # Fastpath for empty request. |
||
353 | if not sids: |
||
354 | return {} |
||
355 | |||
356 | cache = self._asset_cache |
||
357 | hits = {} |
||
358 | |||
359 | for assets in self._group_into_chunks(sids): |
||
360 | # Load misses from the db. |
||
361 | query = self._select_assets_by_sid(asset_tbl, assets) |
||
362 | |||
363 | for row in imap(dict, query.execute().fetchall()): |
||
364 | asset = asset_type(**_convert_asset_timestamp_fields(row)) |
||
365 | sid = asset.sid |
||
366 | hits[sid] = cache[sid] = asset |
||
367 | |||
368 | # If we get here, it means something in our code thought that a |
||
369 | # particular sid was an equity/future and called this function with a |
||
370 | # concrete type, but we couldn't actually resolve the asset. This is |
||
371 | # an error in our code, not a user-input error. |
||
372 | misses = tuple(set(sids) - viewkeys(hits)) |
||
373 | if misses: |
||
374 | if asset_type == Equity: |
||
375 | raise EquitiesNotFound(sids=misses) |
||
376 | else: |
||
377 | raise FutureContractsNotFound(sids=misses) |
||
378 | return hits |
||
379 | |||
380 | def _get_fuzzy_candidates(self, fuzzy_symbol): |
||
381 | candidates = sa.select( |
||
382 | (self.equities.c.sid,) |
||
383 | ).where(self.equities.c.fuzzy_symbol == fuzzy_symbol).order_by( |
||
384 | self.equities.c.start_date.desc(), |
||
385 | self.equities.c.end_date.desc() |
||
386 | ).execute().fetchall() |
||
387 | return candidates |
||
388 | |||
389 | def _get_fuzzy_candidates_in_range(self, fuzzy_symbol, ad_value): |
||
390 | candidates = sa.select( |
||
391 | (self.equities.c.sid,) |
||
392 | ).where( |
||
393 | sa.and_( |
||
394 | self.equities.c.fuzzy_symbol == fuzzy_symbol, |
||
395 | self.equities.c.start_date <= ad_value, |
||
396 | self.equities.c.end_date >= ad_value |
||
397 | ) |
||
398 | ).order_by( |
||
399 | self.equities.c.start_date.desc(), |
||
400 | self.equities.c.end_date.desc(), |
||
401 | ).execute().fetchall() |
||
402 | return candidates |
||
403 | |||
404 | def _get_split_candidates_in_range(self, |
||
405 | company_symbol, |
||
406 | share_class_symbol, |
||
407 | ad_value): |
||
408 | candidates = sa.select( |
||
409 | (self.equities.c.sid,) |
||
410 | ).where( |
||
411 | sa.and_( |
||
412 | self.equities.c.company_symbol == company_symbol, |
||
413 | self.equities.c.share_class_symbol == share_class_symbol, |
||
414 | self.equities.c.start_date <= ad_value, |
||
415 | self.equities.c.end_date >= ad_value |
||
416 | ) |
||
417 | ).order_by( |
||
418 | self.equities.c.start_date.desc(), |
||
419 | self.equities.c.end_date.desc(), |
||
420 | ).execute().fetchall() |
||
421 | return candidates |
||
422 | |||
423 | def _get_split_candidates(self, company_symbol, share_class_symbol): |
||
424 | candidates = sa.select( |
||
425 | (self.equities.c.sid,) |
||
426 | ).where( |
||
427 | sa.and_( |
||
428 | self.equities.c.company_symbol == company_symbol, |
||
429 | self.equities.c.share_class_symbol == share_class_symbol |
||
430 | ) |
||
431 | ).order_by( |
||
432 | self.equities.c.start_date.desc(), |
||
433 | self.equities.c.end_date.desc(), |
||
434 | ).execute().fetchall() |
||
435 | return candidates |
||
436 | |||
437 | def _resolve_no_matching_candidates(self, |
||
438 | company_symbol, |
||
439 | share_class_symbol, |
||
440 | ad_value): |
||
441 | candidates = sa.select((self.equities.c.sid,)).where( |
||
442 | sa.and_( |
||
443 | self.equities.c.company_symbol == company_symbol, |
||
444 | self.equities.c.share_class_symbol == |
||
445 | share_class_symbol, |
||
446 | self.equities.c.start_date <= ad_value), |
||
447 | ).order_by( |
||
448 | self.equities.c.end_date.desc(), |
||
449 | ).execute().fetchall() |
||
450 | return candidates |
||
451 | |||
452 | def _get_best_candidate(self, candidates): |
||
453 | return self._retrieve_equity(candidates[0]['sid']) |
||
454 | |||
455 | def _get_equities_from_candidates(self, candidates): |
||
456 | sids = map(itemgetter('sid'), candidates) |
||
457 | results = self.retrieve_equities(sids) |
||
458 | return [results[sid] for sid in sids] |
||
459 | |||
460 | def lookup_symbol(self, symbol, as_of_date, fuzzy=False): |
||
461 | """ |
||
462 | Return matching Equity of name symbol in database. |
||
463 | |||
464 | If multiple Equities are found and as_of_date is not set, |
||
465 | raises MultipleSymbolsFound. |
||
466 | |||
467 | If no Equity was active at as_of_date raises SymbolNotFound. |
||
468 | """ |
||
469 | company_symbol, share_class_symbol, fuzzy_symbol = \ |
||
470 | split_delimited_symbol(symbol) |
||
471 | if as_of_date: |
||
472 | # Format inputs |
||
473 | as_of_date = pd.Timestamp(as_of_date).normalize() |
||
474 | ad_value = as_of_date.value |
||
475 | |||
476 | if fuzzy: |
||
477 | # Search for a single exact match on the fuzzy column |
||
478 | candidates = self._get_fuzzy_candidates_in_range(fuzzy_symbol, |
||
479 | ad_value) |
||
480 | |||
481 | # If exactly one SID exists for fuzzy_symbol, return that sid |
||
482 | if len(candidates) == 1: |
||
483 | return self._get_best_candidate(candidates) |
||
484 | |||
485 | # Search for exact matches of the split-up company_symbol and |
||
486 | # share_class_symbol |
||
487 | candidates = self._get_split_candidates_in_range( |
||
488 | company_symbol, |
||
489 | share_class_symbol, |
||
490 | ad_value |
||
491 | ) |
||
492 | |||
493 | # If exactly one SID exists for symbol, return that symbol |
||
494 | # If multiple SIDs exist for symbol, return latest start_date with |
||
495 | # end_date as a tie-breaker |
||
496 | if candidates: |
||
497 | return self._get_best_candidate(candidates) |
||
498 | |||
499 | # If no SID exists for symbol, return SID with the |
||
500 | # highest-but-not-over end_date |
||
501 | elif not candidates: |
||
502 | candidates = self._resolve_no_matching_candidates( |
||
503 | company_symbol, |
||
504 | share_class_symbol, |
||
505 | ad_value |
||
506 | ) |
||
507 | if candidates: |
||
508 | return self._get_best_candidate(candidates) |
||
509 | |||
510 | raise SymbolNotFound(symbol=symbol) |
||
511 | |||
512 | else: |
||
513 | # If this is a fuzzy look-up, check if there is exactly one match |
||
514 | # for the fuzzy symbol |
||
515 | if fuzzy: |
||
516 | candidates = self._get_fuzzy_candidates(fuzzy_symbol) |
||
517 | if len(candidates) == 1: |
||
518 | return self._get_best_candidate(candidates) |
||
519 | |||
520 | candidates = self._get_split_candidates(company_symbol, |
||
521 | share_class_symbol) |
||
522 | if len(candidates) == 1: |
||
523 | return self._get_best_candidate(candidates) |
||
524 | elif not candidates: |
||
525 | raise SymbolNotFound(symbol=symbol) |
||
526 | else: |
||
527 | raise MultipleSymbolsFound( |
||
528 | symbol=symbol, |
||
529 | options=self._get_equities_from_candidates(candidates) |
||
530 | ) |
||
531 | |||
532 | def lookup_future_symbol(self, symbol): |
||
533 | """ Return the Future object for a given symbol. |
||
534 | |||
535 | Parameters |
||
536 | ---------- |
||
537 | symbol : str |
||
538 | The symbol of the desired contract. |
||
539 | |||
540 | Returns |
||
541 | ------- |
||
542 | Future |
||
543 | A Future object. |
||
544 | |||
545 | Raises |
||
546 | ------ |
||
547 | SymbolNotFound |
||
548 | Raised when no contract named 'symbol' is found. |
||
549 | |||
550 | """ |
||
551 | |||
552 | data = self._select_asset_by_symbol(self.futures_contracts, symbol)\ |
||
553 | .execute().fetchone() |
||
554 | |||
555 | # If no data found, raise an exception |
||
556 | if not data: |
||
557 | raise SymbolNotFound(symbol=symbol) |
||
558 | return self.retrieve_asset(data['sid']) |
||
559 | |||
560 | def lookup_future_chain(self, root_symbol, as_of_date): |
||
561 | """ Return the futures chain for a given root symbol. |
||
562 | |||
563 | Parameters |
||
564 | ---------- |
||
565 | root_symbol : str |
||
566 | Root symbol of the desired future. |
||
567 | |||
568 | as_of_date : pd.Timestamp or pd.NaT |
||
569 | Date at which the chain determination is rooted. I.e. the |
||
570 | existing contract whose notice date/expiration date is first |
||
571 | after this date is the primary contract, etc. If NaT is |
||
572 | given, the chain is unbounded, and all contracts for this |
||
573 | root symbol are returned. |
||
574 | |||
575 | Returns |
||
576 | ------- |
||
577 | list |
||
578 | A list of Future objects, the chain for the given |
||
579 | parameters. |
||
580 | |||
581 | Raises |
||
582 | ------ |
||
583 | RootSymbolNotFound |
||
584 | Raised when a future chain could not be found for the given |
||
585 | root symbol. |
||
586 | """ |
||
587 | |||
588 | fc_cols = self.futures_contracts.c |
||
589 | |||
590 | if as_of_date is pd.NaT: |
||
591 | # If the as_of_date is NaT, get all contracts for this |
||
592 | # root symbol. |
||
593 | sids = list(map( |
||
594 | itemgetter('sid'), |
||
595 | sa.select((fc_cols.sid,)).where( |
||
596 | (fc_cols.root_symbol == root_symbol), |
||
597 | ).order_by( |
||
598 | fc_cols.notice_date.asc(), |
||
599 | ).execute().fetchall())) |
||
600 | else: |
||
601 | as_of_date = as_of_date.value |
||
602 | |||
603 | sids = list(map( |
||
604 | itemgetter('sid'), |
||
605 | sa.select((fc_cols.sid,)).where( |
||
606 | (fc_cols.root_symbol == root_symbol) & |
||
607 | |||
608 | # Filter to contracts that are still valid. If both |
||
609 | # exist, use the one that comes first in time (i.e. |
||
610 | # the lower value). If either notice_date or |
||
611 | # expiration_date is NaT, use the other. If both are |
||
612 | # NaT, the contract cannot be included in any chain. |
||
613 | sa.case( |
||
614 | [ |
||
615 | ( |
||
616 | fc_cols.notice_date == pd.NaT.value, |
||
617 | fc_cols.expiration_date >= as_of_date |
||
618 | ), |
||
619 | ( |
||
620 | fc_cols.expiration_date == pd.NaT.value, |
||
621 | fc_cols.notice_date >= as_of_date |
||
622 | ) |
||
623 | ], |
||
624 | else_=( |
||
625 | sa.func.min( |
||
626 | fc_cols.notice_date, |
||
627 | fc_cols.expiration_date |
||
628 | ) >= as_of_date |
||
629 | ) |
||
630 | ) |
||
631 | ).order_by( |
||
632 | # Sort using expiration_date if valid. If it's NaT, |
||
633 | # use notice_date instead. |
||
634 | sa.case( |
||
635 | [ |
||
636 | ( |
||
637 | fc_cols.expiration_date == pd.NaT.value, |
||
638 | fc_cols.notice_date |
||
639 | ) |
||
640 | ], |
||
641 | else_=fc_cols.expiration_date |
||
642 | ).asc() |
||
643 | ).execute().fetchall() |
||
644 | )) |
||
645 | |||
646 | if not sids: |
||
647 | # Check if root symbol exists. |
||
648 | count = sa.select((sa.func.count(fc_cols.sid),)).where( |
||
649 | fc_cols.root_symbol == root_symbol, |
||
650 | ).scalar() |
||
651 | if count == 0: |
||
652 | raise RootSymbolNotFound(root_symbol=root_symbol) |
||
653 | |||
654 | contracts = self.retrieve_futures_contracts(sids) |
||
655 | return [contracts[sid] for sid in sids] |
||
656 | |||
657 | def lookup_expired_futures(self, start, end): |
||
658 | start = start.value |
||
659 | end = end.value |
||
660 | |||
661 | fc_cols = self.futures_contracts.c |
||
662 | |||
663 | sids = list(map( |
||
664 | itemgetter('sid'), |
||
665 | sa.select((fc_cols.sid,)).where( |
||
666 | |||
667 | # Filter to contracts that are still valid. If both |
||
668 | # exist, use the one that comes first in time (i.e. |
||
669 | # the lower value). If either notice_date or |
||
670 | # expiration_date is NaT, use the other. If both are |
||
671 | # NaT, the contract cannot be included in any chain. |
||
672 | sa.case( |
||
673 | [ |
||
674 | ( |
||
675 | fc_cols.notice_date == pd.NaT.value, |
||
676 | fc_cols.expiration_date >= start |
||
677 | ), |
||
678 | ( |
||
679 | fc_cols.expiration_date == pd.NaT.value, |
||
680 | fc_cols.notice_date >= start |
||
681 | ) |
||
682 | ], |
||
683 | else_=( |
||
684 | sa.func.min( |
||
685 | fc_cols.notice_date, |
||
686 | fc_cols.expiration_date |
||
687 | ) >= start |
||
688 | ) |
||
689 | ) |
||
690 | & sa.case( |
||
691 | [ |
||
692 | ( |
||
693 | fc_cols.notice_date == pd.NaT.value, |
||
694 | fc_cols.expiration_date <= end |
||
695 | ), |
||
696 | ( |
||
697 | fc_cols.expiration_date == pd.NaT.value, |
||
698 | fc_cols.notice_date <= end |
||
699 | ) |
||
700 | ], |
||
701 | else_=( |
||
702 | sa.func.min( |
||
703 | fc_cols.notice_date, |
||
704 | fc_cols.expiration_date |
||
705 | ) <= end |
||
706 | ) |
||
707 | ) |
||
708 | ).order_by( |
||
709 | # Sort using expiration_date if valid. If it's |
||
710 | # NaT, use notice_date instead. |
||
711 | sa.case( |
||
712 | [ |
||
713 | ( |
||
714 | fc_cols.expiration_date == pd.NaT.value, |
||
715 | fc_cols.notice_date |
||
716 | ) |
||
717 | ], |
||
718 | else_=fc_cols.expiration_date |
||
719 | ).asc() |
||
720 | ).execute().fetchall() |
||
721 | )) |
||
722 | |||
723 | return sids |
||
724 | |||
725 | @property |
||
726 | def sids(self): |
||
727 | return tuple(map( |
||
728 | itemgetter('sid'), |
||
729 | sa.select((self.asset_router.c.sid,)).execute().fetchall(), |
||
730 | )) |
||
731 | |||
732 | def _lookup_generic_scalar(self, |
||
733 | asset_convertible, |
||
734 | as_of_date, |
||
735 | matches, |
||
736 | missing): |
||
737 | """ |
||
738 | Convert asset_convertible to an asset. |
||
739 | |||
740 | On success, append to matches. |
||
741 | On failure, append to missing. |
||
742 | """ |
||
743 | if isinstance(asset_convertible, Asset): |
||
744 | matches.append(asset_convertible) |
||
745 | |||
746 | elif isinstance(asset_convertible, Integral): |
||
747 | try: |
||
748 | result = self.retrieve_asset(int(asset_convertible)) |
||
749 | except SidsNotFound: |
||
750 | missing.append(asset_convertible) |
||
751 | return None |
||
752 | matches.append(result) |
||
753 | |||
754 | elif isinstance(asset_convertible, string_types): |
||
755 | try: |
||
756 | matches.append( |
||
757 | self.lookup_symbol(asset_convertible, as_of_date) |
||
758 | ) |
||
759 | except SymbolNotFound: |
||
760 | missing.append(asset_convertible) |
||
761 | return None |
||
762 | else: |
||
763 | raise NotAssetConvertible( |
||
764 | "Input was %s, not AssetConvertible." |
||
765 | % asset_convertible |
||
766 | ) |
||
767 | |||
768 | def lookup_generic(self, |
||
769 | asset_convertible_or_iterable, |
||
770 | as_of_date): |
||
771 | """ |
||
772 | Convert a AssetConvertible or iterable of AssetConvertibles into |
||
773 | a list of Asset objects. |
||
774 | |||
775 | This method exists primarily as a convenience for implementing |
||
776 | user-facing APIs that can handle multiple kinds of input. It should |
||
777 | not be used for internal code where we already know the expected types |
||
778 | of our inputs. |
||
779 | |||
780 | Returns a pair of objects, the first of which is the result of the |
||
781 | conversion, and the second of which is a list containing any values |
||
782 | that couldn't be resolved. |
||
783 | """ |
||
784 | matches = [] |
||
785 | missing = [] |
||
786 | |||
787 | # Interpret input as scalar. |
||
788 | if isinstance(asset_convertible_or_iterable, AssetConvertible): |
||
789 | self._lookup_generic_scalar( |
||
790 | asset_convertible=asset_convertible_or_iterable, |
||
791 | as_of_date=as_of_date, |
||
792 | matches=matches, |
||
793 | missing=missing, |
||
794 | ) |
||
795 | try: |
||
796 | return matches[0], missing |
||
797 | except IndexError: |
||
798 | if hasattr(asset_convertible_or_iterable, '__int__'): |
||
799 | raise SidsNotFound(sids=[asset_convertible_or_iterable]) |
||
800 | else: |
||
801 | raise SymbolNotFound(symbol=asset_convertible_or_iterable) |
||
802 | |||
803 | # Interpret input as iterable. |
||
804 | try: |
||
805 | iterator = iter(asset_convertible_or_iterable) |
||
806 | except TypeError: |
||
807 | raise NotAssetConvertible( |
||
808 | "Input was not a AssetConvertible " |
||
809 | "or iterable of AssetConvertible." |
||
810 | ) |
||
811 | |||
812 | for obj in iterator: |
||
813 | self._lookup_generic_scalar(obj, as_of_date, matches, missing) |
||
814 | return matches, missing |
||
815 | |||
816 | def map_identifier_index_to_sids(self, index, as_of_date): |
||
817 | """ |
||
818 | This method is for use in sanitizing a user's DataFrame or Panel |
||
819 | inputs. |
||
820 | |||
821 | Takes the given index of identifiers, checks their types, builds assets |
||
822 | if necessary, and returns a list of the sids that correspond to the |
||
823 | input index. |
||
824 | |||
825 | Parameters |
||
826 | ---------- |
||
827 | index : Iterable |
||
828 | An iterable containing ints, strings, or Assets |
||
829 | as_of_date : pandas.Timestamp |
||
830 | A date to be used to resolve any dual-mapped symbols |
||
831 | |||
832 | Returns |
||
833 | ------- |
||
834 | List |
||
835 | A list of integer sids corresponding to the input index |
||
836 | """ |
||
837 | # This method assumes that the type of the objects in the index is |
||
838 | # consistent and can, therefore, be taken from the first identifier |
||
839 | first_identifier = index[0] |
||
840 | |||
841 | # Ensure that input is AssetConvertible (integer, string, or Asset) |
||
842 | if not isinstance(first_identifier, AssetConvertible): |
||
843 | raise MapAssetIdentifierIndexError(obj=first_identifier) |
||
844 | |||
845 | # If sids are provided, no mapping is necessary |
||
846 | if isinstance(first_identifier, Integral): |
||
847 | return index |
||
848 | |||
849 | # Look up all Assets for mapping |
||
850 | matches = [] |
||
851 | missing = [] |
||
852 | for identifier in index: |
||
853 | self._lookup_generic_scalar(identifier, as_of_date, |
||
854 | matches, missing) |
||
855 | |||
856 | if missing: |
||
857 | raise ValueError("Missing assets for identifiers: %s" % missing) |
||
858 | |||
859 | # Return a list of the sids of the found assets |
||
860 | return [asset.sid for asset in matches] |
||
861 | |||
862 | def _compute_asset_lifetimes(self): |
||
863 | """ |
||
864 | Compute and cache a recarry of asset lifetimes. |
||
865 | """ |
||
866 | equities_cols = self.equities.c |
||
867 | buf = np.array( |
||
868 | tuple( |
||
869 | sa.select(( |
||
870 | equities_cols.sid, |
||
871 | equities_cols.start_date, |
||
872 | equities_cols.end_date, |
||
873 | )).execute(), |
||
874 | ), dtype='<f8', # use doubles so we get NaNs |
||
875 | ) |
||
876 | lifetimes = np.recarray( |
||
877 | buf=buf, |
||
878 | shape=(len(buf),), |
||
879 | dtype=[ |
||
880 | ('sid', '<f8'), |
||
881 | ('start', '<f8'), |
||
882 | ('end', '<f8') |
||
883 | ], |
||
884 | ) |
||
885 | start = lifetimes.start |
||
886 | end = lifetimes.end |
||
887 | start[np.isnan(start)] = 0 # convert missing starts to 0 |
||
888 | end[np.isnan(end)] = np.iinfo(int).max # convert missing end to INTMAX |
||
889 | # Cast the results back down to int. |
||
890 | return lifetimes.astype([ |
||
891 | ('sid', '<i8'), |
||
892 | ('start', '<i8'), |
||
893 | ('end', '<i8'), |
||
894 | ]) |
||
895 | |||
896 | def lifetimes(self, dates, include_start_date): |
||
897 | """ |
||
898 | Compute a DataFrame representing asset lifetimes for the specified date |
||
899 | range. |
||
900 | |||
901 | Parameters |
||
902 | ---------- |
||
903 | dates : pd.DatetimeIndex |
||
904 | The dates for which to compute lifetimes. |
||
905 | include_start_date : bool |
||
906 | Whether or not to count the asset as alive on its start_date. |
||
907 | |||
908 | This is useful in a backtesting context where `lifetimes` is being |
||
909 | used to signify "do I have data for this asset as of the morning of |
||
910 | this date?" For many financial metrics, (e.g. daily close), data |
||
911 | isn't available for an asset until the end of the asset's first |
||
912 | day. |
||
913 | |||
914 | Returns |
||
915 | ------- |
||
916 | lifetimes : pd.DataFrame |
||
917 | A frame of dtype bool with `dates` as index and an Int64Index of |
||
918 | assets as columns. The value at `lifetimes.loc[date, asset]` will |
||
919 | be True iff `asset` existed on `date`. If `include_start_date` is |
||
920 | False, then lifetimes.loc[date, asset] will be false when date == |
||
921 | asset.start_date. |
||
922 | |||
923 | See Also |
||
924 | -------- |
||
925 | numpy.putmask |
||
926 | zipline.pipeline.engine.SimplePipelineEngine._compute_root_mask |
||
927 | """ |
||
928 | # This is a less than ideal place to do this, because if someone adds |
||
929 | # assets to the finder after we've touched lifetimes we won't have |
||
930 | # those new assets available. Mutability is not my favorite |
||
931 | # programming feature. |
||
932 | if self._asset_lifetimes is None: |
||
933 | self._asset_lifetimes = self._compute_asset_lifetimes() |
||
934 | lifetimes = self._asset_lifetimes |
||
935 | |||
936 | raw_dates = dates.asi8[:, None] |
||
937 | if include_start_date: |
||
938 | mask = lifetimes.start <= raw_dates |
||
939 | else: |
||
940 | mask = lifetimes.start < raw_dates |
||
941 | mask &= (raw_dates <= lifetimes.end) |
||
942 | |||
943 | return pd.DataFrame(mask, index=dates, columns=lifetimes.sid) |
||
944 | |||
1106 |