theta()   F
last analyzed

Complexity

Conditions 18

Size

Total Lines 140

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 1 Features 0
Metric Value
cc 18
c 1
b 1
f 0
dl 0
loc 140
rs 2

How to fix   Long Method    Complexity   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Complexity

Complex classes like theta() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
"""
5
Plotting utilities for The Cannon.
6
"""
7
8
from __future__ import (division, print_function, absolute_import,
9
                        unicode_literals)
10
11
__all__ = ["theta", "scatter", "one_to_one"]
12
13
import logging
14
import numpy as np
15
16
logger = logging.getLogger(__name__)
17
18
try:
19
    import matplotlib.pyplot as plt
20
    from matplotlib.ticker import MaxNLocator
21
22
except ImportError:
23
    logger.warn("Could not import matplotlib; plotting functionality disabled")    
24
25
26
def theta(model, indices=None, label_terms=None, show_label_terms=True,
27
    normalize=True, common_axis=False, latex_label_names=None, xlim=None, 
28
    **kwargs):
29
    """
30
    Plot the spectral derivates (:math:`\boldsymbol{\theta}` coefficiets) from a
31
    trained model.
32
33
    :param model:
34
        A trained CannonModel object.
35
36
    :param indices: [optional]
37
        The indices of :math:`\boldsymbol{\theta}` to plot. By default all
38
        coefficients will be shown.
39
40
    :param label_terms: [optional]:
41
        Specify the label terms to show coefficients for. This is similar to
42
        specifying the `indices`, except you don't have to calculate the position
43
        of each label name.
44
45
        For example, specifying ``indices=0`` and ``label_terms=['TEFF', 'MG_H']``
46
        would show the first :math:`\theta` value (mean flux), as well as the
47
        :math:`\theta` coefficients that correspond to the linear terms of
48
        ``'TEFF'`` and ``'MG_H'``.
49
50
        Note that label_terms is specific to the model vectorizer.
51
        The vectorizer must be able to identify the label term by the inputs
52
        provided (e.g., a polynomial vectorizer will recognize ``'TEFF'`` is the
53
        linear coefficient of ``'TEFF'``, but ``'TEFF'`` on its own may not be
54
        recognisable to a vectorizer that uses sine and cosine functions.)
55
    
56
    :param show_label_terms: [optional]
57
        Show the label terms on the right hand side of each axis.
58
59
    :param normalize: [optional]
60
        Normalize each coefficient between [-1, 1], except for the first theta
61
        coefficient (mean flux).
62
63
    :param common_axis: [optional]
64
        Show all spectral derivatives on a single axes.
65
66
    :param latex_label_names: [optional]
67
        A list containing the label names as LaTeX representations.
68
69
    :param xlim: [optional]
70
        The x-limits to apply to all axes.
71
72
    :returns:
73
        A figure showing the spectral derivatives.
74
    """
75
76
    if not model.is_trained:
77
        raise ValueError("model needs to be trained first")
78
79
    if latex_label_names is None:
80
        label_names = model.vectorizer.label_names
81
    else:
82
        label_names = latex_label_names
83
84
    if indices is None and label_terms is None:
85
        label_indices = np.arange(model.theta.shape[1])
86
    else:
87
        label_indices = []
88
        if indices is not None:
89
            label_indices.extend(np.array(indices).astype(int).flatten())
90
        if label_terms is not None:
91
            raise NotImplementedError
92
93
    label_indices = np.array(label_indices)
94
95
    if len(set(label_indices)) < label_indices.size:
96
        logger.warn("Removing duplicate label indices")
97
        label_indices = np.unique(label_indices)
98
99
    K = len(label_indices)
100
101
    fig, axes = plt.subplots(K)
102
    axes = np.array([axes]).flatten()
103
104
    if common_axis:
105
        raise NotImplementedError
106
107
    if model.dispersion is None:
108
        x = np.arange(model.theta.shape[0])
109
    else:
110
        x = model.dispersion
111
112
    plot_kwds = dict(c="b", lw=1)
113
    plot_kwds.update(kwargs.get("plot_kwds", {}))
114
115
    for i, (ax, label_index) in enumerate(zip(axes, label_indices)):
116
117
        y = model.theta.T[label_index].copy()
118
        scale = np.max(np.abs(y)) if normalize and label_index != 0 else 1.0
119
120
        ax.plot(x, y/scale, **plot_kwds)
121
122
        if normalize and label_index != 0:
123
            ax.set_ylim(-1.2, 1.2)
124
            ax.set_yticks([-1, 1])
125
            ylabel = r"$\theta_{{{0}}}/\max{{|\theta_{{{0}}}|}}$".format(label_index)
126
127
        else:
128
            ylabel = r"$\theta_{{{0}}}$".format(label_index)
129
            ax.yaxis.set_major_locator(MaxNLocator(3))
130
131
132
        ax.set_ylabel(ylabel, rotation=0, verticalalignment="center")
133
        ax.yaxis.labelpad = 30
134
135
        if show_label_terms:
136
            rhs_ylabel = model.vectorizer.get_human_readable_label_term(label_index,
137
                label_names=label_names, mul='\cdot', pow='^')
138
            ax_rhs = ax.twinx()
139
            if latex_label_names is not None:
140
                rhs_ylabel = r"${}$".format(rhs_ylabel)
141
142
            ax_rhs.set_ylabel(rhs_ylabel, rotation=0, verticalalignment="center")
143
            ax_rhs.yaxis.labelpad = 30
144
            ax_rhs.set_yticks([])
145
146
147
        if ax.is_last_row():
148
            if model.dispersion is None:
149
                xlabel = r"${\rm Pixel}$"
150
            else:
151
                xlabel = r"${\rm Wavelength},$ $({\rm AA})$"
152
            ax.set_xlabel(xlabel)
153
154
        else:
155
            ax.set_xticklabels([])
156
157
        # Set RHS label.
158
        ax.xaxis.set_major_locator(MaxNLocator(6))
159
160
        ax.set_xlim(xlim)
161
162
    fig.tight_layout()
163
    fig.subplots_adjust(hspace=0.10)
164
165
    return fig
166
167
168
def scatter(model, ax=None, **kwargs):
169
    """
170
    Plot the noise residuals (:math:`s`) at each pixel.
171
172
    :param model:
173
        A trained CannonModel object.
174
175
    :returns:
176
        A figure showing the noise residuals at every pixel.
177
    """
178
179
    if not model.is_trained:
180
        raise ValueError("model needs to be trained first")
181
182
    fig = None
183
    if ax is None:
184
        fig, ax = plt.subplots()
185
186
    if model.dispersion is None:
187
        x = np.arange(model.s2.size)
188
    else:
189
        x = model.dispersion
190
191
    plot_kwds = dict(lw=1, c="b")
192
    plot_kwds.update(kwargs.pop("plot_kwds", {}))
193
194
    ax.plot(x, model.s2**0.5, **plot_kwds)
195
196
    if model.dispersion is None:
197
        ax.set_xlabel(r"${\rm Pixel}$")
198
    else:
199
        ax.set_xlabel(r"${\rm Wavelength}$ $[{\rm \AA}]$")
200
201
    ax.set_ylim(0, ax.get_ylim()[1])
202
    ax.set_ylabel(r"${\rm Scatter},$ $s$")
203
204
    ax.xaxis.set_major_locator(MaxNLocator(6))
205
    ax.yaxis.set_major_locator(MaxNLocator(6))
206
207
    if fig is not None:
208
        fig.tight_layout()
209
    else:
210
        fig = ax.figure
211
212
    return fig
213
214
215
def one_to_one(model, test_labels, cov=None, latex_label_names=None,
216
    show_statistics=True, **kwargs):
