| Conditions | 9 |
| Total Lines | 105 |
| Lines | 0 |
| Ratio | 0 % |
| Changes | 1 | ||
| Bugs | 1 | Features | 0 |
Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.
For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.
Commonly applied refactorings include:
If many parameters/temporary variables are present:
| 1 | #!/usr/bin/env python |
||
| 215 | def one_to_one(model, test_labels, cov=None, latex_label_names=None, |
||
| 216 | show_statistics=True, **kwargs): |
||
| 217 | """ |
||
| 218 | Plot a one-to-one comparison of the training set labels, and the test set |
||
| 219 | labels inferred from the training set spectra. |
||
| 220 | |||
| 221 | :param model: |
||
| 222 | A trained CannonModel object. |
||
| 223 | |||
| 224 | :param test_labels: |
||
| 225 | An array of test labels, inferred from the training set spectra. |
||
| 226 | |||
| 227 | :param cov: [optional] |
||
| 228 | The covariance matrix returned for all test labels. |
||
| 229 | |||
| 230 | :param latex_label_names: [optional] |
||
| 231 | A list of label names in LaTeX representation. |
||
| 232 | |||
| 233 | :param show_statistics: [optional] |
||
| 234 | Show the mean and standard deviation of residuals in each axis. |
||
| 235 | """ |
||
| 236 | |||
| 237 | if model.training_set_labels.shape != test_labels.shape: |
||
| 238 | raise ValueError( |
||
| 239 | "test labels must have the same shape as training set labels") |
||
| 240 | |||
| 241 | N, K = test_labels.shape |
||
| 242 | if cov is not None and cov.shape != (N, K, K): |
||
| 243 | raise ValueError( |
||
| 244 | "shape mis-match in covariance matrix ({N}, {K}, {K}) != {shape}"\ |
||
| 245 | .format(N=N, K=K, shape=cov.shape)) |
||
| 246 | |||
| 247 | factor = 2.0 |
||
| 248 | lbdim = 0.30 * factor |
||
| 249 | tdim = 0.25 * factor |
||
| 250 | rdim = 0.10 * factor |
||
| 251 | wspace = 0.05 |
||
| 252 | hspace = 0.35 |
||
| 253 | yspace = factor * K + factor * (K - 1.) * hspace |
||
| 254 | xspace = factor |
||
| 255 | |||
| 256 | xdim = lbdim + xspace + rdim |
||
| 257 | ydim = lbdim + yspace + tdim |
||
| 258 | |||
| 259 | fig, axes = plt.subplots(K, figsize=(xdim, ydim)) |
||
| 260 | |||
| 261 | l, b = (lbdim / xdim, lbdim / ydim) |
||
| 262 | t, r = ((lbdim + yspace) / ydim, ((lbdim + xspace) / xdim)) |
||
| 263 | |||
| 264 | fig.subplots_adjust(left=l, bottom=b, right=r, top=t, wspace=wspace, hspace=hspace) |
||
| 265 | |||
| 266 | axes = np.array([axes]).flatten() |
||
| 267 | |||
| 268 | scatter_kwds = dict(s=1, c="k", alpha=0.5) |
||
| 269 | scatter_kwds.update(kwargs.get("scatter_kwds", {})) |
||
| 270 | |||
| 271 | errorbar_kwds = dict(fmt=None, ecolor="k", alpha=0.5, capsize=0) |
||
| 272 | errorbar_kwds.update(kwargs.get("errorbar_kwds", {})) |
||
| 273 | |||
| 274 | for i, ax in enumerate(axes): |
||
| 275 | |||
| 276 | x = model.training_set_labels[:, i] |
||
| 277 | y = test_labels[:, i] |
||
| 278 | |||
| 279 | ax.scatter(x, y, **scatter_kwds) |
||
| 280 | if cov is not None: |
||
| 281 | yerr = cov[:, i, i]**0.5 |
||
| 282 | ax.errorbar(x, y, yerr=yerr, **errorbar_kwds) |
||
| 283 | |||
| 284 | # Set x-axis limits and y-axis limits the same |
||
| 285 | limits = np.array([ax.get_xlim(), ax.get_ylim()]) |
||
| 286 | limits = (np.min(limits), np.max(limits)) |
||
| 287 | |||
| 288 | ax.plot(limits, limits, c="#666666", linestyle=":", zorder=-1) |
||
| 289 | ax.set_xlim(limits) |
||
| 290 | ax.set_ylim(limits) |
||
| 291 | |||
| 292 | label_name = model.vectorizer.label_names[i] |
||
| 293 | |||
| 294 | if latex_label_names is not None: |
||
| 295 | try: |
||
| 296 | label_name = r"${}$".format(latex_label_names[i]) |
||
| 297 | except: |
||
| 298 | logger.warn( |
||
| 299 | "Could not access latex label name for index {} ({})"\ |
||
| 300 | .format(i, label_name)) |
||
| 301 | |||
| 302 | ax.set_title(label_name) |
||
| 303 | |||
| 304 | ax.xaxis.set_major_locator(MaxNLocator(4)) |
||
| 305 | ax.yaxis.set_major_locator(MaxNLocator(4)) |
||
| 306 | |||
| 307 | # Show mean and sigma. |
||
| 308 | if show_statistics: |
||
| 309 | diff = y - x |
||
| 310 | mu = np.median(diff) |
||
| 311 | sigma = np.std(diff) |
||
| 312 | ax.text(0.05, 0.85, r"$\mu = {0:.2f}$".format(mu), |
||
| 313 | transform=ax.transAxes) |
||
| 314 | ax.text(0.05, 0.75, r"$\sigma = {0:.2f}$".format(sigma), |
||
| 315 | transform=ax.transAxes) |
||
| 316 | |||
| 317 | ax.set_aspect(1.0) |
||
| 318 | |||
| 319 | return fig |
||
| 320 |