Total Complexity | 43 |
Total Lines | 421 |
Duplicated Lines | 0 % |
Complex classes like tests.pipeline.NumericalExpressionTestCase often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
1 | from operator import ( |
||
64 | class NumericalExpressionTestCase(TestCase): |
||
65 | |||
66 | def setUp(self): |
||
67 | self.dates = date_range('2014-01-01', periods=5, freq='D') |
||
68 | self.assets = Int64Index(range(5)) |
||
69 | self.f = F() |
||
70 | self.g = G() |
||
71 | self.h = H() |
||
72 | self.d = DateFactor() |
||
73 | self.fake_raw_data = { |
||
74 | self.f: full((5, 5), 3), |
||
75 | self.g: full((5, 5), 2), |
||
76 | self.h: full((5, 5), 1), |
||
77 | self.d: full((5, 5), 0, dtype='datetime64[ns]'), |
||
78 | } |
||
79 | self.mask = DataFrame(True, index=self.dates, columns=self.assets) |
||
80 | |||
81 | def check_output(self, expr, expected): |
||
82 | result = expr._compute( |
||
83 | [self.fake_raw_data[input_] for input_ in expr.inputs], |
||
84 | self.mask.index, |
||
85 | self.mask.columns, |
||
86 | self.mask.values, |
||
87 | ) |
||
88 | check_arrays(result, expected) |
||
89 | |||
90 | def check_constant_output(self, expr, expected): |
||
91 | self.assertFalse(isnan(expected)) |
||
92 | return self.check_output(expr, full((5, 5), expected)) |
||
93 | |||
94 | def test_validate_good(self): |
||
95 | f = self.f |
||
96 | g = self.g |
||
97 | |||
98 | NumericalExpression("x_0", (f,), dtype=float64_dtype) |
||
99 | NumericalExpression("x_0 ", (f,), dtype=float64_dtype) |
||
100 | NumericalExpression("x_0 + x_0", (f,), dtype=float64_dtype) |
||
101 | NumericalExpression("x_0 + 2", (f,), dtype=float64_dtype) |
||
102 | NumericalExpression("2 * x_0", (f,), dtype=float64_dtype) |
||
103 | NumericalExpression("x_0 + x_1", (f, g), dtype=float64_dtype) |
||
104 | NumericalExpression("x_0 + x_1 + x_0", (f, g), dtype=float64_dtype) |
||
105 | NumericalExpression("x_0 + 1 + x_1", (f, g), dtype=float64_dtype) |
||
106 | |||
107 | def test_validate_bad(self): |
||
108 | f, g, h = self.f, self.g, self.h |
||
109 | |||
110 | # Too few inputs. |
||
111 | with self.assertRaises(ValueError): |
||
112 | NumericalExpression("x_0", (), dtype=float64_dtype) |
||
113 | with self.assertRaises(ValueError): |
||
114 | NumericalExpression("x_0 + x_1", (f,), dtype=float64_dtype) |
||
115 | |||
116 | # Too many inputs. |
||
117 | with self.assertRaises(ValueError): |
||
118 | NumericalExpression("x_0", (f, g), dtype=float64_dtype) |
||
119 | with self.assertRaises(ValueError): |
||
120 | NumericalExpression("x_0 + x_1", (f, g, h), dtype=float64_dtype) |
||
121 | |||
122 | # Invalid variable name. |
||
123 | with self.assertRaises(ValueError): |
||
124 | NumericalExpression("x_0x_1", (f,), dtype=float64_dtype) |
||
125 | with self.assertRaises(ValueError): |
||
126 | NumericalExpression("x_0x_1", (f, g), dtype=float64_dtype) |
||
127 | |||
128 | # Variable index must start at 0. |
||
129 | with self.assertRaises(ValueError): |
||
130 | NumericalExpression("x_1", (f,), dtype=float64_dtype) |
||
131 | |||
132 | # Scalar operands must be numeric. |
||
133 | with self.assertRaises(TypeError): |
||
134 | "2" + f |
||
135 | with self.assertRaises(TypeError): |
||
136 | f + "2" |
||
137 | with self.assertRaises(TypeError): |
||
138 | f > "2" |
||
139 | |||
140 | # Boolean binary operators must be between filters. |
||
141 | with self.assertRaises(TypeError): |
||
142 | f + (f > 2) |
||
143 | with self.assertRaises(TypeError): |
||
144 | (f > f) > f |
||
145 | |||
146 | def test_combine_datetimes(self): |
||
147 | with self.assertRaises(TypeError) as e: |
||
148 | self.d + self.d |
||
149 | message = e.exception.args[0] |
||
150 | expected = ( |
||
151 | "Don't know how to compute datetime64[ns] + datetime64[ns].\n" |
||
152 | "Arithmetic operators are only supported on Factors of dtype " |
||
153 | "'float64'." |
||
154 | ) |
||
155 | self.assertEqual(message, expected) |
||
156 | |||
157 | # Confirm that * shows up in the error instead of +. |
||
158 | with self.assertRaises(TypeError) as e: |
||
159 | self.d * self.d |
||
160 | message = e.exception.args[0] |
||
161 | expected = ( |
||
162 | "Don't know how to compute datetime64[ns] * datetime64[ns].\n" |
||
163 | "Arithmetic operators are only supported on Factors of dtype " |
||
164 | "'float64'." |
||
165 | ) |
||
166 | self.assertEqual(message, expected) |
||
167 | |||
168 | def test_combine_datetime_with_float(self): |
||
169 | # Test with both float-type factors and numeric values. |
||
170 | for float_value in (self.f, float64(1.0), 1.0): |
||
171 | for op, sym in ((add, '+'), (mul, '*')): |
||
172 | with self.assertRaises(TypeError) as e: |
||
173 | op(self.f, self.d) |
||
174 | message = e.exception.args[0] |
||
175 | expected = ( |
||
176 | "Don't know how to compute float64 {sym} datetime64[ns].\n" |
||
177 | "Arithmetic operators are only supported on Factors of " |
||
178 | "dtype 'float64'." |
||
179 | ).format(sym=sym) |
||
180 | self.assertEqual(message, expected) |
||
181 | |||
182 | with self.assertRaises(TypeError) as e: |
||
183 | op(self.d, self.f) |
||
184 | message = e.exception.args[0] |
||
185 | expected = ( |
||
186 | "Don't know how to compute datetime64[ns] {sym} float64.\n" |
||
187 | "Arithmetic operators are only supported on Factors of " |
||
188 | "dtype 'float64'." |
||
189 | ).format(sym=sym) |
||
190 | self.assertEqual(message, expected) |
||
191 | |||
192 | def test_negate_datetime(self): |
||
193 | with self.assertRaises(TypeError) as e: |
||
194 | -self.d |
||
195 | |||
196 | message = e.exception.args[0] |
||
197 | expected = ( |
||
198 | "Can't apply unary operator '-' to instance of " |
||
199 | "'DateFactor' with dtype 'datetime64[ns]'.\n" |
||
200 | "'-' is only supported for Factors of dtype 'float64'." |
||
201 | ) |
||
202 | self.assertEqual(message, expected) |
||
203 | |||
204 | def test_negate(self): |
||
205 | f, g = self.f, self.g |
||
206 | |||
207 | self.check_constant_output(-f, -3.0) |
||
208 | self.check_constant_output(--f, 3.0) |
||
209 | self.check_constant_output(---f, -3.0) |
||
210 | |||
211 | self.check_constant_output(-(f + f), -6.0) |
||
212 | self.check_constant_output(-f + -f, -6.0) |
||
213 | self.check_constant_output(-(-f + -f), 6.0) |
||
214 | |||
215 | self.check_constant_output(f + -g, 1.0) |
||
216 | self.check_constant_output(f - -g, 5.0) |
||
217 | |||
218 | self.check_constant_output(-(f + g) + (f + g), 0.0) |
||
219 | self.check_constant_output((f + g) + -(f + g), 0.0) |
||
220 | self.check_constant_output(-(f + g) + -(f + g), -10.0) |
||
221 | |||
222 | def test_add(self): |
||
223 | f, g = self.f, self.g |
||
224 | |||
225 | self.check_constant_output(f + g, 5.0) |
||
226 | |||
227 | self.check_constant_output((1 + f) + g, 6.0) |
||
228 | self.check_constant_output(1 + (f + g), 6.0) |
||
229 | self.check_constant_output((f + 1) + g, 6.0) |
||
230 | self.check_constant_output(f + (1 + g), 6.0) |
||
231 | self.check_constant_output((f + g) + 1, 6.0) |
||
232 | self.check_constant_output(f + (g + 1), 6.0) |
||
233 | |||
234 | self.check_constant_output((f + f) + f, 9.0) |
||
235 | self.check_constant_output(f + (f + f), 9.0) |
||
236 | |||
237 | self.check_constant_output((f + g) + f, 8.0) |
||
238 | self.check_constant_output(f + (g + f), 8.0) |
||
239 | |||
240 | self.check_constant_output((f + g) + (f + g), 10.0) |
||
241 | self.check_constant_output((f + g) + (g + f), 10.0) |
||
242 | self.check_constant_output((g + f) + (f + g), 10.0) |
||
243 | self.check_constant_output((g + f) + (g + f), 10.0) |
||
244 | |||
245 | def test_subtract(self): |
||
246 | f, g = self.f, self.g |
||
247 | |||
248 | self.check_constant_output(f - g, 1.0) # 3 - 2 |
||
249 | |||
250 | self.check_constant_output((1 - f) - g, -4.) # (1 - 3) - 2 |
||
251 | self.check_constant_output(1 - (f - g), 0.0) # 1 - (3 - 2) |
||
252 | self.check_constant_output((f - 1) - g, 0.0) # (3 - 1) - 2 |
||
253 | self.check_constant_output(f - (1 - g), 4.0) # 3 - (1 - 2) |
||
254 | self.check_constant_output((f - g) - 1, 0.0) # (3 - 2) - 1 |
||
255 | self.check_constant_output(f - (g - 1), 2.0) # 3 - (2 - 1) |
||
256 | |||
257 | self.check_constant_output((f - f) - f, -3.) # (3 - 3) - 3 |
||
258 | self.check_constant_output(f - (f - f), 3.0) # 3 - (3 - 3) |
||
259 | |||
260 | self.check_constant_output((f - g) - f, -2.) # (3 - 2) - 3 |
||
261 | self.check_constant_output(f - (g - f), 4.0) # 3 - (2 - 3) |
||
262 | |||
263 | self.check_constant_output((f - g) - (f - g), 0.0) # (3 - 2) - (3 - 2) |
||
264 | self.check_constant_output((f - g) - (g - f), 2.0) # (3 - 2) - (2 - 3) |
||
265 | self.check_constant_output((g - f) - (f - g), -2.) # (2 - 3) - (3 - 2) |
||
266 | self.check_constant_output((g - f) - (g - f), 0.0) # (2 - 3) - (2 - 3) |
||
267 | |||
268 | def test_multiply(self): |
||
269 | f, g = self.f, self.g |
||
270 | |||
271 | self.check_constant_output(f * g, 6.0) |
||
272 | |||
273 | self.check_constant_output((2 * f) * g, 12.0) |
||
274 | self.check_constant_output(2 * (f * g), 12.0) |
||
275 | self.check_constant_output((f * 2) * g, 12.0) |
||
276 | self.check_constant_output(f * (2 * g), 12.0) |
||
277 | self.check_constant_output((f * g) * 2, 12.0) |
||
278 | self.check_constant_output(f * (g * 2), 12.0) |
||
279 | |||
280 | self.check_constant_output((f * f) * f, 27.0) |
||
281 | self.check_constant_output(f * (f * f), 27.0) |
||
282 | |||
283 | self.check_constant_output((f * g) * f, 18.0) |
||
284 | self.check_constant_output(f * (g * f), 18.0) |
||
285 | |||
286 | self.check_constant_output((f * g) * (f * g), 36.0) |
||
287 | self.check_constant_output((f * g) * (g * f), 36.0) |
||
288 | self.check_constant_output((g * f) * (f * g), 36.0) |
||
289 | self.check_constant_output((g * f) * (g * f), 36.0) |
||
290 | |||
291 | self.check_constant_output(f * f * f * 0 * f * f, 0.0) |
||
292 | |||
293 | def test_divide(self): |
||
294 | f, g = self.f, self.g |
||
295 | |||
296 | self.check_constant_output(f / g, 3.0 / 2.0) |
||
297 | |||
298 | self.check_constant_output( |
||
299 | (2 / f) / g, |
||
300 | (2 / 3.0) / 2.0 |
||
301 | ) |
||
302 | self.check_constant_output( |
||
303 | 2 / (f / g), |
||
304 | 2 / (3.0 / 2.0), |
||
305 | ) |
||
306 | self.check_constant_output( |
||
307 | (f / 2) / g, |
||
308 | (3.0 / 2) / 2.0, |
||
309 | ) |
||
310 | self.check_constant_output( |
||
311 | f / (2 / g), |
||
312 | 3.0 / (2 / 2.0), |
||
313 | ) |
||
314 | self.check_constant_output( |
||
315 | (f / g) / 2, |
||
316 | (3.0 / 2.0) / 2, |
||
317 | ) |
||
318 | self.check_constant_output( |
||
319 | f / (g / 2), |
||
320 | 3.0 / (2.0 / 2), |
||
321 | ) |
||
322 | self.check_constant_output( |
||
323 | (f / f) / f, |
||
324 | (3.0 / 3.0) / 3.0 |
||
325 | ) |
||
326 | self.check_constant_output( |
||
327 | f / (f / f), |
||
328 | 3.0 / (3.0 / 3.0), |
||
329 | ) |
||
330 | self.check_constant_output( |
||
331 | (f / g) / f, |
||
332 | (3.0 / 2.0) / 3.0, |
||
333 | ) |
||
334 | self.check_constant_output( |
||
335 | f / (g / f), |
||
336 | 3.0 / (2.0 / 3.0), |
||
337 | ) |
||
338 | |||
339 | self.check_constant_output( |
||
340 | (f / g) / (f / g), |
||
341 | (3.0 / 2.0) / (3.0 / 2.0), |
||
342 | ) |
||
343 | self.check_constant_output( |
||
344 | (f / g) / (g / f), |
||
345 | (3.0 / 2.0) / (2.0 / 3.0), |
||
346 | ) |
||
347 | self.check_constant_output( |
||
348 | (g / f) / (f / g), |
||
349 | (2.0 / 3.0) / (3.0 / 2.0), |
||
350 | ) |
||
351 | self.check_constant_output( |
||
352 | (g / f) / (g / f), |
||
353 | (2.0 / 3.0) / (2.0 / 3.0), |
||
354 | ) |
||
355 | |||
356 | def test_pow(self): |
||
357 | f, g = self.f, self.g |
||
358 | |||
359 | self.check_constant_output(f ** g, 3.0 ** 2) |
||
360 | self.check_constant_output(2 ** f, 2.0 ** 3) |
||
361 | self.check_constant_output(f ** 2, 3.0 ** 2) |
||
362 | |||
363 | self.check_constant_output((f + g) ** 2, (3.0 + 2.0) ** 2) |
||
364 | self.check_constant_output(2 ** (f + g), 2 ** (3.0 + 2.0)) |
||
365 | |||
366 | self.check_constant_output(f ** (f ** g), 3.0 ** (3.0 ** 2.0)) |
||
367 | self.check_constant_output((f ** f) ** g, (3.0 ** 3.0) ** 2.0) |
||
368 | |||
369 | self.check_constant_output((f ** g) ** (f ** g), 9.0 ** 9.0) |
||
370 | self.check_constant_output((f ** g) ** (g ** f), 9.0 ** 8.0) |
||
371 | self.check_constant_output((g ** f) ** (f ** g), 8.0 ** 9.0) |
||
372 | self.check_constant_output((g ** f) ** (g ** f), 8.0 ** 8.0) |
||
373 | |||
374 | def test_mod(self): |
||
375 | f, g = self.f, self.g |
||
376 | |||
377 | self.check_constant_output(f % g, 3.0 % 2.0) |
||
378 | self.check_constant_output(f % 2.0, 3.0 % 2.0) |
||
379 | self.check_constant_output(g % f, 2.0 % 3.0) |
||
380 | |||
381 | self.check_constant_output((f + g) % 2, (3.0 + 2.0) % 2) |
||
382 | self.check_constant_output(2 % (f + g), 2 % (3.0 + 2.0)) |
||
383 | |||
384 | self.check_constant_output(f % (f % g), 3.0 % (3.0 % 2.0)) |
||
385 | self.check_constant_output((f % f) % g, (3.0 % 3.0) % 2.0) |
||
386 | |||
387 | self.check_constant_output((f + g) % (f * g), 5.0 % 6.0) |
||
388 | |||
389 | def test_math_functions(self): |
||
390 | f, g = self.f, self.g |
||
391 | |||
392 | fake_raw_data = self.fake_raw_data |
||
393 | alt_fake_raw_data = { |
||
394 | self.f: full((5, 5), .5), |
||
395 | self.g: full((5, 5), -.5), |
||
396 | } |
||
397 | |||
398 | for funcname in NUMEXPR_MATH_FUNCS: |
||
399 | method = methodcaller(funcname) |
||
400 | func = getattr(numpy, funcname) |
||
401 | |||
402 | # These methods have domains in [0, 1], so we need alternate inputs |
||
403 | # that are in the domain. |
||
404 | if funcname in ('arcsin', 'arccos', 'arctanh'): |
||
405 | self.fake_raw_data = alt_fake_raw_data |
||
406 | else: |
||
407 | self.fake_raw_data = fake_raw_data |
||
408 | |||
409 | f_val = self.fake_raw_data[f][0, 0] |
||
410 | g_val = self.fake_raw_data[g][0, 0] |
||
411 | |||
412 | self.check_constant_output(method(f), func(f_val)) |
||
413 | self.check_constant_output(method(g), func(g_val)) |
||
414 | |||
415 | self.check_constant_output(method(f) + 1, func(f_val) + 1) |
||
416 | self.check_constant_output(1 + method(f), 1 + func(f_val)) |
||
417 | |||
418 | self.check_constant_output(method(f + .25), func(f_val + .25)) |
||
419 | self.check_constant_output(method(.25 + f), func(.25 + f_val)) |
||
420 | |||
421 | self.check_constant_output( |
||
422 | method(f) + method(g), |
||
423 | func(f_val) + func(g_val), |
||
424 | ) |
||
425 | self.check_constant_output( |
||
426 | method(f + g), |
||
427 | func(f_val + g_val), |
||
428 | ) |
||
429 | |||
430 | def test_comparisons(self): |
||
431 | f, g, h = self.f, self.g, self.h |
||
432 | self.fake_raw_data = { |
||
433 | f: arange(25).reshape(5, 5), |
||
434 | g: arange(25).reshape(5, 5) - eye(5), |
||
435 | h: full((5, 5), 5), |
||
436 | } |
||
437 | f_data = self.fake_raw_data[f] |
||
438 | g_data = self.fake_raw_data[g] |
||
439 | |||
440 | cases = [ |
||
441 | # Sanity Check with hand-computed values. |
||
442 | (f, g, eye(5), zeros((5, 5))), |
||
443 | (f, 10, f_data, 10), |
||
444 | (10, f, 10, f_data), |
||
445 | (f, f, f_data, f_data), |
||
446 | (f + 1, f, f_data + 1, f_data), |
||
447 | (1 + f, f, 1 + f_data, f_data), |
||
448 | (f, g, f_data, g_data), |
||
449 | (f + 1, g, f_data + 1, g_data), |
||
450 | (f, g + 1, f_data, g_data + 1), |
||
451 | (f + 1, g + 1, f_data + 1, g_data + 1), |
||
452 | ((f + g) / 2, f ** 2, (f_data + g_data) / 2, f_data ** 2), |
||
453 | ] |
||
454 | for op in (gt, ge, lt, le, ne): |
||
455 | for expr_lhs, expr_rhs, expected_lhs, expected_rhs in cases: |
||
456 | self.check_output( |
||
457 | op(expr_lhs, expr_rhs), |
||
458 | op(expected_lhs, expected_rhs), |
||
459 | ) |
||
460 | |||
461 | def test_boolean_binops(self): |
||
462 | f, g, h = self.f, self.g, self.h |
||
463 | self.fake_raw_data = { |
||
464 | f: arange(25).reshape(5, 5), |
||
465 | g: arange(25).reshape(5, 5) - eye(5), |
||
466 | h: full((5, 5), 5), |
||
467 | } |
||
468 | |||
469 | # Should be True on the diagonal. |
||
470 | eye_filter = f > g |
||
471 | # Should be True in the first row only. |
||
472 | first_row_filter = f < h |
||
473 | |||
474 | eye_mask = eye(5, dtype=bool) |
||
475 | first_row_mask = zeros((5, 5), dtype=bool) |
||
476 | first_row_mask[0] = 1 |
||
477 | |||
478 | self.check_output(eye_filter, eye_mask) |
||
479 | self.check_output(first_row_filter, first_row_mask) |
||
480 | |||
481 | for op in (and_, or_): # NumExpr doesn't support xor. |
||
482 | self.check_output( |
||
483 | op(eye_filter, first_row_filter), |
||
484 | op(eye_mask, first_row_mask), |
||
485 | ) |
||
486 |