_select_asset_by_symbol()   A
last analyzed

Complexity

Conditions 1

Size

Total Lines 3

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 1
dl 0
loc 3
rs 10
1
# Copyright 2015 Quantopian, Inc.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
15
from abc import ABCMeta
16
from numbers import Integral
17
from operator import itemgetter
18
19
from logbook import Logger
20
import numpy as np
21
import pandas as pd
22
from pandas import isnull
23
from six import with_metaclass, string_types, viewkeys
24
from six.moves import map as imap, range
25
import sqlalchemy as sa
26
27
from zipline.errors import (
28
    EquitiesNotFound,
29
    FutureContractsNotFound,
30
    MapAssetIdentifierIndexError,
31
    MultipleSymbolsFound,
32
    RootSymbolNotFound,
33
    SidsNotFound,
34
    SymbolNotFound,
35
)
36
from zipline.assets import (
37
    Asset, Equity, Future,
38
)
39
from zipline.assets.asset_writer import (
40
    check_version_info,
41
    split_delimited_symbol,
42
    asset_db_table_names,
43
    SQLITE_MAX_VARIABLE_NUMBER,
44
)
45
from zipline.assets.asset_db_schema import (
46
    ASSET_DB_VERSION
47
)
48
from zipline.utils.control_flow import invert
49
50
log = Logger('assets.py')
51
52
# A set of fields that need to be converted to strings before building an
53
# Asset to avoid unicode fields
54
_asset_str_fields = frozenset({
55
    'symbol',
56
    'asset_name',
57
    'exchange',
58
})
59
60
# A set of fields that need to be converted to timestamps in UTC
61
_asset_timestamp_fields = frozenset({
62
    'start_date',
63
    'end_date',
64
    'first_traded',
65
    'notice_date',
66
    'expiration_date',
67
    'auto_close_date',
68
})
69
70
71
def _convert_asset_timestamp_fields(dict_):
72
    """
73
    Takes in a dict of Asset init args and converts dates to pd.Timestamps
74
    """
75
    for key in (_asset_timestamp_fields & viewkeys(dict_)):
76
        value = pd.Timestamp(dict_[key], tz='UTC')
77
        dict_[key] = None if isnull(value) else value
78
    return dict_
79
80
81
class AssetFinder(object):
82
    """
83
    An AssetFinder is an interface to a database of Asset metadata written by
84
    an ``AssetDBWriter``.
85
86
    This class provides methods for looking up assets by unique integer id or
87
    by symbol.  For historical reasons, we refer to these unique ids as 'sids'.
88
89
    Parameters
90
    ----------
91
    engine : str or SQLAlchemy.engine
92
        An engine with a connection to the asset database to use, or a string
93
        that can be parsed by SQLAlchemy as a URI.
94
95
    See Also
96
    --------
97
    :class:`zipline.assets.asset_writer.AssetDBWriter`
98
    """
99
    # Token used as a substitute for pickling objects that contain a
100
    # reference to an AssetFinder.
101
    PERSISTENT_TOKEN = "<AssetFinder>"
102
103
    def __init__(self, engine):
104
105
        self.engine = engine
106
        metadata = sa.MetaData(bind=engine)
107
        metadata.reflect(only=asset_db_table_names)
108
        for table_name in asset_db_table_names:
109
            setattr(self, table_name, metadata.tables[table_name])
110
111
        # Check the version info of the db for compatibility
112
        check_version_info(self.version_info, ASSET_DB_VERSION)
113
114
        # Cache for lookup of assets by sid, the objects in the asset lookup
115
        # may be shared with the results from equity and future lookup caches.
116
        #
117
        # The top level cache exists to minimize lookups on the asset type
118
        # routing.
119
        #
120
        # The caches are read through, i.e. accessing an asset through
121
        # retrieve_asset will populate the cache on first retrieval.
122
        self._caches = (self._asset_cache, self._asset_type_cache) = {}, {}
123
124
        # Populated on first call to `lifetimes`.
125
        self._asset_lifetimes = None
126
127
    def _reset_caches(self):
128
        """
129
        Reset our asset caches.
130
131
        You probably shouldn't call this method.
132
        """
133
        # This method exists as a workaround for the in-place mutating behavior
134
        # of `TradingAlgorithm._write_and_map_id_index_to_sids`.  No one else
135
        # should be calling this.
136
        for cache in self._caches:
