gammapy.datasets.map   F
last analyzed

Complexity

Total Complexity 327

Size/Duplication

Total Lines 2793
Duplicated Lines 1 %

Importance

Changes 0
Metric Value
eloc 1315
dl 28
loc 2793
rs 0.8
c 0
b 0
f 0
wmc 327

1 Function

Rating   Name   Duplication   Size   Complexity  
A create_map_dataset_geoms() 0 51 3

79 Methods

Rating   Name   Duplication   Size   Complexity  
A MapDataset.npred() 0 14 2
A MapDataset.energy_range() 0 4 1
A MapDataset.data_shape() 0 4 1
A MapDataset.models() 0 4 4
A MapDataset.energy_range_fit() 0 4 1
B MapDataset.__init__() 0 43 5
A MapDataset.__str__() 0 45 3
A MapDataset.energy_range_safe() 0 4 1
A MapDataset._geom() 0 14 5
A MapDataset.npred_signal() 0 39 5
A MapDataset.npred_background() 0 26 5
B MapDataset.from_geoms() 0 61 7
A MapDataset.evaluators() 0 4 1
A MapDataset._energy_range() 0 30 3
A MapDataset.geoms() 0 23 4
A MapDataset.background_model() 0 6 2
A MapDataset.energy_range_total() 0 5 1
A MapDataset.excess() 0 4 1
A MapDataset._background_parameters_changed() 0 7 2
A MapDatasetOnOff.from_geoms() 0 51 2
A MapDataset.create() 0 71 1
F MapDataset.stack() 0 94 25
A MapDatasetOnOff.to_hdulist() 0 26 4
C MapDataset.downsample() 0 67 11
A MapDatasetOnOff._is_stackable() 0 13 3
A MapDataset.peek() 0 41 2
A MapDataset.plot_residuals() 0 66 2
A MapDataset.mask_safe_edisp() 0 17 5
A MapDatasetOnOff._counts_statistic() 0 4 1
A MapDatasetOnOff.stat_array() 0 10 1
A MapDatasetOnOff.to_map_dataset() 0 28 1
A MapDatasetOnOff._geom() 0 14 5
A MapDatasetOnOff.__str__() 0 19 4
A MapDataset.from_dict() 0 6 1
A MapDataset.to_image() 0 15 1
A MapDataset.reset_data_cache() 0 5 3
A MapDataset.stat_array() 0 3 1
A MapDataset.mask_safe_image() 0 6 2
A MapDataset._counts_statistic() 0 4 1
A MapDataset.write() 0 14 1
A MapDataset.to_masked() 0 18 1
A MapDatasetOnOff.alpha() 0 16 2
B MapDatasetOnOff.stack() 0 62 7
B MapDataset.pad() 0 45 8
A MapDatasetOnOff.pad() 0 2 1
C MapDataset.resample_energy_axis() 0 55 9
A MapDataset._read_lazy() 0 39 3
C MapDataset.to_hdulist() 0 45 10
A MapDataset.mask_fit_image() 0 6 2
A MapDataset.plot_residuals_spatial() 0 66 4
A MapDatasetOnOff.to_spectrum_dataset() 0 57 3
C MapDataset.cutout() 0 46 9
A MapDatasetOnOff.npred_background() 0 18 1
A MapDatasetOnOff.background() 0 14 2
B MapDatasetOnOff.info_dict() 0 50 6
A MapDataset.mask_safe_psf() 0 9 3
A MapDataset.read() 0 35 4
A MapDatasetOnOff.__init__() 0 31 1
B MapDatasetOnOff.from_map_dataset() 0 51 5
A MapDataset.mask_image() 0 9 2
A MapDatasetOnOff.slice_by_idx() 0 33 4
A MapDatasetOnOff.stat_sum() 0 3 1
C MapDataset.slice_by_idx() 0 62 9
C MapDataset.to_region_map_dataset() 0 53 10
B MapDataset.residuals() 0 42 6
F MapDatasetOnOff.from_hdulist() 14 85 15
F MapDataset.info_dict() 0 107 17
A MapDatasetOnOff.npred_off() 0 11 1
A MapDataset.slice_by_energy() 0 45 3
B MapDataset.plot_residuals_spectral() 0 83 7
A MapDataset.stat_sum() 0 8 2
A MapDatasetOnOff.downsample() 0 50 3
A MapDataset.fake() 0 16 1
A MapDatasetOnOff.resample_energy_axis() 0 46 3
A MapDatasetOnOff.fake() 0 26 1
A MapDatasetOnOff._read_lazy() 0 4 1
B MapDataset.to_spectrum_dataset() 0 64 7
F MapDataset.from_hdulist() 14 79 14
A MapDatasetOnOff.cutout() 0 41 4

How to fix   Duplicated Code    Complexity   

Duplicated Code

Duplicate code is one of the most pungent code smells. A rule that is often used is to re-structure code once it is duplicated in three or more places.

Common duplication problems, and corresponding solutions are:

Complexity

 Tip:   Before tackling complexity, make sure that you eliminate any duplication first. This often can reduce the size of classes significantly.

Complex classes like gammapy.datasets.map often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
# Licensed under a 3-clause BSD style license - see LICENSE.rst
2
import logging
3
import numpy as np
4
import astropy.units as u
5
from astropy.io import fits
6
from astropy.table import Table
7
from regions import CircleSkyRegion
8
import matplotlib.pyplot as plt
9
from gammapy.data import GTI
10
from gammapy.irf import EDispKernelMap, EDispMap, PSFKernel, PSFMap, RecoPSFMap
11
from gammapy.maps import Map, MapAxis
12
from gammapy.modeling.models import DatasetModels, FoVBackgroundModel
13
from gammapy.stats import (
14
    CashCountsStatistic,
15
    WStatCountsStatistic,
16
    cash,
17
    cash_sum_cython,
18
    get_wstat_mu_bkg,
19
    wstat,
20
)
21
from gammapy.utils.fits import HDULocation, LazyFitsData
22
from gammapy.utils.random import get_random_state
23
from gammapy.utils.scripts import make_name, make_path
24
from gammapy.utils.table import hstack_columns
25
from .core import Dataset
26
from .evaluator import MapEvaluator
27
from .utils import get_axes
28
29
__all__ = ["MapDataset", "MapDatasetOnOff", "create_map_dataset_geoms"]
30
31
log = logging.getLogger(__name__)
32
33
34
RAD_MAX = 0.66
35
RAD_AXIS_DEFAULT = MapAxis.from_bounds(
36
    0, RAD_MAX, nbin=66, node_type="edges", name="rad", unit="deg"
37
)
38
MIGRA_AXIS_DEFAULT = MapAxis.from_bounds(
39
    0.2, 5, nbin=48, node_type="edges", name="migra"
40
)
41
42
BINSZ_IRF_DEFAULT = 0.2
43
44
EVALUATION_MODE = "local"
45
USE_NPRED_CACHE = True
46
47
48
def create_map_dataset_geoms(
49
    geom,
50
    energy_axis_true=None,
51
    migra_axis=None,
52
    rad_axis=None,
53
    binsz_irf=None,
54
):
55
    """Create map geometries for a `MapDataset`
56
57
    Parameters
58
    ----------
59
    geom : `~gammapy.maps.WcsGeom`
60
        Reference target geometry in reco energy, used for counts and background maps
61
    energy_axis_true : `~gammapy.maps.MapAxis`
62
        True energy axis used for IRF maps
63
    migra_axis : `~gammapy.maps.MapAxis`
64
        If set, this provides the migration axis for the energy dispersion map.
65
        If not set, an EDispKernelMap is produced instead. Default is None
66
    rad_axis : `~gammapy.maps.MapAxis`
67
        Rad axis for the psf map
68
    binsz_irf : float
69
        IRF Map pixel size in degrees.
70
71
    Returns
72
    -------
73
    geoms : dict
74
        Dict with map geometries.
75
    """
76
    rad_axis = rad_axis or RAD_AXIS_DEFAULT
77
78
    if energy_axis_true is not None:
79
        energy_axis_true.assert_name("energy_true")
80
    else:
81
        energy_axis_true = geom.axes["energy"].copy(name="energy_true")
82
83
    binsz_irf = binsz_irf or BINSZ_IRF_DEFAULT
84
    geom_image = geom.to_image()
85
    geom_exposure = geom_image.to_cube([energy_axis_true])
86
    geom_irf = geom_image.to_binsz(binsz=binsz_irf)
87
    geom_psf = geom_irf.to_cube([rad_axis, energy_axis_true])
88
89
    if migra_axis:
90
        geom_edisp = geom_irf.to_cube([migra_axis, energy_axis_true])
91
    else:
92
        geom_edisp = geom_irf.to_cube([geom.axes["energy"], energy_axis_true])
93
94
    return {
95
        "geom": geom,
96
        "geom_exposure": geom_exposure,
97
        "geom_psf": geom_psf,
98
        "geom_edisp": geom_edisp,
99
    }
100
101
102
class MapDataset(Dataset):
103
    """
104
    Bundle together binned counts, background, IRFs, models and compute a likelihood.
105
     Uses Cash statistics by default.
106
107
    Parameters
108
    ----------
109
    models : `~gammapy.modeling.models.Models`
110
        Source sky models.
111
    counts : `~gammapy.maps.WcsNDMap` or `~gammapy.utils.fits.HDULocation`
112
        Counts cube
113
    exposure : `~gammapy.maps.WcsNDMap` or `~gammapy.utils.fits.HDULocation`
114
        Exposure cube
115
    background : `~gammapy.maps.WcsNDMap` or `~gammapy.utils.fits.HDULocation`
116
        Background cube
117
    mask_fit : `~gammapy.maps.WcsNDMap` or `~gammapy.utils.fits.HDULocation`
118
        Mask to apply to the likelihood for fitting.
119
    psf : `~gammapy.irf.PSFMap` or `~gammapy.utils.fits.HDULocation`
120
        PSF kernel
121
    edisp : `~gammapy.irf.EDispMap` or `~gammapy.utils.fits.HDULocation`
122
        Energy dispersion kernel
123
    mask_safe : `~gammapy.maps.WcsNDMap` or `~gammapy.utils.fits.HDULocation`
124
        Mask defining the safe data range.
125
    gti : `~gammapy.data.GTI`
126
        GTI of the observation or union of GTI if it is a stacked observation
127
    meta_table : `~astropy.table.Table`
128
        Table listing information on observations used to create the dataset.
129
        One line per observation for stacked datasets.
130
131
    If an `HDULocation` is passed the map is loaded lazily. This means the
132
    map data is only loaded in memory as the corresponding data attribute
133
    on the MapDataset is accessed. If it was accessed once it is cached for
134
    the next time.
135
136
    Examples
137
    --------
138
    >>> from gammapy.datasets import MapDataset
139
    >>> filename = "$GAMMAPY_DATA/cta-1dc-gc/cta-1dc-gc.fits.gz"
140
    >>> dataset = MapDataset.read(filename, name="cta-dataset")
141
    >>> print(dataset)
142
    MapDataset
143
    ----------
144
    <BLANKLINE>
145
      Name                            : cta-dataset
146
    <BLANKLINE>
147
      Total counts                    : 104317
148
      Total background counts         : 91507.70
149
      Total excess counts             : 12809.30
150
    <BLANKLINE>
151
      Predicted counts                : 91507.69
152
      Predicted background counts     : 91507.70
153
      Predicted excess counts         : nan
154
    <BLANKLINE>
155
      Exposure min                    : 6.28e+07 m2 s
156
      Exposure max                    : 1.90e+10 m2 s
157
    <BLANKLINE>
158
      Number of total bins            : 768000
159
      Number of fit bins              : 691680
160
    <BLANKLINE>
161
      Fit statistic type              : cash
162
      Fit statistic value (-2 log(L)) : nan
163
    <BLANKLINE>
164
      Number of models                : 0
165
      Number of parameters            : 0
166
      Number of free parameters       : 0
167
168
169
    See Also
170
    --------
171
    MapDatasetOnOff, SpectrumDataset, FluxPointsDataset
172
    """
173
174
    stat_type = "cash"
175
    tag = "MapDataset"
176
    counts = LazyFitsData(cache=True)
177
    exposure = LazyFitsData(cache=True)
178
    edisp = LazyFitsData(cache=True)
179
    background = LazyFitsData(cache=True)
180
    psf = LazyFitsData(cache=True)
181
    mask_fit = LazyFitsData(cache=True)
182
    mask_safe = LazyFitsData(cache=True)
183
184
    _lazy_data_members = [
185
        "counts",
186
        "exposure",
187
        "edisp",
188
        "psf",
189
        "mask_fit",
190
        "mask_safe",
191
        "background",
192
    ]
193
194
    def __init__(
195
        self,
196
        models=None,
197
        counts=None,
198
        exposure=None,
199
        background=None,
200
        psf=None,
201
        edisp=None,
202
        mask_safe=None,
203
        mask_fit=None,
204
        gti=None,
205
        meta_table=None,
206
        name=None,
207
    ):
208
        self._name = make_name(name)
209
        self._evaluators = {}
210
211
        self.counts = counts
212
        self.exposure = exposure
213
        self.background = background
214
        self._background_cached = None
215
        self._background_parameters_cached = None
216
217
        self.mask_fit = mask_fit
218
219
        if psf and not isinstance(psf, (PSFMap, HDULocation)):
220
            raise ValueError(
221
                f"'psf' must be a 'PSFMap' or `HDULocation` object, got {type(psf)}"
222
            )
223
224
        self.psf = psf
225
226
        if edisp and not isinstance(edisp, (EDispMap, EDispKernelMap, HDULocation)):
227
            raise ValueError(
228
                "'edisp' must be a 'EDispMap', `EDispKernelMap` or 'HDULocation' "
229
                f"object, got `{type(edisp)}` instead."
230
            )