217
    """
218
    Plot a one-to-one comparison of the training set labels, and the test set
219
    labels inferred from the training set spectra.
220
221
    :param model:
222
        A trained CannonModel object.
223
224
    :param test_labels:
225
        An array of test labels, inferred from the training set spectra.
226
227
    :param cov: [optional]
228
        The covariance matrix returned for all test labels.
229
230
    :param latex_label_names: [optional]
231
        A list of label names in LaTeX representation.
232
233
    :param show_statistics: [optional]
234
        Show the mean and standard deviation of residuals in each axis.
235
    """
236
237
    if model.training_set_labels.shape != test_labels.shape:
238
        raise ValueError(
239
            "test labels must have the same shape as training set labels")
240
241
    N, K = test_labels.shape
242
    if cov is not None and cov.shape != (N, K, K):
243
        raise ValueError(
244
            "shape mis-match in covariance matrix ({N}, {K}, {K}) != {shape}"\
245
            .format(N=N, K=K, shape=cov.shape))
246
247
    factor = 2.0           
248
    lbdim = 0.30 * factor
249
    tdim = 0.25 * factor
250
    rdim = 0.10 * factor
251
    wspace = 0.05
252
    hspace = 0.35
253
    yspace = factor * K + factor * (K - 1.) * hspace
254
    xspace = factor
255
256
    xdim = lbdim + xspace + rdim
257
    ydim = lbdim + yspace + tdim
258
259
    fig, axes = plt.subplots(K, figsize=(xdim, ydim))
260
    
261
    l, b = (lbdim / xdim, lbdim / ydim)
262
    t, r = ((lbdim + yspace) / ydim, ((lbdim + xspace) / xdim))
263
264
    fig.subplots_adjust(left=l, bottom=b, right=r, top=t, wspace=wspace, hspace=hspace)
265
266
    axes = np.array([axes]).flatten()
267
268
    scatter_kwds = dict(s=1, c="k", alpha=0.5)
269
    scatter_kwds.update(kwargs.get("scatter_kwds", {}))
270
271
    errorbar_kwds = dict(fmt=None, ecolor="k", alpha=0.5, capsize=0)
272
    errorbar_kwds.update(kwargs.get("errorbar_kwds", {}))
273
274
    for i, ax in enumerate(axes):
275
276
        x = model.training_set_labels[:, i]
277
        y = test_labels[:, i]
278
279
        ax.scatter(x, y, **scatter_kwds)
280
        if cov is not None:
281
            yerr = cov[:, i, i]**0.5
282
            ax.errorbar(x, y, yerr=yerr, **errorbar_kwds)
283
284
        # Set x-axis limits and y-axis limits the same
285
        limits = np.array([ax.get_xlim(), ax.get_ylim()])
286
        limits = (np.min(limits), np.max(limits))
287
        
288
        ax.plot(limits, limits, c="#666666", linestyle=":", zorder=-1)
289
        ax.set_xlim(limits)
290
        ax.set_ylim(limits)
291
292
        label_name = model.vectorizer.label_names[i]
293
294
        if latex_label_names is not None:
295
            try:
296
                label_name = r"${}$".format(latex_label_names[i])
297
            except:
298
                logger.warn(
299
                    "Could not access latex label name for index {} ({})"\
300
                    .format(i, label_name))
301
302
        ax.set_title(label_name)
303
304
        ax.xaxis.set_major_locator(MaxNLocator(4))
305
        ax.yaxis.set_major_locator(MaxNLocator(4))
306
307
        # Show mean and sigma.
308
        if show_statistics:
309
            diff = y - x
310
            mu = np.median(diff)
311
            sigma = np.std(diff)
312
            ax.text(0.05, 0.85, r"$\mu = {0:.2f}$".format(mu),
313
                transform=ax.transAxes)
314
            ax.text(0.05, 0.75, r"$\sigma = {0:.2f}$".format(sigma),
315
                transform=ax.transAxes)
316
        
317
        ax.set_aspect(1.0)
318
319
    return fig
320