Completed
Push — master ( ebb4fb...323695 )
by
unknown
01:25
created

zipline.pipeline.Term.mask()   A

Complexity

Conditions 1

Size

Total Lines 7

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 1
dl 0
loc 7
rs 9.4286
1
"""
2
Base class for Filters, Factors and Classifiers
3
"""
4
from abc import ABCMeta, abstractproperty
5
from weakref import WeakValueDictionary
6
7
from numpy import dtype as dtype_class
8
from six import with_metaclass
9
10
from zipline.errors import (
11
    DTypeNotSpecified,
12
    InputTermNotAtomic,
13
    InvalidDType,
14
    TermInputsNotSpecified,
15
    WindowLengthNotSpecified,
16
)
17
from zipline.utils.memoize import lazyval
18
from zipline.utils.numpy_utils import bool_dtype, default_fillvalue_for_dtype
19
from zipline.utils.sentinel import sentinel
20
21
22
NotSpecified = sentinel(
23
    'NotSpecified',
24
    'Singleton sentinel value used for Term defaults.',
25
)
26
27
28
class Term(with_metaclass(ABCMeta, object)):
29
    """
30
    Base class for terms in a Pipeline API compute graph.
31
    """
32
    # These are NotSpecified because a subclass is required to provide them.
33
    dtype = NotSpecified
34
    domain = NotSpecified
35
36
    # Subclasses aren't required to provide `params`.  The default behavior is
37
    # no params.
38
    params = ()
39
40
    _term_cache = WeakValueDictionary()
41
42
    def __new__(cls,
43
                domain=domain,
44
                dtype=dtype,
45
                # params is explicitly not allowed to be passed to an instance.
46
                *args,
47
                **kwargs):
48
        """
49
        Memoized constructor for Terms.
50
51
        Caching previously-constructed Terms is useful because it allows us to
52
        only compute equivalent sub-expressions once when traversing a Pipeline
53
        dependency graph.
54
55
        Caching previously-constructed Terms is **sane** because terms and
56
        their inputs are both conceptually immutable.
57
        """
58
        # Class-level attributes can be used to provide defaults for Term
59
        # subclasses.
60
61
        if domain is NotSpecified:
62
            domain = cls.domain
63
64
        dtype = cls._validate_dtype(dtype)
65
        params = cls._pop_params(kwargs)
66
67
        identity = cls.static_identity(
68
            domain=domain,
69
            dtype=dtype,
70
            params=params,
71
            *args, **kwargs
72
        )
73
74
        try:
75
            return cls._term_cache[identity]
76
        except KeyError:
77
            new_instance = cls._term_cache[identity] = \
78
                super(Term, cls).__new__(cls)._init(
79
                    domain=domain,
80
                    dtype=dtype,
81
                    params=params,
82
                    *args, **kwargs
83
                )
84
            return new_instance
85
86
    @classmethod
87
    def _pop_params(cls, kwargs):
88
        """
89
        Pop entries from the `kwargs` passed to cls.__new__ based on the values
90
        in `cls.params`.
91
92
        Parameters
93
        ----------
94
        kwargs : dict
95
            The kwargs passed to cls.__new__.
96
97
        Returns
98
        -------
99
        params : list[(str, object)]
100
            A list of string, value pairs containing the entries in cls.params.
101
102
        Raises
103
        ------
104
        TypeError
105
            Raised if any parameter values are not passed or not hashable.
106
        """
107
        param_values = []
108
        for key in cls.params:
109
            try:
110
                value = kwargs.pop(key)
111
                # Check here that the value is hashable so that we fail here
112
                # instead of trying to hash the param values tuple later.
113
                hash(key)
114
                param_values.append(value)
115
            except KeyError:
116
                raise TypeError(
117
                    "{typename} expected a keyword parameter {name!r}.".format(
118
                        typename=cls.__name__,
119
                        name=key
120
                    )
121
                )
122
            except TypeError:
123
                # Value wasn't hashable.
124
                raise TypeError(
125
                    "{typename} expected a hashable value for parameter "
126
                    "{name!r}, but got {value!r} instead.".format(
127
                        typename=cls.__name__,
128
                        name=key,
129
                        value=value,
130
                    )
131
                )
132
        return tuple(zip(cls.params, param_values))
133
134
    @classmethod
135
    def _validate_dtype(cls, passed_dtype):
