Completed
Push — master ( 0c6327...53c124 )
by Andy
01:07
created

AnniesLasso.BaseCannonModel   F

Complexity

Total Complexity 106

Size/Duplication

Total Lines 576
Duplicated Lines 0 %
Metric Value
wmc 106
dl 0
loc 576
rs 1.5789

30 Methods

Rating   Name   Duplication   Size   Complexity  
A is_trained() 0 3 1
A human_readable_label_vector() 0 4 1
A training_flux_uncertainties() 0 3 1
B _verify_labels_available() 0 16 5
B __init__() 0 21 5
F save() 0 54 12
C cross_validate() 0 50 7
A get_training_label_residuals() 0 11 1
A labels_array() 0 9 2
A labels() 0 4 1
B _get_lowest_order_label_indices() 0 11 5
A train() 0 2 1
A coefficients() 0 3 1
A reset() 0 10 2
F load() 0 48 12
B _get_labels() 0 8 5
A _format_input_labels() 0 9 4
B label_vector() 0 34 5
B dispersion() 0 6 5
C _verify_training_data() 0 27 7
A __str__() 0 9 2
A predict() 0 2 1
A training_fluxes() 0 3 1
A __repr__() 0 3 1
A pixel_label_vector() 0 3 1
A label_vector_array() 0 12 2
A scatter() 0 22 1
A labels_available() 0 6 1
A fit() 0 2 1
A training_labels() 0 3 1

How to fix   Complexity   

Complex Class

Complex classes like AnniesLasso.BaseCannonModel often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
"""
5
An abstract model class for The Cannon.
6
"""
7
8
from __future__ import (division, print_function, absolute_import,
9
                        unicode_literals)
10
11
__all__ = [
12
    "BaseCannonModel", "requires_training_wheels", "requires_model_description"]
13
14
15
import logging
16
import numpy as np
17
import multiprocessing as mp
18
from collections import OrderedDict
19
from datetime import datetime
20
from os import path
21
from six.moves import cPickle as pickle
22
23
from . import (utils, __version__ as code_version)
24
25
logger = logging.getLogger(__name__)
26
27
28
def requires_training_wheels(method):
29
    """
30
    A decorator for model methods that require training before being run.
31
    """
32
33
    def wrapper(model, *args, **kwargs):
34
        if not model.is_trained:
35
            raise TypeError("the model needs training first")
36
        return method(model, *args, **kwargs)
37
    return wrapper
38
39
40
def requires_model_description(method):
41
    """
42
    A decorator for model methods that require a full model description.
43
    (That is, none of the _descriptive_attributes are None)
44
    """
45
46
    def wrapper(model, *args, **kwargs):
47
        for descriptive_attribute in model._descriptive_attributes:
48
            if getattr(model, descriptive_attribute) is None:
49
                raise TypeError("the model requires a {} term".format(
50
                    descriptive_attribute.lstrip("_")))
51
        return method(model, *args, **kwargs)
52
    return wrapper
53
54
55
class BaseCannonModel(object):
56
    """
57
    An abstract Cannon model object that implements convenience functions.
58
59
    :param labels:
60
        A table with columns as labels, and stars as rows.
61
62
    :type labels:
63
        :class:`~astropy.table.Table` or numpy structured array
64
65
    :param fluxes:
66
        An array of fluxes for stars in the training set, given as shape
67
        `(num_stars, num_pixels)`. The `num_stars` should match the number of
68
        rows in `labels`.
69
70
    :type fluxes:
71
        :class:`np.ndarray`
72
73
    :param flux_uncertainties:
74
        An array of 1-sigma flux uncertainties for stars in the training set,
75
        The shape of the `flux_uncertainties` should match `fluxes`. 
76
77
    :type flux_uncertainties:
78
        :class:`np.ndarray`
79
80
    :param dispersion: [optional]
81
        The dispersion values corresponding to the given pixels. If provided, 
82
        this should have length `num_pixels`.
83
84
    :param live_dangerously: [optional]
85
        If enabled then no checks will be made on the label names, prohibiting
86
        the user to input human-readable forms of the label vector.
87
    """
88
89
    _descriptive_attributes = ["_label_vector"]
90
    _trained_attributes = []
91
    _data_attributes = []
92
    _forbidden_label_characters = "^*"
93
    
94
    def __init__(self, labels, fluxes, flux_uncertainties, dispersion=None,
95
        threads=1, pool=None, live_dangerously=False):
96
97
        self._training_labels = labels
98
        self._training_fluxes = np.atleast_2d(fluxes)
