zipline.gens.AlgorithmSimulator.transform()   F
last analyzed

Complexity

Conditions 16

Size

Total Lines 80

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 16
dl 0
loc 80
rs 2.1087

How to fix   Long Method    Complexity   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Complexity

Complex classes like zipline.gens.AlgorithmSimulator.transform() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
#
2
# Copyright 2014 Quantopian, Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
from contextlib2 import ExitStack
17
18
from logbook import Logger, Processor
19
from pandas.tslib import normalize_date
20
21
from zipline.utils.api_support import ZiplineAPI
22
23
from zipline.finance.trading import NoFurtherDataError
24
from zipline.protocol import (
25
    BarData,
26
    SIDData,
27
    DATASOURCE_TYPE
28
)
29
30
log = Logger('Trade Simulation')
31
32
33
class AlgorithmSimulator(object):
34
35
    EMISSION_TO_PERF_KEY_MAP = {
36
        'minute': 'minute_perf',
37
        'daily': 'daily_perf'
38
    }
39
40
    def __init__(self, algo, sim_params):
41
42
        # ==============
43
        # Simulation
44
        # Param Setup
45
        # ==============
46
        self.sim_params = sim_params
47
48
        # ==============
49
        # Algo Setup
50
        # ==============
51
        self.algo = algo
52
        self.algo_start = normalize_date(self.sim_params.first_open)
53
        self.env = algo.trading_environment
54
55
        # ==============
56
        # Snapshot Setup
57
        # ==============
58
59
        # The algorithm's data as of our most recent event.
60
        # We want an object that will have empty objects as default
61
        # values on missing keys.
62
        self.current_data = BarData()
63
64
        # We don't have a datetime for the current snapshot until we
65
        # receive a message.
66
        self.simulation_dt = None
67
68
        # =============
69
        # Logging Setup
70
        # =============
71
72
        # Processor function for injecting the algo_dt into
73
        # user prints/logs.
74
        def inject_algo_dt(record):
75
            if 'algo_dt' not in record.extra:
76
                record.extra['algo_dt'] = self.simulation_dt
77
        self.processor = Processor(inject_algo_dt)
78
79
    def transform(self, stream_in):
80
        """
81
        Main generator work loop.
82
        """
83
        # Initialize the mkt_close
84
        mkt_open = self.algo.perf_tracker.market_open
85
        mkt_close = self.algo.perf_tracker.market_close
86
87
        # inject the current algo
88
        # snapshot time to any log record generated.
89
90
        with ExitStack() as stack:
91
            stack.enter_context(self.processor)
92
            stack.enter_context(ZiplineAPI(self.algo))
93
94
            data_frequency = self.sim_params.data_frequency
95
96
            self._call_before_trading_start(mkt_open)
97
98
            for date, snapshot in stream_in:
99
100
                self.simulation_dt = date
101
                self.on_dt_changed(date)
102
103
                # If we're still in the warmup period.  Use the event to
104
                # update our universe, but don't yield any perf messages,
105
                # and don't send a snapshot to handle_data.
106
                if date < self.algo_start:
107
                    for event in snapshot:
108
                        if event.type == DATASOURCE_TYPE.SPLIT:
109
                            self.algo.blotter.process_split(event)
110
111
                        elif event.type == DATASOURCE_TYPE.TRADE:
112
                            self.update_universe(event)
113
                            self.algo.perf_tracker.process_trade(event)
114
                        elif event.type == DATASOURCE_TYPE.CUSTOM:
115
                            self.update_universe(event)
116
117
                else:
118
                    messages = self._process_snapshot(
119
                        date,
120
                        snapshot,
121
                        self.algo.instant_fill,
122
                    )
123
                    # Perf messages are only emitted if the snapshot contained
124
                    # a benchmark event.
125
                    for message in messages:
126
                        yield message
127
128
                    # When emitting minutely, we need to call
129
                    # before_trading_start before the next trading day begins
130
                    if date == mkt_close:
131
                        if mkt_close <= self.algo.perf_tracker.last_close:
132
                            before_last_close = \
133
                                mkt_close < self.algo.perf_tracker.last_close
134
                            try:
135
                                mkt_open, mkt_close = \
136
                                    self.env.next_open_and_close(mkt_close)
137
138
                            except NoFurtherDataError:
139
                                # If at the end of backtest history,
140
                                # skip advancing market close.
141
                                pass
142
143
                            if before_last_close:
144
                                self._call_before_trading_start(mkt_open)
145
146
                    elif data_frequency == 'daily':
147
                        next_day = self.env.next_trading_day(date)
148
149
                        if next_day is not None and \
