Passed
Push — master ( f05dd0...7b2f36 )
by Axel
02:40 queued 12s
created

gammapy.maps.core.Map.interp_to_geom()   B

Complexity

Conditions 6

Size

Total Lines 38
Code Lines 14

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 6
eloc 14
nop 5
dl 0
loc 38
rs 8.6666
c 0
b 0
f 0
1
# Licensed under a 3-clause BSD style license - see LICENSE.rst
2
import abc
3
import copy
4
import inspect
5
import json
6
import numpy as np
7
from astropy import units as u
8
from astropy.io import fits
9
from gammapy.utils.scripts import make_path
10
from .axes import MapAxis
11
from .coord import MapCoord
12
from .geom import pix_tuple_to_idx
13
from .io import JsonQuantityDecoder
14
15
__all__ = ["Map"]
16
17
18
class Map(abc.ABC):
19
    """Abstract map class.
20
21
    This can represent WCS- or HEALPIX-based maps
22
    with 2 spatial dimensions and N non-spatial dimensions.
23
24
    Parameters
25
    ----------
26
    geom : `~gammapy.maps.Geom`
27
        Geometry
28
    data : `~numpy.ndarray` or `~astropy.units.Quantity`
29
        Data array
30
    meta : `dict`
31
        Dictionary to store meta data
32
    unit : str or `~astropy.units.Unit`
33
        Data unit, ignored if data is a Quantity.
34
    """
35
36
    tag = "map"
37
38
    def __init__(self, geom, data, meta=None, unit=""):
39
        self._geom = geom
40
41
        if isinstance(data, u.Quantity):
42
            self.unit = unit
43
            self.quantity = data
44
        else:
45
            self.data = data
46
            self.unit = unit
47
48
        if meta is None:
49
            self.meta = {}
50
        else:
51
            self.meta = meta
52
53
    def _init_copy(self, **kwargs):
54
        """Init map instance by copying missing init arguments from self."""
55
        argnames = inspect.getfullargspec(self.__init__).args
56
        argnames.remove("self")
57
        argnames.remove("dtype")
58
59
        for arg in argnames:
60
            value = getattr(self, "_" + arg)
61
            kwargs.setdefault(arg, copy.deepcopy(value))
62
63
        return self.from_geom(**kwargs)
64
65
    @property
66
    def is_mask(self):
67
        """Whether map is mask with bool dtype"""
68
        return self.data.dtype == bool
69
70
    @property
71
    def geom(self):
72
        """Map geometry (`~gammapy.maps.Geom`)"""
73
        return self._geom
74
75
    @property
76
    def data(self):
77
        """Data array (`~numpy.ndarray`)"""
78
        return self._data
79
80
    @data.setter
81
    def data(self, value):
82
        """Set data
83
84
        Parameters
85
        ----------
86
        value : array-like
87
            Data array
88
        """
89
        if np.isscalar(value):
90
            value = value * np.ones(self.geom.data_shape, dtype=type(value))
91
92
        if isinstance(value, u.Quantity):
93
            raise TypeError("Map data must be a Numpy array. Set unit separately")
94
95
        if not value.shape == self.geom.data_shape:
96
            value = value.reshape(self.geom.data_shape)
97
98
        self._data = value
99
100
    @property
101
    def unit(self):
102
        """Map unit (`~astropy.units.Unit`)"""
103
        return self._unit
104
105
    @unit.setter
106
    def unit(self, val):
107
        self._unit = u.Unit(val)
108
109
    @property
110
    def meta(self):
111
        """Map meta (`dict`)"""
112
        return self._meta
113
114
    @meta.setter
115
    def meta(self, val):
116
        self._meta = val
117
118
    @property
119
    def quantity(self):
120
        """Map data times unit (`~astropy.units.Quantity`)"""
121
        return u.Quantity(self.data, self.unit, copy=False)
122
123
    @quantity.setter
124
    def quantity(self, val):
125
        """Set data and unit
126
127
        Parameters
128
        ----------
129
        value : `~astropy.units.Quantity`
130
           Quantity
131
        """
132
        val = u.Quantity(val, copy=False)
133
134
        self.data = val.value
135
        self.unit = val.unit
136
137
    @staticmethod
138
    def create(**kwargs):
139
        """Create an empty map object.
140
141
        This method accepts generic options listed below, as well as options
142
        for `HpxMap` and `WcsMap` objects. For WCS-specific options, see
143
        `WcsMap.create` and for HPX-specific options, see `HpxMap.create`.
144
145
        Parameters
146
        ----------
147
        frame : str
148
            Coordinate system, either Galactic ("galactic") or Equatorial
149
            ("icrs").
150
        map_type : {'wcs', 'wcs-sparse', 'hpx', 'hpx-sparse', 'region'}
151
            Map type.  Selects the class that will be used to
152
            instantiate the map.
153
        binsz : float or `~numpy.ndarray`
154
            Pixel size in degrees.
155
        skydir : `~astropy.coordinates.SkyCoord`
156
            Coordinate of map center.
157
        axes : list
158
            List of `~MapAxis` objects for each non-spatial dimension.
159
            If None then the map will be a 2D image.
160
        dtype : str
161
            Data type, default is 'float32'
162
        unit : str or `~astropy.units.Unit`
163
            Data unit.
164
        meta : `dict`
165
            Dictionary to store meta data.
166
        region : `~regions.SkyRegion`
167
            Sky region used for the region map.
168
169
        Returns
170
        -------
171
        map : `Map`
172
            Empty map object.
173
        """
174
        from .hpx import HpxMap
175
        from .region import RegionNDMap
176
        from .wcs import WcsMap
177
178
        map_type = kwargs.setdefault("map_type", "wcs")
179
        if "wcs" in map_type.lower():
180
            return WcsMap.create(**kwargs)
181
        elif "hpx" in map_type.lower():
182
            return HpxMap.create(**kwargs)
183
        elif map_type == "region":
184
            _ = kwargs.pop("map_type")
185
            return RegionNDMap.create(**kwargs)
186
        else:
187
            raise ValueError(f"Unrecognized map type: {map_type!r}")
188
189
    @staticmethod
190
    def read(
191
        filename, hdu=None, hdu_bands=None, map_type="auto", format=None, colname=None
192
    ):
193
        """Read a map from a FITS file.
194
195
        Parameters
196
        ----------
197
        filename : str or `~pathlib.Path`
198
            Name of the FITS file.
199
        hdu : str
200
            Name or index of the HDU with the map data.
201
        hdu_bands : str
202
            Name or index of the HDU with the BANDS table.  If not
203
            defined this will be inferred from the FITS header of the
204
            map HDU.
205
        map_type : {'wcs', 'wcs-sparse', 'hpx', 'hpx-sparse', 'auto', 'region'}
206
            Map type.  Selects the class that will be used to
207
            instantiate the map.  The map type should be consistent
208
            with the format of the input file.  If map_type is 'auto'
209
            then an appropriate map type will be inferred from the
210
            input file.
211
        colname : str, optional
212
            data column name to be used of healix map.
213
214
        Returns
215
        -------
216
        map_out : `Map`
217
            Map object
218
        """
219
        with fits.open(str(make_path(filename)), memmap=False) as hdulist:
220
            return Map.from_hdulist(
221
                hdulist, hdu, hdu_bands, map_type, format=format, colname=colname
222
            )
