fit_pixel_fixed_scatter()   F
last analyzed

Complexity

Conditions 17

Size

Total Lines 179

Duplication

Lines 0
Ratio 0 %

Importance

Changes 11
Bugs 1 Features 0
Metric Value
cc 17
c 11
b 1
f 0
dl 0
loc 179
rs 2

How to fix   Long Method    Complexity   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Complexity

Complex classes like fit_pixel_fixed_scatter() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
"""
5
Fitting functions for use in The Cannon.
6
"""
7
8
from __future__ import (division, print_function, absolute_import,
9
                        unicode_literals)
10
11
__all__ = ["fit_spectrum", "fit_pixel_fixed_scatter", "fit_theta_by_linalg",
12
    "chi_sq", "L1Norm_variation"]
13
14
import logging
15
import numpy as np
16
import scipy.optimize as op
17
from time import time
18
19
logger = logging.getLogger(__name__)
20
21
22
def fit_spectrum(flux, ivar, initial_labels, vectorizer, theta, s2, fiducials,
23
    scales, dispersion=None, use_derivatives=True, op_kwds=None):
24
    """
25
    Fit a single spectrum by least-squared fitting.
26
27
    :param flux:
28
        The normalized flux values.
29
30
    :param ivar:
31
        The inverse variance array for the normalized fluxes.
32
33
    :param initial_labels:
34
        The point(s) to initialize optimization from.
35
36
    :param vectorizer:
37
        The vectorizer to use when fitting the data.
38
39
    :param theta:
40
        The theta coefficients (spectral derivatives) of the trained model.
41
42
    :param s2:
43
        The pixel scatter (s^2) array for each pixel.
44
45
    :param dispersion: [optional]
46
        The dispersion (e.g., wavelength) points for the normalized fluxes.
47
48
    :param use_derivatives: [optional]
49
        Boolean `True` indicating to use analytic derivatives provided by 
50
        the vectorizer, `None` to calculate on the fly, or a callable
51
        function to calculate your own derivatives.
52
53
    :param op_kwds: [optional]
54
        Optimization keywords that get passed to `scipy.optimize.leastsq`.
55
56
    :returns:
57
        A three-length tuple containing: the optimized labels, the covariance
58
        matrix, and metadata associated with the optimization.
59
    """
60
61
    adjusted_ivar = ivar/(1. + ivar * s2)
62
63
    # Exclude non-finite points (e.g., points with zero inverse variance
64
    # or non-finite flux values, but the latter shouldn't exist anyway).
65
    use = np.isfinite(flux * adjusted_ivar) * (adjusted_ivar > 0)
66
    L = len(vectorizer.label_names)
67
68
    if not np.any(use):
69
        logger.warn("No information in spectrum!")
70
        return (np.nan * np.ones(L), None, {
71
                "fail_message": "Pixels contained no information"})
72
73
    # Splice the arrays we will use most.
74
    flux = flux[use]
75
    weights = np.sqrt(adjusted_ivar[use]) # --> 1.0 / sigma
76
    use_theta = theta[use]
77
78
    initial_labels = np.atleast_2d(initial_labels)
79
80
    # Check the vectorizer whether it has a derivative built in.
81
    if use_derivatives not in (None, False):
82
        try:
83
            vectorizer.get_label_vector_derivative(initial_labels[0])
84
85
        except NotImplementedError:
86
            Dfun = None
87
            logger.warn("No label vector derivatives available in {}!".format(
88
                vectorizer))
89
90
        except:
91
            logger.exception("Exception raised when trying to calculate the "\
92
                             "label vector derivative at the fiducial values:")
93
            raise
94
95
        else:
96
            # Use the label vector derivative.
97
            Dfun = lambda parameters: weights * np.dot(use_theta,
98
                vectorizer.get_label_vector_derivative(parameters)).T
99
100
    else:
101
        Dfun = None
102
103
    def func(parameters):
104
        return np.dot(use_theta, vectorizer(parameters))[:, 0]
105
106
    def residuals(parameters):
107
        return weights * (func(parameters) - flux)
