plot_first_order_derivatives()   F
last analyzed

Complexity

Conditions 25

Size

Total Lines 106

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
cc 25
c 1
b 0
f 0
dl 0
loc 106
rs 2

How to fix   Long Method    Complexity   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Complexity

Complex classes like plot_first_order_derivatives() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
2
3
"""
4
Plot first-order element coefficients as a function of lambda.
5
"""
6
7
8
import matplotlib.pyplot as plt
9
#plt.rcParams["text.usetex"] = True
10
#plt.rcParams["text.latex.preamble"] = [r"\usepackage{amsmath}"]
11
import matplotlib.colors as cm
12
import numpy as np
13
import os
14
from scipy import optimize as op
15
from matplotlib.ticker import MaxNLocator
16
17
import colormaps as cmaps
18
19
import AnniesLasso as tc
20
21
22
def _show_xlim_changes(fig, diag=0.015, xtol=0):
23
24
    N = len(fig.axes)
25
    if 2 > N: return
26
    for i, ax in enumerate(fig.axes):
27
28
        if i > 0:
29
            # Put LHS break marks in.
30
            kwargs = dict(transform=ax.transAxes, color="k", clip_on=False)
31
            ax.plot((-diag, +diag), (  - diag,   + diag), **kwargs)
32
            ax.plot((-diag, +diag), (1 - diag, 1 + diag), **kwargs)
33
34
        if i != N - 1:
35
            # Put RHS break marks in.
36
            kwargs = dict(transform=ax.transAxes, color="k", clip_on=False)
37
            ax.plot((1 - diag, 1 + diag), (1 - diag, 1 + diag), **kwargs) 
38
            ax.plot((1 - diag, 1 + diag), (  - diag,   + diag), **kwargs)
39
40
        # Control spines depending on which axes it is
41
        if i == 0:
42
            ax.yaxis.tick_left() 
43
            ax.spines["right"].set_visible(False)
44
            ax.set_xlim(ax.get_xlim()[0], ax.get_xlim()[1] - xtol)
45
46
        elif i > 0 and i != N - 1:
47
            ax.set_xlim(ax.get_xlim()[0] + xtol, ax.get_xlim()[1] - xtol)
48
            ax.yaxis.set_tick_params(size=0)
49
            ax.tick_params(labelleft='off')
50
            ax.spines["left"].set_visible(False)
51
            ax.spines["right"].set_visible(False)
52
53
        else:
54
            ax.set_xlim(ax.get_xlim()[0] + xtol, ax.get_xlim()[1])
55
            ax.yaxis.tick_right()
56
            ax.tick_params(labelleft='off')
57
            ax.spines["left"].set_visible(False)
58
59
    return None
60
61
62
63
def plot_first_order_derivatives(model, label_names=None, scaled=True,
64
    show_clipped_region=False, colors=None, zorders=None,
65
    clip_less_than=None, label_wavelengths=None, latex_label_names=None,
66
    wavelength_regions=None, show_legend=True, **kwargs):
67
68
    if wavelength_regions is None:
69
        wavelength_regions = [(model.dispersion[0], model.dispersion[-1])]
70
71
    if label_names is None:
72
        label_names = model.vectorizer.label_names
73
74
    if latex_label_names is None:
75
        latex_label_names = {}
76
77
    fig, axes = plt.subplots(1, len(wavelength_regions), figsize=(15, 3.5))
78
79
    if len(label_names) > 1:
80
        #cmap = cm.LinearSegmentedColormap.from_list(
81
        #    "inferno", cmaps._inferno_data, len(label_names))
82
        cmap = plt.cm.get_cmap("Set1", len(label_names))
83
    else:
84
        cmap = lambda x: "k"
85
86
    if colors is not None:
87
        cmap = lambda x: colors[x % len(colors)]
88
89
    axes = np.array(axes).flatten()
90
91
    scales = []
92
    for i, label_name in enumerate(label_names):
93
94
        # First order derivatives are always indexed first.
95
        index = 1 + model.vectorizer.label_names.index(label_name)
96
        y = model.theta[:, index]
97
        #y = y
98
        if clip_less_than is not None:
99
            y[np.abs(y) < clip_less_than] = 0
100
101
        scale = np.nanmax(np.abs(y)) if scaled else 1.
102
        y = y / scale
103
104
        c = cmap(i)
105
        zorder = 1
106
        if zorders is not None: zorder = zorders[i]
107
        for ax in axes:
108
            ax.plot(model.dispersion, y, c=c, zorder=zorder,
109
                label=latex_label_names.get(label_name, label_name))
110
111
            if clip_less_than is not None and show_clipped_region:
112
                ax.axhspan(-clip_less_than/scale, +clip_less_than/scale,
113
                    xmin=-1, xmax=+2,
114
                    facecolor=c, edgecolor=c, zorder=-100, alpha=0.1)
115
116
117
    # Plot any wavelengths.
118
    if label_wavelengths is not None:
119
        label_yvalue = 1.0
120
        for label_name, wavelengths in label_wavelengths.items():
121
            try:
122
                color = cmap(label_names.index(label_name))
123
                label = None
124
            except (IndexError, ValueError):
125
                color = 'k'
126
                #for ax in axes:
127
                #    ax.plot([model.dispersion[0] - 1], [0], c=color, label=latex_label_names.get(label_name, label_name))
128
129
            for ax in axes:
130
                ax.plot(wavelengths, label_yvalue * np.ones_like(wavelengths),
131
                    "|", markersize=20, markeredgewidth=2, c=color)
132
    
