zipline.finance.TradingEnvironment   F
last analyzed

Complexity

Total Complexity 67

Size/Duplication

Total Lines 405
Duplicated Lines 0 %
Metric Value
dl 0
loc 405
rs 3.0612
wmc 67

27 Methods

Rating   Name   Duplication   Size   Complexity  
A is_market_hours() 0 6 2
A opens_in_range() 0 2 1
A get_open_and_close() 0 4 1
B market_minute_window() 0 40 6
D write_data() 0 66 10
A utc_dt_in_exchange() 0 2 1
A previous_market_minute() 0 17 4
B __init__() 0 61 6
A normalize_date() 0 3 1
A previous_trading_day() 0 10 3
A exchange_dt_in_utc() 0 2 1
B add_trading_days() 0 28 5
A next_trading_day() 0 10 3
A previous_open_and_close() 0 13 2
A _write_data_lists() 0 4 1
A trading_day_distance() 0 14 3
A next_open_and_close() 0 14 2
A days_in_range() 0 4 1
A market_minutes_for_day() 0 3 1
A closes_in_range() 0 2 1
A next_market_minute() 0 17 4
A open_close_window() 0 13 1
A get_index() 0 10 2
A _write_data_dicts() 0 4 1
A minutes_for_days_in_range() 0 15 2
A is_trading_day() 0 3 1
A _write_data_dataframes() 0 4 1

How to fix   Complexity   

Complex Class

Complex classes like zipline.finance.TradingEnvironment often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
#
2
# Copyright 2014 Quantopian, Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
import bisect
17
import logbook
18
import datetime
19
20
import pandas as pd
21
import numpy as np
22
from six import string_types
23
from sqlalchemy import create_engine
24
25
from zipline.data.loader import load_market_data
26
from zipline.utils import tradingcalendar
27
from zipline.assets import AssetFinder
28
from zipline.assets.asset_writer import (
29
    AssetDBWriterFromList,
30
    AssetDBWriterFromDictionary,
31
    AssetDBWriterFromDataFrame)
32
from zipline.errors import (
33
    NoFurtherDataError
34
)
35
36
37
log = logbook.Logger('Trading')
38
39
40
# The financial simulations in zipline depend on information
41
# about the benchmark index and the risk free rates of return.
42
# The benchmark index defines the benchmark returns used in
43
# the calculation of performance metrics such as alpha/beta. Many
44
# components, including risk, performance, transforms, and
45
# batch_transforms, need access to a calendar of trading days and
46
# market hours. The TradingEnvironment maintains two time keeping
47
# facilities:
48
#   - a DatetimeIndex of trading days for calendar calculations
49
#   - a timezone name, which should be local to the exchange
50
#   hosting the benchmark index. All dates are normalized to UTC
51
#   for serialization and storage, and the timezone is used to
52
#   ensure proper rollover through daylight savings and so on.
53
#
54
# User code will not normally need to use TradingEnvironment
55
# directly. If you are extending zipline's core financial
56
# components and need to use the environment, you must import the module and
57
# build a new TradingEnvironment object, then pass that TradingEnvironment as
58
# the 'env' arg to your TradingAlgorithm.
59
60
class TradingEnvironment(object):
61
62
    # Token used as a substitute for pickling objects that contain a
63
    # reference to a TradingEnvironment
64
    PERSISTENT_TOKEN = "<TradingEnvironment>"
65
66
    def __init__(
67
        self,
68
        load=None,
69
        bm_symbol='^GSPC',
70
        exchange_tz="US/Eastern",
71
        max_date=None,
72
        env_trading_calendar=tradingcalendar,
73
        asset_db_path=':memory:'
74
    ):
75
        """
76
        @load is function that returns benchmark_returns and treasury_curves
77
        The treasury_curves are expected to be a DataFrame with an index of
78
        dates and columns of the curve names, e.g. '10year', '1month', etc.
79
        """
80
        self.trading_day = env_trading_calendar.trading_day.copy()
81
82
        # `tc_td` is short for "trading calendar trading days"
83
        tc_td = env_trading_calendar.trading_days
84
85
        if max_date:
86
            self.trading_days = tc_td[tc_td <= max_date].copy()
87
        else:
88
            self.trading_days = tc_td.copy()
89
90
        self.first_trading_day = self.trading_days[0]
