gammapy.maps.core   F
last analyzed

Complexity

Total Complexity 205

Size/Duplication

Total Lines 1908
Duplicated Lines 1.47 %

Importance

Changes 0
Metric Value
eloc 650
dl 28
loc 1908
rs 1.95
c 0
b 0
f 0
wmc 205

90 Methods

Rating   Name   Duplication   Size   Complexity  
A Map.quantity() 0 4 1
A Map.meta() 0 4 1
A Map.read() 0 33 2
B Map._get_map_cls() 0 27 6
A Map.upsample() 0 24 1
A Map.downsample() 0 21 1
A Map.unit() 0 4 1
A Map.crop() 0 16 1
A Map.pad() 0 41 4
A Map.data() 0 4 4
A Map.geom() 0 4 1
A Map.from_geom() 0 36 4
A Map.__init__() 0 14 3
A Map._get_meta_from_header() 0 7 2
A Map.write() 0 39 1
A Map.coadd() 0 24 3
A Map.rename_axes() 0 17 1
B Map._get_map_type() 0 20 6
A Map.iter_by_axis() 0 17 3
A Map.create() 0 51 4
A Map._pad_spatial() 0 3 1
A Map.iter_by_image() 0 24 3
A Map.from_hdulist() 0 36 3
A Map.iter_by_image_data() 0 18 2
A Map._init_copy() 0 11 2
A Map.is_mask() 0 4 1
A Map.interp_by_coord() 0 23 1
A Map.interp_by_pix() 0 24 1
A Map.get_by_idx() 0 18 1
A Map.get_by_pix() 0 32 2
A Map.slice_by_idx() 0 23 1
A Map.resample_axis() 0 50 4
A Map._resample_by_idx() 0 18 1
A Map.resample() 0 30 2
A Map.get_image_by_idx() 0 27 2
A Map.get_image_by_coord() 0 63 2
A Map.get_image_by_pix() 0 21 1
A Map.get_by_coord() 0 22 1
A Map.set_by_idx() 0 14 1
A Map.__mul__() 0 2 1
A Map.mask_nearest_position() 0 23 2
A Map._boolean_arithmetics() 0 16 5
A Map.__isub__() 0 2 1
A Map.__itruediv__() 0 2 1
F Map.reproject_to_geom() 0 72 14
A Map.__lt__() 0 2 1
A Map.to_unit() 0 15 1
A Map.normalize() 0 16 2
A Map.__le__() 0 2 1
A Map.cumsum() 0 34 2
A Map.to_cube() 0 25 3
C Map.plot_interactive() 0 79 9
F Map.plot_grid() 0 91 14
A Map.get_spectrum() 0 26 2
A Map.set_by_coord() 0 14 1
A Map._arithmetics() 0 13 4
A Map.__ge__() 0 2 1
A Map.__repr__() 0 7 2
A Map.apply_edisp() 0 26 2
B Map.interp_to_geom() 0 41 6
A Map.fill_by_coord() 0 14 1
A Map.sum_over_axes() 0 20 1
A Map.__ior__() 0 2 1
A Map.fill_events() 0 3 1
A Map.is_allclose() 28 28 3
A Map.__imul__() 0 2 1
A Map.__add__() 0 2 1
A Map.integral() 0 24 1
A Map.__iadd__() 0 2 1
A Map.__sub__() 0 2 1
A Map.split_by_axis() 0 19 2
A Map.set_by_pix() 0 16 1
A Map.__xor__() 0 2 1
A Map.reduce() 0 34 3
A Map.__ixor__() 0 2 1
A Map.reduce_over_axes() 0 33 3
A Map.__eq__() 0 2 1
A Map.fill_by_pix() 0 16 1
B Map.from_stack() 0 46 7
A Map.copy() 0 23 3
A Map.__array__() 0 2 1
A Map.__truediv__() 0 2 1
A Map.__invert__() 0 2 1
A Map.sample_coord() 0 27 1
A Map.__ne__() 0 2 1
A Map.__or__() 0 2 1
A Map.__iand__() 0 2 1
A Map.fill_by_idx() 0 14 1
A Map.__gt__() 0 2 1
A Map.__and__() 0 2 1

How to fix   Duplicated Code    Complexity   

Duplicated Code

Duplicate code is one of the most pungent code smells. A rule that is often used is to re-structure code once it is duplicated in three or more places.

Common duplication problems, and corresponding solutions are:

Complexity

 Tip:   Before tackling complexity, make sure that you eliminate any duplication first. This often can reduce the size of classes significantly.

Complex classes like gammapy.maps.core often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
# Licensed under a 3-clause BSD style license - see LICENSE.rst
2
import abc
3
import copy
4
import inspect
5
import json
6
from collections import OrderedDict
7
import numpy as np
8
from astropy import units as u
9
from astropy.io import fits
10
import matplotlib.pyplot as plt
11
from gammapy.utils.random import InverseCDFSampler, get_random_state
12
from gammapy.utils.scripts import make_path
13
from gammapy.utils.units import energy_unit_format
14
from .axes import MapAxis
15
from .coord import MapCoord
16
from .geom import pix_tuple_to_idx
17
from .io import JsonQuantityDecoder
18
19
__all__ = ["Map"]
20
21
22
class Map(abc.ABC):
23
    """Abstract map class.
24
25
    This can represent WCS- or HEALPIX-based maps
26
    with 2 spatial dimensions and N non-spatial dimensions.
27
28
    Parameters
29
    ----------
30
    geom : `~gammapy.maps.Geom`
31
        Geometry
32
    data : `~numpy.ndarray` or `~astropy.units.Quantity`
33
        Data array
34
    meta : `dict`
35
        Dictionary to store meta data
36
    unit : str or `~astropy.units.Unit`
37
        Data unit, ignored if data is a Quantity.
38
    """
39
40
    tag = "map"
41
42
    def __init__(self, geom, data, meta=None, unit=""):
43
        self._geom = geom
44
45
        if isinstance(data, u.Quantity):
46
            self._unit = u.Unit(unit)
47
            self.quantity = data
48
        else:
49
            self.data = data
50
            self._unit = u.Unit(unit)
51
52
        if meta is None:
53
            self.meta = {}
54
        else:
55
            self.meta = meta
56
57
    def _init_copy(self, **kwargs):
58
        """Init map instance by copying missing init arguments from self."""
59
        argnames = inspect.getfullargspec(self.__init__).args
60
        argnames.remove("self")
61
        argnames.remove("dtype")
62
63
        for arg in argnames:
64
            value = getattr(self, "_" + arg)
65
            kwargs.setdefault(arg, copy.deepcopy(value))
66
67
        return self.from_geom(**kwargs)
68
69
    @property
70
    def is_mask(self):
71
        """Whether map is mask with bool dtype"""
72
        return self.data.dtype == bool
73
74
    @property
75
    def geom(self):
76
        """Map geometry (`~gammapy.maps.Geom`)"""
77
        return self._geom
78
79
    @property
80
    def data(self):
81
        """Data array (`~numpy.ndarray`)"""
82
        return self._data
83
84
    @data.setter
85
    def data(self, value):
86
        """Set data
87
88
        Parameters
89
        ----------
90
        value : array-like
91
            Data array
92
        """
93
        if np.isscalar(value):
94
            value = value * np.ones(self.geom.data_shape, dtype=type(value))
95
96
        if isinstance(value, u.Quantity):
97
            raise TypeError("Map data must be a Numpy array. Set unit separately")
98
99
        if not value.shape == self.geom.data_shape:
100
            value = value.reshape(self.geom.data_shape)
101
102
        self._data = value
103
104
    @property
105
    def unit(self):
106
        """Map unit (`~astropy.units.Unit`)"""
107
        return self._unit
108
109
    @property
110
    def meta(self):
111
        """Map meta (`dict`)"""
112
        return self._meta
113
114
    @meta.setter
115
    def meta(self, val):
116
        self._meta = val
117
118
    @property
119
    def quantity(self):
120
        """Map data times unit (`~astropy.units.Quantity`)"""
121
        return u.Quantity(self.data, self.unit, copy=False)
122
123
    @quantity.setter
124
    def quantity(self, val):
125
        """Set data and unit
126
127
        Parameters
128
        ----------
129
        value : `~astropy.units.Quantity`
130
           Quantity
131
        """
132
        val = u.Quantity(val, copy=False)
133
134
        self.data = val.value
135
        self._unit = val.unit
136
137
    def rename_axes(self, names, new_names):
138
        """Rename the Map axes.
139
140
        Parameters
141
        ----------
142
        names : list or str
143
            Names of the axes.
144
        new_names : list or str
145
            New names of the axes (list must be of same length than `names`).
146
147
        Returns
148
        -------
149
        geom : `~Map`
150
            Renamed Map.
151
        """
152
        geom = self.geom.rename_axes(names=names, new_names=new_names)
153
        return self._init_copy(geom=geom)
154
155
    @staticmethod
156
    def create(**kwargs):
