Conditions | 17 |
Total Lines | 149 |
Lines | 0 |
Ratio | 0 % |
Changes | 2 | ||
Bugs | 1 | Features | 0 |
Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.
For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.
Commonly applied refactorings include:
If many parameters/temporary variables are present:
Complex classes like fit_spectrum() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
1 | #!/usr/bin/env python |
||
22 | def fit_spectrum(flux, ivar, initial_labels, vectorizer, theta, s2, fiducials, |
||
23 | scales, dispersion=None, **kwargs): |
||
24 | """ |
||
25 | Fit a single spectrum by least-squared fitting. |
||
26 | |||
27 | :param flux: |
||
28 | The normalized flux values. |
||
29 | |||
30 | :param ivar: |
||
31 | The inverse variance array for the normalized fluxes. |
||
32 | |||
33 | :param initial_labels: |
||
34 | The point(s) to initialize optimization from. |
||
35 | |||
36 | :param vectorizer: |
||
37 | The vectorizer to use when fitting the data. |
||
38 | |||
39 | :param theta: |
||
40 | The theta coefficients (spectral derivatives) of the trained model. |
||
41 | |||
42 | :param s2: |
||
43 | The pixel scatter (s^2) array for each pixel. |
||
44 | |||
45 | :param dispersion: [optional] |
||
46 | The dispersion (e.g., wavelength) points for the normalized fluxes. |
||
47 | |||
48 | :returns: |
||
49 | A three-length tuple containing: the optimized labels, the covariance |
||
50 | matrix, and metadata associated with the optimization. |
||
51 | """ |
||
52 | |||
53 | adjusted_ivar = ivar/(1. + ivar * s2) |
||
54 | |||
55 | # Exclude non-finite points (e.g., points with zero inverse variance |
||
56 | # or non-finite flux values, but the latter shouldn't exist anyway). |
||
57 | use = np.isfinite(flux * adjusted_ivar) * (adjusted_ivar > 0) |
||
58 | L = len(vectorizer.label_names) |
||
59 | |||
60 | if not np.any(use): |
||
61 | logger.warn("No information in spectrum!") |
||
62 | return (np.nan * np.ones(L), None, { |
||
63 | "fail_message": "Pixels contained no information"}) |
||
64 | |||
65 | # Splice the arrays we will use most. |
||
66 | flux = flux[use] |
||
67 | weights = np.sqrt(adjusted_ivar[use]) # --> 1.0 / sigma |
||
68 | use_theta = theta[use] |
||
69 | |||
70 | initial_labels = np.atleast_2d(initial_labels) |
||
71 | |||
72 | # Check the vectorizer whether it has a derivative built in. |
||
73 | Dfun = kwargs.pop("Dfun", True) |
||
74 | if Dfun not in (None, False): |
||
75 | try: |
||
76 | vectorizer.get_label_vector_derivative(initial_labels[0]) |
||
77 | |||
78 | except NotImplementedError: |
||
79 | Dfun = None |
||
80 | logger.warn("No label vector derivatives available in {}!".format( |
||
81 | vectorizer)) |
||
82 | |||
83 | except: |
||
84 | logger.exception("Exception raised when trying to calculate the "\ |
||
85 | "label vector derivative at the fiducial values:") |
||
86 | raise |
||
87 | |||
88 | else: |
||
89 | # Use the label vector derivative. |
||
90 | Dfun = lambda parameters: weights * np.dot(use_theta, |
||
91 | vectorizer.get_label_vector_derivative(parameters)).T |
||
92 | |||
93 | else: |
||
94 | Dfun = None |
||
95 | |||
96 | def func(parameters): |
||
97 | return np.dot(use_theta, vectorizer(parameters))[:, 0] |
||
98 | |||
99 | def residuals(parameters): |
||
100 | return weights * (func(parameters) - flux) |
||
101 | |||
102 | kwds = { |
||
103 | "func": residuals, |
||
104 | "Dfun": Dfun, |
||
105 | "col_deriv": True, |
||
106 | |||
107 | # These get passed through to leastsq: |
||
108 | "ftol": 7./3 - 4./3 - 1, # Machine precision. |
||
109 | "xtol": 7./3 - 4./3 - 1, # Machine precision. |
||
110 | "gtol": 0.0, |
||
111 | "maxfev": 100000, # MAGIC |
||
112 | "epsfcn": None, |
||
113 | "factor": 1.0, |
||
114 | } |
||
115 | |||
116 | # Only update the keywords with things that op.curve_fit/op.leastsq expects. |
||
117 | for key in set(kwargs).intersection(kwds): |
||
118 | kwds[key] = kwargs[key] |
||
119 | |||
120 | results = [] |
||
121 | for x0 in initial_labels: |
||
122 | |||
123 | try: |
||
124 | op_labels, cov, meta, mesg, ier = op.leastsq( |
||
125 | x0=(x0 - fiducials)/scales, full_output=True, **kwds) |
||
126 | |||
127 | except RuntimeError: |
||
128 | logger.exception("Exception in fitting from {}".format(x0)) |
||
129 | continue |
||
130 | |||
131 | meta.update( |
||
132 | dict(x0=x0, chi_sq=np.sum(meta["fvec"]**2), ier=ier, mesg=mesg)) |
||
133 | results.append((op_labels, cov, meta)) |
||
134 | |||
135 | if len(results) == 0: |
||
136 | logger.warn("No results found!") |
||
137 | return (np.nan * np.ones(L), None, dict(fail_message="No results found")) |
||
138 | |||
139 | best_result_index = np.nanargmin([m["chi_sq"] for (o, c, m) in results]) |
||
140 | op_labels, cov, meta = results[best_result_index] |
||
141 | |||
142 | # De-scale the optimized labels. |
||
143 | meta["model_flux"] = func(op_labels) |
||
144 | op_labels = op_labels * scales + fiducials |
||
145 | |||
146 | if np.allclose(op_labels, meta["x0"]): |
||
147 | logger.warn( |
||
148 | "Discarding optimized result because it is exactly the same as the " |
||
149 | "initial value!") |
||
150 | |||
151 | # We are in dire straits. We should not trust the result. |
||
152 | op_labels *= np.nan |
||
153 | meta["fail_message"] = "Optimized result same as initial value." |
||
154 | |||
155 | if not np.any(np.isfinite(cov)): |
||
156 | logger.warn("Non-finite covariance matrix returned!") |
||
157 | |||
158 | # Save additional information. |
||
159 | meta.update({ |
||
160 | "method": "leastsq", |
||
161 | "label_names": vectorizer.label_names, |
||
162 | "best_result_index": best_result_index, |
||
163 | "derivatives_used": Dfun is not None, |
||
164 | "snr": np.nanmedian(flux * weights), |
||
165 | "r_chi_sq": meta["chi_sq"]/(use.sum() - L - 1), |
||
166 | }) |
||
167 | for key in ("ftol", "xtol", "gtol", "maxfev", "factor", "epsfcn"): |
||
168 | meta[key] = kwds[key] |
||
169 | |||
170 | return (op_labels, cov, meta) |
||
171 | |||
433 |