Completed
Push — master ( ebb4fb...323695 )
by
unknown
01:25
created

test_boolean_binops()   B

Complexity

Conditions 2

Size

Total Lines 24

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 2
dl 0
loc 24
rs 8.9714
1
from operator import (
2
    add,
3
    and_,
4
    ge,
5
    gt,
6
    le,
7
    lt,
8
    methodcaller,
9
    mul,
10
    ne,
11
    or_,
12
)
13
from unittest import TestCase
14
15
import numpy
16
from numpy import (
17
    arange,
18
    eye,
19
    float64,
20
    full,
21
    isnan,
22
    zeros,
23
)
24
from pandas import (
25
    DataFrame,
26
    date_range,
27
    Int64Index,
28
)
29
30
from zipline.pipeline import Factor
31
from zipline.pipeline.expression import (
32
    NumericalExpression,
33
    NUMEXPR_MATH_FUNCS,
34
)
35
36
from zipline.utils.numpy_utils import datetime64ns_dtype, float64_dtype
37
from zipline.utils.test_utils import check_arrays
38
39
40
class F(Factor):
41
    dtype = float64_dtype
42
    inputs = ()
43
    window_length = 0
44
45
46
class G(Factor):
47
    dtype = float64_dtype
48
    inputs = ()
49
    window_length = 0
50
51
52
class H(Factor):
53
    dtype = float64_dtype
54
    inputs = ()
55
    window_length = 0
56
57
58
class DateFactor(Factor):
59
    dtype = datetime64ns_dtype
60
    inputs = ()
61
    window_length = 0
62
63
64
class NumericalExpressionTestCase(TestCase):
65
66
    def setUp(self):
67
        self.dates = date_range('2014-01-01', periods=5, freq='D')
68
        self.assets = Int64Index(range(5))
69
        self.f = F()
70
        self.g = G()
71
        self.h = H()
72
        self.d = DateFactor()
73
        self.fake_raw_data = {
74
            self.f: full((5, 5), 3),
75
            self.g: full((5, 5), 2),
76
            self.h: full((5, 5), 1),
77
            self.d: full((5, 5), 0, dtype='datetime64[ns]'),
78
        }
79
        self.mask = DataFrame(True, index=self.dates, columns=self.assets)
80
81
    def check_output(self, expr, expected):
82
        result = expr._compute(
83
            [self.fake_raw_data[input_] for input_ in expr.inputs],
84
            self.mask.index,
85
            self.mask.columns,
86
            self.mask.values,
87
        )
88
        check_arrays(result, expected)
89
90
    def check_constant_output(self, expr, expected):
91
        self.assertFalse(isnan(expected))
92
        return self.check_output(expr, full((5, 5), expected))
93
94
    def test_validate_good(self):
95
        f = self.f
96
        g = self.g
97
98
        NumericalExpression("x_0", (f,), dtype=float64_dtype)
99
        NumericalExpression("x_0 ", (f,), dtype=float64_dtype)
100
        NumericalExpression("x_0 + x_0", (f,), dtype=float64_dtype)
101
        NumericalExpression("x_0 + 2", (f,), dtype=float64_dtype)
102
        NumericalExpression("2 * x_0", (f,), dtype=float64_dtype)
103
        NumericalExpression("x_0 + x_1", (f, g), dtype=float64_dtype)
104
        NumericalExpression("x_0 + x_1 + x_0", (f, g), dtype=float64_dtype)
105
        NumericalExpression("x_0 + 1 + x_1", (f, g), dtype=float64_dtype)
106
107
    def test_validate_bad(self):
108
        f, g, h = self.f, self.g, self.h
109
110
        # Too few inputs.
111
        with self.assertRaises(ValueError):
112
            NumericalExpression("x_0", (), dtype=float64_dtype)
113
        with self.assertRaises(ValueError):
114
            NumericalExpression("x_0 + x_1", (f,), dtype=float64_dtype)
115
116
        # Too many inputs.
117
        with self.assertRaises(ValueError):