108
109
    kwds = {
110
        "func": residuals,
111
        "Dfun": Dfun,
112
        "col_deriv": True,
113
114
        # These get passed through to leastsq:
115
        "ftol": 7./3 - 4./3 - 1, # Machine precision.
116
        "xtol": 7./3 - 4./3 - 1, # Machine precision.
117
        "gtol": 0.0,
118
        "maxfev": 100000, # MAGIC
119
        "epsfcn": None,
120
        "factor": 1.0,
121
    }
122
123
    # Only update the keywords with things that op.curve_fit/op.leastsq expects.
124
    if op_kwds is not None:
125
        for key in set(op_kwds).intersection(kwds):
126
            kwds[key] = op_kwds[key]
127
128
    results = []
129
    for x0 in initial_labels:
130
131
        try:
132
            op_labels, cov, meta, mesg, ier = op.leastsq(
133
                x0=(x0 - fiducials)/scales, full_output=True, **kwds)
134
135
        except RuntimeError:
136
            logger.exception("Exception in fitting from {}".format(x0))
137
            continue
138
139
        meta.update(
140
            dict(x0=x0, chi_sq=np.sum(meta["fvec"]**2), ier=ier, mesg=mesg))
141
        results.append((op_labels, cov, meta))
142
143
    if len(results) == 0:
144
        logger.warn("No results found!")
145
        return (np.nan * np.ones(L), None, dict(fail_message="No results found"))
146
147
    best_result_index = np.nanargmin([m["chi_sq"] for (o, c, m) in results])
148
    op_labels, cov, meta = results[best_result_index]
149
150
    # De-scale the optimized labels.
151
    meta["model_flux"] = func(op_labels)
152
    op_labels = op_labels * scales + fiducials
153
154
    if np.allclose(op_labels, meta["x0"]):
155
        logger.warn(
156
            "Discarding optimized result because it is exactly the same as the "
157
            "initial value!")
158
159
        # We are in dire straits. We should not trust the result.
160
        op_labels *= np.nan
161
        meta["fail_message"] = "Optimized result same as initial value."
162
163
    if cov is None:
164
        cov = np.ones((len(op_labels), len(op_labels)))
165
166
    if not np.any(np.isfinite(cov)):
167
        logger.warn("Non-finite covariance matrix returned!")
168
169
    # Save additional information.
170
    meta.update({
171
        "method": "leastsq",
172
        "label_names": vectorizer.label_names,
173
        "best_result_index": best_result_index,
174
        "derivatives_used": Dfun is not None,
175
        "snr": np.nanmedian(flux * weights),
176
        "r_chi_sq": meta["chi_sq"]/(use.sum() - L - 1),
177
    })
178
    for key in ("ftol", "xtol", "gtol", "maxfev", "factor", "epsfcn"):
179
        meta[key] = kwds[key]
180
181
    return (op_labels, cov, meta)
182
183
184
185
def fit_theta_by_linalg(flux, ivar, s2, design_matrix):
186
    """
187
    Fit theta coefficients to a set of normalized fluxes for a single pixel.
188
189
    :param flux:
190
        The normalized fluxes for a single pixel (across many stars).
191
192
    :param ivar:
193
        The inverse variance of the normalized flux values for a single pixel
194
        across many stars.
195
196
    :param s2:
197
        The noise residual (squared scatter term) to adopt in the pixel.
198
199
    :param design_matrix:
200
        The model design matrix.
201
202
    :returns:
203
        The label vector coefficients for the pixel, and the inverse variance
204
        matrix.
205
    """
206
207
    adjusted_ivar = ivar/(1. + ivar * s2)
208
    CiA = design_matrix * np.tile(adjusted_ivar, (design_matrix.shape[1], 1)).T
209
    try:
210
        ATCiAinv = np.linalg.inv(np.dot(design_matrix.T, CiA))
211
    except np.linalg.linalg.LinAlgError:
212
        N = design_matrix.shape[1]
213
        return (np.hstack([1, np.zeros(N - 1)]), np.inf * np.eye(N))
