CannonModel.theta()   A
last analyzed

Complexity

Conditions 1

Size

Total Lines 4

Duplication

Lines 0
Ratio 0 %

Importance

Changes 2
Bugs 0 Features 0
Metric Value
cc 1
c 2
b 0
f 0
dl 0
loc 4
rs 10
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
"""
5
The Cannon.
6
"""
7
8
from __future__ import (division, print_function, absolute_import,
9
                        unicode_literals)
10
11
__all__ = ["CannonModel"]
12
13
import logging
14
import multiprocessing as mp
15
import numpy as np
16
import os
17
import pickle
18
from datetime import datetime
19
from functools import wraps
20
from sys import version_info
21
from scipy.spatial import Delaunay
22
23
from .vectorizer.base import BaseVectorizer
24
from . import (censoring, fitting, utils, vectorizer as vectorizer_module, __version__)
25
26
27
logger = logging.getLogger(__name__)
28
29
30
def requires_training(method):
31
    """
32
    A decorator for model methods that require training before being run.
33
34
    :param method:
35
        A method belonging to CannonModel.
36
    """
37
    @wraps(method)
38
    def wrapper(model, *args, **kwargs):
39
        if not model.is_trained:
40
            raise TypeError("the model requires training first")
41
        return method(model, *args, **kwargs)
42
    return wrapper
43
44
45
class CannonModel(object):
46
    """
47
    A model for The Cannon which includes L1 regularization and pixel censoring.
48
49
    :param training_set_labels:
50
        A set of objects with labels known to high fidelity. This can be 
51
        given as a numpy structured array, or an astropy table.
52
53
    :param training_set_flux:
54
        An array of normalised fluxes for stars in the labelled set, given 
55
        as shape `(num_stars, num_pixels)`. The `num_stars` should match the
56
        number of rows in `training_set_labels`.
57
58
    :param training_set_ivar:
59
        An array of inverse variances on the normalized fluxes for stars in 
60
        the training set. The shape of the `training_set_ivar` array should
61
        match that of `training_set_flux`.
62
63
    :param vectorizer:
64
        A vectorizer to take input labels and produce a design matrix. This
65
        should be a sub-class of `vectorizer.BaseVectorizer`.
66
67
    :param dispersion: [optional]
68
        The dispersion values corresponding to the given pixels. If provided, 
69
        this should have a size of `num_pixels`.
70
    
71
    :param regularization: [optional]
72
        The strength of the L1 regularization. This should either be `None`,
73
        a float-type value for single regularization strength for all pixels,
74
        or a float-like array of length `num_pixels`.
75
76
    :param censors: [optional]
77
        A dictionary containing label names as keys and boolean censoring
78
        masks as values.
79
    """
80
81
    _data_attributes = \
82
        ("training_set_labels", "training_set_flux", "training_set_ivar")
83
84
    # Descriptive attributes are needed to train *and* test the model.
85
    _descriptive_attributes = \
86
        ("vectorizer", "censors", "regularization", "dispersion")
87
88
    # Trained attributes are set only at training time.
89
    _trained_attributes = ("theta", "s2")
90
    
91
    def __init__(self, training_set_labels, training_set_flux, training_set_ivar,
92
        vectorizer, dispersion=None, regularization=None, censors=None, **kwargs):
93
94
        # Save the vectorizer.
95
        if not isinstance(vectorizer, BaseVectorizer):
96
            raise TypeError(
97
                "vectorizer must be a sub-class of vectorizer.BaseVectorizer")
98
        
99
        self._vectorizer = vectorizer
100
        
101
        if training_set_flux is None and training_set_ivar is None:
102
103
            # Must be reading in a model that does not have the training set
104
            # spectra saved.
105
            self._training_set_flux = None
106
            self._training_set_ivar = None
107
            self._training_set_labels = training_set_labels
108
109
        else:
110
            self._training_set_flux = np.atleast_2d(training_set_flux)
111
            self._training_set_ivar = np.atleast_2d(training_set_ivar)
112
            
113
            if isinstance(training_set_labels, np.ndarray) \
114
            and training_set_labels.shape[0] == self._training_set_flux.shape[0] \
115
            and training_set_labels.shape[1] == len(vectorizer.label_names):
116
                # A valid array was given as the training set labels, not a table.
117
                self._training_set_labels = training_set_labels
118
            else: 
119
                self._training_set_labels = np.array(
120
                    [training_set_labels[ln] for ln in vectorizer.label_names]).T
121
            
122
            # Check that the flux and ivar are valid.
