Completed
Push — master ( 53c124...83b528 )
by Andy
50s
created

AnniesLasso.BaseCannonModel.label_vector()   A

Complexity

Conditions 1

Size

Total Lines 4

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 1
dl 0
loc 4
rs 10
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
"""
5
An abstract model class for The Cannon.
6
"""
7
8
from __future__ import (division, print_function, absolute_import,
9
                        unicode_literals)
10
11
__all__ = [
12
    "BaseCannonModel", "requires_training_wheels", "requires_model_description"]
13
14
15
import logging
16
import numpy as np
17
import multiprocessing as mp
18
from collections import OrderedDict
19
from datetime import datetime
20
from os import path
21
from six.moves import cPickle as pickle
22
23
from . import (utils, __version__ as code_version)
24
25
logger = logging.getLogger(__name__)
26
27
28
def requires_training_wheels(method):
29
    """
30
    A decorator for model methods that require training before being run.
31
    """
32
33
    def wrapper(model, *args, **kwargs):
34
        if not model.is_trained:
35
            raise TypeError("the model needs training first")
36
        return method(model, *args, **kwargs)
37
    return wrapper
38
39
40
def requires_model_description(method):
41
    """
42
    A decorator for model methods that require a full model description.
43
    (That is, none of the _descriptive_attributes are None)
44
    """
45
46
    def wrapper(model, *args, **kwargs):
47
        for descriptive_attribute in model._descriptive_attributes:
48
            if getattr(model, descriptive_attribute) is None:
49
                raise TypeError("the model requires a {} term".format(
50
                    descriptive_attribute.lstrip("_")))
51
        return method(model, *args, **kwargs)
52
    return wrapper
53
54
55
class BaseCannonModel(object):
56
    """
57
    An abstract Cannon model object that implements convenience functions.
58
59
    :param labels:
60
        A table with columns as labels, and stars as rows.
61
62
    :type labels:
63
        :class:`~astropy.table.Table` or numpy structured array
64
65
    :param fluxes:
66
        An array of fluxes for stars in the training set, given as shape
67
        `(num_stars, num_pixels)`. The `num_stars` should match the number of
68
        rows in `labels`.
69
70
    :type fluxes:
71
        :class:`np.ndarray`
72
73
    :param flux_uncertainties:
74
        An array of 1-sigma flux uncertainties for stars in the training set,
75
        The shape of the `flux_uncertainties` should match `fluxes`. 
76
77
    :type flux_uncertainties:
78
        :class:`np.ndarray`
79
80
    :param dispersion: [optional]
81
        The dispersion values corresponding to the given pixels. If provided, 
82
        this should have length `num_pixels`.
83
84
    :param live_dangerously: [optional]
85
        If enabled then no checks will be made on the label names, prohibiting
86
        the user to input human-readable forms of the label vector.
87
    """
88
89
    _descriptive_attributes = ["_label_vector"]
90
    _trained_attributes = ["_scatter", "_coefficients", "_pivots"]
91
    _data_attributes = []
92
    _forbidden_label_characters = "^*"
93
    
94
    def __init__(self, labels, fluxes, flux_uncertainties, dispersion=None,
95
        threads=1, pool=None, live_dangerously=False):
96
97
        self._training_labels = labels
98
        self._training_fluxes = np.atleast_2d(fluxes)
99
        self._training_flux_uncertainties = np.atleast_2d(flux_uncertainties)
100
        self._dispersion = np.arange(fluxes.shape[1], dtype=int) \
101
            if dispersion is None else dispersion
102
        
103
        for attribute in self._descriptive_attributes:
104
            setattr(self, attribute, None)
105
        
106
        # The training data must be checked, but users can live dangerously if
107
        # they think they can correctly specify the label vector description.
108
        self._verify_training_data()
109
        if not live_dangerously:
110
            self._verify_labels_available()
111
112
        self.reset()
113
        self.threads = threads
114
        self.pool = pool or mp.Pool(threads) if threads > 1 else None
115
116
117
    def reset(self):
118
        """
119
        Clear any attributes that have been trained upon.
120
        """
121
122
        self._trained = False
123
        for attribute in self._trained_attributes:
124
            setattr(self, attribute, None)
125
        
126
        return None
127
128
129
    def __str__(self):
130
        return "<{module}.{name} {trained}using a training set of {N} stars "\
131
               "with {K} available labels and {M} pixels each>".format(
132
                    module=self.__module__,
133
                    name=type(self).__name__,
134
                    trained="trained " if self.is_trained else "",
135
                    N=len(self.training_labels),
136
                    K=len(self.labels_available),
137
                    M=len(self.dispersion))
