Completed
Pull Request — master (#858)
by Eddie
02:02
created

zipline.sources.PandasCSV.__iter__()   F

Complexity

Conditions 13

Size

Total Lines 54

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 13
dl 0
loc 54
rs 3.5512

How to fix   Long Method    Complexity   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Complexity

Complex classes like zipline.sources.PandasCSV.__iter__() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
#
2
# Copyright 2014 Quantopian, Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
from copy import copy
17
import six
18
19
import numpy as np
20
from datetime import timedelta
21
import pandas as pd
22
23
from zipline.sources.data_source import DataSource
24
from zipline.utils import tradingcalendar as calendar_nyse
25
from zipline.gens.utils import hash_args
26
27
28
class RandomWalkSource(DataSource):
29
    """RandomWalkSource that emits events with prices that follow a
30
    random walk. Will generate valid datetimes that match market hours
31
    of the supplied calendar and can generate emit events with
32
    user-defined frequencies (e.g. minutely).
33
34
    """
35
    VALID_FREQS = frozenset(('daily', 'minute'))
36
37
    def __init__(self, start_prices=None, freq='minute', start=None,
38
                 end=None, drift=0.1, sd=0.1, calendar=calendar_nyse):
39
        """
40
        :Arguments:
41
            start_prices : dict
42
                 sid -> starting price.
43
                 Default: {0: 100, 1: 500}
44
            freq : str <default='minute'>
45
                 Emits events according to freq.
46
                 Can be 'daily' or 'minute'
47
            start : datetime <default=start of calendar>
48
                 Start dt to emit events.
49
            end : datetime <default=end of calendar>
50
                 End dt until to which emit events.
51
            drift: float <default=0.1>
52
                 Constant drift of the price series.
53
            sd: float <default=0.1>
54
                 Standard deviation of the price series.
55
            calendar : calendar object <default: NYSE>
56
                 Calendar to use.
57
                 See zipline.utils for different choices.
58
59
        :Example:
60
            # Assumes you have instantiated your Algorithm
61
            # as myalgo.
62
            myalgo = MyAlgo()
63
            source = RandomWalkSource()
64
            myalgo.run(source)
65
66
        """
67
        # Hash_value for downstream sorting.
68
        self.arg_string = hash_args(start_prices, freq, start, end,
69
                                    calendar.__name__)
70
71
        if freq not in self.VALID_FREQS:
72
            raise ValueError('%s not in %s' % (freq, self.VALID_FREQS))
73
74
        self.freq = freq
75
        if start_prices is None:
76
            self.start_prices = {0: 100,
77
                                 1: 500}
78
        else:
79
            self.start_prices = start_prices
80
81
        self.calendar = calendar
82
        if start is None:
83
            self.start = calendar.start
84
        else:
85
            self.start = start
86
        if end is None:
87
            self.end = calendar.end_base
88
        else:
89
            self.end = end
90
91
        self.drift = drift
92
        self.sd = sd
93
94
        self.sids = self.start_prices.keys()
95
96
        self.open_and_closes = \
97
            calendar.open_and_closes[self.start:self.end]
98
99
        self._raw_data = None
100
101
    @property
102
    def instance_hash(self):
103
        return self.arg_string
104
105
    @property
106
    def mapping(self):
107
        return {
108
            'dt': (lambda x: x, 'dt'),
109
            'sid': (lambda x: x, 'sid'),
110
            'price': (float, 'price'),
111
            'volume': (int, 'volume'),
112
            'open_price': (float, 'open_price'),
113
            'high': (float, 'high'),
114
            'low': (float, 'low'),
115
        }
116
117
    def _gen_next_step(self, x):
118
        x += np.random.randn() * self.sd + self.drift
119
        return max(x, 0.1)
120
121
    def _gen_events(self, cur_prices, current_dt):
122
        for sid, price in six.iteritems(cur_prices):
123
            cur_prices[sid] = self._gen_next_step(cur_prices[sid])
124
125
            event = {
126
                'dt': current_dt,
127
                'sid': sid,
128
                'price': cur_prices[sid],
129
                'volume': np.random.randint(1e5, 1e6),
130
                'open_price': cur_prices[sid],
131
                'high': cur_prices[sid] + .1,
132
                'low': cur_prices[sid] - .1,
133
            }
134
135
            yield event
136
137
    def raw_data_gen(self):
138
        cur_prices = copy(self.start_prices)
139
        for _, (open_dt, close_dt) in self.open_and_closes.iterrows():
140
            current_dt = copy(open_dt)
141
            if self.freq == 'minute':
142
                # Emit minutely trade signals from open to close
143
                while current_dt <= close_dt:
144
                    for event in self._gen_events(cur_prices, current_dt):
145
                        yield event
146
                    current_dt += timedelta(minutes=1)
147
            elif self.freq == 'daily':
148
                # Emit one signal per day at close
149
                for event in self._gen_events(
150
                        cur_prices, pd.tslib.normalize_date(close_dt)):
151
                    yield event
152
153
    @property
154
    def raw_data(self):
155
        if not self._raw_data:
156
            self._raw_data = self.raw_data_gen()
157
        return self._raw_data
158