123
            self._verify_training_data(**kwargs)
124
125
        # Set regularization, censoring, dispersion.
126
        self.regularization = regularization
127
        self.censors = censors
128
        self.dispersion = dispersion
129
130
        # Set useful private attributes.
131
        __scale_labels_function = kwargs.get("__scale_labels_function", 
132
            lambda l: np.ptp(np.percentile(l, [2.5, 97.5], axis=0), axis=0))
133
        __fiducial_labels_function = kwargs.get("__fiducial_labels_function",
134
            lambda l: np.percentile(l, 50, axis=0))
135
136
        self._scales = __scale_labels_function(self.training_set_labels)
137
        self._fiducials = __fiducial_labels_function(self.training_set_labels)
138
        self._design_matrix = vectorizer(
139
            (self.training_set_labels - self._fiducials)/self._scales).T
140
141
        self.reset()
142
143
        return None
144
145
146
    # Representations.
147
148
149
    def __str__(self):
150
        return "<{module}.{name} of {K} labels {trained}with a training set "\
151
               "of {N} stars each with {M} pixels>".format(
152
                    module=self.__module__,
153
                    name=type(self).__name__,
154
                    trained="trained " if self.is_trained else "",
155
                    K=self.training_set_labels.shape[1],
156
                    N=self.training_set_labels.shape[0], 
157
                    M=self.training_set_flux.shape[1])
158
159
160
    def __repr__(self):
161
        return "<{0}.{1} object at {2}>".format(self.__module__, 
162
            type(self).__name__, hex(id(self)))
163
164
165
    # Model attributes that cannot (well, should not) be changed.
166
167
168
    @property
169
    def training_set_labels(self):
170
        """ Return the labels in the training set. """
171
        return self._training_set_labels
172
173
174
    @property
175
    def training_set_flux(self):
176
        """ Return the training set fluxes. """
177
        return self._training_set_flux
178
179
180
    @property
181
    def training_set_ivar(self):
182
        """ Return the inverse variances of the training set fluxes. """
183
        return self._training_set_ivar
184
185
186
    @property
187
    def vectorizer(self):
188
        """ Return the vectorizer for this model. """
189
        return self._vectorizer
190
191
192
    @property
193
    def design_matrix(self):
194
        """ Return the design matrix for this model. """
195
        return self._design_matrix
196
197
198
    def _censored_design_matrix(self, pixel_index, fill_value=np.nan):
199
        """
200
        Return a censored design matrix for the given pixel index, and a mask of
201
        which theta values to ignore when fitting.
202
    
203
        :param pixel_index:
204
            The zero-indexed pixel.
205
206
        :returns:
207
            A two-length tuple containing the censored design mask for this
208
            pixel, and a boolean mask of values to exclude when fitting for
209
            the spectral derivatives.
210
        """
211
212
        if not self.censors or self.censors is None \
213
        or len(set(self.censors).intersection(self.vectorizer.label_names)) == 0:
214
            return self.design_matrix
215
216
        data = (self.training_set_labels.copy() - self._fiducials)/self._scales
217
        for i, label_name in enumerate(self.vectorizer.label_names):
218
            try:
219
                use = self.censors[label_name][pixel_index]
220
221
            except KeyError:
222
                continue
223
224
            if not use:
225
                data[:, i] = fill_value
226
227
        return self.vectorizer(data).T
228
229
230
    @property
231
    def theta(self):
232
        """ Return the theta coefficients (spectral model derivatives). """
233
        return self._theta
234
235
236
    @property
237
    def s2(self):
238
        """ Return the intrinsic variance (s^2) for all pixels. """
239
        return self._s2
240
241
242
    # Model attributes that can be changed after initiation.
243
244
245
    @property
246
    def censors(self):
247
        """ Return the wavelength censor masks for the labels. """
248
        return self._censors
249
250
251
    @censors.setter
252
    def censors(self, censors):
253
        """
254
        Set label censoring masks for each pixel.
255
256
        :param censors:
257
            A dictionary-like object with label names as keys, and boolean arrays
258
            as values.
259
        """
260
261
        censors = {} if censors is None else censors
262
        if isinstance(censors, censoring.Censors):
263
            # Could be a censoring dictionary from a different model,
264
            # with different label names and pixels.
265
            
266
            # But more likely: we are loading a model from disk.
267
            self._censors = censors
268
269
        elif isinstance(censors, dict):
270
            self._censors = censoring.Censors(
271
                self.vectorizer.label_names, self.training_set_flux.shape[1],
272
                censors)