138
139
140
    def __repr__(self):
141
        return "<{0}.{1} object at {2}>".format(
142
            self.__module__, type(self).__name__, hex(id(self)))
143
144
145
    # Attributes related to the training data.
146
    @property
147
    def dispersion(self):
148
        """
149
        Return the dispersion points for all pixels.
150
        """
151
        return self._dispersion
152
153
154
    @dispersion.setter
155
    def dispersion(self, dispersion):
156
        """
157
        Set the dispersion values for all the pixels.
158
        """
159
        try:
160
            len(dispersion)
161
        except TypeError:
162
            raise TypeError("dispersion provided must be an array or list-like")
163
164
        if len(dispersion) != self.training_fluxes.shape[1]:
165
            raise ValueError("dispersion provided does not match the number "
166
                             "of pixels per star ({0} != {1})".format(
167
                                len(dispersion), self.training_fluxes.shape[1]))
168
169
        dispersion = np.array(dispersion)
170
        if dispersion.dtype.kind not in "iuf":
171
            raise ValueError("dispersion values are not float-like")
172
173
        if not np.all(np.isfinite(dispersion)):
174
            raise ValueError("dispersion values must be finite")
175
176
        self._dispersion = dispersion
177
        return None
178
179
180
    @property
181
    def training_labels(self):
182
        return self._training_labels
183
184
185
    @property
186
    def training_fluxes(self):
187
        return self._training_fluxes
188
189
190
    @property
191
    def training_flux_uncertainties(self):
192
        return self._training_flux_uncertainties
193
194
195
    # Verifying the training data.
196
    def _verify_labels_available(self):
197
        """
198
        Verify the label names provided do not include forbidden characters.
199
        """
200
        if self._forbidden_label_characters is None:
201
            return True
202
203
        for label in self.training_labels.dtype.names:
204
            for character in self._forbidden_label_characters:
205
                if character in label:
206
                    raise ValueError(
207
                        "forbidden character '{char}' is in potential "
208
                        "label '{label}' - you can disable this verification "
209
                        "by enabling `live_dangerously`".format(
210
                            char=character, label=label))
211
        return None
212
213
214
    def _verify_training_data(self):
215
        """
216
        Verify the training data for the appropriate shape and content.
217
        """
218
        if self.training_fluxes.shape != self.training_flux_uncertainties.shape:
219
            raise ValueError(
220
                "the training flux and uncertainty arrays should "
221
                "have the same shape")
222
223
        if len(self.training_labels) == 0 \
224
        or self.training_labels.dtype.names is None:
225
            raise ValueError("no named labels provided for the training set")
226
227
        if len(self.training_labels) != self.training_fluxes.shape[0]:
228
            raise ValueError(
229
                "the first axes of the training flux array should "
230
                "have the same shape as the nuber of rows in the label table "
231
                "(N_stars, N_pixels)")
232
233
        if self.dispersion is not None:
234
            dispersion = np.atleast_1d(self.dispersion).flatten()
235
            if dispersion.size != self.training_fluxes.shape[1]:
236
                raise ValueError(
237
                    "mis-match between the number of wavelength "
238
                    "points ({0}) and flux values ({1})".format(
239
                        self.training_fluxes.shape[1], dispersion.size))
240
        return None
241
242
243
    @property
244
    def is_trained(self):
245
        return self._trained
246
247
248
    # Attributes related to the labels and the label vector description.
249
    @property
250
    def labels_available(self):
251
        """
252
        All of the available labels for each star in the training set.
253
        """
254
        return self.training_labels.dtype.names
255
256
257
    @property
258
    def label_vector(self):
259
        """ The label vector for all pixels. """
260
        return self._label_vector
261
262
263
    @label_vector.setter
264
    def label_vector(self, label_vector_description):
265
        """
266
        Set a label vector.
267
268
        :param label_vector_description:
269
            A structured or human-readable version of the label vector
270
            description.
271
        """
272
273
        if label_vector_description is None:
274
            self._label_vector = None
275
            self.reset()
276
            return None
277
278
        label_vector = utils.parse_label_vector(label_vector_description)
279
280
        # Need to actually verify that the parameters listed in the label vector
281
        # are actually present in the training labels.
282
        missing = \
283
            set(self._get_labels(label_vector)).difference(self.labels_available)
284
        if missing:
285
            raise ValueError("the following labels parsed from the label "
286
                             "vector description are missing in the training "
287
                             "set of labels: {0}".format(", ".join(missing)))