133
    for ax, wavelength_region in zip(axes, wavelength_regions):
134
        ax.set_xlim(wavelength_region)
135
        if scaled:
136
            ax.set_ylim(-1.2, 1.2)
137
        
138
        if ax.is_first_col():
139
            if scaled:
140
                ax.set_ylabel(r"$\theta/{\max|\theta|}$")
141
            else:
142
                ax.set_ylabel(r"$\theta$")
143
        else:
144
            ax.set_yticklabels([])
145
    
146
        ax.xaxis.set_major_locator(MaxNLocator(4))
147
148
    xlabel = r"$\lambda$ $({\rm\AA})$"
149
    if len(wavelength_regions) == 1:
150
        ax.set_xlabel(xlabel)
151
152
    else:
153
        _show_xlim_changes(fig)
154
        
155
        ax = fig.add_axes([0, 0, 1, 1])
156
        ax.set_axis_off()
157
        ax.set_xlim(0, 1)
158
        ax.set_ylim(0, 1)
159
        ax.text(0.5, 0.05, xlabel, rotation='horizontal',
160
            horizontalalignment='center', verticalalignment='center')
161
162
    fig.tight_layout()
163
164
    if show_legend:
165
        axes[0].legend(loc="upper right", ncol=kwargs.get("legend_ncol", len(label_names) % 7), frameon=False)
166
    fig.subplots_adjust(wspace=0.01, bottom=0.20)
167
168
    return fig
169
170
171
172
173
174
175
176
177
178
179
if __name__ == "__main__":
180
181
182
    PATH, CATALOG, FILE_FORMAT = ("/Volumes/My Book/surveys/apogee/",
183
        "apogee-rg.fits", "apogee-rg-custom-normalization-{}.memmap")
184
185
    model = tc.load_model("gridsearch-20.0-3.0.model")
186
    model._dispersion = np.memmap(
187
        os.path.join(PATH, FILE_FORMAT).format("dispersion"),
188
        mode="c", dtype=float)
189
190
191
192
    fig = plot_first_order_derivatives(model,
193
        label_names=["AL_H", "S_H", "K_H"],
194
        clip_less_than=np.std(np.abs(model.theta[:, 6])),
195
        label_wavelengths={
196
            "AL_H": [16723.524113765838, 16767.938194147067],
197
            "K_H": [15172.521340566429],
198
            "Missing": [15235.7, 16755.1]
199
        },
200
        latex_label_names={
201
            "AL_H": r"$[\rm{Al}/\rm{H}]$",
202
            "K_H": r"$[\rm{K}/\rm{H}]$",
203
            "S_H": r"$[\rm{S}/\rm{H}]$",
204
            "Missing": "Missing/Unknown (Shetrone+ 2015)"
205
        },
206
        show_clipped_region=True,
207
        wavelength_regions=[
208
            (15152.465463818111, 15400),
209
            (16601, 16800),
210
        ])
211
    # Show first figure.
212
    fig.savefig("papers/sparse-first-order-coefficients.pdf", dpi=300)
213
    fig.savefig("papers/sparse-first-order-coefficients.png", dpi=300)
214
215
    # Now zoom in around those sections.
216
217
    colors = []
218
    cmap = plt.cm.get_cmap("Set1", 3)
219
    colors = [cmap(0)] + ["#CCCCCC"] * 11 + [cmap(1)] + ["#CCCCCC"] * 2
220
221
    fig = plot_first_order_derivatives(model,
222
        label_names=model.vectorizer.label_names[2:],
223
        clip_less_than=None, #np.std(np.abs(model.theta[:, 6])),
224
        scaled=True,
225
        show_clipped_region=False,
226
        colors=colors,
227
        zorders=[10] + [0] * 11 + [10] + [0] * 2,
228
        show_legend=False,
229
        latex_label_names={
230
            "AL_H": r"$[\rm{Al}/\rm{H}]$",
231
            "CA_H": r"$[\rm{Ca}/\rm{H}]$",
232
            "C_H": r"$[\rm{C}/\rm{H}]$",
233
            "FE_H": r"$[\rm{Fe}/\rm{H}]$",
234
            "K_H": r"$[\rm{K}/\rm{H}]$",
235
            "MG_H": r"$[\rm{Mg}/\rm{H}]$",
236
            "MN_H": r"$[\rm{Mn}/\rm{H}]$",
237
            "NA_H": r"$[\rm{Na}/\rm{H}]$",
238
            "NI_H": r"$[\rm{Ni}/\rm{H}]$",
239
            "N_H": r"$[\rm{N}/\rm{H}]$",
240
            "O_H": r"$[\rm{O}/\rm{H}]$",
241
            "SI_H": r"$[\rm{Si}/\rm{H}]$",
242
            "S_H": r"$[\rm{S}/\rm{H}]$",
243
            "TI_H": r"$[\rm{Ti}/\rm{H}]$",
244
            "V_H": r"$[\rm{V}/\rm{H}]$"
245
        },
246
        label_wavelengths={
247
            "Missing": [15235.7, 16755.1]
248
        },
249
        wavelength_regions=[
250
            (15235.6 - 10, 10 + 15235.6),
251
            (16755.1 - 10, 10 + 16755.1)
252
        ])
253
254
    for ax in fig.axes[:-1]:
255
        ax.set_xticklabels([r"${0:.0f}$".format(_) for _ in ax.get_xticks()])
256
257
    fig.savefig("papers/sparse-first-order-coefficients-zoom.pdf", dpi=300)
258
    fig.savefig("papers/sparse-first-order-coefficients-zoom.png", dpi=300)
259
260
261