223
224
    @staticmethod
225
    def from_geom(geom, meta=None, data=None, unit="", dtype="float32"):
226
        """Generate an empty map from a `Geom` instance.
227
228
        Parameters
229
        ----------
230
        geom : `Geom`
231
            Map geometry.
232
        data : `numpy.ndarray`
233
            data array
234
        meta : `dict`
235
            Dictionary to store meta data.
236
        unit : str or `~astropy.units.Unit`
237
            Data unit.
238
239
        Returns
240
        -------
241
        map_out : `Map`
242
            Map object
243
244
        """
245
        from .hpx import HpxGeom
246
        from .region import RegionGeom
247
        from .wcs import WcsGeom
248
249
        if isinstance(geom, HpxGeom):
250
            map_type = "hpx"
251
        elif isinstance(geom, WcsGeom):
252
            map_type = "wcs"
253
        elif isinstance(geom, RegionGeom):
254
            map_type = "region"
255
        else:
256
            raise ValueError("Unrecognized geom type.")
257
258
        cls_out = Map._get_map_cls(map_type)
259
        return cls_out(geom, data=data, meta=meta, unit=unit, dtype=dtype)
260
261
    @staticmethod
262
    def from_hdulist(
263
        hdulist, hdu=None, hdu_bands=None, map_type="auto", format=None, colname=None
264
    ):
265
        """Create from `astropy.io.fits.HDUList`.
266
267
        Parameters
268
        ----------
269
        hdulist :  `~astropy.io.fits.HDUList`
270
            HDU list containing HDUs for map data and bands.
271
        hdu : str
272
            Name or index of the HDU with the map data.
273
        hdu_bands : str
274
            Name or index of the HDU with the BANDS table.
275
        map_type : {"auto", "wcs", "hpx", "region"}
276
            Map type.
277
        format : {'gadf', 'fgst-ccube', 'fgst-template'}
278
            FITS format convention.
279
        colname : str, optional
280
            Data column name to be used for the HEALPix map.
281
282
        Returns
283
        -------
284
        map_out : `Map`
285
            Map object
286
        """
287
        if map_type == "auto":
288
            map_type = Map._get_map_type(hdulist, hdu)
289
        cls_out = Map._get_map_cls(map_type)
290
        if map_type == "hpx":
291
            return cls_out.from_hdulist(
292
                hdulist, hdu=hdu, hdu_bands=hdu_bands, format=format, colname=colname
293
            )
294
        else:
295
            return cls_out.from_hdulist(
296
                hdulist, hdu=hdu, hdu_bands=hdu_bands, format=format
297
            )
298
299
    @staticmethod
300
    def _get_meta_from_header(header):
301
        """Load meta data from a FITS header."""
302
        if "META" in header:
303
            return json.loads(header["META"], cls=JsonQuantityDecoder)
304
        else:
305
            return {}
306
307
    @staticmethod
308
    def _get_map_type(hdu_list, hdu_name):
309
        """Infer map type from a FITS HDU.
310
311
        Only read header, never data, to have good performance.
312
        """
313
        if hdu_name is None:
314
            # Find the header of the first non-empty HDU
315
            header = hdu_list[0].header
316
            if header["NAXIS"] == 0:
317
                header = hdu_list[1].header
318
        else:
319
            header = hdu_list[hdu_name].header
320
321
        if ("PIXTYPE" in header) and (header["PIXTYPE"] == "HEALPIX"):
322
            return "hpx"
323
        elif "CTYPE1" in header:
324
            return "wcs"
325
        else:
326
            return "region"
327
328
    @staticmethod
329
    def _get_map_cls(map_type):
330
        """Get map class for given `map_type` string.
331
332
        This should probably be a registry dict so that users
333
        can add supported map types to the `gammapy.maps` I/O
334
        (see e.g. the Astropy table format I/O registry),
335
        but that's non-trivial to implement without avoiding circular imports.
336
        """
337
        if map_type == "wcs":
338
            from .wcs import WcsNDMap
339
340
            return WcsNDMap
341
        elif map_type == "wcs-sparse":
342
            raise NotImplementedError()
343
        elif map_type == "hpx":
344
            from .hpx import HpxNDMap
345
346
            return HpxNDMap
347
        elif map_type == "hpx-sparse":
348
            raise NotImplementedError()
349
        elif map_type == "region":
350
            from .region import RegionNDMap
351
352
            return RegionNDMap
353
        else:
354
            raise ValueError(f"Unrecognized map type: {map_type!r}")
355
356
    def write(self, filename, overwrite=False, **kwargs):
357
        """Write to a FITS file.
358
359
        Parameters
360
        ----------
361
        filename : str
362
            Output file name.
363
        overwrite : bool
364
            Overwrite existing file?
365
        hdu : str
366
            Set the name of the image extension.  By default this will
367
            be set to SKYMAP (for BINTABLE HDU) or PRIMARY (for IMAGE
368
            HDU).
369
        hdu_bands : str
370
            Set the name of the bands table extension.  By default this will
371
            be set to BANDS.
372
        format : str, optional
373
            FITS format convention.  By default files will be written
374
            to the gamma-astro-data-formats (GADF) format.  This
375
            option can be used to write files that are compliant with
376
            format conventions required by specific software (e.g. the
377
            Fermi Science Tools). The following formats are supported:
378
379
                - "gadf" (default)
380
                - "fgst-ccube"
381
                - "fgst-ltcube"
382
                - "fgst-bexpcube"
383
                - "fgst-srcmap"
384
                - "fgst-template"
385
                - "fgst-srcmap-sparse"
386
                - "galprop"
387
                - "galprop2"
388
389
        sparse : bool
390
            Sparsify the map by dropping pixels with zero amplitude.
391
            This option is only compatible with the 'gadf' format.
392
        """
393
        hdulist = self.to_hdulist(**kwargs)
394
        hdulist.writeto(str(make_path(filename)), overwrite=overwrite)
395
396
    def iter_by_image(self):
397
        """Iterate over image planes of the map.
398
399
        This is a generator yielding ``(data, idx)`` tuples,
400
        where ``data`` is a `numpy.ndarray` view of the image plane data,
401
        and ``idx`` is a tuple of int, the index of the image plane.
402
403
        The image plane index is in data order, so that the data array can be
404
        indexed directly.
405
        """
406
        for idx in np.ndindex(self.geom.shape_axes):
407
            yield self.data[idx[::-1]], idx[::-1]
408
409
    def coadd(self, map_in, weights=None):
410
        """Add the contents of ``map_in`` to this map.
411
412
        This method can be used to combine maps containing integral quantities (e.g. counts)
413
        or differential quantities if the maps have the same binning.
414
415
        Parameters
416
        ----------
417
        map_in : `Map`
418
            Input map.
419
        weights: `Map` or `~numpy.ndarray`
420
            The weight factors while adding
421
        """
422
        if not self.unit.is_equivalent(map_in.unit):
423
            raise ValueError("Incompatible units")
424
425
        # TODO: Check whether geometries are aligned and if so sum the
426
        # data vectors directly
427
        if weights is not None:
428
            map_in = map_in * weights
429
        idx = map_in.geom.get_idx()
430
        coords = map_in.geom.get_coord()
