Completed
Push — master ( 83b528...b1912a )
by Andy
58s
created

do_io()   F

Complexity

Conditions 11

Size

Total Lines 64

Duplication

Lines 0
Ratio 0 %
Metric Value
dl 0
loc 64
rs 3.913
cc 11

How to fix   Long Method    Complexity   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Complexity

Complex classes like AnniesLasso.tests.TestCannonModelRealistically.do_io() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
"""
5
Unit tests for the Cannon model class and associated functions.
6
"""
7
8
import numpy as np
9
import sys
10
import unittest
11
from six.moves import cPickle as pickle
12
from os import path, remove
13
from tempfile import mkstemp
14
15
from AnniesLasso import cannon, utils
16
17
18
class TestCannonModel(unittest.TestCase):
19
20
    def setUp(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.

Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.

You can also find more detailed suggestions in the “Code” section of your repository.

Loading history...
21
        # Initialise some faux data and labels.
22
        labels = "ABCDE"
23
        N_labels = len(labels)
24
        N_stars = np.random.randint(1, 500)
25
        N_pixels = np.random.randint(1, 10000)
26
        shape = (N_stars, N_pixels)
27
28
        self.valid_training_labels = np.rec.array(
29
            np.random.uniform(size=(N_stars, N_labels)),
30
            dtype=[(label, '<f8') for label in labels])
31
32
        self.valid_fluxes = np.random.uniform(size=shape)
33
        self.valid_flux_uncertainties = np.random.uniform(size=shape)
34
35
    def get_model(self):
36
        return cannon.CannonModel(
37
            self.valid_training_labels, self.valid_fluxes,
38
            self.valid_flux_uncertainties)
39
40
    def test_init(self):
41
        self.assertIsNotNone(self.get_model())
42
43
44
# The test_data_set.pkl contains:
45
# (training_labels, training_fluxes, training_flux_uncertainties, coefficients,
46
#  scatter, label_vector)
47
# The training labels are not named, but they are: (TEFF, LOGG, PARAM_M_H)
48
49
class TestCannonModelRealistically(unittest.TestCase):
50
51
    def setUp(self):
52
        # Set up a model using the test data set.
53
        here = path.dirname(path.realpath(__file__))
54
        kwds = { "encoding": "latin1" } \
55
            if sys.version_info[0] >= 3 else {}
56
        with open(path.join(here, "test_data_set.pkl"), "rb") as fp:
57
            contents = pickle.load(fp, **kwds)
58
59
        # Unpack it all 
60
        training_labels, training_fluxes, training_flux_uncertainties, \
61
            coefficients, scatter, pivots, label_vector = contents
62
63
        training_labels = np.core.records.fromarrays(training_labels,
64
            names="TEFF,LOGG,PARAM_M_H", formats="f8,f8,f8")
65
66
        self.test_data_set = {
67
            "training_labels": training_labels,
68
            "training_fluxes": training_fluxes,
69
            "training_flux_uncertainties": training_flux_uncertainties,
70
            "coefficients": coefficients,
71
            "scatter": scatter,
72
            "pivots": pivots,
73
            "label_vector": label_vector
74
75
        }
76
        self.model_serial = cannon.CannonModel(training_labels, training_fluxes,
77
            training_flux_uncertainties)
78
        self.model_parallel = cannon.CannonModel(training_labels,
79
            training_fluxes, training_flux_uncertainties, threads=2)
80
81
        self.models = (self.model_serial, self.model_parallel)
82
83
    def do_training(self):
84
        for model in self.models:
85
            model.reset()
86
            model.label_vector = self.test_data_set["label_vector"]
87
            self.assertIsNotNone(model.train())
88
89
        # Check that the trained attributes in both model are equal.
90
        for _attribute in self.model_serial._trained_attributes:
91
92
            # And nearly as we expected.
93
            self.assertTrue(np.allclose(
94
                getattr(self.model_serial, _attribute),
95
                getattr(self.model_parallel, _attribute)
96
                ))
97
98
            self.assertTrue(np.allclose(
99
                self.test_data_set[_attribute[1:]],
100
                getattr(self.model_serial, _attribute)))
101
                #rtol=0.5, atol=1e-8))
102
103
    def do_residuals(self):
104
        serial = self.model_serial.get_training_label_residuals()
105
        parallel = self.model_parallel.get_training_label_residuals()
106
        self.assertTrue(np.allclose(serial, parallel))
107
108
    def ruin_the_trained_coefficients(self):
109
        self.model_serial.scatter = None
110
        self.assertIsNone(self.model_serial.scatter)
111
112
        with self.assertRaises(ValueError):
113
            self.model_parallel.scatter = [3]
114
115
        for item in (0., False, True):
116
            with self.assertRaises(ValueError):
117
                self.model_parallel.scatter = item
118
119
        with self.assertRaises(ValueError):
120
            self.model_parallel.scatter = \
121
                -1 * np.ones_like(self.model_parallel.dispersion)
122
123
        _ = np.array(self.model_parallel.scatter).copy()
124
        _ += 1.
125
        self.model_parallel.scatter = _
126
        self.assertTrue(np.allclose(_, self.model_parallel.scatter))
127
128
129
        self.model_serial.coefficients = None
130
        self.assertIsNone(self.model_serial.coefficients)
131
132
        with self.assertRaises(ValueError):
133
            self.model_parallel.coefficients = np.arange(12).reshape((3, 2, 2))
134
135
        with self.assertRaises(ValueError):
136
            _ = np.ones_like(self.model_parallel.coefficients)
137
            self.model_parallel.coefficients = _.T
138
139
        with self.assertRaises(ValueError):
140
            _ = np.ones_like(self.model_parallel.coefficients)
141
            self.model_parallel.coefficients = _[:, :-1]
142
        
143
        _ = np.array(self.model_parallel.coefficients).copy()
144
        _ += 0.5
145
        self.model_parallel.coefficients = _
146
        self.assertTrue(np.allclose(_, self.model_parallel.coefficients))
147
148
    def do_io(self):
149
150
        _, temp_filename = mkstemp()
151
        remove(temp_filename)
152
        self.model_serial.save(temp_filename, include_training_data=False)
153
        with self.assertRaises(IOError):
154
            self.model_serial.save(temp_filename, overwrite=False)
155
156
        names = ("_data_attributes", "_trained_attributes",
157
            "_descriptive_attributes")
158
        attrs = (
159
            self.model_serial._data_attributes,
160
            self.model_serial._trained_attributes,
161
            self.model_serial._descriptive_attributes
162
            )
163
        for name, item in zip(names, attrs):
164
            _ = [] + list(item)
165
            _.append("metadata")
166
            setattr(self.model_serial, name, _)
167
            with self.assertRaises(ValueError):
168
                self.model_serial.save(temp_filename, overwrite=True)
169
            setattr(self.model_serial, name, _[:-1])
170
171
        self.model_serial.save(temp_filename, include_training_data=True,
172
            overwrite=True)
173
174
        self.model_parallel.reset()
175
        self.model_parallel.load(temp_filename, verify_training_data=True)
176
177
        # Check that the trained attributes in both model are equal.
178
        for _attribute in self.model_serial._trained_attributes:
179
180
            # And nearly as we expected.
181
            self.assertTrue(np.allclose(
182
                getattr(self.model_serial, _attribute),
183
                getattr(self.model_parallel, _attribute)
184
                ))
185
186
            self.assertTrue(np.allclose(
187
                self.test_data_set[_attribute[1:]],
188
                getattr(self.model_serial, _attribute)))
189
                #rtol=0.5, atol=1e-8))
190
191
        # Check that the data attributes in both model are equal.
192
        for _attribute in self.model_serial._data_attributes:
193
            self.assertTrue(
194
                utils.short_hash(getattr(self.model_serial, _attribute)),
195
                utils.short_hash(getattr(self.model_parallel, _attribute))
196
            )
197
198
        # Alter the hash and expect failure
199
        kwds = { "encoding": "latin1" } if sys.version_info[0] >= 3 else {}
200
        with open(temp_filename, "rb") as fp:
201
            contents = pickle.load(fp, **kwds)
202
203
        contents["training_set_hash"] = ""
204
        with open(temp_filename, "wb") as fp:
205
            pickle.dump(contents, fp, -1)
206
207
        with self.assertRaises(ValueError):
208
            self.model_serial.load(temp_filename, verify_training_data=True)
209
210
        if path.exists(temp_filename):
211
            remove(temp_filename)
212
213
    def do_cv(self):
214
        self.model_parallel.cross_validate(N=1, debug=True)
215
216
        def choo_choo(old, new):
217
            None
218
219
        self.model_parallel.cross_validate(N=1, debug=True, pre_train=choo_choo)
220
221
    def do_predict(self):
222
        _ = [self.model_serial.training_labels[label][0] \
223
            for label in self.model_serial.labels]
224
        self.assertTrue(np.allclose(
225
            self.model_serial.predict(_),
226
            self.model_serial.predict(**dict(zip(self.model_serial.labels, _)))))
227
228
    def do_fit(self):
229
        self.assertIsNotNone(
230
            self.model_serial.fit(self.model_serial.training_fluxes[0],
231
                self.model_serial.training_flux_uncertainties[0],
232
                full_output=True))
233
234
    def do_edge_cases(self):
235
        self.model_serial.reset()
236
237
        # This label vector only contains one term in cross-terms (PARAM_M_H)
238
        self.model_serial.label_vector = \
239
            "TEFF^3 + TEFF^2 + TEFF + LOGG + PARAM_M_H*LOGG"
240
        self.assertIn(None, self.model_serial._get_lowest_order_label_indices())
241
242
        # Set large uncertainties for one pixel.
243
        self.model_serial._training_flux_uncertainties[:, 0] = 10.
244
        self.model_serial._training_fluxes[:, 1] = \
245
            np.random.uniform(low=-0.5, high=0.5,
246
                size=self.model_serial._training_fluxes.shape[0])
247
248
        # Train and fit using this unusual label vector.
249
        self.model_serial.train()
250
        self.model_serial.fit(self.model_serial._training_fluxes[1],
251
            self.model_serial._training_flux_uncertainties[1])
252
253
        # See if we can make things break or warn.
254
        self.model_serial._training_fluxes[10] = 1000.
255
        self.model_serial._training_flux_uncertainties[10] = 0.
256
        self.model_serial.reset()
257
        self.model_serial.label_vector = "TEFF^5 + LOGG^3 + PARAM_M_H^5"
258
        for label in self.model_serial.labels:
259
            self.model_serial._training_labels[label] = 0.
260
        self.model_serial.train()
261
262
        # TODO: Force things to break
263
        #with self.assertRaises(np.linalg.linalg.LinAlgError):
264
        #    self.model_serial.train(debug=True)
265
266
        #with self.assertRaises(np.linalg.linalg.LinAlgError):
267
        #    self.model_serial.cross_validate(N=1, debug=True)
268
269
    def runTest(self):
270
271
        # Train all.
272
        self.do_training()
273
274
        self.do_residuals()
275
276
        self.ruin_the_trained_coefficients()
277
278
        # Train again.
279
        self.do_training()
280
281
        # Predict stuff.
282
        self.do_predict()
283
284
        self.do_fit()
285
286
        # Do cross-validation.
287
        self.do_cv()
288
289
        # Try I/O/
290
        self.do_io()
291
292
        # Do_edges
293
        self.do_edge_cases()
294
295