Completed
Push — master ( 53c124...83b528 )
by Andy
50s
created

do_training()   B

Complexity

Conditions 5

Size

Total Lines 29

Duplication

Lines 0
Ratio 0 %
Metric Value
dl 0
loc 29
rs 8.0896
cc 5
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
"""
5
Unit tests for the Cannon model class and associated functions.
6
"""
7
8
import numpy as np
9
import sys
10
import unittest
11
from six.moves import cPickle as pickle
12
from os import path, remove
13
from tempfile import mkstemp
14
15
from AnniesLasso import cannon, utils
16
17
18
class TestCannonModel(unittest.TestCase):
19
20
    def setUp(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
21
        # Initialise some faux data and labels.
22
        labels = "ABCDE"
23
        N_labels = len(labels)
24
        N_stars = np.random.randint(1, 500)
25
        N_pixels = np.random.randint(1, 10000)
26
        shape = (N_stars, N_pixels)
27
28
        self.valid_training_labels = np.rec.array(
29
            np.random.uniform(size=(N_stars, N_labels)),
30
            dtype=[(label, '<f8') for label in labels])
31
32
        self.valid_fluxes = np.random.uniform(size=shape)
33
        self.valid_flux_uncertainties = np.random.uniform(size=shape)
34
35
    def get_model(self):
36
        return cannon.CannonModel(
37
            self.valid_training_labels, self.valid_fluxes,
38
            self.valid_flux_uncertainties)
39
40
    def test_init(self):
41
        self.assertIsNotNone(self.get_model())
42
43
44
# The test_data_set.pkl contains:
45
# (training_labels, training_fluxes, training_flux_uncertainties, coefficients,
46
#  scatter, label_vector)
47
# The training labels are not named, but they are: (TEFF, LOGG, PARAM_M_H)
48
49
class TestCannonModelRealistically(unittest.TestCase):
50
51
    def setUp(self):
52
        # Set up a model using the test data set.
53
        here = path.dirname(path.realpath(__file__))
54
        kwds = { "encoding": "latin1" } \
55
            if sys.version_info[0] >= 3 else {}
56
        with open(path.join(here, "test_data_set.pkl"), "rb") as fp:
57
            contents = pickle.load(fp, **kwds)
58
59
        # Unpack it all 
60
        training_labels, training_fluxes, training_flux_uncertainties, \
61
            coefficients, scatter, pivots, label_vector = contents
62
63
        training_labels = np.core.records.fromarrays(training_labels,
64
            names="TEFF,LOGG,PARAM_M_H", formats="f8,f8,f8")
65
66
        self.test_data_set = {
67
            "training_labels": training_labels,
68
            "training_fluxes": training_fluxes,
69
            "training_flux_uncertainties": training_flux_uncertainties,
70
            "coefficients": coefficients,
71
            "scatter": scatter,
72
            "pivots": pivots,
73
            "label_vector": label_vector
74
75
        }
76
        self.model_serial = cannon.CannonModel(training_labels, training_fluxes,
77
            training_flux_uncertainties)
78
        self.model_parallel = cannon.CannonModel(training_labels,
79
            training_fluxes, training_flux_uncertainties, threads=2)
80
81
        self.models = (self.model_serial, self.model_parallel)
82
83
    def do_training(self):
84
        for model in self.models:
85
            model.reset()
86
            model.label_vector = self.test_data_set["label_vector"]
87
            self.assertIsNotNone(model.train())
88
89
        # Check that the trained attributes in both model are equal.
90
        for _attribute in self.model_serial._trained_attributes:
91
92
            # And nearly as we expected.
93
            expected = self.test_data_set[_attribute[1:]]
94
            if isinstance(expected, dict):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
95
                for key in expected:
96
                    self.assertEqual(
97
                        getattr(self.model_serial, _attribute)[key],
98
                        getattr(self.model_parallel, _attribute)[key]
99
                    )
100
                    self.assertEqual(expected[key],
101
                        getattr(self.model_serial, _attribute)[key])
102
            else:
103
104
                self.assertTrue(np.allclose(
105
                    getattr(self.model_serial, _attribute),
106
                    getattr(self.model_parallel, _attribute)
107
                    ))
108
109
                self.assertTrue(np.allclose(
110
                    expected,
111
                    getattr(self.model_serial, _attribute)))
112
                    #rtol=0.5, atol=1e-8))
113
114
    def do_residuals(self):
115
        serial = self.model_serial.get_training_label_residuals()
116
        parallel = self.model_parallel.get_training_label_residuals()
117
        self.assertTrue(np.allclose(serial, parallel))
118
119
    def ruin_the_trained_coefficients(self):
120
        self.model_serial.scatter = None
121
        self.assertIsNone(self.model_serial.scatter)
122
123
        with self.assertRaises(ValueError):
124
            self.model_parallel.scatter = [3]
125
126
        for item in (0., False, True):
127
            with self.assertRaises(ValueError):
128
                self.model_parallel.scatter = item
129
130
        with self.assertRaises(ValueError):
131
            self.model_parallel.scatter = \
132
                -1 * np.ones_like(self.model_parallel.dispersion)
133
134
        _ = np.array(self.model_parallel.scatter).copy()
135
        _ += 1.
136
        self.model_parallel.scatter = _
137
        self.assertTrue(np.allclose(_, self.model_parallel.scatter))
138
139
140
        self.model_serial.coefficients = None
141
        self.assertIsNone(self.model_serial.coefficients)
142
143
        with self.assertRaises(ValueError):
144
            self.model_parallel.coefficients = np.arange(12).reshape((3, 2, 2))
145
146
        with self.assertRaises(ValueError):
147
            _ = np.ones_like(self.model_parallel.coefficients)
148
            self.model_parallel.coefficients = _.T
149
150
        with self.assertRaises(ValueError):
151
            _ = np.ones_like(self.model_parallel.coefficients)
152
            self.model_parallel.coefficients = _[:, :-1]
153
        
154
        _ = np.array(self.model_parallel.coefficients).copy()
155
        _ += 0.5
156
        self.model_parallel.coefficients = _
157
        self.assertTrue(np.allclose(_, self.model_parallel.coefficients))
158
159
    def do_io(self):
160
161
        _, temp_filename = mkstemp()
162
        remove(temp_filename)
163
        self.model_serial.save(temp_filename, include_training_data=False)
164
        with self.assertRaises(IOError):
165
            self.model_serial.save(temp_filename, overwrite=False)
166
167
        names = ("_data_attributes", "_trained_attributes",
168
            "_descriptive_attributes")
169
        attrs = (
170
            self.model_serial._data_attributes,
171
            self.model_serial._trained_attributes,
172
            self.model_serial._descriptive_attributes
173
            )
174
        for name, item in zip(names, attrs):
175
            _ = [] + list(item)
176
            _.append("metadata")
177
            setattr(self.model_serial, name, _)
178
            with self.assertRaises(ValueError):
179
                self.model_serial.save(temp_filename, overwrite=True)
180
            setattr(self.model_serial, name, _[:-1])
181
182
        self.model_serial.save(temp_filename, include_training_data=True,
183
            overwrite=True)
184
185
        self.model_parallel.reset()
186
        self.model_parallel.load(temp_filename, verify_training_data=True)
187
188
        # Check that the trained attributes in both model are equal.
189
        for _attribute in self.model_serial._trained_attributes:
190
191
            # And nearly as we expected.
192
            expected = self.test_data_set[_attribute[1:]]
193
            if isinstance(expected, dict):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
194
                for key in expected:
195
                    self.assertEqual(
196
                        getattr(self.model_serial, _attribute)[key],
197
                        getattr(self.model_parallel, _attribute)[key]
198
                    )
199
                    self.assertEqual(expected[key],
200
                        getattr(self.model_serial, _attribute)[key])
201
            else:
202
203
                self.assertTrue(np.allclose(
204
                    getattr(self.model_serial, _attribute),
205
                    getattr(self.model_parallel, _attribute)
206
                    ))
207
208
                self.assertTrue(np.allclose(
209
                    expected,
210
                    getattr(self.model_serial, _attribute)))
211
                    #rtol=0.5, atol=1e-8))
212
213
214
        # Check that the data attributes in both model are equal.
215
        for _attribute in self.model_serial._data_attributes:
216
            self.assertTrue(
217
                utils.short_hash(getattr(self.model_serial, _attribute)),
218
                utils.short_hash(getattr(self.model_parallel, _attribute))
219
            )
220
221
        # Alter the hash and expect failure
222
        kwds = { "encoding": "latin1" } if sys.version_info[0] >= 3 else {}
223
        with open(temp_filename, "rb") as fp:
224
            contents = pickle.load(fp, **kwds)
225
226
        contents["training_set_hash"] = ""
227
        with open(temp_filename, "wb") as fp:
228
            pickle.dump(contents, fp, -1)
229
230
        with self.assertRaises(ValueError):
231
            self.model_serial.load(temp_filename, verify_training_data=True)
232
233
        if path.exists(temp_filename):
234
            remove(temp_filename)
235
236
    def do_cv(self):
237
        self.model_parallel.cross_validate(N=1, debug=True)
238
239
        def choo_choo(old, new):
240
            None
241
242
        self.model_parallel.cross_validate(N=1, debug=True, pre_train=choo_choo)
243
244
    def do_predict(self):
245
        _ = [self.model_serial.training_labels[label][0] \
246
            for label in self.model_serial.labels]
247
        self.assertTrue(np.allclose(
248
            self.model_serial.predict(_),
249
            self.model_serial.predict(**dict(zip(self.model_serial.labels, _)))))
250
251
    def do_fit(self):
252
        self.assertIsNotNone(
253
            self.model_serial.fit(self.model_serial.training_fluxes[0],
254
                self.model_serial.training_flux_uncertainties[0],
255
                full_output=True))
256
257
    def do_edge_cases(self):
258
        self.model_serial.reset()
259
260
        # This label vector only contains one term in cross-terms (PARAM_M_H)
261
        self.model_serial.label_vector = \
262
            "TEFF^3 + TEFF^2 + TEFF + LOGG + PARAM_M_H*LOGG"
263
        self.assertIn(None, self.model_serial._get_lowest_order_label_indices())
264
265
        # Set large uncertainties for one pixel.
266
        self.model_serial._training_flux_uncertainties[:, 0] = 10.
267
        self.model_serial._training_fluxes[:, 1] = \
268
            np.random.uniform(low=-0.5, high=0.5,
269
                size=self.model_serial._training_fluxes.shape[0])
270
271
        # Train and fit using this unusual label vector.
272
        self.model_serial.train()
273
        self.model_serial.fit(self.model_serial._training_fluxes[1],
274
            self.model_serial._training_flux_uncertainties[1])
275
276
        # See if we can make things break or warn.
277
        self.model_serial._training_fluxes[10] = 1000.
278
        self.model_serial._training_flux_uncertainties[10] = 0.
279
        self.model_serial.reset()
280
        self.model_serial.label_vector = "TEFF^5 + LOGG^3 + PARAM_M_H^5"
281
        for label in self.model_serial.labels:
282
            self.model_serial._training_labels[label] = 0.
283
        self.model_serial.train()
284
285
        # TODO: Force things to break
286
        #with self.assertRaises(np.linalg.linalg.LinAlgError):
287
        #    self.model_serial.train(debug=True)
288
289
        #with self.assertRaises(np.linalg.linalg.LinAlgError):
290
        #    self.model_serial.cross_validate(N=1, debug=True)
291
292
    def runTest(self):
293
294
        # Train all.
295
        self.do_training()
296
297
        self.do_residuals()
298
299
        self.ruin_the_trained_coefficients()
300
301
        # Train again.
302
        self.do_training()
303
304
        # Predict stuff.
305
        self.do_predict()
306
307
        self.do_fit()
308
309
        # Do cross-validation.
310
        self.do_cv()
311
312
        # Try I/O/
313
        self.do_io()
314
315
        # Do_edges
316
        self.do_edge_cases()
317
318