| Total Complexity | 124 | 
| Total Lines | 699 | 
| Duplicated Lines | 0 % | 
Complex classes like AnniesLasso.BaseCannonModel often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
| 1 | #!/usr/bin/env python  | 
            ||
| 55 | class BaseCannonModel(object):  | 
            ||
| 56 | """  | 
            ||
| 57 | An abstract Cannon model object that implements convenience functions.  | 
            ||
| 58 | |||
| 59 | :param labels:  | 
            ||
| 60 | A table with columns as labels, and stars as rows.  | 
            ||
| 61 | |||
| 62 | :type labels:  | 
            ||
| 63 | :class:`~astropy.table.Table` or numpy structured array  | 
            ||
| 64 | |||
| 65 | :param fluxes:  | 
            ||
| 66 | An array of fluxes for stars in the training set, given as shape  | 
            ||
| 67 | `(num_stars, num_pixels)`. The `num_stars` should match the number of  | 
            ||
| 68 | rows in `labels`.  | 
            ||
| 69 | |||
| 70 | :type fluxes:  | 
            ||
| 71 | :class:`np.ndarray`  | 
            ||
| 72 | |||
| 73 | :param flux_uncertainties:  | 
            ||
| 74 | An array of 1-sigma flux uncertainties for stars in the training set,  | 
            ||
| 75 | The shape of the `flux_uncertainties` should match `fluxes`.  | 
            ||
| 76 | |||
| 77 | :type flux_uncertainties:  | 
            ||
| 78 | :class:`np.ndarray`  | 
            ||
| 79 | |||
| 80 | :param dispersion: [optional]  | 
            ||
| 81 | The dispersion values corresponding to the given pixels. If provided,  | 
            ||
| 82 | this should have length `num_pixels`.  | 
            ||
| 83 | |||
| 84 | :param live_dangerously: [optional]  | 
            ||
| 85 | If enabled then no checks will be made on the label names, prohibiting  | 
            ||
| 86 | the user to input human-readable forms of the label vector.  | 
            ||
| 87 | """  | 
            ||
| 88 | |||
| 89 | _descriptive_attributes = ["_label_vector"]  | 
            ||
| 90 | _trained_attributes = ["_scatter", "_coefficients", "_pivots"]  | 
            ||
| 91 | _data_attributes = []  | 
            ||
| 92 | _forbidden_label_characters = "^*"  | 
            ||
| 93 | |||
| 94 | def __init__(self, labels, fluxes, flux_uncertainties, dispersion=None,  | 
            ||
| 95 | threads=1, pool=None, live_dangerously=False):  | 
            ||
| 96 | |||
| 97 | self._training_labels = labels  | 
            ||
| 98 | self._training_fluxes = np.atleast_2d(fluxes)  | 
            ||
| 99 | self._training_flux_uncertainties = np.atleast_2d(flux_uncertainties)  | 
            ||
| 100 | self._dispersion = np.arange(fluxes.shape[1], dtype=int) \  | 
            ||
| 101 | if dispersion is None else dispersion  | 
            ||
| 102 | |||
| 103 | for attribute in self._descriptive_attributes:  | 
            ||
| 104 | setattr(self, attribute, None)  | 
            ||
| 105 | |||
| 106 | # The training data must be checked, but users can live dangerously if  | 
            ||
| 107 | # they think they can correctly specify the label vector description.  | 
            ||
| 108 | self._verify_training_data()  | 
            ||
| 109 | if not live_dangerously:  | 
            ||
| 110 | self._verify_labels_available()  | 
            ||
| 111 | |||
| 112 | self.reset()  | 
            ||
| 113 | self.threads = threads  | 
            ||
| 114 | self.pool = pool or mp.Pool(threads) if threads > 1 else None  | 
            ||
| 115 | |||
| 116 | |||
| 117 | def reset(self):  | 
            ||
| 118 | """  | 
            ||
| 119 | Clear any attributes that have been trained upon.  | 
            ||
| 120 | """  | 
            ||
| 121 | |||
| 122 | self._trained = False  | 
            ||
| 123 | for attribute in self._trained_attributes:  | 
            ||
| 124 | setattr(self, attribute, None)  | 
            ||
| 125 | |||
| 126 | return None  | 
            ||
| 127 | |||
| 128 | |||
| 129 | def __str__(self):  | 
            ||
| 130 |         return "<{module}.{name} {trained}using a training set of {N} stars "\ | 
            ||
| 131 |                "with {K} available labels and {M} pixels each>".format( | 
            ||
| 132 | module=self.__module__,  | 
            ||
| 133 | name=type(self).__name__,  | 
            ||
| 134 | trained="trained " if self.is_trained else "",  | 
            ||
| 135 | N=len(self.training_labels),  | 
            ||
| 136 | K=len(self.labels_available),  | 
            ||
| 137 | M=len(self.dispersion))  | 
            ||
