Completed
Push — master ( 45d566...de5fcf )
by Axel
30s queued 24s
created

gammapy.datasets.map.MapDataset.mask_fit_image()   A

Complexity

Conditions 2

Size

Total Lines 6
Code Lines 5

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 5
dl 0
loc 6
rs 10
c 0
b 0
f 0
cc 2
nop 1
1
# Licensed under a 3-clause BSD style license - see LICENSE.rst
2
import logging
3
import numpy as np
4
import astropy.units as u
5
from astropy.io import fits
6
from astropy.table import Table
7
from regions import CircleSkyRegion
8
from gammapy.data import GTI
9
from gammapy.irf import EDispKernelMap, EDispMap, PSFKernel, PSFMap
10
from gammapy.maps import Map, MapAxis
11
from gammapy.modeling.models import (
12
    DatasetModels,
13
    FoVBackgroundModel,
14
)
15
from gammapy.stats import (
16
    CashCountsStatistic,
17
    WStatCountsStatistic,
18
    cash,
19
    cash_sum_cython,
20
    get_wstat_mu_bkg,
21
    wstat,
22
)
23
from gammapy.utils.fits import HDULocation, LazyFitsData
24
from gammapy.utils.random import get_random_state
25
from gammapy.utils.scripts import make_name, make_path
26
from gammapy.utils.table import hstack_columns
27
from .core import Dataset
28
from .utils import get_axes
29
from .evaluator import MapEvaluator
30
31
__all__ = ["MapDataset", "MapDatasetOnOff", "create_map_dataset_geoms"]
32
33
log = logging.getLogger(__name__)
34
35
36
RAD_MAX = 0.66
37
RAD_AXIS_DEFAULT = MapAxis.from_bounds(
38
    0, RAD_MAX, nbin=66, node_type="edges", name="rad", unit="deg"
39
)
40
MIGRA_AXIS_DEFAULT = MapAxis.from_bounds(
41
    0.2, 5, nbin=48, node_type="edges", name="migra"
42
)
43
44
BINSZ_IRF_DEFAULT = 0.2
45
46
EVALUATION_MODE = "local"
47
USE_NPRED_CACHE = True
48
49
50
def create_map_dataset_geoms(
51
    geom,
52
    energy_axis_true=None,
53
    migra_axis=None,
54
    rad_axis=None,
55
    binsz_irf=None,
56
):
57
    """Create map geometries for a `MapDataset`
58
59
    Parameters
60
    ----------
61
    geom : `~gammapy.maps.WcsGeom`
62
        Reference target geometry in reco energy, used for counts and background maps
63
    energy_axis_true : `~gammapy.maps.MapAxis`
64
        True energy axis used for IRF maps
65
    migra_axis : `~gammapy.maps.MapAxis`
66
        If set, this provides the migration axis for the energy dispersion map.
67
        If not set, an EDispKernelMap is produced instead. Default is None
68
    rad_axis : `~gammapy.maps.MapAxis`
69
        Rad axis for the psf map
70
    binsz_irf : float
71
        IRF Map pixel size in degrees.
72
73
    Returns
74
    -------
75
    geoms : dict
76
        Dict with map geometries.
77
    """
78
    rad_axis = rad_axis or RAD_AXIS_DEFAULT
79
80
    if energy_axis_true is not None:
81
        energy_axis_true.assert_name("energy_true")
82
    else:
83
        energy_axis_true = geom.axes["energy"].copy(name="energy_true")
84
85
    binsz_irf = binsz_irf or BINSZ_IRF_DEFAULT
86
    geom_image = geom.to_image()
87
    geom_exposure = geom_image.to_cube([energy_axis_true])
88
    geom_irf = geom_image.to_binsz(binsz=binsz_irf)
89
    geom_psf = geom_irf.to_cube([rad_axis, energy_axis_true])
90
91
    if migra_axis:
92
        geom_edisp = geom_irf.to_cube([migra_axis, energy_axis_true])
93
    else:
94
        geom_edisp = geom_irf.to_cube([geom.axes["energy"], energy_axis_true])
95
96
    return {
97
        "geom": geom,
98
        "geom_exposure": geom_exposure,
99
        "geom_psf": geom_psf,
100
        "geom_edisp": geom_edisp,
101
    }
102
103
104
class MapDataset(Dataset):
105
    """Perform sky model likelihood fit on maps.
106
107
    If an `HDULocation` is passed the map is loaded lazily. This means the
108
    map data is only loaded in memory as the corresponding data attribute
109
    on the MapDataset is accessed. If it was accessed once it is cached for
110
    the next time.
111
112
    Parameters
113
    ----------
114
    models : `~gammapy.modeling.models.Models`
115
        Source sky models.
116
    counts : `~gammapy.maps.WcsNDMap` or `~gammapy.utils.fits.HDULocation`
117
        Counts cube
118
    exposure : `~gammapy.maps.WcsNDMap` or `~gammapy.utils.fits.HDULocation`
119
        Exposure cube
120
    background : `~gammapy.maps.WcsNDMap` or `~gammapy.utils.fits.HDULocation`
121
        Background cube
122
    mask_fit : `~gammapy.maps.WcsNDMap` or `~gammapy.utils.fits.HDULocation`
123
        Mask to apply to the likelihood for fitting.
124
    psf : `~gammapy.irf.PSFMap` or `~gammapy.utils.fits.HDULocation`
125
        PSF kernel
126
    edisp : `~gammapy.irf.EDispKernel` or `~gammapy.irf.EDispMap` or `~gammapy.utils.fits.HDULocation`
127
        Energy dispersion kernel
128
    mask_safe : `~gammapy.maps.WcsNDMap` or `~gammapy.utils.fits.HDULocation`
129
        Mask defining the safe data range.
130
    gti : `~gammapy.data.GTI`
131
        GTI of the observation or union of GTI if it is a stacked observation
132
    meta_table : `~astropy.table.Table`
133
        Table listing information on observations used to create the dataset.
134
        One line per observation for stacked datasets.
135
136
137
    See Also
138
    --------
139
    MapDatasetOnOff, SpectrumDataset, FluxPointsDataset
140
    """
141
142
    stat_type = "cash"
143
    tag = "MapDataset"
144
    counts = LazyFitsData(cache=True)
145
    exposure = LazyFitsData(cache=True)
146
    edisp = LazyFitsData(cache=True)
147
    background = LazyFitsData(cache=True)
148
    psf = LazyFitsData(cache=True)
149
    mask_fit = LazyFitsData(cache=True)
150
    mask_safe = LazyFitsData(cache=True)
151
152
    _lazy_data_members = [
153
        "counts",
154
        "exposure",
155
        "edisp",
156
        "psf",
157
        "mask_fit",
158
        "mask_safe",
159
        "background",
160
    ]
161
162
    def __init__(
163
        self,
164
        models=None,
165
        counts=None,
166
        exposure=None,
167
        background=None,
168
        psf=None,
169
        edisp=None,
170
        mask_safe=None,
171
        mask_fit=None,
172
        gti=None,
173
        meta_table=None,
174
        name=None,
175
    ):
176
        self._name = make_name(name)
177
        self._evaluators = {}
178
179
        self.counts = counts
180
        self.exposure = exposure
181
        self.background = background
182
        self._background_cached = None
183
        self._background_parameters_cached = None
184
185
        self.mask_fit = mask_fit
186
187
        if psf and not isinstance(psf, (PSFMap, HDULocation)):
188
            raise ValueError(
189
                f"'psf' must be a 'PSFMap' or `HDULocation` object, got {type(psf)}"
190
            )
191
192
        self.psf = psf
193
194
        if edisp and not isinstance(edisp, (EDispMap, EDispKernelMap, HDULocation)):
195
            raise ValueError(
196
                f"'edisp' must be a 'EDispMap', `EDispKernelMap` or 'HDULocation' object, got {type(edisp)}"
197
            )
198
199
        self.edisp = edisp
200
        self.mask_safe = mask_safe
201
        self.gti = gti
202
        self.models = models
203
        self.meta_table = meta_table
204
205
    # TODO: keep or remove?
206
    @property
207
    def background_model(self):
208
        try:
209
            return self.models[f"{self.name}-bkg"]
210
        except (ValueError, TypeError):
211
            pass
212
213
    def __str__(self):
214
        str_ = f"{self.__class__.__name__}\n"
215
        str_ += "-" * len(self.__class__.__name__) + "\n"
216
        str_ += "\n"
217
        str_ += "\t{:32}: {{name}} \n\n".format("Name")
218
        str_ += "\t{:32}: {{counts:.0f}} \n".format("Total counts")
219
        str_ += "\t{:32}: {{background:.2f}}\n".format("Total background counts")
220
        str_ += "\t{:32}: {{excess:.2f}}\n\n".format("Total excess counts")
221
222
        str_ += "\t{:32}: {{npred:.2f}}\n".format("Predicted counts")
223
        str_ += "\t{:32}: {{npred_background:.2f}}\n".format(
224
            "Predicted background counts"
225
        )
226
        str_ += "\t{:32}: {{npred_signal:.2f}}\n\n".format("Predicted excess counts")
227
228
        str_ += "\t{:32}: {{exposure_min:.2e}}\n".format("Exposure min")
229
        str_ += "\t{:32}: {{exposure_max:.2e}}\n\n".format("Exposure max")
230
231
        str_ += "\t{:32}: {{n_bins}} \n".format("Number of total bins")
232
        str_ += "\t{:32}: {{n_fit_bins}} \n\n".format("Number of fit bins")
233
234
        # likelihood section
235
        str_ += "\t{:32}: {{stat_type}}\n".format("Fit statistic type")
236
        str_ += "\t{:32}: {{stat_sum:.2f}}\n\n".format(
237
            "Fit statistic value (-2 log(L))"
238
        )
239
240
        info = self.info_dict()
241
        str_ = str_.format(**info)
242
243
        # model section
244
        n_models, n_pars, n_free_pars = 0, 0, 0
245
        if self.models is not None:
246
            n_models = len(self.models)
247
            n_pars = len(self.models.parameters)
248
            n_free_pars = len(self.models.parameters.free_parameters)
249
250
        str_ += "\t{:32}: {} \n".format("Number of models", n_models)
251
        str_ += "\t{:32}: {}\n".format("Number of parameters", n_pars)
252
        str_ += "\t{:32}: {}\n\n".format("Number of free parameters", n_free_pars)
253
254
        if self.models is not None:
255
            str_ += "\t" + "\n\t".join(str(self.models).split("\n")[2:])
256
257
        return str_.expandtabs(tabsize=2)
258
259
    @property
260
    def geoms(self):
261
        """Map geometries
262
263
        Returns
264
        -------
265
        geoms : dict
266
            Dict of map geometries involved in the dataset.
267
        """
268
        geoms = {}
269
270
        geoms["geom"] = self._geom
271
272
        if self.exposure:
273
            geoms["geom_exposure"] = self.exposure.geom
274
275
        if self.psf:
276
            geoms["geom_psf"] = self.psf.psf_map.geom
277
278
        if self.edisp:
279
            geoms["geom_edisp"] = self.edisp.edisp_map.geom
280
281
        return geoms
282
283
    @property
284
    def models(self):
285
        """Models (`~gammapy.modeling.models.Models`)."""
286
        return self._models
287
288
    @property
289
    def excess(self):
290
        """Excess"""
291
        return self.counts - self.background
292
293
    @models.setter
294
    def models(self, models):
295
        """Models setter"""
296
        self._evaluators = {}
297
298
        if models is not None:
299
            models = DatasetModels(models)
300
            models = models.select(datasets_names=self.name)
301
302
            for model in models:
303
                if not isinstance(model, FoVBackgroundModel):
304
                    evaluator = MapEvaluator(
305
                        model=model,
306
                        evaluation_mode=EVALUATION_MODE,
307
                        gti=self.gti,
308
                        use_cache=USE_NPRED_CACHE,
309
                    )
310
                    self._evaluators[model.name] = evaluator
311
312
        self._models = models
313
314
    @property
315
    def evaluators(self):
316
        """Model evaluators"""
317
        return self._evaluators
318
319
    @property
320
    def _geom(self):
321
        """Main analysis geometry"""