118
            NumericalExpression("x_0", (f, g), dtype=float64_dtype)
119
        with self.assertRaises(ValueError):
120
            NumericalExpression("x_0 + x_1", (f, g, h), dtype=float64_dtype)
121
122
        # Invalid variable name.
123
        with self.assertRaises(ValueError):
124
            NumericalExpression("x_0x_1", (f,), dtype=float64_dtype)
125
        with self.assertRaises(ValueError):
126
            NumericalExpression("x_0x_1", (f, g), dtype=float64_dtype)
127
128
        # Variable index must start at 0.
129
        with self.assertRaises(ValueError):
130
            NumericalExpression("x_1", (f,), dtype=float64_dtype)
131
132
        # Scalar operands must be numeric.
133
        with self.assertRaises(TypeError):
134
            "2" + f
135
        with self.assertRaises(TypeError):
136
            f + "2"
137
        with self.assertRaises(TypeError):
138
            f > "2"
139
140
        # Boolean binary operators must be between filters.
141
        with self.assertRaises(TypeError):
142
            f + (f > 2)
143
        with self.assertRaises(TypeError):
144
            (f > f) > f
145
146
    def test_combine_datetimes(self):
147
        with self.assertRaises(TypeError) as e:
148
            self.d + self.d
149
        message = e.exception.args[0]
150
        expected = (
151
            "Don't know how to compute datetime64[ns] + datetime64[ns].\n"
152
            "Arithmetic operators are only supported on Factors of dtype "
153
            "'float64'."
154
        )
155
        self.assertEqual(message, expected)
156
157
        # Confirm that * shows up in the error instead of +.
158
        with self.assertRaises(TypeError) as e:
159
            self.d * self.d
160
        message = e.exception.args[0]
161
        expected = (
162
            "Don't know how to compute datetime64[ns] * datetime64[ns].\n"
163
            "Arithmetic operators are only supported on Factors of dtype "
164
            "'float64'."
165
        )
166
        self.assertEqual(message, expected)
167
168
    def test_combine_datetime_with_float(self):
169
        # Test with both float-type factors and numeric values.
170
        for float_value in (self.f, float64(1.0), 1.0):
171
            for op, sym in ((add, '+'), (mul, '*')):
172
                with self.assertRaises(TypeError) as e:
173
                    op(self.f, self.d)
174
                message = e.exception.args[0]
175
                expected = (
176
                    "Don't know how to compute float64 {sym} datetime64[ns].\n"
177
                    "Arithmetic operators are only supported on Factors of "
178
                    "dtype 'float64'."
179
                ).format(sym=sym)
180
                self.assertEqual(message, expected)
181
182
                with self.assertRaises(TypeError) as e:
183
                    op(self.d, self.f)
184
                message = e.exception.args[0]
185
                expected = (
186
                    "Don't know how to compute datetime64[ns] {sym} float64.\n"
187
                    "Arithmetic operators are only supported on Factors of "
188
                    "dtype 'float64'."
189
                ).format(sym=sym)
190
                self.assertEqual(message, expected)
191
192
    def test_negate_datetime(self):
193
        with self.assertRaises(TypeError) as e:
194
            -self.d
195
196
        message = e.exception.args[0]
197
        expected = (
198
            "Can't apply unary operator '-' to instance of "
199
            "'DateFactor' with dtype 'datetime64[ns]'.\n"
200
            "'-' is only supported for Factors of dtype 'float64'."
201
        )
202
        self.assertEqual(message, expected)
203
204
    def test_negate(self):
205
        f, g = self.f, self.g
206
207
        self.check_constant_output(-f, -3.0)
208
        self.check_constant_output(--f, 3.0)
209
        self.check_constant_output(---f, -3.0)
210
211
        self.check_constant_output(-(f + f), -6.0)
212
        self.check_constant_output(-f + -f, -6.0)
213
        self.check_constant_output(-(-f + -f), 6.0)
214
215
        self.check_constant_output(f + -g, 1.0)