214
215
    ATY = np.dot(design_matrix.T, flux * adjusted_ivar)
216
    theta = np.dot(ATCiAinv, ATY)
217
218
    return (theta, ATCiAinv)
219
220
221
222
# TODO: This logic should probably go somewhere else.
223
224
225
def chi_sq(theta, design_matrix, flux, ivar, axis=None, gradient=True):
226
    """
227
    Calculate the chi-squared difference between the spectral model and flux.
228
229
    :param theta:
230
        The theta coefficients.
231
232
    :param design_matrix:
233
        The model design matrix.
234
235
    :param flux:
236
        The normalized flux values.
237
238
    :param ivar:
239
        The inverse variances of the normalized flux values.
240
241
    :param axis: [optional]
242
        The axis to sum the chi-squared values across.
243
244
    :param gradient: [optional]
245
        Return the chi-squared value and its derivatives (Jacobian).
246
247
    :returns:
248
        The chi-squared difference between the spectral model and flux, and
249
        optionally, the Jacobian.
250
    """
251
    residuals = np.dot(theta, design_matrix.T) - flux
252
253
    ivar_residuals = ivar * residuals
254
    f = np.sum(ivar_residuals * residuals, axis=axis)
255
    if not gradient:
256
        return f
257
258
    g = 2.0 * np.dot(design_matrix.T, ivar_residuals)
259
    return (f, g)
260
261
262
def L1Norm_variation(theta):
263
    """
264
    Return the L1 norm of theta (except the first entry) and its derivative.
265
266
    :param theta:
267
        An array of finite values.
268
269
    :returns:
270
        A two-length tuple containing: the L1 norm of theta (except the first
271
        entry), and the derivative of the L1 norm of theta.
272
    """
273
274
    return (np.sum(np.abs(theta[1:])), np.hstack([0.0, np.sign(theta[1:])]))
275
276
277
def _pixel_objective_function_fixed_scatter(theta, design_matrix, flux, ivar,
278
    regularization, gradient=True):
279
    """
280
    The objective function for a single regularized pixel with fixed scatter.
281
282
    :param theta:
283
        The spectral coefficients.
284
285
    :param normalized_flux:
286
        The normalized flux values for a single pixel across many stars.
287
288
    :param adjusted_ivar:
289
        The adjusted inverse variance of the normalized flux values for a single
290
        pixel across many stars. This adjusted inverse variance array should
291
        already have the scatter included.
292
293
    :param regularization:
294
        The regularization term to scale the L1 norm of theta with.
295
296
    :param design_matrix:
297
        The design matrix for the model.
298
299
    :param gradient: [optional]
300
        Also return the analytic derivative of the objective function.
301
    """
302
303
    if gradient:
304
        csq, d_csq = chi_sq(theta, design_matrix, flux, ivar, gradient=True)
305
        L1, d_L1 = L1Norm_variation(theta)
306
307
        f = csq + regularization * L1
308
        g = d_csq + regularization * d_L1
309
310
        return (f, g)
311
312
    else:
313
        csq = chi_sq(theta, design_matrix, flux, ivar, gradient=False)
314
        L1, d_L1 = L1Norm_variation(theta)
315
316
        return csq + regularization * L1
317
318
319
def _scatter_objective_function(scatter, residuals_squared, ivar):
320
    adjusted_ivar = ivar/(1.0 + ivar * scatter**2)
321
    chi_sq = residuals_squared * adjusted_ivar
322
    return (np.median(chi_sq) - 1.0)**2
323
324
325
def _remove_forbidden_op_kwds(op_method, op_kwds):
326
    """
327
    Remove forbidden optimization keywords.
328
329
    :param op_method:
330
        The optimization algorithm to use.
331
332
    :param op_kwds:
333
        Optimization keywords.
334
335
    :returns:
336
        `None`. The dictionary of `op_kwds` will be updated.
337
    """
