| 1 |  |  | import pandas as pd | 
            
                                                                                                            
                            
            
                                    
            
            
                | 2 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 3 |  |  | from datetime import timedelta | 
            
                                                                                                            
                            
            
                                    
            
            
                | 4 |  |  | from unittest import TestCase | 
            
                                                                                                            
                            
            
                                    
            
            
                | 5 |  |  | from testfixtures import TempDirectory | 
            
                                                                                                            
                            
            
                                    
            
            
                | 6 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 7 |  |  | from zipline.algorithm import TradingAlgorithm | 
            
                                                                                                            
                            
            
                                    
            
            
                | 8 |  |  | from zipline.errors import TradingControlViolation | 
            
                                                                                                            
                            
            
                                    
            
            
                | 9 |  |  | from zipline.finance.trading import TradingEnvironment | 
            
                                                                                                            
                            
            
                                    
            
            
                | 10 |  |  | from zipline.utils.test_utils import ( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 11 |  |  |     setup_logger, teardown_logger, security_list_copy, add_security_data,) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 12 |  |  | from zipline.utils import factory | 
            
                                                                                                            
                            
            
                                    
            
            
                | 13 |  |  | from zipline.utils.security_list import ( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 14 |  |  |     SecurityListSet, load_from_directory) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 15 |  |  | from zipline.utils.test_utils import create_data_portal | 
            
                                                                                                            
                            
            
                                    
            
            
                | 16 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 17 |  |  | LEVERAGED_ETFS = load_from_directory('leveraged_etf_list') | 
            
                                                                                                            
                            
            
                                    
            
            
                | 18 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 19 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 20 |  |  | class RestrictedAlgoWithCheck(TradingAlgorithm): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 21 |  |  |     def initialize(self, symbol): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 22 |  |  |         self.rl = SecurityListSet(self.get_datetime, self.asset_finder) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 23 |  |  |         self.set_do_not_order_list(self.rl.leveraged_etf_list) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 24 |  |  |         self.order_count = 0 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 25 |  |  |         self.sid = self.symbol(symbol) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 26 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 27 |  |  |     def handle_data(self, data): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 28 |  |  |         if not self.order_count: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 29 |  |  |             if self.sid not in \ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 30 |  |  |                     self.rl.leveraged_etf_list: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 31 |  |  |                 self.order(self.sid, 100) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 32 |  |  |                 self.order_count += 1 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 33 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 34 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 35 |  |  | class RestrictedAlgoWithoutCheck(TradingAlgorithm): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 36 |  |  |     def initialize(self, symbol): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 37 |  |  |         self.rl = SecurityListSet(self.get_datetime, self.asset_finder) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 38 |  |  |         self.set_do_not_order_list(self.rl.leveraged_etf_list) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 39 |  |  |         self.order_count = 0 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 40 |  |  |         self.sid = self.symbol(symbol) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 41 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 42 |  |  |     def handle_data(self, data): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 43 |  |  |         self.order(self.sid, 100) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 44 |  |  |         self.order_count += 1 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 45 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 46 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 47 |  |  | class IterateRLAlgo(TradingAlgorithm): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 48 |  |  |     def initialize(self, symbol): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 49 |  |  |         self.rl = SecurityListSet(self.get_datetime, self.asset_finder) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 50 |  |  |         self.set_do_not_order_list(self.rl.leveraged_etf_list) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 51 |  |  |         self.order_count = 0 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 52 |  |  |         self.sid = self.symbol(symbol) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 53 |  |  |         self.found = False | 
            
                                                                                                            
                            
            
                                    
            
            
                | 54 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 55 |  |  |     def handle_data(self, data): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 56 |  |  |         for stock in self.rl.leveraged_etf_list: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 57 |  |  |             if stock == self.sid: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 58 |  |  |                 self.found = True | 
            
                                                                                                            
                            
            
                                    
            
            
                | 59 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 60 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 61 |  |  | class SecurityListTestCase(TestCase): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 62 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 63 |  |  |     @classmethod | 
            
                                                                                                            
                            
            
                                    
            
            
                | 64 |  |  |     def setUpClass(cls): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 65 |  |  |         # this is ugly, but we need to create two different | 
            
                                                                                                            
                            
            
                                    
            
            
                | 66 |  |  |         # TradingEnvironment/DataPortal pairs | 
            
                                                                                                            
                            
            
                                    
            
            
                | 67 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 68 |  |  |         cls.env = TradingEnvironment() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 69 |  |  |         cls.env2 = TradingEnvironment() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 70 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 71 |  |  |         cls.extra_knowledge_date = pd.Timestamp("2015-01-27", tz='UTC') | 
            
                                                                                                            
                            
            
                                    
            
            
                | 72 |  |  |         cls.trading_day_before_first_kd = pd.Timestamp("2015-01-23", tz='UTC') | 
            
                                                                                                            
                            
            
                                    
            
            
                | 73 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 74 |  |  |         symbols = ['AAPL', 'GOOG', 'BZQ', 'URTY', 'JFT'] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 75 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 76 |  |  |         days = cls.env.days_in_range( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 77 |  |  |             list(LEVERAGED_ETFS.keys())[0], | 
            
                                                                                                            
                            
            
                                    
            
            
                | 78 |  |  |             pd.Timestamp("2015-02-17", tz='UTC') | 
            
                                                                                                            
                            
            
                                    
            
            
                | 79 |  |  |         ) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 80 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 81 |  |  |         cls.sim_params = factory.create_simulation_parameters( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 82 |  |  |             start=list(LEVERAGED_ETFS.keys())[0], | 
            
                                                                                                            
                            
            
                                    
            
            
                | 83 |  |  |             num_days=4, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 84 |  |  |             env=cls.env | 
            
                                                                                                            
                            
            
                                    
            
            
                | 85 |  |  |         ) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 86 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 87 |  |  |         cls.sim_params2 = factory.create_simulation_parameters( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 88 |  |  |             start=cls.trading_day_before_first_kd, num_days=4 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 89 |  |  |         ) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 90 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 91 |  |  |         equities_metadata = {} | 
            
                                                                                                            
                            
            
                                    
            
            
                | 92 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 93 |  |  |         for i, symbol in enumerate(symbols): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 94 |  |  |             equities_metadata[i] = { | 
            
                                                                                                            
                            
            
                                    
            
            
                | 95 |  |  |                 'start_date': days[0], | 
            
                                                                                                            
                            
            
                                    
            
            
                | 96 |  |  |                 'end_date': days[-1], | 
            
                                                                                                            
                            
            
                                    
            
            
                | 97 |  |  |                 'symbol': symbol | 
            
                                                                                                            
                            
            
                                    
            
            
                | 98 |  |  |             } | 
            
                                                                                                            
                            
            
                                    
            
            
                | 99 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 100 |  |  |         equities_metadata2 = {} | 
            
                                                                                                            
                            
            
                                    
            
            
                | 101 |  |  |         for i, symbol in enumerate(symbols): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 102 |  |  |             equities_metadata2[i] = { | 
            
                                                                                                            
                            
            
                                    
            
            
                | 103 |  |  |                 'start_date': cls.sim_params2.period_start, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 104 |  |  |                 'end_date': cls.sim_params2.period_end, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 105 |  |  |                 'symbol': symbol | 
            
                                                                                                            
                            
            
                                    
            
            
                | 106 |  |  |             } | 
            
                                                                                                            
                            
            
                                    
            
            
                | 107 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 108 |  |  |         cls.env.write_data(equities_data=equities_metadata) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 109 |  |  |         cls.env2.write_data(equities_data=equities_metadata2) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 110 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 111 |  |  |         cls.tempdir = TempDirectory() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 112 |  |  |         cls.tempdir2 = TempDirectory() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 113 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 114 |  |  |         cls.data_portal = create_data_portal( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 115 |  |  |             env=cls.env, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 116 |  |  |             tempdir=cls.tempdir, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 117 |  |  |             sim_params=cls.sim_params, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 118 |  |  |             sids=range(0, 5), | 
            
                                                                                                            
                            
            
                                    
            
            
                | 119 |  |  |         ) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 120 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 121 |  |  |         cls.data_portal2 = create_data_portal( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 122 |  |  |             env=cls.env2, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 123 |  |  |             tempdir=cls.tempdir2, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 124 |  |  |             sim_params=cls.sim_params2, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 125 |  |  |             sids=range(0, 5) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 126 |  |  |         ) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 127 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 128 |  |  |         setup_logger(cls) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 129 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 130 |  |  |     @classmethod | 
            
                                                                                                            
                            
            
                                    
            
            
                | 131 |  |  |     def tearDownClass(cls): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 132 |  |  |         del cls.env | 
            
                                                                                                            
                            
            
                                    
            
            
                | 133 |  |  |         cls.tempdir.cleanup() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 134 |  |  |         cls.tempdir2.cleanup() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 135 |  |  |         teardown_logger(cls) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 136 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 137 |  |  |     def test_iterate_over_restricted_list(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 138 |  |  |         algo = IterateRLAlgo(symbol='BZQ', sim_params=self.sim_params, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 139 |  |  |                              env=self.env) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 140 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 141 |  |  |         algo.run(self.data_portal) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 142 |  |  |         self.assertTrue(algo.found) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 143 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 144 |  |  |     def test_security_list(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 145 |  |  |         # set the knowledge date to the first day of the | 
            
                                                                                                            
                            
            
                                    
            
            
                | 146 |  |  |         # leveraged etf knowledge date. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 147 |  |  |         def get_datetime(): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 148 |  |  |             return list(LEVERAGED_ETFS.keys())[0] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 149 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 150 |  |  |         rl = SecurityListSet(get_datetime, self.env.asset_finder) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 151 |  |  |         # assert that a sample from the leveraged list are in restricted | 
            
                                                                                                            
                            
            
                                    
            
            
                | 152 |  |  |         should_exist = [ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 153 |  |  |             asset.sid for asset in | 
            
                                                                                                            
                            
            
                                    
            
            
                | 154 |  |  |             [self.env.asset_finder.lookup_symbol( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 155 |  |  |                 symbol, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 156 |  |  |                 as_of_date=self.extra_knowledge_date) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 157 |  |  |              for symbol in ["BZQ", "URTY", "JFT"]] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 158 |  |  |         ] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 159 |  |  |         for sid in should_exist: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 160 |  |  |             self.assertIn(sid, rl.leveraged_etf_list) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 161 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 162 |  |  |         # assert that a sample of allowed stocks are not in restricted | 
            
                                                                                                            
                            
            
                                    
            
            
                | 163 |  |  |         shouldnt_exist = [ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 164 |  |  |             asset.sid for asset in | 
            
                                                                                                            
                            
            
                                    
            
            
                | 165 |  |  |             [self.env.asset_finder.lookup_symbol( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 166 |  |  |                 symbol, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 167 |  |  |                 as_of_date=self.extra_knowledge_date) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 168 |  |  |              for symbol in ["AAPL", "GOOG"]] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 169 |  |  |         ] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 170 |  |  |         for sid in shouldnt_exist: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 171 |  |  |             self.assertNotIn(sid, rl.leveraged_etf_list) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 172 |  |  |  | 
            
                                                                                                            
                                                                
            
                                    
            
            
                | 173 |  |  |     def test_security_add(self): | 
            
                                                                        
                            
            
                                    
            
            
                | 174 |  |  |         def get_datetime(): | 
            
                                                                        
                            
            
                                    
            
            
                | 175 |  |  |             return pd.Timestamp("2015-01-27", tz='UTC') | 
            
                                                                                                            
                            
            
                                    
            
            
                | 176 |  |  |         with security_list_copy(): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 177 |  |  |             add_security_data(['AAPL', 'GOOG'], []) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 178 |  |  |             rl = SecurityListSet(get_datetime, self.env.asset_finder) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 179 |  |  |             should_exist = [ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 180 |  |  |                 asset.sid for asset in | 
            
                                                                                                            
                            
            
                                    
            
            
                | 181 |  |  |                 [self.env.asset_finder.lookup_symbol( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 182 |  |  |                     symbol, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 183 |  |  |                     as_of_date=self.extra_knowledge_date | 
            
                                                                                                            
                            
            
                                    
            
            
                | 184 |  |  |                 ) for symbol in ["AAPL", "GOOG", "BZQ", "URTY"]] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 185 |  |  |             ] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 186 |  |  |             for sid in should_exist: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 187 |  |  |                 self.assertIn(sid, rl.leveraged_etf_list) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 188 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 189 |  |  |     def test_security_add_delete(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 190 |  |  |         with security_list_copy(): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 191 |  |  |             def get_datetime(): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 192 |  |  |                 return pd.Timestamp("2015-01-27", tz='UTC') | 
            
                                                                                                            
                            
            
                                    
            
            
                | 193 |  |  |             rl = SecurityListSet(get_datetime, self.env.asset_finder) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 194 |  |  |             self.assertNotIn("BZQ", rl.leveraged_etf_list) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 195 |  |  |             self.assertNotIn("URTY", rl.leveraged_etf_list) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 196 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 197 |  |  |     def test_algo_without_rl_violation_via_check(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 198 |  |  |         algo = RestrictedAlgoWithCheck(symbol='BZQ', | 
            
                                                                                                            
                            
            
                                    
            
            
                | 199 |  |  |                                        sim_params=self.sim_params, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 200 |  |  |                                        env=self.env) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 201 |  |  |         algo.run(self.data_portal) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 202 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 203 |  |  |     def test_algo_without_rl_violation(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 204 |  |  |         algo = RestrictedAlgoWithoutCheck(symbol='AAPL', | 
            
                                                                                                            
                            
            
                                    
            
            
                | 205 |  |  |                                           sim_params=self.sim_params, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 206 |  |  |                                           env=self.env) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 207 |  |  |         algo.run(self.data_portal) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 208 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 209 |  |  |     def test_algo_with_rl_violation(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 210 |  |  |         algo = RestrictedAlgoWithoutCheck(symbol='BZQ', | 
            
                                                                                                            
                            
            
                                    
            
            
                | 211 |  |  |                                           sim_params=self.sim_params, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 212 |  |  |                                           env=self.env) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 213 |  |  |         with self.assertRaises(TradingControlViolation) as ctx: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 214 |  |  |             algo.run(self.data_portal) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 215 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 216 |  |  |         self.check_algo_exception(algo, ctx, 0) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 217 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 218 |  |  |         # repeat with a symbol from a different lookup date | 
            
                                                                                                            
                            
            
                                    
            
            
                | 219 |  |  |         algo = RestrictedAlgoWithoutCheck(symbol='JFT', | 
            
                                                                                                            
                            
            
                                    
            
            
                | 220 |  |  |                                           sim_params=self.sim_params, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 221 |  |  |                                           env=self.env) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 222 |  |  |         with self.assertRaises(TradingControlViolation) as ctx: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 223 |  |  |             algo.run(self.data_portal) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 224 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 225 |  |  |         self.check_algo_exception(algo, ctx, 0) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 226 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 227 |  |  |     def test_algo_with_rl_violation_after_knowledge_date(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 228 |  |  |         sim_params = factory.create_simulation_parameters( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 229 |  |  |             start=list( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 230 |  |  |                 LEVERAGED_ETFS.keys())[0] + timedelta(days=7), num_days=5, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 231 |  |  |             env=self.env) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 232 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 233 |  |  |         data_portal = create_data_portal( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 234 |  |  |             self.env, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 235 |  |  |             self.tempdir, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 236 |  |  |             sim_params=sim_params, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 237 |  |  |             sids=range(0, 5) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 238 |  |  |         ) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 239 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 240 |  |  |         algo = RestrictedAlgoWithoutCheck(symbol='BZQ', | 
            
                                                                                                            
                            
            
                                    
            
            
                | 241 |  |  |                                           sim_params=sim_params, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 242 |  |  |                                           env=self.env) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 243 |  |  |         with self.assertRaises(TradingControlViolation) as ctx: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 244 |  |  |             algo.run(data_portal=data_portal) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 245 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 246 |  |  |         self.check_algo_exception(algo, ctx, 0) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 247 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 248 |  |  |     def test_algo_with_rl_violation_cumulative(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 249 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 250 |  |  |         Add a new restriction, run a test long after both | 
            
                                                                                                            
                            
            
                                    
            
            
                | 251 |  |  |         knowledge dates, make sure stock from original restriction | 
            
                                                                                                            
                            
            
                                    
            
            
                | 252 |  |  |         set is still disallowed. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 253 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 254 |  |  |         sim_params = factory.create_simulation_parameters( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 255 |  |  |             start=list( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 256 |  |  |                 LEVERAGED_ETFS.keys())[0] + timedelta(days=7), num_days=4) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 257 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 258 |  |  |         with security_list_copy(): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 259 |  |  |             add_security_data(['AAPL'], []) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 260 |  |  |             algo = RestrictedAlgoWithoutCheck( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 261 |  |  |                 symbol='BZQ', sim_params=sim_params, env=self.env) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 262 |  |  |             with self.assertRaises(TradingControlViolation) as ctx: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 263 |  |  |                 algo.run(self.data_portal) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 264 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 265 |  |  |             self.check_algo_exception(algo, ctx, 0) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 266 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 267 |  |  |     def test_algo_without_rl_violation_after_delete(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 268 |  |  |         new_tempdir = TempDirectory() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 269 |  |  |         try: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 270 |  |  |             with security_list_copy(): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 271 |  |  |                 # add a delete statement removing bzq | 
            
                                                                                                            
                            
            
                                    
            
            
                | 272 |  |  |                 # write a new delete statement file to disk | 
            
                                                                                                            
                            
            
                                    
            
            
                | 273 |  |  |                 add_security_data([], ['BZQ']) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 274 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 275 |  |  |                 # now fast-forward to self.extra_knowledge_date.  requires | 
            
                                                                                                            
                            
            
                                    
            
            
                | 276 |  |  |                 # a new env, simparams, and dataportal | 
            
                                                                                                            
                            
            
                                    
            
            
                | 277 |  |  |                 env = TradingEnvironment() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 278 |  |  |                 sim_params = factory.create_simulation_parameters( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 279 |  |  |                     start=self.extra_knowledge_date, num_days=4, env=env) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 280 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 281 |  |  |                 env.write_data(equities_data={ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 282 |  |  |                     "0": { | 
            
                                                                                                            
                            
            
                                    
            
            
                | 283 |  |  |                         'symbol': 'BZQ', | 
            
                                                                                                            
                            
            
                                    
            
            
                | 284 |  |  |                         'start_date': sim_params.period_start, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 285 |  |  |                         'end_date': sim_params.period_end, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 286 |  |  |                     } | 
            
                                                                                                            
                            
            
                                    
            
            
                | 287 |  |  |                 }) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 288 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 289 |  |  |                 data_portal = create_data_portal( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 290 |  |  |                     env, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 291 |  |  |                     new_tempdir, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 292 |  |  |                     sim_params, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 293 |  |  |                     range(0, 5) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 294 |  |  |                 ) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 295 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 296 |  |  |                 algo = RestrictedAlgoWithoutCheck( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 297 |  |  |                     symbol='BZQ', sim_params=sim_params, env=env | 
            
                                                                                                            
                            
            
                                    
            
            
                | 298 |  |  |                 ) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 299 |  |  |                 algo.run(data_portal) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 300 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 301 |  |  |         finally: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 302 |  |  |             new_tempdir.cleanup() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 303 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 304 |  |  |     def test_algo_with_rl_violation_after_add(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 305 |  |  |         with security_list_copy(): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 306 |  |  |             add_security_data(['AAPL'], []) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 307 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 308 |  |  |             algo = RestrictedAlgoWithoutCheck(symbol='AAPL', | 
            
                                                                                                            
                            
            
                                    
            
            
                | 309 |  |  |                                               sim_params=self.sim_params2, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 310 |  |  |                                               env=self.env2) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 311 |  |  |             with self.assertRaises(TradingControlViolation) as ctx: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 312 |  |  |                 algo.run(self.data_portal2) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 313 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 314 |  |  |             self.check_algo_exception(algo, ctx, 2) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 315 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 316 |  |  |     def check_algo_exception(self, algo, ctx, expected_order_count): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 317 |  |  |         self.assertEqual(algo.order_count, expected_order_count) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 318 |  |  |         exc = ctx.exception | 
            
                                                                                                            
                            
            
                                    
            
            
                | 319 |  |  |         self.assertEqual(TradingControlViolation, type(exc)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 320 |  |  |         exc_msg = str(ctx.exception) | 
            
                                                                                                            
                                                                
            
                                    
            
            
                | 321 |  |  |         self.assertTrue("RestrictedListOrder" in exc_msg) | 
            
                                                        
            
                                    
            
            
                | 322 |  |  |  |