322
        if self.counts is not None:
323
            return self.counts.geom
324
        elif self.background is not None:
325
            return self.background.geom
326
        elif self.mask_safe is not None:
327
            return self.mask_safe.geom
328
        elif self.mask_fit is not None:
329
            return self.mask_fit.geom
330
        else:
331
            raise ValueError(
332
                "Either 'counts', 'background', 'mask_fit'"
333
                " or 'mask_safe' must be defined."
334
            )
335
336
    @property
337
    def data_shape(self):
338
        """Shape of the counts or background data (tuple)"""
339
        return self._geom.data_shape
340
341
    def _energy_range(self, mask_map=None):
342
        """Compute the energy range maps with or without the fit mask."""
343
        geom = self._geom
344
        energy = geom.axes["energy"].edges
345
        e_i = geom.axes.index_data("energy")
346
        geom = geom.drop("energy")
347
348
        if mask_map is not None:
349
            mask = mask_map.data
350
            if mask.any():
351
                idx = mask.argmax(e_i)
352
                energy_min = energy.value[idx]
353
                mask_nan = ~mask.any(e_i)
354
                energy_min[mask_nan] = np.nan
355
356
                mask = np.flip(mask, e_i)
357
                idx = mask.argmax(e_i)
358
                energy_max = energy.value[::-1][idx]
359
                energy_max[mask_nan] = np.nan
360
            else:
361
                energy_min = np.full(geom.data_shape, np.nan)
362
                energy_max = energy_min.copy()
363
        else:
364
            data_shape = geom.data_shape
365
            energy_min = np.full(data_shape, energy.value[0])
366
            energy_max = np.full(data_shape, energy.value[-1])
367
368
        map_min = Map.from_geom(geom, data=energy_min, unit=energy.unit)
369
        map_max = Map.from_geom(geom, data=energy_max, unit=energy.unit)
370
        return map_min, map_max
371
372
    @property
373
    def energy_range(self):
374
        """Energy range maps defined by the full mask (mask_safe and mask_fit)."""
375
        return self._energy_range(self.mask)
376
377
    @property
378
    def energy_range_safe(self):
379
        """Energy range maps defined by the mask_safe only."""
380
        return self._energy_range(self.mask_safe)
381
382
    @property
383
    def energy_range_fit(self):
384
        """Energy range maps defined by the mask_fit only."""
385
        return self._energy_range(self.mask_fit)
386
387
    def npred(self):
388
        """Predicted source and background counts
389
390
        Returns
391
        -------
392
        npred : `Map`
393
            Total predicted counts
394
        """
395
        npred_total = self.npred_signal()
396
397
        if self.background:
398
            npred_total += self.npred_background()
399
400
        return npred_total
401
402
    def npred_background(self):
403
        """Predicted background counts
404
405
        The predicted background counts depend on the parameters
406
        of the `FoVBackgroundModel` defined in the dataset.
407
408
        Returns
409
        -------
410
        npred_background : `Map`
411
            Predicted counts from the background.
412
        """
413
        background = self.background
414
        if self.background_model and background:
415
            if self._background_parameters_changed:
416
                values = self.background_model.evaluate_geom(geom=self.background.geom)
417
                if self._background_cached is None:
418
                    self._background_cached = background * values
419
                else:
420
                    self._background_cached.data = background.data * values.value
421
                    self._background_cached.unit = background.unit
422
            return self._background_cached
423
        else:
424
            return background
425
426
        return background
427
428
    def _background_parameters_changed(self):
429
        values = self.background_model.parameters.value
430
        # TODO: possibly allow for a tolerance here?
431
        changed = ~np.all(self._background_parameters_cached == values)
432
        if changed:
433
            self._background_parameters_cached = values
434
        return changed
435
436
    def npred_signal(self, model_name=None):
437
        """ "Model predicted signal counts.
438
439
        If a model is passed, predicted counts from that component is returned.
440
        Else, the total signal counts are returned.
441
442
        Parameters
443
        -------------
444
        model_name: str
445
            Name of  SkyModel for which to compute the npred for.
446
            If none, the sum of all components (minus the background model)
447
            is returned
448
449
        Returns
450
        ----------
451
        npred_sig: `gammapy.maps.Map`
452
            Map of the predicted signal counts
453
        """
454
        npred_total = Map.from_geom(self._geom, dtype=float)
455
456
        evaluators = self.evaluators
457
        if model_name is not None:
458
            evaluators = {model_name: self.evaluators[model_name]}
459
460
        for evaluator in evaluators.values():
461
            if evaluator.needs_update:
462
                evaluator.update(
463
                    self.exposure,
464
                    self.psf,
465
                    self.edisp,
466
                    self._geom,
467
                    self.mask_image,
468
                )
469
470
            if evaluator.contributes:
471
                npred = evaluator.compute_npred()
472
                npred_total.stack(npred)
473
474
        return npred_total
475
476
    @classmethod
477
    def from_geoms(
478
        cls,
479
        geom,
480
        geom_exposure=None,
481
        geom_psf=None,
482
        geom_edisp=None,
483
        reference_time="2000-01-01",
484
        name=None,
485
        **kwargs,
486
    ):
487
        """
488
        Create a MapDataset object with zero filled maps according to the specified geometries
489
490
        Parameters
491
        ----------
492
        geom : `Geom`
493
            geometry for the counts and background maps
494
        geom_exposure : `Geom`
495
            geometry for the exposure map
496
        geom_psf : `Geom`
497
            geometry for the psf map
498
        geom_edisp : `Geom`
499
            geometry for the energy dispersion kernel map.
500
            If geom_edisp has a migra axis, this will create an EDispMap instead.
501
        reference_time : `~astropy.time.Time`
502
            the reference time to use in GTI definition
503
        name : str
504
            Name of the returned dataset.
505
506
        Returns
507
        -------
508
        dataset : `MapDataset` or `SpectrumDataset`
509
            A dataset containing zero filled maps
510
        """
511
        name = make_name(name)
512
        kwargs = kwargs.copy()
513
        kwargs["name"] = name
514
        kwargs["counts"] = Map.from_geom(geom, unit="")
515
        kwargs["background"] = Map.from_geom(geom, unit="")
516
517
        if geom_exposure:
518
            kwargs["exposure"] = Map.from_geom(geom_exposure, unit="m2 s")
519
520
        if geom_edisp:
521
            if "energy" in geom_edisp.axes.names:
522
                kwargs["edisp"] = EDispKernelMap.from_geom(geom_edisp)
523
            else:
524
                kwargs["edisp"] = EDispMap.from_geom(geom_edisp)
525
526
        if geom_psf:
527
            kwargs["psf"] = PSFMap.from_geom(geom_psf)
528
529
        kwargs.setdefault(
530
            "gti", GTI.create([] * u.s, [] * u.s, reference_time=reference_time)
531
        )
532
        kwargs["mask_safe"] = Map.from_geom(geom, unit="", dtype=bool)
533
        return cls(**kwargs)
534
535
    @classmethod
536
    def create(
537
        cls,
538
        geom,
539
        energy_axis_true=None,
540
        migra_axis=None,
541
        rad_axis=None,
542
        binsz_irf=None,
543
        reference_time="2000-01-01",
544
        name=None,
545
        meta_table=None,
546
        **kwargs,
547
    ):
548
        """Create a MapDataset object with zero filled maps.
549
550
        Parameters
551
        ----------
552
        geom : `~gammapy.maps.WcsGeom`
553
            Reference target geometry in reco energy, used for counts and background maps
554
        energy_axis_true : `~gammapy.maps.MapAxis`
555
            True energy axis used for IRF maps
556
        migra_axis : `~gammapy.maps.MapAxis`
557
            If set, this provides the migration axis for the energy dispersion map.
558
            If not set, an EDispKernelMap is produced instead. Default is None
559
        rad_axis : `~gammapy.maps.MapAxis`
560
            Rad axis for the psf map
561
        binsz_irf : float
562
            IRF Map pixel size in degrees.
563
        reference_time : `~astropy.time.Time`
564
            the reference time to use in GTI definition
565
        name : str
566
            Name of the returned dataset.
567
        meta_table : `~astropy.table.Table`
568
            Table listing information on observations used to create the dataset.
569
            One line per observation for stacked datasets.
570
571
        Returns
572
        -------
573
        empty_maps : `MapDataset`
574
            A MapDataset containing zero filled maps
575
        """
576
        geoms = create_map_dataset_geoms(
577
            geom=geom,
578
            energy_axis_true=energy_axis_true,
579
            rad_axis=rad_axis,
580
            migra_axis=migra_axis,
581
            binsz_irf=binsz_irf,
582
        )
583
584
        kwargs.update(geoms)
585
        return cls.from_geoms(reference_time=reference_time, name=name, **kwargs)
586
587
    @property
588
    def mask_safe_image(self):
589
        """Reduced mask safe"""
590
        if self.mask_safe is None:
591
            return None
592
        return self.mask_safe.reduce_over_axes(func=np.logical_or)
593
594
    @property
595
    def mask_fit_image(self):
596
        """Reduced mask fit"""
597
        if self.mask_fit is None:
598
            return None
599
        return self.mask_fit.reduce_over_axes(func=np.logical_or)
600
601
    @property
602
    def mask_image(self):
603
        """Reduced mask"""
604
        if self.mask is None:
605
            mask = Map.from_geom(self._geom.to_image(), dtype=bool)
606
            mask.data |= True
607
            return mask
608
609
        return self.mask.reduce_over_axes(func=np.logical_or)
610
611
    @property
612
    def mask_safe_psf(self):
613
        """Mask safe for psf maps"""
614
        if self.mask_safe is None or self.psf is None:
615
            return None
616
617
        geom = self.psf.psf_map.geom.squash("energy_true").squash("rad")
618
        mask_safe_psf = self.mask_safe_image.interp_to_geom(geom.to_image())
619
        return mask_safe_psf.to_cube(geom.axes)
620
621
    @property
622
    def mask_safe_edisp(self):
623
        """Mask safe for edisp maps"""
624
        if self.mask_safe is None or self.edisp is None:
625
            return None
626
627
        if self.mask_safe.geom.is_region:
628
            return self.mask_safe
629
630
        geom = self.edisp.edisp_map.geom.squash("energy_true")
631
632
        if "migra" in geom.axes.names:
633
            geom = geom.squash("migra")
634
            mask_safe_edisp = self.mask_safe_image.interp_to_geom(geom.to_image())
635
            return mask_safe_edisp.to_cube(geom.axes)
636
637
        return self.mask_safe.interp_to_geom(geom)
638
639
    def to_masked(self, name=None, nan_to_num=True):
640
        """Return masked dataset
641
642
        Parameters
643
        ----------
644
        name : str
645
            Name of the masked dataset.
646
647
        Returns
648
        -------
649
        dataset : `MapDataset` or `SpectrumDataset`
650
            Masked dataset
651
        """
652
        dataset = self.__class__.from_geoms(**self.geoms, name=name)
653
        dataset.stack(self, nan_to_num=nan_to_num)
654
        return dataset
655
656
    def stack(self, other, nan_to_num=True):
657
        r"""Stack another dataset in place.
658
659
        Safe mask is applied to compute the stacked counts data. Counts outside
660
        each dataset safe mask are lost.
661
662
        The stacking of 2 datasets is implemented as follows. Here, :math:`k`
663
        denotes a bin in reconstructed energy and :math:`j = {1,2}` is the dataset number
664
665
        The ``mask_safe`` of each dataset is defined as:
666
667
        .. math::
668
669
            \epsilon_{jk} =\left\{\begin{array}{cl} 1, &
670
            \mbox{if bin k is inside the thresholds}\\ 0, &
671
            \mbox{otherwise} \end{array}\right.
672
673
        Then the total ``counts`` and model background ``bkg`` are computed according to:
674
675
        .. math::
676
677
            \overline{\mathrm{n_{on}}}_k =  \mathrm{n_{on}}_{1k} \cdot \epsilon_{1k} +
678
             \mathrm{n_{on}}_{2k} \cdot \epsilon_{2k}
679
680
            \overline{bkg}_k = bkg_{1k} \cdot \epsilon_{1k} +
681
             bkg_{2k} \cdot \epsilon_{2k}
682
683
        The stacked ``safe_mask`` is then:
684
685
        .. math::
686
687
            \overline{\epsilon_k} = \epsilon_{1k} OR \epsilon_{2k}
688
689
690
        Parameters
691
        ----------
692
        other: `~gammapy.datasets.MapDataset` or `~gammapy.datasets.MapDatasetOnOff`
693
            Map dataset to be stacked with this one. If other is an on-off
694
            dataset alpha * counts_off is used as a background model.
695
        nan_to_num: bool
696
            Non-finite values are replaced by zero if True (default).
697
698
        """