137
            cache.clear()
138
139
    def lookup_asset_types(self, sids):
140
        """
141
        Retrieve asset types for a list of sids.
142
143
        Parameters
144
        ----------
145
        sids : list[int]
146
147
        Returns
148
        -------
149
        types : dict[sid -> str or None]
150
            Asset types for the provided sids.
151
        """
152
        found = {}
153
        missing = set()
154
155
        for sid in sids:
156
            try:
157
                found[sid] = self._asset_type_cache[sid]
158
            except KeyError:
159
                missing.add(sid)
160
161
        if not missing:
162
            return found
163
164
        router_cols = self.asset_router.c
165
166
        for assets in self._group_into_chunks(missing):
167
            query = sa.select((router_cols.sid, router_cols.asset_type)).where(
168
                self.asset_router.c.sid.in_(map(int, assets))
169
            )
170
            for sid, type_ in query.execute().fetchall():
171
                missing.remove(sid)
172
                found[sid] = self._asset_type_cache[sid] = type_
173
174
            for sid in missing:
175
                found[sid] = self._asset_type_cache[sid] = None
176
177
        return found
178
179
    @staticmethod
180
    def _group_into_chunks(items, chunk_size=SQLITE_MAX_VARIABLE_NUMBER):
181
        items = list(items)
182
        return [items[x:x+chunk_size]
183
                for x in range(0, len(items), chunk_size)]
184
185
    def group_by_type(self, sids):
186
        """
187
        Group a list of sids by asset type.
188
189
        Parameters
190
        ----------
191
        sids : list[int]
192
193
        Returns
194
        -------
195
        types : dict[str or None -> list[int]]
196
            A dict mapping unique asset types to lists of sids drawn from sids.
197
            If we fail to look up an asset, we assign it a key of None.
198
        """
199
        return invert(self.lookup_asset_types(sids))
200
201
    def retrieve_asset(self, sid, default_none=False):
202
        """
203
        Retrieve the Asset for a given sid.
204
        """
205
        return self.retrieve_all((sid,), default_none=default_none)[0]
206
207
    def retrieve_all(self, sids, default_none=False):
208
        """
209
        Retrieve all assets in `sids`.
210
211
        Parameters
212
        ----------
213
        sids : interable of int
214
            Assets to retrieve.
215
        default_none : bool
216
            If True, return None for failed lookups.
217
            If False, raise `SidsNotFound`.
218
219
        Returns
220
        -------
221
        assets : list[int or None]
222
            A list of the same length as `sids` containing Assets (or Nones)
223
            corresponding to the requested sids.
224
225
        Raises
226
        ------
227
        SidsNotFound
228
            When a requested sid is not found and default_none=False.
229
        """
230
        hits, missing, failures = {}, set(), []
231
        for sid in sids:
232
            try:
233
                asset = self._asset_cache[sid]
234
                if not default_none and asset is None:
235
                    # Bail early if we've already cached that we don't know
236
                    # about an asset.
237
                    raise SidsNotFound(sids=[sid])
238
                hits[sid] = asset
239
            except KeyError:
240
                missing.add(sid)
241
242
        # All requests were cache hits.  Return requested sids in order.
243
        if not missing:
244
            return [hits[sid] for sid in sids]
245
246
        update_hits = hits.update
247
248
        # Look up cache misses by type.
249
        type_to_assets = self.group_by_type(missing)
250
251
        # Handle failures
252
        failures = {failure: None for failure in type_to_assets.pop(None, ())}
253
        update_hits(failures)
254
        self._asset_cache.update(failures)
255
256
        if failures and not default_none:
257
            raise SidsNotFound(sids=list(failures))
258
259
        # We don't update the asset cache here because it should already be
260
        # updated by `self.retrieve_equities`.
261
        update_hits(self.retrieve_equities(type_to_assets.pop('equity', ())))
262
        update_hits(
263
            self.retrieve_futures_contracts(type_to_assets.pop('future', ()))
264
        )
265
266
        # We shouldn't know about any other asset types.
267
        if type_to_assets:
268
            raise AssertionError(
269
                "Found asset types: %s" % list(type_to_assets.keys())
270
            )
271
272
        return [hits[sid] for sid in sids]
273
274
    def retrieve_equities(self, sids):