231
232
        self.edisp = edisp
233
        self.mask_safe = mask_safe
234
        self.gti = gti
235
        self.models = models
236
        self.meta_table = meta_table
237
238
    # TODO: keep or remove?
239
    @property
240
    def background_model(self):
241
        try:
242
            return self.models[f"{self.name}-bkg"]
243
        except (ValueError, TypeError):
244
            pass
245
246
    def __str__(self):
247
        str_ = f"{self.__class__.__name__}\n"
248
        str_ += "-" * len(self.__class__.__name__) + "\n"
249
        str_ += "\n"
250
        str_ += "\t{:32}: {{name}} \n\n".format("Name")
251
        str_ += "\t{:32}: {{counts:.0f}} \n".format("Total counts")
252
        str_ += "\t{:32}: {{background:.2f}}\n".format("Total background counts")
253
        str_ += "\t{:32}: {{excess:.2f}}\n\n".format("Total excess counts")
254
255
        str_ += "\t{:32}: {{npred:.2f}}\n".format("Predicted counts")
256
        str_ += "\t{:32}: {{npred_background:.2f}}\n".format(
257
            "Predicted background counts"
258
        )
259
        str_ += "\t{:32}: {{npred_signal:.2f}}\n\n".format("Predicted excess counts")
260
261
        str_ += "\t{:32}: {{exposure_min:.2e}}\n".format("Exposure min")
262
        str_ += "\t{:32}: {{exposure_max:.2e}}\n\n".format("Exposure max")
263
264
        str_ += "\t{:32}: {{n_bins}} \n".format("Number of total bins")
265
        str_ += "\t{:32}: {{n_fit_bins}} \n\n".format("Number of fit bins")
266
267
        # likelihood section
268
        str_ += "\t{:32}: {{stat_type}}\n".format("Fit statistic type")
269
        str_ += "\t{:32}: {{stat_sum:.2f}}\n\n".format(
270
            "Fit statistic value (-2 log(L))"
271
        )
272
273
        info = self.info_dict()
274
        str_ = str_.format(**info)
275
276
        # model section
277
        n_models, n_pars, n_free_pars = 0, 0, 0
278
        if self.models is not None:
279
            n_models = len(self.models)
280
            n_pars = len(self.models.parameters)
281
            n_free_pars = len(self.models.parameters.free_parameters)
282
283
        str_ += "\t{:32}: {} \n".format("Number of models", n_models)
284
        str_ += "\t{:32}: {}\n".format("Number of parameters", n_pars)
285
        str_ += "\t{:32}: {}\n\n".format("Number of free parameters", n_free_pars)
286
287
        if self.models is not None:
288
            str_ += "\t" + "\n\t".join(str(self.models).split("\n")[2:])
289
290
        return str_.expandtabs(tabsize=2)
291
292
    @property
293
    def geoms(self):
294
        """Map geometries
295
296
        Returns
297
        -------
298
        geoms : dict
299
            Dict of map geometries involved in the dataset.
300
        """
301
        geoms = {}
302
303
        geoms["geom"] = self._geom
304
305
        if self.exposure:
306
            geoms["geom_exposure"] = self.exposure.geom
307
308
        if self.psf:
309
            geoms["geom_psf"] = self.psf.psf_map.geom
310
311
        if self.edisp:
312
            geoms["geom_edisp"] = self.edisp.edisp_map.geom
313
314
        return geoms
315
316
    @property
317
    def models(self):
318
        """Models set on the dataset (`~gammapy.modeling.models.Models`)."""
319
        return self._models
320
321
    @property
322
    def excess(self):
323
        """Observed excess: counts-background"""
324
        return self.counts - self.background
325
326
    @models.setter
327
    def models(self, models):
328
        """Models setter"""
329
        self._evaluators = {}
330
331
        if models is not None:
332
            models = DatasetModels(models)
333
            models = models.select(datasets_names=self.name)
334
335
            for model in models:
336
                if not isinstance(model, FoVBackgroundModel):
337
                    evaluator = MapEvaluator(
338
                        model=model,
339
                        evaluation_mode=EVALUATION_MODE,
340
                        gti=self.gti,
341
                        use_cache=USE_NPRED_CACHE,
342
                    )
343
                    self._evaluators[model.name] = evaluator
344
345
        self._models = models
346
347
    @property
348
    def evaluators(self):
349
        """Model evaluators"""
350
        return self._evaluators
351
352
    @property
353
    def _geom(self):
354
        """Main analysis geometry"""
355
        if self.counts is not None:
356
            return self.counts.geom
357
        elif self.background is not None:
358
            return self.background.geom
359
        elif self.mask_safe is not None:
360
            return self.mask_safe.geom
361
        elif self.mask_fit is not None:
362
            return self.mask_fit.geom
363
        else:
364
            raise ValueError(
365
                "Either 'counts', 'background', 'mask_fit'"
366
                " or 'mask_safe' must be defined."
367
            )
368
369
    @property
370
    def data_shape(self):
371
        """Shape of the counts or background data (tuple)"""
372
        return self._geom.data_shape
373
374
    def _energy_range(self, mask_map=None):
375
        """Compute the energy range maps with or without the fit mask."""
376
        geom = self._geom
377
        energy = geom.axes["energy"].edges
378
        e_i = geom.axes.index_data("energy")
379
        geom = geom.drop("energy")
380
381
        if mask_map is not None:
382
            mask = mask_map.data
383
            if mask.any():
384
                idx = mask.argmax(e_i)
385
                energy_min = energy.value[idx]
386
                mask_nan = ~mask.any(e_i)
387
                energy_min[mask_nan] = np.nan
388
389
                mask = np.flip(mask, e_i)
390
                idx = mask.argmax(e_i)
391
                energy_max = energy.value[::-1][idx]
392
                energy_max[mask_nan] = np.nan
393
            else:
394
                energy_min = np.full(geom.data_shape, np.nan)
395
                energy_max = energy_min.copy()
396
        else:
397
            data_shape = geom.data_shape
398
            energy_min = np.full(data_shape, energy.value[0])
399
            energy_max = np.full(data_shape, energy.value[-1])
400
401
        map_min = Map.from_geom(geom, data=energy_min, unit=energy.unit)
402
        map_max = Map.from_geom(geom, data=energy_max, unit=energy.unit)
403
        return map_min, map_max
404
405
    @property
406
    def energy_range(self):
407
        """Energy range maps defined by the mask_safe and mask_fit."""
408
        return self._energy_range(self.mask)
409
410
    @property
411
    def energy_range_safe(self):
412
        """Energy range maps defined by the mask_safe only."""
413
        return self._energy_range(self.mask_safe)
414
415
    @property
416
    def energy_range_fit(self):
417
        """Energy range maps defined by the mask_fit only."""
418
        return self._energy_range(self.mask_fit)
419
420
    @property
421
    def energy_range_total(self):
422
        """Largest energy range among all pixels, defined by mask_safe and mask_fit."""
423
        energy_min_map, energy_max_map = self.energy_range
424
        return np.nanmin(energy_min_map.quantity), np.nanmax(energy_max_map.quantity)
425
426
    def npred(self):
427
        """Total predicted source and background counts
428
429
        Returns
430
        -------
431
        npred : `Map`
432
            Total predicted counts
433
        """
434
        npred_total = self.npred_signal()
435
436
        if self.background:
437
            npred_total += self.npred_background()
438
        npred_total.data[npred_total.data < 0.0] = 0
439
        return npred_total
440
441
    def npred_background(self):
442
        """Predicted background counts
443
444
        The predicted background counts depend on the parameters
445
        of the `FoVBackgroundModel` defined in the dataset.
446
447
        Returns
448
        -------
449
        npred_background : `Map`
450
            Predicted counts from the background.
451
        """
452
        background = self.background
453
        if self.background_model and background:
454
            if self._background_parameters_changed:
455
                values = self.background_model.evaluate_geom(geom=self.background.geom)
456
                if self._background_cached is None:
457
                    self._background_cached = background * values
458
                else:
459
                    self._background_cached.quantity = (
460
                        background.quantity * values.value
461
                    )
462
            return self._background_cached
463
        else:
464
            return background
465
466
        return background
467
468
    def _background_parameters_changed(self):
469
        values = self.background_model.parameters.value
470
        # TODO: possibly allow for a tolerance here?
471
        changed = ~np.all(self._background_parameters_cached == values)
472
        if changed:
473
            self._background_parameters_cached = values
474
        return changed
475
476
    def npred_signal(self, model_name=None):
477
        """Model predicted signal counts.
478
479
        If a model name is passed, predicted counts from that component are returned.
480
        Else, the total signal counts are returned.
481
482
        Parameters
483
        ----------
484
        model_name: str
485
            Name of  SkyModel for which to compute the npred for.
486
            If none, the sum of all components (minus the background model)
487
            is returned
488
489
        Returns
490
        -------
491
        npred_sig: `gammapy.maps.Map`
492
            Map of the predicted signal counts
493
        """
494
        npred_total = Map.from_geom(self._geom, dtype=float)
495
496
        evaluators = self.evaluators
497
        if model_name is not None:
498
            evaluators = {model_name: self.evaluators[model_name]}
499
500
        for evaluator in evaluators.values():
501
            if evaluator.needs_update:
502
                evaluator.update(
503
                    self.exposure,
504
                    self.psf,
505
                    self.edisp,
506
                    self._geom,
507
                    self.mask_image,
508
                )
509
510
            if evaluator.contributes:
511
                npred = evaluator.compute_npred()
512
                npred_total.stack(npred)
513
514
        return npred_total
515
516
    @classmethod
517
    def from_geoms(
518
        cls,
519
        geom,
520
        geom_exposure=None,
521
        geom_psf=None,
522
        geom_edisp=None,
523
        reference_time="2000-01-01",
524
        name=None,
525
        **kwargs,
526
    ):
527
        """
528
        Create a MapDataset object with zero filled maps according to the specified geometries
529
530
        Parameters
531
        ----------
532
        geom : `Geom`
533
            geometry for the counts and background maps
534
        geom_exposure : `Geom`
535
            geometry for the exposure map
536
        geom_psf : `Geom`
537
            geometry for the psf map
538
        geom_edisp : `Geom`
539
            geometry for the energy dispersion kernel map.
540
            If geom_edisp has a migra axis, this will create an EDispMap instead.
541
        reference_time : `~astropy.time.Time`
542
            the reference time to use in GTI definition
543
        name : str
544
            Name of the returned dataset.
545
546
        Returns
547
        -------
548
        dataset : `MapDataset` or `SpectrumDataset`
549
            A dataset containing zero filled maps
550
        """
551
        name = make_name(name)
552
        kwargs = kwargs.copy()
553
        kwargs["name"] = name
554
        kwargs["counts"] = Map.from_geom(geom, unit="")
555
        kwargs["background"] = Map.from_geom(geom, unit="")
556
557
        if geom_exposure:
558
            kwargs["exposure"] = Map.from_geom(geom_exposure, unit="m2 s")
559
560
        if geom_edisp:
561
            if "energy" in geom_edisp.axes.names:
562
                kwargs["edisp"] = EDispKernelMap.from_geom(geom_edisp)
563
            else:
564
                kwargs["edisp"] = EDispMap.from_geom(geom_edisp)
565
566
        if geom_psf:
567
            if "energy_true" in geom_psf.axes.names:
568
                kwargs["psf"] = PSFMap.from_geom(geom_psf)
569
            elif "energy" in geom_psf.axes.names:
570
                kwargs["psf"] = RecoPSFMap.from_geom(geom_psf)
571
572
        kwargs.setdefault(
573
            "gti", GTI.create([] * u.s, [] * u.s, reference_time=reference_time)
574
        )
575
        kwargs["mask_safe"] = Map.from_geom(geom, unit="", dtype=bool)
576
        return cls(**kwargs)
577
578
    @classmethod
579
    def create(
580
        cls,
581
        geom,
582
        energy_axis_true=None,
583
        migra_axis=None,
584
        rad_axis=None,
585
        binsz_irf=None,
586
        reference_time="2000-01-01",
587
        name=None,
588
        meta_table=None,
589
        **kwargs,
590
    ):
591
        """Create a MapDataset object with zero filled maps.
592
593
        Parameters
594
        ----------
595
        geom : `~gammapy.maps.WcsGeom`
596
            Reference target geometry in reco energy, used for counts and background maps
597
        energy_axis_true : `~gammapy.maps.MapAxis`
598
            True energy axis used for IRF maps
599
        migra_axis : `~gammapy.maps.MapAxis`
600
            If set, this provides the migration axis for the energy dispersion map.
601
            If not set, an EDispKernelMap is produced instead. Default is None
602
        rad_axis : `~gammapy.maps.MapAxis`
603
            Rad axis for the psf map
604
        binsz_irf : float
605
            IRF Map pixel size in degrees.
606
        reference_time : `~astropy.time.Time`
607
            the reference time to use in GTI definition
608
        name : str
609
            Name of the returned dataset.
610
        meta_table : `~astropy.table.Table`
611
            Table listing information on observations used to create the dataset.
612
            One line per observation for stacked datasets.
613
614
        Returns
615
        -------
616
        empty_maps : `MapDataset`
617
            A MapDataset containing zero filled maps
618
619
        Examples
620
        --------
621
        >>> from gammapy.datasets import MapDataset
622
        >>> from gammapy.maps import WcsGeom, MapAxis
623
624
        >>> energy_axis = MapAxis.from_energy_bounds(1.0, 10.0, 4, unit="TeV")
625
        >>> energy_axis_true = MapAxis.from_energy_bounds(
626
                    0.5, 20, 10, unit="TeV", name="energy_true"
627
                )
628
        >>> geom = WcsGeom.create(
629
                    skydir=(83.633, 22.014),
630
                    binsz=0.02, width=(2, 2),
631
                    frame="icrs",
632
                    proj="CAR",
633
                    axes=[energy_axis]
634
                )
635
        >>> empty = MapDataset.create(geom=geom, energy_axis_true=energy_axis_true, name="empty")
636
        """
