| Total Complexity | 153 | 
| Total Lines | 1209 | 
| Duplicated Lines | 0 % | 
Complex classes like zipline.TradingAlgorithm often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
| 1 | #  | 
            ||
| 115 | class TradingAlgorithm(object):  | 
            ||
| 116 | """  | 
            ||
| 117 | Base class for trading algorithms. Inherit and overload  | 
            ||
| 118 | initialize() and handle_data(data).  | 
            ||
| 119 | |||
| 120 | A new algorithm could look like this:  | 
            ||
| 121 | ```  | 
            ||
| 122 | from zipline.api import order, symbol  | 
            ||
| 123 | |||
| 124 | def initialize(context):  | 
            ||
| 125 |         context.sid = symbol('AAPL') | 
            ||
| 126 | context.amount = 100  | 
            ||
| 127 | |||
| 128 | def handle_data(context, data):  | 
            ||
| 129 | sid = context.sid  | 
            ||
| 130 | amount = context.amount  | 
            ||
| 131 | order(sid, amount)  | 
            ||
| 132 | ```  | 
            ||
| 133 | To then to run this algorithm pass these functions to  | 
            ||
| 134 | TradingAlgorithm:  | 
            ||
| 135 | |||
| 136 | my_algo = TradingAlgorithm(initialize, handle_data)  | 
            ||
| 137 | stats = my_algo.run(data)  | 
            ||
| 138 | |||
| 139 | """  | 
            ||
| 140 | |||
| 141 | def __init__(self, *args, **kwargs):  | 
            ||
| 142 | """Initialize sids and other state variables.  | 
            ||
| 143 | |||
| 144 | :Arguments:  | 
            ||
| 145 | :Optional:  | 
            ||
| 146 | initialize : function  | 
            ||
| 147 | Function that is called with a single  | 
            ||
| 148 | argument at the begninning of the simulation.  | 
            ||
| 149 | handle_data : function  | 
            ||
| 150 | Function that is called with 2 arguments  | 
            ||
| 151 | (context and data) on every bar.  | 
            ||
| 152 | script : str  | 
            ||
| 153 | Algoscript that contains initialize and  | 
            ||
| 154 | handle_data function definition.  | 
            ||
| 155 |             data_frequency : {'daily', 'minute'} | 
            ||
| 156 | The duration of the bars.  | 
            ||
| 157 | capital_base : float <default: 1.0e5>  | 
            ||
| 158 | How much capital to start with.  | 
            ||
| 159 | asset_finder : An AssetFinder object  | 
            ||
| 160 | A new AssetFinder object to be used in this TradingEnvironment  | 
            ||
| 161 | equities_metadata : can be either:  | 
            ||
| 162 | - dict  | 
            ||
| 163 | - pandas.DataFrame  | 
            ||
| 164 | - object with 'read' property  | 
            ||
| 165 | If dict is provided, it must have the following structure:  | 
            ||
| 166 | * keys are the identifiers  | 
            ||
| 167 | * values are dicts containing the metadata, with the metadata  | 
            ||
| 168 | field name as the key  | 
            ||
| 169 | If pandas.DataFrame is provided, it must have the  | 
            ||
| 170 | following structure:  | 
            ||
| 171 | * column names must be the metadata fields  | 
            ||
| 172 | * index must be the different asset identifiers  | 
            ||
| 173 | * array contents should be the metadata value  | 
            ||
| 174 | If an object with a 'read' property is provided, 'read' must  | 
            ||
| 175 | return rows containing at least one of 'sid' or 'symbol' along  | 
            ||
| 176 | with the other metadata fields.  | 
            ||
| 177 | identifiers : List  | 
            ||
| 178 | Any asset identifiers that are not provided in the  | 
            ||
| 179 | equities_metadata, but will be traded by this TradingAlgorithm  | 
            ||
| 180 | """  | 
            ||
| 181 | self.sources = []  | 
            ||
| 182 | |||
| 183 | # List of trading controls to be used to validate orders.  | 
            ||
| 184 | self.trading_controls = []  | 
            ||
| 185 | |||
| 186 | # List of account controls to be checked on each bar.  | 
            ||
| 187 | self.account_controls = []  | 
            ||
| 188 | |||
| 189 |         self._recorded_vars = {} | 
            ||
| 190 |         self.namespace = kwargs.pop('namespace', {}) | 
            ||
| 191 | |||
| 192 |         self._platform = kwargs.pop('platform', 'zipline') | 
            ||
| 193 | |||
| 194 | self.logger = None  | 
            ||
| 195 | |||
| 196 | self.data_portal = None  | 
            ||
| 197 | |||
| 198 | # If an env has been provided, pop it  | 
            ||
| 199 |         self.trading_environment = kwargs.pop('env', None) | 
            ||
| 200 | |||
| 201 | if self.trading_environment is None:  | 
            ||
| 202 | self.trading_environment = TradingEnvironment()  | 
            ||
| 203 | |||
| 204 | # Update the TradingEnvironment with the provided asset metadata  | 
            ||
| 205 | self.trading_environment.write_data(  | 
            ||
| 206 |             equities_data=kwargs.pop('equities_metadata', {}), | 
            ||
| 207 |             equities_identifiers=kwargs.pop('identifiers', []), | 
            ||
| 208 |             futures_data=kwargs.pop('futures_metadata', {}), | 
            ||
| 209 | )  | 
            ||
| 210 | |||
| 211 | # set the capital base  | 
            ||
| 212 |         self.capital_base = kwargs.pop('capital_base', DEFAULT_CAPITAL_BASE) | 
            ||
| 213 |         self.sim_params = kwargs.pop('sim_params', None) | 
            ||
| 214 | if self.sim_params is None:  | 
            ||
| 215 | self.sim_params = create_simulation_parameters(  | 
            ||
| 216 | capital_base=self.capital_base,  | 
            ||
| 217 |                 start=kwargs.pop('start', None), | 
            ||
| 218 |                 end=kwargs.pop('end', None), | 
            ||
| 219 | env=self.trading_environment,  | 
            ||
| 220 | )  | 
            ||
| 221 | else:  | 
            ||
| 222 | self.sim_params.update_internal_from_env(self.trading_environment)  | 
            ||
| 223 | |||
| 224 | self.perf_tracker = None  | 
            ||
| 225 | # Pull in the environment's new AssetFinder for quick reference  | 
            ||
| 226 | self.asset_finder = self.trading_environment.asset_finder  | 
            ||
| 227 | |||
| 228 | # Initialize Pipeline API data.  | 
            ||
| 229 |         self.init_engine(kwargs.pop('get_pipeline_loader', None)) | 
            ||
| 230 |         self._pipelines = {} | 
            ||
| 231 | # Create an always-expired cache so that we compute the first time data  | 
            ||
| 232 | # is requested.  | 
            ||
| 233 | self._pipeline_cache = CachedObject(None, pd.Timestamp(0, tz='UTC'))  | 
            ||
| 234 | |||
| 235 |         self.blotter = kwargs.pop('blotter', None) | 
            ||
| 236 | if not self.blotter:  | 
            ||
| 237 | self.blotter = Blotter(  | 
            ||
| 238 | slippage_func=VolumeShareSlippage(),  | 
            ||
| 239 | commission=PerShare()  | 
            ||
| 240 | )  | 
            ||
| 241 | |||
| 242 | # The symbol lookup date specifies the date to use when resolving  | 
            ||
| 243 | # symbols to sids, and can be set using set_symbol_lookup_date()  | 
            ||
| 244 | self._symbol_lookup_date = None  | 
            ||
| 245 | |||
| 246 | self._portfolio = None  | 
            ||
| 247 | self._account = None  | 
            ||
| 248 | |||
| 249 | # If string is passed in, execute and get reference to  | 
            ||
| 250 | # functions.  | 
            ||
| 251 |         self.algoscript = kwargs.pop('script', None) | 
            ||
| 252 | |||
| 253 | self._initialize = None  | 
            ||
| 254 | self._before_trading_start = None  | 
            ||
| 255 | self._analyze = None  | 
            ||
| 256 | |||
| 257 | self.event_manager = EventManager(  | 
            ||
| 258 |             create_context=kwargs.pop('create_event_context', None), | 
            ||
| 259 | )  | 
            ||
| 260 | |||
| 261 | if self.algoscript is not None:  | 
            ||