273
274
        else:
275
            raise TypeError(
276
                "censors must be a dictionary or a censoring.Censors object")
277
278
279
    @property
280
    def dispersion(self):
281
        """ Return the dispersion points for all pixels. """
282
        return self._dispersion
283
284
285
    @dispersion.setter
286
    def dispersion(self, dispersion):
287
        """
288
        Set the dispersion values for all the pixels.
289
290
        :param dispersion:
291
            An array of the dispersion values.
292
        """
293
        if dispersion is None:
294
            self._dispersion = None
295
            return None
296
297
        dispersion = np.array(dispersion).flatten()
298
        if self.training_set_flux is not None \
299
        and dispersion.size != self.training_set_flux.shape[1]:
300
            raise ValueError("dispersion provided does not match the number "
301
                             "of pixels per star ({0} != {1})".format(
302
                                dispersion.size, self.training_set_flux.shape[1]))
303
304
        if dispersion.dtype.kind not in "iuf":
305
            raise ValueError("dispersion values are not float-like")
306
307
        if not np.all(np.isfinite(dispersion)):
308
            raise ValueError("dispersion values must be finite")
309
310
        self._dispersion = dispersion
311
        return None
312
313
314
    @property
315
    def regularization(self):
316
        """ Return the strength of the L1 regularization for this model. """
317
        return self._regularization
318
319
320
    @regularization.setter
321
    def regularization(self, regularization):
322
        """
323
        Specify the strength of the regularization for the model, either as a
324
        single value for all pixels, or a different strength for each pixel.
325
326
        :param regularization:
327
            The L1-regularization strength for the model.
328
        """
329
330
        if regularization is None:
331
            self._regularization = None
332
            return None
333
334
        regularization = np.array(regularization).flatten()
335
        if regularization.size == 1:
336
            regularization = regularization[0]
337
            if 0 > regularization or not np.isfinite(regularization):
338
                raise ValueError("regularization must be positive and finite")
339
340
        elif regularization.size != self.training_set_flux.shape[1]:
341
            raise ValueError("regularization array must be of size `num_pixels`")
342
343
            if any(0 > regularization) \
344
            or not np.all(np.isfinite(regularization)):
345
                raise ValueError("regularization must be positive and finite")
346
347
        self._regularization = regularization
348
        return None
349
350
351
    # Convenient functions and properties.
352
353
354
    @property
355
    def is_trained(self):
356
        """ Return true or false for whether the model is trained. """
357
        return all(getattr(self, attr, None) is not None \
358
            for attr in self._trained_attributes)
359
360
361
    def reset(self):
362
        """ Clear any attributes that have been trained. """
363
        for attribute in self._trained_attributes:
364
            setattr(self, "_{}".format(attribute), None)
365
        return None
366
367
368
    def _pixel_access(self, array, index, default=None):
369
        """
370
        Safely access a (potentially per-pixel) attribute of the model.
371
        
372
        :param array:
373
            Either `None`, a float value, or an array the size of the dispersion
374
            array.
375
376
        :param index:
377
            The zero-indexed pixel to attempt to access.
378
379
        :param default: [optional]
380
            The default value to return if `array` is None.
381
        """
382
383
        if array is None:
384
            return default
385
        try:
386
            return array[index]
387
        except (IndexError, TypeError):
388
            return array
389
390
391
    def _verify_training_data(self, rho_warning=0.90):
392
        """
393
        Verify the training data for the appropriate shape and content.
394
395
        :param rho_warning: [optional]
396
            Maximum correlation value between labels before a warning is given.
397
        """
398
399
        if self.training_set_flux.shape != self.training_set_ivar.shape:
400
            raise ValueError("the training set flux and inverse variance arrays"
401
                             " for the labelled set must have the same shape")
402
403
        if len(self.training_set_labels) != self.training_set_flux.shape[0]:
404
            raise ValueError(
405
                "the first axes of the training set flux array should "
406
                "have the same shape as the nuber of rows in the labelled set"
407
                "(N_stars, N_pixels)")
408
409
        if not np.all(np.isfinite(self.training_set_labels)):
410
            raise ValueError("training set labels are not all finite")
411
412
        if not np.all(np.isfinite(self.training_set_flux)):
413
            raise ValueError("training set fluxes are not all finite")
414
415
        if not np.all(self.training_set_ivar >= 0) \
416
        or not np.all(np.isfinite(self.training_set_ivar)):
417
            raise ValueError("training set ivars are not all positive finite")
