RestrictedCannonModel   A
last analyzed

Complexity

Total Complexity 12

Size/Duplication

Total Lines 127
Duplicated Lines 0 %

Importance

Changes 5
Bugs 0 Features 0
Metric Value
c 5
b 0
f 0
dl 0
loc 127
rs 10
wmc 12

3 Methods

Rating   Name   Duplication   Size   Complexity  
A __init__() 0 11 1
D theta_bounds() 0 4 8
B train() 0 24 2
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
"""
5
A restricted Cannon model where bounds are placed on theta coefficients in order
6
to make the model more physically realistic and limit information propagated
7
through abundance correlations.
8
"""
9
10
from __future__ import (division, print_function, absolute_import,
11
                        unicode_literals)
12
13
__all__ = ["RestrictedCannonModel"]
14
15
import logging
16
from .model import CannonModel
17
18
logger = logging.getLogger(__name__)
19
20
21
class RestrictedCannonModel(CannonModel):
22
    """
23
    A model for The Cannon which includes L1 regularization, pixel censoring,
24
    and is capable of placing bounds on theta coefficients in order to make the
25
    model more physically realistic and limit information propagated through
26
    abundance correlations.
27
28
    :param training_set_labels:
29
        A set of objects with labels known to high fidelity. This can be 
30
        given as a numpy structured array, or an astropy table.
31
32
    :param training_set_flux:
33
        An array of normalised fluxes for stars in the labelled set, given 
34
        as shape `(num_stars, num_pixels)`. The `num_stars` should match the
35
        number of rows in `training_set_labels`.
36
37
    :param training_set_ivar:
38
        An array of inverse variances on the normalized fluxes for stars in 
39
        the training set. The shape of the `training_set_ivar` array should
40
        match that of `training_set_flux`.
41
42
    :param vectorizer:
43
        A vectorizer to take input labels and produce a design matrix. This
44
        should be a sub-class of `vectorizer.BaseVectorizer`.
45
46
    :param dispersion: [optional]
47
        The dispersion values corresponding to the given pixels. If provided, 
48
        this should have a size of `num_pixels`.
49
    
50
    :param regularization: [optional]
51
        The strength of the L1 regularization. This should either be `None`,
52
        a float-type value for single regularization strength for all pixels,
53
        or a float-like array of length `num_pixels`.
54
55
    :param censors: [optional]
56
        A dictionary containing label names as keys and boolean censoring
57
        masks as values.
58
59
    :param theta_bounds: [optional]
60
        A dictionary containing label names as keys and two-length tuples as
61
        values, indicating acceptable minimum and maximum values. Specify
62
        `None` to indicate no limit on a boundary.
63
    """
64
65
    def __init__(self, training_set_labels, training_set_flux, training_set_ivar,
66
        vectorizer, dispersion=None, regularization=None, censors=None, 
67
        theta_bounds=None, **kwargs):
68
69
        super(RestrictedCannonModel, self).__init__(training_set_labels,
70
            training_set_flux, training_set_ivar, vectorizer, 
71
            dispersion=dispersion, regularization=regularization, 
72
            censors=censors, **kwargs)
73
74
        self.theta_bounds = theta_bounds
75
        return None
76
77
78
    @property
79
    def theta_bounds(self):
80
        """ Return the boundaries placed on theta coefficients. """
81
        return self._theta_bounds
82
83
84
    @theta_bounds.setter
85
    def theta_bounds(self, theta_bounds):
86
        """
87
        Set lower and upper boundaries on specific theta coefficients.
88
89
        :param theta_bounds:
90
            A dictionary containing vectorizer terms as keys and two-length 
91
            tuples as values, indicating acceptable minimum and maximum values. 
92
            Specify `None` to indicate no limit on a boundary. For example:
93
            `theta_bounds={"FE_H": (None, 0), "TEFF^3": (None, None)}`
94
        """
95
        theta_bounds = {} if theta_bounds is None else theta_bounds
96
        if isinstance(theta_bounds, dict):
97
            
98
            label_vector = self.vectorizer.human_readable_label_vector
99
            terms = label_vector.split(" + ")
100
            checked_bounds = {}
101
            for term in theta_bounds.keys():
102
                bounds = theta_bounds[term]
103
                term = str(term)
104
                
105
                if term not in terms:
106
                    logging.warn("Boundary on term '{}' ignored because it is "
107
                                 "not in the label vector: {}".format(
108
                                    term, label_vector))
109
                else:
110
                    if len(bounds) != 2:
111
                        raise ValueError("bounds must be a two-length tuple")
112
                    if None not in bounds and bounds[1] < bounds[0]:
113
                        raise ValueError("bounds must be in (min, max) order")
114
115
                    checked_bounds[term] = bounds
116
117
            self._theta_bounds = checked_bounds
118
119
        else:
120
            raise TypeError("theta_bounds must be a dictionary-like object")
121
122
123
124
    def train(self, threads=None, op_kwds=None):
125
        """
126
        Train the model.
127
128
        :param threads: [optional]
129
            The number of parallel threads to use.
130
131
        :param op_kwds:
132
            Keyword arguments to provide directly to the optimization function.
133
134
        :returns:
135
            A three-length tuple containing the spectral coefficients `theta`,
136
            the squared scatter term at each pixel `s2`, and metadata related to
137
            the training of each pixel.
138
        """
139
140
        # Generate the optimization bounds based on self.theta_bounds.
141
        op_bounds = [self.theta_bounds.get(term, (None, None)) \
142
            for term in self.vectorizer.human_readable_label_vector.split(" + ")]
143
144
        kwds = dict(op_method="l_bfgs_b", op_strict=False, op_kwds=(op_kwds or {}))
145
        kwds["op_kwds"].update(bounds=op_bounds)
146
        
147
        return super(RestrictedCannonModel, self).train(threads=threads, **kwds)
148