338
    all_allowed_keys = dict(
339
        l_bfgs_b=("x0", "args", "bounds", "m", "factr", "pgtol", "epsilon", 
340
            "iprint", "maxfun", "maxiter", "disp", "callback", "maxls"),
341
        powell=("x0", "args", "xtol", "ftol", "maxiter", "maxfun", 
342
            "full_output", "disp", "retall", "callback", "initial_simplex"))
343
344
    forbidden_keys = set(op_kwds).difference(all_allowed_keys[op_method])
345
    if forbidden_keys:
346
        logger.warn("Ignoring forbidden optimization keywords for {}: {}"\
347
            .format(op_method, ", ".join(forbidden_keys)))
348
        for key in forbidden_keys:
349
            del op_kwds[key]
350
351
    return None
352
            
353
354
355
def fit_pixel_fixed_scatter(flux, ivar, initial_thetas, design_matrix,
356
    regularization, censoring_mask, **kwargs):
357
    """
358
    Fit theta coefficients and noise residual for a single pixel, using
359
    an initially fixed scatter value.
360
361
    :param flux:
362
        The normalized flux values.
363
364
    :param ivar:
365
        The inverse variance array for the normalized fluxes.
366
367
    :param initial_thetas:
368
        A list of initial theta values to start from, and their source. For
369
        example: `[(theta_0, "guess"), (theta_1, "old_theta")]
370
371
    :param design_matrix:
372
        The model design matrix.
373
374
    :param regularization:
375
        The regularization strength to apply during optimization (Lambda).
376
377
    :param censoring_mask:
378
        A per-label censoring mask for each pixel.
379
380
    :keyword op_method:
381
        The optimization method to use. Valid options are: `l_bfgs_b`, `powell`.
382
383
    :keyword op_kwds:
384
        A dictionary of arguments that will be provided to the optimizer.
385
386
    :returns:
387
        The optimized theta coefficients, the noise residual `s2`, and
388
        metadata related to the optimization process.
389
    """
390
391
    if np.sum(ivar) < 1.0 * ivar.size: # MAGIC
392
        metadata = dict(message="No pixel information.", op_time=0.0)
393
        fiducial = np.hstack([1.0, np.zeros(design_matrix.shape[1] - 1)])
394
        return (fiducial, np.inf, metadata) # MAGIC
395
396
    # Determine if any theta coefficients will be censored.
397
    censored_theta = ~np.any(np.isfinite(design_matrix), axis=0)
398
    # Make the design matrix safe to use.
399
    design_matrix[:, censored_theta] = 0
400
401
    feval = []
402
    for initial_theta, initial_theta_source in initial_thetas:
403
        feval.append(_pixel_objective_function_fixed_scatter(
404
            initial_theta, design_matrix, flux, ivar, regularization, False))
405
406
    initial_theta, initial_theta_source = initial_thetas[np.nanargmin(feval)]
407
408
    base_op_kwds = dict(x0=initial_theta,
409
        args=(design_matrix, flux, ivar, regularization),
410
        disp=False, maxfun=np.inf, maxiter=np.inf)
411
412
    theta_0 = kwargs.get("__theta_0", None)
413
    if theta_0 is not None:
414
        logger.warn("FIXING theta_0. HIGHLY EXPERIMENTAL.")
415
416
        # Subtract from flux.
417
        # Set design matrix entry to zero.
418
        # Update to theta later on.
419
        new_flux = flux - theta_0
420
        new_design_matrix = np.copy(design_matrix)
421
        new_design_matrix[:, 0] = 0.0
422
423
        base_op_kwds["args"] = (new_design_matrix, new_flux, ivar, regularization)
424
425
    if any(censored_theta):
426
        # If the initial_theta is the same size as the censored_mask, but different
427
        # to the design_matrix, then we need to censor the initial theta so that we
428
        # don't bother solving for those parameters.
429
        base_op_kwds["x0"] = np.array(base_op_kwds["x0"])[~censored_theta]
430
        base_op_kwds["args"] = (design_matrix[:, ~censored_theta], flux, ivar,
431
            regularization)
432
433
    # Allow either l_bfgs_b or powell
434
    t_init = time()
435
    default_op_method = "l_bfgs_b"
