| Conditions | 18 |
| Total Lines | 140 |
| Lines | 0 |
| Ratio | 0 % |
| Changes | 1 | ||
| Bugs | 1 | Features | 0 |
Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.
For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.
Commonly applied refactorings include:
If many parameters/temporary variables are present:
Complex classes like theta() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
| 1 | #!/usr/bin/env python |
||
| 26 | def theta(model, indices=None, label_terms=None, show_label_terms=True, |
||
| 27 | normalize=True, common_axis=False, latex_label_names=None, xlim=None, |
||
| 28 | **kwargs): |
||
| 29 | """ |
||
| 30 | Plot the spectral derivates (:math:`\boldsymbol{\theta}` coefficiets) from a |
||
| 31 | trained model. |
||
| 32 | |||
| 33 | :param model: |
||
| 34 | A trained CannonModel object. |
||
| 35 | |||
| 36 | :param indices: [optional] |
||
| 37 | The indices of :math:`\boldsymbol{\theta}` to plot. By default all |
||
| 38 | coefficients will be shown. |
||
| 39 | |||
| 40 | :param label_terms: [optional]: |
||
| 41 | Specify the label terms to show coefficients for. This is similar to |
||
| 42 | specifying the `indices`, except you don't have to calculate the position |
||
| 43 | of each label name. |
||
| 44 | |||
| 45 | For example, specifying ``indices=0`` and ``label_terms=['TEFF', 'MG_H']`` |
||
| 46 | would show the first :math:`\theta` value (mean flux), as well as the |
||
| 47 | :math:`\theta` coefficients that correspond to the linear terms of |
||
| 48 | ``'TEFF'`` and ``'MG_H'``. |
||
| 49 | |||
| 50 | Note that label_terms is specific to the model vectorizer. |
||
| 51 | The vectorizer must be able to identify the label term by the inputs |
||
| 52 | provided (e.g., a polynomial vectorizer will recognize ``'TEFF'`` is the |
||
| 53 | linear coefficient of ``'TEFF'``, but ``'TEFF'`` on its own may not be |
||
| 54 | recognisable to a vectorizer that uses sine and cosine functions.) |
||
| 55 | |||
| 56 | :param show_label_terms: [optional] |
||
| 57 | Show the label terms on the right hand side of each axis. |
||
| 58 | |||
| 59 | :param normalize: [optional] |
||
| 60 | Normalize each coefficient between [-1, 1], except for the first theta |
||
| 61 | coefficient (mean flux). |
||
| 62 | |||
| 63 | :param common_axis: [optional] |
||
| 64 | Show all spectral derivatives on a single axes. |
||
| 65 | |||
| 66 | :param latex_label_names: [optional] |
||
| 67 | A list containing the label names as LaTeX representations. |
||
| 68 | |||
| 69 | :param xlim: [optional] |
||
| 70 | The x-limits to apply to all axes. |
||
| 71 | |||
| 72 | :returns: |
||
| 73 | A figure showing the spectral derivatives. |
||
| 74 | """ |
||
| 75 | |||
| 76 | if not model.is_trained: |
||
| 77 | raise ValueError("model needs to be trained first") |
||
| 78 | |||
| 79 | if latex_label_names is None: |
||
| 80 | label_names = model.vectorizer.label_names |
||
| 81 | else: |
||
| 82 | label_names = latex_label_names |
||
| 83 | |||
| 84 | if indices is None and label_terms is None: |
||
| 85 | label_indices = np.arange(model.theta.shape[1]) |
||
| 86 | else: |
||
| 87 | label_indices = [] |
||
| 88 | if indices is not None: |
||
| 89 | label_indices.extend(np.array(indices).astype(int).flatten()) |
||
| 90 | if label_terms is not None: |
||
| 91 | raise NotImplementedError |
||
| 92 | |||
| 93 | label_indices = np.array(label_indices) |
||
| 94 | |||
| 95 | if len(set(label_indices)) < label_indices.size: |
||
| 96 | logger.warn("Removing duplicate label indices") |
||
| 97 | label_indices = np.unique(label_indices) |
||
| 98 | |||
| 99 | K = len(label_indices) |
||
| 100 | |||
| 101 | fig, axes = plt.subplots(K) |
||
| 102 | axes = np.array([axes]).flatten() |
||
| 103 | |||
| 104 | if common_axis: |
||
| 105 | raise NotImplementedError |
||
| 106 | |||
| 107 | if model.dispersion is None: |
||
| 108 | x = np.arange(model.theta.shape[0]) |
||
| 109 | else: |
||
| 110 | x = model.dispersion |
||
| 111 | |||
| 112 | plot_kwds = dict(c="b", lw=1) |
||
| 113 | plot_kwds.update(kwargs.get("plot_kwds", {})) |
||
| 114 | |||
| 115 | for i, (ax, label_index) in enumerate(zip(axes, label_indices)): |
||
| 116 | |||
| 117 | y = model.theta.T[label_index].copy() |
||
| 118 | scale = np.max(np.abs(y)) if normalize and label_index != 0 else 1.0 |
||
| 119 | |||
| 120 | ax.plot(x, y/scale, **plot_kwds) |
||
| 121 | |||
| 122 | if normalize and label_index != 0: |
||
| 123 | ax.set_ylim(-1.2, 1.2) |
||
| 124 | ax.set_yticks([-1, 1]) |
||
| 125 | ylabel = r"$\theta_{{{0}}}/\max{{|\theta_{{{0}}}|}}$".format(label_index) |
||
| 126 | |||
| 127 | else: |
||
| 128 | ylabel = r"$\theta_{{{0}}}$".format(label_index) |
||
| 129 | ax.yaxis.set_major_locator(MaxNLocator(3)) |
||
| 130 | |||
| 131 | |||
| 132 | ax.set_ylabel(ylabel, rotation=0, verticalalignment="center") |
||
| 133 | ax.yaxis.labelpad = 30 |
||
| 134 | |||
| 135 | if show_label_terms: |
||
| 136 | rhs_ylabel = model.vectorizer.get_human_readable_label_term(label_index, |
||
| 137 | label_names=label_names, mul='\cdot', pow='^') |
||
| 138 | ax_rhs = ax.twinx() |
||
| 139 | if latex_label_names is not None: |
||
| 140 | rhs_ylabel = r"${}$".format(rhs_ylabel) |
||
| 141 | |||
| 142 | ax_rhs.set_ylabel(rhs_ylabel, rotation=0, verticalalignment="center") |
||
| 143 | ax_rhs.yaxis.labelpad = 30 |
||
| 144 | ax_rhs.set_yticks([]) |
||
| 145 | |||
| 146 | |||
| 147 | if ax.is_last_row(): |
||
| 148 | if model.dispersion is None: |
||
| 149 | xlabel = r"${\rm Pixel}$" |
||
| 150 | else: |
||
| 151 | xlabel = r"${\rm Wavelength},$ $({\rm AA})$" |
||
| 152 | ax.set_xlabel(xlabel) |
||
| 153 | |||
| 154 | else: |
||
| 155 | ax.set_xticklabels([]) |
||
| 156 | |||
| 157 | # Set RHS label. |
||
| 158 | ax.xaxis.set_major_locator(MaxNLocator(6)) |
||
| 159 | |||
| 160 | ax.set_xlim(xlim) |
||
| 161 | |||
| 162 | fig.tight_layout() |
||
| 163 | fig.subplots_adjust(hspace=0.10) |
||
| 164 | |||
| 165 | return fig |
||
| 166 | |||
| 320 |