275
        """
276
        Retrieve Equity objects for a list of sids.
277
278
        Users generally shouldn't need to this method (instead, they should
279
        prefer the more general/friendly `retrieve_assets`), but it has a
280
        documented interface and tests because it's used upstream.
281
282
        Parameters
283
        ----------
284
        sids : iterable[int]
285
286
        Returns
287
        -------
288
        equities : dict[int -> Equity]
289
290
        Raises
291
        ------
292
        EquitiesNotFound
293
            When any requested asset isn't found.
294
        """
295
        return self._retrieve_assets(sids, self.equities, Equity)
296
297
    def _retrieve_equity(self, sid):
298
        return self.retrieve_equities((sid,))[sid]
299
300
    def retrieve_futures_contracts(self, sids):
301
        """
302
        Retrieve Future objects for an iterable of sids.
303
304
        Users generally shouldn't need to this method (instead, they should
305
        prefer the more general/friendly `retrieve_assets`), but it has a
306
        documented interface and tests because it's used upstream.
307
308
        Parameters
309
        ----------
310
        sids : iterable[int]
311
312
        Returns
313
        -------
314
        equities : dict[int -> Equity]
315
316
        Raises
317
        ------
318
        EquitiesNotFound
319
            When any requested asset isn't found.
320
        """
321
        return self._retrieve_assets(sids, self.futures_contracts, Future)
322
323
    @staticmethod
324
    def _select_assets_by_sid(asset_tbl, sids):
325
        return sa.select([asset_tbl]).where(
326
            asset_tbl.c.sid.in_(map(int, sids))
327
        )
328
329
    @staticmethod
330
    def _select_asset_by_symbol(asset_tbl, symbol):
331
        return sa.select([asset_tbl]).where(asset_tbl.c.symbol == symbol)
332
333
    def _retrieve_assets(self, sids, asset_tbl, asset_type):
334
        """
335
        Internal function for loading assets from a table.
336
337
        This should be the only method of `AssetFinder` that writes Assets into
338
        self._asset_cache.
339
340
        Parameters
341
        ---------
342
        sids : iterable of int
343
            Asset ids to look up.
344
        asset_tbl : sqlalchemy.Table
345
            Table from which to query assets.
346
        asset_type : type
347
            Type of asset to be constructed.
348
349
        Returns
350
        -------
351
        assets : dict[int -> Asset]
352
            Dict mapping requested sids to the retrieved assets.
353
        """
354
        # Fastpath for empty request.
355
        if not sids:
356
            return {}
357
358
        cache = self._asset_cache
359
        hits = {}
360
361
        for assets in self._group_into_chunks(sids):
362
            # Load misses from the db.
363
            query = self._select_assets_by_sid(asset_tbl, assets)
364
365
            for row in imap(dict, query.execute().fetchall()):
366
                asset = asset_type(**_convert_asset_timestamp_fields(row))
367
                sid = asset.sid
368
                hits[sid] = cache[sid] = asset
369
370
        # If we get here, it means something in our code thought that a
371
        # particular sid was an equity/future and called this function with a
372
        # concrete type, but we couldn't actually resolve the asset.  This is
373
        # an error in our code, not a user-input error.
374
        misses = tuple(set(sids) - viewkeys(hits))
375
        if misses:
376
            if asset_type == Equity:
377
                raise EquitiesNotFound(sids=misses)
378
            else:
379
                raise FutureContractsNotFound(sids=misses)
380
        return hits
381
382
    def _get_fuzzy_candidates(self, fuzzy_symbol):
383
        candidates = sa.select(
384
            (self.equities.c.sid,)
385
        ).where(self.equities.c.fuzzy_symbol == fuzzy_symbol).order_by(
386
            self.equities.c.start_date.desc(),
387
            self.equities.c.end_date.desc()
388
        ).execute().fetchall()
389
        return candidates
390
391
    def _get_fuzzy_candidates_in_range(self, fuzzy_symbol, ad_value):
392
        candidates = sa.select(
393
            (self.equities.c.sid,)
394
        ).where(
395
            sa.and_(
396
                self.equities.c.fuzzy_symbol == fuzzy_symbol,
397
                self.equities.c.start_date <= ad_value,
398
                self.equities.c.end_date >= ad_value
399
            )
400
        ).order_by(
401
            self.equities.c.start_date.desc(),
402
            self.equities.c.end_date.desc(),
403
        ).execute().fetchall()
404
        return candidates