436
    op_method = kwargs.get("op_method", default_op_method) or default_op_method
437
    op_method = op_method.lower()
438
439
    op_strict = kwargs.get("op_strict", True)
440
441
    while True:
442
        if op_method == "l_bfgs_b":
443
            op_kwds = dict()
444
            op_kwds.update(base_op_kwds)
445
            op_kwds.update(
446
                m=design_matrix.shape[1], maxls=20, factr=10.0, pgtol=1e-6)
447
            op_kwds.update((kwargs.get("op_kwds", {}) or {}))
448
449
            # If op_bounds are given and we are censoring some theta terms, then we
450
            # will need to adjust which op_bounds we provide.
451
            if "bounds" in op_kwds and any(censored_theta):
452
                op_kwds["bounds"] = [b for b, is_censored in \
453
                    zip(op_kwds["bounds"], censored_theta) if not is_censored]
454
455
            # Just-in-time to remove forbidden keywords.
456
            _remove_forbidden_op_kwds(op_method, op_kwds)
457
458
            op_params, fopt, metadata = op.fmin_l_bfgs_b(
459
                _pixel_objective_function_fixed_scatter,
460
                fprime=None, approx_grad=None, **op_kwds)
461
462
            metadata.update(dict(fopt=fopt))
463
464
            warnflag = metadata.get("warnflag", -1)
465
            if warnflag > 0:
466
                reason = "too many function evaluations or too many iterations" \
467
                         if warnflag == 1 else metadata["task"]
468
                logger.warn("Optimization warning (l_bfgs_b): {}".format(reason))
469
470
                if op_strict:
471
                    # Do optimization again.
472
                    op_method = "powell" 
473
                    base_op_kwds.update(x0=op_params)
474
                else:
475
                    break
476
477
            else:
478
                break
479
480
        elif op_method == "powell":
481
            op_kwds = dict()
482
            op_kwds.update(base_op_kwds)
483
            op_kwds.update(xtol=1e-6, ftol=1e-6)
484
            op_kwds.update((kwargs.get("op_kwds", {}) or {}))
485
486
            # Set 'False' in args so that we don't return the gradient, 
487
            # because fmin doesn't want it.
488
            args = list(op_kwds["args"])
489
            args.append(False)
490
            op_kwds["args"] = tuple(args)
491
492
            t_init = time()
493
494
            # Just-in-time to remove forbidden keywords.
495
            _remove_forbidden_op_kwds(op_method, op_kwds)
496
497
            op_params, fopt, direc, n_iter, n_funcs, warnflag = op.fmin_powell(
498
                _pixel_objective_function_fixed_scatter, 
499
                full_output=True, **op_kwds)
500
501
            metadata = dict(fopt=fopt, direc=direc, n_iter=n_iter, 
502
                n_funcs=n_funcs, warnflag=warnflag)
503
            break
504
505
        else:
506
            raise ValueError("unknown optimization method '{}' -- "
507
                             "powell or l_bfgs_b are available".format(op_method))
508
509
    # Additional metadata common to both optimizers.
510
    metadata.update(dict(op_method=op_method, op_time=time() - t_init,
511
        initial_theta=initial_theta, initial_theta_source=initial_theta_source))
512
513
    # De-censor the optimized parameters.
514
    if any(censored_theta):
515
        theta = np.zeros(censored_theta.size)
516
        theta[~censored_theta] = op_params
517
518
    else:
519
        theta = op_params
520
521
    if theta_0 is not None:
522
        theta[0] = theta_0
523
524
    # Fit the scatter.
525
    op_fmin_kwds = dict(disp=False, maxiter=np.inf, maxfun=np.inf)
526
    op_fmin_kwds.update(
527
        xtol=op_kwds.get("xtol", 1e-8), ftol=op_kwds.get("ftol", 1e-8))
528
529
    residuals_squared = (flux - np.dot(theta, design_matrix.T))**2
530
    scatter = op.fmin(_scatter_objective_function, 0.0,
531
        args=(residuals_squared, ivar), disp=False)
532
533
    return (theta, scatter**2, metadata)
534