637
638
        geoms = create_map_dataset_geoms(
639
            geom=geom,
640
            energy_axis_true=energy_axis_true,
641
            rad_axis=rad_axis,
642
            migra_axis=migra_axis,
643
            binsz_irf=binsz_irf,
644
        )
645
646
        kwargs.update(geoms)
647
        return cls.from_geoms(
648
            reference_time=reference_time, name=name, meta_table=meta_table, **kwargs
649
        )
650
651
    @property
652
    def mask_safe_image(self):
653
        """Reduced mask safe"""
654
        if self.mask_safe is None:
655
            return None
656
        return self.mask_safe.reduce_over_axes(func=np.logical_or)
657
658
    @property
659
    def mask_fit_image(self):
660
        """Reduced mask fit"""
661
        if self.mask_fit is None:
662
            return None
663
        return self.mask_fit.reduce_over_axes(func=np.logical_or)
664
665
    @property
666
    def mask_image(self):
667
        """Reduced mask"""
668
        if self.mask is None:
669
            mask = Map.from_geom(self._geom.to_image(), dtype=bool)
670
            mask.data |= True
671
            return mask
672
673
        return self.mask.reduce_over_axes(func=np.logical_or)
674
675
    @property
676
    def mask_safe_psf(self):
677
        """Mask safe for psf maps"""
678
        if self.mask_safe is None or self.psf is None:
679
            return None
680
681
        geom = self.psf.psf_map.geom.squash("energy_true").squash("rad")
682
        mask_safe_psf = self.mask_safe_image.interp_to_geom(geom.to_image())
683
        return mask_safe_psf.to_cube(geom.axes)
684
685
    @property
686
    def mask_safe_edisp(self):
687
        """Mask safe for edisp maps"""
688
        if self.mask_safe is None or self.edisp is None:
689
            return None
690
691
        if self.mask_safe.geom.is_region:
692
            return self.mask_safe
693
694
        geom = self.edisp.edisp_map.geom.squash("energy_true")
695
696
        if "migra" in geom.axes.names:
697
            geom = geom.squash("migra")
698
            mask_safe_edisp = self.mask_safe_image.interp_to_geom(geom.to_image())
699
            return mask_safe_edisp.to_cube(geom.axes)
700
701
        return self.mask_safe.interp_to_geom(geom)
702
703
    def to_masked(self, name=None, nan_to_num=True):
704
        """Return masked dataset
705
706
        Parameters
707
        ----------
708
        name : str
709
            Name of the masked dataset.
710
        nan_to_num: bool
711
            Non-finite values are replaced by zero if True (default).
712
713
        Returns
714
        -------
715
        dataset : `MapDataset` or `SpectrumDataset`
716
            Masked dataset
717
        """
718
        dataset = self.__class__.from_geoms(**self.geoms, name=name)
719
        dataset.stack(self, nan_to_num=nan_to_num)
720
        return dataset
721
722
    def stack(self, other, nan_to_num=True):
723
        r"""Stack another dataset in place. The original dataset is modified.
724
725
        Safe mask is applied to compute the stacked counts data. Counts outside
726
        each dataset safe mask are lost.
727
728
        The stacking of 2 datasets is implemented as follows. Here, :math:`k`
729
        denotes a bin in reconstructed energy and :math:`j = {1,2}` is the dataset number
730
731
        The ``mask_safe`` of each dataset is defined as:
732
733
        .. math::
734
735
            \epsilon_{jk} =\left\{\begin{array}{cl} 1, &
736
            \mbox{if bin k is inside the thresholds}\\ 0, &
737
            \mbox{otherwise} \end{array}\right.
738
739
        Then the total ``counts`` and model background ``bkg`` are computed according to:
740
741
        .. math::
742
743
            \overline{\mathrm{n_{on}}}_k =  \mathrm{n_{on}}_{1k} \cdot \epsilon_{1k} +
744
             \mathrm{n_{on}}_{2k} \cdot \epsilon_{2k}
745
746
            \overline{bkg}_k = bkg_{1k} \cdot \epsilon_{1k} +
747
             bkg_{2k} \cdot \epsilon_{2k}
748
749
        The stacked ``safe_mask`` is then:
750
751
        .. math::
752
753
            \overline{\epsilon_k} = \epsilon_{1k} OR \epsilon_{2k}
754
755
756
        Parameters
757
        ----------
758
        other: `~gammapy.datasets.MapDataset` or `~gammapy.datasets.MapDatasetOnOff`
759
            Map dataset to be stacked with this one. If other is an on-off
760
            dataset alpha * counts_off is used as a background model.
761
        nan_to_num: bool
762
            Non-finite values are replaced by zero if True (default).
763
764
        """
765
        if self.counts and other.counts:
766
            self.counts.stack(
767
                other.counts, weights=other.mask_safe, nan_to_num=nan_to_num
768
            )
769
770
        if self.exposure and other.exposure:
771
            self.exposure.stack(
772
                other.exposure, weights=other.mask_safe_image, nan_to_num=nan_to_num
773
            )
774
            # TODO: check whether this can be improved e.g. handling this in GTI
775
776
            if "livetime" in other.exposure.meta and np.any(other.mask_safe_image):
777
                if "livetime" in self.exposure.meta:
778
                    self.exposure.meta["livetime"] += other.exposure.meta["livetime"]
779
                else:
780
                    self.exposure.meta["livetime"] = other.exposure.meta[
781
                        "livetime"
782
                    ].copy()
783
784
        if self.stat_type == "cash":
785
            if self.background and other.background:
786
                background = self.npred_background() * self.mask_safe
787
                background.stack(
788
                    other.npred_background(),
789
                    weights=other.mask_safe,
790
                    nan_to_num=nan_to_num,
791
                )
792
                self.background = background
793
794
        if self.psf and other.psf:
795
            self.psf.stack(other.psf, weights=other.mask_safe_psf)
796
797
        if self.edisp and other.edisp:
798
            self.edisp.stack(other.edisp, weights=other.mask_safe_edisp)
799
800
        if self.mask_safe and other.mask_safe:
801
            self.mask_safe.stack(other.mask_safe)
802
803
        if self.mask_fit and other.mask_fit:
804
            self.mask_fit.stack(other.mask_fit)
805
        elif other.mask_fit:
806
            self.mask_fit = other.mask_fit.copy()
807
808
        if self.gti and other.gti:
809
            self.gti.stack(other.gti)
810
            self.gti = self.gti.union()
811
812
        if self.meta_table and other.meta_table:
813
            self.meta_table = hstack_columns(self.meta_table, other.meta_table)
814
        elif other.meta_table:
815
            self.meta_table = other.meta_table.copy()
816
817
    def stat_array(self):
818
        """Likelihood per bin given the current model parameters"""
819
        return cash(n_on=self.counts.data, mu_on=self.npred().data)
820
821
    def residuals(self, method="diff", **kwargs):
822
        """Compute residuals map.
823
824
        Parameters
825
        ----------
826
        method: {"diff", "diff/model", "diff/sqrt(model)"}
827
            Method used to compute the residuals. Available options are:
828
                - "diff" (default): data - model
829
                - "diff/model": (data - model) / model
830
                - "diff/sqrt(model)": (data - model) / sqrt(model)
831
        **kwargs : dict
832
            Keyword arguments forwarded to `Map.smooth()`
833
834
        Returns
835
        -------
836
        residuals : `gammapy.maps.Map`
837
            Residual map.
838
        """
839
        npred, counts = self.npred(), self.counts.copy()
840
841
        if self.mask:
842
            npred = npred * self.mask
843
            counts = counts * self.mask
844
845
        if kwargs:
846
            kwargs.setdefault("mode", "constant")
847
            kwargs.setdefault("width", "0.1 deg")
848
            kwargs.setdefault("kernel", "gauss")
849
            with np.errstate(invalid="ignore", divide="ignore"):
850
                npred = npred.smooth(**kwargs)
851
                counts = counts.smooth(**kwargs)
852
                if self.mask:
853
                    mask = self.mask.smooth(**kwargs)
854
                    npred /= mask
855
                    counts /= mask
856
857
        residuals = self._compute_residuals(counts, npred, method=method)
858
859
        if self.mask:
860
            residuals.data[~self.mask.data] = np.nan
861
862
        return residuals
863
864
    def plot_residuals_spatial(
865
        self,
866
        ax=None,
867
        method="diff",
868
        smooth_kernel="gauss",
869
        smooth_radius="0.1 deg",
870
        **kwargs,
871
    ):
872
        """Plot spatial residuals.
873
874
        The normalization used for the residuals computation can be controlled
875
        using the method parameter.
876
877
        Parameters
878
        ----------
879
        ax : `~astropy.visualization.wcsaxes.WCSAxes`
880
            Axes to plot on.
881
        method : {"diff", "diff/model", "diff/sqrt(model)"}
882
            Normalization used to compute the residuals, see `MapDataset.residuals`.
883
        smooth_kernel : {"gauss", "box"}
884
            Kernel shape.
885
        smooth_radius: `~astropy.units.Quantity`, str or float
886
            Smoothing width given as quantity or float. If a float is given, it
887
            is interpreted as smoothing width in pixels.
888
        **kwargs : dict
889
            Keyword arguments passed to `~matplotlib.axes.Axes.imshow`.
890
891
        Returns
892
        -------
893
        ax : `~astropy.visualization.wcsaxes.WCSAxes`
894
            WCSAxes object.
895
896
        Examples
897
        --------
898
        >>> from gammapy.datasets import MapDataset
899
        >>> dataset = MapDataset.read("$GAMMAPY_DATA/cta-1dc-gc/cta-1dc-gc.fits.gz")
900
        >>> kwargs = {"cmap": "RdBu_r", "vmin":-5, "vmax":5, "add_cbar": True}
901
        >>> dataset.plot_residuals_spatial(method="diff/sqrt(model)", **kwargs) # doctest: +SKIP
902
        """
903
        counts, npred = self.counts.copy(), self.npred()
904
905
        if counts.geom.is_region:
906
            raise ValueError("Cannot plot spatial residuals for RegionNDMap")
907
908
        if self.mask is not None:
909
            counts *= self.mask
910
            npred *= self.mask
911
912
        counts_spatial = counts.sum_over_axes().smooth(
913
            width=smooth_radius, kernel=smooth_kernel
914
        )
915
        npred_spatial = npred.sum_over_axes().smooth(
916
            width=smooth_radius, kernel=smooth_kernel
917
        )
918
        residuals = self._compute_residuals(counts_spatial, npred_spatial, method)
919
920
        if self.mask_safe is not None:
921
            mask = self.mask_safe.reduce_over_axes(func=np.logical_or, keepdims=True)
922
            residuals.data[~mask.data] = np.nan
923
924
        kwargs.setdefault("add_cbar", True)
925
        kwargs.setdefault("cmap", "coolwarm")
926
        kwargs.setdefault("vmin", -5)
927
        kwargs.setdefault("vmax", 5)
928
        ax = residuals.plot(ax, **kwargs)
929
        return ax
930
931
    def plot_residuals_spectral(self, ax=None, method="diff", region=None, **kwargs):
932
        """Plot spectral residuals.
933
934
        The residuals are extracted from the provided region, and the normalization
935
        used for its computation can be controlled using the method parameter.
936
937
        The error bars are computed using the uncertainty on the excess with a symmetric assumption.
938
939
        Parameters
940
        ----------
941
        ax : `~matplotlib.axes.Axes`
942
            Axes to plot on.
943
        method : {"diff", "diff/sqrt(model)"}
944
            Normalization used to compute the residuals, see `SpectrumDataset.residuals`.
945
        region: `~regions.SkyRegion` (required)
946
            Target sky region.
947
        **kwargs : dict
948
            Keyword arguments passed to `~matplotlib.axes.Axes.errorbar`.
949
950
        Returns
951
        -------
952
        ax : `~matplotlib.axes.Axes`
953
            Axes object.
954
955
        Examples
956
        --------
957
        >>> from gammapy.datasets import MapDataset
958
        >>> dataset = MapDataset.read("$GAMMAPY_DATA/cta-1dc-gc/cta-1dc-gc.fits.gz")
959
        >>> kwargs = {"markerfacecolor": "blue", "markersize":8, "marker":'s'}
960
        >>> dataset.plot_residuals_spectral(method="diff/sqrt(model)", **kwargs) # doctest: +SKIP
961
962
        """
963
        counts, npred = self.counts.copy(), self.npred()
964
965
        if self.mask is None:
966
            mask = self.counts.copy()
967
            mask.data = 1
968
        else:
969
            mask = self.mask
970
        counts *= mask
971
        npred *= mask
972
973
        counts_spec = counts.get_spectrum(region)
974
        npred_spec = npred.get_spectrum(region)
975
        residuals = self._compute_residuals(counts_spec, npred_spec, method)
976
977
        if self.stat_type == "wstat":
978
            counts_off = (self.counts_off * mask).get_spectrum(region)
979
980
            with np.errstate(invalid="ignore"):
981
                alpha = (self.background * mask).get_spectrum(region) / counts_off
982
983
            mu_sig = (self.npred_signal() * mask).get_spectrum(region)
984
            stat = WStatCountsStatistic(
985
                n_on=counts_spec,
986
                n_off=counts_off,
987
                alpha=alpha,
988
                mu_sig=mu_sig,
989
            )
990
        elif self.stat_type == "cash":
