Total Complexity | 157 |
Total Lines | 1239 |
Duplicated Lines | 0 % |
Complex classes like zipline.TradingAlgorithm often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
1 | # |
||
109 | class TradingAlgorithm(object): |
||
110 | """ |
||
111 | Base class for trading algorithms. Inherit and overload |
||
112 | initialize() and handle_data(data). |
||
113 | |||
114 | A new algorithm could look like this: |
||
115 | ``` |
||
116 | from zipline.api import order, symbol |
||
117 | |||
118 | def initialize(context): |
||
119 | context.sid = symbol('AAPL') |
||
120 | context.amount = 100 |
||
121 | |||
122 | def handle_data(context, data): |
||
123 | sid = context.sid |
||
124 | amount = context.amount |
||
125 | order(sid, amount) |
||
126 | ``` |
||
127 | To then to run this algorithm pass these functions to |
||
128 | TradingAlgorithm: |
||
129 | |||
130 | my_algo = TradingAlgorithm(initialize, handle_data) |
||
131 | stats = my_algo.run(data) |
||
132 | |||
133 | """ |
||
134 | |||
135 | def __init__(self, *args, **kwargs): |
||
136 | """Initialize sids and other state variables. |
||
137 | |||
138 | :Arguments: |
||
139 | :Optional: |
||
140 | initialize : function |
||
141 | Function that is called with a single |
||
142 | argument at the begninning of the simulation. |
||
143 | handle_data : function |
||
144 | Function that is called with 2 arguments |
||
145 | (context and data) on every bar. |
||
146 | script : str |
||
147 | Algoscript that contains initialize and |
||
148 | handle_data function definition. |
||
149 | data_frequency : {'daily', 'minute'} |
||
150 | The duration of the bars. |
||
151 | capital_base : float <default: 1.0e5> |
||
152 | How much capital to start with. |
||
153 | instant_fill : bool <default: False> |
||
154 | Whether to fill orders immediately or on next bar. |
||
155 | asset_finder : An AssetFinder object |
||
156 | A new AssetFinder object to be used in this TradingEnvironment |
||
157 | equities_metadata : can be either: |
||
158 | - dict |
||
159 | - pandas.DataFrame |
||
160 | - object with 'read' property |
||
161 | If dict is provided, it must have the following structure: |
||
162 | * keys are the identifiers |
||
163 | * values are dicts containing the metadata, with the metadata |
||
164 | field name as the key |
||
165 | If pandas.DataFrame is provided, it must have the |
||
166 | following structure: |
||
167 | * column names must be the metadata fields |
||
168 | * index must be the different asset identifiers |
||
169 | * array contents should be the metadata value |
||
170 | If an object with a 'read' property is provided, 'read' must |
||
171 | return rows containing at least one of 'sid' or 'symbol' along |
||
172 | with the other metadata fields. |
||
173 | identifiers : List |
||
174 | Any asset identifiers that are not provided in the |
||
175 | equities_metadata, but will be traded by this TradingAlgorithm |
||
176 | """ |
||
177 | self.sources = [] |
||
178 | self.clock = None |
||
179 | |||
180 | # List of trading controls to be used to validate orders. |
||
181 | self.trading_controls = [] |
||
182 | |||
183 | # List of account controls to be checked on each bar. |
||
184 | self.account_controls = [] |
||
185 | |||
186 | self._recorded_vars = {} |
||
187 | self.namespace = kwargs.pop('namespace', {}) |
||
188 | |||
189 | self._platform = kwargs.pop('platform', 'zipline') |
||
190 | |||
191 | self.logger = None |
||
192 | |||
193 | self.benchmark_source = None |
||
194 | |||
195 | self.instant_fill = kwargs.pop('instant_fill', False) |
||
196 | |||
197 | # If an env has been provided, pop it |
||
198 | self.trading_environment = kwargs.pop('env', None) |
||
199 | |||
200 | if self.trading_environment is None: |
||
201 | self.trading_environment = TradingEnvironment() |
||
202 | |||
203 | self.data_portal = None |
||
204 | |||
205 | # Update the TradingEnvironment with the provided asset metadata |
||
206 | self.trading_environment.write_data( |
||
207 | equities_data=kwargs.pop('equities_metadata', {}), |
||
208 | equities_identifiers=kwargs.pop('identifiers', []), |
||
209 | futures_data=kwargs.pop('futures_metadata', {}), |
||
210 | ) |
||
211 | |||
212 | # set the capital base |
||
213 | self.capital_base = kwargs.pop('capital_base', DEFAULT_CAPITAL_BASE) |
||
214 | self.sim_params = kwargs.pop('sim_params', None) |
||
215 | if self.sim_params is None: |
||
216 | self.sim_params = create_simulation_parameters( |
||
217 | capital_base=self.capital_base, |
||
218 | start=kwargs.pop('start', None), |
||
219 | end=kwargs.pop('end', None), |
||
220 | env=self.trading_environment, |
||
221 | ) |
||
222 | else: |
||
223 | self.sim_params.update_internal_from_env(self.trading_environment) |
||
224 | |||
225 | self.perf_tracker = None |
||
226 | # Pull in the environment's new AssetFinder for quick reference |
||
227 | self.asset_finder = self.trading_environment.asset_finder |
||
228 | |||
229 | # Initialize Pipeline API data. |
||
230 | self.init_engine(kwargs.pop('get_pipeline_loader', None)) |
||
231 | self._pipelines = {} |
||
232 | # Create an always-expired cache so that we compute the first time data |
||
233 | # is requested. |
||
234 | self._pipeline_cache = CachedObject(None, pd.Timestamp(0, tz='UTC')) |
||
235 | |||
236 | self.blotter = kwargs.pop('blotter', None) |
||
237 | if not self.blotter: |
||
238 | self.blotter = Blotter( |
||
239 | slippage_func=VolumeShareSlippage(), |
||
240 | commission=PerShare() |
||
241 | ) |
||
242 | |||
243 | # The symbol lookup date specifies the date to use when resolving |
||
244 | # symbols to sids, and can be set using set_symbol_lookup_date() |
||
245 | self._symbol_lookup_date = None |
||
246 | |||
247 | self._portfolio = None |
||
248 | self._account = None |
||
249 | |||
250 | # If string is passed in, execute and get reference to |
||
251 | # functions. |
||
252 | self.algoscript = kwargs.pop('script', None) |
||
253 | |||
254 | self._initialize = None |
||
255 | self._before_trading_start = None |
||
256 | self._analyze = None |
||
257 | |||
258 | self.event_manager = EventManager( |
||
259 | create_context=kwargs.pop('create_event_context', None), |
||
260 | ) |
||
261 | |||
262 | if self.algoscript is not None: |
||
263 | filename = kwargs.pop('algo_filename', None) |
||
264 | if filename is None: |
||
265 | filename = '<string>' |
||
266 | code = compile(self.algoscript, filename, 'exec') |
||
267 | exec_(code, self.namespace) |
||
268 | self._initialize = self.namespace.get('initialize') |
||
269 | if 'handle_data' not in self.namespace: |
||
270 | raise ValueError('You must define a handle_data function.') |
||
271 | else: |
||
272 | self._handle_data = self.namespace['handle_data'] |
||
273 | |||
274 | self._before_trading_start = \ |
||
275 | self.namespace.get('before_trading_start') |
||
276 | # Optional analyze function, gets called after run |
||
277 | self._analyze = self.namespace.get('analyze') |
||
278 | |||
279 | elif kwargs.get('initialize') and kwargs.get('handle_data'): |
||
280 | if self.algoscript is not None: |
||
281 | raise ValueError('You can not set script and \ |
||
282 | initialize/handle_data.') |
||
283 | self._initialize = kwargs.pop('initialize') |
||
284 | self._handle_data = kwargs.pop('handle_data') |
||
285 | self._before_trading_start = kwargs.pop('before_trading_start', |
||
286 | None) |
||
287 | self._analyze = kwargs.pop('analyze', None) |
||
288 | |||
289 | self.event_manager.add_event( |
||
290 | zipline.utils.events.Event( |
||
291 | zipline.utils.events.Always(), |
||
292 | # We pass handle_data.__func__ to get the unbound method. |
||
293 | # We will explicitly pass the algorithm to bind it again. |
||
294 | self.handle_data.__func__, |
||
295 | ), |
||
296 | prepend=True, |
||
297 | ) |
||
298 | |||
299 | # If method not defined, NOOP |
||
300 | if self._initialize is None: |
||
301 | self._initialize = lambda x: None |
||
302 | |||
303 | # Alternative way of setting data_frequency for backwards |
||
304 | # compatibility. |
||
305 | if 'data_frequency' in kwargs: |
||
306 | self.data_frequency = kwargs.pop('data_frequency') |
||
307 | |||
308 | # Prepare the algo for initialization |
||
309 | self.initialized = False |
||
310 | self.initialize_args = args |
||
311 | self.initialize_kwargs = kwargs |
||
312 | |||
313 | self.benchmark_sid = kwargs.pop('benchmark_sid', None) |
||
314 | |||
315 | def init_engine(self, get_loader): |
||
316 | """ |
||
317 | Construct and store a PipelineEngine from loader. |
||
318 | |||
319 | If get_loader is None, constructs a NoOpPipelineEngine. |
||
320 | """ |
||
321 | if get_loader is not None: |
||
322 | self.engine = SimplePipelineEngine( |
||
323 | get_loader, |
||
324 | self.trading_environment.trading_days, |
||
325 | self.asset_finder, |
||
326 | ) |
||
327 | else: |
||
328 | self.engine = NoOpPipelineEngine() |
||
329 | |||
330 | def initialize(self, *args, **kwargs): |
||
331 | """ |
||
332 | Call self._initialize with `self` made available to Zipline API |
||
333 | functions. |
||
334 | """ |
||
335 | with ZiplineAPI(self): |
||
336 | self._initialize(self, *args, **kwargs) |
||
337 | |||
338 | def before_trading_start(self, data): |
||
339 | if self._before_trading_start is None: |
||
340 | return |
||
341 | |||
342 | self._before_trading_start(self, data) |
||
343 | |||
344 | def handle_data(self, data): |
||
345 | self._handle_data(self, data) |
||
346 | |||
347 | # Unlike trading controls which remain constant unless placing an |
||
348 | # order, account controls can change each bar. Thus, must check |
||
349 | # every bar no matter if the algorithm places an order or not. |
||
350 | self.validate_account_controls() |
||
351 | |||
352 | def analyze(self, perf): |
||
353 | if self._analyze is None: |
||
354 | return |
||
355 | |||
356 | with ZiplineAPI(self): |
||
357 | self._analyze(self, perf) |
||
358 | |||
359 | def __repr__(self): |
||
360 | """ |
||
361 | N.B. this does not yet represent a string that can be used |
||
362 | to instantiate an exact copy of an algorithm. |
||
363 | |||
364 | However, it is getting close, and provides some value as something |
||
365 | that can be inspected interactively. |
||
366 | """ |
||
367 | return """ |
||
368 | {class_name}( |
||
369 | capital_base={capital_base} |
||
370 | sim_params={sim_params}, |
||
371 | initialized={initialized}, |
||
372 | slippage={slippage}, |
||
373 | commission={commission}, |
||
374 | blotter={blotter}, |
||
375 | recorded_vars={recorded_vars}) |
||
376 | """.strip().format(class_name=self.__class__.__name__, |
||
377 | capital_base=self.capital_base, |
||
378 | sim_params=repr(self.sim_params), |
||
379 | initialized=self.initialized, |
||
380 | slippage=repr(self.blotter.slippage_func), |
||
381 | commission=repr(self.blotter.commission), |
||
382 | blotter=repr(self.blotter), |
||
383 | recorded_vars=repr(self.recorded_vars)) |
||
384 | |||
385 | def ensure_clock(self): |
||
386 | """ |
||
387 | If the clock property is not set, then create one based on frequency. |
||
388 | """ |
||
389 | if self.clock is None: |
||
390 | if self.sim_params.data_frequency == 'minute': |
||
391 | env = self.trading_environment |
||
392 | trading_o_and_c = env.open_and_closes.ix[ |
||
393 | self.sim_params.trading_days] |
||
394 | market_opens = trading_o_and_c['market_open'].values.astype( |
||
395 | 'datetime64[ns]').astype(np.int64) |
||
396 | market_closes = trading_o_and_c['market_close'].values.astype( |
||
397 | 'datetime64[ns]').astype(np.int64) |
||
398 | if self.sim_params.emission_rate == "daily": |
||
399 | self.clock = MinuteSimulationClock( |
||
400 | self.sim_params.trading_days, |
||
401 | market_opens, |
||
402 | market_closes, |
||
403 | self.data_portal, |
||
404 | env.trading_days |
||
405 | ) |
||
406 | else: |
||
407 | self.clock = MinuteEmissionClock( |
||
408 | self.sim_params.trading_days, |
||
409 | market_opens, |
||
410 | market_closes, |
||
411 | self.data_portal, |
||
412 | env.trading_days |
||
413 | ) |
||
414 | |||
415 | elif self.sim_params.data_frequency == 'daily': |
||
416 | self.clock = DailySimulationClock(self.sim_params.trading_days) |
||
417 | |||
418 | def create_benchmark_source(self): |
||
419 | return BenchmarkSource( |
||
420 | self.benchmark_sid, |
||
421 | self.trading_environment, |
||
422 | self.sim_params.trading_days, |
||
423 | self.data_portal, |
||
424 | emission_rate=self.sim_params.emission_rate, |
||
425 | ) |
||
426 | |||
427 | def _create_generator(self, sim_params): |
||
428 | """ |
||
429 | Create a basic generator setup using the sources to this algorithm. |
||
430 | |||
431 | ::source_filter:: is a method that receives events in date |
||
432 | sorted order, and returns True for those events that should be |
||
433 | processed by the zipline, and False for those that should be |
||
434 | skipped. |
||
435 | """ |
||
436 | if sim_params is not None: |
||
437 | self.sim_params = sim_params |
||
438 | |||
439 | if self.perf_tracker is None: |
||
440 | # Build a perf_tracker |
||
441 | self.perf_tracker = PerformanceTracker( |
||
442 | sim_params=self.sim_params, |
||
443 | env=self.trading_environment, |
||
444 | data_portal=self.data_portal) |
||
445 | # Set the dt initially to the period start by forcing it to change |
||
446 | self.on_dt_changed(self.sim_params.period_start) |
||
447 | |||
448 | # HACK: When running with the `run` method, we set perf_tracker to |
||
449 | # None so that it will be overwritten here. |
||
450 | self.perf_tracker = PerformanceTracker( |
||
451 | sim_params=sim_params, env=self.trading_environment, |
||
452 | data_portal=self.data_portal |
||
453 | ) |
||
454 | |||
455 | if not self.initialized: |
||
456 | self.initialize(*self.initialize_args, **self.initialize_kwargs) |
||
457 | self.initialized = True |
||
458 | |||
459 | self.ensure_clock() |
||
460 | |||
461 | self.trading_client = AlgorithmSimulator( |
||
462 | self, |
||
463 | sim_params, |
||
464 | self.data_portal, |
||
465 | self.clock, |
||
466 | self.create_benchmark_source() |
||
467 | ) |
||
468 | |||
469 | return self.trading_client.transform() |
||
470 | |||
471 | def get_generator(self): |
||
472 | """ |
||
473 | Override this method to add new logic to the construction |
||
474 | of the generator. Overrides can use the _create_generator |
||
475 | method to get a standard construction generator. |
||
476 | """ |
||
477 | return self._create_generator(self.sim_params) |
||
478 | |||
479 | def run(self, data_portal=None): |
||
480 | """Run the algorithm. |
||
481 | |||
482 | :Arguments: |
||
483 | source : DataPortal |
||
484 | |||
485 | :Returns: |
||
486 | daily_stats : pandas.DataFrame |
||
487 | Daily performance metrics such as returns, alpha etc. |
||
488 | |||
489 | """ |
||
490 | if self.data_portal is None: |
||
491 | self.data_portal = data_portal |
||
492 | |||
493 | self.ensure_clock() |
||
494 | |||
495 | # force a reset of the performance tracker, in case |
||
496 | # this is a repeat run of the algorithm. |
||
497 | self.perf_tracker = None |
||
498 | |||
499 | # create zipline |
||
500 | self.gen = self.get_generator() |
||
501 | |||
502 | # loop through simulated_trading, each iteration returns a |
||
503 | # perf dictionary |
||
504 | perfs = [] |
||
505 | for perf in self.gen: |
||
506 | perfs.append(perf) |
||
507 | |||
508 | # convert perf dict to pandas dataframe |
||
509 | daily_stats = self._create_daily_stats(perfs) |
||
510 | |||
511 | self.analyze(daily_stats) |
||
512 | |||
513 | return daily_stats |
||
514 | |||
515 | def _write_and_map_id_index_to_sids(self, identifiers, as_of_date): |
||
516 | # Build new Assets for identifiers that can't be resolved as |
||
517 | # sids/Assets |
||
518 | identifiers_to_build = [] |
||
519 | for identifier in identifiers: |
||
520 | asset = None |
||
521 | |||
522 | if isinstance(identifier, Asset): |
||
523 | asset = self.asset_finder.retrieve_asset(sid=identifier.sid, |
||
524 | default_none=True) |
||
525 | elif isinstance(identifier, Integral): |
||
526 | asset = self.asset_finder.retrieve_asset(sid=identifier, |
||
527 | default_none=True) |
||
528 | if asset is None: |
||
529 | identifiers_to_build.append(identifier) |
||
530 | |||
531 | self.trading_environment.write_data( |
||
532 | equities_identifiers=identifiers_to_build) |
||
533 | |||
534 | # We need to clear out any cache misses that were stored while trying |
||
535 | # to do lookups. The real fix for this problem is to not construct an |
||
536 | # AssetFinder until we `run()` when we actually have all the data we |
||
537 | # need to so. |
||
538 | self.asset_finder._reset_caches() |
||
539 | |||
540 | return self.asset_finder.map_identifier_index_to_sids( |
||
541 | identifiers, as_of_date, |
||
542 | ) |
||
543 | |||
544 | def _create_daily_stats(self, perfs): |
||
545 | # create daily and cumulative stats dataframe |
||
546 | daily_perfs = [] |
||
547 | # TODO: the loop here could overwrite expected properties |
||
548 | # of daily_perf. Could potentially raise or log a |
||
549 | # warning. |
||
550 | for perf in perfs: |
||
551 | if 'daily_perf' in perf: |
||
552 | |||
553 | perf['daily_perf'].update( |
||
554 | perf['daily_perf'].pop('recorded_vars') |
||
555 | ) |
||
556 | perf['daily_perf'].update(perf['cumulative_risk_metrics']) |
||
557 | daily_perfs.append(perf['daily_perf']) |
||
558 | else: |
||
559 | self.risk_report = perf |
||
560 | |||
561 | daily_dts = [np.datetime64(perf['period_close'], utc=True) |
||
562 | for perf in daily_perfs] |
||
563 | daily_stats = pd.DataFrame(daily_perfs, index=daily_dts) |
||
564 | |||
565 | return daily_stats |
||
566 | |||
567 | @api_method |
||
568 | def get_environment(self, field='platform'): |
||
569 | env = { |
||
570 | 'arena': self.sim_params.arena, |
||
571 | 'data_frequency': self.sim_params.data_frequency, |
||
572 | 'start': self.sim_params.first_open, |
||
573 | 'end': self.sim_params.last_close, |
||
574 | 'capital_base': self.sim_params.capital_base, |
||
575 | 'platform': self._platform |
||
576 | } |
||
577 | if field == '*': |
||
578 | return env |
||
579 | else: |
||
580 | return env[field] |
||
581 | |||
582 | @api_method |
||
583 | def fetch_csv(self, url, |
||
584 | pre_func=None, |
||
585 | post_func=None, |
||
586 | date_column='date', |
||
587 | date_format=None, |
||
588 | timezone=pytz.utc.zone, |
||
589 | symbol=None, |
||
590 | mask=True, |
||
591 | symbol_column=None, |
||
592 | special_params_checker=None, |
||
593 | **kwargs): |
||
594 | |||
595 | # Show all the logs every time fetcher is used. |
||
596 | csv_data_source = PandasRequestsCSV( |
||
597 | url, |
||
598 | pre_func, |
||
599 | post_func, |
||
600 | self.trading_environment, |
||
601 | self.sim_params.period_start, |
||
602 | self.sim_params.period_end, |
||
603 | date_column, |
||
604 | date_format, |
||
605 | timezone, |
||
606 | symbol, |
||
607 | mask, |
||
608 | symbol_column, |
||
609 | data_frequency=self.data_frequency, |
||
610 | special_params_checker=special_params_checker, |
||
611 | **kwargs |
||
612 | ) |
||
613 | |||
614 | # ingest this into dataportal |
||
615 | self.data_portal.handle_extra_source(csv_data_source.df) |
||
616 | |||
617 | return csv_data_source |
||
618 | |||
619 | def add_event(self, rule=None, callback=None): |
||
620 | """ |
||
621 | Adds an event to the algorithm's EventManager. |
||
622 | """ |
||
623 | self.event_manager.add_event( |
||
624 | zipline.utils.events.Event(rule, callback), |
||
625 | ) |
||
626 | |||
627 | @api_method |
||
628 | def schedule_function(self, |
||
629 | func, |
||
630 | date_rule=None, |
||
631 | time_rule=None, |
||
632 | half_days=True): |
||
633 | """ |
||
634 | Schedules a function to be called with some timed rules. |
||
635 | """ |
||
636 | date_rule = date_rule or DateRuleFactory.every_day() |
||
637 | time_rule = ((time_rule or TimeRuleFactory.market_open()) |
||
638 | if self.sim_params.data_frequency == 'minute' else |
||
639 | # If we are in daily mode the time_rule is ignored. |
||
640 | zipline.utils.events.Always()) |
||
641 | |||
642 | self.add_event( |
||
643 | make_eventrule(date_rule, time_rule, half_days), |
||
644 | func, |
||
645 | ) |
||
646 | |||
647 | @api_method |
||
648 | def record(self, *args, **kwargs): |
||
649 | """ |
||
650 | Track and record local variable (i.e. attributes) each day. |
||
651 | """ |
||
652 | # Make 2 objects both referencing the same iterator |
||
653 | args = [iter(args)] * 2 |
||
654 | |||
655 | # Zip generates list entries by calling `next` on each iterator it |
||
656 | # receives. In this case the two iterators are the same object, so the |
||
657 | # call to next on args[0] will also advance args[1], resulting in zip |
||
658 | # returning (a,b) (c,d) (e,f) rather than (a,a) (b,b) (c,c) etc. |
||
659 | positionals = zip(*args) |
||
660 | for name, value in chain(positionals, iteritems(kwargs)): |
||
661 | self._recorded_vars[name] = value |
||
662 | |||
663 | @api_method |
||
664 | def set_benchmark(self, benchmark_sid): |
||
665 | if self.initialized: |
||
666 | raise SetBenchmarkOutsideInitialize() |
||
667 | |||
668 | self.benchmark_sid = benchmark_sid |
||
669 | |||
670 | @api_method |
||
671 | @preprocess(symbol_str=ensure_upper_case) |
||
672 | def symbol(self, symbol_str): |
||
673 | """ |
||
674 | Default symbol lookup for any source that directly maps the |
||
675 | symbol to the Asset (e.g. yahoo finance). |
||
676 | """ |
||
677 | # If the user has not set the symbol lookup date, |
||
678 | # use the period_end as the date for sybmol->sid resolution. |
||
679 | _lookup_date = self._symbol_lookup_date if self._symbol_lookup_date is not None \ |
||
680 | else self.sim_params.period_end |
||
681 | |||
682 | return self.asset_finder.lookup_symbol( |
||
683 | symbol_str, |
||
684 | as_of_date=_lookup_date, |
||
685 | ) |
||
686 | |||
687 | @api_method |
||
688 | def symbols(self, *args): |
||
689 | """ |
||
690 | Default symbols lookup for any source that directly maps the |
||
691 | symbol to the Asset (e.g. yahoo finance). |
||
692 | """ |
||
693 | return [self.symbol(identifier) for identifier in args] |
||
694 | |||
695 | @api_method |
||
696 | def sid(self, a_sid): |
||
697 | """ |
||
698 | Default sid lookup for any source that directly maps the integer sid |
||
699 | to the Asset. |
||
700 | """ |
||
701 | return self.asset_finder.retrieve_asset(a_sid) |
||
702 | |||
703 | @api_method |
||
704 | @preprocess(symbol=ensure_upper_case) |
||
705 | def future_symbol(self, symbol): |
||
706 | """ Lookup a futures contract with a given symbol. |
||
707 | |||
708 | Parameters |
||
709 | ---------- |
||
710 | symbol : str |
||
711 | The symbol of the desired contract. |
||
712 | |||
713 | Returns |
||
714 | ------- |
||
715 | Future |
||
716 | A Future object. |
||
717 | |||
718 | Raises |
||
719 | ------ |
||
720 | SymbolNotFound |
||
721 | Raised when no contract named 'symbol' is found. |
||
722 | |||
723 | """ |
||
724 | return self.asset_finder.lookup_future_symbol(symbol) |
||
725 | |||
726 | @api_method |
||
727 | @preprocess(root_symbol=ensure_upper_case) |
||
728 | def future_chain(self, root_symbol, as_of_date=None): |
||
729 | """ Look up a future chain with the specified parameters. |
||
730 | |||
731 | Parameters |
||
732 | ---------- |
||
733 | root_symbol : str |
||
734 | The root symbol of a future chain. |
||
735 | as_of_date : datetime.datetime or pandas.Timestamp or str, optional |
||
736 | Date at which the chain determination is rooted. I.e. the |
||
737 | existing contract whose notice date is first after this date is |
||
738 | the primary contract, etc. |
||
739 | |||
740 | Returns |
||
741 | ------- |
||
742 | FutureChain |
||
743 | The future chain matching the specified parameters. |
||
744 | |||
745 | Raises |
||
746 | ------ |
||
747 | RootSymbolNotFound |
||
748 | If a future chain could not be found for the given root symbol. |
||
749 | """ |
||
750 | if as_of_date: |
||
751 | try: |
||
752 | as_of_date = pd.Timestamp(as_of_date, tz='UTC') |
||
753 | except ValueError: |
||
754 | raise UnsupportedDatetimeFormat(input=as_of_date, |
||
755 | method='future_chain') |
||
756 | return FutureChain( |
||
757 | asset_finder=self.asset_finder, |
||
758 | get_datetime=self.get_datetime, |
||
759 | root_symbol=root_symbol, |
||
760 | as_of_date=as_of_date |
||
761 | ) |
||
762 | |||
763 | def _calculate_order_value_amount(self, asset, value): |
||
764 | """ |
||
765 | Calculates how many shares/contracts to order based on the type of |
||
766 | asset being ordered. |
||
767 | """ |
||
768 | last_price = self.trading_client.current_data[asset].price |
||
769 | |||
770 | if tolerant_equals(last_price, 0): |
||
771 | zero_message = "Price of 0 for {psid}; can't infer value".format( |
||
772 | psid=asset |
||
773 | ) |
||
774 | if self.logger: |
||
775 | self.logger.debug(zero_message) |
||
776 | # Don't place any order |
||
777 | return 0 |
||
778 | |||
779 | if isinstance(asset, Future): |
||
780 | value_multiplier = asset.contract_multiplier |
||
781 | else: |
||
782 | value_multiplier = 1 |
||
783 | |||
784 | return value / (last_price * value_multiplier) |
||
785 | |||
786 | @api_method |
||
787 | def order(self, sid, amount, |
||
788 | limit_price=None, |
||
789 | stop_price=None, |
||
790 | style=None): |
||
791 | """ |
||
792 | Place an order using the specified parameters. |
||
793 | """ |
||
794 | # Truncate to the integer share count that's either within .0001 of |
||
795 | # amount or closer to zero. |
||
796 | # E.g. 3.9999 -> 4.0; 5.5 -> 5.0; -5.5 -> -5.0 |
||
797 | amount = int(round_if_near_integer(amount)) |
||
798 | |||
799 | # Raises a ZiplineError if invalid parameters are detected. |
||
800 | self.validate_order_params(sid, |
||
801 | amount, |
||
802 | limit_price, |
||
803 | stop_price, |
||
804 | style) |
||
805 | |||
806 | # Convert deprecated limit_price and stop_price parameters to use |
||
807 | # ExecutionStyle objects. |
||
808 | style = self.__convert_order_params_for_blotter(limit_price, |
||
809 | stop_price, |
||
810 | style) |
||
811 | return self.blotter.order(sid, amount, style) |
||
812 | |||
813 | def validate_order_params(self, |
||
814 | asset, |
||
815 | amount, |
||
816 | limit_price, |
||
817 | stop_price, |
||
818 | style): |
||
819 | """ |
||
820 | Helper method for validating parameters to the order API function. |
||
821 | |||
822 | Raises an UnsupportedOrderParameters if invalid arguments are found. |
||
823 | """ |
||
824 | |||
825 | if not self.initialized: |
||
826 | raise OrderDuringInitialize( |
||
827 | msg="order() can only be called from within handle_data()" |
||
828 | ) |
||
829 | |||
830 | if style: |
||
831 | if limit_price: |
||
832 | raise UnsupportedOrderParameters( |
||
833 | msg="Passing both limit_price and style is not supported." |
||
834 | ) |
||
835 | |||
836 | if stop_price: |
||
837 | raise UnsupportedOrderParameters( |
||
838 | msg="Passing both stop_price and style is not supported." |
||
839 | ) |
||
840 | |||
841 | if not isinstance(asset, Asset): |
||
842 | raise UnsupportedOrderParameters( |
||
843 | msg="Passing non-Asset argument to 'order()' is not supported." |
||
844 | " Use 'sid()' or 'symbol()' methods to look up an Asset." |
||
845 | ) |
||
846 | |||
847 | for control in self.trading_controls: |
||
848 | control.validate(asset, |
||
849 | amount, |
||
850 | self.portfolio, |
||
851 | self.get_datetime(), |
||
852 | self.trading_client.current_data) |
||
853 | |||
854 | @staticmethod |
||
855 | def __convert_order_params_for_blotter(limit_price, stop_price, style): |
||
856 | """ |
||
857 | Helper method for converting deprecated limit_price and stop_price |
||
858 | arguments into ExecutionStyle instances. |
||
859 | |||
860 | This function assumes that either style == None or (limit_price, |
||
861 | stop_price) == (None, None). |
||
862 | """ |
||
863 | # TODO_SS: DeprecationWarning for usage of limit_price and stop_price. |
||
864 | if style: |
||
865 | assert (limit_price, stop_price) == (None, None) |
||
866 | return style |
||
867 | if limit_price and stop_price: |
||
868 | return StopLimitOrder(limit_price, stop_price) |
||
869 | if limit_price: |
||
870 | return LimitOrder(limit_price) |
||
871 | if stop_price: |
||
872 | return StopOrder(stop_price) |
||
873 | else: |
||
874 | return MarketOrder() |
||
875 | |||
876 | @api_method |
||
877 | def order_value(self, sid, value, |
||
878 | limit_price=None, stop_price=None, style=None): |
||
879 | """ |
||
880 | Place an order by desired value rather than desired number of shares. |
||
881 | If the requested sid exists, the requested value is |
||
882 | divided by its price to imply the number of shares to transact. |
||
883 | If the Asset being ordered is a Future, the 'value' calculated |
||
884 | is actually the exposure, as Futures have no 'value'. |
||
885 | |||
886 | value > 0 :: Buy/Cover |
||
887 | value < 0 :: Sell/Short |
||
888 | Market order: order(sid, value) |
||
889 | Limit order: order(sid, value, limit_price) |
||
890 | Stop order: order(sid, value, None, stop_price) |
||
891 | StopLimit order: order(sid, value, limit_price, stop_price) |
||
892 | """ |
||
893 | amount = self._calculate_order_value_amount(sid, value) |
||
894 | return self.order(sid, amount, |
||
895 | limit_price=limit_price, |
||
896 | stop_price=stop_price, |
||
897 | style=style) |
||
898 | |||
899 | @property |
||
900 | def recorded_vars(self): |
||
901 | return copy(self._recorded_vars) |
||
902 | |||
903 | @property |
||
904 | def portfolio(self): |
||
905 | return self.updated_portfolio() |
||
906 | |||
907 | def updated_portfolio(self): |
||
908 | if self._portfolio is None and self.perf_tracker is not None: |
||
909 | self._portfolio = \ |
||
910 | self.perf_tracker.get_portfolio(self.datetime) |
||
911 | return self._portfolio |
||
912 | |||
913 | @property |
||
914 | def account(self): |
||
915 | return self.updated_account() |
||
916 | |||
917 | def updated_account(self): |
||
918 | if self._account is None and self.perf_tracker is not None: |
||
919 | self._account = \ |
||
920 | self.perf_tracker.get_account(self.datetime) |
||
921 | return self._account |
||
922 | |||
923 | def set_logger(self, logger): |
||
924 | self.logger = logger |
||
925 | |||
926 | def on_dt_changed(self, dt): |
||
927 | """ |
||
928 | Callback triggered by the simulation loop whenever the current dt |
||
929 | changes. |
||
930 | |||
931 | Any logic that should happen exactly once at the start of each datetime |
||
932 | group should happen here. |
||
933 | """ |
||
934 | assert isinstance(dt, datetime), \ |
||
935 | "Attempt to set algorithm's current time with non-datetime" |
||
936 | assert dt.tzinfo == pytz.utc, \ |
||
937 | "Algorithm expects a utc datetime" |
||
938 | |||
939 | self.datetime = dt |
||
940 | self.perf_tracker.set_date(dt) |
||
941 | self.blotter.set_date(dt) |
||
942 | |||
943 | self._portfolio = None |
||
944 | self._account = None |
||
945 | |||
946 | @api_method |
||
947 | def get_datetime(self, tz=None): |
||
948 | """ |
||
949 | Returns the simulation datetime. |
||
950 | """ |
||
951 | dt = self.datetime |
||
952 | assert dt.tzinfo == pytz.utc, "Algorithm should have a utc datetime" |
||
953 | |||
954 | if tz is not None: |
||
955 | # Convert to the given timezone passed as a string or tzinfo. |
||
956 | if isinstance(tz, string_types): |
||
957 | tz = pytz.timezone(tz) |
||
958 | dt = dt.astimezone(tz) |
||
959 | |||
960 | return dt # datetime.datetime objects are immutable. |
||
961 | |||
962 | def update_dividends(self, dividend_frame): |
||
963 | """ |
||
964 | Set DataFrame used to process dividends. DataFrame columns should |
||
965 | contain at least the entries in zp.DIVIDEND_FIELDS. |
||
966 | """ |
||
967 | self.perf_tracker.update_dividends(dividend_frame) |
||
968 | |||
969 | @api_method |
||
970 | def set_slippage(self, slippage): |
||
971 | if not isinstance(slippage, SlippageModel): |
||
972 | raise UnsupportedSlippageModel() |
||
973 | if self.initialized: |
||
974 | raise OverrideSlippagePostInit() |
||
975 | self.blotter.slippage_func = slippage |
||
976 | |||
977 | @api_method |
||
978 | def set_commission(self, commission): |
||
979 | if not isinstance(commission, (PerShare, PerTrade, PerDollar)): |
||
980 | raise UnsupportedCommissionModel() |
||
981 | |||
982 | if self.initialized: |
||
983 | raise OverrideCommissionPostInit() |
||
984 | self.blotter.commission = commission |
||
985 | |||
986 | @api_method |
||
987 | def set_symbol_lookup_date(self, dt): |
||
988 | """ |
||
989 | Set the date for which symbols will be resolved to their sids |
||
990 | (symbols may map to different firms or underlying assets at |
||
991 | different times) |
||
992 | """ |
||
993 | try: |
||
994 | self._symbol_lookup_date = pd.Timestamp(dt, tz='UTC') |
||
995 | except ValueError: |
||
996 | raise UnsupportedDatetimeFormat(input=dt, |
||
997 | method='set_symbol_lookup_date') |
||
998 | |||
999 | # Remain backwards compatibility |
||
1000 | @property |
||
1001 | def data_frequency(self): |
||
1002 | return self.sim_params.data_frequency |
||
1003 | |||
1004 | @data_frequency.setter |
||
1005 | def data_frequency(self, value): |
||
1006 | assert value in ('daily', 'minute') |
||
1007 | self.sim_params.data_frequency = value |
||
1008 | |||
1009 | @api_method |
||
1010 | def order_percent(self, sid, percent, |
||
1011 | limit_price=None, stop_price=None, style=None): |
||
1012 | """ |
||
1013 | Place an order in the specified asset corresponding to the given |
||
1014 | percent of the current portfolio value. |
||
1015 | |||
1016 | Note that percent must expressed as a decimal (0.50 means 50\%). |
||
1017 | """ |
||
1018 | value = self.portfolio.portfolio_value * percent |
||
1019 | return self.order_value(sid, value, |
||
1020 | limit_price=limit_price, |
||
1021 | stop_price=stop_price, |
||
1022 | style=style) |
||
1023 | |||
1024 | @api_method |
||
1025 | def order_target(self, sid, target, |
||
1026 | limit_price=None, stop_price=None, style=None): |
||
1027 | """ |
||
1028 | Place an order to adjust a position to a target number of shares. If |
||
1029 | the position doesn't already exist, this is equivalent to placing a new |
||
1030 | order. If the position does exist, this is equivalent to placing an |
||
1031 | order for the difference between the target number of shares and the |
||
1032 | current number of shares. |
||
1033 | """ |
||
1034 | if sid in self.portfolio.positions: |
||
1035 | current_position = self.portfolio.positions[sid].amount |
||
1036 | req_shares = target - current_position |
||
1037 | return self.order(sid, req_shares, |
||
1038 | limit_price=limit_price, |
||
1039 | stop_price=stop_price, |
||
1040 | style=style) |
||
1041 | else: |
||
1042 | return self.order(sid, target, |
||
1043 | limit_price=limit_price, |
||
1044 | stop_price=stop_price, |
||
1045 | style=style) |
||
1046 | |||
1047 | @api_method |
||
1048 | def order_target_value(self, sid, target, |
||
1049 | limit_price=None, stop_price=None, style=None): |
||
1050 | """ |
||
1051 | Place an order to adjust a position to a target value. If |
||
1052 | the position doesn't already exist, this is equivalent to placing a new |
||
1053 | order. If the position does exist, this is equivalent to placing an |
||
1054 | order for the difference between the target value and the |
||
1055 | current value. |
||
1056 | If the Asset being ordered is a Future, the 'target value' calculated |
||
1057 | is actually the target exposure, as Futures have no 'value'. |
||
1058 | """ |
||
1059 | target_amount = self._calculate_order_value_amount(sid, target) |
||
1060 | return self.order_target(sid, target_amount, |
||
1061 | limit_price=limit_price, |
||
1062 | stop_price=stop_price, |
||
1063 | style=style) |
||
1064 | |||
1065 | @api_method |
||
1066 | def order_target_percent(self, sid, target, |
||
1067 | limit_price=None, stop_price=None, style=None): |
||
1068 | """ |
||
1069 | Place an order to adjust a position to a target percent of the |
||
1070 | current portfolio value. If the position doesn't already exist, this is |
||
1071 | equivalent to placing a new order. If the position does exist, this is |
||
1072 | equivalent to placing an order for the difference between the target |
||
1073 | percent and the current percent. |
||
1074 | |||
1075 | Note that target must expressed as a decimal (0.50 means 50\%). |
||
1076 | """ |
||
1077 | target_value = self.portfolio.portfolio_value * target |
||
1078 | return self.order_target_value(sid, target_value, |
||
1079 | limit_price=limit_price, |
||
1080 | stop_price=stop_price, |
||
1081 | style=style) |
||
1082 | |||
1083 | @api_method |
||
1084 | def get_open_orders(self, sid=None): |
||
1085 | if sid is None: |
||
1086 | return { |
||
1087 | key: [order.to_api_obj() for order in orders] |
||
1088 | for key, orders in iteritems(self.blotter.open_orders) |
||
1089 | if orders |
||
1090 | } |
||
1091 | if sid in self.blotter.open_orders: |
||
1092 | orders = self.blotter.open_orders[sid] |
||
1093 | return [order.to_api_obj() for order in orders] |
||
1094 | return [] |
||
1095 | |||
1096 | @api_method |
||
1097 | def get_order(self, order_id): |
||
1098 | if order_id in self.blotter.orders: |
||
1099 | return self.blotter.orders[order_id].to_api_obj() |
||
1100 | |||
1101 | @api_method |
||
1102 | def cancel_order(self, order_param): |
||
1103 | order_id = order_param |
||
1104 | if isinstance(order_param, zipline.protocol.Order): |
||
1105 | order_id = order_param.id |
||
1106 | |||
1107 | self.blotter.cancel(order_id) |
||
1108 | |||
1109 | @api_method |
||
1110 | @require_initialized(HistoryInInitialize()) |
||
1111 | def history(self, sids, bar_count, frequency, field, ffill=True): |
||
1112 | if self.data_portal is None: |
||
1113 | raise Exception("no data portal!") |
||
1114 | |||
1115 | return self.data_portal.get_history_window( |
||
1116 | sids, |
||
1117 | self.get_datetime(), |
||
1118 | bar_count, |
||
1119 | frequency, |
||
1120 | field, |
||
1121 | ffill |
||
1122 | ) |
||
1123 | #################### |
||
1124 | # Account Controls # |
||
1125 | #################### |
||
1126 | |||
1127 | def register_account_control(self, control): |
||
1128 | """ |
||
1129 | Register a new AccountControl to be checked on each bar. |
||
1130 | """ |
||
1131 | if self.initialized: |
||
1132 | raise RegisterAccountControlPostInit() |
||
1133 | self.account_controls.append(control) |
||
1134 | |||
1135 | def validate_account_controls(self): |
||
1136 | for control in self.account_controls: |
||
1137 | control.validate(self.portfolio, |
||
1138 | self.account, |
||
1139 | self.get_datetime(), |
||
1140 | self.trading_client.current_data) |
||
1141 | |||
1142 | @api_method |
||
1143 | def set_max_leverage(self, max_leverage=None): |
||
1144 | """ |
||
1145 | Set a limit on the maximum leverage of the algorithm. |
||
1146 | """ |
||
1147 | control = MaxLeverage(max_leverage) |
||
1148 | self.register_account_control(control) |
||
1149 | |||
1150 | #################### |
||
1151 | # Trading Controls # |
||
1152 | #################### |
||
1153 | |||
1154 | def register_trading_control(self, control): |
||
1155 | """ |
||
1156 | Register a new TradingControl to be checked prior to order calls. |
||
1157 | """ |
||
1158 | if self.initialized: |
||
1159 | raise RegisterTradingControlPostInit() |
||
1160 | self.trading_controls.append(control) |
||
1161 | |||
1162 | @api_method |
||
1163 | def set_max_position_size(self, |
||
1164 | sid=None, |
||
1165 | max_shares=None, |
||
1166 | max_notional=None): |
||
1167 | """ |
||
1168 | Set a limit on the number of shares and/or dollar value held for the |
||
1169 | given sid. Limits are treated as absolute values and are enforced at |
||
1170 | the time that the algo attempts to place an order for sid. This means |
||
1171 | that it's possible to end up with more than the max number of shares |
||
1172 | due to splits/dividends, and more than the max notional due to price |
||
1173 | improvement. |
||
1174 | |||
1175 | If an algorithm attempts to place an order that would result in |
||
1176 | increasing the absolute value of shares/dollar value exceeding one of |
||
1177 | these limits, raise a TradingControlException. |
||
1178 | """ |
||
1179 | control = MaxPositionSize(asset=sid, |
||
1180 | max_shares=max_shares, |
||
1181 | max_notional=max_notional) |
||
1182 | self.register_trading_control(control) |
||
1183 | |||
1184 | @api_method |
||
1185 | def set_max_order_size(self, sid=None, max_shares=None, max_notional=None): |
||
1186 | """ |
||
1187 | Set a limit on the number of shares and/or dollar value of any single |
||
1188 | order placed for sid. Limits are treated as absolute values and are |
||
1189 | enforced at the time that the algo attempts to place an order for sid. |
||
1190 | |||
1191 | If an algorithm attempts to place an order that would result in |
||
1192 | exceeding one of these limits, raise a TradingControlException. |
||
1193 | """ |
||
1194 | control = MaxOrderSize(asset=sid, |
||
1195 | max_shares=max_shares, |
||
1196 | max_notional=max_notional) |
||
1197 | self.register_trading_control(control) |
||
1198 | |||
1199 | @api_method |
||
1200 | def set_max_order_count(self, max_count): |
||
1201 | """ |
||
1202 | Set a limit on the number of orders that can be placed within the given |
||
1203 | time interval. |
||
1204 | """ |
||
1205 | control = MaxOrderCount(max_count) |
||
1206 | self.register_trading_control(control) |
||
1207 | |||
1208 | @api_method |
||
1209 | def set_do_not_order_list(self, restricted_list): |
||
1210 | """ |
||
1211 | Set a restriction on which sids can be ordered. |
||
1212 | """ |
||
1213 | control = RestrictedListOrder(restricted_list) |
||
1214 | self.register_trading_control(control) |
||
1215 | |||
1216 | @api_method |
||
1217 | def set_long_only(self): |
||
1218 | """ |
||
1219 | Set a rule specifying that this algorithm cannot take short positions. |
||
1220 | """ |
||
1221 | self.register_trading_control(LongOnly()) |
||
1222 | |||
1223 | ############## |
||
1224 | # Pipeline API |
||
1225 | ############## |
||
1226 | @api_method |
||
1227 | @require_not_initialized(AttachPipelineAfterInitialize()) |
||
1228 | def attach_pipeline(self, pipeline, name, chunksize=None): |
||
1229 | """ |
||
1230 | Register a pipeline to be computed at the start of each day. |
||
1231 | """ |
||
1232 | if self._pipelines: |
||
1233 | raise NotImplementedError("Multiple pipelines are not supported.") |
||
1234 | if chunksize is None: |
||
1235 | # Make the first chunk smaller to get more immediate results: |
||
1236 | # (one week, then every half year) |
||
1237 | chunks = iter(chain([5], repeat(126))) |
||
1238 | else: |
||
1239 | chunks = iter(repeat(int(chunksize))) |
||
1240 | self._pipelines[name] = pipeline, chunks |
||
1241 | |||
1242 | # Return the pipeline to allow expressions like |
||
1243 | # p = attach_pipeline(Pipeline(), 'name') |
||
1244 | return pipeline |
||
1245 | |||
1246 | @api_method |
||
1247 | @require_initialized(PipelineOutputDuringInitialize()) |
||
1248 | def pipeline_output(self, name): |
||
1249 | """ |
||
1250 | Get the results of pipeline with name `name`. |
||
1251 | |||
1252 | Parameters |
||
1253 | ---------- |
||
1254 | name : str |
||
1255 | Name of the pipeline for which results are requested. |
||
1256 | |||
1257 | Returns |
||
1258 | ------- |
||
1259 | results : pd.DataFrame |
||
1260 | DataFrame containing the results of the requested pipeline for |
||
1261 | the current simulation date. |
||
1262 | |||
1263 | Raises |
||
1264 | ------ |
||
1265 | NoSuchPipeline |
||
1266 | Raised when no pipeline with the name `name` has been registered. |
||
1267 | |||
1268 | See Also |
||
1269 | -------- |
||
1270 | :meth:`zipline.pipeline.engine.PipelineEngine.run_pipeline` |
||
1271 | """ |
||
1272 | # NOTE: We don't currently support multiple pipelines, but we plan to |
||
1273 | # in the future. |
||
1274 | try: |
||
1275 | p, chunks = self._pipelines[name] |
||
1276 | except KeyError: |
||
1277 | raise NoSuchPipeline( |
||
1278 | name=name, |
||
1279 | valid=list(self._pipelines.keys()), |
||
1280 | ) |
||
1281 | return self._pipeline_output(p, chunks) |
||
1282 | |||
1283 | def _pipeline_output(self, pipeline, chunks): |
||
1284 | """ |
||
1285 | Internal implementation of `pipeline_output`. |
||
1286 | """ |
||
1287 | today = normalize_date(self.get_datetime()) |
||
1288 | try: |
||
1289 | data = self._pipeline_cache.unwrap(today) |
||
1290 | except Expired: |
||
1291 | data, valid_until = self._run_pipeline( |
||
1292 | pipeline, today, next(chunks), |
||
1293 | ) |
||
1294 | self._pipeline_cache = CachedObject(data, valid_until) |
||
1295 | |||
1296 | # Now that we have a cached result, try to return the data for today. |
||
1297 | try: |
||
1298 | return data.loc[today] |
||
1299 | except KeyError: |
||
1300 | # This happens if no assets passed the pipeline screen on a given |
||
1301 | # day. |
||
1302 | return pd.DataFrame(index=[], columns=data.columns) |
||
1303 | |||
1304 | def _run_pipeline(self, pipeline, start_date, chunksize): |
||
1305 | """ |
||
1306 | Compute `pipeline`, providing values for at least `start_date`. |
||
1307 | |||
1308 | Produces a DataFrame containing data for days between `start_date` and |
||
1309 | `end_date`, where `end_date` is defined by: |
||
1310 | |||
1311 | `end_date = min(start_date + chunksize trading days, |
||
1312 | simulation_end)` |
||
1313 | |||
1314 | Returns |
||
1315 | ------- |
||
1316 | (data, valid_until) : tuple (pd.DataFrame, pd.Timestamp) |
||
1317 | |||
1318 | See Also |
||
1319 | -------- |
||
1320 | PipelineEngine.run_pipeline |
||
1321 | """ |
||
1322 | days = self.trading_environment.trading_days |
||
1323 | |||
1324 | # Load data starting from the previous trading day... |
||
1325 | start_date_loc = days.get_loc(start_date) |
||
1326 | |||
1327 | # ...continuing until either the day before the simulation end, or |
||
1328 | # until chunksize days of data have been loaded. |
||
1329 | sim_end = self.sim_params.last_close.normalize() |
||
1330 | end_loc = min(start_date_loc + chunksize, days.get_loc(sim_end)) |
||
1331 | end_date = days[end_loc] |
||
1332 | |||
1333 | return \ |
||
1334 | self.engine.run_pipeline(pipeline, start_date, end_date), end_date |
||
1335 | |||
1336 | ################## |
||
1337 | # End Pipeline API |
||
1338 | ################## |
||
1339 | |||
1340 | @classmethod |
||
1341 | def all_api_methods(cls): |
||
1342 | """ |
||
1343 | Return a list of all the TradingAlgorithm API methods. |
||
1344 | """ |
||
1345 | return [ |
||
1346 | fn for fn in itervalues(vars(cls)) |
||
1347 | if getattr(fn, 'is_api_method', False) |
||
1348 | ] |
||
1349 |