Passed
Pull Request — master (#1898)
by
unknown
04:26
created

gammapy/maps/base.py (1 issue)

1
# Licensed under a 3-clause BSD style license - see LICENSE.rst
0 ignored issues
show
Too many lines in module (1070/1000)
Loading history...
2
from __future__ import absolute_import, division, print_function, unicode_literals
3
import abc
4
import copy
5
import inspect
6
import json
7
import numpy as np
8
from collections import OrderedDict
9
from astropy import units as u
10
from astropy.utils.misc import InheritDocstrings
11
from astropy.io import fits
12
from .geom import pix_tuple_to_idx, MapCoord
13
from .utils import unpack_seq
14
from ..extern import six
15
from ..utils.scripts import make_path
16
17
__all__ = ["Map"]
18
19
20
class MapMeta(InheritDocstrings, abc.ABCMeta):
21
    pass
22
23
24
@six.add_metaclass(MapMeta)
25
class Map(object):
26
    """Abstract map class.
27
28
    This can represent WCS- or HEALPIX-based maps
29
    with 2 spatial dimensions and N non-spatial dimensions.
30
31
    Parameters
32
    ----------
33
    geom : `~gammapy.maps.MapGeom`
34
        Geometry
35
    data : `~numpy.ndarray`
36
        Data array
37
    meta : `~collections.OrderedDict`
38
        Dictionary to store meta data.
39
    unit : str or `~astropy.units.Unit`
40
        Data unit
41
    """
42
43
    def __init__(self, geom, data, meta=None, unit=""):
44
        self.geom = geom
45
        self.data = data
46
        self.unit = unit
47
48
        if meta is None:
49
            self.meta = {}
50
        else:
51
            self.meta = meta
52
53
    def _init_copy(self, **kwargs):
54
        """Init map instance by copying missing init arguments from self.
55
        """
56
        argnames = inspect.getargspec(self.__init__).args
57
        argnames.remove("self")
58
        argnames.remove("dtype")
59
60
        for arg in argnames:
61
            value = getattr(self, "_" + arg)
62
            kwargs.setdefault(arg, copy.deepcopy(value))
63
64
        return self.from_geom(**kwargs)
65
66
    @property
67
    def geom(self):
68
        """Map geometry (`~gammapy.maps.MapGeom`)"""
69
        return self._geom
70
71
    @geom.setter
72
    def geom(self, val):
73
        self._geom = val
74
75
    @property
76
    def data(self):
77
        """Data array (`~numpy.ndarray`)"""
78
        return self._data
79
80
    @data.setter
81
    def data(self, val):
82
        if val.shape != self.geom.data_shape:
83
            raise ValueError(
84
                "Shape {!r} does not match map data shape {!r}"
85
                "".format(val.shape, self.geom.data_shape)
86
            )
87
88
        if isinstance(val, u.Quantity):
89
            raise TypeError("Map data must be a Numpy array. Set unit separately")
90
91
        self._data = val
92
93
    @property
94
    def unit(self):
95
        """Map unit (`~astropy.units.Unit`)"""
96
        return self._unit
97
98
    @unit.setter
99
    def unit(self, val):
100
        self._unit = u.Unit(val)
101
102
    @property
103
    def meta(self):
104
        """Map meta (`~collections.OrderedDict`)"""
105
        return self._meta
106
107
    @meta.setter
108
    def meta(self, val):
109
        self._meta = OrderedDict(val)
110
111
    @property
112
    def quantity(self):
113
        """Map data times unit (`~astropy.units.Quantity`)"""
114
        return u.Quantity(self.data, self.unit, copy=False)
115
116
    @quantity.setter
117
    def quantity(self, val):
118
        val = u.Quantity(val, copy=False)
119
        self.data = val.value
120
        self.unit = val.unit
121
122
    @staticmethod
123
    def create(**kwargs):
124
        """Create an empty map object.
125
126
        This method accepts generic options listed below, as well as options
127
        for `HpxMap` and `WcsMap` objects. For WCS-specific options, see
128
        `WcsMap.create` and for HPX-specific options, see `HpxMap.create`.
129
130
        Parameters
131
        ----------
132
        coordsys : str
133
            Coordinate system, either Galactic ('GAL') or Equatorial
134
            ('CEL').
135
        map_type : {'wcs', 'wcs-sparse', 'hpx', 'hpx-sparse'}
136
            Map type.  Selects the class that will be used to
137
            instantiate the map.
138
        binsz : float or `~numpy.ndarray`
139
            Pixel size in degrees.
140
        skydir : `~astropy.coordinates.SkyCoord`
141
            Coordinate of map center.
142
        axes : list
143
            List of `~MapAxis` objects for each non-spatial dimension.
144
            If None then the map will be a 2D image.
145
        dtype : str
146
            Data type, default is 'float32'
147
        unit : str or `~astropy.units.Unit`
148
            Data unit.
149
        meta : `~collections.OrderedDict`
150
            Dictionary to store meta data.
151
152
        Returns
153
        -------
154
        map : `Map`
155
            Empty map object.
156
        """
157
        from .hpxmap import HpxMap
158
        from .wcsmap import WcsMap
159
160
        map_type = kwargs.setdefault("map_type", "wcs")
161
        if "wcs" in map_type.lower():
162
            return WcsMap.create(**kwargs)
163
        elif "hpx" in map_type.lower():
164
            return HpxMap.create(**kwargs)
165
        else:
166
            raise ValueError("Unrecognized map type: {!r}".format(map_type))
167
168
    @staticmethod
169
    def read(filename, hdu=None, hdu_bands=None, map_type="auto"):
170
        """Read a map from a FITS file.
171
172
        Parameters
173
        ----------
174
        filename : str or `~pathlib.Path`
175
            Name of the FITS file.
176
        hdu : str
177
            Name or index of the HDU with the map data.
178
        hdu_bands : str
179
            Name or index of the HDU with the BANDS table.  If not
180
            defined this will be inferred from the FITS header of the
181
            map HDU.
182
        map_type : {'wcs', 'wcs-sparse', 'hpx', 'hpx-sparse', 'auto'}
183
            Map type.  Selects the class that will be used to
184
            instantiate the map.  The map type should be consistent
185
            with the format of the input file.  If map_type is 'auto'
186
            then an appropriate map type will be inferred from the
187
            input file.
188
189
        Returns
190
        -------
191
        map_out : `Map`
192
            Map object
193
        """
194
        filename = str(make_path(filename))
195
        with fits.open(filename, memmap=False) as hdulist:
196
            return Map.from_hdulist(hdulist, hdu, hdu_bands, map_type)
197
198
    @staticmethod
199
    def from_geom(geom, meta=None, data=None, map_type="auto", unit=""):
200
        """Generate an empty map from a `MapGeom` instance.
201
202
        Parameters
203
        ----------
204
        geom : `MapGeom`
205
            Map geometry.
206
        data : `numpy.ndarray`
207
            data array
208
        meta : `~collections.OrderedDict`
209
            Dictionary to store meta data.
210
        map_type : {'wcs', 'wcs-sparse', 'hpx', 'hpx-sparse', 'auto'}
211
            Map type.  Selects the class that will be used to
212
            instantiate the map. The map type should be consistent
213
            with the geometry. If map_type is 'auto' then an
214
            appropriate map type will be inferred from type of ``geom``.
215
        unit : str or `~astropy.units.Unit`
216
            Data unit.
217
218
        Returns
219
        -------
220
        map_out : `Map`
221
            Map object
222
223
        """
224
        if map_type == "auto":
225
226
            from .hpx import HpxGeom
227
            from .wcs import WcsGeom
228
229
            if isinstance(geom, HpxGeom):
230
                map_type = "hpx"
231
            elif isinstance(geom, WcsGeom):
232
                map_type = "wcs"
233
            else:
234
                raise ValueError("Unrecognized geom type.")
235
236
        cls_out = Map._get_map_cls(map_type)
237
        return cls_out(geom, data=data, meta=meta, unit=unit)
238
239
    @staticmethod
240
    def from_hdulist(hdulist, hdu=None, hdu_bands=None, map_type="auto"):
241
        """Create from `astropy.io.fits.HDUList`."""
242
        if map_type == "auto":
243
            map_type = Map._get_map_type(hdulist, hdu)
244
        cls_out = Map._get_map_cls(map_type)
245
        return cls_out.from_hdulist(hdulist, hdu=hdu, hdu_bands=hdu_bands)
246
247
    @staticmethod
248
    def _get_meta_from_header(header):
249
        """Load meta data from a FITS header."""
250
        if "META" in header:
251
            meta = json.loads(header["META"], object_pairs_hook=OrderedDict)
252
        else:
253
            meta = OrderedDict()
254
        return meta
255
256
    @staticmethod
257
    def _get_map_type(hdu_list, hdu_name):
258
        """Infer map type from a FITS HDU.
259
260
        Only read header, never data, to have good performance.
261
        """
262
        if hdu_name is None:
263
            # Find the header of the first non-empty HDU
264
            header = hdu_list[0].header
265
            if header["NAXIS"] == 0:
266
                header = hdu_list[1].header
267
        else:
268
            header = hdu_list[hdu_name].header
269
270
        if ("PIXTYPE" in header) and (header["PIXTYPE"] == "HEALPIX"):
271
            return "hpx"
272
        else:
273
            return "wcs"
274
275
    @staticmethod
276
    def _get_map_cls(map_type):
277
        """Get map class for given `map_type` string.
278
279
        This should probably be a registry dict so that users
280
        can add supported map types to the `gammapy.maps` I/O
281
        (see e.g. the Astropy table format I/O registry),
282
        but that's non-trivial to implement without avoiding circular imports.
283
        """
284
        if map_type == "wcs":
285
            from .wcsnd import WcsNDMap
286
287
            return WcsNDMap
288
        elif map_type == "wcs-sparse":
289
            raise NotImplementedError()
290
        elif map_type == "hpx":
291
            from .hpxnd import HpxNDMap
292
293
            return HpxNDMap
294
        elif map_type == "hpx-sparse":
295
            from .hpxsparse import HpxSparseMap
296
297
            return HpxSparseMap
298
        else:
299
            raise ValueError("Unrecognized map type: {!r}".format(map_type))
300
301
    def write(self, filename, overwrite=False, **kwargs):
302
        """Write to a FITS file.
303
304
        Parameters
305
        ----------
306
        filename : str
307
            Output file name.
308
        overwrite : bool
309
            Overwrite existing file?
310
        hdu : str
311
            Set the name of the image extension.  By default this will
312
            be set to SKYMAP (for BINTABLE HDU) or PRIMARY (for IMAGE
313
            HDU).
314
        hdu_bands : str
315
            Set the name of the bands table extension.  By default this will
316
            be set to BANDS.
317
        conv : str
318
            FITS format convention.  By default files will be written
319
            to the gamma-astro-data-formats (GADF) format.  This
320
            option can be used to write files that are compliant with
321
            format conventions required by specific software (e.g. the
322
            Fermi Science Tools).  Supported conventions are 'gadf',
323
            'fgst-ccube', 'fgst-ltcube', 'fgst-bexpcube',
324
            'fgst-template', 'fgst-srcmap', 'fgst-srcmap-sparse',
325
            'galprop', and 'galprop2'.
326
        sparse : bool
327
            Sparsify the map by dropping pixels with zero amplitude.
328
            This option is only compatible with the 'gadf' format.
329
        """
330
        hdulist = self.to_hdulist(**kwargs)
331
        hdulist.writeto(filename, overwrite=overwrite)
332
333
    def iter_by_image(self):
334
        """Iterate over image planes of the map returning a tuple with the image
335
        array and image plane index.
336
337
        The image plane index is in data order, so that the data array can be
338
        indexed directly. See :ref:`mapiter` for further information.
339
340
        Returns
341
        -------
342
        val : `~numpy.ndarray`
343
            Array of image plane values.
344
        idx : tuple
345
            Index of image plane.
346
        """
347
        for idx in np.ndindex(self.geom.shape):
348
            yield self.data[idx[::-1]], idx[::-1]
349
350
    def iter_by_pix(self, buffersize=1):
351
        """Iterate over elements of the map returning a tuple with values and
352
        pixel coordinates.
353
354
        Parameters
355
        ----------
356
        buffersize : int
357
            Set the size of the buffer.  The map will be returned in
358
            chunks of the given size.
359
360
        Returns
361
        -------
362
        val : `~numpy.ndarray`
363
            Map values.
364
        pix : tuple
365
            Tuple of pixel coordinates.
366
        """
367
        pix = list(self.geom.get_idx(flat=True))
368
        vals = self.data[np.isfinite(self.data)]
369
        x = [vals] + pix
370
        return unpack_seq(
371
            np.nditer(x, flags=["external_loop", "buffered"], buffersize=buffersize)
372
        )
373
374
    def iter_by_coord(self, buffersize=1):
375
        """Iterate over elements of the map returning a tuple with values and
376
        map coordinates.
377
378
        Parameters
379
        ----------
380
        buffersize : int
381
            Set the size of the buffer.  The map will be returned in
382
            chunks of the given size.
383
384
        Returns
385
        -------
386
        val : `~numpy.ndarray`
387
            Map values.
388
        coords : tuple
389
            Tuple of map coordinates.
390
        """
391
        coords = list(self.geom.get_coord(flat=True))
392
        vals = self.data[np.isfinite(self.data)]
393
        x = [vals] + coords
394
        return unpack_seq(
395
            np.nditer(x, flags=["external_loop", "buffered"], buffersize=buffersize)
396
        )
397
398
    @abc.abstractmethod
399
    def sum_over_axes(self):
400
        """Reduce to a 2D image by summing over non-spatial dimensions."""
401
        pass
402
403
    def coadd(self, map_in):
404
        """Add the contents of ``map_in`` to this map.  This method can be
405
        used to combine maps containing integral quantities (e.g. counts)
406
        or differential quantities if the maps have the same binning.
407
408
        Parameters
409
        ----------
410
        map_in : `Map`
411
            Input map.
412
        """
413
        if not self.unit.is_equivalent(map_in.unit):
414
            raise ValueError("Incompatible units")
415
416
        # TODO: Check whether geometries are aligned and if so sum the
417
        # data vectors directly
418
        idx = map_in.geom.get_idx()
419
        coords = map_in.geom.get_coord()
420
        vals = u.Quantity(map_in.get_by_idx(idx), map_in.unit)
421
        self.fill_by_coord(coords, vals)
422
423
    def reproject(self, geom, order=1, mode="interp"):
424
        """Reproject this map to a different geometry.
425
426
        Only spatial axes are reprojected, if you would like to reproject
427
        non-spatial axes consider using `Map.interp_by_coord()` instead.
428
429
        Parameters
430
        ----------
431
        geom : `MapGeom`
432
            Geometry of projection.
433
        mode : {'interp', 'exact'}
434
            Method for reprojection.  'interp' method interpolates at pixel
435
            centers.  'exact' method integrates over intersection of pixels.
436
        order : int or str
437
            Order of interpolating polynomial (0 = nearest-neighbor, 1 =
438
            linear, 2 = quadratic, 3 = cubic).
439
440
        Returns
441
        -------
442
        map : `Map`
443
            Reprojected map.
444
        """
445
        if geom.is_image:
446
            axes = [ax.copy() for ax in self.geom.axes]
447
            geom = geom.copy(axes=axes)
448
        else:
449
            axes_eq = geom.ndim == self.geom.ndim
450
            axes_eq &= np.all(
451
                [ax0 == ax1 for ax0, ax1 in zip(geom.axes, self.geom.axes)]
452
            )
453
454
            if not axes_eq:
455
                raise ValueError(
456
                    "Map and target geometry non-spatial axes must match."
457
                    "Use interp_by_coord to interpolate in non-spatial axes."
458
                )
459
460
        if geom.is_hpx:
461
            return self._reproject_to_hpx(geom, mode=mode, order=order)
462
        else:
463
            return self._reproject_to_wcs(geom, mode=mode, order=order)
464
465
    @abc.abstractmethod
466
    def pad(self, pad_width, mode="constant", cval=0, order=1):
467
        """Pad the spatial dimension of the map by extending the edge of the
468
        map by the given number of pixels.
469
470
        Parameters
471
        ----------
472
        pad_width : {sequence, array_like, int}
473
            Number of pixels padded to the edges of each axis.
474
        mode : {'edge', 'constant', 'interp'}
475
            Padding mode.  'edge' pads with the closest edge value.
476
            'constant' pads with a constant value. 'interp' pads with
477
            an extrapolated value.
478
        cval : float
479
            Padding value when mode='consant'.
480
        order : int
481
            Order of interpolation when mode='constant' (0 =
482
            nearest-neighbor, 1 = linear, 2 = quadratic, 3 = cubic).
483
484
        Returns
485
        -------
486
        map : `Map`
487
            Padded map.
488
489
        """
490
        pass
491
492
    @abc.abstractmethod
493
    def crop(self, crop_width):
494
        """Crop the spatial dimension of the map by removing a number of
495
        pixels from the edge of the map.
496
497
        Parameters
498
        ----------
499
        crop_width : {sequence, array_like, int}
500
            Number of pixels cropped from the edges of each axis.
501
            Defined analogously to ``pad_with`` from `numpy.pad`.
502
503
        Returns
504
        -------
505
        map : `Map`
506
            Cropped map.
507
        """
508
        pass
509
510
    @abc.abstractmethod
511
    def downsample(self, factor, preserve_counts=True):
512
        """Downsample the spatial dimension by a given factor.
513
514
        Parameters
515
        ----------
516
        factor : int
517
            Downsampling factor.
518
        preserve_counts : bool
519
            Preserve the integral over each bin.  This should be true
520
            if the map is an integral quantity (e.g. counts) and false if
521
            the map is a differential quantity (e.g. intensity).
522
523
        Returns
524
        -------
525
        map : `Map`
526
            Downsampled map.
527
        """
528
        pass
529
530
    @abc.abstractmethod
531
    def upsample(self, factor, order=0, preserve_counts=True):
532
        """Upsample the spatial dimension by a given factor.
533
534
        Parameters
535
        ----------
536
        factor : int
537
            Upsampling factor.
538
        order : int
539
            Order of the interpolation used for upsampling.
540
        preserve_counts : bool
541
            Preserve the integral over each bin.  This should be true
542
            if the map is an integral quantity (e.g. counts) and false if
543
            the map is a differential quantity (e.g. intensity).
544
545
        Returns
546
        -------
547
        map : `Map`
548
            Upsampled map.
549
550
        """
551
        pass
552
553
    def slice_by_idx(self, slices):
554
        """Slice sub map from map object.
555
556
        For usage examples, see :ref:`mapslicing`.
557
558
        Parameters
559
        ----------
560
        slices : dict
561
            Dict of axes names and integers or `slice` object pairs. Contains one
562
            element for each non-spatial dimension. For integer indexing the
563
            corresponding axes is dropped from the map. Axes not specified in the
564
            dict are kept unchanged.
565
566
        Returns
567
        -------
568
        map_out : `Map`
569
            Sliced map object.
570
        """
571
        geom = self.geom.slice_by_idx(slices)
572
        slices = tuple([slices.get(ax.name, slice(None)) for ax in self.geom.axes])
573
        data = self.data[slices[::-1]]
574
        return self.__class__(geom=geom, data=data, unit=self.unit, meta=self.meta)
575
576
    def get_image_by_coord(self, coords):
577
        """Return spatial map at the given axis coordinates.
578
579
        Parameters
580
        ----------
581
        coords : tuple or dict
582
            Tuple should be ordered as (x_0, ..., x_n) where x_i are coordinates
583
            for non-spatial dimensions of the map. Dict should specify the axis
584
            names of the non-spatial axes such as {'axes0': x_0, ..., 'axesn': x_n}.
585
586
        Examples
587
        --------
588
589
        ::
590
591
            import numpy as np
592
            from gammapy.maps import Map, MapAxis
593
            from astropy.coordinates import SkyCoord
594
            from astropy import units as u
595
596
            # Define map axes
597
            energy_axis = MapAxis.from_edges(
598
                np.logspace(-1., 1., 4), unit='TeV', name='energy',
599
            )
600
601
            time_axis = MapAxis.from_edges(
602
                np.linspace(0., 10, 20), unit='h', name='time',
603
            )
604
605
            # Define map center
606
            skydir = SkyCoord(0, 0, frame='galactic', unit='deg')
607
608
            # Create map
609
            m_wcs = Map.create(
610
                map_type='wcs',
611
                binsz=0.02,
612
                skydir=skydir,
613
                width=10.0,
614
                axes=[energy_axis, time_axis],
615
            )
616
617
            # Get image by coord tuple
618
            image = m_wcs.get_image_by_coord(('500 GeV', '1 h'))
619
620
            # Get image by coord dict with strings
621
            image = m_wcs.get_image_by_coord({'energy': '500 GeV', 'time': '1 h'})
622
623
            # Get image by coord dict with quantities
624
            image = m_wcs.get_image_by_coord({'energy': 0.5 * u.TeV, 'time': 1 * u.h})
625
626
        See Also
627
        --------
628
        get_image_by_idx, get_image_by_pix
629
630
        Returns
631
        -------
632
        map_out : `Map`
633
            Map with spatial dimensions only.
634
        """
635
        if isinstance(coords, tuple):
636
            axes_names = [_.name for _ in self.geom.axes]
637
            coords = OrderedDict(zip(axes_names, coords))
638
639
        idx = []
640
        for ax in self.geom.axes:
641
            value = coords[ax.name]
642
            idx.append(ax.coord_to_idx(value))
643
644
        return self.get_image_by_idx(idx)
645
646
    def get_image_by_pix(self, pix):
647
        """Return spatial map at the given axis pixel coordinates
648
649
        Parameters
650
        ----------
651
        pix : tuple
652
            Tuple of scalar pixel coordinates for each non-spatial dimension of
653
            the map. Tuple should be ordered as (I_0, ..., I_n). Pixel coordinates
654
            can be either float or integer type.
655
656
        See Also
657
        --------
658
        get_image_by_coord, get_image_by_idx
659
660
        Returns
661
        -------
662
        map_out : `Map`
663
            Map with spatial dimensions only.
664
        """
665
        idx = self.geom.pix_to_idx(pix)
666
        return self.get_image_by_idx(idx)
667
668
    def get_image_by_idx(self, idx):
669
        """Return spatial map at the given axis pixel indices.
670
671
        Parameters
672
        ----------
673
        idx : tuple
674
            Tuple of scalar indices for each non spatial dimension of the map.
675
            Tuple should be ordered as (I_0, ..., I_n).
676
677
        See Also
678
        --------
679
        get_image_by_coord, get_image_by_pix
680
681
        Returns
682
        -------
683
        map_out : `Map`
684
            Map with spatial dimensions only.
685
        """
686
        if len(idx) != len(self.geom.axes):
687
            raise ValueError("Tuple length must equal number of non-spatial dimensions")
688
689
        # Only support scalar indices per axis
690
        idx = tuple([int(_) for _ in idx])
691
692
        geom = self.geom.to_image()
693
        data = self.data[idx[::-1]]
694
        return self.__class__(geom=geom, data=data, unit=self.unit, meta=self.meta)
695
696
    def get_by_coord(self, coords):
697
        """Return map values at the given map coordinates.
698
699
        Parameters
700
        ----------
701
        coords : tuple or `~gammapy.maps.MapCoord`
702
            Coordinate arrays for each dimension of the map.  Tuple
703
            should be ordered as (lon, lat, x_0, ..., x_n) where x_i
704
            are coordinates for non-spatial dimensions of the map.
705
706
        Returns
707
        -------
708
        vals : `~numpy.ndarray`
709
           Values of pixels in the map.  np.nan used to flag coords
710
           outside of map.
711
        """
712
        coords = MapCoord.create(coords, coordsys=self.geom.coordsys)
713
        msk = self.geom.contains(coords)
714
        vals = np.empty(coords.shape, dtype=self.data.dtype)
715
        coords = coords.apply_mask(msk)
716
        idx = self.geom.coord_to_idx(coords)
717
        vals[msk] = self.get_by_idx(idx)
718
        vals[~msk] = np.nan
719
        return vals
720
721
    def get_by_pix(self, pix):
722
        """Return map values at the given pixel coordinates.
723
724
        Parameters
725
        ----------
726
        pix : tuple
727
            Tuple of pixel index arrays for each dimension of the map.
728
            Tuple should be ordered as (I_lon, I_lat, I_0, ..., I_n)
729
            for WCS maps and (I_hpx, I_0, ..., I_n) for HEALPix maps.
730
            Pixel indices can be either float or integer type.
731
732
        Returns
733
        ----------
734
        vals : `~numpy.ndarray`
735
           Array of pixel values.  np.nan used to flag coordinates
736
           outside of map
737
        """
738
        # FIXME: Support local indexing here?
739
        # FIXME: Support slicing?
740
        pix = [np.array(p, copy=False, ndmin=1) for p in pix]
741
        pix = np.broadcast_arrays(*pix)
742
        msk = self.geom.contains_pix(pix)
743
        vals = np.empty(pix[0].shape, dtype=self.data.dtype)
744
        pix = tuple([p[msk] for p in pix])
745
        idx = self.geom.pix_to_idx(pix)
746
        vals[msk] = self.get_by_idx(idx)
747
        vals[~msk] = np.nan
748
        return vals
749
750
    @abc.abstractmethod
751
    def get_by_idx(self, idx):
752
        """Return map values at the given pixel indices.
753
754
        Parameters
755
        ----------
756
        idx : tuple
757
            Tuple of pixel index arrays for each dimension of the map.
758
            Tuple should be ordered as (I_lon, I_lat, I_0, ..., I_n)
759
            for WCS maps and (I_hpx, I_0, ..., I_n) for HEALPix maps.
760
761
        Returns
762
        ----------
763
        vals : `~numpy.ndarray`
764
           Array of pixel values.
765
           np.nan used to flag coordinate outside of map
766
        """
767
        pass
768
769
    @abc.abstractmethod
770
    def interp_by_coord(self, coords, interp=None, fill_value=None):
771
        """Interpolate map values at the given map coordinates.
772
773
        Parameters
774
        ----------
775
        coords : tuple or `~gammapy.maps.MapCoord`
776
            Coordinate arrays for each dimension of the map.  Tuple
777
            should be ordered as (lon, lat, x_0, ..., x_n) where x_i
778
            are coordinates for non-spatial dimensions of the map.
779
780
        interp : {None, 'nearest', 'linear', 'cubic', 0, 1, 2, 3}
781
            Method to interpolate data values.  By default no
782
            interpolation is performed and the return value will be
783
            the amplitude of the pixel encompassing the given
784
            coordinate.  Integer values can be used in lieu of strings
785
            to choose the interpolation method of the given order
786
            (0='nearest', 1='linear', 2='quadratic', 3='cubic').  Note
787
            that only 'nearest' and 'linear' methods are supported for
788
            all map types.
789
        fill_value : None or float value
790
            The value to use for points outside of the interpolation domain.
791
            If None, values outside the domain are extrapolated.
792
793
        Returns
794
        -------
795
        vals : `~numpy.ndarray`
796
            Interpolated pixel values.
797
        """
798
        pass
799
800
    @abc.abstractmethod
801
    def interp_by_pix(self, pix, interp=None, fill_value=None):
802
        """Interpolate map values at the given pixel coordinates.
803
804
        Parameters
805
        ----------
806
        pix : tuple
807
            Tuple of pixel coordinate arrays for each dimension of the
808
            map.  Tuple should be ordered as (p_lon, p_lat, p_0, ...,
809
            p_n) where p_i are pixel coordinates for non-spatial
810
            dimensions of the map.
811
812
        interp : {None, 'nearest', 'linear', 'cubic', 0, 1, 2, 3}
813
            Method to interpolate data values.  By default no
814
            interpolation is performed and the return value will be
815
            the amplitude of the pixel encompassing the given
816
            coordinate.  Integer values can be used in lieu of strings
817
            to choose the interpolation method of the given order
818
            (0='nearest', 1='linear', 2='quadratic', 3='cubic').  Note
819
            that only 'nearest' and 'linear' methods are supported for
820
            all map types.
821
        fill_value : None or float value
822
            The value to use for points outside of the interpolation domain.
823
            If None, values outside the domain are extrapolated.
824
825
        Returns
826
        -------
827
        vals : `~numpy.ndarray`
828
            Interpolated pixel values.
829
        """
830
        pass
831
832
    def fill_by_coord(self, coords, weights=None):
833
        """Fill pixels at ``coords`` with given ``weights``.
834
835
        Parameters
836
        ----------
837
        coords : tuple or `~gammapy.maps.MapCoord`
838
            Coordinate arrays for each dimension of the map.  Tuple
839
            should be ordered as (lon, lat, x_0, ..., x_n) where x_i
840
            are coordinates for non-spatial dimensions of the map.
841
        weights : `~numpy.ndarray`
842
            Weights vector. Default is weight of one.
843
        """
844
        idx = self.geom.coord_to_idx(coords)
845
        self.fill_by_idx(idx, weights)
846
847
    def fill_by_pix(self, pix, weights=None):
848
        """Fill pixels at ``pix`` with given ``weights``.
849
850
        Parameters
851
        ----------
852
        pix : tuple
853
            Tuple of pixel index arrays for each dimension of the map.
854
            Tuple should be ordered as (I_lon, I_lat, I_0, ..., I_n)
855
            for WCS maps and (I_hpx, I_0, ..., I_n) for HEALPix maps.
856
            Pixel indices can be either float or integer type.  Float
857
            indices will be rounded to the nearest integer.
858
        weights : `~numpy.ndarray`
859
            Weights vector. Default is weight of one.
860
        """
861
        idx = pix_tuple_to_idx(pix)
862
        return self.fill_by_idx(idx, weights=weights)
863
864
    @abc.abstractmethod
865
    def fill_by_idx(self, idx, weights=None):
866
        """Fill pixels at ``idx`` with given ``weights``.
867
868
        Parameters
869
        ----------
870
        idx : tuple
871
            Tuple of pixel index arrays for each dimension of the map.
872
            Tuple should be ordered as (I_lon, I_lat, I_0, ..., I_n)
873
            for WCS maps and (I_hpx, I_0, ..., I_n) for HEALPix maps.
874
        weights : `~numpy.ndarray`
875
            Weights vector. Default is weight of one.
876
        """
877
        pass
878
879
    def set_by_coord(self, coords, vals):
880
        """Set pixels at ``coords`` with given ``vals``.
881
882
        Parameters
883
        ----------
884
        coords : tuple or `~gammapy.maps.MapCoord`
885
            Coordinate arrays for each dimension of the map.  Tuple
886
            should be ordered as (lon, lat, x_0, ..., x_n) where x_i
887
            are coordinates for non-spatial dimensions of the map.
888
        vals : `~numpy.ndarray`
889
            Values vector.
890
        """
891
        idx = self.geom.coord_to_pix(coords)
892
        self.set_by_pix(idx, vals)
893
894
    def set_by_pix(self, pix, vals):
895
        """Set pixels at ``pix`` with given ``vals``.
896
897
        Parameters
898
        ----------
899
        pix : tuple
900
            Tuple of pixel index arrays for each dimension of the map.
901
            Tuple should be ordered as (I_lon, I_lat, I_0, ..., I_n)
902
            for WCS maps and (I_hpx, I_0, ..., I_n) for HEALPix maps.
903
            Pixel indices can be either float or integer type.  Float
904
            indices will be rounded to the nearest integer.
905
        vals : `~numpy.ndarray`
906
            Values vector.
907
        """
908
        idx = pix_tuple_to_idx(pix)
909
        return self.set_by_idx(idx, vals)
910
911
    @abc.abstractmethod
912
    def set_by_idx(self, idx, vals):
913
        """Set pixels at ``idx`` with given ``vals``.
914
915
        Parameters
916
        ----------
917
        idx : tuple
918
            Tuple of pixel index arrays for each dimension of the map.
919
            Tuple should be ordered as (I_lon, I_lat, I_0, ..., I_n)
920
            for WCS maps and (I_hpx, I_0, ..., I_n) for HEALPix maps.
921
        vals : `~numpy.ndarray`
922
            Values vector.
923
        """
924
        pass
925
926
    def plot_interactive(self, rc_params=None, **kwargs):
927
        """
928
        Plot map with interactive widgets to explore the non spatial axes.
929
930
        Parameters
931
        ----------
932
        rc_params : dict
933
            Passed to ``matplotlib.rc_context(rc=rc_params)`` to style the plot.
934
        **kwargs : dict
935
            Keyword arguments passed to `WcsNDMap.plot`.
936
937
        Examples
938
        --------
939
        You can try this out e.g. using a Fermi-LAT diffuse model cube with an energy axis::
940
941
            from gammapy.maps import Map
942
943
            m = Map.read("$GAMMAPY_EXTRA/datasets/vela_region/gll_iem_v05_rev1_cutout.fits")
944
            m.plot_interactive(cmap='gnuplot2')
945
946
        If you would like to adjust the figure size you can use the ``rc_params`` argument::
947
948
            rc_params = {'figure.figsize': (12, 6), 'font.size': 12}
949
            m.plot_interactive(rc_params=rc_params)
950
        """
951
        import matplotlib as mpl
952
        import matplotlib.pyplot as plt
953
        from ipywidgets.widgets.interaction import interact, fixed
954
        from ipywidgets import SelectionSlider, RadioButtons
955
956
        if self.geom.is_image:
957
            raise TypeError("Use .plot() for 2D Maps")
958
959
        kwargs.setdefault("interpolation", "nearest")
960
        kwargs.setdefault("origin", "lower")
961
        kwargs.setdefault("cmap", "afmhot")
962
963
        rc_params = rc_params or {}
964
        stretch = kwargs.pop("stretch", "sqrt")
965
966
        interact_kwargs = {}
967
968
        for axis in self.geom.axes:
969
            if axis.node_type == "edges":
970
                options = [
971
                    "{:.2e} - {:.2e} {}".format(val_min, val_max, axis.unit)
972
                    for val_min, val_max in zip(axis.edges[:-1], axis.edges[1:])
973
                ]
974
            else:
975
                options = ["{:.2e} {}".format(val, axis.unit) for val in axis.center]
976
977
            interact_kwargs[axis.name] = SelectionSlider(
978
                options=options,
979
                description="Select {}:".format(axis.name),
980
                continuous_update=False,
981
                style={"description_width": "initial"},
982
                layout={"width": "50%"},
983
            )
984
            interact_kwargs[axis.name + "_options"] = fixed(options)
985
986
        interact_kwargs["stretch"] = RadioButtons(
987
            options=["linear", "sqrt", "log"],
988
            value=stretch,
989
            description="Select stretch:",
990
            style={"description_width": "initial"},
991
        )
992
993
        @interact(**interact_kwargs)
994
        def _plot_interactive(**ikwargs):
995
            idx = [
996
                ikwargs[ax.name + "_options"].index(ikwargs[ax.name])
997
                for ax in self.geom.axes
998
            ]
999
            img = self.get_image_by_idx(idx)
1000
            stretch = ikwargs["stretch"]
1001
            with mpl.rc_context(rc=rc_params):
1002
                fig, ax, cbar = img.plot(stretch=stretch, **kwargs)
1003
                plt.show()
1004
1005
    def copy(self, **kwargs):
1006
        """Copy map instance and overwrite given attributes, except for geometry.
1007
1008
        Parameters
1009
        ----------
1010
        **kwargs : dict
1011
            Keyword arguments to overwrite in the map constructor.
1012
1013
        Returns
1014
        --------
1015
        copy : `Map`
1016
            Copied Map.
1017
        """
1018
        if "geom" in kwargs:
1019
            raise ValueError("Can't copy and change geometry of the map.")
1020
        return self._init_copy(**kwargs)
1021
1022
    def __repr__(self):
1023
        str_ = self.__class__.__name__
1024
        str_ += "\n\n"
1025
        geom = self.geom.__class__.__name__
1026
        str_ += "\tgeom  : {} \n ".format(geom)
1027
        axes = ["skycoord"] if self.geom.is_hpx else ["lon", "lat"]
1028
        axes = axes + [_.name for _ in self.geom.axes]
1029
        str_ += "\taxes  : {}\n".format(", ".join(axes))
1030
        str_ += "\tshape : {}\n".format(self.geom.data_shape[::-1])
1031
        str_ += "\tndim  : {}\n".format(self.geom.ndim)
1032
        str_ += "\tunit  : {!r} \n".format(str(self.unit))
1033
        str_ += "\tdtype : {} \n".format(self.data.dtype)
1034
        return str_
1035
1036
    def _arithmetics(self, operator, other, copy):
1037
        """ Perform arithmetics on maps after checking geometry consistency"""
1038
        if isinstance(other, Map):
1039
            # TODO: check consistency
1040
            q = other.quantity
1041
        else:
1042
            q = u.Quantity(other, copy=False)
1043
1044
        out = self.copy() if copy else self
1045
        out.quantity = operator(out.quantity, q)
1046
        return out
1047
1048
    def __add__(self, other):
1049
        return self._arithmetics(np.add, other, copy=True)
1050
1051
    def __iadd__(self, other):
1052
        return self._arithmetics(np.add, other, copy=False)
1053
1054
    def __sub__(self, other):
1055
        return self._arithmetics(np.subtract, other, copy=True)
1056
1057
    def __isub__(self, other):
1058
        return self._arithmetics(np.subtract, other, copy=False)
1059
1060
    def __mul__(self, other):
1061
        return self._arithmetics(np.multiply, other, copy=True)
1062
1063
    def __imul__(self, other):
1064
        return self._arithmetics(np.multiply, other, copy=False)
1065
1066
    def __truediv__(self, other):
1067
        return self._arithmetics(np.true_divide, other, copy=True)
1068
1069
    def __itruediv__(self, other):
1070
        return self._arithmetics(np.true_divide, other, copy=False)
1071