Completed
Pull Request — master (#906)
by Eddie
01:40
created

_apply_adjustments_to_window()   D

Complexity

Conditions 8

Size

Total Lines 32

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 8
dl 0
loc 32
rs 4
1
#
2
# Copyright 2015 Quantopian, Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
import bcolz
17
from logbook import Logger
18
19
import numpy as np
20
import pandas as pd
21
from pandas.tslib import normalize_date
22
23
from zipline.assets import Future, Equity
24
from zipline.data.us_equity_pricing import NoDataOnDate
25
26
from zipline.utils import tradingcalendar
27
from zipline.errors import (
28
    NoTradeDataAvailableTooEarly,
29
    NoTradeDataAvailableTooLate
30
)
31
32
log = Logger('DataPortal')
33
34
BASE_FIELDS = {
35
    'open': 'open',
36
    'open_price': 'open',
37
    'high': 'high',
38
    'low': 'low',
39
    'close': 'close',
40
    'close_price': 'close',
41
    'volume': 'volume',
42
    'price': 'close'
43
}
44
45
46
class DataPortal(object):
47
    def __init__(self,
48
                 env,
49
                 sim_params=None,
50
                 equity_daily_reader=None,
51
                 equity_minute_reader=None,
52
                 future_daily_reader=None,
53
                 future_minute_reader=None,
54
                 adjustment_reader=None):
55
        self.env = env
56
57
        # Internal pointers to the current dt (can be minute) and current day.
58
        # In daily mode, they point to the same thing. In minute mode, it's
59
        # useful to have separate pointers to the current day and to the
60
        # current minute.  These pointers are updated by the
61
        # AlgorithmSimulator's transform loop.
62
        self.current_dt = None
63
        self.current_day = None
64
65
        self.views = {}
66
67
        self._asset_finder = env.asset_finder
68
69
        self._carrays = {
70
            'open': {},
71
            'high': {},
72
            'low': {},
73
            'close': {},
74
            'volume': {},
75
            'sid': {},
76
            'dt': {},
77
        }
78
79
        self._adjustment_reader = adjustment_reader
80
81
        # caches of sid -> adjustment list
82
        self._splits_dict = {}
83
        self._mergers_dict = {}
84
        self._dividends_dict = {}
85
86
        # Cache of sid -> the first trading day of an asset, even if that day
87
        # is before 1/2/2002.
88
        self._asset_start_dates = {}
89
        self._asset_end_dates = {}
90
91
        # Handle extra sources, like Fetcher.
92
        self._augmented_sources_map = {}
93
        self._extra_source_df = None
94
95
        if self._sim_params is not None:
96
            self._data_frequency = self._sim_params.data_frequency
97
        else:
98
            self._data_frequency = "minute"
99
100
        self.MINUTE_PRICE_ADJUSTMENT_FACTOR = 0.001
101
102
        self._equity_daily_reader = equity_daily_reader
103
        self._equity_minute_reader = equity_minute_reader
104
        self._future_daily_reader = future_daily_reader
105
        self._future_minute_reader = future_minute_reader
106
107
    def _open_minute_file(self, field, asset):
108
        sid_str = str(int(asset))
109
110
        try:
111
            carray = self._carrays[field][sid_str]
112
        except KeyError:
113
            carray = self._carrays[field][sid_str] = \
114
                self._get_ctable(asset)[field]
115
116
        return carray
117
118
    def _get_ctable(self, asset):
119
        sid = int(asset)
120
121
        if isinstance(asset, Future):
122
            if self._future_minute_reader.sid_path_func is not None:
123
                path = self._future_minute_reader.sid_path_func(
124
                    self._future_minute_reader.rootdir, sid
125
                )
126
            else:
127
                path = "{0}/{1}.bcolz".format(
128
                    self._future_minute_reader.rootdir, sid)
129
        elif isinstance(asset, Equity):
130
            if self._equity_minute_reader.sid_path_func is not None:
131
                path = self._equity_minute_reader.sid_path_func(
132
                    self._equity_minute_reader.rootdir, sid
133
                )
134
            else:
135
                path = "{0}/{1}.bcolz".format(
136
                    self._equity_minute_reader.rootdir, sid)
137
138
        return bcolz.open(path, mode='r')
139
140
    def get_spot_value(self, asset, field, dt=None):
141
        """
142
        Public API method that returns a scalar value representing the value
143
        of the desired asset's field at either the given dt, or this data
144
        portal's current_dt.
145
146
        Parameters
147
        ---------
148
        asset : Asset
149
            The asset whose data is desired.
150
151
        field: string
152
            The desired field of the asset.  Valid values are "open",
153
            "open_price", "high", "low", "close", "close_price", "volume", and
154
            "price".
155
156
        dt: pd.Timestamp
157
            (Optional) The timestamp for the desired value.
158
159
        Returns
160
        -------
161
        The value of the desired field at the desired time.
162
        """
163
        if field not in BASE_FIELDS:
164
            raise KeyError("Invalid column: " + str(field))
165
166
        column_to_use = BASE_FIELDS[field]
167
168
        if isinstance(asset, int):
169
            asset = self._asset_finder.retrieve_asset(asset)
170
171
        self._check_is_currently_alive(asset, dt)
172
173
        if self._data_frequency == "daily":
174
            day_to_use = dt or self.current_day
175
            day_to_use = normalize_date(day_to_use)
176
            return self._get_daily_data(asset, column_to_use, day_to_use)
177
        else:
178
            dt_to_use = dt or self.current_dt
179
180
            if isinstance(asset, Future):
181
                return self._get_minute_spot_value_future(
182
                    asset, column_to_use, dt_to_use)
183
            else:
184
                return self._get_minute_spot_value(
185
                    asset, column_to_use, dt_to_use)
186
187
    def _get_minute_spot_value_future(self, asset, column, dt):
188
        # Futures bcolz files have 1440 bars per day (24 hours), 7 days a week.
189
        # The file attributes contain the "start_dt" and "last_dt" fields,
190
        # which represent the time period for this bcolz file.
191
192
        # The start_dt is midnight of the first day that this future started
193
        # trading.
194
195
        # figure out the # of minutes between dt and this asset's start_dt
196
        start_date = self._get_asset_start_date(asset)
197
        minute_offset = int((dt - start_date).total_seconds() / 60)
198
199
        if minute_offset < 0:
200
            # asking for a date that is before the asset's start date, no dice
201
            return 0.0
202
203
        # then just index into the bcolz carray at that offset
204
        carray = self._open_minute_file(column, asset)
205
        result = carray[minute_offset]
206
207
        # if there's missing data, go backwards until we run out of file
208
        while result == 0 and minute_offset > 0:
209
            minute_offset -= 1
210
            result = carray[minute_offset]
211
212
        if column != 'volume':
213
            return result * self.MINUTE_PRICE_ADJUSTMENT_FACTOR
214
        else:
215
            return result
216
217
    def _get_minute_spot_value(self, asset, column, dt):
218
        # if dt is before the first market minute, minute_index
219
        # will be 0.  if it's after the last market minute, it'll
220
        # be len(minutes_for_day)
221
        given_day = pd.Timestamp(dt.date(), tz='utc')
222
        day_index = self._equity_minute_reader.trading_days.searchsorted(
223
            given_day)
224
225
        # if dt is before the first market minute, minute_index
226
        # will be 0.  if it's after the last market minute, it'll
227
        # be len(minutes_for_day)
228
        minute_index = self.env.market_minutes_for_day(given_day).\
229
            searchsorted(dt)
230
231
        minute_offset_to_use = (day_index * 390) + minute_index
232
233
        carray = self._equity_minute_reader._open_minute_file(column, asset)
234
        result = carray[minute_offset_to_use]
235
236
        if result == 0:
237
            # if the given minute doesn't have data, we need to seek
238
            # backwards until we find data. This makes the data
239
            # forward-filled.
240
241
            # get this asset's start date, so that we don't look before it.
242
            start_date = self._get_asset_start_date(asset)
243
            start_date_idx = self._equity_minute_reader.trading_days.\
244
                searchsorted(start_date)
245
            start_day_offset = start_date_idx * 390
246
247
            original_start = minute_offset_to_use
248
249
            while result == 0 and minute_offset_to_use > start_day_offset:
250
                minute_offset_to_use -= 1
251
                result = carray[minute_offset_to_use]
252
253
            # once we've found data, we need to check whether it needs
254
            # to be adjusted.
255
            if result != 0:
256
                minutes = self.env.market_minute_window(
257
                    start=(dt or self.current_dt),
258
                    count=(original_start - minute_offset_to_use + 1),
259
                    step=-1
260
                ).order()
261
262
                # only need to check for adjustments if we've gone back
263
                # far enough to cross the day boundary.
264
                if minutes[0].date() != minutes[-1].date():
265
                    # create a np array of size minutes, fill it all with
266
                    # the same value.  and adjust the array.
267
                    arr = np.array([result] * len(minutes),
268
                                   dtype=np.float64)
269
                    self._apply_all_adjustments(
270
                        data=arr,
271
                        asset=asset,
272
                        dts=minutes,
273
                        field=column
274
                    )
275
276
                    # The first value of the adjusted array is the value
277
                    # we want.
278
                    result = arr[0]
279
280
        if column != 'volume':
281
            return result * self.MINUTE_PRICE_ADJUSTMENT_FACTOR
282
        else:
283
            return result
284
285
    def _get_daily_data(self, asset, column, dt):
286
        while True:
287
            try:
288
                value = self._equity_daily_reader.spot_price(
289
                    asset, dt, column)
290
                if value != -1:
291
                    return value
292
                else:
293
                    dt -= tradingcalendar.trading_day
294
            except NoDataOnDate:
295
                return 0
296
297
    def _apply_all_adjustments(self, data, asset, dts, field,
298
                               price_adj_factor=1.0):
299
        """
300
        Internal method that applies all the necessary adjustments on the
301
        given data array.
302
303
        The adjustments are:
304
        - splits
305
        - if field != "volume":
306
            - mergers
307
            - dividends
308
            - * 0.001
309
            - any zero fields replaced with NaN
310
        - all values rounded to 3 digits after the decimal point.
311
312
        Parameters
313
        ----------
314
        data : np.array
315
            The data to be adjusted.
316
317
        asset: Asset
318
            The asset whose data is being adjusted.
319
320
        dts: pd.DateTimeIndex
321
            The list of minutes or days representing the desired window.
322
323
        field: string
324
            The field whose values are in the data array.
325
326
        price_adj_factor: float
327
            Factor with which to adjust OHLC values.
328
        Returns
329
        -------
330
        None.  The data array is modified in place.
331
        """
332
        self._apply_adjustments_to_window(
333
            self._get_adjustment_list(
334
                asset, self._splits_dict, "SPLITS"
335
            ),
336
            data,
337
            dts,
338
            field != 'volume'
339
        )
340
341
        if field != 'volume':
342
            self._apply_adjustments_to_window(
343
                self._get_adjustment_list(
344
                    asset, self._mergers_dict, "MERGERS"
345
                ),
346
                data,
347
                dts,
348
                True
349
            )
350
351
            self._apply_adjustments_to_window(
352
                self._get_adjustment_list(
353
                    asset, self._dividends_dict, "DIVIDENDS"
354
                ),
355
                data,
356
                dts,
357
                True
358
            )
359
360
            data *= price_adj_factor
361
362
            # if anything is zero, it's a missing bar, so replace it with NaN.
363
            # we only want to do this for non-volume fields, because a missing
364
            # volume should be 0.
365
            data[data == 0] = np.NaN
366
367
        np.around(data, 3, out=data)
368
369
    @staticmethod
370
    def _apply_adjustments_to_window(adjustments_list, window_data,
371
                                     dts_in_window, multiply):
372
        if len(adjustments_list) == 0:
373
            return
374
375
        # advance idx to the correct spot in the adjustments list, based on
376
        # when the window starts
377
        idx = 0
378
379
        while idx < len(adjustments_list) and dts_in_window[0] >\
380
                adjustments_list[idx][0]:
381
            idx += 1
382
383
        # if we've advanced through all the adjustments, then there's nothing
384
        # to do.
385
        if idx == len(adjustments_list):
386
            return
387
388
        while idx < len(adjustments_list):
389
            adjustment_to_apply = adjustments_list[idx]
390
391
            if adjustment_to_apply[0] > dts_in_window[-1]:
392
                break
393
394
            range_end = dts_in_window.searchsorted(adjustment_to_apply[0])
395
            if multiply:
396
                window_data[0:range_end] *= adjustment_to_apply[1]
397
            else:
398
                window_data[0:range_end] /= adjustment_to_apply[1]
399
400
            idx += 1
401
402
    def _get_adjustment_list(self, asset, adjustments_dict, table_name):
403
        """
404
        Internal method that returns a list of adjustments for the given sid.
405
406
        Parameters
407
        ----------
408
        asset : Asset
409
            The asset for which to return adjustments.
410
411
        adjustments_dict: dict
412
            A dictionary of sid -> list that is used as a cache.
413
414
        table_name: string
415
            The table that contains this data in the adjustments db.
416
417
        Returns
418
        -------
419
        adjustments: list
420
            A list of [multiplier, pd.Timestamp], earliest first
421
422
        """
423
        if self._adjustment_reader is None:
424
            return []
425
426
        sid = int(asset)
427
428
        try:
429
            adjustments = adjustments_dict[sid]
430
        except KeyError:
431
            adjustments = adjustments_dict[sid] = self._adjustment_reader.\
432
                get_adjustments_for_sid(table_name, sid)
433
434
        return adjustments
435
436
    def _check_is_currently_alive(self, asset, dt):
437
        if dt is None:
438
            dt = self.current_day
439
440
        sid = int(asset)
441
442
        if sid not in self._asset_start_dates:
443
            self._get_asset_start_date(asset)
444
445
        start_date = self._asset_start_dates[sid]
446
        if self._asset_start_dates[sid] > dt:
447
            raise NoTradeDataAvailableTooEarly(
448
                sid=sid,
449
                dt=dt,
450
                start_dt=start_date
451
            )
452
453
        end_date = self._asset_end_dates[sid]
454
        if self._asset_end_dates[sid] < dt:
455
            raise NoTradeDataAvailableTooLate(
456
                sid=sid,
457
                dt=dt,
458
                end_dt=end_date
459
            )
460
461
    def _get_asset_start_date(self, asset):
462
        self._ensure_asset_dates(asset)
463
        return self._asset_start_dates[asset]
464
465
    def _get_asset_end_date(self, asset):
466
        self._ensure_asset_dates(asset)
467
        return self._asset_end_dates[asset]
468
469
    def _ensure_asset_dates(self, asset):
470
        sid = int(asset)
471
472
        if sid not in self._asset_start_dates:
473
            self._asset_start_dates[sid] = asset.start_date
474
            self._asset_end_dates[sid] = asset.end_date
475