150
                           next_day < self.algo.perf_tracker.last_close:
151
                            self._call_before_trading_start(next_day)
152
153
                    self.algo.portfolio_needs_update = True
154
                    self.algo.account_needs_update = True
155
                    self.algo.performance_needs_update = True
156
157
            risk_message = self.algo.perf_tracker.handle_simulation_end()
158
            yield risk_message
159
160
    def _process_snapshot(self, dt, snapshot, instant_fill):
161
        """
162
        Process a stream of events corresponding to a single datetime, possibly
163
        returning a perf message to be yielded.
164
165
        If @instant_fill = True, we delay processing of events until after the
166
        user's call to handle_data, and we process the user's placed orders
167
        before the snapshot's events.  Note that this introduces a lookahead
168
        bias, since the user effectively is effectively placing orders that are
169
        filled based on trades that happened prior to the call the handle_data.
170
171
        If @instant_fill = False, we process Trade events before calling
172
        handle_data.  This means that orders are filled based on trades
173
        occurring in the next snapshot.  This is the more conservative model,
174
        and as such it is the default behavior in TradingAlgorithm.
175
        """
176
177
        # Flags indicating whether we saw any events of type TRADE and type
178
        # BENCHMARK.  Respectively, these control whether or not handle_data is
179
        # called for this snapshot and whether we emit a perf message for this
180
        # snapshot.
181
        any_trade_occurred = False
182
        benchmark_event_occurred = False
183
184
        if instant_fill:
185
            events_to_be_processed = []
186
187
        # Assign process events to variables to avoid attribute access in
188
        # innermost loops.
189
        #
190
        # Done here, to allow for perf_tracker or blotter to be swapped out
191
        # or changed in between snapshots.
192
        perf_process_trade = self.algo.perf_tracker.process_trade
193
        perf_process_transaction = self.algo.perf_tracker.process_transaction
194
        perf_process_order = self.algo.perf_tracker.process_order
195
        perf_process_benchmark = self.algo.perf_tracker.process_benchmark
196
        perf_process_split = self.algo.perf_tracker.process_split
197
        perf_process_dividend = self.algo.perf_tracker.process_dividend
198
        perf_process_commission = self.algo.perf_tracker.process_commission
199
        perf_process_close_position = \
200
            self.algo.perf_tracker.process_close_position
201
        blotter_process_trade = self.algo.blotter.process_trade
202
        blotter_process_benchmark = self.algo.blotter.process_benchmark
203
204
        # Containers for the snapshotted events, so that the events are
205
        # processed in a predictable order, without relying on the sorted order
206
        # of the individual sources.
207
208
        # There is only one benchmark per snapshot, will be set to the current
209
        # benchmark iff it occurs.
210
        benchmark = None
211
        # trades and customs are initialized as a list since process_snapshot
212
        # is most often called on market bars, which could contain trades or
213
        # custom events.
214
        trades = []
215
        customs = []
216
        closes = []
217
218
        # splits and dividends are processed once a day.
219
        #
220
        # The avoidance of creating the list every time this is called is more
221
        # to attempt to show that this is the infrequent case of the method,
222
        # since the performance benefit from deferring the list allocation is
223
        # marginal.  splits list will be allocated when a split occurs in the
224
        # snapshot.
225
        splits = None
226
        # dividends list will be allocated when a dividend occurs in the
227
        # snapshot.
228
        dividends = None
229
230
        for event in snapshot:
231
            if event.type == DATASOURCE_TYPE.TRADE:
232
                trades.append(event)
233
            elif event.type == DATASOURCE_TYPE.BENCHMARK:
234
                benchmark = event
235
            elif event.type == DATASOURCE_TYPE.SPLIT:
236
                if splits is None:
237
                    splits = []
238
                splits.append(event)
239
            elif event.type == DATASOURCE_TYPE.CUSTOM:
240
                customs.append(event)
241
            elif event.type == DATASOURCE_TYPE.DIVIDEND:
242
                if dividends is None:
243
                    dividends = []
244
                dividends.append(event)
245
            elif event.type == DATASOURCE_TYPE.CLOSE_POSITION:
246
                closes.append(event)
247
            else:
248
                raise log.warn("Unrecognized event=%s".format(event))
249
250
        # Handle benchmark first.
251
        #
252
        # Internal broker implementation depends on the benchmark being
253
        # processed first so that transactions and commissions reported from
254
        # the broker can be injected.
255
        if benchmark is not None:
256
            benchmark_event_occurred = True
257
            perf_process_benchmark(benchmark)
258
            for txn, order in blotter_process_benchmark(benchmark):
