Completed
Pull Request — master (#858)
by Eddie
02:50 queued 01:15
created

zipline.gens.AlgorithmSimulator   A

Complexity

Total Complexity 26

Size/Duplication

Total Lines 193
Duplicated Lines 0 %
Metric Value
dl 0
loc 193
rs 10
wmc 26

10 Methods

Rating   Name   Duplication   Size   Complexity  
A inject_algo_dt() 0 3 2
B __init__() 0 43 3
A get_simulation_dt() 0 2 1
B every_bar() 0 41 6
A _get_daily_message() 0 8 1
A once_a_day() 0 21 3
A _get_minute_message() 0 14 2
A _create_bar_data() 0 5 1
A handle_benchmark() 0 3 1
F transform() 0 108 18
1
#
2
# Copyright 2015 Quantopian, Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
from logbook import Logger, Processor
16
from pandas.tslib import normalize_date
17
from zipline.protocol import BarData
18
from zipline.utils.api_support import ZiplineAPI
19
20
from zipline.gens.sim_engine import (
21
    BAR,
22
    DAY_START,
23
    DAY_END,
24
    MINUTE_END
25
)
26
27
log = Logger('Trade Simulation')
28
29
30
class AlgorithmSimulator(object):
31
32
    EMISSION_TO_PERF_KEY_MAP = {
33
        'minute': 'minute_perf',
34
        'daily': 'daily_perf'
35
    }
36
37
    def __init__(self, algo, sim_params, data_portal, clock, benchmark_source):
38
39
        # ==============
40
        # Simulation
41
        # Param Setup
42
        # ==============
43
        self.sim_params = sim_params
44
        self.env = algo.trading_environment
45
        self.data_portal = data_portal
46
47
        # ==============
48
        # Algo Setup
49
        # ==============
50
        self.algo = algo
51
        self.algo_start = normalize_date(self.sim_params.first_open)
52
53
        # ==============
54
        # Snapshot Setup
55
        # ==============
56
57
        # The algorithm's data as of our most recent event.
58
        # We want an object that will have empty objects as default
59
        # values on missing keys.
60
        self.current_data = self._create_bar_data()
61
62
        # We don't have a datetime for the current snapshot until we
63
        # receive a message.
64
        self.simulation_dt = None
65
66
        self.clock = clock
67
68
        self.benchmark_source = benchmark_source
69
70
        # =============
71
        # Logging Setup
72
        # =============
73
74
        # Processor function for injecting the algo_dt into
75
        # user prints/logs.
76
        def inject_algo_dt(record):
77
            if 'algo_dt' not in record.extra:
78
                record.extra['algo_dt'] = self.simulation_dt
79
        self.processor = Processor(inject_algo_dt)
80
81
    def get_simulation_dt(self):
82
        return self.simulation_dt
83
84
    def _create_bar_data(self):
85
        return BarData(
86
            data_portal=self.data_portal,
87
            simulation_dt_func=self.get_simulation_dt,
88
            data_frequency=self.sim_params.data_frequency,
89
        )
90
91
    def transform(self):
92
        """
93
        Main generator work loop.
94
        """
95
        algo = self.algo
96
        algo.data_portal = self.data_portal
97
        handle_data = algo.event_manager.handle_data
98
        current_data = self.current_data
99
100
        data_portal = self.data_portal
101
102
        # can't cache a pointer to algo.perf_tracker because we're not
103
        # guaranteed that the algo doesn't swap out perf trackers during
104
        # its lifetime.
105
        # likewise, we can't cache a pointer to the blotter.
106
107
        algo.perf_tracker.position_tracker.data_portal = data_portal
108
109
        def every_bar(dt_to_use):
110
            # called every tick (minute or day).
111
112
            self.simulation_dt = dt_to_use
113
            algo.on_dt_changed(dt_to_use)
114
115
            blotter = algo.blotter
116
            perf_tracker = algo.perf_tracker
117
118
            # handle any transactions and commissions coming out new orders
119
            # placed in the last bar
120
            new_transactions, new_commissions = \
121
                blotter.get_transactions(current_data)
122
123
            for transaction in new_transactions:
124
                perf_tracker.process_transaction(transaction)
125
126
                # since this order was modified, record it
127
                order = blotter.orders[transaction.order_id]
128
                perf_tracker.process_order(order)
129
130
            if new_commissions:
131
                for commission in new_commissions:
132
                    perf_tracker.process_commission(commission)
133
134
            handle_data(algo, current_data, dt_to_use)
135
136
            # grab any new orders from the blotter, then clear the list.
137
            # this includes cancelled orders.
138
            new_orders = blotter.new_orders
139
            blotter.new_orders = []
140
141
            # if we have any new orders, record them so that we know
142
            # in what perf period they were placed.
143
            if new_orders:
144
                for new_order in new_orders:
145
                    perf_tracker.process_order(new_order)
146
147
            self.algo.portfolio_needs_update = True
148
            self.algo.account_needs_update = True
149
            self.algo.performance_needs_update = True
150
151
        def once_a_day(midnight_dt):
152
            # set all the timestamps
153
            self.simulation_dt = midnight_dt
154
            algo.on_dt_changed(midnight_dt)
155
156
            # call before trading start
157
            algo.before_trading_start(current_data)
158
159
            perf_tracker = algo.perf_tracker
160
161
            # handle any splits that impact any positions or any open orders.
162
            sids_we_care_about = \
163
                list(set(list(perf_tracker.position_tracker.positions.keys()) +
164
                         list(algo.blotter.open_orders.keys())))
165
166
            if len(sids_we_care_about) > 0:
167
                splits = data_portal.get_splits(sids_we_care_about,
168
                                                midnight_dt)
169
                if len(splits) > 0:
170
                    algo.blotter.process_splits(splits)
171
                    perf_tracker.position_tracker.handle_splits(splits)
172
173
        def handle_benchmark(date):
174
            algo.perf_tracker.all_benchmark_returns[date] = \
175
                self.benchmark_source.get_value(date)
176
177
        with self.processor, ZiplineAPI(self.algo):
178
            for dt, action in self.clock:
179
                if action == BAR:
180
                    every_bar(dt)
181
                elif action == DAY_START:
182
                    once_a_day(dt)
183
                elif action == DAY_END:
184
                    # End of the day.
185
                    handle_benchmark(normalize_date(dt))
186
                    yield self._get_daily_message(dt, algo, algo.perf_tracker)
187
                elif action == MINUTE_END:
188
                    handle_benchmark(dt)
189
                    minute_msg, daily_msg = \
190
                        self._get_minute_message(dt, algo, algo.perf_tracker)
191
192
                    yield minute_msg
193
194
                    if daily_msg:
195
                        yield daily_msg
196
197
        risk_message = algo.perf_tracker.handle_simulation_end()
198
        yield risk_message
199
200
    @staticmethod
201
    def _get_daily_message(dt, algo, perf_tracker):
202
        """
203
        Get a perf message for the given datetime.
204
        """
205
        perf_message = perf_tracker.handle_market_close_daily(dt)
206
        perf_message['daily_perf']['recorded_vars'] = algo.recorded_vars
207
        return perf_message
208
209
    @staticmethod
210
    def _get_minute_message(dt, algo, perf_tracker):
211
        """
212
        Get a perf message for the given datetime.
213
        """
214
        rvars = algo.recorded_vars
215
216
        minute_message, daily_message = perf_tracker.handle_minute_close(dt)
217
        minute_message['minute_perf']['recorded_vars'] = rvars
218
219
        if daily_message:
220
            daily_message["daily_perf"]["recorded_vars"] = rvars
221
222
        return minute_message, daily_message
223