1
|
|
|
""" |
2
|
|
|
filter.py |
3
|
|
|
""" |
4
|
|
|
from numpy import ( |
5
|
|
|
bool_, |
6
|
|
|
float64, |
7
|
|
|
nan, |
8
|
|
|
nanpercentile, |
9
|
|
|
) |
10
|
|
|
from itertools import chain |
11
|
|
|
from operator import attrgetter |
12
|
|
|
|
13
|
|
|
from zipline.errors import ( |
14
|
|
|
BadPercentileBounds, |
15
|
|
|
UnsupportedDataType, |
16
|
|
|
) |
17
|
|
|
from zipline.pipeline.term import ( |
18
|
|
|
CompositeTerm, |
19
|
|
|
CustomTermMixin, |
20
|
|
|
RequiredWindowLengthMixin, |
21
|
|
|
SingleInputMixin, |
22
|
|
|
) |
23
|
|
|
from zipline.pipeline.expression import ( |
24
|
|
|
BadBinaryOperator, |
25
|
|
|
FILTER_BINOPS, |
26
|
|
|
method_name_for_op, |
27
|
|
|
NumericalExpression, |
28
|
|
|
) |
29
|
|
|
from zipline.utils.control_flow import nullctx |
30
|
|
|
|
31
|
|
|
|
32
|
|
|
def concat_tuples(*tuples): |
33
|
|
|
""" |
34
|
|
|
Concatenate a sequence of tuples into one tuple. |
35
|
|
|
""" |
36
|
|
|
return tuple(chain(*tuples)) |
37
|
|
|
|
38
|
|
|
|
39
|
|
|
def binary_operator(op): |
40
|
|
|
""" |
41
|
|
|
Factory function for making binary operator methods on a Filter subclass. |
42
|
|
|
|
43
|
|
|
Returns a function "binary_operator" suitable for implementing functions |
44
|
|
|
like __and__ or __or__. |
45
|
|
|
""" |
46
|
|
|
# When combining a Filter with a NumericalExpression, we use this |
47
|
|
|
# attrgetter instance to defer to the commuted interpretation of the |
48
|
|
|
# NumericalExpression operator. |
49
|
|
|
commuted_method_getter = attrgetter(method_name_for_op(op, commute=True)) |
50
|
|
|
|
51
|
|
|
def binary_operator(self, other): |
52
|
|
|
if isinstance(self, NumericalExpression): |
|
|
|
|
53
|
|
|
self_expr, other_expr, new_inputs = self.build_binary_op( |
54
|
|
|
op, other, |
55
|
|
|
) |
56
|
|
|
return NumExprFilter( |
57
|
|
|
"({left}) {op} ({right})".format( |
58
|
|
|
left=self_expr, |
59
|
|
|
op=op, |
60
|
|
|
right=other_expr, |
61
|
|
|
), |
62
|
|
|
new_inputs, |
63
|
|
|
) |
64
|
|
|
elif isinstance(other, NumericalExpression): |
65
|
|
|
# NumericalExpression overrides numerical ops to correctly handle |
66
|
|
|
# merging of inputs. Look up and call the appropriate |
67
|
|
|
# right-binding operator with ourself as the input. |
68
|
|
|
return commuted_method_getter(other)(self) |
69
|
|
|
elif isinstance(other, Filter): |
70
|
|
|
if self is other: |
71
|
|
|
return NumExprFilter( |
72
|
|
|
"x_0 {op} x_0".format(op=op), |
73
|
|
|
(self,), |
74
|
|
|
) |
75
|
|
|
return NumExprFilter( |
76
|
|
|
"x_0 {op} x_1".format(op=op), |
77
|
|
|
(self, other), |
78
|
|
|
) |
79
|
|
|
elif isinstance(other, int): # Note that this is true for bool as well |
80
|
|
|
return NumExprFilter( |
81
|
|
|
"x_0 {op} ({constant})".format(op=op, constant=int(other)), |
82
|
|
|
binds=(self,), |
83
|
|
|
) |
84
|
|
|
raise BadBinaryOperator(op, self, other) |
85
|
|
|
|
86
|
|
|
binary_operator.__doc__ = "Binary Operator: '%s'" % op |
87
|
|
|
return binary_operator |
88
|
|
|
|
89
|
|
|
|
90
|
|
|
def unary_operator(op): |
|
|
|
|
91
|
|
|
""" |
92
|
|
|
Factory function for making unary operator methods for Filters. |
93
|
|
|
""" |
94
|
|
|
valid_ops = {'~'} |
95
|
|
|
if op not in valid_ops: |
96
|
|
|
raise ValueError("Invalid unary operator %s." % op) |
97
|
|
|
|
98
|
|
|
def unary_operator(self): |
99
|
|
|
# This can't be hoisted up a scope because the types returned by |
100
|
|
|
# unary_op_return_type aren't defined when the top-level function is |
101
|
|
|
# invoked. |
102
|
|
|
if isinstance(self, NumericalExpression): |
103
|
|
|
return NumExprFilter( |
104
|
|
|
"{op}({expr})".format(op=op, expr=self._expr), |
105
|
|
|
self.inputs, |
106
|
|
|
) |
107
|
|
|
else: |
108
|
|
|
return NumExprFilter("{op}x_0".format(op=op), (self,)) |
109
|
|
|
|
110
|
|
|
unary_operator.__doc__ = "Unary Operator: '%s'" % op |
111
|
|
|
return unary_operator |
112
|
|
|
|
113
|
|
|
|
114
|
|
|
class Filter(CompositeTerm): |
115
|
|
|
""" |
116
|
|
|
Pipeline API expression producing boolean-valued outputs. |
117
|
|
|
""" |
118
|
|
|
dtype = bool_ |
119
|
|
|
|
120
|
|
|
clsdict = locals() |
121
|
|
|
clsdict.update( |
122
|
|
|
{ |
123
|
|
|
method_name_for_op(op): binary_operator(op) |
124
|
|
|
for op in FILTER_BINOPS |
125
|
|
|
} |
126
|
|
|
) |
127
|
|
|
__invert__ = unary_operator('~') |
128
|
|
|
|
129
|
|
|
|
130
|
|
|
class NumExprFilter(NumericalExpression, Filter): |
131
|
|
|
""" |
132
|
|
|
A Filter computed from a numexpr expression. |
133
|
|
|
""" |
134
|
|
|
|
135
|
|
|
def _compute(self, arrays, dates, assets, mask): |
136
|
|
|
""" |
137
|
|
|
Compute our result with numexpr, then re-apply `mask`. |
138
|
|
|
""" |
139
|
|
|
return super(NumExprFilter, self)._compute( |
140
|
|
|
arrays, |
141
|
|
|
dates, |
142
|
|
|
assets, |
143
|
|
|
mask, |
144
|
|
|
) & mask |
145
|
|
|
|
146
|
|
|
|
147
|
|
|
class PercentileFilter(SingleInputMixin, Filter): |
148
|
|
|
""" |
149
|
|
|
A Filter representing assets falling between percentile bounds of a Factor. |
150
|
|
|
|
151
|
|
|
Parameters |
152
|
|
|
---------- |
153
|
|
|
factor : zipline.pipeline.factor.Factor |
154
|
|
|
The factor over which to compute percentile bounds. |
155
|
|
|
min_percentile : float [0.0, 1.0] |
156
|
|
|
The minimum percentile rank of an asset that will pass the filter. |
157
|
|
|
max_percentile : float [0.0, 1.0] |
158
|
|
|
The maxiumum percentile rank of an asset that will pass the filter. |
159
|
|
|
""" |
160
|
|
|
window_length = 0 |
161
|
|
|
|
162
|
|
|
def __new__(cls, factor, min_percentile, max_percentile, mask): |
163
|
|
|
return super(PercentileFilter, cls).__new__( |
164
|
|
|
cls, |
165
|
|
|
inputs=(factor,), |
166
|
|
|
mask=mask, |
167
|
|
|
min_percentile=min_percentile, |
168
|
|
|
max_percentile=max_percentile, |
169
|
|
|
) |
170
|
|
|
|
171
|
|
|
def _init(self, min_percentile, max_percentile, *args, **kwargs): |
172
|
|
|
self._min_percentile = min_percentile |
173
|
|
|
self._max_percentile = max_percentile |
174
|
|
|
return super(PercentileFilter, self)._init(*args, **kwargs) |
175
|
|
|
|
176
|
|
|
@classmethod |
177
|
|
|
def static_identity(cls, min_percentile, max_percentile, *args, **kwargs): |
178
|
|
|
return ( |
179
|
|
|
super(PercentileFilter, cls).static_identity(*args, **kwargs), |
180
|
|
|
min_percentile, |
181
|
|
|
max_percentile, |
182
|
|
|
) |
183
|
|
|
|
184
|
|
|
def _validate(self): |
185
|
|
|
""" |
186
|
|
|
Ensure that our percentile bounds are well-formed. |
187
|
|
|
""" |
188
|
|
|
if not 0.0 <= self._min_percentile < self._max_percentile <= 100.0: |
189
|
|
|
raise BadPercentileBounds( |
190
|
|
|
min_percentile=self._min_percentile, |
191
|
|
|
max_percentile=self._max_percentile, |
192
|
|
|
) |
193
|
|
|
return super(PercentileFilter, self)._validate() |
194
|
|
|
|
195
|
|
|
def _compute(self, arrays, dates, assets, mask): |
196
|
|
|
""" |
197
|
|
|
For each row in the input, compute a mask of all values falling between |
198
|
|
|
the given percentiles. |
199
|
|
|
""" |
200
|
|
|
# TODO: Review whether there's a better way of handling small numbers |
201
|
|
|
# of columns. |
202
|
|
|
data = arrays[0].copy().astype(float64) |
203
|
|
|
data[~mask] = nan |
204
|
|
|
|
205
|
|
|
# FIXME: np.nanpercentile **should** support computing multiple bounds |
206
|
|
|
# at once, but there's a bug in the logic for multiple bounds in numpy |
207
|
|
|
# 1.9.2. It will be fixed in 1.10. |
208
|
|
|
# c.f. https://github.com/numpy/numpy/pull/5981 |
209
|
|
|
lower_bounds = nanpercentile( |
210
|
|
|
data, |
211
|
|
|
self._min_percentile, |
212
|
|
|
axis=1, |
213
|
|
|
keepdims=True, |
214
|
|
|
) |
215
|
|
|
upper_bounds = nanpercentile( |
216
|
|
|
data, |
217
|
|
|
self._max_percentile, |
218
|
|
|
axis=1, |
219
|
|
|
keepdims=True, |
220
|
|
|
) |
221
|
|
|
return (lower_bounds <= data) & (data <= upper_bounds) |
222
|
|
|
|
223
|
|
|
|
224
|
|
|
class CustomFilter(RequiredWindowLengthMixin, CustomTermMixin, Filter): |
225
|
|
|
""" |
226
|
|
|
Filter analog to ``CustomFactor``. |
227
|
|
|
""" |
228
|
|
|
ctx = nullctx() |
229
|
|
|
|
230
|
|
|
def _validate(self): |
231
|
|
|
if self.dtype != bool_: |
232
|
|
|
raise UnsupportedDataType(dtype=self.dtype) |
233
|
|
|
return super(CustomFilter, self)._validate() |
234
|
|
|
|
Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.
You can also find more detailed suggestions in the “Code” section of your repository.