| Total Complexity | 186 |
| Total Lines | 1389 |
| Duplicated Lines | 0 % |
Complex classes like zipline.TradingAlgorithm often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
| 1 | # |
||
| 112 | class TradingAlgorithm(object): |
||
| 113 | """ |
||
| 114 | Base class for trading algorithms. Inherit and overload |
||
| 115 | initialize() and handle_data(data). |
||
| 116 | |||
| 117 | A new algorithm could look like this: |
||
| 118 | ``` |
||
| 119 | from zipline.api import order, symbol |
||
| 120 | |||
| 121 | def initialize(context): |
||
| 122 | context.sid = symbol('AAPL') |
||
| 123 | context.amount = 100 |
||
| 124 | |||
| 125 | def handle_data(context, data): |
||
| 126 | sid = context.sid |
||
| 127 | amount = context.amount |
||
| 128 | order(sid, amount) |
||
| 129 | ``` |
||
| 130 | To then to run this algorithm pass these functions to |
||
| 131 | TradingAlgorithm: |
||
| 132 | |||
| 133 | my_algo = TradingAlgorithm(initialize, handle_data) |
||
| 134 | stats = my_algo.run(data) |
||
| 135 | |||
| 136 | """ |
||
| 137 | |||
| 138 | def __init__(self, *args, **kwargs): |
||
| 139 | """Initialize sids and other state variables. |
||
| 140 | |||
| 141 | :Arguments: |
||
| 142 | :Optional: |
||
| 143 | initialize : function |
||
| 144 | Function that is called with a single |
||
| 145 | argument at the begninning of the simulation. |
||
| 146 | handle_data : function |
||
| 147 | Function that is called with 2 arguments |
||
| 148 | (context and data) on every bar. |
||
| 149 | script : str |
||
| 150 | Algoscript that contains initialize and |
||
| 151 | handle_data function definition. |
||
| 152 | data_frequency : {'daily', 'minute'} |
||
| 153 | The duration of the bars. |
||
| 154 | capital_base : float <default: 1.0e5> |
||
| 155 | How much capital to start with. |
||
| 156 | instant_fill : bool <default: False> |
||
| 157 | Whether to fill orders immediately or on next bar. |
||
| 158 | asset_finder : An AssetFinder object |
||
| 159 | A new AssetFinder object to be used in this TradingEnvironment |
||
| 160 | equities_metadata : can be either: |
||
| 161 | - dict |
||
| 162 | - pandas.DataFrame |
||
| 163 | - object with 'read' property |
||
| 164 | If dict is provided, it must have the following structure: |
||
| 165 | * keys are the identifiers |
||
| 166 | * values are dicts containing the metadata, with the metadata |
||
| 167 | field name as the key |
||
| 168 | If pandas.DataFrame is provided, it must have the |
||
| 169 | following structure: |
||
| 170 | * column names must be the metadata fields |
||
| 171 | * index must be the different asset identifiers |
||
| 172 | * array contents should be the metadata value |
||
| 173 | If an object with a 'read' property is provided, 'read' must |
||
| 174 | return rows containing at least one of 'sid' or 'symbol' along |
||
| 175 | with the other metadata fields. |
||
| 176 | identifiers : List |
||
| 177 | Any asset identifiers that are not provided in the |
||
| 178 | equities_metadata, but will be traded by this TradingAlgorithm |
||
| 179 | """ |
||
| 180 | self.sources = [] |
||
| 181 | |||
| 182 | # List of trading controls to be used to validate orders. |
||
| 183 | self.trading_controls = [] |
||
| 184 | |||
| 185 | # List of account controls to be checked on each bar. |
||
| 186 | self.account_controls = [] |
||
| 187 | |||
| 188 | self._recorded_vars = {} |
||
| 189 | self.namespace = kwargs.pop('namespace', {}) |
||
| 190 | |||
| 191 | self._platform = kwargs.pop('platform', 'zipline') |
||
| 192 | |||
| 193 | self.logger = None |
||
| 194 | |||
| 195 | self.benchmark_return_source = None |
||
| 196 | |||
| 197 | # default components for transact |
||
| 198 | self.slippage = VolumeShareSlippage() |
||
| 199 | self.commission = PerShare() |
||
| 200 | |||
| 201 | self.instant_fill = kwargs.pop('instant_fill', False) |
||
| 202 | |||
| 203 | # If an env has been provided, pop it |
||
| 204 | self.trading_environment = kwargs.pop('env', None) |
||
| 205 | |||
| 206 | if self.trading_environment is None: |
||
| 207 | self.trading_environment = TradingEnvironment() |
||
| 208 | |||
| 209 | # Update the TradingEnvironment with the provided asset metadata |
||
| 210 | self.trading_environment.write_data( |
||
| 211 | equities_data=kwargs.pop('equities_metadata', {}), |
||
| 212 | equities_identifiers=kwargs.pop('identifiers', []), |
||
| 213 | futures_data=kwargs.pop('futures_metadata', {}), |
||
| 214 | ) |
||
| 215 | |||
| 216 | # set the capital base |
||
| 217 | self.capital_base = kwargs.pop('capital_base', DEFAULT_CAPITAL_BASE) |
||
| 218 | self.sim_params = kwargs.pop('sim_params', None) |
||
| 219 | if self.sim_params is None: |
||
| 220 | self.sim_params = create_simulation_parameters( |
||
| 221 | capital_base=self.capital_base, |
||
| 222 | start=kwargs.pop('start', None), |
||
| 223 | end=kwargs.pop('end', None), |
||
| 224 | env=self.trading_environment, |
||
| 225 | ) |
||
| 226 | else: |
||
| 227 | self.sim_params.update_internal_from_env(self.trading_environment) |
||
| 228 | |||
| 229 | # Build a perf_tracker |
||
| 230 | self.perf_tracker = PerformanceTracker(sim_params=self.sim_params, |
||
| 231 | env=self.trading_environment) |
||
| 232 | |||
| 233 | # Pull in the environment's new AssetFinder for quick reference |
||
| 234 | self.asset_finder = self.trading_environment.asset_finder |
||
| 235 | |||
| 236 | # Initialize Pipeline API data. |
||
| 237 | self.init_engine(kwargs.pop('get_pipeline_loader', None)) |
||
| 238 | self._pipelines = {} |
||
| 239 | # Create an always-expired cache so that we compute the first time data |
||
| 240 | # is requested. |
||
| 241 | self._pipeline_cache = CachedObject(None, pd.Timestamp(0, tz='UTC')) |
||
| 242 | |||
| 243 | self.blotter = kwargs.pop('blotter', None) |
||
| 244 | if not self.blotter: |
||
| 245 | self.blotter = Blotter() |
||
| 246 | |||
| 247 | # Set the dt initally to the period start by forcing it to change |
||
| 248 | self.on_dt_changed(self.sim_params.period_start) |
||
| 249 | |||
| 250 | # The symbol lookup date specifies the date to use when resolving |
||
| 251 | # symbols to sids, and can be set using set_symbol_lookup_date() |
||
| 252 | self._symbol_lookup_date = None |
||
| 253 | |||
| 254 | self.portfolio_needs_update = True |
||
| 255 | self.account_needs_update = True |
||
| 256 | self.performance_needs_update = True |
||
| 257 | self._portfolio = None |
||
| 258 | self._account = None |
||
| 259 | |||
| 260 | self.history_container_class = kwargs.pop( |
||
| 261 | 'history_container_class', HistoryContainer, |
||
| 262 | ) |
||
| 263 | self.history_container = None |
||
| 264 | self.history_specs = {} |
||
| 265 | |||
| 266 | # If string is passed in, execute and get reference to |
||
| 267 | # functions. |
||
| 268 | self.algoscript = kwargs.pop('script', None) |
||
| 269 | |||
| 270 | self._initialize = None |
||
| 271 | self._before_trading_start = None |
||
| 272 | self._analyze = None |
||
| 273 | |||
| 274 | self.event_manager = EventManager( |
||
| 275 | create_context=kwargs.pop('create_event_context', None), |
||
| 276 | ) |
||
| 277 | |||
| 278 | if self.algoscript is not None: |
||
| 279 | filename = kwargs.pop('algo_filename', None) |
||
| 280 | if filename is None: |
||
| 281 | filename = '<string>' |
||
| 282 | code = compile(self.algoscript, filename, 'exec') |
||
| 283 | exec_(code, self.namespace) |
||
| 284 | self._initialize = self.namespace.get('initialize') |
||
| 285 | if 'handle_data' not in self.namespace: |
||
| 286 | raise ValueError('You must define a handle_data function.') |
||
| 287 | else: |
||
| 288 | self._handle_data = self.namespace['handle_data'] |
||
| 289 | |||
| 290 | self._before_trading_start = \ |
||
| 291 | self.namespace.get('before_trading_start') |
||
| 292 | # Optional analyze function, gets called after run |
||
| 293 | self._analyze = self.namespace.get('analyze') |
||
| 294 | |||
| 295 | elif kwargs.get('initialize') and kwargs.get('handle_data'): |
||
| 296 | if self.algoscript is not None: |
||
| 297 | raise ValueError('You can not set script and \ |
||
| 298 | initialize/handle_data.') |
||
| 299 | self._initialize = kwargs.pop('initialize') |
||
| 300 | self._handle_data = kwargs.pop('handle_data') |
||
| 301 | self._before_trading_start = kwargs.pop('before_trading_start', |
||
| 302 | None) |
||
| 303 | self._analyze = kwargs.pop('analyze', None) |
||
| 304 | |||
| 305 | self.event_manager.add_event( |
||
| 306 | zipline.utils.events.Event( |
||
| 307 | zipline.utils.events.Always(), |
||
| 308 | # We pass handle_data.__func__ to get the unbound method. |
||
| 309 | # We will explicitly pass the algorithm to bind it again. |
||
| 310 | self.handle_data.__func__, |
||
| 311 | ), |
||
| 312 | prepend=True, |
||
| 313 | ) |
||
| 314 | |||
| 315 | # If method not defined, NOOP |
||
| 316 | if self._initialize is None: |
||
| 317 | self._initialize = lambda x: None |
||
| 318 | |||
| 319 | # Alternative way of setting data_frequency for backwards |
||
| 320 | # compatibility. |
||
| 321 | if 'data_frequency' in kwargs: |
||
| 322 | self.data_frequency = kwargs.pop('data_frequency') |
||
| 323 | |||
| 324 | self._most_recent_data = None |
||
| 325 | |||
| 326 | # Prepare the algo for initialization |
||
| 327 | self.initialized = False |
||
| 328 | self.initialize_args = args |
||
| 329 | self.initialize_kwargs = kwargs |
||
| 330 | |||
| 331 | def init_engine(self, get_loader): |
||
| 332 | """ |
||
| 333 | Construct and store a PipelineEngine from loader. |
||
| 334 | |||
| 335 | If get_loader is None, constructs a NoOpPipelineEngine. |
||
| 336 | """ |
||
| 337 | if get_loader is not None: |
||
| 338 | self.engine = SimplePipelineEngine( |
||
| 339 | get_loader, |
||
| 340 | self.trading_environment.trading_days, |
||
| 341 | self.asset_finder, |
||
| 342 | ) |
||
| 343 | else: |
||
| 344 | self.engine = NoOpPipelineEngine() |
||
| 345 | |||
| 346 | def initialize(self, *args, **kwargs): |
||
| 347 | """ |
||
| 348 | Call self._initialize with `self` made available to Zipline API |
||
| 349 | functions. |
||
| 350 | """ |
||
| 351 | with ZiplineAPI(self): |
||
| 352 | self._initialize(self, *args, **kwargs) |
||
| 353 | |||
| 354 | def before_trading_start(self, data): |
||
| 355 | if self._before_trading_start is None: |
||
| 356 | return |
||
| 357 | |||
| 358 | self._before_trading_start(self, data) |
||
| 359 | |||
| 360 | def handle_data(self, data): |
||
| 361 | self._most_recent_data = data |
||
| 362 | if self.history_container: |
||
| 363 | self.history_container.update(data, self.datetime) |
||
| 364 | |||
| 365 | self._handle_data(self, data) |
||
| 366 | |||
| 367 | # Unlike trading controls which remain constant unless placing an |
||
| 368 | # order, account controls can change each bar. Thus, must check |
||
| 369 | # every bar no matter if the algorithm places an order or not. |
||
| 370 | self.validate_account_controls() |
||
| 371 | |||
| 372 | def analyze(self, perf): |
||
| 373 | if self._analyze is None: |
||
| 374 | return |
||
| 375 | |||
| 376 | with ZiplineAPI(self): |
||
| 377 | self._analyze(self, perf) |
||
| 378 | |||
| 379 | def __repr__(self): |
||
| 380 | """ |
||
| 381 | N.B. this does not yet represent a string that can be used |
||
| 382 | to instantiate an exact copy of an algorithm. |
||
| 383 | |||
| 384 | However, it is getting close, and provides some value as something |
||
| 385 | that can be inspected interactively. |
||
| 386 | """ |
||
| 387 | return """ |
||
| 388 | {class_name}( |
||
| 389 | capital_base={capital_base} |
||
| 390 | sim_params={sim_params}, |
||
| 391 | initialized={initialized}, |
||
| 392 | slippage={slippage}, |
||
| 393 | commission={commission}, |
||
| 394 | blotter={blotter}, |
||
| 395 | recorded_vars={recorded_vars}) |
||
| 396 | """.strip().format(class_name=self.__class__.__name__, |
||
| 397 | capital_base=self.capital_base, |
||
| 398 | sim_params=repr(self.sim_params), |
||
| 399 | initialized=self.initialized, |
||
| 400 | slippage=repr(self.slippage), |
||
| 401 | commission=repr(self.commission), |
||
| 402 | blotter=repr(self.blotter), |
||
| 403 | recorded_vars=repr(self.recorded_vars)) |
||
| 404 | |||
| 405 | def _create_data_generator(self, source_filter, sim_params=None): |
||
| 406 | """ |
||
| 407 | Create a merged data generator using the sources attached to this |
||
| 408 | algorithm. |
||
| 409 | |||
| 410 | ::source_filter:: is a method that receives events in date |
||
| 411 | sorted order, and returns True for those events that should be |
||
| 412 | processed by the zipline, and False for those that should be |
||
| 413 | skipped. |
||
| 414 | """ |
||
| 415 | if sim_params is None: |
||
| 416 | sim_params = self.sim_params |
||
| 417 | |||
| 418 | if self.benchmark_return_source is None: |
||
| 419 | if sim_params.data_frequency == 'minute' or \ |
||
| 420 | sim_params.emission_rate == 'minute': |
||
| 421 | def update_time(date): |
||
| 422 | return self.trading_environment.get_open_and_close(date)[1] |
||
| 423 | else: |
||
| 424 | def update_time(date): |
||
| 425 | return date |
||
| 426 | benchmark_return_source = [ |
||
| 427 | Event({'dt': update_time(dt), |
||
| 428 | 'returns': ret, |
||
| 429 | 'type': zipline.protocol.DATASOURCE_TYPE.BENCHMARK, |
||
| 430 | 'source_id': 'benchmarks'}) |
||
| 431 | for dt, ret in |
||
| 432 | self.trading_environment.benchmark_returns.iteritems() |
||
| 433 | if dt.date() >= sim_params.period_start.date() and |
||
| 434 | dt.date() <= sim_params.period_end.date() |
||
| 435 | ] |
||
| 436 | else: |
||
| 437 | benchmark_return_source = self.benchmark_return_source |
||
| 438 | |||
| 439 | date_sorted = date_sorted_sources(*self.sources) |
||
| 440 | |||
| 441 | if source_filter: |
||
| 442 | date_sorted = filter(source_filter, date_sorted) |
||
| 443 | |||
| 444 | with_benchmarks = date_sorted_sources(benchmark_return_source, |
||
| 445 | date_sorted) |
||
| 446 | |||
| 447 | # Group together events with the same dt field. This depends on the |
||
| 448 | # events already being sorted. |
||
| 449 | return groupby(with_benchmarks, attrgetter('dt')) |
||
| 450 | |||
| 451 | def _create_generator(self, sim_params, source_filter=None): |
||
| 452 | """ |
||
| 453 | Create a basic generator setup using the sources to this algorithm. |
||
| 454 | |||
| 455 | ::source_filter:: is a method that receives events in date |
||
| 456 | sorted order, and returns True for those events that should be |
||
| 457 | processed by the zipline, and False for those that should be |
||
| 458 | skipped. |
||
| 459 | """ |
||
| 460 | |||
| 461 | if not self.initialized: |
||
| 462 | self.initialize(*self.initialize_args, **self.initialize_kwargs) |
||
| 463 | self.initialized = True |
||
| 464 | |||
| 465 | if self.perf_tracker is None: |
||
| 466 | # HACK: When running with the `run` method, we set perf_tracker to |
||
| 467 | # None so that it will be overwritten here. |
||
| 468 | self.perf_tracker = PerformanceTracker( |
||
| 469 | sim_params=sim_params, env=self.trading_environment |
||
| 470 | ) |
||
| 471 | |||
| 472 | self.portfolio_needs_update = True |
||
| 473 | self.account_needs_update = True |
||
| 474 | self.performance_needs_update = True |
||
| 475 | |||
| 476 | self.data_gen = self._create_data_generator(source_filter, sim_params) |
||
| 477 | |||
| 478 | self.trading_client = AlgorithmSimulator(self, sim_params) |
||
| 479 | |||
| 480 | transact_method = transact_partial(self.slippage, self.commission) |
||
| 481 | self.set_transact(transact_method) |
||
| 482 | |||
| 483 | return self.trading_client.transform(self.data_gen) |
||
| 484 | |||
| 485 | def get_generator(self): |
||
| 486 | """ |
||
| 487 | Override this method to add new logic to the construction |
||
| 488 | of the generator. Overrides can use the _create_generator |
||
| 489 | method to get a standard construction generator. |
||
| 490 | """ |
||
| 491 | return self._create_generator(self.sim_params) |
||
| 492 | |||
| 493 | # TODO: make a new subclass, e.g. BatchAlgorithm, and move |
||
| 494 | # the run method to the subclass, and refactor to put the |
||
| 495 | # generator creation logic into get_generator. |
||
| 496 | def run(self, source, overwrite_sim_params=True, |
||
| 497 | benchmark_return_source=None): |
||
| 498 | """Run the algorithm. |
||
| 499 | |||
| 500 | :Arguments: |
||
| 501 | source : can be either: |
||
| 502 | - pandas.DataFrame |
||
| 503 | - zipline source |
||
| 504 | - list of sources |
||
| 505 | |||
| 506 | If pandas.DataFrame is provided, it must have the |
||
| 507 | following structure: |
||
| 508 | * column names must be the different asset identifiers |
||
| 509 | * index must be DatetimeIndex |
||
| 510 | * array contents should be price info. |
||
| 511 | |||
| 512 | :Returns: |
||
| 513 | daily_stats : pandas.DataFrame |
||
| 514 | Daily performance metrics such as returns, alpha etc. |
||
| 515 | |||
| 516 | """ |
||
| 517 | |||
| 518 | # Ensure that source is a DataSource object |
||
| 519 | if isinstance(source, list): |
||
| 520 | if overwrite_sim_params: |
||
| 521 | warnings.warn("""List of sources passed, will not attempt to extract start and end |
||
| 522 | dates. Make sure to set the correct fields in sim_params passed to |
||
| 523 | __init__().""", UserWarning) |
||
| 524 | overwrite_sim_params = False |
||
| 525 | elif isinstance(source, pd.DataFrame): |
||
| 526 | # if DataFrame provided, map columns to sids and wrap |
||
| 527 | # in DataFrameSource |
||
| 528 | copy_frame = source.copy() |
||
| 529 | copy_frame.columns = self._write_and_map_id_index_to_sids( |
||
| 530 | source.columns, source.index[0], |
||
| 531 | ) |
||
| 532 | source = DataFrameSource(copy_frame) |
||
| 533 | |||
| 534 | elif isinstance(source, pd.Panel): |
||
| 535 | # If Panel provided, map items to sids and wrap |
||
| 536 | # in DataPanelSource |
||
| 537 | copy_panel = source.copy() |
||
| 538 | copy_panel.items = self._write_and_map_id_index_to_sids( |
||
| 539 | source.items, source.major_axis[0], |
||
| 540 | ) |
||
| 541 | source = DataPanelSource(copy_panel) |
||
| 542 | |||
| 543 | if isinstance(source, list): |
||
| 544 | self.set_sources(source) |
||
| 545 | else: |
||
| 546 | self.set_sources([source]) |
||
| 547 | |||
| 548 | # Override sim_params if params are provided by the source. |
||
| 549 | if overwrite_sim_params: |
||
| 550 | if hasattr(source, 'start'): |
||
| 551 | self.sim_params.period_start = source.start |
||
| 552 | if hasattr(source, 'end'): |
||
| 553 | self.sim_params.period_end = source.end |
||
| 554 | # Changing period_start and period_close might require updating |
||
| 555 | # of first_open and last_close. |
||
| 556 | self.sim_params.update_internal_from_env( |
||
| 557 | env=self.trading_environment |
||
| 558 | ) |
||
| 559 | |||
| 560 | # The sids field of the source is the reference for the universe at |
||
| 561 | # the start of the run |
||
| 562 | self._current_universe = set() |
||
| 563 | for source in self.sources: |
||
| 564 | for sid in source.sids: |
||
| 565 | self._current_universe.add(sid) |
||
| 566 | # Check that all sids from the source are accounted for in |
||
| 567 | # the AssetFinder. This retrieve call will raise an exception if the |
||
| 568 | # sid is not found. |
||
| 569 | for sid in self._current_universe: |
||
| 570 | self.asset_finder.retrieve_asset(sid) |
||
| 571 | |||
| 572 | # force a reset of the performance tracker, in case |
||
| 573 | # this is a repeat run of the algorithm. |
||
| 574 | self.perf_tracker = None |
||
| 575 | |||
| 576 | # create zipline |
||
| 577 | self.gen = self._create_generator(self.sim_params) |
||
| 578 | |||
| 579 | # Create history containers |
||
| 580 | if self.history_specs: |
||
| 581 | self.history_container = self.history_container_class( |
||
| 582 | self.history_specs, |
||
| 583 | self.current_universe(), |
||
| 584 | self.sim_params.first_open, |
||
| 585 | self.sim_params.data_frequency, |
||
| 586 | self.trading_environment, |
||
| 587 | ) |
||
| 588 | |||
| 589 | # loop through simulated_trading, each iteration returns a |
||
| 590 | # perf dictionary |
||
| 591 | perfs = [] |
||
| 592 | for perf in self.gen: |
||
| 593 | perfs.append(perf) |
||
| 594 | |||
| 595 | # convert perf dict to pandas dataframe |
||
| 596 | daily_stats = self._create_daily_stats(perfs) |
||
| 597 | |||
| 598 | self.analyze(daily_stats) |
||
| 599 | |||
| 600 | return daily_stats |
||
| 601 | |||
| 602 | def _write_and_map_id_index_to_sids(self, identifiers, as_of_date): |
||
| 603 | # Build new Assets for identifiers that can't be resolved as |
||
| 604 | # sids/Assets |
||
| 605 | identifiers_to_build = [] |
||
| 606 | for identifier in identifiers: |
||
| 607 | asset = None |
||
| 608 | |||
| 609 | if isinstance(identifier, Asset): |
||
| 610 | asset = self.asset_finder.retrieve_asset(sid=identifier.sid, |
||
| 611 | default_none=True) |
||
| 612 | elif isinstance(identifier, Integral): |
||
| 613 | asset = self.asset_finder.retrieve_asset(sid=identifier, |
||
| 614 | default_none=True) |
||
| 615 | if asset is None: |
||
| 616 | identifiers_to_build.append(identifier) |
||
| 617 | |||
| 618 | self.trading_environment.write_data( |
||
| 619 | equities_identifiers=identifiers_to_build) |
||
| 620 | |||
| 621 | # We need to clear out any cache misses that were stored while trying |
||
| 622 | # to do lookups. The real fix for this problem is to not construct an |
||
| 623 | # AssetFinder until we `run()` when we actually have all the data we |
||
| 624 | # need to so. |
||
| 625 | self.asset_finder._reset_caches() |
||
| 626 | |||
| 627 | return self.asset_finder.map_identifier_index_to_sids( |
||
| 628 | identifiers, as_of_date, |
||
| 629 | ) |
||
| 630 | |||
| 631 | def _create_daily_stats(self, perfs): |
||
| 632 | # create daily and cumulative stats dataframe |
||
| 633 | daily_perfs = [] |
||
| 634 | # TODO: the loop here could overwrite expected properties |
||
| 635 | # of daily_perf. Could potentially raise or log a |
||
| 636 | # warning. |
||
| 637 | for perf in perfs: |
||
| 638 | if 'daily_perf' in perf: |
||
| 639 | |||
| 640 | perf['daily_perf'].update( |
||
| 641 | perf['daily_perf'].pop('recorded_vars') |
||
| 642 | ) |
||
| 643 | perf['daily_perf'].update(perf['cumulative_risk_metrics']) |
||
| 644 | daily_perfs.append(perf['daily_perf']) |
||
| 645 | else: |
||
| 646 | self.risk_report = perf |
||
| 647 | |||
| 648 | daily_dts = [np.datetime64(perf['period_close'], utc=True) |
||
| 649 | for perf in daily_perfs] |
||
| 650 | daily_stats = pd.DataFrame(daily_perfs, index=daily_dts) |
||
| 651 | |||
| 652 | return daily_stats |
||
| 653 | |||
| 654 | @api_method |
||
| 655 | def add_transform(self, transform, days=None): |
||
| 656 | """ |
||
| 657 | Ensures that the history container will have enough size to service |
||
| 658 | a simple transform. |
||
| 659 | |||
| 660 | :Arguments: |
||
| 661 | transform : string |
||
| 662 | The transform to add. must be an element of: |
||
| 663 | {'mavg', 'stddev', 'vwap', 'returns'}. |
||
| 664 | days : int <default=None> |
||
| 665 | The maximum amount of days you will want for this transform. |
||
| 666 | This is not needed for 'returns'. |
||
| 667 | """ |
||
| 668 | if transform not in {'mavg', 'stddev', 'vwap', 'returns'}: |
||
| 669 | raise ValueError('Invalid transform') |
||
| 670 | |||
| 671 | if transform == 'returns': |
||
| 672 | if days is not None: |
||
| 673 | raise ValueError('returns does use days') |
||
| 674 | |||
| 675 | self.add_history(2, '1d', 'price') |
||
| 676 | return |
||
| 677 | elif days is None: |
||
| 678 | raise ValueError('no number of days specified') |
||
| 679 | |||
| 680 | if self.sim_params.data_frequency == 'daily': |
||
| 681 | mult = 1 |
||
| 682 | freq = '1d' |
||
| 683 | else: |
||
| 684 | mult = 390 |
||
| 685 | freq = '1m' |
||
| 686 | |||
| 687 | bars = mult * days |
||
| 688 | self.add_history(bars, freq, 'price') |
||
| 689 | |||
| 690 | if transform == 'vwap': |
||
| 691 | self.add_history(bars, freq, 'volume') |
||
| 692 | |||
| 693 | @api_method |
||
| 694 | def get_environment(self, field='platform'): |
||
| 695 | env = { |
||
| 696 | 'arena': self.sim_params.arena, |
||
| 697 | 'data_frequency': self.sim_params.data_frequency, |
||
| 698 | 'start': self.sim_params.first_open, |
||
| 699 | 'end': self.sim_params.last_close, |
||
| 700 | 'capital_base': self.sim_params.capital_base, |
||
| 701 | 'platform': self._platform |
||
| 702 | } |
||
| 703 | if field == '*': |
||
| 704 | return env |
||
| 705 | else: |
||
| 706 | return env[field] |
||
| 707 | |||
| 708 | def add_event(self, rule=None, callback=None): |
||
| 709 | """ |
||
| 710 | Adds an event to the algorithm's EventManager. |
||
| 711 | """ |
||
| 712 | self.event_manager.add_event( |
||
| 713 | zipline.utils.events.Event(rule, callback), |
||
| 714 | ) |
||
| 715 | |||
| 716 | @api_method |
||
| 717 | def schedule_function(self, |
||
| 718 | func, |
||
| 719 | date_rule=None, |
||
| 720 | time_rule=None, |
||
| 721 | half_days=True): |
||
| 722 | """ |
||
| 723 | Schedules a function to be called with some timed rules. |
||
| 724 | """ |
||
| 725 | date_rule = date_rule or DateRuleFactory.every_day() |
||
| 726 | time_rule = ((time_rule or TimeRuleFactory.market_open()) |
||
| 727 | if self.sim_params.data_frequency == 'minute' else |
||
| 728 | # If we are in daily mode the time_rule is ignored. |
||
| 729 | zipline.utils.events.Always()) |
||
| 730 | |||
| 731 | self.add_event( |
||
| 732 | make_eventrule(date_rule, time_rule, half_days), |
||
| 733 | func, |
||
| 734 | ) |
||
| 735 | |||
| 736 | @api_method |
||
| 737 | def record(self, *args, **kwargs): |
||
| 738 | """ |
||
| 739 | Track and record local variable (i.e. attributes) each day. |
||
| 740 | """ |
||
| 741 | # Make 2 objects both referencing the same iterator |
||
| 742 | args = [iter(args)] * 2 |
||
| 743 | |||
| 744 | # Zip generates list entries by calling `next` on each iterator it |
||
| 745 | # receives. In this case the two iterators are the same object, so the |
||
| 746 | # call to next on args[0] will also advance args[1], resulting in zip |
||
| 747 | # returning (a,b) (c,d) (e,f) rather than (a,a) (b,b) (c,c) etc. |
||
| 748 | positionals = zip(*args) |
||
| 749 | for name, value in chain(positionals, iteritems(kwargs)): |
||
| 750 | self._recorded_vars[name] = value |
||
| 751 | |||
| 752 | @api_method |
||
| 753 | @preprocess(symbol_str=ensure_upper_case) |
||
| 754 | def symbol(self, symbol_str): |
||
| 755 | """ |
||
| 756 | Default symbol lookup for any source that directly maps the |
||
| 757 | symbol to the Asset (e.g. yahoo finance). |
||
| 758 | """ |
||
| 759 | # If the user has not set the symbol lookup date, |
||
| 760 | # use the period_end as the date for sybmol->sid resolution. |
||
| 761 | _lookup_date = self._symbol_lookup_date if self._symbol_lookup_date is not None \ |
||
| 762 | else self.sim_params.period_end |
||
| 763 | |||
| 764 | return self.asset_finder.lookup_symbol( |
||
| 765 | symbol_str, |
||
| 766 | as_of_date=_lookup_date, |
||
| 767 | ) |
||
| 768 | |||
| 769 | @api_method |
||
| 770 | def symbols(self, *args): |
||
| 771 | """ |
||
| 772 | Default symbols lookup for any source that directly maps the |
||
| 773 | symbol to the Asset (e.g. yahoo finance). |
||
| 774 | """ |
||
| 775 | return [self.symbol(identifier) for identifier in args] |
||
| 776 | |||
| 777 | @api_method |
||
| 778 | def sid(self, a_sid): |
||
| 779 | """ |
||
| 780 | Default sid lookup for any source that directly maps the integer sid |
||
| 781 | to the Asset. |
||
| 782 | """ |
||
| 783 | return self.asset_finder.retrieve_asset(a_sid) |
||
| 784 | |||
| 785 | @api_method |
||
| 786 | @preprocess(symbol=ensure_upper_case) |
||
| 787 | def future_symbol(self, symbol): |
||
| 788 | """ Lookup a futures contract with a given symbol. |
||
| 789 | |||
| 790 | Parameters |
||
| 791 | ---------- |
||
| 792 | symbol : str |
||
| 793 | The symbol of the desired contract. |
||
| 794 | |||
| 795 | Returns |
||
| 796 | ------- |
||
| 797 | Future |
||
| 798 | A Future object. |
||
| 799 | |||
| 800 | Raises |
||
| 801 | ------ |
||
| 802 | SymbolNotFound |
||
| 803 | Raised when no contract named 'symbol' is found. |
||
| 804 | |||
| 805 | """ |
||
| 806 | return self.asset_finder.lookup_future_symbol(symbol) |
||
| 807 | |||
| 808 | @api_method |
||
| 809 | @preprocess(root_symbol=ensure_upper_case) |
||
| 810 | def future_chain(self, root_symbol, as_of_date=None): |
||
| 811 | """ Look up a future chain with the specified parameters. |
||
| 812 | |||
| 813 | Parameters |
||
| 814 | ---------- |
||
| 815 | root_symbol : str |
||
| 816 | The root symbol of a future chain. |
||
| 817 | as_of_date : datetime.datetime or pandas.Timestamp or str, optional |
||
| 818 | Date at which the chain determination is rooted. I.e. the |
||
| 819 | existing contract whose notice date is first after this date is |
||
| 820 | the primary contract, etc. |
||
| 821 | |||
| 822 | Returns |
||
| 823 | ------- |
||
| 824 | FutureChain |
||
| 825 | The future chain matching the specified parameters. |
||
| 826 | |||
| 827 | Raises |
||
| 828 | ------ |
||
| 829 | RootSymbolNotFound |
||
| 830 | If a future chain could not be found for the given root symbol. |
||
| 831 | """ |
||
| 832 | if as_of_date: |
||
| 833 | try: |
||
| 834 | as_of_date = pd.Timestamp(as_of_date, tz='UTC') |
||
| 835 | except ValueError: |
||
| 836 | raise UnsupportedDatetimeFormat(input=as_of_date, |
||
| 837 | method='future_chain') |
||
| 838 | return FutureChain( |
||
| 839 | asset_finder=self.asset_finder, |
||
| 840 | get_datetime=self.get_datetime, |
||
| 841 | root_symbol=root_symbol, |
||
| 842 | as_of_date=as_of_date |
||
| 843 | ) |
||
| 844 | |||
| 845 | def _calculate_order_value_amount(self, asset, value): |
||
| 846 | """ |
||
| 847 | Calculates how many shares/contracts to order based on the type of |
||
| 848 | asset being ordered. |
||
| 849 | """ |
||
| 850 | last_price = self.trading_client.current_data[asset].price |
||
| 851 | |||
| 852 | if tolerant_equals(last_price, 0): |
||
| 853 | zero_message = "Price of 0 for {psid}; can't infer value".format( |
||
| 854 | psid=asset |
||
| 855 | ) |
||
| 856 | if self.logger: |
||
| 857 | self.logger.debug(zero_message) |
||
| 858 | # Don't place any order |
||
| 859 | return 0 |
||
| 860 | |||
| 861 | if isinstance(asset, Future): |
||
| 862 | value_multiplier = asset.contract_multiplier |
||
| 863 | else: |
||
| 864 | value_multiplier = 1 |
||
| 865 | |||
| 866 | return value / (last_price * value_multiplier) |
||
| 867 | |||
| 868 | @api_method |
||
| 869 | def order(self, sid, amount, |
||
| 870 | limit_price=None, |
||
| 871 | stop_price=None, |
||
| 872 | style=None): |
||
| 873 | """ |
||
| 874 | Place an order using the specified parameters. |
||
| 875 | """ |
||
| 876 | |||
| 877 | def round_if_near_integer(a, epsilon=1e-4): |
||
| 878 | """ |
||
| 879 | Round a to the nearest integer if that integer is within an epsilon |
||
| 880 | of a. |
||
| 881 | """ |
||
| 882 | if abs(a - round(a)) <= epsilon: |
||
| 883 | return round(a) |
||
| 884 | else: |
||
| 885 | return a |
||
| 886 | |||
| 887 | # Truncate to the integer share count that's either within .0001 of |
||
| 888 | # amount or closer to zero. |
||
| 889 | # E.g. 3.9999 -> 4.0; 5.5 -> 5.0; -5.5 -> -5.0 |
||
| 890 | amount = int(round_if_near_integer(amount)) |
||
| 891 | |||
| 892 | # Raises a ZiplineError if invalid parameters are detected. |
||
| 893 | self.validate_order_params(sid, |
||
| 894 | amount, |
||
| 895 | limit_price, |
||
| 896 | stop_price, |
||
| 897 | style) |
||
| 898 | |||
| 899 | # Convert deprecated limit_price and stop_price parameters to use |
||
| 900 | # ExecutionStyle objects. |
||
| 901 | style = self.__convert_order_params_for_blotter(limit_price, |
||
| 902 | stop_price, |
||
| 903 | style) |
||
| 904 | return self.blotter.order(sid, amount, style) |
||
| 905 | |||
| 906 | def validate_order_params(self, |
||
| 907 | asset, |
||
| 908 | amount, |
||
| 909 | limit_price, |
||
| 910 | stop_price, |
||
| 911 | style): |
||
| 912 | """ |
||
| 913 | Helper method for validating parameters to the order API function. |
||
| 914 | |||
| 915 | Raises an UnsupportedOrderParameters if invalid arguments are found. |
||
| 916 | """ |
||
| 917 | |||
| 918 | if not self.initialized: |
||
| 919 | raise OrderDuringInitialize( |
||
| 920 | msg="order() can only be called from within handle_data()" |
||
| 921 | ) |
||
| 922 | |||
| 923 | if style: |
||
| 924 | if limit_price: |
||
| 925 | raise UnsupportedOrderParameters( |
||
| 926 | msg="Passing both limit_price and style is not supported." |
||
| 927 | ) |
||
| 928 | |||
| 929 | if stop_price: |
||
| 930 | raise UnsupportedOrderParameters( |
||
| 931 | msg="Passing both stop_price and style is not supported." |
||
| 932 | ) |
||
| 933 | |||
| 934 | if not isinstance(asset, Asset): |
||
| 935 | raise UnsupportedOrderParameters( |
||
| 936 | msg="Passing non-Asset argument to 'order()' is not supported." |
||
| 937 | " Use 'sid()' or 'symbol()' methods to look up an Asset." |
||
| 938 | ) |
||
| 939 | |||
| 940 | for control in self.trading_controls: |
||
| 941 | control.validate(asset, |
||
| 942 | amount, |
||
| 943 | self.updated_portfolio(), |
||
| 944 | self.get_datetime(), |
||
| 945 | self.trading_client.current_data) |
||
| 946 | |||
| 947 | @staticmethod |
||
| 948 | def __convert_order_params_for_blotter(limit_price, stop_price, style): |
||
| 949 | """ |
||
| 950 | Helper method for converting deprecated limit_price and stop_price |
||
| 951 | arguments into ExecutionStyle instances. |
||
| 952 | |||
| 953 | This function assumes that either style == None or (limit_price, |
||
| 954 | stop_price) == (None, None). |
||
| 955 | """ |
||
| 956 | # TODO_SS: DeprecationWarning for usage of limit_price and stop_price. |
||
| 957 | if style: |
||
| 958 | assert (limit_price, stop_price) == (None, None) |
||
| 959 | return style |
||
| 960 | if limit_price and stop_price: |
||
| 961 | return StopLimitOrder(limit_price, stop_price) |
||
| 962 | if limit_price: |
||
| 963 | return LimitOrder(limit_price) |
||
| 964 | if stop_price: |
||
| 965 | return StopOrder(stop_price) |
||
| 966 | else: |
||
| 967 | return MarketOrder() |
||
| 968 | |||
| 969 | @api_method |
||
| 970 | def order_value(self, sid, value, |
||
| 971 | limit_price=None, stop_price=None, style=None): |
||
| 972 | """ |
||
| 973 | Place an order by desired value rather than desired number of shares. |
||
| 974 | If the requested sid is found in the universe, the requested value is |
||
| 975 | divided by its price to imply the number of shares to transact. |
||
| 976 | If the Asset being ordered is a Future, the 'value' calculated |
||
| 977 | is actually the exposure, as Futures have no 'value'. |
||
| 978 | |||
| 979 | value > 0 :: Buy/Cover |
||
| 980 | value < 0 :: Sell/Short |
||
| 981 | Market order: order(sid, value) |
||
| 982 | Limit order: order(sid, value, limit_price) |
||
| 983 | Stop order: order(sid, value, None, stop_price) |
||
| 984 | StopLimit order: order(sid, value, limit_price, stop_price) |
||
| 985 | """ |
||
| 986 | amount = self._calculate_order_value_amount(sid, value) |
||
| 987 | return self.order(sid, amount, |
||
| 988 | limit_price=limit_price, |
||
| 989 | stop_price=stop_price, |
||
| 990 | style=style) |
||
| 991 | |||
| 992 | @property |
||
| 993 | def recorded_vars(self): |
||
| 994 | return copy(self._recorded_vars) |
||
| 995 | |||
| 996 | @property |
||
| 997 | def portfolio(self): |
||
| 998 | return self.updated_portfolio() |
||
| 999 | |||
| 1000 | def updated_portfolio(self): |
||
| 1001 | if self.portfolio_needs_update: |
||
| 1002 | self._portfolio = \ |
||
| 1003 | self.perf_tracker.get_portfolio(self.performance_needs_update) |
||
| 1004 | self.portfolio_needs_update = False |
||
| 1005 | self.performance_needs_update = False |
||
| 1006 | return self._portfolio |
||
| 1007 | |||
| 1008 | @property |
||
| 1009 | def account(self): |
||
| 1010 | return self.updated_account() |
||
| 1011 | |||
| 1012 | def updated_account(self): |
||
| 1013 | if self.account_needs_update: |
||
| 1014 | self._account = \ |
||
| 1015 | self.perf_tracker.get_account(self.performance_needs_update) |
||
| 1016 | self.account_needs_update = False |
||
| 1017 | self.performance_needs_update = False |
||
| 1018 | return self._account |
||
| 1019 | |||
| 1020 | def set_logger(self, logger): |
||
| 1021 | self.logger = logger |
||
| 1022 | |||
| 1023 | def on_dt_changed(self, dt): |
||
| 1024 | """ |
||
| 1025 | Callback triggered by the simulation loop whenever the current dt |
||
| 1026 | changes. |
||
| 1027 | |||
| 1028 | Any logic that should happen exactly once at the start of each datetime |
||
| 1029 | group should happen here. |
||
| 1030 | """ |
||
| 1031 | assert isinstance(dt, datetime), \ |
||
| 1032 | "Attempt to set algorithm's current time with non-datetime" |
||
| 1033 | assert dt.tzinfo == pytz.utc, \ |
||
| 1034 | "Algorithm expects a utc datetime" |
||
| 1035 | |||
| 1036 | self.datetime = dt |
||
| 1037 | self.perf_tracker.set_date(dt) |
||
| 1038 | self.blotter.set_date(dt) |
||
| 1039 | |||
| 1040 | @api_method |
||
| 1041 | def get_datetime(self, tz=None): |
||
| 1042 | """ |
||
| 1043 | Returns the simulation datetime. |
||
| 1044 | """ |
||
| 1045 | dt = self.datetime |
||
| 1046 | assert dt.tzinfo == pytz.utc, "Algorithm should have a utc datetime" |
||
| 1047 | |||
| 1048 | if tz is not None: |
||
| 1049 | # Convert to the given timezone passed as a string or tzinfo. |
||
| 1050 | if isinstance(tz, string_types): |
||
| 1051 | tz = pytz.timezone(tz) |
||
| 1052 | dt = dt.astimezone(tz) |
||
| 1053 | |||
| 1054 | return dt # datetime.datetime objects are immutable. |
||
| 1055 | |||
| 1056 | def set_transact(self, transact): |
||
| 1057 | """ |
||
| 1058 | Set the method that will be called to create a |
||
| 1059 | transaction from open orders and trade events. |
||
| 1060 | """ |
||
| 1061 | self.blotter.transact = transact |
||
| 1062 | |||
| 1063 | def update_dividends(self, dividend_frame): |
||
| 1064 | """ |
||
| 1065 | Set DataFrame used to process dividends. DataFrame columns should |
||
| 1066 | contain at least the entries in zp.DIVIDEND_FIELDS. |
||
| 1067 | """ |
||
| 1068 | self.perf_tracker.update_dividends(dividend_frame) |
||
| 1069 | |||
| 1070 | @api_method |
||
| 1071 | def set_slippage(self, slippage): |
||
| 1072 | if not isinstance(slippage, SlippageModel): |
||
| 1073 | raise UnsupportedSlippageModel() |
||
| 1074 | if self.initialized: |
||
| 1075 | raise SetSlippagePostInit() |
||
| 1076 | self.slippage = slippage |
||
| 1077 | |||
| 1078 | @api_method |
||
| 1079 | def set_commission(self, commission): |
||
| 1080 | if not isinstance(commission, (PerShare, PerTrade, PerDollar)): |
||
| 1081 | raise UnsupportedCommissionModel() |
||
| 1082 | |||
| 1083 | if self.initialized: |
||
| 1084 | raise SetCommissionPostInit() |
||
| 1085 | self.commission = commission |
||
| 1086 | |||
| 1087 | @api_method |
||
| 1088 | def set_symbol_lookup_date(self, dt): |
||
| 1089 | """ |
||
| 1090 | Set the date for which symbols will be resolved to their sids |
||
| 1091 | (symbols may map to different firms or underlying assets at |
||
| 1092 | different times) |
||
| 1093 | """ |
||
| 1094 | try: |
||
| 1095 | self._symbol_lookup_date = pd.Timestamp(dt, tz='UTC') |
||
| 1096 | except ValueError: |
||
| 1097 | raise UnsupportedDatetimeFormat(input=dt, |
||
| 1098 | method='set_symbol_lookup_date') |
||
| 1099 | |||
| 1100 | def set_sources(self, sources): |
||
| 1101 | assert isinstance(sources, list) |
||
| 1102 | self.sources = sources |
||
| 1103 | |||
| 1104 | # Remain backwards compatibility |
||
| 1105 | @property |
||
| 1106 | def data_frequency(self): |
||
| 1107 | return self.sim_params.data_frequency |
||
| 1108 | |||
| 1109 | @data_frequency.setter |
||
| 1110 | def data_frequency(self, value): |
||
| 1111 | assert value in ('daily', 'minute') |
||
| 1112 | self.sim_params.data_frequency = value |
||
| 1113 | |||
| 1114 | @api_method |
||
| 1115 | def order_percent(self, sid, percent, |
||
| 1116 | limit_price=None, stop_price=None, style=None): |
||
| 1117 | """ |
||
| 1118 | Place an order in the specified asset corresponding to the given |
||
| 1119 | percent of the current portfolio value. |
||
| 1120 | |||
| 1121 | Note that percent must expressed as a decimal (0.50 means 50\%). |
||
| 1122 | """ |
||
| 1123 | value = self.portfolio.portfolio_value * percent |
||
| 1124 | return self.order_value(sid, value, |
||
| 1125 | limit_price=limit_price, |
||
| 1126 | stop_price=stop_price, |
||
| 1127 | style=style) |
||
| 1128 | |||
| 1129 | @api_method |
||
| 1130 | def order_target(self, sid, target, |
||
| 1131 | limit_price=None, stop_price=None, style=None): |
||
| 1132 | """ |
||
| 1133 | Place an order to adjust a position to a target number of shares. If |
||
| 1134 | the position doesn't already exist, this is equivalent to placing a new |
||
| 1135 | order. If the position does exist, this is equivalent to placing an |
||
| 1136 | order for the difference between the target number of shares and the |
||
| 1137 | current number of shares. |
||
| 1138 | """ |
||
| 1139 | if sid in self.portfolio.positions: |
||
| 1140 | current_position = self.portfolio.positions[sid].amount |
||
| 1141 | req_shares = target - current_position |
||
| 1142 | return self.order(sid, req_shares, |
||
| 1143 | limit_price=limit_price, |
||
| 1144 | stop_price=stop_price, |
||
| 1145 | style=style) |
||
| 1146 | else: |
||
| 1147 | return self.order(sid, target, |
||
| 1148 | limit_price=limit_price, |
||
| 1149 | stop_price=stop_price, |
||
| 1150 | style=style) |
||
| 1151 | |||
| 1152 | @api_method |
||
| 1153 | def order_target_value(self, sid, target, |
||
| 1154 | limit_price=None, stop_price=None, style=None): |
||
| 1155 | """ |
||
| 1156 | Place an order to adjust a position to a target value. If |
||
| 1157 | the position doesn't already exist, this is equivalent to placing a new |
||
| 1158 | order. If the position does exist, this is equivalent to placing an |
||
| 1159 | order for the difference between the target value and the |
||
| 1160 | current value. |
||
| 1161 | If the Asset being ordered is a Future, the 'target value' calculated |
||
| 1162 | is actually the target exposure, as Futures have no 'value'. |
||
| 1163 | """ |
||
| 1164 | target_amount = self._calculate_order_value_amount(sid, target) |
||
| 1165 | return self.order_target(sid, target_amount, |
||
| 1166 | limit_price=limit_price, |
||
| 1167 | stop_price=stop_price, |
||
| 1168 | style=style) |
||
| 1169 | |||
| 1170 | @api_method |
||
| 1171 | def order_target_percent(self, sid, target, |
||
| 1172 | limit_price=None, stop_price=None, style=None): |
||
| 1173 | """ |
||
| 1174 | Place an order to adjust a position to a target percent of the |
||
| 1175 | current portfolio value. If the position doesn't already exist, this is |
||
| 1176 | equivalent to placing a new order. If the position does exist, this is |
||
| 1177 | equivalent to placing an order for the difference between the target |
||
| 1178 | percent and the current percent. |
||
| 1179 | |||
| 1180 | Note that target must expressed as a decimal (0.50 means 50\%). |
||
| 1181 | """ |
||
| 1182 | target_value = self.portfolio.portfolio_value * target |
||
| 1183 | return self.order_target_value(sid, target_value, |
||
| 1184 | limit_price=limit_price, |
||
| 1185 | stop_price=stop_price, |
||
| 1186 | style=style) |
||
| 1187 | |||
| 1188 | @api_method |
||
| 1189 | def get_open_orders(self, sid=None): |
||
| 1190 | if sid is None: |
||
| 1191 | return { |
||
| 1192 | key: [order.to_api_obj() for order in orders] |
||
| 1193 | for key, orders in iteritems(self.blotter.open_orders) |
||
| 1194 | if orders |
||
| 1195 | } |
||
| 1196 | if sid in self.blotter.open_orders: |
||
| 1197 | orders = self.blotter.open_orders[sid] |
||
| 1198 | return [order.to_api_obj() for order in orders] |
||
| 1199 | return [] |
||
| 1200 | |||
| 1201 | @api_method |
||
| 1202 | def get_order(self, order_id): |
||
| 1203 | if order_id in self.blotter.orders: |
||
| 1204 | return self.blotter.orders[order_id].to_api_obj() |
||
| 1205 | |||
| 1206 | @api_method |
||
| 1207 | def cancel_order(self, order_param): |
||
| 1208 | order_id = order_param |
||
| 1209 | if isinstance(order_param, zipline.protocol.Order): |
||
| 1210 | order_id = order_param.id |
||
| 1211 | |||
| 1212 | self.blotter.cancel(order_id) |
||
| 1213 | |||
| 1214 | @api_method |
||
| 1215 | def add_history(self, bar_count, frequency, field, ffill=True): |
||
| 1216 | data_frequency = self.sim_params.data_frequency |
||
| 1217 | history_spec = HistorySpec(bar_count, frequency, field, ffill, |
||
| 1218 | data_frequency=data_frequency, |
||
| 1219 | env=self.trading_environment) |
||
| 1220 | self.history_specs[history_spec.key_str] = history_spec |
||
| 1221 | if self.initialized: |
||
| 1222 | if self.history_container: |
||
| 1223 | self.history_container.ensure_spec( |
||
| 1224 | history_spec, self.datetime, self._most_recent_data, |
||
| 1225 | ) |
||
| 1226 | else: |
||
| 1227 | self.history_container = self.history_container_class( |
||
| 1228 | self.history_specs, |
||
| 1229 | self.current_universe(), |
||
| 1230 | self.sim_params.first_open, |
||
| 1231 | self.sim_params.data_frequency, |
||
| 1232 | env=self.trading_environment, |
||
| 1233 | ) |
||
| 1234 | |||
| 1235 | def get_history_spec(self, bar_count, frequency, field, ffill): |
||
| 1236 | spec_key = HistorySpec.spec_key(bar_count, frequency, field, ffill) |
||
| 1237 | if spec_key not in self.history_specs: |
||
| 1238 | data_freq = self.sim_params.data_frequency |
||
| 1239 | spec = HistorySpec( |
||
| 1240 | bar_count, |
||
| 1241 | frequency, |
||
| 1242 | field, |
||
| 1243 | ffill, |
||
| 1244 | data_frequency=data_freq, |
||
| 1245 | env=self.trading_environment, |
||
| 1246 | ) |
||
| 1247 | self.history_specs[spec_key] = spec |
||
| 1248 | if not self.history_container: |
||
| 1249 | self.history_container = self.history_container_class( |
||
| 1250 | self.history_specs, |
||
| 1251 | self.current_universe(), |
||
| 1252 | self.datetime, |
||
| 1253 | self.sim_params.data_frequency, |
||
| 1254 | bar_data=self._most_recent_data, |
||
| 1255 | env=self.trading_environment, |
||
| 1256 | ) |
||
| 1257 | self.history_container.ensure_spec( |
||
| 1258 | spec, self.datetime, self._most_recent_data, |
||
| 1259 | ) |
||
| 1260 | return self.history_specs[spec_key] |
||
| 1261 | |||
| 1262 | @api_method |
||
| 1263 | @require_initialized(HistoryInInitialize()) |
||
| 1264 | def history(self, bar_count, frequency, field, ffill=True): |
||
| 1265 | history_spec = self.get_history_spec( |
||
| 1266 | bar_count, |
||
| 1267 | frequency, |
||
| 1268 | field, |
||
| 1269 | ffill, |
||
| 1270 | ) |
||
| 1271 | return self.history_container.get_history(history_spec, self.datetime) |
||
| 1272 | |||
| 1273 | #################### |
||
| 1274 | # Account Controls # |
||
| 1275 | #################### |
||
| 1276 | |||
| 1277 | def register_account_control(self, control): |
||
| 1278 | """ |
||
| 1279 | Register a new AccountControl to be checked on each bar. |
||
| 1280 | """ |
||
| 1281 | if self.initialized: |
||
| 1282 | raise RegisterAccountControlPostInit() |
||
| 1283 | self.account_controls.append(control) |
||
| 1284 | |||
| 1285 | def validate_account_controls(self): |
||
| 1286 | for control in self.account_controls: |
||
| 1287 | control.validate(self.updated_portfolio(), |
||
| 1288 | self.updated_account(), |
||
| 1289 | self.get_datetime(), |
||
| 1290 | self.trading_client.current_data) |
||
| 1291 | |||
| 1292 | @api_method |
||
| 1293 | def set_max_leverage(self, max_leverage=None): |
||
| 1294 | """ |
||
| 1295 | Set a limit on the maximum leverage of the algorithm. |
||
| 1296 | """ |
||
| 1297 | control = MaxLeverage(max_leverage) |
||
| 1298 | self.register_account_control(control) |
||
| 1299 | |||
| 1300 | #################### |
||
| 1301 | # Trading Controls # |
||
| 1302 | #################### |
||
| 1303 | |||
| 1304 | def register_trading_control(self, control): |
||
| 1305 | """ |
||
| 1306 | Register a new TradingControl to be checked prior to order calls. |
||
| 1307 | """ |
||
| 1308 | if self.initialized: |
||
| 1309 | raise RegisterTradingControlPostInit() |
||
| 1310 | self.trading_controls.append(control) |
||
| 1311 | |||
| 1312 | @api_method |
||
| 1313 | def set_max_position_size(self, |
||
| 1314 | sid=None, |
||
| 1315 | max_shares=None, |
||
| 1316 | max_notional=None): |
||
| 1317 | """ |
||
| 1318 | Set a limit on the number of shares and/or dollar value held for the |
||
| 1319 | given sid. Limits are treated as absolute values and are enforced at |
||
| 1320 | the time that the algo attempts to place an order for sid. This means |
||
| 1321 | that it's possible to end up with more than the max number of shares |
||
| 1322 | due to splits/dividends, and more than the max notional due to price |
||
| 1323 | improvement. |
||
| 1324 | |||
| 1325 | If an algorithm attempts to place an order that would result in |
||
| 1326 | increasing the absolute value of shares/dollar value exceeding one of |
||
| 1327 | these limits, raise a TradingControlException. |
||
| 1328 | """ |
||
| 1329 | control = MaxPositionSize(asset=sid, |
||
| 1330 | max_shares=max_shares, |
||
| 1331 | max_notional=max_notional) |
||
| 1332 | self.register_trading_control(control) |
||
| 1333 | |||
| 1334 | @api_method |
||
| 1335 | def set_max_order_size(self, sid=None, max_shares=None, max_notional=None): |
||
| 1336 | """ |
||
| 1337 | Set a limit on the number of shares and/or dollar value of any single |
||
| 1338 | order placed for sid. Limits are treated as absolute values and are |
||
| 1339 | enforced at the time that the algo attempts to place an order for sid. |
||
| 1340 | |||
| 1341 | If an algorithm attempts to place an order that would result in |
||
| 1342 | exceeding one of these limits, raise a TradingControlException. |
||
| 1343 | """ |
||
| 1344 | control = MaxOrderSize(asset=sid, |
||
| 1345 | max_shares=max_shares, |
||
| 1346 | max_notional=max_notional) |
||
| 1347 | self.register_trading_control(control) |
||
| 1348 | |||
| 1349 | @api_method |
||
| 1350 | def set_max_order_count(self, max_count): |
||
| 1351 | """ |
||
| 1352 | Set a limit on the number of orders that can be placed within the given |
||
| 1353 | time interval. |
||
| 1354 | """ |
||
| 1355 | control = MaxOrderCount(max_count) |
||
| 1356 | self.register_trading_control(control) |
||
| 1357 | |||
| 1358 | @api_method |
||
| 1359 | def set_do_not_order_list(self, restricted_list): |
||
| 1360 | """ |
||
| 1361 | Set a restriction on which sids can be ordered. |
||
| 1362 | """ |
||
| 1363 | control = RestrictedListOrder(restricted_list) |
||
| 1364 | self.register_trading_control(control) |
||
| 1365 | |||
| 1366 | @api_method |
||
| 1367 | def set_long_only(self): |
||
| 1368 | """ |
||
| 1369 | Set a rule specifying that this algorithm cannot take short positions. |
||
| 1370 | """ |
||
| 1371 | self.register_trading_control(LongOnly()) |
||
| 1372 | |||
| 1373 | ############## |
||
| 1374 | # Pipeline API |
||
| 1375 | ############## |
||
| 1376 | @api_method |
||
| 1377 | @require_not_initialized(AttachPipelineAfterInitialize()) |
||
| 1378 | def attach_pipeline(self, pipeline, name, chunksize=None): |
||
| 1379 | """ |
||
| 1380 | Register a pipeline to be computed at the start of each day. |
||
| 1381 | """ |
||
| 1382 | if self._pipelines: |
||
| 1383 | raise NotImplementedError("Multiple pipelines are not supported.") |
||
| 1384 | if chunksize is None: |
||
| 1385 | # Make the first chunk smaller to get more immediate results: |
||
| 1386 | # (one week, then every half year) |
||
| 1387 | chunks = iter(chain([5], repeat(126))) |
||
| 1388 | else: |
||
| 1389 | chunks = iter(repeat(int(chunksize))) |
||
| 1390 | self._pipelines[name] = pipeline, chunks |
||
| 1391 | |||
| 1392 | # Return the pipeline to allow expressions like |
||
| 1393 | # p = attach_pipeline(Pipeline(), 'name') |
||
| 1394 | return pipeline |
||
| 1395 | |||
| 1396 | @api_method |
||
| 1397 | @require_initialized(PipelineOutputDuringInitialize()) |
||
| 1398 | def pipeline_output(self, name): |
||
| 1399 | """ |
||
| 1400 | Get the results of pipeline with name `name`. |
||
| 1401 | |||
| 1402 | Parameters |
||
| 1403 | ---------- |
||
| 1404 | name : str |
||
| 1405 | Name of the pipeline for which results are requested. |
||
| 1406 | |||
| 1407 | Returns |
||
| 1408 | ------- |
||
| 1409 | results : pd.DataFrame |
||
| 1410 | DataFrame containing the results of the requested pipeline for |
||
| 1411 | the current simulation date. |
||
| 1412 | |||
| 1413 | Raises |
||
| 1414 | ------ |
||
| 1415 | NoSuchPipeline |
||
| 1416 | Raised when no pipeline with the name `name` has been registered. |
||
| 1417 | |||
| 1418 | See Also |
||
| 1419 | -------- |
||
| 1420 | :meth:`zipline.pipeline.engine.PipelineEngine.run_pipeline` |
||
| 1421 | """ |
||
| 1422 | # NOTE: We don't currently support multiple pipelines, but we plan to |
||
| 1423 | # in the future. |
||
| 1424 | try: |
||
| 1425 | p, chunks = self._pipelines[name] |
||
| 1426 | except KeyError: |
||
| 1427 | raise NoSuchPipeline( |
||
| 1428 | name=name, |
||
| 1429 | valid=list(self._pipelines.keys()), |
||
| 1430 | ) |
||
| 1431 | return self._pipeline_output(p, chunks) |
||
| 1432 | |||
| 1433 | def _pipeline_output(self, pipeline, chunks): |
||
| 1434 | """ |
||
| 1435 | Internal implementation of `pipeline_output`. |
||
| 1436 | """ |
||
| 1437 | today = normalize_date(self.get_datetime()) |
||
| 1438 | try: |
||
| 1439 | data = self._pipeline_cache.unwrap(today) |
||
| 1440 | except Expired: |
||
| 1441 | data, valid_until = self._run_pipeline( |
||
| 1442 | pipeline, today, next(chunks), |
||
| 1443 | ) |
||
| 1444 | self._pipeline_cache = CachedObject(data, valid_until) |
||
| 1445 | |||
| 1446 | # Now that we have a cached result, try to return the data for today. |
||
| 1447 | try: |
||
| 1448 | return data.loc[today] |
||
| 1449 | except KeyError: |
||
| 1450 | # This happens if no assets passed the pipeline screen on a given |
||
| 1451 | # day. |
||
| 1452 | return pd.DataFrame(index=[], columns=data.columns) |
||
| 1453 | |||
| 1454 | def _run_pipeline(self, pipeline, start_date, chunksize): |
||
| 1455 | """ |
||
| 1456 | Compute `pipeline`, providing values for at least `start_date`. |
||
| 1457 | |||
| 1458 | Produces a DataFrame containing data for days between `start_date` and |
||
| 1459 | `end_date`, where `end_date` is defined by: |
||
| 1460 | |||
| 1461 | `end_date = min(start_date + chunksize trading days, |
||
| 1462 | simulation_end)` |
||
| 1463 | |||
| 1464 | Returns |
||
| 1465 | ------- |
||
| 1466 | (data, valid_until) : tuple (pd.DataFrame, pd.Timestamp) |
||
| 1467 | |||
| 1468 | See Also |
||
| 1469 | -------- |
||
| 1470 | PipelineEngine.run_pipeline |
||
| 1471 | """ |
||
| 1472 | days = self.trading_environment.trading_days |
||
| 1473 | |||
| 1474 | # Load data starting from the previous trading day... |
||
| 1475 | start_date_loc = days.get_loc(start_date) |
||
| 1476 | |||
| 1477 | # ...continuing until either the day before the simulation end, or |
||
| 1478 | # until chunksize days of data have been loaded. |
||
| 1479 | sim_end = self.sim_params.last_close.normalize() |
||
| 1480 | end_loc = min(start_date_loc + chunksize, days.get_loc(sim_end)) |
||
| 1481 | end_date = days[end_loc] |
||
| 1482 | |||
| 1483 | return \ |
||
| 1484 | self.engine.run_pipeline(pipeline, start_date, end_date), end_date |
||
| 1485 | |||
| 1486 | ################## |
||
| 1487 | # End Pipeline API |
||
| 1488 | ################## |
||
| 1489 | |||
| 1490 | def current_universe(self): |
||
| 1491 | return self._current_universe |
||
| 1492 | |||
| 1493 | @classmethod |
||
| 1494 | def all_api_methods(cls): |
||
| 1495 | """ |
||
| 1496 | Return a list of all the TradingAlgorithm API methods. |
||
| 1497 | """ |
||
| 1498 | return [ |
||
| 1499 | fn for fn in itervalues(vars(cls)) |
||
| 1500 | if getattr(fn, 'is_api_method', False) |
||
| 1501 | ] |
||
| 1502 |