288
289
        # If this is really a new label vector description,
290
        # then we are no longer trained.
291
        if not hasattr(self, "_label_vector") \
292
        or label_vector != self._label_vector:
293
            self._label_vector = label_vector
294
            self.reset()
295
296
        self.pivots = \
297
            { l: np.nanmean(self.training_labels[l]) for l in self.labels }
298
        
299
        return None
300
301
302
    @property
303
    def human_readable_label_vector(self):
304
        """ Return a human-readable form of the label vector. """
305
        return utils.human_readable_label_vector(self.label_vector)
306
307
308
    @property
309
    def labels(self):
310
        """ The labels that contribute to the label vector. """
311
        return self._get_labels(self.label_vector)
312
    
313
314
    def _get_labels(self, label_vector):
315
        """
316
        Return the labels that contribute to the structured label vector
317
        provided.
318
        """
319
        return () if label_vector is None else \
320
            list(OrderedDict.fromkeys([label for term in label_vector \
321
                for label, power in term if power != 0]))
322
323
324
    def _get_lowest_order_label_indices(self):
325
        """
326
        Get the indices for the lowest power label terms in the label vector.
327
        """
328
        indices = OrderedDict()
329
        for i, term in enumerate(self.label_vector):
330
            if len(term) > 1: continue
331
            label, order = term[0]
332
            if order < indices.get(label, [None, np.inf])[-1]:
333
                indices[label] = (i, order)
334
        return [indices.get(label, [None])[0] for label in self.labels]
335
336
337
    # Trained attributes that subclasses are likely to use.
338
    @property
339
    def coefficients(self):
340
        return self._coefficients
341
342
343
    @coefficients.setter
344
    def coefficients(self, coefficients):
345
        """
346
        Set the label vector coefficients for each pixel. This assumes a
347
        'standard' model where the label vector is common to all pixels.
348
349
        :param coefficients:
350
            A 2-D array of coefficients of shape 
351
            (`N_pixels`, `N_label_vector_terms`).
352
        """
353
354
        if coefficients is None:
355
            self._coefficients = None
356
            return None
357
358
        coefficients = np.atleast_2d(coefficients)
359
        if len(coefficients.shape) > 2:
360
            raise ValueError("coefficients must be a 2D array")
361
362
        P, Q = coefficients.shape
363
        if P != len(self.dispersion):
364
            raise ValueError("axis 0 of coefficients array does not match the "
365
                             "number of pixels ({0} != {1})".format(
366
                                P, len(self.dispersion)))
367
        if Q != 1 + len(self.label_vector):
368
            raise ValueError("axis 1 of coefficients array does not match the "
369
                             "number of label vector terms ({0} != {1})".format(
370
                                Q, 1 + len(self.label_vector)))
371
        self._coefficients = coefficients
372
        return None
373
374
375
    @property
376
    def scatter(self):
377
        return self._scatter
378
379
380
    @scatter.setter
381
    def scatter(self, scatter):
382
        """
383
        Set the scatter values for each pixel.
384
385
        :param scatter:
386
            A 1-D array of scatter terms.
387
        """
388
389
        if scatter is None:
390
            self._scatter = None
391
            return None
392
        
393
        scatter = np.array(scatter).flatten()
394
        if scatter.size != len(self.dispersion):
395
            raise ValueError("number of scatter values does not match "
396
                             "the number of pixels ({0} != {1})".format(
397
                                scatter.size, len(self.dispersion)))
398
        if np.any(scatter < 0):
399
            raise ValueError("scatter terms must be positive")
400
        self._scatter = scatter
401
        return None
402
403
404
    @property
405
    def pivots(self):
406
        """
407
        Return the mean values of the unique labels in the label vector.
408
        """
409
        return self._pivots
410
411
412
    @pivots.setter
413
    def pivots(self, pivots):
414
        """
415
        Return the pivot values for each unique label in the label vector.
416
417
        :param pivots:
418
            A list of pivot values for the corresponding terms in `self.labels`.
419
        """
420
421
        if pivots is None:
422
            self._pivots = None
423
            return None
424
425
        if not isinstance(pivots, dict):
426
            raise TypeError("pivots must be a dictionary")
427
428
        missing = set(self.labels).difference(pivots)
429
        if any(missing):
430
            raise ValueError("pivot values for the following labels "
431
                             "are missing: {}".format(", ".join(list(missing))))
432
433
        if not np.all(np.isfinite(pivots.values())):
434
            raise ValueError("pivot values must be finite")