136
        """
137
        Validate a `dtype` passed to Term.__new__.
138
139
        If passed_dtype is NotSpecified, then we try to fall back to a
140
        class-level attribute.  If a value is found at that point, we pass it
141
        to np.dtype so that users can pass `float` or `bool` and have them
142
        coerce to the appropriate numpy types.
143
144
        Returns
145
        -------
146
        validated : np.dtype
147
            The dtype to use for the new term.
148
149
        Raises
150
        ------
151
        DTypeNotSpecified
152
            When no dtype was passed to the instance, and the class doesn't
153
            provide a default.
154
        InvalidDType
155
            When either the class or the instance provides a value not
156
            coercible to a numpy dtype.
157
        """
158
        dtype = passed_dtype
159
        if dtype is NotSpecified:
160
            dtype = cls.dtype
161
        if dtype is NotSpecified:
162
            raise DTypeNotSpecified(termname=cls.__name__)
163
        try:
164
            dtype = dtype_class(dtype)
165
        except TypeError:
166
            raise InvalidDType(dtype=dtype, termname=cls.__name__)
167
        return dtype
168
169
    def __init__(self, *args, **kwargs):
170
        """
171
        Noop constructor to play nicely with our caching __new__.  Subclasses
172
        should implement _init instead of this method.
173
174
        When a class' __new__ returns an instance of that class, Python will
175
        automatically call __init__ on the object, even if a new object wasn't
176
        actually constructed.  Because we memoize instances, we often return an
177
        object that was already initialized from __new__, in which case we
178
        don't want to call __init__ again.
179
180
        Subclasses that need to initialize new instances should override _init,
181
        which is guaranteed to be called only once.
182
        """
183
        pass
184
185
    @classmethod
186
    def static_identity(cls, domain, dtype, params):
187
        """
188
        Return the identity of the Term that would be constructed from the
189
        given arguments.
190
191
        Identities that compare equal will cause us to return a cached instance
192
        rather than constructing a new one.  We do this primarily because it
193
        makes dependency resolution easier.
194
195
        This is a classmethod so that it can be called from Term.__new__ to
196
        determine whether to produce a new instance.
197
        """
198
        return (cls, domain, dtype, params)
199
200
    def _init(self, domain, dtype, params):
201
        """
202
        Parameters
203
        ----------
204
        domain : object
205
            Unused placeholder.
206
        dtype : np.dtype
207
            Dtype of this term's output.
208
        params : tuple[(str, hashable)]
209
            Tuple of key/value pairs of additional parameters.
210
        """
211
        self.domain = domain
212
        self.dtype = dtype
213
214
        for name, value in params:
215
            if hasattr(self, name):
216
                raise TypeError(
217
                    "Parameter {name!r} conflicts with already-present"
218
                    "attribute with value {value!r}.".format(
219
                        name=name,
220
                        value=getattr(self, name),
221
                    )
222
                )
223
            # TODO: Consider setting these values as attributes and replacing
224
            # the boilerplate in NumericalExpression, Rank, and
225
            # PercentileFilter.
226
227
        self.params = dict(params)
228
229
        # Make sure that subclasses call super() in their _validate() methods
230
        # by setting this flag.  The base class implementation of _validate
231
        # should set this flag to True.
232
        self._subclass_called_super_validate = False
233
        self._validate()
234
        del self._subclass_called_super_validate
235
236
        return self
237
238
    def _validate(self):
239
        """
240
        Assert that this term is well-formed.  This should be called exactly
241
        once, at the end of Term._init().
242
        """
243
        # mark that we got here to enforce that subclasses overriding _validate
244
        # call super().
245
        self._subclass_called_super_validate = True
246
247
    @abstractproperty
248
    def inputs(self):
249
        """
250
        A tuple of other Terms that this Term requires for computation.
251
        """
252
        raise NotImplementedError()
253
254
    @abstractproperty
255
    def mask(self):
256
        """
257
        A 2D Filter representing asset/date pairs to include while
258
        computing this Term. (True means include; False means exclude.)
259
        """
260
        raise NotImplementedError()
261
262
    @lazyval
263
    def dependencies(self):
264
        return self.inputs + (self.mask,)
265
266
    @lazyval
267
    def atomic(self):
268
        return not any(dep for dep in self.dependencies
269
                       if dep is not AssetExists())
270
271
    @lazyval
272
    def missing_value(self):
273
        return default_fillvalue_for_dtype(self.dtype)