431
        vals = u.Quantity(map_in.get_by_idx(idx), map_in.unit)
432
        self.fill_by_coord(coords, vals)
433
434
    def pad(self, pad_width, axis_name=None, mode="constant", cval=0, method="linear"):
435
        """Pad the spatial dimensions of the map.
436
437
        Parameters
438
        ----------
439
        pad_width : {sequence, array_like, int}
440
            Number of pixels padded to the edges of each axis.
441
        axis_name : str
442
            Which axis to downsample. By default spatial axes are padded.
443
        mode : {'edge', 'constant', 'interp'}
444
            Padding mode.  'edge' pads with the closest edge value.
445
            'constant' pads with a constant value. 'interp' pads with
446
            an extrapolated value.
447
        cval : float
448
            Padding value when mode='consant'.
449
450
        Returns
451
        -------
452
        map : `Map`
453
            Padded map.
454
455
        """
456
        if axis_name:
457
            if np.isscalar(pad_width):
458
                pad_width = (pad_width, pad_width)
459
460
            geom = self.geom.pad(pad_width=pad_width, axis_name=axis_name)
461
            idx = self.geom.axes.index_data(axis_name)
462
            pad_width_np = [(0, 0)] * self.data.ndim
463
            pad_width_np[idx] = pad_width
464
465
            kwargs = {}
466
            if mode == "constant":
467
                kwargs["constant_values"] = cval
468
469
            data = np.pad(self.data, pad_width=pad_width_np, mode=mode, **kwargs)
470
            return self.__class__(
471
                geom=geom, data=data, unit=self.unit, meta=self.meta.copy()
472
            )
473
474
        return self._pad_spatial(pad_width, mode="constant", cval=cval)
475
476
    @abc.abstractmethod
477
    def _pad_spatial(self, pad_width, mode="constant", cval=0, order=1):
478
        pass
479
480
    @abc.abstractmethod
481
    def crop(self, crop_width):
482
        """Crop the spatial dimensions of the map.
483
484
        Parameters
485
        ----------
486
        crop_width : {sequence, array_like, int}
487
            Number of pixels cropped from the edges of each axis.
488
            Defined analogously to ``pad_with`` from `numpy.pad`.
489
490
        Returns
491
        -------
492
        map : `Map`
493
            Cropped map.
494
        """
495
        pass
496
497
    @abc.abstractmethod
498
    def downsample(self, factor, preserve_counts=True, axis_name=None):
499
        """Downsample the spatial dimension by a given factor.
500
501
        Parameters
502
        ----------
503
        factor : int
504
            Downsampling factor.
505
        preserve_counts : bool
506
            Preserve the integral over each bin.  This should be true
507
            if the map is an integral quantity (e.g. counts) and false if
508
            the map is a differential quantity (e.g. intensity).
509
        axis_name : str
510
            Which axis to downsample. By default spatial axes are downsampled.
511
512
        Returns
513
        -------
514
        map : `Map`
515
            Downsampled map.
516
        """
517
        pass
518
519
    @abc.abstractmethod
520
    def upsample(self, factor, order=0, preserve_counts=True, axis=None):
521
        """Upsample the spatial dimension by a given factor.
522
523
        Parameters
524
        ----------
525
        factor : int
526
            Upsampling factor.
527
        order : int
528
            Order of the interpolation used for upsampling.
529
        preserve_counts : bool
530
            Preserve the integral over each bin.  This should be true
531
            if the map is an integral quantity (e.g. counts) and false if
532
            the map is a differential quantity (e.g. intensity).
533
        axis : str
534
            Which axis to upsample. By default spatial axes are upsampled.
535
536
537
        Returns
538
        -------
539
        map : `Map`
540
            Upsampled map.
541
        """
542
        pass
543
544
    def resample_axis(self, axis, weights=None, ufunc=np.add):
545
        """Resample map to a new axis binning by grouping over smaller bins and apply ufunc to the bin contents.
546
547
        By default, the map content are summed over the smaller bins. Other numpy ufunc can be used,
548
        e.g. np.logical_and, np.logical_or
549
550
        Parameters
551
        ----------
552
        axis : `MapAxis`
553
            New map axis.
554
        weights : `Map`
555
            Array to be used as weights. The spatial geometry must be equivalent
556
            to `other` and additional axes must be broadcastable.
557
        ufunc : `~numpy.ufunc`
558
            ufunc to use to resample the axis. Default is numpy.add.
559
560
561
        Returns
562
        -------
563
        map : `Map`
564
            Map with resampled axis.
565
        """
566
        from .hpx import HpxGeom
567
568
        geom = self.geom.resample_axis(axis)
569
570
        axis_self = self.geom.axes[axis.name]
571
        axis_resampled = geom.axes[axis.name]
572
573
        # We don't use MapAxis.coord_to_idx because is does not behave as needed with boundaries
574
        coord = axis_resampled.edges.value
575
        edges = axis_self.edges.value
576
        indices = np.digitize(coord, edges) - 1
577
578
        idx = self.geom.axes.index_data(axis.name)
579
580
        weights = 1 if weights is None else weights.data
581
582
        if not isinstance(self.geom, HpxGeom):
583
            shape = self.geom._shape[:2]
584
        else:
585
            shape = (self.geom.data_shape[-1],)
586
        shape += tuple([ax.nbin if ax != axis else 1 for ax in self.geom.axes])
587
588
        padded_array = np.append(self.data * weights, np.zeros(shape[::-1]), axis=idx)
589
590
        slices = tuple([slice(0, _) for _ in geom.data_shape])
591
        data = ufunc.reduceat(padded_array, indices=indices, axis=idx)[slices]
592
593
        return self._init_copy(data=data, geom=geom)
594
595
    def slice_by_idx(
596
        self,
597
        slices,
598
    ):
599
        """Slice sub map from map object.
600
601
        Parameters
602
        ----------
603
        slices : dict
604
            Dict of axes names and integers or `slice` object pairs. Contains one
605
            element for each non-spatial dimension. For integer indexing the
606
            corresponding axes is dropped from the map. Axes not specified in the
607
            dict are kept unchanged.
608
609
        Returns
610
        -------
611
        map_out : `Map`
612
            Sliced map object.
613
        """
614
        geom = self.geom.slice_by_idx(slices)
615
        slices = tuple([slices.get(ax.name, slice(None)) for ax in self.geom.axes])
616
        data = self.data[slices[::-1]]
617
        return self.__class__(geom=geom, data=data, unit=self.unit, meta=self.meta)
618
619
    def get_image_by_coord(self, coords):