435
436
        self._pivots = pivots
437
        return None
438
439
440
    # Methods which must be implemented or updated by the subclasses.
441
    def pixel_label_vector(self, pixel_index):
442
        """ The label vector for a given pixel. """
443
        return self.label_vector
444
445
446
    def train(self, *args, **kwargs):
447
        raise NotImplementedError("The train method must be "
448
                                  "implemented by subclasses")
449
450
451
    def predict(self, *args, **kwargs):
452
        raise NotImplementedError("The predict method must be "
453
                                  "implemented by subclasses")
454
455
456
    def fit(self, *args, **kwargs):
457
        raise NotImplementedError("The fit method must be "
458
                                  "implemented by subclasses")
459
460
461
    # I/O
462
    @requires_training_wheels
463
    def save(self, filename, include_training_data=False, overwrite=False):
464
        """
465
        Serialise the trained model and save it to disk. This will save all
466
        relevant training attributes, and optionally, the training data.
467
468
        :param filename:
469
            The path to save the model to.
470
471
        :param include_training_data: [optional]
472
            Save the training data (labels, fluxes, uncertainties) used to train
473
            the model.
474
475
        :param overwrite: [optional]
476
            Overwrite the existing file path, if it already exists.
477
        """
478
479
        if path.exists(filename) and not overwrite:
480
            raise IOError("filename already exists: {0}".format(filename))
481
482
        attributes = list(self._descriptive_attributes) \
483
            + list(self._trained_attributes) \
484
            + list(self._data_attributes)
485
        if "metadata" in attributes:
486
            raise ValueError("'metadata' is a protected attribute and cannot "
487
                             "be used in the _*_attributes in a class")
488
489
        # Store up all the trained attributes and a hash of the training set.
490
        contents = OrderedDict([
491
            (attr.lstrip("_"), getattr(self, attr)) for attr in \
492
            (self._descriptive_attributes + self._trained_attributes)])
493
        contents["training_set_hash"] = utils.short_hash(getattr(self, attr) \
494
            for attr in self._data_attributes)
495
496
        if include_training_data:
497
            contents.update([(attr.lstrip("_"), getattr(self, attr)) \
498
                for attr in self._data_attributes])
499
500
        contents["metadata"] = {
501
            "version": code_version,
502
            "model_name": type(self).__name__, 
503
            "modified": str(datetime.now()),
504
            "data_attributes": \
505
                [_.lstrip("_") for _ in self._data_attributes],
506
            "trained_attributes": \
507
                [_.lstrip("_") for _ in self._trained_attributes],
508
            "descriptive_attributes": \
509
                [_.lstrip("_") for _ in self._descriptive_attributes]
510
        }
511
512
        with open(filename, "wb") as fp:
513
            pickle.dump(contents, fp, -1)
514
515
        return None
516
517
518
    def load(self, filename, verify_training_data=False):
519
        """
520
        Load a saved model from disk.
521
522
        :param filename:
523
            The path where to load the model from.
524
525
        :param verify_training_data: [optional]
526
            If there is training data in the saved model, verify its contents.
527
            Otherwise if no training data is saved, verify that the data used
528
            to train the model is the same data provided when this model was
529
            instantiated.
530
        """
531
532
        with open(filename, "rb") as fp:
533
            contents = pickle.load(fp)
534
535
        assert contents["metadata"]["model_name"] == type(self).__name__
536
537
        # If data exists, deal with that first.
538
        has_data = (contents["metadata"]["data_attributes"][0] in contents)
539
        if has_data:
540
541
            if verify_training_data:
542
                data_hash = utils.short_hash(contents[attr] \
543
                    for attr in contents["metadata"]["data_attributes"])
544
                if contents["training_set_hash"] is not None \
545
                and data_hash != contents["training_set_hash"]:
546
                    raise ValueError("expected hash for the training data is "
547
                                     "different to the actual data hash "
548
                                     "({0} != {1})".format(
549
                                        contents["training_set_hash"],
550
                                        data_hash))
551
552
            # Set the data attributes.
553
            for attribute in contents["metadata"]["data_attributes"]:
554
                if attribute in contents:
555
                    setattr(self, "_{}".format(attribute), contents[attribute])
556
557
        # Set descriptive and trained attributes.
558
        self.reset()
559
        for attribute in contents["metadata"]["descriptive_attributes"]:
560
            setattr(self, "_{}".format(attribute), contents[attribute])
561
        for attribute in contents["metadata"]["trained_attributes"]:
562
            setattr(self, "_{}".format(attribute), contents[attribute])