259
                if txn.type == DATASOURCE_TYPE.TRANSACTION:
260
                    perf_process_transaction(txn)
261
                elif txn.type == DATASOURCE_TYPE.COMMISSION:
262
                    perf_process_commission(txn)
263
                perf_process_order(order)
264
265
        for trade in trades:
266
            self.update_universe(trade)
267
            any_trade_occurred = True
268
            if instant_fill:
269
                events_to_be_processed.append(trade)
270
            else:
271
                for txn, order in blotter_process_trade(trade):
272
                    if txn.type == DATASOURCE_TYPE.TRANSACTION:
273
                        perf_process_transaction(txn)
274
                    elif txn.type == DATASOURCE_TYPE.COMMISSION:
275
                        perf_process_commission(txn)
276
                    perf_process_order(order)
277
                perf_process_trade(trade)
278
279
        for custom in customs:
280
            self.update_universe(custom)
281
282
        for close in closes:
283
            self.update_universe(close)
284
            perf_process_close_position(close)
285
286
        if splits is not None:
287
            for split in splits:
288
                # process_split is not assigned to a variable since it is
289
                # called rarely compared to the other event processors.
290
                self.algo.blotter.process_split(split)
291
                perf_process_split(split)
292
293
        if dividends is not None:
294
            for dividend in dividends:
295
                perf_process_dividend(dividend)
296
297
        if any_trade_occurred:
298
            new_orders = self._call_handle_data()
299
            for order in new_orders:
300
                perf_process_order(order)
301
302
        if instant_fill:
303
            # Now that handle_data has been called and orders have been placed,
304
            # process the event stream to fill user orders based on the events
305
            # from this snapshot.
306
            for trade in events_to_be_processed:
307
                for txn, order in blotter_process_trade(trade):
308
                    if txn is not None:
309
                        perf_process_transaction(txn)
310
                    if order is not None:
311
                        perf_process_order(order)
312
                perf_process_trade(trade)
313
314
        if benchmark_event_occurred:
315
            return self.generate_messages(dt)
316
        else:
317
            return ()
318
319
    def _call_handle_data(self):
320
        """
321
        Call the user's handle_data, returning any orders placed by the algo
322
        during the call.
323
        """
324
        self.algo.event_manager.handle_data(
325
            self.algo,
326
            self.current_data,
327
            self.simulation_dt,
328
        )
329
        orders = self.algo.blotter.new_orders
330
        self.algo.blotter.new_orders = []
331
        return orders
332
333
    def _call_before_trading_start(self, dt):
334
        dt = normalize_date(dt)
335
        self.simulation_dt = dt
336
        self.on_dt_changed(dt)
337
        self.algo.before_trading_start(self.current_data)
338
339
    def on_dt_changed(self, dt):
340
        if self.algo.datetime != dt:
341
            self.algo.on_dt_changed(dt)
342
343
    def generate_messages(self, dt):
344
        """
345
        Generator that yields perf messages for the given datetime.
346
        """
347
        # Ensure that updated_portfolio has been called at least once for this
348
        # dt before we emit a perf message.  This is a no-op if
349
        # updated_portfolio has already been called this dt.
350
        self.algo.updated_portfolio()
351
        self.algo.updated_account()
352
353
        rvars = self.algo.recorded_vars
354
        if self.algo.perf_tracker.emission_rate == 'daily':
355
            perf_message = \
356
                self.algo.perf_tracker.handle_market_close_daily()
357
            perf_message['daily_perf']['recorded_vars'] = rvars
358
            yield perf_message
359
360
        elif self.algo.perf_tracker.emission_rate == 'minute':
361
            # close the minute in the tracker, and collect the daily message if
362
            # the minute is the close of the trading day
363
            minute_message, daily_message = \
364
                self.algo.perf_tracker.handle_minute_close(dt)
365
366
            # collect and yield the minute's perf message
367
            minute_message['minute_perf']['recorded_vars'] = rvars
368
            yield minute_message
369
370
            # if there was a daily perf message, collect and yield it
371
            if daily_message:
372
                daily_message['daily_perf']['recorded_vars'] = rvars
373
                yield daily_message
374
375
    def update_universe(self, event):
376
        """
377
        Update the universe with new event information.
378
        """
379
        # Update our knowledge of this event's sid
380
        # rather than use if event.sid in ..., just trying
381
        # and handling the exception is significantly faster
382
        try:
383
            sid_data = self.current_data[event.sid]
384
        except KeyError:
385
            sid_data = self.current_data[event.sid] = SIDData(event.sid)
386
387
        sid_data.__dict__.update(event.__dict__)
388