| 138 | |||
| 139 | |||
| 140 | def __repr__(self):  | 
            ||
| 141 |         return "<{0}.{1} object at {2}>".format( | 
            ||
| 142 | self.__module__, type(self).__name__, hex(id(self)))  | 
            ||
| 143 | |||
| 144 | |||
| 145 | # Attributes related to the training data.  | 
            ||
| 146 | @property  | 
            ||
| 147 | def dispersion(self):  | 
            ||
| 148 | """  | 
            ||
| 149 | Return the dispersion points for all pixels.  | 
            ||
| 150 | """  | 
            ||
| 151 | return self._dispersion  | 
            ||
| 152 | |||
| 153 | |||
| 154 | @dispersion.setter  | 
            ||
| 155 | def dispersion(self, dispersion):  | 
            ||
| 156 | """  | 
            ||
| 157 | Set the dispersion values for all the pixels.  | 
            ||
| 158 | """  | 
            ||
| 159 | try:  | 
            ||
| 160 | len(dispersion)  | 
            ||
| 161 | except TypeError:  | 
            ||
| 162 |             raise TypeError("dispersion provided must be an array or list-like") | 
            ||
| 163 | |||
| 164 | if len(dispersion) != self.training_fluxes.shape[1]:  | 
            ||
| 165 |             raise ValueError("dispersion provided does not match the number " | 
            ||
| 166 |                              "of pixels per star ({0} != {1})".format( | 
            ||
| 167 | len(dispersion), self.training_fluxes.shape[1]))  | 
            ||
| 168 | |||
| 169 | dispersion = np.array(dispersion)  | 
            ||
| 170 | if dispersion.dtype.kind not in "iuf":  | 
            ||
| 171 |             raise ValueError("dispersion values are not float-like") | 
            ||
| 172 | |||
| 173 | if not np.all(np.isfinite(dispersion)):  | 
            ||
| 174 |             raise ValueError("dispersion values must be finite") | 
            ||
| 175 | |||
| 176 | self._dispersion = dispersion  | 
            ||
| 177 | return None  | 
            ||
| 178 | |||
| 179 | |||
| 180 | @property  | 
            ||
| 181 | def training_labels(self):  | 
            ||
| 182 | return self._training_labels  | 
            ||
| 183 | |||
| 184 | |||
| 185 | @property  | 
            ||
| 186 | def training_fluxes(self):  | 
            ||
| 187 | return self._training_fluxes  | 
            ||
| 188 | |||
| 189 | |||
| 190 | @property  | 
            ||
| 191 | def training_flux_uncertainties(self):  | 
            ||
| 192 | return self._training_flux_uncertainties  | 
            ||
| 193 | |||
| 194 | |||
| 195 | # Verifying the training data.  | 
            ||
| 196 | def _verify_labels_available(self):  | 
            ||
| 197 | """  | 
            ||
| 198 | Verify the label names provided do not include forbidden characters.  | 
            ||
| 199 | """  | 
            ||
| 200 | if self._forbidden_label_characters is None:  | 
            ||
| 201 | return True  | 
            ||
| 202 | |||
| 203 | for label in self.training_labels.dtype.names:  | 
            ||
| 204 | for character in self._forbidden_label_characters:  | 
            ||
| 205 | if character in label:  | 
            ||
| 206 | raise ValueError(  | 
            ||
| 207 |                         "forbidden character '{char}' is in potential " | 
            ||
| 208 |                         "label '{label}' - you can disable this verification " | 
            ||
| 209 | "by enabling `live_dangerously`".format(  | 
            ||
| 210 | char=character, label=label))  | 
            ||
| 211 | return None  | 
            ||
| 212 | |||
| 213 | |||
| 214 | def _verify_training_data(self):  | 
            ||
| 215 | """  | 
            ||
| 216 | Verify the training data for the appropriate shape and content.  | 
            ||
| 217 | """  | 
            ||
| 218 | if self.training_fluxes.shape != self.training_flux_uncertainties.shape:  | 
            ||
| 219 | raise ValueError(  | 
            ||
| 220 | "the training flux and uncertainty arrays should "  | 
            ||
| 221 | "have the same shape")  | 
            ||
| 222 | |||
| 223 | if len(self.training_labels) == 0 \  | 
            ||
| 224 | or self.training_labels.dtype.names is None:  | 
            ||
| 225 |             raise ValueError("no named labels provided for the training set") | 
            ||
| 226 | |||
| 227 | if len(self.training_labels) != self.training_fluxes.shape[0]:  | 
            ||