418
419
        # Look for very high correlation coefficients between labels, which
420
        # could make the training time very difficult.
421
        rho = np.corrcoef(self.training_set_labels.T)
422
423
        # Set the diagonal indices to zero.
424
        K = rho.shape[0]
425
        rho[np.diag_indices(K)] = 0.0
426
        indices = np.argsort(rho.flatten())[::-1]
427
428
        for index in indices:
429
            x, y = (index % K, int(index / K)) 
430
            rho_xy = rho[x, y]
431
            if rho_xy >= rho_warning: 
432
                if x > y: # One warning per correlated label pair.
433
                    logger.warn("Labels '{X}' and '{Y}' are highly correlated ("\
434
                        "rho = {rho_xy:.2}). This may cause very slow training "\
435
                        "times. Are both labels needed?".format(
436
                            X=self.vectorizer.label_names[x],
437
                            Y=self.vectorizer.label_names[y],
438
                            rho_xy=rho_xy))
439
            else:
440
                break
441
        return None
442
443
444
    def in_convex_hull(self, labels):
445
        """
446
        Return whether the provided labels are inside a complex hull constructed
447
        from the labelled set.
448
449
        :param labels:
450
            A `NxK` array of `N` sets of `K` labels, where `K` is the number of
451
            labels that make up the vectorizer.
452
453
        :returns:
454
            A boolean array as to whether the points are in the complex hull of
455
            the labelled set.
456
        """
457
458
        labels = np.atleast_2d(labels)
459
        if labels.shape[1] != self.training_set_labels.shape[1]:
460
            raise ValueError("expected {} labels; got {}".format(
461
                self.training_set_labels.shape[1], labels.shape[1]))
462
463
        hull = Delaunay(self.training_set_labels)
464
        return hull.find_simplex(labels) >= 0
465
466
467
    def write(self, path, include_training_set_spectra=False, overwrite=False,
468
        protocol=-1):
469
        """
470
        Serialise the trained model and save it to disk. This will save all
471
        relevant training attributes, and optionally, the training data.
472
473
        :param path:
474
            The path to save the model to.
475
476
        :param include_training_set_spectra: [optional]
477
            Save the labelled set, normalised flux and inverse variance used to
478
            train the model.
479
480
        :param overwrite: [optional]
481
            Overwrite the existing file path, if it already exists.
482
483
        :param protocol: [optional]
484
            The Python pickling protocol to employ. Use 2 for compatibility with
485
            previous Python releases, -1 for performance.
486
        """
487
488
        if os.path.exists(path) and not overwrite:
489
            raise IOError("path already exists: {0}".format(path))
490
491
        attributes = list(self._descriptive_attributes) \
492
                   + list(self._trained_attributes) \
493
                   + list(self._data_attributes)
494
495
        if "metadata" in attributes:
496
            logger.warn("'metadata' is a protected attribute. Ignoring.")
497
            attributes.remote("metadata")
498
499
        # Store up all the trained attributes and a hash of the training set.
500
        state = {}
501
        for attribute in attributes:
502
503
            value = getattr(self, attribute)
504
505
            try:
506
                # If it's a vectorizer or censoring dict, etc, get the state.
507
                value = value.__getstate__()
508
            except:
509
                None
510
511
            state[attribute] = value
512
513
        # Create a metadata dictionary.
514
        state["metadata"] = dict(
515
            version=__version__,
516
            model_class=type(self).__name__,
517
            modified=str(datetime.now()),
518
            data_attributes=self._data_attributes,
519
            descriptive_attributes=self._descriptive_attributes,
520
            trained_attributes=self._trained_attributes,
521
            training_set_hash=utils.short_hash(
522
                getattr(self, attr) for attr in self._data_attributes),
523
        )
524
525
        if not include_training_set_spectra:
526
            state.pop("training_set_flux")
527
            state.pop("training_set_ivar")
528
529
        elif not self.is_trained:
530
            logger.warn("The training set spectra won't be saved, and this model"\
531
                        "is not already trained. The saved model will not be "\
532
                        "able to be trained when loaded!")
533
534
        with open(path, "wb") as fp:
535
            pickle.dump(state, fp, protocol) 
536
        return None
537
538
539
    @classmethod
540
    def read(cls, path, **kwargs):
541
        """
542
        Read a saved model from disk.
543
544
        :param path:
545
            The path where to load the model from.
546
        """
547
548
        encodings = ("utf-8", "latin-1")
549
        for encoding in encodings:
550
            kwds = {"encoding": encoding} if version_info[0] >= 3 else {}