405
406
    def _get_split_candidates_in_range(self,
407
                                       company_symbol,
408
                                       share_class_symbol,
409
                                       ad_value):
410
        candidates = sa.select(
411
            (self.equities.c.sid,)
412
        ).where(
413
            sa.and_(
414
                self.equities.c.company_symbol == company_symbol,
415
                self.equities.c.share_class_symbol == share_class_symbol,
416
                self.equities.c.start_date <= ad_value,
417
                self.equities.c.end_date >= ad_value
418
            )
419
        ).order_by(
420
            self.equities.c.start_date.desc(),
421
            self.equities.c.end_date.desc(),
422
        ).execute().fetchall()
423
        return candidates
424
425
    def _get_split_candidates(self, company_symbol, share_class_symbol):
426
        candidates = sa.select(
427
            (self.equities.c.sid,)
428
        ).where(
429
            sa.and_(
430
                self.equities.c.company_symbol == company_symbol,
431
                self.equities.c.share_class_symbol == share_class_symbol
432
            )
433
        ).order_by(
434
            self.equities.c.start_date.desc(),
435
            self.equities.c.end_date.desc(),
436
        ).execute().fetchall()
437
        return candidates
438
439
    def _resolve_no_matching_candidates(self,
440
                                        company_symbol,
441
                                        share_class_symbol,
442
                                        ad_value):
443
        candidates = sa.select((self.equities.c.sid,)).where(
444
            sa.and_(
445
                self.equities.c.company_symbol == company_symbol,
446
                self.equities.c.share_class_symbol ==
447
                share_class_symbol,
448
                self.equities.c.start_date <= ad_value),
449
            ).order_by(
450
                self.equities.c.end_date.desc(),
451
            ).execute().fetchall()
452
        return candidates
453
454
    def _get_best_candidate(self, candidates):
455
        return self._retrieve_equity(candidates[0]['sid'])
456
457
    def _get_equities_from_candidates(self, candidates):
458
        sids = map(itemgetter('sid'), candidates)
459
        results = self.retrieve_equities(sids)
460
        return [results[sid] for sid in sids]
461
462
    def lookup_symbol(self, symbol, as_of_date, fuzzy=False):
463
        """
464
        Return matching Equity of name symbol in database.
465
466
        If multiple Equities are found and as_of_date is not set,
467
        raises MultipleSymbolsFound.
468
469
        If no Equity was active at as_of_date raises SymbolNotFound.
470
        """
471
        company_symbol, share_class_symbol, fuzzy_symbol = \
472
            split_delimited_symbol(symbol)
473
        if as_of_date:
474
            # Format inputs
475
            as_of_date = pd.Timestamp(as_of_date).normalize()
476
            ad_value = as_of_date.value
477
478
            if fuzzy:
479
                # Search for a single exact match on the fuzzy column
480
                candidates = self._get_fuzzy_candidates_in_range(fuzzy_symbol,
481
                                                                 ad_value)
482
483
                # If exactly one SID exists for fuzzy_symbol, return that sid
484
                if len(candidates) == 1:
485
                    return self._get_best_candidate(candidates)
486
487
            # Search for exact matches of the split-up company_symbol and
488
            # share_class_symbol
489
            candidates = self._get_split_candidates_in_range(
490
                company_symbol,
491
                share_class_symbol,
492
                ad_value
493
            )
494
495
            # If exactly one SID exists for symbol, return that symbol
496
            # If multiple SIDs exist for symbol, return latest start_date with
497
            # end_date as a tie-breaker
498
            if candidates:
499
                return self._get_best_candidate(candidates)
500
501
            # If no SID exists for symbol, return SID with the
502
            # highest-but-not-over end_date
503
            elif not candidates:
504
                candidates = self._resolve_no_matching_candidates(
505
                    company_symbol,
506
                    share_class_symbol,
507
                    ad_value
508
                )
509
                if candidates:
510
                    return self._get_best_candidate(candidates)
511
512
            raise SymbolNotFound(symbol=symbol)
513
514
        else:
515
            # If this is a fuzzy look-up, check if there is exactly one match
516
            # for the fuzzy symbol
517
            if fuzzy:
518
                candidates = self._get_fuzzy_candidates(fuzzy_symbol)
519
                if len(candidates) == 1:
520
                    return self._get_best_candidate(candidates)
521
522
            candidates = self._get_split_candidates(company_symbol,
523
                                                    share_class_symbol)