157
        """Create an empty map object.
158
159
        This method accepts generic options listed below, as well as options
160
        for `HpxMap` and `WcsMap` objects. For WCS-specific options, see
161
        `WcsMap.create` and for HPX-specific options, see `HpxMap.create`.
162
163
        Parameters
164
        ----------
165
        frame : str
166
            Coordinate system, either Galactic ("galactic") or Equatorial
167
            ("icrs").
168
        map_type : {'wcs', 'wcs-sparse', 'hpx', 'hpx-sparse', 'region'}
169
            Map type.  Selects the class that will be used to
170
            instantiate the map.
171
        binsz : float or `~numpy.ndarray`
172
            Pixel size in degrees.
173
        skydir : `~astropy.coordinates.SkyCoord`
174
            Coordinate of map center.
175
        axes : list
176
            List of `~MapAxis` objects for each non-spatial dimension.
177
            If None then the map will be a 2D image.
178
        dtype : str
179
            Data type, default is 'float32'
180
        unit : str or `~astropy.units.Unit`
181
            Data unit.
182
        meta : `dict`
183
            Dictionary to store meta data.
184
        region : `~regions.SkyRegion`
185
            Sky region used for the region map.
186
187
        Returns
188
        -------
189
        map : `Map`
190
            Empty map object.
191
        """
192
        from .hpx import HpxMap
193
        from .region import RegionNDMap
194
        from .wcs import WcsMap
195
196
        map_type = kwargs.setdefault("map_type", "wcs")
197
        if "wcs" in map_type.lower():
198
            return WcsMap.create(**kwargs)
199
        elif "hpx" in map_type.lower():
200
            return HpxMap.create(**kwargs)
201
        elif map_type == "region":
202
            _ = kwargs.pop("map_type")
203
            return RegionNDMap.create(**kwargs)
204
        else:
205
            raise ValueError(f"Unrecognized map type: {map_type!r}")
206
207
    @staticmethod
208
    def read(
209
        filename, hdu=None, hdu_bands=None, map_type="auto", format=None, colname=None
210
    ):
211
        """Read a map from a FITS file.
212
213
        Parameters
214
        ----------
215
        filename : str or `~pathlib.Path`
216
            Name of the FITS file.
217
        hdu : str
218
            Name or index of the HDU with the map data.
219
        hdu_bands : str
220
            Name or index of the HDU with the BANDS table.  If not
221
            defined this will be inferred from the FITS header of the
222
            map HDU.
223
        map_type : {'wcs', 'wcs-sparse', 'hpx', 'hpx-sparse', 'auto', 'region'}
224
            Map type.  Selects the class that will be used to
225
            instantiate the map.  The map type should be consistent
226
            with the format of the input file.  If map_type is 'auto'
227
            then an appropriate map type will be inferred from the
228
            input file.
229
        colname : str, optional
230
            data column name to be used of healix map.
231
232
        Returns
233
        -------
234
        map_out : `Map`
235
            Map object
236
        """
237
        with fits.open(str(make_path(filename)), memmap=False) as hdulist:
238
            return Map.from_hdulist(
239
                hdulist, hdu, hdu_bands, map_type, format=format, colname=colname
240
            )
241
242
    @staticmethod
243
    def from_geom(geom, meta=None, data=None, unit="", dtype="float32"):
244
        """Generate an empty map from a `Geom` instance.
245
246
        Parameters
247
        ----------
248
        geom : `Geom`
249
            Map geometry.
250
        data : `numpy.ndarray`
251
            data array
252
        meta : `dict`
253
            Dictionary to store meta data.
254
        unit : str or `~astropy.units.Unit`
255
            Data unit.
256
257
        Returns
258
        -------
259
        map_out : `Map`
260
            Map object
261
262
        """
263
        from .hpx import HpxGeom
264
        from .region import RegionGeom
265
        from .wcs import WcsGeom
266
267
        if isinstance(geom, HpxGeom):
268
            map_type = "hpx"
269
        elif isinstance(geom, WcsGeom):
270
            map_type = "wcs"
271
        elif isinstance(geom, RegionGeom):
272
            map_type = "region"
273
        else:
274
            raise ValueError("Unrecognized geom type.")
275
276
        cls_out = Map._get_map_cls(map_type)
277
        return cls_out(geom, data=data, meta=meta, unit=unit, dtype=dtype)
278
279
    @staticmethod
280
    def from_hdulist(
281
        hdulist, hdu=None, hdu_bands=None, map_type="auto", format=None, colname=None
282
    ):
283
        """Create from `astropy.io.fits.HDUList`.
284
285
        Parameters
286
        ----------
287
        hdulist :  `~astropy.io.fits.HDUList`
288
            HDU list containing HDUs for map data and bands.
289
        hdu : str
290
            Name or index of the HDU with the map data.
291
        hdu_bands : str
292
            Name or index of the HDU with the BANDS table.
293
        map_type : {"auto", "wcs", "hpx", "region"}
294
            Map type.
295
        format : {'gadf', 'fgst-ccube', 'fgst-template'}
296
            FITS format convention.
297
        colname : str, optional
298
            Data column name to be used for the HEALPix map.
299
300
        Returns
301
        -------
302
        map_out : `Map`
303
            Map object
304
        """
305
        if map_type == "auto":
306
            map_type = Map._get_map_type(hdulist, hdu)
307
        cls_out = Map._get_map_cls(map_type)
308
        if map_type == "hpx":
309
            return cls_out.from_hdulist(
310
                hdulist, hdu=hdu, hdu_bands=hdu_bands, format=format, colname=colname
311
            )
312
        else:
313
            return cls_out.from_hdulist(
314
                hdulist, hdu=hdu, hdu_bands=hdu_bands, format=format
315
            )
316
317
    @staticmethod
318
    def _get_meta_from_header(header):
319
        """Load meta data from a FITS header."""
320
        if "META" in header:
321
            return json.loads(header["META"], cls=JsonQuantityDecoder)
322
        else:
323
            return {}
324
325
    @staticmethod
326
    def _get_map_type(hdu_list, hdu_name):
327
        """Infer map type from a FITS HDU.
328
329
        Only read header, never data, to have good performance.
330
        """
331
        if hdu_name is None:
332
            # Find the header of the first non-empty HDU
333
            header = hdu_list[0].header
334
            if header["NAXIS"] == 0:
335
                header = hdu_list[1].header
336
        else:
337
            header = hdu_list[hdu_name].header
338
339
        if ("PIXTYPE" in header) and (header["PIXTYPE"] == "HEALPIX"):
340
            return "hpx"
341
        elif "CTYPE1" in header:
342
            return "wcs"
343
        else:
344
            return "region"
345
346
    @staticmethod
347
    def _get_map_cls(map_type):
348
        """Get map class for given `map_type` string.
349
350
        This should probably be a registry dict so that users
351
        can add supported map types to the `gammapy.maps` I/O
352
        (see e.g. the Astropy table format I/O registry),
353
        but that's non-trivial to implement without avoiding circular imports.
354
        """
355
        if map_type == "wcs":
356
            from .wcs import WcsNDMap
357
358
            return WcsNDMap
359
        elif map_type == "wcs-sparse":
360
            raise NotImplementedError()
361
        elif map_type == "hpx":
362
            from .hpx import HpxNDMap
363
364
            return HpxNDMap
365
        elif map_type == "hpx-sparse":
366
            raise NotImplementedError()
367
        elif map_type == "region":
368
            from .region import RegionNDMap
369
370
            return RegionNDMap
371
        else:
372
            raise ValueError(f"Unrecognized map type: {map_type!r}")
373
374
    def write(self, filename, overwrite=False, **kwargs):
375
        """Write to a FITS file.
376
377
        Parameters
378
        ----------
379
        filename : str
380
            Output file name.
381
        overwrite : bool
382
            Overwrite existing file?
383
        hdu : str
384
            Set the name of the image extension.  By default this will
385
            be set to SKYMAP (for BINTABLE HDU) or PRIMARY (for IMAGE
386
            HDU).
387
        hdu_bands : str
388
            Set the name of the bands table extension.  By default this will
389
            be set to BANDS.
390
        format : str, optional
391
            FITS format convention.  By default files will be written
392
            to the gamma-astro-data-formats (GADF) format.  This
393
            option can be used to write files that are compliant with
394
            format conventions required by specific software (e.g. the
395
            Fermi Science Tools). The following formats are supported:
396
397
                - "gadf" (default)
398
                - "fgst-ccube"
399
                - "fgst-ltcube"
400
                - "fgst-bexpcube"
401
                - "fgst-srcmap"
402
                - "fgst-template"
403
                - "fgst-srcmap-sparse"
404
                - "galprop"
405
                - "galprop2"
406
407
        sparse : bool
408
            Sparsify the map by dropping pixels with zero amplitude.
409
            This option is only compatible with the 'gadf' format.
410
        """
411
        hdulist = self.to_hdulist(**kwargs)
412
        hdulist.writeto(str(make_path(filename)), overwrite=overwrite)
413
414
    def iter_by_axis(self, axis_name, keepdims=False):
415
        """ "Iterate over a given axis
416
417
        Yields
418
        ------
419
        map : `Map`
420
            Map iteration.
421
422
        See also
423
        --------
424
        iter_by_image : iterate by image returning a map
425
        """