| 228 | raise ValueError(  | 
            ||
| 229 | "the first axes of the training flux array should "  | 
            ||
| 230 | "have the same shape as the nuber of rows in the label table "  | 
            ||
| 231 | "(N_stars, N_pixels)")  | 
            ||
| 232 | |||
| 233 | if self.dispersion is not None:  | 
            ||
| 234 | dispersion = np.atleast_1d(self.dispersion).flatten()  | 
            ||
| 235 | if dispersion.size != self.training_fluxes.shape[1]:  | 
            ||
| 236 | raise ValueError(  | 
            ||
| 237 | "mis-match between the number of wavelength "  | 
            ||
| 238 |                     "points ({0}) and flux values ({1})".format( | 
            ||
| 239 | self.training_fluxes.shape[1], dispersion.size))  | 
            ||
| 240 | return None  | 
            ||
| 241 | |||
| 242 | |||
| 243 | @property  | 
            ||
| 244 | def is_trained(self):  | 
            ||
| 245 | return self._trained  | 
            ||
| 246 | |||
| 247 | |||
| 248 | # Attributes related to the labels and the label vector description.  | 
            ||
| 249 | @property  | 
            ||
| 250 | def labels_available(self):  | 
            ||
| 251 | """  | 
            ||
| 252 | All of the available labels for each star in the training set.  | 
            ||
| 253 | """  | 
            ||
| 254 | return self.training_labels.dtype.names  | 
            ||
| 255 | |||
| 256 | |||
| 257 | @property  | 
            ||
| 258 | def label_vector(self):  | 
            ||
| 259 | """ The label vector for all pixels. """  | 
            ||
| 260 | return self._label_vector  | 
            ||
| 261 | |||
| 262 | |||
| 263 | @label_vector.setter  | 
            ||
| 264 | def label_vector(self, label_vector_description):  | 
            ||
| 265 | """  | 
            ||
| 266 | Set a label vector.  | 
            ||
| 267 | |||
| 268 | :param label_vector_description:  | 
            ||
| 269 | A structured or human-readable version of the label vector  | 
            ||
| 270 | description.  | 
            ||
| 271 | """  | 
            ||
| 272 | |||
| 273 | if label_vector_description is None:  | 
            ||
| 274 | self._label_vector = None  | 
            ||
| 275 | self.reset()  | 
            ||
| 276 | return None  | 
            ||
| 277 | |||
| 278 | label_vector = utils.parse_label_vector(label_vector_description)  | 
            ||
| 279 | |||
| 280 | # Need to actually verify that the parameters listed in the label vector  | 
            ||
| 281 | # are actually present in the training labels.  | 
            ||
| 282 | missing = \  | 
            ||
| 283 | set(self._get_labels(label_vector)).difference(self.labels_available)  | 
            ||
| 284 | if missing:  | 
            ||
| 285 |             raise ValueError("the following labels parsed from the label " | 
            ||
| 286 | "vector description are missing in the training "  | 
            ||
| 287 |                              "set of labels: {0}".format(", ".join(missing))) | 
            ||
| 288 | |||
| 289 | # If this is really a new label vector description,  | 
            ||
| 290 | # then we are no longer trained.  | 
            ||
| 291 | if not hasattr(self, "_label_vector") \  | 
            ||
| 292 | or label_vector != self._label_vector:  | 
            ||
| 293 | self._label_vector = label_vector  | 
            ||
| 294 | self.reset()  | 
            ||
| 295 | |||
| 296 | self.pivots = \  | 
            ||
| 297 | np.array([np.nanmean(self.training_labels[l]) for l in self.labels])  | 
            ||
| 298 | |||
| 299 | return None  | 
            ||
| 300 | |||
| 301 | |||
| 302 | @property  | 
            ||
| 303 | def human_readable_label_vector(self):  | 
            ||
| 304 | """ Return a human-readable form of the label vector. """  | 
            ||
| 305 | return utils.human_readable_label_vector(self.label_vector)  | 
            ||
| 306 | |||
| 307 | |||
| 308 | @property  | 
            ||
| 309 | def labels(self):  | 
            ||
| 310 | """ The labels that contribute to the label vector. """  | 
            ||
| 311 | return self._get_labels(self.label_vector)  | 
            ||
| 312 | |||
| 313 | |||
| 314 | def _get_labels(self, label_vector):  | 
            ||
| 315 | """  | 
            ||
| 316 | Return the labels that contribute to the structured label vector  | 
            ||
| 317 | provided.  | 
            ||
| 318 | """  | 
            ||
| 319 | return () if label_vector is None else \  | 
            ||
| 320 | list(OrderedDict.fromkeys([label for term in label_vector \  | 
            ||
| 321 | for label, power in term if power != 0]))  | 
            ||
