Completed
Push — master ( 4fc32f...027495 )
by Andy
01:22
created

AnniesLasso.tests.TestRequiresLabelVector   A

Complexity

Total Complexity 5

Size/Duplication

Total Lines 13
Duplicated Lines 0 %
Metric Value
wmc 5
dl 0
loc 13
rs 10

2 Methods

Rating   Name   Duplication   Size   Complexity  
A TestRequiresLabelVector.test_with_label_vector() 0 5 2
A TestRequiresLabelVector.test_without_label_vector() 0 6 3
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
"""
5
Unit tests for the base model class and associated functions.
6
"""
7
8
import numpy as np
9
import unittest
10
from AnniesLasso import model, utils
11
12
13
class NullObject(object):
14
    pass
15
16
17
class TestRequiresTrainingWheels(unittest.TestCase):
18
    def test_not_trained(self):
19
        o = NullObject()
20
        o.is_trained = False
21
        with self.assertRaises(TypeError):
22
            model.requires_training_wheels(lambda x: None)(o)
23
24
    def test_is_trained(self):
25
        o = NullObject()
26
        o.is_trained = True
27
        self.assertIsNone(model.requires_training_wheels(lambda x: None)(o))
28
29
30
class TestRequiresLabelVector(unittest.TestCase):
31
    def test_with_label_vector(self):
32
        o = NullObject()
33
        o._descriptive_attributes = ["label_vector"]
34
        o.label_vector = ""
35
        self.assertIsNone(model.requires_model_description(lambda x: None)(o))
36
37
    def test_without_label_vector(self):
38
        o = NullObject()
39
        o._descriptive_attributes = ["label_vector"]
40
        o.label_vector = None
41
        with self.assertRaises(TypeError):
42
            model.requires_model_description(lambda x: None)(o)
43
44
45
class TestBaseCannonModel(unittest.TestCase):
46
47
    def setUp(self):
48
        # Initialise some faux data and labels.
49
        labels = "ABCDE"
50
        N_labels = len(labels)
51
        N_stars = np.random.randint(10, 500)
52
        N_pixels = np.random.randint(1, 10000)
53
        shape = (N_stars, N_pixels)
54
55
        self.valid_training_labels = np.rec.array(
56
            np.random.uniform(low=0.5, high=1.5, size=(N_stars, N_labels)),
57
            dtype=[(label, '<f8') for label in labels])
58
59
        self.valid_fluxes = np.random.uniform(low=0.5, high=1.5, size=shape)
60
        self.valid_flux_uncertainties = np.random.uniform(low=0.5, high=1.5,
61
            size=shape)
62
63
    def runTest(self):
64
        None
65
    
66
    def get_model(self, **kwargs):
67
        return model.BaseCannonModel(
68
            self.valid_training_labels, self.valid_fluxes,
69
            self.valid_flux_uncertainties, **kwargs)
70
71
    def test_repr(self):
72
        self.runTest() # Just for that 100%, baby.
73
        m = self.get_model()
74
        print("{0} {1}".format(m.__str__(), m.__repr__()))
75
76
    def test_get_dispersion(self):
77
        m = self.get_model()
78
        self.assertSequenceEqual(
79
            tuple(m.dispersion), 
80
            tuple(np.arange(self.valid_fluxes.shape[1])))
81
82
    def test_set_dispersion(self):
83
        m = self.get_model()
84
        for item in (None, False, True):
85
            # Incorrect data type (not an iterable)
86
            with self.assertRaises(TypeError):
87
                m.dispersion = item
88
89
        for item in ("", {}, [], (), set()):
90
            # These are iterable but have the wrong lengths.
91
            with self.assertRaises(ValueError):
92
                m.dispersion = item
93
94
        with self.assertRaises(ValueError):
95
            m.dispersion = [3,4,2,1]
96
97
        # These should work.
98
        m.dispersion = 10 + np.arange(self.valid_fluxes.shape[1])
99
        m.dispersion = -100 + np.arange(self.valid_fluxes.shape[1])