620
        """Return spatial map at the given axis coordinates.
621
622
        Parameters
623
        ----------
624
        coords : tuple or dict
625
            Tuple should be ordered as (x_0, ..., x_n) where x_i are coordinates
626
            for non-spatial dimensions of the map. Dict should specify the axis
627
            names of the non-spatial axes such as {'axes0': x_0, ..., 'axesn': x_n}.
628
629
        Returns
630
        -------
631
        map_out : `Map`
632
            Map with spatial dimensions only.
633
634
        See Also
635
        --------
636
        get_image_by_idx, get_image_by_pix
637
638
        Examples
639
        --------
640
        ::
641
642
            import numpy as np
643
            from gammapy.maps import Map, MapAxis
644
            from astropy.coordinates import SkyCoord
645
            from astropy import units as u
646
647
            # Define map axes
648
            energy_axis = MapAxis.from_edges(
649
                np.logspace(-1., 1., 4), unit='TeV', name='energy',
650
            )
651
652
            time_axis = MapAxis.from_edges(
653
                np.linspace(0., 10, 20), unit='h', name='time',
654
            )
655
656
            # Define map center
657
            skydir = SkyCoord(0, 0, frame='galactic', unit='deg')
658
659
            # Create map
660
            m_wcs = Map.create(
661
                map_type='wcs',
662
                binsz=0.02,
663
                skydir=skydir,
664
                width=10.0,
665
                axes=[energy_axis, time_axis],
666
            )
667
668
            # Get image by coord tuple
669
            image = m_wcs.get_image_by_coord(('500 GeV', '1 h'))
670
671
            # Get image by coord dict with strings
672
            image = m_wcs.get_image_by_coord({'energy': '500 GeV', 'time': '1 h'})
673
674
            # Get image by coord dict with quantities
675
            image = m_wcs.get_image_by_coord({'energy': 0.5 * u.TeV, 'time': 1 * u.h})
676
        """
677
        if isinstance(coords, tuple):
678
            coords = dict(zip(self.geom.axes.names, coords))
679
680
        idx = self.geom.axes.coord_to_idx(coords)
681
        return self.get_image_by_idx(idx)
682
683
    def get_image_by_pix(self, pix):
684
        """Return spatial map at the given axis pixel coordinates
685
686
        Parameters
687
        ----------
688
        pix : tuple
689
            Tuple of scalar pixel coordinates for each non-spatial dimension of
690
            the map. Tuple should be ordered as (I_0, ..., I_n). Pixel coordinates
691
            can be either float or integer type.
692
693
        See Also
694
        --------
695
        get_image_by_coord, get_image_by_idx
696
697
        Returns
698
        -------
699
        map_out : `Map`
700
            Map with spatial dimensions only.
701
        """
702
        idx = self.geom.pix_to_idx(pix)
703
        return self.get_image_by_idx(idx)
704
705
    def get_image_by_idx(self, idx):
706
        """Return spatial map at the given axis pixel indices.
707
708
        Parameters
709
        ----------
710
        idx : tuple
711
            Tuple of scalar indices for each non spatial dimension of the map.
712
            Tuple should be ordered as (I_0, ..., I_n).
713
714
        See Also
715
        --------
716
        get_image_by_coord, get_image_by_pix
717
718
        Returns
719
        -------
720
        map_out : `Map`
721
            Map with spatial dimensions only.
722
        """
723
        if len(idx) != len(self.geom.axes):
724
            raise ValueError("Tuple length must equal number of non-spatial dimensions")
725
726
        # Only support scalar indices per axis
727
        idx = tuple([int(_) for _ in idx])
728
729
        geom = self.geom.to_image()
730
        data = self.data[idx[::-1]]
731
        return self.__class__(geom=geom, data=data, unit=self.unit, meta=self.meta)
732
733
    def get_by_coord(self, coords, fill_value=np.nan):
734
        """Return map values at the given map coordinates.
735
736
        Parameters
737
        ----------
738
        coords : tuple or `~gammapy.maps.MapCoord`
739
            Coordinate arrays for each dimension of the map.  Tuple
740
            should be ordered as (lon, lat, x_0, ..., x_n) where x_i
741
            are coordinates for non-spatial dimensions of the map.
742
        fill_value : float
743
            Value which is returned if the position is outside of the projection
744
            footprint
745
746
        Returns
747
        -------
748
        vals : `~numpy.ndarray`
749
           Values of pixels in the map.  np.nan used to flag coords
750
           outside of map.
751
        """
752
        coords = MapCoord.create(
753
            coords, frame=self.geom.frame, axis_names=self.geom.axes.names
754
        )
755
        pix = self.geom.coord_to_pix(coords)
756
        vals = self.get_by_pix(pix, fill_value=fill_value)
757
        return vals
758
759
    def get_by_pix(self, pix, fill_value=np.nan):
760
        """Return map values at the given pixel coordinates.
761
762
        Parameters
763
        ----------
764
        pix : tuple
765
            Tuple of pixel index arrays for each dimension of the map.
766
            Tuple should be ordered as (I_lon, I_lat, I_0, ..., I_n)
767
            for WCS maps and (I_hpx, I_0, ..., I_n) for HEALPix maps.
768
            Pixel indices can be either float or integer type.
769
        fill_value : float
770
            Value which is returned if the position is outside of the projection
771
            footprint
772
773
        Returns
774
        -------
775
        vals : `~numpy.ndarray`
776
           Array of pixel values.  np.nan used to flag coordinates
777
           outside of map
778
        """
779
        # FIXME: Support local indexing here?
780
        # FIXME: Support slicing?
781
        pix = np.broadcast_arrays(*pix)
782
        idx = self.geom.pix_to_idx(pix)
783
        vals = self.get_by_idx(idx)
784
        mask = self.geom.contains_pix(pix)
785
786
        if not mask.all():
787
            vals = vals.astype(type(fill_value))
788
            vals[~mask] = fill_value
789
790
        return vals
791
792
    @abc.abstractmethod
793
    def get_by_idx(self, idx):
794
        """Return map values at the given pixel indices.
795
796
        Parameters
797
        ----------
798
        idx : tuple
799
            Tuple of pixel index arrays for each dimension of the map.
800
            Tuple should be ordered as (I_lon, I_lat, I_0, ..., I_n)
801
            for WCS maps and (I_hpx, I_0, ..., I_n) for HEALPix maps.
802
803
        Returns
804
        -------
805
        vals : `~numpy.ndarray`
806
           Array of pixel values.
807
           np.nan used to flag coordinate outside of map
808
        """
809
        pass
810
811
    @abc.abstractmethod
812
    def interp_by_coord(self, coords, method="nearest", fill_value=None):
813
        """Interpolate map values at the given map coordinates.
814
815
        Parameters
816
        ----------
817
        coords : tuple or `~gammapy.maps.MapCoord`
818
            Coordinate arrays for each dimension of the map.  Tuple
819
            should be ordered as (lon, lat, x_0, ..., x_n) where x_i
820
            are coordinates for non-spatial dimensions of the map.
821
        method : {"nearest", "linear"}
822
            Method to interpolate data values. By default no
823
            interpolation is performed and the return value will be
824
            the amplitude of the pixel encompassing the given
825
            coordinate.
826
        fill_value : None or float value
827
            The value to use for points outside of the interpolation domain.
828
            If None, values outside the domain are extrapolated.
829
830
        Returns
831
        -------
832
        vals : `~numpy.ndarray`
833
            Interpolated pixel values.
834
        """
835
        pass
836
837
    @abc.abstractmethod
838
    def interp_by_pix(self, pix, method="nearest", fill_value=None):
839
        """Interpolate map values at the given pixel coordinates.
840
841
        Parameters
842
        ----------
843
        pix : tuple
844
            Tuple of pixel coordinate arrays for each dimension of the
845
            map.  Tuple should be ordered as (p_lon, p_lat, p_0, ...,
846
            p_n) where p_i are pixel coordinates for non-spatial
847
            dimensions of the map.
848
        method : {"nearest", "linear"}
849
            Method to interpolate data values.  By default no
850
            interpolation is performed and the return value will be
851
            the amplitude of the pixel encompassing the given
852
            coordinate.
853
        fill_value : None or float value
854
            The value to use for points outside of the interpolation domain.
855
            If None, values outside the domain are extrapolated.
856
857
        Returns
858
        -------
859
        vals : `~numpy.ndarray`
860
            Interpolated pixel values.
861
        """