524
            if len(candidates) == 1:
525
                return self._get_best_candidate(candidates)
526
            elif not candidates:
527
                raise SymbolNotFound(symbol=symbol)
528
            else:
529
                raise MultipleSymbolsFound(
530
                    symbol=symbol,
531
                    options=self._get_equities_from_candidates(candidates)
532
                )
533
534
    def lookup_future_symbol(self, symbol):
535
        """ Return the Future object for a given symbol.
536
537
        Parameters
538
        ----------
539
        symbol : str
540
            The symbol of the desired contract.
541
542
        Returns
543
        -------
544
        Future
545
            A Future object.
546
547
        Raises
548
        ------
549
        SymbolNotFound
550
            Raised when no contract named 'symbol' is found.
551
552
        """
553
554
        data = self._select_asset_by_symbol(self.futures_contracts, symbol)\
555
                   .execute().fetchone()
556
557
        # If no data found, raise an exception
558
        if not data:
559
            raise SymbolNotFound(symbol=symbol)
560
        return self.retrieve_asset(data['sid'])
561
562
    def lookup_future_chain(self, root_symbol, as_of_date):
563
        """ Return the futures chain for a given root symbol.
564
565
        Parameters
566
        ----------
567
        root_symbol : str
568
            Root symbol of the desired future.
569
570
        as_of_date : pd.Timestamp or pd.NaT
571
            Date at which the chain determination is rooted. I.e. the
572
            existing contract whose notice date/expiration date is first
573
            after this date is the primary contract, etc. If NaT is
574
            given, the chain is unbounded, and all contracts for this
575
            root symbol are returned.
576
577
        Returns
578
        -------
579
        list
580
            A list of Future objects, the chain for the given
581
            parameters.
582
583
        Raises
584
        ------
585
        RootSymbolNotFound
586
            Raised when a future chain could not be found for the given
587
            root symbol.
588
        """
589
590
        fc_cols = self.futures_contracts.c
591
592
        if as_of_date is pd.NaT:
593
            # If the as_of_date is NaT, get all contracts for this
594
            # root symbol.
595
            sids = list(map(
596
                itemgetter('sid'),
597
                sa.select((fc_cols.sid,)).where(
598
                    (fc_cols.root_symbol == root_symbol),
599
                ).order_by(
600
                    fc_cols.notice_date.asc(),
601
                ).execute().fetchall()))
602
        else:
603
            as_of_date = as_of_date.value
604
605
            sids = list(map(
606
                itemgetter('sid'),
607
                sa.select((fc_cols.sid,)).where(
608
                    (fc_cols.root_symbol == root_symbol) &
609
610
                    # Filter to contracts that are still valid. If both
611
                    # exist, use the one that comes first in time (i.e.
612
                    # the lower value). If either notice_date or
613
                    # expiration_date is NaT, use the other. If both are
614
                    # NaT, the contract cannot be included in any chain.
615
                    sa.case(
616
                        [
617
                            (
618
                                fc_cols.notice_date == pd.NaT.value,
619
                                fc_cols.expiration_date >= as_of_date
620
                            ),
621
                            (
622
                                fc_cols.expiration_date == pd.NaT.value,
623
                                fc_cols.notice_date >= as_of_date
624
                            )
625
                        ],
626
                        else_=(
627
                            sa.func.min(
628
                                fc_cols.notice_date,
629
                                fc_cols.expiration_date
630
                            ) >= as_of_date
631
                        )
632
                    )
633
                ).order_by(
634
                    # Sort using expiration_date if valid. If it's NaT,
635
                    # use notice_date instead.
636
                    sa.case(
637
                        [
638
                            (
639
                                fc_cols.expiration_date == pd.NaT.value,
640
                                fc_cols.notice_date
641
                            )
642
                        ],
643
                        else_=fc_cols.expiration_date
644
                    ).asc()
645
                ).execute().fetchall()
646
            ))
647
648
        if not sids:
649
            # Check if root symbol exists.
650
            count = sa.select((sa.func.count(fc_cols.sid),)).where(
651
                fc_cols.root_symbol == root_symbol,
652
            ).scalar()
653
            if count == 0:
654
                raise RootSymbolNotFound(root_symbol=root_symbol)
655
656
        contracts = self.retrieve_futures_contracts(sids)
657
        return [contracts[sid] for sid in sids]
658
659
    @property