91
        self.last_trading_day = self.trading_days[-1]
92
93
        self.early_closes = env_trading_calendar.get_early_closes(
94
            self.first_trading_day, self.last_trading_day)
95
96
        self.open_and_closes = env_trading_calendar.open_and_closes.loc[
97
            self.trading_days]
98
99
        self.bm_symbol = bm_symbol
100
        if not load:
101
            load = load_market_data
102
103
        self.benchmark_returns, self.treasury_curves = \
104
            load(self.trading_day, self.trading_days, self.bm_symbol)
105
106
        if max_date:
107
            tr_c = self.treasury_curves
108
            # Mask the treasury curves down to the current date.
109
            # In the case of live trading, the last date in the treasury
110
            # curves would be the day before the date considered to be
111
            # 'today'.
112
            self.treasury_curves = tr_c[tr_c.index <= max_date]
113
114
        self.exchange_tz = exchange_tz
115
116
        if isinstance(asset_db_path, string_types):
117
            asset_db_path = 'sqlite:///%s' % asset_db_path
118
            self.engine = engine = create_engine(asset_db_path)
119
            AssetDBWriterFromDictionary().init_db(engine)
120
        else:
121
            self.engine = engine = asset_db_path
122
123
        if engine is not None:
124
            self.asset_finder = AssetFinder(engine)
125
        else:
126
            self.asset_finder = None
127
128
    def write_data(self,
129
                   engine=None,
130
                   equities_data=None,
131
                   futures_data=None,
132
                   exchanges_data=None,
133
                   root_symbols_data=None,
134
                   equities_df=None,
135
                   futures_df=None,
136
                   exchanges_df=None,
137
                   root_symbols_df=None,
138
                   equities_identifiers=None,
139
                   futures_identifiers=None,
140
                   exchanges_identifiers=None,
141
                   root_symbols_identifiers=None,
142
                   allow_sid_assignment=True):
143
        """ Write the supplied data to the database.
144
145
        Parameters
146
        ----------
147
        equities_data: dict, optional
148
            A dictionary of equity metadata
149
        futures_data: dict, optional
150
            A dictionary of futures metadata
151
        exchanges_data: dict, optional
152
            A dictionary of exchanges metadata
153
        root_symbols_data: dict, optional
154
            A dictionary of root symbols metadata
155
        equities_df: pandas.DataFrame, optional
156
            A pandas.DataFrame of equity metadata
157
        futures_df: pandas.DataFrame, optional
158
            A pandas.DataFrame of futures metadata
159
        exchanges_df: pandas.DataFrame, optional
160
            A pandas.DataFrame of exchanges metadata
161
        root_symbols_df: pandas.DataFrame, optional
162
            A pandas.DataFrame of root symbols metadata
163
        equities_identifiers: list, optional
164
            A list of equities identifiers (sids, symbols, Assets)
165
        futures_identifiers: list, optional
166
            A list of futures identifiers (sids, symbols, Assets)
167
        exchanges_identifiers: list, optional
168
            A list of exchanges identifiers (ids or names)
169
        root_symbols_identifiers: list, optional
170
            A list of root symbols identifiers (ids or symbols)
171
        """
172
        if engine:
173
            self.engine = engine
174
175
        # If any pandas.DataFrame data has been provided,
176
        # write it to the database.
177
        if (equities_df is not None or futures_df is not None or
178
                exchanges_df is not None or root_symbols_df is not None):
179
            self._write_data_dataframes(equities_df, futures_df,
180
                                        exchanges_df, root_symbols_df)
181
182
        if (equities_data is not None or futures_data is not None or
183
                exchanges_data is not None or root_symbols_data is not None):
184
            self._write_data_dicts(equities_data, futures_data,
185
                                   exchanges_data, root_symbols_data)
186
187
        # These could be lists or other iterables such as a pandas.Index.
188
        # For simplicity, don't check whether data has been provided.
189
        self._write_data_lists(equities_identifiers,
190
                               futures_identifiers,
191
                               exchanges_identifiers,
192
                               root_symbols_identifiers,
193
                               allow_sid_assignment=allow_sid_assignment)
194
195
    def _write_data_lists(self, equities=None, futures=None, exchanges=None,
196
                          root_symbols=None, allow_sid_assignment=True):
197
        AssetDBWriterFromList(equities, futures, exchanges, root_symbols)\