| 322 | |||
| 323 | |||
| 324 | def _get_lowest_order_label_indices(self):  | 
            ||
| 325 | """  | 
            ||
| 326 | Get the indices for the lowest power label terms in the label vector.  | 
            ||
| 327 | """  | 
            ||
| 328 | indices = OrderedDict()  | 
            ||
| 329 | for i, term in enumerate(self.label_vector):  | 
            ||
| 330 | if len(term) > 1: continue  | 
            ||
| 331 | label, order = term[0]  | 
            ||
| 332 | if order < indices.get(label, [None, np.inf])[-1]:  | 
            ||
| 333 | indices[label] = (i, order)  | 
            ||
| 334 | return [indices.get(label, [None])[0] for label in self.labels]  | 
            ||
| 335 | |||
| 336 | |||
| 337 | # Trained attributes that subclasses are likely to use.  | 
            ||
| 338 | @property  | 
            ||
| 339 | def coefficients(self):  | 
            ||
| 340 | return self._coefficients  | 
            ||
| 341 | |||
| 342 | |||
| 343 | @coefficients.setter  | 
            ||
| 344 | def coefficients(self, coefficients):  | 
            ||
| 345 | """  | 
            ||
| 346 | Set the label vector coefficients for each pixel. This assumes a  | 
            ||
| 347 | 'standard' model where the label vector is common to all pixels.  | 
            ||
| 348 | |||
| 349 | :param coefficients:  | 
            ||
| 350 | A 2-D array of coefficients of shape  | 
            ||
| 351 | (`N_pixels`, `N_label_vector_terms`).  | 
            ||
| 352 | """  | 
            ||
| 353 | |||
| 354 | if coefficients is None:  | 
            ||
| 355 | self._coefficients = None  | 
            ||
| 356 | return None  | 
            ||
| 357 | |||
| 358 | coefficients = np.atleast_2d(coefficients)  | 
            ||
| 359 | if len(coefficients.shape) > 2:  | 
            ||
| 360 |             raise ValueError("coefficients must be a 2D array") | 
            ||
| 361 | |||
| 362 | P, Q = coefficients.shape  | 
            ||
| 363 | if P != len(self.dispersion):  | 
            ||
| 364 |             raise ValueError("axis 0 of coefficients array does not match the " | 
            ||
| 365 |                              "number of pixels ({0} != {1})".format( | 
            ||
| 366 | P, len(self.dispersion)))  | 
            ||
| 367 | if Q != 1 + len(self.label_vector):  | 
            ||
| 368 |             raise ValueError("axis 1 of coefficients array does not match the " | 
            ||
| 369 |                              "number of label vector terms ({0} != {1})".format( | 
            ||
| 370 | Q, 1 + len(self.label_vector)))  | 
            ||
| 371 | self._coefficients = coefficients  | 
            ||
| 372 | return None  | 
            ||
| 373 | |||
| 374 | |||
| 375 | @property  | 
            ||
| 376 | def scatter(self):  | 
            ||
| 377 | return self._scatter  | 
            ||
| 378 | |||
| 379 | |||
| 380 | @scatter.setter  | 
            ||
| 381 | def scatter(self, scatter):  | 
            ||
| 382 | """  | 
            ||
| 383 | Set the scatter values for each pixel.  | 
            ||
| 384 | |||
| 385 | :param scatter:  | 
            ||
| 386 | A 1-D array of scatter terms.  | 
            ||
| 387 | """  | 
            ||
| 388 | |||
| 389 | if scatter is None:  | 
            ||
| 390 | self._scatter = None  | 
            ||
| 391 | return None  | 
            ||
| 392 | |||
| 393 | scatter = np.array(scatter).flatten()  | 
            ||
| 394 | if scatter.size != len(self.dispersion):  | 
            ||
| 395 |             raise ValueError("number of scatter values does not match " | 
            ||
| 396 |                              "the number of pixels ({0} != {1})".format( | 
            ||
| 397 | scatter.size, len(self.dispersion)))  | 
            ||
| 398 | if np.any(scatter < 0):  | 
            ||
| 399 |             raise ValueError("scatter terms must be positive") | 
            ||
| 400 | self._scatter = scatter  | 
            ||
| 401 | return None  | 
            ||
| 402 | |||
| 403 | |||
| 404 | @property  | 
            ||
| 405 | def pivots(self):  | 
            ||
| 406 | """  | 
            ||
| 407 | Return the mean values of the unique labels in the label vector.  | 
            ||
| 408 | """  | 
            ||
| 409 | return self._pivots  | 
            ||
| 410 | |||
| 411 | |||
| 412 | @pivots.setter  | 
            ||
| 413 | def pivots(self, pivots):  | 
            ||
