1
|
|
|
#!/usr/bin/env python |
2
|
|
|
# -*- coding: utf-8 -*- |
3
|
|
|
|
4
|
|
|
""" |
5
|
|
|
Unit tests for the Cannon model class and associated functions. |
6
|
|
|
""" |
7
|
|
|
|
8
|
|
|
import numpy as np |
9
|
|
|
import sys |
10
|
|
|
import unittest |
11
|
|
|
from six.moves import cPickle as pickle |
12
|
|
|
from os import path, remove |
13
|
|
|
from tempfile import mkstemp |
14
|
|
|
|
15
|
|
|
from AnniesLasso import cannon, utils |
16
|
|
|
|
17
|
|
|
|
18
|
|
|
class TestCannonModel(unittest.TestCase): |
19
|
|
|
|
20
|
|
|
def setUp(self): |
|
|
|
|
21
|
|
|
# Initialise some faux data and labels. |
22
|
|
|
labels = "ABCDE" |
23
|
|
|
N_labels = len(labels) |
24
|
|
|
N_stars = np.random.randint(1, 500) |
25
|
|
|
N_pixels = np.random.randint(1, 10000) |
26
|
|
|
shape = (N_stars, N_pixels) |
27
|
|
|
|
28
|
|
|
self.valid_training_labels = np.rec.array( |
29
|
|
|
np.random.uniform(size=(N_stars, N_labels)), |
30
|
|
|
dtype=[(label, '<f8') for label in labels]) |
31
|
|
|
|
32
|
|
|
self.valid_fluxes = np.random.uniform(size=shape) |
33
|
|
|
self.valid_flux_uncertainties = np.random.uniform(size=shape) |
34
|
|
|
|
35
|
|
|
def get_model(self): |
36
|
|
|
return cannon.CannonModel( |
37
|
|
|
self.valid_training_labels, self.valid_fluxes, |
38
|
|
|
self.valid_flux_uncertainties) |
39
|
|
|
|
40
|
|
|
def test_init(self): |
41
|
|
|
self.assertIsNotNone(self.get_model()) |
42
|
|
|
|
43
|
|
|
|
44
|
|
|
# The test_data_set.pkl contains: |
45
|
|
|
# (training_labels, training_fluxes, training_flux_uncertainties, coefficients, |
46
|
|
|
# scatter, label_vector) |
47
|
|
|
# The training labels are not named, but they are: (TEFF, LOGG, PARAM_M_H) |
48
|
|
|
|
49
|
|
|
class TestCannonModelRealistically(unittest.TestCase): |
50
|
|
|
|
51
|
|
|
def setUp(self): |
52
|
|
|
# Set up a model using the test data set. |
53
|
|
|
here = path.dirname(path.realpath(__file__)) |
54
|
|
|
kwds = { "encoding": "latin1" } \ |
55
|
|
|
if sys.version_info[0] >= 3 else {} |
56
|
|
|
with open(path.join(here, "test_data_set.pkl"), "rb") as fp: |
57
|
|
|
contents = pickle.load(fp, **kwds) |
58
|
|
|
|
59
|
|
|
# Unpack it all |
60
|
|
|
training_labels, training_fluxes, training_flux_uncertainties, \ |
61
|
|
|
coefficients, scatter, pivots, label_vector = contents |
62
|
|
|
|
63
|
|
|
training_labels = np.core.records.fromarrays(training_labels, |
64
|
|
|
names="TEFF,LOGG,PARAM_M_H", formats="f8,f8,f8") |
65
|
|
|
|
66
|
|
|
self.test_data_set = { |
67
|
|
|
"training_labels": training_labels, |
68
|
|
|
"training_fluxes": training_fluxes, |
69
|
|
|
"training_flux_uncertainties": training_flux_uncertainties, |
70
|
|
|
"coefficients": coefficients, |
71
|
|
|
"scatter": scatter, |
72
|
|
|
"pivots": pivots, |
73
|
|
|
"label_vector": label_vector |
74
|
|
|
|
75
|
|
|
} |
76
|
|
|
self.model_serial = cannon.CannonModel(training_labels, training_fluxes, |
77
|
|
|
training_flux_uncertainties) |
78
|
|
|
self.model_parallel = cannon.CannonModel(training_labels, |
79
|
|
|
training_fluxes, training_flux_uncertainties, threads=2) |
80
|
|
|
|
81
|
|
|
self.models = (self.model_serial, self.model_parallel) |
82
|
|
|
|
83
|
|
|
def do_training(self): |
84
|
|
|
for model in self.models: |
85
|
|
|
model.reset() |
86
|
|
|
model.label_vector = self.test_data_set["label_vector"] |
87
|
|
|
self.assertIsNotNone(model.train()) |
88
|
|
|
|
89
|
|
|
# Check that the trained attributes in both model are equal. |
90
|
|
|
for _attribute in self.model_serial._trained_attributes: |
91
|
|
|
|
92
|
|
|
# And nearly as we expected. |
93
|
|
|
expected = self.test_data_set[_attribute[1:]] |
94
|
|
|
if isinstance(expected, dict): |
|
|
|
|
95
|
|
|
for key in expected: |
96
|
|
|
self.assertEqual( |
97
|
|
|
getattr(self.model_serial, _attribute)[key], |
98
|
|
|
getattr(self.model_parallel, _attribute)[key] |
99
|
|
|
) |
100
|
|
|
self.assertEqual(expected[key], |
101
|
|
|
getattr(self.model_serial, _attribute)[key]) |
102
|
|
|
else: |
103
|
|
|
|
104
|
|
|
self.assertTrue(np.allclose( |
105
|
|
|
getattr(self.model_serial, _attribute), |
106
|
|
|
getattr(self.model_parallel, _attribute) |
107
|
|
|
)) |
108
|
|
|
|
109
|
|
|
self.assertTrue(np.allclose( |
110
|
|
|
expected, |
111
|
|
|
getattr(self.model_serial, _attribute))) |
112
|
|
|
#rtol=0.5, atol=1e-8)) |
113
|
|
|
|
114
|
|
|
def do_residuals(self): |
115
|
|
|
serial = self.model_serial.get_training_label_residuals() |
116
|
|
|
parallel = self.model_parallel.get_training_label_residuals() |
117
|
|
|
self.assertTrue(np.allclose(serial, parallel)) |
118
|
|
|
|
119
|
|
|
def ruin_the_trained_coefficients(self): |
120
|
|
|
self.model_serial.scatter = None |
121
|
|
|
self.assertIsNone(self.model_serial.scatter) |
122
|
|
|
|
123
|
|
|
with self.assertRaises(ValueError): |
124
|
|
|
self.model_parallel.scatter = [3] |
125
|
|
|
|
126
|
|
|
for item in (0., False, True): |
127
|
|
|
with self.assertRaises(ValueError): |
128
|
|
|
self.model_parallel.scatter = item |
129
|
|
|
|
130
|
|
|
with self.assertRaises(ValueError): |
131
|
|
|
self.model_parallel.scatter = \ |
132
|
|
|
-1 * np.ones_like(self.model_parallel.dispersion) |
133
|
|
|
|
134
|
|
|
_ = np.array(self.model_parallel.scatter).copy() |
135
|
|
|
_ += 1. |
136
|
|
|
self.model_parallel.scatter = _ |
137
|
|
|
self.assertTrue(np.allclose(_, self.model_parallel.scatter)) |
138
|
|
|
|
139
|
|
|
|
140
|
|
|
self.model_serial.coefficients = None |
141
|
|
|
self.assertIsNone(self.model_serial.coefficients) |
142
|
|
|
|
143
|
|
|
with self.assertRaises(ValueError): |
144
|
|
|
self.model_parallel.coefficients = np.arange(12).reshape((3, 2, 2)) |
145
|
|
|
|
146
|
|
|
with self.assertRaises(ValueError): |
147
|
|
|
_ = np.ones_like(self.model_parallel.coefficients) |
148
|
|
|
self.model_parallel.coefficients = _.T |
149
|
|
|
|
150
|
|
|
with self.assertRaises(ValueError): |
151
|
|
|
_ = np.ones_like(self.model_parallel.coefficients) |
152
|
|
|
self.model_parallel.coefficients = _[:, :-1] |
153
|
|
|
|
154
|
|
|
_ = np.array(self.model_parallel.coefficients).copy() |
155
|
|
|
_ += 0.5 |
156
|
|
|
self.model_parallel.coefficients = _ |
157
|
|
|
self.assertTrue(np.allclose(_, self.model_parallel.coefficients)) |
158
|
|
|
|
159
|
|
|
def do_io(self): |
160
|
|
|
|
161
|
|
|
_, temp_filename = mkstemp() |
162
|
|
|
remove(temp_filename) |
163
|
|
|
self.model_serial.save(temp_filename, include_training_data=False) |
164
|
|
|
with self.assertRaises(IOError): |
165
|
|
|
self.model_serial.save(temp_filename, overwrite=False) |
166
|
|
|
|
167
|
|
|
names = ("_data_attributes", "_trained_attributes", |
168
|
|
|
"_descriptive_attributes") |
169
|
|
|
attrs = ( |
170
|
|
|
self.model_serial._data_attributes, |
171
|
|
|
self.model_serial._trained_attributes, |
172
|
|
|
self.model_serial._descriptive_attributes |
173
|
|
|
) |
174
|
|
|
for name, item in zip(names, attrs): |
175
|
|
|
_ = [] + list(item) |
176
|
|
|
_.append("metadata") |
177
|
|
|
setattr(self.model_serial, name, _) |
178
|
|
|
with self.assertRaises(ValueError): |
179
|
|
|
self.model_serial.save(temp_filename, overwrite=True) |
180
|
|
|
setattr(self.model_serial, name, _[:-1]) |
181
|
|
|
|
182
|
|
|
self.model_serial.save(temp_filename, include_training_data=True, |
183
|
|
|
overwrite=True) |
184
|
|
|
|
185
|
|
|
self.model_parallel.reset() |
186
|
|
|
self.model_parallel.load(temp_filename, verify_training_data=True) |
187
|
|
|
|
188
|
|
|
# Check that the trained attributes in both model are equal. |
189
|
|
|
for _attribute in self.model_serial._trained_attributes: |
190
|
|
|
|
191
|
|
|
# And nearly as we expected. |
192
|
|
|
expected = self.test_data_set[_attribute[1:]] |
193
|
|
|
if isinstance(expected, dict): |
|
|
|
|
194
|
|
|
for key in expected: |
195
|
|
|
self.assertEqual( |
196
|
|
|
getattr(self.model_serial, _attribute)[key], |
197
|
|
|
getattr(self.model_parallel, _attribute)[key] |
198
|
|
|
) |
199
|
|
|
self.assertEqual(expected[key], |
200
|
|
|
getattr(self.model_serial, _attribute)[key]) |
201
|
|
|
else: |
202
|
|
|
|
203
|
|
|
self.assertTrue(np.allclose( |
204
|
|
|
getattr(self.model_serial, _attribute), |
205
|
|
|
getattr(self.model_parallel, _attribute) |
206
|
|
|
)) |
207
|
|
|
|
208
|
|
|
self.assertTrue(np.allclose( |
209
|
|
|
expected, |
210
|
|
|
getattr(self.model_serial, _attribute))) |
211
|
|
|
#rtol=0.5, atol=1e-8)) |
212
|
|
|
|
213
|
|
|
|
214
|
|
|
# Check that the data attributes in both model are equal. |
215
|
|
|
for _attribute in self.model_serial._data_attributes: |
216
|
|
|
self.assertTrue( |
217
|
|
|
utils.short_hash(getattr(self.model_serial, _attribute)), |
218
|
|
|
utils.short_hash(getattr(self.model_parallel, _attribute)) |
219
|
|
|
) |
220
|
|
|
|
221
|
|
|
# Alter the hash and expect failure |
222
|
|
|
kwds = { "encoding": "latin1" } if sys.version_info[0] >= 3 else {} |
223
|
|
|
with open(temp_filename, "rb") as fp: |
224
|
|
|
contents = pickle.load(fp, **kwds) |
225
|
|
|
|
226
|
|
|
contents["training_set_hash"] = "" |
227
|
|
|
with open(temp_filename, "wb") as fp: |
228
|
|
|
pickle.dump(contents, fp, -1) |
229
|
|
|
|
230
|
|
|
with self.assertRaises(ValueError): |
231
|
|
|
self.model_serial.load(temp_filename, verify_training_data=True) |
232
|
|
|
|
233
|
|
|
if path.exists(temp_filename): |
234
|
|
|
remove(temp_filename) |
235
|
|
|
|
236
|
|
|
def do_cv(self): |
237
|
|
|
self.model_parallel.cross_validate(N=1, debug=True) |
238
|
|
|
|
239
|
|
|
def choo_choo(old, new): |
240
|
|
|
None |
241
|
|
|
|
242
|
|
|
self.model_parallel.cross_validate(N=1, debug=True, pre_train=choo_choo) |
243
|
|
|
|
244
|
|
|
def do_predict(self): |
245
|
|
|
_ = [self.model_serial.training_labels[label][0] \ |
246
|
|
|
for label in self.model_serial.labels] |
247
|
|
|
self.assertTrue(np.allclose( |
248
|
|
|
self.model_serial.predict(_), |
249
|
|
|
self.model_serial.predict(**dict(zip(self.model_serial.labels, _))))) |
250
|
|
|
|
251
|
|
|
def do_fit(self): |
252
|
|
|
self.assertIsNotNone( |
253
|
|
|
self.model_serial.fit(self.model_serial.training_fluxes[0], |
254
|
|
|
self.model_serial.training_flux_uncertainties[0], |
255
|
|
|
full_output=True)) |
256
|
|
|
|
257
|
|
|
def do_edge_cases(self): |
258
|
|
|
self.model_serial.reset() |
259
|
|
|
|
260
|
|
|
# This label vector only contains one term in cross-terms (PARAM_M_H) |
261
|
|
|
self.model_serial.label_vector = \ |
262
|
|
|
"TEFF^3 + TEFF^2 + TEFF + LOGG + PARAM_M_H*LOGG" |
263
|
|
|
self.assertIn(None, self.model_serial._get_lowest_order_label_indices()) |
264
|
|
|
|
265
|
|
|
# Set large uncertainties for one pixel. |
266
|
|
|
self.model_serial._training_flux_uncertainties[:, 0] = 10. |
267
|
|
|
self.model_serial._training_fluxes[:, 1] = \ |
268
|
|
|
np.random.uniform(low=-0.5, high=0.5, |
269
|
|
|
size=self.model_serial._training_fluxes.shape[0]) |
270
|
|
|
|
271
|
|
|
# Train and fit using this unusual label vector. |
272
|
|
|
self.model_serial.train() |
273
|
|
|
self.model_serial.fit(self.model_serial._training_fluxes[1], |
274
|
|
|
self.model_serial._training_flux_uncertainties[1]) |
275
|
|
|
|
276
|
|
|
# See if we can make things break or warn. |
277
|
|
|
self.model_serial._training_fluxes[10] = 1000. |
278
|
|
|
self.model_serial._training_flux_uncertainties[10] = 0. |
279
|
|
|
self.model_serial.reset() |
280
|
|
|
self.model_serial.label_vector = "TEFF^5 + LOGG^3 + PARAM_M_H^5" |
281
|
|
|
for label in self.model_serial.labels: |
282
|
|
|
self.model_serial._training_labels[label] = 0. |
283
|
|
|
self.model_serial.train() |
284
|
|
|
|
285
|
|
|
# TODO: Force things to break |
286
|
|
|
#with self.assertRaises(np.linalg.linalg.LinAlgError): |
287
|
|
|
# self.model_serial.train(debug=True) |
288
|
|
|
|
289
|
|
|
#with self.assertRaises(np.linalg.linalg.LinAlgError): |
290
|
|
|
# self.model_serial.cross_validate(N=1, debug=True) |
291
|
|
|
|
292
|
|
|
def runTest(self): |
293
|
|
|
|
294
|
|
|
# Train all. |
295
|
|
|
self.do_training() |
296
|
|
|
|
297
|
|
|
self.do_residuals() |
298
|
|
|
|
299
|
|
|
self.ruin_the_trained_coefficients() |
300
|
|
|
|
301
|
|
|
# Train again. |
302
|
|
|
self.do_training() |
303
|
|
|
|
304
|
|
|
# Predict stuff. |
305
|
|
|
self.do_predict() |
306
|
|
|
|
307
|
|
|
self.do_fit() |
308
|
|
|
|
309
|
|
|
# Do cross-validation. |
310
|
|
|
self.do_cv() |
311
|
|
|
|
312
|
|
|
# Try I/O/ |
313
|
|
|
self.do_io() |
314
|
|
|
|
315
|
|
|
# Do_edges |
316
|
|
|
self.do_edge_cases() |
317
|
|
|
|
318
|
|
|
|
Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.
You can also find more detailed suggestions in the “Code” section of your repository.