99
        self._training_flux_uncertainties = np.atleast_2d(flux_uncertainties)
100
        self._dispersion = np.arange(fluxes.shape[1], dtype=int) \
101
            if dispersion is None else dispersion
102
        
103
        for attribute in self._descriptive_attributes:
104
            setattr(self, attribute, None)
105
        
106
        # The training data must be checked, but users can live dangerously if
107
        # they think they can correctly specify the label vector description.
108
        self._verify_training_data()
109
        if not live_dangerously:
110
            self._verify_labels_available()
111
112
        self.reset()
113
        self.threads = threads
114
        self.pool = pool or mp.Pool(threads) if threads > 1 else None
115
116
117
    def reset(self):
118
        """
119
        Clear any attributes that have been trained upon.
120
        """
121
122
        self._trained = False
123
        for attribute in self._trained_attributes:
124
            setattr(self, attribute, None)
125
        
126
        return None
127
128
129
    def __str__(self):
130
        return "<{module}.{name} {trained}using a training set of {N} stars "\
131
               "with {K} available labels and {M} pixels each>".format(
132
                    module=self.__module__,
133
                    name=type(self).__name__,
134
                    trained="trained " if self.is_trained else "",
135
                    N=len(self.training_labels),
136
                    K=len(self.labels_available),
137
                    M=len(self.dispersion))
138
139
140
    def __repr__(self):
141
        return "<{0}.{1} object at {2}>".format(
142
            self.__module__, type(self).__name__, hex(id(self)))
143
144
145
    # Attributes related to the training data.
146
    @property
147
    def dispersion(self):
148
        """
149
        Return the dispersion points for all pixels.
150
        """
151
        return self._dispersion
152
153
154
    @dispersion.setter
155
    def dispersion(self, dispersion):
156
        """
157
        Set the dispersion values for all the pixels.
158
        """
159
        try:
160
            len(dispersion)
161
        except TypeError:
162
            raise TypeError("dispersion provided must be an array or list-like")
163
164
        if len(dispersion) != self.training_fluxes.shape[1]:
165
            raise ValueError("dispersion provided does not match the number "
166
                             "of pixels per star ({0} != {1})".format(
167
                                len(dispersion), self.training_fluxes.shape[1]))
168
169
        dispersion = np.array(dispersion)
170
        if dispersion.dtype.kind not in "iuf":
171
            raise ValueError("dispersion values are not float-like")
172
173
        if not np.all(np.isfinite(dispersion)):
174
            raise ValueError("dispersion values must be finite")
175
176
        self._dispersion = dispersion
177
        return None
178
179
180
    @property
181
    def training_labels(self):
182
        return self._training_labels
183
184
185
    @property
186
    def training_fluxes(self):
187
        return self._training_fluxes
188
189
190
    @property
191
    def training_flux_uncertainties(self):
192
        return self._training_flux_uncertainties
193
194
195
    # Verifying the training data.
196
    def _verify_labels_available(self):
197
        """
198
        Verify the label names provided do not include forbidden characters.
199
        """
200
        if self._forbidden_label_characters is None:
201
            return True
202
203
        for label in self.training_labels.dtype.names:
204
            for character in self._forbidden_label_characters:
205
                if character in label:
206
                    raise ValueError(
207
                        "forbidden character '{char}' is in potential "
208
                        "label '{label}' - you can disable this verification "
209
                        "by enabling `live_dangerously`".format(
210
                            char=character, label=label))
211
        return None
212
213
214
    def _verify_training_data(self):
215
        """
216
        Verify the training data for the appropriate shape and content.
217
        """
218
        if self.training_fluxes.shape != self.training_flux_uncertainties.shape:
219
            raise ValueError(
220
                "the training flux and uncertainty arrays should "
221
                "have the same shape")
222
223
        if len(self.training_labels) == 0 \
224
        or self.training_labels.dtype.names is None:
225
            raise ValueError("no named labels provided for the training set")
226
227
        if len(self.training_labels) != self.training_fluxes.shape[0]:
228
            raise ValueError(
229
                "the first axes of the training flux array should "
230
                "have the same shape as the nuber of rows in the label table "
231
                "(N_stars, N_pixels)")
232
233
        if self.dispersion is not None:
234
            dispersion = np.atleast_1d(self.dispersion).flatten()
235
            if dispersion.size != self.training_fluxes.shape[1]:
236
                raise ValueError(
237
                    "mis-match between the number of wavelength "
238
                    "points ({0}) and flux values ({1})".format(
239
                        self.training_fluxes.shape[1], dispersion.size))