699
        if self.counts and other.counts:
700
            self.counts.stack(
701
                other.counts, weights=other.mask_safe, nan_to_num=nan_to_num
702
            )
703
704
        if self.exposure and other.exposure:
705
            self.exposure.stack(
706
                other.exposure, weights=other.mask_safe_image, nan_to_num=nan_to_num
707
            )
708
            # TODO: check whether this can be improved e.g. handling this in GTI
709
710
            if "livetime" in other.exposure.meta and np.any(other.mask_safe_image):
711
                if "livetime" in self.exposure.meta:
712
                    self.exposure.meta["livetime"] += other.exposure.meta["livetime"]
713
                else:
714
                    self.exposure.meta["livetime"] = other.exposure.meta[
715
                        "livetime"
716
                    ].copy()
717
718
        if self.stat_type == "cash":
719
            if self.background and other.background:
720
                background = self.npred_background() * self.mask_safe
721
                background.stack(
722
                    other.npred_background(),
723
                    weights=other.mask_safe,
724
                    nan_to_num=nan_to_num,
725
                )
726
                self.background = background
727
728
        if self.psf and other.psf:
729
            self.psf.stack(other.psf, weights=other.mask_safe_psf)
730
731
        if self.edisp and other.edisp:
732
            self.edisp.stack(other.edisp, weights=other.mask_safe_edisp)
733
734
        if self.mask_safe and other.mask_safe:
735
            self.mask_safe.stack(other.mask_safe)
736
737
        if self.gti and other.gti:
738
            self.gti.stack(other.gti)
739
            self.gti = self.gti.union()
740
741
        if self.meta_table and other.meta_table:
742
            self.meta_table = hstack_columns(self.meta_table, other.meta_table)
743
        elif other.meta_table:
744
            self.meta_table = other.meta_table.copy()
745
746
    def stat_array(self):
747
        """Likelihood per bin given the current model parameters"""
748
        return cash(n_on=self.counts.data, mu_on=self.npred().data)
749
750
    def residuals(self, method="diff", **kwargs):
751
        """Compute residuals map.
752
753
        Parameters
754
        ----------
755
        method: {"diff", "diff/model", "diff/sqrt(model)"}
756
            Method used to compute the residuals. Available options are:
757
                - "diff" (default): data - model
758
                - "diff/model": (data - model) / model
759
                - "diff/sqrt(model)": (data - model) / sqrt(model)
760
        **kwargs : dict
761
            Keyword arguments forwarded to `Map.smooth()`
762
763
        Returns
764
        -------
765
        residuals : `gammapy.maps.Map`
766
            Residual map.
767
        """
768
        npred, counts = self.npred(), self.counts.copy()
769
770
        if self.mask:
771
            npred = npred * self.mask
772
            counts = counts * self.mask
773
774
        if kwargs:
775
            kwargs.setdefault("mode", "constant")
776
            kwargs.setdefault("width", "0.1 deg")
777
            kwargs.setdefault("kernel", "gauss")
778
            with np.errstate(invalid="ignore", divide="ignore"):
779
                npred = npred.smooth(**kwargs)
780
                counts = counts.smooth(**kwargs)
781
                if self.mask:
782
                    mask = self.mask.smooth(**kwargs)
783
                    npred /= mask
784
                    counts /= mask
785
786
        residuals = self._compute_residuals(counts, npred, method=method)
787
788
        if self.mask:
789
            residuals.data[~self.mask.data] = np.nan
790
791
        return residuals
792
793
    def plot_residuals_spatial(
794
        self,
795
        ax=None,
796
        method="diff",
797
        smooth_kernel="gauss",
798
        smooth_radius="0.1 deg",
799
        **kwargs,
800
    ):
801
        """Plot spatial residuals.
802
803
        The normalization used for the residuals computation can be controlled
804
        using the method parameter.
805
806
        Parameters
807
        ----------
808
        ax : `~astropy.visualization.wcsaxes.WCSAxes`
809
            Axes to plot on.
810
        method : {"diff", "diff/model", "diff/sqrt(model)"}
811
            Normalization used to compute the residuals, see `MapDataset.residuals`.
812
        smooth_kernel : {"gauss", "box"}
813
            Kernel shape.
814
        smooth_radius: `~astropy.units.Quantity`, str or float
815
            Smoothing width given as quantity or float. If a float is given, it
816
            is interpreted as smoothing width in pixels.
817
        **kwargs : dict
818
            Keyword arguments passed to `~matplotlib.axes.Axes.imshow`.
819
820
        Returns
821
        -------
822
        ax : `~astropy.visualization.wcsaxes.WCSAxes`
823
            WCSAxes object.
824
        """
825
        counts, npred = self.counts.copy(), self.npred()
826
827
        if counts.geom.is_region:
828
            raise ValueError("Cannot plot spatial residuals for RegionNDMap")
829
830
        if self.mask is not None:
831
            counts *= self.mask
832
            npred *= self.mask
833
834
        counts_spatial = counts.sum_over_axes().smooth(
835
            width=smooth_radius, kernel=smooth_kernel
836
        )
837
        npred_spatial = npred.sum_over_axes().smooth(
838
            width=smooth_radius, kernel=smooth_kernel
839
        )
840
        residuals = self._compute_residuals(counts_spatial, npred_spatial, method)
841
842
        if self.mask_safe is not None:
843
            mask = self.mask_safe.reduce_over_axes(func=np.logical_or, keepdims=True)
844
            residuals.data[~mask.data] = np.nan
845
846
        kwargs.setdefault("add_cbar", True)
847
        kwargs.setdefault("cmap", "coolwarm")
848
        kwargs.setdefault("vmin", -5)
849
        kwargs.setdefault("vmax", 5)
850
        ax = residuals.plot(ax, **kwargs)
851
        return ax
852
853
    def plot_residuals_spectral(self, ax=None, method="diff", region=None, **kwargs):
854
        """Plot spectral residuals.
855
856
        The residuals are extracted from the provided region, and the normalization
857
        used for its computation can be controlled using the method parameter.
858
859
        Parameters
860
        ----------
861
        ax : `~matplotlib.axes.Axes`
862
            Axes to plot on.
863
        method : {"diff", "diff/sqrt(model)"}
864
            Normalization used to compute the residuals, see `SpectrumDataset.residuals`.
865
        region: `~regions.SkyRegion` (required)
866
            Target sky region.
867
        **kwargs : dict
868
            Keyword arguments passed to `~matplotlib.axes.Axes.errorbar`.
869
870
        Returns
871
        -------
872
        ax : `~matplotlib.axes.Axes`
873
            Axes object.
874
        """
875
        counts, npred = self.counts.copy(), self.npred()
876
877
        if self.mask is None:
878
            mask = self.counts.copy()
879
            mask.data = 1
880
        else:
881
            mask = self.mask
882
        counts *= mask
883
        npred *= mask
884
885
        counts_spec = counts.get_spectrum(region)
886
        npred_spec = npred.get_spectrum(region)
887
        residuals = self._compute_residuals(counts_spec, npred_spec, method)
888
889
        if method == "diff":
890
            if self.stat_type == "wstat":
891
                counts_off = (self.counts_off * mask).get_spectrum(region)
892
893
                with np.errstate(invalid="ignore"):
894
                    alpha = (self.background * mask).get_spectrum(region) / counts_off
895
896
                mu_sig = (self.npred_signal() * mask).get_spectrum(region)
897
                stat = WStatCountsStatistic(
898
                    n_on=counts_spec,
899
                    n_off=counts_off,
900
                    alpha=alpha,
901
                    mu_sig=mu_sig,
902
                )
903
            elif self.stat_type == "cash":
904
                stat = CashCountsStatistic(counts_spec.data, npred_spec.data)
905
            yerr = stat.error
0 ignored issues
show
introduced by
The variable stat does not seem to be defined for all execution paths.
Loading history...
906
        elif method == "diff/sqrt(model)":
907
            yerr = np.ones_like(residuals.data)
908
        else:
909
            raise ValueError(
910
                'Invalid method, choose between "diff" and "diff/sqrt(model)"'
911
            )
912
913
        kwargs.setdefault("color", kwargs.pop("c", "black"))
914
        ax = residuals.plot(ax, yerr=yerr, **kwargs)
915
        ax.axhline(0, color=kwargs["color"], lw=0.5)
916
917
        label = self._residuals_labels[method]
918
        ax.set_ylabel(f"Residuals ({label})")
919
        ax.set_yscale("linear")
920
        ymin = 1.05 * np.nanmin(residuals.data - yerr)
921
        ymax = 1.05 * np.nanmax(residuals.data + yerr)
922
        ax.set_ylim(ymin, ymax)
923
        return ax
924
925
    def plot_residuals(
926
        self,
927
        ax_spatial=None,
928
        ax_spectral=None,
929
        kwargs_spatial=None,
930
        kwargs_spectral=None,
931
    ):
932
        """Plot spatial and spectral residuals in two panels.
933
934
        Calls `~MapDataset.plot_residuals_spatial` and `~MapDataset.plot_residuals_spectral`.
935
        The spectral residuals are extracted from the provided region, and the
936
        normalization used for its computation can be controlled using the method
937
        parameter. The region outline is overlaid on the residuals map.
938
939
        Parameters
940
        ----------
941
        ax_spatial : `~astropy.visualization.wcsaxes.WCSAxes`
942
            Axes to plot spatial residuals on.
943
        ax_spectral : `~matplotlib.axes.Axes`
944
            Axes to plot spectral residuals on.
945
        kwargs_spatial : dict
946
            Keyword arguments passed to `~MapDataset.plot_residuals_spatial`.
947
        kwargs_spectral : dict (``region`` required)
948
            Keyword arguments passed to `~MapDataset.plot_residuals_spectral`.
949
950
        Returns
951
        -------
952
        ax_spatial, ax_spectral : `~astropy.visualization.wcsaxes.WCSAxes`, `~matplotlib.axes.Axes`
953
            Spatial and spectral residuals plots.
954
        """
955
        ax_spatial, ax_spectral = get_axes(
956
            ax_spatial,
957
            ax_spectral,
958
            12,
959
            4,
960
            [1, 2, 1],
961
            [1, 2, 2],
962
            {"projection": self._geom.to_image().wcs},
963
        )
964
        kwargs_spatial = kwargs_spatial or {}
965
966
        self.plot_residuals_spatial(ax_spatial, **kwargs_spatial)
967
        self.plot_residuals_spectral(ax_spectral, **kwargs_spectral)
968
969
        # Overlay spectral extraction region on the spatial residuals
970
        region = kwargs_spectral["region"]
971
        pix_region = region.to_pixel(self._geom.to_image().wcs)
972
        pix_region.plot(ax=ax_spatial)
973
974
        return ax_spatial, ax_spectral
975
976
    def stat_sum(self):
977
        """Total likelihood given the current model parameters."""
978
        counts, npred = self.counts.data.astype(float), self.npred().data
979
980
        if self.mask is not None:
981
            return cash_sum_cython(counts[self.mask.data], npred[self.mask.data])
982
        else:
983
            return cash_sum_cython(counts.ravel(), npred.ravel())
984
985
    def fake(self, random_state="random-seed"):
986
        """Simulate fake counts for the current model and reduced IRFs.
987
988
        This method overwrites the counts defined on the dataset object.
989
990
        Parameters
991
        ----------
992
        random_state : {int, 'random-seed', 'global-rng', `~numpy.random.RandomState`}
993
                Defines random number generator initialisation.
994
                Passed to `~gammapy.utils.random.get_random_state`.
995
        """