563
        self._trained = True
564
565
        return None
566
567
568
    # Properties and attributes related to training, etc.
569
    @property
570
    @requires_model_description
571
    def labels_array(self):
572
        """
573
        Return an array containing just the training labels, given the label
574
        vector. This array does not account for any pivoting.
575
        """
576
        return _build_label_vector_rows([[(label, 1)] for label in self.labels], 
577
            self.training_labels)[1:].T
578
579
580
    @property
581
    @requires_model_description
582
    def label_vector_array(self):
583
        """
584
        Build the label vector array.
585
        """
586
587
        lva = _build_label_vector_rows(
588
            self.label_vector, self.training_labels, self.pivots)
589
590
        if not np.all(np.isfinite(lva)):
591
            logger.warn("Non-finite labels in the label vector array!")
592
        return lva
593
594
595
    # Residuals in labels in the training data set.
596
    @requires_training_wheels
597
    def get_training_label_residuals(self):
598
        """
599
        Return the residuals (model - training) between the parameters that the
600
        model returns for each star, and the training set value.
601
        """
602
        
603
        optimised_labels = self.fit(self.training_fluxes,
604
            self.training_flux_uncertainties, full_output=False)
605
606
        return optimised_labels - self.labels_array
607
608
609
    def _format_input_labels(self, args=None, **kwargs):
610
        """
611
        Format input labels either from a list or dictionary into a common form.
612
        """
613
614
        # We want labels in a dictionary.
615
        labels = kwargs if args is None else dict(zip(self.labels, args))
616
        return { k: (v if isinstance(v, (list, tuple, np.ndarray)) else [v]) \
617
            for k, v in labels.items() }
618
619
620
    @requires_model_description
621
    def cross_validate(self, pre_train=None, **kwargs):
622
        """
623
        Perform leave-one-out cross-validation on the training set.
624
        """
625
        
626
        inferred = np.nan * np.ones_like(self.labels_array)
627
        N_training_set, N_labels = inferred.shape
628
        N_stop_at = kwargs.pop("N", N_training_set)
629
630
        debug = kwargs.pop("debug", False)
631
        
632
        kwds = { "threads": self.threads }
633
        kwds.update(kwargs)
634
635
        for i in range(N_training_set):
636
            
637
            training_set = np.ones(N_training_set, dtype=bool)
638
            training_set[i] = False
639
640
            # Create a clean model to use so we don't overwrite self.
641
            model = self.__class__(
642
                self.training_labels[training_set],
643
                self.training_fluxes[training_set],
644
                self.training_flux_uncertainties[training_set],
645
                **kwds)
646
647
            # Initialise and run any pre-training function.
648
            for _attribute in self._descriptive_attributes:
649
                setattr(model, _attribute[1:], getattr(self, _attribute[1:]))
650
651
            if pre_train is not None:
652
                pre_train(self, model)
653
654
            # Train and solve.
655
            model.train()
656
657
            try:
658
                inferred[i, :] = model.fit(self.training_fluxes[i],
659
                    self.training_flux_uncertainties[i], full_output=False)
660
661
            except:
662
                logger.exception("Exception during cross-validation on object "
663
                                 "with index {0}:".format(i))
664
                if debug: raise
665
666
            if i == N_stop_at + 1:
667
                break
668
669
        return inferred[:N_stop_at, :]
670
671
672
def _build_label_vector_rows(label_vector, training_labels, pivots=None):
673
    """
674
    Build a label vector row from a description of the label vector (as indices
675
    and orders to the power of) and the label values themselves.
676
677
    For example: if the first item of `labels` is `A`, and the label vector
678
    description is `A^3` then the first item of `label_vector` would be:
679
680
    `[[(0, 3)], ...`
681
682
    This indicates the first label item (index `0`) to the power `3`.
683
684
    :param label_vector:
685
        An `(index, order)` description of the label vector. 
686
687
    :param training_labels:
688
        The values of the corresponding training labels.
689
690
    :returns:
691
        The corresponding label vector row.
692
    """
693
694
    pivots = pivots or {}
695
696
    columns = [np.ones(len(training_labels), dtype=float)]
697
    for term in label_vector:
698
        column = 1.
699
        for label, order in term:
700
            column *= (np.array(training_labels[label]).flatten() \
701
                - pivots.get(label, 0))**order
702
        columns.append(column)
703
704
    try:
705
        return np.vstack(columns)
706
707
    except ValueError:
708
        columns[0] = np.ones(1)
709
        return np.vstack(columns)
710