Completed
Push — master ( 4fc32f...027495 )
by Andy
01:22
created

plot_cluster_comparison()   F

Complexity

Conditions 41

Size

Total Lines 240

Duplication

Lines 48
Ratio 20 %

Importance

Changes 2
Bugs 0 Features 0
Metric Value
dl 48
loc 240
rs 2
c 2
b 0
f 0
cc 41

How to fix   Long Method    Complexity   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Complexity

Complex classes like plot_cluster_comparison() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
"""
2
Plot some cluster abundance stuff compared to ASPCAP.
3
"""
4
5
6
import matplotlib
7
matplotlib.rcParams["text.usetex"] = True
8
9
10
import matplotlib.pyplot as plt
11
import numpy as np
12
13
from astropy.io import fits
14
from matplotlib.ticker import MaxNLocator
15
16
17
# x vs y
18
# membership criteria.
19
20
# C vs N
21
22
# O vs Na
23
# O vs Mg
24
# O vs Al
25
# O vs S
26
27
# C+N+O vs FE_H?
28
# Mg vs Al
29
30
31
32
33
def plot_cluster_comparison(data, cluster_name, membership, x_elements, 
34
    y_elements, used_cannon_for_target_selection=True, vel_lim=None,
35
    xlims=None, ylims=None):
36
    """
37
    membership should be same len as data
38
    """
39
40
    candidate_color, membership_color = ("#666666", "#3498DB")
41
    candidate_color, membership_color = ("#BBBBBB", "#3498DB")
42
    tc_suffix, aspcap_suffix = ("", "_ASPCAP")
43
    
44
    candidates = data["FIELD"] == cluster_name
45
46
    membership_kwds = {"s": 50, "lw": 1.5}
47
    candidate_kwds = {"s": 30, "marker": "+", "lw": 1.5}
48
49
    fig, axes = plt.subplots(6, 2, figsize=(5.1, 16))
50
    axes = np.array(axes).flatten()
51
52
    axes[0].set_visible(False)
53
    axes[1].set_visible(False)
54
55
    top_ax = plt.subplot(6, 1, 1)
56
57
58
    # Vhelio and FE_H_1 (our metallicity?)
59
    suffix = tc_suffix if used_cannon_for_target_selection else aspcap_suffix
60
    top_ax.scatter(
61
        data["VHELIO_AVG"][candidates], data["FE_H" + suffix][candidates], 
62
        facecolor=candidate_color, rasterized=True,
63
        label=r"$\texttt{{FIELD = {0}}}$".format(cluster_name),
64
        **candidate_kwds)
65
    top_ax.scatter(
66
        data["VHELIO_AVG"][membership], data["FE_H" + suffix][membership], 
67
        facecolor=membership_color, rasterized=True, **membership_kwds)
68
    top_ax.errorbar(
69
        data["VHELIO_AVG"][membership], data["FE_H" + suffix][membership],
70
        xerr=data["VERR"][membership], yerr=data["E_FE_H" + suffix][membership],
71
        rasterized=True, 
72
        fmt=None, ecolor="k", zorder=-1)
73
74
75
    N, M = len(data["VHELIO_AVG"][candidates]), len(data["VHELIO_AVG"][membership])
76
    top_ax.text(0.05, 0.95, r"${:,}$".format(N), color=candidate_color,
77
            verticalalignment="top", horizontalalignment="left",
78
            transform=top_ax.transAxes)  
79
    top_ax.text(0.05, 0.95 - 0.11, r"${:,}$".format(M), color=membership_color,
80
            verticalalignment="top", horizontalalignment="left",
81
            transform=top_ax.transAxes)  
82
83
    #top_ax.legend(frameon=True, fontsize=11, loc="upper left")
84
85
    top_ax.set_xlabel(r"$V_{\rm helio}$ $(\rm{km}$ $\rm{s}^{-1})$")
86
    if used_cannon_for_target_selection:
87
        top_ax.set_ylabel(r"$[\rm{Fe}/\rm{H}]$ $(\rm{The}$ $\rm{Cannon})$")
88
    else:
89
        top_ax.set_ylabel(r"$[\rm{Fe}/\rm{H}]$ $(\rm{ASPCAP})$")
