Completed
Push — master ( eb70ed...446da2 )
by Andy
59s
created

AnniesLasso.tests.TestBaseCannonModel.get_model()   A

Complexity

Conditions 1

Size

Total Lines 4

Duplication

Lines 0
Ratio 0 %
Metric Value
dl 0
loc 4
rs 10
cc 1
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
"""
5
Unit tests for the base model class and associated functions.
6
"""
7
8
import numpy as np
9
import unittest
10
from AnniesLasso import model, utils
11
12
13
class NullObject(object):
14
    pass
15
16
17
class TestRequiresTrainingWheels(unittest.TestCase):
18
    def test_not_trained(self):
19
        o = NullObject()
20
        o.is_trained = False
21
        with self.assertRaises(TypeError):
22
            model.requires_training_wheels(lambda x: None)(o)
23
24
    def test_is_trained(self):
25
        o = NullObject()
26
        o.is_trained = True
27
        self.assertIsNone(model.requires_training_wheels(lambda x: None)(o))
28
29
30
class TestRequiresLabelVector(unittest.TestCase):
31
    def test_with_label_vector(self):
32
        o = NullObject()
33
        o.label_vector = ""
34
        self.assertIsNone(model.requires_label_vector(lambda x: None)(o))
35
36
    def test_without_label_vector(self):
37
        o = NullObject()
38
        o.label_vector = None
39
        with self.assertRaises(TypeError):
40
            model.requires_label_vector(lambda x: None)(o)
41
42
43
class TestBaseCannonModel(unittest.TestCase):
44
45
    def setUp(self):
46
        # Initialise some faux data and labels.
47
        labels = "ABCDE"
48
        N_labels = len(labels)
49
        N_stars = np.random.randint(10, 500)
50
        N_pixels = np.random.randint(1, 10000)
51
        shape = (N_stars, N_pixels)
52
53
        self.valid_training_labels = np.rec.array(
54
            np.random.uniform(low=0.5, high=1.5, size=(N_stars, N_labels)),
55
            dtype=[(label, '<f8') for label in labels])
56
57
        self.valid_fluxes = np.random.uniform(low=0.5, high=1.5, size=shape)
58
        self.valid_flux_uncertainties = np.random.uniform(low=0.5, high=1.5,
59
            size=shape)
60
61
    def runTest(self):
62
        None
63
    
64
    def get_model(self, **kwargs):
65
        return model.BaseCannonModel(
66
            self.valid_training_labels, self.valid_fluxes,
67
            self.valid_flux_uncertainties, **kwargs)
68
69
    def test_repr(self):
70
        self.runTest() # Just for that 100%, baby.
71
        m = self.get_model()
72
        print("{0} {1}".format(m.__str__(), m.__repr__()))
73
74
    def test_get_dispersion(self):
75
        m = self.get_model()
76
        self.assertSequenceEqual(
77
            tuple(m.dispersion), 
78
            tuple(np.arange(self.valid_fluxes.shape[1])))
79
80
    def test_set_dispersion(self):
81
        m = self.get_model()
82
        for item in (None, False, True):
83
            # Incorrect data type (not an iterable)
84
            with self.assertRaises(TypeError):
85
                m.dispersion = item
86
87
        for item in ("", {}, [], (), set()):
88
            # These are iterable but have the wrong lengths.
89
            with self.assertRaises(ValueError):
90
                m.dispersion = item
91
92
        with self.assertRaises(ValueError):
93
            m.dispersion = [3,4,2,1]
94
95
        # These should work.
96
        m.dispersion = 10 + np.arange(self.valid_fluxes.shape[1])
97
        m.dispersion = -100 + np.arange(self.valid_fluxes.shape[1])
98
        m.dispersion = 520938.4 + np.arange(self.valid_fluxes.shape[1])
99
100
        # Disallow non-finite numbers.
101
        with self.assertRaises(ValueError):
102
            d = np.arange(self.valid_fluxes.shape[1], dtype=float)
103
            d[0] = np.nan
104
            m.dispersion = d
105
106
        with self.assertRaises(ValueError):
107
            d = np.arange(self.valid_fluxes.shape[1], dtype=float)
108
            d[0] = np.inf
109
            m.dispersion = d
110
111
        with self.assertRaises(ValueError):
112
            d = np.arange(self.valid_fluxes.shape[1], dtype=float)
113
            d[0] = -np.inf
114
            m.dispersion = d
115
116
        # Disallow non-float like things.
117
        with self.assertRaises(ValueError):
118
            d = np.array([""] * self.valid_fluxes.shape[1])
119
            m.dispersion = d
120
121
        with self.assertRaises(ValueError):
122
            d = np.array([None] * self.valid_fluxes.shape[1])
123
            m.dispersion = d
124
        
125
    def test_get_training_data(self):
126
        m = self.get_model()
127
        self.assertIsNotNone(m.training_labels)
128
        self.assertIsNotNone(m.training_fluxes)
129
        self.assertIsNotNone(m.training_flux_uncertainties)
130
131
    def test_invalid_label_names(self):
132
        m = self.get_model()
133
        for character in m._forbidden_label_characters:
134
135
            invalid_labels = [] + list(m.labels_available)
136
            invalid_labels[0] = "HELLO_{}".format(character)
137
138
            N_stars = len(self.valid_training_labels)
139
            N_labels = len(invalid_labels)
140
            invalid_training_labels = np.rec.array(
141
                np.random.uniform(size=(N_stars, N_labels)),
142
                dtype=[(l, '<f8') for l in invalid_labels])
143
144
            m = model.BaseCannonModel(invalid_training_labels,
145
                self.valid_fluxes, self.valid_flux_uncertainties,
146
                live_dangerously=True)
