Passed
Pull Request — master (#2635)
by Axel
02:31
created

FluxPointsDataset.models()   A

Complexity

Conditions 4

Size

Total Lines 3
Code Lines 10

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 10
dl 0
loc 3
rs 9.9
c 0
b 0
f 0
cc 4
nop 2
1
# Licensed under a 3-clause BSD style license - see LICENSE.rst
2
import logging
3
import numpy as np
4
from astropy import units as u
5
from astropy.io.registry import IORegistryError
6
from astropy.table import Table, vstack
7
from gammapy.modeling import Dataset, Datasets, Fit, Parameters
8
from gammapy.modeling.models import (
9
    PowerLawSpectralModel,
10
    ScaleSpectralModel,
11
    SkyModel,
12
    SkyModels,
13
)
14
from gammapy.utils.interpolation import interpolate_profile
15
from gammapy.utils.scripts import make_path
16
from gammapy.utils.table import table_from_row_data, table_standardise_units_copy
17
from .dataset import SpectrumDatasetOnOff
18
19
__all__ = ["FluxPoints", "FluxPointsEstimator", "FluxPointsDataset"]
20
21
log = logging.getLogger(__name__)
22
23
REQUIRED_COLUMNS = {
24
    "dnde": ["e_ref", "dnde"],
25
    "e2dnde": ["e_ref", "e2dnde"],
26
    "flux": ["e_min", "e_max", "flux"],
27
    "eflux": ["e_min", "e_max", "eflux"],
28
    # TODO: extend required columns
29
    "likelihood": [
30
        "e_min",
31
        "e_max",
32
        "e_ref",
33
        "ref_dnde",
34
        "norm",
35
        "norm_scan",
36
        "stat_scan",
37
    ],
38
}
39
40
OPTIONAL_COLUMNS = {
41
    "dnde": ["dnde_err", "dnde_errp", "dnde_errn", "dnde_ul", "is_ul"],
42
    "e2dnde": ["e2dnde_err", "e2dnde_errp", "e2dnde_errn", "e2dnde_ul", "is_ul"],
43
    "flux": ["flux_err", "flux_errp", "flux_errn", "flux_ul", "is_ul"],
44
    "eflux": ["eflux_err", "eflux_errp", "eflux_errn", "eflux_ul", "is_ul"],
45
}
46
47
DEFAULT_UNIT = {
48
    "dnde": u.Unit("cm-2 s-1 TeV-1"),
49
    "e2dnde": u.Unit("erg cm-2 s-1"),
50
    "flux": u.Unit("cm-2 s-1"),
51
    "eflux": u.Unit("erg cm-2 s-1"),
52
}
53
54
55
class FluxPoints:
56
    """Flux points container.
57
58
    The supported formats are described here: :ref:`gadf:flux-points`
59
60
    In summary, the following formats and minimum required columns are:
61
62
    * Format ``dnde``: columns ``e_ref`` and ``dnde``
63
    * Format ``e2dnde``: columns ``e_ref``, ``e2dnde``
64
    * Format ``flux``: columns ``e_min``, ``e_max``, ``flux``
65
    * Format ``eflux``: columns ``e_min``, ``e_max``, ``eflux``
66
67
    Parameters
68
    ----------
69
    table : `~astropy.table.Table`
70
        Table with flux point data
71
72
    Attributes
73
    ----------
74
    table : `~astropy.table.Table`
75
        Table with flux point data
76
77
    Examples
78
    --------
79
    The `FluxPoints` object is most easily created by reading a file with
80
    flux points given in one of the formats documented above::
81
82
        from gammapy.spectrum import FluxPoints
83
        filename = '$GAMMAPY_DATA/tests/spectrum/flux_points/flux_points.fits'
84
        flux_points = FluxPoints.read(filename)
85
        flux_points.plot()
86
87
    An instance of `FluxPoints` can also be created by passing an instance of
88
    `astropy.table.Table`, which contains the required columns, such as `'e_ref'`
89
    and `'dnde'`. The corresponding `sed_type` has to be defined in the meta data
90
    of the table::
91
92
        from astropy import units as u
93
        from astropy.table import Table
94
        from gammapy.spectrum import FluxPoints
95
        from gammapy.modeling.models import PowerLawSpectralModel
96
97
        table = Table()
98
        pwl = PowerLawSpectralModel()
99
        e_ref = np.logspace(0, 2, 7) * u.TeV
100
        table['e_ref'] = e_ref
101
        table['dnde'] = pwl(e_ref)
102
        table.meta['SED_TYPE'] = 'dnde'
103
104
        flux_points = FluxPoints(table)
105
        flux_points.plot()
106
107
    If you have flux points in a different data format, the format can be changed
108
    by renaming the table columns and adding meta data::
109
110
111
        from astropy import units as u
112
        from astropy.table import Table
113
        from gammapy.spectrum import FluxPoints
114
115
        table = Table.read('$GAMMAPY_DATA/tests/spectrum/flux_points/flux_points_ctb_37b.txt',
116
                           format='ascii.csv', delimiter=' ', comment='#')
117
        table.meta['SED_TYPE'] = 'dnde'
118
        table.rename_column('Differential_Flux', 'dnde')
119
        table['dnde'].unit = 'cm-2 s-1 TeV-1'
120
121
        table.rename_column('lower_error', 'dnde_errn')
122
        table['dnde_errn'].unit = 'cm-2 s-1 TeV-1'
123
124
        table.rename_column('upper_error', 'dnde_errp')
125
        table['dnde_errp'].unit = 'cm-2 s-1 TeV-1'
126
127
        table.rename_column('E', 'e_ref')
128
        table['e_ref'].unit = 'TeV'
129
130
        flux_points = FluxPoints(table)
131
        flux_points.plot()
132
133
    """
134
135
    def __init__(self, table):
136
        self.table = table_standardise_units_copy(table)
137
        # validate that the table is a valid representation
138
        # of the given flux point sed type
139
        self._validate_table(self.table, table.meta["SED_TYPE"])
140
141
    def __repr__(self):
142
        return f"{self.__class__.__name__}(sed_type={self.sed_type!r}, n_points={len(self.table)})"
143
144
    @property
145
    def table_formatted(self):
146
        """Return formatted version of the flux points table. Used for pretty printing"""
147
        table = self.table.copy()
148
149
        for column in table.colnames:
150
            if column.startswith(("dnde", "eflux", "flux", "e2dnde", "ref")):
151
                table[column].format = ".3e"
152
            elif column.startswith(
153
                ("e_min", "e_max", "e_ref", "sqrt_ts", "norm", "ts", "stat")
154
            ):
155
                table[column].format = ".3f"
156
157
        return table
158
159
    @classmethod
160
    def read(cls, filename, **kwargs):
161
        """Read flux points.
162
163
        Parameters
164
        ----------
165
        filename : str
166
            Filename
167
        kwargs : dict
168
            Keyword arguments passed to `astropy.table.Table.read`.
169
        """
170
        filename = make_path(filename)
171
        try:
172
            table = Table.read(filename, **kwargs)
173
        except IORegistryError:
174
            kwargs.setdefault("format", "ascii.ecsv")
175
            table = Table.read(filename, **kwargs)
176
177
        if "SED_TYPE" not in table.meta.keys():
178
            sed_type = cls._guess_sed_type(table)
179
            table.meta["SED_TYPE"] = sed_type
180
181
        # TODO: check sign and factor 2 here
182
        # https://github.com/gammapy/gammapy/pull/2546#issuecomment-554274318
183
        # The idea below is to support the format here:
184
        # https://gamma-astro-data-formats.readthedocs.io/en/latest/spectra/flux_points/index.html#likelihood-columns
185
        # but internally to go to the uniform "stat"
186
187
        if "loglike" in table.colnames and "stat" not in table.colnames:
188
            table["stat"] = 2 * table["loglike"]
189
190
        if "loglike_null" in table.colnames and "stat_null" not in table.colnames:
191
            table["stat_null"] = 2 * table["loglike_null"]
192
193
        if "dloglike_scan" in table.colnames and "stat_scan" not in table.colnames:
194
            table["stat_scan"] = 2 * table["dloglike_scan"]
195
196
        return cls(table=table)
197
198
    def write(self, filename, **kwargs):
199
        """Write flux points.
200
201
        Parameters
202
        ----------
203
        filename : str
204
            Filename
205
        kwargs : dict
206
            Keyword arguments passed to `astropy.table.Table.write`.
207
        """
208
        filename = make_path(filename)
209
        try:
210
            self.table.write(filename, **kwargs)
211
        except IORegistryError:
212
            kwargs.setdefault("format", "ascii.ecsv")
213
            self.table.write(filename, **kwargs)
214
215
    @classmethod
216
    def stack(cls, flux_points):
217
        """Create flux points by stacking list of flux points.
218
219
        The first `FluxPoints` object in the list is taken as a reference to infer
220
        column names and units for the stacked object.
221
222
        Parameters
223
        ----------
224
        flux_points : list of `FluxPoints`
225
            List of flux points to stack.
226
227
        Returns
228
        -------
229
        flux_points : `FluxPoints`
230
            Flux points without upper limit points.
231
        """
232
        reference = flux_points[0].table
233
234
        tables = []
235
        for _ in flux_points:
236
            table = _.table
237
            for colname in reference.colnames:
238
                column = reference[colname]
239
                if column.unit:
240
                    table[colname] = table[colname].quantity.to(column.unit)
241
            tables.append(table[reference.colnames])
242
243
        table_stacked = vstack(tables)
244
        table_stacked.meta["SED_TYPE"] = reference.meta["SED_TYPE"]
245
246
        return cls(table_stacked)
247
248
    def drop_ul(self):
249
        """Drop upper limit flux points.
250
251
        Returns
252
        -------
253
        flux_points : `FluxPoints`
254
            Flux points with upper limit points removed.
255
256
        Examples
257
        --------
258
        >>> from gammapy.spectrum import FluxPoints
259
        >>> filename = '$GAMMAPY_DATA/tests/spectrum/flux_points/flux_points.fits'
260
        >>> flux_points = FluxPoints.read(filename)
261
        >>> print(flux_points)
262
        FluxPoints(sed_type="flux", n_points=24)
263
        >>> print(flux_points.drop_ul())
264
        FluxPoints(sed_type="flux", n_points=19)
265
        """
266
        table_drop_ul = self.table[~self.is_ul]
267
        return self.__class__(table_drop_ul)
268
269
    def _flux_to_dnde(self, e_ref, table, model, pwl_approx):
270
        if model is None:
271
            model = PowerLawSpectralModel()
272
273
        e_min, e_max = self.e_min, self.e_max
274
275
        flux = table["flux"].quantity
276
        dnde = self._dnde_from_flux(flux, model, e_ref, e_min, e_max, pwl_approx)
277
278
        # Add to result table
279
        table["e_ref"] = e_ref
280
        table["dnde"] = dnde
281
282
        if "flux_err" in table.colnames:
283
            table["dnde_err"] = dnde * table["flux_err"].quantity / flux
284
285
        if "flux_errn" in table.colnames:
286
            table["dnde_errn"] = dnde * table["flux_errn"].quantity / flux
287
            table["dnde_errp"] = dnde * table["flux_errp"].quantity / flux
288
289
        if "flux_ul" in table.colnames:
290
            flux_ul = table["flux_ul"].quantity
291
            dnde_ul = self._dnde_from_flux(
292
                flux_ul, model, e_ref, e_min, e_max, pwl_approx
293
            )
294
            table["dnde_ul"] = dnde_ul
295
296
        return table
297
298
    @staticmethod
299
    def _dnde_to_e2dnde(e_ref, table):
300
        for suffix in ["", "_ul", "_err", "_errp", "_errn"]:
301
            try:
302
                data = table["dnde" + suffix].quantity
303
                table["e2dnde" + suffix] = (e_ref ** 2 * data).to(
304
                    DEFAULT_UNIT["e2dnde"]
305
                )
306
            except KeyError:
307
                continue
308
309
        return table
310
311
    @staticmethod
312
    def _e2dnde_to_dnde(e_ref, table):
313
        for suffix in ["", "_ul", "_err", "_errp", "_errn"]:
314
            try:
315
                data = table["e2dnde" + suffix].quantity
316
                table["dnde" + suffix] = (data / e_ref ** 2).to(DEFAULT_UNIT["dnde"])
317
            except KeyError:
318
                continue
319
320
        return table
321
322
    def to_sed_type(self, sed_type, method="log_center", model=None, pwl_approx=False):
323
        """Convert to a different SED type (return new `FluxPoints`).
324
325
        See: https://ui.adsabs.harvard.edu/abs/1995NIMPA.355..541L for details
326
        on the `'lafferty'` method.
327
328
        Parameters
329
        ----------
330
        sed_type : {'dnde'}
331
             SED type to convert to.
332
        model : `~gammapy.modeling.models.SpectralModel`
333
            Spectral model assumption.  Note that the value of the amplitude parameter
334
            does not matter. Still it is recommended to use something with the right
335
            scale and units. E.g. `amplitude = 1e-12 * u.Unit('cm-2 s-1 TeV-1')`
336
        method : {'lafferty', 'log_center', 'table'}
337
            Flux points `e_ref` estimation method:
338
339
                * `'laferty'` Lafferty & Wyatt model-based e_ref
340
                * `'log_center'` log bin center e_ref
341
                * `'table'` using column 'e_ref' from input flux_points
342
        pwl_approx : bool
343
            Use local power law appoximation at e_ref to compute differential flux
344
            from the integral flux. This method is used by the Fermi-LAT catalogs.
345
346
        Returns
347
        -------
348
        flux_points : `FluxPoints`
349
            Flux points including differential quantity columns `dnde`
350
            and `dnde_err` (optional), `dnde_ul` (optional).
351
352
        Examples
353
        --------
354
        >>> from gammapy.spectrum import FluxPoints
355
        >>> from gammapy.modeling.models import PowerLawSpectralModel
356
        >>> filename = '$GAMMAPY_DATA/tests/spectrum/flux_points/flux_points.fits'
357
        >>> flux_points = FluxPoints.read(filename)
358
        >>> model = PowerLawSpectralModel(index=2.2)
359
        >>> flux_points_dnde = flux_points.to_sed_type('dnde', model=model)
360
        """
361
        # TODO: implement other directions.
362
        table = self.table.copy()
363
364
        if self.sed_type == "flux" and sed_type == "dnde":
365
            # Compute e_ref
366
            if method == "table":
367
                e_ref = table["e_ref"].quantity
368
            elif method == "log_center":
369
                e_ref = np.sqrt(self.e_min * self.e_max)
370
            elif method == "lafferty":
371
                # set e_ref that it represents the mean dnde in the given energy bin
372
                e_ref = self._e_ref_lafferty(model, self.e_min, self.e_max)
373
            else:
374
                raise ValueError(f"Invalid method: {method}")
375
            table = self._flux_to_dnde(e_ref, table, model, pwl_approx)
376
377
        elif self.sed_type == "dnde" and sed_type == "e2dnde":
378
            table = self._dnde_to_e2dnde(self.e_ref, table)
379
380
        elif self.sed_type == "e2dnde" and sed_type == "dnde":
381
            table = self._e2dnde_to_dnde(self.e_ref, table)
382
383
        elif self.sed_type == "likelihood" and sed_type in ["dnde", "flux", "eflux"]:
384
            for suffix in ["", "_ul", "_err", "_errp", "_errn"]:
385
                try:
386
                    table[sed_type + suffix] = (
387
                        table["ref_" + sed_type] * table["norm" + suffix]
388
                    )
389
                except KeyError:
390
                    continue
391
        else:
392
            raise NotImplementedError
393
394
        table.meta["SED_TYPE"] = sed_type
395
        return FluxPoints(table)
396
397
    @staticmethod
398
    def _e_ref_lafferty(model, e_min, e_max):
399
        """Helper for `to_sed_type`.
400
401
        Compute e_ref that the value at e_ref corresponds
402
        to the mean value between e_min and e_max.
403
        """
404
        flux = model.integral(e_min, e_max)
405
        dnde_mean = flux / (e_max - e_min)
406
        return model.inverse(dnde_mean)
407
408
    @staticmethod
409
    def _dnde_from_flux(flux, model, e_ref, e_min, e_max, pwl_approx):
410
        """Helper for `to_sed_type`.
411
412
        Compute dnde under the assumption that flux equals expected
413
        flux from model.
414
        """
415
        dnde_model = model(e_ref)
416
417
        if pwl_approx:
418
            index = model.spectral_index(e_ref)
419
            flux_model = PowerLawSpectralModel.evaluate_integral(
420
                emin=e_min,
421
                emax=e_max,
422
                index=index,
423
                reference=e_ref,
424
                amplitude=dnde_model,
425
            )
426
        else:
427
            flux_model = model.integral(e_min, e_max, intervals=True)
428
429
        return dnde_model * (flux / flux_model)
430
431
    @property
432
    def sed_type(self):
433
        """SED type (str).
434
435
        One of: {'dnde', 'e2dnde', 'flux', 'eflux'}
436
        """
437
        return self.table.meta["SED_TYPE"]
438
439
    @staticmethod
440
    def _guess_sed_type(table):
441
        """Guess SED type from table content."""
442
        valid_sed_types = list(REQUIRED_COLUMNS.keys())
443
        for sed_type in valid_sed_types:
444
            required = set(REQUIRED_COLUMNS[sed_type])
445
            if required.issubset(table.colnames):
446
                return sed_type
447
448
    @staticmethod
449
    def _guess_sed_type_from_unit(unit):
450
        """Guess SED type from unit."""
451
        for sed_type, default_unit in DEFAULT_UNIT.items():
452
            if unit.is_equivalent(default_unit):
453
                return sed_type
454
455
    @staticmethod
456
    def _validate_table(table, sed_type):
457
        """Validate input table."""
458
        required = set(REQUIRED_COLUMNS[sed_type])
459
460
        if not required.issubset(table.colnames):
461
            missing = required.difference(table.colnames)
462
            raise ValueError(
463
                "Missing columns for sed type '{}':" " {}".format(sed_type, missing)
464
            )
465
466
    @staticmethod
467
    def _get_y_energy_unit(y_unit):
468
        """Get energy part of the given y unit."""
469
        try:
470
            return [_ for _ in y_unit.bases if _.physical_type == "energy"][0]
471
        except IndexError:
472
            return u.Unit("TeV")
473
474
    def _plot_get_energy_err(self):
475
        """Compute energy error for given sed type"""
476
        try:
477
            e_min = self.table["e_min"].quantity
478
            e_max = self.table["e_max"].quantity
479
            e_ref = self.e_ref
480
            x_err = ((e_ref - e_min), (e_max - e_ref))
481
        except KeyError:
482
            x_err = None
483
        return x_err
484
485
    def _plot_get_flux_err(self, sed_type=None):
486
        """Compute flux error for given sed type"""
487
        try:
488
            # asymmetric error
489
            y_errn = self.table[sed_type + "_errn"].quantity
490
            y_errp = self.table[sed_type + "_errp"].quantity
491
            y_err = (y_errn, y_errp)
492
        except KeyError:
493
            try:
494
                # symmetric error
495
                y_err = self.table[sed_type + "_err"].quantity
496
                y_err = (y_err, y_err)
497
            except KeyError:
498
                # no error at all
499
                y_err = None
500
        return y_err
501
502
    @property
503
    def is_ul(self):
504
        try:
505
            return self.table["is_ul"].data.astype("bool")
506
        except KeyError:
507
            return np.isnan(self.table[self.sed_type])
508
509
    @property
510
    def e_ref(self):
511
        """Reference energy.
512
513
        Defined by `e_ref` column in `FluxPoints.table` or computed as log
514
        center, if `e_min` and `e_max` columns are present in `FluxPoints.table`.
515
516
        Returns
517
        -------
518
        e_ref : `~astropy.units.Quantity`
519
            Reference energy.
520
        """
521
        try:
522
            return self.table["e_ref"].quantity
523
        except KeyError:
524
            return np.sqrt(self.e_min * self.e_max)
525
526
    @property
527
    def e_edges(self):
528
        """Edges of the energy bin.
529
530
        Returns
531
        -------
532
        e_edges : `~astropy.units.Quantity`
533
            Energy edges.
534
        """
535
        e_edges = list(self.e_min)
536
        e_edges += [self.e_max[-1]]
537
        return u.Quantity(e_edges, self.e_min.unit, copy=False)
538
539
    @property
540
    def e_min(self):
541
        """Lower bound of energy bin.
542
543
        Defined by `e_min` column in `FluxPoints.table`.
544
545
        Returns
546
        -------
547
        e_min : `~astropy.units.Quantity`
548
            Lower bound of energy bin.
549
        """
550
        return self.table["e_min"].quantity
551
552
    @property
553
    def e_max(self):
554
        """Upper bound of energy bin.
555
556
        Defined by ``e_max`` column in ``table``.
557
558
        Returns
559
        -------
560
        e_max : `~astropy.units.Quantity`
561
            Upper bound of energy bin.
562
        """
563
        return self.table["e_max"].quantity
564
565
    def plot(
566
        self, ax=None, energy_unit="TeV", flux_unit=None, energy_power=0, **kwargs
567
    ):
568
        """Plot flux points.
569
570
        Parameters
571
        ----------
572
        ax : `~matplotlib.axes.Axes`
573
            Axis object to plot on.
574
        energy_unit : str, `~astropy.units.Unit`, optional
575
            Unit of the energy axis
576
        flux_unit : str, `~astropy.units.Unit`, optional
577
            Unit of the flux axis
578
        energy_power : int
579
            Power of energy to multiply y axis with
580
        kwargs : dict
581
            Keyword arguments passed to :func:`matplotlib.pyplot.errorbar`
582
583
        Returns
584
        -------
585
        ax : `~matplotlib.axes.Axes`
586
            Axis object
587
        """
588
        import matplotlib.pyplot as plt
589
590
        if ax is None:
591
            ax = plt.gca()
592
593
        sed_type = self.sed_type
594
        y_unit = u.Unit(flux_unit or DEFAULT_UNIT[sed_type])
595
596
        y = self.table[sed_type].quantity.to(y_unit)
597
        x = self.e_ref.to(energy_unit)
598
599
        # get errors and ul
600
        is_ul = self.is_ul
601
        x_err_all = self._plot_get_energy_err()
602
        y_err_all = self._plot_get_flux_err(sed_type)
603
604
        # handle energy power
605
        e_unit = self._get_y_energy_unit(y_unit)
606
        y_unit = y.unit * e_unit ** energy_power
607
        y = (y * np.power(x, energy_power)).to(y_unit)
608
609
        y_err, x_err = None, None
610
611
        if y_err_all:
612
            y_errn = (y_err_all[0] * np.power(x, energy_power)).to(y_unit)
613
            y_errp = (y_err_all[1] * np.power(x, energy_power)).to(y_unit)
614
            y_err = (y_errn[~is_ul].to_value(y_unit), y_errp[~is_ul].to_value(y_unit))
615
616
        if x_err_all:
617
            x_errn, x_errp = x_err_all
618
            x_err = (
619
                x_errn[~is_ul].to_value(energy_unit),
620
                x_errp[~is_ul].to_value(energy_unit),
621
            )
622
623
        # set flux points plotting defaults
624
        kwargs.setdefault("marker", "+")
625
        kwargs.setdefault("ls", "None")
626
627
        ebar = ax.errorbar(
628
            x[~is_ul].value, y[~is_ul].value, yerr=y_err, xerr=x_err, **kwargs
629
        )
630
631
        if is_ul.any():
632
            if x_err_all:
633
                x_errn, x_errp = x_err_all
634
                x_err = (
635
                    x_errn[is_ul].to_value(energy_unit),
636
                    x_errp[is_ul].to_value(energy_unit),
637
                )
638
639
            y_ul = self.table[sed_type + "_ul"].quantity
640
            y_ul = (y_ul * np.power(x, energy_power)).to(y_unit)
641
642
            y_err = (0.5 * y_ul[is_ul].value, np.zeros_like(y_ul[is_ul].value))
643
644
            kwargs.setdefault("color", ebar[0].get_color())
645
646
            # pop label keyword to avoid that it appears twice in the legend
647
            kwargs.pop("label", None)
648
            ax.errorbar(
649
                x[is_ul].value,
650
                y_ul[is_ul].value,
651
                xerr=x_err,
652
                yerr=y_err,
653
                uplims=True,
654
                **kwargs,
655
            )
656
657
        ax.set_xscale("log", nonposx="clip")
658
        ax.set_yscale("log", nonposy="clip")
659
        ax.set_xlabel(f"Energy ({energy_unit})")
660
        ax.set_ylabel(f"{self.sed_type} ({y_unit})")
661
        return ax
662
663
    def plot_ts_profiles(
664
        self,
665
        ax=None,
666
        energy_unit="TeV",
667
        add_cbar=True,
668
        y_values=None,
669
        y_unit=None,
670
        **kwargs,
671
    ):
672
        """Plot fit statistic SED profiles as a density plot.
673
674
        Parameters
675
        ----------
676
        ax : `~matplotlib.axes.Axes`
677
            Axis object to plot on.
678
        energy_unit : str, `~astropy.units.Unit`, optional
679
            Unit of the energy axis
680
        y_values : `astropy.units.Quantity`
681
            Array of y-values to use for the fit statistic profile evaluation.
682
        y_unit : str or `astropy.units.Unit`
683
            Unit to use for the y-axis.
684
        add_cbar : bool
685
            Whether to add a colorbar to the plot.
686
        kwargs : dict
687
            Keyword arguments passed to :func:`matplotlib.pyplot.pcolormesh`
688
689
        Returns
690
        -------
691
        ax : `~matplotlib.axes.Axes`
692
            Axis object
693
        """
694
        import matplotlib.pyplot as plt
695
696
        if ax is None:
697
            ax = plt.gca()
698
699
        self._validate_table(self.table, "likelihood")
700
        y_unit = u.Unit(y_unit or DEFAULT_UNIT[self.sed_type])
701
702
        if y_values is None:
703
            ref_values = self.table["ref_" + self.sed_type].quantity
704
            y_values = np.logspace(
705
                np.log10(0.2 * ref_values.value.min()),
706
                np.log10(5 * ref_values.value.max()),
707
                500,
708
            )
709
            y_values = u.Quantity(y_values, y_unit, copy=False)
710
711
        x = self.e_edges.to(energy_unit)
712
713
        # Compute fit statistic "image" one energy bin at a time
714
        # by interpolating e2dnde at the log bin centers
715
        z = np.empty((len(self.table), len(y_values)))
716
        for idx, row in enumerate(self.table):
717
            y_ref = self.table["ref_" + self.sed_type].quantity[idx]
718
            norm = (y_values / y_ref).to_value("")
719
            norm_scan = row["norm_scan"]
720
            ts_scan = row["stat_scan"] - row["stat"]
721
            interp = interpolate_profile(norm_scan, ts_scan)
722
            z[idx] = interp((norm,))
723
724
        kwargs.setdefault("vmax", 0)
725
        kwargs.setdefault("vmin", -4)
726
        kwargs.setdefault("zorder", 0)
727
        kwargs.setdefault("cmap", "Blues")
728
        kwargs.setdefault("linewidths", 0)
729
730
        # clipped values are set to NaN so that they appear white on the plot
731
        z[-z < kwargs["vmin"]] = np.nan
732
        caxes = ax.pcolormesh(x.value, y_values.value, -z.T, **kwargs)
733
        ax.set_xscale("log", nonposx="clip")
734
        ax.set_yscale("log", nonposy="clip")
735
        ax.set_xlabel(f"Energy ({energy_unit})")
736
        ax.set_ylabel(f"{self.sed_type} ({y_values.unit})")
737
738
        if add_cbar:
739
            label = "fit statistic difference"
740
            ax.figure.colorbar(caxes, ax=ax, label=label)
741
742
        return ax
743
744
745
class FluxPointsEstimator:
746
    """Flux points estimator.
747
748
    Estimates flux points for a given list of spectral datasets, energies and
749
    spectral model.
750
751
    To estimate the flux point the amplitude of the reference spectral model is
752
    fitted within the energy range defined by the energy group. This is done for
753
    each group independently. The amplitude is re-normalized using the "norm" parameter,
754
    which specifies the deviation of the flux from the reference model in this
755
    energy group. See https://gamma-astro-data-formats.readthedocs.io/en/latest/spectra/binned_likelihoods/index.html
756
    for details.
757
758
    The method is also described in the Fermi-LAT catalog paper
759
    https://ui.adsabs.harvard.edu/#abs/2015ApJS..218...23A
760
    or the HESS Galactic Plane Survey paper
761
    https://ui.adsabs.harvard.edu/#abs/2018A%26A...612A...1H
762
763
    Parameters
764
    ----------
765
    datasets : list of `~gammapy.spectrum.SpectrumDataset`
766
        Spectrum datasets.
767
    e_edges : `~astropy.units.Quantity`
768
        Energy edges of the flux point bins.
769
    source : str
770
        For which source in the model to compute the flux points.
771
    norm_min : float
772
        Minimum value for the norm used for the fit statistic profile evaluation.
773
    norm_max : float
774
        Maximum value for the norm used for the fit statistic profile evaluation.
775
    norm_n_values : int
776
        Number of norm values used for the fit statistic profile.
777
    norm_values : `numpy.ndarray`
778
        Array of norm values to be used for the fit statistic profile.
779
    sigma : int
780
        Sigma to use for asymmetric error computation.
781
    sigma_ul : int
782
        Sigma to use for upper limit computation.
783
    reoptimize : bool
784
        Re-optimize other free model parameters.
785
    """
786
787
    def __init__(
788
        self,
789
        datasets,
790
        e_edges,
791
        source="",
792
        norm_min=0.2,
793
        norm_max=5,
794
        norm_n_values=11,
795
        norm_values=None,
796
        sigma=1,
797
        sigma_ul=2,
798
        reoptimize=False,
799
    ):
800
        # make a copy to not modify the input datasets
801
        if not isinstance(datasets, Datasets):
802
            datasets = Datasets(datasets)
803
804
        if not datasets.is_all_same_type and datasets.is_all_same_shape:
805
            raise ValueError(
806
                "Flux point estimation requires a list of datasets"
807
                " of the same type and data shape."
808
            )
809
810
        self.datasets = datasets.copy()
811
        self.e_edges = e_edges
812
813
        dataset = self.datasets[0]
814
815
        # TODO: this is complex and non-obvious behaviour. Simlify!
816
        if len(dataset.models) > 1:
817
            model = dataset.models[source].spectral_model
818
        else:
819
            model = dataset.models[0].spectral_model
820
821
        self.model = ScaleSpectralModel(model)
822
        self.model.norm.min = 0
823
        self.model.norm.max = 1e3
824
825
        if norm_values is None:
826
            norm_values = np.logspace(
827
                np.log10(norm_min), np.log10(norm_max), norm_n_values
828
            )
829
830
        self.norm_values = norm_values
831
        self.sigma = sigma
832
        self.sigma_ul = sigma_ul
833
        self.reoptimize = reoptimize
834
        self.source = source
835
        self.fit = Fit(self.datasets)
836
837
        self._set_scale_model()
838
        self._contribute_to_stat = False
839
840
    def _freeze_parameters(self):
841
        # freeze other parameters
842
        for par in self.datasets.parameters:
843
            if par is not self.model.norm:
844
                par.frozen = True
845
846
    def _freeze_empty_background(self):
847
        from gammapy.cube import MapDataset
848
849
        counts_all = self.estimate_counts()["counts"]
850
851
        for counts, dataset in zip(counts_all, self.datasets):
852
            if isinstance(dataset, MapDataset) and counts == 0:
853
                if dataset.background_model is not None:
854
                    dataset.background_model.parameters.freeze_all()
855
856
    def _set_scale_model(self):
857
        # set the model on all datasets
858
        for dataset in self.datasets:
859
            if len(dataset.models) > 1:
860
                dataset.models[self.source].spectral_model = self.model
861
            else:
862
                dataset.models[0].spectral_model = self.model
863
864
    @property
865
    def ref_model(self):
866
        return self.model.model
867
868
    @property
869
    def e_groups(self):
870
        """Energy grouping table `~astropy.table.Table`"""
871
        dataset = self.datasets[0]
872
        if isinstance(dataset, SpectrumDatasetOnOff):
873
            energy_axis = dataset.counts.energy
874
        else:
875
            energy_axis = dataset.counts.geom.get_axis_by_name("energy")
876
        return energy_axis.group_table(self.e_edges)
877
878
    def __str__(self):
879
        s = f"{self.__class__.__name__}:\n"
880
        s += str(self.datasets) + "\n"
881
        s += str(self.e_edges) + "\n"
882
        s += str(self.model) + "\n"
883
        return s
884
885
    def run(self, steps="all"):
886
        """Run the flux point estimator for all energy groups.
887
888
        Returns
889
        -------
890
        flux_points : `FluxPoints`
891
            Estimated flux points.
892
        steps : list of str
893
            Which steps to execute. See `estimate_flux_point` for details
894
            and available options.
895
        """
896
        rows = []
897
        for e_group in self.e_groups:
898
            if e_group["bin_type"].strip() != "normal":
899
                log.debug("Skipping under-/ overflow bin in flux point estimation.")
900
                continue
901
902
            row = self.estimate_flux_point(e_group, steps=steps)
903
            rows.append(row)
904
905
        table = table_from_row_data(rows=rows, meta={"SED_TYPE": "likelihood"})
906
        return FluxPoints(table).to_sed_type("dnde")
907
908
    def _energy_mask(self, e_group, dataset):
909
        energy_mask = np.zeros(dataset.data_shape)
910
        energy_mask[e_group["idx_min"] : e_group["idx_max"] + 1] = 1
911
        return energy_mask.astype(bool)
912
913
    def estimate_flux_point(self, e_group, steps="all"):
914
        """Estimate flux point for a single energy group.
915
916
        Parameters
917
        ----------
918
        e_group : `~astropy.table.Row`
919
            Energy group to compute the flux point for.
920
        steps : list of str
921
            Which steps to execute. Available options are:
922
923
                * "err": estimate symmetric error.
924
                * "errn-errp": estimate asymmetric errors.
925
                * "ul": estimate upper limits.
926
                * "ts": estimate ts and sqrt(ts) values.
927
                * "norm-scan": estimate fit statistic profiles.
928
929
            By default all steps are executed.
930
931
        Returns
932
        -------
933
        result : dict
934
            Dict with results for the flux point.
935
        """
936
        e_min, e_max = e_group["energy_min"], e_group["energy_max"]
937
        # Put at log center of the bin
938
        e_ref = np.sqrt(e_min * e_max)
939
940
        result = {
941
            "e_ref": e_ref,
942
            "e_min": e_min,
943
            "e_max": e_max,
944
            "ref_dnde": self.ref_model(e_ref),
945
            "ref_flux": self.ref_model.integral(e_min, e_max),
946
            "ref_eflux": self.ref_model.energy_flux(e_min, e_max),
947
            "ref_e2dnde": self.ref_model(e_ref) * e_ref ** 2,
948
        }
949
950
        for dataset in self.datasets:
951
            dataset.mask_fit = self._energy_mask(e_group=e_group, dataset=dataset)
952
            mask = dataset.mask_fit
953
954
            if dataset.mask_safe is not None:
955
                mask &= dataset.mask_safe
956
957
            self._contribute_to_stat |= mask.any()
958
959
        with self.datasets.parameters.restore_values:
960
961
            self._freeze_empty_background()
962
963
            if not self.reoptimize:
964
                self._freeze_parameters()
965
966
            result.update(self.estimate_norm())
967
968
            if not result.pop("success"):
969
                log.warning(
970
                    "Fit failed for flux point between {e_min:.3f} and {e_max:.3f},"
971
                    " setting NaN.".format(e_min=e_min, e_max=e_max)
972
                )
973
974
            if steps == "all":
975
                steps = ["err", "counts", "errp-errn", "ul", "ts", "norm-scan"]
976
977
            if "err" in steps:
978
                result.update(self.estimate_norm_err())
979
980
            if "counts" in steps:
981
                result.update(self.estimate_counts())
982
983
            if "errp-errn" in steps:
984
                result.update(self.estimate_norm_errn_errp())
985
986
            if "ul" in steps:
987
                result.update(self.estimate_norm_ul())
988
989
            if "ts" in steps:
990
                result.update(self.estimate_norm_ts())
991
992
            if "norm-scan" in steps:
993
                result.update(self.estimate_norm_scan())
994
995
        return result
996
997
    def estimate_norm_errn_errp(self):
998
        """Estimate asymmetric errors for a flux point.
999
1000
        Returns
1001
        -------
1002
        result : dict
1003
            Dict with asymmetric errors for the flux point norm.
1004
        """
1005
        if not self._contribute_to_stat:
1006
            return {"norm_errp": np.nan, "norm_errn": np.nan}
1007
1008
        result = self.fit.confidence(parameter=self.model.norm, sigma=self.sigma)
1009
        return {"norm_errp": result["errp"], "norm_errn": result["errn"]}
1010
1011
    def estimate_norm_err(self):
1012
        """Estimate covariance errors for a flux point.
1013
1014
        Returns
1015
        -------
1016
        result : dict
1017
            Dict with symmetric error for the flux point norm.
1018
        """
1019
        if not self._contribute_to_stat:
1020
            return {"norm_err": np.nan}
1021
1022
        result = self.fit.covariance()
1023
        norm_err = result.parameters.error(self.model.norm)
1024
        return {"norm_err": norm_err}
1025
1026
    def estimate_counts(self):
1027
        """Estimate counts for the flux point.
1028
1029
        Returns
1030
        -------
1031
        result : dict
1032
            Dict with an array with one entry per dataset with counts for the flux point.
1033
        """
1034
        if not self._contribute_to_stat:
1035
            return {"counts": np.zeros(len(self.datasets))}
1036
1037
        counts = []
1038
        for dataset in self.datasets:
1039
            mask = dataset.mask_fit
1040
            if dataset.mask_safe is not None:
1041
                mask &= dataset.mask_safe
1042
1043
            counts.append(dataset.counts.data[mask].sum())
1044
1045
        return {"counts": np.array(counts, dtype=int)}
1046
1047 View Code Duplication
    def estimate_norm_ul(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
1048
        """Estimate upper limit for a flux point.
1049
1050
        Returns
1051
        -------
1052
        result : dict
1053
            Dict with upper limit for the flux point norm.
1054
        """
1055
        if not self._contribute_to_stat:
1056
            return {"norm_ul": np.nan}
1057
1058
        norm = self.model.norm
1059
1060
        # TODO: the minuit backend has convergence problems when the fit statistic is not
1061
        #  of parabolic shape, which is the case, when there are zero counts in the
1062
        #  energy bin. For this case we change to the scipy backend.
1063
        counts = self.estimate_counts()["counts"]
1064
1065
        if np.all(counts == 0):
1066
            result = self.fit.confidence(
1067
                parameter=norm,
1068
                sigma=self.sigma_ul,
1069
                backend="scipy",
1070
                reoptimize=self.reoptimize,
1071
            )
1072
        else:
1073
            result = self.fit.confidence(parameter=norm, sigma=self.sigma_ul)
1074
1075
        return {"norm_ul": result["errp"] + norm.value}
1076
1077 View Code Duplication
    def estimate_norm_ts(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
1078
        """Estimate ts and sqrt(ts) for the flux point.
1079
1080
        Returns
1081
        -------
1082
        result : dict
1083
            Dict with ts and sqrt(ts) for the flux point.
1084
        """
1085
        if not self._contribute_to_stat:
1086
            return {"sqrt_ts": np.nan, "ts": np.nan}
1087
1088
        stat = self.datasets.stat_sum()
1089
1090
        # store best fit amplitude, set amplitude of fit model to zero
1091
        self.model.norm.value = 0
1092
        self.model.norm.frozen = True
1093
1094
        if self.reoptimize:
1095
            _ = self.fit.optimize()
1096
1097
        stat_null = self.datasets.stat_sum()
1098
1099
        # compute sqrt TS
1100
        ts = np.abs(stat_null - stat)
1101
        sqrt_ts = np.sqrt(ts)
1102
        return {"sqrt_ts": sqrt_ts, "ts": ts}
1103
1104
    def estimate_norm_scan(self):
1105
        """Estimate fit statistic profile for the norm parameter.
1106
1107
        Returns
1108
        -------
1109
        result : dict
1110
            Keys: "norm_scan", "stat_scan"
1111
        """
1112
        if not self._contribute_to_stat:
1113
            nans = np.nan * np.empty_like(self.norm_values)
1114
            return {"norm_scan": nans, "stat_scan": nans}
1115
1116
        result = self.fit.stat_profile(
1117
            self.model.norm, values=self.norm_values, reoptimize=self.reoptimize
1118
        )
1119
        return {"norm_scan": result["values"], "stat_scan": result["stat"]}
1120
1121
    def estimate_norm(self):
1122
        """Fit norm of the flux point.
1123
1124
        Returns
1125
        -------
1126
        result : dict
1127
            Dict with "norm" and "stat" for the flux point.
1128
        """
1129
        if not self._contribute_to_stat:
1130
            return {"norm": np.nan, "stat": np.nan, "success": False}
1131
1132
        # start optimization with norm=1
1133
        self.model.norm.value = 1.0
1134
        self.model.norm.frozen = False
1135
1136
        result = self.fit.optimize()
1137
1138
        if result.success:
1139
            norm = self.model.norm.value
1140
        else:
1141
            norm = np.nan
1142
1143
        return {"norm": norm, "stat": result.total_stat, "success": result.success}
1144
1145
1146
class FluxPointsDataset(Dataset):
1147
    """
1148
    Fit a set of flux points with a parametric model.
1149
1150
    Parameters
1151
    ----------
1152
    models : `~gammapy.modeling.models.SkyModels`
1153
        Models (only spectral part needs to be set)
1154
    data : `~gammapy.spectrum.FluxPoints`
1155
        Flux points.
1156
    mask_fit : `numpy.ndarray`
1157
        Mask to apply for fitting
1158
    likelihood : {"chi2", "chi2assym"}
1159
        Likelihood function to use for the fit.
1160
    mask_safe : `numpy.ndarray`
1161
        Mask defining the safe data range.
1162
1163
    Examples
1164
    --------
1165
    Load flux points from file and fit with a power-law model::
1166
1167
        from gammapy.modeling import Fit
1168
        from gammapy.modeling.models import PowerLawSpectralModel, SkyModel
1169
        from gammapy.spectrum import FluxPoints, FluxPointsDataset
1170
1171
        filename = "$GAMMAPY_DATA/tests/spectrum/flux_points/diff_flux_points.fits"
1172
        flux_points = FluxPoints.read(filename)
1173
1174
        model = SkyModel(spectral_model=PowerLawSpectralModel())
1175
1176
        dataset = FluxPointsDataset(model, flux_points)
1177
        fit = Fit([dataset])
1178
        result = fit.run()
1179
        print(result)
1180
        print(result.parameters.to_table())
1181
    """
1182
1183
    tag = "FluxPointsDataset"
1184
1185
    def __init__(
1186
        self, models, data, mask_fit=None, likelihood="chi2", mask_safe=None, name=""
1187
    ):
1188
        self.data = data
1189
        self.mask_fit = mask_fit
1190
        self.name = name
1191
        self.models = models
1192
        if data.sed_type != "dnde":
1193
            raise ValueError("Currently only flux points of type 'dnde' are supported.")
1194
1195
        if mask_safe is None:
1196
            mask_safe = np.isfinite(data.table["dnde"])
1197
1198
        self.mask_safe = mask_safe
1199
1200
        if likelihood in ["chi2", "chi2assym"]:
1201
            self.likelihood_type = likelihood
1202
        else:
1203
            raise ValueError(
1204
                f"Invalid likelihood: {likelihood!r}."
1205
                " Choose either 'chi2' or 'chi2assym'."
1206
            )
1207
1208
    @property
1209
    def models(self):
1210
        return self._models
1211
1212 View Code Duplication
    @models.setter
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
1213
    def models(self, value):
1214
        if isinstance(value, SkyModels):
1215
            models = value
1216
        elif isinstance(value, list):
1217
            models = SkyModels(value)
1218
        elif isinstance(value, SkyModel):
1219
            models = SkyModels([value])
1220
        else:
1221
            raise TypeError(f"Invalid: {value!r}")
1222
1223
        self._models = models
1224
1225
    @property
1226
    def parameters(self):
1227
        """List of parameters (`~gammapy.modeling.Parameters`)"""
1228
        parameters = []
1229
1230
        for component in self.models:
1231
            parameters.append(component.spectral_model.parameters)
1232
1233
        return Parameters.from_stack(parameters)
1234
1235
    def write(self, filename, overwrite=True, **kwargs):
1236
        """Write flux point dataset to file.
1237
1238
        Parameters
1239
        ----------
1240
        filename : str
1241
            Filename to write to.
1242
        overwrite : bool
1243
            Overwrite existing file.
1244
        **kwargs : dict
1245
             Keyword arguments passed to `~astropy.table.Table.write`.
1246
        """
1247
        table = self.data.table.copy()
1248
        if self.mask_fit is None:
1249
            mask_fit = self.mask_safe
1250
        else:
1251
            mask_fit = self.mask_fit
1252
1253
        table["mask_fit"] = mask_fit
1254
        table["mask_safe"] = self.mask_safe
1255
        table.write(filename, overwrite=overwrite, **kwargs)
1256
1257
    @classmethod
1258
    def from_dict(cls, data, components, models):
1259
        """Create flux point dataset from dict.
1260
1261
        Parameters
1262
        ----------
1263
        data : dict
1264
            Dict containing data to create dataset from.
1265
        components : list of dict
1266
            Not used.
1267
        models : list of `SkyModel`
1268
            List of model components.
1269
1270
        Returns
1271
        -------
1272
        dataset : `FluxPointsDataset`
1273
            Flux point datasets.
1274
        """
1275
        models = [model for model in models if model.name in data["models"]]
1276
        # TODO: assumes that the model is a skymodel
1277
        # so this will work only when this change will be effective
1278
        table = Table.read(data["filename"])
1279
        mask_fit = table["mask_fit"].data.astype("bool")
1280
        mask_safe = table["mask_safe"].data.astype("bool")
1281
        table.remove_columns(["mask_fit", "mask_safe"])
1282
        return cls(
1283
            models=models,
1284
            name=data["name"],
1285
            data=FluxPoints(table),
1286
            mask_fit=mask_fit,
1287
            mask_safe=mask_safe,
1288
            likelihood=data["likelihood"],
1289
        )
1290
1291
    def to_dict(self, filename=""):
1292
        """Convert to dict for YAML serialization."""
1293
        if self.models is not None:
1294
            models = [_.name for _ in self.models]
1295
        else:
1296
            models = []
1297
1298
        return {
1299
            "name": self.name,
1300
            "type": self.tag,
1301
            "models": models,
1302
            "likelihood": self.likelihood_type,
1303
            "filename": str(filename),
1304
        }
1305
1306
    def __str__(self):
1307
        str_ = f"{self.__class__.__name__}: \n"
1308
        str_ += "\n"
1309
        if self.models is None:
1310
            str_ += "\t{:32}:   {} \n".format("Model Name", "No Model")
1311
        else:
1312
            str_ += "\t{:32}:   {} \n".format("Total flux points", len(self.data.table))
1313
            str_ += "\t{:32}:   {} \n".format(
1314
                "Points used for the fit", self.mask.sum()
1315
            )
1316
            str_ += "\t{:32}:   {} \n".format(
1317
                "Excluded for safe energy range", (~self.mask_safe).sum()
1318
            )
1319
            if self.mask_fit is None:
1320
                str_ += "\t{:32}:   {} \n".format("Excluded by user", "0")
1321
            else:
1322
                str_ += "\t{:32}:   {} \n".format(
1323
                    "Excluded by user", (~self.mask_fit).sum()
1324
                )
1325
            str_ += "\t{:32}:   {}\n".format(
1326
                "Model Name", self.models.__class__.__name__
1327
            )
1328
            str_ += "\t{:32}:   {}\n".format("N parameters", len(self.parameters))
1329
            str_ += "\t{:32}:   {}\n".format(
1330
                "N free parameters", len(self.parameters.free_parameters)
1331
            )
1332
            str_ += "\tList of parameters\n"
1333
            for par in self.parameters:
1334
                if par.frozen:
1335
                    if par.name == "amplitude":
1336
                        str_ += "\t \t {:14} (Frozen):   {:.2e} {} \n".format(
1337
                            par.name, par.value, par.unit
1338
                        )
1339
                    else:
1340
                        str_ += "\t \t {:14} (Frozen):   {:.2f} {} \n".format(
1341
                            par.name, par.value, par.unit
1342
                        )
1343
                else:
1344
                    if par.name == "amplitude":
1345
                        str_ += "\t \t {:23}:   {:.2e} {} \n".format(
1346
                            par.name, par.value, par.unit
1347
                        )
1348
                    else:
1349
                        str_ += "\t \t {:23}:   {:.2f} {} \n".format(
1350
                            par.name, par.value, par.unit
1351
                        )
1352
            str_ += "\t{:32}:   {}\n".format("Likelihood type", self.likelihood_type)
1353
            str_ += "\t{:32}:   {:.2f}\n".format("Likelihood value", self.stat_sum())
1354
        return str_
1355
1356
    def data_shape(self):
1357
        """Shape of the flux points data (tuple)."""
1358
        return self.data.e_ref.shape
1359
1360
    @staticmethod
1361
    def _stat_chi2(data, model, sigma):
1362
        return ((data - model) / sigma).to_value("") ** 2
1363
1364
    @staticmethod
1365
    def _stat_chi2_assym(data, model, sigma_n, sigma_p):
1366
        """Assymetric chi2 statistics for a list of flux points and model."""
1367
        is_p = model > data
1368
        sigma = sigma_n
1369
        sigma[is_p] = sigma_p[is_p]
1370
        return FluxPointsDataset._stat_chi2(data, model, sigma)
1371
1372
    def flux_pred(self):
1373
        """Compute predicted flux."""
1374
        flux = 0.0
1375
        for component in self.models:
1376
            flux += component.spectral_model(self.data.e_ref)
1377
        return flux
1378
1379
    def stat_array(self):
1380
        """Fit statistic array."""
1381
        model = self.flux_pred()
1382
        data = self.data.table["dnde"].quantity
1383
1384
        if self.likelihood_type == "chi2":
1385
            sigma = self.data.table["dnde_err"].quantity
1386
            return self._stat_chi2(data, model, sigma)
1387
        elif self.likelihood_type == "chi2assym":
1388
            sigma_n = self.data.table["dnde_errn"].quantity
1389
            sigma_p = self.data.table["dnde_errp"].quantity
1390
            return self._stat_chi2_assym(data, model, sigma_n, sigma_p)
1391
        else:
1392
            # TODO: add fit statistic profiles
1393
            pass
1394
1395
    def residuals(self, method="diff"):
1396
        """Compute the flux point residuals ().
1397
1398
        Parameters
1399
        ----------
1400
        method: {"diff", "diff/model", "diff/sqrt(model)"}
1401
            Method used to compute the residuals. Available options are:
1402
                - `diff` (default): data - model
1403
                - `diff/model`: (data - model) / model
1404
                - `diff/sqrt(model)`: (data - model) / sqrt(model)
1405
                - `norm='sqrt_model'` for: (flux points - model)/sqrt(model)
1406
1407
1408
        Returns
1409
        -------
1410
        residuals : `~numpy.ndarray`
1411
            Residuals array.
1412
        """
1413
        fp = self.data
1414
        data = fp.table[fp.sed_type]
1415
1416
        model = self.flux_pred()
1417
1418
        residuals = self._compute_residuals(data, model, method)
1419
        # Remove residuals for upper_limits
1420
        residuals[fp.is_ul] = np.nan
1421
        return residuals
1422
1423
    def peek(self, method="diff/model", **kwargs):
1424
        """Plot flux points, best fit model and residuals.
1425
1426
        Parameters
1427
        ----------
1428
        method : {"diff", "diff/model", "diff/sqrt(model)"}
1429
            Method used to compute the residuals, see `MapDataset.residuals()`
1430
        """
1431
        from matplotlib.gridspec import GridSpec
1432
        import matplotlib.pyplot as plt
1433
1434
        gs = GridSpec(7, 1)
1435
1436
        ax_spectrum = plt.subplot(gs[:5, :])
1437
        self.plot_spectrum(ax=ax_spectrum, **kwargs)
1438
1439
        ax_spectrum.set_xticks([])
1440
1441
        ax_residuals = plt.subplot(gs[5:, :])
1442
        self.plot_residuals(ax=ax_residuals, method=method)
1443
        return ax_spectrum, ax_residuals
1444
1445
    @property
1446
    def _e_range(self):
1447
        try:
1448
            return u.Quantity([self.data.e_min.min(), self.data.e_max.max()])
1449
        except KeyError:
1450
            return u.Quantity([self.data.e_ref.min(), self.data.e_ref.max()])
1451
1452
    @property
1453
    def _e_unit(self):
1454
        return self.data.e_ref.unit
1455
1456
    def plot_residuals(self, ax=None, method="diff", **kwargs):
1457
        """Plot flux point residuals.
1458
1459
        Parameters
1460
        ----------
1461
        ax : `~matplotlib.pyplot.Axes`
1462
            Axes object.
1463
        method : {"diff", "diff/model", "diff/sqrt(model)"}
1464
            Method used to compute the residuals, see `MapDataset.residuals()`
1465
        **kwargs : dict
1466
            Keyword arguments passed to `~matplotlib.pyplot.errorbar`.
1467
1468
        Returns
1469
        -------
1470
        ax : `~matplotlib.pyplot.Axes`
1471
            Axes object.
1472
        """
1473
        import matplotlib.pyplot as plt
1474
1475
        ax = plt.gca() if ax is None else ax
1476
1477
        residuals = self.residuals(method=method)
1478
1479
        fp = self.data
1480
1481
        xerr = fp._plot_get_energy_err()
1482
        if xerr is not None:
1483
            xerr = xerr[0].to_value(self._e_unit), xerr[1].to_value(self._e_unit)
1484
1485
        model = self.flux_pred()
1486
        yerr = fp._plot_get_flux_err(fp.sed_type)
1487
1488
        if method == "diff":
1489
            unit = yerr[0].unit
1490
            yerr = yerr[0].to_value(unit), yerr[1].to_value(unit)
1491
        elif method == "diff/model":
1492
            unit = ""
1493
            yerr = (yerr[0] / model).to_value(""), (yerr[1] / model).to_value(unit)
1494
        else:
1495
            raise ValueError("Invalid method, choose between 'diff' and 'diff/model'")
1496
1497
        kwargs.setdefault("marker", "+")
1498
        kwargs.setdefault("ls", "None")
1499
        kwargs.setdefault("color", "black")
1500
1501
        ax.errorbar(
1502
            self.data.e_ref.value, residuals.value, xerr=xerr, yerr=yerr, **kwargs
1503
        )
1504
1505
        # format axes
1506
        ax.axhline(0, color="black", lw=0.5)
1507
        ax.set_ylabel("Residuals {}".format(unit.__str__()))
1508
        ax.set_xlabel(f"Energy ({self._e_unit})")
1509
        ax.set_xscale("log")
1510
        ax.set_xlim(self._e_range.to_value(self._e_unit))
1511
        y_max = 2 * np.nanmax(residuals).value
1512
        ax.set_ylim(-y_max, y_max)
1513
        return ax
1514
1515
    def plot_spectrum(self, ax=None, fp_kwargs=None, model_kwargs=None):
1516
        """
1517
        Plot spectrum including flux points and model.
1518
1519
        Parameters
1520
        ----------
1521
        ax : `~matplotlib.pyplot.Axes`
1522
            Axes object.
1523
        fp_kwargs : dict
1524
            Keyword arguments passed to `FluxPoints.plot`.
1525
        model_kwargs : dict
1526
            Keywords passed to `SpectralModel.plot` and `SpectralModel.plot_error`
1527
1528
        Returns
1529
        -------
1530
        ax : `~matplotlib.pyplot.Axes`
1531
            Axes object.
1532
        """
1533
        import matplotlib.pyplot as plt
1534
1535
        ax = plt.gca() if ax is None else ax
1536
        fp_kwargs = {} if fp_kwargs is None else fp_kwargs
1537
        model_kwargs = {} if model_kwargs is None else model_kwargs
1538
1539
        kwargs = {
1540
            "flux_unit": "erg-1 cm-2 s-1",
1541
            "energy_unit": "TeV",
1542
            "energy_power": 2,
1543
        }
1544
1545
        # plot flux points
1546
        plot_kwargs = kwargs.copy()
1547
        plot_kwargs.update(fp_kwargs)
1548
        plot_kwargs.setdefault("label", "Flux points")
1549
        ax = self.data.plot(ax=ax, **plot_kwargs)
1550
1551
        plot_kwargs = kwargs.copy()
1552
        plot_kwargs.setdefault("energy_range", self._e_range)
1553
        plot_kwargs.setdefault("zorder", 10)
1554
        plot_kwargs.update(model_kwargs)
1555
        plot_kwargs.setdefault("label", "Best fit model")
1556
        for _ in self.models:
1557
            _.spectral_model.plot(ax=ax, **plot_kwargs)
1558
1559
        plot_kwargs.setdefault("color", ax.lines[-1].get_color())
1560
        del plot_kwargs["label"]
1561
1562
        if self.models.parameters.covariance is not None:
1563
            try:
1564
                self.models.plot_error(ax=ax, **plot_kwargs)
1565
            except AttributeError:
1566
                log.debug("Model does not support evaluation of errors")
1567
1568
        # format axes
1569
        ax.set_xlim(self._e_range.to_value(self._e_unit))
1570
        return ax
1571