862
        pass
863
864
    def interp_to_geom(self, geom, preserve_counts=False, fill_value=0, **kwargs):
865
        """Interpolate map to input geometry.
866
867
        Parameters
868
        ----------
869
        geom : `~gammapy.maps.Geom`
870
            Target Map geometry
871
        preserve_counts : bool
872
            Preserve the integral over each bin.  This should be true
873
            if the map is an integral quantity (e.g. counts) and false if
874
            the map is a differential quantity (e.g. intensity)
875
        **kwargs : dict
876
            Keyword arguments passed to `Map.interp_by_coord`
877
878
        Returns
879
        -------
880
        interp_map : `Map`
881
            Interpolated Map
882
        """
883
        coords = geom.get_coord()
884
        map_copy = self.copy()
885
886
        if preserve_counts:
887
            if geom.ndim > 2 and geom.axes[0] != self.geom.axes[0]:
888
                raise ValueError(f"Energy axis do not match: expected {self.geom.axes[0]}, but got {geom.axes[0]}.")
889
            map_copy.data /= map_copy.geom.solid_angle().to_value("deg2")
890
891
        if map_copy.is_mask:
892
            # TODO: check this NaN handling is needed
893
            data = map_copy.get_by_coord(coords)
894
            data = np.nan_to_num(data, nan=fill_value).astype(bool)
895
        else:
896
            data = map_copy.interp_by_coord(coords, **kwargs)
897
898
        if preserve_counts:
899
            data *= geom.solid_angle().to_value("deg2")
900
901
        return Map.from_geom(geom, data=data, unit=self.unit)
902
903
    def fill_events(self, events):
904
        """Fill event coordinates (`~gammapy.data.EventList`)."""
905
        self.fill_by_coord(events.map_coord(self.geom))
906
907
    def fill_by_coord(self, coords, weights=None):
908
        """Fill pixels at ``coords`` with given ``weights``.
909
910
        Parameters
911
        ----------
912
        coords : tuple or `~gammapy.maps.MapCoord`
913
            Coordinate arrays for each dimension of the map.  Tuple
914
            should be ordered as (lon, lat, x_0, ..., x_n) where x_i
915
            are coordinates for non-spatial dimensions of the map.
916
        weights : `~numpy.ndarray`
917
            Weights vector. Default is weight of one.
918
        """
919
        idx = self.geom.coord_to_idx(coords)
920
        self.fill_by_idx(idx, weights)
921
922
    def fill_by_pix(self, pix, weights=None):
923
        """Fill pixels at ``pix`` with given ``weights``.
924
925
        Parameters
926
        ----------
927
        pix : tuple
928
            Tuple of pixel index arrays for each dimension of the map.
929
            Tuple should be ordered as (I_lon, I_lat, I_0, ..., I_n)
930
            for WCS maps and (I_hpx, I_0, ..., I_n) for HEALPix maps.
931
            Pixel indices can be either float or integer type.  Float
932
            indices will be rounded to the nearest integer.
933
        weights : `~numpy.ndarray`
934
            Weights vector. Default is weight of one.
935
        """
936
        idx = pix_tuple_to_idx(pix)
937
        return self.fill_by_idx(idx, weights=weights)
938
939
    @abc.abstractmethod
940
    def fill_by_idx(self, idx, weights=None):
941
        """Fill pixels at ``idx`` with given ``weights``.
942
943
        Parameters
944
        ----------
945
        idx : tuple
946
            Tuple of pixel index arrays for each dimension of the map.
947
            Tuple should be ordered as (I_lon, I_lat, I_0, ..., I_n)
948
            for WCS maps and (I_hpx, I_0, ..., I_n) for HEALPix maps.
949
        weights : `~numpy.ndarray`
950
            Weights vector. Default is weight of one.
951
        """
952
        pass
953
954
    def set_by_coord(self, coords, vals):
955
        """Set pixels at ``coords`` with given ``vals``.
956
957
        Parameters
958
        ----------
959
        coords : tuple or `~gammapy.maps.MapCoord`
960
            Coordinate arrays for each dimension of the map.  Tuple
961
            should be ordered as (lon, lat, x_0, ..., x_n) where x_i
962
            are coordinates for non-spatial dimensions of the map.
963
        vals : `~numpy.ndarray`
964
            Values vector.
965
        """
966
        idx = self.geom.coord_to_pix(coords)
967
        self.set_by_pix(idx, vals)
968
969
    def set_by_pix(self, pix, vals):
970
        """Set pixels at ``pix`` with given ``vals``.
971
972
        Parameters
973
        ----------
974
        pix : tuple
975
            Tuple of pixel index arrays for each dimension of the map.
976
            Tuple should be ordered as (I_lon, I_lat, I_0, ..., I_n)
977
            for WCS maps and (I_hpx, I_0, ..., I_n) for HEALPix maps.
978
            Pixel indices can be either float or integer type.  Float
979
            indices will be rounded to the nearest integer.
980
        vals : `~numpy.ndarray`
981
            Values vector.
982
        """
983
        idx = pix_tuple_to_idx(pix)
984
        return self.set_by_idx(idx, vals)
985
986
    @abc.abstractmethod
987
    def set_by_idx(self, idx, vals):
988
        """Set pixels at ``idx`` with given ``vals``.
989
990
        Parameters
991
        ----------
992
        idx : tuple
993
            Tuple of pixel index arrays for each dimension of the map.
994
            Tuple should be ordered as (I_lon, I_lat, I_0, ..., I_n)
995
            for WCS maps and (I_hpx, I_0, ..., I_n) for HEALPix maps.
996
        vals : `~numpy.ndarray`
997
            Values vector.
998
        """
999
        pass
1000
1001
    def plot_grid(self, figsize=None, ncols=3, **kwargs):
1002
        """Plot map as a grid of subplots for non-spatial axes
1003
1004
        Parameters
1005
        ----------
1006
        figsize : tuple of int
1007
            Figsize to plot on
1008
        ncols : int
1009
            Number of columns to plot
1010
        **kwargs : dict
1011
            Keyword arguments passed to `Map.plot`.
1012
1013
        Returns
1014
        -------
1015
        axes : `~numpy.ndarray` of `~matplotlib.pyplot.Axes`
1016
            Axes grid
1017
        """
1018
        import matplotlib.pyplot as plt
1019
1020
        if len(self.geom.axes) > 1:
1021
            raise ValueError("Grid plotting is only supported for one non spatial axis")
1022
1023
        axis = self.geom.axes[0]
1024
1025
        cols = min(ncols, axis.nbin)
1026
        rows = 1 + (axis.nbin - 1) // cols
1027
1028
        if figsize is None:
1029
            width = 12
1030
            figsize = (width, width * rows / cols)
1031
1032
        if self.geom.is_hpx:
1033
            wcs = self.geom.to_wcs_geom().wcs
1034
        else:
1035
            wcs = self.geom.wcs