147
148
            m._forbidden_label_characters = None
149
            self.assertTrue(m._verify_labels_available())
150
151
            with self.assertRaises(ValueError):
152
                m = model.BaseCannonModel(invalid_training_labels,
153
                    self.valid_fluxes, self.valid_flux_uncertainties)
154
155
    def test_get_label_vector(self):
156
        m = self.get_model()
157
        m.label_vector = "A + B + C"
158
        self.assertEqual(m.pixel_label_vector(1), [
159
            [("A", 1)],
160
            [("B", 1)],
161
            [("C", 1)]
162
        ])
163
164
    def test_set_label_vector(self):
165
        m = self.get_model()
166
        label_vector = "A + B + C + D + E"
167
168
        m.label_vector = label_vector
169
        self.assertEqual(m.label_vector, utils.parse_label_vector(label_vector))
170
        self.assertEqual("1 + A + B + C + D + E", m.human_readable_label_vector)
171
172
        with self.assertRaises(ValueError):
173
            m.label_vector = "A + G"
174
175
        m.label_vector = None
176
        self.assertIsNone(m.label_vector)
177
178
        for item in (True, False, 0, 1.0):
179
            with self.assertRaises(TypeError):
180
                m.label_vector = item
181
182
    def test_label_getsetters(self):
183
184
        m = self.get_model()
185
        self.assertEqual((), m.labels)
186
187
        m.label_vector = "A + B + C"
188
        self.assertSequenceEqual(("A", "B", "C"), tuple(m.labels))
189
190
        with self.assertRaises(AttributeError):
191
            m.labels = None
192
193
    def test_inheritence(self):
194
        m = self.get_model()
195
        m.label_vector = "A + B + C"
196
        with self.assertRaises(NotImplementedError):
197
            m.train()
198
        m._trained = True
199
        with self.assertRaises(NotImplementedError):
200
            m.predict()
201
        with self.assertRaises(NotImplementedError):
202
            m.fit()
203
        
204
    def test_get_label_indices(self):
205
        m = self.get_model()
206
        m.label_vector = "A^5 + A^2 + B^3 + C + C*D + D^6"
207
        self.assertEqual([1, 2, 3, 5], m._get_lowest_order_label_indices())
208
209
    def test_data_verification(self):
210
        m = self.get_model()
211
        m._training_fluxes = m._training_fluxes.reshape(1, -1)
212
        with self.assertRaises(ValueError):
213
            m._verify_training_data()
214
215
        m._training_flux_uncertainties = \
216
            m._training_flux_uncertainties.reshape(m._training_fluxes.shape)
217
        with self.assertRaises(ValueError):
218
            m._verify_training_data()
219
220
        with self.assertRaises(ValueError):
221
            m = self.get_model(dispersion=[1,2,3])
222
223
        N_labels = 2
224
        N_stars, N_pixels = self.valid_fluxes.shape
225
        invalid_training_labels = np.random.uniform(
226
            low=0.5, high=1.5, size=(N_stars, N_labels))
227
        with self.assertRaises(ValueError):
228
            m = model.BaseCannonModel(invalid_training_labels,
229
                self.valid_fluxes, self.valid_flux_uncertainties)
230
231
    def test_labels_array(self):
232
        m = self.get_model()
233
        m.label_vector = "A^2 + B^3 + C^5"
234
235
        for i, label in enumerate("ABC"):
236
            foo = m.labels_array[:, i]
237
            bar = np.array(m.training_labels[label]).flatten()
238
            self.assertTrue(np.allclose(foo, bar))
239
240
        with self.assertRaises(AttributeError):
241
            m.labels_array = None
242
243
    def test_label_vector_array(self):
244
        m = self.get_model()
245
        m.label_vector = "A^2.0 + B^3.4 + C^5"
246
247
        self.assertTrue(np.allclose(
248
            np.array(m.training_labels["A"]**2).flatten(),
249
            m.label_vector_array[1],
250
        ))
251
        self.assertTrue(np.allclose(
252
            np.array(m.training_labels["B"]**3.4).flatten(),
253
            m.label_vector_array[2]
254
        ))
255
        self.assertTrue(np.allclose(
256
            np.array(m.training_labels["C"]**5).flatten(),
257
            m.label_vector_array[3]
258
        ))
259
260
        m.training_labels["A"][0] = np.nan
261
        m.label_vector_array # For Coveralls.
262
263
        kwd1 = {
264
            "A": float(m.training_labels["A"][1]),
265
            "B": float(m.training_labels["B"][1]),
266
            "C": float(m.training_labels["C"][1])
267
        }
268
        kwd2 = {
269
            "A": [m.training_labels["A"][1][0]],
270
            "B": [m.training_labels["B"][1][0]],
271
            "C": [m.training_labels["C"][1][0]]
272
        }
273
        self.assertTrue(np.allclose(
274
            model._build_label_vector_rows(m.label_vector, kwd1),
275
            model._build_label_vector_rows(m.label_vector, kwd2)
276
        ))
277
278
    def test_format_input_labels(self):
279
        m = self.get_model()
280
        m.label_vector = "A^2.0 + B^3.4 + C^5"
281
282
        kwds = {"A": [5], "B": [3], "C": [0.43]}
283
        for k, v in m._format_input_labels(None, **kwds).items():
284
            self.assertEqual(kwds[k], v)
285
        for k, v in m._format_input_labels([5, 3, 0.43]).items():
286
            self.assertEqual(kwds[k], v)
287
288
        kwds_input = {k: v[0] for k, v in kwds.items() }
289
        for k, v in m._format_input_labels(None, **kwds_input).items():
290
            self.assertEqual(kwds[k], v)
291
292
293
294
295
    # The trained attributes and I/O functions will be tested in the sub-classes
296