240
        return None
241
242
243
    @property
244
    def is_trained(self):
245
        return self._trained
246
247
248
    # Attributes related to the labels and the label vector description.
249
    @property
250
    def labels_available(self):
251
        """
252
        All of the available labels for each star in the training set.
253
        """
254
        return self.training_labels.dtype.names
255
256
257
    @property
258
    def label_vector(self):
259
        """ The label vector for all pixels. """
260
        return self._label_vector
261
262
263
    @label_vector.setter
264
    def label_vector(self, label_vector_description):
265
        """
266
        Set a label vector.
267
268
        :param label_vector_description:
269
            A structured or human-readable version of the label vector
270
            description.
271
        """
272
273
        if label_vector_description is None:
274
            self._label_vector = None
275
            self.reset()
276
            return None
277
278
        label_vector = utils.parse_label_vector(label_vector_description)
279
280
        # Need to actually verify that the parameters listed in the label vector
281
        # are actually present in the training labels.
282
        missing = \
283
            set(self._get_labels(label_vector)).difference(self.labels_available)
284
        if missing:
285
            raise ValueError("the following labels parsed from the label "
286
                             "vector description are missing in the training "
287
                             "set of labels: {0}".format(", ".join(missing)))
288
289
        # If this is really a new label vector description,
290
        # then we are no longer trained.
291
        if not hasattr(self, "_label_vector") \
292
        or label_vector != self._label_vector:
293
            self._label_vector = label_vector
294
            self.reset()
295
296
        return None
297
298
299
    @property
300
    def human_readable_label_vector(self):
301
        """ Return a human-readable form of the label vector. """
302
        return utils.human_readable_label_vector(self.label_vector)
303
304
305
    @property
306
    def labels(self):
307
        """ The labels that contribute to the label vector. """
308
        return self._get_labels(self.label_vector)
309
    
310
311
    def _get_labels(self, label_vector):
312
        """
313
        Return the labels that contribute to the structured label vector
314
        provided.
315
        """
316
        return () if label_vector is None else \
317
            list(OrderedDict.fromkeys([label for term in label_vector \
318
                for label, power in term if power != 0]))
319
320
321
    def _get_lowest_order_label_indices(self):
322
        """
323
        Get the indices for the lowest power label terms in the label vector.
324
        """
325
        indices = OrderedDict()
326
        for i, term in enumerate(self.label_vector):
327
            if len(term) > 1: continue
328
            label, order = term[0]
329
            if order < indices.get(label, [None, np.inf])[-1]:
330
                indices[label] = (i, order)
331
        return [indices.get(label, [None])[0] for label in self.labels]
332
333
334
    # Trained attributes that subclasses are likely to use.
335
    @property
336
    def coefficients(self):
337
        return self._coefficients
338
339
340
    @coefficients.setter
341
    def coefficients(self, coefficients):
342
        """
343
        Set the label vector coefficients for each pixel. This assumes a
344
        'standard' model where the label vector is common to all pixels.
345
346
        :param coefficients:
347
            A 2-D array of coefficients of shape 
348
            (`N_pixels`, `N_label_vector_terms`).
349
        """
350
351
        if coefficients is None:
352
            self._coefficients = None
353
            return None
354
355
        coefficients = np.atleast_2d(coefficients)
356
        if len(coefficients.shape) > 2:
357
            raise ValueError("coefficients must be a 2D array")
358
359
        P, Q = coefficients.shape
360
        if P != len(self.dispersion):
361
            raise ValueError("axis 0 of coefficients array does not match the "
362
                             "number of pixels ({0} != {1})".format(
363
                                P, len(self.dispersion)))
364
        if Q != 1 + len(self.label_vector):
365
            raise ValueError("axis 1 of coefficients array does not match the "
366
                             "number of label vector terms ({0} != {1})".format(
367
                                Q, 1 + len(self.label_vector)))
368
        self._coefficients = coefficients
369
        return None
370
371
372
    @property
373
    def scatter(self):
374
        return self._scatter
375
376
377
    @scatter.setter
378
    def scatter(self, scatter):
379
        """
380
        Set the scatter values for each pixel.
381
382
        :param scatter:
383
            A 1-D array of scatter terms.
384
        """
385
386
        if scatter is None:
387
            self._scatter = None
388
            return None
389
        
390
        scatter = np.array(scatter).flatten()
391
        if scatter.size != len(self.dispersion):
392
            raise ValueError("number of scatter values does not match "
393
                             "the number of pixels ({0} != {1})".format(
394
                                scatter.size, len(self.dispersion)))