274
275
276
class AssetExists(Term):
277
    """
278
    Pseudo-filter describing whether or not an asset existed on a given day.
279
    This is the default mask for all terms that haven't been passed a mask
280
    explicitly.
281
282
    This is morally a Filter, in the sense that it produces a boolean value for
283
    every asset on every date.  We don't subclass Filter, however, because
284
    `AssetExists` is computed directly by the PipelineEngine.
285
286
    See Also
287
    --------
288
    zipline.assets.AssetFinder.lifetimes
289
    """
290
    dtype = bool_dtype
291
    dataset = None
292
    extra_input_rows = 0
293
    inputs = ()
294
    dependencies = ()
295
    mask = None
296
297
    def __repr__(self):
298
        return "AssetExists()"
299
300
301
class CompositeTerm(Term):
302
    inputs = NotSpecified
303
    window_length = NotSpecified
304
    mask = NotSpecified
305
306
    def __new__(cls,
307
                inputs=inputs,
308
                window_length=window_length,
309
                mask=mask,
310
                *args, **kwargs):
311
312
        if inputs is NotSpecified:
313
            inputs = cls.inputs
314
315
        # Having inputs = NotSpecified is an error, but we handle it later
316
        # in self._validate rather than here.
317
        if inputs is not NotSpecified:
318
            # Allow users to specify lists as class-level defaults, but
319
            # normalize to a tuple so that inputs is hashable.
320
            inputs = tuple(inputs)
321
322
        if mask is NotSpecified:
323
            mask = cls.mask
324
        if mask is NotSpecified:
325
            mask = AssetExists()
326
327
        if window_length is NotSpecified:
328
            window_length = cls.window_length
329
330
        return super(CompositeTerm, cls).__new__(cls, inputs=inputs, mask=mask,
331
                                                 window_length=window_length,
332
                                                 *args, **kwargs)
333
334
    def _init(self, inputs, window_length, mask, *args, **kwargs):
335
        self.inputs = inputs
336
        self.window_length = window_length
337
        self.mask = mask
338
        return super(CompositeTerm, self)._init(*args, **kwargs)
339
340
    @classmethod
341
    def static_identity(cls, inputs, window_length, mask, *args, **kwargs):
342
        return (
343
            super(CompositeTerm, cls).static_identity(*args, **kwargs),
344
            inputs,
345
            window_length,
346
            mask,
347
        )
348
349
    def _validate(self):
350
        """
351
        Assert that this term is well-formed.  This should be called exactly
352
        once, at the end of Term._init().
353
        """
354
        if self.inputs is NotSpecified:
355
            raise TermInputsNotSpecified(termname=type(self).__name__)
356
        if self.window_length is NotSpecified:
357
            raise WindowLengthNotSpecified(termname=type(self).__name__)
358
        if self.mask is NotSpecified:
359
            # This isn't user error, this is a bug in our code.
360
            raise AssertionError("{term} has no mask".format(term=self))
361
362
        if self.window_length:
363
            for child in self.inputs:
364
                if not child.atomic:
365
                    raise InputTermNotAtomic(parent=self, child=child)
366
367
        return super(CompositeTerm, self)._validate()
368
369
    def _compute(self, inputs, dates, assets, mask):
370
        """
371
        Subclasses should implement this to perform actual computation.
372
        This is `_compute` rather than just `compute` because `compute` is
373
        reserved for user-supplied functions in CustomFactor.
374
        """
375
        raise NotImplementedError()
376
377
    @lazyval
378
    def windowed(self):
379
        """
380
        Whether or not this term represents a trailing window computation.
381
382
        If term.windowed is truthy, its compute_from_windows method will be
383
        called with instances of AdjustedArray as inputs.
384
385
        If term.windowed is falsey, its compute_from_baseline will be called
386
        with instances of np.ndarray as inputs.
387
        """
388
        return (
389
            self.window_length is not NotSpecified
390
            and self.window_length > 0
391
        )
392
393
    @lazyval
394
    def extra_input_rows(self):
395
        """
396
        The number of extra rows needed for each of our inputs to compute this
397
        term.
398
        """
399
        return max(0, self.window_length - 1)
400
401
    def __repr__(self):
402
        return (
403
            "{type}({inputs}, window_length={window_length})"
404
        ).format(
405
            type=type(self).__name__,
406
            inputs=self.inputs,
407
            window_length=self.window_length,
408
        )
409