Conditions | 18 |
Total Lines | 140 |
Lines | 0 |
Ratio | 0 % |
Changes | 1 | ||
Bugs | 1 | Features | 0 |
Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.
For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.
Commonly applied refactorings include:
If many parameters/temporary variables are present:
Complex classes like theta() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
1 | #!/usr/bin/env python |
||
26 | def theta(model, indices=None, label_terms=None, show_label_terms=True, |
||
27 | normalize=True, common_axis=False, latex_label_names=None, xlim=None, |
||
28 | **kwargs): |
||
29 | """ |
||
30 | Plot the spectral derivates (:math:`\boldsymbol{\theta}` coefficiets) from a |
||
31 | trained model. |
||
32 | |||
33 | :param model: |
||
34 | A trained CannonModel object. |
||
35 | |||
36 | :param indices: [optional] |
||
37 | The indices of :math:`\boldsymbol{\theta}` to plot. By default all |
||
38 | coefficients will be shown. |
||
39 | |||
40 | :param label_terms: [optional]: |
||
41 | Specify the label terms to show coefficients for. This is similar to |
||
42 | specifying the `indices`, except you don't have to calculate the position |
||
43 | of each label name. |
||
44 | |||
45 | For example, specifying ``indices=0`` and ``label_terms=['TEFF', 'MG_H']`` |
||
46 | would show the first :math:`\theta` value (mean flux), as well as the |
||
47 | :math:`\theta` coefficients that correspond to the linear terms of |
||
48 | ``'TEFF'`` and ``'MG_H'``. |
||
49 | |||
50 | Note that label_terms is specific to the model vectorizer. |
||
51 | The vectorizer must be able to identify the label term by the inputs |
||
52 | provided (e.g., a polynomial vectorizer will recognize ``'TEFF'`` is the |
||
53 | linear coefficient of ``'TEFF'``, but ``'TEFF'`` on its own may not be |
||
54 | recognisable to a vectorizer that uses sine and cosine functions.) |
||
55 | |||
56 | :param show_label_terms: [optional] |
||
57 | Show the label terms on the right hand side of each axis. |
||
58 | |||
59 | :param normalize: [optional] |
||
60 | Normalize each coefficient between [-1, 1], except for the first theta |
||
61 | coefficient (mean flux). |
||
62 | |||
63 | :param common_axis: [optional] |
||
64 | Show all spectral derivatives on a single axes. |
||
65 | |||
66 | :param latex_label_names: [optional] |
||
67 | A list containing the label names as LaTeX representations. |
||
68 | |||
69 | :param xlim: [optional] |
||
70 | The x-limits to apply to all axes. |
||
71 | |||
72 | :returns: |
||
73 | A figure showing the spectral derivatives. |
||
74 | """ |
||
75 | |||
76 | if not model.is_trained: |
||
77 | raise ValueError("model needs to be trained first") |
||
78 | |||
79 | if latex_label_names is None: |
||
80 | label_names = model.vectorizer.label_names |
||
81 | else: |
||
82 | label_names = latex_label_names |
||
83 | |||
84 | if indices is None and label_terms is None: |
||
85 | label_indices = np.arange(model.theta.shape[1]) |
||
86 | else: |
||
87 | label_indices = [] |
||
88 | if indices is not None: |
||
89 | label_indices.extend(np.array(indices).astype(int).flatten()) |
||
90 | if label_terms is not None: |
||
91 | raise NotImplementedError |
||
92 | |||
93 | label_indices = np.array(label_indices) |
||
94 | |||
95 | if len(set(label_indices)) < label_indices.size: |
||
96 | logger.warn("Removing duplicate label indices") |
||
97 | label_indices = np.unique(label_indices) |
||
98 | |||
99 | K = len(label_indices) |
||
100 | |||
101 | fig, axes = plt.subplots(K) |
||
102 | axes = np.array([axes]).flatten() |
||
103 | |||
104 | if common_axis: |
||
105 | raise NotImplementedError |
||
106 | |||
107 | if model.dispersion is None: |
||
108 | x = np.arange(model.theta.shape[0]) |
||
109 | else: |
||
110 | x = model.dispersion |
||
111 | |||
112 | plot_kwds = dict(c="b", lw=1) |
||
113 | plot_kwds.update(kwargs.get("plot_kwds", {})) |
||
114 | |||
115 | for i, (ax, label_index) in enumerate(zip(axes, label_indices)): |
||
116 | |||
117 | y = model.theta.T[label_index].copy() |
||
118 | scale = np.max(np.abs(y)) if normalize and label_index != 0 else 1.0 |
||
119 | |||
120 | ax.plot(x, y/scale, **plot_kwds) |
||
121 | |||
122 | if normalize and label_index != 0: |
||
123 | ax.set_ylim(-1.2, 1.2) |
||
124 | ax.set_yticks([-1, 1]) |
||
125 | ylabel = r"$\theta_{{{0}}}/\max{{|\theta_{{{0}}}|}}$".format(label_index) |
||
126 | |||
127 | else: |
||
128 | ylabel = r"$\theta_{{{0}}}$".format(label_index) |
||
129 | ax.yaxis.set_major_locator(MaxNLocator(3)) |
||
130 | |||
131 | |||
132 | ax.set_ylabel(ylabel, rotation=0, verticalalignment="center") |
||
133 | ax.yaxis.labelpad = 30 |
||
134 | |||
135 | if show_label_terms: |
||
136 | rhs_ylabel = model.vectorizer.get_human_readable_label_term(label_index, |
||
137 | label_names=label_names, mul='\cdot', pow='^') |
||
138 | ax_rhs = ax.twinx() |
||
139 | if latex_label_names is not None: |
||
140 | rhs_ylabel = r"${}$".format(rhs_ylabel) |
||
141 | |||
142 | ax_rhs.set_ylabel(rhs_ylabel, rotation=0, verticalalignment="center") |
||
143 | ax_rhs.yaxis.labelpad = 30 |
||
144 | ax_rhs.set_yticks([]) |
||
145 | |||
146 | |||
147 | if ax.is_last_row(): |
||
148 | if model.dispersion is None: |
||
149 | xlabel = r"${\rm Pixel}$" |
||
150 | else: |
||
151 | xlabel = r"${\rm Wavelength},$ $({\rm AA})$" |
||
152 | ax.set_xlabel(xlabel) |
||
153 | |||
154 | else: |
||
155 | ax.set_xticklabels([]) |
||
156 | |||
157 | # Set RHS label. |
||
158 | ax.xaxis.set_major_locator(MaxNLocator(6)) |
||
159 | |||
160 | ax.set_xlim(xlim) |
||
161 | |||
162 | fig.tight_layout() |
||
163 | fig.subplots_adjust(hspace=0.10) |
||
164 | |||
165 | return fig |
||
166 | |||
320 |