216
        self.check_constant_output(f - -g, 5.0)
217
218
        self.check_constant_output(-(f + g) + (f + g), 0.0)
219
        self.check_constant_output((f + g) + -(f + g), 0.0)
220
        self.check_constant_output(-(f + g) + -(f + g), -10.0)
221
222
    def test_add(self):
223
        f, g = self.f, self.g
224
225
        self.check_constant_output(f + g, 5.0)
226
227
        self.check_constant_output((1 + f) + g, 6.0)
228
        self.check_constant_output(1 + (f + g), 6.0)
229
        self.check_constant_output((f + 1) + g, 6.0)
230
        self.check_constant_output(f + (1 + g), 6.0)
231
        self.check_constant_output((f + g) + 1, 6.0)
232
        self.check_constant_output(f + (g + 1), 6.0)
233
234
        self.check_constant_output((f + f) + f, 9.0)
235
        self.check_constant_output(f + (f + f), 9.0)
236
237
        self.check_constant_output((f + g) + f, 8.0)
238
        self.check_constant_output(f + (g + f), 8.0)
239
240
        self.check_constant_output((f + g) + (f + g), 10.0)
241
        self.check_constant_output((f + g) + (g + f), 10.0)
242
        self.check_constant_output((g + f) + (f + g), 10.0)
243
        self.check_constant_output((g + f) + (g + f), 10.0)
244
245
    def test_subtract(self):
246
        f, g = self.f, self.g
247
248
        self.check_constant_output(f - g, 1.0)  # 3 - 2
249
250
        self.check_constant_output((1 - f) - g, -4.)   # (1 - 3) - 2
251
        self.check_constant_output(1 - (f - g), 0.0)   # 1 - (3 - 2)
252
        self.check_constant_output((f - 1) - g, 0.0)   # (3 - 1) - 2
253
        self.check_constant_output(f - (1 - g), 4.0)   # 3 - (1 - 2)
254
        self.check_constant_output((f - g) - 1, 0.0)   # (3 - 2) - 1
255
        self.check_constant_output(f - (g - 1), 2.0)   # 3 - (2 - 1)
256
257
        self.check_constant_output((f - f) - f, -3.)   # (3 - 3) - 3
258
        self.check_constant_output(f - (f - f), 3.0)   # 3 - (3 - 3)
259
260
        self.check_constant_output((f - g) - f, -2.)   # (3 - 2) - 3
261
        self.check_constant_output(f - (g - f), 4.0)   # 3 - (2 - 3)
262
263
        self.check_constant_output((f - g) - (f - g), 0.0)  # (3 - 2) - (3 - 2)
264
        self.check_constant_output((f - g) - (g - f), 2.0)  # (3 - 2) - (2 - 3)
265
        self.check_constant_output((g - f) - (f - g), -2.)  # (2 - 3) - (3 - 2)
266
        self.check_constant_output((g - f) - (g - f), 0.0)  # (2 - 3) - (2 - 3)
267
268
    def test_multiply(self):
269
        f, g = self.f, self.g
270
271
        self.check_constant_output(f * g, 6.0)
272
273
        self.check_constant_output((2 * f) * g, 12.0)
274
        self.check_constant_output(2 * (f * g), 12.0)
275
        self.check_constant_output((f * 2) * g, 12.0)
276
        self.check_constant_output(f * (2 * g), 12.0)
277
        self.check_constant_output((f * g) * 2, 12.0)
278
        self.check_constant_output(f * (g * 2), 12.0)
279
280
        self.check_constant_output((f * f) * f, 27.0)
281
        self.check_constant_output(f * (f * f), 27.0)
282
283
        self.check_constant_output((f * g) * f, 18.0)
284
        self.check_constant_output(f * (g * f), 18.0)
285
286
        self.check_constant_output((f * g) * (f * g), 36.0)
287
        self.check_constant_output((f * g) * (g * f), 36.0)
288
        self.check_constant_output((g * f) * (f * g), 36.0)
