Completed
Pull Request — master (#846)
by Warren
05:15 queued 03:42
created

lookup_expired_futures()   A

Complexity

Conditions 1

Size

Total Lines 19

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 1
dl 0
loc 19
rs 9.4286
1
# Copyright 2015 Quantopian, Inc.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
15
from abc import ABCMeta
16
from numbers import Integral
17
from operator import itemgetter
18
19
from logbook import Logger
20
import numpy as np
21
import pandas as pd
22
from pandas import isnull
23
from six import with_metaclass, string_types, viewkeys
24
from six.moves import map as imap, range
25
import sqlalchemy as sa
26
27
from zipline.errors import (
28
    EquitiesNotFound,
29
    FutureContractsNotFound,
30
    MapAssetIdentifierIndexError,
31
    MultipleSymbolsFound,
32
    RootSymbolNotFound,
33
    SidsNotFound,
34
    SymbolNotFound,
35
)
36
from zipline.assets import (
37
    Asset, Equity, Future,
38
)
39
from zipline.assets.asset_writer import (
40
    split_delimited_symbol,
41
    check_version_info,
42
    ASSET_DB_VERSION,
43
    asset_db_table_names,
44
    SQLITE_MAX_VARIABLE_NUMBER
45
)
46
from zipline.utils.control_flow import invert
47
48
log = Logger('assets.py')
49
50
# A set of fields that need to be converted to strings before building an
51
# Asset to avoid unicode fields
52
_asset_str_fields = frozenset({
53
    'symbol',
54
    'asset_name',
55
    'exchange',
56
})
57
58
# A set of fields that need to be converted to timestamps in UTC
59
_asset_timestamp_fields = frozenset({
60
    'start_date',
61
    'end_date',
62
    'first_traded',
63
    'notice_date',
64
    'expiration_date',
65
    'auto_close_date',
66
})
67
68
69
def _convert_asset_timestamp_fields(dict_):
70
    """
71
    Takes in a dict of Asset init args and converts dates to pd.Timestamps
72
    """
73
    for key in (_asset_timestamp_fields & viewkeys(dict_)):
74
        value = pd.Timestamp(dict_[key], tz='UTC')
75
        dict_[key] = None if isnull(value) else value
76
    return dict_
77
78
79
class AssetFinder(object):
80
    """
81
    An AssetFinder is an interface to a database of Asset metadata written by
82
    an ``AssetDBWriter``.
83
84
    This class provides methods for looking up assets by unique integer id or
85
    by symbol.  For historical reasons, we refer to these unique ids as 'sids'.
86
87
    Parameters
88
    ----------
89
    engine : str or SQLAlchemy.engine
90
        An engine with a connection to the asset database to use, or a string
91
        that can be parsed by SQLAlchemy as a URI.
92
93
    See Also
94
    --------
95
    :class:`zipline.assets.asset_writer.AssetDBWriter`
96
    """
97
    # Token used as a substitute for pickling objects that contain a
98
    # reference to an AssetFinder.
99
    PERSISTENT_TOKEN = "<AssetFinder>"
100
101
    def __init__(self, engine):
102
103
        self.engine = engine
104
        metadata = sa.MetaData(bind=engine)
105
        metadata.reflect(only=asset_db_table_names)
106
        for table_name in asset_db_table_names:
107
            setattr(self, table_name, metadata.tables[table_name])
108
109
        # Check the version info of the db for compatibility
110
        check_version_info(self.version_info, ASSET_DB_VERSION)
111
112
        # Cache for lookup of assets by sid, the objects in the asset lookup
113
        # may be shared with the results from equity and future lookup caches.
114
        #
115
        # The top level cache exists to minimize lookups on the asset type
116
        # routing.
117
        #
118
        # The caches are read through, i.e. accessing an asset through
119
        # retrieve_asset will populate the cache on first retrieval.
120
        self._caches = (self._asset_cache, self._asset_type_cache) = {}, {}
121
122
        # Populated on first call to `lifetimes`.
123
        self._asset_lifetimes = None
124
125
    def _reset_caches(self):
126
        """
127
        Reset our asset caches.
128
129
        You probably shouldn't call this method.
130
        """
131
        # This method exists as a workaround for the in-place mutating behavior
132
        # of `TradingAlgorithm._write_and_map_id_index_to_sids`.  No one else
133
        # should be calling this.
134
        for cache in self._caches:
135
            cache.clear()
136
137
    def lookup_asset_types(self, sids):
138
        """
139
        Retrieve asset types for a list of sids.
140
141
        Parameters
142
        ----------
143
        sids : list[int]
144
145
        Returns
146
        -------
147
        types : dict[sid -> str or None]
148
            Asset types for the provided sids.
149
        """
150
        found = {}
151
        missing = set()
152
153
        for sid in sids:
154
            try:
155
                found[sid] = self._asset_type_cache[sid]
156
            except KeyError:
157
                missing.add(sid)
158
159
        if not missing:
160
            return found
161
162
        router_cols = self.asset_router.c
163
164
        for assets in self._group_into_chunks(missing):
165
            query = sa.select((router_cols.sid, router_cols.asset_type)).where(
166
                self.asset_router.c.sid.in_(map(int, assets))
167
            )
168
            for sid, type_ in query.execute().fetchall():
169
                missing.remove(sid)
170
                found[sid] = self._asset_type_cache[sid] = type_
171
172
            for sid in missing:
173
                found[sid] = self._asset_type_cache[sid] = None
174
175
        return found
176
177
    @staticmethod
178
    def _group_into_chunks(items, chunk_size=SQLITE_MAX_VARIABLE_NUMBER):
179
        items = list(items)
180
        return [items[x:x+chunk_size]
181
                for x in range(0, len(items), chunk_size)]
182
183
    def group_by_type(self, sids):
184
        """
185
        Group a list of sids by asset type.
186
187
        Parameters
188
        ----------
189
        sids : list[int]
190
191
        Returns
192
        -------
193
        types : dict[str or None -> list[int]]
194
            A dict mapping unique asset types to lists of sids drawn from sids.
195
            If we fail to look up an asset, we assign it a key of None.
196
        """
197
        return invert(self.lookup_asset_types(sids))
198
199
    def retrieve_asset(self, sid, default_none=False):
200
        """
201
        Retrieve the Asset for a given sid.
202
        """
203
        return self.retrieve_all((sid,), default_none=default_none)[0]
204
205
    def retrieve_all(self, sids, default_none=False):
206
        """
207
        Retrieve all assets in `sids`.
208
209
        Parameters
210
        ----------
211
        sids : interable of int
212
            Assets to retrieve.
213
        default_none : bool
214
            If True, return None for failed lookups.
215
            If False, raise `SidsNotFound`.
216
217
        Returns
218
        -------
219
        assets : list[int or None]
220
            A list of the same length as `sids` containing Assets (or Nones)
221
            corresponding to the requested sids.
222
223
        Raises
224
        ------
225
        SidsNotFound
226
            When a requested sid is not found and default_none=False.
227
        """
228
        hits, missing, failures = {}, set(), []
229
        for sid in sids:
230
            try:
231
                asset = self._asset_cache[sid]
232
                if not default_none and asset is None:
233
                    # Bail early if we've already cached that we don't know
234
                    # about an asset.
235
                    raise SidsNotFound(sids=[sid])
236
                hits[sid] = asset
237
            except KeyError:
238
                missing.add(sid)
239
240
        # All requests were cache hits.  Return requested sids in order.
241
        if not missing:
242
            return [hits[sid] for sid in sids]
243
244
        update_hits = hits.update
245
246
        # Look up cache misses by type.
247
        type_to_assets = self.group_by_type(missing)
248
249
        # Handle failures
250
        failures = {failure: None for failure in type_to_assets.pop(None, ())}
251
        update_hits(failures)
252
        self._asset_cache.update(failures)
253
254
        if failures and not default_none:
255
            raise SidsNotFound(sids=list(failures))
256
257
        # We don't update the asset cache here because it should already be
258
        # updated by `self.retrieve_equities`.
259
        update_hits(self.retrieve_equities(type_to_assets.pop('equity', ())))
260
        update_hits(
261
            self.retrieve_futures_contracts(type_to_assets.pop('future', ()))
262
        )
263
264
        # We shouldn't know about any other asset types.
265
        if type_to_assets:
266
            raise AssertionError(
267
                "Found asset types: %s" % list(type_to_assets.keys())
268
            )
269
270
        return [hits[sid] for sid in sids]
271
272
    def retrieve_equities(self, sids):
273
        """
274
        Retrieve Equity objects for a list of sids.
275
276
        Users generally shouldn't need to this method (instead, they should
277
        prefer the more general/friendly `retrieve_assets`), but it has a
278
        documented interface and tests because it's used upstream.
279
280
        Parameters
281
        ----------
282
        sids : iterable[int]
283
284
        Returns
285
        -------
286
        equities : dict[int -> Equity]
287
288
        Raises
289
        ------
290
        EquitiesNotFound
291
            When any requested asset isn't found.
292
        """
293
        return self._retrieve_assets(sids, self.equities, Equity)
294
295
    def _retrieve_equity(self, sid):
296
        return self.retrieve_equities((sid,))[sid]
297
298
    def retrieve_futures_contracts(self, sids):
299
        """
300
        Retrieve Future objects for an iterable of sids.
301
302
        Users generally shouldn't need to this method (instead, they should
303
        prefer the more general/friendly `retrieve_assets`), but it has a
304
        documented interface and tests because it's used upstream.
305
306
        Parameters
307
        ----------
308
        sids : iterable[int]
309
310
        Returns
311
        -------
312
        equities : dict[int -> Equity]
313
314
        Raises
315
        ------
316
        EquitiesNotFound
317
            When any requested asset isn't found.
318
        """
319
        return self._retrieve_assets(sids, self.futures_contracts, Future)
320
321
    @staticmethod
322
    def _select_assets_by_sid(asset_tbl, sids):
323
        return sa.select([asset_tbl]).where(
324
            asset_tbl.c.sid.in_(map(int, sids))
325
        )
326
327
    @staticmethod
328
    def _select_asset_by_symbol(asset_tbl, symbol):
329
        return sa.select([asset_tbl]).where(asset_tbl.c.symbol == symbol)
330
331
    def _retrieve_assets(self, sids, asset_tbl, asset_type):
332
        """
333
        Internal function for loading assets from a table.
334
335
        This should be the only method of `AssetFinder` that writes Assets into
336
        self._asset_cache.
337
338
        Parameters
339
        ---------
340
        sids : iterable of int
341
            Asset ids to look up.
342
        asset_tbl : sqlalchemy.Table
343
            Table from which to query assets.
344
        asset_type : type
345
            Type of asset to be constructed.
346
347
        Returns
348
        -------
349
        assets : dict[int -> Asset]
350
            Dict mapping requested sids to the retrieved assets.
351
        """
352
        # Fastpath for empty request.
353
        if not sids:
354
            return {}
355
356
        cache = self._asset_cache
357
        hits = {}
358
359
        for assets in self._group_into_chunks(sids):
360
            # Load misses from the db.
361
            query = self._select_assets_by_sid(asset_tbl, assets)
362
363
            for row in imap(dict, query.execute().fetchall()):
364
                asset = asset_type(**_convert_asset_timestamp_fields(row))
365
                sid = asset.sid
366
                hits[sid] = cache[sid] = asset
367
368
        # If we get here, it means something in our code thought that a
369
        # particular sid was an equity/future and called this function with a
370
        # concrete type, but we couldn't actually resolve the asset.  This is
371
        # an error in our code, not a user-input error.
372
        misses = tuple(set(sids) - viewkeys(hits))
373
        if misses:
374
            if asset_type == Equity:
375
                raise EquitiesNotFound(sids=misses)
376
            else:
377
                raise FutureContractsNotFound(sids=misses)
378
        return hits
379
380
    def _get_fuzzy_candidates(self, fuzzy_symbol):
381
        candidates = sa.select(
382
            (self.equities.c.sid,)
383
        ).where(self.equities.c.fuzzy_symbol == fuzzy_symbol).order_by(
384
            self.equities.c.start_date.desc(),
385
            self.equities.c.end_date.desc()
386
        ).execute().fetchall()
387
        return candidates
388
389
    def _get_fuzzy_candidates_in_range(self, fuzzy_symbol, ad_value):
390
        candidates = sa.select(
391
            (self.equities.c.sid,)
392
        ).where(
393
            sa.and_(
394
                self.equities.c.fuzzy_symbol == fuzzy_symbol,
395
                self.equities.c.start_date <= ad_value,
396
                self.equities.c.end_date >= ad_value
397
            )
398
        ).order_by(
399
            self.equities.c.start_date.desc(),
400
            self.equities.c.end_date.desc(),
401
        ).execute().fetchall()
402
        return candidates
403
404
    def _get_split_candidates_in_range(self,
405
                                       company_symbol,
406
                                       share_class_symbol,
407
                                       ad_value):
408
        candidates = sa.select(
409
            (self.equities.c.sid,)
410
        ).where(
411
            sa.and_(
412
                self.equities.c.company_symbol == company_symbol,
413
                self.equities.c.share_class_symbol == share_class_symbol,
414
                self.equities.c.start_date <= ad_value,
415
                self.equities.c.end_date >= ad_value
416
            )
417
        ).order_by(
418
            self.equities.c.start_date.desc(),
419
            self.equities.c.end_date.desc(),
420
        ).execute().fetchall()
421
        return candidates
422
423
    def _get_split_candidates(self, company_symbol, share_class_symbol):
424
        candidates = sa.select(
425
            (self.equities.c.sid,)
426
        ).where(
427
            sa.and_(
428
                self.equities.c.company_symbol == company_symbol,
429
                self.equities.c.share_class_symbol == share_class_symbol
430
            )
431
        ).order_by(
432
            self.equities.c.start_date.desc(),
433
            self.equities.c.end_date.desc(),
434
        ).execute().fetchall()
435
        return candidates
436
437
    def _resolve_no_matching_candidates(self,
438
                                        company_symbol,
439
                                        share_class_symbol,
440
                                        ad_value):
441
        candidates = sa.select((self.equities.c.sid,)).where(
442
            sa.and_(
443
                self.equities.c.company_symbol == company_symbol,
444
                self.equities.c.share_class_symbol ==
445
                share_class_symbol,
446
                self.equities.c.start_date <= ad_value),
447
        ).order_by(
448
            self.equities.c.end_date.desc(),
449
        ).execute().fetchall()
450
        return candidates
451
452
    def _get_best_candidate(self, candidates):
453
        return self._retrieve_equity(candidates[0]['sid'])
454
455
    def _get_equities_from_candidates(self, candidates):
456
        sids = map(itemgetter('sid'), candidates)
457
        results = self.retrieve_equities(sids)
458
        return [results[sid] for sid in sids]
459
460
    def lookup_symbol(self, symbol, as_of_date, fuzzy=False):
461
        """
462
        Return matching Equity of name symbol in database.
463
464
        If multiple Equities are found and as_of_date is not set,
465
        raises MultipleSymbolsFound.
466
467
        If no Equity was active at as_of_date raises SymbolNotFound.
468
        """
469
        company_symbol, share_class_symbol, fuzzy_symbol = \
470
            split_delimited_symbol(symbol)
471
        if as_of_date:
472
            # Format inputs
473
            as_of_date = pd.Timestamp(as_of_date).normalize()
474
            ad_value = as_of_date.value
475
476
            if fuzzy:
477
                # Search for a single exact match on the fuzzy column
478
                candidates = self._get_fuzzy_candidates_in_range(fuzzy_symbol,
479
                                                                 ad_value)
480
481
                # If exactly one SID exists for fuzzy_symbol, return that sid
482
                if len(candidates) == 1:
483
                    return self._get_best_candidate(candidates)
484
485
            # Search for exact matches of the split-up company_symbol and
486
            # share_class_symbol
487
            candidates = self._get_split_candidates_in_range(
488
                company_symbol,
489
                share_class_symbol,
490
                ad_value
491
            )
492
493
            # If exactly one SID exists for symbol, return that symbol
494
            # If multiple SIDs exist for symbol, return latest start_date with
495
            # end_date as a tie-breaker
496
            if candidates:
497
                return self._get_best_candidate(candidates)
498
499
            # If no SID exists for symbol, return SID with the
500
            # highest-but-not-over end_date
501
            elif not candidates:
502
                candidates = self._resolve_no_matching_candidates(
503
                    company_symbol,
504
                    share_class_symbol,
505
                    ad_value
506
                )
507
                if candidates:
508
                    return self._get_best_candidate(candidates)
509
510
            raise SymbolNotFound(symbol=symbol)
511
512
        else:
513
            # If this is a fuzzy look-up, check if there is exactly one match
514
            # for the fuzzy symbol
515
            if fuzzy:
516
                candidates = self._get_fuzzy_candidates(fuzzy_symbol)
517
                if len(candidates) == 1:
518
                    return self._get_best_candidate(candidates)
519
520
            candidates = self._get_split_candidates(company_symbol,
521
                                                    share_class_symbol)
522
            if len(candidates) == 1:
523
                return self._get_best_candidate(candidates)
524
            elif not candidates:
525
                raise SymbolNotFound(symbol=symbol)
526
            else:
527
                raise MultipleSymbolsFound(
528
                    symbol=symbol,
529
                    options=self._get_equities_from_candidates(candidates)
530
                )
531
532
    def lookup_future_symbol(self, symbol):
533
        """ Return the Future object for a given symbol.
534
535
        Parameters
536
        ----------
537
        symbol : str
538
            The symbol of the desired contract.
539
540
        Returns
541
        -------
542
        Future
543
            A Future object.
544
545
        Raises
546
        ------
547
        SymbolNotFound
548
            Raised when no contract named 'symbol' is found.
549
550
        """
551
552
        data = self._select_asset_by_symbol(self.futures_contracts, symbol)\
553
                   .execute().fetchone()
554
555
        # If no data found, raise an exception
556
        if not data:
557
            raise SymbolNotFound(symbol=symbol)
558
        return self.retrieve_asset(data['sid'])
559
560
    def lookup_future_chain(self, root_symbol, as_of_date):
561
        """ Return the futures chain for a given root symbol.
562
563
        Parameters
564
        ----------
565
        root_symbol : str
566
            Root symbol of the desired future.
567
568
        as_of_date : pd.Timestamp or pd.NaT
569
            Date at which the chain determination is rooted. I.e. the
570
            existing contract whose notice date/expiration date is first
571
            after this date is the primary contract, etc. If NaT is
572
            given, the chain is unbounded, and all contracts for this
573
            root symbol are returned.
574
575
        Returns
576
        -------
577
        list
578
            A list of Future objects, the chain for the given
579
            parameters.
580
581
        Raises
582
        ------
583
        RootSymbolNotFound
584
            Raised when a future chain could not be found for the given
585
            root symbol.
586
        """
587
588
        fc_cols = self.futures_contracts.c
589
590
        if as_of_date is pd.NaT:
591
            # If the as_of_date is NaT, get all contracts for this
592
            # root symbol.
593
            sids = list(map(
594
                itemgetter('sid'),
595
                sa.select((fc_cols.sid,)).where(
596
                    (fc_cols.root_symbol == root_symbol),
597
                ).order_by(
598
                    fc_cols.notice_date.asc(),
599
                ).execute().fetchall()))
600
        else:
601
            as_of_date = as_of_date.value
602
603
            sids = list(map(
604
                itemgetter('sid'),
605
                sa.select((fc_cols.sid,)).where(
606
                    (fc_cols.root_symbol == root_symbol) &
607
608
                    # Filter to contracts that are still valid. If both
609
                    # exist, use the one that comes first in time (i.e.
610
                    # the lower value). If either notice_date or
611
                    # expiration_date is NaT, use the other. If both are
612
                    # NaT, the contract cannot be included in any chain.
613
                    sa.case(
614
                        [
615
                            (
616
                                fc_cols.notice_date == pd.NaT.value,
617
                                fc_cols.expiration_date >= as_of_date
618
                            ),
619
                            (
620
                                fc_cols.expiration_date == pd.NaT.value,
621
                                fc_cols.notice_date >= as_of_date
622
                            )
623
                        ],
624
                        else_=(
625
                            sa.func.min(
626
                                fc_cols.notice_date,
627
                                fc_cols.expiration_date
628
                            ) >= as_of_date
629
                        )
630
                    )
631
                ).order_by(
632
                    # Sort using expiration_date if valid. If it's NaT,
633
                    # use notice_date instead.
634
                    sa.case(
635
                        [
636
                            (
637
                                fc_cols.expiration_date == pd.NaT.value,
638
                                fc_cols.notice_date
639
                            )
640
                        ],
641
                        else_=fc_cols.expiration_date
642
                    ).asc()
643
                ).execute().fetchall()
644
            ))
645
646
        if not sids:
647
            # Check if root symbol exists.
648
            count = sa.select((sa.func.count(fc_cols.sid),)).where(
649
                fc_cols.root_symbol == root_symbol,
650
            ).scalar()
651
            if count == 0:
652
                raise RootSymbolNotFound(root_symbol=root_symbol)
653
654
        contracts = self.retrieve_futures_contracts(sids)
655
        return [contracts[sid] for sid in sids]
656
657
    def lookup_expired_futures(self, start, end):
658
        start = start.value
659
        end = end.value
660
661
        fc_cols = self.futures_contracts.c
662
663
        nd = sa.func.nullif(fc_cols.notice_date, pd.tslib.iNaT)
664
        ed = sa.func.nullif(fc_cols.expiration_date, pd.tslib.iNaT)
665
        date = sa.func.coalesce(sa.func.min(nd, ed), ed, nd)
666
667
        sids = list(map(
668
            itemgetter('sid'),
669
            sa.select((fc_cols.sid,)).where(
670
                (date >= start) & (date < end)).order_by(
671
                sa.func.coalesce(ed, nd).asc()
672
            ).execute().fetchall()
673
        ))
674
675
        return sids
676
677
    @property
678
    def sids(self):
679
        return tuple(map(
680
            itemgetter('sid'),
681
            sa.select((self.asset_router.c.sid,)).execute().fetchall(),
682
        ))
683
684
    def _lookup_generic_scalar(self,
685
                               asset_convertible,
686
                               as_of_date,
687
                               matches,
688
                               missing):
689
        """
690
        Convert asset_convertible to an asset.
691
692
        On success, append to matches.
693
        On failure, append to missing.
694
        """
695
        if isinstance(asset_convertible, Asset):
696
            matches.append(asset_convertible)
697
698
        elif isinstance(asset_convertible, Integral):
699
            try:
700
                result = self.retrieve_asset(int(asset_convertible))
701
            except SidsNotFound:
702
                missing.append(asset_convertible)
703
                return None
704
            matches.append(result)
705
706
        elif isinstance(asset_convertible, string_types):
707
            try:
708
                matches.append(
709
                    self.lookup_symbol(asset_convertible, as_of_date)
710
                )
711
            except SymbolNotFound:
712
                missing.append(asset_convertible)
713
                return None
714
        else:
715
            raise NotAssetConvertible(
716
                "Input was %s, not AssetConvertible."
717
                % asset_convertible
718
            )
719
720
    def lookup_generic(self,
721
                       asset_convertible_or_iterable,
722
                       as_of_date):
723
        """
724
        Convert a AssetConvertible or iterable of AssetConvertibles into
725
        a list of Asset objects.
726
727
        This method exists primarily as a convenience for implementing
728
        user-facing APIs that can handle multiple kinds of input.  It should
729
        not be used for internal code where we already know the expected types
730
        of our inputs.
731
732
        Returns a pair of objects, the first of which is the result of the
733
        conversion, and the second of which is a list containing any values
734
        that couldn't be resolved.
735
        """
736
        matches = []
737
        missing = []
738
739
        # Interpret input as scalar.
740
        if isinstance(asset_convertible_or_iterable, AssetConvertible):
741
            self._lookup_generic_scalar(
742
                asset_convertible=asset_convertible_or_iterable,
743
                as_of_date=as_of_date,
744
                matches=matches,
745
                missing=missing,
746
            )
747
            try:
748
                return matches[0], missing
749
            except IndexError:
750
                if hasattr(asset_convertible_or_iterable, '__int__'):
751
                    raise SidsNotFound(sids=[asset_convertible_or_iterable])
752
                else:
753
                    raise SymbolNotFound(symbol=asset_convertible_or_iterable)
754
755
        # Interpret input as iterable.
756
        try:
757
            iterator = iter(asset_convertible_or_iterable)
758
        except TypeError:
759
            raise NotAssetConvertible(
760
                "Input was not a AssetConvertible "
761
                "or iterable of AssetConvertible."
762
            )
763
764
        for obj in iterator:
765
            self._lookup_generic_scalar(obj, as_of_date, matches, missing)
766
        return matches, missing
767
768
    def map_identifier_index_to_sids(self, index, as_of_date):
769
        """
770
        This method is for use in sanitizing a user's DataFrame or Panel
771
        inputs.
772
773
        Takes the given index of identifiers, checks their types, builds assets
774
        if necessary, and returns a list of the sids that correspond to the
775
        input index.
776
777
        Parameters
778
        ----------
779
        index : Iterable
780
            An iterable containing ints, strings, or Assets
781
        as_of_date : pandas.Timestamp
782
            A date to be used to resolve any dual-mapped symbols
783
784
        Returns
785
        -------
786
        List
787
            A list of integer sids corresponding to the input index
788
        """
789
        # This method assumes that the type of the objects in the index is
790
        # consistent and can, therefore, be taken from the first identifier
791
        first_identifier = index[0]
792
793
        # Ensure that input is AssetConvertible (integer, string, or Asset)
794
        if not isinstance(first_identifier, AssetConvertible):
795
            raise MapAssetIdentifierIndexError(obj=first_identifier)
796
797
        # If sids are provided, no mapping is necessary
798
        if isinstance(first_identifier, Integral):
799
            return index
800
801
        # Look up all Assets for mapping
802
        matches = []
803
        missing = []
804
        for identifier in index:
805
            self._lookup_generic_scalar(identifier, as_of_date,
806
                                        matches, missing)
807
808
        if missing:
809
            raise ValueError("Missing assets for identifiers: %s" % missing)
810
811
        # Return a list of the sids of the found assets
812
        return [asset.sid for asset in matches]
813
814
    def _compute_asset_lifetimes(self):
815
        """
816
        Compute and cache a recarry of asset lifetimes.
817
        """
818
        equities_cols = self.equities.c
819
        buf = np.array(
820
            tuple(
821
                sa.select((
822
                    equities_cols.sid,
823
                    equities_cols.start_date,
824
                    equities_cols.end_date,
825
                )).execute(),
826
            ), dtype='<f8',  # use doubles so we get NaNs
827
        )
828
        lifetimes = np.recarray(
829
            buf=buf,
830
            shape=(len(buf),),
831
            dtype=[
832
                ('sid', '<f8'),
833
                ('start', '<f8'),
834
                ('end', '<f8')
835
            ],
836
        )
837
        start = lifetimes.start
838
        end = lifetimes.end
839
        start[np.isnan(start)] = 0  # convert missing starts to 0
840
        end[np.isnan(end)] = np.iinfo(int).max  # convert missing end to INTMAX
841
        # Cast the results back down to int.
842
        return lifetimes.astype([
843
            ('sid', '<i8'),
844
            ('start', '<i8'),
845
            ('end', '<i8'),
846
        ])
847
848
    def lifetimes(self, dates, include_start_date):
849
        """
850
        Compute a DataFrame representing asset lifetimes for the specified date
851
        range.
852
853
        Parameters
854
        ----------
855
        dates : pd.DatetimeIndex
856
            The dates for which to compute lifetimes.
857
        include_start_date : bool
858
            Whether or not to count the asset as alive on its start_date.
859
860
            This is useful in a backtesting context where `lifetimes` is being
861
            used to signify "do I have data for this asset as of the morning of
862
            this date?"  For many financial metrics, (e.g. daily close), data
863
            isn't available for an asset until the end of the asset's first
864
            day.
865
866
        Returns
867
        -------
868
        lifetimes : pd.DataFrame
869
            A frame of dtype bool with `dates` as index and an Int64Index of
870
            assets as columns.  The value at `lifetimes.loc[date, asset]` will
871
            be True iff `asset` existed on `date`.  If `include_start_date` is
872
            False, then lifetimes.loc[date, asset] will be false when date ==
873
            asset.start_date.
874
875
        See Also
876
        --------
877
        numpy.putmask
878
        zipline.pipeline.engine.SimplePipelineEngine._compute_root_mask
879
        """
880
        # This is a less than ideal place to do this, because if someone adds
881
        # assets to the finder after we've touched lifetimes we won't have
882
        # those new assets available.  Mutability is not my favorite
883
        # programming feature.
884
        if self._asset_lifetimes is None:
885
            self._asset_lifetimes = self._compute_asset_lifetimes()
886
        lifetimes = self._asset_lifetimes
887
888
        raw_dates = dates.asi8[:, None]
889
        if include_start_date:
890
            mask = lifetimes.start <= raw_dates
891
        else:
892
            mask = lifetimes.start < raw_dates
893
        mask &= (raw_dates <= lifetimes.end)
894
895
        return pd.DataFrame(mask, index=dates, columns=lifetimes.sid)
896
897
898
class AssetConvertible(with_metaclass(ABCMeta)):
899
    """
900
    ABC for types that are convertible to integer-representations of
901
    Assets.
902
903
    Includes Asset, six.string_types, and Integral
904
    """
905
    pass
906
907
908
AssetConvertible.register(Integral)
909
AssetConvertible.register(Asset)
910
# Use six.string_types for Python2/3 compatibility
911
for _type in string_types:
912
    AssetConvertible.register(_type)
913
914
915
class NotAssetConvertible(ValueError):
916
    pass
917
918
919
class AssetFinderCachedEquities(AssetFinder):
920
    """
921
    An extension to AssetFinder that loads all equities from equities table
922
    into memory and overrides the methods that lookup_symbol uses to look up
923
    those equities.
924
    """
925
926
    def __init__(self, engine):
927
        super(AssetFinderCachedEquities, self).__init__(engine)
928
        self.fuzzy_symbol_hashed_equities = {}
929
        self.company_share_class_hashed_equities = {}
930
        self.hashed_equities = sa.select(self.equities.c).execute().fetchall()
931
        self._load_hashed_equities()
932
933
    def _load_hashed_equities(self):
934
        """
935
        Populates two maps - fuzzy symbol to list of equities having that
936
        fuzzy symbol and company symbol/share class symbol to list of
937
        equities having that combination of company symbol/share class symbol.
938
        """
939
        for equity in self.hashed_equities:
940
            company_symbol = equity['company_symbol']
941
            share_class_symbol = equity['share_class_symbol']
942
            fuzzy_symbol = equity['fuzzy_symbol']
943
            asset = self._convert_row_to_equity(equity)
944
            self.company_share_class_hashed_equities.setdefault(
945
                (company_symbol, share_class_symbol),
946
                []
947
            ).append(asset)
948
            self.fuzzy_symbol_hashed_equities.setdefault(
949
                fuzzy_symbol, []
950
            ).append(asset)
951
952
    def _convert_row_to_equity(self, row):
953
        """
954
        Converts a SQLAlchemy equity row to an Equity object.
955
        """
956
        return Equity(**_convert_asset_timestamp_fields(dict(row)))
957
958
    def _get_fuzzy_candidates(self, fuzzy_symbol):
959
        return self.fuzzy_symbol_hashed_equities.get(fuzzy_symbol, ())
960
961
    def _get_fuzzy_candidates_in_range(self, fuzzy_symbol, ad_value):
962
        return only_active_assets(
963
            ad_value,
964
            self._get_fuzzy_candidates(fuzzy_symbol),
965
        )
966
967
    def _get_split_candidates(self, company_symbol, share_class_symbol):
968
        return self.company_share_class_hashed_equities.get(
969
            (company_symbol, share_class_symbol),
970
            (),
971
        )
972
973
    def _get_split_candidates_in_range(self,
974
                                       company_symbol,
975
                                       share_class_symbol,
976
                                       ad_value):
977
        return sorted(
978
            only_active_assets(
979
                ad_value,
980
                self._get_split_candidates(company_symbol, share_class_symbol),
981
            ),
982
            key=lambda x: (x.start_date, x.end_date),
983
            reverse=True,
984
        )
985
986
    def _resolve_no_matching_candidates(self,
987
                                        company_symbol,
988
                                        share_class_symbol,
989
                                        ad_value):
990
        equities = self._get_split_candidates(
991
            company_symbol, share_class_symbol
992
        )
993
        partial_candidates = []
994
        for equity in equities:
995
            if equity.start_date.value <= ad_value:
996
                partial_candidates.append(equity)
997
        if partial_candidates:
998
            partial_candidates = sorted(
999
                partial_candidates,
1000
                key=lambda x: x.end_date,
1001
                reverse=True
1002
            )
1003
        return partial_candidates
1004
1005
    def _get_best_candidate(self, candidates):
1006
        return candidates[0]
1007
1008
    def _get_equities_from_candidates(self, candidates):
1009
        return candidates
1010
1011
1012
def was_active(reference_date_value, asset):
1013
    """
1014
    Whether or not `asset` was active at the time corresponding to
1015
    `reference_date_value`.
1016
1017
    Parameters
1018
    ----------
1019
    reference_date_value : int
1020
        Date, represented as nanoseconds since EPOCH, for which we want to know
1021
        if `asset` was alive.  This is generally the result of accessing the
1022
        `value` attribute of a pandas Timestamp.
1023
    asset : Asset
1024
        The asset object to check.
1025
1026
    Returns
1027
    -------
1028
    was_active : bool
1029
        Whether or not the `asset` existed at the specified time.
1030
    """
1031
    return (
1032
        asset.start_date.value
1033
        <= reference_date_value
1034
        <= asset.end_date.value
1035
    )
1036
1037
1038
def only_active_assets(reference_date_value, assets):
1039
    """
1040
    Filter an iterable of Asset objects down to just assets that were alive at
1041
    the time corresponding to `reference_date_value`.
1042
1043
    Parameters
1044
    ----------
1045
    reference_date_value : int
1046
        Date, represented as nanoseconds since EPOCH, for which we want to know
1047
        if `asset` was alive.  This is generally the result of accessing the
1048
        `value` attribute of a pandas Timestamp.
1049
    assets : iterable[Asset]
1050
        The assets to filter.
1051
1052
    Returns
1053
    -------
1054
    active_assets : list
1055
        List of the active assets from `assets` on the requested date.
1056
    """
1057
    return [a for a in assets if was_active(reference_date_value, a)]
1058