996
        random_state = get_random_state(random_state)
997
        npred = self.npred()
998
        data = np.nan_to_num(npred.data, copy=True, nan=0.0, posinf=0.0, neginf=0.0)
999
        npred.data = random_state.poisson(data)
1000
        self.counts = npred
1001
1002
    def to_hdulist(self):
1003
        """Convert map dataset to list of HDUs.
1004
1005
        Returns
1006
        -------
1007
        hdulist : `~astropy.io.fits.HDUList`
1008
            Map dataset list of HDUs.
1009
        """
1010
        # TODO: what todo about the model and background model parameters?
1011
        exclude_primary = slice(1, None)
1012
1013
        hdu_primary = fits.PrimaryHDU()
1014
        hdulist = fits.HDUList([hdu_primary])
1015
        if self.counts is not None:
1016
            hdulist += self.counts.to_hdulist(hdu="counts")[exclude_primary]
1017
1018
        if self.exposure is not None:
1019
            hdulist += self.exposure.to_hdulist(hdu="exposure")[exclude_primary]
1020
1021
        if self.background is not None:
1022
            hdulist += self.background.to_hdulist(hdu="background")[exclude_primary]
1023
1024
        if self.edisp is not None:
1025
            hdulist += self.edisp.to_hdulist()[exclude_primary]
1026
1027
        if self.psf is not None:
1028
            hdulist += self.psf.to_hdulist()[exclude_primary]
1029
1030
        if self.mask_safe is not None:
1031
            hdulist += self.mask_safe.to_hdulist(hdu="mask_safe")[exclude_primary]
1032
1033
        if self.mask_fit is not None:
1034
            hdulist += self.mask_fit.to_hdulist(hdu="mask_fit")[exclude_primary]
1035
1036
        if self.gti is not None:
1037
            hdulist.append(fits.BinTableHDU(self.gti.table, name="GTI"))
1038
1039
        return hdulist
1040
1041
    @classmethod
1042
    def from_hdulist(cls, hdulist, name=None, lazy=False, format="gadf"):
1043
        """Create map dataset from list of HDUs.
1044
1045
        Parameters
1046
        ----------
1047
        hdulist : `~astropy.io.fits.HDUList`
1048
            List of HDUs.
1049
        name : str
1050
            Name of the new dataset.
1051
        format : {"gadf"}
1052
            Format the hdulist is given in.
1053
1054
        Returns
1055
        -------
1056
        dataset : `MapDataset`
1057
            Map dataset.
1058
        """
1059
        name = make_name(name)
1060
        kwargs = {"name": name}
1061
1062
        if "COUNTS" in hdulist:
1063
            kwargs["counts"] = Map.from_hdulist(hdulist, hdu="counts", format=format)
1064
1065
        if "EXPOSURE" in hdulist:
1066
            exposure = Map.from_hdulist(hdulist, hdu="exposure", format=format)
1067
            if exposure.geom.axes[0].name == "energy":
1068
                exposure.geom.axes[0].name = "energy_true"
1069
            kwargs["exposure"] = exposure
1070
1071
        if "BACKGROUND" in hdulist:
1072
            kwargs["background"] = Map.from_hdulist(
1073
                hdulist, hdu="background", format=format
1074
            )
1075
1076 View Code Duplication
        if "EDISP" in hdulist:
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
1077
            edisp_map = Map.from_hdulist(hdulist, hdu="edisp", format=format)
1078
1079
            try:
1080
                exposure_map = Map.from_hdulist(
1081
                    hdulist, hdu="edisp_exposure", format=format
1082
                )
1083
            except KeyError:
1084
                exposure_map = None
1085
1086
            if edisp_map.geom.axes[0].name == "energy":
1087
                kwargs["edisp"] = EDispKernelMap(edisp_map, exposure_map)
1088
            else:
1089
                kwargs["edisp"] = EDispMap(edisp_map, exposure_map)
1090
1091
        if "PSF" in hdulist:
1092
            psf_map = Map.from_hdulist(hdulist, hdu="psf", format=format)
1093
            try:
1094
                exposure_map = Map.from_hdulist(
1095
                    hdulist, hdu="psf_exposure", format=format
1096
                )
1097
            except KeyError:
1098
                exposure_map = None
1099
            kwargs["psf"] = PSFMap(psf_map, exposure_map)
1100
1101
        if "MASK_SAFE" in hdulist:
1102
            mask_safe = Map.from_hdulist(hdulist, hdu="mask_safe", format=format)
1103
            mask_safe.data = mask_safe.data.astype(bool)
1104
            kwargs["mask_safe"] = mask_safe
1105
1106
        if "MASK_FIT" in hdulist:
1107
            mask_fit = Map.from_hdulist(hdulist, hdu="mask_fit", format=format)
1108
            mask_fit.data = mask_fit.data.astype(bool)
1109
            kwargs["mask_fit"] = mask_fit
1110
1111
        if "GTI" in hdulist:
1112
            gti = GTI(Table.read(hdulist, hdu="GTI"))
1113
            kwargs["gti"] = gti
1114
1115
        return cls(**kwargs)
1116
1117
    def write(self, filename, overwrite=False):
1118
        """Write map dataset to file.
1119
1120
        Parameters
1121
        ----------
1122
        filename : str
1123
            Filename to write to.
1124
        overwrite : bool
1125
            Overwrite file if it exists.
1126
        """
1127
        self.to_hdulist().writeto(str(make_path(filename)), overwrite=overwrite)
1128
1129
    @classmethod
1130
    def _read_lazy(cls, name, filename, cache, format=format):
1131
        kwargs = {"name": name}
1132
        try:
1133
            kwargs["gti"] = GTI.read(filename)
1134
        except KeyError:
1135
            pass
1136
1137
        path = make_path(filename)
1138
        for hdu_name in ["counts", "exposure", "mask_fit", "mask_safe", "background"]:
1139
            kwargs[hdu_name] = HDULocation(
1140
                hdu_class="map",
1141
                file_dir=path.parent,
1142
                file_name=path.name,
1143
                hdu_name=hdu_name.upper(),
1144
                cache=cache,
1145
                format=format,
1146
            )
1147
1148
        kwargs["edisp"] = HDULocation(
1149
            hdu_class="edisp_kernel_map",
1150
            file_dir=path.parent,
1151
            file_name=path.name,
1152
            hdu_name="EDISP",
1153
            cache=cache,
1154
            format=format,
1155
        )
1156
1157
        kwargs["psf"] = HDULocation(
1158
            hdu_class="psf_map",
1159
            file_dir=path.parent,
1160
            file_name=path.name,
1161
            hdu_name="PSF",
1162
            cache=cache,
1163
            format=format,
1164
        )
1165
1166
        return cls(**kwargs)
1167
1168
    @classmethod
1169
    def read(cls, filename, name=None, lazy=False, cache=True, format="gadf"):
1170
        """Read map dataset from file.
1171
1172
        Parameters
1173
        ----------
1174
        filename : str
1175
            Filename to read from.
1176
        name : str
1177
            Name of the new dataset.
1178
        lazy : bool
1179
            Whether to lazy load data into memory
1180
        cache : bool
1181
            Whether to cache the data after loading.
1182
        format : {"gadf"}
1183
            Format of the dataset file.
1184
1185
        Returns
1186
        -------
1187
        dataset : `MapDataset`
1188
            Map dataset.
1189
        """
1190
        name = make_name(name)
1191
1192
        if lazy:
1193
            return cls._read_lazy(
1194
                name=name, filename=filename, cache=cache, format=format
1195
            )
1196
        else:
1197
            with fits.open(str(make_path(filename)), memmap=False) as hdulist:
1198
                return cls.from_hdulist(hdulist, name=name, format=format)
1199
1200
    @classmethod
1201
    def from_dict(cls, data, lazy=False, cache=True):
1202
        """Create from dicts and models list generated from YAML serialization."""
1203
        filename = make_path(data["filename"])
1204
        dataset = cls.read(filename, name=data["name"], lazy=lazy, cache=cache)
1205
        return dataset
1206
1207
    def info_dict(self, in_safe_data_range=True):
1208
        """Info dict with summary statistics, summed over energy
1209
1210
        Parameters
1211
        ----------
1212
        in_safe_data_range : bool
1213
            Whether to sum only in the safe energy range
1214
1215
        Returns
1216
        -------
1217
        info_dict : dict
1218
            Dictionary with summary info.
1219
        """
1220
        info = {}
1221
        info["name"] = self.name
1222
1223
        if self.mask_safe and in_safe_data_range:
1224
            mask = self.mask_safe.data.astype(bool)
1225
        else:
1226
            mask = slice(None)
1227
1228
        counts = 0
1229
        if self.counts:
1230
            counts = self.counts.data[mask].sum()
1231
1232
        info["counts"] = int(counts)
1233
1234
        background = np.nan
1235
        if self.background:
1236
            background = self.background.data[mask].sum()
1237
1238
        info["background"] = float(background)
1239
1240
        info["excess"] = counts - background
1241
        info["sqrt_ts"] = CashCountsStatistic(counts, background).sqrt_ts
1242
1243
        npred = np.nan
1244
        if self.models or not np.isnan(background):
1245
            npred = self.npred().data[mask].sum()
1246
1247
        info["npred"] = float(npred)
1248
1249
        npred_background = np.nan
1250
        if self.background:
1251
            npred_background = self.npred_background().data[mask].sum()
1252
1253
        info["npred_background"] = float(npred_background)
1254
1255
        npred_signal = np.nan
1256
        if self.models:
1257
            npred_signal = self.npred_signal().data[mask].sum()
1258
1259
        info["npred_signal"] = float(npred_signal)
1260
1261
        exposure_min = np.nan * u.Unit("cm s")
1262
        exposure_max = np.nan * u.Unit("cm s")
1263
        livetime = np.nan * u.s
1264
1265
        if self.exposure is not None:
1266
            mask_exposure = self.exposure.data > 0
1267
1268
            if self.mask_safe is not None:
1269
                mask_spatial = self.mask_safe.reduce_over_axes(func=np.logical_or).data
1270
                mask_exposure = mask_exposure & mask_spatial[np.newaxis, :, :]
1271
1272
            if not mask_exposure.any():
1273
                mask_exposure = slice(None)
1274
1275
            exposure_min = np.min(self.exposure.quantity[mask_exposure])
1276
            exposure_max = np.max(self.exposure.quantity[mask_exposure])
1277
            livetime = self.exposure.meta.get("livetime", np.nan * u.s).copy()
1278
1279
        info["exposure_min"] = exposure_min.item()
1280
        info["exposure_max"] = exposure_max.item()
1281
        info["livetime"] = livetime
1282
1283
        ontime = u.Quantity(np.nan, "s")
1284
        if self.gti:
1285
            ontime = self.gti.time_sum
1286
1287
        info["ontime"] = ontime
1288
1289
        info["counts_rate"] = info["counts"] / info["livetime"]
1290
        info["background_rate"] = info["background"] / info["livetime"]
1291
        info["excess_rate"] = info["excess"] / info["livetime"]
1292
1293
        # data section
1294
        n_bins = 0
1295
        if self.counts is not None:
1296
            n_bins = self.counts.data.size
1297
        info["n_bins"] = int(n_bins)
1298
1299
        n_fit_bins = 0
1300
        if self.mask is not None:
1301
            n_fit_bins = np.sum(self.mask.data)
1302
1303
        info["n_fit_bins"] = int(n_fit_bins)
1304
        info["stat_type"] = self.stat_type
1305
1306
        stat_sum = np.nan
1307
        if self.counts is not None and self.models is not None:
1308
            stat_sum = self.stat_sum()
1309
1310
        info["stat_sum"] = float(stat_sum)
1311
1312
        return info
1313
1314
    def to_spectrum_dataset(self, on_region, containment_correction=False, name=None):