1036
1037
        fig, axes = plt.subplots(
1038
            ncols=cols,
1039
            nrows=rows,
1040
            subplot_kw={"projection": wcs},
1041
            figsize=figsize,
1042
            gridspec_kw={"hspace": 0.1, "wspace": 0.1},
1043
        )
1044
1045
        for idx in range(cols * rows):
1046
            ax = axes.flat[idx]
1047
1048
            try:
1049
                image = self.get_image_by_idx((idx,))
1050
            except IndexError:
1051
                ax.set_visible(False)
1052
                continue
1053
1054
            if image.geom.is_hpx:
1055
                image_wcs = image.to_wcs(
1056
                    normalize=False,
1057
                    proj="AIT",
1058
                    oversample=2,
1059
                )
1060
            else:
1061
                image_wcs = image
1062
1063
            image_wcs.plot(ax=ax, **kwargs)
1064
1065
            if axis.node_type == "center":
1066
                info = f"{axis.center[idx]:.1f}"
1067
            else:
1068
                info = f"{axis.edges[idx]:.1f} - {axis.edges[idx + 1]:.1f} "
1069
1070
            ax.set_title(f"{axis.name.capitalize()} " + info)
1071
            lon, lat = ax.coords[0], ax.coords[1]
1072
            lon.set_ticks_position("b")
1073
            lat.set_ticks_position("l")
1074
1075
            row, col = np.unravel_index(idx, shape=(rows, cols))
1076
1077
            if col > 0:
1078
                lat.set_ticklabel_visible(False)
1079
                lat.set_axislabel("")
1080
1081
            if row < (rows - 1):
1082
                lon.set_ticklabel_visible(False)
1083
                lon.set_axislabel("")
1084
1085
        return axes
1086
1087
    def plot_interactive(self, rc_params=None, **kwargs):
1088
        """
1089
        Plot map with interactive widgets to explore the non spatial axes.
1090
1091
        Parameters
1092
        ----------
1093
        rc_params : dict
1094
            Passed to ``matplotlib.rc_context(rc=rc_params)`` to style the plot.
1095
        **kwargs : dict
1096
            Keyword arguments passed to `WcsNDMap.plot`.
1097
1098
        Examples
1099
        --------
1100
        You can try this out e.g. using a Fermi-LAT diffuse model cube with an energy axis::
1101
1102
            from gammapy.maps import Map
1103
1104
            m = Map.read("$GAMMAPY_DATA/fermi_3fhl/gll_iem_v06_cutout.fits")
1105
            m.plot_interactive(add_cbar=True, stretch="sqrt")
1106
1107
        If you would like to adjust the figure size you can use the ``rc_params`` argument::
1108
1109
            rc_params = {'figure.figsize': (12, 6), 'font.size': 12}
1110
            m.plot_interactive(rc_params=rc_params)
1111
        """
1112
        import matplotlib as mpl
1113
        import matplotlib.pyplot as plt
1114
        from ipywidgets import RadioButtons, SelectionSlider
1115
        from ipywidgets.widgets.interaction import fixed, interact
1116
1117
        if self.geom.is_image:
1118
            raise TypeError("Use .plot() for 2D Maps")
1119
1120
        kwargs.setdefault("interpolation", "nearest")
1121
        kwargs.setdefault("origin", "lower")
1122
        kwargs.setdefault("cmap", "afmhot")
1123
1124
        rc_params = rc_params or {}
1125
        stretch = kwargs.pop("stretch", "sqrt")
1126
1127
        interact_kwargs = {}
1128
1129
        for axis in self.geom.axes:
1130
            options = axis.as_plot_labels
1131
            interact_kwargs[axis.name] = SelectionSlider(
1132
                options=options,
1133
                description=f"Select {axis.name}:",
1134
                continuous_update=False,
1135
                style={"description_width": "initial"},
1136
                layout={"width": "50%"},
1137
            )
1138
            interact_kwargs[axis.name + "_options"] = fixed(options)
1139
1140
        interact_kwargs["stretch"] = RadioButtons(
1141
            options=["linear", "sqrt", "log"],
1142
            value=stretch,
1143
            description="Select stretch:",
1144
            style={"description_width": "initial"},
1145
        )
1146
1147
        @interact(**interact_kwargs)
1148
        def _plot_interactive(**ikwargs):
1149
            idx = [
1150
                ikwargs[ax.name + "_options"].index(ikwargs[ax.name])
1151
                for ax in self.geom.axes
1152
            ]
1153
            img = self.get_image_by_idx(idx)
1154
            stretch = ikwargs["stretch"]
1155
            with mpl.rc_context(rc=rc_params):
1156
                img.plot(stretch=stretch, **kwargs)
1157
                plt.show()
1158
1159
    def copy(self, **kwargs):
1160
        """Copy map instance and overwrite given attributes, except for geometry.
1161
1162
        Parameters
1163
        ----------
1164
        **kwargs : dict
1165
            Keyword arguments to overwrite in the map constructor.
1166
1167
        Returns
1168
        -------
1169
        copy : `Map`
1170
            Copied Map.
1171
        """
1172
        if "geom" in kwargs:
1173
            geom = kwargs["geom"]
1174
            if not geom.data_shape == self.geom.data_shape:
1175
                raise ValueError(
1176
                    "Can't copy and change data size of the map. "
1177
                    f" Current shape {self.geom.data_shape},"
1178
                    f" requested shape {geom.data_shape}"
1179
                )
1180
1181
        return self._init_copy(**kwargs)
1182
1183
    def apply_edisp(self, edisp):
1184
        """Apply energy dispersion to map. Requires energy axis.
1185
1186
        Parameters
1187
        ----------
1188
        edisp : `gammapy.irf.EDispKernel`
1189
            Energy dispersion matrix
1190
1191
        Returns
1192
        -------
1193
        map : `WcsNDMap`
1194
            Map with energy dispersion applied.
1195
        """
1196
        # TODO: either use sparse matrix mutiplication or something like edisp.is_diagonal
1197
        if edisp is not None:
1198
            loc = self.geom.axes.index("energy_true")
1199
            data = np.rollaxis(self.data, loc, len(self.data.shape))
1200
            data = np.dot(data, edisp.pdf_matrix)
1201
            data = np.rollaxis(data, -1, loc)
1202
            energy_axis = edisp.axes["energy"].copy(name="energy")
1203
        else:
1204
            data = self.data
1205
            energy_axis = self.geom.axes["energy_true"].copy(name="energy")
1206
1207
        geom = self.geom.to_image().to_cube(axes=[energy_axis])
1208
        return self._init_copy(geom=geom, data=data)
1209
1210
    def mask_nearest_position(self, position):
1211
        """Given a sky coordinate return nearest valid position in the mask
1212
1213
        If the mask contains additional axes, the mask is reduced over those.
1214
1215
        Parameters
1216
        ----------
1217
        position : `SkyCoord`
1218
            Test position
1219
1220
        Returns
1221
        -------
1222
        position : `SkyCoord`
1223
            Nearest position in the mask
1224
        """
1225
        if not self.geom.is_image:
1226
            raise ValueError("Method only supported for 2D images")
1227
1228
        coords = self.geom.to_image().get_coord().skycoord
1229
        separation = coords.separation(position)
1230
        separation[~self.data] = np.inf