90
91
    top_ax.set_title(r"$\rm{{{0}}}$ $\rm{{membership}}$ $\rm{{selection}}$".format(
92
        cluster_name))
93
94
    top_ax.xaxis.set_major_locator(MaxNLocator(4))
95
    top_ax.yaxis.set_major_locator(MaxNLocator(4))
96
97
98
    
99
    for j, (element_x, element_y) in enumerate(zip(x_elements, y_elements)):
100
101
        x_wrt_fe, y_wrt_fe = (True, True)
102
103
        if element_x.lower() == "fe":
104
            x_wrt_fe = False
105
106
        if element_y.lower() == "fe":
107
            y_wrt_fe = False
108
109
        # X/Y for The Cannon
110
        for i, (mask, color) \
111
        in enumerate(zip((candidates, membership), (candidate_color, membership_color))):
112
113
            xerr, yerr = None, None
114 View Code Duplication
            if "," in element_x:
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
115
                x = 0
116
                xerr = 0
117
                for each in element_x.split(","):
118
                    x += data["{0}_H{1}".format(each.upper(), tc_suffix)]
119
                    xerr += data["E_{0}_H{1}".format(each.upper(), tc_suffix)]**2
120
121
                    if x_wrt_fe:
122
                        x = x - data["FE_H{}".format(tc_suffix)]
123
124
                if x_wrt_fe:
125
                    xerr += data["E_FE_H{0}".format(tc_suffix)]**2
126
                xerr = np.sqrt(xerr)
127
128
            else:
129
                x = data["{0}_H{1}".format(element_x.upper(), tc_suffix)]
130
                if x_wrt_fe:
131
                    x = x - data["FE_H{}".format(tc_suffix)]
132
                    xerr = (
133
                        data["E_{0}_H{1}".format(element_x.upper(), tc_suffix)]**2 + \
134
                        data["E_FE_H{0}".format(tc_suffix)]**2)**0.5
135
                
136
                else:
137
                    xerr = data["E_{0}_H{1}".format(element_x.upper(), tc_suffix)]
138
                
139
140 View Code Duplication
            if "," in element_y:
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
141
                y = 0
142
                yerr = 0
143
                for each in element_y.split(","):
144
                    y += data["{0}_H{1}".format(each.upper(), tc_suffix)]
145
                    yerr += data["E_{0}_H{1}".format(each.upper(), tc_suffix)]**2
146
147
                    if y_wrt_fe:
148
                        y = y - data["FE_H{}".format(tc_suffix)]
149
150
                if y_wrt_fe:
151
                    yerr += data["E_FE_H{}".format(tc_suffix)]**2
152
                yerr = np.sqrt(yerr)
153
154
            else:
155
                y = data["{0}_H{1}".format(element_y.upper(), tc_suffix)]
156
                if y_wrt_fe:
157
                    y = y - data["FE_H{}".format(tc_suffix)]
158
                    yerr = (
159
                        data["E_{0}_H{1}".format(element_y.upper(), tc_suffix)]**2 + \
160
                        data["E_FE_H{0}".format(tc_suffix)]**2
161
                        )**0.5
162
                else:
163
                    yerr = data["E_{0}_H{1}".format(element_y.upper(), tc_suffix)]
164
165
166
            kwds = candidate_kwds if i == 0 else membership_kwds
167
            axes[2*j + 2 + 1].scatter(x[mask], y[mask], facecolor=color, rasterized=True, **kwds)
168
            if xerr is not None and yerr is not None and color == membership_color:
169
                axes[2*j + 2 + 1].errorbar(x[mask], y[mask],
170
                    xerr=xerr[mask], yerr=yerr[mask], 
171
                    fmt=None, ecolor="k", zorder=-1, rasterized=True)
172
173
            # Quote the number of points.
174
            axes[2*j + 2 + 1].text(0.05, 0.95 - i * 0.10, r"${:,}$".format(len(x[mask])),
175
                color=color,
176
                verticalalignment="top", horizontalalignment="left",
177
                transform=axes[2*j + 2 + 1].transAxes)
178
179
180
        if xlims is None:
181
            tc_xlims = axes[2*j + 2 + 1].get_xlim()
182
            percent = 0.20 # 10%
183
            half_ptp = (np.ptp(tc_xlims) * (1 + percent))/2.
184
            tc_xlims = (np.mean(tc_xlims) - half_ptp, half_ptp + np.mean(tc_xlims))
185
186
        else:
187
            tc_xlims = xlims
188
189
        if ylims is None:
190
            tc_ylims = axes[2*j + 2 + 1].get_ylim()
191
            # Expand the scale just a little bit.
192
            percent = 0.20 # 10%
193
            half_ptp = (np.ptp(tc_ylims) * (1 + percent))/2.
194
            tc_ylims = (np.mean(tc_ylims) - half_ptp, half_ptp + np.mean(tc_ylims))
195
        else:
196
            tc_ylims = ylims
197
198
        # X/Y for ASPCAP.
199
        for i, (mask, color) \
200
        in enumerate(zip((candidates, membership), (candidate_color, membership_color))):
201
202
            if "," in element_x:
203
                x = 0
204
                for each in element_x.split(","):
205
                    x += data["{0}_H{1}".format(each.upper(), aspcap_suffix)]
206
                    if x_wrt_fe:
207
                        x = x - data["FE_H{}".format(aspcap_suffix)]
208
            else:
209
                x = data["{0}_H{1}".format(element_x.upper(), aspcap_suffix)]
210
                if x_wrt_fe:
211
                    x = x - data["FE_H{}".format(aspcap_suffix)]
212
213
214
            if "," in element_y:
215
                y = 0
216
                for each in element_y.split(","):
217
                    y += data["{0}_H{1}".format(each.upper(), aspcap_suffix)]
218
                    if y_wrt_fe:
219
                        y = y - data["FE_H{}".format(aspcap_suffix)]
220
            else:
221
                y = data["{0}_H{1}".format(element_y.upper(), aspcap_suffix)]
222
                if y_wrt_fe:
223
                    y = y - data["FE_H{}".format(aspcap_suffix)]
224
225
            kwds = candidate_kwds if i == 0 else membership_kwds
226
            axes[2*j + 2].scatter(x[mask], y[mask], facecolor=color, rasterized=True, **kwds)
227
228
            N = sum((tc_xlims[1] > x[mask]) * (x[mask] > tc_xlims[0]) \
229
                  * (tc_ylims[1] > y[mask]) * (y[mask] > tc_ylims[0]))
230
            axes[2*j + 2].text(0.05, 0.95 - i * 0.10, r"${:,}$".format(N), color=color,
231
                verticalalignment="top", horizontalalignment="left",
232
                transform=axes[2*j + 2].transAxes)  
233
234
        
235
        if j == 0:
236
            axes[2*j + 2].set_title(r"${\rm ASPCAP}$", y=1.05)
237
            axes[2*j + 2 + 1].set_title(r"${\rm The}$ ${\rm Cannon}$", y=1.05)
238
    
239
240
        for ax in (axes[2*j + 2], axes[2*j + 2 + 1]):
241
            ax.set_xlim(tc_xlims)
242
            ax.set_ylim(tc_ylims)
243
244
            ax.xaxis.set_major_locator(MaxNLocator(4))
245
            ax.yaxis.set_major_locator(MaxNLocator(4))
246
247
            ax.set_xlabel(r"$[\rm{{{0}}}/\rm{{{1}}}]$".format(element_x.title(),
248
                "Fe" if x_wrt_fe else "H"))
249
        
250
        if "," in element_y:
251
            axes[2*j + 2].set_ylabel(r"$[(\rm{{{0}}})/{{{1}}}\rm{{{2}}}]$".format(
252
                element_y.replace(",", "+"), element_y.count(",") + 1,
253
                "Fe" if y_wrt_fe else "H"))
254
        else:
255
            axes[2*j + 2].set_ylabel(r"$[\rm{{{0}}}/\rm{{{1}}}]$".format(element_y.title(),
256
                    "Fe" if y_wrt_fe else "H"))
257
        axes[2*j + 2 + 1].yaxis.set_ticklabels([])
258
259
        