660
    def sids(self):
661
        return tuple(map(
662
            itemgetter('sid'),
663
            sa.select((self.asset_router.c.sid,)).execute().fetchall(),
664
        ))
665
666
    def _lookup_generic_scalar(self,
667
                               asset_convertible,
668
                               as_of_date,
669
                               matches,
670
                               missing):
671
        """
672
        Convert asset_convertible to an asset.
673
674
        On success, append to matches.
675
        On failure, append to missing.
676
        """
677
        if isinstance(asset_convertible, Asset):
678
            matches.append(asset_convertible)
679
680
        elif isinstance(asset_convertible, Integral):
681
            try:
682
                result = self.retrieve_asset(int(asset_convertible))
683
            except SidsNotFound:
684
                missing.append(asset_convertible)
685
                return None
686
            matches.append(result)
687
688
        elif isinstance(asset_convertible, string_types):
689
            try:
690
                matches.append(
691
                    self.lookup_symbol(asset_convertible, as_of_date)
692
                )
693
            except SymbolNotFound:
694
                missing.append(asset_convertible)
695
                return None
696
        else:
697
            raise NotAssetConvertible(
698
                "Input was %s, not AssetConvertible."
699
                % asset_convertible
700
            )
701
702
    def lookup_generic(self,
703
                       asset_convertible_or_iterable,
704
                       as_of_date):
705
        """
706
        Convert a AssetConvertible or iterable of AssetConvertibles into
707
        a list of Asset objects.
708
709
        This method exists primarily as a convenience for implementing
710
        user-facing APIs that can handle multiple kinds of input.  It should
711
        not be used for internal code where we already know the expected types
712
        of our inputs.
713
714
        Returns a pair of objects, the first of which is the result of the
715
        conversion, and the second of which is a list containing any values
716
        that couldn't be resolved.
717
        """
718
        matches = []
719
        missing = []
720
721
        # Interpret input as scalar.
722
        if isinstance(asset_convertible_or_iterable, AssetConvertible):
723
            self._lookup_generic_scalar(
724
                asset_convertible=asset_convertible_or_iterable,
725
                as_of_date=as_of_date,
726
                matches=matches,
727
                missing=missing,
728
            )
729
            try:
730
                return matches[0], missing
731
            except IndexError:
732
                if hasattr(asset_convertible_or_iterable, '__int__'):
733
                    raise SidsNotFound(sids=[asset_convertible_or_iterable])
734
                else:
735
                    raise SymbolNotFound(symbol=asset_convertible_or_iterable)
736
737
        # Interpret input as iterable.
738
        try:
739
            iterator = iter(asset_convertible_or_iterable)
740
        except TypeError:
741
            raise NotAssetConvertible(
742
                "Input was not a AssetConvertible "
743
                "or iterable of AssetConvertible."
744
            )
745
746
        for obj in iterator:
747
            self._lookup_generic_scalar(obj, as_of_date, matches, missing)
748
        return matches, missing
749
750
    def map_identifier_index_to_sids(self, index, as_of_date):
751
        """
752
        This method is for use in sanitizing a user's DataFrame or Panel
753
        inputs.
754
755
        Takes the given index of identifiers, checks their types, builds assets
756
        if necessary, and returns a list of the sids that correspond to the
757
        input index.
758
759
        Parameters
760
        ----------
761
        index : Iterable
762
            An iterable containing ints, strings, or Assets
763
        as_of_date : pandas.Timestamp
764
            A date to be used to resolve any dual-mapped symbols
765
766
        Returns
767
        -------
768
        List
769
            A list of integer sids corresponding to the input index
770
        """
771
        # This method assumes that the type of the objects in the index is
772
        # consistent and can, therefore, be taken from the first identifier
773
        first_identifier = index[0]
774
775
        # Ensure that input is AssetConvertible (integer, string, or Asset)
776
        if not isinstance(first_identifier, AssetConvertible):
777
            raise MapAssetIdentifierIndexError(obj=first_identifier)
778
779
        # If sids are provided, no mapping is necessary
780
        if isinstance(first_identifier, Integral):
781
            return index
782
783
        # Look up all Assets for mapping
784
        matches = []
785
        missing = []
786
        for identifier in index:
787
            self._lookup_generic_scalar(identifier, as_of_date,
788
                                        matches, missing)
789
790
        if missing:
791
            raise ValueError("Missing assets for identifiers: %s" % missing)