1315
        """Return a ~gammapy.datasets.SpectrumDataset from on_region.
1316
1317
        Counts and background are summed in the on_region. Exposure is taken
1318
        from the average exposure.
1319
1320
        The energy dispersion kernel is obtained at the on_region center.
1321
        Only regions with centers are supported.
1322
1323
        The model is not exported to the ~gammapy.datasets.SpectrumDataset.
1324
        It must be set after the dataset extraction.
1325
1326
        Parameters
1327
        ----------
1328
        on_region : `~regions.SkyRegion`
1329
            the input ON region on which to extract the spectrum
1330
        containment_correction : bool
1331
            Apply containment correction for point sources and circular on regions
1332
        name : str
1333
            Name of the new dataset.
1334
1335
        Returns
1336
        -------
1337
        dataset : `~gammapy.datasets.SpectrumDataset`
1338
            the resulting reduced dataset
1339
        """
1340
        from .spectrum import SpectrumDataset
1341
1342
        dataset = self.to_region_map_dataset(region=on_region, name=name)
1343
1344
        if containment_correction:
1345
            if not isinstance(on_region, CircleSkyRegion):
1346
                raise TypeError(
1347
                    "Containment correction is only supported for" " `CircleSkyRegion`."
1348
                )
1349
            elif self.psf is None or isinstance(self.psf, PSFKernel):
1350
                raise ValueError("No PSFMap set. Containment correction impossible")
1351
            else:
1352
                geom = dataset.exposure.geom
1353
                energy_true = geom.axes["energy_true"].center
1354
                containment = self.psf.containment(
1355
                    position=on_region.center,
1356
                    energy_true=energy_true,
1357
                    rad=on_region.radius,
1358
                )
1359
                dataset.exposure.quantity *= containment.reshape(geom.data_shape)
1360
1361
        kwargs = {"name": name}
1362
1363
        for key in [
1364
            "counts",
1365
            "edisp",
1366
            "mask_safe",
1367
            "mask_fit",
1368
            "exposure",
1369
            "gti",
1370
            "meta_table",
1371
        ]:
1372
            kwargs[key] = getattr(dataset, key)
1373
1374
        if self.stat_type == "cash":
1375
            kwargs["background"] = dataset.background
1376
1377
        return  SpectrumDataset(**kwargs)
1378
1379
    def to_region_map_dataset(self, region, name=None):
1380
        """Integrate the map dataset in a given region.
1381
1382
        Counts and background of the dataset are integrated in the given region,
1383
        taking the safe mask into accounts. The exposure is averaged in the
1384
        region again taking the safe mask into account. The PSF and energy
1385
        dispersion kernel are taken at the center of the region.
1386
1387
        Parameters
1388
        ----------
1389
        region : `~regions.SkyRegion`
1390
            Region from which to extract the spectrum
1391
        name : str
1392
            Name of the new dataset.
1393
1394
        Returns
1395
        -------
1396
        dataset : `~gammapy.datasets.MapDataset`
1397
            the resulting reduced dataset
1398
        """
1399
        name = make_name(name)
1400
        kwargs = {"gti": self.gti, "name": name, "meta_table": self.meta_table}
1401
1402
        if self.mask_safe:
1403
            kwargs["mask_safe"] = self.mask_safe.to_region_nd_map(region, func=np.any)
1404
1405
        if self.mask_fit:
1406
            kwargs["mask_fit"] = self.mask_fit.to_region_nd_map(region, func=np.any)
1407
1408
        if self.counts:
1409
            kwargs["counts"] = self.counts.to_region_nd_map(
1410
                region, np.sum, weights=self.mask_safe
1411
            )
1412
1413
        if self.stat_type == "cash" and self.background:
1414
            kwargs["background"] = self.background.to_region_nd_map(
1415
                region, func=np.sum, weights=self.mask_safe
1416
            )
1417
1418
        if self.exposure:
1419
            kwargs["exposure"] = self.exposure.to_region_nd_map(region, func=np.mean)
1420
1421
        region = region.center if region else None
1422
1423
        # TODO: Compute average psf in region
1424
        if self.psf:
1425
            kwargs["psf"] = self.psf.to_region_nd_map(region)
1426
1427
        # TODO: Compute average edisp in region
1428
        if self.edisp is not None:
1429
            kwargs["edisp"] = self.edisp.to_region_nd_map(region)
1430
1431
        return self.__class__(**kwargs)
1432
1433
    def cutout(self, position, width, mode="trim", name=None):
1434
        """Cutout map dataset.
1435
1436
        Parameters
1437
        ----------
1438
        position : `~astropy.coordinates.SkyCoord`
1439
            Center position of the cutout region.
1440
        width : tuple of `~astropy.coordinates.Angle`
1441
            Angular sizes of the region in (lon, lat) in that specific order.
1442
            If only one value is passed, a square region is extracted.
1443
        mode : {'trim', 'partial', 'strict'}
1444
            Mode option for Cutout2D, for details see `~astropy.nddata.utils.Cutout2D`.
1445
        name : str
1446
            Name of the new dataset.
1447
1448
        Returns
1449
        -------
1450
        cutout : `MapDataset`
1451
            Cutout map dataset.
1452
        """
1453
        name = make_name(name)
1454
        kwargs = {"gti": self.gti, "name": name, "meta_table": self.meta_table}
1455
        cutout_kwargs = {"position": position, "width": width, "mode": mode}
1456
1457
        if self.counts is not None:
1458
            kwargs["counts"] = self.counts.cutout(**cutout_kwargs)
1459
1460
        if self.exposure is not None:
1461
            kwargs["exposure"] = self.exposure.cutout(**cutout_kwargs)
1462
1463
        if self.background is not None and self.stat_type == "cash":
1464
            kwargs["background"] = self.background.cutout(**cutout_kwargs)
1465
1466
        if self.edisp is not None:
1467
            kwargs["edisp"] = self.edisp.cutout(**cutout_kwargs)
1468
1469
        if self.psf is not None:
1470
            kwargs["psf"] = self.psf.cutout(**cutout_kwargs)
1471
1472
        if self.mask_safe is not None:
1473
            kwargs["mask_safe"] = self.mask_safe.cutout(**cutout_kwargs)
1474
1475
        if self.mask_fit is not None:
1476
            kwargs["mask_fit"] = self.mask_fit.cutout(**cutout_kwargs)
1477
1478
        return self.__class__(**kwargs)
1479
1480
    def downsample(self, factor, axis_name=None, name=None):
1481
        """Downsample map dataset.
1482
1483
        The PSFMap and EDispKernelMap are not downsampled, except if
1484
        a corresponding axis is given.
1485
1486
        Parameters
1487
        ----------
1488
        factor : int
1489
            Downsampling factor.
1490
        axis_name : str
1491
            Which non-spatial axis to downsample. By default only spatial axes are downsampled.
1492
        name : str
1493
            Name of the downsampled dataset.
1494
1495
        Returns
1496
        -------
1497
        dataset : `MapDataset` or `SpectrumDataset`
1498
            Downsampled map dataset.
1499
        """
1500
        name = make_name(name)
1501
1502
        kwargs = {"gti": self.gti, "name": name, "meta_table": self.meta_table}
1503
1504
        if self.counts is not None:
1505
            kwargs["counts"] = self.counts.downsample(
1506
                factor=factor,
1507
                preserve_counts=True,
1508
                axis_name=axis_name,
1509
                weights=self.mask_safe,
1510
            )
1511
1512
        if self.exposure is not None:
1513
            if axis_name is None:
1514
                kwargs["exposure"] = self.exposure.downsample(
1515
                    factor=factor, preserve_counts=False, axis_name=None
1516
                )
1517
            else:
1518
                kwargs["exposure"] = self.exposure.copy()
1519
1520
        if self.background is not None and self.stat_type == "cash":
1521
            kwargs["background"] = self.background.downsample(
1522
                factor=factor, axis_name=axis_name, weights=self.mask_safe
1523
            )
1524
1525
        if self.edisp is not None:
1526
            if axis_name is not None:
1527
                kwargs["edisp"] = self.edisp.downsample(
1528
                    factor=factor, axis_name=axis_name, weights=self.mask_safe_edisp
1529
                )
1530
            else:
1531
                kwargs["edisp"] = self.edisp.copy()
1532
1533
        if self.psf is not None:
1534
            kwargs["psf"] = self.psf.copy()
1535
1536
        if self.mask_safe is not None:
1537
            kwargs["mask_safe"] = self.mask_safe.downsample(
1538
                factor=factor, preserve_counts=False, axis_name=axis_name
1539
            )
1540
1541
        if self.mask_fit is not None:
1542
            kwargs["mask_fit"] = self.mask_fit.downsample(
1543
                factor=factor, preserve_counts=False, axis_name=axis_name
1544
            )
1545
1546
        return self.__class__(**kwargs)
1547
1548
    def pad(self, pad_width, mode="constant", name=None):
1549
        """Pad the spatial dimensions of the dataset.
1550
1551
        The padding only applies to counts, masks, background and exposure.
1552
1553
        Counts, background and masks are padded with zeros, exposure is padded with edge value.
1554
1555
        Parameters
1556
        ----------
1557
        pad_width : {sequence, array_like, int}
1558
            Number of pixels padded to the edges of each axis.
1559
        name : str
1560
            Name of the padded dataset.
1561
1562
        Returns
1563
        -------
1564
        dataset : `MapDataset`
1565
            Padded map dataset.
1566
1567
        """
1568
        name = make_name(name)
1569
        kwargs = {"gti": self.gti, "name": name, "meta_table": self.meta_table}
1570
1571
        if self.counts is not None:
1572
            kwargs["counts"] = self.counts.pad(pad_width=pad_width, mode=mode)
1573
1574
        if self.exposure is not None:
1575
            kwargs["exposure"] = self.exposure.pad(pad_width=pad_width, mode=mode)
1576
1577
        if self.background is not None:
1578
            kwargs["background"] = self.background.pad(pad_width=pad_width, mode=mode)
1579
1580
        if self.edisp is not None:
1581
            kwargs["edisp"] = self.edisp.copy()
1582
1583
        if self.psf is not None:
1584
            kwargs["psf"] = self.psf.copy()
1585
1586
        if self.mask_safe is not None:
1587
            kwargs["mask_safe"] = self.mask_safe.pad(pad_width=pad_width, mode=mode)
1588
1589
        if self.mask_fit is not None:
1590
            kwargs["mask_fit"] = self.mask_fit.pad(pad_width=pad_width, mode=mode)
1591
1592
        return self.__class__(**kwargs)
1593
1594
    def slice_by_idx(self, slices, name=None):
1595
        """Slice sub dataset.
1596
1597
        The slicing only applies to the maps that define the corresponding axes.
1598
1599
        Parameters
1600
        ----------
1601
        slices : dict
1602
            Dict of axes names and integers or `slice` object pairs. Contains one
1603
            element for each non-spatial dimension. For integer indexing the
1604
            corresponding axes is dropped from the map. Axes not specified in the
1605
            dict are kept unchanged.
1606
        name : str
1607
            Name of the sliced dataset.
1608
1609
        Returns
1610
        -------
1611
        dataset : `MapDataset` or `SpectrumDataset`
1612
            Sliced dataset
1613
        """
1614
        name = make_name(name)
1615
        kwargs = {"gti": self.gti, "name": name, "meta_table": self.meta_table}
1616
1617
        if self.counts is not None:
1618
            kwargs["counts"] = self.counts.slice_by_idx(slices=slices)
1619
1620
        if self.exposure is not None:
1621
            kwargs["exposure"] = self.exposure.slice_by_idx(slices=slices)
1622
1623
        if self.background is not None and self.stat_type == "cash":
1624
            kwargs["background"] = self.background.slice_by_idx(slices=slices)
1625
1626
        if self.edisp is not None:
1627
            kwargs["edisp"] = self.edisp.slice_by_idx(slices=slices)
1628
1629
        if self.psf is not None:
1630
            kwargs["psf"] = self.psf.slice_by_idx(slices=slices)
1631
1632
        if self.mask_safe is not None:
1633
            kwargs["mask_safe"] = self.mask_safe.slice_by_idx(slices=slices)
1634
1635
        if self.mask_fit is not None:
1636
            kwargs["mask_fit"] = self.mask_fit.slice_by_idx(slices=slices)
1637
1638
        return self.__class__(**kwargs)
1639
1640
    def slice_by_energy(self, energy_min=None, energy_max=None, name=None):