260
    for ax in axes[2:]:
261
        ax.set(adjustable='box-forced', aspect=np.ptp(ax.get_xlim())/np.ptp(ax.get_ylim()))
262
263
    fig.tight_layout()
264
265
    if vel_lim is not None:
266
        top_ax.set_xlim(vel_lim)
267
268
    fig.subplots_adjust(hspace=-0.0, bottom=0.03)
269
    pos = top_ax.get_position()
270
    top_ax.set_position([pos.x0, pos.y0 + 0.06, pos.width, pos.height - 0.06])
271
272
    return fig
273
274
275
if __name__ == "__main__":
276
277
    # To speed up development cycle..
278
    try:
279
        data
280
    except:
281
282
        catalog = Table.read("../tc-cse-regularized-apogee-catalog.fits.gz")
283
        ok = catalog["OK"] * (catalog["R_CHI_SQ"] < 3) * (catalog["TEFF"] > 4000) * (catalog["TEFF"] < 5500)
284
        data = catalog[ok]
285
286
287
288
    """
289
    M67 186
290
    N6791 173
291
    N2243 187
292
    N188 252
293
    N1333 380
294
    N6819 307
295
    N7789 327
296
297
    M35N2158 189
298
    M54SGRC1 386
299
    M5PAL5 317
300
    """
301
302
303
    # M 15
304
    M15_members = (data["FIELD"] == "M15") \
305
                * (data["VHELIO_AVG"] > -130) * (data["VHELIO_AVG"] < 80) \
306
                * (data["FE_H"] < -1.7)
307
    
308
    M15_figure = plot_cluster_comparison(data, "M15", M15_members, 
309
        ["C", "O", "Mg", "Ca", "Fe"], ["N", "Na", "Al", "S", "C,N,O"],
310
        vel_lim=(-400, 150))
311
    M15_figure.savefig("M15_comparison.pdf", dpi=300)
312
313
314
    # M13
315
    M13_members = (data["FIELD"] == "M13") \
316
                * (data["VHELIO_AVG"] > -265) * (data["VHELIO_AVG"] < -220) \
317
                * (data["FE_H"] < -1.2)
318
    M13_figure = plot_cluster_comparison(data, "M13", M13_members, 
319
        ["C", "O", "Mg", "Ca", "Fe"], ["N", "Na", "Al", "S", "C,N,O"],
320
        )
321
    M13_figure.savefig("M13_comparison.pdf", dpi=300)
322
323
324
    # M 92
325
    M92_members = (data["FIELD"] == "M92") \
326
                * (data["VHELIO_AVG"] > -140) * (data["VHELIO_AVG"] < -100) \
327
                * (data["FE_H"] < -1.7)
328
    M92_figure = plot_cluster_comparison(data, "M92", M92_members, 
329
        ["C", "O", "Mg", "Ca", "Fe"], ["N", "Na", "Al", "S", "C,N,O"],
330
        )
331
    M92_figure.savefig("M92_comparison.pdf", dpi=300)
332
333
334
    # M 53
335
    M53_members = (data["FIELD"] == "M53") \
336
                * (data["VHELIO_AVG"] > -80) * (data["VHELIO_AVG"] < -40) \
337
                * (data["FE_H"] < -1.5)
338
    M53_figure = plot_cluster_comparison(data, "M53", M53_members, 
339
        ["C", "O", "Mg", "Ca", "Fe"], ["N", "Na", "Al", "S", "C,N,O"],
340
        )
341
    M53_figure.savefig("M53_comparison.pdf", dpi=300)
342
343
344
        
345
346
    # M 3
347
    M3_members = (data["FIELD"] == "M3") \
348
               * (data["VHELIO_AVG"] > -165) * (data["VHELIO_AVG"] < -120) \
349
               * (data["FE_H"] > -1.65) * (data["FE_H"] < -1.2)
350
351
    M3_figure = plot_cluster_comparison(data, "M3", M3_members,
352
        ["C", "O", "Mg", "Ca", "Fe"], ["N", "Na", "Al", "S", "C,N,O"])
353
    M3_figure.savefig("M3_comparison.pdf", dpi=300)
354
355