426
        axis = self.geom.axes[axis_name]
427
        for idx in range(axis.nbin):
428
            idx_axis = slice(idx, idx + 1) if keepdims else idx
429
            slices = {axis_name: idx_axis}
430
            yield self.slice_by_idx(slices=slices)
431
432
    def iter_by_image(self, keepdims=False):
433
        """Iterate over image planes of a map.
434
435
        Parameters
436
        ----------
437
        keepdims : bool
438
            Keep dimensions.
439
440
        Yields
441
        ------
442
        map : `Map`
443
            Map iteration.
444
445
        See also
446
        --------
447
        iter_by_image_data : iterate by image returning data and index
448
        """
449
        for idx in np.ndindex(self.geom.shape_axes):
450
            if keepdims:
451
                names = self.geom.axes.names
452
                slices = {name: slice(_, _ + 1) for name, _ in zip(names, idx)}
453
                yield self.slice_by_idx(slices=slices)
454
            else:
455
                yield self.get_image_by_idx(idx=idx)
456
457
    def iter_by_image_data(self):
458
        """Iterate over image planes of the map.
459
460
        The image plane index is in data order, so that the data array can be
461
        indexed directly.
462
463
        Yields
464
        ------
465
        (data, idx) : tuple
466
            Where ``data`` is a `numpy.ndarray` view of the image plane data,
467
            and ``idx`` is a tuple of int, the index of the image plane.
468
469
        See also
470
        --------
471
        iter_by_image : iterate by image returning a map
472
        """
473
        for idx in np.ndindex(self.geom.shape_axes):
474
            yield self.data[idx[::-1]], idx[::-1]
475
476
    def coadd(self, map_in, weights=None):
477
        """Add the contents of ``map_in`` to this map.
478
479
        This method can be used to combine maps containing integral quantities (e.g. counts)
480
        or differential quantities if the maps have the same binning.
481
482
        Parameters
483
        ----------
484
        map_in : `Map`
485
            Input map.
486
        weights: `Map` or `~numpy.ndarray`
487
            The weight factors while adding
488
        """
489
        if not self.unit.is_equivalent(map_in.unit):
490
            raise ValueError("Incompatible units")
491
492
        # TODO: Check whether geometries are aligned and if so sum the
493
        # data vectors directly
494
        if weights is not None:
495
            map_in = map_in * weights
496
        idx = map_in.geom.get_idx()
497
        coords = map_in.geom.get_coord()
498
        vals = u.Quantity(map_in.get_by_idx(idx), map_in.unit)
499
        self.fill_by_coord(coords, vals)
500
501
    def pad(self, pad_width, axis_name=None, mode="constant", cval=0, method="linear"):
502
        """Pad the spatial dimensions of the map.
503
504
        Parameters
505
        ----------
506
        pad_width : {sequence, array_like, int}
507
            Number of pixels padded to the edges of each axis.
508
        axis_name : str
509
            Which axis to downsample. By default spatial axes are padded.
510
        mode : {'edge', 'constant', 'interp'}
511
            Padding mode.  'edge' pads with the closest edge value.
512
            'constant' pads with a constant value. 'interp' pads with
513
            an extrapolated value.
514
        cval : float
515
            Padding value when mode='consant'.
516
517
        Returns
518
        -------
519
        map : `Map`
520
            Padded map.
521
522
        """
523
        if axis_name:
524
            if np.isscalar(pad_width):
525
                pad_width = (pad_width, pad_width)
526
527
            geom = self.geom.pad(pad_width=pad_width, axis_name=axis_name)
528
            idx = self.geom.axes.index_data(axis_name)
529
            pad_width_np = [(0, 0)] * self.data.ndim
530
            pad_width_np[idx] = pad_width
531
532
            kwargs = {}
533
            if mode == "constant":
534
                kwargs["constant_values"] = cval
535
536
            data = np.pad(self.data, pad_width=pad_width_np, mode=mode, **kwargs)
537
            return self.__class__(
538
                geom=geom, data=data, unit=self.unit, meta=self.meta.copy()
539
            )
540
541
        return self._pad_spatial(pad_width, mode="constant", cval=cval)
542
543
    @abc.abstractmethod
544
    def _pad_spatial(self, pad_width, mode="constant", cval=0, order=1):
545
        pass
546
547
    @abc.abstractmethod
548
    def crop(self, crop_width):
549
        """Crop the spatial dimensions of the map.
550
551
        Parameters
552
        ----------
553
        crop_width : {sequence, array_like, int}
554
            Number of pixels cropped from the edges of each axis.
555
            Defined analogously to ``pad_with`` from `numpy.pad`.
556
557
        Returns
558
        -------
559
        map : `Map`
560
            Cropped map.
561
        """
562
        pass
563
564
    @abc.abstractmethod
565
    def downsample(self, factor, preserve_counts=True, axis_name=None):
566
        """Downsample the spatial dimension by a given factor.
567
568
        Parameters
569
        ----------
570
        factor : int
571
            Downsampling factor.
572
        preserve_counts : bool
573
            Preserve the integral over each bin.  This should be true
574
            if the map is an integral quantity (e.g. counts) and false if
575
            the map is a differential quantity (e.g. intensity).
576
        axis_name : str
577
            Which axis to downsample. By default spatial axes are downsampled.
578
579
        Returns
580
        -------
581
        map : `Map`
582
            Downsampled map.
583
        """
584
        pass
585
586
    @abc.abstractmethod
587
    def upsample(self, factor, order=0, preserve_counts=True, axis_name=None):
588
        """Upsample the spatial dimension by a given factor.
589
590
        Parameters
591
        ----------
592
        factor : int
593
            Upsampling factor.
594
        order : int
595
            Order of the interpolation used for upsampling.
596
        preserve_counts : bool
597
            Preserve the integral over each bin.  This should be true
598
            if the map is an integral quantity (e.g. counts) and false if
599
            the map is a differential quantity (e.g. intensity).
600
        axis_name : str
601
            Which axis to upsample. By default spatial axes are upsampled.
602
603
604
        Returns
605
        -------
606
        map : `Map`
607
            Upsampled map.
608
        """
609
        pass
610
611
    def resample(self, geom, weights=None, preserve_counts=True):
612
        """Resample pixels to ``geom`` with given ``weights``.
613
614
        Parameters
615
        ----------
616
        geom : `~gammapy.maps.Geom`
617
            Target Map geometry
618
        weights : `~numpy.ndarray`
619
            Weights vector. Default is weight of one. Must have same shape as
620
            the data of the map.
621
        preserve_counts : bool
622
            Preserve the integral over each bin.  This should be true
623
            if the map is an integral quantity (e.g. counts) and false if
624
            the map is a differential quantity (e.g. intensity)
625
626
        Returns
627
        -------
628
        resampled_map : `Map`
629
            Resampled map
630
        """
631
        coords = self.geom.get_coord()
632
        idx = geom.coord_to_idx(coords)
633
634
        weights = 1 if weights is None else weights
635
636
        resampled = self.from_geom(geom=geom)
637
        resampled._resample_by_idx(
638
            idx, weights=self.data * weights, preserve_counts=preserve_counts
639
        )
640
        return resampled
641
642
    @abc.abstractmethod
643
    def _resample_by_idx(self, idx, weights=None, preserve_counts=False):
644
        """Resample pixels at ``idx`` with given ``weights``.
645
646
        Parameters
647
        ----------
648
        idx : tuple
649
            Tuple of pixel index arrays for each dimension of the map.
650
            Tuple should be ordered as (I_lon, I_lat, I_0, ..., I_n)
651
            for WCS maps and (I_hpx, I_0, ..., I_n) for HEALPix maps.
652
        weights : `~numpy.ndarray`
653
            Weights vector. Default is weight of one.
654
        preserve_counts : bool
655
            Preserve the integral over each bin.  This should be true
656
            if the map is an integral quantity (e.g. counts) and false if
657
            the map is a differential quantity (e.g. intensity)
658
        """
659
        pass
660
661
    def resample_axis(self, axis, weights=None, ufunc=np.add):
662
        """Resample map to a new axis by grouping and reducing smaller bins by a given ufunc
663
664
        By default, the map content are summed over the smaller bins. Other numpy ufunc can be
665
        used, e.g. `numpy.logical_and` or `numpy.logical_or`.
666
667
        Parameters
668
        ----------
669
        axis : `MapAxis`
670
            New map axis.
671
        weights : `Map`
672
            Array to be used as weights. The spatial geometry must be equivalent
673
            to `other` and additional axes must be broadcastable.
674
        ufunc : `~numpy.ufunc`
675
            ufunc to use to resample the axis. Default is numpy.add.
676
677
678
        Returns
679
        -------
680
        map : `Map`
681
            Map with resampled axis.
682
        """
683
        from .hpx import HpxGeom
684
685
        geom = self.geom.resample_axis(axis)
686
687
        axis_self = self.geom.axes[axis.name]
688
        axis_resampled = geom.axes[axis.name]
689
690
        # We don't use MapAxis.coord_to_idx because is does not behave as needed with boundaries
691
        coord = axis_resampled.edges.value
692
        edges = axis_self.edges.value