1641
        """Select and slice datasets in energy range
1642
1643
        Parameters
1644
        ----------
1645
        energy_min, energy_max : `~astropy.units.Quantity`
1646
            Energy bounds to compute the flux point for.
1647
        name : str
1648
            Name of the sliced dataset.
1649
1650
        Returns
1651
        -------
1652
        dataset : `MapDataset`
1653
            Sliced Dataset
1654
1655
        """
1656
        name = make_name(name)
1657
1658
        energy_axis = self._geom.axes["energy"]
1659
1660
        if energy_min is None:
1661
            energy_min = energy_axis.bounds[0]
1662
1663
        if energy_max is None:
1664
            energy_max = energy_axis.bounds[1]
1665
1666
        energy_min, energy_max = u.Quantity(energy_min), u.Quantity(energy_max)
1667
1668
        group = energy_axis.group_table(edges=[energy_min, energy_max])
1669
1670
        is_normal = group["bin_type"] == "normal   "
1671
        group = group[is_normal]
1672
1673
        slices = {
1674
            "energy": slice(int(group["idx_min"][0]), int(group["idx_max"][0]) + 1)
1675
        }
1676
1677
        return self.slice_by_idx(slices, name=name)
1678
1679
    def reset_data_cache(self):
1680
        """Reset data cache to free memory space"""
1681
        for name in self._lazy_data_members:
1682
            if self.__dict__.pop(name, False):
1683
                log.info(f"Clearing {name} cache for dataset {self.name}")
1684
1685
    def resample_energy_axis(self, energy_axis, name=None):
1686
        """Resample MapDataset over new reco energy axis.
1687
1688
        Counts are summed taking into account safe mask.
1689
1690
        Parameters
1691
        ----------
1692
        energy_axis : `~gammapy.maps.MapAxis`
1693
            New reconstructed energy axis.
1694
        name: str
1695
            Name of the new dataset.
1696
1697
        Returns
1698
        -------
1699
        dataset: `MapDataset` or `SpectrumDataset`
1700
            Resampled dataset.
1701
        """
1702
        name = make_name(name)
1703
        kwargs = {"gti": self.gti, "name": name, "meta_table": self.meta_table}
1704
1705
        if self.exposure:
1706
            kwargs["exposure"] = self.exposure
1707
1708
        if self.psf:
1709
            kwargs["psf"] = self.psf
1710
1711
        if self.mask_safe is not None:
1712
            kwargs["mask_safe"] = self.mask_safe.resample_axis(
1713
                axis=energy_axis, ufunc=np.logical_or
1714
            )
1715
1716
        if self.mask_fit is not None:
1717
            kwargs["mask_fit"] = self.mask_fit.resample_axis(
1718
                axis=energy_axis, ufunc=np.logical_or
1719
            )
1720
1721
        if self.counts is not None:
1722
            kwargs["counts"] = self.counts.resample_axis(
1723
                axis=energy_axis, weights=self.mask_safe
1724
            )
1725
1726
        if self.background is not None and self.stat_type == "cash":
1727
            kwargs["background"] = self.background.resample_axis(
1728
                axis=energy_axis, weights=self.mask_safe
1729
            )
1730
1731
        # Mask_safe or mask_irf??
1732
        if isinstance(self.edisp, EDispKernelMap):
1733
            kwargs["edisp"] = self.edisp.resample_energy_axis(
1734
                energy_axis=energy_axis, weights=self.mask_safe_edisp
1735
            )
1736
        else:  # None or EDispMap
1737
            kwargs["edisp"] = self.edisp
1738
1739
        return self.__class__(**kwargs)
1740
1741
    def to_image(self, name=None):
1742
        """Create images by summing over the reconstructed energy axis.
1743
1744
        Parameters
1745
        ----------
1746
        name : str
1747
            Name of the new dataset.
1748
1749
        Returns
1750
        -------
1751
        dataset : `MapDataset` or `SpectrumDataset`
1752
            Dataset integrated over non-spatial axes.
1753
        """
1754
        energy_axis = self._geom.axes["energy"].squash()
1755
        return self.resample_energy_axis(energy_axis=energy_axis, name=name)
1756
1757
1758
    def peek(self, figsize=(12,10)):
1759
        """Quick-look summary plots.
1760
1761
        Parameters
1762
        ----------
1763
        fig : `~matplotlib.figure.Figure`
1764
            Figure to add AxesSubplot on.
1765
1766
        Returns
1767
        -------
1768
        ax1, ax2, ax3 : `~matplotlib.axes.AxesSubplot`
1769
            Counts, excess and exposure.
1770
        """
1771
        def plot_mask(ax, mask, **kwargs):
1772
            if mask is not None:
1773
                mask.plot_mask(ax=ax, **kwargs)
1774
1775
        import matplotlib.pyplot as plt
1776
1777
        fig, axes  = plt.subplots(ncols=2,
1778
                                  nrows=2,
1779
                                  subplot_kw={"projection": self._geom.wcs},
1780
                                  figsize=figsize,
1781
                                  gridspec_kw={"hspace": 0.1, "wspace": 0.1},
1782
                                  )
1783
1784
        axes = axes.flat
1785
        axes[0].set_title("Counts")
1786
        self.counts.sum_over_axes().plot(ax=axes[0], add_cbar=True)
1787
        plot_mask(ax=axes[0], mask=self.mask_fit_image, alpha=0.2)
1788
        plot_mask(ax=axes[0], mask=self.mask_safe_image, hatches=["///"], colors="w")
1789
1790
        axes[1].set_title("Excess counts")
1791
        self.excess.sum_over_axes().plot(ax=axes[1], add_cbar=True)
1792
        plot_mask(ax=axes[1], mask=self.mask_fit_image, alpha=0.2)
1793
        plot_mask(ax=axes[1], mask=self.mask_safe_image,  hatches=["///"], colors="w")
1794
1795
        axes[2].set_title("Exposure")
1796
        self.exposure.sum_over_axes().plot(ax=axes[2], add_cbar=True)
1797
        plot_mask(ax=axes[2], mask=self.mask_safe_image, hatches=["///"], colors="w")
1798
1799
        axes[3].set_title("Background")
1800
        self.background.sum_over_axes().plot(ax=axes[3], add_cbar=True)
1801
        plot_mask(ax=axes[3], mask=self.mask_fit_image, alpha=0.2)
1802
        plot_mask(ax=axes[3], mask=self.mask_safe_image, hatches=["///"], colors="w")
1803
1804
class MapDatasetOnOff(MapDataset):
1805
    """Map dataset for on-off likelihood fitting.
1806
1807
    Parameters
1808
    ----------
1809
    models : `~gammapy.modeling.models.Models`
1810
        Source sky models.
1811
    counts : `~gammapy.maps.WcsNDMap`
1812
        Counts cube
1813
    counts_off : `~gammapy.maps.WcsNDMap`
1814
        Ring-convolved counts cube
1815
    acceptance : `~gammapy.maps.WcsNDMap`
1816
        Acceptance from the IRFs
1817
    acceptance_off : `~gammapy.maps.WcsNDMap`
1818
        Acceptance off
1819
    exposure : `~gammapy.maps.WcsNDMap`
1820
        Exposure cube
1821
    mask_fit : `~gammapy.maps.WcsNDMap`
1822
        Mask to apply to the likelihood for fitting.
1823
    psf : `~gammapy.irf.PSFKernel`
1824
        PSF kernel
1825
    edisp : `~gammapy.irf.EDispKernel`
1826
        Energy dispersion
1827
    mask_safe : `~gammapy.maps.WcsNDMap`
1828
        Mask defining the safe data range.
1829
    gti : `~gammapy.data.GTI`
1830
        GTI of the observation or union of GTI if it is a stacked observation
1831
    meta_table : `~astropy.table.Table`
1832
        Table listing information on observations used to create the dataset.
1833
        One line per observation for stacked datasets.
1834
    name : str
1835
        Name of the dataset.
1836
1837
1838
    See Also
1839
    --------
1840
    MapDataset, SpectrumDataset, FluxPointsDataset
1841
1842
    """
1843
1844
    stat_type = "wstat"
1845
    tag = "MapDatasetOnOff"
1846
1847
    def __init__(
1848
        self,
1849
        models=None,
1850
        counts=None,
1851
        counts_off=None,
1852
        acceptance=None,
1853
        acceptance_off=None,
1854
        exposure=None,
1855
        mask_fit=None,
1856
        psf=None,
1857
        edisp=None,
1858
        name=None,
1859
        mask_safe=None,
1860
        gti=None,
1861
        meta_table=None,
1862
    ):
1863
        self._name = make_name(name)
1864
        self._evaluators = {}
1865
1866
        self.counts = counts
1867
        self.counts_off = counts_off
1868
        self.exposure = exposure
1869
        self.acceptance = acceptance
1870
        self.acceptance_off = acceptance_off
1871
        self.gti = gti
1872
        self.mask_fit = mask_fit
1873
        self.psf = psf
1874
        self.edisp = edisp
1875
        self.models = models
1876
        self.mask_safe = mask_safe
1877
        self.meta_table = meta_table
1878
1879
    def __str__(self):
1880
        str_ = super().__str__()
1881
1882
        counts_off = np.nan
1883
        if self.counts_off is not None:
1884
            counts_off = np.sum(self.counts_off.data)
1885
        str_ += "\t{:32}: {:.0f} \n".format("Total counts_off", counts_off)
1886
1887
        acceptance = np.nan
1888
        if self.acceptance is not None:
1889
            acceptance = np.sum(self.acceptance.data)
1890
        str_ += "\t{:32}: {:.0f} \n".format("Acceptance", acceptance)
1891
1892
        acceptance_off = np.nan
1893
        if self.acceptance_off is not None:
1894
            acceptance_off = np.sum(self.acceptance_off.data)
1895
        str_ += "\t{:32}: {:.0f} \n".format("Acceptance off", acceptance_off)
1896
1897
        return str_.expandtabs(tabsize=2)
1898
1899
    @property
1900
    def _geom(self):
1901
        """Main analysis geometry"""
1902
        if self.counts is not None:
1903
            return self.counts.geom
1904
        elif self.counts_off is not None:
1905
            return self.counts_off.geom
1906
        elif self.acceptance is not None:
1907
            return self.acceptance.geom
1908
        elif self.acceptance_off is not None:
1909
            return self.acceptance_off.geom
1910
        else:
1911
            raise ValueError(
1912
                "Either 'counts', 'counts_off', 'acceptance' or 'acceptance_of' must be defined."
1913
            )
1914
1915
    @property
1916
    def alpha(self):
1917
        """Exposure ratio between signal and background regions
1918
1919
        See :ref:`wstat`
1920
1921
        Returns
1922
        -------
1923
        alpha : `Map`
1924
            Alpha map
1925
        """
1926
        with np.errstate(invalid="ignore", divide="ignore"):
1927
            alpha = self.acceptance / self.acceptance_off
1928
1929
        alpha.data = np.nan_to_num(alpha.data)
1930
        return alpha
1931
1932
    def npred_background(self):
1933
        """Prediced background counts estimated from the marginalized likelihood estimate.
1934
1935
        See :ref:`wstat`
1936
1937
        Returns
1938
        -------
1939
        npred_background : `Map`
1940
            Predicted background counts
1941
        """
1942
        mu_bkg = self.alpha.data * get_wstat_mu_bkg(
1943
            n_on=self.counts.data,
1944
            n_off=self.counts_off.data,
1945
            alpha=self.alpha.data,
1946
            mu_sig=self.npred_signal().data,
1947
        )
1948
        mu_bkg = np.nan_to_num(mu_bkg)
1949
        return Map.from_geom(geom=self._geom, data=mu_bkg)
1950
1951
    def npred_off(self):
1952
        """Predicted counts in the off region
1953
1954
        See :ref:`wstat`
1955
1956
        Returns
1957
        -------
1958
        npred_off : `Map`
1959
            Predicted off counts
1960
        """
1961
        return self.npred_background() / self.alpha
1962
1963
    @property
1964
    def background(self):
1965
        """Computed as alpha * n_off
1966
1967
        See :ref:`wstat`
1968
1969
        Returns
1970
        -------
1971
        background : `Map`
1972
            Background map
1973
        """
