Total Complexity | 68 |
Total Lines | 677 |
Duplicated Lines | 0 % |
Complex classes like tests.AssetFinderTestCase often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
1 | # |
||
363 | class AssetFinderTestCase(TestCase): |
||
364 | |||
365 | def setUp(self): |
||
366 | self.env = TradingEnvironment(load=noop_load) |
||
367 | self.asset_finder_type = AssetFinder |
||
368 | |||
369 | def test_lookup_symbol_delimited(self): |
||
370 | as_of = pd.Timestamp('2013-01-01', tz='UTC') |
||
371 | frame = pd.DataFrame.from_records( |
||
372 | [ |
||
373 | { |
||
374 | 'sid': i, |
||
375 | 'symbol': 'TEST.%d' % i, |
||
376 | 'company_name': "company%d" % i, |
||
377 | 'start_date': as_of.value, |
||
378 | 'end_date': as_of.value, |
||
379 | 'exchange': uuid.uuid4().hex |
||
380 | } |
||
381 | for i in range(3) |
||
382 | ] |
||
383 | ) |
||
384 | self.env.write_data(equities_df=frame) |
||
385 | finder = self.asset_finder_type(self.env.engine) |
||
386 | asset_0, asset_1, asset_2 = ( |
||
387 | finder.retrieve_asset(i) for i in range(3) |
||
388 | ) |
||
389 | |||
390 | # we do it twice to catch caching bugs |
||
391 | for i in range(2): |
||
392 | with self.assertRaises(SymbolNotFound): |
||
393 | finder.lookup_symbol('TEST', as_of) |
||
394 | with self.assertRaises(SymbolNotFound): |
||
395 | finder.lookup_symbol('TEST1', as_of) |
||
396 | # '@' is not a supported delimiter |
||
397 | with self.assertRaises(SymbolNotFound): |
||
398 | finder.lookup_symbol('TEST@1', as_of) |
||
399 | |||
400 | # Adding an unnecessary fuzzy shouldn't matter. |
||
401 | for fuzzy_char in ['-', '/', '_', '.']: |
||
402 | self.assertEqual( |
||
403 | asset_1, |
||
404 | finder.lookup_symbol('TEST%s1' % fuzzy_char, as_of) |
||
405 | ) |
||
406 | |||
407 | def test_lookup_symbol_fuzzy(self): |
||
408 | metadata = { |
||
409 | 0: {'symbol': 'PRTY_HRD'}, |
||
410 | 1: {'symbol': 'BRKA'}, |
||
411 | 2: {'symbol': 'BRK_A'}, |
||
412 | } |
||
413 | self.env.write_data(equities_data=metadata) |
||
414 | finder = self.env.asset_finder |
||
415 | dt = pd.Timestamp('2013-01-01', tz='UTC') |
||
416 | |||
417 | # Try combos of looking up PRTYHRD with and without a time or fuzzy |
||
418 | # Both non-fuzzys get no result |
||
419 | with self.assertRaises(SymbolNotFound): |
||
420 | finder.lookup_symbol('PRTYHRD', None) |
||
421 | with self.assertRaises(SymbolNotFound): |
||
422 | finder.lookup_symbol('PRTYHRD', dt) |
||
423 | # Both fuzzys work |
||
424 | self.assertEqual(0, finder.lookup_symbol('PRTYHRD', None, fuzzy=True)) |
||
425 | self.assertEqual(0, finder.lookup_symbol('PRTYHRD', dt, fuzzy=True)) |
||
426 | |||
427 | # Try combos of looking up PRTY_HRD, all returning sid 0 |
||
428 | self.assertEqual(0, finder.lookup_symbol('PRTY_HRD', None)) |
||
429 | self.assertEqual(0, finder.lookup_symbol('PRTY_HRD', dt)) |
||
430 | self.assertEqual(0, finder.lookup_symbol('PRTY_HRD', None, fuzzy=True)) |
||
431 | self.assertEqual(0, finder.lookup_symbol('PRTY_HRD', dt, fuzzy=True)) |
||
432 | |||
433 | # Try combos of looking up BRKA, all returning sid 1 |
||
434 | self.assertEqual(1, finder.lookup_symbol('BRKA', None)) |
||
435 | self.assertEqual(1, finder.lookup_symbol('BRKA', dt)) |
||
436 | self.assertEqual(1, finder.lookup_symbol('BRKA', None, fuzzy=True)) |
||
437 | self.assertEqual(1, finder.lookup_symbol('BRKA', dt, fuzzy=True)) |
||
438 | |||
439 | # Try combos of looking up BRK_A, all returning sid 2 |
||
440 | self.assertEqual(2, finder.lookup_symbol('BRK_A', None)) |
||
441 | self.assertEqual(2, finder.lookup_symbol('BRK_A', dt)) |
||
442 | self.assertEqual(2, finder.lookup_symbol('BRK_A', None, fuzzy=True)) |
||
443 | self.assertEqual(2, finder.lookup_symbol('BRK_A', dt, fuzzy=True)) |
||
444 | |||
445 | def test_lookup_symbol(self): |
||
446 | |||
447 | # Incrementing by two so that start and end dates for each |
||
448 | # generated Asset don't overlap (each Asset's end_date is the |
||
449 | # day after its start date.) |
||
450 | dates = pd.date_range('2013-01-01', freq='2D', periods=5, tz='UTC') |
||
451 | df = pd.DataFrame.from_records( |
||
452 | [ |
||
453 | { |
||
454 | 'sid': i, |
||
455 | 'symbol': 'existing', |
||
456 | 'start_date': date.value, |
||
457 | 'end_date': (date + timedelta(days=1)).value, |
||
458 | 'exchange': 'NYSE', |
||
459 | } |
||
460 | for i, date in enumerate(dates) |
||
461 | ] |
||
462 | ) |
||
463 | self.env.write_data(equities_df=df) |
||
464 | finder = self.asset_finder_type(self.env.engine) |
||
465 | for _ in range(2): # Run checks twice to test for caching bugs. |
||
466 | with self.assertRaises(SymbolNotFound): |
||
467 | finder.lookup_symbol('NON_EXISTING', dates[0]) |
||
468 | |||
469 | with self.assertRaises(MultipleSymbolsFound): |
||
470 | finder.lookup_symbol('EXISTING', None) |
||
471 | |||
472 | for i, date in enumerate(dates): |
||
473 | # Verify that we correctly resolve multiple symbols using |
||
474 | # the supplied date |
||
475 | result = finder.lookup_symbol('EXISTING', date) |
||
476 | self.assertEqual(result.symbol, 'EXISTING') |
||
477 | self.assertEqual(result.sid, i) |
||
478 | |||
479 | def test_lookup_symbol_from_multiple_valid(self): |
||
480 | # This test asserts that we resolve conflicts in accordance with the |
||
481 | # following rules when we have multiple assets holding the same symbol |
||
482 | # at the same time: |
||
483 | |||
484 | # If multiple SIDs exist for symbol S at time T, return the candidate |
||
485 | # SID whose start_date is highest. (200 cases) |
||
486 | |||
487 | # If multiple SIDs exist for symbol S at time T, the best candidate |
||
488 | # SIDs share the highest start_date, return the SID with the highest |
||
489 | # end_date. (34 cases) |
||
490 | |||
491 | # It is the opinion of the author (ssanderson) that we should consider |
||
492 | # this malformed input and fail here. But this is the current indended |
||
493 | # behavior of the code, and I accidentally broke it while refactoring. |
||
494 | # These will serve as regression tests until the time comes that we |
||
495 | # decide to enforce this as an error. |
||
496 | |||
497 | # See https://github.com/quantopian/zipline/issues/837 for more |
||
498 | # details. |
||
499 | |||
500 | df = pd.DataFrame.from_records( |
||
501 | [ |
||
502 | { |
||
503 | 'sid': 1, |
||
504 | 'symbol': 'multiple', |
||
505 | 'start_date': pd.Timestamp('2010-01-01'), |
||
506 | 'end_date': pd.Timestamp('2012-01-01'), |
||
507 | 'exchange': 'NYSE' |
||
508 | }, |
||
509 | # Same as asset 1, but with a later end date. |
||
510 | { |
||
511 | 'sid': 2, |
||
512 | 'symbol': 'multiple', |
||
513 | 'start_date': pd.Timestamp('2010-01-01'), |
||
514 | 'end_date': pd.Timestamp('2013-01-01'), |
||
515 | 'exchange': 'NYSE' |
||
516 | }, |
||
517 | # Same as asset 1, but with a later start_date |
||
518 | { |
||
519 | 'sid': 3, |
||
520 | 'symbol': 'multiple', |
||
521 | 'start_date': pd.Timestamp('2011-01-01'), |
||
522 | 'end_date': pd.Timestamp('2012-01-01'), |
||
523 | 'exchange': 'NYSE' |
||
524 | }, |
||
525 | ] |
||
526 | ) |
||
527 | |||
528 | def check(expected_sid, date): |
||
529 | result = finder.lookup_symbol( |
||
530 | 'MULTIPLE', date, |
||
531 | ) |
||
532 | self.assertEqual(result.symbol, 'MULTIPLE') |
||
533 | self.assertEqual(result.sid, expected_sid) |
||
534 | |||
535 | with tmp_asset_finder(finder_cls=self.asset_finder_type, |
||
536 | equities=df) as finder: |
||
537 | self.assertIsInstance(finder, self.asset_finder_type) |
||
538 | |||
539 | # Sids 1 and 2 are eligible here. We should get asset 2 because it |
||
540 | # has the later end_date. |
||
541 | check(2, pd.Timestamp('2010-12-31')) |
||
542 | |||
543 | # Sids 1, 2, and 3 are eligible here. We should get sid 3 because |
||
544 | # it has a later start_date |
||
545 | check(3, pd.Timestamp('2011-01-01')) |
||
546 | |||
547 | def test_lookup_generic(self): |
||
548 | """ |
||
549 | Ensure that lookup_generic works with various permutations of inputs. |
||
550 | """ |
||
551 | with build_lookup_generic_cases(self.asset_finder_type) as cases: |
||
552 | for finder, symbols, reference_date, expected in cases: |
||
553 | results, missing = finder.lookup_generic(symbols, |
||
554 | reference_date) |
||
555 | self.assertEqual(results, expected) |
||
556 | self.assertEqual(missing, []) |
||
557 | |||
558 | def test_lookup_generic_handle_missing(self): |
||
559 | data = pd.DataFrame.from_records( |
||
560 | [ |
||
561 | { |
||
562 | 'sid': 0, |
||
563 | 'symbol': 'real', |
||
564 | 'start_date': pd.Timestamp('2013-1-1', tz='UTC'), |
||
565 | 'end_date': pd.Timestamp('2014-1-1', tz='UTC'), |
||
566 | 'exchange': '', |
||
567 | }, |
||
568 | { |
||
569 | 'sid': 1, |
||
570 | 'symbol': 'also_real', |
||
571 | 'start_date': pd.Timestamp('2013-1-1', tz='UTC'), |
||
572 | 'end_date': pd.Timestamp('2014-1-1', tz='UTC'), |
||
573 | 'exchange': '', |
||
574 | }, |
||
575 | # Sid whose end date is before our query date. We should |
||
576 | # still correctly find it. |
||
577 | { |
||
578 | 'sid': 2, |
||
579 | 'symbol': 'real_but_old', |
||
580 | 'start_date': pd.Timestamp('2002-1-1', tz='UTC'), |
||
581 | 'end_date': pd.Timestamp('2003-1-1', tz='UTC'), |
||
582 | 'exchange': '', |
||
583 | }, |
||
584 | # Sid whose start_date is **after** our query date. We should |
||
585 | # **not** find it. |
||
586 | { |
||
587 | 'sid': 3, |
||
588 | 'symbol': 'real_but_in_the_future', |
||
589 | 'start_date': pd.Timestamp('2014-1-1', tz='UTC'), |
||
590 | 'end_date': pd.Timestamp('2020-1-1', tz='UTC'), |
||
591 | 'exchange': 'THE FUTURE', |
||
592 | }, |
||
593 | ] |
||
594 | ) |
||
595 | self.env.write_data(equities_df=data) |
||
596 | finder = self.asset_finder_type(self.env.engine) |
||
597 | results, missing = finder.lookup_generic( |
||
598 | ['REAL', 1, 'FAKE', 'REAL_BUT_OLD', 'REAL_BUT_IN_THE_FUTURE'], |
||
599 | pd.Timestamp('2013-02-01', tz='UTC'), |
||
600 | ) |
||
601 | |||
602 | self.assertEqual(len(results), 3) |
||
603 | self.assertEqual(results[0].symbol, 'REAL') |
||
604 | self.assertEqual(results[0].sid, 0) |
||
605 | self.assertEqual(results[1].symbol, 'ALSO_REAL') |
||
606 | self.assertEqual(results[1].sid, 1) |
||
607 | self.assertEqual(results[2].symbol, 'REAL_BUT_OLD') |
||
608 | self.assertEqual(results[2].sid, 2) |
||
609 | |||
610 | self.assertEqual(len(missing), 2) |
||
611 | self.assertEqual(missing[0], 'FAKE') |
||
612 | self.assertEqual(missing[1], 'REAL_BUT_IN_THE_FUTURE') |
||
613 | |||
614 | def test_insert_metadata(self): |
||
615 | data = {0: {'start_date': '2014-01-01', |
||
616 | 'end_date': '2015-01-01', |
||
617 | 'symbol': "PLAY", |
||
618 | 'foo_data': "FOO"}} |
||
619 | self.env.write_data(equities_data=data) |
||
620 | finder = self.asset_finder_type(self.env.engine) |
||
621 | # Test proper insertion |
||
622 | equity = finder.retrieve_asset(0) |
||
623 | self.assertIsInstance(equity, Equity) |
||
624 | self.assertEqual('PLAY', equity.symbol) |
||
625 | self.assertEqual(pd.Timestamp('2015-01-01', tz='UTC'), |
||
626 | equity.end_date) |
||
627 | |||
628 | # Test invalid field |
||
629 | with self.assertRaises(AttributeError): |
||
630 | equity.foo_data |
||
631 | |||
632 | def test_consume_metadata(self): |
||
633 | |||
634 | # Test dict consumption |
||
635 | dict_to_consume = {0: {'symbol': 'PLAY'}, |
||
636 | 1: {'symbol': 'MSFT'}} |
||
637 | self.env.write_data(equities_data=dict_to_consume) |
||
638 | finder = self.asset_finder_type(self.env.engine) |
||
639 | |||
640 | equity = finder.retrieve_asset(0) |
||
641 | self.assertIsInstance(equity, Equity) |
||
642 | self.assertEqual('PLAY', equity.symbol) |
||
643 | |||
644 | # Test dataframe consumption |
||
645 | df = pd.DataFrame(columns=['asset_name', 'exchange'], index=[0, 1]) |
||
646 | df['asset_name'][0] = "Dave'N'Busters" |
||
647 | df['exchange'][0] = "NASDAQ" |
||
648 | df['asset_name'][1] = "Microsoft" |
||
649 | df['exchange'][1] = "NYSE" |
||
650 | self.env = TradingEnvironment(load=noop_load) |
||
651 | self.env.write_data(equities_df=df) |
||
652 | finder = self.asset_finder_type(self.env.engine) |
||
653 | self.assertEqual('NASDAQ', finder.retrieve_asset(0).exchange) |
||
654 | self.assertEqual('Microsoft', finder.retrieve_asset(1).asset_name) |
||
655 | |||
656 | def test_consume_asset_as_identifier(self): |
||
657 | # Build some end dates |
||
658 | eq_end = pd.Timestamp('2012-01-01', tz='UTC') |
||
659 | fut_end = pd.Timestamp('2008-01-01', tz='UTC') |
||
660 | |||
661 | # Build some simple Assets |
||
662 | equity_asset = Equity(1, symbol="TESTEQ", end_date=eq_end) |
||
663 | future_asset = Future(200, symbol="TESTFUT", end_date=fut_end) |
||
664 | |||
665 | # Consume the Assets |
||
666 | self.env.write_data(equities_identifiers=[equity_asset], |
||
667 | futures_identifiers=[future_asset]) |
||
668 | finder = self.asset_finder_type(self.env.engine) |
||
669 | |||
670 | # Test equality with newly built Assets |
||
671 | self.assertEqual(equity_asset, finder.retrieve_asset(1)) |
||
672 | self.assertEqual(future_asset, finder.retrieve_asset(200)) |
||
673 | self.assertEqual(eq_end, finder.retrieve_asset(1).end_date) |
||
674 | self.assertEqual(fut_end, finder.retrieve_asset(200).end_date) |
||
675 | |||
676 | def test_sid_assignment(self): |
||
677 | |||
678 | # This metadata does not contain SIDs |
||
679 | metadata = ['PLAY', 'MSFT'] |
||
680 | |||
681 | today = normalize_date(pd.Timestamp('2015-07-09', tz='UTC')) |
||
682 | |||
683 | # Write data with sid assignment |
||
684 | self.env.write_data(equities_identifiers=metadata, |
||
685 | allow_sid_assignment=True) |
||
686 | |||
687 | # Verify that Assets were built and different sids were assigned |
||
688 | finder = self.asset_finder_type(self.env.engine) |
||
689 | play = finder.lookup_symbol('PLAY', today) |
||
690 | msft = finder.lookup_symbol('MSFT', today) |
||
691 | self.assertEqual('PLAY', play.symbol) |
||
692 | self.assertIsNotNone(play.sid) |
||
693 | self.assertNotEqual(play.sid, msft.sid) |
||
694 | |||
695 | def test_sid_assignment_failure(self): |
||
696 | |||
697 | # This metadata does not contain SIDs |
||
698 | metadata = ['PLAY', 'MSFT'] |
||
699 | |||
700 | # Write data without sid assignment, asserting failure |
||
701 | with self.assertRaises(SidAssignmentError): |
||
702 | self.env.write_data(equities_identifiers=metadata, |
||
703 | allow_sid_assignment=False) |
||
704 | |||
705 | def test_security_dates_warning(self): |
||
706 | |||
707 | # Build an asset with an end_date |
||
708 | eq_end = pd.Timestamp('2012-01-01', tz='UTC') |
||
709 | equity_asset = Equity(1, symbol="TESTEQ", end_date=eq_end) |
||
710 | |||
711 | # Catch all warnings |
||
712 | with warnings.catch_warnings(record=True) as w: |
||
713 | # Cause all warnings to always be triggered |
||
714 | warnings.simplefilter("always") |
||
715 | equity_asset.security_start_date |
||
716 | equity_asset.security_end_date |
||
717 | equity_asset.security_name |
||
718 | # Verify the warning |
||
719 | self.assertEqual(3, len(w)) |
||
720 | for warning in w: |
||
721 | self.assertTrue(issubclass(warning.category, |
||
722 | DeprecationWarning)) |
||
723 | |||
724 | def test_lookup_future_chain(self): |
||
725 | metadata = { |
||
726 | # Notice day is today, so should be valid. |
||
727 | 0: { |
||
728 | 'symbol': 'ADN15', |
||
729 | 'root_symbol': 'AD', |
||
730 | 'notice_date': pd.Timestamp('2015-05-14', tz='UTC'), |
||
731 | 'expiration_date': pd.Timestamp('2015-06-14', tz='UTC'), |
||
732 | 'start_date': pd.Timestamp('2015-01-01', tz='UTC') |
||
733 | }, |
||
734 | 1: { |
||
735 | 'symbol': 'ADV15', |
||
736 | 'root_symbol': 'AD', |
||
737 | 'notice_date': pd.Timestamp('2015-08-14', tz='UTC'), |
||
738 | 'expiration_date': pd.Timestamp('2015-09-14', tz='UTC'), |
||
739 | 'start_date': pd.Timestamp('2015-01-01', tz='UTC') |
||
740 | }, |
||
741 | # Starts trading today, so should be valid. |
||
742 | 2: { |
||
743 | 'symbol': 'ADF16', |
||
744 | 'root_symbol': 'AD', |
||
745 | 'notice_date': pd.Timestamp('2015-11-16', tz='UTC'), |
||
746 | 'expiration_date': pd.Timestamp('2015-12-16', tz='UTC'), |
||
747 | 'start_date': pd.Timestamp('2015-05-14', tz='UTC') |
||
748 | }, |
||
749 | # Starts trading in August, so not valid. |
||
750 | 3: { |
||
751 | 'symbol': 'ADX16', |
||
752 | 'root_symbol': 'AD', |
||
753 | 'notice_date': pd.Timestamp('2015-11-16', tz='UTC'), |
||
754 | 'expiration_date': pd.Timestamp('2015-12-16', tz='UTC'), |
||
755 | 'start_date': pd.Timestamp('2015-08-01', tz='UTC') |
||
756 | }, |
||
757 | # Notice date comes after expiration |
||
758 | 4: { |
||
759 | 'symbol': 'ADZ16', |
||
760 | 'root_symbol': 'AD', |
||
761 | 'notice_date': pd.Timestamp('2016-11-25', tz='UTC'), |
||
762 | 'expiration_date': pd.Timestamp('2016-11-16', tz='UTC'), |
||
763 | 'start_date': pd.Timestamp('2015-08-01', tz='UTC') |
||
764 | }, |
||
765 | # This contract has no start date and also this contract should be |
||
766 | # last in all chains |
||
767 | 5: { |
||
768 | 'symbol': 'ADZ20', |
||
769 | 'root_symbol': 'AD', |
||
770 | 'notice_date': pd.Timestamp('2020-11-25', tz='UTC'), |
||
771 | 'expiration_date': pd.Timestamp('2020-11-16', tz='UTC') |
||
772 | }, |
||
773 | } |
||
774 | self.env.write_data(futures_data=metadata) |
||
775 | finder = self.asset_finder_type(self.env.engine) |
||
776 | dt = pd.Timestamp('2015-05-14', tz='UTC') |
||
777 | dt_2 = pd.Timestamp('2015-10-14', tz='UTC') |
||
778 | dt_3 = pd.Timestamp('2016-11-17', tz='UTC') |
||
779 | |||
780 | # Check that we get the expected number of contracts, in the |
||
781 | # right order |
||
782 | ad_contracts = finder.lookup_future_chain('AD', dt) |
||
783 | self.assertEqual(len(ad_contracts), 6) |
||
784 | self.assertEqual(ad_contracts[0].sid, 0) |
||
785 | self.assertEqual(ad_contracts[1].sid, 1) |
||
786 | self.assertEqual(ad_contracts[5].sid, 5) |
||
787 | |||
788 | # Check that, when some contracts have expired, the chain has advanced |
||
789 | # properly to the next contracts |
||
790 | ad_contracts = finder.lookup_future_chain('AD', dt_2) |
||
791 | self.assertEqual(len(ad_contracts), 4) |
||
792 | self.assertEqual(ad_contracts[0].sid, 2) |
||
793 | self.assertEqual(ad_contracts[3].sid, 5) |
||
794 | |||
795 | # Check that when the expiration_date has passed but the |
||
796 | # notice_date hasn't, contract is still considered invalid. |
||
797 | ad_contracts = finder.lookup_future_chain('AD', dt_3) |
||
798 | self.assertEqual(len(ad_contracts), 1) |
||
799 | self.assertEqual(ad_contracts[0].sid, 5) |
||
800 | |||
801 | # Check that pd.NaT for as_of_date gives the whole chain |
||
802 | ad_contracts = finder.lookup_future_chain('AD', pd.NaT) |
||
803 | self.assertEqual(len(ad_contracts), 6) |
||
804 | self.assertEqual(ad_contracts[5].sid, 5) |
||
805 | |||
806 | def test_map_identifier_index_to_sids(self): |
||
807 | # Build an empty finder and some Assets |
||
808 | dt = pd.Timestamp('2014-01-01', tz='UTC') |
||
809 | finder = self.asset_finder_type(self.env.engine) |
||
810 | asset1 = Equity(1, symbol="AAPL") |
||
811 | asset2 = Equity(2, symbol="GOOG") |
||
812 | asset200 = Future(200, symbol="CLK15") |
||
813 | asset201 = Future(201, symbol="CLM15") |
||
814 | |||
815 | # Check for correct mapping and types |
||
816 | pre_map = [asset1, asset2, asset200, asset201] |
||
817 | post_map = finder.map_identifier_index_to_sids(pre_map, dt) |
||
818 | self.assertListEqual([1, 2, 200, 201], post_map) |
||
819 | for sid in post_map: |
||
820 | self.assertIsInstance(sid, int) |
||
821 | |||
822 | # Change order and check mapping again |
||
823 | pre_map = [asset201, asset2, asset200, asset1] |
||
824 | post_map = finder.map_identifier_index_to_sids(pre_map, dt) |
||
825 | self.assertListEqual([201, 2, 200, 1], post_map) |
||
826 | |||
827 | def test_compute_lifetimes(self): |
||
828 | num_assets = 4 |
||
829 | trading_day = self.env.trading_day |
||
830 | first_start = pd.Timestamp('2015-04-01', tz='UTC') |
||
831 | |||
832 | frame = make_rotating_equity_info( |
||
833 | num_assets=num_assets, |
||
834 | first_start=first_start, |
||
835 | frequency=self.env.trading_day, |
||
836 | periods_between_starts=3, |
||
837 | asset_lifetime=5 |
||
838 | ) |
||
839 | |||
840 | self.env.write_data(equities_df=frame) |
||
841 | finder = self.env.asset_finder |
||
842 | |||
843 | all_dates = pd.date_range( |
||
844 | start=first_start, |
||
845 | end=frame.end_date.max(), |
||
846 | freq=trading_day, |
||
847 | ) |
||
848 | |||
849 | for dates in all_subindices(all_dates): |
||
850 | expected_with_start_raw = full( |
||
851 | shape=(len(dates), num_assets), |
||
852 | fill_value=False, |
||
853 | dtype=bool, |
||
854 | ) |
||
855 | expected_no_start_raw = full( |
||
856 | shape=(len(dates), num_assets), |
||
857 | fill_value=False, |
||
858 | dtype=bool, |
||
859 | ) |
||
860 | |||
861 | for i, date in enumerate(dates): |
||
862 | it = frame[['start_date', 'end_date']].itertuples() |
||
863 | for j, start, end in it: |
||
864 | # This way of doing the checks is redundant, but very |
||
865 | # clear. |
||
866 | if start <= date <= end: |
||
867 | expected_with_start_raw[i, j] = True |
||
868 | if start < date: |
||
869 | expected_no_start_raw[i, j] = True |
||
870 | |||
871 | expected_with_start = pd.DataFrame( |
||
872 | data=expected_with_start_raw, |
||
873 | index=dates, |
||
874 | columns=frame.index.values, |
||
875 | ) |
||
876 | result = finder.lifetimes(dates, include_start_date=True) |
||
877 | assert_frame_equal(result, expected_with_start) |
||
878 | |||
879 | expected_no_start = pd.DataFrame( |
||
880 | data=expected_no_start_raw, |
||
881 | index=dates, |
||
882 | columns=frame.index.values, |
||
883 | ) |
||
884 | result = finder.lifetimes(dates, include_start_date=False) |
||
885 | assert_frame_equal(result, expected_no_start) |
||
886 | |||
887 | def test_sids(self): |
||
888 | # Ensure that the sids property of the AssetFinder is functioning |
||
889 | self.env.write_data(equities_identifiers=[1, 2, 3]) |
||
890 | sids = self.env.asset_finder.sids |
||
891 | self.assertEqual(3, len(sids)) |
||
892 | self.assertTrue(1 in sids) |
||
893 | self.assertTrue(2 in sids) |
||
894 | self.assertTrue(3 in sids) |
||
895 | |||
896 | def test_group_by_type(self): |
||
897 | equities = make_simple_equity_info( |
||
898 | range(5), |
||
899 | start_date=pd.Timestamp('2014-01-01'), |
||
900 | end_date=pd.Timestamp('2015-01-01'), |
||
901 | ) |
||
902 | futures = make_commodity_future_info( |
||
903 | first_sid=6, |
||
904 | root_symbols=['CL'], |
||
905 | years=[2014], |
||
906 | ) |
||
907 | # Intersecting sid queries, to exercise loading of partially-cached |
||
908 | # results. |
||
909 | queries = [ |
||
910 | ([0, 1, 3], [6, 7]), |
||
911 | ([0, 2, 3], [7, 10]), |
||
912 | (list(equities.index), list(futures.index)), |
||
913 | ] |
||
914 | with tmp_asset_finder(equities=equities, futures=futures) as finder: |
||
915 | for equity_sids, future_sids in queries: |
||
916 | results = finder.group_by_type(equity_sids + future_sids) |
||
917 | self.assertEqual( |
||
918 | results, |
||
919 | {'equity': set(equity_sids), 'future': set(future_sids)}, |
||
920 | ) |
||
921 | |||
922 | @parameterized.expand([ |
||
923 | (Equity, 'retrieve_equities', EquitiesNotFound), |
||
924 | (Future, 'retrieve_futures_contracts', FutureContractsNotFound), |
||
925 | ]) |
||
926 | def test_retrieve_specific_type(self, type_, lookup_name, failure_type): |
||
927 | equities = make_simple_equity_info( |
||
928 | range(5), |
||
929 | start_date=pd.Timestamp('2014-01-01'), |
||
930 | end_date=pd.Timestamp('2015-01-01'), |
||
931 | ) |
||
932 | max_equity = equities.index.max() |
||
933 | futures = make_commodity_future_info( |
||
934 | first_sid=max_equity + 1, |
||
935 | root_symbols=['CL'], |
||
936 | years=[2014], |
||
937 | ) |
||
938 | equity_sids = [0, 1] |
||
939 | future_sids = [max_equity + 1, max_equity + 2, max_equity + 3] |
||
940 | if type_ == Equity: |
||
941 | success_sids = equity_sids |
||
942 | fail_sids = future_sids |
||
943 | else: |
||
944 | fail_sids = equity_sids |
||
945 | success_sids = future_sids |
||
946 | |||
947 | with tmp_asset_finder(equities=equities, futures=futures) as finder: |
||
948 | # Run twice to exercise caching. |
||
949 | lookup = getattr(finder, lookup_name) |
||
950 | for _ in range(2): |
||
951 | results = lookup(success_sids) |
||
952 | self.assertIsInstance(results, dict) |
||
953 | self.assertEqual(set(results.keys()), set(success_sids)) |
||
954 | self.assertEqual( |
||
955 | valmap(int, results), |
||
956 | dict(zip(success_sids, success_sids)), |
||
957 | ) |
||
958 | self.assertEqual( |
||
959 | {type_}, |
||
960 | {type(asset) for asset in itervalues(results)}, |
||
961 | ) |
||
962 | with self.assertRaises(failure_type): |
||
963 | lookup(fail_sids) |
||
964 | with self.assertRaises(failure_type): |
||
965 | # Should fail if **any** of the assets are bad. |
||
966 | lookup([success_sids[0], fail_sids[0]]) |
||
967 | |||
968 | def test_retrieve_all(self): |
||
969 | equities = make_simple_equity_info( |
||
970 | range(5), |
||
971 | start_date=pd.Timestamp('2014-01-01'), |
||
972 | end_date=pd.Timestamp('2015-01-01'), |
||
973 | ) |
||
974 | max_equity = equities.index.max() |
||
975 | futures = make_commodity_future_info( |
||
976 | first_sid=max_equity + 1, |
||
977 | root_symbols=['CL'], |
||
978 | years=[2014], |
||
979 | ) |
||
980 | |||
981 | with tmp_asset_finder(equities=equities, futures=futures) as finder: |
||
982 | all_sids = finder.sids |
||
983 | self.assertEqual(len(all_sids), len(equities) + len(futures)) |
||
984 | queries = [ |
||
985 | # Empty Query. |
||
986 | (), |
||
987 | # Only Equities. |
||
988 | tuple(equities.index[:2]), |
||
989 | # Only Futures. |
||
990 | tuple(futures.index[:3]), |
||
991 | # Mixed, all cache misses. |
||
992 | tuple(equities.index[2:]) + tuple(futures.index[3:]), |
||
993 | # Mixed, all cache hits. |
||
994 | tuple(equities.index[2:]) + tuple(futures.index[3:]), |
||
995 | # Everything. |
||
996 | all_sids, |
||
997 | all_sids, |
||
998 | ] |
||
999 | for sids in queries: |
||
1000 | equity_sids = [i for i in sids if i <= max_equity] |
||
1001 | future_sids = [i for i in sids if i > max_equity] |
||
1002 | results = finder.retrieve_all(sids) |
||
1003 | self.assertEqual(sids, tuple(map(int, results))) |
||
1004 | |||
1005 | self.assertEqual( |
||
1006 | [Equity for _ in equity_sids] + |
||
1007 | [Future for _ in future_sids], |
||
1008 | list(map(type, results)), |
||
1009 | ) |
||
1010 | self.assertEqual( |
||
1011 | ( |
||
1012 | list(equities.symbol.loc[equity_sids]) + |
||
1013 | list(futures.symbol.loc[future_sids]) |
||
1014 | ), |
||
1015 | list(asset.symbol for asset in results), |
||
1016 | ) |
||
1017 | |||
1018 | @parameterized.expand([ |
||
1019 | (EquitiesNotFound, 'equity', 'equities'), |
||
1020 | (FutureContractsNotFound, 'future contract', 'future contracts'), |
||
1021 | (SidsNotFound, 'asset', 'assets'), |
||
1022 | ]) |
||
1023 | def test_error_message_plurality(self, |
||
1024 | error_type, |
||
1025 | singular, |
||
1026 | plural): |
||
1027 | try: |
||
1028 | raise error_type(sids=[1]) |
||
1029 | except error_type as e: |
||
1030 | self.assertEqual( |
||
1031 | str(e), |
||
1032 | "No {singular} found for sid: 1.".format(singular=singular) |
||
1033 | ) |
||
1034 | try: |
||
1035 | raise error_type(sids=[1, 2]) |
||
1036 | except error_type as e: |
||
1037 | self.assertEqual( |
||
1038 | str(e), |
||
1039 | "No {plural} found for sids: [1, 2].".format(plural=plural) |
||
1040 | ) |
||
1357 |