693
        indices = np.digitize(coord, edges) - 1
694
695
        idx = self.geom.axes.index_data(axis.name)
696
697
        weights = 1 if weights is None else weights.data
698
699
        if not isinstance(self.geom, HpxGeom):
700
            shape = self.geom._shape[:2]
701
        else:
702
            shape = (self.geom.data_shape[-1],)
703
        shape += tuple([ax.nbin if ax != axis else 1 for ax in self.geom.axes])
704
705
        padded_array = np.append(self.data * weights, np.zeros(shape[::-1]), axis=idx)
706
707
        slices = tuple([slice(0, _) for _ in geom.data_shape])
708
        data = ufunc.reduceat(padded_array, indices=indices, axis=idx)[slices]
709
710
        return self._init_copy(data=data, geom=geom)
711
712
    def slice_by_idx(
713
        self,
714
        slices,
715
    ):
716
        """Slice sub map from map object.
717
718
        Parameters
719
        ----------
720
        slices : dict
721
            Dict of axes names and integers or `slice` object pairs. Contains one
722
            element for each non-spatial dimension. For integer indexing the
723
            corresponding axes is dropped from the map. Axes not specified in the
724
            dict are kept unchanged.
725
726
        Returns
727
        -------
728
        map_out : `Map`
729
            Sliced map object.
730
        """
731
        geom = self.geom.slice_by_idx(slices)
732
        slices = tuple([slices.get(ax.name, slice(None)) for ax in self.geom.axes])
733
        data = self.data[slices[::-1]]
734
        return self.__class__(geom=geom, data=data, unit=self.unit, meta=self.meta)
735
736
    def get_image_by_coord(self, coords):
737
        """Return spatial map at the given axis coordinates.
738
739
        Parameters
740
        ----------
741
        coords : tuple or dict
742
            Tuple should be ordered as (x_0, ..., x_n) where x_i are coordinates
743
            for non-spatial dimensions of the map. Dict should specify the axis
744
            names of the non-spatial axes such as {'axes0': x_0, ..., 'axesn': x_n}.
745
746
        Returns
747
        -------
748
        map_out : `Map`
749
            Map with spatial dimensions only.
750
751
        See Also
752
        --------
753
        get_image_by_idx, get_image_by_pix
754
755
        Examples
756
        --------
757
        ::
758
759
            import numpy as np
760
            from gammapy.maps import Map, MapAxis
761
            from astropy.coordinates import SkyCoord
762
            from astropy import units as u
763
764
            # Define map axes
765
            energy_axis = MapAxis.from_edges(
766
                np.logspace(-1., 1., 4), unit='TeV', name='energy',
767
            )
768
769
            time_axis = MapAxis.from_edges(
770
                np.linspace(0., 10, 20), unit='h', name='time',
771
            )
772
773
            # Define map center
774
            skydir = SkyCoord(0, 0, frame='galactic', unit='deg')
775
776
            # Create map
777
            m_wcs = Map.create(
778
                map_type='wcs',
779
                binsz=0.02,
780
                skydir=skydir,
781
                width=10.0,
782
                axes=[energy_axis, time_axis],
783
            )
784
785
            # Get image by coord tuple
786
            image = m_wcs.get_image_by_coord(('500 GeV', '1 h'))
787
788
            # Get image by coord dict with strings
789
            image = m_wcs.get_image_by_coord({'energy': '500 GeV', 'time': '1 h'})
790
791
            # Get image by coord dict with quantities
792
            image = m_wcs.get_image_by_coord({'energy': 0.5 * u.TeV, 'time': 1 * u.h})
793
        """
794
        if isinstance(coords, tuple):
795
            coords = dict(zip(self.geom.axes.names, coords))
796
797
        idx = self.geom.axes.coord_to_idx(coords)
798
        return self.get_image_by_idx(idx)
799
800
    def get_image_by_pix(self, pix):
801
        """Return spatial map at the given axis pixel coordinates
802
803
        Parameters
804
        ----------
805
        pix : tuple
806
            Tuple of scalar pixel coordinates for each non-spatial dimension of
807
            the map. Tuple should be ordered as (I_0, ..., I_n). Pixel coordinates
808
            can be either float or integer type.
809
810
        See Also
811
        --------
812
        get_image_by_coord, get_image_by_idx
813
814
        Returns
815
        -------
816
        map_out : `Map`
817
            Map with spatial dimensions only.
818
        """
819
        idx = self.geom.pix_to_idx(pix)
820
        return self.get_image_by_idx(idx)
821
822
    def get_image_by_idx(self, idx):
823
        """Return spatial map at the given axis pixel indices.
824
825
        Parameters
826
        ----------
827
        idx : tuple
828
            Tuple of scalar indices for each non spatial dimension of the map.
829
            Tuple should be ordered as (I_0, ..., I_n).
830
831
        See Also
832
        --------
833
        get_image_by_coord, get_image_by_pix
834
835
        Returns
836
        -------
837
        map_out : `Map`
838
            Map with spatial dimensions only.
839
        """
840
        if len(idx) != len(self.geom.axes):
841
            raise ValueError("Tuple length must equal number of non-spatial dimensions")
842
843
        # Only support scalar indices per axis
844
        idx = tuple([int(_) for _ in idx])
845
846
        geom = self.geom.to_image()
847
        data = self.data[idx[::-1]]
848
        return self.__class__(geom=geom, data=data, unit=self.unit, meta=self.meta)
849
850
    def get_by_coord(self, coords, fill_value=np.nan):
851
        """Return map values at the given map coordinates.
852
853
        Parameters
854
        ----------
855
        coords : tuple or `~gammapy.maps.MapCoord`
856
            Coordinate arrays for each dimension of the map.  Tuple
857
            should be ordered as (lon, lat, x_0, ..., x_n) where x_i
858
            are coordinates for non-spatial dimensions of the map.
859
        fill_value : float
860
            Value which is returned if the position is outside of the projection
861
            footprint
862
863
        Returns
864
        -------
865
        vals : `~numpy.ndarray`
866
           Values of pixels in the map.  np.nan used to flag coords
867
           outside of map.
868
        """
869
        pix = self.geom.coord_to_pix(coords=coords)
870
        vals = self.get_by_pix(pix, fill_value=fill_value)
871
        return vals
872
873
    def get_by_pix(self, pix, fill_value=np.nan):
874
        """Return map values at the given pixel coordinates.
875
876
        Parameters
877
        ----------
878
        pix : tuple
879
            Tuple of pixel index arrays for each dimension of the map.
880
            Tuple should be ordered as (I_lon, I_lat, I_0, ..., I_n)
881
            for WCS maps and (I_hpx, I_0, ..., I_n) for HEALPix maps.
882
            Pixel indices can be either float or integer type.
883
        fill_value : float
884
            Value which is returned if the position is outside of the projection
885
            footprint
886
887
        Returns
888
        -------
889
        vals : `~numpy.ndarray`
890
           Array of pixel values.  np.nan used to flag coordinates
891
           outside of map
892
        """
893
        # FIXME: Support local indexing here?
894
        # FIXME: Support slicing?
895
        pix = np.broadcast_arrays(*pix)
896
        idx = self.geom.pix_to_idx(pix)
897
        vals = self.get_by_idx(idx)
898
        mask = self.geom.contains_pix(pix)
899
900
        if not mask.all():
901
            vals = vals.astype(type(fill_value))
902
            vals[~mask] = fill_value
903
904
        return vals
905
906
    @abc.abstractmethod
907
    def get_by_idx(self, idx):
908
        """Return map values at the given pixel indices.
909
910
        Parameters
911
        ----------
912
        idx : tuple
913
            Tuple of pixel index arrays for each dimension of the map.
914
            Tuple should be ordered as (I_lon, I_lat, I_0, ..., I_n)
915
            for WCS maps and (I_hpx, I_0, ..., I_n) for HEALPix maps.
916
917
        Returns
918
        -------
919
        vals : `~numpy.ndarray`
920
           Array of pixel values.
921
           np.nan used to flag coordinate outside of map
922
        """
923
        pass
924
925
    @abc.abstractmethod
926
    def interp_by_coord(self, coords, method="linear", fill_value=None):
927
        """Interpolate map values at the given map coordinates.
928
929
        Parameters
930
        ----------
931
        coords : tuple or `~gammapy.maps.MapCoord`
932
            Coordinate arrays for each dimension of the map.  Tuple
933
            should be ordered as (lon, lat, x_0, ..., x_n) where x_i
934
            are coordinates for non-spatial dimensions of the map.
935
        method : {"linear", "nearest"}
936
            Method to interpolate data values. By default linear
937
            interpolation is performed.
938
        fill_value : None or float value
939
            The value to use for points outside of the interpolation domain.
940
            If None, values outside the domain are extrapolated.
941
942
        Returns
943
        -------
944
        vals : `~numpy.ndarray`
945
            Interpolated pixel values.
946
        """
947
        pass
948
949
    @abc.abstractmethod
950
    def interp_by_pix(self, pix, method="linear", fill_value=None):