1974
        return self.alpha * self.counts_off
1975
1976
    def stat_array(self):
1977
        """Likelihood per bin given the current model parameters"""
1978
        mu_sig = self.npred_signal().data
1979
        on_stat_ = wstat(
1980
            n_on=self.counts.data,
1981
            n_off=self.counts_off.data,
1982
            alpha=list(self.alpha.data),
1983
            mu_sig=mu_sig,
1984
        )
1985
        return np.nan_to_num(on_stat_)
1986
1987
    @classmethod
1988
    def from_geoms(
1989
        cls,
1990
        geom,
1991
        geom_exposure,
1992
        geom_psf=None,
1993
        geom_edisp=None,
1994
        reference_time="2000-01-01",
1995
        name=None,
1996
        **kwargs,
1997
    ):
1998
        """Create a MapDatasetOnOff object  switch zero filled maps according to the specified geometries
1999
2000
        Parameters
2001
        ----------
2002
        geom : `gammapy.maps.WcsGeom`
2003
            geometry for the counts, counts_off, acceptance and acceptance_off maps
2004
        geom_exposure : `gammapy.maps.WcsGeom`
2005
            geometry for the exposure map
2006
        geom_psf : `gammapy.maps.WcsGeom`
2007
            geometry for the psf map
2008
        geom_edisp : `gammapy.maps.WcsGeom`
2009
            geometry for the energy dispersion kernel map.
2010
            If geom_edisp has a migra axis, this will create an EDispMap instead.
2011
        reference_time : `~astropy.time.Time`
2012
            the reference time to use in GTI definition
2013
        name : str
2014
            Name of the returned dataset.
2015
2016
        Returns
2017
        -------
2018
        empty_maps : `MapDatasetOnOff`
2019
            A MapDatasetOnOff containing zero filled maps
2020
        """
2021
        #  TODO: it seems the super() pattern does not work here?
2022
        dataset = MapDataset.from_geoms(
2023
            geom=geom,
2024
            geom_exposure=geom_exposure,
2025
            geom_psf=geom_psf,
2026
            geom_edisp=geom_edisp,
2027
            name=name,
2028
            reference_time=reference_time,
2029
            **kwargs,
2030
        )
2031
2032
        off_maps = {}
2033
2034
        for key in ["counts_off", "acceptance", "acceptance_off"]:
2035
            off_maps[key] = Map.from_geom(geom, unit="")
2036
2037
        return cls.from_map_dataset(dataset, name=name, **off_maps)
2038
2039
    @classmethod
2040
    def from_map_dataset(
2041
        cls, dataset, acceptance, acceptance_off, counts_off=None, name=None
2042
    ):
2043
        """Create on off dataset from a map dataset.
2044
2045
        Parameters
2046
        ----------
2047
        dataset : `MapDataset`
2048
            Spectrum dataset defining counts, edisp, aeff, livetime etc.
2049
        acceptance : `Map`
2050
            Relative background efficiency in the on region.
2051
        acceptance_off : `Map`
2052
            Relative background efficiency in the off region.
2053
        counts_off : `Map`
2054
            Off counts map . If the dataset provides a background model,
2055
            and no off counts are defined. The off counts are deferred from
2056
            counts_off / alpha.
2057
        name : str
2058
            Name of the returned dataset.
2059
2060
        Returns
2061
        -------
2062
        dataset : `MapDatasetOnOff`
2063
            Map dataset on off.
2064
2065
        """
2066
        if counts_off is None and dataset.background is not None:
2067
            alpha = acceptance / acceptance_off
2068
            counts_off = dataset.npred_background() / alpha
2069
2070
        if np.isscalar(acceptance):
2071
            acceptance = Map.from_geom(dataset._geom, data=acceptance)
2072
2073
        if np.isscalar(acceptance_off):
2074
            acceptance_off = Map.from_geom(dataset._geom, data=acceptance_off)
2075
2076
        return cls(
2077
            models=dataset.models,
2078
            counts=dataset.counts,
2079
            exposure=dataset.exposure,
2080
            counts_off=counts_off,
2081
            edisp=dataset.edisp,
2082
            psf=dataset.psf,
2083
            mask_safe=dataset.mask_safe,
2084
            mask_fit=dataset.mask_fit,
2085
            acceptance=acceptance,
2086
            acceptance_off=acceptance_off,
2087
            gti=dataset.gti,
2088
            name=name,
2089
            meta_table=dataset.meta_table,
2090
        )
2091
2092
    def to_map_dataset(self, name=None):
2093
        """Convert a MapDatasetOnOff to  MapDataset
2094
        The background model template is taken as alpha*counts_off
2095
2096
        Parameters
2097
        ----------
2098
        name: str
2099
            Name of the new dataset
2100
2101
        Returns
2102
        -------
2103
        dataset: `MapDataset`
2104
            Map dataset with cash statistics
2105
        """
2106
        name = make_name(name)
2107
2108
        return MapDataset(
2109
            counts=self.counts,
2110
            exposure=self.exposure,
2111
            psf=self.psf,
2112
            edisp=self.edisp,
2113
            name=name,
2114
            gti=self.gti,
2115
            mask_fit=self.mask_fit,
2116
            mask_safe=self.mask_safe,
2117
            background=self.counts_off * self.alpha,
2118
            meta_table=self.meta_table,
2119
        )
2120
2121
    @property
2122
    def _is_stackable(self):
2123
        """Check if the Dataset contains enough information to be stacked"""
2124
        incomplete = (
2125
            self.acceptance_off is None
2126
            or self.acceptance is None
2127
            or self.counts_off is None
2128
        )
2129
        unmasked = np.any(self.mask_safe.data)
2130
        if incomplete and unmasked:
2131
            return False
2132
        else:
2133
            return True
2134
2135
    def stack(self, other, nan_to_num=True):
2136
        r"""Stack another dataset in place.
2137
2138
        The ``acceptance`` of the stacked dataset is normalized to 1,
2139
        and the stacked ``acceptance_off`` is scaled so that:
2140
2141
        .. math::
2142
            \alpha_\text{stacked} =
2143
            \frac{1}{a_\text{off}} =
2144
            \frac{\alpha_1\text{OFF}_1 + \alpha_2\text{OFF}_2}{\text{OFF}_1 + OFF_2}
2145
2146
        Parameters
2147
        ----------
2148
        other : `MapDatasetOnOff`
2149
            Other dataset
2150
        nan_to_num: bool
2151
            Non-finite values are replaced by zero if True (default).
2152
        """
2153
        if not isinstance(other, MapDatasetOnOff):
2154
            raise TypeError("Incompatible types for MapDatasetOnOff stacking")
2155
2156
        if not self._is_stackable or not other._is_stackable:
2157
            raise ValueError("Cannot stack incomplete MapDatsetOnOff.")
2158
2159
        geom = self.counts.geom
2160
        total_off = Map.from_geom(geom)
2161
        total_alpha = Map.from_geom(geom)
2162
2163
        if self.counts_off:
2164
            total_off.stack(
2165
                self.counts_off, weights=self.mask_safe, nan_to_num=nan_to_num
2166
            )
2167
            total_alpha.stack(
2168
                self.alpha * self.counts_off,
2169
                weights=self.mask_safe,
2170
                nan_to_num=nan_to_num,
2171
            )
2172
        if other.counts_off:
2173
            total_off.stack(
2174
                other.counts_off, weights=other.mask_safe, nan_to_num=nan_to_num
2175
            )
2176
            total_alpha.stack(
2177
                other.alpha * other.counts_off,
2178
                weights=other.mask_safe,
2179
                nan_to_num=nan_to_num,
2180
            )
2181
2182
        with np.errstate(divide="ignore", invalid="ignore"):
2183
            acceptance_off = total_off / total_alpha
2184
            average_alpha = total_alpha.data.sum() / total_off.data.sum()
2185
2186
        # For the bins where the stacked OFF counts equal 0, the alpha value is performed by weighting on the total
2187
        # OFF counts of each run
2188
        is_zero = total_off.data == 0
2189
        acceptance_off.data[is_zero] = 1 / average_alpha
2190
2191
        self.acceptance.data[...] = 1
2192
        self.acceptance_off = acceptance_off
2193
2194
        self.counts_off = total_off
2195
2196
        super().stack(other, nan_to_num=nan_to_num)
2197
2198
    def stat_sum(self):
2199
        """Total likelihood given the current model parameters."""
2200
        return Dataset.stat_sum(self)
2201
2202
    def fake(self, npred_background, random_state="random-seed"):
2203
        """Simulate fake counts (on and off) for the current model and reduced IRFs.
2204
2205
        This method overwrites the counts defined on the dataset object.
2206
2207
        Parameters
2208
        ----------
2209
        random_state : {int, 'random-seed', 'global-rng', `~numpy.random.RandomState`}
2210
                Defines random number generator initialisation.
2211
                Passed to `~gammapy.utils.random.get_random_state`.
2212
        """
2213
        random_state = get_random_state(random_state)
2214
        npred = self.npred_signal()
2215
        data = np.nan_to_num(npred.data, copy=True, nan=0.0, posinf=0.0, neginf=0.0)
2216
        npred.data = random_state.poisson(data)
2217
2218
        npred_bkg = random_state.poisson(npred_background.data)
2219
2220
        self.counts = npred + npred_bkg
2221
2222
        npred_off = npred_background / self.alpha
2223
        data_off = np.nan_to_num(
2224
            npred_off.data, copy=True, nan=0.0, posinf=0.0, neginf=0.0
2225
        )
2226
        npred_off.data = random_state.poisson(data_off)
2227
        self.counts_off = npred_off
2228
2229
    def to_hdulist(self):
2230
        """Convert map dataset to list of HDUs.
2231
2232
        Returns
2233
        -------
2234
        hdulist : `~astropy.io.fits.HDUList`
2235
            Map dataset list of HDUs.
2236
        """
2237
        hdulist = super().to_hdulist()
2238
        exclude_primary = slice(1, None)
2239
2240
        del hdulist["BACKGROUND"]
2241
        del hdulist["BACKGROUND_BANDS"]
2242
2243
        if self.counts_off is not None:
2244
            hdulist += self.counts_off.to_hdulist(hdu="counts_off")[exclude_primary]
2245
2246
        if self.acceptance is not None:
2247
            hdulist += self.acceptance.to_hdulist(hdu="acceptance")[exclude_primary]
2248
2249
        if self.acceptance_off is not None:
2250
            hdulist += self.acceptance_off.to_hdulist(hdu="acceptance_off")[
2251
                exclude_primary
2252
            ]
2253
2254
        return hdulist
2255
2256
    @classmethod
2257
    def from_hdulist(cls, hdulist, name=None, format="gadf"):
2258
        """Create map dataset from list of HDUs.
2259
2260
        Parameters
2261
        ----------
2262
        hdulist : `~astropy.io.fits.HDUList`
2263
            List of HDUs.
2264
        name : str
2265
            Name of the new dataset.
2266
        format : {"gadf"}
2267
            Format the hdulist is given in.
2268
2269
        Returns
2270
        -------
2271
        dataset : `MapDatasetOnOff`
2272
            Map dataset.
2273
        """
2274
        kwargs = {}
2275
        kwargs["name"] = name
2276
2277
        if "COUNTS" in hdulist:
2278
            kwargs["counts"] = Map.from_hdulist(hdulist, hdu="counts", format=format)
2279
2280
        if "COUNTS_OFF" in hdulist:
2281
            kwargs["counts_off"] = Map.from_hdulist(
2282
                hdulist, hdu="counts_off", format=format
2283
            )
2284
2285
        if "ACCEPTANCE" in hdulist:
2286
            kwargs["acceptance"] = Map.from_hdulist(
2287
                hdulist, hdu="acceptance", format=format
2288
            )
2289
2290
        if "ACCEPTANCE_OFF" in hdulist:
2291
            kwargs["acceptance_off"] = Map.from_hdulist(
2292
                hdulist, hdu="acceptance_off", format=format
2293
            )
2294
2295
        if "EXPOSURE" in hdulist:
2296
            kwargs["exposure"] = Map.from_hdulist(
2297
                hdulist, hdu="exposure", format=format
2298
            )
