1
|
|
|
#!/usr/bin/env python |
2
|
|
|
# -*- coding: utf-8 -*- |
3
|
|
|
|
4
|
|
|
""" |
5
|
|
|
The Cannon. |
6
|
|
|
""" |
7
|
|
|
|
8
|
|
|
from __future__ import (division, print_function, absolute_import, |
9
|
|
|
unicode_literals) |
10
|
|
|
|
11
|
|
|
__all__ = ["CannonModel"] |
12
|
|
|
|
13
|
|
|
import logging |
14
|
|
|
import multiprocessing as mp |
15
|
|
|
import numpy as np |
16
|
|
|
import os |
17
|
|
|
import pickle |
18
|
|
|
from datetime import datetime |
19
|
|
|
from functools import wraps |
20
|
|
|
from sys import version_info |
21
|
|
|
from scipy.spatial import Delaunay |
22
|
|
|
|
23
|
|
|
from .vectorizer.base import BaseVectorizer |
24
|
|
|
from . import (censoring, fitting, utils, vectorizer as vectorizer_module, __version__) |
25
|
|
|
|
26
|
|
|
|
27
|
|
|
logger = logging.getLogger(__name__) |
28
|
|
|
|
29
|
|
|
|
30
|
|
|
def requires_training(method): |
31
|
|
|
""" |
32
|
|
|
A decorator for model methods that require training before being run. |
33
|
|
|
|
34
|
|
|
:param method: |
35
|
|
|
A method belonging to CannonModel. |
36
|
|
|
""" |
37
|
|
|
@wraps(method) |
38
|
|
|
def wrapper(model, *args, **kwargs): |
39
|
|
|
if not model.is_trained: |
40
|
|
|
raise TypeError("the model requires training first") |
41
|
|
|
return method(model, *args, **kwargs) |
42
|
|
|
return wrapper |
43
|
|
|
|
44
|
|
|
|
45
|
|
|
class CannonModel(object): |
46
|
|
|
""" |
47
|
|
|
A model for The Cannon which includes L1 regularization and pixel censoring. |
48
|
|
|
|
49
|
|
|
:param training_set_labels: |
50
|
|
|
A set of objects with labels known to high fidelity. This can be |
51
|
|
|
given as a numpy structured array, or an astropy table. |
52
|
|
|
|
53
|
|
|
:param training_set_flux: |
54
|
|
|
An array of normalised fluxes for stars in the labelled set, given |
55
|
|
|
as shape `(num_stars, num_pixels)`. The `num_stars` should match the |
56
|
|
|
number of rows in `training_set_labels`. |
57
|
|
|
|
58
|
|
|
:param training_set_ivar: |
59
|
|
|
An array of inverse variances on the normalized fluxes for stars in |
60
|
|
|
the training set. The shape of the `training_set_ivar` array should |
61
|
|
|
match that of `training_set_flux`. |
62
|
|
|
|
63
|
|
|
:param vectorizer: |
64
|
|
|
A vectorizer to take input labels and produce a design matrix. This |
65
|
|
|
should be a sub-class of `vectorizer.BaseVectorizer`. |
66
|
|
|
|
67
|
|
|
:param dispersion: [optional] |
68
|
|
|
The dispersion values corresponding to the given pixels. If provided, |
69
|
|
|
this should have a size of `num_pixels`. |
70
|
|
|
|
71
|
|
|
:param regularization: [optional] |
72
|
|
|
The strength of the L1 regularization. This should either be `None`, |
73
|
|
|
a float-type value for single regularization strength for all pixels, |
74
|
|
|
or a float-like array of length `num_pixels`. |
75
|
|
|
|
76
|
|
|
:param censors: [optional] |
77
|
|
|
A dictionary containing label names as keys and boolean censoring |
78
|
|
|
masks as values. |
79
|
|
|
""" |
80
|
|
|
|
81
|
|
|
_data_attributes = \ |
82
|
|
|
("training_set_labels", "training_set_flux", "training_set_ivar") |
83
|
|
|
|
84
|
|
|
# Descriptive attributes are needed to train *and* test the model. |
85
|
|
|
_descriptive_attributes = \ |
86
|
|
|
("vectorizer", "censors", "regularization", "dispersion") |
87
|
|
|
|
88
|
|
|
# Trained attributes are set only at training time. |
89
|
|
|
_trained_attributes = ("theta", "s2") |
90
|
|
|
|
91
|
|
|
def __init__(self, training_set_labels, training_set_flux, training_set_ivar, |
92
|
|
|
vectorizer, dispersion=None, regularization=None, censors=None, **kwargs): |
93
|
|
|
|
94
|
|
|
# Save the vectorizer. |
95
|
|
|
if not isinstance(vectorizer, BaseVectorizer): |
96
|
|
|
raise TypeError( |
97
|
|
|
"vectorizer must be a sub-class of vectorizer.BaseVectorizer") |
98
|
|
|
|
99
|
|
|
self._vectorizer = vectorizer |
100
|
|
|
|
101
|
|
|
if training_set_flux is None and training_set_ivar is None: |
102
|
|
|
|
103
|
|
|
# Must be reading in a model that does not have the training set |
104
|
|
|
# spectra saved. |
105
|
|
|
self._training_set_flux = None |
106
|
|
|
self._training_set_ivar = None |
107
|
|
|
self._training_set_labels = training_set_labels |
108
|
|
|
|
109
|
|
|
else: |
110
|
|
|
self._training_set_flux = np.atleast_2d(training_set_flux) |
111
|
|
|
self._training_set_ivar = np.atleast_2d(training_set_ivar) |
112
|
|
|
|
113
|
|
|
if isinstance(training_set_labels, np.ndarray) \ |
114
|
|
|
and training_set_labels.shape[0] == self._training_set_flux.shape[0] \ |
115
|
|
|
and training_set_labels.shape[1] == len(vectorizer.label_names): |
116
|
|
|
# A valid array was given as the training set labels, not a table. |
117
|
|
|
self._training_set_labels = training_set_labels |
118
|
|
|
else: |
119
|
|
|
self._training_set_labels = np.array( |
120
|
|
|
[training_set_labels[ln] for ln in vectorizer.label_names]).T |
121
|
|
|
|
122
|
|
|
# Check that the flux and ivar are valid. |
123
|
|
|
self._verify_training_data(**kwargs) |
124
|
|
|
|
125
|
|
|
# Set regularization, censoring, dispersion. |
126
|
|
|
self.regularization = regularization |
127
|
|
|
self.censors = censors |
128
|
|
|
self.dispersion = dispersion |
129
|
|
|
|
130
|
|
|
# Set useful private attributes. |
131
|
|
|
__scale_labels_function = kwargs.get("__scale_labels_function", |
132
|
|
|
lambda l: np.ptp(np.percentile(l, [2.5, 97.5], axis=0), axis=0)) |
133
|
|
|
__fiducial_labels_function = kwargs.get("__fiducial_labels_function", |
134
|
|
|
lambda l: np.percentile(l, 50, axis=0)) |
135
|
|
|
|
136
|
|
|
self._scales = __scale_labels_function(self.training_set_labels) |
137
|
|
|
self._fiducials = __fiducial_labels_function(self.training_set_labels) |
138
|
|
|
self._design_matrix = vectorizer( |
139
|
|
|
(self.training_set_labels - self._fiducials)/self._scales).T |
140
|
|
|
|
141
|
|
|
self.reset() |
142
|
|
|
|
143
|
|
|
return None |
144
|
|
|
|
145
|
|
|
|
146
|
|
|
# Representations. |
147
|
|
|
|
148
|
|
|
|
149
|
|
|
def __str__(self): |
150
|
|
|
return "<{module}.{name} of {K} labels {trained}with a training set "\ |
151
|
|
|
"of {N} stars each with {M} pixels>".format( |
152
|
|
|
module=self.__module__, |
153
|
|
|
name=type(self).__name__, |
154
|
|
|
trained="trained " if self.is_trained else "", |
155
|
|
|
K=self.training_set_labels.shape[1], |
156
|
|
|
N=self.training_set_labels.shape[0], |
157
|
|
|
M=self.training_set_flux.shape[1]) |
158
|
|
|
|
159
|
|
|
|
160
|
|
|
def __repr__(self): |
161
|
|
|
return "<{0}.{1} object at {2}>".format(self.__module__, |
162
|
|
|
type(self).__name__, hex(id(self))) |
163
|
|
|
|
164
|
|
|
|
165
|
|
|
# Model attributes that cannot (well, should not) be changed. |
166
|
|
|
|
167
|
|
|
|
168
|
|
|
@property |
169
|
|
|
def training_set_labels(self): |
170
|
|
|
""" Return the labels in the training set. """ |
171
|
|
|
return self._training_set_labels |
172
|
|
|
|
173
|
|
|
|
174
|
|
|
@property |
175
|
|
|
def training_set_flux(self): |
176
|
|
|
""" Return the training set fluxes. """ |
177
|
|
|
return self._training_set_flux |
178
|
|
|
|
179
|
|
|
|
180
|
|
|
@property |
181
|
|
|
def training_set_ivar(self): |
182
|
|
|
""" Return the inverse variances of the training set fluxes. """ |
183
|
|
|
return self._training_set_ivar |
184
|
|
|
|
185
|
|
|
|
186
|
|
|
@property |
187
|
|
|
def vectorizer(self): |
188
|
|
|
""" Return the vectorizer for this model. """ |
189
|
|
|
return self._vectorizer |
190
|
|
|
|
191
|
|
|
|
192
|
|
|
@property |
193
|
|
|
def design_matrix(self): |
194
|
|
|
""" Return the design matrix for this model. """ |
195
|
|
|
return self._design_matrix |
196
|
|
|
|
197
|
|
|
|
198
|
|
|
def _censored_design_matrix(self, pixel_index, fill_value=np.nan): |
199
|
|
|
""" |
200
|
|
|
Return a censored design matrix for the given pixel index, and a mask of |
201
|
|
|
which theta values to ignore when fitting. |
202
|
|
|
|
203
|
|
|
:param pixel_index: |
204
|
|
|
The zero-indexed pixel. |
205
|
|
|
|
206
|
|
|
:returns: |
207
|
|
|
A two-length tuple containing the censored design mask for this |
208
|
|
|
pixel, and a boolean mask of values to exclude when fitting for |
209
|
|
|
the spectral derivatives. |
210
|
|
|
""" |
211
|
|
|
|
212
|
|
|
if not self.censors or self.censors is None \ |
213
|
|
|
or len(set(self.censors).intersection(self.vectorizer.label_names)) == 0: |
214
|
|
|
return self.design_matrix |
215
|
|
|
|
216
|
|
|
data = (self.training_set_labels.copy() - self._fiducials)/self._scales |
217
|
|
|
for i, label_name in enumerate(self.vectorizer.label_names): |
218
|
|
|
try: |
219
|
|
|
use = self.censors[label_name][pixel_index] |
220
|
|
|
|
221
|
|
|
except KeyError: |
222
|
|
|
continue |
223
|
|
|
|
224
|
|
|
if not use: |
225
|
|
|
data[:, i] = fill_value |
226
|
|
|
|
227
|
|
|
return self.vectorizer(data).T |
228
|
|
|
|
229
|
|
|
|
230
|
|
|
@property |
231
|
|
|
def theta(self): |
232
|
|
|
""" Return the theta coefficients (spectral model derivatives). """ |
233
|
|
|
return self._theta |
234
|
|
|
|
235
|
|
|
|
236
|
|
|
@property |
237
|
|
|
def s2(self): |
238
|
|
|
""" Return the intrinsic variance (s^2) for all pixels. """ |
239
|
|
|
return self._s2 |
240
|
|
|
|
241
|
|
|
|
242
|
|
|
# Model attributes that can be changed after initiation. |
243
|
|
|
|
244
|
|
|
|
245
|
|
|
@property |
246
|
|
|
def censors(self): |
247
|
|
|
""" Return the wavelength censor masks for the labels. """ |
248
|
|
|
return self._censors |
249
|
|
|
|
250
|
|
|
|
251
|
|
|
@censors.setter |
252
|
|
|
def censors(self, censors): |
253
|
|
|
""" |
254
|
|
|
Set label censoring masks for each pixel. |
255
|
|
|
|
256
|
|
|
:param censors: |
257
|
|
|
A dictionary-like object with label names as keys, and boolean arrays |
258
|
|
|
as values. |
259
|
|
|
""" |
260
|
|
|
|
261
|
|
|
censors = {} if censors is None else censors |
262
|
|
|
if isinstance(censors, censoring.Censors): |
263
|
|
|
# Could be a censoring dictionary from a different model, |
264
|
|
|
# with different label names and pixels. |
265
|
|
|
|
266
|
|
|
# But more likely: we are loading a model from disk. |
267
|
|
|
self._censors = censors |
268
|
|
|
|
269
|
|
|
elif isinstance(censors, dict): |
270
|
|
|
self._censors = censoring.Censors( |
271
|
|
|
self.vectorizer.label_names, self.training_set_flux.shape[1], |
272
|
|
|
censors) |
273
|
|
|
|
274
|
|
|
else: |
275
|
|
|
raise TypeError( |
276
|
|
|
"censors must be a dictionary or a censoring.Censors object") |
277
|
|
|
|
278
|
|
|
|
279
|
|
|
@property |
280
|
|
|
def dispersion(self): |
281
|
|
|
""" Return the dispersion points for all pixels. """ |
282
|
|
|
return self._dispersion |
283
|
|
|
|
284
|
|
|
|
285
|
|
|
@dispersion.setter |
286
|
|
|
def dispersion(self, dispersion): |
287
|
|
|
""" |
288
|
|
|
Set the dispersion values for all the pixels. |
289
|
|
|
|
290
|
|
|
:param dispersion: |
291
|
|
|
An array of the dispersion values. |
292
|
|
|
""" |
293
|
|
|
if dispersion is None: |
294
|
|
|
self._dispersion = None |
295
|
|
|
return None |
296
|
|
|
|
297
|
|
|
dispersion = np.array(dispersion).flatten() |
298
|
|
|
if self.training_set_flux is not None \ |
299
|
|
|
and dispersion.size != self.training_set_flux.shape[1]: |
300
|
|
|
raise ValueError("dispersion provided does not match the number " |
301
|
|
|
"of pixels per star ({0} != {1})".format( |
302
|
|
|
dispersion.size, self.training_set_flux.shape[1])) |
303
|
|
|
|
304
|
|
|
if dispersion.dtype.kind not in "iuf": |
305
|
|
|
raise ValueError("dispersion values are not float-like") |
306
|
|
|
|
307
|
|
|
if not np.all(np.isfinite(dispersion)): |
308
|
|
|
raise ValueError("dispersion values must be finite") |
309
|
|
|
|
310
|
|
|
self._dispersion = dispersion |
311
|
|
|
return None |
312
|
|
|
|
313
|
|
|
|
314
|
|
|
@property |
315
|
|
|
def regularization(self): |
316
|
|
|
""" Return the strength of the L1 regularization for this model. """ |
317
|
|
|
return self._regularization |
318
|
|
|
|
319
|
|
|
|
320
|
|
|
@regularization.setter |
321
|
|
|
def regularization(self, regularization): |
322
|
|
|
""" |
323
|
|
|
Specify the strength of the regularization for the model, either as a |
324
|
|
|
single value for all pixels, or a different strength for each pixel. |
325
|
|
|
|
326
|
|
|
:param regularization: |
327
|
|
|
The L1-regularization strength for the model. |
328
|
|
|
""" |
329
|
|
|
|
330
|
|
|
if regularization is None: |
331
|
|
|
self._regularization = None |
332
|
|
|
return None |
333
|
|
|
|
334
|
|
|
regularization = np.array(regularization).flatten() |
335
|
|
|
if regularization.size == 1: |
336
|
|
|
regularization = regularization[0] |
337
|
|
|
if 0 > regularization or not np.isfinite(regularization): |
338
|
|
|
raise ValueError("regularization must be positive and finite") |
339
|
|
|
|
340
|
|
|
elif regularization.size != self.training_set_flux.shape[1]: |
341
|
|
|
raise ValueError("regularization array must be of size `num_pixels`") |
342
|
|
|
|
343
|
|
|
if any(0 > regularization) \ |
344
|
|
|
or not np.all(np.isfinite(regularization)): |
345
|
|
|
raise ValueError("regularization must be positive and finite") |
346
|
|
|
|
347
|
|
|
self._regularization = regularization |
348
|
|
|
return None |
349
|
|
|
|
350
|
|
|
|
351
|
|
|
# Convenient functions and properties. |
352
|
|
|
|
353
|
|
|
|
354
|
|
|
@property |
355
|
|
|
def is_trained(self): |
356
|
|
|
""" Return true or false for whether the model is trained. """ |
357
|
|
|
return all(getattr(self, attr, None) is not None \ |
358
|
|
|
for attr in self._trained_attributes) |
359
|
|
|
|
360
|
|
|
|
361
|
|
|
def reset(self): |
362
|
|
|
""" Clear any attributes that have been trained. """ |
363
|
|
|
for attribute in self._trained_attributes: |
364
|
|
|
setattr(self, "_{}".format(attribute), None) |
365
|
|
|
return None |
366
|
|
|
|
367
|
|
|
|
368
|
|
|
def _pixel_access(self, array, index, default=None): |
369
|
|
|
""" |
370
|
|
|
Safely access a (potentially per-pixel) attribute of the model. |
371
|
|
|
|
372
|
|
|
:param array: |
373
|
|
|
Either `None`, a float value, or an array the size of the dispersion |
374
|
|
|
array. |
375
|
|
|
|
376
|
|
|
:param index: |
377
|
|
|
The zero-indexed pixel to attempt to access. |
378
|
|
|
|
379
|
|
|
:param default: [optional] |
380
|
|
|
The default value to return if `array` is None. |
381
|
|
|
""" |
382
|
|
|
|
383
|
|
|
if array is None: |
384
|
|
|
return default |
385
|
|
|
try: |
386
|
|
|
return array[index] |
387
|
|
|
except (IndexError, TypeError): |
388
|
|
|
return array |
389
|
|
|
|
390
|
|
|
|
391
|
|
|
def _verify_training_data(self, rho_warning=0.90): |
392
|
|
|
""" |
393
|
|
|
Verify the training data for the appropriate shape and content. |
394
|
|
|
|
395
|
|
|
:param rho_warning: [optional] |
396
|
|
|
Maximum correlation value between labels before a warning is given. |
397
|
|
|
""" |
398
|
|
|
|
399
|
|
|
if self.training_set_flux.shape != self.training_set_ivar.shape: |
400
|
|
|
raise ValueError("the training set flux and inverse variance arrays" |
401
|
|
|
" for the labelled set must have the same shape") |
402
|
|
|
|
403
|
|
|
if len(self.training_set_labels) != self.training_set_flux.shape[0]: |
404
|
|
|
raise ValueError( |
405
|
|
|
"the first axes of the training set flux array should " |
406
|
|
|
"have the same shape as the nuber of rows in the labelled set" |
407
|
|
|
"(N_stars, N_pixels)") |
408
|
|
|
|
409
|
|
|
if not np.all(np.isfinite(self.training_set_labels)): |
410
|
|
|
raise ValueError("training set labels are not all finite") |
411
|
|
|
|
412
|
|
|
if not np.all(np.isfinite(self.training_set_flux)): |
413
|
|
|
raise ValueError("training set fluxes are not all finite") |
414
|
|
|
|
415
|
|
|
if not np.all(self.training_set_ivar >= 0) \ |
416
|
|
|
or not np.all(np.isfinite(self.training_set_ivar)): |
417
|
|
|
raise ValueError("training set ivars are not all positive finite") |
418
|
|
|
|
419
|
|
|
# Look for very high correlation coefficients between labels, which |
420
|
|
|
# could make the training time very difficult. |
421
|
|
|
rho = np.corrcoef(self.training_set_labels.T) |
422
|
|
|
|
423
|
|
|
# Set the diagonal indices to zero. |
424
|
|
|
K = rho.shape[0] |
425
|
|
|
rho[np.diag_indices(K)] = 0.0 |
426
|
|
|
indices = np.argsort(rho.flatten())[::-1] |
427
|
|
|
|
428
|
|
|
for index in indices: |
429
|
|
|
x, y = (index % K, int(index / K)) |
430
|
|
|
rho_xy = rho[x, y] |
431
|
|
|
if rho_xy >= rho_warning: |
432
|
|
|
if x > y: # One warning per correlated label pair. |
433
|
|
|
logger.warn("Labels '{X}' and '{Y}' are highly correlated ("\ |
434
|
|
|
"rho = {rho_xy:.2}). This may cause very slow training "\ |
435
|
|
|
"times. Are both labels needed?".format( |
436
|
|
|
X=self.vectorizer.label_names[x], |
437
|
|
|
Y=self.vectorizer.label_names[y], |
438
|
|
|
rho_xy=rho_xy)) |
439
|
|
|
else: |
440
|
|
|
break |
441
|
|
|
return None |
442
|
|
|
|
443
|
|
|
|
444
|
|
|
def in_convex_hull(self, labels): |
445
|
|
|
""" |
446
|
|
|
Return whether the provided labels are inside a complex hull constructed |
447
|
|
|
from the labelled set. |
448
|
|
|
|
449
|
|
|
:param labels: |
450
|
|
|
A `NxK` array of `N` sets of `K` labels, where `K` is the number of |
451
|
|
|
labels that make up the vectorizer. |
452
|
|
|
|
453
|
|
|
:returns: |
454
|
|
|
A boolean array as to whether the points are in the complex hull of |
455
|
|
|
the labelled set. |
456
|
|
|
""" |
457
|
|
|
|
458
|
|
|
labels = np.atleast_2d(labels) |
459
|
|
|
if labels.shape[1] != self.training_set_labels.shape[1]: |
460
|
|
|
raise ValueError("expected {} labels; got {}".format( |
461
|
|
|
self.training_set_labels.shape[1], labels.shape[1])) |
462
|
|
|
|
463
|
|
|
hull = Delaunay(self.training_set_labels) |
464
|
|
|
return hull.find_simplex(labels) >= 0 |
465
|
|
|
|
466
|
|
|
|
467
|
|
|
def write(self, path, include_training_set_spectra=False, overwrite=False, |
468
|
|
|
protocol=-1): |
469
|
|
|
""" |
470
|
|
|
Serialise the trained model and save it to disk. This will save all |
471
|
|
|
relevant training attributes, and optionally, the training data. |
472
|
|
|
|
473
|
|
|
:param path: |
474
|
|
|
The path to save the model to. |
475
|
|
|
|
476
|
|
|
:param include_training_set_spectra: [optional] |
477
|
|
|
Save the labelled set, normalised flux and inverse variance used to |
478
|
|
|
train the model. |
479
|
|
|
|
480
|
|
|
:param overwrite: [optional] |
481
|
|
|
Overwrite the existing file path, if it already exists. |
482
|
|
|
|
483
|
|
|
:param protocol: [optional] |
484
|
|
|
The Python pickling protocol to employ. Use 2 for compatibility with |
485
|
|
|
previous Python releases, -1 for performance. |
486
|
|
|
""" |
487
|
|
|
|
488
|
|
|
if os.path.exists(path) and not overwrite: |
489
|
|
|
raise IOError("path already exists: {0}".format(path)) |
490
|
|
|
|
491
|
|
|
attributes = list(self._descriptive_attributes) \ |
492
|
|
|
+ list(self._trained_attributes) \ |
493
|
|
|
+ list(self._data_attributes) |
494
|
|
|
|
495
|
|
|
if "metadata" in attributes: |
496
|
|
|
logger.warn("'metadata' is a protected attribute. Ignoring.") |
497
|
|
|
attributes.remote("metadata") |
498
|
|
|
|
499
|
|
|
# Store up all the trained attributes and a hash of the training set. |
500
|
|
|
state = {} |
501
|
|
|
for attribute in attributes: |
502
|
|
|
|
503
|
|
|
value = getattr(self, attribute) |
504
|
|
|
|
505
|
|
|
try: |
506
|
|
|
# If it's a vectorizer or censoring dict, etc, get the state. |
507
|
|
|
value = value.__getstate__() |
508
|
|
|
except: |
509
|
|
|
None |
510
|
|
|
|
511
|
|
|
state[attribute] = value |
512
|
|
|
|
513
|
|
|
# Create a metadata dictionary. |
514
|
|
|
state["metadata"] = dict( |
515
|
|
|
version=__version__, |
516
|
|
|
model_class=type(self).__name__, |
517
|
|
|
modified=str(datetime.now()), |
518
|
|
|
data_attributes=self._data_attributes, |
519
|
|
|
descriptive_attributes=self._descriptive_attributes, |
520
|
|
|
trained_attributes=self._trained_attributes, |
521
|
|
|
training_set_hash=utils.short_hash( |
522
|
|
|
getattr(self, attr) for attr in self._data_attributes), |
523
|
|
|
) |
524
|
|
|
|
525
|
|
|
if not include_training_set_spectra: |
526
|
|
|
state.pop("training_set_flux") |
527
|
|
|
state.pop("training_set_ivar") |
528
|
|
|
|
529
|
|
|
elif not self.is_trained: |
530
|
|
|
logger.warn("The training set spectra won't be saved, and this model"\ |
531
|
|
|
"is not already trained. The saved model will not be "\ |
532
|
|
|
"able to be trained when loaded!") |
533
|
|
|
|
534
|
|
|
with open(path, "wb") as fp: |
535
|
|
|
pickle.dump(state, fp, protocol) |
536
|
|
|
return None |
537
|
|
|
|
538
|
|
|
|
539
|
|
|
@classmethod |
540
|
|
|
def read(cls, path, **kwargs): |
541
|
|
|
""" |
542
|
|
|
Read a saved model from disk. |
543
|
|
|
|
544
|
|
|
:param path: |
545
|
|
|
The path where to load the model from. |
546
|
|
|
""" |
547
|
|
|
|
548
|
|
|
encodings = ("utf-8", "latin-1") |
549
|
|
|
for encoding in encodings: |
550
|
|
|
kwds = {"encoding": encoding} if version_info[0] >= 3 else {} |
551
|
|
|
try: |
552
|
|
|
with open(path, "rb") as fp: |
553
|
|
|
state = pickle.load(fp, **kwds) |
554
|
|
|
|
555
|
|
|
except UnicodeDecodeError: |
556
|
|
|
if encoding == encodings: |
557
|
|
|
raise |
558
|
|
|
|
559
|
|
|
# Parse the state. |
560
|
|
|
metadata = state.get("metadata", {}) |
561
|
|
|
version_saved = metadata.get("version", "0.1.0") |
562
|
|
|
if version_saved >= "0.2.0": # Refactor'd. |
563
|
|
|
|
564
|
|
|
init_attributes = list(metadata["data_attributes"]) \ |
565
|
|
|
+ list(metadata["descriptive_attributes"]) |
566
|
|
|
|
567
|
|
|
kwds = dict([(a, state.get(a, None)) for a in init_attributes]) |
568
|
|
|
|
569
|
|
|
# Initiate the vectorizer. |
570
|
|
|
vectorizer_class, vectorizer_kwds = kwds["vectorizer"] |
571
|
|
|
klass = getattr(vectorizer_module, vectorizer_class) |
572
|
|
|
kwds["vectorizer"] = klass(**vectorizer_kwds) |
573
|
|
|
|
574
|
|
|
# Initiate the censors. |
575
|
|
|
kwds["censors"] = censoring.Censors(**kwds["censors"]) |
576
|
|
|
|
577
|
|
|
model = cls(**kwds) |
578
|
|
|
|
579
|
|
|
# Set training attributes. |
580
|
|
|
for attr in metadata["trained_attributes"]: |
581
|
|
|
setattr(model, "_{}".format(attr), state.get(attr, None)) |
582
|
|
|
|
583
|
|
|
return model |
584
|
|
|
|
585
|
|
|
else: |
586
|
|
|
raise NotImplementedError( |
587
|
|
|
"Cannot auto-convert old model files yet; " |
588
|
|
|
"contact Andy Casey <[email protected]> if you need this") |
589
|
|
|
|
590
|
|
|
|
591
|
|
|
def train(self, threads=None, op_method=None, op_strict=True, op_kwds=None, |
592
|
|
|
**kwargs): |
593
|
|
|
""" |
594
|
|
|
Train the model. |
595
|
|
|
|
596
|
|
|
:param threads: [optional] |
597
|
|
|
The number of parallel threads to use. |
598
|
|
|
|
599
|
|
|
:param op_method: [optional] |
600
|
|
|
The optimization algorithm to use: l_bfgs_b (default) and powell |
601
|
|
|
are available. |
602
|
|
|
|
603
|
|
|
:param op_strict: [optional] |
604
|
|
|
Default to Powell's optimization method if BFGS fails. |
605
|
|
|
|
606
|
|
|
:param op_kwds: |
607
|
|
|
Keyword arguments to provide directly to the optimization function. |
608
|
|
|
|
609
|
|
|
:returns: |
610
|
|
|
A three-length tuple containing the spectral coefficients `theta`, |
611
|
|
|
the squared scatter term at each pixel `s2`, and metadata related to |
612
|
|
|
the training of each pixel. |
613
|
|
|
""" |
614
|
|
|
|
615
|
|
|
kwds = dict(op_method=op_method, op_strict=op_strict, op_kwds=op_kwds) |
616
|
|
|
kwds.update(kwargs) |
617
|
|
|
|
618
|
|
|
if self.training_set_flux is None or self.training_set_ivar is None: |
619
|
|
|
raise TypeError( |
620
|
|
|
"cannot train: training set spectra not saved with the model") |
621
|
|
|
|
622
|
|
|
S, P = self.training_set_flux.shape |
623
|
|
|
T = self.design_matrix.shape[1] |
624
|
|
|
|
625
|
|
|
logger.info("Training {0}-label {1} with {2} stars and {3} pixels/star"\ |
626
|
|
|
.format(len(self.vectorizer.label_names), type(self).__name__, S, P)) |
627
|
|
|
|
628
|
|
|
# Parallelise out. |
629
|
|
|
if threads in (1, None): |
630
|
|
|
mapper, pool = (map, None) |
631
|
|
|
|
632
|
|
|
else: |
633
|
|
|
pool = mp.Pool(threads) |
634
|
|
|
mapper = pool.map |
635
|
|
|
|
636
|
|
|
func = utils.wrapper(fitting.fit_pixel_fixed_scatter, None, kwds, P) |
637
|
|
|
|
638
|
|
|
meta = [] |
639
|
|
|
theta = np.nan * np.ones((P, T)) |
640
|
|
|
s2 = np.nan * np.ones(P) |
641
|
|
|
|
642
|
|
|
for pixel, (flux, ivar) \ |
643
|
|
|
in enumerate(zip(self.training_set_flux.T, self.training_set_ivar.T)): |
644
|
|
|
|
645
|
|
|
args = ( |
646
|
|
|
flux, ivar, |
647
|
|
|
self._initial_theta(pixel), |
648
|
|
|
self._censored_design_matrix(pixel), |
649
|
|
|
self._pixel_access(self.regularization, pixel, 0.0), |
650
|
|
|
None |
651
|
|
|
) |
652
|
|
|
(pixel_theta, pixel_s2, pixel_meta), = mapper(func, [args]) |
653
|
|
|
|
654
|
|
|
meta.append(pixel_meta) |
655
|
|
|
theta[pixel], s2[pixel] = (pixel_theta, pixel_s2) |
656
|
|
|
|
657
|
|
|
self._theta, self._s2 = (theta, s2) |
658
|
|
|
|
659
|
|
|
if pool is not None: |
660
|
|
|
pool.close() |
661
|
|
|
pool.join() |
662
|
|
|
|
663
|
|
|
return (theta, s2, meta) |
664
|
|
|
|
665
|
|
|
|
666
|
|
|
@requires_training |
667
|
|
|
def __call__(self, labels): |
668
|
|
|
""" |
669
|
|
|
Return spectral fluxes, given the labels. |
670
|
|
|
|
671
|
|
|
:param labels: |
672
|
|
|
An array of stellar labels. |
673
|
|
|
""" |
674
|
|
|
|
675
|
|
|
# Scale and offset the labels. |
676
|
|
|
scaled_labels = (np.atleast_2d(labels) - self._fiducials)/self._scales |
677
|
|
|
flux = np.dot(self.theta, self.vectorizer(scaled_labels)).T |
678
|
|
|
return flux[0] if flux.shape[0] == 1 else flux |
679
|
|
|
|
680
|
|
|
|
681
|
|
|
@requires_training |
682
|
|
|
def test(self, flux, ivar, initial_labels=None, threads=None, |
683
|
|
|
use_derivatives=True, op_kwds=None): |
684
|
|
|
""" |
685
|
|
|
Run the test step on spectra. |
686
|
|
|
|
687
|
|
|
:param flux: |
688
|
|
|
The (pseudo-continuum-normalized) spectral flux. |
689
|
|
|
|
690
|
|
|
:param ivar: |
691
|
|
|
The inverse variance values for the spectral fluxes. |
692
|
|
|
|
693
|
|
|
:param initial_labels: [optional] |
694
|
|
|
The initial labels to try for each spectrum. This can be a single |
695
|
|
|
set of initial values, or one set of initial values for each star. |
696
|
|
|
|
697
|
|
|
:param threads: [optional] |
698
|
|
|
The number of parallel threads to use. |
699
|
|
|
|
700
|
|
|
:param use_derivatives: [optional] |
701
|
|
|
Boolean `True` indicating to use analytic derivatives provided by |
702
|
|
|
the vectorizer, `None` to calculate on the fly, or a callable |
703
|
|
|
function to calculate your own derivatives. |
704
|
|
|
|
705
|
|
|
:param op_kwds: [optional] |
706
|
|
|
Optimization keywords that get passed to `scipy.optimize.leastsq`. |
707
|
|
|
""" |
708
|
|
|
|
709
|
|
|
if flux is None or ivar is None: |
710
|
|
|
raise ValueError("flux and ivar must not be None") |
711
|
|
|
|
712
|
|
|
if op_kwds is None: |
713
|
|
|
op_kwds = dict() |
714
|
|
|
|
715
|
|
|
if threads in (1, None): |
716
|
|
|
mapper, pool = (map, None) |
717
|
|
|
|
718
|
|
|
else: |
719
|
|
|
pool = mp.Pool(threads) |
720
|
|
|
mapper = pool.map |
721
|
|
|
|
722
|
|
|
flux, ivar = (np.atleast_2d(flux), np.atleast_2d(ivar)) |
723
|
|
|
S, P = flux.shape |
724
|
|
|
|
725
|
|
|
if ivar.shape != flux.shape: |
726
|
|
|
raise ValueError("flux and ivar arrays must be the same shape") |
727
|
|
|
|
728
|
|
|
if initial_labels is None: |
729
|
|
|
initial_labels = self._fiducials |
730
|
|
|
|
731
|
|
|
initial_labels = np.atleast_2d(initial_labels) |
732
|
|
|
if initial_labels.shape[0] != S and len(initial_labels.shape) == 2: |
733
|
|
|
initial_labels = np.tile(initial_labels.flatten(), S)\ |
734
|
|
|
.reshape(S, -1, len(self._fiducials)) |
735
|
|
|
|
736
|
|
|
args = (self.vectorizer, self.theta, self.s2, self._fiducials, |
737
|
|
|
self._scales) |
738
|
|
|
kwargs = dict(use_derivatives=use_derivatives, op_kwds=op_kwds) |
739
|
|
|
|
740
|
|
|
func = utils.wrapper(fitting.fit_spectrum, args, kwargs, S, |
741
|
|
|
message="Running test step on {} spectra".format(S)) |
742
|
|
|
|
743
|
|
|
labels, cov, meta = zip(*mapper(func, zip(*(flux, ivar, initial_labels)))) |
744
|
|
|
|
745
|
|
|
if pool is not None: |
746
|
|
|
pool.close() |
747
|
|
|
pool.join() |
748
|
|
|
|
749
|
|
|
return (np.array(labels), np.array(cov), meta) |
750
|
|
|
|
751
|
|
|
|
752
|
|
|
def _initial_theta(self, pixel_index, **kwargs): |
753
|
|
|
""" |
754
|
|
|
Return a list of guesses of the spectral coefficients for the given |
755
|
|
|
pixel index. Initial values are sourced in the following preference |
756
|
|
|
order: |
757
|
|
|
|
758
|
|
|
(1) a previously trained `theta` value for this pixel, |
759
|
|
|
(2) an estimate of `theta` using linear algebra, |
760
|
|
|
(3) a neighbouring pixel's `theta` value, |
761
|
|
|
(4) the fiducial value of [1, 0, ..., 0]. |
762
|
|
|
|
763
|
|
|
:param pixel_index: |
764
|
|
|
The zero-indexed integer of the pixel. |
765
|
|
|
|
766
|
|
|
:returns: |
767
|
|
|
A list of initial theta guesses, and the source of each guess. |
768
|
|
|
""" |
769
|
|
|
|
770
|
|
|
guesses = [] |
771
|
|
|
|
772
|
|
|
if self.theta is not None: |
773
|
|
|
# Previously trained theta value. |
774
|
|
|
if np.all(np.isfinite(self.theta[pixel_index])): |
775
|
|
|
guesses.append((self.theta[pixel_index], "previously_trained")) |
776
|
|
|
|
777
|
|
|
# Estimate from linear algebra. |
778
|
|
|
theta, cov = fitting.fit_theta_by_linalg( |
779
|
|
|
self.training_set_flux[:, pixel_index], |
780
|
|
|
self.training_set_ivar[:, pixel_index], |
781
|
|
|
s2=kwargs.get("s2", 0.0), design_matrix=self.design_matrix) |
782
|
|
|
|
783
|
|
|
if np.all(np.isfinite(theta)): |
784
|
|
|
guesses.append((theta, "linear_algebra")) |
785
|
|
|
|
786
|
|
|
if self.theta is not None: |
787
|
|
|
# Neighbouring pixels value. |
788
|
|
|
for neighbour_pixel_index in set(np.clip( |
789
|
|
|
[pixel_index - 1, pixel_index + 1], |
790
|
|
|
0, self.training_set_flux.shape[1] - 1)): |
791
|
|
|
|
792
|
|
|
if np.all(np.isfinite(self.theta[neighbour_pixel_index])): |
793
|
|
|
guesses.append( |
794
|
|
|
(self.theta[neighbour_pixel_index], "neighbour_pixel")) |
795
|
|
|
|
796
|
|
|
# Fiducial value. |
797
|
|
|
fiducial = np.hstack([1.0, np.zeros(len(self.vectorizer.terms))]) |
798
|
|
|
guesses.append((fiducial, "fiducial")) |
799
|
|
|
|
800
|
|
|
return guesses |
801
|
|
|
|