1231
        idx = np.argmin(separation)
1232
        return coords.flatten()[idx]
1233
1234
    def sum_over_axes(self, axes_names=None, keepdims=True, weights=None):
1235
        """To sum map values over all non-spatial axes.
1236
1237
        Parameters
1238
        ----------
1239
        keepdims : bool, optional
1240
            If this is set to true, the axes which are summed over are left in
1241
            the map with a single bin
1242
        axes_names: list of str
1243
            Names of MapAxis to reduce over. If None, all will summed over
1244
        weights : `Map`
1245
            Weights to be applied. The Map should have the same geometry.
1246
1247
        Returns
1248
        -------
1249
        map_out : `~Map`
1250
            Map with non-spatial axes summed over
1251
        """
1252
        return self.reduce_over_axes(
1253
            func=np.add, axes_names=axes_names, keepdims=keepdims, weights=weights
1254
        )
1255
1256
    def reduce_over_axes(
1257
        self, func=np.add, keepdims=False, axes_names=None, weights=None
1258
    ):
1259
        """Reduce map over non-spatial axes
1260
1261
        Parameters
1262
        ----------
1263
        func : `~numpy.ufunc`
1264
            Function to use for reducing the data.
1265
        keepdims : bool, optional
1266
            If this is set to true, the axes which are summed over are left in
1267
            the map with a single bin
1268
        axes_names: list
1269
            Names of MapAxis to reduce over
1270
            If None, all will reduced
1271
        weights : `Map`
1272
            Weights to be applied.
1273
1274
        Returns
1275
        -------
1276
        map_out : `~Map`
1277
            Map with non-spatial axes reduced
1278
        """
1279
        if axes_names is None:
1280
            axes_names = self.geom.axes.names
1281
1282
        map_out = self.copy()
1283
        for axis_name in axes_names:
1284
            map_out = map_out.reduce(
1285
                axis_name, func=func, keepdims=keepdims, weights=weights
1286
            )
1287
1288
        return map_out
1289
1290
    def reduce(self, axis_name, func=np.add, keepdims=False, weights=None):
1291
        """Reduce map over a single non-spatial axis
1292
1293
        Parameters
1294
        ----------
1295
        axis_name: str
1296
            The name of the axis to reduce over
1297
        func : `~numpy.ufunc`
1298
            Function to use for reducing the data.
1299
        keepdims : bool, optional
1300
            If this is set to true, the axes which are summed over are left in
1301
            the map with a single bin
1302
        weights : `Map`
1303
            Weights to be applied.
1304
1305
        Returns
1306
        -------
1307
        map_out : `~Map`
1308
            Map with the given non-spatial axes reduced
1309
        """
1310
        if keepdims:
1311
            geom = self.geom.squash(axis_name=axis_name)
1312
        else:
1313
            geom = self.geom.drop(axis_name=axis_name)
1314
1315
        idx = self.geom.axes.index_data(axis_name)
1316
1317
        data = self.data
1318
1319
        if weights is not None:
1320
            data = data * weights
1321
1322
        data = func.reduce(data, axis=idx, keepdims=keepdims, where=~np.isnan(data))
1323
        return self._init_copy(geom=geom, data=data)
1324
1325
    def cumsum(self, axis_name):
1326
        """Compute cumulative sum along a given axis
1327
1328
        Parameters
1329
        ----------
1330
        axis_name : str
1331
            Along which axis to integrate.
1332
1333
        Returns
1334
        -------
1335
        cumsum : `Map`
1336
            Map with cumulative sum
1337
        """
1338
        axis = self.geom.axes[axis_name]
1339
        axis_idx = self.geom.axes.index_data(axis_name)
1340
1341
        # TODO: the broadcasting should be done by axis.center, axis.bin_width etc.
1342
        shape = [1] * len(self.geom.data_shape)
1343
        shape[axis_idx] = -1
1344
1345
        values = self.quantity * axis.bin_width.reshape(shape)
1346
1347
        if axis_name == "rad":
1348
            # take Jacobian into account
1349
            values = 2 * np.pi * axis.center.reshape(shape) * values
1350
1351
        data = np.insert(values.cumsum(axis=axis_idx), 0, 0, axis=axis_idx)
1352
1353
        axis_shifted = MapAxis.from_nodes(
1354
            axis.edges, name=axis.name, interp=axis.interp
1355
        )
1356
        axes = self.geom.axes.replace(axis_shifted)
1357
        geom = self.geom.to_image().to_cube(axes)
1358
        return self.__class__(geom=geom, data=data.value, unit=data.unit)
1359
1360
    def integral(self, axis_name, coords, **kwargs):
1361
        """Compute integral along a given axis
1362
1363
        This method uses interpolation of the cumulative sum.
1364
1365
        Parameters
1366
        ----------
1367
        axis_name : str
1368
            Along which axis to integrate.
1369
        coords : dict or `MapCoord`
1370
            Map coordinates
1371
1372
        **kwargs : dict
1373
            Coordinates at which to evaluate the IRF
1374
1375
        Returns
1376
        -------
1377
        array : `~astropy.units.Quantity`
1378
            Returns 2D array with axes offset
1379
        """
1380
        cumsum = self.cumsum(axis_name=axis_name)
1381
        cumsum = cumsum.pad(pad_width=1, axis_name=axis_name, mode="edge")
1382
        return u.Quantity(
1383
            cumsum.interp_by_coord(coords, **kwargs), cumsum.unit, copy=False
1384
        )
1385
1386
    def normalize(self, axis_name=None):
1387
        """Normalise data in place along a given axis.
1388
1389
        Parameters
1390
        ----------
1391
        axis_name : str
1392
            Along which axis to normalize.
1393
1394
        """
1395
        cumsum = self.cumsum(axis_name=axis_name).quantity
1396
1397
        with np.errstate(invalid="ignore", divide="ignore"):
1398
            axis = self.geom.axes.index_data(axis_name=axis_name)
1399
            normed = self.quantity / cumsum.max(axis=axis, keepdims=True)
1400
1401
        self.quantity = np.nan_to_num(normed)
1402
1403
    @classmethod
1404
    def from_stack(cls, maps, axis=None, axis_name=None):
1405
        """Create Map from list of images and a non-spatial axis.
1406
1407
        The image geometries must be aligned, except for the axis that is stacked.
1408
1409
        Parameters
1410
        ----------
1411
        maps : list of `Map` objects
1412
            List of maps
1413
        axis : `MapAxis`
1414
            If a `MapAxis` is provided the maps are stacked along the last data
1415
            axis and the new axis is introduced.
1416
        axis_name : str
1417
            If an axis name is as string the given the maps are stacked along
1418
            the given axis name.
1419
1420
        Returns
1421
        -------
1422
        map : `Map`
1423
            Map with additional non-spatial axis.
1424
        """
1425
        geom = maps[0].geom
1426
1427
        if axis_name is None and axis is None:
1428
            axis_name = geom.axes.names[-1]
1429
1430
        if axis_name:
1431
            axis = MapAxis.from_stack(axes=[m.geom.axes[axis_name] for m in maps])
1432
            geom = geom.drop(axis_name=axis_name)
1433
1434
        data = []
1435
1436
        for m in maps:
1437
            if axis_name:
1438
                m_geom = m.geom.drop(axis_name=axis_name)