| 262 |             filename = kwargs.pop('algo_filename', None) | 
            ||
| 263 | if filename is None:  | 
            ||
| 264 | filename = '<string>'  | 
            ||
| 265 | code = compile(self.algoscript, filename, 'exec')  | 
            ||
| 266 | exec_(code, self.namespace)  | 
            ||
| 267 |             self._initialize = self.namespace.get('initialize') | 
            ||
| 268 | if 'handle_data' not in self.namespace:  | 
            ||
| 269 |                 raise ValueError('You must define a handle_data function.') | 
            ||
| 270 | else:  | 
            ||
| 271 | self._handle_data = self.namespace['handle_data']  | 
            ||
| 272 | |||
| 273 | self._before_trading_start = \  | 
            ||
| 274 |                 self.namespace.get('before_trading_start') | 
            ||
| 275 | # Optional analyze function, gets called after run  | 
            ||
| 276 |             self._analyze = self.namespace.get('analyze') | 
            ||
| 277 | |||
| 278 |         elif kwargs.get('initialize') and kwargs.get('handle_data'): | 
            ||
| 279 | if self.algoscript is not None:  | 
            ||
| 280 |                 raise ValueError('You can not set script and \ | 
            ||
| 281 | initialize/handle_data.')  | 
            ||
| 282 |             self._initialize = kwargs.pop('initialize') | 
            ||
| 283 |             self._handle_data = kwargs.pop('handle_data') | 
            ||
| 284 |             self._before_trading_start = kwargs.pop('before_trading_start', | 
            ||
| 285 | None)  | 
            ||
| 286 |             self._analyze = kwargs.pop('analyze', None) | 
            ||
| 287 | |||
| 288 | self.event_manager.add_event(  | 
            ||
| 289 | zipline.utils.events.Event(  | 
            ||
| 290 | zipline.utils.events.Always(),  | 
            ||
| 291 | # We pass handle_data.__func__ to get the unbound method.  | 
            ||
| 292 | # We will explicitly pass the algorithm to bind it again.  | 
            ||
| 293 | self.handle_data.__func__,  | 
            ||
| 294 | ),  | 
            ||
| 295 | prepend=True,  | 
            ||
| 296 | )  | 
            ||
| 297 | |||
| 298 | # If method not defined, NOOP  | 
            ||
| 299 | if self._initialize is None:  | 
            ||
| 300 | self._initialize = lambda x: None  | 
            ||
| 301 | |||
| 302 | # Alternative way of setting data_frequency for backwards  | 
            ||
| 303 | # compatibility.  | 
            ||
| 304 | if 'data_frequency' in kwargs:  | 
            ||
| 305 |             self.data_frequency = kwargs.pop('data_frequency') | 
            ||
| 306 | |||
| 307 | # Prepare the algo for initialization  | 
            ||
| 308 | self.initialized = False  | 
            ||
| 309 | self.initialize_args = args  | 
            ||
| 310 | self.initialize_kwargs = kwargs  | 
            ||
| 311 | |||
| 312 |         self.benchmark_sid = kwargs.pop('benchmark_sid', None) | 
            ||
| 313 | |||
| 314 | def init_engine(self, get_loader):  | 
            ||
| 315 | """  | 
            ||
| 316 | Construct and store a PipelineEngine from loader.  | 
            ||
| 317 | |||
| 318 | If get_loader is None, constructs a NoOpPipelineEngine.  | 
            ||
| 319 | """  | 
            ||
| 320 | if get_loader is not None:  | 
            ||
| 321 | self.engine = SimplePipelineEngine(  | 
            ||
| 322 | get_loader,  | 
            ||
| 323 | self.trading_environment.trading_days,  | 
            ||
| 324 | self.asset_finder,  | 
            ||
| 325 | )  | 
            ||
| 326 | else:  | 
            ||
| 327 | self.engine = NoOpPipelineEngine()  | 
            ||
| 328 | |||
| 329 | def initialize(self, *args, **kwargs):  | 
            ||
| 330 | """  | 
            ||
| 331 | Call self._initialize with `self` made available to Zipline API  | 
            ||
| 332 | functions.  | 
            ||
| 333 | """  | 
            ||
| 334 | with ZiplineAPI(self):  | 
            ||
| 335 | self._initialize(self, *args, **kwargs)  | 
            ||
| 336 | |||
| 337 | def before_trading_start(self, data):  | 
            ||
| 338 | if self._before_trading_start is None:  | 
            ||
| 339 | return  | 
            ||
| 340 | |||
| 341 | self._before_trading_start(self, data)  | 
            ||
| 342 | |||
| 343 | def handle_data(self, data):  | 
            ||
| 344 | self._handle_data(self, data)  | 
            ||
| 345 | |||
| 346 | # Unlike trading controls which remain constant unless placing an  | 
            ||
| 347 | # order, account controls can change each bar. Thus, must check  | 
            ||
| 348 | # every bar no matter if the algorithm places an order or not.  | 
            ||
| 349 | self.validate_account_controls()  | 
            ||
| 350 | |||
| 351 | def analyze(self, perf):  | 
            ||
| 352 | if self._analyze is None:  | 
            ||
| 353 | return  | 
            ||
| 354 | |||
| 355 | with ZiplineAPI(self):  | 
            ||
| 356 | self._analyze(self, perf)  | 
            ||
| 357 | |||
| 358 | def __repr__(self):  | 
            ||
| 359 | """  | 
            ||
| 360 | N.B. this does not yet represent a string that can be used  | 
            ||
| 361 | to instantiate an exact copy of an algorithm.  | 
            ||
| 362 | |||
| 363 | However, it is getting close, and provides some value as something  | 
            ||
| 364 | that can be inspected interactively.  | 
            ||
| 365 | """  | 
            ||
| 366 | return """  | 
            ||
| 367 | {class_name}( | 
            ||
| 368 |     capital_base={capital_base} | 
            ||
| 369 |     sim_params={sim_params}, | 
            ||
| 370 |     initialized={initialized}, | 
            ||
| 371 |     slippage={slippage}, | 
            ||
| 372 |     commission={commission}, | 
            ||
| 373 |     blotter={blotter}, | 
            ||
| 374 |     recorded_vars={recorded_vars}) | 
            ||
| 375 | """.strip().format(class_name=self.__class__.__name__,  | 
            ||
| 376 | capital_base=self.capital_base,  | 
            ||
| 377 | sim_params=repr(self.sim_params),  | 
            ||
| 378 | initialized=self.initialized,  | 
            ||
| 379 | slippage=repr(self.blotter.slippage_func),  | 
            ||
| 380 | commission=repr(self.blotter.commission),  | 
            ||
| 381 | blotter=repr(self.blotter),  | 
            ||
| 382 | recorded_vars=repr(self.recorded_vars))  | 
            ||
| 383 | |||
| 384 | def _create_clock(self):  | 
            ||
| 385 | """  | 
            ||
| 386 | If the clock property is not set, then create one based on frequency.  | 
            ||
| 387 | """  | 
            ||
| 388 | if self.sim_params.data_frequency == 'minute':  | 
            ||
| 389 | env = self.trading_environment  | 
            ||
| 390 | trading_o_and_c = env.open_and_closes.ix[  | 
            ||
| 391 | self.sim_params.trading_days]  | 
            ||
| 392 | market_opens = trading_o_and_c['market_open'].values.astype(  | 
            ||
| 393 | 'datetime64[ns]').astype(np.int64)  | 
            ||
| 394 | market_closes = trading_o_and_c['market_close'].values.astype(  | 
            ||
| 395 | 'datetime64[ns]').astype(np.int64)  | 
            ||
| 396 | |||
| 397 | minutely_emission = self.sim_params.emission_rate == "minute"  | 
            ||
| 398 | |||
| 399 | clock = MinuteSimulationClock(  | 
            ||
| 400 | self.sim_params.trading_days,  | 
            ||
| 401 | market_opens,  | 
            ||
| 402 | market_closes,  | 
            ||
| 403 | env.trading_days,  | 
            ||
| 404 | minutely_emission  | 
            ||
| 405 | )  | 
            ||
| 406 | self.data_portal.setup_offset_cache(  | 
            ||
| 407 | clock.minutes_by_day,  | 
            ||
| 408 | clock.minutes_to_day)  | 
            ||
| 409 | return clock  | 
            ||
| 410 | else:  | 
            ||
| 411 | return DailySimulationClock(self.sim_params.trading_days)  | 
            ||
| 412 | |||
| 413 | def _create_benchmark_source(self):  | 
            ||
| 414 | return BenchmarkSource(  | 
            ||
| 415 | self.benchmark_sid,  | 
            ||
| 416 | self.trading_environment,  | 
            ||
| 417 | self.sim_params.trading_days,  | 
            ||
| 418 | self.data_portal,  | 
            ||
| 419 | emission_rate=self.sim_params.emission_rate,  | 
            ||
| 420 | )  | 
            ||
| 421 | |||
| 422 | def _create_generator(self, sim_params):  | 
            ||
| 423 | if sim_params is not None:  | 
            ||
| 424 | self.sim_params = sim_params  | 
            ||
| 425 | |||
| 426 | if self.perf_tracker is None:  | 
            ||
| 427 | # HACK: When running with the `run` method, we set perf_tracker to  | 
            ||
| 428 | # None so that it will be overwritten here.  | 
            ||
| 429 | self.perf_tracker = PerformanceTracker(  | 
            ||
| 430 | sim_params=self.sim_params,  | 
            ||
| 431 | env=self.trading_environment,  | 
            ||
| 432 | data_portal=self.data_portal  | 
            ||
| 433 | )  | 
            ||
| 434 | |||
| 435 | # Set the dt initially to the period start by forcing it to change.  | 
            ||
| 436 | self.on_dt_changed(self.sim_params.period_start)  | 
            ||
| 437 | |||
| 438 | if not self.initialized:  | 
            ||
| 439 | self.initialize(*self.initialize_args, **self.initialize_kwargs)  | 
            ||
| 440 | self.initialized = True  | 
            ||
| 441 | |||
| 442 | self.trading_client = AlgorithmSimulator(  | 
            ||
| 443 | self,  | 
            ||
| 444 | sim_params,  | 
            ||
| 445 | self.data_portal,  | 
            ||
| 446 | self._create_clock(),  | 
            ||
| 447 | self._create_benchmark_source()  | 
            ||
| 448 | )  | 
            ||
| 449 | |||
| 450 | return self.trading_client.transform()  | 
            ||
| 451 | |||
| 452 | def get_generator(self):  | 
            ||
| 453 | """  | 
            ||
| 454 | Override this method to add new logic to the construction  | 
            ||
| 455 | of the generator. Overrides can use the _create_generator  | 
            ||
| 456 | method to get a standard construction generator.  | 
            ||
| 457 | """  | 
            ||
| 458 | return self._create_generator(self.sim_params)  | 
            ||
| 459 | |||
| 460 | def run(self, data_portal=None):  | 
            ||
| 461 | """Run the algorithm.  | 
            ||
| 462 | |||
| 463 | :Arguments:  | 
            ||
| 464 | source : DataPortal  | 
            ||
| 465 | |||
| 466 | :Returns:  | 
            ||
| 467 | daily_stats : pandas.DataFrame  | 
            ||
| 468 | Daily performance metrics such as returns, alpha etc.  | 
            ||
| 469 | |||
| 470 | """  | 
            ||
| 471 | self.data_portal = data_portal  | 
            ||
| 472 | |||
| 473 | # Force a reset of the performance tracker, in case  | 
            ||
| 474 | # this is a repeat run of the algorithm.  | 
            ||
| 475 | self.perf_tracker = None  | 
            ||
| 476 | |||
| 477 | # Create zipline and loop through simulated_trading.  | 
            ||
| 478 | # Each iteration returns a perf dictionary  | 
            ||
| 479 | perfs = []  | 
            ||
| 480 | for perf in self.get_generator():  | 
            ||
| 481 | perfs.append(perf)  | 
            ||
| 482 | |||
| 483 | # convert perf dict to pandas dataframe  | 
            ||
| 484 | daily_stats = self._create_daily_stats(perfs)  | 
            ||
| 485 | |||
| 486 | self.analyze(daily_stats)  | 
            ||
| 487 | |||
| 488 | return daily_stats  | 
            ||
| 489 | |||
| 490 | def _write_and_map_id_index_to_sids(self, identifiers, as_of_date):  | 
            ||
| 491 | # Build new Assets for identifiers that can't be resolved as  | 
            ||
| 492 | # sids/Assets  | 
            ||
| 493 | identifiers_to_build = []  | 
            ||
| 494 | for identifier in identifiers:  | 
            ||
| 495 | asset = None  | 
            ||
| 496 | |||
| 497 | if isinstance(identifier, Asset):  | 
            ||
| 498 | asset = self.asset_finder.retrieve_asset(sid=identifier.sid,  | 
            ||
| 499 | default_none=True)  | 
            ||
| 500 | elif isinstance(identifier, Integral):  | 
            ||
| 501 | asset = self.asset_finder.retrieve_asset(sid=identifier,  | 
            ||
| 502 | default_none=True)  | 
            ||
| 503 | if asset is None:  | 
            ||
| 504 | identifiers_to_build.append(identifier)  | 
            ||
| 505 | |||
| 506 | self.trading_environment.write_data(  | 
            ||
| 507 | equities_identifiers=identifiers_to_build)  | 
            ||
| 508 | |||
| 509 | # We need to clear out any cache misses that were stored while trying  | 
            ||
| 510 | # to do lookups. The real fix for this problem is to not construct an  | 
            ||
| 511 | # AssetFinder until we `run()` when we actually have all the data we  | 
            ||
| 512 | # need to so.  | 
            ||
| 513 | self.asset_finder._reset_caches()  | 
            ||
| 514 | |||
| 515 | return self.asset_finder.map_identifier_index_to_sids(  | 
            ||
| 516 | identifiers, as_of_date,  | 
            ||
| 517 | )  | 
            ||
| 518 | |||
| 519 | def _create_daily_stats(self, perfs):  | 
            ||
| 520 | # create daily and cumulative stats dataframe  | 
            ||
| 521 | daily_perfs = []  | 
            ||
| 522 | # TODO: the loop here could overwrite expected properties  | 
            ||
| 523 | # of daily_perf. Could potentially raise or log a  | 
            ||
| 524 | # warning.  | 
            ||
| 525 | for perf in perfs:  | 
            ||
| 526 | if 'daily_perf' in perf:  | 
            ||
| 527 | |||
| 528 | perf['daily_perf'].update(  | 
            ||
| 529 |                     perf['daily_perf'].pop('recorded_vars') | 
            ||
| 530 | )  | 
            ||
| 531 | perf['daily_perf'].update(perf['cumulative_risk_metrics'])  | 
            ||
| 532 | daily_perfs.append(perf['daily_perf'])  | 
            ||
| 533 | else:  | 
            ||
| 534 | self.risk_report = perf  | 
            ||
| 535 | |||
| 536 | daily_dts = [np.datetime64(perf['period_close'], utc=True)  | 
            ||
| 537 | for perf in daily_perfs]  | 
            ||
| 538 | daily_stats = pd.DataFrame(daily_perfs, index=daily_dts)  | 
            ||
| 539 | |||
| 540 | return daily_stats  | 
            ||
| 541 | |||
| 542 | @api_method  | 
            ||
| 543 | def get_environment(self, field='platform'):  | 
            ||
| 544 |         env = { | 
            ||
| 545 | 'arena': self.sim_params.arena,  | 
            ||
| 546 | 'data_frequency': self.sim_params.data_frequency,  | 
            ||
| 547 | 'start': self.sim_params.first_open,  | 
            ||
| 548 | 'end': self.sim_params.last_close,  | 
            ||
| 549 | 'capital_base': self.sim_params.capital_base,  | 
            ||
| 550 | 'platform': self._platform  | 
            ||
| 551 | }  | 
            ||
| 552 | if field == '*':  | 
            ||
| 553 | return env  | 
            ||
| 554 | else:  | 
            ||
| 555 | return env[field]  | 
            ||
| 556 | |||
| 557 | @api_method  | 
            ||
| 558 | def fetch_csv(self, url,  | 
            ||
| 559 | pre_func=None,  | 
            ||
| 560 | post_func=None,  | 
            ||
| 561 | date_column='date',  | 
            ||
| 562 | date_format=None,  | 
            ||
| 563 | timezone=pytz.utc.zone,  | 
            ||
| 564 | symbol=None,  | 
            ||
| 565 | mask=True,  | 
            ||
| 566 | symbol_column=None,  | 
            ||
| 567 | special_params_checker=None,  | 
            ||
| 568 | **kwargs):  | 
            ||
| 569 | |||
| 570 | # Show all the logs every time fetcher is used.  | 
            ||
| 571 | csv_data_source = PandasRequestsCSV(  | 
            ||
| 572 | url,  | 
            ||
| 573 | pre_func,  | 
            ||
| 574 | post_func,  | 
            ||
| 575 | self.trading_environment,  | 
            ||
| 576 | self.sim_params.period_start,  | 
            ||
| 577 | self.sim_params.period_end,  | 
            ||
| 578 | date_column,  | 
            ||
| 579 | date_format,  | 
            ||
| 580 | timezone,  | 
            ||
| 581 | symbol,  | 
            ||
| 582 | mask,  | 
            ||
| 583 | symbol_column,  | 
            ||
| 584 | data_frequency=self.data_frequency,  | 
            ||
| 585 | special_params_checker=special_params_checker,  | 
            ||
| 586 | **kwargs  | 
            ||
| 587 | )  | 
            ||
| 588 | |||
| 589 | # ingest this into dataportal  | 
            ||
| 590 | self.data_portal.handle_extra_source(csv_data_source.df)  | 
            ||
| 591 | |||
| 592 | return csv_data_source  | 
            ||
| 593 | |||
| 594 | def add_event(self, rule=None, callback=None):  | 
            ||
| 595 | """  | 
            ||
| 596 | Adds an event to the algorithm's EventManager.  | 
            ||
| 597 | """  | 
            ||
| 598 | self.event_manager.add_event(  | 
            ||
| 599 | zipline.utils.events.Event(rule, callback),  | 
            ||
| 600 | )  | 
            ||
| 601 | |||
| 602 | @api_method  | 
            ||
| 603 | def schedule_function(self,  | 
            ||
| 604 | func,  | 
            ||
| 605 | date_rule=None,  | 
            ||
| 606 | time_rule=None,  | 
            ||
| 607 | half_days=True):  | 
            ||
| 608 | """  | 
            ||
| 609 | Schedules a function to be called with some timed rules.  | 
            ||
| 610 | """  | 
            ||
| 611 | date_rule = date_rule or DateRuleFactory.every_day()  | 
            ||
| 612 | time_rule = ((time_rule or TimeRuleFactory.market_open())  | 
            ||
| 613 | if self.sim_params.data_frequency == 'minute' else  | 
            ||
| 614 | # If we are in daily mode the time_rule is ignored.  | 
            ||
| 615 | zipline.utils.events.Always())  | 
            ||
| 616 | |||
| 617 | self.add_event(  | 
            ||
| 618 | make_eventrule(date_rule, time_rule, half_days),  | 
            ||
| 619 | func,  | 
            ||
| 620 | )  | 
            ||
| 621 | |||
| 622 | @api_method  | 
            ||
| 623 | def record(self, *args, **kwargs):  | 
            ||
| 624 | """  | 
            ||
| 625 | Track and record local variable (i.e. attributes) each day.  | 
            ||
| 626 | """  | 
            ||
| 627 | # Make 2 objects both referencing the same iterator  | 
            ||
| 628 | args = [iter(args)] * 2  | 
            ||
| 629 | |||
| 630 | # Zip generates list entries by calling `next` on each iterator it  | 
            ||
| 631 | # receives. In this case the two iterators are the same object, so the  | 
            ||
| 632 | # call to next on args[0] will also advance args[1], resulting in zip  | 
            ||
| 633 | # returning (a,b) (c,d) (e,f) rather than (a,a) (b,b) (c,c) etc.  | 
            ||
| 634 | positionals = zip(*args)  | 
            ||
| 635 | for name, value in chain(positionals, iteritems(kwargs)):  | 
            ||
| 636 | self._recorded_vars[name] = value  | 
            ||
| 637 | |||
| 638 | @api_method  | 
            ||
| 639 | def set_benchmark(self, benchmark_sid):  | 
            ||
| 640 | if self.initialized:  | 
            ||
| 641 | raise SetBenchmarkOutsideInitialize()  | 
            ||
| 642 | |||
| 643 | self.benchmark_sid = benchmark_sid  | 
            ||
| 644 | |||
| 645 | @api_method  | 
            ||
| 646 | @preprocess(symbol_str=ensure_upper_case)  | 
            ||
| 647 | def symbol(self, symbol_str):  | 
            ||
| 648 | """  | 
            ||
| 649 | Default symbol lookup for any source that directly maps the  | 
            ||
| 650 | symbol to the Asset (e.g. yahoo finance).  | 
            ||
| 651 | """  | 
            ||
| 652 | # If the user has not set the symbol lookup date,  | 
            ||
| 653 | # use the period_end as the date for sybmol->sid resolution.  | 
            ||
| 654 | _lookup_date = self._symbol_lookup_date if self._symbol_lookup_date is not None \  | 
            ||
| 655 | else self.sim_params.period_end  | 
            ||
| 656 | |||
| 657 | return self.asset_finder.lookup_symbol(  | 
            ||
| 658 | symbol_str,  | 
            ||
| 659 | as_of_date=_lookup_date,  | 
            ||
| 660 | )  | 
            ||
| 661 | |||
| 662 | @api_method  | 
            ||
| 663 | def symbols(self, *args):  | 
            ||
| 664 | """  | 
            ||
| 665 | Default symbols lookup for any source that directly maps the  | 
            ||
| 666 | symbol to the Asset (e.g. yahoo finance).  | 
            ||
| 667 | """  | 
            ||
| 668 | return [self.symbol(identifier) for identifier in args]  | 
            ||
| 669 | |||
| 670 | @api_method  | 
            ||
| 671 | def sid(self, a_sid):  | 
            ||
| 672 | """  | 
            ||
| 673 | Default sid lookup for any source that directly maps the integer sid  | 
            ||
| 674 | to the Asset.  | 
            ||
| 675 | """  | 
            ||
| 676 | return self.asset_finder.retrieve_asset(a_sid)  | 
            ||
| 677 | |||
| 678 | @api_method  | 
            ||
| 679 | @preprocess(symbol=ensure_upper_case)  | 
            ||
| 680 | def future_symbol(self, symbol):  | 
            ||
| 681 | """ Lookup a futures contract with a given symbol.  | 
            ||
| 682 | |||
| 683 | Parameters  | 
            ||
| 684 | ----------  | 
            ||
| 685 | symbol : str  | 
            ||
| 686 | The symbol of the desired contract.  | 
            ||
| 687 | |||
| 688 | Returns  | 
            ||
| 689 | -------  | 
            ||
| 690 | Future  | 
            ||
| 691 | A Future object.  | 
            ||
| 692 | |||
| 693 | Raises  | 
            ||
| 694 | ------  | 
            ||
| 695 | SymbolNotFound  | 
            ||
| 696 | Raised when no contract named 'symbol' is found.  | 
            ||
| 697 | |||
| 698 | """  | 
            ||
| 699 | return self.asset_finder.lookup_future_symbol(symbol)  | 
            ||
| 700 | |||
| 701 | @api_method  | 
            ||
| 702 | @preprocess(root_symbol=ensure_upper_case)  | 
            ||
| 703 | def future_chain(self, root_symbol, as_of_date=None):  | 
            ||
| 704 | """ Look up a future chain with the specified parameters.  | 
            ||
| 705 | |||
| 706 | Parameters  | 
            ||
| 707 | ----------  | 
            ||
| 708 | root_symbol : str  | 
            ||
| 709 | The root symbol of a future chain.  | 
            ||
| 710 | as_of_date : datetime.datetime or pandas.Timestamp or str, optional  | 
            ||
| 711 | Date at which the chain determination is rooted. I.e. the  | 
            ||
| 712 | existing contract whose notice date is first after this date is  | 
            ||
| 713 | the primary contract, etc.  | 
            ||
| 714 | |||
| 715 | Returns  | 
            ||
| 716 | -------  | 
            ||
| 717 | FutureChain  | 
            ||
| 718 | The future chain matching the specified parameters.  | 
            ||
| 719 | |||
| 720 | Raises  | 
            ||
| 721 | ------  | 
            ||
| 722 | RootSymbolNotFound  | 
            ||
| 723 | If a future chain could not be found for the given root symbol.  | 
            ||
| 724 | """  | 
            ||
| 725 | if as_of_date:  | 
            ||
| 726 | try:  | 
            ||
| 727 | as_of_date = pd.Timestamp(as_of_date, tz='UTC')  | 
            ||
| 728 | except ValueError:  | 
            ||
| 729 | raise UnsupportedDatetimeFormat(input=as_of_date,  | 
            ||
| 730 | method='future_chain')  | 
            ||
| 731 | return FutureChain(  | 
            ||
| 732 | asset_finder=self.asset_finder,  | 
            ||
| 733 | get_datetime=self.get_datetime,  | 
            ||
| 734 | root_symbol=root_symbol,  | 
            ||
| 735 | as_of_date=as_of_date  | 
            ||
| 736 | )  | 
            ||
| 737 | |||
| 738 | def _calculate_order_value_amount(self, asset, value):  | 
            ||
| 739 | """  | 
            ||
| 740 | Calculates how many shares/contracts to order based on the type of  | 
            ||
| 741 | asset being ordered.  | 
            ||
| 742 | """  | 
            ||
| 743 | last_price = self.trading_client.current_data[asset].price  | 
            ||
| 744 | |||
| 745 | if tolerant_equals(last_price, 0):  | 
            ||
| 746 |             zero_message = "Price of 0 for {psid}; can't infer value".format( | 
            ||
| 747 | psid=asset  | 
            ||
| 748 | )  | 
            ||
| 749 | if self.logger:  | 
            ||
| 750 | self.logger.debug(zero_message)  | 
            ||
| 751 | # Don't place any order  | 
            ||
| 752 | return 0  | 
            ||
| 753 | |||
| 754 | if isinstance(asset, Future):  | 
            ||
| 755 | value_multiplier = asset.contract_multiplier  | 
            ||
| 756 | else:  | 
            ||
| 757 | value_multiplier = 1  | 
            ||
| 758 | |||
| 759 | return value / (last_price * value_multiplier)  | 
            ||
| 760 | |||
| 761 | @api_method  | 
            ||
| 762 | def order(self, sid, amount,  | 
            ||
| 763 | limit_price=None,  | 
            ||
| 764 | stop_price=None,  | 
            ||
| 765 | style=None):  | 
            ||
| 766 | """  | 
            ||
| 767 | Place an order using the specified parameters.  | 
            ||
| 768 | """  | 
            ||
| 769 | # Truncate to the integer share count that's either within .0001 of  | 
            ||
| 770 | # amount or closer to zero.  | 
            ||
| 771 | # E.g. 3.9999 -> 4.0; 5.5 -> 5.0; -5.5 -> -5.0  | 
            ||
| 772 | amount = int(round_if_near_integer(amount))  | 
            ||
| 773 | |||
| 774 | # Raises a ZiplineError if invalid parameters are detected.  | 
            ||
| 775 | self.validate_order_params(sid,  | 
            ||
| 776 | amount,  | 
            ||
| 777 | limit_price,  | 
            ||
| 778 | stop_price,  | 
            ||
| 779 | style)  | 
            ||
| 780 | |||
| 781 | # Convert deprecated limit_price and stop_price parameters to use  | 
            ||
| 782 | # ExecutionStyle objects.  | 
            ||
| 783 | style = self.__convert_order_params_for_blotter(limit_price,  | 
            ||
| 784 | stop_price,  | 
            ||
| 785 | style)  | 
            ||
| 786 | return self.blotter.order(sid, amount, style)  | 
            ||
| 787 | |||
| 788 | def validate_order_params(self,  | 
            ||
| 789 | asset,  | 
            ||
| 790 | amount,  | 
            ||
| 791 | limit_price,  | 
            ||
| 792 | stop_price,  | 
            ||
| 793 | style):  | 
            ||
| 794 | """  | 
            ||
| 795 | Helper method for validating parameters to the order API function.  | 
            ||
| 796 | |||
| 797 | Raises an UnsupportedOrderParameters if invalid arguments are found.  | 
            ||
| 798 | """  | 
            ||
| 799 | |||
| 800 | if not self.initialized:  | 
            ||
| 801 | raise OrderDuringInitialize(  | 
            ||
| 802 | msg="order() can only be called from within handle_data()"  | 
            ||
| 803 | )  | 
            ||
| 804 | |||
| 805 | if style:  | 
            ||
| 806 | if limit_price:  | 
            ||
| 807 | raise UnsupportedOrderParameters(  | 
            ||
| 808 | msg="Passing both limit_price and style is not supported."  | 
            ||
| 809 | )  | 
            ||
| 810 | |||
| 811 | if stop_price:  | 
            ||
| 812 | raise UnsupportedOrderParameters(  | 
            ||
| 813 | msg="Passing both stop_price and style is not supported."  | 
            ||
| 814 | )  | 
            ||
| 815 | |||
| 816 | if not isinstance(asset, Asset):  | 
            ||
| 817 | raise UnsupportedOrderParameters(  | 
            ||
| 818 | msg="Passing non-Asset argument to 'order()' is not supported."  | 
            ||
| 819 | " Use 'sid()' or 'symbol()' methods to look up an Asset."  | 
            ||
| 820 | )  | 
            ||
| 821 | |||
| 822 | for control in self.trading_controls:  | 
            ||
| 823 | control.validate(asset,  | 
            ||
| 824 | amount,  | 
            ||
| 825 | self.portfolio,  | 
            ||
| 826 | self.get_datetime(),  | 
            ||
| 827 | self.trading_client.current_data)  | 
            ||
| 828 | |||
| 829 | @staticmethod  | 
            ||
| 830 | def __convert_order_params_for_blotter(limit_price, stop_price, style):  | 
            ||
| 831 | """  | 
            ||
| 832 | Helper method for converting deprecated limit_price and stop_price  | 
            ||
| 833 | arguments into ExecutionStyle instances.  | 
            ||
| 834 | |||
| 835 | This function assumes that either style == None or (limit_price,  | 
            ||
| 836 | stop_price) == (None, None).  | 
            ||
| 837 | """  | 
            ||
| 838 | # TODO_SS: DeprecationWarning for usage of limit_price and stop_price.  | 
            ||
| 839 | if style:  | 
            ||
| 840 | assert (limit_price, stop_price) == (None, None)  | 
            ||
| 841 | return style  | 
            ||
| 842 | if limit_price and stop_price:  | 
            ||
| 843 | return StopLimitOrder(limit_price, stop_price)  | 
            ||
| 844 | if limit_price:  | 
            ||
| 845 | return LimitOrder(limit_price)  | 
            ||
| 846 | if stop_price:  | 
            ||
| 847 | return StopOrder(stop_price)  | 
            ||
| 848 | else:  | 
            ||
| 849 | return MarketOrder()  | 
            ||
| 850 | |||
| 851 | @api_method  | 
            ||
| 852 | def order_value(self, sid, value,  | 
            ||
| 853 | limit_price=None, stop_price=None, style=None):  | 
            ||
| 854 | """  | 
            ||
| 855 | Place an order by desired value rather than desired number of shares.  | 
            ||
| 856 | If the requested sid exists, the requested value is  | 
            ||
| 857 | divided by its price to imply the number of shares to transact.  | 
            ||
| 858 | If the Asset being ordered is a Future, the 'value' calculated  | 
            ||
| 859 | is actually the exposure, as Futures have no 'value'.  | 
            ||
| 860 | |||
| 861 | value > 0 :: Buy/Cover  | 
            ||
| 862 | value < 0 :: Sell/Short  | 
            ||
| 863 | Market order: order(sid, value)  | 
            ||
| 864 | Limit order: order(sid, value, limit_price)  | 
            ||
| 865 | Stop order: order(sid, value, None, stop_price)  | 
            ||
| 866 | StopLimit order: order(sid, value, limit_price, stop_price)  | 
            ||
| 867 | """  | 
            ||
| 868 | amount = self._calculate_order_value_amount(sid, value)  | 
            ||
| 869 | return self.order(sid, amount,  | 
            ||
| 870 | limit_price=limit_price,  | 
            ||
| 871 | stop_price=stop_price,  | 
            ||
| 872 | style=style)  | 
            ||
| 873 | |||
| 874 | @property  | 
            ||
| 875 | def recorded_vars(self):  | 
            ||
| 876 | return copy(self._recorded_vars)  | 
            ||
| 877 | |||
| 878 | @property  | 
            ||
| 879 | def portfolio(self):  | 
            ||
| 880 | return self.updated_portfolio()  | 
            ||
| 881 | |||
| 882 | def updated_portfolio(self):  | 
            ||
| 883 | if self._portfolio is None and self.perf_tracker is not None:  | 
            ||
| 884 | self._portfolio = \  | 
            ||
| 885 | self.perf_tracker.get_portfolio(self.datetime)  | 
            ||
| 886 | return self._portfolio  | 
            ||
| 887 | |||
| 888 | @property  | 
            ||
| 889 | def account(self):  | 
            ||
| 890 | return self.updated_account()  | 
            ||
| 891 | |||
| 892 | def updated_account(self):  | 
            ||
| 893 | if self._account is None and self.perf_tracker is not None:  | 
            ||
| 894 | self._account = \  | 
            ||
| 895 | self.perf_tracker.get_account(self.datetime)  | 
            ||
| 896 | return self._account  | 
            ||
| 897 | |||
| 898 | def set_logger(self, logger):  | 
            ||
| 899 | self.logger = logger  | 
            ||
| 900 | |||
| 901 | def on_dt_changed(self, dt):  | 
            ||
| 902 | """  | 
            ||
| 903 | Callback triggered by the simulation loop whenever the current dt  | 
            ||
| 904 | changes.  | 
            ||
| 905 | |||
| 906 | Any logic that should happen exactly once at the start of each datetime  | 
            ||
| 907 | group should happen here.  | 
            ||
| 908 | """  | 
            ||
| 909 | assert isinstance(dt, datetime), \  | 
            ||
| 910 | "Attempt to set algorithm's current time with non-datetime"  | 
            ||
| 911 | assert dt.tzinfo == pytz.utc, \  | 
            ||
| 912 | "Algorithm expects a utc datetime"  | 
            ||
| 913 | |||
| 914 | self.datetime = dt  | 
            ||
| 915 | self.perf_tracker.set_date(dt)  | 
            ||
| 916 | self.blotter.set_date(dt)  | 
            ||
| 917 | |||
| 918 | self._portfolio = None  | 
            ||
| 919 | self._account = None  | 
            ||
| 920 | |||
| 921 | @api_method  | 
            ||
| 922 | def get_datetime(self, tz=None):  | 
            ||
| 923 | """  | 
            ||
| 924 | Returns the simulation datetime.  | 
            ||
| 925 | """  | 
            ||
| 926 | dt = self.datetime  | 
            ||
| 927 | assert dt.tzinfo == pytz.utc, "Algorithm should have a utc datetime"  | 
            ||
| 928 | |||
| 929 | if tz is not None:  | 
            ||
| 930 | # Convert to the given timezone passed as a string or tzinfo.  | 
            ||
| 931 | if isinstance(tz, string_types):  | 
            ||
| 932 | tz = pytz.timezone(tz)  | 
            ||
| 933 | dt = dt.astimezone(tz)  | 
            ||
| 934 | |||
| 935 | return dt # datetime.datetime objects are immutable.  | 
            ||
| 936 | |||
| 937 | def update_dividends(self, dividend_frame):  | 
            ||
| 938 | """  | 
            ||
| 939 | Set DataFrame used to process dividends. DataFrame columns should  | 
            ||
| 940 | contain at least the entries in zp.DIVIDEND_FIELDS.  | 
            ||
| 941 | """  | 
            ||
| 942 | self.perf_tracker.update_dividends(dividend_frame)  | 
            ||
| 943 | |||
| 944 | @api_method  | 
            ||
| 945 | def set_slippage(self, slippage):  | 
            ||
| 946 | if not isinstance(slippage, SlippageModel):  | 
            ||
| 947 | raise UnsupportedSlippageModel()  | 
            ||
| 948 | if self.initialized:  | 
            ||
| 949 | raise OverrideSlippagePostInit()  | 
            ||
| 950 | self.blotter.slippage_func = slippage  | 
            ||
| 951 | |||
| 952 | @api_method  | 
            ||
| 953 | def set_commission(self, commission):  | 
            ||
| 954 | if not isinstance(commission, (PerShare, PerTrade, PerDollar)):  | 
            ||
| 955 | raise UnsupportedCommissionModel()  | 
            ||
| 956 | |||
| 957 | if self.initialized:  | 
            ||
| 958 | raise OverrideCommissionPostInit()  | 
            ||
| 959 | self.blotter.commission = commission  | 
            ||
| 960 | |||
| 961 | @api_method  | 
            ||
| 962 | def set_symbol_lookup_date(self, dt):  | 
            ||
| 963 | """  | 
            ||
| 964 | Set the date for which symbols will be resolved to their sids  | 
            ||
| 965 | (symbols may map to different firms or underlying assets at  | 
            ||
| 966 | different times)  | 
            ||
| 967 | """  | 
            ||
| 968 | try:  | 
            ||
| 969 | self._symbol_lookup_date = pd.Timestamp(dt, tz='UTC')  | 
            ||
| 970 | except ValueError:  | 
            ||
| 971 | raise UnsupportedDatetimeFormat(input=dt,  | 
            ||
| 972 | method='set_symbol_lookup_date')  | 
            ||
| 973 | |||
| 974 | # Remain backwards compatibility  | 
            ||
| 975 | @property  | 
            ||
| 976 | def data_frequency(self):  | 
            ||
| 977 | return self.sim_params.data_frequency  | 
            ||
| 978 | |||
| 979 | @data_frequency.setter  | 
            ||
| 980 | def data_frequency(self, value):  | 
            ||
| 981 |         assert value in ('daily', 'minute') | 
            ||
| 982 | self.sim_params.data_frequency = value  | 
            ||
| 983 | |||
| 984 | @api_method  | 
            ||
| 985 | def order_percent(self, sid, percent,  | 
            ||
| 986 | limit_price=None, stop_price=None, style=None):  | 
            ||
| 987 | """  | 
            ||
| 988 | Place an order in the specified asset corresponding to the given  | 
            ||
| 989 | percent of the current portfolio value.  | 
            ||
| 990 | |||
| 991 | Note that percent must expressed as a decimal (0.50 means 50\%).  | 
            ||
| 992 | """  | 
            ||
| 993 | value = self.portfolio.portfolio_value * percent  | 
            ||
| 994 | return self.order_value(sid, value,  | 
            ||
| 995 | limit_price=limit_price,  | 
            ||
| 996 | stop_price=stop_price,  | 
            ||
| 997 | style=style)  | 
            ||
| 998 | |||
| 999 | @api_method  | 
            ||
| 1000 | def order_target(self, sid, target,  | 
            ||
| 1001 | limit_price=None, stop_price=None, style=None):  | 
            ||
| 1002 | """  | 
            ||
| 1003 | Place an order to adjust a position to a target number of shares. If  | 
            ||
| 1004 | the position doesn't already exist, this is equivalent to placing a new  | 
            ||
| 1005 | order. If the position does exist, this is equivalent to placing an  | 
            ||
| 1006 | order for the difference between the target number of shares and the  | 
            ||
| 1007 | current number of shares.  | 
            ||
| 1008 | """  | 
            ||
| 1009 | if sid in self.portfolio.positions:  | 
            ||
| 1010 | current_position = self.portfolio.positions[sid].amount  | 
            ||
| 1011 | req_shares = target - current_position  | 
            ||
| 1012 | return self.order(sid, req_shares,  | 
            ||
| 1013 | limit_price=limit_price,  | 
            ||
| 1014 | stop_price=stop_price,  | 
            ||
| 1015 | style=style)  | 
            ||
| 1016 | else:  | 
            ||
| 1017 | return self.order(sid, target,  | 
            ||
| 1018 | limit_price=limit_price,  | 
            ||
| 1019 | stop_price=stop_price,  | 
            ||
| 1020 | style=style)  | 
            ||
| 1021 | |||
| 1022 | @api_method  | 
            ||
| 1023 | def order_target_value(self, sid, target,  | 
            ||
| 1024 | limit_price=None, stop_price=None, style=None):  | 
            ||
| 1025 | """  | 
            ||
| 1026 | Place an order to adjust a position to a target value. If  | 
            ||
| 1027 | the position doesn't already exist, this is equivalent to placing a new  | 
            ||
| 1028 | order. If the position does exist, this is equivalent to placing an  | 
            ||
| 1029 | order for the difference between the target value and the  | 
            ||
| 1030 | current value.  | 
            ||
| 1031 | If the Asset being ordered is a Future, the 'target value' calculated  | 
            ||
| 1032 | is actually the target exposure, as Futures have no 'value'.  | 
            ||
| 1033 | """  | 
            ||
| 1034 | target_amount = self._calculate_order_value_amount(sid, target)  | 
            ||
| 1035 | return self.order_target(sid, target_amount,  | 
            ||
| 1036 | limit_price=limit_price,  | 
            ||
| 1037 | stop_price=stop_price,  | 
            ||
| 1038 | style=style)  | 
            ||
| 1039 | |||
| 1040 | @api_method  | 
            ||
| 1041 | def order_target_percent(self, sid, target,  | 
            ||
| 1042 | limit_price=None, stop_price=None, style=None):  | 
            ||
| 1043 | """  | 
            ||
| 1044 | Place an order to adjust a position to a target percent of the  | 
            ||
| 1045 | current portfolio value. If the position doesn't already exist, this is  | 
            ||
| 1046 | equivalent to placing a new order. If the position does exist, this is  | 
            ||
| 1047 | equivalent to placing an order for the difference between the target  | 
            ||
| 1048 | percent and the current percent.  | 
            ||
| 1049 | |||
| 1050 | Note that target must expressed as a decimal (0.50 means 50\%).  | 
            ||
| 1051 | """  | 
            ||
| 1052 | target_value = self.portfolio.portfolio_value * target  | 
            ||
| 1053 | return self.order_target_value(sid, target_value,  | 
            ||
| 1054 | limit_price=limit_price,  | 
            ||
| 1055 | stop_price=stop_price,  | 
            ||
| 1056 | style=style)  | 
            ||
| 1057 | |||
| 1058 | @api_method  | 
            ||
| 1059 | def get_open_orders(self, sid=None):  | 
            ||
| 1060 | if sid is None:  | 
            ||
| 1061 |             return { | 
            ||
| 1062 | key: [order.to_api_obj() for order in orders]  | 
            ||
| 1063 | for key, orders in iteritems(self.blotter.open_orders)  | 
            ||
| 1064 | if orders  | 
            ||
| 1065 | }  | 
            ||
| 1066 | if sid in self.blotter.open_orders:  | 
            ||
| 1067 | orders = self.blotter.open_orders[sid]  | 
            ||
| 1068 | return [order.to_api_obj() for order in orders]  | 
            ||
| 1069 | return []  | 
            ||
| 1070 | |||
| 1071 | @api_method  | 
            ||
| 1072 | def get_order(self, order_id):  | 
            ||
| 1073 | if order_id in self.blotter.orders:  | 
            ||
| 1074 | return self.blotter.orders[order_id].to_api_obj()  | 
            ||
| 1075 | |||
| 1076 | @api_method  | 
            ||
| 1077 | def cancel_order(self, order_param):  | 
            ||
| 1078 | order_id = order_param  | 
            ||
| 1079 | if isinstance(order_param, zipline.protocol.Order):  | 
            ||
| 1080 | order_id = order_param.id  | 
            ||
| 1081 | |||
| 1082 | self.blotter.cancel(order_id)  | 
            ||
| 1083 | |||
| 1084 | @api_method  | 
            ||
| 1085 | @require_initialized(HistoryInInitialize())  | 
            ||
| 1086 | def history(self, sids, bar_count, frequency, field, ffill=True):  | 
            ||
| 1087 | if self.data_portal is None:  | 
            ||
| 1088 |             raise Exception("no data portal!") | 
            ||
| 1089 | |||
| 1090 | return self.data_portal.get_history_window(  | 
            ||
| 1091 | sids,  | 
            ||
| 1092 | self.datetime,  | 
            ||
| 1093 | bar_count,  | 
            ||
| 1094 | frequency,  | 
            ||
| 1095 | field,  | 
            ||
| 1096 | ffill,  | 
            ||
| 1097 | )  | 
            ||
| 1098 | |||
| 1099 | ####################  | 
            ||
| 1100 | # Account Controls #  | 
            ||
| 1101 | ####################  | 
            ||
| 1102 | |||
| 1103 | def register_account_control(self, control):  | 
            ||
| 1104 | """  | 
            ||
| 1105 | Register a new AccountControl to be checked on each bar.  | 
            ||
| 1106 | """  | 
            ||
| 1107 | if self.initialized:  | 
            ||
| 1108 | raise RegisterAccountControlPostInit()  | 
            ||
| 1109 | self.account_controls.append(control)  | 
            ||
| 1110 | |||
| 1111 | def validate_account_controls(self):  | 
            ||
| 1112 | for control in self.account_controls:  | 
            ||
| 1113 | control.validate(self.portfolio,  | 
            ||
| 1114 | self.account,  | 
            ||
| 1115 | self.get_datetime(),  | 
            ||
| 1116 | self.trading_client.current_data)  | 
            ||
| 1117 | |||
| 1118 | @api_method  | 
            ||
| 1119 | def set_max_leverage(self, max_leverage=None):  | 
            ||
| 1120 | """  | 
            ||
| 1121 | Set a limit on the maximum leverage of the algorithm.  | 
            ||
| 1122 | """  | 
            ||
| 1123 | control = MaxLeverage(max_leverage)  | 
            ||
| 1124 | self.register_account_control(control)  | 
            ||
| 1125 | |||
| 1126 | ####################  | 
            ||
| 1127 | # Trading Controls #  | 
            ||
| 1128 | ####################  | 
            ||
| 1129 | |||
| 1130 | def register_trading_control(self, control):  | 
            ||
| 1131 | """  | 
            ||
| 1132 | Register a new TradingControl to be checked prior to order calls.  | 
            ||
| 1133 | """  | 
            ||
| 1134 | if self.initialized:  | 
            ||
| 1135 | raise RegisterTradingControlPostInit()  | 
            ||
| 1136 | self.trading_controls.append(control)  | 
            ||
| 1137 | |||
| 1138 | @api_method  | 
            ||
| 1139 | def set_max_position_size(self,  | 
            ||
| 1140 | sid=None,  | 
            ||
| 1141 | max_shares=None,  | 
            ||
| 1142 | max_notional=None):  | 
            ||
| 1143 | """  | 
            ||
| 1144 | Set a limit on the number of shares and/or dollar value held for the  | 
            ||
| 1145 | given sid. Limits are treated as absolute values and are enforced at  | 
            ||
| 1146 | the time that the algo attempts to place an order for sid. This means  | 
            ||
| 1147 | that it's possible to end up with more than the max number of shares  | 
            ||
| 1148 | due to splits/dividends, and more than the max notional due to price  | 
            ||
| 1149 | improvement.  | 
            ||
| 1150 | |||
| 1151 | If an algorithm attempts to place an order that would result in  | 
            ||
| 1152 | increasing the absolute value of shares/dollar value exceeding one of  | 
            ||
| 1153 | these limits, raise a TradingControlException.  | 
            ||
| 1154 | """  | 
            ||
| 1155 | control = MaxPositionSize(asset=sid,  | 
            ||
| 1156 | max_shares=max_shares,  | 
            ||
| 1157 | max_notional=max_notional)  | 
            ||
| 1158 | self.register_trading_control(control)  | 
            ||
| 1159 | |||
| 1160 | @api_method  | 
            ||
| 1161 | def set_max_order_size(self, sid=None, max_shares=None, max_notional=None):  | 
            ||
| 1162 | """  | 
            ||
| 1163 | Set a limit on the number of shares and/or dollar value of any single  | 
            ||
| 1164 | order placed for sid. Limits are treated as absolute values and are  | 
            ||
| 1165 | enforced at the time that the algo attempts to place an order for sid.  | 
            ||
| 1166 | |||
| 1167 | If an algorithm attempts to place an order that would result in  | 
            ||
| 1168 | exceeding one of these limits, raise a TradingControlException.  | 
            ||
| 1169 | """  | 
            ||
| 1170 | control = MaxOrderSize(asset=sid,  | 
            ||
| 1171 | max_shares=max_shares,  | 
            ||
| 1172 | max_notional=max_notional)  | 
            ||
| 1173 | self.register_trading_control(control)  | 
            ||
| 1174 | |||
| 1175 | @api_method  | 
            ||
| 1176 | def set_max_order_count(self, max_count):  | 
            ||
| 1177 | """  | 
            ||
| 1178 | Set a limit on the number of orders that can be placed within the given  | 
            ||
| 1179 | time interval.  | 
            ||
| 1180 | """  | 
            ||
| 1181 | control = MaxOrderCount(max_count)  | 
            ||
| 1182 | self.register_trading_control(control)  | 
            ||
| 1183 | |||
| 1184 | @api_method  | 
            ||
| 1185 | def set_do_not_order_list(self, restricted_list):  | 
            ||
| 1186 | """  | 
            ||
| 1187 | Set a restriction on which sids can be ordered.  | 
            ||
| 1188 | """  | 
            ||
| 1189 | control = RestrictedListOrder(restricted_list)  | 
            ||
| 1190 | self.register_trading_control(control)  | 
            ||
| 1191 | |||
| 1192 | @api_method  | 
            ||
| 1193 | def set_long_only(self):  | 
            ||
| 1194 | """  | 
            ||
| 1195 | Set a rule specifying that this algorithm cannot take short positions.  | 
            ||
| 1196 | """  | 
            ||
| 1197 | self.register_trading_control(LongOnly())  | 
            ||
| 1198 | |||
| 1199 | ##############  | 
            ||
| 1200 | # Pipeline API  | 
            ||
| 1201 | ##############  | 
            ||
| 1202 | @api_method  | 
            ||
| 1203 | @require_not_initialized(AttachPipelineAfterInitialize())  | 
            ||
| 1204 | def attach_pipeline(self, pipeline, name, chunksize=None):  | 
            ||
| 1205 | """  | 
            ||
| 1206 | Register a pipeline to be computed at the start of each day.  | 
            ||
| 1207 | """  | 
            ||
| 1208 | if self._pipelines:  | 
            ||
| 1209 |             raise NotImplementedError("Multiple pipelines are not supported.") | 
            ||
| 1210 | if chunksize is None:  | 
            ||
| 1211 | # Make the first chunk smaller to get more immediate results:  | 
            ||
| 1212 | # (one week, then every half year)  | 
            ||
| 1213 | chunks = iter(chain([5], repeat(126)))  | 
            ||
| 1214 | else:  | 
            ||
| 1215 | chunks = iter(repeat(int(chunksize)))  | 
            ||
| 1216 | self._pipelines[name] = pipeline, chunks  | 
            ||
| 1217 | |||
| 1218 | # Return the pipeline to allow expressions like  | 
            ||
| 1219 | # p = attach_pipeline(Pipeline(), 'name')  | 
            ||
| 1220 | return pipeline  | 
            ||
| 1221 | |||
| 1222 | @api_method  | 
            ||
| 1223 | @require_initialized(PipelineOutputDuringInitialize())  | 
            ||
| 1224 | def pipeline_output(self, name):  | 
            ||
| 1225 | """  | 
            ||
| 1226 | Get the results of pipeline with name `name`.  | 
            ||
| 1227 | |||
| 1228 | Parameters  | 
            ||
| 1229 | ----------  | 
            ||
| 1230 | name : str  | 
            ||
| 1231 | Name of the pipeline for which results are requested.  | 
            ||
| 1232 | |||
| 1233 | Returns  | 
            ||
| 1234 | -------  | 
            ||
| 1235 | results : pd.DataFrame  | 
            ||
| 1236 | DataFrame containing the results of the requested pipeline for  | 
            ||
| 1237 | the current simulation date.  | 
            ||
| 1238 | |||
| 1239 | Raises  | 
            ||
| 1240 | ------  | 
            ||
| 1241 | NoSuchPipeline  | 
            ||
| 1242 | Raised when no pipeline with the name `name` has been registered.  | 
            ||
| 1243 | |||
| 1244 | See Also  | 
            ||
| 1245 | --------  | 
            ||
| 1246 | :meth:`zipline.pipeline.engine.PipelineEngine.run_pipeline`  | 
            ||
| 1247 | """  | 
            ||
| 1248 | # NOTE: We don't currently support multiple pipelines, but we plan to  | 
            ||
| 1249 | # in the future.  | 
            ||
| 1250 | try:  | 
            ||
| 1251 | p, chunks = self._pipelines[name]  | 
            ||
| 1252 | except KeyError:  | 
            ||
| 1253 | raise NoSuchPipeline(  | 
            ||
| 1254 | name=name,  | 
            ||
| 1255 | valid=list(self._pipelines.keys()),  | 
            ||
| 1256 | )  | 
            ||
| 1257 | return self._pipeline_output(p, chunks)  | 
            ||
| 1258 | |||
| 1259 | def _pipeline_output(self, pipeline, chunks):  | 
            ||
| 1260 | """  | 
            ||
| 1261 | Internal implementation of `pipeline_output`.  | 
            ||
| 1262 | """  | 
            ||
| 1263 | today = normalize_date(self.get_datetime())  | 
            ||
| 1264 | try:  | 
            ||
| 1265 | data = self._pipeline_cache.unwrap(today)  | 
            ||
| 1266 | except Expired:  | 
            ||
| 1267 | data, valid_until = self._run_pipeline(  | 
            ||
| 1268 | pipeline, today, next(chunks),  | 
            ||
| 1269 | )  | 
            ||
| 1270 | self._pipeline_cache = CachedObject(data, valid_until)  | 
            ||
| 1271 | |||
| 1272 | # Now that we have a cached result, try to return the data for today.  | 
            ||
| 1273 | try:  | 
            ||
| 1274 | return data.loc[today]  | 
            ||
| 1275 | except KeyError:  | 
            ||
| 1276 | # This happens if no assets passed the pipeline screen on a given  | 
            ||
| 1277 | # day.  | 
            ||
| 1278 | return pd.DataFrame(index=[], columns=data.columns)  | 
            ||
| 1279 | |||
| 1280 | def _run_pipeline(self, pipeline, start_date, chunksize):  | 
            ||
| 1281 | """  | 
            ||
| 1282 | Compute `pipeline`, providing values for at least `start_date`.  | 
            ||
| 1283 | |||
| 1284 | Produces a DataFrame containing data for days between `start_date` and  | 
            ||
| 1285 | `end_date`, where `end_date` is defined by:  | 
            ||
| 1286 | |||
| 1287 | `end_date = min(start_date + chunksize trading days,  | 
            ||
| 1288 | simulation_end)`  | 
            ||
| 1289 | |||
| 1290 | Returns  | 
            ||
| 1291 | -------  | 
            ||
| 1292 | (data, valid_until) : tuple (pd.DataFrame, pd.Timestamp)  | 
            ||
| 1293 | |||
| 1294 | See Also  | 
            ||
| 1295 | --------  | 
            ||
| 1296 | PipelineEngine.run_pipeline  | 
            ||
| 1297 | """  | 
            ||
| 1298 | days = self.trading_environment.trading_days  | 
            ||
| 1299 | |||
| 1300 | # Load data starting from the previous trading day...  | 
            ||
| 1301 | start_date_loc = days.get_loc(start_date)  | 
            ||
| 1302 | |||
| 1303 | # ...continuing until either the day before the simulation end, or  | 
            ||
| 1304 | # until chunksize days of data have been loaded.  | 
            ||
| 1305 | sim_end = self.sim_params.last_close.normalize()  | 
            ||
| 1306 | end_loc = min(start_date_loc + chunksize, days.get_loc(sim_end))  | 
            ||
| 1307 | end_date = days[end_loc]  | 
            ||
| 1308 | |||
| 1309 | return \  | 
            ||
| 1310 | self.engine.run_pipeline(pipeline, start_date, end_date), end_date  | 
            ||
| 1311 | |||
| 1312 | ##################  | 
            ||
| 1313 | # End Pipeline API  | 
            ||
| 1314 | ##################  | 
            ||
| 1315 | |||
| 1316 | @classmethod  | 
            ||
| 1317 | def all_api_methods(cls):  | 
            ||
| 1318 | """  | 
            ||
| 1319 | Return a list of all the TradingAlgorithm API methods.  | 
            ||
| 1320 | """  | 
            ||
| 1321 | return [  | 
            ||
| 1322 | fn for fn in itervalues(vars(cls))  | 
            ||
| 1323 | if getattr(fn, 'is_api_method', False)  | 
            ||
| 1324 | ]  | 
            ||
| 1325 |