551
            try:
552
                with open(path, "rb") as fp:        
553
                    state = pickle.load(fp, **kwds)
554
555
            except UnicodeDecodeError:
556
                if encoding == encodings:
557
                    raise
558
559
        # Parse the state.
560
        metadata = state.get("metadata", {})
561
        version_saved = metadata.get("version", "0.1.0")
562
        if version_saved >= "0.2.0": # Refactor'd.
563
564
            init_attributes = list(metadata["data_attributes"]) \
565
                            + list(metadata["descriptive_attributes"])
566
567
            kwds = dict([(a, state.get(a, None)) for a in init_attributes])
568
569
            # Initiate the vectorizer.
570
            vectorizer_class, vectorizer_kwds = kwds["vectorizer"]
571
            klass = getattr(vectorizer_module, vectorizer_class)
572
            kwds["vectorizer"] = klass(**vectorizer_kwds)
573
574
            # Initiate the censors.
575
            kwds["censors"] = censoring.Censors(**kwds["censors"])
576
577
            model = cls(**kwds)
578
579
            # Set training attributes.
580
            for attr in metadata["trained_attributes"]:
581
                setattr(model, "_{}".format(attr), state.get(attr, None))
582
583
            return model
584
            
585
        else:
586
            raise NotImplementedError(
587
                "Cannot auto-convert old model files yet; "
588
                "contact Andy Casey <[email protected]> if you need this")
589
590
591
    def train(self, threads=None, op_method=None, op_strict=True, op_kwds=None,
592
        **kwargs):
593
        """
594
        Train the model.
595
596
        :param threads: [optional]
597
            The number of parallel threads to use.
598
599
        :param op_method: [optional]
600
            The optimization algorithm to use: l_bfgs_b (default) and powell
601
            are available.
602
603
        :param op_strict: [optional]
604
            Default to Powell's optimization method if BFGS fails.
605
606
        :param op_kwds:
607
            Keyword arguments to provide directly to the optimization function.
608
609
        :returns:
610
            A three-length tuple containing the spectral coefficients `theta`,
611
            the squared scatter term at each pixel `s2`, and metadata related to
612
            the training of each pixel.
613
        """
614
615
        kwds = dict(op_method=op_method, op_strict=op_strict, op_kwds=op_kwds)
616
        kwds.update(kwargs)
617
618
        if self.training_set_flux is None or self.training_set_ivar is None:
619
            raise TypeError(
620
                "cannot train: training set spectra not saved with the model")
621
622
        S, P = self.training_set_flux.shape
623
        T = self.design_matrix.shape[1]
624
625
        logger.info("Training {0}-label {1} with {2} stars and {3} pixels/star"\
626
            .format(len(self.vectorizer.label_names), type(self).__name__, S, P))
627
628
        # Parallelise out.
629
        if threads in (1, None):
630
            mapper, pool = (map, None)
631
632
        else:
633
            pool = mp.Pool(threads)
634
            mapper = pool.map
635
636
        func = utils.wrapper(fitting.fit_pixel_fixed_scatter, None, kwds, P)
637
638
        meta = []
639
        theta = np.nan * np.ones((P, T))
640
        s2 = np.nan * np.ones(P)
641
642
        for pixel, (flux, ivar) \
643
        in enumerate(zip(self.training_set_flux.T, self.training_set_ivar.T)):
644
645
            args = (
646
                flux, ivar, 
647
                self._initial_theta(pixel),
648
                self._censored_design_matrix(pixel),
649
                self._pixel_access(self.regularization, pixel, 0.0),
650
                None
651
            )
652
            (pixel_theta, pixel_s2, pixel_meta), = mapper(func, [args])
653
654
            meta.append(pixel_meta)
655
            theta[pixel], s2[pixel] = (pixel_theta, pixel_s2)
656
657
        self._theta, self._s2 = (theta, s2)
658
659
        if pool is not None:
660
            pool.close()
661
            pool.join()
662
663
        return (theta, s2, meta)
664
665
666
    @requires_training
667
    def __call__(self, labels):
668
        """
669
        Return spectral fluxes, given the labels.
670
671
        :param labels:
672
            An array of stellar labels.
673
        """
674
675
        # Scale and offset the labels.
676
        scaled_labels = (np.atleast_2d(labels) - self._fiducials)/self._scales
677
        flux = np.dot(self.theta, self.vectorizer(scaled_labels)).T
678
        return flux[0] if flux.shape[0] == 1 else flux