951
        """Interpolate map values at the given pixel coordinates.
952
953
        Parameters
954
        ----------
955
        pix : tuple
956
            Tuple of pixel coordinate arrays for each dimension of the
957
            map.  Tuple should be ordered as (p_lon, p_lat, p_0, ...,
958
            p_n) where p_i are pixel coordinates for non-spatial
959
            dimensions of the map.
960
        method : {"linear", "nearest"}
961
            Method to interpolate data values. By default linear
962
            interpolation is performed.
963
        fill_value : None or float value
964
            The value to use for points outside of the interpolation domain.
965
            If None, values outside the domain are extrapolated.
966
967
        Returns
968
        -------
969
        vals : `~numpy.ndarray`
970
            Interpolated pixel values.
971
        """
972
        pass
973
974
    def interp_to_geom(self, geom, preserve_counts=False, fill_value=0, **kwargs):
975
        """Interpolate map to input geometry.
976
977
        Parameters
978
        ----------
979
        geom : `~gammapy.maps.Geom`
980
            Target Map geometry
981
        preserve_counts : bool
982
            Preserve the integral over each bin.  This should be true
983
            if the map is an integral quantity (e.g. counts) and false if
984
            the map is a differential quantity (e.g. intensity)
985
        **kwargs : dict
986
            Keyword arguments passed to `Map.interp_by_coord`
987
988
        Returns
989
        -------
990
        interp_map : `Map`
991
            Interpolated Map
992
        """
993
        coords = geom.get_coord()
994
        map_copy = self.copy()
995
996
        if preserve_counts:
997
            if geom.ndim > 2 and geom.axes[0] != self.geom.axes[0]:
998
                raise ValueError(
999
                    f"Energy axis do not match: expected {self.geom.axes[0]},"
1000
                    " but got {geom.axes[0]}."
1001
                )
1002
            map_copy.data /= map_copy.geom.solid_angle().to_value("deg2")
1003
1004
        if map_copy.is_mask:
1005
            # TODO: check this NaN handling is needed
1006
            data = map_copy.get_by_coord(coords)
1007
            data = np.nan_to_num(data, nan=fill_value).astype(bool)
1008
        else:
1009
            data = map_copy.interp_by_coord(coords, fill_value=fill_value, **kwargs)
1010
1011
        if preserve_counts:
1012
            data *= geom.solid_angle().to_value("deg2")
1013
1014
        return Map.from_geom(geom, data=data, unit=self.unit)
1015
1016
    def reproject_to_geom(self, geom, preserve_counts=False, precision_factor=10):
1017
        """Reproject map to input geometry.
1018
1019
        Parameters
1020
        ----------
1021
        geom : `~gammapy.maps.Geom`
1022
            Target Map geometry
1023
        preserve_counts : bool
1024
            Preserve the integral over each bin.  This should be true
1025
            if the map is an integral quantity (e.g. counts) and false if
1026
            the map is a differential quantity (e.g. intensity)
1027
        precision_factor : int
1028
           Minimal factor between the bin size of the output map and the oversampled base map.
1029
           Used only for the oversampling method.
1030
1031
        Returns
1032
        -------
1033
        output_map : `Map`
1034
            Reprojected Map
1035
        """
1036
        from .hpx import HpxGeom
1037
        from .region import RegionGeom
1038
1039
        axes = [ax.copy() for ax in self.geom.axes]
1040
        geom3d = geom.copy(axes=axes)
1041
1042
        if not geom.is_image:
1043
            if geom.axes.names != geom3d.axes.names:
1044
                raise ValueError("Axis names and order should be the same.")
1045
            if geom.axes != geom3d.axes and (
1046
                isinstance(geom3d, HpxGeom) or isinstance(self.geom, HpxGeom)
1047
            ):
1048
                raise TypeError(
1049
                    "Reprojection to 3d geom with non-identical axes is not supported for HpxGeom. "
1050
                    "Reproject to 2d geom first and then use inter_to_geom method."
1051
                )
1052
        if isinstance(geom3d, RegionGeom):
1053
            base_factor = (
1054
                geom3d.to_wcs_geom().pixel_scales.min() / self.geom.pixel_scales.min()
1055
            )
1056
        elif isinstance(self.geom, RegionGeom):
1057
            base_factor = (
1058
                geom3d.pixel_scales.min() / self.geom.to_wcs_geom().pixel_scales.min()
1059
            )
1060
        else:
1061
            base_factor = geom3d.pixel_scales.min() / self.geom.pixel_scales.min()
1062
1063
        if base_factor >= precision_factor:
1064
            input_map = self
1065
        else:
1066
            factor = precision_factor / base_factor
1067
            if isinstance(self.geom, HpxGeom):
1068
                factor = int(2 ** np.ceil(np.log(factor) / np.log(2)))
1069
            else:
1070
                factor = int(np.ceil(factor))
1071
            input_map = self.upsample(factor=factor, preserve_counts=preserve_counts)
1072
1073
        output_map = input_map.resample(geom3d, preserve_counts=preserve_counts)
1074
1075
        if not geom.is_image and geom.axes != geom3d.axes:
1076
            for base_ax, target_ax in zip(geom3d.axes, geom.axes):
1077
                base_factor = base_ax.bin_width.min() / target_ax.bin_width.min()
1078
                if not base_factor >= precision_factor:
1079
                    factor = precision_factor / base_factor
1080
                    factor = int(np.ceil(factor))
1081
                    output_map = output_map.upsample(
1082
                        factor=factor,
1083
                        preserve_counts=preserve_counts,
1084
                        axis_name=base_ax.name,
1085
                    )
1086
            output_map = output_map.resample(geom, preserve_counts=preserve_counts)
1087
        return output_map
1088
1089
    def fill_events(self, events):
1090
        """Fill event coordinates (`~gammapy.data.EventList`)."""
1091
        self.fill_by_coord(events.map_coord(self.geom))
1092
1093
    def fill_by_coord(self, coords, weights=None):
1094
        """Fill pixels at ``coords`` with given ``weights``.
1095
1096
        Parameters
1097
        ----------
1098
        coords : tuple or `~gammapy.maps.MapCoord`
1099
            Coordinate arrays for each dimension of the map.  Tuple
1100
            should be ordered as (lon, lat, x_0, ..., x_n) where x_i
1101
            are coordinates for non-spatial dimensions of the map.
1102
        weights : `~numpy.ndarray`
1103
            Weights vector. Default is weight of one.
1104
        """
1105
        idx = self.geom.coord_to_idx(coords)
1106
        self.fill_by_idx(idx, weights=weights)
1107
1108
    def fill_by_pix(self, pix, weights=None):
1109
        """Fill pixels at ``pix`` with given ``weights``.
1110
1111
        Parameters
1112
        ----------
1113
        pix : tuple
1114
            Tuple of pixel index arrays for each dimension of the map.
1115
            Tuple should be ordered as (I_lon, I_lat, I_0, ..., I_n)
1116
            for WCS maps and (I_hpx, I_0, ..., I_n) for HEALPix maps.
1117
            Pixel indices can be either float or integer type.  Float
1118
            indices will be rounded to the nearest integer.
1119
        weights : `~numpy.ndarray`
1120
            Weights vector. Default is weight of one.
1121
        """
1122
        idx = pix_tuple_to_idx(pix)
1123
        return self.fill_by_idx(idx, weights=weights)
1124
1125
    @abc.abstractmethod
1126
    def fill_by_idx(self, idx, weights=None):
1127
        """Fill pixels at ``idx`` with given ``weights``.
1128
1129
        Parameters
1130
        ----------
1131
        idx : tuple
1132
            Tuple of pixel index arrays for each dimension of the map.
1133
            Tuple should be ordered as (I_lon, I_lat, I_0, ..., I_n)
1134
            for WCS maps and (I_hpx, I_0, ..., I_n) for HEALPix maps.
1135
        weights : `~numpy.ndarray`
1136
            Weights vector. Default is weight of one.
1137
        """
1138
        pass
1139
1140
    def set_by_coord(self, coords, vals):
1141
        """Set pixels at ``coords`` with given ``vals``.
1142
1143
        Parameters
1144
        ----------
1145
        coords : tuple or `~gammapy.maps.MapCoord`
1146
            Coordinate arrays for each dimension of the map.  Tuple
1147
            should be ordered as (lon, lat, x_0, ..., x_n) where x_i
1148
            are coordinates for non-spatial dimensions of the map.
1149
        vals : `~numpy.ndarray`
1150
            Values vector.
1151
        """
1152
        idx = self.geom.coord_to_pix(coords)
1153
        self.set_by_pix(idx, vals)
1154
1155
    def set_by_pix(self, pix, vals):
1156
        """Set pixels at ``pix`` with given ``vals``.
1157
1158
        Parameters
1159
        ----------
1160
        pix : tuple
1161
            Tuple of pixel index arrays for each dimension of the map.
1162
            Tuple should be ordered as (I_lon, I_lat, I_0, ..., I_n)
1163
            for WCS maps and (I_hpx, I_0, ..., I_n) for HEALPix maps.
1164
            Pixel indices can be either float or integer type.  Float
1165
            indices will be rounded to the nearest integer.
1166
        vals : `~numpy.ndarray`
1167
            Values vector.
1168
        """
