Total Complexity | 106 |
Total Lines | 576 |
Duplicated Lines | 0 % |
Complex classes like AnniesLasso.BaseCannonModel often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
1 | #!/usr/bin/env python |
||
55 | class BaseCannonModel(object): |
||
56 | """ |
||
57 | An abstract Cannon model object that implements convenience functions. |
||
58 | |||
59 | :param labels: |
||
60 | A table with columns as labels, and stars as rows. |
||
61 | |||
62 | :type labels: |
||
63 | :class:`~astropy.table.Table` or numpy structured array |
||
64 | |||
65 | :param fluxes: |
||
66 | An array of fluxes for stars in the training set, given as shape |
||
67 | `(num_stars, num_pixels)`. The `num_stars` should match the number of |
||
68 | rows in `labels`. |
||
69 | |||
70 | :type fluxes: |
||
71 | :class:`np.ndarray` |
||
72 | |||
73 | :param flux_uncertainties: |
||
74 | An array of 1-sigma flux uncertainties for stars in the training set, |
||
75 | The shape of the `flux_uncertainties` should match `fluxes`. |
||
76 | |||
77 | :type flux_uncertainties: |
||
78 | :class:`np.ndarray` |
||
79 | |||
80 | :param dispersion: [optional] |
||
81 | The dispersion values corresponding to the given pixels. If provided, |
||
82 | this should have length `num_pixels`. |
||
83 | |||
84 | :param live_dangerously: [optional] |
||
85 | If enabled then no checks will be made on the label names, prohibiting |
||
86 | the user to input human-readable forms of the label vector. |
||
87 | """ |
||
88 | |||
89 | _descriptive_attributes = ["_label_vector"] |
||
90 | _trained_attributes = [] |
||
91 | _data_attributes = [] |
||
92 | _forbidden_label_characters = "^*" |
||
93 | |||
94 | def __init__(self, labels, fluxes, flux_uncertainties, dispersion=None, |
||
95 | threads=1, pool=None, live_dangerously=False): |
||
96 | |||
97 | self._training_labels = labels |
||
98 | self._training_fluxes = np.atleast_2d(fluxes) |
||
99 | self._training_flux_uncertainties = np.atleast_2d(flux_uncertainties) |
||
100 | self._dispersion = np.arange(fluxes.shape[1], dtype=int) \ |
||
101 | if dispersion is None else dispersion |
||
102 | |||
103 | for attribute in self._descriptive_attributes: |
||
104 | setattr(self, attribute, None) |
||
105 | |||
106 | # The training data must be checked, but users can live dangerously if |
||
107 | # they think they can correctly specify the label vector description. |
||
108 | self._verify_training_data() |
||
109 | if not live_dangerously: |
||
110 | self._verify_labels_available() |
||
111 | |||
112 | self.reset() |
||
113 | self.threads = threads |
||
114 | self.pool = pool or mp.Pool(threads) if threads > 1 else None |
||
115 | |||
116 | |||
117 | def reset(self): |
||
118 | """ |
||
119 | Clear any attributes that have been trained upon. |
||
120 | """ |
||
121 | |||
122 | self._trained = False |
||
123 | for attribute in self._trained_attributes: |
||
124 | setattr(self, attribute, None) |
||
125 | |||
126 | return None |
||
127 | |||
128 | |||
129 | def __str__(self): |
||
130 | return "<{module}.{name} {trained}using a training set of {N} stars "\ |
||
131 | "with {K} available labels and {M} pixels each>".format( |
||
132 | module=self.__module__, |
||
133 | name=type(self).__name__, |
||
134 | trained="trained " if self.is_trained else "", |
||
135 | N=len(self.training_labels), |
||
136 | K=len(self.labels_available), |
||
137 | M=len(self.dispersion)) |
||
138 | |||
139 | |||
140 | def __repr__(self): |
||
141 | return "<{0}.{1} object at {2}>".format( |
||
142 | self.__module__, type(self).__name__, hex(id(self))) |
||
143 | |||
144 | |||
145 | # Attributes related to the training data. |
||
146 | @property |
||
147 | def dispersion(self): |
||
148 | """ |
||
149 | Return the dispersion points for all pixels. |
||
150 | """ |
||
151 | return self._dispersion |
||
152 | |||
153 | |||
154 | @dispersion.setter |
||
155 | def dispersion(self, dispersion): |
||
156 | """ |
||
157 | Set the dispersion values for all the pixels. |
||
158 | """ |
||
159 | try: |
||
160 | len(dispersion) |
||
161 | except TypeError: |
||
162 | raise TypeError("dispersion provided must be an array or list-like") |
||
163 | |||
164 | if len(dispersion) != self.training_fluxes.shape[1]: |
||
165 | raise ValueError("dispersion provided does not match the number " |
||
166 | "of pixels per star ({0} != {1})".format( |
||
167 | len(dispersion), self.training_fluxes.shape[1])) |
||
168 | |||
169 | dispersion = np.array(dispersion) |
||
170 | if dispersion.dtype.kind not in "iuf": |
||
171 | raise ValueError("dispersion values are not float-like") |
||
172 | |||
173 | if not np.all(np.isfinite(dispersion)): |
||
174 | raise ValueError("dispersion values must be finite") |
||
175 | |||
176 | self._dispersion = dispersion |
||
177 | return None |
||
178 | |||
179 | |||
180 | @property |
||
181 | def training_labels(self): |
||
182 | return self._training_labels |
||
183 | |||
184 | |||
185 | @property |
||
186 | def training_fluxes(self): |
||
187 | return self._training_fluxes |
||
188 | |||
189 | |||
190 | @property |
||
191 | def training_flux_uncertainties(self): |
||
192 | return self._training_flux_uncertainties |
||
193 | |||
194 | |||
195 | # Verifying the training data. |
||
196 | def _verify_labels_available(self): |
||
197 | """ |
||
198 | Verify the label names provided do not include forbidden characters. |
||
199 | """ |
||
200 | if self._forbidden_label_characters is None: |
||
201 | return True |
||
202 | |||
203 | for label in self.training_labels.dtype.names: |
||
204 | for character in self._forbidden_label_characters: |
||
205 | if character in label: |
||
206 | raise ValueError( |
||
207 | "forbidden character '{char}' is in potential " |
||
208 | "label '{label}' - you can disable this verification " |
||
209 | "by enabling `live_dangerously`".format( |
||
210 | char=character, label=label)) |
||
211 | return None |
||
212 | |||
213 | |||
214 | def _verify_training_data(self): |
||
215 | """ |
||
216 | Verify the training data for the appropriate shape and content. |
||
217 | """ |
||
218 | if self.training_fluxes.shape != self.training_flux_uncertainties.shape: |
||
219 | raise ValueError( |
||
220 | "the training flux and uncertainty arrays should " |
||
221 | "have the same shape") |
||
222 | |||
223 | if len(self.training_labels) == 0 \ |
||
224 | or self.training_labels.dtype.names is None: |
||
225 | raise ValueError("no named labels provided for the training set") |
||
226 | |||
227 | if len(self.training_labels) != self.training_fluxes.shape[0]: |
||
228 | raise ValueError( |
||
229 | "the first axes of the training flux array should " |
||
230 | "have the same shape as the nuber of rows in the label table " |
||
231 | "(N_stars, N_pixels)") |
||
232 | |||
233 | if self.dispersion is not None: |
||
234 | dispersion = np.atleast_1d(self.dispersion).flatten() |
||
235 | if dispersion.size != self.training_fluxes.shape[1]: |
||
236 | raise ValueError( |
||
237 | "mis-match between the number of wavelength " |
||
238 | "points ({0}) and flux values ({1})".format( |
||
239 | self.training_fluxes.shape[1], dispersion.size)) |
||
240 | return None |
||
241 | |||
242 | |||
243 | @property |
||
244 | def is_trained(self): |
||
245 | return self._trained |
||
246 | |||
247 | |||
248 | # Attributes related to the labels and the label vector description. |
||
249 | @property |
||
250 | def labels_available(self): |
||
251 | """ |
||
252 | All of the available labels for each star in the training set. |
||
253 | """ |
||
254 | return self.training_labels.dtype.names |
||
255 | |||
256 | |||
257 | @property |
||
258 | def label_vector(self): |
||
259 | """ The label vector for all pixels. """ |
||
260 | return self._label_vector |
||
261 | |||
262 | |||
263 | @label_vector.setter |
||
264 | def label_vector(self, label_vector_description): |
||
265 | """ |
||
266 | Set a label vector. |
||
267 | |||
268 | :param label_vector_description: |
||
269 | A structured or human-readable version of the label vector |
||
270 | description. |
||
271 | """ |
||
272 | |||
273 | if label_vector_description is None: |
||
274 | self._label_vector = None |
||
275 | self.reset() |
||
276 | return None |
||
277 | |||
278 | label_vector = utils.parse_label_vector(label_vector_description) |
||
279 | |||
280 | # Need to actually verify that the parameters listed in the label vector |
||
281 | # are actually present in the training labels. |
||
282 | missing = \ |
||
283 | set(self._get_labels(label_vector)).difference(self.labels_available) |
||
284 | if missing: |
||
285 | raise ValueError("the following labels parsed from the label " |
||
286 | "vector description are missing in the training " |
||
287 | "set of labels: {0}".format(", ".join(missing))) |
||
288 | |||
289 | # If this is really a new label vector description, |
||
290 | # then we are no longer trained. |
||
291 | if not hasattr(self, "_label_vector") \ |
||
292 | or label_vector != self._label_vector: |
||
293 | self._label_vector = label_vector |
||
294 | self.reset() |
||
295 | |||
296 | return None |
||
297 | |||
298 | |||
299 | @property |
||
300 | def human_readable_label_vector(self): |
||
301 | """ Return a human-readable form of the label vector. """ |
||
302 | return utils.human_readable_label_vector(self.label_vector) |
||
303 | |||
304 | |||
305 | @property |
||
306 | def labels(self): |
||
307 | """ The labels that contribute to the label vector. """ |
||
308 | return self._get_labels(self.label_vector) |
||
309 | |||
310 | |||
311 | def _get_labels(self, label_vector): |
||
312 | """ |
||
313 | Return the labels that contribute to the structured label vector |
||
314 | provided. |
||
315 | """ |
||
316 | return () if label_vector is None else \ |
||
317 | list(OrderedDict.fromkeys([label for term in label_vector \ |
||
318 | for label, power in term if power != 0])) |
||
319 | |||
320 | |||
321 | def _get_lowest_order_label_indices(self): |
||
322 | """ |
||
323 | Get the indices for the lowest power label terms in the label vector. |
||
324 | """ |
||
325 | indices = OrderedDict() |
||
326 | for i, term in enumerate(self.label_vector): |
||
327 | if len(term) > 1: continue |
||
328 | label, order = term[0] |
||
329 | if order < indices.get(label, [None, np.inf])[-1]: |
||
330 | indices[label] = (i, order) |
||
331 | return [indices.get(label, [None])[0] for label in self.labels] |
||
332 | |||
333 | |||
334 | # Trained attributes that subclasses are likely to use. |
||
335 | @property |
||
336 | def coefficients(self): |
||
337 | return self._coefficients |
||
338 | |||
339 | |||
340 | @coefficients.setter |
||
341 | def coefficients(self, coefficients): |
||
342 | """ |
||
343 | Set the label vector coefficients for each pixel. This assumes a |
||
344 | 'standard' model where the label vector is common to all pixels. |
||
345 | |||
346 | :param coefficients: |
||
347 | A 2-D array of coefficients of shape |
||
348 | (`N_pixels`, `N_label_vector_terms`). |
||
349 | """ |
||
350 | |||
351 | if coefficients is None: |
||
352 | self._coefficients = None |
||
353 | return None |
||
354 | |||
355 | coefficients = np.atleast_2d(coefficients) |
||
356 | if len(coefficients.shape) > 2: |
||
357 | raise ValueError("coefficients must be a 2D array") |
||
358 | |||
359 | P, Q = coefficients.shape |
||
360 | if P != len(self.dispersion): |
||
361 | raise ValueError("axis 0 of coefficients array does not match the " |
||
362 | "number of pixels ({0} != {1})".format( |
||
363 | P, len(self.dispersion))) |
||
364 | if Q != 1 + len(self.label_vector): |
||
365 | raise ValueError("axis 1 of coefficients array does not match the " |
||
366 | "number of label vector terms ({0} != {1})".format( |
||
367 | Q, 1 + len(self.label_vector))) |
||
368 | self._coefficients = coefficients |
||
369 | return None |
||
370 | |||
371 | |||
372 | @property |
||
373 | def scatter(self): |
||
374 | return self._scatter |
||
375 | |||
376 | |||
377 | @scatter.setter |
||
378 | def scatter(self, scatter): |
||
379 | """ |
||
380 | Set the scatter values for each pixel. |
||
381 | |||
382 | :param scatter: |
||
383 | A 1-D array of scatter terms. |
||
384 | """ |
||
385 | |||
386 | if scatter is None: |
||
387 | self._scatter = None |
||
388 | return None |
||
389 | |||
390 | scatter = np.array(scatter).flatten() |
||
391 | if scatter.size != len(self.dispersion): |
||
392 | raise ValueError("number of scatter values does not match " |
||
393 | "the number of pixels ({0} != {1})".format( |
||
394 | scatter.size, len(self.dispersion))) |
||
395 | if np.any(scatter < 0): |
||
396 | raise ValueError("scatter terms must be positive") |
||
397 | self._scatter = scatter |
||
398 | return None |
||
399 | |||
400 | |||
401 | # Methods which must be implemented or updated by the subclasses. |
||
402 | def pixel_label_vector(self, pixel_index): |
||
403 | """ The label vector for a given pixel. """ |
||
404 | return self.label_vector |
||
405 | |||
406 | |||
407 | def train(self, *args, **kwargs): |
||
408 | raise NotImplementedError("The train method must be " |
||
409 | "implemented by subclasses") |
||
410 | |||
411 | |||
412 | def predict(self, *args, **kwargs): |
||
413 | raise NotImplementedError("The predict method must be " |
||
414 | "implemented by subclasses") |
||
415 | |||
416 | |||
417 | def fit(self, *args, **kwargs): |
||
418 | raise NotImplementedError("The fit method must be " |
||
419 | "implemented by subclasses") |
||
420 | |||
421 | |||
422 | # I/O |
||
423 | @requires_training_wheels |
||
424 | def save(self, filename, include_training_data=False, overwrite=False): |
||
425 | """ |
||
426 | Serialise the trained model and save it to disk. This will save all |
||
427 | relevant training attributes, and optionally, the training data. |
||
428 | |||
429 | :param filename: |
||
430 | The path to save the model to. |
||
431 | |||
432 | :param include_training_data: [optional] |
||
433 | Save the training data (labels, fluxes, uncertainties) used to train |
||
434 | the model. |
||
435 | |||
436 | :param overwrite: [optional] |
||
437 | Overwrite the existing file path, if it already exists. |
||
438 | """ |
||
439 | |||
440 | if path.exists(filename) and not overwrite: |
||
441 | raise IOError("filename already exists: {0}".format(filename)) |
||
442 | |||
443 | attributes = list(self._descriptive_attributes) \ |
||
444 | + list(self._trained_attributes) \ |
||
445 | + list(self._data_attributes) |
||
446 | if "metadata" in attributes: |
||
447 | raise ValueError("'metadata' is a protected attribute and cannot " |
||
448 | "be used in the _*_attributes in a class") |
||
449 | |||
450 | # Store up all the trained attributes and a hash of the training set. |
||
451 | contents = OrderedDict([ |
||
452 | (attr.lstrip("_"), getattr(self, attr)) for attr in \ |
||
453 | (self._descriptive_attributes + self._trained_attributes)]) |
||
454 | contents["training_set_hash"] = utils.short_hash(getattr(self, attr) \ |
||
455 | for attr in self._data_attributes) |
||
456 | |||
457 | if include_training_data: |
||
458 | contents.update([(attr.lstrip("_"), getattr(self, attr)) \ |
||
459 | for attr in self._data_attributes]) |
||
460 | |||
461 | contents["metadata"] = { |
||
462 | "version": code_version, |
||
463 | "model_name": type(self).__name__, |
||
464 | "modified": str(datetime.now()), |
||
465 | "data_attributes": \ |
||
466 | [_.lstrip("_") for _ in self._data_attributes], |
||
467 | "trained_attributes": \ |
||
468 | [_.lstrip("_") for _ in self._trained_attributes], |
||
469 | "descriptive_attributes": \ |
||
470 | [_.lstrip("_") for _ in self._descriptive_attributes] |
||
471 | } |
||
472 | |||
473 | with open(filename, "wb") as fp: |
||
474 | pickle.dump(contents, fp, -1) |
||
475 | |||
476 | return None |
||
477 | |||
478 | |||
479 | def load(self, filename, verify_training_data=False): |
||
480 | """ |
||
481 | Load a saved model from disk. |
||
482 | |||
483 | :param filename: |
||
484 | The path where to load the model from. |
||
485 | |||
486 | :param verify_training_data: [optional] |
||
487 | If there is training data in the saved model, verify its contents. |
||
488 | Otherwise if no training data is saved, verify that the data used |
||
489 | to train the model is the same data provided when this model was |
||
490 | instantiated. |
||
491 | """ |
||
492 | |||
493 | with open(filename, "rb") as fp: |
||
494 | contents = pickle.load(fp) |
||
495 | |||
496 | assert contents["metadata"]["model_name"] == type(self).__name__ |
||
497 | |||
498 | # If data exists, deal with that first. |
||
499 | has_data = (contents["metadata"]["data_attributes"][0] in contents) |
||
500 | if has_data: |
||
501 | |||
502 | if verify_training_data: |
||
503 | data_hash = utils.short_hash(contents[attr] \ |
||
504 | for attr in contents["metadata"]["data_attributes"]) |
||
505 | if contents["training_set_hash"] is not None \ |
||
506 | and data_hash != contents["training_set_hash"]: |
||
507 | raise ValueError("expected hash for the training data is " |
||
508 | "different to the actual data hash " |
||
509 | "({0} != {1})".format( |
||
510 | contents["training_set_hash"], |
||
511 | data_hash)) |
||
512 | |||
513 | # Set the data attributes. |
||
514 | for attribute in contents["metadata"]["data_attributes"]: |
||
515 | if attribute in contents: |
||
516 | setattr(self, "_{}".format(attribute), contents[attribute]) |
||
517 | |||
518 | # Set descriptive and trained attributes. |
||
519 | self.reset() |
||
520 | for attribute in contents["metadata"]["descriptive_attributes"]: |
||
521 | setattr(self, "_{}".format(attribute), contents[attribute]) |
||
522 | for attribute in contents["metadata"]["trained_attributes"]: |
||
523 | setattr(self, "_{}".format(attribute), contents[attribute]) |
||
524 | self._trained = True |
||
525 | |||
526 | return None |
||
527 | |||
528 | |||
529 | # Properties and attribuets related to training, etc. |
||
530 | @property |
||
531 | @requires_model_description |
||
532 | def labels_array(self): |
||
533 | """ |
||
534 | Return an array containing just the training labels, given the label |
||
535 | vector. |
||
536 | """ |
||
537 | return _build_label_vector_rows( |
||
538 | [[(label, 1)] for label in self.labels], self.training_labels)[1:].T |
||
539 | |||
540 | |||
541 | @property |
||
542 | @requires_model_description |
||
543 | def label_vector_array(self): |
||
544 | """ |
||
545 | Build the label vector array. |
||
546 | """ |
||
547 | |||
548 | lva = _build_label_vector_rows(self.label_vector, self.training_labels) |
||
549 | |||
550 | if not np.all(np.isfinite(lva)): |
||
551 | logger.warn("Non-finite labels in the label vector array!") |
||
552 | return lva |
||
553 | |||
554 | |||
555 | # Residuals in labels in the training data set. |
||
556 | @requires_training_wheels |
||
557 | def get_training_label_residuals(self): |
||
558 | """ |
||
559 | Return the residuals (model - training) between the parameters that the |
||
560 | model returns for each star, and the training set value. |
||
561 | """ |
||
562 | |||
563 | optimised_labels = self.fit(self.training_fluxes, |
||
564 | self.training_flux_uncertainties, full_output=False) |
||
565 | |||
566 | return optimised_labels - self.labels_array |
||
567 | |||
568 | |||
569 | def _format_input_labels(self, args=None, **kwargs): |
||
570 | """ |
||
571 | Format input labels either from a list or dictionary into a common form. |
||
572 | """ |
||
573 | |||
574 | # We want labels in a dictionary. |
||
575 | labels = kwargs if args is None else dict(zip(self.labels, args)) |
||
576 | return { k: (v if isinstance(v, (list, tuple, np.ndarray)) else [v]) \ |
||
577 | for k, v in labels.items() } |
||
578 | |||
579 | |||
580 | # Put Cross-validation functions in here. |
||
581 | @requires_model_description |
||
582 | def cross_validate(self, pre_train=None, **kwargs): |
||
583 | """ |
||
584 | Perform leave-one-out cross-validation on the training set. |
||
585 | """ |
||
586 | |||
587 | inferred = np.nan * np.ones_like(self.labels_array) |
||
588 | N_training_set, N_labels = inferred.shape |
||
589 | N_stop_at = kwargs.pop("N", N_training_set) |
||
590 | |||
591 | debug = kwargs.pop("debug", False) |
||
592 | |||
593 | kwds = { "threads": self.threads } |
||
594 | kwds.update(kwargs) |
||
595 | |||
596 | for i in range(N_training_set): |
||
597 | |||
598 | training_set = np.ones(N_training_set, dtype=bool) |
||
599 | training_set[i] = False |
||
600 | |||
601 | # Create a clean model to use so we don't overwrite self. |
||
602 | model = self.__class__( |
||
603 | self.training_labels[training_set], |
||
604 | self.training_fluxes[training_set], |
||
605 | self.training_flux_uncertainties[training_set], |
||
606 | **kwds) |
||
607 | |||
608 | # Initialise and run any pre-training function. |
||
609 | for _attribute in self._descriptive_attributes: |
||
610 | setattr(model, _attribute[1:], getattr(self, _attribute[1:])) |
||
611 | |||
612 | if pre_train is not None: |
||
613 | pre_train(self, model) |
||
614 | |||
615 | # Train and solve. |
||
616 | model.train() |
||
617 | |||
618 | try: |
||
619 | inferred[i, :] = model.fit(self.training_fluxes[i], |
||
620 | self.training_flux_uncertainties[i], full_output=False) |
||
621 | |||
622 | except: |
||
623 | logger.exception("Exception during cross-validation on object " |
||
624 | "with index {0}:".format(i)) |
||
625 | if debug: raise |
||
626 | |||
627 | if i == N_stop_at + 1: |
||
628 | break |
||
629 | |||
630 | return inferred[:N_stop_at, :] |
||
631 | |||
668 |