991
            stat = CashCountsStatistic(counts_spec.data, npred_spec.data)
992
        excess_error = stat.error
0 ignored issues
show
introduced by
The variable stat does not seem to be defined for all execution paths.
Loading history...
993
994
        if method == "diff":
995
            yerr = excess_error
996
        elif method == "diff/sqrt(model)":
997
            yerr = excess_error / np.sqrt(npred_spec.data)
998
        else:
999
            raise ValueError(
1000
                'Invalid method, choose between "diff" and "diff/sqrt(model)"'
1001
            )
1002
1003
        kwargs.setdefault("color", kwargs.pop("c", "black"))
1004
        ax = residuals.plot(ax, yerr=yerr, **kwargs)
1005
        ax.axhline(0, color=kwargs["color"], lw=0.5)
1006
1007
        label = self._residuals_labels[method]
1008
        ax.set_ylabel(f"Residuals ({label})")
1009
        ax.set_yscale("linear")
1010
        ymin = 1.05 * np.nanmin(residuals.data - yerr)
1011
        ymax = 1.05 * np.nanmax(residuals.data + yerr)
1012
        ax.set_ylim(ymin, ymax)
1013
        return ax
1014
1015
    def plot_residuals(
1016
        self,
1017
        ax_spatial=None,
1018
        ax_spectral=None,
1019
        kwargs_spatial=None,
1020
        kwargs_spectral=None,
1021
    ):
1022
        """Plot spatial and spectral residuals in two panels.
1023
1024
        Calls `~MapDataset.plot_residuals_spatial` and `~MapDataset.plot_residuals_spectral`.
1025
        The spectral residuals are extracted from the provided region, and the
1026
        normalization used for its computation can be controlled using the method
1027
        parameter. The region outline is overlaid on the residuals map. If no region is passed,
1028
        the residuals are computed for the entire map
1029
1030
        Parameters
1031
        ----------
1032
        ax_spatial : `~astropy.visualization.wcsaxes.WCSAxes`
1033
            Axes to plot spatial residuals on.
1034
        ax_spectral : `~matplotlib.axes.Axes`
1035
            Axes to plot spectral residuals on.
1036
        kwargs_spatial : dict
1037
            Keyword arguments passed to `~MapDataset.plot_residuals_spatial`.
1038
        kwargs_spectral : dict
1039
            Keyword arguments passed to `~MapDataset.plot_residuals_spectral`.
1040
            The region should be passed as a dictionary key
1041
1042
        Returns
1043
        -------
1044
        ax_spatial, ax_spectral : `~astropy.visualization.wcsaxes.WCSAxes`, `~matplotlib.axes.Axes`
1045
            Spatial and spectral residuals plots.
1046
1047
        Examples
1048
        --------
1049
        >>> from regions import CircleSkyRegion
1050
        >>> from astropy.coordinates import SkyCoord
1051
        >>> import astropy.units as u
1052
        >>> from gammapy.datasets import MapDataset
1053
        >>> dataset = MapDataset.read("$GAMMAPY_DATA/cta-1dc-gc/cta-1dc-gc.fits.gz")
1054
        >>> reg = CircleSkyRegion(SkyCoord(0,0, unit="deg", frame="galactic"), radius=1.0 * u.deg)
1055
        >>> kwargs_spatial = {"cmap": "RdBu_r", "vmin":-5, "vmax":5, "add_cbar": True}
1056
        >>> kwargs_spectral = {"region":reg, "markerfacecolor": "blue", "markersize": 8, "marker": "s"}  # noqa: E501
1057
        >>> dataset.plot_residuals(kwargs_spatial=kwargs_spatial, kwargs_spectral=kwargs_spectral) # doctest: +SKIP noqa: E501
1058
        """
1059
        ax_spatial, ax_spectral = get_axes(
1060
            ax_spatial,
1061
            ax_spectral,
1062
            12,
1063
            4,
1064
            [1, 2, 1],
1065
            [1, 2, 2],
1066
            {"projection": self._geom.to_image().wcs},
1067
        )
1068
        kwargs_spatial = kwargs_spatial or {}
1069
        kwargs_spectral = kwargs_spectral or {}
1070
1071
        self.plot_residuals_spatial(ax_spatial, **kwargs_spatial)
1072
        self.plot_residuals_spectral(ax_spectral, **kwargs_spectral)
1073
1074
        # Overlay spectral extraction region on the spatial residuals
1075
        region = kwargs_spectral.get("region")
1076
        if region is not None:
1077
            pix_region = region.to_pixel(self._geom.to_image().wcs)
1078
            pix_region.plot(ax=ax_spatial)
1079
1080
        return ax_spatial, ax_spectral
1081
1082
    def stat_sum(self):
1083
        """Total likelihood given the current model parameters."""
1084
        counts, npred = self.counts.data.astype(float), self.npred().data
1085
1086
        if self.mask is not None:
1087
            return cash_sum_cython(counts[self.mask.data], npred[self.mask.data])
1088
        else:
1089
            return cash_sum_cython(counts.ravel(), npred.ravel())
1090
1091
    def fake(self, random_state="random-seed"):
1092
        """Simulate fake counts for the current model and reduced IRFs.
1093
1094
        This method overwrites the counts defined on the dataset object.
1095
1096
        Parameters
1097
        ----------
1098
        random_state : {int, 'random-seed', 'global-rng', `~numpy.random.RandomState`}
1099
                Defines random number generator initialisation.
1100
                Passed to `~gammapy.utils.random.get_random_state`.
1101
        """
1102
        random_state = get_random_state(random_state)
1103
        npred = self.npred()
1104
        data = np.nan_to_num(npred.data, copy=True, nan=0.0, posinf=0.0, neginf=0.0)
1105
        npred.data = random_state.poisson(data)
1106
        self.counts = npred
1107
1108
    def to_hdulist(self):
1109
        """Convert map dataset to list of HDUs.
1110
1111
        Returns
1112
        -------
1113
        hdulist : `~astropy.io.fits.HDUList`
1114
            Map dataset list of HDUs.
1115
        """
1116
        # TODO: what todo about the model and background model parameters?
1117
        exclude_primary = slice(1, None)
1118
1119
        hdu_primary = fits.PrimaryHDU()
1120
1121
        header = hdu_primary.header
1122
        header["NAME"] = self.name
1123
1124
        hdulist = fits.HDUList([hdu_primary])
1125
        if self.counts is not None:
1126
            hdulist += self.counts.to_hdulist(hdu="counts")[exclude_primary]
1127
1128
        if self.exposure is not None:
1129
            hdulist += self.exposure.to_hdulist(hdu="exposure")[exclude_primary]
1130
1131
        if self.background is not None:
1132
            hdulist += self.background.to_hdulist(hdu="background")[exclude_primary]
1133
1134
        if self.edisp is not None:
1135
            hdulist += self.edisp.to_hdulist()[exclude_primary]
1136
1137
        if self.psf is not None:
1138
            hdulist += self.psf.to_hdulist()[exclude_primary]
1139
1140
        if self.mask_safe is not None:
1141
            hdulist += self.mask_safe.to_hdulist(hdu="mask_safe")[exclude_primary]
1142
1143
        if self.mask_fit is not None:
1144
            hdulist += self.mask_fit.to_hdulist(hdu="mask_fit")[exclude_primary]
1145
1146
        if self.gti is not None:
1147
            hdulist.append(fits.BinTableHDU(self.gti.table, name="GTI"))
1148
1149
        if self.meta_table is not None:
1150
            hdulist.append(fits.BinTableHDU(self.meta_table, name="META_TABLE"))
1151
1152
        return hdulist
1153
1154
    @classmethod
1155
    def from_hdulist(cls, hdulist, name=None, lazy=False, format="gadf"):
1156
        """Create map dataset from list of HDUs.
1157
1158
        Parameters
1159
        ----------
1160
        hdulist : `~astropy.io.fits.HDUList`
1161
            List of HDUs.
1162
        name : str
1163
            Name of the new dataset.
1164
        format : {"gadf"}
1165
            Format the hdulist is given in.
1166
1167
        Returns
1168
        -------
1169
        dataset : `MapDataset`
1170
            Map dataset.
1171
        """
1172
        name = make_name(name)
1173
        kwargs = {"name": name}
1174
1175
        if "COUNTS" in hdulist:
1176
            kwargs["counts"] = Map.from_hdulist(hdulist, hdu="counts", format=format)
1177
1178
        if "EXPOSURE" in hdulist:
1179
            exposure = Map.from_hdulist(hdulist, hdu="exposure", format=format)
1180
            if exposure.geom.axes[0].name == "energy":
1181
                exposure.geom.axes[0].name = "energy_true"
1182
            kwargs["exposure"] = exposure
1183
1184
        if "BACKGROUND" in hdulist:
1185
            kwargs["background"] = Map.from_hdulist(
1186
                hdulist, hdu="background", format=format
1187
            )
1188
1189 View Code Duplication
        if "EDISP" in hdulist:
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
1190
            edisp_map = Map.from_hdulist(hdulist, hdu="edisp", format=format)
1191
1192
            try:
1193
                exposure_map = Map.from_hdulist(
1194
                    hdulist, hdu="edisp_exposure", format=format
1195
                )
1196
            except KeyError:
1197
                exposure_map = None
1198
1199
            if edisp_map.geom.axes[0].name == "energy":
1200
                kwargs["edisp"] = EDispKernelMap(edisp_map, exposure_map)
1201
            else:
1202
                kwargs["edisp"] = EDispMap(edisp_map, exposure_map)
1203
1204
        if "PSF" in hdulist:
1205
            psf_map = Map.from_hdulist(hdulist, hdu="psf", format=format)
1206
            try:
1207
                exposure_map = Map.from_hdulist(
1208
                    hdulist, hdu="psf_exposure", format=format
1209
                )
1210
            except KeyError:
1211
                exposure_map = None
1212
            kwargs["psf"] = PSFMap(psf_map, exposure_map)
1213
1214
        if "MASK_SAFE" in hdulist:
1215
            mask_safe = Map.from_hdulist(hdulist, hdu="mask_safe", format=format)
1216
            mask_safe.data = mask_safe.data.astype(bool)
1217
            kwargs["mask_safe"] = mask_safe
1218
1219
        if "MASK_FIT" in hdulist:
1220
            mask_fit = Map.from_hdulist(hdulist, hdu="mask_fit", format=format)
1221
            mask_fit.data = mask_fit.data.astype(bool)
1222
            kwargs["mask_fit"] = mask_fit
1223
1224
        if "GTI" in hdulist:
1225
            gti = GTI(Table.read(hdulist, hdu="GTI"))
1226
            kwargs["gti"] = gti
1227
1228
        if "META_TABLE" in hdulist:
1229
            meta_table = Table.read(hdulist, hdu="META_TABLE")
1230
            kwargs["meta_table"] = meta_table
1231
1232
        return cls(**kwargs)
1233
1234
    def write(self, filename, overwrite=False):
1235
        """Write Dataset to file.
1236
1237
        A MapDataset is serialised using the GADF format with a WCS geometry.
1238
        A SpectrumDataset uses the same format, with a RegionGeom.
1239
1240
        Parameters
1241
        ----------
1242
        filename : str
1243
            Filename to write to.
1244
        overwrite : bool
1245
            Overwrite file if it exists.
1246
        """
1247
        self.to_hdulist().writeto(str(make_path(filename)), overwrite=overwrite)
1248
1249
    @classmethod
1250
    def _read_lazy(cls, name, filename, cache, format=format):
1251
        name = make_name(name)
1252
        kwargs = {"name": name}
1253
        try:
1254
            kwargs["gti"] = GTI.read(filename)
1255
        except KeyError:
1256
            pass
1257
1258
        path = make_path(filename)
1259
        for hdu_name in ["counts", "exposure", "mask_fit", "mask_safe", "background"]:
1260
            kwargs[hdu_name] = HDULocation(
1261
                hdu_class="map",
1262
                file_dir=path.parent,
1263
                file_name=path.name,
1264
                hdu_name=hdu_name.upper(),
1265
                cache=cache,
1266
                format=format,
1267
            )
1268
1269
        kwargs["edisp"] = HDULocation(
1270
            hdu_class="edisp_kernel_map",
1271
            file_dir=path.parent,
1272
            file_name=path.name,
1273
            hdu_name="EDISP",
1274
            cache=cache,
1275
            format=format,
1276
        )
1277
1278
        kwargs["psf"] = HDULocation(
1279
            hdu_class="psf_map",
1280
            file_dir=path.parent,
1281
            file_name=path.name,
1282
            hdu_name="PSF",
1283
            cache=cache,
1284
            format=format,
1285
        )
1286
1287
        return cls(**kwargs)
1288
1289
    @classmethod
1290
    def read(cls, filename, name=None, lazy=False, cache=True, format="gadf"):
1291
        """Read a dataset from file.
1292
1293
        Parameters
1294
        ----------
1295
        filename : str
1296
            Filename to read from.
1297
        name : str
1298
            Name of the new dataset.
1299
        lazy : bool
1300
            Whether to lazy load data into memory
1301
        cache : bool
1302
            Whether to cache the data after loading.
1303
        format : {"gadf"}
1304
            Format of the dataset file.
1305
1306
        Returns
1307
        -------
1308
        dataset : `MapDataset`
1309
            Map dataset.
1310
        """
1311
1312
        if name is None:
1313
            header = fits.getheader(str(make_path(filename)))
1314
            name = header.get("NAME", name)
1315
        ds_name = make_name(name)
1316
1317
        if lazy:
1318
            return cls._read_lazy(
1319
                name=ds_name, filename=filename, cache=cache, format=format
1320
            )
1321
        else:
1322
            with fits.open(str(make_path(filename)), memmap=False) as hdulist:
1323
                return cls.from_hdulist(hdulist, name=ds_name, format=format)
1324
1325
    @classmethod
1326
    def from_dict(cls, data, lazy=False, cache=True):
1327
        """Create from dicts and models list generated from YAML serialization."""
1328
        filename = make_path(data["filename"])
1329
        dataset = cls.read(filename, name=data["name"], lazy=lazy, cache=cache)
1330
        return dataset
1331
1332
    @property
1333
    def _counts_statistic(self):
1334
        """Counts statistics of the dataset."""
1335
        return CashCountsStatistic(self.counts, self.background)
1336
1337
    def info_dict(self, in_safe_data_range=True):
1338
        """Info dict with summary statistics, summed over energy
1339
1340
        Parameters
1341
        ----------
1342
        in_safe_data_range : bool
1343
            Whether to sum only in the safe energy range
1344
1345
        Returns
1346
        -------
1347
        info_dict : dict
1348
            Dictionary with summary info.
1349
        """
1350
        info = {}
1351
        info["name"] = self.name
1352
1353
        if self.mask_safe and in_safe_data_range:
1354
            mask = self.mask_safe.data.astype(bool)
1355
        else:
1356
            mask = slice(None)
1357
1358
        counts = 0
1359
        background, excess, sqrt_ts = np.nan, np.nan, np.nan
1360
        if self.counts:
1361
            summed_stat = self._counts_statistic[mask].sum()
1362
            counts = summed_stat.n_on
1363
1364
            if self.background:
1365
                background = summed_stat.n_bkg
1366
                excess = summed_stat.n_sig
1367
                sqrt_ts = summed_stat.sqrt_ts
1368
1369
        info["counts"] = int(counts)
1370
        info["excess"] = float(excess)
1371
        info["sqrt_ts"] = sqrt_ts
1372
        info["background"] = float(background)
1373
1374
        npred = np.nan
1375
        if self.models or not np.isnan(background):
1376
            npred = self.npred().data[mask].sum()
1377
1378
        info["npred"] = float(npred)
1379
1380
        npred_background = np.nan
1381
        if self.background:
1382
            npred_background = self.npred_background().data[mask].sum()
1383
1384
        info["npred_background"] = float(npred_background)
1385
1386
        npred_signal = np.nan
1387
        if self.models:
1388
            npred_signal = self.npred_signal().data[mask].sum()
1389
1390
        info["npred_signal"] = float(npred_signal)
1391
1392
        exposure_min = np.nan * u.Unit("cm s")
1393
        exposure_max = np.nan * u.Unit("cm s")
1394
        livetime = np.nan * u.s
1395
1396
        if self.exposure is not None:
1397
            mask_exposure = self.exposure.data > 0
1398
1399
            if self.mask_safe is not None:
1400
                mask_spatial = self.mask_safe.reduce_over_axes(func=np.logical_or).data
1401
                mask_exposure = mask_exposure & mask_spatial[np.newaxis, :, :]
1402
1403
            if not mask_exposure.any():
1404
                mask_exposure = slice(None)
1405
1406
            exposure_min = np.min(self.exposure.quantity[mask_exposure])
1407
            exposure_max = np.max(self.exposure.quantity[mask_exposure])
1408
            livetime = self.exposure.meta.get("livetime", np.nan * u.s).copy()
1409
1410
        info["exposure_min"] = exposure_min.item()
1411
        info["exposure_max"] = exposure_max.item()
1412
        info["livetime"] = livetime
1413
1414
        ontime = u.Quantity(np.nan, "s")
1415
        if self.gti:
1416
            ontime = self.gti.time_sum
1417
1418
        info["ontime"] = ontime
1419
1420
        info["counts_rate"] = info["counts"] / info["livetime"]
1421
        info["background_rate"] = info["background"] / info["livetime"]
1422
        info["excess_rate"] = info["excess"] / info["livetime"]
1423
1424
        # data section
1425
        n_bins = 0
1426
        if self.counts is not None:
1427
            n_bins = self.counts.data.size
1428
        info["n_bins"] = int(n_bins)
1429
1430
        n_fit_bins = 0
1431
        if self.mask is not None:
1432
            n_fit_bins = np.sum(self.mask.data)
1433
1434
        info["n_fit_bins"] = int(n_fit_bins)
1435
        info["stat_type"] = self.stat_type
1436
1437
        stat_sum = np.nan
1438
        if self.counts is not None and self.models is not None:
1439
            stat_sum = self.stat_sum()
1440
1441
        info["stat_sum"] = float(stat_sum)
1442
1443
        return info
1444
1445
    def to_spectrum_dataset(self, on_region, containment_correction=False, name=None):
1446
        """Return a ~gammapy.datasets.SpectrumDataset from on_region.
1447
1448
        Counts and background are summed in the on_region. Exposure is taken
1449
        from the average exposure.
1450
1451
        The energy dispersion kernel is obtained at the on_region center.
1452
        Only regions with centers are supported.
1453
1454
        The model is not exported to the ~gammapy.datasets.SpectrumDataset.
1455
        It must be set after the dataset extraction.
1456
1457
        Parameters
1458
        ----------
1459
        on_region : `~regions.SkyRegion`
1460
            the input ON region on which to extract the spectrum
1461
        containment_correction : bool
1462
            Apply containment correction for point sources and circular on regions
1463
        name : str
1464
            Name of the new dataset.
1465
1466
        Returns
1467
        -------
1468
        dataset : `~gammapy.datasets.SpectrumDataset`
1469
            the resulting reduced dataset
1470
        """
1471
        from .spectrum import SpectrumDataset
1472
1473
        dataset = self.to_region_map_dataset(region=on_region, name=name)
1474
1475
        if containment_correction:
1476
            if not isinstance(on_region, CircleSkyRegion):
1477
                raise TypeError(
1478
                    "Containment correction is only supported for" " `CircleSkyRegion`."
1479
                )
1480
            elif self.psf is None or isinstance(self.psf, PSFKernel):
1481
                raise ValueError("No PSFMap set. Containment correction impossible")
1482
            else:
1483
                geom = dataset.exposure.geom
1484
                energy_true = geom.axes["energy_true"].center
1485
                containment = self.psf.containment(
1486
                    position=on_region.center,
1487
                    energy_true=energy_true,
1488
                    rad=on_region.radius,
1489
                )
1490
                dataset.exposure.quantity *= containment.reshape(geom.data_shape)
1491
1492
        kwargs = {"name": name}
1493
1494
        for key in [
1495
            "counts",
1496
            "edisp",
1497
            "mask_safe",
1498
            "mask_fit",
1499
            "exposure",
1500
            "gti",
1501
            "meta_table",
1502
        ]:
1503
            kwargs[key] = getattr(dataset, key)
1504
1505
        if self.stat_type == "cash":
1506
            kwargs["background"] = dataset.background
1507
1508
        return SpectrumDataset(**kwargs)
1509
1510
    def to_region_map_dataset(self, region, name=None):
1511
        """Integrate the map dataset in a given region.
1512
1513
        Counts and background of the dataset are integrated in the given region,
1514
        taking the safe mask into accounts. The exposure is averaged in the
1515
        region again taking the safe mask into account. The PSF and energy
1516
        dispersion kernel are taken at the center of the region.
1517
1518
        Parameters
1519
        ----------
1520
        region : `~regions.SkyRegion`
1521
            Region from which to extract the spectrum
1522
        name : str
1523
            Name of the new dataset.
1524
1525
        Returns
1526
        -------
1527
        dataset : `~gammapy.datasets.MapDataset`
1528
            the resulting reduced dataset
1529
        """
1530
        name = make_name(name)
1531
        kwargs = {"gti": self.gti, "name": name, "meta_table": self.meta_table}
1532
1533
        if self.mask_safe:
1534
            kwargs["mask_safe"] = self.mask_safe.to_region_nd_map(region, func=np.any)
1535
1536
        if self.mask_fit:
1537
            kwargs["mask_fit"] = self.mask_fit.to_region_nd_map(region, func=np.any)
1538
1539
        if self.counts:
1540
            kwargs["counts"] = self.counts.to_region_nd_map(
1541
                region, np.sum, weights=self.mask_safe
1542
            )
1543
1544
        if self.stat_type == "cash" and self.background:
1545
            kwargs["background"] = self.background.to_region_nd_map(
1546
                region, func=np.sum, weights=self.mask_safe
1547
            )
1548
1549
        if self.exposure:
1550
            kwargs["exposure"] = self.exposure.to_region_nd_map(region, func=np.mean)
1551
1552
        region = region.center if region else None
1553
1554
        # TODO: Compute average psf in region
1555
        if self.psf:
1556
            kwargs["psf"] = self.psf.to_region_nd_map(region)
1557
1558
        # TODO: Compute average edisp in region
1559
        if self.edisp is not None:
1560
            kwargs["edisp"] = self.edisp.to_region_nd_map(region)
1561
1562
        return self.__class__(**kwargs)
1563
1564
    def cutout(self, position, width, mode="trim", name=None):
1565
        """Cutout map dataset.
1566
1567
        Parameters
1568
        ----------
1569
        position : `~astropy.coordinates.SkyCoord`
1570
            Center position of the cutout region.
1571
        width : tuple of `~astropy.coordinates.Angle`
1572
            Angular sizes of the region in (lon, lat) in that specific order.
1573
            If only one value is passed, a square region is extracted.
1574
        mode : {'trim', 'partial', 'strict'}
1575
            Mode option for Cutout2D, for details see `~astropy.nddata.utils.Cutout2D`.
1576
        name : str
1577
            Name of the new dataset.
1578
1579
        Returns
1580
        -------
1581
        cutout : `MapDataset`
1582
            Cutout map dataset.
1583
        """
1584
        name = make_name(name)
1585
        kwargs = {"gti": self.gti, "name": name, "meta_table": self.meta_table}
1586
        cutout_kwargs = {"position": position, "width": width, "mode": mode}
1587
1588
        if self.counts is not None:
1589
            kwargs["counts"] = self.counts.cutout(**cutout_kwargs)
1590
1591
        if self.exposure is not None:
1592
            kwargs["exposure"] = self.exposure.cutout(**cutout_kwargs)
1593
1594
        if self.background is not None and self.stat_type == "cash":
1595
            kwargs["background"] = self.background.cutout(**cutout_kwargs)
1596
1597
        if self.edisp is not None:
1598
            kwargs["edisp"] = self.edisp.cutout(**cutout_kwargs)
1599
1600
        if self.psf is not None:
1601
            kwargs["psf"] = self.psf.cutout(**cutout_kwargs)
1602
1603
        if self.mask_safe is not None:
1604
            kwargs["mask_safe"] = self.mask_safe.cutout(**cutout_kwargs)
1605
1606
        if self.mask_fit is not None:
1607
            kwargs["mask_fit"] = self.mask_fit.cutout(**cutout_kwargs)
1608
1609
        return self.__class__(**kwargs)
1610
1611
    def downsample(self, factor, axis_name=None, name=None):
1612
        """Downsample map dataset.
1613
1614
        The PSFMap and EDispKernelMap are not downsampled, except if
1615
        a corresponding axis is given.
1616
1617
        Parameters
1618
        ----------
1619
        factor : int
1620
            Downsampling factor.
1621
        axis_name : str
1622
            Which non-spatial axis to downsample. By default only spatial axes are downsampled.
1623
        name : str
1624
            Name of the downsampled dataset.
1625
1626
        Returns
1627
        -------
1628
        dataset : `MapDataset` or `SpectrumDataset`
1629
            Downsampled map dataset.
1630
        """
1631
        name = make_name(name)
1632
1633
        kwargs = {"gti": self.gti, "name": name, "meta_table": self.meta_table}
1634
1635
        if self.counts is not None:
1636
            kwargs["counts"] = self.counts.downsample(
1637
                factor=factor,
1638
                preserve_counts=True,
1639
                axis_name=axis_name,
1640
                weights=self.mask_safe,
1641
            )
1642
1643
        if self.exposure is not None:
1644
            if axis_name is None:
1645
                kwargs["exposure"] = self.exposure.downsample(
1646
                    factor=factor, preserve_counts=False, axis_name=None
1647
                )
1648
            else:
1649
                kwargs["exposure"] = self.exposure.copy()
1650
1651
        if self.background is not None and self.stat_type == "cash":
1652
            kwargs["background"] = self.background.downsample(
1653
                factor=factor, axis_name=axis_name, weights=self.mask_safe
1654
            )
1655
1656
        if self.edisp is not None:
1657
            if axis_name is not None:
1658
                kwargs["edisp"] = self.edisp.downsample(
1659
                    factor=factor, axis_name=axis_name, weights=self.mask_safe_edisp
1660
                )
1661
            else:
1662
                kwargs["edisp"] = self.edisp.copy()
1663
1664
        if self.psf is not None:
1665
            kwargs["psf"] = self.psf.copy()
1666
1667
        if self.mask_safe is not None:
1668
            kwargs["mask_safe"] = self.mask_safe.downsample(
1669
                factor=factor, preserve_counts=False, axis_name=axis_name
1670
            )
1671
1672
        if self.mask_fit is not None:
1673
            kwargs["mask_fit"] = self.mask_fit.downsample(
1674
                factor=factor, preserve_counts=False, axis_name=axis_name
1675
            )
1676
1677
        return self.__class__(**kwargs)
1678
1679
    def pad(self, pad_width, mode="constant", name=None):
1680
        """Pad the spatial dimensions of the dataset.
1681
1682
        The padding only applies to counts, masks, background and exposure.
1683
1684
        Counts, background and masks are padded with zeros, exposure is padded with edge value.
1685
1686
        Parameters
1687
        ----------
1688
        pad_width : {sequence, array_like, int}
1689
            Number of pixels padded to the edges of each axis.
1690
        name : str
1691
            Name of the padded dataset.
1692
1693
        Returns
1694
        -------
1695
        dataset : `MapDataset`
1696
            Padded map dataset.
1697
1698
        """