198
            .write_all(self.engine, allow_sid_assignment=allow_sid_assignment)
199
200
    def _write_data_dicts(self, equities=None, futures=None, exchanges=None,
201
                          root_symbols=None):
202
        AssetDBWriterFromDictionary(equities, futures, exchanges, root_symbols)\
203
            .write_all(self.engine)
204
205
    def _write_data_dataframes(self, equities=None, futures=None,
206
                               exchanges=None, root_symbols=None):
207
        AssetDBWriterFromDataFrame(equities, futures, exchanges, root_symbols)\
208
            .write_all(self.engine)
209
210
    def normalize_date(self, test_date):
211
        test_date = pd.Timestamp(test_date, tz='UTC')
212
        return pd.tseries.tools.normalize_date(test_date)
213
214
    def utc_dt_in_exchange(self, dt):
215
        return pd.Timestamp(dt).tz_convert(self.exchange_tz)
216
217
    def exchange_dt_in_utc(self, dt):
218
        return pd.Timestamp(dt, tz=self.exchange_tz).tz_convert('UTC')
219
220
    def is_market_hours(self, test_date):
221
        if not self.is_trading_day(test_date):
222
            return False
223
224
        mkt_open, mkt_close = self.get_open_and_close(test_date)
225
        return test_date >= mkt_open and test_date <= mkt_close
226
227
    def is_trading_day(self, test_date):
228
        dt = self.normalize_date(test_date)
229
        return (dt in self.trading_days)
230
231
    def next_trading_day(self, test_date):
232
        dt = self.normalize_date(test_date)
233
        delta = datetime.timedelta(days=1)
234
235
        while dt <= self.last_trading_day:
236
            dt += delta
237
            if dt in self.trading_days:
238
                return dt
239
240
        return None
241
242
    def previous_trading_day(self, test_date):
243
        dt = self.normalize_date(test_date)
244
        delta = datetime.timedelta(days=-1)
245
246
        while self.first_trading_day < dt:
247
            dt += delta
248
            if dt in self.trading_days:
249
                return dt
250
251
        return None
252
253
    def add_trading_days(self, n, date):
254
        """
255
        Adds n trading days to date. If this would fall outside of the
256
        trading calendar, a NoFurtherDataError is raised.
257
258
        :Arguments:
259
            n : int
260
                The number of days to add to date, this can be positive or
261
                negative.
262
            date : datetime
263
                The date to add to.
264
265
        :Returns:
266
            new_date : datetime
267
                n trading days added to date.
268
        """
269
        if n == 1:
270
            return self.next_trading_day(date)
271
        if n == -1:
272
            return self.previous_trading_day(date)
273
274
        idx = self.get_index(date) + n
275
        if idx < 0 or idx >= len(self.trading_days):
276
            raise NoFurtherDataError(
277
                msg='Cannot add %d days to %s' % (n, date)
278
            )
279
280
        return self.trading_days[idx]
281
282
    def days_in_range(self, start, end):
283
        mask = ((self.trading_days >= start) &
284
                (self.trading_days <= end))
285
        return self.trading_days[mask]
286
287
    def opens_in_range(self, start, end):
288
        return self.open_and_closes.market_open.loc[start:end]
289
290
    def closes_in_range(self, start, end):
291
        return self.open_and_closes.market_close.loc[start:end]
292
293
    def minutes_for_days_in_range(self, start, end):
294
        """
295
        Get all market minutes for the days between start and end, inclusive.
296
        """
297
        start_date = self.normalize_date(start)
298
        end_date = self.normalize_date(end)
299
300
        all_minutes = []
301
        for day in self.days_in_range(start_date, end_date):
302
            day_minutes = self.market_minutes_for_day(day)
303
            all_minutes.append(day_minutes)
304
305
        # Concatenate all minutes and truncate minutes before start/after end.
306
        return pd.DatetimeIndex(
307
            np.concatenate(all_minutes), copy=False, tz='UTC',
308
        )
309
310
    def next_open_and_close(self, start_date):
311
        """
312
        Given the start_date, returns the next open and close of
313
        the market.
314
        """
315
        next_open = self.next_trading_day(start_date)
316
317
        if next_open is None:
318
            raise NoFurtherDataError(
319
                msg=("Attempt to backtest beyond available history. "
320
                     "Last known date: %s" % self.last_trading_day)
321
            )
