1
|
|
|
from operator import ( |
2
|
|
|
add, |
3
|
|
|
and_, |
4
|
|
|
ge, |
5
|
|
|
gt, |
6
|
|
|
le, |
7
|
|
|
lt, |
8
|
|
|
methodcaller, |
9
|
|
|
mul, |
10
|
|
|
ne, |
11
|
|
|
or_, |
12
|
|
|
) |
13
|
|
|
from unittest import TestCase |
14
|
|
|
|
15
|
|
|
import numpy |
16
|
|
|
from numpy import ( |
17
|
|
|
arange, |
18
|
|
|
eye, |
19
|
|
|
float64, |
20
|
|
|
full, |
21
|
|
|
isnan, |
22
|
|
|
zeros, |
23
|
|
|
) |
24
|
|
|
from pandas import ( |
25
|
|
|
DataFrame, |
26
|
|
|
date_range, |
27
|
|
|
Int64Index, |
28
|
|
|
) |
29
|
|
|
|
30
|
|
|
from zipline.pipeline import Factor |
31
|
|
|
from zipline.pipeline.expression import ( |
32
|
|
|
NumericalExpression, |
33
|
|
|
NUMEXPR_MATH_FUNCS, |
34
|
|
|
) |
35
|
|
|
|
36
|
|
|
from zipline.utils.numpy_utils import datetime64ns_dtype, float64_dtype |
37
|
|
|
from zipline.utils.test_utils import check_arrays |
38
|
|
|
|
39
|
|
|
|
40
|
|
|
class F(Factor): |
41
|
|
|
dtype = float64_dtype |
42
|
|
|
inputs = () |
43
|
|
|
window_length = 0 |
44
|
|
|
|
45
|
|
|
|
46
|
|
|
class G(Factor): |
47
|
|
|
dtype = float64_dtype |
48
|
|
|
inputs = () |
49
|
|
|
window_length = 0 |
50
|
|
|
|
51
|
|
|
|
52
|
|
|
class H(Factor): |
53
|
|
|
dtype = float64_dtype |
54
|
|
|
inputs = () |
55
|
|
|
window_length = 0 |
56
|
|
|
|
57
|
|
|
|
58
|
|
|
class DateFactor(Factor): |
59
|
|
|
dtype = datetime64ns_dtype |
60
|
|
|
inputs = () |
61
|
|
|
window_length = 0 |
62
|
|
|
|
63
|
|
|
|
64
|
|
|
class NumericalExpressionTestCase(TestCase): |
65
|
|
|
|
66
|
|
|
def setUp(self): |
67
|
|
|
self.dates = date_range('2014-01-01', periods=5, freq='D') |
68
|
|
|
self.assets = Int64Index(range(5)) |
69
|
|
|
self.f = F() |
70
|
|
|
self.g = G() |
71
|
|
|
self.h = H() |
72
|
|
|
self.d = DateFactor() |
73
|
|
|
self.fake_raw_data = { |
74
|
|
|
self.f: full((5, 5), 3), |
75
|
|
|
self.g: full((5, 5), 2), |
76
|
|
|
self.h: full((5, 5), 1), |
77
|
|
|
self.d: full((5, 5), 0, dtype='datetime64[ns]'), |
78
|
|
|
} |
79
|
|
|
self.mask = DataFrame(True, index=self.dates, columns=self.assets) |
80
|
|
|
|
81
|
|
|
def check_output(self, expr, expected): |
82
|
|
|
result = expr._compute( |
83
|
|
|
[self.fake_raw_data[input_] for input_ in expr.inputs], |
84
|
|
|
self.mask.index, |
85
|
|
|
self.mask.columns, |
86
|
|
|
self.mask.values, |
87
|
|
|
) |
88
|
|
|
check_arrays(result, expected) |
89
|
|
|
|
90
|
|
|
def check_constant_output(self, expr, expected): |
91
|
|
|
self.assertFalse(isnan(expected)) |
92
|
|
|
return self.check_output(expr, full((5, 5), expected)) |
93
|
|
|
|
94
|
|
|
def test_validate_good(self): |
95
|
|
|
f = self.f |
96
|
|
|
g = self.g |
97
|
|
|
|
98
|
|
|
NumericalExpression("x_0", (f,), dtype=float64_dtype) |
99
|
|
|
NumericalExpression("x_0 ", (f,), dtype=float64_dtype) |
100
|
|
|
NumericalExpression("x_0 + x_0", (f,), dtype=float64_dtype) |
101
|
|
|
NumericalExpression("x_0 + 2", (f,), dtype=float64_dtype) |
102
|
|
|
NumericalExpression("2 * x_0", (f,), dtype=float64_dtype) |
103
|
|
|
NumericalExpression("x_0 + x_1", (f, g), dtype=float64_dtype) |
104
|
|
|
NumericalExpression("x_0 + x_1 + x_0", (f, g), dtype=float64_dtype) |
105
|
|
|
NumericalExpression("x_0 + 1 + x_1", (f, g), dtype=float64_dtype) |
106
|
|
|
|
107
|
|
|
def test_validate_bad(self): |
108
|
|
|
f, g, h = self.f, self.g, self.h |
109
|
|
|
|
110
|
|
|
# Too few inputs. |
111
|
|
|
with self.assertRaises(ValueError): |
112
|
|
|
NumericalExpression("x_0", (), dtype=float64_dtype) |
113
|
|
|
with self.assertRaises(ValueError): |
114
|
|
|
NumericalExpression("x_0 + x_1", (f,), dtype=float64_dtype) |
115
|
|
|
|
116
|
|
|
# Too many inputs. |
117
|
|
|
with self.assertRaises(ValueError): |
118
|
|
|
NumericalExpression("x_0", (f, g), dtype=float64_dtype) |
119
|
|
|
with self.assertRaises(ValueError): |
120
|
|
|
NumericalExpression("x_0 + x_1", (f, g, h), dtype=float64_dtype) |
121
|
|
|
|
122
|
|
|
# Invalid variable name. |
123
|
|
|
with self.assertRaises(ValueError): |
124
|
|
|
NumericalExpression("x_0x_1", (f,), dtype=float64_dtype) |
125
|
|
|
with self.assertRaises(ValueError): |
126
|
|
|
NumericalExpression("x_0x_1", (f, g), dtype=float64_dtype) |
127
|
|
|
|
128
|
|
|
# Variable index must start at 0. |
129
|
|
|
with self.assertRaises(ValueError): |
130
|
|
|
NumericalExpression("x_1", (f,), dtype=float64_dtype) |
131
|
|
|
|
132
|
|
|
# Scalar operands must be numeric. |
133
|
|
|
with self.assertRaises(TypeError): |
134
|
|
|
"2" + f |
135
|
|
|
with self.assertRaises(TypeError): |
136
|
|
|
f + "2" |
137
|
|
|
with self.assertRaises(TypeError): |
138
|
|
|
f > "2" |
139
|
|
|
|
140
|
|
|
# Boolean binary operators must be between filters. |
141
|
|
|
with self.assertRaises(TypeError): |
142
|
|
|
f + (f > 2) |
143
|
|
|
with self.assertRaises(TypeError): |
144
|
|
|
(f > f) > f |
145
|
|
|
|
146
|
|
|
def test_combine_datetimes(self): |
147
|
|
|
with self.assertRaises(TypeError) as e: |
148
|
|
|
self.d + self.d |
149
|
|
|
message = e.exception.args[0] |
150
|
|
|
expected = ( |
151
|
|
|
"Don't know how to compute datetime64[ns] + datetime64[ns].\n" |
152
|
|
|
"Arithmetic operators are only supported on Factors of dtype " |
153
|
|
|
"'float64'." |
154
|
|
|
) |
155
|
|
|
self.assertEqual(message, expected) |
156
|
|
|
|
157
|
|
|
# Confirm that * shows up in the error instead of +. |
158
|
|
|
with self.assertRaises(TypeError) as e: |
159
|
|
|
self.d * self.d |
160
|
|
|
message = e.exception.args[0] |
161
|
|
|
expected = ( |
162
|
|
|
"Don't know how to compute datetime64[ns] * datetime64[ns].\n" |
163
|
|
|
"Arithmetic operators are only supported on Factors of dtype " |
164
|
|
|
"'float64'." |
165
|
|
|
) |
166
|
|
|
self.assertEqual(message, expected) |
167
|
|
|
|
168
|
|
|
def test_combine_datetime_with_float(self): |
169
|
|
|
# Test with both float-type factors and numeric values. |
170
|
|
|
for float_value in (self.f, float64(1.0), 1.0): |
171
|
|
|
for op, sym in ((add, '+'), (mul, '*')): |
172
|
|
|
with self.assertRaises(TypeError) as e: |
173
|
|
|
op(self.f, self.d) |
174
|
|
|
message = e.exception.args[0] |
175
|
|
|
expected = ( |
176
|
|
|
"Don't know how to compute float64 {sym} datetime64[ns].\n" |
177
|
|
|
"Arithmetic operators are only supported on Factors of " |
178
|
|
|
"dtype 'float64'." |
179
|
|
|
).format(sym=sym) |
180
|
|
|
self.assertEqual(message, expected) |
181
|
|
|
|
182
|
|
|
with self.assertRaises(TypeError) as e: |
183
|
|
|
op(self.d, self.f) |
184
|
|
|
message = e.exception.args[0] |
185
|
|
|
expected = ( |
186
|
|
|
"Don't know how to compute datetime64[ns] {sym} float64.\n" |
187
|
|
|
"Arithmetic operators are only supported on Factors of " |
188
|
|
|
"dtype 'float64'." |
189
|
|
|
).format(sym=sym) |
190
|
|
|
self.assertEqual(message, expected) |
191
|
|
|
|
192
|
|
|
def test_negate_datetime(self): |
193
|
|
|
with self.assertRaises(TypeError) as e: |
194
|
|
|
-self.d |
195
|
|
|
|
196
|
|
|
message = e.exception.args[0] |
197
|
|
|
expected = ( |
198
|
|
|
"Can't apply unary operator '-' to instance of " |
199
|
|
|
"'DateFactor' with dtype 'datetime64[ns]'.\n" |
200
|
|
|
"'-' is only supported for Factors of dtype 'float64'." |
201
|
|
|
) |
202
|
|
|
self.assertEqual(message, expected) |
203
|
|
|
|
204
|
|
|
def test_negate(self): |
205
|
|
|
f, g = self.f, self.g |
206
|
|
|
|
207
|
|
|
self.check_constant_output(-f, -3.0) |
208
|
|
|
self.check_constant_output(--f, 3.0) |
209
|
|
|
self.check_constant_output(---f, -3.0) |
210
|
|
|
|
211
|
|
|
self.check_constant_output(-(f + f), -6.0) |
212
|
|
|
self.check_constant_output(-f + -f, -6.0) |
213
|
|
|
self.check_constant_output(-(-f + -f), 6.0) |
214
|
|
|
|
215
|
|
|
self.check_constant_output(f + -g, 1.0) |
216
|
|
|
self.check_constant_output(f - -g, 5.0) |
217
|
|
|
|
218
|
|
|
self.check_constant_output(-(f + g) + (f + g), 0.0) |
219
|
|
|
self.check_constant_output((f + g) + -(f + g), 0.0) |
220
|
|
|
self.check_constant_output(-(f + g) + -(f + g), -10.0) |
221
|
|
|
|
222
|
|
|
def test_add(self): |
223
|
|
|
f, g = self.f, self.g |
224
|
|
|
|
225
|
|
|
self.check_constant_output(f + g, 5.0) |
226
|
|
|
|
227
|
|
|
self.check_constant_output((1 + f) + g, 6.0) |
228
|
|
|
self.check_constant_output(1 + (f + g), 6.0) |
229
|
|
|
self.check_constant_output((f + 1) + g, 6.0) |
230
|
|
|
self.check_constant_output(f + (1 + g), 6.0) |
231
|
|
|
self.check_constant_output((f + g) + 1, 6.0) |
232
|
|
|
self.check_constant_output(f + (g + 1), 6.0) |
233
|
|
|
|
234
|
|
|
self.check_constant_output((f + f) + f, 9.0) |
235
|
|
|
self.check_constant_output(f + (f + f), 9.0) |
236
|
|
|
|
237
|
|
|
self.check_constant_output((f + g) + f, 8.0) |
238
|
|
|
self.check_constant_output(f + (g + f), 8.0) |
239
|
|
|
|
240
|
|
|
self.check_constant_output((f + g) + (f + g), 10.0) |
241
|
|
|
self.check_constant_output((f + g) + (g + f), 10.0) |
242
|
|
|
self.check_constant_output((g + f) + (f + g), 10.0) |
243
|
|
|
self.check_constant_output((g + f) + (g + f), 10.0) |
244
|
|
|
|
245
|
|
|
def test_subtract(self): |
246
|
|
|
f, g = self.f, self.g |
247
|
|
|
|
248
|
|
|
self.check_constant_output(f - g, 1.0) # 3 - 2 |
249
|
|
|
|
250
|
|
|
self.check_constant_output((1 - f) - g, -4.) # (1 - 3) - 2 |
251
|
|
|
self.check_constant_output(1 - (f - g), 0.0) # 1 - (3 - 2) |
252
|
|
|
self.check_constant_output((f - 1) - g, 0.0) # (3 - 1) - 2 |
253
|
|
|
self.check_constant_output(f - (1 - g), 4.0) # 3 - (1 - 2) |
254
|
|
|
self.check_constant_output((f - g) - 1, 0.0) # (3 - 2) - 1 |
255
|
|
|
self.check_constant_output(f - (g - 1), 2.0) # 3 - (2 - 1) |
256
|
|
|
|
257
|
|
|
self.check_constant_output((f - f) - f, -3.) # (3 - 3) - 3 |
258
|
|
|
self.check_constant_output(f - (f - f), 3.0) # 3 - (3 - 3) |
259
|
|
|
|
260
|
|
|
self.check_constant_output((f - g) - f, -2.) # (3 - 2) - 3 |
261
|
|
|
self.check_constant_output(f - (g - f), 4.0) # 3 - (2 - 3) |
262
|
|
|
|
263
|
|
|
self.check_constant_output((f - g) - (f - g), 0.0) # (3 - 2) - (3 - 2) |
264
|
|
|
self.check_constant_output((f - g) - (g - f), 2.0) # (3 - 2) - (2 - 3) |
265
|
|
|
self.check_constant_output((g - f) - (f - g), -2.) # (2 - 3) - (3 - 2) |
266
|
|
|
self.check_constant_output((g - f) - (g - f), 0.0) # (2 - 3) - (2 - 3) |
267
|
|
|
|
268
|
|
|
def test_multiply(self): |
269
|
|
|
f, g = self.f, self.g |
270
|
|
|
|
271
|
|
|
self.check_constant_output(f * g, 6.0) |
272
|
|
|
|
273
|
|
|
self.check_constant_output((2 * f) * g, 12.0) |
274
|
|
|
self.check_constant_output(2 * (f * g), 12.0) |
275
|
|
|
self.check_constant_output((f * 2) * g, 12.0) |
276
|
|
|
self.check_constant_output(f * (2 * g), 12.0) |
277
|
|
|
self.check_constant_output((f * g) * 2, 12.0) |
278
|
|
|
self.check_constant_output(f * (g * 2), 12.0) |
279
|
|
|
|
280
|
|
|
self.check_constant_output((f * f) * f, 27.0) |
281
|
|
|
self.check_constant_output(f * (f * f), 27.0) |
282
|
|
|
|
283
|
|
|
self.check_constant_output((f * g) * f, 18.0) |
284
|
|
|
self.check_constant_output(f * (g * f), 18.0) |
285
|
|
|
|
286
|
|
|
self.check_constant_output((f * g) * (f * g), 36.0) |
287
|
|
|
self.check_constant_output((f * g) * (g * f), 36.0) |
288
|
|
|
self.check_constant_output((g * f) * (f * g), 36.0) |
289
|
|
|
self.check_constant_output((g * f) * (g * f), 36.0) |
290
|
|
|
|
291
|
|
|
self.check_constant_output(f * f * f * 0 * f * f, 0.0) |
292
|
|
|
|
293
|
|
|
def test_divide(self): |
294
|
|
|
f, g = self.f, self.g |
295
|
|
|
|
296
|
|
|
self.check_constant_output(f / g, 3.0 / 2.0) |
297
|
|
|
|
298
|
|
|
self.check_constant_output( |
299
|
|
|
(2 / f) / g, |
300
|
|
|
(2 / 3.0) / 2.0 |
301
|
|
|
) |
302
|
|
|
self.check_constant_output( |
303
|
|
|
2 / (f / g), |
304
|
|
|
2 / (3.0 / 2.0), |
305
|
|
|
) |
306
|
|
|
self.check_constant_output( |
307
|
|
|
(f / 2) / g, |
308
|
|
|
(3.0 / 2) / 2.0, |
309
|
|
|
) |
310
|
|
|
self.check_constant_output( |
311
|
|
|
f / (2 / g), |
312
|
|
|
3.0 / (2 / 2.0), |
313
|
|
|
) |
314
|
|
|
self.check_constant_output( |
315
|
|
|
(f / g) / 2, |
316
|
|
|
(3.0 / 2.0) / 2, |
317
|
|
|
) |
318
|
|
|
self.check_constant_output( |
319
|
|
|
f / (g / 2), |
320
|
|
|
3.0 / (2.0 / 2), |
321
|
|
|
) |
322
|
|
|
self.check_constant_output( |
323
|
|
|
(f / f) / f, |
324
|
|
|
(3.0 / 3.0) / 3.0 |
325
|
|
|
) |
326
|
|
|
self.check_constant_output( |
327
|
|
|
f / (f / f), |
328
|
|
|
3.0 / (3.0 / 3.0), |
329
|
|
|
) |
330
|
|
|
self.check_constant_output( |
331
|
|
|
(f / g) / f, |
332
|
|
|
(3.0 / 2.0) / 3.0, |
333
|
|
|
) |
334
|
|
|
self.check_constant_output( |
335
|
|
|
f / (g / f), |
336
|
|
|
3.0 / (2.0 / 3.0), |
337
|
|
|
) |
338
|
|
|
|
339
|
|
|
self.check_constant_output( |
340
|
|
|
(f / g) / (f / g), |
341
|
|
|
(3.0 / 2.0) / (3.0 / 2.0), |
342
|
|
|
) |
343
|
|
|
self.check_constant_output( |
344
|
|
|
(f / g) / (g / f), |
345
|
|
|
(3.0 / 2.0) / (2.0 / 3.0), |
346
|
|
|
) |
347
|
|
|
self.check_constant_output( |
348
|
|
|
(g / f) / (f / g), |
349
|
|
|
(2.0 / 3.0) / (3.0 / 2.0), |
350
|
|
|
) |
351
|
|
|
self.check_constant_output( |
352
|
|
|
(g / f) / (g / f), |
353
|
|
|
(2.0 / 3.0) / (2.0 / 3.0), |
354
|
|
|
) |
355
|
|
|
|
356
|
|
|
def test_pow(self): |
357
|
|
|
f, g = self.f, self.g |
358
|
|
|
|
359
|
|
|
self.check_constant_output(f ** g, 3.0 ** 2) |
360
|
|
|
self.check_constant_output(2 ** f, 2.0 ** 3) |
361
|
|
|
self.check_constant_output(f ** 2, 3.0 ** 2) |
362
|
|
|
|
363
|
|
|
self.check_constant_output((f + g) ** 2, (3.0 + 2.0) ** 2) |
364
|
|
|
self.check_constant_output(2 ** (f + g), 2 ** (3.0 + 2.0)) |
365
|
|
|
|
366
|
|
|
self.check_constant_output(f ** (f ** g), 3.0 ** (3.0 ** 2.0)) |
367
|
|
|
self.check_constant_output((f ** f) ** g, (3.0 ** 3.0) ** 2.0) |
368
|
|
|
|
369
|
|
|
self.check_constant_output((f ** g) ** (f ** g), 9.0 ** 9.0) |
370
|
|
|
self.check_constant_output((f ** g) ** (g ** f), 9.0 ** 8.0) |
371
|
|
|
self.check_constant_output((g ** f) ** (f ** g), 8.0 ** 9.0) |
372
|
|
|
self.check_constant_output((g ** f) ** (g ** f), 8.0 ** 8.0) |
373
|
|
|
|
374
|
|
|
def test_mod(self): |
375
|
|
|
f, g = self.f, self.g |
376
|
|
|
|
377
|
|
|
self.check_constant_output(f % g, 3.0 % 2.0) |
378
|
|
|
self.check_constant_output(f % 2.0, 3.0 % 2.0) |
379
|
|
|
self.check_constant_output(g % f, 2.0 % 3.0) |
380
|
|
|
|
381
|
|
|
self.check_constant_output((f + g) % 2, (3.0 + 2.0) % 2) |
382
|
|
|
self.check_constant_output(2 % (f + g), 2 % (3.0 + 2.0)) |
383
|
|
|
|
384
|
|
|
self.check_constant_output(f % (f % g), 3.0 % (3.0 % 2.0)) |
385
|
|
|
self.check_constant_output((f % f) % g, (3.0 % 3.0) % 2.0) |
386
|
|
|
|
387
|
|
|
self.check_constant_output((f + g) % (f * g), 5.0 % 6.0) |
388
|
|
|
|
389
|
|
|
def test_math_functions(self): |
390
|
|
|
f, g = self.f, self.g |
391
|
|
|
|
392
|
|
|
fake_raw_data = self.fake_raw_data |
393
|
|
|
alt_fake_raw_data = { |
394
|
|
|
self.f: full((5, 5), .5), |
395
|
|
|
self.g: full((5, 5), -.5), |
396
|
|
|
} |
397
|
|
|
|
398
|
|
|
for funcname in NUMEXPR_MATH_FUNCS: |
399
|
|
|
method = methodcaller(funcname) |
400
|
|
|
func = getattr(numpy, funcname) |
401
|
|
|
|
402
|
|
|
# These methods have domains in [0, 1], so we need alternate inputs |
403
|
|
|
# that are in the domain. |
404
|
|
|
if funcname in ('arcsin', 'arccos', 'arctanh'): |
405
|
|
|
self.fake_raw_data = alt_fake_raw_data |
406
|
|
|
else: |
407
|
|
|
self.fake_raw_data = fake_raw_data |
408
|
|
|
|
409
|
|
|
f_val = self.fake_raw_data[f][0, 0] |
410
|
|
|
g_val = self.fake_raw_data[g][0, 0] |
411
|
|
|
|
412
|
|
|
self.check_constant_output(method(f), func(f_val)) |
413
|
|
|
self.check_constant_output(method(g), func(g_val)) |
414
|
|
|
|
415
|
|
|
self.check_constant_output(method(f) + 1, func(f_val) + 1) |
416
|
|
|
self.check_constant_output(1 + method(f), 1 + func(f_val)) |
417
|
|
|
|
418
|
|
|
self.check_constant_output(method(f + .25), func(f_val + .25)) |
419
|
|
|
self.check_constant_output(method(.25 + f), func(.25 + f_val)) |
420
|
|
|
|
421
|
|
|
self.check_constant_output( |
422
|
|
|
method(f) + method(g), |
423
|
|
|
func(f_val) + func(g_val), |
424
|
|
|
) |
425
|
|
|
self.check_constant_output( |
426
|
|
|
method(f + g), |
427
|
|
|
func(f_val + g_val), |
428
|
|
|
) |
429
|
|
|
|
430
|
|
|
def test_comparisons(self): |
431
|
|
|
f, g, h = self.f, self.g, self.h |
432
|
|
|
self.fake_raw_data = { |
433
|
|
|
f: arange(25).reshape(5, 5), |
434
|
|
|
g: arange(25).reshape(5, 5) - eye(5), |
435
|
|
|
h: full((5, 5), 5), |
436
|
|
|
} |
437
|
|
|
f_data = self.fake_raw_data[f] |
438
|
|
|
g_data = self.fake_raw_data[g] |
439
|
|
|
|
440
|
|
|
cases = [ |
441
|
|
|
# Sanity Check with hand-computed values. |
442
|
|
|
(f, g, eye(5), zeros((5, 5))), |
443
|
|
|
(f, 10, f_data, 10), |
444
|
|
|
(10, f, 10, f_data), |
445
|
|
|
(f, f, f_data, f_data), |
446
|
|
|
(f + 1, f, f_data + 1, f_data), |
447
|
|
|
(1 + f, f, 1 + f_data, f_data), |
448
|
|
|
(f, g, f_data, g_data), |
449
|
|
|
(f + 1, g, f_data + 1, g_data), |
450
|
|
|
(f, g + 1, f_data, g_data + 1), |
451
|
|
|
(f + 1, g + 1, f_data + 1, g_data + 1), |
452
|
|
|
((f + g) / 2, f ** 2, (f_data + g_data) / 2, f_data ** 2), |
453
|
|
|
] |
454
|
|
|
for op in (gt, ge, lt, le, ne): |
455
|
|
|
for expr_lhs, expr_rhs, expected_lhs, expected_rhs in cases: |
456
|
|
|
self.check_output( |
457
|
|
|
op(expr_lhs, expr_rhs), |
458
|
|
|
op(expected_lhs, expected_rhs), |
459
|
|
|
) |
460
|
|
|
|
461
|
|
|
def test_boolean_binops(self): |
462
|
|
|
f, g, h = self.f, self.g, self.h |
463
|
|
|
self.fake_raw_data = { |
464
|
|
|
f: arange(25).reshape(5, 5), |
465
|
|
|
g: arange(25).reshape(5, 5) - eye(5), |
466
|
|
|
h: full((5, 5), 5), |
467
|
|
|
} |
468
|
|
|
|
469
|
|
|
# Should be True on the diagonal. |
470
|
|
|
eye_filter = f > g |
471
|
|
|
# Should be True in the first row only. |
472
|
|
|
first_row_filter = f < h |
473
|
|
|
|
474
|
|
|
eye_mask = eye(5, dtype=bool) |
475
|
|
|
first_row_mask = zeros((5, 5), dtype=bool) |
476
|
|
|
first_row_mask[0] = 1 |
477
|
|
|
|
478
|
|
|
self.check_output(eye_filter, eye_mask) |
479
|
|
|
self.check_output(first_row_filter, first_row_mask) |
480
|
|
|
|
481
|
|
|
for op in (and_, or_): # NumExpr doesn't support xor. |
482
|
|
|
self.check_output( |
483
|
|
|
op(eye_filter, first_row_filter), |
484
|
|
|
op(eye_mask, first_row_mask), |
485
|
|
|
) |
486
|
|
|
|