1699
        name = make_name(name)
1700
        kwargs = {"gti": self.gti, "name": name, "meta_table": self.meta_table}
1701
1702
        if self.counts is not None:
1703
            kwargs["counts"] = self.counts.pad(pad_width=pad_width, mode=mode)
1704
1705
        if self.exposure is not None:
1706
            kwargs["exposure"] = self.exposure.pad(pad_width=pad_width, mode=mode)
1707
1708
        if self.background is not None:
1709
            kwargs["background"] = self.background.pad(pad_width=pad_width, mode=mode)
1710
1711
        if self.edisp is not None:
1712
            kwargs["edisp"] = self.edisp.copy()
1713
1714
        if self.psf is not None:
1715
            kwargs["psf"] = self.psf.copy()
1716
1717
        if self.mask_safe is not None:
1718
            kwargs["mask_safe"] = self.mask_safe.pad(pad_width=pad_width, mode=mode)
1719
1720
        if self.mask_fit is not None:
1721
            kwargs["mask_fit"] = self.mask_fit.pad(pad_width=pad_width, mode=mode)
1722
1723
        return self.__class__(**kwargs)
1724
1725
    def slice_by_idx(self, slices, name=None):
1726
        """Slice sub dataset.
1727
1728
        The slicing only applies to the maps that define the corresponding axes.
1729
1730
        Parameters
1731
        ----------
1732
        slices : dict
1733
            Dict of axes names and integers or `slice` object pairs. Contains one
1734
            element for each non-spatial dimension. For integer indexing the
1735
            corresponding axes is dropped from the map. Axes not specified in the
1736
            dict are kept unchanged.
1737
        name : str
1738
            Name of the sliced dataset.
1739
1740
        Returns
1741
        -------
1742
        dataset : `MapDataset` or `SpectrumDataset`
1743
            Sliced dataset
1744
1745
        Examples
1746
        --------
1747
        >>> from gammapy.datasets import MapDataset
1748
        >>> dataset = MapDataset.read("$GAMMAPY_DATA/cta-1dc-gc/cta-1dc-gc.fits.gz")
1749
        >>> slices = {"energy": slice(0, 3)} #to get the first 3 energy slices
1750
        >>> sliced = dataset.slice_by_idx(slices)
1751
        >>> print(sliced.geoms["geom"])
1752
        WcsGeom
1753
                axes       : ['lon', 'lat', 'energy']
1754
                shape      : (320, 240, 3)
1755
                ndim       : 3
1756
                frame      : galactic
1757
                projection : CAR
1758
                center     : 0.0 deg, 0.0 deg
1759
                width      : 8.0 deg x 6.0 deg
1760
                wcs ref    : 0.0 deg, 0.0 deg
1761
        """
1762
        name = make_name(name)
1763
        kwargs = {"gti": self.gti, "name": name, "meta_table": self.meta_table}
1764
1765
        if self.counts is not None:
1766
            kwargs["counts"] = self.counts.slice_by_idx(slices=slices)
1767
1768
        if self.exposure is not None:
1769
            kwargs["exposure"] = self.exposure.slice_by_idx(slices=slices)
1770
1771
        if self.background is not None and self.stat_type == "cash":
1772
            kwargs["background"] = self.background.slice_by_idx(slices=slices)
1773
1774
        if self.edisp is not None:
1775
            kwargs["edisp"] = self.edisp.slice_by_idx(slices=slices)
1776
1777
        if self.psf is not None:
1778
            kwargs["psf"] = self.psf.slice_by_idx(slices=slices)
1779
1780
        if self.mask_safe is not None:
1781
            kwargs["mask_safe"] = self.mask_safe.slice_by_idx(slices=slices)
1782
1783
        if self.mask_fit is not None:
1784
            kwargs["mask_fit"] = self.mask_fit.slice_by_idx(slices=slices)
1785
1786
        return self.__class__(**kwargs)
1787
1788
    def slice_by_energy(self, energy_min=None, energy_max=None, name=None):
1789
        """Select and slice datasets in energy range
1790
1791
        Parameters
1792
        ----------
1793
        energy_min, energy_max : `~astropy.units.Quantity`
1794
            Energy bounds to compute the flux point for.
1795
        name : str
1796
            Name of the sliced dataset.
1797
1798
        Returns
1799
        -------
1800
        dataset : `MapDataset`
1801
            Sliced Dataset
1802
1803
        Examples
1804
        --------
1805
        >>> from gammapy.datasets import MapDataset
1806
        >>> dataset = MapDataset.read("$GAMMAPY_DATA/cta-1dc-gc/cta-1dc-gc.fits.gz")
1807
        >>> sliced = dataset.slice_by_energy(energy_min="1 TeV", energy_max="5 TeV")
1808
        >>> sliced.data_shape
1809
        (3, 240, 320)
1810
        """
1811
        name = make_name(name)
1812
1813
        energy_axis = self._geom.axes["energy"]
1814
1815
        if energy_min is None:
1816
            energy_min = energy_axis.bounds[0]
1817
1818
        if energy_max is None:
1819
            energy_max = energy_axis.bounds[1]
1820
1821
        energy_min, energy_max = u.Quantity(energy_min), u.Quantity(energy_max)
1822
1823
        group = energy_axis.group_table(edges=[energy_min, energy_max])
1824
1825
        is_normal = group["bin_type"] == "normal   "
1826
        group = group[is_normal]
1827
1828
        slices = {
1829
            "energy": slice(int(group["idx_min"][0]), int(group["idx_max"][0]) + 1)
1830
        }
1831
1832
        return self.slice_by_idx(slices, name=name)
1833
1834
    def reset_data_cache(self):
1835
        """Reset data cache to free memory space"""
1836
        for name in self._lazy_data_members:
1837
            if self.__dict__.pop(name, False):
1838
                log.info(f"Clearing {name} cache for dataset {self.name}")
1839
1840
    def resample_energy_axis(self, energy_axis, name=None):
1841
        """Resample MapDataset over new reco energy axis.
1842
1843
        Counts are summed taking into account safe mask.
1844
1845
        Parameters
1846
        ----------
1847
        energy_axis : `~gammapy.maps.MapAxis`
1848
            New reconstructed energy axis.
1849
        name: str
1850
            Name of the new dataset.
1851
1852
        Returns
1853
        -------
1854
        dataset: `MapDataset` or `SpectrumDataset`
1855
            Resampled dataset.
1856
        """
1857
        name = make_name(name)
1858
        kwargs = {"gti": self.gti, "name": name, "meta_table": self.meta_table}
1859
1860
        if self.exposure:
1861
            kwargs["exposure"] = self.exposure
1862
1863
        if self.psf:
1864
            kwargs["psf"] = self.psf
1865
1866
        if self.mask_safe is not None:
1867
            kwargs["mask_safe"] = self.mask_safe.resample_axis(
1868
                axis=energy_axis, ufunc=np.logical_or
1869
            )
1870
1871
        if self.mask_fit is not None:
1872
            kwargs["mask_fit"] = self.mask_fit.resample_axis(
1873
                axis=energy_axis, ufunc=np.logical_or
1874
            )
1875
1876
        if self.counts is not None:
1877
            kwargs["counts"] = self.counts.resample_axis(
1878
                axis=energy_axis, weights=self.mask_safe
1879
            )
1880
1881
        if self.background is not None and self.stat_type == "cash":
1882
            kwargs["background"] = self.background.resample_axis(
1883
                axis=energy_axis, weights=self.mask_safe
1884
            )
1885
1886
        # Mask_safe or mask_irf??
1887
        if isinstance(self.edisp, EDispKernelMap):
1888
            kwargs["edisp"] = self.edisp.resample_energy_axis(
1889
                energy_axis=energy_axis, weights=self.mask_safe_edisp
1890
            )
1891
        else:  # None or EDispMap
1892
            kwargs["edisp"] = self.edisp
1893
1894
        return self.__class__(**kwargs)
1895
1896
    def to_image(self, name=None):
1897
        """Create images by summing over the reconstructed energy axis.
1898
1899
        Parameters
1900
        ----------
1901
        name : str
1902
            Name of the new dataset.
1903
1904
        Returns
1905
        -------
1906
        dataset : `MapDataset` or `SpectrumDataset`
1907
            Dataset integrated over non-spatial axes.
1908
        """
1909
        energy_axis = self._geom.axes["energy"].squash()
1910
        return self.resample_energy_axis(energy_axis=energy_axis, name=name)
1911
1912
    def peek(self, figsize=(12, 10)):
1913
        """Quick-look summary plots.
1914
1915
        Parameters
1916
        ----------
1917
        figsize : tuple
1918
            Size of the figure.
1919
1920
        """
1921
1922
        def plot_mask(ax, mask, **kwargs):
1923
            if mask is not None:
1924
                mask.plot_mask(ax=ax, **kwargs)
1925
1926
        fig, axes = plt.subplots(
1927
            ncols=2,
1928
            nrows=2,
1929
            subplot_kw={"projection": self._geom.wcs},
1930
            figsize=figsize,
1931
            gridspec_kw={"hspace": 0.1, "wspace": 0.1},
1932
        )
1933
1934
        axes = axes.flat
1935
        axes[0].set_title("Counts")
1936
        self.counts.sum_over_axes().plot(ax=axes[0], add_cbar=True)
1937
        plot_mask(ax=axes[0], mask=self.mask_fit_image, alpha=0.2)
1938
        plot_mask(ax=axes[0], mask=self.mask_safe_image, hatches=["///"], colors="w")
1939
1940
        axes[1].set_title("Excess counts")
1941
        self.excess.sum_over_axes().plot(ax=axes[1], add_cbar=True)
1942
        plot_mask(ax=axes[1], mask=self.mask_fit_image, alpha=0.2)
1943
        plot_mask(ax=axes[1], mask=self.mask_safe_image, hatches=["///"], colors="w")
1944
1945
        axes[2].set_title("Exposure")
1946
        self.exposure.sum_over_axes().plot(ax=axes[2], add_cbar=True)
1947
        plot_mask(ax=axes[2], mask=self.mask_safe_image, hatches=["///"], colors="w")
1948
1949
        axes[3].set_title("Background")
1950
        self.background.sum_over_axes().plot(ax=axes[3], add_cbar=True)
1951
        plot_mask(ax=axes[3], mask=self.mask_fit_image, alpha=0.2)
1952
        plot_mask(ax=axes[3], mask=self.mask_safe_image, hatches=["///"], colors="w")
1953
1954
1955
class MapDatasetOnOff(MapDataset):
1956
    """Map dataset for on-off likelihood fitting. Uses wstat statistics.
1957
1958
    Parameters
1959
    ----------
1960
    models : `~gammapy.modeling.models.Models`
1961
        Source sky models.
1962
    counts : `~gammapy.maps.WcsNDMap`
1963
        Counts cube
1964
    counts_off : `~gammapy.maps.WcsNDMap`
1965
        Ring-convolved counts cube
1966
    acceptance : `~gammapy.maps.WcsNDMap`
1967
        Acceptance from the IRFs
1968
    acceptance_off : `~gammapy.maps.WcsNDMap`
1969
        Acceptance off
1970
    exposure : `~gammapy.maps.WcsNDMap`
1971
        Exposure cube
1972
    mask_fit : `~gammapy.maps.WcsNDMap`
1973
        Mask to apply to the likelihood for fitting.
1974
    psf : `~gammapy.irf.PSFKernel`
1975
        PSF kernel
1976
    edisp : `~gammapy.irf.EDispKernel`
1977
        Energy dispersion
1978
    mask_safe : `~gammapy.maps.WcsNDMap`
1979
        Mask defining the safe data range.
1980
    gti : `~gammapy.data.GTI`
1981
        GTI of the observation or union of GTI if it is a stacked observation
1982
    meta_table : `~astropy.table.Table`
1983
        Table listing information on observations used to create the dataset.
1984
        One line per observation for stacked datasets.
1985
    name : str
1986
        Name of the dataset.
1987
1988
1989
    See Also
1990
    --------
1991
    MapDataset, SpectrumDataset, FluxPointsDataset
1992
1993
    """
1994
1995
    stat_type = "wstat"
1996
    tag = "MapDatasetOnOff"
1997
1998
    def __init__(
1999
        self,
2000
        models=None,
2001
        counts=None,
2002
        counts_off=None,
2003
        acceptance=None,
2004
        acceptance_off=None,
2005
        exposure=None,
2006
        mask_fit=None,
2007
        psf=None,
2008
        edisp=None,
2009
        name=None,
2010
        mask_safe=None,
2011
        gti=None,
2012
        meta_table=None,
2013
    ):
2014
        self._name = make_name(name)
2015
        self._evaluators = {}
2016
2017
        self.counts = counts
2018
        self.counts_off = counts_off
2019
        self.exposure = exposure
2020
        self.acceptance = acceptance
2021
        self.acceptance_off = acceptance_off
2022
        self.gti = gti
2023
        self.mask_fit = mask_fit
2024
        self.psf = psf
2025
        self.edisp = edisp
2026
        self.models = models
2027
        self.mask_safe = mask_safe
2028
        self.meta_table = meta_table
2029
2030
    def __str__(self):
2031
        str_ = super().__str__()
2032
2033
        counts_off = np.nan
2034
        if self.counts_off is not None:
2035
            counts_off = np.sum(self.counts_off.data)
2036
        str_ += "\t{:32}: {:.0f} \n".format("Total counts_off", counts_off)
2037
2038
        acceptance = np.nan
2039
        if self.acceptance is not None:
2040
            acceptance = np.sum(self.acceptance.data)
2041
        str_ += "\t{:32}: {:.0f} \n".format("Acceptance", acceptance)
2042
2043
        acceptance_off = np.nan
2044
        if self.acceptance_off is not None:
2045
            acceptance_off = np.sum(self.acceptance_off.data)
2046
        str_ += "\t{:32}: {:.0f} \n".format("Acceptance off", acceptance_off)
2047
2048
        return str_.expandtabs(tabsize=2)
2049
2050
    @property
2051
    def _geom(self):
2052
        """Main analysis geometry"""
2053
        if self.counts is not None:
2054
            return self.counts.geom
2055
        elif self.counts_off is not None:
2056
            return self.counts_off.geom
2057
        elif self.acceptance is not None:
2058
            return self.acceptance.geom
2059
        elif self.acceptance_off is not None:
2060
            return self.acceptance_off.geom
2061
        else:
2062
            raise ValueError(
2063
                "Either 'counts', 'counts_off', 'acceptance' or 'acceptance_of' must be defined."
2064
            )
2065
2066
    @property
2067
    def alpha(self):
2068
        """Exposure ratio between signal and background regions
2069
2070
        See :ref:`wstat`
2071
2072
        Returns
2073
        -------
2074
        alpha : `Map`
2075
            Alpha map
2076
        """
2077
        with np.errstate(invalid="ignore", divide="ignore"):
2078
            alpha = self.acceptance / self.acceptance_off
2079
2080
        alpha.data = np.nan_to_num(alpha.data)
2081
        return alpha
2082
2083
    def npred_background(self):
2084
        """Predicted background counts estimated from the marginalized likelihood estimate.
2085
2086
        See :ref:`wstat`
2087
2088
        Returns
2089
        -------
2090
        npred_background : `Map`
2091
            Predicted background counts
2092
        """
2093
        mu_bkg = self.alpha.data * get_wstat_mu_bkg(
2094
            n_on=self.counts.data,
2095
            n_off=self.counts_off.data,
2096
            alpha=self.alpha.data,
2097
            mu_sig=self.npred_signal().data,
2098
        )
2099
        mu_bkg = np.nan_to_num(mu_bkg)
2100
        return Map.from_geom(geom=self._geom, data=mu_bkg)
2101
2102
    def npred_off(self):
2103
        """Predicted counts in the off region; mu_bkg/alpha
2104
2105
        See :ref:`wstat`
2106
2107
        Returns
2108
        -------
2109
        npred_off : `Map`
2110
            Predicted off counts
2111
        """
2112
        return self.npred_background() / self.alpha
2113
2114
    @property
2115
    def background(self):
2116
        """Computed as alpha * n_off
2117
2118
        See :ref:`wstat`
2119
2120
        Returns
2121
        -------
2122
        background : `Map`
2123
            Background map
2124
        """
2125
        if self.counts_off is None:
2126
            return None
2127
        return self.alpha * self.counts_off
2128
2129
    def stat_array(self):
2130
        """Likelihood per bin given the current model parameters"""
2131
        mu_sig = self.npred_signal().data
2132
        on_stat_ = wstat(
2133
            n_on=self.counts.data,
2134
            n_off=self.counts_off.data,
2135
            alpha=list(self.alpha.data),
2136
            mu_sig=mu_sig,
2137
        )
2138
        return np.nan_to_num(on_stat_)
2139
2140
    @property
2141
    def _counts_statistic(self):
2142
        """Counts statistics of the dataset."""
2143
        return WStatCountsStatistic(self.counts, self.counts_off, self.alpha)
2144
2145
    @classmethod
2146
    def from_geoms(
2147
        cls,
2148
        geom,
2149
        geom_exposure,
2150
        geom_psf=None,
2151
        geom_edisp=None,
2152
        reference_time="2000-01-01",
2153
        name=None,
2154
        **kwargs,
2155
    ):
2156
        """Create an empty `MapDatasetOnOff` object according to the specified geometries
2157
2158
        Parameters
2159
        ----------
2160
        geom : `gammapy.maps.WcsGeom`
2161
            geometry for the counts, counts_off, acceptance and acceptance_off maps
2162
        geom_exposure : `gammapy.maps.WcsGeom`
2163
            geometry for the exposure map
2164
        geom_psf : `gammapy.maps.WcsGeom`
2165
            geometry for the psf map
2166
        geom_edisp : `gammapy.maps.WcsGeom`
2167
            geometry for the energy dispersion kernel map.
2168
            If geom_edisp has a migra axis, this will create an EDispMap instead.
2169
        reference_time : `~astropy.time.Time`
2170
            the reference time to use in GTI definition
2171
        name : str
2172
            Name of the returned dataset.
2173
2174
        Returns
2175
        -------
2176
        empty_maps : `MapDatasetOnOff`
2177
            A MapDatasetOnOff containing zero filled maps
2178
        """
2179
        #  TODO: it seems the super() pattern does not work here?
2180
        dataset = MapDataset.from_geoms(
2181
            geom=geom,
2182
            geom_exposure=geom_exposure,
2183
            geom_psf=geom_psf,
2184
            geom_edisp=geom_edisp,
2185
            name=name,
2186
            reference_time=reference_time,
2187
            **kwargs,
2188
        )
2189
2190
        off_maps = {}
2191
2192
        for key in ["counts_off", "acceptance", "acceptance_off"]:
2193
            off_maps[key] = Map.from_geom(geom, unit="")
2194
2195
        return cls.from_map_dataset(dataset, name=name, **off_maps)
2196
2197
    @classmethod
2198
    def from_map_dataset(
2199
        cls, dataset, acceptance, acceptance_off, counts_off=None, name=None
2200
    ):
2201
        """Create on off dataset from a map dataset.
2202
2203
        Parameters
2204
        ----------
2205
        dataset : `MapDataset`
2206
            Spectrum dataset defining counts, edisp, aeff, livetime etc.
2207
        acceptance : `Map`
2208
            Relative background efficiency in the on region.
2209
        acceptance_off : `Map`
2210
            Relative background efficiency in the off region.
2211
        counts_off : `Map`
2212
            Off counts map . If the dataset provides a background model,
2213
            and no off counts are defined. The off counts are deferred from
2214
            counts_off / alpha.
2215
        name : str
2216
            Name of the returned dataset.
2217
2218
        Returns
2219
        -------
2220
        dataset : `MapDatasetOnOff`
2221
            Map dataset on off.
2222
2223
        """
2224
        if counts_off is None and dataset.background is not None:
2225
            alpha = acceptance / acceptance_off
2226
            counts_off = dataset.npred_background() / alpha
2227
2228
        if np.isscalar(acceptance):
2229
            acceptance = Map.from_geom(dataset._geom, data=acceptance)
2230
2231
        if np.isscalar(acceptance_off):
2232
            acceptance_off = Map.from_geom(dataset._geom, data=acceptance_off)
2233
2234
        return cls(
2235
            models=dataset.models,
2236
            counts=dataset.counts,
2237
            exposure=dataset.exposure,
2238
            counts_off=counts_off,
2239
            edisp=dataset.edisp,
2240
            psf=dataset.psf,
2241
            mask_safe=dataset.mask_safe,
2242
            mask_fit=dataset.mask_fit,
2243
            acceptance=acceptance,
2244
            acceptance_off=acceptance_off,
2245
            gti=dataset.gti,
2246
            name=name,
2247
            meta_table=dataset.meta_table,
2248
        )
2249
2250
    def to_map_dataset(self, name=None):
2251
        """Convert a MapDatasetOnOff to  MapDataset
2252
2253
        The background model template is taken as alpha * counts_off
2254
2255
        Parameters
2256
        ----------
2257
        name: str
2258
            Name of the new dataset
2259
2260
        Returns
2261
        -------
2262
        dataset: `MapDataset`
2263
            Map dataset with cash statistics
2264
        """
2265
        name = make_name(name)
2266
2267
        return MapDataset(
2268
            counts=self.counts,
2269
            exposure=self.exposure,
2270
            psf=self.psf,
2271
            edisp=self.edisp,
2272
            name=name,
2273
            gti=self.gti,
2274
            mask_fit=self.mask_fit,
2275
            mask_safe=self.mask_safe,
2276
            background=self.counts_off * self.alpha,
2277
            meta_table=self.meta_table,
2278
        )
2279
2280
    @property
2281
    def _is_stackable(self):
2282
        """Check if the Dataset contains enough information to be stacked"""
2283
        incomplete = (
2284
            self.acceptance_off is None
2285
            or self.acceptance is None
2286
            or self.counts_off is None
2287
        )
2288
        unmasked = np.any(self.mask_safe.data)
2289
        if incomplete and unmasked:
2290
            return False
2291
        else:
2292
            return True
2293
2294
    def stack(self, other, nan_to_num=True):
2295
        r"""Stack another dataset in place.
2296
2297
        The ``acceptance`` of the stacked dataset is normalized to 1,
2298
        and the stacked ``acceptance_off`` is scaled so that:
2299
2300
        .. math::
2301
            \alpha_\text{stacked} =
2302
            \frac{1}{a_\text{off}} =
2303
            \frac{\alpha_1\text{OFF}_1 + \alpha_2\text{OFF}_2}{\text{OFF}_1 + OFF_2}
2304
2305
        Parameters
2306
        ----------
2307
        other : `MapDatasetOnOff`
2308
            Other dataset
2309
        nan_to_num: bool
2310
            Non-finite values are replaced by zero if True (default).
2311
        """
2312
        if not isinstance(other, MapDatasetOnOff):
2313
            raise TypeError("Incompatible types for MapDatasetOnOff stacking")
2314
2315
        if not self._is_stackable or not other._is_stackable:
2316
            raise ValueError("Cannot stack incomplete MapDatsetOnOff.")
2317
2318
        geom = self.counts.geom
2319
        total_off = Map.from_geom(geom)
2320
        total_alpha = Map.from_geom(geom)
2321
2322
        if self.counts_off:
2323
            total_off.stack(
2324
                self.counts_off, weights=self.mask_safe, nan_to_num=nan_to_num
2325
            )
2326
            total_alpha.stack(
2327
                self.alpha * self.counts_off,
2328
                weights=self.mask_safe,
2329
                nan_to_num=nan_to_num,
2330
            )
2331
        if other.counts_off:
2332
            total_off.stack(
2333
                other.counts_off, weights=other.mask_safe, nan_to_num=nan_to_num
2334
            )
2335
            total_alpha.stack(
2336
                other.alpha * other.counts_off,
2337
                weights=other.mask_safe,
2338
                nan_to_num=nan_to_num,
2339
            )
2340
2341
        with np.errstate(divide="ignore", invalid="ignore"):
2342
            acceptance_off = total_off / total_alpha
2343
            average_alpha = total_alpha.data.sum() / total_off.data.sum()
2344
2345
        # For the bins where the stacked OFF counts equal 0, the alpha value is
2346
        # performed by weighting on the total OFF counts of each run
2347
        is_zero = total_off.data == 0
2348
        acceptance_off.data[is_zero] = 1 / average_alpha
2349
2350
        self.acceptance.data[...] = 1
2351
        self.acceptance_off = acceptance_off
2352
2353
        self.counts_off = total_off
2354
2355
        super().stack(other, nan_to_num=nan_to_num)
2356
2357
    def stat_sum(self):
2358
        """Total likelihood given the current model parameters."""
2359
        return Dataset.stat_sum(self)
2360
2361
    def fake(self, npred_background, random_state="random-seed"):
2362
        """Simulate fake counts (on and off) for the current model and reduced IRFs.
2363
2364
        This method overwrites the counts defined on the dataset object.
2365
2366
        Parameters
2367
        ----------
2368
        random_state : {int, 'random-seed', 'global-rng', `~numpy.random.RandomState`}
2369
                Defines random number generator initialisation.
2370
                Passed to `~gammapy.utils.random.get_random_state`.
2371
        """
2372
        random_state = get_random_state(random_state)
2373
        npred = self.npred_signal()
2374
        data = np.nan_to_num(npred.data, copy=True, nan=0.0, posinf=0.0, neginf=0.0)
2375
        npred.data = random_state.poisson(data)
2376
2377
        npred_bkg = random_state.poisson(npred_background.data)
2378
2379
        self.counts = npred + npred_bkg
2380
2381
        npred_off = npred_background / self.alpha
2382
        data_off = np.nan_to_num(
2383
            npred_off.data, copy=True, nan=0.0, posinf=0.0, neginf=0.0
2384
        )
2385
        npred_off.data = random_state.poisson(data_off)
2386
        self.counts_off = npred_off
2387
2388
    def to_hdulist(self):
2389
        """Convert map dataset to list of HDUs.
2390
2391
        Returns
2392
        -------
2393
        hdulist : `~astropy.io.fits.HDUList`
2394
            Map dataset list of HDUs.
2395
        """
2396
        hdulist = super().to_hdulist()
2397
        exclude_primary = slice(1, None)
2398
2399
        del hdulist["BACKGROUND"]
2400
        del hdulist["BACKGROUND_BANDS"]
2401
2402
        if self.counts_off is not None:
2403
            hdulist += self.counts_off.to_hdulist(hdu="counts_off")[exclude_primary]
2404
2405
        if self.acceptance is not None:
2406
            hdulist += self.acceptance.to_hdulist(hdu="acceptance")[exclude_primary]
2407
2408
        if self.acceptance_off is not None:
2409
            hdulist += self.acceptance_off.to_hdulist(hdu="acceptance_off")[
2410
                exclude_primary
2411
            ]
2412
2413
        return hdulist
2414
2415
    @classmethod
2416
    def _read_lazy(cls, filename, name=None, cache=True, format="gadf"):