100
        m.dispersion = 520938.4 + np.arange(self.valid_fluxes.shape[1])
101
102
        # Disallow non-finite numbers.
103
        with self.assertRaises(ValueError):
104
            d = np.arange(self.valid_fluxes.shape[1], dtype=float)
105
            d[0] = np.nan
106
            m.dispersion = d
107
108
        with self.assertRaises(ValueError):
109
            d = np.arange(self.valid_fluxes.shape[1], dtype=float)
110
            d[0] = np.inf
111
            m.dispersion = d
112
113
        with self.assertRaises(ValueError):
114
            d = np.arange(self.valid_fluxes.shape[1], dtype=float)
115
            d[0] = -np.inf
116
            m.dispersion = d
117
118
        # Disallow non-float like things.
119
        with self.assertRaises(ValueError):
120
            d = np.array([""] * self.valid_fluxes.shape[1])
121
            m.dispersion = d
122
123
        with self.assertRaises(ValueError):
124
            d = np.array([None] * self.valid_fluxes.shape[1])
125
            m.dispersion = d
126
        
127
    def test_get_training_data(self):
128
        m = self.get_model()
129
        self.assertIsNotNone(m.training_labels)
130
        self.assertIsNotNone(m.training_fluxes)
131
        self.assertIsNotNone(m.training_flux_uncertainties)
132
133
    def test_invalid_label_names(self):
134
        m = self.get_model()
135
        for character in m._forbidden_label_characters:
136
137
            invalid_labels = [] + list(m.labels_available)
138
            invalid_labels[0] = "HELLO_{}".format(character)
139
140
            N_stars = len(self.valid_training_labels)
141
            N_labels = len(invalid_labels)
142
            invalid_training_labels = np.rec.array(
143
                np.random.uniform(size=(N_stars, N_labels)),
144
                dtype=[(l, '<f8') for l in invalid_labels])
145
146
            m = model.BaseCannonModel(invalid_training_labels,
147
                self.valid_fluxes, self.valid_flux_uncertainties,
148
                live_dangerously=True)
149
150
            m._forbidden_label_characters = None
151
            self.assertTrue(m._verify_labels_available())
152
153
            with self.assertRaises(ValueError):
154
                m = model.BaseCannonModel(invalid_training_labels,
155
                    self.valid_fluxes, self.valid_flux_uncertainties)
156
157
    def test_get_label_vector(self):
158
        m = self.get_model()
159
        m.label_vector = "A + B + C"
160
        self.assertEqual(m.pixel_label_vector(1), [
161
            [("A", 1)],
162
            [("B", 1)],
163
            [("C", 1)]
164
        ])
165
166
    def test_set_label_vector(self):
167
        m = self.get_model()
168
        label_vector = "A + B + C + D + E"
169
170
        m.label_vector = label_vector
171
        self.assertEqual(m.label_vector, utils.parse_label_vector(label_vector))
172
        self.assertEqual("1 + A + B + C + D + E", m.human_readable_label_vector)
173
174
        with self.assertRaises(ValueError):
175
            m.label_vector = "A + G"
176
177
        m.label_vector = None
178
        self.assertIsNone(m.label_vector)
179
180
        for item in (True, False, 0, 1.0):
181
            with self.assertRaises(TypeError):
182
                m.label_vector = item
183
184
    def test_label_getsetters(self):
185
186
        m = self.get_model()
187
        self.assertEqual((), m.labels)
188
189
        m.label_vector = "A + B + C"
190
        self.assertSequenceEqual(("A", "B", "C"), tuple(m.labels))
191
192
        with self.assertRaises(AttributeError):
193
            m.labels = None
194
195
    def test_inheritence(self):
196
        m = self.get_model()
197
        m.label_vector = "A + B + C"
198
        with self.assertRaises(NotImplementedError):
199
            m.train()
200
        m._trained = True
201
        with self.assertRaises(NotImplementedError):
202
            m.predict()
