Completed
Pull Request — master (#948)
by
unknown
01:23
created

zipline.pipeline.factors.binary_operator()   C

Complexity

Conditions 7

Size

Total Lines 60

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 7
dl 0
loc 60
rs 6.3864

How to fix   Long Method   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

1
"""
2
factor.py
3
"""
4
from operator import attrgetter
5
from numbers import Number
6
7
from numpy import float64, inf
8
from toolz import curry
9
10
from zipline.errors import (
11
    UnknownRankMethod,
12
    UnsupportedDataType,
13
)
14
from zipline.lib.rank import masked_rankdata_2d
15
from zipline.pipeline.mixins import (
16
    CustomTermMixin,
17
    PositiveWindowLengthMixin,
18
    SingleInputMixin,
19
)
20
from zipline.pipeline.term import CompositeTerm, NotSpecified
21
from zipline.pipeline.expression import (
22
    BadBinaryOperator,
23
    COMPARISONS,
24
    is_comparison,
25
    MATH_BINOPS,
26
    method_name_for_op,
27
    NumericalExpression,
28
    NUMEXPR_MATH_FUNCS,
29
    UNARY_OPS,
30
    unary_op_name,
31
)
32
from zipline.pipeline.filters import (
33
    NumExprFilter,
34
    PercentileFilter,
35
)
36
from zipline.utils.control_flow import nullctx
37
from zipline.utils.numpy_utils import (
38
    bool_dtype,
39
    datetime64ns_dtype,
40
    float64_dtype,
41
)
42
from zipline.utils.preprocess import preprocess
43
44
45
_RANK_METHODS = frozenset(['average', 'min', 'max', 'dense', 'ordinal'])
46
47
48
def numbers_to_float64(func, argname, argvalue):
49
    """
50
    Preprocessor for converting numerical inputs into floats.
51
52
    This is used in the binary operator constructors for Factor so that
53
    `2 + Factor()` has the same behavior as `2.0 + Factor()`.
54
    """
55
    if isinstance(argvalue, Number):
56
        return float64(argvalue)
57
    return argvalue
58
59
60
@curry
61
def set_attribute(name, value):
62
    """
63
    Decorator factory for setting attributes on a function.
64
65
    Doesn't change the behavior of the wrapped function.
66
67
    Usage
68
    -----
69
    >>> @set_attribute('__name__', 'foo')
70
    ... def bar():
71
    ...     return 3
72
    ...
73
    >>> bar()
74
    3
75
    >>> bar.__name__
76
    'foo'
77
    """
78
    def decorator(f):
79
        setattr(f, name, value)
80
        return f
81
    return decorator
82
83
84
# Decorators for setting the __name__ and __doc__ properties of a decorated
85
# function.
86
# Example:
87
with_name = set_attribute('__name__')
88
with_doc = set_attribute('__doc__')
89
90
91
def binop_return_type(op):
92
    if is_comparison(op):
93
        return NumExprFilter
94
    else:
95
        return NumExprFactor
96
97
98
def binop_return_dtype(op, left, right):
99
    """
100
    Compute the expected return dtype for the given binary operator.
101
102
    Parameters
103
    ----------
104
    op : str
105
        Operator symbol, (e.g. '+', '-', ...).
106
    left : numpy.dtype
107
        Dtype of left hand side.
108
    right : numpy.dtype
109
        Dtype of right hand side.
110
111
    Returns
112
    -------
113
    outdtype : numpy.dtype
114
        The dtype of the result of `left <op> right`.
115
    """
116
    if is_comparison(op):
117
        if left != right:
118
            raise TypeError(
119
                "Don't know how to compute {left} {op} {right}.\n"
120
                "Comparisons are only supported between Factors of equal "
121
                "dtypes.".format(left=left, op=op, right=right)
122
            )
123
        return bool_dtype
124
125
    elif left != float64_dtype or right != float64_dtype:
126
        raise TypeError(
127
            "Don't know how to compute {left} {op} {right}.\n"
128
            "Arithmetic operators are only supported on Factors of "
129
            "dtype 'float64'.".format(
130
                left=left.name,
131
                op=op,
132
                right=right.name,
133
            )
134
        )
135
    return float64_dtype
136
137
138
def binary_operator(op):
139
    """
140
    Factory function for making binary operator methods on a Factor subclass.
141
142
    Returns a function, "binary_operator" suitable for implementing functions
143
    like __add__.
144
    """
145
    # When combining a Factor with a NumericalExpression, we use this
146
    # attrgetter instance to defer to the commuted implementation of the
147
    # NumericalExpression operator.
148
    commuted_method_getter = attrgetter(method_name_for_op(op, commute=True))
149
150
    @preprocess(other=numbers_to_float64)
151
    @with_doc("Binary Operator: '%s'" % op)
152
    @with_name(method_name_for_op(op))
153
    def binary_operator(self, other):
154
        # This can't be hoisted up a scope because the types returned by
155
        # binop_return_type aren't defined when the top-level function is
156
        # invoked in the class body of Factor.
157
        return_type = binop_return_type(op)
158
        if isinstance(self, NumExprFactor):
159
            self_expr, other_expr, new_inputs = self.build_binary_op(
160
                op, other,
161
            )
162
            return return_type(
163
                "({left}) {op} ({right})".format(
164
                    left=self_expr,
165
                    op=op,
166
                    right=other_expr,
167
                ),
168
                new_inputs,
169
                dtype=binop_return_dtype(op, self.dtype, other.dtype),
170
            )
171
        elif isinstance(other, NumExprFactor):
172
            # NumericalExpression overrides ops to correctly handle merging of
173
            # inputs.  Look up and call the appropriate reflected operator with
174
            # ourself as the input.
175
            return commuted_method_getter(other)(self)
176
        elif isinstance(other, Factor):
177
            if self is other:
178
                return return_type(
179
                    "x_0 {op} x_0".format(op=op),
180
                    (self,),
181
                    dtype=binop_return_dtype(op, self.dtype, other.dtype),
182
                )
183
            return return_type(
184
                "x_0 {op} x_1".format(op=op),
185
                (self, other),
186
                dtype=binop_return_dtype(op, self.dtype, other.dtype),
187
            )
188
        elif isinstance(other, Number):
189
            return return_type(
190
                "x_0 {op} ({constant})".format(op=op, constant=other),
191
                binds=(self,),
192
                # Interpret numeric literals as floats.
193
                dtype=binop_return_dtype(op, self.dtype, other.dtype)
194
            )
195
        raise BadBinaryOperator(op, self, other)
196
197
    return binary_operator
198
199
200
def reflected_binary_operator(op):
201
    """
202
    Factory function for making binary operator methods on a Factor.
203
204
    Returns a function, "reflected_binary_operator" suitable for implementing
205
    functions like __radd__.
206
    """
207
    assert not is_comparison(op)
208
209
    @preprocess(other=numbers_to_float64)
210
    @with_name(method_name_for_op(op, commute=True))
211
    def reflected_binary_operator(self, other):
212
213
        if isinstance(self, NumericalExpression):
214
            self_expr, other_expr, new_inputs = self.build_binary_op(
215
                op, other
216
            )
217
            return NumExprFactor(
218
                "({left}) {op} ({right})".format(
219
                    left=other_expr,
220
                    right=self_expr,
221
                    op=op,
222
                ),
223
                new_inputs,
224
                dtype=binop_return_dtype(op, other.dtype, self.dtype)
225
            )
226
227
        # Only have to handle the numeric case because in all other valid cases
228
        # the corresponding left-binding method will be called.
229
        elif isinstance(other, Number):
230
            return NumExprFactor(
231
                "{constant} {op} x_0".format(op=op, constant=other),
232
                binds=(self,),
233
                dtype=binop_return_dtype(op, other.dtype, self.dtype),
234
            )
235
        raise BadBinaryOperator(op, other, self)
236
    return reflected_binary_operator
237
238
239
def unary_operator(op):
240
    """
241
    Factory function for making unary operator methods for Factors.
242
    """
243
    # Only negate is currently supported.
244
    valid_ops = {'-'}
245
    if op not in valid_ops:
246
        raise ValueError("Invalid unary operator %s." % op)
247
248
    @with_doc("Unary Operator: '%s'" % op)
249
    @with_name(unary_op_name(op))
250
    def unary_operator(self):
251
        if self.dtype != float64_dtype:
252
            raise TypeError(
253
                "Can't apply unary operator {op!r} to instance of "
254
                "{typename!r} with dtype {dtypename!r}.\n"
255
                "{op!r} is only supported for Factors of dtype "
256
                "'float64'.".format(
257
                    op=op,
258
                    typename=type(self).__name__,
259
                    dtypename=self.dtype.name,
260
                )
261
            )
262
263
        # This can't be hoisted up a scope because the types returned by
264
        # unary_op_return_type aren't defined when the top-level function is
265
        # invoked.
266
        if isinstance(self, NumericalExpression):
267
            return NumExprFactor(
268
                "{op}({expr})".format(op=op, expr=self._expr),
269
                self.inputs,
270
                dtype=float64_dtype,
271
            )
272
        else:
273
            return NumExprFactor(
274
                "{op}x_0".format(op=op),
275
                (self,),
276
                dtype=float64_dtype,
277
            )
278
    return unary_operator
279
280
281
def function_application(func):
282
    """
283
    Factory function for producing function application methods for Factor
284
    subclasses.
285
    """
286
    if func not in NUMEXPR_MATH_FUNCS:
287
        raise ValueError("Unsupported mathematical function '%s'" % func)
288
289
    @with_name(func)
290
    def mathfunc(self):
291
        if isinstance(self, NumericalExpression):
292
            return NumExprFactor(
293
                "{func}({expr})".format(func=func, expr=self._expr),
294
                self.inputs,
295
                dtype=float64_dtype,
296
            )
297
        else:
298
            return NumExprFactor(
299
                "{func}(x_0)".format(func=func),
300
                (self,),
301
                dtype=float64_dtype,
302
            )
303
    return mathfunc
304
305
306
FACTOR_DTYPES = frozenset([datetime64ns_dtype, float64_dtype])
307
308
309
class Factor(CompositeTerm):
310
    """
311
    Pipeline API expression producing numerically-valued outputs.
312
    """
313
    # Dynamically add functions for creating NumExprFactor/NumExprFilter
314
    # instances.
315
    clsdict = locals()
316
    clsdict.update(
317
        {
318
            method_name_for_op(op): binary_operator(op)
319
            # Don't override __eq__ because it breaks comparisons on tuples of
320
            # Factors.
321
            for op in MATH_BINOPS.union(COMPARISONS - {'=='})
322
        }
323
    )
324
    clsdict.update(
325
        {
326
            method_name_for_op(op, commute=True): reflected_binary_operator(op)
327
            for op in MATH_BINOPS
328
        }
329
    )
330
    clsdict.update(
331
        {
332
            unary_op_name(op): unary_operator(op)
333
            for op in UNARY_OPS
334
        }
335
    )
336
337
    clsdict.update(
338
        {
339
            funcname: function_application(funcname)
340
            for funcname in NUMEXPR_MATH_FUNCS
341
        }
342
    )
343
344
    __truediv__ = clsdict['__div__']
345
    __rtruediv__ = clsdict['__rdiv__']
346
347
    eq = binary_operator('==')
348
349
    def _validate(self):
350
        # Do superclass validation first so that `NotSpecified` dtypes get
351
        # handled.
352
        retval = super(Factor, self)._validate()
353
        if self.dtype not in FACTOR_DTYPES:
354
            raise UnsupportedDataType(
355
                typename=type(self).__name__,
356
                dtype=self.dtype
357
            )
358
        return retval
359
360
    def rank(self, method='ordinal', ascending=True, mask=NotSpecified):
361
        """
362
        Construct a new Factor representing the sorted rank of each column
363
        within each row.
364
365
        Parameters
366
        ----------
367
        method : str, {'ordinal', 'min', 'max', 'dense', 'average'}
368
            The method used to assign ranks to tied elements. See
369
            `scipy.stats.rankdata` for a full description of the semantics for
370
            each ranking method. Default is 'ordinal'.
371
        ascending : bool, optional
372
            Whether to return sorted rank in ascending or descending order.
373
            Default is True.
374
        mask : zipline.pipeline.Filter, optional
375
            A Filter representing assets to consider when computing ranks.
376
            If mask is supplied, ranks are computed ignoring any asset/date
377
            pairs for which `mask` produces a value of False.
378
379
        Returns
380
        -------
381
        ranks : zipline.pipeline.factors.Rank
382
            A new factor that will compute the ranking of the data produced by
383
            `self`.
384
385
        Notes
386
        -----
387
        The default value for `method` is different from the default for
388
        `scipy.stats.rankdata`.  See that function's documentation for a full
389
        description of the valid inputs to `method`.
390
391
        Missing or non-existent data on a given day will cause an asset to be
392
        given a rank of NaN for that day.
393
394
        See Also
395
        --------
396
        scipy.stats.rankdata
397
        zipline.lib.rank.masked_rankdata_2d
398
        zipline.pipeline.factors.factor.Rank
399
        """
400
        return Rank(self, method=method, ascending=ascending, mask=mask)
401
402
    def top(self, N, mask=NotSpecified):
403
        """
404
        Construct a Filter matching the top N asset values of self each day.
405
406
        Parameters
407
        ----------
408
        N : int
409
            Number of assets passing the returned filter each day.
410
        mask : zipline.pipeline.Filter, optional
411
            A Filter representing assets to consider when computing ranks.
412
            If mask is supplied, top values are computed ignoring any
413
            asset/date pairs for which `mask` produces a value of False.
414
415
        Returns
416
        -------
417
        filter : zipline.pipeline.filters.Filter
418
        """
419
        return self.rank(ascending=False, mask=mask) <= N
420
421
    def bottom(self, N, mask=NotSpecified):
422
        """
423
        Construct a Filter matching the bottom N asset values of self each day.
424
425
        Parameters
426
        ----------
427
        N : int
428
            Number of assets passing the returned filter each day.
429
        mask : zipline.pipeline.Filter, optional
430
            A Filter representing assets to consider when computing ranks.
431
            If mask is supplied, bottom values are computed ignoring any
432
            asset/date pairs for which `mask` produces a value of False.
433
434
        Returns
435
        -------
436
        filter : zipline.pipeline.Filter
437
        """
438
        return self.rank(ascending=True, mask=mask) <= N
439
440
    def percentile_between(self,
441
                           min_percentile,
442
                           max_percentile,
443
                           mask=NotSpecified):
444
        """
445
        Construct a new Filter representing entries from the output of this
446
        Factor that fall within the percentile range defined by min_percentile
447
        and max_percentile.
448
449
        Parameters
450
        ----------
451
        min_percentile : float [0.0, 100.0]
452
            Return True for assets falling above this percentile in the data.
453
        max_percentile : float [0.0, 100.0]
454
            Return True for assets falling below this percentile in the data.
455
        mask : zipline.pipeline.Filter, optional
456
            A Filter representing assets to consider when percentile
457
            thresholds.  If mask is supplied, percentile cutoffs are computed
458
            each day using only assets for which `mask` returns True, and
459
            assets not passing `mask` will produce False in the output of this
460
            filter as well.
461
462
        Returns
463
        -------
464
        out : zipline.pipeline.filters.PercentileFilter
465
            A new filter that will compute the specified percentile-range mask.
466
467
        See Also
468
        --------
469
        zipline.pipeline.filters.filter.PercentileFilter
470
        """
471
        return PercentileFilter(
472
            self,
473
            min_percentile=min_percentile,
474
            max_percentile=max_percentile,
475
            mask=mask,
476
        )
477
478
    def isnan(self):
479
        """
480
        A Filter producing True for all values where this Factor is NaN.
481
482
        Returns
483
        -------
484
        nanfilter : zipline.pipeline.filters.Filter
485
        """
486
        return self != self
487
488
    def notnan(self):
489
        """
490
        A Filter producing True for values where this Factor is not NaN.
491
492
        Returns
493
        -------
494
        nanfilter : zipline.pipeline.filters.Filter
495
        """
496
        return ~self.isnan()
497
498
    def isfinite(self):
499
        """
500
        A Filter producing True for values where this Factor is anything but
501
        NaN, inf, or -inf.
502
        """
503
        return (-inf < self) & (self < inf)
504
505
506
class NumExprFactor(NumericalExpression, Factor):
507
    """
508
    Factor computed from a numexpr expression.
509
510
    Parameters
511
    ----------
512
    expr : string
513
       A string suitable for passing to numexpr.  All variables in 'expr'
514
       should be of the form "x_i", where i is the index of the corresponding
515
       factor input in 'binds'.
516
    binds : tuple
517
       A tuple of factors to use as inputs.
518
519
    Notes
520
    -----
521
    NumExprFactors are constructed by numerical operators like `+` and `-`.
522
    Users should rarely need to construct a NumExprFactor directly.
523
    """
524
    pass
525
526
527
class Rank(SingleInputMixin, Factor):
528
    """
529
    A Factor representing the row-wise rank data of another Factor.
530
531
    Parameters
532
    ----------
533
    factor : zipline.pipeline.factors.Factor
534
        The factor on which to compute ranks.
535
    method : str, {'average', 'min', 'max', 'dense', 'ordinal'}
536
        The method used to assign ranks to tied elements.  See
537
        `scipy.stats.rankdata` for a full description of the semantics for each
538
        ranking method.
539
540
    See Also
541
    --------
542
    scipy.stats.rankdata : Underlying ranking algorithm.
543
    zipline.factors.Factor.rank : Method-style interface to same functionality.
544
545
    Notes
546
    -----
547
    Most users should call Factor.rank rather than directly construct an
548
    instance of this class.
549
    """
550
    window_length = 0
551
    dtype = float64_dtype
552
553
    def __new__(cls, factor, method, ascending, mask):
554
        return super(Rank, cls).__new__(
555
            cls,
556
            inputs=(factor,),
557
            method=method,
558
            ascending=ascending,
559
            mask=mask,
560
        )
561
562
    def _init(self, method, ascending, *args, **kwargs):
563
        self._method = method
564
        self._ascending = ascending
565
        return super(Rank, self)._init(*args, **kwargs)
566
567
    @classmethod
568
    def static_identity(cls, method, ascending, *args, **kwargs):
569
        return (
570
            super(Rank, cls).static_identity(*args, **kwargs),
571
            method,
572
            ascending,
573
        )
574
575
    def _validate(self):
576
        """
577
        Verify that the stored rank method is valid.
578
        """
579
        if self._method not in _RANK_METHODS:
580
            raise UnknownRankMethod(
581
                method=self._method,
582
                choices=set(_RANK_METHODS),
583
            )
584
        return super(Rank, self)._validate()
585
586
    def _compute(self, arrays, dates, assets, mask):
587
        """
588
        For each row in the input, compute a like-shaped array of per-row
589
        ranks.
590
        """
591
        return masked_rankdata_2d(
592
            arrays[0],
593
            mask,
594
            self.inputs[0].missing_value,
595
            self._method,
596
            self._ascending,
597
        )
598
599
    def __repr__(self):
600
        return "{type}({input_}, method='{method}', mask={mask})".format(
601
            type=type(self).__name__,
602
            input_=self.inputs[0],
603
            method=self._method,
604
            mask=self.mask,
605
        )
606
607
608
class CustomFactor(PositiveWindowLengthMixin, CustomTermMixin, Factor):
609
    '''
610
    Base class for user-defined Factors.
611
612
    Parameters
613
    ----------
614
    inputs : iterable, optional
615
        An iterable of `BoundColumn` instances (e.g. USEquityPricing.close),
616
        describing the data to load and pass to `self.compute`.  If this
617
        argument is passed to the CustomFactor constructor, we look for a
618
        class-level attribute named `inputs`.
619
    window_length : int, optional
620
        Number of rows to pass for each input.  If this argument is not passed
621
        to the CustomFactor constructor, we look for a class-level attribute
622
        named `window_length`.
623
624
    Notes
625
    -----
626
    Users implementing their own Factors should subclass CustomFactor and
627
    implement a method named `compute` with the following signature:
628
629
    .. code-block:: python
630
631
        def compute(self, today, assets, out, *inputs):
632
           ...
633
634
    On each simulation date, ``compute`` will be called with the current date,
635
    an array of sids, an output array, and an input array for each expression
636
    passed as inputs to the CustomFactor constructor.
637
638
    The specific types of the values passed to `compute` are as follows::
639
640
        today : np.datetime64[ns]
641
            Row label for the last row of all arrays passed as `inputs`.
642
        assets : np.array[int64, ndim=1]
643
            Column labels for `out` and`inputs`.
644
        out : np.array[self.dtype, ndim=1]
645
            Output array of the same shape as `assets`.  `compute` should write
646
            its desired return values into `out`.
647
        *inputs : tuple of np.array
648
            Raw data arrays corresponding to the values of `self.inputs`.
649
650
    ``compute`` functions should expect to be passed NaN values for dates on
651
    which no data was available for an asset.  This may include dates on which
652
    an asset did not yet exist.
653
654
    For example, if a CustomFactor requires 10 rows of close price data, and
655
    asset A started trading on Monday June 2nd, 2014, then on Tuesday, June
656
    3rd, 2014, the column of input data for asset A will have 9 leading NaNs
657
    for the preceding days on which data was not yet available.
658
659
    Examples
660
    --------
661
662
    A CustomFactor with pre-declared defaults:
663
664
    .. code-block:: python
665
666
        class TenDayRange(CustomFactor):
667
            """
668
            Computes the difference between the highest high in the last 10
669
            days and the lowest low.
670
671
            Pre-declares high and low as default inputs and `window_length` as
672
            10.
673
            """
674
675
            inputs = [USEquityPricing.high, USEquityPricing.low]
676
            window_length = 10
677
678
            def compute(self, today, assets, out, highs, lows):
679
                from numpy import nanmin, nanmax
680
681
                highest_highs = nanmax(highs, axis=0)
682
                lowest_lows = nanmin(lows, axis=0)
683
                out[:] = highest_highs - lowest_lows
684
685
686
        # Doesn't require passing inputs or window_length because they're
687
        # pre-declared as defaults for the TenDayRange class.
688
        ten_day_range = TenDayRange()
689
690
    A CustomFactor without defaults:
691
692
    .. code-block:: python
693
694
        class MedianValue(CustomFactor):
695
            """
696
            Computes the median value of an arbitrary single input over an
697
            arbitrary window..
698
699
            Does not declare any defaults, so values for `window_length` and
700
            `inputs` must be passed explicitly on every construction.
701
            """
702
703
            def compute(self, today, assets, out, data):
704
                from numpy import nanmedian
705
                out[:] = data.nanmedian(data, axis=0)
706
707
        # Values for `inputs` and `window_length` must be passed explicitly to
708
        # MedianValue.
709
        median_close10 = MedianValue([USEquityPricing.close], window_length=10)
710
        median_low15 = MedianValue([USEquityPricing.low], window_length=15)
711
    '''
712
    dtype = float64_dtype
713
    ctx = nullctx()
714