| 414 | """  | 
            ||
| 415 | Return the pivot values for each unique label in the label vector.  | 
            ||
| 416 | |||
| 417 | :param pivots:  | 
            ||
| 418 | A list of pivot values for the corresponding terms in `self.labels`.  | 
            ||
| 419 | """  | 
            ||
| 420 | |||
| 421 | if pivots is None:  | 
            ||
| 422 | self._pivots = None  | 
            ||
| 423 | return None  | 
            ||
| 424 | |||
| 425 | """  | 
            ||
| 426 | if not isinstance(pivots, dict):  | 
            ||
| 427 |             raise TypeError("pivots must be a dictionary") | 
            ||
| 428 | |||
| 429 | missing = set(self.labels).difference(pivots)  | 
            ||
| 430 | if any(missing):  | 
            ||
| 431 |             raise ValueError("pivot values for the following labels " | 
            ||
| 432 |                              "are missing: {}".format(", ".join(list(missing)))) | 
            ||
| 433 | |||
| 434 | if not np.all(np.isfinite(pivots.values())):  | 
            ||
| 435 |             raise ValueError("pivot values must be finite") | 
            ||
| 436 | |||
| 437 | self._pivots = pivots  | 
            ||
| 438 | """  | 
            ||
| 439 | |||
| 440 | pivots = np.array(pivots).flatten()  | 
            ||
| 441 | N_labels = len(self.labels)  | 
            ||
| 442 | if pivots.size != N_labels:  | 
            ||
| 443 |             raise ValueError("number of pivot values does not match the " | 
            ||
| 444 | "number of unique labels in the label vector "  | 
            ||
| 445 |                              "({0} != {1})".format(pivots.size, N_labels)) | 
            ||
| 446 | |||
| 447 | if not np.all(np.isfinite(pivots)):  | 
            ||
| 448 |             raise ValueError("pivot values must be finite") | 
            ||
| 449 | |||
| 450 | self._pivots = pivots  | 
            ||
| 451 | return None  | 
            ||
| 452 | |||
| 453 | |||
| 454 | # Methods which must be implemented or updated by the subclasses.  | 
            ||
| 455 | def pixel_label_vector(self, pixel_index):  | 
            ||
| 456 | """ The label vector for a given pixel. """  | 
            ||
| 457 | return self.label_vector  | 
            ||
| 458 | |||
| 459 | |||
| 460 | def train(self, *args, **kwargs):  | 
            ||
| 461 |         raise NotImplementedError("The train method must be " | 
            ||
| 462 | "implemented by subclasses")  | 
            ||
| 463 | |||
| 464 | |||
| 465 | def predict(self, *args, **kwargs):  | 
            ||
| 466 |         raise NotImplementedError("The predict method must be " | 
            ||
| 467 | "implemented by subclasses")  | 
            ||
| 468 | |||
| 469 | |||
| 470 | def fit(self, *args, **kwargs):  | 
            ||
| 471 |         raise NotImplementedError("The fit method must be " | 
            ||
| 472 | "implemented by subclasses")  | 
            ||
| 473 | |||
| 474 | |||
| 475 | # I/O  | 
            ||
| 476 | @requires_training_wheels  | 
            ||
| 477 | def save(self, filename, include_training_data=False, overwrite=False):  | 
            ||
| 478 | """  | 
            ||
| 479 | Serialise the trained model and save it to disk. This will save all  | 
            ||
| 480 | relevant training attributes, and optionally, the training data.  | 
            ||
| 481 | |||
| 482 | :param filename:  | 
            ||
| 483 | The path to save the model to.  | 
            ||
| 484 | |||
| 485 | :param include_training_data: [optional]  | 
            ||
| 486 | Save the training data (labels, fluxes, uncertainties) used to train  | 
            ||
| 487 | the model.  | 
            ||
| 488 | |||
| 489 | :param overwrite: [optional]  | 
            ||
| 490 | Overwrite the existing file path, if it already exists.  | 
            ||
| 491 | """  | 
            ||
| 492 | |||
| 493 | if path.exists(filename) and not overwrite:  | 
            ||
| 494 |             raise IOError("filename already exists: {0}".format(filename)) | 
            ||
| 495 | |||
| 496 | attributes = list(self._descriptive_attributes) \  | 
            ||
| 497 | + list(self._trained_attributes) \  | 
            ||
| 498 | + list(self._data_attributes)  | 
            ||
| 499 | if "metadata" in attributes:  | 
            ||
| 500 |             raise ValueError("'metadata' is a protected attribute and cannot " | 
            ||
| 501 | "be used in the _*_attributes in a class")  | 
            ||
| 502 | |||
| 503 | # Store up all the trained attributes and a hash of the training set.  | 
            ||
