| Total Complexity | 48 |
| Total Lines | 249 |
| Duplicated Lines | 0 % |
Complex classes like AnniesLasso.tests.TestBaseCannonModel often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
| 1 | #!/usr/bin/env python |
||
| 45 | class TestBaseCannonModel(unittest.TestCase): |
||
| 46 | |||
| 47 | def setUp(self): |
||
| 48 | # Initialise some faux data and labels. |
||
| 49 | labels = "ABCDE" |
||
| 50 | N_labels = len(labels) |
||
| 51 | N_stars = np.random.randint(10, 500) |
||
| 52 | N_pixels = np.random.randint(1, 10000) |
||
| 53 | shape = (N_stars, N_pixels) |
||
| 54 | |||
| 55 | self.valid_training_labels = np.rec.array( |
||
| 56 | np.random.uniform(low=0.5, high=1.5, size=(N_stars, N_labels)), |
||
| 57 | dtype=[(label, '<f8') for label in labels]) |
||
| 58 | |||
| 59 | self.valid_fluxes = np.random.uniform(low=0.5, high=1.5, size=shape) |
||
| 60 | self.valid_flux_uncertainties = np.random.uniform(low=0.5, high=1.5, |
||
| 61 | size=shape) |
||
| 62 | |||
| 63 | def runTest(self): |
||
| 64 | None |
||
| 65 | |||
| 66 | def get_model(self, **kwargs): |
||
| 67 | return model.BaseCannonModel( |
||
| 68 | self.valid_training_labels, self.valid_fluxes, |
||
| 69 | self.valid_flux_uncertainties, **kwargs) |
||
| 70 | |||
| 71 | def test_repr(self): |
||
| 72 | self.runTest() # Just for that 100%, baby. |
||
| 73 | m = self.get_model() |
||
| 74 | print("{0} {1}".format(m.__str__(), m.__repr__())) |
||
| 75 | |||
| 76 | def test_get_dispersion(self): |
||
| 77 | m = self.get_model() |
||
| 78 | self.assertSequenceEqual( |
||
| 79 | tuple(m.dispersion), |
||
| 80 | tuple(np.arange(self.valid_fluxes.shape[1]))) |
||
| 81 | |||
| 82 | def test_set_dispersion(self): |
||
| 83 | m = self.get_model() |
||
| 84 | for item in (None, False, True): |
||
| 85 | # Incorrect data type (not an iterable) |
||
| 86 | with self.assertRaises(TypeError): |
||
| 87 | m.dispersion = item |
||
| 88 | |||
| 89 | for item in ("", {}, [], (), set()): |
||
| 90 | # These are iterable but have the wrong lengths. |
||
| 91 | with self.assertRaises(ValueError): |
||
| 92 | m.dispersion = item |
||
| 93 | |||
| 94 | with self.assertRaises(ValueError): |
||
| 95 | m.dispersion = [3,4,2,1] |
||
| 96 | |||
| 97 | # These should work. |
||
| 98 | m.dispersion = 10 + np.arange(self.valid_fluxes.shape[1]) |
||
| 99 | m.dispersion = -100 + np.arange(self.valid_fluxes.shape[1]) |
||
| 100 | m.dispersion = 520938.4 + np.arange(self.valid_fluxes.shape[1]) |
||
| 101 | |||
| 102 | # Disallow non-finite numbers. |
||
| 103 | with self.assertRaises(ValueError): |
||
| 104 | d = np.arange(self.valid_fluxes.shape[1], dtype=float) |
||
| 105 | d[0] = np.nan |
||
| 106 | m.dispersion = d |
||
| 107 | |||
| 108 | with self.assertRaises(ValueError): |
||
| 109 | d = np.arange(self.valid_fluxes.shape[1], dtype=float) |
||
| 110 | d[0] = np.inf |
||
| 111 | m.dispersion = d |
||
| 112 | |||
| 113 | with self.assertRaises(ValueError): |
||
| 114 | d = np.arange(self.valid_fluxes.shape[1], dtype=float) |
||
| 115 | d[0] = -np.inf |
||
| 116 | m.dispersion = d |
||
| 117 | |||
| 118 | # Disallow non-float like things. |
||
| 119 | with self.assertRaises(ValueError): |
||
| 120 | d = np.array([""] * self.valid_fluxes.shape[1]) |
||
| 121 | m.dispersion = d |
||
| 122 | |||
| 123 | with self.assertRaises(ValueError): |
||
| 124 | d = np.array([None] * self.valid_fluxes.shape[1]) |
||
| 125 | m.dispersion = d |
||
| 126 | |||
| 127 | def test_get_training_data(self): |
||
| 128 | m = self.get_model() |
||
| 129 | self.assertIsNotNone(m.training_labels) |
||
| 130 | self.assertIsNotNone(m.training_fluxes) |
||
| 131 | self.assertIsNotNone(m.training_flux_uncertainties) |
||
| 132 | |||
| 133 | def test_invalid_label_names(self): |
||
| 134 | m = self.get_model() |
||
| 135 | for character in m._forbidden_label_characters: |
||
| 136 | |||
| 137 | invalid_labels = [] + list(m.labels_available) |
||
| 138 | invalid_labels[0] = "HELLO_{}".format(character) |
||
| 139 | |||
| 140 | N_stars = len(self.valid_training_labels) |
||
| 141 | N_labels = len(invalid_labels) |
||
| 142 | invalid_training_labels = np.rec.array( |
||
| 143 | np.random.uniform(size=(N_stars, N_labels)), |
||
| 144 | dtype=[(l, '<f8') for l in invalid_labels]) |
||
| 145 | |||
| 146 | m = model.BaseCannonModel(invalid_training_labels, |
||
| 147 | self.valid_fluxes, self.valid_flux_uncertainties, |
||
| 148 | live_dangerously=True) |
||
| 149 | |||
| 150 | m._forbidden_label_characters = None |
||
| 151 | self.assertTrue(m._verify_labels_available()) |
||
| 152 | |||
| 153 | with self.assertRaises(ValueError): |
||
| 154 | m = model.BaseCannonModel(invalid_training_labels, |
||
| 155 | self.valid_fluxes, self.valid_flux_uncertainties) |
||
| 156 | |||
| 157 | def test_get_label_vector(self): |
||
| 158 | m = self.get_model() |
||
| 159 | m.label_vector = "A + B + C" |
||
| 160 | self.assertEqual(m.pixel_label_vector(1), [ |
||
| 161 | [("A", 1)], |
||
| 162 | [("B", 1)], |
||
| 163 | [("C", 1)] |
||
| 164 | ]) |
||
| 165 | |||
| 166 | def test_set_label_vector(self): |
||
| 167 | m = self.get_model() |
||
| 168 | label_vector = "A + B + C + D + E" |
||
| 169 | |||
| 170 | m.label_vector = label_vector |
||
| 171 | self.assertEqual(m.label_vector, utils.parse_label_vector(label_vector)) |
||
| 172 | self.assertEqual("1 + A + B + C + D + E", m.human_readable_label_vector) |
||
| 173 | |||
| 174 | with self.assertRaises(ValueError): |
||
| 175 | m.label_vector = "A + G" |
||
| 176 | |||
| 177 | m.label_vector = None |
||
| 178 | self.assertIsNone(m.label_vector) |
||
| 179 | |||
| 180 | for item in (True, False, 0, 1.0): |
||
| 181 | with self.assertRaises(TypeError): |
||
| 182 | m.label_vector = item |
||
| 183 | |||
| 184 | def test_label_getsetters(self): |
||
| 185 | |||
| 186 | m = self.get_model() |
||
| 187 | self.assertEqual((), m.labels) |
||
| 188 | |||
| 189 | m.label_vector = "A + B + C" |
||
| 190 | self.assertSequenceEqual(("A", "B", "C"), tuple(m.labels)) |
||
| 191 | |||
| 192 | with self.assertRaises(AttributeError): |
||
| 193 | m.labels = None |
||
| 194 | |||
| 195 | def test_inheritence(self): |
||
| 196 | m = self.get_model() |
||
| 197 | m.label_vector = "A + B + C" |
||
| 198 | with self.assertRaises(NotImplementedError): |
||
| 199 | m.train() |
||
| 200 | m._trained = True |
||
| 201 | with self.assertRaises(NotImplementedError): |
||
| 202 | m.predict() |
||
| 203 | with self.assertRaises(NotImplementedError): |
||
| 204 | m.fit() |
||
| 205 | |||
| 206 | def test_get_label_indices(self): |
||
| 207 | m = self.get_model() |
||
| 208 | m.label_vector = "A^5 + A^2 + B^3 + C + C*D + D^6" |
||
| 209 | self.assertEqual([1, 2, 3, 5], m._get_lowest_order_label_indices()) |
||
| 210 | |||
| 211 | def test_data_verification(self): |
||
| 212 | m = self.get_model() |
||
| 213 | m._training_fluxes = m._training_fluxes.reshape(1, -1) |
||
| 214 | with self.assertRaises(ValueError): |
||
| 215 | m._verify_training_data() |
||
| 216 | |||
| 217 | m._training_flux_uncertainties = \ |
||
| 218 | m._training_flux_uncertainties.reshape(m._training_fluxes.shape) |
||
| 219 | with self.assertRaises(ValueError): |
||
| 220 | m._verify_training_data() |
||
| 221 | |||
| 222 | with self.assertRaises(ValueError): |
||
| 223 | m = self.get_model(dispersion=[1,2,3]) |
||
| 224 | |||
| 225 | N_labels = 2 |
||
| 226 | N_stars, N_pixels = self.valid_fluxes.shape |
||
| 227 | invalid_training_labels = np.random.uniform( |
||
| 228 | low=0.5, high=1.5, size=(N_stars, N_labels)) |
||
| 229 | with self.assertRaises(ValueError): |
||
| 230 | m = model.BaseCannonModel(invalid_training_labels, |
||
| 231 | self.valid_fluxes, self.valid_flux_uncertainties) |
||
| 232 | |||
| 233 | def test_labels_array(self): |
||
| 234 | m = self.get_model() |
||
| 235 | m.label_vector = "A^2 + B^3 + C^5" |
||
| 236 | |||
| 237 | for i, label in enumerate("ABC"): |
||
| 238 | foo = m.labels_array[:, i] |
||
| 239 | bar = np.array(m.training_labels[label]).flatten() |
||
| 240 | self.assertTrue(np.allclose(foo, bar)) |
||
| 241 | |||
| 242 | with self.assertRaises(AttributeError): |
||
| 243 | m.labels_array = None |
||
| 244 | |||
| 245 | def test_label_vector_array(self): |
||
| 246 | m = self.get_model() |
||
| 247 | m.label_vector = "A^2.0 + B^3.4 + C^5" |
||
| 248 | m.pivots = np.zeros(len(m.labels)) |
||
| 249 | |||
| 250 | self.assertTrue(np.allclose( |
||
| 251 | np.array(m.training_labels["A"]**2).flatten(), |
||
| 252 | m.label_vector_array[1], |
||
| 253 | )) |
||
| 254 | self.assertTrue(np.allclose( |
||
| 255 | np.array(m.training_labels["B"]**3.4).flatten(), |
||
| 256 | m.label_vector_array[2] |
||
| 257 | )) |
||
| 258 | self.assertTrue(np.allclose( |
||
| 259 | np.array(m.training_labels["C"]**5).flatten(), |
||
| 260 | m.label_vector_array[3] |
||
| 261 | )) |
||
| 262 | |||
| 263 | m.training_labels["A"][0] = np.nan |
||
| 264 | m.label_vector_array # For Coveralls. |
||
| 265 | |||
| 266 | kwd1 = { |
||
| 267 | "A": float(m.training_labels["A"][1]), |
||
| 268 | "B": float(m.training_labels["B"][1]), |
||
| 269 | "C": float(m.training_labels["C"][1]) |
||
| 270 | } |
||
| 271 | kwd2 = { |
||
| 272 | "A": [m.training_labels["A"][1][0]], |
||
| 273 | "B": [m.training_labels["B"][1][0]], |
||
| 274 | "C": [m.training_labels["C"][1][0]] |
||
| 275 | } |
||
| 276 | self.assertTrue(np.allclose( |
||
| 277 | model._build_label_vector_rows(m.label_vector, kwd1), |
||
| 278 | model._build_label_vector_rows(m.label_vector, kwd2) |
||
| 279 | )) |
||
| 280 | |||
| 281 | def test_format_input_labels(self): |
||
| 282 | m = self.get_model() |
||
| 283 | m.label_vector = "A^2.0 + B^3.4 + C^5" |
||
| 284 | |||
| 285 | kwds = {"A": [5], "B": [3], "C": [0.43]} |
||
| 286 | for k, v in m._format_input_labels(None, **kwds).items(): |
||
| 287 | self.assertEqual(kwds[k], v) |
||
| 288 | for k, v in m._format_input_labels([5, 3, 0.43]).items(): |
||
| 289 | self.assertEqual(kwds[k], v) |
||
| 290 | |||
| 291 | kwds_input = {k: v[0] for k, v in kwds.items() } |
||
| 292 | for k, v in m._format_input_labels(None, **kwds_input).items(): |
||
| 293 | self.assertEqual(kwds[k], v) |
||
| 294 | |||
| 299 |