1169
        idx = pix_tuple_to_idx(pix)
1170
        return self.set_by_idx(idx, vals)
1171
1172
    @abc.abstractmethod
1173
    def set_by_idx(self, idx, vals):
1174
        """Set pixels at ``idx`` with given ``vals``.
1175
1176
        Parameters
1177
        ----------
1178
        idx : tuple
1179
            Tuple of pixel index arrays for each dimension of the map.
1180
            Tuple should be ordered as (I_lon, I_lat, I_0, ..., I_n)
1181
            for WCS maps and (I_hpx, I_0, ..., I_n) for HEALPix maps.
1182
        vals : `~numpy.ndarray`
1183
            Values vector.
1184
        """
1185
        pass
1186
1187
    def plot_grid(self, figsize=None, ncols=3, **kwargs):
1188
        """Plot map as a grid of subplots for non-spatial axes
1189
1190
        Parameters
1191
        ----------
1192
        figsize : tuple of int
1193
            Figsize to plot on
1194
        ncols : int
1195
            Number of columns to plot
1196
        **kwargs : dict
1197
            Keyword arguments passed to `Map.plot`.
1198
1199
        Returns
1200
        -------
1201
        axes : `~numpy.ndarray` of `~matplotlib.pyplot.Axes`
1202
            Axes grid
1203
        """
1204
        if len(self.geom.axes) > 1:
1205
            raise ValueError("Grid plotting is only supported for one non spatial axis")
1206
1207
        axis = self.geom.axes[0]
1208
1209
        cols = min(ncols, axis.nbin)
1210
        rows = 1 + (axis.nbin - 1) // cols
1211
1212
        if figsize is None:
1213
            width = 12
1214
            figsize = (width, width * rows / cols)
1215
1216
        if self.geom.is_hpx:
1217
            wcs = self.geom.to_wcs_geom().wcs
1218
        else:
1219
            wcs = self.geom.wcs
1220
1221
        fig, axes = plt.subplots(
1222
            ncols=cols,
1223
            nrows=rows,
1224
            subplot_kw={"projection": wcs},
1225
            figsize=figsize,
1226
            gridspec_kw={"hspace": 0.1, "wspace": 0.1},
1227
        )
1228
1229
        for idx in range(cols * rows):
1230
            ax = axes.flat[idx]
1231
1232
            try:
1233
                image = self.get_image_by_idx((idx,))
1234
            except IndexError:
1235
                ax.set_visible(False)
1236
                continue
1237
1238
            if image.geom.is_hpx:
1239
                image_wcs = image.to_wcs(
1240
                    normalize=False,
1241
                    proj="AIT",
1242
                    oversample=2,
1243
                )
1244
            else:
1245
                image_wcs = image
1246
1247
            image_wcs.plot(ax=ax, **kwargs)
1248
1249
            if axis.node_type == "center":
1250
                if axis.name == "energy" or axis.name == "energy_true":
1251
                    info = energy_unit_format(axis.center[idx])
1252
                else:
1253
                    info = f"{axis.center[idx]:.1f}"
1254
            else:
1255
                if axis.name == "energy" or axis.name == "energy_true":
1256
                    info = (
1257
                        f"{energy_unit_format(axis.edges[idx])} - "
1258
                        f"{energy_unit_format(axis.edges[idx+1])}"
1259
                    )
1260
                else:
1261
                    info = f"{axis.edges[idx]:.1f} - {axis.edges[idx + 1]:.1f} "
1262
            ax.set_title(f"{axis.name.capitalize()} " + info)
1263
            lon, lat = ax.coords[0], ax.coords[1]
1264
            lon.set_ticks_position("b")
1265
            lat.set_ticks_position("l")
1266
1267
            row, col = np.unravel_index(idx, shape=(rows, cols))
1268
1269
            if col > 0:
1270
                lat.set_ticklabel_visible(False)
1271
                lat.set_axislabel("")
1272
1273
            if row < (rows - 1):
1274
                lon.set_ticklabel_visible(False)
1275
                lon.set_axislabel("")
1276
1277
        return axes
1278
1279
    def plot_interactive(self, rc_params=None, **kwargs):
1280
        """
1281
        Plot map with interactive widgets to explore the non spatial axes.
1282
1283
        Parameters
1284
        ----------
1285
        rc_params : dict
1286
            Passed to ``matplotlib.rc_context(rc=rc_params)`` to style the plot.
1287
        **kwargs : dict
1288
            Keyword arguments passed to `WcsNDMap.plot`.
1289
1290
        Examples
1291
        --------
1292
        You can try this out e.g. using a Fermi-LAT diffuse model cube with an energy axis::
1293
1294
            from gammapy.maps import Map
1295
1296
            m = Map.read("$GAMMAPY_DATA/fermi_3fhl/gll_iem_v06_cutout.fits")
1297
            m.plot_interactive(add_cbar=True, stretch="sqrt")
1298
1299
        If you would like to adjust the figure size you can use the ``rc_params`` argument::
1300
1301
            rc_params = {'figure.figsize': (12, 6), 'font.size': 12}
1302
            m.plot_interactive(rc_params=rc_params)
1303
        """
1304
        import matplotlib as mpl
1305
        from ipywidgets import RadioButtons, SelectionSlider
1306
        from ipywidgets.widgets.interaction import fixed, interact
1307
1308
        if self.geom.is_image:
1309
            raise TypeError("Use .plot() for 2D Maps")
1310
1311
        kwargs.setdefault("interpolation", "nearest")
1312
        kwargs.setdefault("origin", "lower")
1313
        kwargs.setdefault("cmap", "afmhot")
1314
1315
        rc_params = rc_params or {}
1316
        stretch = kwargs.pop("stretch", "sqrt")
1317
1318
        interact_kwargs = {}
1319
1320
        for axis in self.geom.axes:
1321
            if axis.node_type == "center":
1322
                if axis.name == "energy" or axis.name == "energy_true":
1323
                    options = energy_unit_format(axis.center)
1324
                else:
1325
                    options = axis.as_plot_labels
1326
            elif axis.name == "energy" or axis.name == "energy_true":
1327
                E = energy_unit_format(axis.edges)
1328
                options = [f"{E[i]} - {E[i+1]}" for i in range(len(E) - 1)]
1329
            else:
1330
                options = axis.as_plot_labels
1331
            interact_kwargs[axis.name] = SelectionSlider(
1332
                options=options,
1333
                description=f"Select {axis.name}:",
1334
                continuous_update=False,
1335
                style={"description_width": "initial"},
1336
                layout={"width": "50%"},
1337
            )
1338
            interact_kwargs[axis.name + "_options"] = fixed(options)
1339
1340
        interact_kwargs["stretch"] = RadioButtons(
1341
            options=["linear", "sqrt", "log"],
1342
            value=stretch,
1343
            description="Select stretch:",
1344
            style={"description_width": "initial"},
1345
        )
1346
1347
        @interact(**interact_kwargs)
1348
        def _plot_interactive(**ikwargs):
1349
            idx = [
1350
                ikwargs[ax.name + "_options"].index(ikwargs[ax.name])
1351
                for ax in self.geom.axes
1352
            ]
1353
            img = self.get_image_by_idx(idx)
1354
            stretch = ikwargs["stretch"]
1355
            with mpl.rc_context(rc=rc_params):
1356
                img.plot(stretch=stretch, **kwargs)
1357
                plt.show()
1358
1359
    def copy(self, **kwargs):
1360
        """Copy map instance and overwrite given attributes, except for geometry.
1361
1362
        Parameters
1363
        ----------
1364
        **kwargs : dict
1365
            Keyword arguments to overwrite in the map constructor.
1366
1367
        Returns
1368
        -------
1369
        copy : `Map`
1370
            Copied Map.
1371
        """
1372
        if "geom" in kwargs:
1373
            geom = kwargs["geom"]
1374
            if not geom.data_shape == self.geom.data_shape:
1375
                raise ValueError(
1376
                    "Can't copy and change data size of the map. "
1377
                    f" Current shape {self.geom.data_shape},"
1378
                    f" requested shape {geom.data_shape}"
1379
                )
1380
1381
        return self._init_copy(**kwargs)
1382
1383
    def apply_edisp(self, edisp):
1384
        """Apply energy dispersion to map. Requires energy axis.
1385
1386
        Parameters
1387
        ----------
1388
        edisp : `gammapy.irf.EDispKernel`
1389
            Energy dispersion matrix
1390
1391
        Returns
1392
        -------
1393
        map : `WcsNDMap`
1394
            Map with energy dispersion applied.
1395
        """
1396
        # TODO: either use sparse matrix mutiplication or something like edisp.is_diagonal
1397
        if edisp is not None:
1398
            loc = self.geom.axes.index("energy_true")
1399
            data = np.rollaxis(self.data, loc, len(self.data.shape))
1400
            data = np.dot(data, edisp.pdf_matrix)
1401
            data = np.rollaxis(data, -1, loc)
1402
            energy_axis = edisp.axes["energy"].copy(name="energy")
1403
        else:
1404
            data = self.data
1405
            energy_axis = self.geom.axes["energy_true"].copy(name="energy")
