Completed
Pull Request — master (#1)
by Andy
01:12
created

AnniesLasso.BaseCannonModel.training_fluxes()   A

Complexity

Conditions 1

Size

Total Lines 3

Duplication

Lines 0
Ratio 0 %
Metric Value
dl 0
loc 3
rs 10
cc 1
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
"""
5
An abstract model class for The Cannon.
6
"""
7
8
from __future__ import (division, print_function, absolute_import,
9
                        unicode_literals)
10
11
__all__ = \
12
    ["BaseCannonModel", "requires_training_wheels", "requires_label_vector"]
13
14
import logging
15
import numpy as np
16
import multiprocessing as mp
17
from collections import OrderedDict
18
from datetime import datetime
19
from os import path
20
from six.moves import cPickle as pickle
21
22
from . import (utils, __version__ as code_version)
23
24
logger = logging.getLogger(__name__)
25
26
27
def requires_training_wheels(method):
28
    """
29
    A decorator for model methods that require training before being run.
30
    """
31
32
    def wrapper(model, *args, **kwargs):
33
        if not model.is_trained:
34
            raise TypeError("the model needs training first")
35
        return method(model, *args, **kwargs)
36
    return wrapper
37
38
39
def requires_label_vector(method):
40
    """
41
    A decorator for model methods that require a label vector description.
42
    """
43
44
    def wrapper(model, *args, **kwargs):
45
        if model.label_vector is None:
46
            raise TypeError("the model requires a label vector description")
47
        return method(model, *args, **kwargs)
48
    return wrapper
49
50
51
class BaseCannonModel(object):
52
    """
53
    An abstract Cannon model object that implements convenience functions.
54
55
    :param labels:
56
        A table with columns as labels, and stars as rows.
57
58
    :type labels:
59
        :class:`~astropy.table.Table` or numpy structured array
60
61
    :param fluxes:
62
        An array of fluxes for stars in the training set, given as shape
63
        `(num_stars, num_pixels)`. The `num_stars` should match the number of
64
        rows in `labels`.
65
66
    :type fluxes:
67
        :class:`np.ndarray`
68
69
    :param flux_uncertainties:
70
        An array of 1-sigma flux uncertainties for stars in the training set,
71
        The shape of the `flux_uncertainties` should match `fluxes`. 
72
73
    :type flux_uncertainties:
74
        :class:`np.ndarray`
75
76
    :param dispersion: [optional]
77
        The dispersion values corresponding to the given pixels. If provided, 
78
        this should have length `num_pixels`.
79
80
    :param live_dangerously: [optional]
81
        If enabled then no checks will be made on the label names, prohibiting
82
        the user to input human-readable forms of the label vector.
83
    """
84
85
    _descriptive_attributes = ["_label_vector"]
86
    _trained_attributes = []
87
    _data_attributes = []
88
    _forbidden_label_characters = "^*"
89
    
90
    def __init__(self, labels, fluxes, flux_uncertainties, dispersion=None,
91
        threads=1, pool=None, live_dangerously=False):
92
93
        self._training_labels = labels
94
        self._training_fluxes = np.atleast_2d(fluxes)
95
        self._training_flux_uncertainties = np.atleast_2d(flux_uncertainties)
96
        self._dispersion = np.arange(fluxes.shape[1], dtype=int) \
97
            if dispersion is None else dispersion
98
        
99
        for attribute in self._descriptive_attributes:
100
            setattr(self, attribute, None)
101
        
102
        # The training data must be checked, but users can live dangerously if
103
        # they think they can correctly specify the label vector description.
104
        self._verify_training_data()
105
        if not live_dangerously:
106
            self._verify_labels_available()
107
108
        self.reset()
109
        self.threads = threads
110
        self.pool = pool or mp.Pool(threads) if threads > 1 else None
111
112
113
    def reset(self):
114
        """
115
        Clear any attributes that have been trained upon.
116
        """
117
118
        self._trained = False
119
        for attribute in self._trained_attributes:
120
            setattr(self, attribute, None)
121
        
122
        return None
123
124
125
    def __str__(self):
126
        return "<{module}.{name} {trained}using a training set of {N} stars "\
127
               "with {K} available labels and {M} pixels each>".format(
128
                    module=self.__module__,
129
                    name=type(self).__name__,
130
                    trained="trained " if self.is_trained else "",
131
                    N=len(self.training_labels),
132
                    K=len(self.labels_available),
133
                    M=len(self.dispersion))
134
135
136
    def __repr__(self):
137
        return "<{0}.{1} object at {2}>".format(
138
            self.__module__, type(self).__name__, hex(id(self)))
139
140
141
    # Attributes related to the training data.
142
    @property
143
    def dispersion(self):
144
        """
145
        Return the dispersion points for all pixels.
146
        """
147
        return self._dispersion
148
149
150
    @dispersion.setter
151
    def dispersion(self, dispersion):
152
        """
153
        Set the dispersion values for all the pixels.
154
        """
155
        try:
156
            len(dispersion)
157
        except TypeError:
158
            raise TypeError("dispersion provided must be an array or list-like")
159
160
        if len(dispersion) != self.training_fluxes.shape[1]:
161
            raise ValueError("dispersion provided does not match the number "
162
                             "of pixels per star ({0} != {1})".format(
163
                                len(dispersion), self.training_fluxes.shape[1]))
164
165
        dispersion = np.array(dispersion)
166
        if dispersion.dtype.kind not in "iuf":
167
            raise ValueError("dispersion values are not float-like")
168
169
        if not np.all(np.isfinite(dispersion)):
170
            raise ValueError("dispersion values must be finite")
171
172
        self._dispersion = dispersion
173
        return None
174
175
176
    @property
177
    def training_labels(self):
178
        return self._training_labels
179
180
181
    @property
182
    def training_fluxes(self):
183
        return self._training_fluxes
184
185
186
    @property
187
    def training_flux_uncertainties(self):
188
        return self._training_flux_uncertainties
189
190
191
    # Verifying the training data.
192
    def _verify_labels_available(self):
193
        """
194
        Verify the label names provided do not include forbidden characters.
195
        """
196
        if self._forbidden_label_characters is None:
197
            return True
198
199
        for label in self.training_labels.dtype.names:
200
            for character in self._forbidden_label_characters:
201
                if character in label:
202
                    raise ValueError(
203
                        "forbidden character '{char}' is in potential "
204
                        "label '{label}' - you can disable this verification "
205
                        "by enabling `live_dangerously`".format(
206
                            char=character, label=label))
207
        return None
208
209
210
    def _verify_training_data(self):
211
        """
212
        Verify the training data for the appropriate shape and content.
213
        """
214
        if self.training_fluxes.shape != self.training_flux_uncertainties.shape:
215
            raise ValueError(
216
                "the training flux and uncertainty arrays should "
217
                "have the same shape")
218
219
        if len(self.training_labels) == 0 \
220
        or self.training_labels.dtype.names is None:
221
            raise ValueError("no named labels provided for the training set")
222
223
        if len(self.training_labels) != self.training_fluxes.shape[0]:
224
            raise ValueError(
225
                "the first axes of the training flux array should "
226
                "have the same shape as the nuber of rows in the label table "
227
                "(N_stars, N_pixels)")
228
229
        if self.dispersion is not None:
230
            dispersion = np.atleast_1d(self.dispersion).flatten()
231
            if dispersion.size != self.training_fluxes.shape[1]:
232
                raise ValueError(
233
                    "mis-match between the number of wavelength "
234
                    "points ({0}) and flux values ({1})".format(
235
                        self.training_fluxes.shape[1], dispersion.size))
236
        return None
237
238
239
    @property
240
    def is_trained(self):
241
        return self._trained
242
243
244
    # Attributes related to the labels and the label vector description.
245
    @property
246
    def labels_available(self):
247
        """
248
        All of the available labels for each star in the training set.
249
        """
250
        return self.training_labels.dtype.names
251
252
253
    @property
254
    def label_vector(self):
255
        """ The label vector for all pixels. """
256
        return self._label_vector
257
258
259
    @label_vector.setter
260
    def label_vector(self, label_vector_description):
261
        """
262
        Set a label vector.
263
264
        :param label_vector_description:
265
            A structured or human-readable version of the label vector
266
            description.
267
        """
268
269
        if label_vector_description is None:
270
            self._label_vector = None
271
            self.reset()
272
            return None
273
274
        label_vector = utils.parse_label_vector(label_vector_description)
275
276
        # Need to actually verify that the parameters listed in the label vector
277
        # are actually present in the training labels.
278
        missing = \
279
            set(self._get_labels(label_vector)).difference(self.labels_available)
280
        if missing:
281
            raise ValueError("the following labels parsed from the label "
282
                             "vector description are missing in the training "
283
                             "set of labels: {0}".format(", ".join(missing)))
284
285
        # If this is really a new label vector description,
286
        # then we are no longer trained.
287
        if not hasattr(self, "_label_vector") \
288
        or label_vector != self._label_vector:
289
            self._label_vector = label_vector
290
            self.reset()
291
292
        return None
293
294
295
    @property
296
    def human_readable_label_vector(self):
297
        """ Return a human-readable form of the label vector. """
298
        return utils.human_readable_label_vector(self.label_vector)
299
300
301
    @property
302
    def labels(self):
303
        """ The labels that contribute to the label vector. """
304
        return self._get_labels(self.label_vector)
305
    
306
307
    def _get_labels(self, label_vector):
308
        """
309
        Return the labels that contribute to the structured label vector
310
        provided.
311
        """
312
        return () if label_vector is None else \
313
            list(OrderedDict.fromkeys([label for term in label_vector \
314
                for label, power in term if power != 0]))
315
316
317
    def _get_lowest_order_label_indices(self):
318
        """
319
        Get the indices for the lowest power label terms in the label vector.
320
        """
321
        indices = OrderedDict()
322
        for i, term in enumerate(self.label_vector):
323
            if len(term) > 1: continue
324
            label, order = term[0]
325
            if order < indices.get(label, [None, np.inf])[-1]:
326
                indices[label] = (i, order)
327
        return [indices.get(label, [None])[0] for label in self.labels]
328
329
330
    # Trained attributes that subclasses are likely to use.
331
    @property
332
    def coefficients(self):
333
        return self._coefficients
334
335
336
    @coefficients.setter
337
    def coefficients(self, coefficients):
338
        """
339
        Set the label vector coefficients for each pixel. This assumes a
340
        'standard' model where the label vector is common to all pixels.
341
342
        :param coefficients:
343
            A 2-D array of coefficients of shape 
344
            (`N_pixels`, `N_label_vector_terms`).
345
        """
346
347
        if coefficients is None:
348
            self._coefficients = None
349
            return None
350
351
        coefficients = np.atleast_2d(coefficients)
352
        if len(coefficients.shape) > 2:
353
            raise ValueError("coefficients must be a 2D array")
354
355
        P, Q = coefficients.shape
356
        if P != len(self.dispersion):
357
            raise ValueError("axis 0 of coefficients array does not match the "
358
                             "number of pixels ({0} != {1})".format(
359
                                P, len(self.dispersion)))
360
        if Q != 1 + len(self.label_vector):
361
            raise ValueError("axis 1 of coefficients array does not match the "
362
                             "number of label vector terms ({0} != {1})".format(
363
                                Q, 1 + len(self.label_vector)))
364
        self._coefficients = coefficients
365
        return None
366
367
368
    @property
369
    def scatter(self):
370
        return self._scatter
371
372
373
    @scatter.setter
374
    def scatter(self, scatter):
375
        """
376
        Set the scatter values for each pixel.
377
378
        :param scatter:
379
            A 1-D array of scatter terms.
380
        """
381
382
        if scatter is None:
383
            self._scatter = None
384
            return None
385
        
386
        scatter = np.array(scatter).flatten()
387
        if scatter.size != len(self.dispersion):
388
            raise ValueError("number of scatter values does not match "
389
                             "the number of pixels ({0} != {1})".format(
390
                                scatter.size, len(self.dispersion)))
391
        if np.any(scatter < 0):
392
            raise ValueError("scatter terms must be positive")
393
        self._scatter = scatter
394
        return None
395
396
397
    # Methods which must be implemented or updated by the subclasses.
398
    def pixel_label_vector(self, pixel_index):
399
        """ The label vector for a given pixel. """
400
        return self.label_vector
401
402
403
    def train(self, *args, **kwargs):
404
        raise NotImplementedError("The train method must be "
405
                                  "implemented by subclasses")
406
407
408
    def predict(self, *args, **kwargs):
409
        raise NotImplementedError("The predict method must be "
410
                                  "implemented by subclasses")
411
412
413
    def fit(self, *args, **kwargs):
414
        raise NotImplementedError("The fit method must be "
415
                                  "implemented by subclasses")
416
417
418
    # I/O
419
    @requires_training_wheels
420
    def save(self, filename, include_training_data=False, overwrite=False):
421
        """
422
        Serialise the trained model and save it to disk. This will save all
423
        relevant training attributes, and optionally, the training data.
424
425
        :param filename:
426
            The path to save the model to.
427
428
        :param include_training_data: [optional]
429
            Save the training data (labels, fluxes, uncertainties) used to train
430
            the model.
431
432
        :param overwrite: [optional]
433
            Overwrite the existing file path, if it already exists.
434
        """
435
436
        if path.exists(filename) and not overwrite:
437
            raise IOError("filename already exists: {0}".format(filename))
438
439
        attributes = list(self._descriptive_attributes) \
440
            + list(self._trained_attributes) \
441
            + list(self._data_attributes)
442
        if "metadata" in attributes:
443
            raise ValueError("'metadata' is a protected attribute and cannot "
444
                             "be used in the _*_attributes in a class")
445
446
        # Store up all the trained attributes and a hash of the training set.
447
        contents = OrderedDict([
448
            (attr.lstrip("_"), getattr(self, attr)) for attr in \
449
            (self._descriptive_attributes + self._trained_attributes)])
450
        contents["training_set_hash"] = utils.short_hash(getattr(self, attr) \
451
            for attr in self._data_attributes)
452
453
        if include_training_data:
454
            contents.update([(attr.lstrip("_"), getattr(self, attr)) \
455
                for attr in self._data_attributes])
456
457
        contents["metadata"] = {
458
            "version": code_version,
459
            "model_name": type(self).__name__, 
460
            "modified": str(datetime.now()),
461
            "data_attributes": \
462
                [_.lstrip("_") for _ in self._data_attributes],
463
            "trained_attributes": \
464
                [_.lstrip("_") for _ in self._trained_attributes],
465
            "descriptive_attributes": \
466
                [_.lstrip("_") for _ in self._descriptive_attributes]
467
        }
468
469
        with open(filename, "wb") as fp:
470
            pickle.dump(contents, fp, -1)
471
472
        return None
473
474
475
    def load(self, filename, verify_training_data=False):
476
        """
477
        Load a saved model from disk.
478
479
        :param filename:
480
            The path where to load the model from.
481
482
        :param verify_training_data: [optional]
483
            If there is training data in the saved model, verify its contents.
484
            Otherwise if no training data is saved, verify that the data used
485
            to train the model is the same data provided when this model was
486
            instantiated.
487
        """
488
489
        with open(filename, "rb") as fp:
490
            contents = pickle.load(fp)
491
492
        assert contents["metadata"]["model_name"] == type(self).__name__
493
494
        # If data exists, deal with that first.
495
        has_data = (contents["metadata"]["data_attributes"][0] in contents)
496
        if has_data:
497
498
            if verify_training_data:
499
                data_hash = utils.short_hash(contents[attr] \
500
                    for attr in contents["metadata"]["data_attributes"])
501
                if contents["training_set_hash"] is not None \
502
                and data_hash != contents["training_set_hash"]:
503
                    raise ValueError("expected hash for the training data is "
504
                                     "different to the actual data hash "
505
                                     "({0} != {1})".format(
506
                                        contents["training_set_hash"],
507
                                        data_hash))
508
509
            # Set the data attributes.
510
            for attribute in contents["metadata"]["data_attributes"]:
511
                if attribute in contents:
512
                    setattr(self, "_{}".format(attribute), contents[attribute])
513
514
        # Set descriptive and trained attributes.
515
        self.reset()
516
        for attribute in contents["metadata"]["descriptive_attributes"]:
517
            setattr(self, "_{}".format(attribute), contents[attribute])
518
        for attribute in contents["metadata"]["trained_attributes"]:
519
            setattr(self, "_{}".format(attribute), contents[attribute])
520
        self._trained = True
521
522
        return None
523
524
525
    # Properties and attribuets related to training, etc.
526
    @property
527
    @requires_label_vector
528
    def labels_array(self):
529
        """
530
        Return an array containing just the training labels, given the label
531
        vector.
532
        """
533
        return _build_label_vector_rows(
534
            [[(label, 1)] for label in self.labels], self.training_labels)[1:].T
535
536
537
    @property
538
    @requires_label_vector
539
    def label_vector_array(self):
540
        """
541
        Build the label vector array.
542
        """
543
544
        lva = _build_label_vector_rows(self.label_vector, self.training_labels)
545
546
        if not np.all(np.isfinite(lva)):
547
            logger.warn("Non-finite labels in the label vector array!")
548
        return lva
549
550
551
    # Residuals in labels in the training data set.
552
    @requires_training_wheels
553
    def get_training_label_residuals(self):
554
        """
555
        Return the residuals (model - training) between the parameters that the
556
        model returns for each star, and the training set value.
557
        """
558
        
559
        optimised_labels = self.fit(self.training_fluxes,
560
            self.training_flux_uncertainties, full_output=False)
561
562
        return optimised_labels - self.labels_array
563
564
565
    def _format_input_labels(self, args=None, **kwargs):
566
        """
567
        Format input labels either from a list or dictionary into a common form.
568
        """
569
570
        # We want labels in a dictionary.
571
        labels = kwargs if args is None else dict(zip(self.labels, args))
572
        return { k: (v if isinstance(v, (list, tuple, np.ndarray)) else [v]) \
573
            for k, v in labels.items() }
574
575
576
    # Put Cross-validation functions in here.
577
    @requires_label_vector
578
    def cross_validate(self, pre_train=None, **kwargs):
579
        """
580
        Perform leave-one-out cross-validation on the training set.
581
        """
582
        
583
        inferred = np.nan * np.ones_like(self.labels_array)
584
        N_training_set, N_labels = inferred.shape
585
        N_stop_at = kwargs.pop("N", N_training_set)
586
587
        debug = kwargs.pop("debug", False)
588
        
589
        kwds = { "threads": self.threads }
590
        kwds.update(kwargs)
591
592
        for i in range(N_training_set):
593
            
594
            training_set = np.ones(N_training_set, dtype=bool)
595
            training_set[i] = False
596
597
            # Create a clean model to use so we don't overwrite self.
598
            model = self.__class__(
599
                self.training_labels[training_set],
600
                self.training_fluxes[training_set],
601
                self.training_flux_uncertainties[training_set],
602
                **kwds)
603
604
            # Initialise and run any pre-training function.
605
            for _attribute in self._descriptive_attributes:
606
                setattr(model, _attribute[1:], getattr(self, _attribute[1:]))
607
608
            if pre_train is not None:
609
                pre_train(self, model)
610
611
            # Train and solve.
612
            model.train()
613
614
            try:
615
                inferred[i, :] = model.fit(self.training_fluxes[i],
616
                    self.training_flux_uncertainties[i], full_output=False)
617
618
            except:
619
                logger.exception("Exception during cross-validation on object "
620
                                 "with index {0}:".format(i))
621
                if debug: raise
622
623
            if i == N_stop_at + 1:
624
                break
625
626
        return inferred[:N_stop_at, :]
627
628
629
def _build_label_vector_rows(label_vector, training_labels):
630
    """
631
    Build a label vector row from a description of the label vector (as indices
632
    and orders to the power of) and the label values themselves.
633
634
    For example: if the first item of `labels` is `A`, and the label vector
635
    description is `A^3` then the first item of `label_vector` would be:
636
637
    `[[(0, 3)], ...`
638
639
    This indicates the first label item (index `0`) to the power `3`.
640
641
    :param label_vector:
642
        An `(index, order)` description of the label vector. 
643
644
    :param training_labels:
645
        The values of the corresponding training labels.
646
647
    :returns:
648
        The corresponding label vector row.
649
    """
650
651
    columns = [np.ones(len(training_labels), dtype=float)]
652
    for term in label_vector:
653
        column = 1.
654
        for label, order in term:
655
            column *= np.array(training_labels[label]).flatten()**order
656
        columns.append(column)
657
658
    try:
659
        return np.vstack(columns)
660
661
    except ValueError:
662
        columns[0] = np.ones(1)
663
        return np.vstack(columns)
664