Completed
Pull Request — master (#836)
by
unknown
01:28
created

zipline.pipeline.CustomTermMixin.__init__()   A

Complexity

Conditions 1

Size

Total Lines 4

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 1
dl 0
loc 4
rs 10
1
"""
2
Base class for Filters, Factors and Classifiers
3
"""
4
from abc import ABCMeta, abstractproperty
5
from weakref import WeakValueDictionary
6
7
from numpy import bool_, full, nan
8
from six import with_metaclass
9
10
from zipline.errors import (
11
    DTypeNotSpecified,
12
    InputTermNotAtomic,
13
    TermInputsNotSpecified,
14
    WindowLengthNotPositive,
15
    WindowLengthNotSpecified,
16
)
17
from zipline.utils.memoize import lazyval
18
from zipline.utils.sentinel import sentinel
19
20
21
NotSpecified = sentinel(
22
    'NotSpecified',
23
    'Singleton sentinel value used for Term defaults.',
24
)
25
26
27
class Term(with_metaclass(ABCMeta, object)):
28
    """
29
    Base class for terms in a Pipeline API compute graph.
30
    """
31
    # These are NotSpecified because a subclass is required to provide them.
32
    dtype = NotSpecified
33
    domain = NotSpecified
34
35
    _term_cache = WeakValueDictionary()
36
37
    def __new__(cls,
38
                domain=NotSpecified,
39
                dtype=NotSpecified,
40
                *args,
41
                **kwargs):
42
        """
43
        Memoized constructor for Terms.
44
45
        Caching previously-constructed Terms is useful because it allows us to
46
        only compute equivalent sub-expressions once when traversing a Pipeline
47
        dependency graph.
48
49
        Caching previously-constructed Terms is **sane** because terms and
50
        their inputs are both conceptually immutable.
51
        """
52
        # Class-level attributes can be used to provide defaults for Term
53
        # subclasses.
54
55
        if domain is NotSpecified:
56
            domain = cls.domain
57
58
        if dtype is NotSpecified:
59
            dtype = cls.dtype
60
61
        identity = cls.static_identity(
62
            domain=domain,
63
            dtype=dtype,
64
            *args, **kwargs
65
        )
66
67
        try:
68
            return cls._term_cache[identity]
69
        except KeyError:
70
            new_instance = cls._term_cache[identity] = \
71
                super(Term, cls).__new__(cls)._init(
72
                    domain=domain,
73
                    dtype=dtype,
74
                    *args, **kwargs
75
                )
76
            return new_instance
77
78
    def __init__(self, *args, **kwargs):
79
        """
80
        Noop constructor to play nicely with our caching __new__.  Subclasses
81
        should implement _init instead of this method.
82
83
        When a class' __new__ returns an instance of that class, Python will
84
        automatically call __init__ on the object, even if a new object wasn't
85
        actually constructed.  Because we memoize instances, we often return an
86
        object that was already initialized from __new__, in which case we
87
        don't want to call __init__ again.
88
89
        Subclasses that need to initialize new instances should override _init,
90
        which is guaranteed to be called only once.
91
        """
92
        pass
93
94
    def _init(self, domain, dtype):
95
        self.domain = domain
96
        self.dtype = dtype
97
98
        self._validate()
99
        return self
100
101
    @classmethod
102
    def static_identity(cls, domain, dtype):
103
        """
104
        Return the identity of the Term that would be constructed from the
105
        given arguments.
106
107
        Identities that compare equal will cause us to return a cached instance
108
        rather than constructing a new one.  We do this primarily because it
109
        makes dependency resolution easier.
110
111
        This is a classmethod so that it can be called from Term.__new__ to
112
        determine whether to produce a new instance.
113
        """
114
        return (cls, domain, dtype)
115
116
    def _validate(self):
117
        """
118
        Assert that this term is well-formed.  This should be called exactly
119
        once, at the end of Term._init().
120
        """
121
        if self.dtype is NotSpecified:
122
            raise DTypeNotSpecified(termname=type(self).__name__)
123
124
    @abstractproperty
125
    def inputs(self):
126
        """
127
        A tuple of other Terms that this Term requires for computation.
128
        """
129
        raise NotImplementedError()
130
131
    @abstractproperty
132
    def mask(self):
133
        """
134
        A 2D Filter representing asset/date pairs to include while
135
        computing this Term. (True means include; False means exclude.)
136
        """
137
        raise NotImplementedError()
138
139
    @lazyval
140
    def dependencies(self):
141
        return self.inputs + (self.mask,)
142
143
    @lazyval
144
    def atomic(self):
145
        return not any(dep for dep in self.dependencies
146
                       if dep is not AssetExists())
147
148
149
class AssetExists(Term):
150
    """
151
    Pseudo-filter describing whether or not an asset existed on a given day.
152
    This is the default mask for all terms that haven't been passed a mask
153
    explicitly.
154
155
    This is morally a Filter, in the sense that it produces a boolean value for
156
    every asset on every date.  We don't subclass Filter, however, because
157
    `AssetExists` is computed directly by the PipelineEngine.
158
159
    See Also
160
    --------
161
    zipline.assets.AssetFinder.lifetimes
162
    """
163
    dtype = bool_
164
    dataset = None
165
    extra_input_rows = 0
166
    inputs = ()
167
    dependencies = ()
168
    mask = None
169
170
    def __repr__(self):
171
        return "AssetExists()"
172
173
174
# TODO: Move mixins to a separate file?
175
class SingleInputMixin(object):
176
177
    def _validate(self):
178
        num_inputs = len(self.inputs)
179
        if num_inputs != 1:
180
            raise ValueError(
181
                "{typename} expects only one input, "
182
                "but received {num_inputs} instead.".format(
183
                    typename=type(self).__name__,
184
                    num_inputs=num_inputs
185
                )
186
            )
187
        return super(SingleInputMixin, self)._validate()
188
189
190
class RequiredWindowLengthMixin(object):
191
    def _validate(self):
192
        if not self.windowed:
193
            raise WindowLengthNotPositive(window_length=self.window_length)
194
        return super(RequiredWindowLengthMixin, self)._validate()
195
196
197
class CustomTermMixin(object):
198
    """
199
    Mixin for user-defined rolling-window Terms.
200
201
    Implements `_compute` in terms of a user-defined `compute` function, which
202
    is mapped over the input windows.
203
204
    Used by CustomFactor, CustomFilter, CustomClassifier, etc.
205
    """
206
207
    def __new__(cls, inputs=NotSpecified, window_length=NotSpecified):
208
209
        return super(CustomTermMixin, cls).__new__(
210
            cls,
211
            inputs=inputs,
212
            window_length=window_length,
213
        )
214
215
    def __init__(self, inputs=NotSpecified, window_length=NotSpecified):
216
        return super(CustomTermMixin, self).__init__(
217
            inputs=inputs,
218
            window_length=window_length,
219
        )
220
221
    def compute(self, today, assets, out, *arrays):
222
        """
223
        Override this method with a function that writes a value into `out`.
224
        """
225
        raise NotImplementedError()
226
227
    def _compute(self, windows, dates, assets, mask):
228
        """
229
        Call the user's `compute` function on each window with a pre-built
230
        output array.
231
        """
232
        # TODO: Make mask available to user's `compute`.
233
        compute = self.compute
234
        out = full(mask.shape, nan, dtype=self.dtype)
235
        with self.ctx:
236
            # TODO: Consider pre-filtering columns that are all-nan at each
237
            # time-step?
238
            for idx, date in enumerate(dates):
239
                compute(
240
                    date,
241
                    assets,
242
                    out[idx],
243
                    *(next(w) for w in windows)
244
                )
245
        out[~mask] = nan
246
        return out
247
248
    def short_repr(self):
249
        return type(self).__name__ + '(%d)' % self.window_length
250
251
252
class CompositeTerm(Term):
253
    inputs = NotSpecified
254
    window_length = NotSpecified
255
    mask = NotSpecified
256
257
    def __new__(cls, inputs=NotSpecified, window_length=NotSpecified,
258
                mask=NotSpecified, *args, **kwargs):
259
260
        if inputs is NotSpecified:
261
            inputs = cls.inputs
262
        # Having inputs = NotSpecified is an error, but we handle it later
263
        # in self._validate rather than here.
264
        if inputs is not NotSpecified:
265
            # Allow users to specify lists as class-level defaults, but
266
            # normalize to a tuple so that inputs is hashable.
267
            inputs = tuple(inputs)
268
269
        if mask is NotSpecified:
270
            mask = cls.mask
271
        if mask is NotSpecified:
272
            mask = AssetExists()
273
274
        if window_length is NotSpecified:
275
            window_length = cls.window_length
276
277
        return super(CompositeTerm, cls).__new__(cls, inputs=inputs, mask=mask,
278
                                                 window_length=window_length,
279
                                                 *args, **kwargs)
280
281
    def _init(self, inputs, window_length, mask, *args, **kwargs):
282
        self.inputs = inputs
283
        self.window_length = window_length
284
        self.mask = mask
285
        return super(CompositeTerm, self)._init(*args, **kwargs)
286
287
    @classmethod
288
    def static_identity(cls, inputs, window_length, mask, *args, **kwargs):
289
        return (
290
            super(CompositeTerm, cls).static_identity(*args, **kwargs),
291
            inputs,
292
            window_length,
293
            mask,
294
        )
295
296
    def _validate(self):
297
        """
298
        Assert that this term is well-formed.  This should be called exactly
299
        once, at the end of Term._init().
300
        """
301
        if self.inputs is NotSpecified:
302
            raise TermInputsNotSpecified(termname=type(self).__name__)
303
        if self.window_length is NotSpecified:
304
            raise WindowLengthNotSpecified(termname=type(self).__name__)
305
        if self.mask is NotSpecified:
306
            # This isn't user error, this is a bug in our code.
307
            raise AssertionError("{term} has no mask".format(term=self))
308
309
        if self.window_length:
310
            for child in self.inputs:
311
                if not child.atomic:
312
                    raise InputTermNotAtomic(parent=self, child=child)
313
314
        return super(CompositeTerm, self)._validate()
315
316
    def _compute(self, inputs, dates, assets, mask):
317
        """
318
        Subclasses should implement this to perform actual computation.
319
        This is `_compute` rather than just `compute` because `compute` is
320
        reserved for user-supplied functions in CustomFactor.
321
        """
322
        raise NotImplementedError()
323
324
    @lazyval
325
    def windowed(self):
326
        """
327
        Whether or not this term represents a trailing window computation.
328
329
        If term.windowed is truthy, its compute_from_windows method will be
330
        called with instances of AdjustedArray as inputs.
331
332
        If term.windowed is falsey, its compute_from_baseline will be called
333
        with instances of np.ndarray as inputs.
334
        """
335
        return (
336
            self.window_length is not NotSpecified
337
            and self.window_length > 0
338
        )
339
340
    @lazyval
341
    def extra_input_rows(self):
342
        """
343
        The number of extra rows needed for each of our inputs to compute this
344
        term.
345
        """
346
        return max(0, self.window_length - 1)
347
348
    def __repr__(self):
349
        return (
350
            "{type}({inputs}, window_length={window_length})"
351
        ).format(
352
            type=type(self).__name__,
353
            inputs=self.inputs,
354
            window_length=self.window_length,
355
        )
356