Completed
Push — master ( eb70ed...446da2 )
by Andy
59s
created

AnniesLasso.tests.TestCannonModel   A

Complexity

Total Complexity 4

Size/Duplication

Total Lines 24
Duplicated Lines 0 %
Metric Value
wmc 4
dl 0
loc 24
rs 10

3 Methods

Rating   Name   Duplication   Size   Complexity  
A get_model() 0 4 1
A setUp() 0 14 2
A test_init() 0 2 1
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
"""
5
Unit tests for the Cannon model class and associated functions.
6
"""
7
8
import numpy as np
9
import sys
10
import unittest
11
from six.moves import cPickle as pickle
12
from os import path, remove
13
from tempfile import mkstemp
14
15
from AnniesLasso import cannon, utils
16
17
18
class TestCannonModel(unittest.TestCase):
19
20
    def setUp(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
21
        # Initialise some faux data and labels.
22
        labels = "ABCDE"
23
        N_labels = len(labels)
24
        N_stars = np.random.randint(1, 500)
25
        N_pixels = np.random.randint(1, 10000)
26
        shape = (N_stars, N_pixels)
27
28
        self.valid_training_labels = np.rec.array(
29
            np.random.uniform(size=(N_stars, N_labels)),
30
            dtype=[(label, '<f8') for label in labels])
31
32
        self.valid_fluxes = np.random.uniform(size=shape)
33
        self.valid_flux_uncertainties = np.random.uniform(size=shape)
34
35
    def get_model(self):
36
        return cannon.CannonModel(
37
            self.valid_training_labels, self.valid_fluxes,
38
            self.valid_flux_uncertainties)
39
40
    def test_init(self):
41
        self.assertIsNotNone(self.get_model())
42
43
44
# The test_data_set.pkl contains:
45
# (training_labels, training_fluxes, training_flux_uncertainties, coefficients,
46
#  scatter, label_vector)
47
# The training labels are not named, but they are: (TEFF, LOGG, PARAM_M_H)
48
49
class TestCannonModelRealistically(unittest.TestCase):
50
51
    def setUp(self):
52
        # Set up a model using the test data set.
53
        here = path.dirname(path.realpath(__file__))
54
        kwds = { "encoding": "latin1" } \
55
            if sys.version_info[0] >= 3 else {}
56
        with open(path.join(here, "test_data_set.pkl"), "rb") as fp:
57
            contents = pickle.load(fp, **kwds)
58
59
        # Unpack it all 
60
        training_labels, training_fluxes, training_flux_uncertainties, \
61
            coefficients, scatter, label_vector = contents
62
63
        training_labels = np.core.records.fromarrays(training_labels,
64
            names="TEFF,LOGG,PARAM_M_H", formats="f8,f8,f8")
65
66
        self.test_data_set = {
67
            "training_labels": training_labels,
68
            "training_fluxes": training_fluxes,
69
            "training_flux_uncertainties": training_flux_uncertainties,
70
            "coefficients": coefficients,
71
            "scatter": scatter,
72
            "label_vector": label_vector
73
74
        }
75
        self.model_serial = cannon.CannonModel(training_labels, training_fluxes,
76
            training_flux_uncertainties)
77
        self.model_parallel = cannon.CannonModel(training_labels,
78
            training_fluxes, training_flux_uncertainties, threads=2)
79
80
        self.models = (self.model_serial, self.model_parallel)
81
82
    def do_training(self):
83
        for model in self.models:
84
            model.reset()
85
            model.label_vector = self.test_data_set["label_vector"]
86
            self.assertIsNotNone(model.train())
87
88
        # Check that the trained attributes in both model are equal.
89
        for _attribute in self.model_serial._trained_attributes:
90
            self.assertTrue(np.allclose(
91
                getattr(self.model_serial, _attribute),
92
                getattr(self.model_parallel, _attribute)
93
                ))
94
95
            # And nearly as we expected.
96
            self.assertTrue(np.allclose(
97
                self.test_data_set[_attribute[1:]],
98
                getattr(self.model_serial, _attribute),
99
                rtol=0.5, atol=1e-8))
100
101
    def do_residuals(self):
102
        serial = self.model_serial.get_training_label_residuals()
103
        parallel = self.model_parallel.get_training_label_residuals()
104
        self.assertTrue(np.allclose(serial, parallel))
105
106
    def ruin_the_trained_coefficients(self):
107
        self.model_serial.scatter = None
108
        self.assertIsNone(self.model_serial.scatter)
109
110
        with self.assertRaises(ValueError):
111
            self.model_parallel.scatter = [3]
112
113
        for item in (0., False, True):
114
            with self.assertRaises(ValueError):
115
                self.model_parallel.scatter = item
116
117
        with self.assertRaises(ValueError):
118
            self.model_parallel.scatter = \
119
                -1 * np.ones_like(self.model_parallel.dispersion)
120
121
        _ = np.array(self.model_parallel.scatter).copy()
122
        _ += 1.
123
        self.model_parallel.scatter = _
124
        self.assertTrue(np.allclose(_, self.model_parallel.scatter))
125
126
127
        self.model_serial.coefficients = None
128
        self.assertIsNone(self.model_serial.coefficients)
129
130
        with self.assertRaises(ValueError):
131
            self.model_parallel.coefficients = np.arange(12).reshape((3, 2, 2))
132
133
        with self.assertRaises(ValueError):
134
            _ = np.ones_like(self.model_parallel.coefficients)
135
            self.model_parallel.coefficients = _.T
136
137
        with self.assertRaises(ValueError):
138
            _ = np.ones_like(self.model_parallel.coefficients)
139
            self.model_parallel.coefficients = _[:, :-1]
140
        
141
        _ = np.array(self.model_parallel.coefficients).copy()
142
        _ += 0.5
143
        self.model_parallel.coefficients = _
144
        self.assertTrue(np.allclose(_, self.model_parallel.coefficients))
145
146
    def do_io(self):
147
148
        _, temp_filename = mkstemp()
149
        remove(temp_filename)
150
        self.model_serial.save(temp_filename, include_training_data=False)
151
        with self.assertRaises(IOError):
152
            self.model_serial.save(temp_filename, overwrite=False)
153
154
        names = ("_data_attributes", "_trained_attributes",
155
            "_descriptive_attributes")
156
        attrs = (
157
            self.model_serial._data_attributes,
158
            self.model_serial._trained_attributes,
159
            self.model_serial._descriptive_attributes
160
            )
161
        for name, item in zip(names, attrs):
162
            _ = [] + list(item)
163
            _.append("metadata")
164
            setattr(self.model_serial, name, _)
165
            with self.assertRaises(ValueError):
166
                self.model_serial.save(temp_filename, overwrite=True)
167
            setattr(self.model_serial, name, _[:-1])
168
169
        self.model_serial.save(temp_filename, include_training_data=True,
170
            overwrite=True)
171
172
        self.model_parallel.reset()
173
        self.model_parallel.load(temp_filename, verify_training_data=True)
174
175
        # Check that the trained attributes in both model are equal.
176
        for _attribute in self.model_serial._trained_attributes:
177
            self.assertTrue(np.allclose(
178
                getattr(self.model_serial, _attribute),
179
                getattr(self.model_parallel, _attribute)
180
                ))
181
182
            # And nearly as we expected.
183
            self.assertTrue(np.allclose(
184
                self.test_data_set[_attribute[1:]],
185
                getattr(self.model_serial, _attribute),
186
                rtol=0.5, atol=1e-8))
187
188
        # Check that the data attributes in both model are equal.
189
        for _attribute in self.model_serial._data_attributes:
190
            self.assertTrue(
191
                utils.short_hash(getattr(self.model_serial, _attribute)),
192
                utils.short_hash(getattr(self.model_parallel, _attribute))
193
            )
194
195
196
        # Alter the hash and expect failure
197
        kwds = { "encoding": "latin1" } if sys.version_info[0] >= 3 else {}
198
        with open(temp_filename, "rb") as fp:
199
            contents = pickle.load(fp, **kwds)
200
201
        contents["training_set_hash"] = ""
202
        with open(temp_filename, "wb") as fp:
203
            pickle.dump(contents, fp, -1)
204
205
        with self.assertRaises(ValueError):
206
            self.model_serial.load(temp_filename, verify_training_data=True)
207
208
        if path.exists(temp_filename):
209
            remove(temp_filename)
210
211
    def do_cv(self):
212
        self.model_parallel.cross_validate(N=1, debug=True)
213
214
        def choo_choo(old, new):
215
            None
216
217
        self.model_parallel.cross_validate(N=1, debug=True, pre_train=choo_choo)
218
219
    def do_predict(self):
220
        _ = [self.model_serial.training_labels[label][0] \
221
            for label in self.model_serial.labels]
222
        self.assertTrue(np.allclose(
223
            self.model_serial.predict(_),
224
            self.model_serial.predict(**dict(zip(self.model_serial.labels, _)))))
225
226
    def do_fit(self):
227
        self.assertIsNotNone(
228
            self.model_serial.fit(self.model_serial.training_fluxes[0],
229
                self.model_serial.training_flux_uncertainties[0],
230
                full_output=True))
231
232
    def do_edge_cases(self):
233
        self.model_serial.reset()
234
235
        # This label vector only contains one term in cross-terms (PARAM_M_H)
236
        self.model_serial.label_vector = \
237
            "TEFF^3 + TEFF^2 + TEFF + LOGG + PARAM_M_H*LOGG"
238
        self.assertIn(None, self.model_serial._get_lowest_order_label_indices())
239
240
        # Set large uncertainties for one pixel.
241
        self.model_serial._training_flux_uncertainties[:, 0] = 10.
242
        self.model_serial._training_fluxes[:, 1] = \
243
            np.random.uniform(low=-0.5, high=0.5,
244
                size=self.model_serial._training_fluxes.shape[0])
245
246
        # Train and fit using this unusual label vector.
247
        self.model_serial.train()
248
        self.model_serial.fit(self.model_serial._training_fluxes[1],
249
            self.model_serial._training_flux_uncertainties[1])
250
251
        # See if we can make things break or warn.
252
        self.model_serial._training_fluxes[10] = 10.
253
        self.model_serial._training_flux_uncertainties[10] = 0.99
254
255
        self.model_serial._training_flux_uncertainties[11] = 0.
256
        self.model_serial.reset()
257
        self.model_serial.label_vector = "TEFF^5 + LOGG^3 + PARAM_M_H^5"
258
        for label in self.model_serial.labels:
259
            self.model_serial._training_labels[label] = 0.
260
        self.model_serial.train()
261
262
        with self.assertRaises(np.linalg.linalg.LinAlgError):
263
            self.model_serial.train(debug=True)
264
265
        with self.assertRaises(np.linalg.linalg.LinAlgError):
266
            self.model_serial.cross_validate(N=1, debug=True)
267
268
    def runTest(self):
269
270
        # Train all.
271
        self.do_training()
272
273
        self.do_residuals()
274
275
        self.ruin_the_trained_coefficients()
276
277
        # Train again.
278
        self.do_training()
279
280
        # Predict stuff.
281
        self.do_predict()
282
283
        self.do_fit()
284
285
        # Do cross-validation.
286
        self.do_cv()
287
288
        # Try I/O/
289
        self.do_io()
290
291
        # Do_edges
292
        self.do_edge_cases()
293
294