792
793
        # Return a list of the sids of the found assets
794
        return [asset.sid for asset in matches]
795
796
    def _compute_asset_lifetimes(self):
797
        """
798
        Compute and cache a recarry of asset lifetimes.
799
        """
800
        equities_cols = self.equities.c
801
        buf = np.array(
802
            tuple(
803
                sa.select((
804
                    equities_cols.sid,
805
                    equities_cols.start_date,
806
                    equities_cols.end_date,
807
                )).execute(),
808
            ), dtype='<f8',  # use doubles so we get NaNs
809
        )
810
        lifetimes = np.recarray(
811
            buf=buf,
812
            shape=(len(buf),),
813
            dtype=[
814
                ('sid', '<f8'),
815
                ('start', '<f8'),
816
                ('end', '<f8')
817
            ],
818
        )
819
        start = lifetimes.start
820
        end = lifetimes.end
821
        start[np.isnan(start)] = 0  # convert missing starts to 0
822
        end[np.isnan(end)] = np.iinfo(int).max  # convert missing end to INTMAX
823
        # Cast the results back down to int.
824
        return lifetimes.astype([
825
            ('sid', '<i8'),
826
            ('start', '<i8'),
827
            ('end', '<i8'),
828
        ])
829
830
    def lifetimes(self, dates, include_start_date):
831
        """
832
        Compute a DataFrame representing asset lifetimes for the specified date
833
        range.
834
835
        Parameters
836
        ----------
837
        dates : pd.DatetimeIndex
838
            The dates for which to compute lifetimes.
839
        include_start_date : bool
840
            Whether or not to count the asset as alive on its start_date.
841
842
            This is useful in a backtesting context where `lifetimes` is being
843
            used to signify "do I have data for this asset as of the morning of
844
            this date?"  For many financial metrics, (e.g. daily close), data
845
            isn't available for an asset until the end of the asset's first
846
            day.
847
848
        Returns
849
        -------
850
        lifetimes : pd.DataFrame
851
            A frame of dtype bool with `dates` as index and an Int64Index of
852
            assets as columns.  The value at `lifetimes.loc[date, asset]` will
853
            be True iff `asset` existed on `date`.  If `include_start_date` is
854
            False, then lifetimes.loc[date, asset] will be false when date ==
855
            asset.start_date.
856
857
        See Also
858
        --------
859
        numpy.putmask
860
        zipline.pipeline.engine.SimplePipelineEngine._compute_root_mask
861
        """
862
        # This is a less than ideal place to do this, because if someone adds
863
        # assets to the finder after we've touched lifetimes we won't have
864
        # those new assets available.  Mutability is not my favorite
865
        # programming feature.
866
        if self._asset_lifetimes is None:
867
            self._asset_lifetimes = self._compute_asset_lifetimes()
868
        lifetimes = self._asset_lifetimes
869
870
        raw_dates = dates.asi8[:, None]
871
        if include_start_date:
872
            mask = lifetimes.start <= raw_dates
873
        else:
874
            mask = lifetimes.start < raw_dates
875
        mask &= (raw_dates <= lifetimes.end)
876
877
        return pd.DataFrame(mask, index=dates, columns=lifetimes.sid)
878
879
880
class AssetConvertible(with_metaclass(ABCMeta)):
881
    """
882
    ABC for types that are convertible to integer-representations of
883
    Assets.
884
885
    Includes Asset, six.string_types, and Integral
886
    """
887
    pass
888
889
890
AssetConvertible.register(Integral)
891
AssetConvertible.register(Asset)
892
# Use six.string_types for Python2/3 compatibility
893
for _type in string_types:
894
    AssetConvertible.register(_type)
895
896
897
class NotAssetConvertible(ValueError):
898
    pass
899
900
901
class AssetFinderCachedEquities(AssetFinder):
902
    """
903
    An extension to AssetFinder that loads all equities from equities table
904
    into memory and overrides the methods that lookup_symbol uses to look up
905
    those equities.
906
    """
907
    def __init__(self, engine):
908
        super(AssetFinderCachedEquities, self).__init__(engine)
909
        self.fuzzy_symbol_hashed_equities = {}
910
        self.company_share_class_hashed_equities = {}
911
        self.hashed_equities = sa.select(self.equities.c).execute().fetchall()
912
        self._load_hashed_equities()