395
        if np.any(scatter < 0):
396
            raise ValueError("scatter terms must be positive")
397
        self._scatter = scatter
398
        return None
399
400
401
    # Methods which must be implemented or updated by the subclasses.
402
    def pixel_label_vector(self, pixel_index):
403
        """ The label vector for a given pixel. """
404
        return self.label_vector
405
406
407
    def train(self, *args, **kwargs):
408
        raise NotImplementedError("The train method must be "
409
                                  "implemented by subclasses")
410
411
412
    def predict(self, *args, **kwargs):
413
        raise NotImplementedError("The predict method must be "
414
                                  "implemented by subclasses")
415
416
417
    def fit(self, *args, **kwargs):
418
        raise NotImplementedError("The fit method must be "
419
                                  "implemented by subclasses")
420
421
422
    # I/O
423
    @requires_training_wheels
424
    def save(self, filename, include_training_data=False, overwrite=False):
425
        """
426
        Serialise the trained model and save it to disk. This will save all
427
        relevant training attributes, and optionally, the training data.
428
429
        :param filename:
430
            The path to save the model to.
431
432
        :param include_training_data: [optional]
433
            Save the training data (labels, fluxes, uncertainties) used to train
434
            the model.
435
436
        :param overwrite: [optional]
437
            Overwrite the existing file path, if it already exists.
438
        """
439
440
        if path.exists(filename) and not overwrite:
441
            raise IOError("filename already exists: {0}".format(filename))
442
443
        attributes = list(self._descriptive_attributes) \
444
            + list(self._trained_attributes) \
445
            + list(self._data_attributes)
446
        if "metadata" in attributes:
447
            raise ValueError("'metadata' is a protected attribute and cannot "
448
                             "be used in the _*_attributes in a class")
449
450
        # Store up all the trained attributes and a hash of the training set.
451
        contents = OrderedDict([
452
            (attr.lstrip("_"), getattr(self, attr)) for attr in \
453
            (self._descriptive_attributes + self._trained_attributes)])
454
        contents["training_set_hash"] = utils.short_hash(getattr(self, attr) \
455
            for attr in self._data_attributes)
456
457
        if include_training_data:
458
            contents.update([(attr.lstrip("_"), getattr(self, attr)) \
459
                for attr in self._data_attributes])
460
461
        contents["metadata"] = {
462
            "version": code_version,
463
            "model_name": type(self).__name__, 
464
            "modified": str(datetime.now()),
465
            "data_attributes": \
466
                [_.lstrip("_") for _ in self._data_attributes],
467
            "trained_attributes": \
468
                [_.lstrip("_") for _ in self._trained_attributes],
469
            "descriptive_attributes": \
470
                [_.lstrip("_") for _ in self._descriptive_attributes]
471
        }
472
473
        with open(filename, "wb") as fp:
474
            pickle.dump(contents, fp, -1)
475
476
        return None
477
478
479
    def load(self, filename, verify_training_data=False):
480
        """
481
        Load a saved model from disk.
482
483
        :param filename:
484
            The path where to load the model from.
485
486
        :param verify_training_data: [optional]
487
            If there is training data in the saved model, verify its contents.
488
            Otherwise if no training data is saved, verify that the data used
489
            to train the model is the same data provided when this model was
490
            instantiated.
491
        """
492
493
        with open(filename, "rb") as fp:
494
            contents = pickle.load(fp)
495
496
        assert contents["metadata"]["model_name"] == type(self).__name__
497
498
        # If data exists, deal with that first.
499
        has_data = (contents["metadata"]["data_attributes"][0] in contents)
500
        if has_data:
501
502
            if verify_training_data:
503
                data_hash = utils.short_hash(contents[attr] \
504
                    for attr in contents["metadata"]["data_attributes"])
505
                if contents["training_set_hash"] is not None \
506
                and data_hash != contents["training_set_hash"]:
507
                    raise ValueError("expected hash for the training data is "
508
                                     "different to the actual data hash "
509
                                     "({0} != {1})".format(
510
                                        contents["training_set_hash"],
511
                                        data_hash))
512
513
            # Set the data attributes.
514
            for attribute in contents["metadata"]["data_attributes"]:
515
                if attribute in contents:
516
                    setattr(self, "_{}".format(attribute), contents[attribute])
517
518
        # Set descriptive and trained attributes.
519
        self.reset()
520
        for attribute in contents["metadata"]["descriptive_attributes"]:
521
            setattr(self, "_{}".format(attribute), contents[attribute])
