Completed
Pull Request — master (#906)
by Eddie
01:25
created

zipline.data.DataPortal.get_spot_value()   B

Complexity

Conditions 5

Size

Total Lines 46

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 5
dl 0
loc 46
rs 8.1277
1
#
2
# Copyright 2015 Quantopian, Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
import bcolz
17
from logbook import Logger
18
19
import numpy as np
20
import pandas as pd
21
from pandas.tslib import normalize_date
22
23
from zipline.assets import Future, Equity
24
from zipline.data.us_equity_pricing import NoDataOnDate
25
26
from zipline.utils import tradingcalendar
27
from zipline.errors import (
28
    NoTradeDataAvailableTooEarly,
29
    NoTradeDataAvailableTooLate
30
)
31
32
log = Logger('DataPortal')
33
34
BASE_FIELDS = {
35
    'open': 'open',
36
    'open_price': 'open',
37
    'high': 'high',
38
    'low': 'low',
39
    'close': 'close',
40
    'close_price': 'close',
41
    'volume': 'volume',
42
    'price': 'close'
43
}
44
45
46
class DataPortal(object):
47
    def __init__(self,
48
                 env,
49
                 sim_params=None,
50
                 equity_daily_reader=None,
51
                 equity_minute_reader=None,
52
                 future_daily_reader=None,
53
                 future_minute_reader=None,
54
                 adjustment_reader=None):
55
        self.env = env
56
57
        # Internal pointers to the current dt (can be minute) and current day.
58
        # In daily mode, they point to the same thing. In minute mode, it's
59
        # useful to have separate pointers to the current day and to the
60
        # current minute.  These pointers are updated by the
61
        # AlgorithmSimulator's transform loop.
62
        self.current_dt = None
63
        self.current_day = None
64
65
        self.views = {}
66
67
        self._asset_finder = env.asset_finder
68
69
        self._carrays = {
70
            'open': {},
71
            'high': {},
72
            'low': {},
73
            'close': {},
74
            'volume': {},
75
            'sid': {},
76
            'dt': {},
77
        }
78
79
        self._adjustment_reader = adjustment_reader
80
81
        # caches of sid -> adjustment list
82
        self._splits_dict = {}
83
        self._mergers_dict = {}
84
        self._dividends_dict = {}
85
86
        # Cache of sid -> the first trading day of an asset, even if that day
87
        # is before 1/2/2002.
88
        self._asset_start_dates = {}
89
        self._asset_end_dates = {}
90
91
        # Handle extra sources, like Fetcher.
92
        self._augmented_sources_map = {}
93
        self._extra_source_df = None
94
95
        self._sim_params = sim_params
96
        if self._sim_params is not None:
97
            self._data_frequency = self._sim_params.data_frequency
98
        else:
99
            self._data_frequency = "minute"
100
101
        self.MINUTE_PRICE_ADJUSTMENT_FACTOR = 0.001
102
103
        self._equity_daily_reader = equity_daily_reader
104
        self._equity_minute_reader = equity_minute_reader
105
        self._future_daily_reader = future_daily_reader
106
        self._future_minute_reader = future_minute_reader
107
108
    def _open_minute_file(self, field, asset):
109
        sid_str = str(int(asset))
110
111
        try:
112
            carray = self._carrays[field][sid_str]
113
        except KeyError:
114
            carray = self._carrays[field][sid_str] = \
115
                self._get_ctable(asset)[field]
116
117
        return carray
118
119
    def _get_ctable(self, asset):
120
        sid = int(asset)
121
122
        if isinstance(asset, Future):
123
            if self._future_minute_reader.sid_path_func is not None:
124
                path = self._future_minute_reader.sid_path_func(
125
                    self._future_minute_reader.rootdir, sid
126
                )
127
            else:
128
                path = "{0}/{1}.bcolz".format(
129
                    self._future_minute_reader.rootdir, sid)
130
        elif isinstance(asset, Equity):
131
            if self._equity_minute_reader.sid_path_func is not None:
132
                path = self._equity_minute_reader.sid_path_func(
133
                    self._equity_minute_reader.rootdir, sid
134
                )
135
            else:
136
                path = "{0}/{1}.bcolz".format(
137
                    self._equity_minute_reader.rootdir, sid)
138
139
        return bcolz.open(path, mode='r')
140
141
    def get_spot_value(self, asset, field, dt=None):
142
        """
143
        Public API method that returns a scalar value representing the value
144
        of the desired asset's field at either the given dt, or this data
145
        portal's current_dt.
146
147
        Parameters
148
        ---------
149
        asset : Asset
150
            The asset whose data is desired.gith
151
152
        field: string
153
            The desired field of the asset.  Valid values are "open",
154
            "open_price", "high", "low", "close", "close_price", "volume", and
155
            "price".
156
157
        dt: pd.Timestamp
158
            (Optional) The timestamp for the desired value.
159
160
        Returns
161
        -------
162
        The value of the desired field at the desired time.
163
        """
164
        if field not in BASE_FIELDS:
165
            raise KeyError("Invalid column: " + str(field))
166
167
        column_to_use = BASE_FIELDS[field]
168
169
        if isinstance(asset, int):
170
            asset = self._asset_finder.retrieve_asset(asset)
171
172
        self._check_is_currently_alive(asset, dt)
173
174
        if self._data_frequency == "daily":
175
            day_to_use = dt or self.current_day
176
            day_to_use = normalize_date(day_to_use)
177
            return self._get_daily_data(asset, column_to_use, day_to_use)
178
        else:
179
            dt_to_use = dt or self.current_dt
180
181
            if isinstance(asset, Future):
182
                return self._get_minute_spot_value_future(
183
                    asset, column_to_use, dt_to_use)
184
            else:
185
                return self._get_minute_spot_value(
186
                    asset, column_to_use, dt_to_use)
187
188
    def _get_minute_spot_value_future(self, asset, column, dt):
189
        # Futures bcolz files have 1440 bars per day (24 hours), 7 days a week.
190
        # The file attributes contain the "start_dt" and "last_dt" fields,
191
        # which represent the time period for this bcolz file.
192
193
        # The start_dt is midnight of the first day that this future started
194
        # trading.
195
196
        # figure out the # of minutes between dt and this asset's start_dt
197
        start_date = self._get_asset_start_date(asset)
198
        minute_offset = int((dt - start_date).total_seconds() / 60)
199
200
        if minute_offset < 0:
201
            # asking for a date that is before the asset's start date, no dice
202
            return 0.0
203
204
        # then just index into the bcolz carray at that offset
205
        carray = self._open_minute_file(column, asset)
206
        result = carray[minute_offset]
207
208
        # if there's missing data, go backwards until we run out of file
209
        while result == 0 and minute_offset > 0:
210
            minute_offset -= 1
211
            result = carray[minute_offset]
212
213
        if column != 'volume':
214
            return result * self.MINUTE_PRICE_ADJUSTMENT_FACTOR
215
        else:
216
            return result
217
218
    def _get_minute_spot_value(self, asset, column, dt):
219
        # if dt is before the first market minute, minute_index
220
        # will be 0.  if it's after the last market minute, it'll
221
        # be len(minutes_for_day)
222
        given_day = pd.Timestamp(dt.date(), tz='utc')
223
        day_index = self._equity_minute_reader.trading_days.searchsorted(
224
            given_day)
225
226
        # if dt is before the first market minute, minute_index
227
        # will be 0.  if it's after the last market minute, it'll
228
        # be len(minutes_for_day)
229
        minute_index = self.env.market_minutes_for_day(given_day).\
230
            searchsorted(dt)
231
232
        minute_offset_to_use = (day_index * 390) + minute_index
233
234
        carray = self._equity_minute_reader._open_minute_file(column, asset)
235
        result = carray[minute_offset_to_use]
236
237
        if result == 0:
238
            # if the given minute doesn't have data, we need to seek
239
            # backwards until we find data. This makes the data
240
            # forward-filled.
241
242
            # get this asset's start date, so that we don't look before it.
243
            start_date = self._get_asset_start_date(asset)
244
            start_date_idx = self._equity_minute_reader.trading_days.\
245
                searchsorted(start_date)
246
            start_day_offset = start_date_idx * 390
247
248
            original_start = minute_offset_to_use
249
250
            while result == 0 and minute_offset_to_use > start_day_offset:
251
                minute_offset_to_use -= 1
252
                result = carray[minute_offset_to_use]
253
254
            # once we've found data, we need to check whether it needs
255
            # to be adjusted.
256
            if result != 0:
257
                minutes = self.env.market_minute_window(
258
                    start=(dt or self.current_dt),
259
                    count=(original_start - minute_offset_to_use + 1),
260
                    step=-1
261
                ).order()
262
263
                # only need to check for adjustments if we've gone back
264
                # far enough to cross the day boundary.
265
                if minutes[0].date() != minutes[-1].date():
266
                    # create a np array of size minutes, fill it all with
267
                    # the same value.  and adjust the array.
268
                    arr = np.array([result] * len(minutes),
269
                                   dtype=np.float64)
270
                    self._apply_all_adjustments(
271
                        data=arr,
272
                        asset=asset,
273
                        dts=minutes,
274
                        field=column
275
                    )
276
277
                    # The first value of the adjusted array is the value
278
                    # we want.
279
                    result = arr[0]
280
281
        if column != 'volume':
282
            return result * self.MINUTE_PRICE_ADJUSTMENT_FACTOR
283
        else:
284
            return result
285
286
    def _get_daily_data(self, asset, column, dt):
287
        while True:
288
            try:
289
                value = self._equity_daily_reader.spot_price(
290
                    asset, dt, column)
291
                if value != -1:
292
                    return value
293
                else:
294
                    dt -= tradingcalendar.trading_day
295
            except NoDataOnDate:
296
                return 0
297
298
    def _apply_all_adjustments(self, data, asset, dts, field,
299
                               price_adj_factor=1.0):
300
        """
301
        Internal method that applies all the necessary adjustments on the
302
        given data array.
303
304
        The adjustments are:
305
        - splits
306
        - if field != "volume":
307
            - mergers
308
            - dividends
309
            - * 0.001
310
            - any zero fields replaced with NaN
311
        - all values rounded to 3 digits after the decimal point.
312
313
        Parameters
314
        ----------
315
        data : np.array
316
            The data to be adjusted.
317
318
        asset: Asset
319
            The asset whose data is being adjusted.
320
321
        dts: pd.DateTimeIndex
322
            The list of minutes or days representing the desired window.
323
324
        field: string
325
            The field whose values are in the data array.
326
327
        price_adj_factor: float
328
            Factor with which to adjust OHLC values.
329
        Returns
330
        -------
331
        None.  The data array is modified in place.
332
        """
333
        self._apply_adjustments_to_window(
334
            self._get_adjustment_list(
335
                asset, self._splits_dict, "SPLITS"
336
            ),
337
            data,
338
            dts,
339
            field != 'volume'
340
        )
341
342
        if field != 'volume':
343
            self._apply_adjustments_to_window(
344
                self._get_adjustment_list(
345
                    asset, self._mergers_dict, "MERGERS"
346
                ),
347
                data,
348
                dts,
349
                True
350
            )
351
352
            self._apply_adjustments_to_window(
353
                self._get_adjustment_list(
354
                    asset, self._dividends_dict, "DIVIDENDS"
355
                ),
356
                data,
357
                dts,
358
                True
359
            )
360
361
            data *= price_adj_factor
362
363
            # if anything is zero, it's a missing bar, so replace it with NaN.
364
            # we only want to do this for non-volume fields, because a missing
365
            # volume should be 0.
366
            data[data == 0] = np.NaN
367
368
        np.around(data, 3, out=data)
369
370
    @staticmethod
371
    def _apply_adjustments_to_window(adjustments_list, window_data,
372
                                     dts_in_window, multiply):
373
        if len(adjustments_list) == 0:
374
            return
375
376
        # advance idx to the correct spot in the adjustments list, based on
377
        # when the window starts
378
        idx = 0
379
380
        while idx < len(adjustments_list) and dts_in_window[0] >\
381
                adjustments_list[idx][0]:
382
            idx += 1
383
384
        # if we've advanced through all the adjustments, then there's nothing
385
        # to do.
386
        if idx == len(adjustments_list):
387
            return
388
389
        while idx < len(adjustments_list):
390
            adjustment_to_apply = adjustments_list[idx]
391
392
            if adjustment_to_apply[0] > dts_in_window[-1]:
393
                break
394
395
            range_end = dts_in_window.searchsorted(adjustment_to_apply[0])
396
            if multiply:
397
                window_data[0:range_end] *= adjustment_to_apply[1]
398
            else:
399
                window_data[0:range_end] /= adjustment_to_apply[1]
400
401
            idx += 1
402
403
    def _get_adjustment_list(self, asset, adjustments_dict, table_name):
404
        """
405
        Internal method that returns a list of adjustments for the given sid.
406
407
        Parameters
408
        ----------
409
        asset : Asset
410
            The asset for which to return adjustments.
411
412
        adjustments_dict: dict
413
            A dictionary of sid -> list that is used as a cache.
414
415
        table_name: string
416
            The table that contains this data in the adjustments db.
417
418
        Returns
419
        -------
420
        adjustments: list
421
            A list of [multiplier, pd.Timestamp], earliest first
422
423
        """
424
        if self._adjustment_reader is None:
425
            return []
426
427
        sid = int(asset)
428
429
        try:
430
            adjustments = adjustments_dict[sid]
431
        except KeyError:
432
            adjustments = adjustments_dict[sid] = self._adjustment_reader.\
433
                get_adjustments_for_sid(table_name, sid)
434
435
        return adjustments
436
437
    def _check_is_currently_alive(self, asset, dt):
438
        if dt is None:
439
            dt = self.current_day
440
441
        sid = int(asset)
442
443
        if sid not in self._asset_start_dates:
444
            self._get_asset_start_date(asset)
445
446
        start_date = self._asset_start_dates[sid]
447
        if self._asset_start_dates[sid] > dt:
448
            raise NoTradeDataAvailableTooEarly(
449
                sid=sid,
450
                dt=dt,
451
                start_dt=start_date
452
            )
453
454
        end_date = self._asset_end_dates[sid]
455
        if self._asset_end_dates[sid] < dt:
456
            raise NoTradeDataAvailableTooLate(
457
                sid=sid,
458
                dt=dt,
459
                end_dt=end_date
460
            )
461
462
    def _get_asset_start_date(self, asset):
463
        self._ensure_asset_dates(asset)
464
        return self._asset_start_dates[asset]
465
466
    def _get_asset_end_date(self, asset):
467
        self._ensure_asset_dates(asset)
468
        return self._asset_end_dates[asset]
469
470
    def _ensure_asset_dates(self, asset):
471
        sid = int(asset)
472
473
        if sid not in self._asset_start_dates:
474
            self._asset_start_dates[sid] = asset.start_date
475
            self._asset_end_dates[sid] = asset.end_date
476