| Conditions | 18 | 
| Total Lines | 140 | 
| Lines | 0 | 
| Ratio | 0 % | 
| Changes | 1 | ||
| Bugs | 1 | Features | 0 | 
Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.
For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.
Commonly applied refactorings include:
If many parameters/temporary variables are present:
Complex classes like theta() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
| 1 | #!/usr/bin/env python | ||
| 26 | def theta(model, indices=None, label_terms=None, show_label_terms=True, | ||
| 27 | normalize=True, common_axis=False, latex_label_names=None, xlim=None, | ||
| 28 | **kwargs): | ||
| 29 | """ | ||
| 30 |     Plot the spectral derivates (:math:`\boldsymbol{\theta}` coefficiets) from a | ||
| 31 | trained model. | ||
| 32 | |||
| 33 | :param model: | ||
| 34 | A trained CannonModel object. | ||
| 35 | |||
| 36 | :param indices: [optional] | ||
| 37 |         The indices of :math:`\boldsymbol{\theta}` to plot. By default all | ||
| 38 | coefficients will be shown. | ||
| 39 | |||
| 40 | :param label_terms: [optional]: | ||
| 41 | Specify the label terms to show coefficients for. This is similar to | ||
| 42 | specifying the `indices`, except you don't have to calculate the position | ||
| 43 | of each label name. | ||
| 44 | |||
| 45 | For example, specifying ``indices=0`` and ``label_terms=['TEFF', 'MG_H']`` | ||
| 46 | would show the first :math:`\theta` value (mean flux), as well as the | ||
| 47 | :math:`\theta` coefficients that correspond to the linear terms of | ||
| 48 | ``'TEFF'`` and ``'MG_H'``. | ||
| 49 | |||
| 50 | Note that label_terms is specific to the model vectorizer. | ||
| 51 | The vectorizer must be able to identify the label term by the inputs | ||
| 52 | provided (e.g., a polynomial vectorizer will recognize ``'TEFF'`` is the | ||
| 53 | linear coefficient of ``'TEFF'``, but ``'TEFF'`` on its own may not be | ||
| 54 | recognisable to a vectorizer that uses sine and cosine functions.) | ||
| 55 | |||
| 56 | :param show_label_terms: [optional] | ||
| 57 | Show the label terms on the right hand side of each axis. | ||
| 58 | |||
| 59 | :param normalize: [optional] | ||
| 60 | Normalize each coefficient between [-1, 1], except for the first theta | ||
| 61 | coefficient (mean flux). | ||
| 62 | |||
| 63 | :param common_axis: [optional] | ||
| 64 | Show all spectral derivatives on a single axes. | ||
| 65 | |||
| 66 | :param latex_label_names: [optional] | ||
| 67 | A list containing the label names as LaTeX representations. | ||
| 68 | |||
| 69 | :param xlim: [optional] | ||
| 70 | The x-limits to apply to all axes. | ||
| 71 | |||
| 72 | :returns: | ||
| 73 | A figure showing the spectral derivatives. | ||
| 74 | """ | ||
| 75 | |||
| 76 | if not model.is_trained: | ||
| 77 |         raise ValueError("model needs to be trained first") | ||
| 78 | |||
| 79 | if latex_label_names is None: | ||
| 80 | label_names = model.vectorizer.label_names | ||
| 81 | else: | ||
| 82 | label_names = latex_label_names | ||
| 83 | |||
| 84 | if indices is None and label_terms is None: | ||
| 85 | label_indices = np.arange(model.theta.shape[1]) | ||
| 86 | else: | ||
| 87 | label_indices = [] | ||
| 88 | if indices is not None: | ||
| 89 | label_indices.extend(np.array(indices).astype(int).flatten()) | ||
| 90 | if label_terms is not None: | ||
| 91 | raise NotImplementedError | ||
| 92 | |||
| 93 | label_indices = np.array(label_indices) | ||
| 94 | |||
| 95 | if len(set(label_indices)) < label_indices.size: | ||
| 96 |         logger.warn("Removing duplicate label indices") | ||
| 97 | label_indices = np.unique(label_indices) | ||
| 98 | |||
| 99 | K = len(label_indices) | ||
| 100 | |||
| 101 | fig, axes = plt.subplots(K) | ||
| 102 | axes = np.array([axes]).flatten() | ||
| 103 | |||
| 104 | if common_axis: | ||
| 105 | raise NotImplementedError | ||
| 106 | |||
| 107 | if model.dispersion is None: | ||
| 108 | x = np.arange(model.theta.shape[0]) | ||
| 109 | else: | ||
| 110 | x = model.dispersion | ||
| 111 | |||
| 112 | plot_kwds = dict(c="b", lw=1) | ||
| 113 |     plot_kwds.update(kwargs.get("plot_kwds", {})) | ||
| 114 | |||
| 115 | for i, (ax, label_index) in enumerate(zip(axes, label_indices)): | ||
| 116 | |||
| 117 | y = model.theta.T[label_index].copy() | ||
| 118 | scale = np.max(np.abs(y)) if normalize and label_index != 0 else 1.0 | ||
| 119 | |||
| 120 | ax.plot(x, y/scale, **plot_kwds) | ||
| 121 | |||
| 122 | if normalize and label_index != 0: | ||
| 123 | ax.set_ylim(-1.2, 1.2) | ||
| 124 | ax.set_yticks([-1, 1]) | ||
| 125 |             ylabel = r"$\theta_{{{0}}}/\max{{|\theta_{{{0}}}|}}$".format(label_index) | ||
| 126 | |||
| 127 | else: | ||
| 128 |             ylabel = r"$\theta_{{{0}}}$".format(label_index) | ||
| 129 | ax.yaxis.set_major_locator(MaxNLocator(3)) | ||
| 130 | |||
| 131 | |||
| 132 | ax.set_ylabel(ylabel, rotation=0, verticalalignment="center") | ||
| 133 | ax.yaxis.labelpad = 30 | ||
| 134 | |||
| 135 | if show_label_terms: | ||
| 136 | rhs_ylabel = model.vectorizer.get_human_readable_label_term(label_index, | ||
| 137 | label_names=label_names, mul='\cdot', pow='^') | ||
| 138 | ax_rhs = ax.twinx() | ||
| 139 | if latex_label_names is not None: | ||
| 140 |                 rhs_ylabel = r"${}$".format(rhs_ylabel) | ||
| 141 | |||
| 142 | ax_rhs.set_ylabel(rhs_ylabel, rotation=0, verticalalignment="center") | ||
| 143 | ax_rhs.yaxis.labelpad = 30 | ||
| 144 | ax_rhs.set_yticks([]) | ||
| 145 | |||
| 146 | |||
| 147 | if ax.is_last_row(): | ||
| 148 | if model.dispersion is None: | ||
| 149 |                 xlabel = r"${\rm Pixel}$" | ||
| 150 | else: | ||
| 151 |                 xlabel = r"${\rm Wavelength},$ $({\rm AA})$" | ||
| 152 | ax.set_xlabel(xlabel) | ||
| 153 | |||
| 154 | else: | ||
| 155 | ax.set_xticklabels([]) | ||
| 156 | |||
| 157 | # Set RHS label. | ||
| 158 | ax.xaxis.set_major_locator(MaxNLocator(6)) | ||
| 159 | |||
| 160 | ax.set_xlim(xlim) | ||
| 161 | |||
| 162 | fig.tight_layout() | ||
| 163 | fig.subplots_adjust(hspace=0.10) | ||
| 164 | |||
| 165 | return fig | ||
| 166 | |||
| 320 |