289
        self.check_constant_output((g * f) * (g * f), 36.0)
290
291
        self.check_constant_output(f * f * f * 0 * f * f, 0.0)
292
293
    def test_divide(self):
294
        f, g = self.f, self.g
295
296
        self.check_constant_output(f / g, 3.0 / 2.0)
297
298
        self.check_constant_output(
299
            (2 / f) / g,
300
            (2 / 3.0) / 2.0
301
        )
302
        self.check_constant_output(
303
            2 / (f / g),
304
            2 / (3.0 / 2.0),
305
        )
306
        self.check_constant_output(
307
            (f / 2) / g,
308
            (3.0 / 2) / 2.0,
309
        )
310
        self.check_constant_output(
311
            f / (2 / g),
312
            3.0 / (2 / 2.0),
313
        )
314
        self.check_constant_output(
315
            (f / g) / 2,
316
            (3.0 / 2.0) / 2,
317
        )
318
        self.check_constant_output(
319
            f / (g / 2),
320
            3.0 / (2.0 / 2),
321
        )
322
        self.check_constant_output(
323
            (f / f) / f,
324
            (3.0 / 3.0) / 3.0
325
        )
326
        self.check_constant_output(
327
            f / (f / f),
328
            3.0 / (3.0 / 3.0),
329
        )
330
        self.check_constant_output(
331
            (f / g) / f,
332
            (3.0 / 2.0) / 3.0,
333
        )
334
        self.check_constant_output(
335
            f / (g / f),
336
            3.0 / (2.0 / 3.0),
337
        )
338
339
        self.check_constant_output(
340
            (f / g) / (f / g),
341
            (3.0 / 2.0) / (3.0 / 2.0),
342
        )
343
        self.check_constant_output(
344
            (f / g) / (g / f),
345
            (3.0 / 2.0) / (2.0 / 3.0),
346
        )
347
        self.check_constant_output(
348
            (g / f) / (f / g),
349
            (2.0 / 3.0) / (3.0 / 2.0),
350
        )
351
        self.check_constant_output(
352
            (g / f) / (g / f),
353
            (2.0 / 3.0) / (2.0 / 3.0),
354
        )
355
356
    def test_pow(self):
357
        f, g = self.f, self.g
358
359
        self.check_constant_output(f ** g, 3.0 ** 2)
360
        self.check_constant_output(2 ** f, 2.0 ** 3)
361
        self.check_constant_output(f ** 2, 3.0 ** 2)
362
363
        self.check_constant_output((f + g) ** 2, (3.0 + 2.0) ** 2)
364
        self.check_constant_output(2 ** (f + g), 2 ** (3.0 + 2.0))
365
366
        self.check_constant_output(f ** (f ** g), 3.0 ** (3.0 ** 2.0))
367
        self.check_constant_output((f ** f) ** g, (3.0 ** 3.0) ** 2.0)
368
369
        self.check_constant_output((f ** g) ** (f ** g), 9.0 ** 9.0)
370
        self.check_constant_output((f ** g) ** (g ** f), 9.0 ** 8.0)
371
        self.check_constant_output((g ** f) ** (f ** g), 8.0 ** 9.0)
372
        self.check_constant_output((g ** f) ** (g ** f), 8.0 ** 8.0)
373
374
    def test_mod(self):
375
        f, g = self.f, self.g
376
377
        self.check_constant_output(f % g, 3.0 % 2.0)
378
        self.check_constant_output(f % 2.0, 3.0 % 2.0)
379
        self.check_constant_output(g % f, 2.0 % 3.0)
380
381
        self.check_constant_output((f + g) % 2, (3.0 + 2.0) % 2)
382
        self.check_constant_output(2 % (f + g), 2 % (3.0 + 2.0))
383
384
        self.check_constant_output(f % (f % g), 3.0 % (3.0 % 2.0))
385
        self.check_constant_output((f % f) % g, (3.0 % 3.0) % 2.0)
