Completed
Push — master ( 4fc32f...027495 )
by Andy
01:22
created

plot_test_scalar_metrics()   F

Complexity

Conditions 10

Size

Total Lines 82

Duplication

Lines 0
Ratio 0 %

Importance

Changes 2
Bugs 0 Features 0
Metric Value
dl 0
loc 82
rs 3.7113
c 2
b 0
f 0
cc 10

How to fix   Long Method    Complexity   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Complexity

Complex classes like plot_test_scalar_metrics() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
"""
2
Plot some quality metrics for gridsearch models.
3
"""
4
5
import matplotlib
6
matplotlib.rcParams["text.usetex"] = True
7
import matplotlib.pyplot as plt
8
from matplotlib.ticker import MaxNLocator
9
import numpy as np
10
import logging
11
12
from glob import glob
13
from six.moves import cPickle as pickle
14
15
16
17
def plot_test_scalar_metrics(metric_function, filenames, metric_label=None,
18
    debug=True, **kwargs):
19
20
    Lambdas = []
21
    scale_factors = []
22
    metrics = []
23
24
    for filename in filenames:
25
26
        try:
27
            _ = filename.split("-")
28
            scale_factor = float(_[1])
29
            log10_Lambda = float(_[2].split(".model")[0])
30
31
        except:
32
            print("Skipping filename {}".format(filename))
33
            continue
34
35
        with open(filename, "rb") as fp:
36
            contents = pickle.load(fp, encoding="latin-1")
37
38
        snrs, high_snr_expected, high_snr_inferred, \
39
            differences_expected, differences_inferred, \
40
            single_visit_inferred = contents
41
42
        # Calculate the metric.
43
        try:
44
            metric = metric_function(*contents)
45
            metric = float(metric) # Must be a float
46
47
        except:
48
            logging.exception("Failed to calculate metric for {}".format(filename))
49
            if debug: raise
50
51
        else:
52
53
            metrics.append(metric)
54
            Lambdas.append(10**log10_Lambda)
55
            scale_factors.append(scale_factor)
56
57
    metrics = np.array(metrics)
58
    Lambdas = np.array(Lambdas)
59
    scale_factors = np.array(scale_factors)
60
61
    # Make the figure
62
    fig, ax = plt.subplots()
63
64
    # scale factors are non-linear, so lets show them as indices then we will
65
    # adjust the y-ticks and labels as necessary
66
    unique_scale_factors = list(np.sort(np.unique(scale_factors)))
67
    scale_factor_indices \
68
        = np.array([unique_scale_factors.index(_) for _ in scale_factors])
69
70
    # Scale the points so that the best metric has s=250.
71
    unity = 250 * min(metrics)
72
    scat = ax.scatter(Lambdas, scale_factor_indices, c=metrics, s=unity/metrics,
73
        cmap=plt.cm.plasma, vmin=0.04, vmax=0.11, **kwargs)
74
    ax.set_yticks(np.arange(len(unique_scale_factors)))
75
    ax.set_yticklabels([r"${0:.1f}$".format(_) for _ in unique_scale_factors])
76
    ax.set_ylim(-1, len(unique_scale_factors))
77
78
    # Draw a circle around the best three.
79
    #for index, color in zip(np.argsort(metrics), ("k", )):#"#AAAAAA", "#BBBBBB", "#CCCCCC", "#DDDDDD")):
80
    #    ax.scatter([Lambdas[index]], [scale_factor_indices[index]],
81
    #        s=450, edgecolor=color, facecolor="w", zorder=-1, linewidths=2)
82
        
83
84
    ax.semilogx()
85
86
    for _ in np.arange(len(unique_scale_factors) - 1):
87
        ax.axhline(_ + 0.5, c="#EEEEEE", zorder=-1)
88
89
    cbar = plt.colorbar(scat)
90
    cbar.set_label(metric_label or r"Metric")
91
    cbar.set_ticks([0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.10, 0.11])
92
93
    ax.set_xlabel(r"$\rm{Regularization},$ $\Lambda$")
94
    ax.set_ylabel(r"$\rm{Scale}$ $\rm{factor},$ $f$")
95
    ax.yaxis.set_tick_params(width=0)
96
97
    fig.tight_layout()
98
    return fig
99
100
if __name__ == "__main__":
101
102
    snapshot_filenames = glob("../*.individual_visits")