1439
            else:
1440
                m_geom = m.geom
1441
1442
            if not m_geom == geom:
1443
                raise ValueError(f"Image geometries not aligned: {m.geom} and {geom}")
1444
1445
            data.append(m.quantity.to_value(maps[0].unit))
1446
1447
        return cls.from_geom(
1448
            data=np.stack(data), geom=geom.to_cube(axes=[axis]), unit=maps[0].unit
1449
        )
1450
1451
    def to_cube(self, axes):
1452
        """Append non-spatial axes to create a higher-dimensional Map.
1453
1454
        This will result in a Map with a new geometry with
1455
        N+M dimensions where N is the number of current dimensions and
1456
        M is the number of axes in the list. The data is reshaped onto the
1457
        new geometry
1458
1459
        Parameters
1460
        ----------
1461
        axes : list
1462
            Axes that will be appended to this Map.
1463
            The axes should have only one bin
1464
1465
        Returns
1466
        -------
1467
        map : `~gammapy.maps.WcsNDMap`
1468
            new map
1469
        """
1470
        for ax in axes:
1471
            if ax.nbin > 1:
1472
                raise ValueError(ax.name, "should have only one bin")
1473
        geom = self.geom.to_cube(axes)
1474
        data = self.data.reshape((1,) * len(axes) + self.data.shape)
1475
        return self.from_geom(data=data, geom=geom, unit=self.unit)
1476
1477
    def get_spectrum(self, region=None, func=np.nansum, weights=None):
1478
        """Extract spectrum in a given region.
1479
1480
        The spectrum can be computed by summing (or, more generally, applying ``func``)
1481
        along the spatial axes in each energy bin. This occurs only inside the ``region``,
1482
        which by default is assumed to be the whole spatial extension of the map.
1483
1484
        Parameters
1485
        ----------
1486
        region: `~regions.Region`
1487
             Region (pixel or sky regions accepted).
1488
        func : numpy.func
1489
            Function to reduce the data. Default is np.nansum.
1490
            For a boolean Map, use np.any or np.all.
1491
        weights : `WcsNDMap`
1492
            Array to be used as weights. The geometry must be equivalent.
1493
1494
        Returns
1495
        -------
1496
        spectrum : `~gammapy.maps.RegionNDMap`
1497
            Spectrum in the given region.
1498
        """
1499
        if not self.geom.has_energy_axis:
1500
            raise ValueError("Energy axis required")
1501
1502
        return self.to_region_nd_map(region=region, func=func, weights=weights)
1503
1504
    def to_unit(self, unit):
1505
        """Convert map to different unit
1506
1507
        Parameters
1508
        ----------
1509
        unit : `~astropy.unit.Unit` or str
1510
            New unit
1511
1512
        Returns
1513
        -------
1514
        map : `Map`
1515
            Map with new unit and converted data
1516
        """
1517
        data = self.quantity.to_value(unit)
1518
        return self.from_geom(self.geom, data=data, unit=unit)
1519
1520
    def __repr__(self):
1521
        geom = self.geom.__class__.__name__
1522
        axes = ["skycoord"] if self.geom.is_hpx else ["lon", "lat"]
1523
        axes = axes + [_.name for _ in self.geom.axes]
1524
1525
        return (
1526
            f"{self.__class__.__name__}\n\n"
1527
            f"\tgeom  : {geom} \n "
1528
            f"\taxes  : {axes}\n"
1529
            f"\tshape : {self.geom.data_shape[::-1]}\n"
1530
            f"\tndim  : {self.geom.ndim}\n"
1531
            f"\tunit  : {self.unit}\n"
1532
            f"\tdtype : {self.data.dtype}\n"
1533
        )
1534
1535
    def _arithmetics(self, operator, other, copy):
1536
        """Perform arithmetic on maps after checking geometry consistency."""
1537
        if isinstance(other, Map):
1538
            if self.geom == other.geom:
1539
                q = other.quantity
1540
            else:
1541
                raise ValueError("Map Arithmetic: Inconsistent geometries.")
1542
        else:
1543
            q = u.Quantity(other, copy=False)
1544
1545
        out = self.copy() if copy else self
1546
        out.quantity = operator(out.quantity, q)
1547
        return out
1548
1549
    def _boolean_arithmetics(self, operator, other, copy):
1550
        """Perform arithmetic on maps after checking geometry consistency."""
1551
        if operator == np.logical_not:
1552
            out = self.copy()
1553
            out.data = operator(out.data)
1554
            return out
1555
1556
        if isinstance(other, Map):
1557
            if self.geom == other.geom:
1558
                other = other.data
1559
            else:
1560
                raise ValueError("Map Arithmetic: Inconsistent geometries.")
1561
1562
        out = self.copy() if copy else self
1563
        out.data = operator(out.data, other)
1564
        return out
1565
1566
    def __add__(self, other):
1567
        return self._arithmetics(np.add, other, copy=True)
1568
1569
    def __iadd__(self, other):
1570
        return self._arithmetics(np.add, other, copy=False)
1571
1572
    def __sub__(self, other):
1573
        return self._arithmetics(np.subtract, other, copy=True)
1574
1575
    def __isub__(self, other):
1576
        return self._arithmetics(np.subtract, other, copy=False)
1577
1578
    def __mul__(self, other):
1579
        return self._arithmetics(np.multiply, other, copy=True)
1580
1581
    def __imul__(self, other):
1582
        return self._arithmetics(np.multiply, other, copy=False)
1583
1584
    def __truediv__(self, other):
1585
        return self._arithmetics(np.true_divide, other, copy=True)
1586
1587
    def __itruediv__(self, other):
1588
        return self._arithmetics(np.true_divide, other, copy=False)
1589
1590
    def __le__(self, other):
1591
        return self._arithmetics(np.less_equal, other, copy=True)
1592
1593
    def __lt__(self, other):
1594
        return self._arithmetics(np.less, other, copy=True)
1595
1596
    def __ge__(self, other):
1597
        return self._arithmetics(np.greater_equal, other, copy=True)
1598
1599
    def __gt__(self, other):
1600
        return self._arithmetics(np.greater, other, copy=True)
1601
1602
    def __eq__(self, other):
1603
        return self._arithmetics(np.equal, other, copy=True)
1604
1605
    def __ne__(self, other):
1606
        return self._arithmetics(np.not_equal, other, copy=True)
1607
1608
    def __and__(self, other):
1609
        return self._boolean_arithmetics(np.logical_and, other, copy=True)
1610
1611
    def __or__(self, other):
1612
        return self._boolean_arithmetics(np.logical_or, other, copy=True)
1613
1614
    def __invert__(self):
1615
        return self._boolean_arithmetics(np.logical_not, None, copy=True)
1616
1617
    def __xor__(self, other):
1618
        return self._boolean_arithmetics(np.logical_xor, other, copy=True)
1619
1620
    def __iand__(self, other):
1621
        return self._boolean_arithmetics(np.logical_and, other, copy=False)
1622
1623
    def __ior__(self, other):
1624
        return self._boolean_arithmetics(np.logical_or, other, copy=False)
1625
1626
    def __ixor__(self, other):
1627
        return self._boolean_arithmetics(np.logical_xor, other, copy=False)
1628
1629
    def __array__(self):
1630
        return self.data
1631