| 1 |  |  | #!/usr/bin/env python | 
            
                                                                                                            
                            
            
                                    
            
            
                | 2 |  |  | # -*- coding: utf-8 -*- | 
            
                                                                                                            
                            
            
                                    
            
            
                | 3 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 4 |  |  | """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 5 |  |  | An abstract model class for The Cannon. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 6 |  |  | """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 7 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 8 |  |  | from __future__ import (division, print_function, absolute_import, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 9 |  |  |                         unicode_literals) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 10 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 11 |  |  | __all__ = [ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 12 |  |  |     "BaseCannonModel", "requires_training_wheels", "requires_model_description"] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 13 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 14 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 15 |  |  | import logging | 
            
                                                                                                            
                            
            
                                    
            
            
                | 16 |  |  | import numpy as np | 
            
                                                                                                            
                            
            
                                    
            
            
                | 17 |  |  | import multiprocessing as mp | 
            
                                                                                                            
                            
            
                                    
            
            
                | 18 |  |  | from collections import OrderedDict | 
            
                                                                                                            
                            
            
                                    
            
            
                | 19 |  |  | from datetime import datetime | 
            
                                                                                                            
                            
            
                                    
            
            
                | 20 |  |  | from os import path | 
            
                                                                                                            
                            
            
                                    
            
            
                | 21 |  |  | from six.moves import cPickle as pickle | 
            
                                                                                                            
                            
            
                                    
            
            
                | 22 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 23 |  |  | from . import (utils, __version__ as code_version) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 24 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 25 |  |  | logger = logging.getLogger(__name__) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 26 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 27 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 28 |  |  | def requires_training_wheels(method): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 29 |  |  |     """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 30 |  |  |     A decorator for model methods that require training before being run. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 31 |  |  |     """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 32 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 33 |  |  |     def wrapper(model, *args, **kwargs): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 34 |  |  |         if not model.is_trained: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 35 |  |  |             raise TypeError("the model needs training first") | 
            
                                                                                                            
                            
            
                                    
            
            
                | 36 |  |  |         return method(model, *args, **kwargs) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 37 |  |  |     return wrapper | 
            
                                                                                                            
                            
            
                                    
            
            
                | 38 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 39 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 40 |  |  | def requires_model_description(method): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 41 |  |  |     """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 42 |  |  |     A decorator for model methods that require a full model description. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 43 |  |  |     (That is, none of the _descriptive_attributes are None) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 44 |  |  |     """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 45 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 46 |  |  |     def wrapper(model, *args, **kwargs): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 47 |  |  |         for descriptive_attribute in model._descriptive_attributes: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 48 |  |  |             if getattr(model, descriptive_attribute) is None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 49 |  |  |                 raise TypeError("the model requires a {} term".format( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 50 |  |  |                     descriptive_attribute.lstrip("_"))) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 51 |  |  |         return method(model, *args, **kwargs) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 52 |  |  |     return wrapper | 
            
                                                                                                            
                            
            
                                    
            
            
                | 53 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 54 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 55 |  |  | class BaseCannonModel(object): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 56 |  |  |     """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 57 |  |  |     An abstract Cannon model object that implements convenience functions. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 58 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 59 |  |  |     :param labels: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 60 |  |  |         A table with columns as labels, and stars as rows. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 61 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 62 |  |  |     :type labels: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 63 |  |  |         :class:`~astropy.table.Table` or numpy structured array | 
            
                                                                                                            
                            
            
                                    
            
            
                | 64 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 65 |  |  |     :param fluxes: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 66 |  |  |         An array of fluxes for stars in the training set, given as shape | 
            
                                                                                                            
                            
            
                                    
            
            
                | 67 |  |  |         `(num_stars, num_pixels)`. The `num_stars` should match the number of | 
            
                                                                                                            
                            
            
                                    
            
            
                | 68 |  |  |         rows in `labels`. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 69 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 70 |  |  |     :type fluxes: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 71 |  |  |         :class:`np.ndarray` | 
            
                                                                                                            
                            
            
                                    
            
            
                | 72 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 73 |  |  |     :param flux_uncertainties: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 74 |  |  |         An array of 1-sigma flux uncertainties for stars in the training set, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 75 |  |  |         The shape of the `flux_uncertainties` should match `fluxes`.  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 76 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 77 |  |  |     :type flux_uncertainties: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 78 |  |  |         :class:`np.ndarray` | 
            
                                                                                                            
                            
            
                                    
            
            
                | 79 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 80 |  |  |     :param dispersion: [optional] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 81 |  |  |         The dispersion values corresponding to the given pixels. If provided,  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 82 |  |  |         this should have length `num_pixels`. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 83 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 84 |  |  |     :param live_dangerously: [optional] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 85 |  |  |         If enabled then no checks will be made on the label names, prohibiting | 
            
                                                                                                            
                            
            
                                    
            
            
                | 86 |  |  |         the user to input human-readable forms of the label vector. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 87 |  |  |     """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 88 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 89 |  |  |     _descriptive_attributes = ["_label_vector"] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 90 |  |  |     _trained_attributes = ["_scatter", "_coefficients", "_pivots"] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 91 |  |  |     _data_attributes = [] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 92 |  |  |     _forbidden_label_characters = "^*" | 
            
                                                                                                            
                            
            
                                    
            
            
                | 93 |  |  |      | 
            
                                                                                                            
                            
            
                                    
            
            
                | 94 |  |  |     def __init__(self, labels, fluxes, flux_uncertainties, dispersion=None, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 95 |  |  |         threads=1, pool=None, live_dangerously=False): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 96 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 97 |  |  |         self._training_labels = labels | 
            
                                                                                                            
                            
            
                                    
            
            
                | 98 |  |  |         self._training_fluxes = np.atleast_2d(fluxes) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 99 |  |  |         self._training_flux_uncertainties = np.atleast_2d(flux_uncertainties) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 100 |  |  |         self._dispersion = np.arange(fluxes.shape[1], dtype=int) \ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 101 |  |  |             if dispersion is None else dispersion | 
            
                                                                                                            
                            
            
                                    
            
            
                | 102 |  |  |          | 
            
                                                                                                            
                            
            
                                    
            
            
                | 103 |  |  |         for attribute in self._descriptive_attributes: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 104 |  |  |             setattr(self, attribute, None) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 105 |  |  |          | 
            
                                                                                                            
                            
            
                                    
            
            
                | 106 |  |  |         # The training data must be checked, but users can live dangerously if | 
            
                                                                                                            
                            
            
                                    
            
            
                | 107 |  |  |         # they think they can correctly specify the label vector description. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 108 |  |  |         self._verify_training_data() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 109 |  |  |         if not live_dangerously: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 110 |  |  |             self._verify_labels_available() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 111 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 112 |  |  |         self.reset() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 113 |  |  |         self.threads = threads | 
            
                                                                                                            
                            
            
                                    
            
            
                | 114 |  |  |         self.pool = pool or mp.Pool(threads) if threads > 1 else None | 
            
                                                                                                            
                            
            
                                    
            
            
                | 115 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 116 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 117 |  |  |     def reset(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 118 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 119 |  |  |         Clear any attributes that have been trained upon. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 120 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 121 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 122 |  |  |         self._trained = False | 
            
                                                                                                            
                            
            
                                    
            
            
                | 123 |  |  |         for attribute in self._trained_attributes: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 124 |  |  |             setattr(self, attribute, None) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 125 |  |  |          | 
            
                                                                                                            
                            
            
                                    
            
            
                | 126 |  |  |         return None | 
            
                                                                                                            
                            
            
                                    
            
            
                | 127 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 128 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 129 |  |  |     def __str__(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 130 |  |  |         return "<{module}.{name} {trained}using a training set of {N} stars "\ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 131 |  |  |                "with {K} available labels and {M} pixels each>".format( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 132 |  |  |                     module=self.__module__, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 133 |  |  |                     name=type(self).__name__, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 134 |  |  |                     trained="trained " if self.is_trained else "", | 
            
                                                                                                            
                            
            
                                    
            
            
                | 135 |  |  |                     N=len(self.training_labels), | 
            
                                                                                                            
                            
            
                                    
            
            
                | 136 |  |  |                     K=len(self.labels_available), | 
            
                                                                                                            
                            
            
                                    
            
            
                | 137 |  |  |                     M=len(self.dispersion)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 138 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 139 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 140 |  |  |     def __repr__(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 141 |  |  |         return "<{0}.{1} object at {2}>".format( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 142 |  |  |             self.__module__, type(self).__name__, hex(id(self))) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 143 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 144 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 145 |  |  |     # Attributes related to the training data. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 146 |  |  |     @property | 
            
                                                                                                            
                            
            
                                    
            
            
                | 147 |  |  |     def dispersion(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 148 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 149 |  |  |         Return the dispersion points for all pixels. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 150 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 151 |  |  |         return self._dispersion | 
            
                                                                                                            
                            
            
                                    
            
            
                | 152 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 153 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 154 |  |  |     @dispersion.setter | 
            
                                                                                                            
                            
            
                                    
            
            
                | 155 |  |  |     def dispersion(self, dispersion): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 156 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 157 |  |  |         Set the dispersion values for all the pixels. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 158 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 159 |  |  |         try: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 160 |  |  |             len(dispersion) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 161 |  |  |         except TypeError: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 162 |  |  |             raise TypeError("dispersion provided must be an array or list-like") | 
            
                                                                                                            
                            
            
                                    
            
            
                | 163 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 164 |  |  |         if len(dispersion) != self.training_fluxes.shape[1]: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 165 |  |  |             raise ValueError("dispersion provided does not match the number " | 
            
                                                                                                            
                            
            
                                    
            
            
                | 166 |  |  |                              "of pixels per star ({0} != {1})".format( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 167 |  |  |                                 len(dispersion), self.training_fluxes.shape[1])) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 168 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 169 |  |  |         dispersion = np.array(dispersion) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 170 |  |  |         if dispersion.dtype.kind not in "iuf": | 
            
                                                                                                            
                            
            
                                    
            
            
                | 171 |  |  |             raise ValueError("dispersion values are not float-like") | 
            
                                                                                                            
                            
            
                                    
            
            
                | 172 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 173 |  |  |         if not np.all(np.isfinite(dispersion)): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 174 |  |  |             raise ValueError("dispersion values must be finite") | 
            
                                                                                                            
                            
            
                                    
            
            
                | 175 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 176 |  |  |         self._dispersion = dispersion | 
            
                                                                                                            
                            
            
                                    
            
            
                | 177 |  |  |         return None | 
            
                                                                                                            
                            
            
                                    
            
            
                | 178 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 179 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 180 |  |  |     @property | 
            
                                                                                                            
                            
            
                                    
            
            
                | 181 |  |  |     def training_labels(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 182 |  |  |         return self._training_labels | 
            
                                                                                                            
                            
            
                                    
            
            
                | 183 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 184 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 185 |  |  |     @property | 
            
                                                                                                            
                            
            
                                    
            
            
                | 186 |  |  |     def training_fluxes(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 187 |  |  |         return self._training_fluxes | 
            
                                                                                                            
                            
            
                                    
            
            
                | 188 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 189 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 190 |  |  |     @property | 
            
                                                                                                            
                            
            
                                    
            
            
                | 191 |  |  |     def training_flux_uncertainties(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 192 |  |  |         return self._training_flux_uncertainties | 
            
                                                                                                            
                            
            
                                    
            
            
                | 193 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 194 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 195 |  |  |     # Verifying the training data. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 196 |  |  |     def _verify_labels_available(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 197 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 198 |  |  |         Verify the label names provided do not include forbidden characters. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 199 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 200 |  |  |         if self._forbidden_label_characters is None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 201 |  |  |             return True | 
            
                                                                                                            
                            
            
                                    
            
            
                | 202 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 203 |  |  |         for label in self.training_labels.dtype.names: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 204 |  |  |             for character in self._forbidden_label_characters: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 205 |  |  |                 if character in label: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 206 |  |  |                     raise ValueError( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 207 |  |  |                         "forbidden character '{char}' is in potential " | 
            
                                                                                                            
                            
            
                                    
            
            
                | 208 |  |  |                         "label '{label}' - you can disable this verification " | 
            
                                                                                                            
                            
            
                                    
            
            
                | 209 |  |  |                         "by enabling `live_dangerously`".format( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 210 |  |  |                             char=character, label=label)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 211 |  |  |         return None | 
            
                                                                                                            
                            
            
                                    
            
            
                | 212 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 213 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 214 |  |  |     def _verify_training_data(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 215 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 216 |  |  |         Verify the training data for the appropriate shape and content. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 217 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 218 |  |  |         if self.training_fluxes.shape != self.training_flux_uncertainties.shape: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 219 |  |  |             raise ValueError( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 220 |  |  |                 "the training flux and uncertainty arrays should " | 
            
                                                                                                            
                            
            
                                    
            
            
                | 221 |  |  |                 "have the same shape") | 
            
                                                                                                            
                            
            
                                    
            
            
                | 222 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 223 |  |  |         if len(self.training_labels) == 0 \ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 224 |  |  |         or self.training_labels.dtype.names is None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 225 |  |  |             raise ValueError("no named labels provided for the training set") | 
            
                                                                                                            
                            
            
                                    
            
            
                | 226 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 227 |  |  |         if len(self.training_labels) != self.training_fluxes.shape[0]: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 228 |  |  |             raise ValueError( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 229 |  |  |                 "the first axes of the training flux array should " | 
            
                                                                                                            
                            
            
                                    
            
            
                | 230 |  |  |                 "have the same shape as the nuber of rows in the label table " | 
            
                                                                                                            
                            
            
                                    
            
            
                | 231 |  |  |                 "(N_stars, N_pixels)") | 
            
                                                                                                            
                            
            
                                    
            
            
                | 232 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 233 |  |  |         if self.dispersion is not None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 234 |  |  |             dispersion = np.atleast_1d(self.dispersion).flatten() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 235 |  |  |             if dispersion.size != self.training_fluxes.shape[1]: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 236 |  |  |                 raise ValueError( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 237 |  |  |                     "mis-match between the number of wavelength " | 
            
                                                                                                            
                            
            
                                    
            
            
                | 238 |  |  |                     "points ({0}) and flux values ({1})".format( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 239 |  |  |                         self.training_fluxes.shape[1], dispersion.size)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 240 |  |  |         return None | 
            
                                                                                                            
                            
            
                                    
            
            
                | 241 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 242 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 243 |  |  |     @property | 
            
                                                                                                            
                            
            
                                    
            
            
                | 244 |  |  |     def is_trained(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 245 |  |  |         return self._trained | 
            
                                                                                                            
                            
            
                                    
            
            
                | 246 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 247 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 248 |  |  |     # Attributes related to the labels and the label vector description. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 249 |  |  |     @property | 
            
                                                                                                            
                            
            
                                    
            
            
                | 250 |  |  |     def labels_available(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 251 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 252 |  |  |         All of the available labels for each star in the training set. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 253 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 254 |  |  |         return self.training_labels.dtype.names | 
            
                                                                                                            
                            
            
                                    
            
            
                | 255 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 256 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 257 |  |  |     @property | 
            
                                                                                                            
                            
            
                                    
            
            
                | 258 |  |  |     def label_vector(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 259 |  |  |         """ The label vector for all pixels. """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 260 |  |  |         return self._label_vector | 
            
                                                                                                            
                            
            
                                    
            
            
                | 261 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 262 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 263 |  |  |     @label_vector.setter | 
            
                                                                                                            
                            
            
                                    
            
            
                | 264 |  |  |     def label_vector(self, label_vector_description): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 265 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 266 |  |  |         Set a label vector. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 267 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 268 |  |  |         :param label_vector_description: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 269 |  |  |             A structured or human-readable version of the label vector | 
            
                                                                                                            
                            
            
                                    
            
            
                | 270 |  |  |             description. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 271 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 272 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 273 |  |  |         if label_vector_description is None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 274 |  |  |             self._label_vector = None | 
            
                                                                                                            
                            
            
                                    
            
            
                | 275 |  |  |             self.reset() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 276 |  |  |             return None | 
            
                                                                                                            
                            
            
                                    
            
            
                | 277 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 278 |  |  |         label_vector = utils.parse_label_vector(label_vector_description) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 279 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 280 |  |  |         # Need to actually verify that the parameters listed in the label vector | 
            
                                                                                                            
                            
            
                                    
            
            
                | 281 |  |  |         # are actually present in the training labels. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 282 |  |  |         missing = \ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 283 |  |  |             set(self._get_labels(label_vector)).difference(self.labels_available) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 284 |  |  |         if missing: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 285 |  |  |             raise ValueError("the following labels parsed from the label " | 
            
                                                                                                            
                            
            
                                    
            
            
                | 286 |  |  |                              "vector description are missing in the training " | 
            
                                                                                                            
                            
            
                                    
            
            
                | 287 |  |  |                              "set of labels: {0}".format(", ".join(missing))) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 288 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 289 |  |  |         # If this is really a new label vector description, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 290 |  |  |         # then we are no longer trained. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 291 |  |  |         if not hasattr(self, "_label_vector") \ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 292 |  |  |         or label_vector != self._label_vector: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 293 |  |  |             self._label_vector = label_vector | 
            
                                                                                                            
                            
            
                                    
            
            
                | 294 |  |  |             self.reset() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 295 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 296 |  |  |         self.pivots = \ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 297 |  |  |             np.array([np.nanmean(self.training_labels[l]) for l in self.labels]) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 298 |  |  |          | 
            
                                                                                                            
                            
            
                                    
            
            
                | 299 |  |  |         return None | 
            
                                                                                                            
                            
            
                                    
            
            
                | 300 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 301 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 302 |  |  |     @property | 
            
                                                                                                            
                            
            
                                    
            
            
                | 303 |  |  |     def human_readable_label_vector(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 304 |  |  |         """ Return a human-readable form of the label vector. """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 305 |  |  |         return utils.human_readable_label_vector(self.label_vector) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 306 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 307 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 308 |  |  |     @property | 
            
                                                                                                            
                            
            
                                    
            
            
                | 309 |  |  |     def labels(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 310 |  |  |         """ The labels that contribute to the label vector. """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 311 |  |  |         return self._get_labels(self.label_vector) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 312 |  |  |      | 
            
                                                                                                            
                            
            
                                    
            
            
                | 313 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 314 |  |  |     def _get_labels(self, label_vector): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 315 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 316 |  |  |         Return the labels that contribute to the structured label vector | 
            
                                                                                                            
                            
            
                                    
            
            
                | 317 |  |  |         provided. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 318 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 319 |  |  |         return () if label_vector is None else \ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 320 |  |  |             list(OrderedDict.fromkeys([label for term in label_vector \ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 321 |  |  |                 for label, power in term if power != 0])) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 322 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 323 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 324 |  |  |     def _get_lowest_order_label_indices(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 325 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 326 |  |  |         Get the indices for the lowest power label terms in the label vector. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 327 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 328 |  |  |         indices = OrderedDict() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 329 |  |  |         for i, term in enumerate(self.label_vector): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 330 |  |  |             if len(term) > 1: continue | 
            
                                                                                                            
                            
            
                                    
            
            
                | 331 |  |  |             label, order = term[0] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 332 |  |  |             if order < indices.get(label, [None, np.inf])[-1]: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 333 |  |  |                 indices[label] = (i, order) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 334 |  |  |         return [indices.get(label, [None])[0] for label in self.labels] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 335 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 336 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 337 |  |  |     # Trained attributes that subclasses are likely to use. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 338 |  |  |     @property | 
            
                                                                                                            
                            
            
                                    
            
            
                | 339 |  |  |     def coefficients(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 340 |  |  |         return self._coefficients | 
            
                                                                                                            
                            
            
                                    
            
            
                | 341 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 342 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 343 |  |  |     @coefficients.setter | 
            
                                                                                                            
                            
            
                                    
            
            
                | 344 |  |  |     def coefficients(self, coefficients): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 345 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 346 |  |  |         Set the label vector coefficients for each pixel. This assumes a | 
            
                                                                                                            
                            
            
                                    
            
            
                | 347 |  |  |         'standard' model where the label vector is common to all pixels. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 348 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 349 |  |  |         :param coefficients: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 350 |  |  |             A 2-D array of coefficients of shape  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 351 |  |  |             (`N_pixels`, `N_label_vector_terms`). | 
            
                                                                                                            
                            
            
                                    
            
            
                | 352 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 353 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 354 |  |  |         if coefficients is None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 355 |  |  |             self._coefficients = None | 
            
                                                                                                            
                            
            
                                    
            
            
                | 356 |  |  |             return None | 
            
                                                                                                            
                            
            
                                    
            
            
                | 357 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 358 |  |  |         coefficients = np.atleast_2d(coefficients) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 359 |  |  |         if len(coefficients.shape) > 2: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 360 |  |  |             raise ValueError("coefficients must be a 2D array") | 
            
                                                                                                            
                            
            
                                    
            
            
                | 361 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 362 |  |  |         P, Q = coefficients.shape | 
            
                                                                                                            
                            
            
                                    
            
            
                | 363 |  |  |         if P != len(self.dispersion): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 364 |  |  |             raise ValueError("axis 0 of coefficients array does not match the " | 
            
                                                                                                            
                            
            
                                    
            
            
                | 365 |  |  |                              "number of pixels ({0} != {1})".format( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 366 |  |  |                                 P, len(self.dispersion))) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 367 |  |  |         if Q != 1 + len(self.label_vector): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 368 |  |  |             raise ValueError("axis 1 of coefficients array does not match the " | 
            
                                                                                                            
                            
            
                                    
            
            
                | 369 |  |  |                              "number of label vector terms ({0} != {1})".format( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 370 |  |  |                                 Q, 1 + len(self.label_vector))) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 371 |  |  |         self._coefficients = coefficients | 
            
                                                                                                            
                            
            
                                    
            
            
                | 372 |  |  |         return None | 
            
                                                                                                            
                            
            
                                    
            
            
                | 373 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 374 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 375 |  |  |     @property | 
            
                                                                                                            
                            
            
                                    
            
            
                | 376 |  |  |     def scatter(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 377 |  |  |         return self._scatter | 
            
                                                                                                            
                            
            
                                    
            
            
                | 378 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 379 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 380 |  |  |     @scatter.setter | 
            
                                                                                                            
                            
            
                                    
            
            
                | 381 |  |  |     def scatter(self, scatter): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 382 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 383 |  |  |         Set the scatter values for each pixel. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 384 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 385 |  |  |         :param scatter: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 386 |  |  |             A 1-D array of scatter terms. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 387 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 388 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 389 |  |  |         if scatter is None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 390 |  |  |             self._scatter = None | 
            
                                                                                                            
                            
            
                                    
            
            
                | 391 |  |  |             return None | 
            
                                                                                                            
                            
            
                                    
            
            
                | 392 |  |  |          | 
            
                                                                                                            
                            
            
                                    
            
            
                | 393 |  |  |         scatter = np.array(scatter).flatten() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 394 |  |  |         if scatter.size != len(self.dispersion): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 395 |  |  |             raise ValueError("number of scatter values does not match " | 
            
                                                                                                            
                            
            
                                    
            
            
                | 396 |  |  |                              "the number of pixels ({0} != {1})".format( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 397 |  |  |                                 scatter.size, len(self.dispersion))) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 398 |  |  |         if np.any(scatter < 0): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 399 |  |  |             raise ValueError("scatter terms must be positive") | 
            
                                                                                                            
                            
            
                                    
            
            
                | 400 |  |  |         self._scatter = scatter | 
            
                                                                                                            
                            
            
                                    
            
            
                | 401 |  |  |         return None | 
            
                                                                                                            
                            
            
                                    
            
            
                | 402 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 403 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 404 |  |  |     @property | 
            
                                                                                                            
                            
            
                                    
            
            
                | 405 |  |  |     def pivots(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 406 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 407 |  |  |         Return the mean values of the unique labels in the label vector. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 408 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 409 |  |  |         return self._pivots | 
            
                                                                                                            
                            
            
                                    
            
            
                | 410 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 411 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 412 |  |  |     @pivots.setter | 
            
                                                                                                            
                            
            
                                    
            
            
                | 413 |  |  |     def pivots(self, pivots): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 414 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 415 |  |  |         Return the pivot values for each unique label in the label vector. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 416 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 417 |  |  |         :param pivots: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 418 |  |  |             A list of pivot values for the corresponding terms in `self.labels`. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 419 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 420 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 421 |  |  |         if pivots is None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 422 |  |  |             self._pivots = None | 
            
                                                                                                            
                            
            
                                    
            
            
                | 423 |  |  |             return None | 
            
                                                                                                            
                            
            
                                    
            
            
                | 424 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 425 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 426 |  |  |         if not isinstance(pivots, dict): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 427 |  |  |             raise TypeError("pivots must be a dictionary") | 
            
                                                                                                            
                            
            
                                    
            
            
                | 428 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 429 |  |  |         missing = set(self.labels).difference(pivots) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 430 |  |  |         if any(missing): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 431 |  |  |             raise ValueError("pivot values for the following labels " | 
            
                                                                                                            
                            
            
                                    
            
            
                | 432 |  |  |                              "are missing: {}".format(", ".join(list(missing)))) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 433 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 434 |  |  |         if not np.all(np.isfinite(pivots.values())): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 435 |  |  |             raise ValueError("pivot values must be finite") | 
            
                                                                                                            
                            
            
                                    
            
            
                | 436 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 437 |  |  |         self._pivots = pivots | 
            
                                                                                                            
                            
            
                                    
            
            
                | 438 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 439 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 440 |  |  |         pivots = np.array(pivots).flatten() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 441 |  |  |         N_labels = len(self.labels) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 442 |  |  |         if pivots.size != N_labels: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 443 |  |  |             raise ValueError("number of pivot values does not match the " | 
            
                                                                                                            
                            
            
                                    
            
            
                | 444 |  |  |                              "number of unique labels in the label vector " | 
            
                                                                                                            
                            
            
                                    
            
            
                | 445 |  |  |                              "({0} != {1})".format(pivots.size, N_labels)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 446 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 447 |  |  |         if not np.all(np.isfinite(pivots)): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 448 |  |  |             raise ValueError("pivot values must be finite") | 
            
                                                                                                            
                            
            
                                    
            
            
                | 449 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 450 |  |  |         self._pivots = pivots | 
            
                                                                                                            
                            
            
                                    
            
            
                | 451 |  |  |         return None | 
            
                                                                                                            
                            
            
                                    
            
            
                | 452 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 453 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 454 |  |  |     # Methods which must be implemented or updated by the subclasses. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 455 |  |  |     def pixel_label_vector(self, pixel_index): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 456 |  |  |         """ The label vector for a given pixel. """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 457 |  |  |         return self.label_vector | 
            
                                                                                                            
                            
            
                                    
            
            
                | 458 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 459 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 460 |  |  |     def train(self, *args, **kwargs): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 461 |  |  |         raise NotImplementedError("The train method must be " | 
            
                                                                                                            
                            
            
                                    
            
            
                | 462 |  |  |                                   "implemented by subclasses") | 
            
                                                                                                            
                            
            
                                    
            
            
                | 463 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 464 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 465 |  |  |     def predict(self, *args, **kwargs): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 466 |  |  |         raise NotImplementedError("The predict method must be " | 
            
                                                                                                            
                            
            
                                    
            
            
                | 467 |  |  |                                   "implemented by subclasses") | 
            
                                                                                                            
                            
            
                                    
            
            
                | 468 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 469 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 470 |  |  |     def fit(self, *args, **kwargs): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 471 |  |  |         raise NotImplementedError("The fit method must be " | 
            
                                                                                                            
                            
            
                                    
            
            
                | 472 |  |  |                                   "implemented by subclasses") | 
            
                                                                                                            
                            
            
                                    
            
            
                | 473 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 474 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 475 |  |  |     # I/O | 
            
                                                                                                            
                            
            
                                    
            
            
                | 476 |  |  |     @requires_training_wheels | 
            
                                                                                                            
                            
            
                                    
            
            
                | 477 |  |  |     def save(self, filename, include_training_data=False, overwrite=False): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 478 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 479 |  |  |         Serialise the trained model and save it to disk. This will save all | 
            
                                                                                                            
                            
            
                                    
            
            
                | 480 |  |  |         relevant training attributes, and optionally, the training data. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 481 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 482 |  |  |         :param filename: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 483 |  |  |             The path to save the model to. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 484 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 485 |  |  |         :param include_training_data: [optional] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 486 |  |  |             Save the training data (labels, fluxes, uncertainties) used to train | 
            
                                                                                                            
                            
            
                                    
            
            
                | 487 |  |  |             the model. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 488 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 489 |  |  |         :param overwrite: [optional] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 490 |  |  |             Overwrite the existing file path, if it already exists. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 491 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 492 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 493 |  |  |         if path.exists(filename) and not overwrite: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 494 |  |  |             raise IOError("filename already exists: {0}".format(filename)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 495 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 496 |  |  |         attributes = list(self._descriptive_attributes) \ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 497 |  |  |             + list(self._trained_attributes) \ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 498 |  |  |             + list(self._data_attributes) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 499 |  |  |         if "metadata" in attributes: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 500 |  |  |             raise ValueError("'metadata' is a protected attribute and cannot " | 
            
                                                                                                            
                            
            
                                    
            
            
                | 501 |  |  |                              "be used in the _*_attributes in a class") | 
            
                                                                                                            
                            
            
                                    
            
            
                | 502 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 503 |  |  |         # Store up all the trained attributes and a hash of the training set. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 504 |  |  |         contents = OrderedDict([ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 505 |  |  |             (attr.lstrip("_"), getattr(self, attr)) for attr in \ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 506 |  |  |             (self._descriptive_attributes + self._trained_attributes)]) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 507 |  |  |         contents["training_set_hash"] = utils.short_hash(getattr(self, attr) \ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 508 |  |  |             for attr in self._data_attributes) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 509 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 510 |  |  |         if include_training_data: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 511 |  |  |             contents.update([(attr.lstrip("_"), getattr(self, attr)) \ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 512 |  |  |                 for attr in self._data_attributes]) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 513 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 514 |  |  |         contents["metadata"] = { | 
            
                                                                                                            
                            
            
                                    
            
            
                | 515 |  |  |             "version": code_version, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 516 |  |  |             "model_name": type(self).__name__,  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 517 |  |  |             "modified": str(datetime.now()), | 
            
                                                                                                            
                            
            
                                    
            
            
                | 518 |  |  |             "data_attributes": \ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 519 |  |  |                 [_.lstrip("_") for _ in self._data_attributes], | 
            
                                                                                                            
                            
            
                                    
            
            
                | 520 |  |  |             "trained_attributes": \ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 521 |  |  |                 [_.lstrip("_") for _ in self._trained_attributes], | 
            
                                                                                                            
                            
            
                                    
            
            
                | 522 |  |  |             "descriptive_attributes": \ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 523 |  |  |                 [_.lstrip("_") for _ in self._descriptive_attributes] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 524 |  |  |         } | 
            
                                                                                                            
                            
            
                                    
            
            
                | 525 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 526 |  |  |         with open(filename, "wb") as fp: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 527 |  |  |             pickle.dump(contents, fp, -1) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 528 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 529 |  |  |         return None | 
            
                                                                                                            
                            
            
                                    
            
            
                | 530 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 531 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 532 |  |  |     def load(self, filename, verify_training_data=False): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 533 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 534 |  |  |         Load a saved model from disk. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 535 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 536 |  |  |         :param filename: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 537 |  |  |             The path where to load the model from. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 538 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 539 |  |  |         :param verify_training_data: [optional] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 540 |  |  |             If there is training data in the saved model, verify its contents. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 541 |  |  |             Otherwise if no training data is saved, verify that the data used | 
            
                                                                                                            
                            
            
                                    
            
            
                | 542 |  |  |             to train the model is the same data provided when this model was | 
            
                                                                                                            
                            
            
                                    
            
            
                | 543 |  |  |             instantiated. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 544 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 545 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 546 |  |  |         with open(filename, "rb") as fp: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 547 |  |  |             contents = pickle.load(fp) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 548 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 549 |  |  |         assert contents["metadata"]["model_name"] == type(self).__name__ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 550 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 551 |  |  |         # If data exists, deal with that first. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 552 |  |  |         has_data = (contents["metadata"]["data_attributes"][0] in contents) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 553 |  |  |         if has_data: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 554 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 555 |  |  |             if verify_training_data: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 556 |  |  |                 data_hash = utils.short_hash(contents[attr] \ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 557 |  |  |                     for attr in contents["metadata"]["data_attributes"]) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 558 |  |  |                 if contents["training_set_hash"] is not None \ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 559 |  |  |                 and data_hash != contents["training_set_hash"]: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 560 |  |  |                     raise ValueError("expected hash for the training data is " | 
            
                                                                                                            
                            
            
                                    
            
            
                | 561 |  |  |                                      "different to the actual data hash " | 
            
                                                                                                            
                            
            
                                    
            
            
                | 562 |  |  |                                      "({0} != {1})".format( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 563 |  |  |                                         contents["training_set_hash"], | 
            
                                                                                                            
                            
            
                                    
            
            
                | 564 |  |  |                                         data_hash)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 565 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 566 |  |  |             # Set the data attributes. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 567 |  |  |             for attribute in contents["metadata"]["data_attributes"]: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 568 |  |  |                 if attribute in contents: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 569 |  |  |                     setattr(self, "_{}".format(attribute), contents[attribute]) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 570 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 571 |  |  |         # Set descriptive and trained attributes. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 572 |  |  |         self.reset() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 573 |  |  |         for attribute in contents["metadata"]["descriptive_attributes"]: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 574 |  |  |             setattr(self, "_{}".format(attribute), contents[attribute]) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 575 |  |  |         for attribute in contents["metadata"]["trained_attributes"]: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 576 |  |  |             setattr(self, "_{}".format(attribute), contents[attribute]) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 577 |  |  |         self._trained = True | 
            
                                                                                                            
                            
            
                                    
            
            
                | 578 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 579 |  |  |         return None | 
            
                                                                                                            
                            
            
                                    
            
            
                | 580 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 581 |  |  |  | 
            
                                                                                                            
                                                                
            
                                    
            
            
                | 582 |  |  |     # Properties and attributes related to training, etc. | 
            
                                                                        
                            
            
                                    
            
            
                | 583 |  |  |     @property | 
            
                                                                        
                            
            
                                    
            
            
                | 584 |  |  |     @requires_model_description | 
            
                                                                        
                            
            
                                    
            
            
                | 585 |  |  |     def labels_array(self): | 
            
                                                                        
                            
            
                                    
            
            
                | 586 |  |  |         """ | 
            
                                                                        
                            
            
                                    
            
            
                | 587 |  |  |         Return an array containing just the training labels, given the label | 
            
                                                                        
                            
            
                                    
            
            
                | 588 |  |  |         vector. This array does not account for any pivoting. | 
            
                                                                        
                            
            
                                    
            
            
                | 589 |  |  |         """ | 
            
                                                                        
                            
            
                                    
            
            
                | 590 |  |  |         return _build_label_vector_rows([[(label, 1)] for label in self.labels],  | 
            
                                                                        
                            
            
                                    
            
            
                | 591 |  |  |             self.training_labels)[1:].T | 
            
                                                                                                            
                            
            
                                    
            
            
                | 592 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 593 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 594 |  |  |     @property | 
            
                                                                                                            
                            
            
                                    
            
            
                | 595 |  |  |     @requires_model_description | 
            
                                                                                                            
                            
            
                                    
            
            
                | 596 |  |  |     def label_vector_array(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 597 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 598 |  |  |         Build the label vector array. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 599 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 600 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 601 |  |  |         lva = _build_label_vector_rows(self.label_vector, self.training_labels,  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 602 |  |  |             dict(zip(self.labels, self.pivots))) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 603 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 604 |  |  |         if not np.all(np.isfinite(lva)): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 605 |  |  |             logger.warn("Non-finite labels in the label vector array!") | 
            
                                                                                                            
                            
            
                                    
            
            
                | 606 |  |  |         return lva | 
            
                                                                                                            
                            
            
                                    
            
            
                | 607 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 608 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 609 |  |  |     # Residuals in labels in the training data set. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 610 |  |  |     @requires_training_wheels | 
            
                                                                                                            
                            
            
                                    
            
            
                | 611 |  |  |     def get_training_label_residuals(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 612 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 613 |  |  |         Return the residuals (model - training) between the parameters that the | 
            
                                                                                                            
                            
            
                                    
            
            
                | 614 |  |  |         model returns for each star, and the training set value. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 615 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 616 |  |  |          | 
            
                                                                                                            
                            
            
                                    
            
            
                | 617 |  |  |         optimised_labels = self.fit(self.training_fluxes, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 618 |  |  |             self.training_flux_uncertainties, full_output=False) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 619 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 620 |  |  |         return optimised_labels - self.labels_array | 
            
                                                                                                            
                            
            
                                    
            
            
                | 621 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 622 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 623 |  |  |     def _format_input_labels(self, args=None, **kwargs): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 624 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 625 |  |  |         Format input labels either from a list or dictionary into a common form. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 626 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 627 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 628 |  |  |         # We want labels in a dictionary. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 629 |  |  |         labels = kwargs if args is None else dict(zip(self.labels, args)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 630 |  |  |         return { k: (v if isinstance(v, (list, tuple, np.ndarray)) else [v]) \ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 631 |  |  |             for k, v in labels.items() } | 
            
                                                                                                            
                            
            
                                    
            
            
                | 632 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 633 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 634 |  |  |     @requires_model_description | 
            
                                                                                                            
                            
            
                                    
            
            
                | 635 |  |  |     def cross_validate(self, pre_train=None, **kwargs): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 636 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 637 |  |  |         Perform leave-one-out cross-validation on the training set. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 638 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 639 |  |  |          | 
            
                                                                                                            
                            
            
                                    
            
            
                | 640 |  |  |         inferred = np.nan * np.ones_like(self.labels_array) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 641 |  |  |         N_training_set, N_labels = inferred.shape | 
            
                                                                                                            
                            
            
                                    
            
            
                | 642 |  |  |         N_stop_at = kwargs.pop("N", N_training_set) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 643 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 644 |  |  |         debug = kwargs.pop("debug", False) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 645 |  |  |          | 
            
                                                                                                            
                            
            
                                    
            
            
                | 646 |  |  |         kwds = { "threads": self.threads } | 
            
                                                                                                            
                            
            
                                    
            
            
                | 647 |  |  |         kwds.update(kwargs) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 648 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 649 |  |  |         for i in range(N_training_set): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 650 |  |  |              | 
            
                                                                                                            
                            
            
                                    
            
            
                | 651 |  |  |             training_set = np.ones(N_training_set, dtype=bool) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 652 |  |  |             training_set[i] = False | 
            
                                                                                                            
                            
            
                                    
            
            
                | 653 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 654 |  |  |             # Create a clean model to use so we don't overwrite self. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 655 |  |  |             model = self.__class__( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 656 |  |  |                 self.training_labels[training_set], | 
            
                                                                                                            
                            
            
                                    
            
            
                | 657 |  |  |                 self.training_fluxes[training_set], | 
            
                                                                                                            
                            
            
                                    
            
            
                | 658 |  |  |                 self.training_flux_uncertainties[training_set], | 
            
                                                                                                            
                            
            
                                    
            
            
                | 659 |  |  |                 **kwds) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 660 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 661 |  |  |             # Initialise and run any pre-training function. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 662 |  |  |             for _attribute in self._descriptive_attributes: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 663 |  |  |                 setattr(model, _attribute[1:], getattr(self, _attribute[1:])) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 664 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 665 |  |  |             if pre_train is not None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 666 |  |  |                 pre_train(self, model) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 667 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 668 |  |  |             # Train and solve. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 669 |  |  |             model.train() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 670 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 671 |  |  |             try: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 672 |  |  |                 inferred[i, :] = model.fit(self.training_fluxes[i], | 
            
                                                                                                            
                            
            
                                    
            
            
                | 673 |  |  |                     self.training_flux_uncertainties[i], full_output=False) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 674 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 675 |  |  |             except: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 676 |  |  |                 logger.exception("Exception during cross-validation on object " | 
            
                                                                                                            
                            
            
                                    
            
            
                | 677 |  |  |                                  "with index {0}:".format(i)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 678 |  |  |                 if debug: raise | 
            
                                                                                                            
                            
            
                                    
            
            
                | 679 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 680 |  |  |             if i == N_stop_at + 1: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 681 |  |  |                 break | 
            
                                                                                                            
                            
            
                                    
            
            
                | 682 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 683 |  |  |         return inferred[:N_stop_at, :] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 684 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 685 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 686 |  |  |     @requires_training_wheels | 
            
                                                                                                            
                            
            
                                    
            
            
                | 687 |  |  |     def define_continuum_mask(self, baseline_flux=None, tolerances=None, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 688 |  |  |         percentiles=None, absolute_percentiles=None): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 689 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 690 |  |  |         Define a continuum mask based on constraints on the baseline flux values | 
            
                                                                                                            
                            
            
                                    
            
            
                | 691 |  |  |         and the percentiles or absolute percentiles of theta coefficients. The | 
            
                                                                                                            
                            
            
                                    
            
            
                | 692 |  |  |         resulting continuum mask is taken for whatever pixels meet all the given | 
            
                                                                                                            
                            
            
                                    
            
            
                | 693 |  |  |         constraints. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 694 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 695 |  |  |         :param baseline_flux: [optional] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 696 |  |  |             The `(lower, upper`) range of acceptable baseline flux values to be  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 697 |  |  |             considered as continuum. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 698 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 699 |  |  |         :param percentiles: [optional] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 700 |  |  |             A dictionary containing the label vector description as keys and | 
            
                                                                                                            
                            
            
                                    
            
            
                | 701 |  |  |             acceptable percentile ranges `(lower, upper)` for each corresponding | 
            
                                                                                                            
                            
            
                                    
            
            
                | 702 |  |  |             label vector term. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 703 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 704 |  |  |         :param absolute_percentiles: [optional] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 705 |  |  |             The same as `percentiles`, except these are calculated on the | 
            
                                                                                                            
                            
            
                                    
            
            
                | 706 |  |  |             absolute values of the model coefficients. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 707 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 708 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 709 |  |  |         mask = np.ones_like(self.dispersion, dtype=bool) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 710 |  |  |         if baseline_flux is not None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 711 |  |  |             if len(baseline_flux) != 2: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 712 |  |  |                 raise ValueError("baseline flux constraints must be given as " | 
            
                                                                                                            
                            
            
                                    
            
            
                | 713 |  |  |                                  "(lower, upper)") | 
            
                                                                                                            
                            
            
                                    
            
            
                | 714 |  |  |             mask *= (max(baseline_flux) >= self.coefficients[:, 0]) \ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 715 |  |  |                   * (self.coefficients[:, 0] >= min(baseline_flux)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 716 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 717 |  |  |         for term, constraints in (tolerances or {}).items(): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 718 |  |  |             if len(constraints) != 2: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 719 |  |  |                 raise ValueError("{} tolerance must be given as (lower, upper)"\ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 720 |  |  |                     .format(term)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 721 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 722 |  |  |             p_term = utils.parse_label_vector(term)[0] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 723 |  |  |             if p_term not in self.label_vector: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 724 |  |  |                 logger.warn("Term {0} ({1}) is not in the label vector, " | 
            
                                                                                                            
                            
            
                                    
            
            
                | 725 |  |  |                             "and is therefore being ignored".format( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 726 |  |  |                                 term, p_term)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 727 |  |  |                 continue | 
            
                                                                                                            
                            
            
                                    
            
            
                | 728 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 729 |  |  |             a = self.coefficients[:, 1 + self.label_vector.index(p_term)] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 730 |  |  |             mask *= (max(constraints) >= a) * (a >= min(constraints)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 731 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 732 |  |  |         for qs, use_abs in zip([percentiles, absolute_percentiles], [0, 1]): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 733 |  |  |             if qs is None: continue | 
            
                                                                                                            
                            
            
                                    
            
            
                | 734 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 735 |  |  |             for term, constraints in qs.items(): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 736 |  |  |                 if len(constraints) != 2: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 737 |  |  |                     raise ValueError("{} constraints must be given as " | 
            
                                                                                                            
                            
            
                                    
            
            
                | 738 |  |  |                                      "(lower, upper)".format(term)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 739 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 740 |  |  |                 p_term = utils.parse_label_vector(term)[0] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 741 |  |  |                 if p_term not in self.label_vector: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 742 |  |  |                     logger.warn("Term {0} ({1}) is not in the label vector, " | 
            
                                                                                                            
                            
            
                                    
            
            
                | 743 |  |  |                                 "and is therefore being ignored".format( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 744 |  |  |                                     term, p_term)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 745 |  |  |                     continue | 
            
                                                                                                            
                            
            
                                    
            
            
                | 746 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 747 |  |  |                 a = self.coefficients[:, 1 + self.label_vector.index(p_term)] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 748 |  |  |                 if use_abs: a = np.abs(a) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 749 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 750 |  |  |                 p = np.percentile(a, constraints) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 751 |  |  |                 mask *= (max(p) >= a) * (a >= min(p)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 752 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 753 |  |  |         return mask | 
            
                                                                                                            
                            
            
                                    
            
            
                | 754 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 755 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 756 |  |  | def _build_label_vector_rows(label_vector, training_labels, pivots=None): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 757 |  |  |     """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 758 |  |  |     Build a label vector row from a description of the label vector (as indices | 
            
                                                                                                            
                            
            
                                    
            
            
                | 759 |  |  |     and orders to the power of) and the label values themselves. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 760 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 761 |  |  |     For example: if the first item of `labels` is `A`, and the label vector | 
            
                                                                                                            
                            
            
                                    
            
            
                | 762 |  |  |     description is `A^3` then the first item of `label_vector` would be: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 763 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 764 |  |  |     `[[(0, 3)], ...` | 
            
                                                                                                            
                            
            
                                    
            
            
                | 765 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 766 |  |  |     This indicates the first label item (index `0`) to the power `3`. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 767 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 768 |  |  |     :param label_vector: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 769 |  |  |         An `(index, order)` description of the label vector.  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 770 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 771 |  |  |     :param training_labels: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 772 |  |  |         The values of the corresponding training labels. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 773 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 774 |  |  |     :returns: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 775 |  |  |         The corresponding label vector row. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 776 |  |  |     """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 777 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 778 |  |  |     pivots = pivots or {} | 
            
                                                                                                            
                            
            
                                    
            
            
                | 779 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 780 |  |  |     columns = [np.ones(len(training_labels), dtype=float)] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 781 |  |  |     for term in label_vector: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 782 |  |  |         column = 1. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 783 |  |  |         for label, order in term: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 784 |  |  |             column *= (np.array(training_labels[label]).flatten() \ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 785 |  |  |                 - pivots.get(label, 0))**order | 
            
                                                                                                            
                            
            
                                    
            
            
                | 786 |  |  |         columns.append(column) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 787 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 788 |  |  |     try: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 789 |  |  |         return np.vstack(columns) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 790 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 791 |  |  |     except ValueError: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 792 |  |  |         columns[0] = np.ones(1) | 
            
                                                                                                            
                                                                
            
                                    
            
            
                | 793 |  |  |         return np.vstack(columns) | 
            
                                                        
            
                                    
            
            
                | 794 |  |  |  |