2299
2300 View Code Duplication
        if "EDISP" in hdulist:
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
2301
            edisp_map = Map.from_hdulist(hdulist, hdu="edisp", format=format)
2302
2303
            try:
2304
                exposure_map = Map.from_hdulist(
2305
                    hdulist, hdu="edisp_exposure", format=format
2306
                )
2307
            except KeyError:
2308
                exposure_map = None
2309
2310
            if edisp_map.geom.axes[0].name == "energy":
2311
                kwargs["edisp"] = EDispKernelMap(edisp_map, exposure_map)
2312
            else:
2313
                kwargs["edisp"] = EDispMap(edisp_map, exposure_map)
2314
2315
        if "PSF" in hdulist:
2316
            psf_map = Map.from_hdulist(hdulist, hdu="psf", format=format)
2317
            try:
2318
                exposure_map = Map.from_hdulist(
2319
                    hdulist, hdu="psf_exposure", format=format
2320
                )
2321
            except KeyError:
2322
                exposure_map = None
2323
            kwargs["psf"] = PSFMap(psf_map, exposure_map)
2324
2325
        if "MASK_SAFE" in hdulist:
2326
            mask_safe = Map.from_hdulist(hdulist, hdu="mask_safe", format=format)
2327
            kwargs["mask_safe"] = mask_safe
2328
2329
        if "MASK_FIT" in hdulist:
2330
            mask_fit = Map.from_hdulist(hdulist, hdu="mask_fit", format=format)
2331
            kwargs["mask_fit"] = mask_fit
2332
2333
        if "GTI" in hdulist:
2334
            gti = GTI(Table.read(hdulist, hdu="GTI"))
2335
            kwargs["gti"] = gti
2336
        return cls(**kwargs)
2337
2338
    def info_dict(self, in_safe_data_range=True):
2339
        """Basic info dict with summary statistics
2340
2341
        If a region is passed, then a spectrum dataset is
2342
        extracted, and the corresponding info returned.
2343
2344
        Parameters
2345
        ----------
2346
        in_safe_data_range : bool
2347
            Whether to sum only in the safe energy range
2348
2349
        Returns
2350
        -------
2351
        info_dict : dict
2352
            Dictionary with summary info.
2353
        """
2354
        # TODO: remove code duplication with SpectrumDatasetOnOff
2355
        info = super().info_dict(in_safe_data_range)
2356
2357
        if self.mask_safe and in_safe_data_range:
2358
            mask = self.mask_safe.data.astype(bool)
2359
        else:
2360
            mask = slice(None)
2361
2362
        counts_off = 0
2363
        if self.counts_off is not None:
2364
            counts_off = self.counts_off.data[mask].sum()
2365
2366
        info["counts_off"] = int(counts_off)
2367
2368
        acceptance = 1
2369
        if self.acceptance:
2370
            # TODO: handle energy dependent a_on / a_off
2371
            acceptance = self.acceptance.data[mask].sum()
2372
2373
        info["acceptance"] = float(acceptance)
2374
2375
        acceptance_off = np.nan
2376
        if self.acceptance_off:
2377
            acceptance_off = acceptance * counts_off / info["background"]
2378
2379
        info["acceptance_off"] = float(acceptance_off)
2380
2381
        alpha = np.nan
2382
        if self.acceptance_off and self.acceptance:
2383
            alpha = np.mean(self.alpha.data[mask])
2384
2385
        info["alpha"] = float(alpha)
2386
2387
        info["sqrt_ts"] = WStatCountsStatistic(
2388
            info["counts"],
2389
            info["counts_off"],
2390
            acceptance / acceptance_off,
2391
        ).sqrt_ts
2392
        info["stat_sum"] = self.stat_sum()
2393
        return info
2394
2395
    def to_spectrum_dataset(self, on_region, containment_correction=False, name=None):
2396
        """Return a ~gammapy.datasets.SpectrumDatasetOnOff from on_region.
2397
2398
        Counts and OFF counts are summed in the on_region.
2399
2400
        Acceptance is the average of all acceptances while acceptance OFF
2401
        is taken such that number of excess is preserved in the on_region.
2402
2403
        Effective area is taken from the average exposure divided by the livetime.
2404
        Here we assume it is the sum of the GTIs.
2405
2406
        The energy dispersion kernel is obtained at the on_region center.
2407
        Only regions with centers are supported.
2408
2409
        The model is not exported to the ~gammapy.dataset.SpectrumDataset.
2410
        It must be set after the dataset extraction.
2411
2412
        Parameters
2413
        ----------
2414
        on_region : `~regions.SkyRegion`
2415
            the input ON region on which to extract the spectrum
2416
        containment_correction : bool
2417
            Apply containment correction for point sources and circular on regions
2418
        name : str
2419
            Name of the new dataset.
2420
2421
        Returns
2422
        -------
2423
        dataset : `~gammapy.datasets.SpectrumDatasetOnOff`
2424
            the resulting reduced dataset
2425
        """
2426
        from .spectrum import SpectrumDatasetOnOff
2427
2428
        dataset = super().to_spectrum_dataset(
2429
            on_region=on_region, containment_correction=containment_correction, name=name
2430
        )
2431
2432
        kwargs = {"name": name}
2433
2434
        if self.counts_off is not None:
2435
            kwargs["counts_off"] = self.counts_off.get_spectrum(
2436
                on_region, np.sum, weights=self.mask_safe
2437
            )
2438
2439
        if self.acceptance is not None:
2440
            kwargs["acceptance"] = self.acceptance.get_spectrum(
2441
                on_region, np.mean, weights=self.mask_safe
2442
            )
2443
            norm = self.background.get_spectrum(
2444
                on_region, np.sum, weights=self.mask_safe
2445
            )
2446
            acceptance_off = kwargs["acceptance"] * kwargs["counts_off"] / norm
2447
            np.nan_to_num(acceptance_off.data, copy=False)
2448
            kwargs["acceptance_off"] = acceptance_off
2449
2450
        return SpectrumDatasetOnOff.from_spectrum_dataset(dataset=dataset, **kwargs)
2451
2452
    def cutout(self, position, width, mode="trim", name=None):
2453
        """Cutout map dataset.
2454
2455
        Parameters
2456
        ----------
2457
        position : `~astropy.coordinates.SkyCoord`
2458
            Center position of the cutout region.
2459
        width : tuple of `~astropy.coordinates.Angle`
2460
            Angular sizes of the region in (lon, lat) in that specific order.
2461
            If only one value is passed, a square region is extracted.
2462
        mode : {'trim', 'partial', 'strict'}
2463
            Mode option for Cutout2D, for details see `~astropy.nddata.utils.Cutout2D`.
2464
        name : str
2465
            Name of the new dataset.
2466
2467
        Returns
2468
        -------
2469
        cutout : `MapDatasetOnOff`
2470
            Cutout map dataset.
2471
        """
2472
        cutout_kwargs = {
2473
            "position": position,
2474
            "width": width,
2475
            "mode": mode,
2476
            "name": name,
2477
        }
2478
2479
        cutout_dataset = super().cutout(**cutout_kwargs)
2480
2481
        del cutout_kwargs["name"]
2482
2483
        if self.counts_off is not None:
2484
            cutout_dataset.counts_off = self.counts_off.cutout(**cutout_kwargs)
2485
2486
        if self.acceptance is not None:
2487
            cutout_dataset.acceptance = self.acceptance.cutout(**cutout_kwargs)
2488
2489
        if self.acceptance_off is not None:
2490
            cutout_dataset.acceptance_off = self.acceptance_off.cutout(**cutout_kwargs)
2491
2492
        return cutout_dataset
2493
2494
    def downsample(self, factor, axis_name=None, name=None):
2495
        """Downsample map dataset.
2496
2497
        The PSFMap and EDispKernelMap are not downsampled, except if
2498
        a corresponding axis is given.
2499
2500
        Parameters
2501
        ----------
2502
        factor : int
2503
            Downsampling factor.
2504
        axis_name : str
2505
            Which non-spatial axis to downsample. By default only spatial axes are downsampled.
2506
        name : str
2507
            Name of the downsampled dataset.
2508
2509
        Returns
2510
        -------
2511
        dataset : `MapDatasetOnOff`
2512
            Downsampled map dataset.
2513
        """
2514
2515
        dataset = super().downsample(factor, axis_name, name)
2516
2517
        counts_off = None
2518
        if self.counts_off is not None:
2519
            counts_off = self.counts_off.downsample(
2520
                factor=factor,
2521
                preserve_counts=True,
2522
                axis_name=axis_name,
2523
                weights=self.mask_safe,
2524
            )
2525
2526
        acceptance, acceptance_off = None, None
2527
        if self.acceptance_off is not None:
2528
            acceptance = self.acceptance.downsample(
2529
                factor=factor, preserve_counts=False, axis_name=axis_name
2530
            )
2531
            factor = self.background.downsample(
2532
                factor=factor,
2533
                preserve_counts=True,
2534
                axis_name=axis_name,
2535
                weights=self.mask_safe,
2536
            )
2537
            acceptance_off = acceptance * counts_off / factor
2538
2539
        return self.__class__.from_map_dataset(
2540
            dataset,
2541
            acceptance=acceptance,
2542
            acceptance_off=acceptance_off,
2543
            counts_off=counts_off,
2544
        )
2545
2546
    def pad(self):
2547
        raise NotImplementedError
2548
2549
    def slice_by_idx(self, slices, name=None):
2550
        """Slice sub dataset.
2551
2552
        The slicing only applies to the maps that define the corresponding axes.
2553
2554
        Parameters
2555
        ----------
2556
        slices : dict
2557
            Dict of axes names and integers or `slice` object pairs. Contains one
2558
            element for each non-spatial dimension. For integer indexing the
2559
            corresponding axes is dropped from the map. Axes not specified in the
2560
            dict are kept unchanged.
2561
        name : str
2562
            Name of the sliced dataset.
2563
2564
        Returns
2565
        -------
2566
        map_out : `Map`
2567
            Sliced map object.
2568
        """
2569
        kwargs = {"name": name}
2570
        dataset = super().slice_by_idx(slices, name)
2571
2572
        if self.counts_off is not None:
2573
            kwargs["counts_off"] = self.counts_off.slice_by_idx(slices=slices)
2574
2575
        if self.acceptance is not None:
2576
            kwargs["acceptance"] = self.acceptance.slice_by_idx(slices=slices)
2577
2578
        if self.acceptance_off is not None:
2579
            kwargs["acceptance_off"] = self.acceptance_off.slice_by_idx(slices=slices)
2580
2581
        return self.from_map_dataset(dataset, **kwargs)
2582
2583
    def resample_energy_axis(self, energy_axis, name=None):
2584
        """Resample MapDatasetOnOff over reconstructed energy edges.
2585
2586
        Counts are summed taking into account safe mask.
2587
2588
        Parameters
2589
        ----------
2590
        energy_axis : `~gammapy.maps.MapAxis`
2591
            New reco energy axis.
2592
        name: str
2593
            Name of the new dataset.
2594
2595
        Returns
2596
        -------
2597
        dataset: `SpectrumDataset`
2598
            Resampled spectrum dataset .
2599
        """
2600
        dataset = super().resample_energy_axis(energy_axis, name)
2601
2602
        counts_off = None
2603
        if self.counts_off is not None:
2604
            counts_off = self.counts_off
2605
            counts_off = counts_off.resample_axis(
2606
                axis=energy_axis, weights=self.mask_safe
2607
            )
2608
2609
        acceptance = 1
2610
        acceptance_off = None
2611
        if self.acceptance is not None:
2612
            acceptance = self.acceptance
2613
            acceptance = acceptance.resample_axis(
2614
                axis=energy_axis, weights=self.mask_safe
2615
            )
2616
2617
            norm_factor = self.background.resample_axis(
2618
                axis=energy_axis, weights=self.mask_safe
2619
            )
2620
2621
            acceptance_off = acceptance * counts_off / norm_factor
2622
2623
        return self.__class__.from_map_dataset(
2624
            dataset,
2625
            acceptance=acceptance,
2626
            acceptance_off=acceptance_off,
2627
            counts_off=counts_off,
2628
            name=name
2629
        )
2630