322
323
        return self.get_open_and_close(next_open)
324
325
    def previous_open_and_close(self, start_date):
326
        """
327
        Given the start_date, returns the previous open and close of the
328
        market.
329
        """
330
        previous = self.previous_trading_day(start_date)
331
332
        if previous is None:
333
            raise NoFurtherDataError(
334
                msg=("Attempt to backtest beyond available history. "
335
                     "First known date: %s" % self.first_trading_day)
336
            )
337
        return self.get_open_and_close(previous)
338
339
    def next_market_minute(self, start):
340
        """
341
        Get the next market minute after @start. This is either the immediate
342
        next minute, the open of the same day if @start is before the market
343
        open on a trading day, or the open of the next market day after @start.
344
        """
345
        if self.is_trading_day(start):
346
            market_open, market_close = self.get_open_and_close(start)
347
            # If start before market open on a trading day, return market open.
348
            if start < market_open:
349
                return market_open
350
            # If start is during trading hours, then get the next minute.
351
            elif start < market_close:
352
                return start + datetime.timedelta(minutes=1)
353
        # If start is not in a trading day, or is after the market close
354
        # then return the open of the *next* trading day.
355
        return self.next_open_and_close(start)[0]
356
357
    def previous_market_minute(self, start):
358
        """
359
        Get the next market minute before @start. This is either the immediate
360
        previous minute, the close of the same day if @start is after the close
361
        on a trading day, or the close of the market day before @start.
362
        """
363
        if self.is_trading_day(start):
364
            market_open, market_close = self.get_open_and_close(start)
365
            # If start after the market close, return market close.
366
            if start > market_close:
367
                return market_close
368
            # If start is during trading hours, then get previous minute.
369
            if start > market_open:
370
                return start - datetime.timedelta(minutes=1)
371
        # If start is not a trading day, or is before the market open
372
        # then return the close of the *previous* trading day.
373
        return self.previous_open_and_close(start)[1]
374
375
    def get_open_and_close(self, day):
376
        index = self.open_and_closes.index.get_loc(day.date())
377
        todays_minutes = self.open_and_closes.values[index]
378
        return todays_minutes[0], todays_minutes[1]
379
380
    def market_minutes_for_day(self, stamp):
381
        market_open, market_close = self.get_open_and_close(stamp)
382
        return pd.date_range(market_open, market_close, freq='T')
383
384
    def open_close_window(self, start, count, offset=0, step=1):
385
        """
386
        Return a DataFrame containing `count` market opens and closes,
387
        beginning with `start` + `offset` days and continuing `step` minutes at
388
        a time.
389
        """
390
        # TODO: Correctly handle end of data.
391
        start_idx = self.get_index(start) + offset
392
        stop_idx = start_idx + (count * step)
393
394
        index = np.arange(start_idx, stop_idx, step)
395
396
        return self.open_and_closes.iloc[index]
397
398
    def market_minute_window(self, start, count, step=1):
399
        """
400
        Return a DatetimeIndex containing `count` market minutes, starting with
401
        `start` and continuing `step` minutes at a time.
402
        """
403
        if not self.is_market_hours(start):
404
            raise ValueError("market_minute_window starting at "
405
                             "non-market time {minute}".format(minute=start))
406
407
        all_minutes = []
408
409
        current_day_minutes = self.market_minutes_for_day(start)
410
        first_minute_idx = current_day_minutes.searchsorted(start)
411
        minutes_in_range = current_day_minutes[first_minute_idx::step]
412
413
        # Build up list of lists of days' market minutes until we have count
414
        # minutes stored altogether.
415
        while True:
416
417
            if len(minutes_in_range) >= count:
418
                # Truncate off extra minutes
419
                minutes_in_range = minutes_in_range[:count]
420
421
            all_minutes.append(minutes_in_range)
422
            count -= len(minutes_in_range)
423
            if count <= 0:
424
                break
425
426
            if step > 0:
427
                start, _ = self.next_open_and_close(start)
428
                current_day_minutes = self.market_minutes_for_day(start)
429
            else:
430
                _, start = self.previous_open_and_close(start)
431
                current_day_minutes = self.market_minutes_for_day(start)
432
433
            minutes_in_range = current_day_minutes[::step]
434
435
        # Concatenate all the accumulated minutes.
