| Total Complexity | 43 |
| Total Lines | 421 |
| Duplicated Lines | 0 % |
Complex classes like tests.pipeline.NumericalExpressionTestCase often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
| 1 | from operator import ( |
||
| 64 | class NumericalExpressionTestCase(TestCase): |
||
| 65 | |||
| 66 | def setUp(self): |
||
| 67 | self.dates = date_range('2014-01-01', periods=5, freq='D') |
||
| 68 | self.assets = Int64Index(range(5)) |
||
| 69 | self.f = F() |
||
| 70 | self.g = G() |
||
| 71 | self.h = H() |
||
| 72 | self.d = DateFactor() |
||
| 73 | self.fake_raw_data = { |
||
| 74 | self.f: full((5, 5), 3), |
||
| 75 | self.g: full((5, 5), 2), |
||
| 76 | self.h: full((5, 5), 1), |
||
| 77 | self.d: full((5, 5), 0, dtype='datetime64[ns]'), |
||
| 78 | } |
||
| 79 | self.mask = DataFrame(True, index=self.dates, columns=self.assets) |
||
| 80 | |||
| 81 | def check_output(self, expr, expected): |
||
| 82 | result = expr._compute( |
||
| 83 | [self.fake_raw_data[input_] for input_ in expr.inputs], |
||
| 84 | self.mask.index, |
||
| 85 | self.mask.columns, |
||
| 86 | self.mask.values, |
||
| 87 | ) |
||
| 88 | check_arrays(result, expected) |
||
| 89 | |||
| 90 | def check_constant_output(self, expr, expected): |
||
| 91 | self.assertFalse(isnan(expected)) |
||
| 92 | return self.check_output(expr, full((5, 5), expected)) |
||
| 93 | |||
| 94 | def test_validate_good(self): |
||
| 95 | f = self.f |
||
| 96 | g = self.g |
||
| 97 | |||
| 98 | NumericalExpression("x_0", (f,), dtype=float64_dtype) |
||
| 99 | NumericalExpression("x_0 ", (f,), dtype=float64_dtype) |
||
| 100 | NumericalExpression("x_0 + x_0", (f,), dtype=float64_dtype) |
||
| 101 | NumericalExpression("x_0 + 2", (f,), dtype=float64_dtype) |
||
| 102 | NumericalExpression("2 * x_0", (f,), dtype=float64_dtype) |
||
| 103 | NumericalExpression("x_0 + x_1", (f, g), dtype=float64_dtype) |
||
| 104 | NumericalExpression("x_0 + x_1 + x_0", (f, g), dtype=float64_dtype) |
||
| 105 | NumericalExpression("x_0 + 1 + x_1", (f, g), dtype=float64_dtype) |
||
| 106 | |||
| 107 | def test_validate_bad(self): |
||
| 108 | f, g, h = self.f, self.g, self.h |
||
| 109 | |||
| 110 | # Too few inputs. |
||
| 111 | with self.assertRaises(ValueError): |
||
| 112 | NumericalExpression("x_0", (), dtype=float64_dtype) |
||
| 113 | with self.assertRaises(ValueError): |
||
| 114 | NumericalExpression("x_0 + x_1", (f,), dtype=float64_dtype) |
||
| 115 | |||
| 116 | # Too many inputs. |
||
| 117 | with self.assertRaises(ValueError): |
||
| 118 | NumericalExpression("x_0", (f, g), dtype=float64_dtype) |
||
| 119 | with self.assertRaises(ValueError): |
||
| 120 | NumericalExpression("x_0 + x_1", (f, g, h), dtype=float64_dtype) |
||
| 121 | |||
| 122 | # Invalid variable name. |
||
| 123 | with self.assertRaises(ValueError): |
||
| 124 | NumericalExpression("x_0x_1", (f,), dtype=float64_dtype) |
||
| 125 | with self.assertRaises(ValueError): |
||
| 126 | NumericalExpression("x_0x_1", (f, g), dtype=float64_dtype) |
||
| 127 | |||
| 128 | # Variable index must start at 0. |
||
| 129 | with self.assertRaises(ValueError): |
||
| 130 | NumericalExpression("x_1", (f,), dtype=float64_dtype) |
||
| 131 | |||
| 132 | # Scalar operands must be numeric. |
||
| 133 | with self.assertRaises(TypeError): |
||
| 134 | "2" + f |
||
| 135 | with self.assertRaises(TypeError): |
||
| 136 | f + "2" |
||
| 137 | with self.assertRaises(TypeError): |
||
| 138 | f > "2" |
||
| 139 | |||
| 140 | # Boolean binary operators must be between filters. |
||
| 141 | with self.assertRaises(TypeError): |
||
| 142 | f + (f > 2) |
||
| 143 | with self.assertRaises(TypeError): |
||
| 144 | (f > f) > f |
||
| 145 | |||
| 146 | def test_combine_datetimes(self): |
||
| 147 | with self.assertRaises(TypeError) as e: |
||
| 148 | self.d + self.d |
||
| 149 | message = e.exception.args[0] |
||
| 150 | expected = ( |
||
| 151 | "Don't know how to compute datetime64[ns] + datetime64[ns].\n" |
||
| 152 | "Arithmetic operators are only supported on Factors of dtype " |
||
| 153 | "'float64'." |
||
| 154 | ) |
||
| 155 | self.assertEqual(message, expected) |
||
| 156 | |||
| 157 | # Confirm that * shows up in the error instead of +. |
||
| 158 | with self.assertRaises(TypeError) as e: |
||
| 159 | self.d * self.d |
||
| 160 | message = e.exception.args[0] |
||
| 161 | expected = ( |
||
| 162 | "Don't know how to compute datetime64[ns] * datetime64[ns].\n" |
||
| 163 | "Arithmetic operators are only supported on Factors of dtype " |
||
| 164 | "'float64'." |
||
| 165 | ) |
||
| 166 | self.assertEqual(message, expected) |
||
| 167 | |||
| 168 | def test_combine_datetime_with_float(self): |
||
| 169 | # Test with both float-type factors and numeric values. |
||
| 170 | for float_value in (self.f, float64(1.0), 1.0): |
||
| 171 | for op, sym in ((add, '+'), (mul, '*')): |
||
| 172 | with self.assertRaises(TypeError) as e: |
||
| 173 | op(self.f, self.d) |
||
| 174 | message = e.exception.args[0] |
||
| 175 | expected = ( |
||
| 176 | "Don't know how to compute float64 {sym} datetime64[ns].\n" |
||
| 177 | "Arithmetic operators are only supported on Factors of " |
||
| 178 | "dtype 'float64'." |
||
| 179 | ).format(sym=sym) |
||
| 180 | self.assertEqual(message, expected) |
||
| 181 | |||
| 182 | with self.assertRaises(TypeError) as e: |
||
| 183 | op(self.d, self.f) |
||
| 184 | message = e.exception.args[0] |
||
| 185 | expected = ( |
||
| 186 | "Don't know how to compute datetime64[ns] {sym} float64.\n" |
||
| 187 | "Arithmetic operators are only supported on Factors of " |
||
| 188 | "dtype 'float64'." |
||
| 189 | ).format(sym=sym) |
||
| 190 | self.assertEqual(message, expected) |
||
| 191 | |||
| 192 | def test_negate_datetime(self): |
||
| 193 | with self.assertRaises(TypeError) as e: |
||
| 194 | -self.d |
||
| 195 | |||
| 196 | message = e.exception.args[0] |
||
| 197 | expected = ( |
||
| 198 | "Can't apply unary operator '-' to instance of " |
||
| 199 | "'DateFactor' with dtype 'datetime64[ns]'.\n" |
||
| 200 | "'-' is only supported for Factors of dtype 'float64'." |
||
| 201 | ) |
||
| 202 | self.assertEqual(message, expected) |
||
| 203 | |||
| 204 | def test_negate(self): |
||
| 205 | f, g = self.f, self.g |
||
| 206 | |||
| 207 | self.check_constant_output(-f, -3.0) |
||
| 208 | self.check_constant_output(--f, 3.0) |
||
| 209 | self.check_constant_output(---f, -3.0) |
||
| 210 | |||
| 211 | self.check_constant_output(-(f + f), -6.0) |
||
| 212 | self.check_constant_output(-f + -f, -6.0) |
||
| 213 | self.check_constant_output(-(-f + -f), 6.0) |
||
| 214 | |||
| 215 | self.check_constant_output(f + -g, 1.0) |
||
| 216 | self.check_constant_output(f - -g, 5.0) |
||
| 217 | |||
| 218 | self.check_constant_output(-(f + g) + (f + g), 0.0) |
||
| 219 | self.check_constant_output((f + g) + -(f + g), 0.0) |
||
| 220 | self.check_constant_output(-(f + g) + -(f + g), -10.0) |
||
| 221 | |||
| 222 | def test_add(self): |
||
| 223 | f, g = self.f, self.g |
||
| 224 | |||
| 225 | self.check_constant_output(f + g, 5.0) |
||
| 226 | |||
| 227 | self.check_constant_output((1 + f) + g, 6.0) |
||
| 228 | self.check_constant_output(1 + (f + g), 6.0) |
||
| 229 | self.check_constant_output((f + 1) + g, 6.0) |
||
| 230 | self.check_constant_output(f + (1 + g), 6.0) |
||
| 231 | self.check_constant_output((f + g) + 1, 6.0) |
||
| 232 | self.check_constant_output(f + (g + 1), 6.0) |
||
| 233 | |||
| 234 | self.check_constant_output((f + f) + f, 9.0) |
||
| 235 | self.check_constant_output(f + (f + f), 9.0) |
||
| 236 | |||
| 237 | self.check_constant_output((f + g) + f, 8.0) |
||
| 238 | self.check_constant_output(f + (g + f), 8.0) |
||
| 239 | |||
| 240 | self.check_constant_output((f + g) + (f + g), 10.0) |
||
| 241 | self.check_constant_output((f + g) + (g + f), 10.0) |
||
| 242 | self.check_constant_output((g + f) + (f + g), 10.0) |
||
| 243 | self.check_constant_output((g + f) + (g + f), 10.0) |
||
| 244 | |||
| 245 | def test_subtract(self): |
||
| 246 | f, g = self.f, self.g |
||
| 247 | |||
| 248 | self.check_constant_output(f - g, 1.0) # 3 - 2 |
||
| 249 | |||
| 250 | self.check_constant_output((1 - f) - g, -4.) # (1 - 3) - 2 |
||
| 251 | self.check_constant_output(1 - (f - g), 0.0) # 1 - (3 - 2) |
||
| 252 | self.check_constant_output((f - 1) - g, 0.0) # (3 - 1) - 2 |
||
| 253 | self.check_constant_output(f - (1 - g), 4.0) # 3 - (1 - 2) |
||
| 254 | self.check_constant_output((f - g) - 1, 0.0) # (3 - 2) - 1 |
||
| 255 | self.check_constant_output(f - (g - 1), 2.0) # 3 - (2 - 1) |
||
| 256 | |||
| 257 | self.check_constant_output((f - f) - f, -3.) # (3 - 3) - 3 |
||
| 258 | self.check_constant_output(f - (f - f), 3.0) # 3 - (3 - 3) |
||
| 259 | |||
| 260 | self.check_constant_output((f - g) - f, -2.) # (3 - 2) - 3 |
||
| 261 | self.check_constant_output(f - (g - f), 4.0) # 3 - (2 - 3) |
||
| 262 | |||
| 263 | self.check_constant_output((f - g) - (f - g), 0.0) # (3 - 2) - (3 - 2) |
||
| 264 | self.check_constant_output((f - g) - (g - f), 2.0) # (3 - 2) - (2 - 3) |
||
| 265 | self.check_constant_output((g - f) - (f - g), -2.) # (2 - 3) - (3 - 2) |
||
| 266 | self.check_constant_output((g - f) - (g - f), 0.0) # (2 - 3) - (2 - 3) |
||
| 267 | |||
| 268 | def test_multiply(self): |
||
| 269 | f, g = self.f, self.g |
||
| 270 | |||
| 271 | self.check_constant_output(f * g, 6.0) |
||
| 272 | |||
| 273 | self.check_constant_output((2 * f) * g, 12.0) |
||
| 274 | self.check_constant_output(2 * (f * g), 12.0) |
||
| 275 | self.check_constant_output((f * 2) * g, 12.0) |
||
| 276 | self.check_constant_output(f * (2 * g), 12.0) |
||
| 277 | self.check_constant_output((f * g) * 2, 12.0) |
||
| 278 | self.check_constant_output(f * (g * 2), 12.0) |
||
| 279 | |||
| 280 | self.check_constant_output((f * f) * f, 27.0) |
||
| 281 | self.check_constant_output(f * (f * f), 27.0) |
||
| 282 | |||
| 283 | self.check_constant_output((f * g) * f, 18.0) |
||
| 284 | self.check_constant_output(f * (g * f), 18.0) |
||
| 285 | |||
| 286 | self.check_constant_output((f * g) * (f * g), 36.0) |
||
| 287 | self.check_constant_output((f * g) * (g * f), 36.0) |
||
| 288 | self.check_constant_output((g * f) * (f * g), 36.0) |
||
| 289 | self.check_constant_output((g * f) * (g * f), 36.0) |
||
| 290 | |||
| 291 | self.check_constant_output(f * f * f * 0 * f * f, 0.0) |
||
| 292 | |||
| 293 | def test_divide(self): |
||
| 294 | f, g = self.f, self.g |
||
| 295 | |||
| 296 | self.check_constant_output(f / g, 3.0 / 2.0) |
||
| 297 | |||
| 298 | self.check_constant_output( |
||
| 299 | (2 / f) / g, |
||
| 300 | (2 / 3.0) / 2.0 |
||
| 301 | ) |
||
| 302 | self.check_constant_output( |
||
| 303 | 2 / (f / g), |
||
| 304 | 2 / (3.0 / 2.0), |
||
| 305 | ) |
||
| 306 | self.check_constant_output( |
||
| 307 | (f / 2) / g, |
||
| 308 | (3.0 / 2) / 2.0, |
||
| 309 | ) |
||
| 310 | self.check_constant_output( |
||
| 311 | f / (2 / g), |
||
| 312 | 3.0 / (2 / 2.0), |
||
| 313 | ) |
||
| 314 | self.check_constant_output( |
||
| 315 | (f / g) / 2, |
||
| 316 | (3.0 / 2.0) / 2, |
||
| 317 | ) |
||
| 318 | self.check_constant_output( |
||
| 319 | f / (g / 2), |
||
| 320 | 3.0 / (2.0 / 2), |
||
| 321 | ) |
||
| 322 | self.check_constant_output( |
||
| 323 | (f / f) / f, |
||
| 324 | (3.0 / 3.0) / 3.0 |
||
| 325 | ) |
||
| 326 | self.check_constant_output( |
||
| 327 | f / (f / f), |
||
| 328 | 3.0 / (3.0 / 3.0), |
||
| 329 | ) |
||
| 330 | self.check_constant_output( |
||
| 331 | (f / g) / f, |
||
| 332 | (3.0 / 2.0) / 3.0, |
||
| 333 | ) |
||
| 334 | self.check_constant_output( |
||
| 335 | f / (g / f), |
||
| 336 | 3.0 / (2.0 / 3.0), |
||
| 337 | ) |
||
| 338 | |||
| 339 | self.check_constant_output( |
||
| 340 | (f / g) / (f / g), |
||
| 341 | (3.0 / 2.0) / (3.0 / 2.0), |
||
| 342 | ) |
||
| 343 | self.check_constant_output( |
||
| 344 | (f / g) / (g / f), |
||
| 345 | (3.0 / 2.0) / (2.0 / 3.0), |
||
| 346 | ) |
||
| 347 | self.check_constant_output( |
||
| 348 | (g / f) / (f / g), |
||
| 349 | (2.0 / 3.0) / (3.0 / 2.0), |
||
| 350 | ) |
||
| 351 | self.check_constant_output( |
||
| 352 | (g / f) / (g / f), |
||
| 353 | (2.0 / 3.0) / (2.0 / 3.0), |
||
| 354 | ) |
||
| 355 | |||
| 356 | def test_pow(self): |
||
| 357 | f, g = self.f, self.g |
||
| 358 | |||
| 359 | self.check_constant_output(f ** g, 3.0 ** 2) |
||
| 360 | self.check_constant_output(2 ** f, 2.0 ** 3) |
||
| 361 | self.check_constant_output(f ** 2, 3.0 ** 2) |
||
| 362 | |||
| 363 | self.check_constant_output((f + g) ** 2, (3.0 + 2.0) ** 2) |
||
| 364 | self.check_constant_output(2 ** (f + g), 2 ** (3.0 + 2.0)) |
||
| 365 | |||
| 366 | self.check_constant_output(f ** (f ** g), 3.0 ** (3.0 ** 2.0)) |
||
| 367 | self.check_constant_output((f ** f) ** g, (3.0 ** 3.0) ** 2.0) |
||
| 368 | |||
| 369 | self.check_constant_output((f ** g) ** (f ** g), 9.0 ** 9.0) |
||
| 370 | self.check_constant_output((f ** g) ** (g ** f), 9.0 ** 8.0) |
||
| 371 | self.check_constant_output((g ** f) ** (f ** g), 8.0 ** 9.0) |
||
| 372 | self.check_constant_output((g ** f) ** (g ** f), 8.0 ** 8.0) |
||
| 373 | |||
| 374 | def test_mod(self): |
||
| 375 | f, g = self.f, self.g |
||
| 376 | |||
| 377 | self.check_constant_output(f % g, 3.0 % 2.0) |
||
| 378 | self.check_constant_output(f % 2.0, 3.0 % 2.0) |
||
| 379 | self.check_constant_output(g % f, 2.0 % 3.0) |
||
| 380 | |||
| 381 | self.check_constant_output((f + g) % 2, (3.0 + 2.0) % 2) |
||
| 382 | self.check_constant_output(2 % (f + g), 2 % (3.0 + 2.0)) |
||
| 383 | |||
| 384 | self.check_constant_output(f % (f % g), 3.0 % (3.0 % 2.0)) |
||
| 385 | self.check_constant_output((f % f) % g, (3.0 % 3.0) % 2.0) |
||
| 386 | |||
| 387 | self.check_constant_output((f + g) % (f * g), 5.0 % 6.0) |
||
| 388 | |||
| 389 | def test_math_functions(self): |
||
| 390 | f, g = self.f, self.g |
||
| 391 | |||
| 392 | fake_raw_data = self.fake_raw_data |
||
| 393 | alt_fake_raw_data = { |
||
| 394 | self.f: full((5, 5), .5), |
||
| 395 | self.g: full((5, 5), -.5), |
||
| 396 | } |
||
| 397 | |||
| 398 | for funcname in NUMEXPR_MATH_FUNCS: |
||
| 399 | method = methodcaller(funcname) |
||
| 400 | func = getattr(numpy, funcname) |
||
| 401 | |||
| 402 | # These methods have domains in [0, 1], so we need alternate inputs |
||
| 403 | # that are in the domain. |
||
| 404 | if funcname in ('arcsin', 'arccos', 'arctanh'): |
||
| 405 | self.fake_raw_data = alt_fake_raw_data |
||
| 406 | else: |
||
| 407 | self.fake_raw_data = fake_raw_data |
||
| 408 | |||
| 409 | f_val = self.fake_raw_data[f][0, 0] |
||
| 410 | g_val = self.fake_raw_data[g][0, 0] |
||
| 411 | |||
| 412 | self.check_constant_output(method(f), func(f_val)) |
||
| 413 | self.check_constant_output(method(g), func(g_val)) |
||
| 414 | |||
| 415 | self.check_constant_output(method(f) + 1, func(f_val) + 1) |
||
| 416 | self.check_constant_output(1 + method(f), 1 + func(f_val)) |
||
| 417 | |||
| 418 | self.check_constant_output(method(f + .25), func(f_val + .25)) |
||
| 419 | self.check_constant_output(method(.25 + f), func(.25 + f_val)) |
||
| 420 | |||
| 421 | self.check_constant_output( |
||
| 422 | method(f) + method(g), |
||
| 423 | func(f_val) + func(g_val), |
||
| 424 | ) |
||
| 425 | self.check_constant_output( |
||
| 426 | method(f + g), |
||
| 427 | func(f_val + g_val), |
||
| 428 | ) |
||
| 429 | |||
| 430 | def test_comparisons(self): |
||
| 431 | f, g, h = self.f, self.g, self.h |
||
| 432 | self.fake_raw_data = { |
||
| 433 | f: arange(25).reshape(5, 5), |
||
| 434 | g: arange(25).reshape(5, 5) - eye(5), |
||
| 435 | h: full((5, 5), 5), |
||
| 436 | } |
||
| 437 | f_data = self.fake_raw_data[f] |
||
| 438 | g_data = self.fake_raw_data[g] |
||
| 439 | |||
| 440 | cases = [ |
||
| 441 | # Sanity Check with hand-computed values. |
||
| 442 | (f, g, eye(5), zeros((5, 5))), |
||
| 443 | (f, 10, f_data, 10), |
||
| 444 | (10, f, 10, f_data), |
||
| 445 | (f, f, f_data, f_data), |
||
| 446 | (f + 1, f, f_data + 1, f_data), |
||
| 447 | (1 + f, f, 1 + f_data, f_data), |
||
| 448 | (f, g, f_data, g_data), |
||
| 449 | (f + 1, g, f_data + 1, g_data), |
||
| 450 | (f, g + 1, f_data, g_data + 1), |
||
| 451 | (f + 1, g + 1, f_data + 1, g_data + 1), |
||
| 452 | ((f + g) / 2, f ** 2, (f_data + g_data) / 2, f_data ** 2), |
||
| 453 | ] |
||
| 454 | for op in (gt, ge, lt, le, ne): |
||
| 455 | for expr_lhs, expr_rhs, expected_lhs, expected_rhs in cases: |
||
| 456 | self.check_output( |
||
| 457 | op(expr_lhs, expr_rhs), |
||
| 458 | op(expected_lhs, expected_rhs), |
||
| 459 | ) |
||
| 460 | |||
| 461 | def test_boolean_binops(self): |
||
| 462 | f, g, h = self.f, self.g, self.h |
||
| 463 | self.fake_raw_data = { |
||
| 464 | f: arange(25).reshape(5, 5), |
||
| 465 | g: arange(25).reshape(5, 5) - eye(5), |
||
| 466 | h: full((5, 5), 5), |
||
| 467 | } |
||
| 468 | |||
| 469 | # Should be True on the diagonal. |
||
| 470 | eye_filter = f > g |
||
| 471 | # Should be True in the first row only. |
||
| 472 | first_row_filter = f < h |
||
| 473 | |||
| 474 | eye_mask = eye(5, dtype=bool) |
||
| 475 | first_row_mask = zeros((5, 5), dtype=bool) |
||
| 476 | first_row_mask[0] = 1 |
||
| 477 | |||
| 478 | self.check_output(eye_filter, eye_mask) |
||
| 479 | self.check_output(first_row_filter, first_row_mask) |
||
| 480 | |||
| 481 | for op in (and_, or_): # NumExpr doesn't support xor. |
||
| 482 | self.check_output( |
||
| 483 | op(eye_filter, first_row_filter), |
||
| 484 | op(eye_mask, first_row_mask), |
||
| 485 | ) |
||
| 486 |