522
        for attribute in contents["metadata"]["trained_attributes"]:
523
            setattr(self, "_{}".format(attribute), contents[attribute])
524
        self._trained = True
525
526
        return None
527
528
529
    # Properties and attribuets related to training, etc.
530
    @property
531
    @requires_model_description
532
    def labels_array(self):
533
        """
534
        Return an array containing just the training labels, given the label
535
        vector.
536
        """
537
        return _build_label_vector_rows(
538
            [[(label, 1)] for label in self.labels], self.training_labels)[1:].T
539
540
541
    @property
542
    @requires_model_description
543
    def label_vector_array(self):
544
        """
545
        Build the label vector array.
546
        """
547
548
        lva = _build_label_vector_rows(self.label_vector, self.training_labels)
549
550
        if not np.all(np.isfinite(lva)):
551
            logger.warn("Non-finite labels in the label vector array!")
552
        return lva
553
554
555
    # Residuals in labels in the training data set.
556
    @requires_training_wheels
557
    def get_training_label_residuals(self):
558
        """
559
        Return the residuals (model - training) between the parameters that the
560
        model returns for each star, and the training set value.
561
        """
562
        
563
        optimised_labels = self.fit(self.training_fluxes,
564
            self.training_flux_uncertainties, full_output=False)
565
566
        return optimised_labels - self.labels_array
567
568
569
    def _format_input_labels(self, args=None, **kwargs):
570
        """
571
        Format input labels either from a list or dictionary into a common form.
572
        """
573
574
        # We want labels in a dictionary.
575
        labels = kwargs if args is None else dict(zip(self.labels, args))
576
        return { k: (v if isinstance(v, (list, tuple, np.ndarray)) else [v]) \
577
            for k, v in labels.items() }
578
579
580
    # Put Cross-validation functions in here.
581
    @requires_model_description
582
    def cross_validate(self, pre_train=None, **kwargs):
583
        """
584
        Perform leave-one-out cross-validation on the training set.
585
        """
586
        
587
        inferred = np.nan * np.ones_like(self.labels_array)
588
        N_training_set, N_labels = inferred.shape
589
        N_stop_at = kwargs.pop("N", N_training_set)
590
591
        debug = kwargs.pop("debug", False)
592
        
593
        kwds = { "threads": self.threads }
594
        kwds.update(kwargs)
595
596
        for i in range(N_training_set):
597
            
598
            training_set = np.ones(N_training_set, dtype=bool)
599
            training_set[i] = False
600
601
            # Create a clean model to use so we don't overwrite self.
602
            model = self.__class__(
603
                self.training_labels[training_set],
604
                self.training_fluxes[training_set],
605
                self.training_flux_uncertainties[training_set],
606
                **kwds)
607
608
            # Initialise and run any pre-training function.
609
            for _attribute in self._descriptive_attributes:
610
                setattr(model, _attribute[1:], getattr(self, _attribute[1:]))
611
612
            if pre_train is not None:
613
                pre_train(self, model)
614
615
            # Train and solve.
616
            model.train()
617
618
            try:
619
                inferred[i, :] = model.fit(self.training_fluxes[i],
620
                    self.training_flux_uncertainties[i], full_output=False)
621
622
            except:
623
                logger.exception("Exception during cross-validation on object "
624
                                 "with index {0}:".format(i))
625
                if debug: raise
626
627
            if i == N_stop_at + 1:
628
                break
629
630
        return inferred[:N_stop_at, :]
631
632
633
def _build_label_vector_rows(label_vector, training_labels):
634
    """
635
    Build a label vector row from a description of the label vector (as indices
636
    and orders to the power of) and the label values themselves.
637
638
    For example: if the first item of `labels` is `A`, and the label vector
639
    description is `A^3` then the first item of `label_vector` would be:
640
641
    `[[(0, 3)], ...`
642
643
    This indicates the first label item (index `0`) to the power `3`.
644
645
    :param label_vector:
646
        An `(index, order)` description of the label vector. 
647
648
    :param training_labels:
649
        The values of the corresponding training labels.
650
651
    :returns:
652
        The corresponding label vector row.
653
    """
654
655
    columns = [np.ones(len(training_labels), dtype=float)]
656
    for term in label_vector:
657
        column = 1.
658
        for label, order in term:
659
            column *= np.array(training_labels[label]).flatten()**order
660
        columns.append(column)
661
662
    try:
663
        return np.vstack(columns)
664
665
    except ValueError:
666
        columns[0] = np.ones(1)
667
        return np.vstack(columns)
668