1406
1407
        geom = self.geom.to_image().to_cube(axes=[energy_axis])
1408
        return self._init_copy(geom=geom, data=data)
1409
1410
    def mask_nearest_position(self, position):
1411
        """Given a sky coordinate return nearest valid position in the mask
1412
1413
        If the mask contains additional axes, the mask is reduced over those.
1414
1415
        Parameters
1416
        ----------
1417
        position : `~astropy.coordinates.SkyCoord`
1418
            Test position
1419
1420
        Returns
1421
        -------
1422
        position : `~astropy.coordinates.SkyCoord`
1423
            Nearest position in the mask
1424
        """
1425
        if not self.geom.is_image:
1426
            raise ValueError("Method only supported for 2D images")
1427
1428
        coords = self.geom.to_image().get_coord().skycoord
1429
        separation = coords.separation(position)
1430
        separation[~self.data] = np.inf
1431
        idx = np.argmin(separation)
1432
        return coords.flatten()[idx]
1433
1434
    def sum_over_axes(self, axes_names=None, keepdims=True, weights=None):
1435
        """To sum map values over all non-spatial axes.
1436
1437
        Parameters
1438
        ----------
1439
        keepdims : bool, optional
1440
            If this is set to true, the axes which are summed over are left in
1441
            the map with a single bin
1442
        axes_names: list of str
1443
            Names of MapAxis to reduce over. If None, all will summed over
1444
        weights : `Map`
1445
            Weights to be applied. The Map should have the same geometry.
1446
1447
        Returns
1448
        -------
1449
        map_out : `~Map`
1450
            Map with non-spatial axes summed over
1451
        """
1452
        return self.reduce_over_axes(
1453
            func=np.add, axes_names=axes_names, keepdims=keepdims, weights=weights
1454
        )
1455
1456
    def reduce_over_axes(
1457
        self, func=np.add, keepdims=False, axes_names=None, weights=None
1458
    ):
1459
        """Reduce map over non-spatial axes
1460
1461
        Parameters
1462
        ----------
1463
        func : `~numpy.ufunc`
1464
            Function to use for reducing the data.
1465
        keepdims : bool, optional
1466
            If this is set to true, the axes which are summed over are left in
1467
            the map with a single bin
1468
        axes_names: list
1469
            Names of MapAxis to reduce over
1470
            If None, all will reduced
1471
        weights : `Map`
1472
            Weights to be applied.
1473
1474
        Returns
1475
        -------
1476
        map_out : `~Map`
1477
            Map with non-spatial axes reduced
1478
        """
1479
        if axes_names is None:
1480
            axes_names = self.geom.axes.names
1481
1482
        map_out = self.copy()
1483
        for axis_name in axes_names:
1484
            map_out = map_out.reduce(
1485
                axis_name, func=func, keepdims=keepdims, weights=weights
1486
            )
1487
1488
        return map_out
1489
1490
    def reduce(self, axis_name, func=np.add, keepdims=False, weights=None):
1491
        """Reduce map over a single non-spatial axis
1492
1493
        Parameters
1494
        ----------
1495
        axis_name: str
1496
            The name of the axis to reduce over
1497
        func : `~numpy.ufunc`
1498
            Function to use for reducing the data.
1499
        keepdims : bool, optional
1500
            If this is set to true, the axes which are summed over are left in
1501
            the map with a single bin
1502
        weights : `Map`
1503
            Weights to be applied.
1504
1505
        Returns
1506
        -------
1507
        map_out : `~Map`
1508
            Map with the given non-spatial axes reduced
1509
        """
1510
        if keepdims:
1511
            geom = self.geom.squash(axis_name=axis_name)
1512
        else:
1513
            geom = self.geom.drop(axis_name=axis_name)
1514
1515
        idx = self.geom.axes.index_data(axis_name)
1516
1517
        data = self.data
1518
1519
        if weights is not None:
1520
            data = data * weights
1521
1522
        data = func.reduce(data, axis=idx, keepdims=keepdims, where=~np.isnan(data))
1523
        return self._init_copy(geom=geom, data=data)
1524
1525
    def cumsum(self, axis_name):
1526
        """Compute cumulative sum along a given axis
1527
1528
        Parameters
1529
        ----------
1530
        axis_name : str
1531
            Along which axis to integrate.
1532
1533
        Returns
1534
        -------
1535
        cumsum : `Map`
1536
            Map with cumulative sum
1537
        """
1538
        axis = self.geom.axes[axis_name]
1539
        axis_idx = self.geom.axes.index_data(axis_name)
1540
1541
        # TODO: the broadcasting should be done by axis.center, axis.bin_width etc.
1542
        shape = [1] * len(self.geom.data_shape)
1543
        shape[axis_idx] = -1
1544
1545
        values = self.quantity * axis.bin_width.reshape(shape)
1546
1547
        if axis_name == "rad":
1548
            # take Jacobian into account
1549
            values = 2 * np.pi * axis.center.reshape(shape) * values
1550
1551
        data = np.insert(values.cumsum(axis=axis_idx), 0, 0, axis=axis_idx)
1552
1553
        axis_shifted = MapAxis.from_nodes(
1554
            axis.edges, name=axis.name, interp=axis.interp
1555
        )
1556
        axes = self.geom.axes.replace(axis_shifted)
1557
        geom = self.geom.to_image().to_cube(axes)
1558
        return self.__class__(geom=geom, data=data.value, unit=data.unit)
1559
1560
    def integral(self, axis_name, coords, **kwargs):
1561
        """Compute integral along a given axis
1562
1563
        This method uses interpolation of the cumulative sum.
1564
1565
        Parameters
1566
        ----------
1567
        axis_name : str
1568
            Along which axis to integrate.
1569
        coords : dict or `MapCoord`
1570
            Map coordinates
1571
1572
        **kwargs : dict
1573
            Coordinates at which to evaluate the IRF
1574
1575
        Returns
1576
        -------
1577
        array : `~astropy.units.Quantity`
1578
            Returns 2D array with axes offset
1579
        """
1580
        cumsum = self.cumsum(axis_name=axis_name)
1581
        cumsum = cumsum.pad(pad_width=1, axis_name=axis_name, mode="edge")
1582
        return u.Quantity(
1583
            cumsum.interp_by_coord(coords, **kwargs), cumsum.unit, copy=False
1584
        )
1585
1586
    def normalize(self, axis_name=None):
1587
        """Normalise data in place along a given axis.
1588
1589
        Parameters
1590
        ----------
1591
        axis_name : str
1592
            Along which axis to normalize.
1593
1594
        """
1595
        cumsum = self.cumsum(axis_name=axis_name).quantity
1596
1597
        with np.errstate(invalid="ignore", divide="ignore"):
1598
            axis = self.geom.axes.index_data(axis_name=axis_name)
1599
            normed = self.quantity / cumsum.max(axis=axis, keepdims=True)
1600
1601
        self.quantity = np.nan_to_num(normed)
1602
1603
    @classmethod
1604
    def from_stack(cls, maps, axis=None, axis_name=None):
1605
        """Create Map from list of images and a non-spatial axis.
1606
1607
        The image geometries must be aligned, except for the axis that is stacked.
1608
1609
        Parameters
1610
        ----------
1611
        maps : list of `Map` objects
1612
            List of maps
1613
        axis : `MapAxis`
1614
            If a `MapAxis` is provided the maps are stacked along the last data
1615
            axis and the new axis is introduced.
1616
        axis_name : str
1617
            If an axis name is as string the given the maps are stacked along
1618
            the given axis name.
1619
1620
        Returns
1621
        -------
1622
        map : `Map`
1623
            Map with additional non-spatial axis.
1624
        """
1625
        geom = maps[0].geom
1626
1627
        if axis_name is None and axis is None:
1628
            axis_name = geom.axes.names[-1]
1629
1630
        if axis_name:
1631
            axis = MapAxis.from_stack(axes=[m.geom.axes[axis_name] for m in maps])
1632
            geom = geom.drop(axis_name=axis_name)
1633
1634
        data = []
1635
1636
        for m in maps:
1637
            if axis_name:
1638
                m_geom = m.geom.drop(axis_name=axis_name)
1639
            else:
1640
                m_geom = m.geom
1641
1642
            if not m_geom == geom:
1643
                raise ValueError(f"Image geometries not aligned: {m.geom} and {geom}")
1644
1645
            data.append(m.quantity.to_value(maps[0].unit))
1646
1647
        return cls.from_geom(
1648
            data=np.stack(data), geom=geom.to_cube(axes=[axis]), unit=maps[0].unit
1649
        )
1650
1651
    def split_by_axis(self, axis_name):
1652
        """Split a Map along an axis into multiple maps.
1653
1654
        Parameters
1655
        ----------
1656
        axis_name : str
1657
            Name of the axis to split
1658
1659
        Returns
1660
        -------
1661
        maps : list
1662
            A list of `~gammapy.maps.Map`
1663
        """
1664
        maps = []
1665
        axis = self.geom.axes[axis_name]
1666
        for idx in range(axis.nbin):
1667
            m = self.slice_by_idx({axis_name: idx})