| 504 | contents = OrderedDict([  | 
            ||
| 505 |             (attr.lstrip("_"), getattr(self, attr)) for attr in \ | 
            ||
| 506 | (self._descriptive_attributes + self._trained_attributes)])  | 
            ||
| 507 | contents["training_set_hash"] = utils.short_hash(getattr(self, attr) \  | 
            ||
| 508 | for attr in self._data_attributes)  | 
            ||
| 509 | |||
| 510 | if include_training_data:  | 
            ||
| 511 |             contents.update([(attr.lstrip("_"), getattr(self, attr)) \ | 
            ||
| 512 | for attr in self._data_attributes])  | 
            ||
| 513 | |||
| 514 |         contents["metadata"] = { | 
            ||
| 515 | "version": code_version,  | 
            ||
| 516 | "model_name": type(self).__name__,  | 
            ||
| 517 | "modified": str(datetime.now()),  | 
            ||
| 518 | "data_attributes": \  | 
            ||
| 519 |                 [_.lstrip("_") for _ in self._data_attributes], | 
            ||
| 520 | "trained_attributes": \  | 
            ||
| 521 |                 [_.lstrip("_") for _ in self._trained_attributes], | 
            ||
| 522 | "descriptive_attributes": \  | 
            ||
| 523 |                 [_.lstrip("_") for _ in self._descriptive_attributes] | 
            ||
| 524 | }  | 
            ||
| 525 | |||
| 526 | with open(filename, "wb") as fp:  | 
            ||
| 527 | pickle.dump(contents, fp, -1)  | 
            ||
| 528 | |||
| 529 | return None  | 
            ||
| 530 | |||
| 531 | |||
| 532 | def load(self, filename, verify_training_data=False):  | 
            ||
| 533 | """  | 
            ||
| 534 | Load a saved model from disk.  | 
            ||
| 535 | |||
| 536 | :param filename:  | 
            ||
| 537 | The path where to load the model from.  | 
            ||
| 538 | |||
| 539 | :param verify_training_data: [optional]  | 
            ||
| 540 | If there is training data in the saved model, verify its contents.  | 
            ||
| 541 | Otherwise if no training data is saved, verify that the data used  | 
            ||
| 542 | to train the model is the same data provided when this model was  | 
            ||
| 543 | instantiated.  | 
            ||
| 544 | """  | 
            ||
| 545 | |||
| 546 | with open(filename, "rb") as fp:  | 
            ||
| 547 | contents = pickle.load(fp)  | 
            ||
| 548 | |||
| 549 | assert contents["metadata"]["model_name"] == type(self).__name__  | 
            ||
| 550 | |||
| 551 | # If data exists, deal with that first.  | 
            ||
| 552 | has_data = (contents["metadata"]["data_attributes"][0] in contents)  | 
            ||
| 553 | if has_data:  | 
            ||
| 554 | |||
| 555 | if verify_training_data:  | 
            ||
| 556 | data_hash = utils.short_hash(contents[attr] \  | 
            ||
| 557 | for attr in contents["metadata"]["data_attributes"])  | 
            ||
| 558 | if contents["training_set_hash"] is not None \  | 
            ||
| 559 | and data_hash != contents["training_set_hash"]:  | 
            ||
| 560 |                     raise ValueError("expected hash for the training data is " | 
            ||
| 561 | "different to the actual data hash "  | 
            ||
| 562 |                                      "({0} != {1})".format( | 
            ||
| 563 | contents["training_set_hash"],  | 
            ||
| 564 | data_hash))  | 
            ||
| 565 | |||
| 566 | # Set the data attributes.  | 
            ||
| 567 | for attribute in contents["metadata"]["data_attributes"]:  | 
            ||
| 568 | if attribute in contents:  | 
            ||
| 569 |                     setattr(self, "_{}".format(attribute), contents[attribute]) | 
            ||
| 570 | |||
| 571 | # Set descriptive and trained attributes.  | 
            ||
| 572 | self.reset()  | 
            ||
| 573 | for attribute in contents["metadata"]["descriptive_attributes"]:  | 
            ||
| 574 |             setattr(self, "_{}".format(attribute), contents[attribute]) | 
            ||
| 575 | for attribute in contents["metadata"]["trained_attributes"]:  | 
            ||
| 576 |             setattr(self, "_{}".format(attribute), contents[attribute]) | 
            ||
| 577 | self._trained = True  | 
            ||
| 578 | |||
| 579 | return None  | 
            ||
| 580 | |||
| 581 | |||
| 582 | # Properties and attributes related to training, etc.  | 
            ||
| 583 | @property  | 
            ||
| 584 | @requires_model_description  | 
            ||
| 585 | def labels_array(self):  | 
            ||
