Completed
Pull Request — master (#881)
by
unknown
01:24
created

zipline.pipeline.filters.CustomFilter   A

Complexity

Total Complexity 2

Size/Duplication

Total Lines 10
Duplicated Lines 0 %
Metric Value
dl 0
loc 10
rs 10
wmc 2

1 Method

Rating   Name   Duplication   Size   Complexity  
A _validate() 0 4 2
1
"""
2
filter.py
3
"""
4
from numpy import (
5
    bool_,
6
    float64,
7
    nan,
8
    nanpercentile,
9
)
10
from itertools import chain
11
from operator import attrgetter
12
13
from zipline.errors import (
14
    BadPercentileBounds,
15
    UnsupportedDataType,
16
)
17
from zipline.pipeline.term import (
18
    CompositeTerm,
19
    CustomTermMixin,
20
    RequiredWindowLengthMixin,
21
    SingleInputMixin,
22
)
23
from zipline.pipeline.expression import (
24
    BadBinaryOperator,
25
    FILTER_BINOPS,
26
    method_name_for_op,
27
    NumericalExpression,
28
)
29
from zipline.utils.control_flow import nullctx
30
31
32
def concat_tuples(*tuples):
33
    """
34
    Concatenate a sequence of tuples into one tuple.
35
    """
36
    return tuple(chain(*tuples))
37
38
39
def binary_operator(op):
40
    """
41
    Factory function for making binary operator methods on a Filter subclass.
42
43
    Returns a function "binary_operator" suitable for implementing functions
44
    like __and__ or __or__.
45
    """
46
    # When combining a Filter with a NumericalExpression, we use this
47
    # attrgetter instance to defer to the commuted interpretation of the
48
    # NumericalExpression operator.
49
    commuted_method_getter = attrgetter(method_name_for_op(op, commute=True))
50
51
    def binary_operator(self, other):
52
        if isinstance(self, NumericalExpression):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
53
            self_expr, other_expr, new_inputs = self.build_binary_op(
54
                op, other,
55
            )
56
            return NumExprFilter(
57
                "({left}) {op} ({right})".format(
58
                    left=self_expr,
59
                    op=op,
60
                    right=other_expr,
61
                ),
62
                new_inputs,
63
            )
64
        elif isinstance(other, NumericalExpression):
65
            # NumericalExpression overrides numerical ops to correctly handle
66
            # merging of inputs.  Look up and call the appropriate
67
            # right-binding operator with ourself as the input.
68
            return commuted_method_getter(other)(self)
69
        elif isinstance(other, Filter):
70
            if self is other:
71
                return NumExprFilter(
72
                    "x_0 {op} x_0".format(op=op),
73
                    (self,),
74
                )
75
            return NumExprFilter(
76
                "x_0 {op} x_1".format(op=op),
77
                (self, other),
78
            )
79
        elif isinstance(other, int):  # Note that this is true for bool as well
80
            return NumExprFilter(
81
                "x_0 {op} ({constant})".format(op=op, constant=int(other)),
82
                binds=(self,),
83
            )
84
        raise BadBinaryOperator(op, self, other)
85
86
    binary_operator.__doc__ = "Binary Operator: '%s'" % op
87
    return binary_operator
88
89
90
def unary_operator(op):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
91
    """
92
    Factory function for making unary operator methods for Filters.
93
    """
94
    valid_ops = {'~'}
95
    if op not in valid_ops:
96
        raise ValueError("Invalid unary operator %s." % op)
97
98
    def unary_operator(self):
99
        # This can't be hoisted up a scope because the types returned by
100
        # unary_op_return_type aren't defined when the top-level function is
101
        # invoked.
102
        if isinstance(self, NumericalExpression):
103
            return NumExprFilter(
104
                "{op}({expr})".format(op=op, expr=self._expr),
105
                self.inputs,
106
            )
107
        else:
108
            return NumExprFilter("{op}x_0".format(op=op), (self,))
109
110
    unary_operator.__doc__ = "Unary Operator: '%s'" % op
111
    return unary_operator
112
113
114
class Filter(CompositeTerm):
115
    """
116
    Pipeline API expression producing boolean-valued outputs.
117
    """
118
    dtype = bool_
119
120
    clsdict = locals()
121
    clsdict.update(
122
        {
123
            method_name_for_op(op): binary_operator(op)
124
            for op in FILTER_BINOPS
125
        }
126
    )
127
    __invert__ = unary_operator('~')
128
129
130
class NumExprFilter(NumericalExpression, Filter):
131
    """
132
    A Filter computed from a numexpr expression.
133
    """
134
135
    def _compute(self, arrays, dates, assets, mask):
136
        """
137
        Compute our result with numexpr, then re-apply `mask`.
138
        """
139
        return super(NumExprFilter, self)._compute(
140
            arrays,
141
            dates,
142
            assets,
143
            mask,
144
        ) & mask
145
146
147
class PercentileFilter(SingleInputMixin, Filter):
148
    """
149
    A Filter representing assets falling between percentile bounds of a Factor.
150
151
    Parameters
152
    ----------
153
    factor : zipline.pipeline.factor.Factor
154
        The factor over which to compute percentile bounds.
155
    min_percentile : float [0.0, 1.0]
156
        The minimum percentile rank of an asset that will pass the filter.
157
    max_percentile : float [0.0, 1.0]
158
        The maxiumum percentile rank of an asset that will pass the filter.
159
    """
160
    window_length = 0
161
162
    def __new__(cls, factor, min_percentile, max_percentile, mask):
163
        return super(PercentileFilter, cls).__new__(
164
            cls,
165
            inputs=(factor,),
166
            mask=mask,
167
            min_percentile=min_percentile,
168
            max_percentile=max_percentile,
169
        )
170
171
    def _init(self, min_percentile, max_percentile, *args, **kwargs):
172
        self._min_percentile = min_percentile
173
        self._max_percentile = max_percentile
174
        return super(PercentileFilter, self)._init(*args, **kwargs)
175
176
    @classmethod
177
    def static_identity(cls, min_percentile, max_percentile, *args, **kwargs):
178
        return (
179
            super(PercentileFilter, cls).static_identity(*args, **kwargs),
180
            min_percentile,
181
            max_percentile,
182
        )
183
184
    def _validate(self):
185
        """
186
        Ensure that our percentile bounds are well-formed.
187
        """
188
        if not 0.0 <= self._min_percentile < self._max_percentile <= 100.0:
189
            raise BadPercentileBounds(
190
                min_percentile=self._min_percentile,
191
                max_percentile=self._max_percentile,
192
            )
193
        return super(PercentileFilter, self)._validate()
194
195
    def _compute(self, arrays, dates, assets, mask):
196
        """
197
        For each row in the input, compute a mask of all values falling between
198
        the given percentiles.
199
        """
200
        # TODO: Review whether there's a better way of handling small numbers
201
        # of columns.
202
        data = arrays[0].copy().astype(float64)
203
        data[~mask] = nan
204
205
        # FIXME: np.nanpercentile **should** support computing multiple bounds
206
        # at once, but there's a bug in the logic for multiple bounds in numpy
207
        # 1.9.2.  It will be fixed in 1.10.
208
        # c.f. https://github.com/numpy/numpy/pull/5981
209
        lower_bounds = nanpercentile(
210
            data,
211
            self._min_percentile,
212
            axis=1,
213
            keepdims=True,
214
        )
215
        upper_bounds = nanpercentile(
216
            data,
217
            self._max_percentile,
218
            axis=1,
219
            keepdims=True,
220
        )
221
        return (lower_bounds <= data) & (data <= upper_bounds)
222
223
224
class CustomFilter(RequiredWindowLengthMixin, CustomTermMixin, Filter):
225
    """
226
    Filter analog to ``CustomFactor``.
227
    """
228
    ctx = nullctx()
229
230
    def _validate(self):
231
        if self.dtype != bool_:
232
            raise UnsupportedDataType(dtype=self.dtype)
233
        return super(CustomFilter, self)._validate()
234