679
680
681
    @requires_training
682
    def test(self, flux, ivar, initial_labels=None, threads=None, 
683
        use_derivatives=True, op_kwds=None):
684
        """
685
        Run the test step on spectra.
686
687
        :param flux:
688
            The (pseudo-continuum-normalized) spectral flux.
689
690
        :param ivar:
691
            The inverse variance values for the spectral fluxes.
692
693
        :param initial_labels: [optional]
694
            The initial labels to try for each spectrum. This can be a single
695
            set of initial values, or one set of initial values for each star.
696
697
        :param threads: [optional]
698
            The number of parallel threads to use.
699
700
        :param use_derivatives: [optional]
701
            Boolean `True` indicating to use analytic derivatives provided by 
702
            the vectorizer, `None` to calculate on the fly, or a callable
703
            function to calculate your own derivatives.
704
705
        :param op_kwds: [optional]
706
            Optimization keywords that get passed to `scipy.optimize.leastsq`.
707
        """
708
709
        if flux is None or ivar is None:
710
            raise ValueError("flux and ivar must not be None")
711
712
        if op_kwds is None:
713
            op_kwds = dict()
714
715
        if threads in (1, None):
716
            mapper, pool = (map, None)
717
718
        else:
719
            pool = mp.Pool(threads)
720
            mapper = pool.map
721
722
        flux, ivar = (np.atleast_2d(flux), np.atleast_2d(ivar))
723
        S, P = flux.shape
724
725
        if ivar.shape != flux.shape:
726
            raise ValueError("flux and ivar arrays must be the same shape")
727
728
        if initial_labels is None:
729
            initial_labels = self._fiducials
730
731
        initial_labels = np.atleast_2d(initial_labels)
732
        if initial_labels.shape[0] != S and len(initial_labels.shape) == 2:
733
            initial_labels = np.tile(initial_labels.flatten(), S)\
734
                             .reshape(S, -1, len(self._fiducials))
735
736
        args = (self.vectorizer, self.theta, self.s2, self._fiducials, 
737
            self._scales)
738
        kwargs = dict(use_derivatives=use_derivatives, op_kwds=op_kwds)
739
740
        func = utils.wrapper(fitting.fit_spectrum, args, kwargs, S,
741
            message="Running test step on {} spectra".format(S))
742
743
        labels, cov, meta = zip(*mapper(func, zip(*(flux, ivar, initial_labels))))
744
745
        if pool is not None:
746
            pool.close()
747
            pool.join()
748
749
        return (np.array(labels), np.array(cov), meta)
750
751
752
    def _initial_theta(self, pixel_index, **kwargs):
753
        """
754
        Return a list of guesses of the spectral coefficients for the given
755
        pixel index. Initial values are sourced in the following preference
756
        order: 
757
758
            (1) a previously trained `theta` value for this pixel,
759
            (2) an estimate of `theta` using linear algebra,
760
            (3) a neighbouring pixel's `theta` value,
761
            (4) the fiducial value of [1, 0, ..., 0].
762
763
        :param pixel_index:
764
            The zero-indexed integer of the pixel.
765
766
        :returns:
767
            A list of initial theta guesses, and the source of each guess.
768
        """
769
770
        guesses = []
771
772
        if self.theta is not None:
773
            # Previously trained theta value.
774
            if np.all(np.isfinite(self.theta[pixel_index])):
775
                guesses.append((self.theta[pixel_index], "previously_trained"))
776
777
        # Estimate from linear algebra.
778
        theta, cov = fitting.fit_theta_by_linalg(
779
            self.training_set_flux[:, pixel_index],
780
            self.training_set_ivar[:, pixel_index],
781
            s2=kwargs.get("s2", 0.0), design_matrix=self.design_matrix)
782
783
        if np.all(np.isfinite(theta)):
784
            guesses.append((theta, "linear_algebra"))
785
786
        if self.theta is not None:
787
            # Neighbouring pixels value.
788
            for neighbour_pixel_index in set(np.clip(
789
                [pixel_index - 1, pixel_index + 1], 
790
                0, self.training_set_flux.shape[1] - 1)):
791
792
                if np.all(np.isfinite(self.theta[neighbour_pixel_index])):
793
                    guesses.append(
794
                        (self.theta[neighbour_pixel_index], "neighbour_pixel"))
795
796
        # Fiducial value.
797
        fiducial = np.hstack([1.0, np.zeros(len(self.vectorizer.terms))])
798
        guesses.append((fiducial, "fiducial"))
799
800
        return guesses
801