Total Complexity | 48 |
Total Lines | 248 |
Duplicated Lines | 0 % |
Complex classes like AnniesLasso.tests.TestBaseCannonModel often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
1 | #!/usr/bin/env python |
||
43 | class TestBaseCannonModel(unittest.TestCase): |
||
44 | |||
45 | def setUp(self): |
||
46 | # Initialise some faux data and labels. |
||
47 | labels = "ABCDE" |
||
48 | N_labels = len(labels) |
||
49 | N_stars = np.random.randint(10, 500) |
||
50 | N_pixels = np.random.randint(1, 10000) |
||
51 | shape = (N_stars, N_pixels) |
||
52 | |||
53 | self.valid_training_labels = np.rec.array( |
||
54 | np.random.uniform(low=0.5, high=1.5, size=(N_stars, N_labels)), |
||
55 | dtype=[(label, '<f8') for label in labels]) |
||
56 | |||
57 | self.valid_fluxes = np.random.uniform(low=0.5, high=1.5, size=shape) |
||
58 | self.valid_flux_uncertainties = np.random.uniform(low=0.5, high=1.5, |
||
59 | size=shape) |
||
60 | |||
61 | def runTest(self): |
||
62 | None |
||
63 | |||
64 | def get_model(self, **kwargs): |
||
65 | return model.BaseCannonModel( |
||
66 | self.valid_training_labels, self.valid_fluxes, |
||
67 | self.valid_flux_uncertainties, **kwargs) |
||
68 | |||
69 | def test_repr(self): |
||
70 | self.runTest() # Just for that 100%, baby. |
||
71 | m = self.get_model() |
||
72 | print("{0} {1}".format(m.__str__(), m.__repr__())) |
||
73 | |||
74 | def test_get_dispersion(self): |
||
75 | m = self.get_model() |
||
76 | self.assertSequenceEqual( |
||
77 | tuple(m.dispersion), |
||
78 | tuple(np.arange(self.valid_fluxes.shape[1]))) |
||
79 | |||
80 | def test_set_dispersion(self): |
||
81 | m = self.get_model() |
||
82 | for item in (None, False, True): |
||
83 | # Incorrect data type (not an iterable) |
||
84 | with self.assertRaises(TypeError): |
||
85 | m.dispersion = item |
||
86 | |||
87 | for item in ("", {}, [], (), set()): |
||
88 | # These are iterable but have the wrong lengths. |
||
89 | with self.assertRaises(ValueError): |
||
90 | m.dispersion = item |
||
91 | |||
92 | with self.assertRaises(ValueError): |
||
93 | m.dispersion = [3,4,2,1] |
||
94 | |||
95 | # These should work. |
||
96 | m.dispersion = 10 + np.arange(self.valid_fluxes.shape[1]) |
||
97 | m.dispersion = -100 + np.arange(self.valid_fluxes.shape[1]) |
||
98 | m.dispersion = 520938.4 + np.arange(self.valid_fluxes.shape[1]) |
||
99 | |||
100 | # Disallow non-finite numbers. |
||
101 | with self.assertRaises(ValueError): |
||
102 | d = np.arange(self.valid_fluxes.shape[1], dtype=float) |
||
103 | d[0] = np.nan |
||
104 | m.dispersion = d |
||
105 | |||
106 | with self.assertRaises(ValueError): |
||
107 | d = np.arange(self.valid_fluxes.shape[1], dtype=float) |
||
108 | d[0] = np.inf |
||
109 | m.dispersion = d |
||
110 | |||
111 | with self.assertRaises(ValueError): |
||
112 | d = np.arange(self.valid_fluxes.shape[1], dtype=float) |
||
113 | d[0] = -np.inf |
||
114 | m.dispersion = d |
||
115 | |||
116 | # Disallow non-float like things. |
||
117 | with self.assertRaises(ValueError): |
||
118 | d = np.array([""] * self.valid_fluxes.shape[1]) |
||
119 | m.dispersion = d |
||
120 | |||
121 | with self.assertRaises(ValueError): |
||
122 | d = np.array([None] * self.valid_fluxes.shape[1]) |
||
123 | m.dispersion = d |
||
124 | |||
125 | def test_get_training_data(self): |
||
126 | m = self.get_model() |
||
127 | self.assertIsNotNone(m.training_labels) |
||
128 | self.assertIsNotNone(m.training_fluxes) |
||
129 | self.assertIsNotNone(m.training_flux_uncertainties) |
||
130 | |||
131 | def test_invalid_label_names(self): |
||
132 | m = self.get_model() |
||
133 | for character in m._forbidden_label_characters: |
||
134 | |||
135 | invalid_labels = [] + list(m.labels_available) |
||
136 | invalid_labels[0] = "HELLO_{}".format(character) |
||
137 | |||
138 | N_stars = len(self.valid_training_labels) |
||
139 | N_labels = len(invalid_labels) |
||
140 | invalid_training_labels = np.rec.array( |
||
141 | np.random.uniform(size=(N_stars, N_labels)), |
||
142 | dtype=[(l, '<f8') for l in invalid_labels]) |
||
143 | |||
144 | m = model.BaseCannonModel(invalid_training_labels, |
||
145 | self.valid_fluxes, self.valid_flux_uncertainties, |
||
146 | live_dangerously=True) |
||
147 | |||
148 | m._forbidden_label_characters = None |
||
149 | self.assertTrue(m._verify_labels_available()) |
||
150 | |||
151 | with self.assertRaises(ValueError): |
||
152 | m = model.BaseCannonModel(invalid_training_labels, |
||
153 | self.valid_fluxes, self.valid_flux_uncertainties) |
||
154 | |||
155 | def test_get_label_vector(self): |
||
156 | m = self.get_model() |
||
157 | m.label_vector = "A + B + C" |
||
158 | self.assertEqual(m.pixel_label_vector(1), [ |
||
159 | [("A", 1)], |
||
160 | [("B", 1)], |
||
161 | [("C", 1)] |
||
162 | ]) |
||
163 | |||
164 | def test_set_label_vector(self): |
||
165 | m = self.get_model() |
||
166 | label_vector = "A + B + C + D + E" |
||
167 | |||
168 | m.label_vector = label_vector |
||
169 | self.assertEqual(m.label_vector, utils.parse_label_vector(label_vector)) |
||
170 | self.assertEqual("1 + A + B + C + D + E", m.human_readable_label_vector) |
||
171 | |||
172 | with self.assertRaises(ValueError): |
||
173 | m.label_vector = "A + G" |
||
174 | |||
175 | m.label_vector = None |
||
176 | self.assertIsNone(m.label_vector) |
||
177 | |||
178 | for item in (True, False, 0, 1.0): |
||
179 | with self.assertRaises(TypeError): |
||
180 | m.label_vector = item |
||
181 | |||
182 | def test_label_getsetters(self): |
||
183 | |||
184 | m = self.get_model() |
||
185 | self.assertEqual((), m.labels) |
||
186 | |||
187 | m.label_vector = "A + B + C" |
||
188 | self.assertSequenceEqual(("A", "B", "C"), tuple(m.labels)) |
||
189 | |||
190 | with self.assertRaises(AttributeError): |
||
191 | m.labels = None |
||
192 | |||
193 | def test_inheritence(self): |
||
194 | m = self.get_model() |
||
195 | m.label_vector = "A + B + C" |
||
196 | with self.assertRaises(NotImplementedError): |
||
197 | m.train() |
||
198 | m._trained = True |
||
199 | with self.assertRaises(NotImplementedError): |
||
200 | m.predict() |
||
201 | with self.assertRaises(NotImplementedError): |
||
202 | m.fit() |
||
203 | |||
204 | def test_get_label_indices(self): |
||
205 | m = self.get_model() |
||
206 | m.label_vector = "A^5 + A^2 + B^3 + C + C*D + D^6" |
||
207 | self.assertEqual([1, 2, 3, 5], m._get_lowest_order_label_indices()) |
||
208 | |||
209 | def test_data_verification(self): |
||
210 | m = self.get_model() |
||
211 | m._training_fluxes = m._training_fluxes.reshape(1, -1) |
||
212 | with self.assertRaises(ValueError): |
||
213 | m._verify_training_data() |
||
214 | |||
215 | m._training_flux_uncertainties = \ |
||
216 | m._training_flux_uncertainties.reshape(m._training_fluxes.shape) |
||
217 | with self.assertRaises(ValueError): |
||
218 | m._verify_training_data() |
||
219 | |||
220 | with self.assertRaises(ValueError): |
||
221 | m = self.get_model(dispersion=[1,2,3]) |
||
222 | |||
223 | N_labels = 2 |
||
224 | N_stars, N_pixels = self.valid_fluxes.shape |
||
225 | invalid_training_labels = np.random.uniform( |
||
226 | low=0.5, high=1.5, size=(N_stars, N_labels)) |
||
227 | with self.assertRaises(ValueError): |
||
228 | m = model.BaseCannonModel(invalid_training_labels, |
||
229 | self.valid_fluxes, self.valid_flux_uncertainties) |
||
230 | |||
231 | def test_labels_array(self): |
||
232 | m = self.get_model() |
||
233 | m.label_vector = "A^2 + B^3 + C^5" |
||
234 | |||
235 | for i, label in enumerate("ABC"): |
||
236 | foo = m.labels_array[:, i] |
||
237 | bar = np.array(m.training_labels[label]).flatten() |
||
238 | self.assertTrue(np.allclose(foo, bar)) |
||
239 | |||
240 | with self.assertRaises(AttributeError): |
||
241 | m.labels_array = None |
||
242 | |||
243 | def test_label_vector_array(self): |
||
244 | m = self.get_model() |
||
245 | m.label_vector = "A^2.0 + B^3.4 + C^5" |
||
246 | |||
247 | self.assertTrue(np.allclose( |
||
248 | np.array(m.training_labels["A"]**2).flatten(), |
||
249 | m.label_vector_array[1], |
||
250 | )) |
||
251 | self.assertTrue(np.allclose( |
||
252 | np.array(m.training_labels["B"]**3.4).flatten(), |
||
253 | m.label_vector_array[2] |
||
254 | )) |
||
255 | self.assertTrue(np.allclose( |
||
256 | np.array(m.training_labels["C"]**5).flatten(), |
||
257 | m.label_vector_array[3] |
||
258 | )) |
||
259 | |||
260 | m.training_labels["A"][0] = np.nan |
||
261 | m.label_vector_array # For Coveralls. |
||
262 | |||
263 | kwd1 = { |
||
264 | "A": float(m.training_labels["A"][1]), |
||
265 | "B": float(m.training_labels["B"][1]), |
||
266 | "C": float(m.training_labels["C"][1]) |
||
267 | } |
||
268 | kwd2 = { |
||
269 | "A": [m.training_labels["A"][1][0]], |
||
270 | "B": [m.training_labels["B"][1][0]], |
||
271 | "C": [m.training_labels["C"][1][0]] |
||
272 | } |
||
273 | self.assertTrue(np.allclose( |
||
274 | model._build_label_vector_rows(m.label_vector, kwd1), |
||
275 | model._build_label_vector_rows(m.label_vector, kwd2) |
||
276 | )) |
||
277 | |||
278 | def test_format_input_labels(self): |
||
279 | m = self.get_model() |
||
280 | m.label_vector = "A^2.0 + B^3.4 + C^5" |
||
281 | |||
282 | kwds = {"A": [5], "B": [3], "C": [0.43]} |
||
283 | for k, v in m._format_input_labels(None, **kwds).items(): |
||
284 | self.assertEqual(kwds[k], v) |
||
285 | for k, v in m._format_input_labels([5, 3, 0.43]).items(): |
||
286 | self.assertEqual(kwds[k], v) |
||
287 | |||
288 | kwds_input = {k: v[0] for k, v in kwds.items() } |
||
289 | for k, v in m._format_input_labels(None, **kwds_input).items(): |
||
290 | self.assertEqual(kwds[k], v) |
||
291 | |||
296 |