386
387
        self.check_constant_output((f + g) % (f * g), 5.0 % 6.0)
388
389
    def test_math_functions(self):
390
        f, g = self.f, self.g
391
392
        fake_raw_data = self.fake_raw_data
393
        alt_fake_raw_data = {
394
            self.f: full((5, 5), .5),
395
            self.g: full((5, 5), -.5),
396
        }
397
398
        for funcname in NUMEXPR_MATH_FUNCS:
399
            method = methodcaller(funcname)
400
            func = getattr(numpy, funcname)
401
402
            # These methods have domains in [0, 1], so we need alternate inputs
403
            # that are in the domain.
404
            if funcname in ('arcsin', 'arccos', 'arctanh'):
405
                self.fake_raw_data = alt_fake_raw_data
406
            else:
407
                self.fake_raw_data = fake_raw_data
408
409
            f_val = self.fake_raw_data[f][0, 0]
410
            g_val = self.fake_raw_data[g][0, 0]
411
412
            self.check_constant_output(method(f), func(f_val))
413
            self.check_constant_output(method(g), func(g_val))
414
415
            self.check_constant_output(method(f) + 1, func(f_val) + 1)
416
            self.check_constant_output(1 + method(f), 1 + func(f_val))
417
418
            self.check_constant_output(method(f + .25), func(f_val + .25))
419
            self.check_constant_output(method(.25 + f), func(.25 + f_val))
420
421
            self.check_constant_output(
422
                method(f) + method(g),
423
                func(f_val) + func(g_val),
424
            )
425
            self.check_constant_output(
426
                method(f + g),
427
                func(f_val + g_val),
428
            )
429
430
    def test_comparisons(self):
431
        f, g, h = self.f, self.g, self.h
432
        self.fake_raw_data = {
433
            f: arange(25).reshape(5, 5),
434
            g: arange(25).reshape(5, 5) - eye(5),
435
            h: full((5, 5), 5),
436
        }
437
        f_data = self.fake_raw_data[f]
438
        g_data = self.fake_raw_data[g]
439
440
        cases = [
441
            # Sanity Check with hand-computed values.
442
            (f, g, eye(5), zeros((5, 5))),
443
            (f, 10, f_data, 10),
444
            (10, f, 10, f_data),
445
            (f, f, f_data, f_data),
446
            (f + 1, f, f_data + 1, f_data),
447
            (1 + f, f, 1 + f_data, f_data),
448
            (f, g, f_data, g_data),
449
            (f + 1, g, f_data + 1, g_data),
450
            (f, g + 1, f_data, g_data + 1),
451
            (f + 1, g + 1, f_data + 1, g_data + 1),
452
            ((f + g) / 2, f ** 2, (f_data + g_data) / 2, f_data ** 2),
453
        ]
454
        for op in (gt, ge, lt, le, ne):
455
            for expr_lhs, expr_rhs, expected_lhs, expected_rhs in cases:
456
                self.check_output(
457
                    op(expr_lhs, expr_rhs),
458
                    op(expected_lhs, expected_rhs),
459
                )
460
461
    def test_boolean_binops(self):
462
        f, g, h = self.f, self.g, self.h
463
        self.fake_raw_data = {
464
            f: arange(25).reshape(5, 5),
465
            g: arange(25).reshape(5, 5) - eye(5),
466
            h: full((5, 5), 5),
467
        }
468
469
        # Should be True on the diagonal.
470
        eye_filter = f > g
471
        # Should be True in the first row only.
472
        first_row_filter = f < h
473
474
        eye_mask = eye(5, dtype=bool)
475
        first_row_mask = zeros((5, 5), dtype=bool)
476
        first_row_mask[0] = 1
477
478
        self.check_output(eye_filter, eye_mask)
479
        self.check_output(first_row_filter, first_row_mask)
480
481
        for op in (and_, or_):  # NumExpr doesn't support xor.
482
            self.check_output(
483
                op(eye_filter, first_row_filter),
484
                op(eye_mask, first_row_mask),
485
            )
486