436
        return pd.DatetimeIndex(
437
            np.concatenate(all_minutes), copy=False, tz='UTC',
438
        )
439
440
    def trading_day_distance(self, first_date, second_date):
441
        first_date = self.normalize_date(first_date)
442
        second_date = self.normalize_date(second_date)
443
444
        # TODO: May be able to replace the following with searchsorted.
445
        # Find leftmost item greater than or equal to day
446
        i = bisect.bisect_left(self.trading_days, first_date)
447
        if i == len(self.trading_days):  # nothing found
448
            return None
449
        j = bisect.bisect_left(self.trading_days, second_date)
450
        if j == len(self.trading_days):
451
            return None
452
453
        return j - i
454
455
    def get_index(self, dt):
456
        """
457
        Return the index of the given @dt, or the index of the preceding
458
        trading day if the given dt is not in the trading calendar.
459
        """
460
        ndt = self.normalize_date(dt)
461
        if ndt in self.trading_days:
462
            return self.trading_days.searchsorted(ndt)
463
        else:
464
            return self.trading_days.searchsorted(ndt) - 1
465
466
467
class SimulationParameters(object):
468
    def __init__(self, period_start, period_end,
469
                 capital_base=10e3,
470
                 emission_rate='daily',
471
                 data_frequency='daily',
472
                 env=None):
473
474
        self.period_start = period_start
475
        self.period_end = period_end
476
        self.capital_base = capital_base
477
478
        self.emission_rate = emission_rate
479
        self.data_frequency = data_frequency
480
481
        # copied to algorithm's environment for runtime access
482
        self.arena = 'backtest'
483
484
        if env is not None:
485
            self.update_internal_from_env(env=env)
486
487
    def update_internal_from_env(self, env):
488
489
        assert self.period_start <= self.period_end, \
490
            "Period start falls after period end."
491
492
        assert self.period_start <= env.last_trading_day, \
493
            "Period start falls after the last known trading day."
494
        assert self.period_end >= env.first_trading_day, \
495
            "Period end falls before the first known trading day."
496
497
        self.first_open = self._calculate_first_open(env)
498
        self.last_close = self._calculate_last_close(env)
499
500
        start_index = env.get_index(self.first_open)
501
        end_index = env.get_index(self.last_close)
502
503
        # take an inclusive slice of the environment's
504
        # trading_days.
505
        self.trading_days = env.trading_days[start_index:end_index + 1]
506
507
    def _calculate_first_open(self, env):
508
        """
509
        Finds the first trading day on or after self.period_start.
510
        """
511
        first_open = self.period_start
512
        one_day = datetime.timedelta(days=1)
513
514
        while not env.is_trading_day(first_open):
515
            first_open = first_open + one_day
516
517
        mkt_open, _ = env.get_open_and_close(first_open)
518
        return mkt_open
519
520
    def _calculate_last_close(self, env):
521
        """
522
        Finds the last trading day on or before self.period_end
523
        """
524
        last_close = self.period_end
525
        one_day = datetime.timedelta(days=1)
526
527
        while not env.is_trading_day(last_close):
528
            last_close = last_close - one_day
529
530
        _, mkt_close = env.get_open_and_close(last_close)
531
        return mkt_close
532
533
    @property
534
    def days_in_period(self):
535
        """return the number of trading days within the period [start, end)"""
536
        return len(self.trading_days)
537
538
    def __repr__(self):
539
        return """
540
{class_name}(
541
    period_start={period_start},
542
    period_end={period_end},
543
    capital_base={capital_base},
544
    data_frequency={data_frequency},
545
    emission_rate={emission_rate},
546
    first_open={first_open},
547
    last_close={last_close})\
548
""".format(class_name=self.__class__.__name__,
549
           period_start=self.period_start,
550
           period_end=self.period_end,
551
           capital_base=self.capital_base,
552
           data_frequency=self.data_frequency,
553
           emission_rate=self.emission_rate,
554
           first_open=self.first_open,
555
           last_close=self.last_close)
556
557
558
def noop_load(*args, **kwargs):
559
    """
560
    A method that can be substituted in as the load method in a
561
    TradingEnvironment to prevent it from loading benchmarks.
562
563
    Accepts any arguments, but returns only a tuple of Nones regardless
564
    of input.
565
    """
566
    return None, None
567