| 586 | """  | 
            ||
| 587 | Return an array containing just the training labels, given the label  | 
            ||
| 588 | vector. This array does not account for any pivoting.  | 
            ||
| 589 | """  | 
            ||
| 590 | return _build_label_vector_rows([[(label, 1)] for label in self.labels],  | 
            ||
| 591 | self.training_labels)[1:].T  | 
            ||
| 592 | |||
| 593 | |||
| 594 | @property  | 
            ||
| 595 | @requires_model_description  | 
            ||
| 596 | def label_vector_array(self):  | 
            ||
| 597 | """  | 
            ||
| 598 | Build the label vector array.  | 
            ||
| 599 | """  | 
            ||
| 600 | |||
| 601 | lva = _build_label_vector_rows(self.label_vector, self.training_labels,  | 
            ||
| 602 | dict(zip(self.labels, self.pivots)))  | 
            ||
| 603 | |||
| 604 | if not np.all(np.isfinite(lva)):  | 
            ||
| 605 |             logger.warn("Non-finite labels in the label vector array!") | 
            ||
| 606 | return lva  | 
            ||
| 607 | |||
| 608 | |||
| 609 | # Residuals in labels in the training data set.  | 
            ||
| 610 | @requires_training_wheels  | 
            ||
| 611 | def get_training_label_residuals(self):  | 
            ||
| 612 | """  | 
            ||
| 613 | Return the residuals (model - training) between the parameters that the  | 
            ||
| 614 | model returns for each star, and the training set value.  | 
            ||
| 615 | """  | 
            ||
| 616 | |||
| 617 | optimised_labels = self.fit(self.training_fluxes,  | 
            ||
| 618 | self.training_flux_uncertainties, full_output=False)  | 
            ||
| 619 | |||
| 620 | return optimised_labels - self.labels_array  | 
            ||
| 621 | |||
| 622 | |||
| 623 | def _format_input_labels(self, args=None, **kwargs):  | 
            ||
| 624 | """  | 
            ||
| 625 | Format input labels either from a list or dictionary into a common form.  | 
            ||
| 626 | """  | 
            ||
| 627 | |||
| 628 | # We want labels in a dictionary.  | 
            ||
| 629 | labels = kwargs if args is None else dict(zip(self.labels, args))  | 
            ||
| 630 |         return { k: (v if isinstance(v, (list, tuple, np.ndarray)) else [v]) \ | 
            ||
| 631 | for k, v in labels.items() }  | 
            ||
| 632 | |||
| 633 | |||
| 634 | @requires_model_description  | 
            ||
| 635 | def cross_validate(self, pre_train=None, **kwargs):  | 
            ||
| 636 | """  | 
            ||
| 637 | Perform leave-one-out cross-validation on the training set.  | 
            ||
| 638 | """  | 
            ||
| 639 | |||
| 640 | inferred = np.nan * np.ones_like(self.labels_array)  | 
            ||
| 641 | N_training_set, N_labels = inferred.shape  | 
            ||
| 642 |         N_stop_at = kwargs.pop("N", N_training_set) | 
            ||
| 643 | |||
| 644 |         debug = kwargs.pop("debug", False) | 
            ||
| 645 | |||
| 646 |         kwds = { "threads": self.threads } | 
            ||
| 647 | kwds.update(kwargs)  | 
            ||
| 648 | |||
| 649 | for i in range(N_training_set):  | 
            ||
| 650 | |||
| 651 | training_set = np.ones(N_training_set, dtype=bool)  | 
            ||
| 652 | training_set[i] = False  | 
            ||
| 653 | |||
| 654 | # Create a clean model to use so we don't overwrite self.  | 
            ||
| 655 | model = self.__class__(  | 
            ||
| 656 | self.training_labels[training_set],  | 
            ||
| 657 | self.training_fluxes[training_set],  | 
            ||
| 658 | self.training_flux_uncertainties[training_set],  | 
            ||
| 659 | **kwds)  | 
            ||
| 660 | |||
| 661 | # Initialise and run any pre-training function.  | 
            ||
| 662 | for _attribute in self._descriptive_attributes:  | 
            ||
| 663 | setattr(model, _attribute[1:], getattr(self, _attribute[1:]))  | 
            ||
| 664 | |||
| 665 | if pre_train is not None:  | 
            ||
| 666 | pre_train(self, model)  | 
            ||
| 667 | |||
| 668 | # Train and solve.  | 
            ||
| 669 | model.train()  | 
            ||
| 670 | |||
| 671 | try:  | 
            ||
| 672 | inferred[i, :] = model.fit(self.training_fluxes[i],  | 
            ||
| 673 | self.training_flux_uncertainties[i], full_output=False)  | 
            ||
| 674 | |||
| 675 | except:  | 
            ||