2417
        raise NotImplementedError(
2418
            f"Lazy loading is not implemented for {cls}, please use option lazy=False."
2419
        )
2420
2421
    @classmethod
2422
    def from_hdulist(cls, hdulist, name=None, format="gadf"):
2423
        """Create map dataset from list of HDUs.
2424
2425
        Parameters
2426
        ----------
2427
        hdulist : `~astropy.io.fits.HDUList`
2428
            List of HDUs.
2429
        name : str
2430
            Name of the new dataset.
2431
        format : {"gadf"}
2432
            Format the hdulist is given in.
2433
2434
        Returns
2435
        -------
2436
        dataset : `MapDatasetOnOff`
2437
            Map dataset.
2438
        """
2439
        kwargs = {}
2440
        kwargs["name"] = name
2441
2442
        if "COUNTS" in hdulist:
2443
            kwargs["counts"] = Map.from_hdulist(hdulist, hdu="counts", format=format)
2444
2445
        if "COUNTS_OFF" in hdulist:
2446
            kwargs["counts_off"] = Map.from_hdulist(
2447
                hdulist, hdu="counts_off", format=format
2448
            )
2449
2450
        if "ACCEPTANCE" in hdulist:
2451
            kwargs["acceptance"] = Map.from_hdulist(
2452
                hdulist, hdu="acceptance", format=format
2453
            )
2454
2455
        if "ACCEPTANCE_OFF" in hdulist:
2456
            kwargs["acceptance_off"] = Map.from_hdulist(
2457
                hdulist, hdu="acceptance_off", format=format
2458
            )
2459
2460
        if "EXPOSURE" in hdulist:
2461
            kwargs["exposure"] = Map.from_hdulist(
2462
                hdulist, hdu="exposure", format=format
2463
            )
2464
2465 View Code Duplication
        if "EDISP" in hdulist:
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
2466
            edisp_map = Map.from_hdulist(hdulist, hdu="edisp", format=format)
2467
2468
            try:
2469
                exposure_map = Map.from_hdulist(
2470
                    hdulist, hdu="edisp_exposure", format=format
2471
                )
2472
            except KeyError:
2473
                exposure_map = None
2474
2475
            if edisp_map.geom.axes[0].name == "energy":
2476
                kwargs["edisp"] = EDispKernelMap(edisp_map, exposure_map)
2477
            else:
2478
                kwargs["edisp"] = EDispMap(edisp_map, exposure_map)
2479
2480
        if "PSF" in hdulist:
2481
            psf_map = Map.from_hdulist(hdulist, hdu="psf", format=format)
2482
            try:
2483
                exposure_map = Map.from_hdulist(
2484
                    hdulist, hdu="psf_exposure", format=format
2485
                )
2486
            except KeyError:
2487
                exposure_map = None
2488
            kwargs["psf"] = PSFMap(psf_map, exposure_map)
2489
2490
        if "MASK_SAFE" in hdulist:
2491
            mask_safe = Map.from_hdulist(hdulist, hdu="mask_safe", format=format)
2492
            kwargs["mask_safe"] = mask_safe
2493
2494
        if "MASK_FIT" in hdulist:
2495
            mask_fit = Map.from_hdulist(hdulist, hdu="mask_fit", format=format)
2496
            kwargs["mask_fit"] = mask_fit
2497
2498
        if "GTI" in hdulist:
2499
            gti = GTI(Table.read(hdulist, hdu="GTI"))
2500
            kwargs["gti"] = gti
2501
2502
        if "META_TABLE" in hdulist:
2503
            meta_table = Table.read(hdulist, hdu="META_TABLE")
2504
            kwargs["meta_table"] = meta_table
2505
        return cls(**kwargs)
2506
2507
    def info_dict(self, in_safe_data_range=True):
2508
        """Basic info dict with summary statistics
2509
2510
        If a region is passed, then a spectrum dataset is
2511
        extracted, and the corresponding info returned.
2512
2513
        Parameters
2514
        ----------
2515
        in_safe_data_range : bool
2516
            Whether to sum only in the safe energy range
2517
2518
        Returns
2519
        -------
2520
        info_dict : dict
2521
            Dictionary with summary info.
2522
        """
2523
        # TODO: remove code duplication with SpectrumDatasetOnOff
2524
        info = super().info_dict(in_safe_data_range)
2525
2526
        if self.mask_safe and in_safe_data_range:
2527
            mask = self.mask_safe.data.astype(bool)
2528
        else:
2529
            mask = slice(None)
2530
2531
        summed_stat = self._counts_statistic[mask].sum()
2532
2533
        counts_off = 0
2534
        if self.counts_off is not None:
2535
            counts_off = summed_stat.n_off
2536
2537
        info["counts_off"] = int(counts_off)
2538
2539
        acceptance = 1
2540
        if self.acceptance:
2541
            acceptance = self.acceptance.data[mask].sum()
2542
2543
        info["acceptance"] = float(acceptance)
2544
2545
        acceptance_off = np.nan
2546
        alpha = np.nan
2547
2548
        if self.acceptance_off:
2549
            alpha = summed_stat.alpha
2550
            acceptance_off = acceptance / alpha
2551
2552
        info["acceptance_off"] = float(acceptance_off)
2553
        info["alpha"] = float(alpha)
2554
2555
        info["stat_sum"] = self.stat_sum()
2556
        return info
2557
2558
    def to_spectrum_dataset(self, on_region, containment_correction=False, name=None):
2559
        """Return a ~gammapy.datasets.SpectrumDatasetOnOff from on_region.
2560
2561
        Counts and OFF counts are summed in the on_region.
2562
2563
        Acceptance is the average of all acceptances while acceptance OFF
2564
        is taken such that number of excess is preserved in the on_region.
2565
2566
        Effective area is taken from the average exposure.
2567
2568
        The energy dispersion kernel is obtained at the on_region center.
2569
        Only regions with centers are supported.
2570
2571
        The models are not exported to the ~gammapy.dataset.SpectrumDatasetOnOff.
2572
        It must be set after the dataset extraction.
2573
2574
        Parameters
2575
        ----------
2576
        on_region : `~regions.SkyRegion`
2577
            the input ON region on which to extract the spectrum
2578
        containment_correction : bool
2579
            Apply containment correction for point sources and circular on regions
2580
        name : str
2581
            Name of the new dataset.
2582
2583
        Returns
2584
        -------
2585
        dataset : `~gammapy.datasets.SpectrumDatasetOnOff`
2586
            the resulting reduced dataset
2587
        """
2588
        from .spectrum import SpectrumDatasetOnOff
2589
2590
        dataset = super().to_spectrum_dataset(
2591
            on_region=on_region,
2592
            containment_correction=containment_correction,
2593
            name=name,
2594
        )
2595
2596
        kwargs = {"name": name}
2597
2598
        if self.counts_off is not None:
2599
            kwargs["counts_off"] = self.counts_off.get_spectrum(
2600
                on_region, np.sum, weights=self.mask_safe
2601
            )
2602
2603
        if self.acceptance is not None:
2604
            kwargs["acceptance"] = self.acceptance.get_spectrum(
2605
                on_region, np.mean, weights=self.mask_safe
2606
            )
2607
            norm = self.background.get_spectrum(
2608
                on_region, np.sum, weights=self.mask_safe
2609
            )
2610
            acceptance_off = kwargs["acceptance"] * kwargs["counts_off"] / norm
2611
            np.nan_to_num(acceptance_off.data, copy=False)
2612
            kwargs["acceptance_off"] = acceptance_off
2613
2614
        return SpectrumDatasetOnOff.from_spectrum_dataset(dataset=dataset, **kwargs)
2615
2616
    def cutout(self, position, width, mode="trim", name=None):
2617
        """Cutout map dataset.
2618
2619
        Parameters
2620
        ----------
2621
        position : `~astropy.coordinates.SkyCoord`
2622
            Center position of the cutout region.
2623
        width : tuple of `~astropy.coordinates.Angle`
2624
            Angular sizes of the region in (lon, lat) in that specific order.
2625
            If only one value is passed, a square region is extracted.
2626
        mode : {'trim', 'partial', 'strict'}
2627
            Mode option for Cutout2D, for details see `~astropy.nddata.utils.Cutout2D`.
2628
        name : str
2629
            Name of the new dataset.
2630
2631
        Returns
2632
        -------
2633
        cutout : `MapDatasetOnOff`
2634
            Cutout map dataset.
2635
        """
2636
        cutout_kwargs = {
2637
            "position": position,
2638
            "width": width,
2639
            "mode": mode,
2640
            "name": name,
2641
        }
2642
2643
        cutout_dataset = super().cutout(**cutout_kwargs)
2644
2645
        del cutout_kwargs["name"]
2646
2647
        if self.counts_off is not None:
2648
            cutout_dataset.counts_off = self.counts_off.cutout(**cutout_kwargs)
2649
2650
        if self.acceptance is not None:
2651
            cutout_dataset.acceptance = self.acceptance.cutout(**cutout_kwargs)
2652
2653
        if self.acceptance_off is not None:
2654
            cutout_dataset.acceptance_off = self.acceptance_off.cutout(**cutout_kwargs)
2655
2656
        return cutout_dataset
2657
2658
    def downsample(self, factor, axis_name=None, name=None):
2659
        """Downsample map dataset.
2660
2661
        The PSFMap and EDispKernelMap are not downsampled, except if
2662
        a corresponding axis is given.
2663
2664
        Parameters
2665
        ----------
2666
        factor : int
2667
            Downsampling factor.
2668
        axis_name : str
2669
            Which non-spatial axis to downsample. By default only spatial axes are downsampled.
2670
        name : str
2671
            Name of the downsampled dataset.
2672
2673
        Returns
2674
        -------
2675
        dataset : `MapDatasetOnOff`
2676
            Downsampled map dataset.
2677
        """
2678
2679
        dataset = super().downsample(factor, axis_name, name)
2680
2681
        counts_off = None
2682
        if self.counts_off is not None:
2683
            counts_off = self.counts_off.downsample(
2684
                factor=factor,
2685
                preserve_counts=True,
2686
                axis_name=axis_name,
2687
                weights=self.mask_safe,
2688
            )
2689
2690
        acceptance, acceptance_off = None, None
2691
        if self.acceptance_off is not None:
2692
            acceptance = self.acceptance.downsample(
2693
                factor=factor, preserve_counts=False, axis_name=axis_name
2694
            )
2695
            factor = self.background.downsample(
2696
                factor=factor,
2697
                preserve_counts=True,
2698
                axis_name=axis_name,
2699
                weights=self.mask_safe,
2700
            )
2701
            acceptance_off = acceptance * counts_off / factor
2702
2703
        return self.__class__.from_map_dataset(
2704
            dataset,
2705
            acceptance=acceptance,
2706
            acceptance_off=acceptance_off,
2707
            counts_off=counts_off,
2708
        )
2709
2710
    def pad(self):
2711
        raise NotImplementedError
2712
2713
    def slice_by_idx(self, slices, name=None):
2714
        """Slice sub dataset.
2715
2716
        The slicing only applies to the maps that define the corresponding axes.
2717
2718
        Parameters
2719
        ----------
2720
        slices : dict
2721
            Dict of axes names and integers or `slice` object pairs. Contains one
2722
            element for each non-spatial dimension. For integer indexing the
2723
            corresponding axes is dropped from the map. Axes not specified in the
2724
            dict are kept unchanged.
2725
        name : str
2726
            Name of the sliced dataset.
2727
2728
        Returns
2729
        -------
2730
        map_out : `Map`
2731
            Sliced map object.
2732
        """
2733
        kwargs = {"name": name}
2734
        dataset = super().slice_by_idx(slices, name)
2735
2736
        if self.counts_off is not None:
2737
            kwargs["counts_off"] = self.counts_off.slice_by_idx(slices=slices)
2738
2739
        if self.acceptance is not None:
2740
            kwargs["acceptance"] = self.acceptance.slice_by_idx(slices=slices)
2741
2742
        if self.acceptance_off is not None:
2743
            kwargs["acceptance_off"] = self.acceptance_off.slice_by_idx(slices=slices)
2744
2745
        return self.from_map_dataset(dataset, **kwargs)
2746
2747
    def resample_energy_axis(self, energy_axis, name=None):
2748
        """Resample MapDatasetOnOff over reconstructed energy edges.
2749
2750
        Counts are summed taking into account safe mask.
2751
2752
        Parameters
2753
        ----------
2754
        energy_axis : `~gammapy.maps.MapAxis`
2755
            New reco energy axis.
2756
        name: str
2757
            Name of the new dataset.
2758
2759
        Returns
2760
        -------
2761
        dataset: `SpectrumDataset`
2762
            Resampled spectrum dataset .
2763
        """
2764
        dataset = super().resample_energy_axis(energy_axis, name)
2765
2766
        counts_off = None
2767
        if self.counts_off is not None:
2768
            counts_off = self.counts_off
2769
            counts_off = counts_off.resample_axis(
2770
                axis=energy_axis, weights=self.mask_safe
2771
            )
2772
2773
        acceptance = 1
2774
        acceptance_off = None
2775
        if self.acceptance is not None:
2776
            acceptance = self.acceptance
2777
            acceptance = acceptance.resample_axis(
2778
                axis=energy_axis, weights=self.mask_safe
2779
            )
2780
2781
            norm_factor = self.background.resample_axis(
2782
                axis=energy_axis, weights=self.mask_safe
2783
            )
2784
2785
            acceptance_off = acceptance * counts_off / norm_factor
2786
2787
        return self.__class__.from_map_dataset(
2788
            dataset,
2789
            acceptance=acceptance,
2790
            acceptance_off=acceptance_off,
2791
            counts_off=counts_off,
2792
            name=name,
2793
        )
2794