1668
            maps.append(m)
1669
        return maps
1670
1671
    def to_cube(self, axes):
1672
        """Append non-spatial axes to create a higher-dimensional Map.
1673
1674
        This will result in a Map with a new geometry with
1675
        N+M dimensions where N is the number of current dimensions and
1676
        M is the number of axes in the list. The data is reshaped onto the
1677
        new geometry
1678
1679
        Parameters
1680
        ----------
1681
        axes : list
1682
            Axes that will be appended to this Map.
1683
            The axes should have only one bin
1684
1685
        Returns
1686
        -------
1687
        map : `~gammapy.maps.WcsNDMap`
1688
            new map
1689
        """
1690
        for ax in axes:
1691
            if ax.nbin > 1:
1692
                raise ValueError(ax.name, "should have only one bin")
1693
        geom = self.geom.to_cube(axes)
1694
        data = self.data.reshape((1,) * len(axes) + self.data.shape)
1695
        return self.from_geom(data=data, geom=geom, unit=self.unit)
1696
1697
    def get_spectrum(self, region=None, func=np.nansum, weights=None):
1698
        """Extract spectrum in a given region.
1699
1700
        The spectrum can be computed by summing (or, more generally, applying ``func``)
1701
        along the spatial axes in each energy bin. This occurs only inside the ``region``,
1702
        which by default is assumed to be the whole spatial extension of the map.
1703
1704
        Parameters
1705
        ----------
1706
        region: `~regions.Region`
1707
             Region (pixel or sky regions accepted).
1708
        func : numpy.func
1709
            Function to reduce the data. Default is np.nansum.
1710
            For a boolean Map, use np.any or np.all.
1711
        weights : `WcsNDMap`
1712
            Array to be used as weights. The geometry must be equivalent.
1713
1714
        Returns
1715
        -------
1716
        spectrum : `~gammapy.maps.RegionNDMap`
1717
            Spectrum in the given region.
1718
        """
1719
        if not self.geom.has_energy_axis:
1720
            raise ValueError("Energy axis required")
1721
1722
        return self.to_region_nd_map(region=region, func=func, weights=weights)
1723
1724
    def to_unit(self, unit):
1725
        """Convert map to different unit
1726
1727
        Parameters
1728
        ----------
1729
        unit : `~astropy.unit.Unit` or str
1730
            New unit
1731
1732
        Returns
1733
        -------
1734
        map : `Map`
1735
            Map with new unit and converted data
1736
        """
1737
        data = self.quantity.to_value(unit)
1738
        return self.from_geom(self.geom, data=data, unit=unit)
1739
1740 View Code Duplication
    def is_allclose(self, other, rtol_axes=1e-3, atol_axes=1e-6, **kwargs):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
1741
        """Compare two Maps for close equivalency
1742
1743
        Parameters
1744
        ----------
1745
        other : `gammapy.maps.Map`
1746
            The Map to compare against
1747
        rtol_axes : float
1748
            Relative tolerance for the axes comparison.
1749
        atol_axes : float
1750
            Relative tolerance for the axes comparison.
1751
        **kwargs : dict
1752
                keywords passed to `numpy.allclose`
1753
1754
        Returns
1755
        -------
1756
        is_allclose : bool
1757
            Whether the Map is all close.
1758
        """
1759
        if not isinstance(other, self.__class__):
1760
            return TypeError(f"Cannot compare {type(self)} and {type(other)}")
1761
1762
        if self.data.shape != other.data.shape:
1763
            return False
1764
1765
        axes_eq = self.axes.is_allclose(other.axes, rtol=rtol_axes, atol=atol_axes)
1766
        data_eq = np.allclose(self.quantity, other.quantity, **kwargs)
1767
        return axes_eq and data_eq
1768
1769
    def __repr__(self):
1770
        geom = self.geom.__class__.__name__
1771
        axes = ["skycoord"] if self.geom.is_hpx else ["lon", "lat"]
1772
        axes = axes + [_.name for _ in self.geom.axes]
1773
1774
        return (
1775
            f"{self.__class__.__name__}\n\n"
1776
            f"\tgeom  : {geom} \n "
1777
            f"\taxes  : {axes}\n"
1778
            f"\tshape : {self.geom.data_shape[::-1]}\n"
1779
            f"\tndim  : {self.geom.ndim}\n"
1780
            f"\tunit  : {self.unit}\n"
1781
            f"\tdtype : {self.data.dtype}\n"
1782
        )
1783
1784
    def _arithmetics(self, operator, other, copy):
1785
        """Perform arithmetic on maps after checking geometry consistency."""
1786
        if isinstance(other, Map):
1787
            if self.geom == other.geom:
1788
                q = other.quantity
1789
            else:
1790
                raise ValueError("Map Arithmetic: Inconsistent geometries.")
1791
        else:
1792
            q = u.Quantity(other, copy=False)
1793
1794
        out = self.copy() if copy else self
1795
        out.quantity = operator(out.quantity, q)
1796
        return out
1797
1798
    def _boolean_arithmetics(self, operator, other, copy):
1799
        """Perform arithmetic on maps after checking geometry consistency."""
1800
        if operator == np.logical_not:
1801
            out = self.copy()
1802
            out.data = operator(out.data)
1803
            return out
1804
1805
        if isinstance(other, Map):
1806
            if self.geom == other.geom:
1807
                other = other.data
1808
            else:
1809
                raise ValueError("Map Arithmetic: Inconsistent geometries.")
1810
1811
        out = self.copy() if copy else self
1812
        out.data = operator(out.data, other)
1813
        return out
1814
1815
    def __add__(self, other):
1816
        return self._arithmetics(np.add, other, copy=True)
1817
1818
    def __iadd__(self, other):
1819
        return self._arithmetics(np.add, other, copy=False)
1820
1821
    def __sub__(self, other):
1822
        return self._arithmetics(np.subtract, other, copy=True)
1823
1824
    def __isub__(self, other):
1825
        return self._arithmetics(np.subtract, other, copy=False)
1826
1827
    def __mul__(self, other):
1828
        return self._arithmetics(np.multiply, other, copy=True)
1829
1830
    def __imul__(self, other):
1831
        return self._arithmetics(np.multiply, other, copy=False)
1832
1833
    def __truediv__(self, other):
1834
        return self._arithmetics(np.true_divide, other, copy=True)
1835
1836
    def __itruediv__(self, other):
1837
        return self._arithmetics(np.true_divide, other, copy=False)
1838
1839
    def __le__(self, other):
1840
        return self._arithmetics(np.less_equal, other, copy=True)
1841
1842
    def __lt__(self, other):
1843
        return self._arithmetics(np.less, other, copy=True)
1844
1845
    def __ge__(self, other):
1846
        return self._arithmetics(np.greater_equal, other, copy=True)
1847
1848
    def __gt__(self, other):
1849
        return self._arithmetics(np.greater, other, copy=True)
1850
1851
    def __eq__(self, other):
1852
        return self._arithmetics(np.equal, other, copy=True)
1853
1854
    def __ne__(self, other):
1855
        return self._arithmetics(np.not_equal, other, copy=True)
1856
1857
    def __and__(self, other):
1858
        return self._boolean_arithmetics(np.logical_and, other, copy=True)
1859
1860
    def __or__(self, other):
1861
        return self._boolean_arithmetics(np.logical_or, other, copy=True)
1862
1863
    def __invert__(self):
1864
        return self._boolean_arithmetics(np.logical_not, None, copy=True)
1865
1866
    def __xor__(self, other):
1867
        return self._boolean_arithmetics(np.logical_xor, other, copy=True)
1868
1869
    def __iand__(self, other):
1870
        return self._boolean_arithmetics(np.logical_and, other, copy=False)
1871
1872
    def __ior__(self, other):
1873
        return self._boolean_arithmetics(np.logical_or, other, copy=False)
1874
1875
    def __ixor__(self, other):
1876
        return self._boolean_arithmetics(np.logical_xor, other, copy=False)
1877
1878
    def __array__(self):
1879
        return self.data
1880
1881
    def sample_coord(self, n_events, random_state=0):
1882
        """Sample position and energy of events.
1883
1884
        Parameters
1885
        ----------
1886
        n_events : int
1887
            Number of events to sample.
1888
        random_state : {int, 'random-seed', 'global-rng', `~numpy.random.RandomState`}
1889
            Defines random number generator initialisation.
1890
            Passed to `~gammapy.utils.random.get_random_state`.
1891
1892
        Returns
1893
        -------
1894
        coords : `~gammapy.maps.MapCoord` object.
1895
            Sequence of coordinates and energies of the sampled events.
1896
        """
1897
1898
        random_state = get_random_state(random_state)
1899
        sampler = InverseCDFSampler(pdf=self.data, random_state=random_state)
1900
1901
        coords_pix = sampler.sample(n_events)
1902
        coords = self.geom.pix_to_coord(coords_pix[::-1])
1903
1904
        # TODO: pix_to_coord should return a MapCoord object
1905
        cdict = OrderedDict(zip(self.geom.axes_names, coords))
1906
1907
        return MapCoord.create(cdict, frame=self.geom.frame)
1908