| 676 |                 logger.exception("Exception during cross-validation on object " | 
            ||
| 677 |                                  "with index {0}:".format(i)) | 
            ||
| 678 | if debug: raise  | 
            ||
| 679 | |||
| 680 | if i == N_stop_at + 1:  | 
            ||
| 681 | break  | 
            ||
| 682 | |||
| 683 | return inferred[:N_stop_at, :]  | 
            ||
| 684 | |||
| 685 | |||
| 686 | @requires_training_wheels  | 
            ||
| 687 | def define_continuum_mask(self, baseline_flux=None, tolerances=None,  | 
            ||
| 688 | percentiles=None, absolute_percentiles=None):  | 
            ||
| 689 | """  | 
            ||
| 690 | Define a continuum mask based on constraints on the baseline flux values  | 
            ||
| 691 | and the percentiles or absolute percentiles of theta coefficients. The  | 
            ||
| 692 | resulting continuum mask is taken for whatever pixels meet all the given  | 
            ||
| 693 | constraints.  | 
            ||
| 694 | |||
| 695 | :param baseline_flux: [optional]  | 
            ||
| 696 | The `(lower, upper`) range of acceptable baseline flux values to be  | 
            ||
| 697 | considered as continuum.  | 
            ||
| 698 | |||
| 699 | :param percentiles: [optional]  | 
            ||
| 700 | A dictionary containing the label vector description as keys and  | 
            ||
| 701 | acceptable percentile ranges `(lower, upper)` for each corresponding  | 
            ||
| 702 | label vector term.  | 
            ||
| 703 | |||
| 704 | :param absolute_percentiles: [optional]  | 
            ||
| 705 | The same as `percentiles`, except these are calculated on the  | 
            ||
| 706 | absolute values of the model coefficients.  | 
            ||
| 707 | """  | 
            ||
| 708 | |||
| 709 | mask = np.ones_like(self.dispersion, dtype=bool)  | 
            ||
| 710 | if baseline_flux is not None:  | 
            ||
| 711 | if len(baseline_flux) != 2:  | 
            ||
| 712 |                 raise ValueError("baseline flux constraints must be given as " | 
            ||
| 713 | "(lower, upper)")  | 
            ||
| 714 | mask *= (max(baseline_flux) >= self.coefficients[:, 0]) \  | 
            ||
| 715 | * (self.coefficients[:, 0] >= min(baseline_flux))  | 
            ||
| 716 | |||
| 717 |         for term, constraints in (tolerances or {}).items(): | 
            ||
| 718 | if len(constraints) != 2:  | 
            ||
| 719 |                 raise ValueError("{} tolerance must be given as (lower, upper)"\ | 
            ||
| 720 | .format(term))  | 
            ||
| 721 | |||
| 722 | p_term = utils.parse_label_vector(term)[0]  | 
            ||
| 723 | if p_term not in self.label_vector:  | 
            ||
| 724 |                 logger.warn("Term {0} ({1}) is not in the label vector, " | 
            ||
| 725 | "and is therefore being ignored".format(  | 
            ||
| 726 | term, p_term))  | 
            ||
| 727 | continue  | 
            ||
| 728 | |||
| 729 | a = self.coefficients[:, 1 + self.label_vector.index(p_term)]  | 
            ||
| 730 | mask *= (max(constraints) >= a) * (a >= min(constraints))  | 
            ||
| 731 | |||
| 732 | for qs, use_abs in zip([percentiles, absolute_percentiles], [0, 1]):  | 
            ||
| 733 | if qs is None: continue  | 
            ||
| 734 | |||
| 735 | for term, constraints in qs.items():  | 
            ||
| 736 | if len(constraints) != 2:  | 
            ||
| 737 |                     raise ValueError("{} constraints must be given as " | 
            ||
| 738 | "(lower, upper)".format(term))  | 
            ||
| 739 | |||
| 740 | p_term = utils.parse_label_vector(term)[0]  | 
            ||
| 741 | if p_term not in self.label_vector:  | 
            ||
| 742 |                     logger.warn("Term {0} ({1}) is not in the label vector, " | 
            ||
| 743 | "and is therefore being ignored".format(  | 
            ||
| 744 | term, p_term))  | 
            ||
| 745 | continue  | 
            ||
| 746 | |||
| 747 | a = self.coefficients[:, 1 + self.label_vector.index(p_term)]  | 
            ||
| 748 | if use_abs: a = np.abs(a)  | 
            ||
| 749 | |||
| 750 | p = np.percentile(a, constraints)  | 
            ||
| 751 | mask *= (max(p) >= a) * (a >= min(p))  | 
            ||
| 752 | |||
| 753 | return mask  | 
            ||
| 754 | |||
| 794 |