103
104
    metric_labels = [
105
        r"$\rm{median}\{|T_{\rm eff,combined} - T_{\rm eff,individual}|\}$",
106
        r"$\rm{median}\{|\log{g}_{\rm combined} - \log{g}_{\rm individual}|\}$",
107
        r"$\rm{median}\{|[\rm{Al}/\rm{H}]_{\rm combined} - [\rm{Al}/\rm{H}]_{\rm individual}|\}$", # AL
108
        r"$\rm{median}\{|[\rm{Ca}/\rm{H}]_{\rm combined} - [\rm{Ca}/\rm{H}]_{\rm individual}|\}$", #'CA',
109
        r"$\rm{median}\{|[\rm{C}/\rm{H}]_{\rm combined} - [\rm{C}/\rm{H}]_{\rm individual}|\}$", #'C',
110
        r"$\rm{median}\{|[\rm{Fe}/\rm{H}]_{\rm combined} - [\rm{Fe}/\rm{H}]_{\rm individual}|\}$", #'FE',
111
        r"$\rm{median}\{|[\rm{K}/\rm{H}]_{\rm combined} - [\rm{K}/\rm{H}]_{\rm individual}|\}$", #'K',
112
        r"$\rm{median}\{|[\rm{Mg}/\rm{H}]_{\rm combined} - [\rm{Mg}/\rm{H}]_{\rm individual}|\}$", #'MG',
113
        r"$\rm{median}\{|[\rm{Mn}/\rm{H}]_{\rm combined} - [\rm{Mn}/\rm{H}]_{\rm individual}|\}$", #'MN',
114
        r"$\rm{median}\{|[\rm{Na}/\rm{H}]_{\rm combined} - [\rm{Na}/\rm{H}]_{\rm individual}|\}$", #'NA',
115
        r"$\rm{median}\{|[\rm{Ni}/\rm{H}]_{\rm combined} - [\rm{Ni}/\rm{H}]_{\rm individual}|\}$", #'NI',
116
        r"$\rm{median}\{|[\rm{N}/\rm{H}]_{\rm combined} - [\rm{N}/\rm{H}]_{\rm individual}|\}$", #'N',
117
        r"$\rm{median}\{|[\rm{O}/\rm{H}]_{\rm combined} - [\rm{O}/\rm{H}]_{\rm individual}|\}$", #'O',
118
        r"$\rm{median}\{|[\rm{Si}/\rm{H}]_{\rm combined} - [\rm{Si}/\rm{H}]_{\rm individual}|\}$", #'SI',
119
        r"$\rm{median}\{|[\rm{S}/\rm{H}]_{\rm combined} - [\rm{S}/\rm{H}]_{\rm individual}|\}$", #'S',
120
        r"$\rm{median}\{|[\rm{Ti}/\rm{H}]_{\rm combined} - [\rm{Ti}/\rm{H}]_{\rm individual}|\}$", #'TI',
121
        r"$\rm{median}\{|[\rm{V}/\rm{H}]_{\rm combined} - [\rm{V}/\rm{H}]_{\rm individual}|\}$", #'V'
122
    ]
123
    label_names = [
124
        "TEFF",
125
        "LOGG",
126
        "AL",
127
        "CA",
128
        "C",
129
        "FE",
130
        "K",
131
        "MG",
132
        "MN",
133
        "NA",
134
        "NI",
135
        "N",
136
        "O",
137
        "SI",
138
        "S",
139
        "TI",
140
        "V"
141
    ]
142
143
    """
144
    figures = []
145
    for i, (metric_label, label_name) in enumerate(zip(metric_labels, label_names)):
146
147
        def metric(snrs, high_snr_expected, high_snr_inferred, 
148
            differences_expected, differences_inferred, single_visit_inferred):
149
            return np.sum(np.abs(differences_expected[:, i]))
150
151
        fig = plot_test_scalar_metrics(metric, snapshot_filenames, metric_label)
152
        
153
        fig.savefig("gs-mad-{0}.png".format(label_name), dpi=300)
154
    """
155
    def metric(snrs, high_snr_expected, high_snr_inferred, 
156
        differences_expected, differences_inferred, single_visit_inferred):
157
        return np.median(np.abs(differences_expected[:, 2:]))
158
159
    metric_label = r"${\rm median}\left(\left|[\rm{X}/\rm{H}]_{\rm combined} - [\rm{X}/\rm{H}]_{\rm individual}\right|\right)$"
160
    fig = plot_test_scalar_metrics(metric, snapshot_filenames, metric_label)
161
    fig.savefig("gs-mad-all-elements.pdf", dpi=300)
162
163
    raise a