913
914
    def _load_hashed_equities(self):
915
        """
916
        Populates two maps - fuzzy symbol to list of equities having that
917
        fuzzy symbol and company symbol/share class symbol to list of
918
        equities having that combination of company symbol/share class symbol.
919
        """
920
        for equity in self.hashed_equities:
921
            company_symbol = equity['company_symbol']
922
            share_class_symbol = equity['share_class_symbol']
923
            fuzzy_symbol = equity['fuzzy_symbol']
924
            asset = self._convert_row_to_equity(equity)
925
            self.company_share_class_hashed_equities.setdefault(
926
                (company_symbol, share_class_symbol),
927
                []
928
            ).append(asset)
929
            self.fuzzy_symbol_hashed_equities.setdefault(
930
                fuzzy_symbol, []
931
            ).append(asset)
932
933
    def _convert_row_to_equity(self, row):
934
        """
935
        Converts a SQLAlchemy equity row to an Equity object.
936
        """
937
        return Equity(**_convert_asset_timestamp_fields(dict(row)))
938
939
    def _get_fuzzy_candidates(self, fuzzy_symbol):
940
        return self.fuzzy_symbol_hashed_equities.get(fuzzy_symbol, ())
941
942
    def _get_fuzzy_candidates_in_range(self, fuzzy_symbol, ad_value):
943
        return only_active_assets(
944
            ad_value,
945
            self._get_fuzzy_candidates(fuzzy_symbol),
946
        )
947
948
    def _get_split_candidates(self, company_symbol, share_class_symbol):
949
        return self.company_share_class_hashed_equities.get(
950
            (company_symbol, share_class_symbol),
951
            (),
952
        )
953
954
    def _get_split_candidates_in_range(self,
955
                                       company_symbol,
956
                                       share_class_symbol,
957
                                       ad_value):
958
        return sorted(
959
            only_active_assets(
960
                ad_value,
961
                self._get_split_candidates(company_symbol, share_class_symbol),
962
            ),
963
            key=lambda x: (x.start_date, x.end_date),
964
            reverse=True,
965
        )
966
967
    def _resolve_no_matching_candidates(self,
968
                                        company_symbol,
969
                                        share_class_symbol,
970
                                        ad_value):
971
        equities = self._get_split_candidates(
972
            company_symbol, share_class_symbol
973
        )
974
        partial_candidates = []
975
        for equity in equities:
976
            if equity.start_date.value <= ad_value:
977
                partial_candidates.append(equity)
978
        if partial_candidates:
979
            partial_candidates = sorted(
980
                partial_candidates,
981
                key=lambda x: x.end_date,
982
                reverse=True
983
            )
984
        return partial_candidates
985
986
    def _get_best_candidate(self, candidates):
987
        return candidates[0]
988
989
    def _get_equities_from_candidates(self, candidates):
990
        return candidates
991
992
993
def was_active(reference_date_value, asset):
994
    """
995
    Whether or not `asset` was active at the time corresponding to
996
    `reference_date_value`.
997
998
    Parameters
999
    ----------
1000
    reference_date_value : int
1001
        Date, represented as nanoseconds since EPOCH, for which we want to know
1002
        if `asset` was alive.  This is generally the result of accessing the
1003
        `value` attribute of a pandas Timestamp.
1004
    asset : Asset
1005
        The asset object to check.
1006
1007
    Returns
1008
    -------
1009
    was_active : bool
1010
        Whether or not the `asset` existed at the specified time.
1011
    """
1012
    return (
1013
        asset.start_date.value
1014
        <= reference_date_value
1015
        <= asset.end_date.value
1016
    )
1017
1018
1019
def only_active_assets(reference_date_value, assets):
1020
    """
1021
    Filter an iterable of Asset objects down to just assets that were alive at
1022
    the time corresponding to `reference_date_value`.
1023
1024
    Parameters
1025
    ----------
1026
    reference_date_value : int
1027
        Date, represented as nanoseconds since EPOCH, for which we want to know
1028
        if `asset` was alive.  This is generally the result of accessing the
1029
        `value` attribute of a pandas Timestamp.
1030
    assets : iterable[Asset]
1031
        The assets to filter.
1032
1033
    Returns
1034
    -------
1035
    active_assets : list
1036
        List of the active assets from `assets` on the requested date.
1037
    """
1038
    return [a for a in assets if was_active(reference_date_value, a)]
1039