203
        with self.assertRaises(NotImplementedError):
204
            m.fit()
205
        
206
    def test_get_label_indices(self):
207
        m = self.get_model()
208
        m.label_vector = "A^5 + A^2 + B^3 + C + C*D + D^6"
209
        self.assertEqual([1, 2, 3, 5], m._get_lowest_order_label_indices())
210
211
    def test_data_verification(self):
212
        m = self.get_model()
213
        m._training_fluxes = m._training_fluxes.reshape(1, -1)
214
        with self.assertRaises(ValueError):
215
            m._verify_training_data()
216
217
        m._training_flux_uncertainties = \
218
            m._training_flux_uncertainties.reshape(m._training_fluxes.shape)
219
        with self.assertRaises(ValueError):
220
            m._verify_training_data()
221
222
        with self.assertRaises(ValueError):
223
            m = self.get_model(dispersion=[1,2,3])
224
225
        N_labels = 2
226
        N_stars, N_pixels = self.valid_fluxes.shape
227
        invalid_training_labels = np.random.uniform(
228
            low=0.5, high=1.5, size=(N_stars, N_labels))
229
        with self.assertRaises(ValueError):
230
            m = model.BaseCannonModel(invalid_training_labels,
231
                self.valid_fluxes, self.valid_flux_uncertainties)
232
233
    def test_labels_array(self):
234
        m = self.get_model()
235
        m.label_vector = "A^2 + B^3 + C^5"
236
237
        for i, label in enumerate("ABC"):
238
            foo = m.labels_array[:, i]
239
            bar = np.array(m.training_labels[label]).flatten()
240
            self.assertTrue(np.allclose(foo, bar))
241
242
        with self.assertRaises(AttributeError):
243
            m.labels_array = None
244
245
    def test_label_vector_array(self):
246
        m = self.get_model()
247
        m.label_vector = "A^2.0 + B^3.4 + C^5"
248
        m.pivots = np.zeros(len(m.labels))
249
        
250
        self.assertTrue(np.allclose(
251
            np.array(m.training_labels["A"]**2).flatten(),
252
            m.label_vector_array[1],
253
        ))
254
        self.assertTrue(np.allclose(
255
            np.array(m.training_labels["B"]**3.4).flatten(),
256
            m.label_vector_array[2]
257
        ))
258
        self.assertTrue(np.allclose(
259
            np.array(m.training_labels["C"]**5).flatten(),
260
            m.label_vector_array[3]
261
        ))
262
263
        m.training_labels["A"][0] = np.nan
264
        m.label_vector_array # For Coveralls.
265
266
        kwd1 = {
267
            "A": float(m.training_labels["A"][1]),
268
            "B": float(m.training_labels["B"][1]),
269
            "C": float(m.training_labels["C"][1])
270
        }
271
        kwd2 = {
272
            "A": [m.training_labels["A"][1][0]],
273
            "B": [m.training_labels["B"][1][0]],
274
            "C": [m.training_labels["C"][1][0]]
275
        }
276
        self.assertTrue(np.allclose(
277
            model._build_label_vector_rows(m.label_vector, kwd1),
278
            model._build_label_vector_rows(m.label_vector, kwd2)
279
        ))
280
281
    def test_format_input_labels(self):
282
        m = self.get_model()
283
        m.label_vector = "A^2.0 + B^3.4 + C^5"
284
285
        kwds = {"A": [5], "B": [3], "C": [0.43]}
286
        for k, v in m._format_input_labels(None, **kwds).items():
287
            self.assertEqual(kwds[k], v)
288
        for k, v in m._format_input_labels([5, 3, 0.43]).items():
289
            self.assertEqual(kwds[k], v)
290
291
        kwds_input = {k: v[0] for k, v in kwds.items() }
292
        for k, v in m._format_input_labels(None, **kwds_input).items():
293
            self.assertEqual(kwds[k], v)
294
295
296
297
298
    # The trained attributes and I/O functions will be tested in the sub-classes
299