| Total Complexity | 106 |
| Total Lines | 576 |
| Duplicated Lines | 0 % |
Complex classes like AnniesLasso.BaseCannonModel often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
| 1 | #!/usr/bin/env python |
||
| 51 | class BaseCannonModel(object): |
||
| 52 | """ |
||
| 53 | An abstract Cannon model object that implements convenience functions. |
||
| 54 | |||
| 55 | :param labels: |
||
| 56 | A table with columns as labels, and stars as rows. |
||
| 57 | |||
| 58 | :type labels: |
||
| 59 | :class:`~astropy.table.Table` or numpy structured array |
||
| 60 | |||
| 61 | :param fluxes: |
||
| 62 | An array of fluxes for stars in the training set, given as shape |
||
| 63 | `(num_stars, num_pixels)`. The `num_stars` should match the number of |
||
| 64 | rows in `labels`. |
||
| 65 | |||
| 66 | :type fluxes: |
||
| 67 | :class:`np.ndarray` |
||
| 68 | |||
| 69 | :param flux_uncertainties: |
||
| 70 | An array of 1-sigma flux uncertainties for stars in the training set, |
||
| 71 | The shape of the `flux_uncertainties` should match `fluxes`. |
||
| 72 | |||
| 73 | :type flux_uncertainties: |
||
| 74 | :class:`np.ndarray` |
||
| 75 | |||
| 76 | :param dispersion: [optional] |
||
| 77 | The dispersion values corresponding to the given pixels. If provided, |
||
| 78 | this should have length `num_pixels`. |
||
| 79 | |||
| 80 | :param live_dangerously: [optional] |
||
| 81 | If enabled then no checks will be made on the label names, prohibiting |
||
| 82 | the user to input human-readable forms of the label vector. |
||
| 83 | """ |
||
| 84 | |||
| 85 | _descriptive_attributes = ["_label_vector"] |
||
| 86 | _trained_attributes = [] |
||
| 87 | _data_attributes = [] |
||
| 88 | _forbidden_label_characters = "^*" |
||
| 89 | |||
| 90 | def __init__(self, labels, fluxes, flux_uncertainties, dispersion=None, |
||
| 91 | threads=1, pool=None, live_dangerously=False): |
||
| 92 | |||
| 93 | self._training_labels = labels |
||
| 94 | self._training_fluxes = np.atleast_2d(fluxes) |
||
| 95 | self._training_flux_uncertainties = np.atleast_2d(flux_uncertainties) |
||
| 96 | self._dispersion = np.arange(fluxes.shape[1], dtype=int) \ |
||
| 97 | if dispersion is None else dispersion |
||
| 98 | |||
| 99 | for attribute in self._descriptive_attributes: |
||
| 100 | setattr(self, attribute, None) |
||
| 101 | |||
| 102 | # The training data must be checked, but users can live dangerously if |
||
| 103 | # they think they can correctly specify the label vector description. |
||
| 104 | self._verify_training_data() |
||
| 105 | if not live_dangerously: |
||
| 106 | self._verify_labels_available() |
||
| 107 | |||
| 108 | self.reset() |
||
| 109 | self.threads = threads |
||
| 110 | self.pool = pool or mp.Pool(threads) if threads > 1 else None |
||
| 111 | |||
| 112 | |||
| 113 | def reset(self): |
||
| 114 | """ |
||
| 115 | Clear any attributes that have been trained upon. |
||
| 116 | """ |
||
| 117 | |||
| 118 | self._trained = False |
||
| 119 | for attribute in self._trained_attributes: |
||
| 120 | setattr(self, attribute, None) |
||
| 121 | |||
| 122 | return None |
||
| 123 | |||
| 124 | |||
| 125 | def __str__(self): |
||
| 126 | return "<{module}.{name} {trained}using a training set of {N} stars "\ |
||
| 127 | "with {K} available labels and {M} pixels each>".format( |
||
| 128 | module=self.__module__, |
||
| 129 | name=type(self).__name__, |
||
| 130 | trained="trained " if self.is_trained else "", |
||
| 131 | N=len(self.training_labels), |
||
| 132 | K=len(self.labels_available), |
||
| 133 | M=len(self.dispersion)) |
||
| 134 | |||
| 135 | |||
| 136 | def __repr__(self): |
||
| 137 | return "<{0}.{1} object at {2}>".format( |
||
| 138 | self.__module__, type(self).__name__, hex(id(self))) |
||
| 139 | |||
| 140 | |||
| 141 | # Attributes related to the training data. |
||
| 142 | @property |
||
| 143 | def dispersion(self): |
||
| 144 | """ |
||
| 145 | Return the dispersion points for all pixels. |
||
| 146 | """ |
||
| 147 | return self._dispersion |
||
| 148 | |||
| 149 | |||
| 150 | @dispersion.setter |
||
| 151 | def dispersion(self, dispersion): |
||
| 152 | """ |
||
| 153 | Set the dispersion values for all the pixels. |
||
| 154 | """ |
||
| 155 | try: |
||
| 156 | len(dispersion) |
||
| 157 | except TypeError: |
||
| 158 | raise TypeError("dispersion provided must be an array or list-like") |
||
| 159 | |||
| 160 | if len(dispersion) != self.training_fluxes.shape[1]: |
||
| 161 | raise ValueError("dispersion provided does not match the number " |
||
| 162 | "of pixels per star ({0} != {1})".format( |
||
| 163 | len(dispersion), self.training_fluxes.shape[1])) |
||
| 164 | |||
| 165 | dispersion = np.array(dispersion) |
||
| 166 | if dispersion.dtype.kind not in "iuf": |
||
| 167 | raise ValueError("dispersion values are not float-like") |
||
| 168 | |||
| 169 | if not np.all(np.isfinite(dispersion)): |
||
| 170 | raise ValueError("dispersion values must be finite") |
||
| 171 | |||
| 172 | self._dispersion = dispersion |
||
| 173 | return None |
||
| 174 | |||
| 175 | |||
| 176 | @property |
||
| 177 | def training_labels(self): |
||
| 178 | return self._training_labels |
||
| 179 | |||
| 180 | |||
| 181 | @property |
||
| 182 | def training_fluxes(self): |
||
| 183 | return self._training_fluxes |
||
| 184 | |||
| 185 | |||
| 186 | @property |
||
| 187 | def training_flux_uncertainties(self): |
||
| 188 | return self._training_flux_uncertainties |
||
| 189 | |||
| 190 | |||
| 191 | # Verifying the training data. |
||
| 192 | def _verify_labels_available(self): |
||
| 193 | """ |
||
| 194 | Verify the label names provided do not include forbidden characters. |
||
| 195 | """ |
||
| 196 | if self._forbidden_label_characters is None: |
||
| 197 | return True |
||
| 198 | |||
| 199 | for label in self.training_labels.dtype.names: |
||
| 200 | for character in self._forbidden_label_characters: |
||
| 201 | if character in label: |
||
| 202 | raise ValueError( |
||
| 203 | "forbidden character '{char}' is in potential " |
||
| 204 | "label '{label}' - you can disable this verification " |
||
| 205 | "by enabling `live_dangerously`".format( |
||
| 206 | char=character, label=label)) |
||
| 207 | return None |
||
| 208 | |||
| 209 | |||
| 210 | def _verify_training_data(self): |
||
| 211 | """ |
||
| 212 | Verify the training data for the appropriate shape and content. |
||
| 213 | """ |
||
| 214 | if self.training_fluxes.shape != self.training_flux_uncertainties.shape: |
||
| 215 | raise ValueError( |
||
| 216 | "the training flux and uncertainty arrays should " |
||
| 217 | "have the same shape") |
||
| 218 | |||
| 219 | if len(self.training_labels) == 0 \ |
||
| 220 | or self.training_labels.dtype.names is None: |
||
| 221 | raise ValueError("no named labels provided for the training set") |
||
| 222 | |||
| 223 | if len(self.training_labels) != self.training_fluxes.shape[0]: |
||
| 224 | raise ValueError( |
||
| 225 | "the first axes of the training flux array should " |
||
| 226 | "have the same shape as the nuber of rows in the label table " |
||
| 227 | "(N_stars, N_pixels)") |
||
| 228 | |||
| 229 | if self.dispersion is not None: |
||
| 230 | dispersion = np.atleast_1d(self.dispersion).flatten() |
||
| 231 | if dispersion.size != self.training_fluxes.shape[1]: |
||
| 232 | raise ValueError( |
||
| 233 | "mis-match between the number of wavelength " |
||
| 234 | "points ({0}) and flux values ({1})".format( |
||
| 235 | self.training_fluxes.shape[1], dispersion.size)) |
||
| 236 | return None |
||
| 237 | |||
| 238 | |||
| 239 | @property |
||
| 240 | def is_trained(self): |
||
| 241 | return self._trained |
||
| 242 | |||
| 243 | |||
| 244 | # Attributes related to the labels and the label vector description. |
||
| 245 | @property |
||
| 246 | def labels_available(self): |
||
| 247 | """ |
||
| 248 | All of the available labels for each star in the training set. |
||
| 249 | """ |
||
| 250 | return self.training_labels.dtype.names |
||
| 251 | |||
| 252 | |||
| 253 | @property |
||
| 254 | def label_vector(self): |
||
| 255 | """ The label vector for all pixels. """ |
||
| 256 | return self._label_vector |
||
| 257 | |||
| 258 | |||
| 259 | @label_vector.setter |
||
| 260 | def label_vector(self, label_vector_description): |
||
| 261 | """ |
||
| 262 | Set a label vector. |
||
| 263 | |||
| 264 | :param label_vector_description: |
||
| 265 | A structured or human-readable version of the label vector |
||
| 266 | description. |
||
| 267 | """ |
||
| 268 | |||
| 269 | if label_vector_description is None: |
||
| 270 | self._label_vector = None |
||
| 271 | self.reset() |
||
| 272 | return None |
||
| 273 | |||
| 274 | label_vector = utils.parse_label_vector(label_vector_description) |
||
| 275 | |||
| 276 | # Need to actually verify that the parameters listed in the label vector |
||
| 277 | # are actually present in the training labels. |
||
| 278 | missing = \ |
||
| 279 | set(self._get_labels(label_vector)).difference(self.labels_available) |
||
| 280 | if missing: |
||
| 281 | raise ValueError("the following labels parsed from the label " |
||
| 282 | "vector description are missing in the training " |
||
| 283 | "set of labels: {0}".format(", ".join(missing))) |
||
| 284 | |||
| 285 | # If this is really a new label vector description, |
||
| 286 | # then we are no longer trained. |
||
| 287 | if not hasattr(self, "_label_vector") \ |
||
| 288 | or label_vector != self._label_vector: |
||
| 289 | self._label_vector = label_vector |
||
| 290 | self.reset() |
||
| 291 | |||
| 292 | return None |
||
| 293 | |||
| 294 | |||
| 295 | @property |
||
| 296 | def human_readable_label_vector(self): |
||
| 297 | """ Return a human-readable form of the label vector. """ |
||
| 298 | return utils.human_readable_label_vector(self.label_vector) |
||
| 299 | |||
| 300 | |||
| 301 | @property |
||
| 302 | def labels(self): |
||
| 303 | """ The labels that contribute to the label vector. """ |
||
| 304 | return self._get_labels(self.label_vector) |
||
| 305 | |||
| 306 | |||
| 307 | def _get_labels(self, label_vector): |
||
| 308 | """ |
||
| 309 | Return the labels that contribute to the structured label vector |
||
| 310 | provided. |
||
| 311 | """ |
||
| 312 | return () if label_vector is None else \ |
||
| 313 | list(OrderedDict.fromkeys([label for term in label_vector \ |
||
| 314 | for label, power in term if power != 0])) |
||
| 315 | |||
| 316 | |||
| 317 | def _get_lowest_order_label_indices(self): |
||
| 318 | """ |
||
| 319 | Get the indices for the lowest power label terms in the label vector. |
||
| 320 | """ |
||
| 321 | indices = OrderedDict() |
||
| 322 | for i, term in enumerate(self.label_vector): |
||
| 323 | if len(term) > 1: continue |
||
| 324 | label, order = term[0] |
||
| 325 | if order < indices.get(label, [None, np.inf])[-1]: |
||
| 326 | indices[label] = (i, order) |
||
| 327 | return [indices.get(label, [None])[0] for label in self.labels] |
||
| 328 | |||
| 329 | |||
| 330 | # Trained attributes that subclasses are likely to use. |
||
| 331 | @property |
||
| 332 | def coefficients(self): |
||
| 333 | return self._coefficients |
||
| 334 | |||
| 335 | |||
| 336 | @coefficients.setter |
||
| 337 | def coefficients(self, coefficients): |
||
| 338 | """ |
||
| 339 | Set the label vector coefficients for each pixel. This assumes a |
||
| 340 | 'standard' model where the label vector is common to all pixels. |
||
| 341 | |||
| 342 | :param coefficients: |
||
| 343 | A 2-D array of coefficients of shape |
||
| 344 | (`N_pixels`, `N_label_vector_terms`). |
||
| 345 | """ |
||
| 346 | |||
| 347 | if coefficients is None: |
||
| 348 | self._coefficients = None |
||
| 349 | return None |
||
| 350 | |||
| 351 | coefficients = np.atleast_2d(coefficients) |
||
| 352 | if len(coefficients.shape) > 2: |
||
| 353 | raise ValueError("coefficients must be a 2D array") |
||
| 354 | |||
| 355 | P, Q = coefficients.shape |
||
| 356 | if P != len(self.dispersion): |
||
| 357 | raise ValueError("axis 0 of coefficients array does not match the " |
||
| 358 | "number of pixels ({0} != {1})".format( |
||
| 359 | P, len(self.dispersion))) |
||
| 360 | if Q != 1 + len(self.label_vector): |
||
| 361 | raise ValueError("axis 1 of coefficients array does not match the " |
||
| 362 | "number of label vector terms ({0} != {1})".format( |
||
| 363 | Q, 1 + len(self.label_vector))) |
||
| 364 | self._coefficients = coefficients |
||
| 365 | return None |
||
| 366 | |||
| 367 | |||
| 368 | @property |
||
| 369 | def scatter(self): |
||
| 370 | return self._scatter |
||
| 371 | |||
| 372 | |||
| 373 | @scatter.setter |
||
| 374 | def scatter(self, scatter): |
||
| 375 | """ |
||
| 376 | Set the scatter values for each pixel. |
||
| 377 | |||
| 378 | :param scatter: |
||
| 379 | A 1-D array of scatter terms. |
||
| 380 | """ |
||
| 381 | |||
| 382 | if scatter is None: |
||
| 383 | self._scatter = None |
||
| 384 | return None |
||
| 385 | |||
| 386 | scatter = np.array(scatter).flatten() |
||
| 387 | if scatter.size != len(self.dispersion): |
||
| 388 | raise ValueError("number of scatter values does not match " |
||
| 389 | "the number of pixels ({0} != {1})".format( |
||
| 390 | scatter.size, len(self.dispersion))) |
||
| 391 | if np.any(scatter < 0): |
||
| 392 | raise ValueError("scatter terms must be positive") |
||
| 393 | self._scatter = scatter |
||
| 394 | return None |
||
| 395 | |||
| 396 | |||
| 397 | # Methods which must be implemented or updated by the subclasses. |
||
| 398 | def pixel_label_vector(self, pixel_index): |
||
| 399 | """ The label vector for a given pixel. """ |
||
| 400 | return self.label_vector |
||
| 401 | |||
| 402 | |||
| 403 | def train(self, *args, **kwargs): |
||
| 404 | raise NotImplementedError("The train method must be " |
||
| 405 | "implemented by subclasses") |
||
| 406 | |||
| 407 | |||
| 408 | def predict(self, *args, **kwargs): |
||
| 409 | raise NotImplementedError("The predict method must be " |
||
| 410 | "implemented by subclasses") |
||
| 411 | |||
| 412 | |||
| 413 | def fit(self, *args, **kwargs): |
||
| 414 | raise NotImplementedError("The fit method must be " |
||
| 415 | "implemented by subclasses") |
||
| 416 | |||
| 417 | |||
| 418 | # I/O |
||
| 419 | @requires_training_wheels |
||
| 420 | def save(self, filename, include_training_data=False, overwrite=False): |
||
| 421 | """ |
||
| 422 | Serialise the trained model and save it to disk. This will save all |
||
| 423 | relevant training attributes, and optionally, the training data. |
||
| 424 | |||
| 425 | :param filename: |
||
| 426 | The path to save the model to. |
||
| 427 | |||
| 428 | :param include_training_data: [optional] |
||
| 429 | Save the training data (labels, fluxes, uncertainties) used to train |
||
| 430 | the model. |
||
| 431 | |||
| 432 | :param overwrite: [optional] |
||
| 433 | Overwrite the existing file path, if it already exists. |
||
| 434 | """ |
||
| 435 | |||
| 436 | if path.exists(filename) and not overwrite: |
||
| 437 | raise IOError("filename already exists: {0}".format(filename)) |
||
| 438 | |||
| 439 | attributes = list(self._descriptive_attributes) \ |
||
| 440 | + list(self._trained_attributes) \ |
||
| 441 | + list(self._data_attributes) |
||
| 442 | if "metadata" in attributes: |
||
| 443 | raise ValueError("'metadata' is a protected attribute and cannot " |
||
| 444 | "be used in the _*_attributes in a class") |
||
| 445 | |||
| 446 | # Store up all the trained attributes and a hash of the training set. |
||
| 447 | contents = OrderedDict([ |
||
| 448 | (attr.lstrip("_"), getattr(self, attr)) for attr in \ |
||
| 449 | (self._descriptive_attributes + self._trained_attributes)]) |
||
| 450 | contents["training_set_hash"] = utils.short_hash(getattr(self, attr) \ |
||
| 451 | for attr in self._data_attributes) |
||
| 452 | |||
| 453 | if include_training_data: |
||
| 454 | contents.update([(attr.lstrip("_"), getattr(self, attr)) \ |
||
| 455 | for attr in self._data_attributes]) |
||
| 456 | |||
| 457 | contents["metadata"] = { |
||
| 458 | "version": code_version, |
||
| 459 | "model_name": type(self).__name__, |
||
| 460 | "modified": str(datetime.now()), |
||
| 461 | "data_attributes": \ |
||
| 462 | [_.lstrip("_") for _ in self._data_attributes], |
||
| 463 | "trained_attributes": \ |
||
| 464 | [_.lstrip("_") for _ in self._trained_attributes], |
||
| 465 | "descriptive_attributes": \ |
||
| 466 | [_.lstrip("_") for _ in self._descriptive_attributes] |
||
| 467 | } |
||
| 468 | |||
| 469 | with open(filename, "wb") as fp: |
||
| 470 | pickle.dump(contents, fp, -1) |
||
| 471 | |||
| 472 | return None |
||
| 473 | |||
| 474 | |||
| 475 | def load(self, filename, verify_training_data=False): |
||
| 476 | """ |
||
| 477 | Load a saved model from disk. |
||
| 478 | |||
| 479 | :param filename: |
||
| 480 | The path where to load the model from. |
||
| 481 | |||
| 482 | :param verify_training_data: [optional] |
||
| 483 | If there is training data in the saved model, verify its contents. |
||
| 484 | Otherwise if no training data is saved, verify that the data used |
||
| 485 | to train the model is the same data provided when this model was |
||
| 486 | instantiated. |
||
| 487 | """ |
||
| 488 | |||
| 489 | with open(filename, "rb") as fp: |
||
| 490 | contents = pickle.load(fp) |
||
| 491 | |||
| 492 | assert contents["metadata"]["model_name"] == type(self).__name__ |
||
| 493 | |||
| 494 | # If data exists, deal with that first. |
||
| 495 | has_data = (contents["metadata"]["data_attributes"][0] in contents) |
||
| 496 | if has_data: |
||
| 497 | |||
| 498 | if verify_training_data: |
||
| 499 | data_hash = utils.short_hash(contents[attr] \ |
||
| 500 | for attr in contents["metadata"]["data_attributes"]) |
||
| 501 | if contents["training_set_hash"] is not None \ |
||
| 502 | and data_hash != contents["training_set_hash"]: |
||
| 503 | raise ValueError("expected hash for the training data is " |
||
| 504 | "different to the actual data hash " |
||
| 505 | "({0} != {1})".format( |
||
| 506 | contents["training_set_hash"], |
||
| 507 | data_hash)) |
||
| 508 | |||
| 509 | # Set the data attributes. |
||
| 510 | for attribute in contents["metadata"]["data_attributes"]: |
||
| 511 | if attribute in contents: |
||
| 512 | setattr(self, "_{}".format(attribute), contents[attribute]) |
||
| 513 | |||
| 514 | # Set descriptive and trained attributes. |
||
| 515 | self.reset() |
||
| 516 | for attribute in contents["metadata"]["descriptive_attributes"]: |
||
| 517 | setattr(self, "_{}".format(attribute), contents[attribute]) |
||
| 518 | for attribute in contents["metadata"]["trained_attributes"]: |
||
| 519 | setattr(self, "_{}".format(attribute), contents[attribute]) |
||
| 520 | self._trained = True |
||
| 521 | |||
| 522 | return None |
||
| 523 | |||
| 524 | |||
| 525 | # Properties and attribuets related to training, etc. |
||
| 526 | @property |
||
| 527 | @requires_label_vector |
||
| 528 | def labels_array(self): |
||
| 529 | """ |
||
| 530 | Return an array containing just the training labels, given the label |
||
| 531 | vector. |
||
| 532 | """ |
||
| 533 | return _build_label_vector_rows( |
||
| 534 | [[(label, 1)] for label in self.labels], self.training_labels)[1:].T |
||
| 535 | |||
| 536 | |||
| 537 | @property |
||
| 538 | @requires_label_vector |
||
| 539 | def label_vector_array(self): |
||
| 540 | """ |
||
| 541 | Build the label vector array. |
||
| 542 | """ |
||
| 543 | |||
| 544 | lva = _build_label_vector_rows(self.label_vector, self.training_labels) |
||
| 545 | |||
| 546 | if not np.all(np.isfinite(lva)): |
||
| 547 | logger.warn("Non-finite labels in the label vector array!") |
||
| 548 | return lva |
||
| 549 | |||
| 550 | |||
| 551 | # Residuals in labels in the training data set. |
||
| 552 | @requires_training_wheels |
||
| 553 | def get_training_label_residuals(self): |
||
| 554 | """ |
||
| 555 | Return the residuals (model - training) between the parameters that the |
||
| 556 | model returns for each star, and the training set value. |
||
| 557 | """ |
||
| 558 | |||
| 559 | optimised_labels = self.fit(self.training_fluxes, |
||
| 560 | self.training_flux_uncertainties, full_output=False) |
||
| 561 | |||
| 562 | return optimised_labels - self.labels_array |
||
| 563 | |||
| 564 | |||
| 565 | def _format_input_labels(self, args=None, **kwargs): |
||
| 566 | """ |
||
| 567 | Format input labels either from a list or dictionary into a common form. |
||
| 568 | """ |
||
| 569 | |||
| 570 | # We want labels in a dictionary. |
||
| 571 | labels = kwargs if args is None else dict(zip(self.labels, args)) |
||
| 572 | return { k: (v if isinstance(v, (list, tuple, np.ndarray)) else [v]) \ |
||
| 573 | for k, v in labels.items() } |
||
| 574 | |||
| 575 | |||
| 576 | # Put Cross-validation functions in here. |
||
| 577 | @requires_label_vector |
||
| 578 | def cross_validate(self, pre_train=None, **kwargs): |
||
| 579 | """ |
||
| 580 | Perform leave-one-out cross-validation on the training set. |
||
| 581 | """ |
||
| 582 | |||
| 583 | inferred = np.nan * np.ones_like(self.labels_array) |
||
| 584 | N_training_set, N_labels = inferred.shape |
||
| 585 | N_stop_at = kwargs.pop("N", N_training_set) |
||
| 586 | |||
| 587 | debug = kwargs.pop("debug", False) |
||
| 588 | |||
| 589 | kwds = { "threads": self.threads } |
||
| 590 | kwds.update(kwargs) |
||
| 591 | |||
| 592 | for i in range(N_training_set): |
||
| 593 | |||
| 594 | training_set = np.ones(N_training_set, dtype=bool) |
||
| 595 | training_set[i] = False |
||
| 596 | |||
| 597 | # Create a clean model to use so we don't overwrite self. |
||
| 598 | model = self.__class__( |
||
| 599 | self.training_labels[training_set], |
||
| 600 | self.training_fluxes[training_set], |
||
| 601 | self.training_flux_uncertainties[training_set], |
||
| 602 | **kwds) |
||
| 603 | |||
| 604 | # Initialise and run any pre-training function. |
||
| 605 | for _attribute in self._descriptive_attributes: |
||
| 606 | setattr(model, _attribute[1:], getattr(self, _attribute[1:])) |
||
| 607 | |||
| 608 | if pre_train is not None: |
||
| 609 | pre_train(self, model) |
||
| 610 | |||
| 611 | # Train and solve. |
||
| 612 | model.train() |
||
| 613 | |||
| 614 | try: |
||
| 615 | inferred[i, :] = model.fit(self.training_fluxes[i], |
||
| 616 | self.training_flux_uncertainties[i], full_output=False) |
||
| 617 | |||
| 618 | except: |
||
| 619 | logger.exception("Exception during cross-validation on object " |
||
| 620 | "with index {0}:".format(i)) |
||
| 621 | if debug: raise |
||
| 622 | |||
| 623 | if i == N_stop_at + 1: |
||
| 624 | break |
||
| 625 | |||
| 626 | return inferred[:N_stop_at, :] |
||
| 627 | |||
| 664 |