Completed
Pull Request — master (#846)
by Warren
03:06
created

lookup_expired_futures()   A

Complexity

Conditions 1

Size

Total Lines 67

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 1
dl 0
loc 67
rs 9.2817

How to fix   Long Method   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

1
# Copyright 2015 Quantopian, Inc.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
15
from abc import ABCMeta
16
from numbers import Integral
17
from operator import itemgetter
18
19
from logbook import Logger
20
import numpy as np
21
import pandas as pd
22
from pandas import isnull
23
from six import with_metaclass, string_types, viewkeys
24
from six.moves import map as imap, range
25
import sqlalchemy as sa
26
27
from zipline.errors import (
28
    EquitiesNotFound,
29
    FutureContractsNotFound,
30
    MapAssetIdentifierIndexError,
31
    MultipleSymbolsFound,
32
    RootSymbolNotFound,
33
    SidsNotFound,
34
    SymbolNotFound,
35
)
36
from zipline.assets import (
37
    Asset, Equity, Future,
38
)
39
from zipline.assets.asset_writer import (
40
    split_delimited_symbol,
41
    check_version_info,
42
    ASSET_DB_VERSION,
43
    asset_db_table_names,
44
    SQLITE_MAX_VARIABLE_NUMBER
45
)
46
from zipline.utils.control_flow import invert
47
48
log = Logger('assets.py')
49
50
# A set of fields that need to be converted to strings before building an
51
# Asset to avoid unicode fields
52
_asset_str_fields = frozenset({
53
    'symbol',
54
    'asset_name',
55
    'exchange',
56
})
57
58
# A set of fields that need to be converted to timestamps in UTC
59
_asset_timestamp_fields = frozenset({
60
    'start_date',
61
    'end_date',
62
    'first_traded',
63
    'notice_date',
64
    'expiration_date',
65
    'auto_close_date',
66
})
67
68
69
def _convert_asset_timestamp_fields(dict_):
70
    """
71
    Takes in a dict of Asset init args and converts dates to pd.Timestamps
72
    """
73
    for key in (_asset_timestamp_fields & viewkeys(dict_)):
74
        value = pd.Timestamp(dict_[key], tz='UTC')
75
        dict_[key] = None if isnull(value) else value
76
    return dict_
77
78
79
class AssetFinder(object):
80
    """
81
    An AssetFinder is an interface to a database of Asset metadata written by
82
    an ``AssetDBWriter``.
83
84
    This class provides methods for looking up assets by unique integer id or
85
    by symbol.  For historical reasons, we refer to these unique ids as 'sids'.
86
87
    Parameters
88
    ----------
89
    engine : str or SQLAlchemy.engine
90
        An engine with a connection to the asset database to use, or a string
91
        that can be parsed by SQLAlchemy as a URI.
92
93
    See Also
94
    --------
95
    :class:`zipline.assets.asset_writer.AssetDBWriter`
96
    """
97
    # Token used as a substitute for pickling objects that contain a
98
    # reference to an AssetFinder.
99
    PERSISTENT_TOKEN = "<AssetFinder>"
100
101
    def __init__(self, engine):
102
103
        self.engine = engine
104
        metadata = sa.MetaData(bind=engine)
105
        metadata.reflect(only=asset_db_table_names)
106
        for table_name in asset_db_table_names:
107
            setattr(self, table_name, metadata.tables[table_name])
108
109
        # Check the version info of the db for compatibility
110
        check_version_info(self.version_info, ASSET_DB_VERSION)
111
112
        # Cache for lookup of assets by sid, the objects in the asset lookup
113
        # may be shared with the results from equity and future lookup caches.
114
        #
115
        # The top level cache exists to minimize lookups on the asset type
116
        # routing.
117
        #
118
        # The caches are read through, i.e. accessing an asset through
119
        # retrieve_asset will populate the cache on first retrieval.
120
        self._caches = (self._asset_cache, self._asset_type_cache) = {}, {}
121
122
        # Populated on first call to `lifetimes`.
123
        self._asset_lifetimes = None
124
125
    def _reset_caches(self):
126
        """
127
        Reset our asset caches.
128
129
        You probably shouldn't call this method.
130
        """
131
        # This method exists as a workaround for the in-place mutating behavior
132
        # of `TradingAlgorithm._write_and_map_id_index_to_sids`.  No one else
133
        # should be calling this.
134
        for cache in self._caches:
135
            cache.clear()
136
137
    def lookup_asset_types(self, sids):
138
        """
139
        Retrieve asset types for a list of sids.
140
141
        Parameters
142
        ----------
143
        sids : list[int]
144
145
        Returns
146
        -------
147
        types : dict[sid -> str or None]
148
            Asset types for the provided sids.
149
        """
150
        found = {}
151
        missing = set()
152
153
        for sid in sids:
154
            try:
155
                found[sid] = self._asset_type_cache[sid]
156
            except KeyError:
157
                missing.add(sid)
158
159
        if not missing:
160
            return found
161
162
        router_cols = self.asset_router.c
163
164
        for assets in self._group_into_chunks(missing):
165
            query = sa.select((router_cols.sid, router_cols.asset_type)).where(
166
                self.asset_router.c.sid.in_(map(int, assets))
167
            )
168
            for sid, type_ in query.execute().fetchall():
169
                missing.remove(sid)
170
                found[sid] = self._asset_type_cache[sid] = type_
171
172
            for sid in missing:
173
                found[sid] = self._asset_type_cache[sid] = None
174
175
        return found
176
177
    @staticmethod
178
    def _group_into_chunks(items, chunk_size=SQLITE_MAX_VARIABLE_NUMBER):
179
        items = list(items)
180
        return [items[x:x+chunk_size]
181
                for x in range(0, len(items), chunk_size)]
182
183
    def group_by_type(self, sids):
184
        """
185
        Group a list of sids by asset type.
186
187
        Parameters
188
        ----------
189
        sids : list[int]
190
191
        Returns
192
        -------
193
        types : dict[str or None -> list[int]]
194
            A dict mapping unique asset types to lists of sids drawn from sids.
195
            If we fail to look up an asset, we assign it a key of None.
196
        """
197
        return invert(self.lookup_asset_types(sids))
198
199
    def retrieve_asset(self, sid, default_none=False):
200
        """
201
        Retrieve the Asset for a given sid.
202
        """
203
        return self.retrieve_all((sid,), default_none=default_none)[0]
204
205
    def retrieve_all(self, sids, default_none=False):
206
        """
207
        Retrieve all assets in `sids`.
208
209
        Parameters
210
        ----------
211
        sids : interable of int
212
            Assets to retrieve.
213
        default_none : bool
214
            If True, return None for failed lookups.
215
            If False, raise `SidsNotFound`.
216
217
        Returns
218
        -------
219
        assets : list[int or None]
220
            A list of the same length as `sids` containing Assets (or Nones)
221
            corresponding to the requested sids.
222
223
        Raises
224
        ------
225
        SidsNotFound
226
            When a requested sid is not found and default_none=False.
227
        """
228
        hits, missing, failures = {}, set(), []
229
        for sid in sids:
230
            try:
231
                asset = self._asset_cache[sid]
232
                if not default_none and asset is None:
233
                    # Bail early if we've already cached that we don't know
234
                    # about an asset.
235
                    raise SidsNotFound(sids=[sid])
236
                hits[sid] = asset
237
            except KeyError:
238
                missing.add(sid)
239
240
        # All requests were cache hits.  Return requested sids in order.
241
        if not missing:
242
            return [hits[sid] for sid in sids]
243
244
        update_hits = hits.update
245
246
        # Look up cache misses by type.
247
        type_to_assets = self.group_by_type(missing)
248
249
        # Handle failures
250
        failures = {failure: None for failure in type_to_assets.pop(None, ())}
251
        update_hits(failures)
252
        self._asset_cache.update(failures)
253
254
        if failures and not default_none:
255
            raise SidsNotFound(sids=list(failures))
256
257
        # We don't update the asset cache here because it should already be
258
        # updated by `self.retrieve_equities`.
259
        update_hits(self.retrieve_equities(type_to_assets.pop('equity', ())))
260
        update_hits(
261
            self.retrieve_futures_contracts(type_to_assets.pop('future', ()))
262
        )
263
264
        # We shouldn't know about any other asset types.
265
        if type_to_assets:
266
            raise AssertionError(
267
                "Found asset types: %s" % list(type_to_assets.keys())
268
            )
269
270
        return [hits[sid] for sid in sids]
271
272
    def retrieve_equities(self, sids):
273
        """
274
        Retrieve Equity objects for a list of sids.
275
276
        Users generally shouldn't need to this method (instead, they should
277
        prefer the more general/friendly `retrieve_assets`), but it has a
278
        documented interface and tests because it's used upstream.
279
280
        Parameters
281
        ----------
282
        sids : iterable[int]
283
284
        Returns
285
        -------
286
        equities : dict[int -> Equity]
287
288
        Raises
289
        ------
290
        EquitiesNotFound
291
            When any requested asset isn't found.
292
        """
293
        return self._retrieve_assets(sids, self.equities, Equity)
294
295
    def _retrieve_equity(self, sid):
296
        return self.retrieve_equities((sid,))[sid]
297
298
    def retrieve_futures_contracts(self, sids):
299
        """
300
        Retrieve Future objects for an iterable of sids.
301
302
        Users generally shouldn't need to this method (instead, they should
303
        prefer the more general/friendly `retrieve_assets`), but it has a
304
        documented interface and tests because it's used upstream.
305
306
        Parameters
307
        ----------
308
        sids : iterable[int]
309
310
        Returns
311
        -------
312
        equities : dict[int -> Equity]
313
314
        Raises
315
        ------
316
        EquitiesNotFound
317
            When any requested asset isn't found.
318
        """
319
        return self._retrieve_assets(sids, self.futures_contracts, Future)
320
321
    @staticmethod
322
    def _select_assets_by_sid(asset_tbl, sids):
323
        return sa.select([asset_tbl]).where(
324
            asset_tbl.c.sid.in_(map(int, sids))
325
        )
326
327
    @staticmethod
328
    def _select_asset_by_symbol(asset_tbl, symbol):
329
        return sa.select([asset_tbl]).where(asset_tbl.c.symbol == symbol)
330
331
    def _retrieve_assets(self, sids, asset_tbl, asset_type):
332
        """
333
        Internal function for loading assets from a table.
334
335
        This should be the only method of `AssetFinder` that writes Assets into
336
        self._asset_cache.
337
338
        Parameters
339
        ---------
340
        sids : iterable of int
341
            Asset ids to look up.
342
        asset_tbl : sqlalchemy.Table
343
            Table from which to query assets.
344
        asset_type : type
345
            Type of asset to be constructed.
346
347
        Returns
348
        -------
349
        assets : dict[int -> Asset]
350
            Dict mapping requested sids to the retrieved assets.
351
        """
352
        # Fastpath for empty request.
353
        if not sids:
354
            return {}
355
356
        cache = self._asset_cache
357
        hits = {}
358
359
        for assets in self._group_into_chunks(sids):
360
            # Load misses from the db.
361
            query = self._select_assets_by_sid(asset_tbl, assets)
362
363
            for row in imap(dict, query.execute().fetchall()):
364
                asset = asset_type(**_convert_asset_timestamp_fields(row))
365
                sid = asset.sid
366
                hits[sid] = cache[sid] = asset
367
368
        # If we get here, it means something in our code thought that a
369
        # particular sid was an equity/future and called this function with a
370
        # concrete type, but we couldn't actually resolve the asset.  This is
371
        # an error in our code, not a user-input error.
372
        misses = tuple(set(sids) - viewkeys(hits))
373
        if misses:
374
            if asset_type == Equity:
375
                raise EquitiesNotFound(sids=misses)
376
            else:
377
                raise FutureContractsNotFound(sids=misses)
378
        return hits
379
380
    def _get_fuzzy_candidates(self, fuzzy_symbol):
381
        candidates = sa.select(
382
            (self.equities.c.sid,)
383
        ).where(self.equities.c.fuzzy_symbol == fuzzy_symbol).order_by(
384
            self.equities.c.start_date.desc(),
385
            self.equities.c.end_date.desc()
386
        ).execute().fetchall()
387
        return candidates
388
389
    def _get_fuzzy_candidates_in_range(self, fuzzy_symbol, ad_value):
390
        candidates = sa.select(
391
            (self.equities.c.sid,)
392
        ).where(
393
            sa.and_(
394
                self.equities.c.fuzzy_symbol == fuzzy_symbol,
395
                self.equities.c.start_date <= ad_value,
396
                self.equities.c.end_date >= ad_value
397
            )
398
        ).order_by(
399
            self.equities.c.start_date.desc(),
400
            self.equities.c.end_date.desc(),
401
        ).execute().fetchall()
402
        return candidates
403
404
    def _get_split_candidates_in_range(self,
405
                                       company_symbol,
406
                                       share_class_symbol,
407
                                       ad_value):
408
        candidates = sa.select(
409
            (self.equities.c.sid,)
410
        ).where(
411
            sa.and_(
412
                self.equities.c.company_symbol == company_symbol,
413
                self.equities.c.share_class_symbol == share_class_symbol,
414
                self.equities.c.start_date <= ad_value,
415
                self.equities.c.end_date >= ad_value
416
            )
417
        ).order_by(
418
            self.equities.c.start_date.desc(),
419
            self.equities.c.end_date.desc(),
420
        ).execute().fetchall()
421
        return candidates
422
423
    def _get_split_candidates(self, company_symbol, share_class_symbol):
424
        candidates = sa.select(
425
            (self.equities.c.sid,)
426
        ).where(
427
            sa.and_(
428
                self.equities.c.company_symbol == company_symbol,
429
                self.equities.c.share_class_symbol == share_class_symbol
430
            )
431
        ).order_by(
432
            self.equities.c.start_date.desc(),
433
            self.equities.c.end_date.desc(),
434
        ).execute().fetchall()
435
        return candidates
436
437
    def _resolve_no_matching_candidates(self,
438
                                        company_symbol,
439
                                        share_class_symbol,
440
                                        ad_value):
441
        candidates = sa.select((self.equities.c.sid,)).where(
442
            sa.and_(
443
                self.equities.c.company_symbol == company_symbol,
444
                self.equities.c.share_class_symbol ==
445
                share_class_symbol,
446
                self.equities.c.start_date <= ad_value),
447
        ).order_by(
448
            self.equities.c.end_date.desc(),
449
        ).execute().fetchall()
450
        return candidates
451
452
    def _get_best_candidate(self, candidates):
453
        return self._retrieve_equity(candidates[0]['sid'])
454
455
    def _get_equities_from_candidates(self, candidates):
456
        sids = map(itemgetter('sid'), candidates)
457
        results = self.retrieve_equities(sids)
458
        return [results[sid] for sid in sids]
459
460
    def lookup_symbol(self, symbol, as_of_date, fuzzy=False):
461
        """
462
        Return matching Equity of name symbol in database.
463
464
        If multiple Equities are found and as_of_date is not set,
465
        raises MultipleSymbolsFound.
466
467
        If no Equity was active at as_of_date raises SymbolNotFound.
468
        """
469
        company_symbol, share_class_symbol, fuzzy_symbol = \
470
            split_delimited_symbol(symbol)
471
        if as_of_date:
472
            # Format inputs
473
            as_of_date = pd.Timestamp(as_of_date).normalize()
474
            ad_value = as_of_date.value
475
476
            if fuzzy:
477
                # Search for a single exact match on the fuzzy column
478
                candidates = self._get_fuzzy_candidates_in_range(fuzzy_symbol,
479
                                                                 ad_value)
480
481
                # If exactly one SID exists for fuzzy_symbol, return that sid
482
                if len(candidates) == 1:
483
                    return self._get_best_candidate(candidates)
484
485
            # Search for exact matches of the split-up company_symbol and
486
            # share_class_symbol
487
            candidates = self._get_split_candidates_in_range(
488
                company_symbol,
489
                share_class_symbol,
490
                ad_value
491
            )
492
493
            # If exactly one SID exists for symbol, return that symbol
494
            # If multiple SIDs exist for symbol, return latest start_date with
495
            # end_date as a tie-breaker
496
            if candidates:
497
                return self._get_best_candidate(candidates)
498
499
            # If no SID exists for symbol, return SID with the
500
            # highest-but-not-over end_date
501
            elif not candidates:
502
                candidates = self._resolve_no_matching_candidates(
503
                    company_symbol,
504
                    share_class_symbol,
505
                    ad_value
506
                )
507
                if candidates:
508
                    return self._get_best_candidate(candidates)
509
510
            raise SymbolNotFound(symbol=symbol)
511
512
        else:
513
            # If this is a fuzzy look-up, check if there is exactly one match
514
            # for the fuzzy symbol
515
            if fuzzy:
516
                candidates = self._get_fuzzy_candidates(fuzzy_symbol)
517
                if len(candidates) == 1:
518
                    return self._get_best_candidate(candidates)
519
520
            candidates = self._get_split_candidates(company_symbol,
521
                                                    share_class_symbol)
522
            if len(candidates) == 1:
523
                return self._get_best_candidate(candidates)
524
            elif not candidates:
525
                raise SymbolNotFound(symbol=symbol)
526
            else:
527
                raise MultipleSymbolsFound(
528
                    symbol=symbol,
529
                    options=self._get_equities_from_candidates(candidates)
530
                )
531
532
    def lookup_future_symbol(self, symbol):
533
        """ Return the Future object for a given symbol.
534
535
        Parameters
536
        ----------
537
        symbol : str
538
            The symbol of the desired contract.
539
540
        Returns
541
        -------
542
        Future
543
            A Future object.
544
545
        Raises
546
        ------
547
        SymbolNotFound
548
            Raised when no contract named 'symbol' is found.
549
550
        """
551
552
        data = self._select_asset_by_symbol(self.futures_contracts, symbol)\
553
                   .execute().fetchone()
554
555
        # If no data found, raise an exception
556
        if not data:
557
            raise SymbolNotFound(symbol=symbol)
558
        return self.retrieve_asset(data['sid'])
559
560
    def lookup_future_chain(self, root_symbol, as_of_date):
561
        """ Return the futures chain for a given root symbol.
562
563
        Parameters
564
        ----------
565
        root_symbol : str
566
            Root symbol of the desired future.
567
568
        as_of_date : pd.Timestamp or pd.NaT
569
            Date at which the chain determination is rooted. I.e. the
570
            existing contract whose notice date/expiration date is first
571
            after this date is the primary contract, etc. If NaT is
572
            given, the chain is unbounded, and all contracts for this
573
            root symbol are returned.
574
575
        Returns
576
        -------
577
        list
578
            A list of Future objects, the chain for the given
579
            parameters.
580
581
        Raises
582
        ------
583
        RootSymbolNotFound
584
            Raised when a future chain could not be found for the given
585
            root symbol.
586
        """
587
588
        fc_cols = self.futures_contracts.c
589
590
        if as_of_date is pd.NaT:
591
            # If the as_of_date is NaT, get all contracts for this
592
            # root symbol.
593
            sids = list(map(
594
                itemgetter('sid'),
595
                sa.select((fc_cols.sid,)).where(
596
                    (fc_cols.root_symbol == root_symbol),
597
                ).order_by(
598
                    fc_cols.notice_date.asc(),
599
                ).execute().fetchall()))
600
        else:
601
            as_of_date = as_of_date.value
602
603
            sids = list(map(
604
                itemgetter('sid'),
605
                sa.select((fc_cols.sid,)).where(
606
                    (fc_cols.root_symbol == root_symbol) &
607
608
                    # Filter to contracts that are still valid. If both
609
                    # exist, use the one that comes first in time (i.e.
610
                    # the lower value). If either notice_date or
611
                    # expiration_date is NaT, use the other. If both are
612
                    # NaT, the contract cannot be included in any chain.
613
                    sa.case(
614
                        [
615
                            (
616
                                fc_cols.notice_date == pd.NaT.value,
617
                                fc_cols.expiration_date >= as_of_date
618
                            ),
619
                            (
620
                                fc_cols.expiration_date == pd.NaT.value,
621
                                fc_cols.notice_date >= as_of_date
622
                            )
623
                        ],
624
                        else_=(
625
                            sa.func.min(
626
                                fc_cols.notice_date,
627
                                fc_cols.expiration_date
628
                            ) >= as_of_date
629
                        )
630
                    )
631
                ).order_by(
632
                    # Sort using expiration_date if valid. If it's NaT,
633
                    # use notice_date instead.
634
                    sa.case(
635
                        [
636
                            (
637
                                fc_cols.expiration_date == pd.NaT.value,
638
                                fc_cols.notice_date
639
                            )
640
                        ],
641
                        else_=fc_cols.expiration_date
642
                    ).asc()
643
                ).execute().fetchall()
644
            ))
645
646
        if not sids:
647
            # Check if root symbol exists.
648
            count = sa.select((sa.func.count(fc_cols.sid),)).where(
649
                fc_cols.root_symbol == root_symbol,
650
            ).scalar()
651
            if count == 0:
652
                raise RootSymbolNotFound(root_symbol=root_symbol)
653
654
        contracts = self.retrieve_futures_contracts(sids)
655
        return [contracts[sid] for sid in sids]
656
657
    def lookup_expired_futures(self, start, end):
658
        start = start.value
659
        end = end.value
660
661
        fc_cols = self.futures_contracts.c
662
663
        sids = list(map(
664
            itemgetter('sid'),
665
            sa.select((fc_cols.sid,)).where(
666
667
                    # Filter to contracts that are still valid. If both
668
                    # exist, use the one that comes first in time (i.e.
669
                    # the lower value). If either notice_date or
670
                    # expiration_date is NaT, use the other. If both are
671
                    # NaT, the contract cannot be included in any chain.
672
                    sa.case(
673
                        [
674
                            (
675
                                fc_cols.notice_date == pd.NaT.value,
676
                                fc_cols.expiration_date >= start
677
                            ),
678
                            (
679
                                fc_cols.expiration_date == pd.NaT.value,
680
                                fc_cols.notice_date >= start
681
                            )
682
                        ],
683
                        else_=(
684
                            sa.func.min(
685
                                fc_cols.notice_date,
686
                                fc_cols.expiration_date
687
                            ) >= start
688
                        )
689
                    )
690
                    & sa.case(
691
                        [
692
                            (
693
                                fc_cols.notice_date == pd.NaT.value,
694
                                fc_cols.expiration_date <= end
695
                            ),
696
                            (
697
                                fc_cols.expiration_date == pd.NaT.value,
698
                                fc_cols.notice_date <= end
699
                            )
700
                        ],
701
                        else_=(
702
                            sa.func.min(
703
                                fc_cols.notice_date,
704
                                fc_cols.expiration_date
705
                            ) <= end
706
                        )
707
                    )
708
                    ).order_by(
709
                # Sort using expiration_date if valid. If it's
710
                # NaT, use notice_date instead.
711
                sa.case(
712
                    [
713
                        (
714
                            fc_cols.expiration_date == pd.NaT.value,
715
                            fc_cols.notice_date
716
                        )
717
                    ],
718
                    else_=fc_cols.expiration_date
719
                ).asc()
720
            ).execute().fetchall()
721
        ))
722
723
        return sids
724
725
    @property
726
    def sids(self):
727
        return tuple(map(
728
            itemgetter('sid'),
729
            sa.select((self.asset_router.c.sid,)).execute().fetchall(),
730
        ))
731
732
    def _lookup_generic_scalar(self,
733
                               asset_convertible,
734
                               as_of_date,
735
                               matches,
736
                               missing):
737
        """
738
        Convert asset_convertible to an asset.
739
740
        On success, append to matches.
741
        On failure, append to missing.
742
        """
743
        if isinstance(asset_convertible, Asset):
744
            matches.append(asset_convertible)
745
746
        elif isinstance(asset_convertible, Integral):
747
            try:
748
                result = self.retrieve_asset(int(asset_convertible))
749
            except SidsNotFound:
750
                missing.append(asset_convertible)
751
                return None
752
            matches.append(result)
753
754
        elif isinstance(asset_convertible, string_types):
755
            try:
756
                matches.append(
757
                    self.lookup_symbol(asset_convertible, as_of_date)
758
                )
759
            except SymbolNotFound:
760
                missing.append(asset_convertible)
761
                return None
762
        else:
763
            raise NotAssetConvertible(
764
                "Input was %s, not AssetConvertible."
765
                % asset_convertible
766
            )
767
768
    def lookup_generic(self,
769
                       asset_convertible_or_iterable,
770
                       as_of_date):
771
        """
772
        Convert a AssetConvertible or iterable of AssetConvertibles into
773
        a list of Asset objects.
774
775
        This method exists primarily as a convenience for implementing
776
        user-facing APIs that can handle multiple kinds of input.  It should
777
        not be used for internal code where we already know the expected types
778
        of our inputs.
779
780
        Returns a pair of objects, the first of which is the result of the
781
        conversion, and the second of which is a list containing any values
782
        that couldn't be resolved.
783
        """
784
        matches = []
785
        missing = []
786
787
        # Interpret input as scalar.
788
        if isinstance(asset_convertible_or_iterable, AssetConvertible):
789
            self._lookup_generic_scalar(
790
                asset_convertible=asset_convertible_or_iterable,
791
                as_of_date=as_of_date,
792
                matches=matches,
793
                missing=missing,
794
            )
795
            try:
796
                return matches[0], missing
797
            except IndexError:
798
                if hasattr(asset_convertible_or_iterable, '__int__'):
799
                    raise SidsNotFound(sids=[asset_convertible_or_iterable])
800
                else:
801
                    raise SymbolNotFound(symbol=asset_convertible_or_iterable)
802
803
        # Interpret input as iterable.
804
        try:
805
            iterator = iter(asset_convertible_or_iterable)
806
        except TypeError:
807
            raise NotAssetConvertible(
808
                "Input was not a AssetConvertible "
809
                "or iterable of AssetConvertible."
810
            )
811
812
        for obj in iterator:
813
            self._lookup_generic_scalar(obj, as_of_date, matches, missing)
814
        return matches, missing
815
816
    def map_identifier_index_to_sids(self, index, as_of_date):
817
        """
818
        This method is for use in sanitizing a user's DataFrame or Panel
819
        inputs.
820
821
        Takes the given index of identifiers, checks their types, builds assets
822
        if necessary, and returns a list of the sids that correspond to the
823
        input index.
824
825
        Parameters
826
        ----------
827
        index : Iterable
828
            An iterable containing ints, strings, or Assets
829
        as_of_date : pandas.Timestamp
830
            A date to be used to resolve any dual-mapped symbols
831
832
        Returns
833
        -------
834
        List
835
            A list of integer sids corresponding to the input index
836
        """
837
        # This method assumes that the type of the objects in the index is
838
        # consistent and can, therefore, be taken from the first identifier
839
        first_identifier = index[0]
840
841
        # Ensure that input is AssetConvertible (integer, string, or Asset)
842
        if not isinstance(first_identifier, AssetConvertible):
843
            raise MapAssetIdentifierIndexError(obj=first_identifier)
844
845
        # If sids are provided, no mapping is necessary
846
        if isinstance(first_identifier, Integral):
847
            return index
848
849
        # Look up all Assets for mapping
850
        matches = []
851
        missing = []
852
        for identifier in index:
853
            self._lookup_generic_scalar(identifier, as_of_date,
854
                                        matches, missing)
855
856
        if missing:
857
            raise ValueError("Missing assets for identifiers: %s" % missing)
858
859
        # Return a list of the sids of the found assets
860
        return [asset.sid for asset in matches]
861
862
    def _compute_asset_lifetimes(self):
863
        """
864
        Compute and cache a recarry of asset lifetimes.
865
        """
866
        equities_cols = self.equities.c
867
        buf = np.array(
868
            tuple(
869
                sa.select((
870
                    equities_cols.sid,
871
                    equities_cols.start_date,
872
                    equities_cols.end_date,
873
                )).execute(),
874
            ), dtype='<f8',  # use doubles so we get NaNs
875
        )
876
        lifetimes = np.recarray(
877
            buf=buf,
878
            shape=(len(buf),),
879
            dtype=[
880
                ('sid', '<f8'),
881
                ('start', '<f8'),
882
                ('end', '<f8')
883
            ],
884
        )
885
        start = lifetimes.start
886
        end = lifetimes.end
887
        start[np.isnan(start)] = 0  # convert missing starts to 0
888
        end[np.isnan(end)] = np.iinfo(int).max  # convert missing end to INTMAX
889
        # Cast the results back down to int.
890
        return lifetimes.astype([
891
            ('sid', '<i8'),
892
            ('start', '<i8'),
893
            ('end', '<i8'),
894
        ])
895
896
    def lifetimes(self, dates, include_start_date):
897
        """
898
        Compute a DataFrame representing asset lifetimes for the specified date
899
        range.
900
901
        Parameters
902
        ----------
903
        dates : pd.DatetimeIndex
904
            The dates for which to compute lifetimes.
905
        include_start_date : bool
906
            Whether or not to count the asset as alive on its start_date.
907
908
            This is useful in a backtesting context where `lifetimes` is being
909
            used to signify "do I have data for this asset as of the morning of
910
            this date?"  For many financial metrics, (e.g. daily close), data
911
            isn't available for an asset until the end of the asset's first
912
            day.
913
914
        Returns
915
        -------
916
        lifetimes : pd.DataFrame
917
            A frame of dtype bool with `dates` as index and an Int64Index of
918
            assets as columns.  The value at `lifetimes.loc[date, asset]` will
919
            be True iff `asset` existed on `date`.  If `include_start_date` is
920
            False, then lifetimes.loc[date, asset] will be false when date ==
921
            asset.start_date.
922
923
        See Also
924
        --------
925
        numpy.putmask
926
        zipline.pipeline.engine.SimplePipelineEngine._compute_root_mask
927
        """
928
        # This is a less than ideal place to do this, because if someone adds
929
        # assets to the finder after we've touched lifetimes we won't have
930
        # those new assets available.  Mutability is not my favorite
931
        # programming feature.
932
        if self._asset_lifetimes is None:
933
            self._asset_lifetimes = self._compute_asset_lifetimes()
934
        lifetimes = self._asset_lifetimes
935
936
        raw_dates = dates.asi8[:, None]
937
        if include_start_date:
938
            mask = lifetimes.start <= raw_dates
939
        else:
940
            mask = lifetimes.start < raw_dates
941
        mask &= (raw_dates <= lifetimes.end)
942
943
        return pd.DataFrame(mask, index=dates, columns=lifetimes.sid)
944
945
946
class AssetConvertible(with_metaclass(ABCMeta)):
947
    """
948
    ABC for types that are convertible to integer-representations of
949
    Assets.
950
951
    Includes Asset, six.string_types, and Integral
952
    """
953
    pass
954
955
956
AssetConvertible.register(Integral)
957
AssetConvertible.register(Asset)
958
# Use six.string_types for Python2/3 compatibility
959
for _type in string_types:
960
    AssetConvertible.register(_type)
961
962
963
class NotAssetConvertible(ValueError):
964
    pass
965
966
967
class AssetFinderCachedEquities(AssetFinder):
968
    """
969
    An extension to AssetFinder that loads all equities from equities table
970
    into memory and overrides the methods that lookup_symbol uses to look up
971
    those equities.
972
    """
973
974
    def __init__(self, engine):
975
        super(AssetFinderCachedEquities, self).__init__(engine)
976
        self.fuzzy_symbol_hashed_equities = {}
977
        self.company_share_class_hashed_equities = {}
978
        self.hashed_equities = sa.select(self.equities.c).execute().fetchall()
979
        self._load_hashed_equities()
980
981
    def _load_hashed_equities(self):
982
        """
983
        Populates two maps - fuzzy symbol to list of equities having that
984
        fuzzy symbol and company symbol/share class symbol to list of
985
        equities having that combination of company symbol/share class symbol.
986
        """
987
        for equity in self.hashed_equities:
988
            company_symbol = equity['company_symbol']
989
            share_class_symbol = equity['share_class_symbol']
990
            fuzzy_symbol = equity['fuzzy_symbol']
991
            asset = self._convert_row_to_equity(equity)
992
            self.company_share_class_hashed_equities.setdefault(
993
                (company_symbol, share_class_symbol),
994
                []
995
            ).append(asset)
996
            self.fuzzy_symbol_hashed_equities.setdefault(
997
                fuzzy_symbol, []
998
            ).append(asset)
999
1000
    def _convert_row_to_equity(self, row):
1001
        """
1002
        Converts a SQLAlchemy equity row to an Equity object.
1003
        """
1004
        return Equity(**_convert_asset_timestamp_fields(dict(row)))
1005
1006
    def _get_fuzzy_candidates(self, fuzzy_symbol):
1007
        return self.fuzzy_symbol_hashed_equities.get(fuzzy_symbol, ())
1008
1009
    def _get_fuzzy_candidates_in_range(self, fuzzy_symbol, ad_value):
1010
        return only_active_assets(
1011
            ad_value,
1012
            self._get_fuzzy_candidates(fuzzy_symbol),
1013
        )
1014
1015
    def _get_split_candidates(self, company_symbol, share_class_symbol):
1016
        return self.company_share_class_hashed_equities.get(
1017
            (company_symbol, share_class_symbol),
1018
            (),
1019
        )
1020
1021
    def _get_split_candidates_in_range(self,
1022
                                       company_symbol,
1023
                                       share_class_symbol,
1024
                                       ad_value):
1025
        return sorted(
1026
            only_active_assets(
1027
                ad_value,
1028
                self._get_split_candidates(company_symbol, share_class_symbol),
1029
            ),
1030
            key=lambda x: (x.start_date, x.end_date),
1031
            reverse=True,
1032
        )
1033
1034
    def _resolve_no_matching_candidates(self,
1035
                                        company_symbol,
1036
                                        share_class_symbol,
1037
                                        ad_value):
1038
        equities = self._get_split_candidates(
1039
            company_symbol, share_class_symbol
1040
        )
1041
        partial_candidates = []
1042
        for equity in equities:
1043
            if equity.start_date.value <= ad_value:
1044
                partial_candidates.append(equity)
1045
        if partial_candidates:
1046
            partial_candidates = sorted(
1047
                partial_candidates,
1048
                key=lambda x: x.end_date,
1049
                reverse=True
1050
            )
1051
        return partial_candidates
1052
1053
    def _get_best_candidate(self, candidates):
1054
        return candidates[0]
1055
1056
    def _get_equities_from_candidates(self, candidates):
1057
        return candidates
1058
1059
1060
def was_active(reference_date_value, asset):
1061
    """
1062
    Whether or not `asset` was active at the time corresponding to
1063
    `reference_date_value`.
1064
1065
    Parameters
1066
    ----------
1067
    reference_date_value : int
1068
        Date, represented as nanoseconds since EPOCH, for which we want to know
1069
        if `asset` was alive.  This is generally the result of accessing the
1070
        `value` attribute of a pandas Timestamp.
1071
    asset : Asset
1072
        The asset object to check.
1073
1074
    Returns
1075
    -------
1076
    was_active : bool
1077
        Whether or not the `asset` existed at the specified time.
1078
    """
1079
    return (
1080
        asset.start_date.value
1081
        <= reference_date_value
1082
        <= asset.end_date.value
1083
    )
1084
1085
1086
def only_active_assets(reference_date_value, assets):
1087
    """
1088
    Filter an iterable of Asset objects down to just assets that were alive at
1089
    the time corresponding to `reference_date_value`.
1090
1091
    Parameters
1092
    ----------
1093
    reference_date_value : int
1094
        Date, represented as nanoseconds since EPOCH, for which we want to know
1095
        if `asset` was alive.  This is generally the result of accessing the
1096
        `value` attribute of a pandas Timestamp.
1097
    assets : iterable[Asset]
1098
        The assets to filter.
1099
1100
    Returns
1101
    -------
1102
    active_assets : list
1103
        List of the active assets from `assets` on the requested date.
1104
    """
1105
    return [a for a in assets if was_active(reference_date_value, a)]
1106