Passed
Pull Request — master (#2446)
by Axel
02:42
created

gammapy.maps.wcsnd.WcsNDMap.stack()   A

Complexity

Conditions 5

Size

Total Lines 27
Code Lines 13

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 13
dl 0
loc 27
rs 9.2833
c 0
b 0
f 0
cc 5
nop 3
1
# Licensed under a 3-clause BSD style license - see LICENSE.rst
2
import logging
3
from collections import OrderedDict
4
import numpy as np
5
import scipy.interpolate
6
import scipy.ndimage
7
import scipy.signal
8
import astropy.units as u
9
from astropy.convolution import Tophat2DKernel
10
from astropy.io import fits
11
from astropy.nddata import Cutout2D
12
from gammapy.extern.skimage import block_reduce
13
from gammapy.utils.interpolation import ScaledRegularGridInterpolator
14
from gammapy.utils.random import InverseCDFSampler, get_random_state
15
from gammapy.utils.units import unit_from_fits_image_hdu
16
from .geom import MapCoord, pix_tuple_to_idx
17
from .reproject import reproject_car_to_hpx, reproject_car_to_wcs
18
from .utils import INVALID_INDEX, interp_to_order
19
from .wcs import _check_width
20
from .wcsmap import WcsGeom, WcsMap
21
22
__all__ = ["WcsNDMap"]
23
24
log = logging.getLogger(__name__)
25
26
27
class WcsNDMap(WcsMap):
28
    """HEALPix map with any number of non-spatial dimensions.
29
30
    This class uses an ND numpy array to store map values. For maps with
31
    non-spatial dimensions and variable pixel size it will allocate an
32
    array with dimensions commensurate with the largest image plane.
33
34
    Parameters
35
    ----------
36
    geom : `~gammapy.maps.WcsGeom`
37
        WCS geometry object.
38
    data : `~numpy.ndarray`
39
        Data array. If none then an empty array will be allocated.
40
    dtype : str, optional
41
        Data type, default is float32
42
    meta : `dict`
43
        Dictionary to store meta data.
44
    unit : str or `~astropy.units.Unit`
45
        The map unit
46
    """
47
48
    def __init__(self, geom, data=None, dtype="float32", meta=None, unit=""):
49
        # TODO: Figure out how to mask pixels for integer data types
50
51
        data_shape = geom.data_shape
52
53
        if data is None:
54
            data = self._make_default_data(geom, data_shape, dtype)
55
56
        super().__init__(geom, data, meta, unit)
57
58
    @staticmethod
59
    def _make_default_data(geom, shape_np, dtype):
60
        # Check whether corners of each image plane are valid
61
62
        data = np.zeros(shape_np, dtype=dtype)
63
64
        if not geom.is_regular or geom.is_allsky:
65
            coords = geom.get_coord()
66
            is_nan = np.isnan(coords.lon)
67
            data[is_nan] = np.nan
68
69
        return data
70
71
    @classmethod
72
    def from_hdu(cls, hdu, hdu_bands=None):
73
        """Make a WcsNDMap object from a FITS HDU.
74
75
        Parameters
76
        ----------
77
        hdu : `~astropy.io.fits.BinTableHDU` or `~astropy.io.fits.ImageHDU`
78
            The map FITS HDU.
79
        hdu_bands : `~astropy.io.fits.BinTableHDU`
80
            The BANDS table HDU.
81
        """
82
        geom = WcsGeom.from_header(hdu.header, hdu_bands)
83
        shape = tuple([ax.nbin for ax in geom.axes])
84
        shape_wcs = tuple([np.max(geom.npix[0]), np.max(geom.npix[1])])
85
86
        meta = cls._get_meta_from_header(hdu.header)
87
        unit = unit_from_fits_image_hdu(hdu.header)
88
        map_out = cls(geom, meta=meta, unit=unit)
89
90
        # TODO: Should we support extracting slices?
91
        if isinstance(hdu, fits.BinTableHDU):
92
            pix = hdu.data.field("PIX")
93
            pix = np.unravel_index(pix, shape_wcs[::-1])
94
            vals = hdu.data.field("VALUE")
95
            if "CHANNEL" in hdu.data.columns.names and shape:
96
                chan = hdu.data.field("CHANNEL")
97
                chan = np.unravel_index(chan, shape[::-1])
98
                idx = chan + pix
99
            else:
100
                idx = pix
101
102
            map_out.set_by_idx(idx[::-1], vals)
103
        else:
104
            map_out.data = hdu.data
105
106
        return map_out
107
108
    def get_by_idx(self, idx):
109
        idx = pix_tuple_to_idx(idx)
110
        return self.data.T[idx]
111
112
    def interp_by_coord(self, coords, interp=None, fill_value=None):
113
114
        if self.geom.is_regular:
115
            pix = self.geom.coord_to_pix(coords)
116
            return self.interp_by_pix(pix, interp=interp, fill_value=fill_value)
117
        else:
118
            return self._interp_by_coord_griddata(coords, interp=interp)
119
120
    def interp_by_pix(self, pix, interp=None, fill_value=None):
121
        """Interpolate map values at the given pixel coordinates.
122
        """
123
        if not self.geom.is_regular:
124
            raise ValueError("interp_by_pix only supported for regular geom.")
125
126
        order = interp_to_order(interp)
127
        if order == 0 or order == 1:
128
            return self._interp_by_pix_linear_grid(
129
                pix, order=order, fill_value=fill_value
130
            )
131
        elif order == 2 or order == 3:
132
            return self._interp_by_pix_map_coordinates(pix, order=order)
133
        else:
134
            raise ValueError(f"Invalid interpolation order: {order!r}")
135
136
    def _interp_by_pix_linear_grid(self, pix, order=1, fill_value=None):
137
        # TODO: Cache interpolator
138
        method_lookup = {0: "nearest", 1: "linear"}
139
        try:
140
            method = method_lookup[order]
141
        except KeyError:
142
            raise ValueError(f"Invalid interpolation order: {order!r}")
143
144
        grid_pix = [np.arange(n, dtype=float) for n in self.data.shape[::-1]]
145
146
        if np.any(np.isfinite(self.data)):
147
            data = self.data.copy().T
148
            data[~np.isfinite(data)] = 0.0
149
        else:
150
            data = self.data.T
151
152
        fn = ScaledRegularGridInterpolator(
153
            grid_pix, data, fill_value=fill_value, bounds_error=False, method=method
154
        )
155
        return fn(tuple(pix), clip=False)
156
157
    def _interp_by_pix_map_coordinates(self, pix, order=1):
158
        pix = tuple(
159
            [
160
                np.array(x, ndmin=1)
161
                if not isinstance(x, np.ndarray) or x.ndim == 0
162
                else x
163
                for x in pix
164
            ]
165
        )
166
        return scipy.ndimage.map_coordinates(
167
            self.data.T, pix, order=order, mode="nearest"
168
        )
169
170
    def _interp_by_coord_griddata(self, coords, interp=None):
171
        order = interp_to_order(interp)
172
        method_lookup = {0: "nearest", 1: "linear", 3: "cubic"}
173
        method = method_lookup.get(order, None)
174
        if method is None:
175
            raise ValueError(f"Invalid interp: {interp!r}")
176
177
        grid_coords = tuple(self.geom.get_coord(flat=True))
178
        data = self.data[np.isfinite(self.data)]
179
        vals = scipy.interpolate.griddata(
180
            grid_coords, data, tuple(coords), method=method
181
        )
182
183
        m = ~np.isfinite(vals)
184
        if np.any(m):
185
            vals_fill = scipy.interpolate.griddata(
186
                grid_coords, data, tuple([c[m] for c in coords]), method="nearest"
187
            )
188
            vals[m] = vals_fill
189
190
        return vals
191
192
    def fill_by_idx(self, idx, weights=None):
193
        idx = pix_tuple_to_idx(idx)
194
        msk = np.all(np.stack([t != INVALID_INDEX.int for t in idx]), axis=0)
195
        idx = [t[msk] for t in idx]
196
197
        if weights is not None:
198
            if isinstance(weights, u.Quantity):
199
                weights = weights.to_value(self.unit)
200
            weights = weights[msk]
201
202
        idx = np.ravel_multi_index(idx, self.data.T.shape)
203
        idx, idx_inv = np.unique(idx, return_inverse=True)
204
        weights = np.bincount(idx_inv, weights=weights).astype(self.data.dtype)
205
        self.data.T.flat[idx] += weights
206
207
    def set_by_idx(self, idx, vals):
208
        idx = pix_tuple_to_idx(idx)
209
        self.data.T[idx] = vals
210
211
    def sum_over_axes(self, keepdims=False):
212
        """To sum map values over all non-spatial axes.
213
214
        Parameters
215
        ----------
216
        keepdims : bool, optional
217
            If this is set to true, the axes which are summed over are left in
218
            the map with a single bin
219
220
        Returns
221
        -------
222
        map_out : WcsNDMap
223
            Map with non-spatial axes summed over
224
        """
225
        axis = tuple(range(self.data.ndim - 2))
226
        geom = self.geom.to_image()
227
        if keepdims:
228
            for ax in self.geom.axes:
229
                geom = geom.to_cube([ax.squash()])
230
        data = np.nansum(self.data, axis=axis, keepdims=keepdims)
231
        # TODO: summing over the axis can change the unit, handle this correctly
232
        return self._init_copy(geom=geom, data=data)
233
234
    def _reproject_to_wcs(self, geom, mode="interp", order=1):
235
        from reproject import reproject_interp, reproject_exact
236
237
        data = np.empty(geom.data_shape)
238
239
        for img, idx in self.iter_by_image():
240
            # TODO: Create WCS object for image plane if
241
            # multi-resolution geom
242
            shape_out = geom.get_image_shape(idx)[::-1]
243
244
            if self.geom.projection == "CAR" and self.geom.is_allsky:
245
                vals, footprint = reproject_car_to_wcs(
246
                    (img, self.geom.wcs), geom.wcs, shape_out=shape_out
247
                )
248
            elif mode == "interp":
249
                vals, footprint = reproject_interp(
250
                    (img, self.geom.wcs), geom.wcs, shape_out=shape_out
251
                )
252
            elif mode == "exact":
253
                vals, footprint = reproject_exact(
254
                    (img, self.geom.wcs), geom.wcs, shape_out=shape_out
255
                )
256
            else:
257
                raise TypeError(f"mode must be 'interp' or 'exact'. Got: {mode!r}")
258
259
            data[idx] = vals
260
261
        return self._init_copy(geom=geom, data=data)
262
263
    def _reproject_to_hpx(self, geom, mode="interp", order=1):
264
        from reproject import reproject_to_healpix
265
266
        data = np.empty(geom.data_shape)
267
        coordsys = "galactic" if geom.coordsys == "GAL" else "icrs"
268
269
        for img, idx in self.iter_by_image():
270
            # TODO: For partial-sky HPX we need to map from full- to
271
            # partial-sky indices
272
            if self.geom.projection == "CAR" and self.geom.is_allsky:
273
                vals, footprint = reproject_car_to_hpx(
274
                    (img, self.geom.wcs),
275
                    coordsys,
276
                    nside=geom.nside,
277
                    nested=geom.nest,
278
                    order=order,
279
                )
280
            else:
281
                vals, footprint = reproject_to_healpix(
282
                    (img, self.geom.wcs),
283
                    coordsys,
284
                    nside=geom.nside,
285
                    nested=geom.nest,
286
                    order=order,
287
                )
288
            data[idx] = vals
289
290
        return self._init_copy(geom=geom, data=data)
291
292
    def pad(self, pad_width, mode="constant", cval=0, order=1):
293
        if np.isscalar(pad_width):
294
            pad_width = (pad_width, pad_width)
295
            pad_width += (0,) * (self.geom.ndim - 2)
296
297
        geom = self.geom.pad(pad_width[:2])
298
        if self.geom.is_regular and mode != "interp":
299
            return self._pad_np(geom, pad_width, mode, cval)
300
        else:
301
            return self._pad_coadd(geom, pad_width, mode, cval, order)
302
303
    def _pad_np(self, geom, pad_width, mode, cval):
304
        """Pad a map using ``numpy.pad``.
305
306
        This method only works for regular geometries but should be more
307
        efficient when working with large maps.
308
        """
309
        kwargs = {}
310
        if mode == "constant":
311
            kwargs["constant_values"] = cval
312
313
        pad_width = [(t, t) for t in pad_width]
314
        data = np.pad(self.data, pad_width[::-1], mode)
315
        return self._init_copy(geom=geom, data=data)
316
317
    def _pad_coadd(self, geom, pad_width, mode, cval, order):
318
        """Pad a map manually by coadding the original map with the new map."""
319
        idx_in = self.geom.get_idx(flat=True)
320
        idx_in = tuple([t + w for t, w in zip(idx_in, pad_width)])[::-1]
321
        idx_out = geom.get_idx(flat=True)[::-1]
322
        map_out = self._init_copy(geom=geom, data=None)
323
        map_out.coadd(self)
324
325
        if mode == "constant":
326
            pad_msk = np.zeros_like(map_out.data, dtype=bool)
327
            pad_msk[idx_out] = True
328
            pad_msk[idx_in] = False
329
            map_out.data[pad_msk] = cval
330
        elif mode == "interp":
331
            coords = geom.pix_to_coord(idx_out[::-1])
332
            m = self.geom.contains(coords)
333
            coords = tuple([c[~m] for c in coords])
334
            vals = self.interp_by_coord(coords, interp=order)
335
            map_out.set_by_coord(coords, vals)
336
        else:
337
            raise ValueError(f"Invalid mode: {mode!r}")
338
339
        return map_out
340
341
    def crop(self, crop_width):
342
        if np.isscalar(crop_width):
343
            crop_width = (crop_width, crop_width)
344
345
        geom = self.geom.crop(crop_width)
346
        if self.geom.is_regular:
347
            slices = [slice(None)] * len(self.geom.axes)
348
            slices += [
349
                slice(crop_width[1], int(self.geom.npix[1] - crop_width[1])),
350
                slice(crop_width[0], int(self.geom.npix[0] - crop_width[0])),
351
            ]
352
            data = self.data[tuple(slices)]
353
            map_out = self._init_copy(geom=geom, data=data)
354
        else:
355
            # FIXME: This could be done more efficiently by
356
            # constructing the appropriate slices for each image plane
357
            map_out = self._init_copy(geom=geom, data=None)
358
            map_out.coadd(self)
359
360
        return map_out
361
362
    def upsample(self, factor, order=0, preserve_counts=True, axis=None):
363
        geom = self.geom.upsample(factor, axis=axis)
364
        idx = geom.get_idx()
365
366
        if axis is None:
367
            pix = (
368
                (idx[0] - 0.5 * (factor - 1)) / factor,
369
                (idx[1] - 0.5 * (factor - 1)) / factor,
370
            ) + idx[2:]
371
        else:
372
            pix = list(idx)
373
            idx_ax = self.geom.get_axis_index_by_name(axis)
374
            pix[idx_ax] = (pix[idx_ax] - 0.5 * (factor - 1)) / factor
375
376
        data = scipy.ndimage.map_coordinates(
377
            self.data.T, tuple(pix), order=order, mode="nearest"
378
        )
379
380
        if preserve_counts:
381
            if axis is None:
382
                data /= factor ** 2
383
            else:
384
                data /= factor
385
386
        return self._init_copy(geom=geom, data=data)
387
388
    def downsample(self, factor, preserve_counts=True, axis=None):
389
        geom = self.geom.downsample(factor, axis=axis)
390
        if axis is None:
391
            block_size = (factor, factor) + (1,) * len(self.geom.axes)
392
        else:
393
            block_size = [1] * self.data.ndim
394
            idx = self.geom.get_axis_index_by_name(axis)
395
            block_size[-(idx + 1)] = factor
396
397
        func = np.nansum if preserve_counts else np.nanmean
398
        data = block_reduce(self.data, tuple(block_size[::-1]), func=func)
399
400
        return self._init_copy(geom=geom, data=data)
401
402
    def plot(self, ax=None, fig=None, add_cbar=False, stretch="linear", **kwargs):
403
        """
404
        Plot image on matplotlib WCS axes.
405
406
        Parameters
407
        ----------
408
        ax : `~astropy.visualization.wcsaxes.WCSAxes`, optional
409
            WCS axis object to plot on.
410
        fig : `~matplotlib.figure.Figure`
411
            Figure object.
412
        add_cbar : bool
413
            Add color bar?
414
        stretch : str
415
            Passed to `astropy.visualization.simple_norm`.
416
        **kwargs : dict
417
            Keyword arguments passed to `~matplotlib.pyplot.imshow`.
418
419
        Returns
420
        -------
421
        fig : `~matplotlib.figure.Figure`
422
            Figure object.
423
        ax : `~astropy.visualization.wcsaxes.WCSAxes`
424
            WCS axis object
425
        cbar : `~matplotlib.colorbar.Colorbar` or None
426
            Colorbar object.
427
        """
428
        import matplotlib.pyplot as plt
429
        from astropy.visualization import simple_norm
430
        from astropy.visualization.wcsaxes.frame import EllipticalFrame
431
432
        if not self.geom.is_image:
433
            raise TypeError("Use .plot_interactive() for Map dimension > 2")
434
435
        if fig is None:
436
            fig = plt.gcf()
437
438
        if ax is None:
439
            if self.geom.is_allsky:
440
                ax = fig.add_subplot(
441
                    1, 1, 1, projection=self.geom.wcs, frame_class=EllipticalFrame
442
                )
443
            else:
444
                ax = fig.add_subplot(1, 1, 1, projection=self.geom.wcs)
445
446
        data = self.data.astype(float)
447
448
        kwargs.setdefault("interpolation", "nearest")
449
        kwargs.setdefault("origin", "lower")
450
        kwargs.setdefault("cmap", "afmhot")
451
452
        norm = simple_norm(data[np.isfinite(data)], stretch)
453
        kwargs.setdefault("norm", norm)
454
455
        caxes = ax.imshow(data, **kwargs)
456
        cbar = fig.colorbar(caxes, ax=ax, label=str(self.unit)) if add_cbar else None
457
458
        if self.geom.is_allsky:
459
            ax = self._plot_format_allsky(ax)
460
        else:
461
            ax = self._plot_format(ax)
462
463
        # without this the axis limits are changed when calling scatter
464
        ax.autoscale(enable=False)
465
        return fig, ax, cbar
466
467
    def _plot_format(self, ax):
468
        try:
469
            ax.coords["glon"].set_axislabel("Galactic Longitude")
470
            ax.coords["glat"].set_axislabel("Galactic Latitude")
471
        except KeyError:
472
            ax.coords["ra"].set_axislabel("Right Ascension")
473
            ax.coords["dec"].set_axislabel("Declination")
474
        except AttributeError:
475
            log.info("Can't set coordinate axes. No WCS information available.")
476
        return ax
477
478
    def _plot_format_allsky(self, ax):
479
        # Remove frame
480
        ax.coords.frame.set_linewidth(0)
481
482
        # Set plot axis limits
483
        ymax, xmax = self.data.shape
484
        xmargin, _ = self.geom.coord_to_pix({"lon": 180, "lat": 0})
485
        _, ymargin = self.geom.coord_to_pix({"lon": 0, "lat": -90})
486
487
        ax.set_xlim(xmargin, xmax - xmargin)
488
        ax.set_ylim(ymargin, ymax - ymargin)
489
490
        ax.text(0, ymax, self.geom.coordsys + " coords")
491
492
        # Grid and ticks
493
        glon_spacing, glat_spacing = 45, 15
494
        lon, lat = ax.coords
495
        lon.set_ticks(spacing=glon_spacing * u.deg, color="w", alpha=0.8)
496
        lat.set_ticks(spacing=glat_spacing * u.deg)
497
        lon.set_ticks_visible(False)
498
499
        lon.set_ticklabel(color="w", alpha=0.8)
500
        lon.grid(alpha=0.2, linestyle="solid", color="w")
501
        lat.grid(alpha=0.2, linestyle="solid", color="w")
502
        return ax
503
504
    def smooth(self, width, kernel="gauss", **kwargs):
505
        """Smooth the map.
506
507
        Iterates over 2D image planes, processing one at a time.
508
509
        Parameters
510
        ----------
511
        width : `~astropy.units.Quantity`, str or float
512
            Smoothing width given as quantity or float. If a float is given it
513
            interpreted as smoothing width in pixels. If an (angular) quantity
514
            is given it converted to pixels using ``geom.wcs.wcs.cdelt``.
515
            It corresponds to the standard deviation in case of a Gaussian kernel,
516
            the radius in case of a disk kernel, and the side length in case
517
            of a box kernel.
518
        kernel : {'gauss', 'disk', 'box'}
519
            Kernel shape
520
        kwargs : dict
521
            Keyword arguments passed to `~scipy.ndimage.uniform_filter`
522
            ('box'), `~scipy.ndimage.gaussian_filter` ('gauss') or
523
            `~scipy.ndimage.convolve` ('disk').
524
525
        Returns
526
        -------
527
        image : `WcsNDMap`
528
            Smoothed image (a copy, the original object is unchanged).
529
        """
530
        if isinstance(width, (u.Quantity, str)):
531
            width = u.Quantity(width) / self.geom.pixel_scales.mean()
532
            width = width.to_value("")
533
534
        smoothed_data = np.empty(self.data.shape, dtype=float)
535
536
        for img, idx in self.iter_by_image():
537
            img = img.astype(float)
538
            if kernel == "gauss":
539
                data = scipy.ndimage.gaussian_filter(img, width, **kwargs)
540
            elif kernel == "disk":
541
                disk = Tophat2DKernel(width)
542
                disk.normalize("integral")
543
                data = scipy.ndimage.convolve(img, disk.array, **kwargs)
544
            elif kernel == "box":
545
                data = scipy.ndimage.uniform_filter(img, width, **kwargs)
546
            else:
547
                raise ValueError(f"Invalid kernel: {kernel!r}")
548
            smoothed_data[idx] = data
549
550
        return self._init_copy(data=smoothed_data)
551
552
    def get_spectrum(self, region=None, func=np.nansum):
553
        """Extract spectrum in a given region.
554
555
        The spectrum can be computed by summing (or, more generally, applying ``func``)
556
        along the spatial axes in each energy bin. This occurs only inside the ``region``,
557
        which by default is assumed to be the whole spatial extension of the map.
558
559
        Parameters
560
        ----------
561
        region: `~regions.Region`
562
             Region (pixel or sky regions accepted).
563
        func : numpy.ufunc
564
            Function to reduce the data.
565
566
        Returns
567
        -------
568
        spectrum : `~gammapy.spectrum.CountsSpectrum`
569
            Spectrum in the given region.
570
        """
571
        from gammapy.spectrum import CountsSpectrum
572
573
        energy_axis = self.geom.get_axis_by_name("energy")
574
575
        if region:
576
            mask = self.geom.region_mask([region])
577
            data = self.data[mask].reshape(energy_axis.nbin, -1)
578
            data = func(data, axis=1)
579
        else:
580
            data = func(self.data, axis=(1, 2))
581
582
        edges = energy_axis.edges
583
        return CountsSpectrum(
584
            data=data, energy_lo=edges[:-1], energy_hi=edges[1:], unit=self.unit
585
        )
586
587
    def convolve(self, kernel, use_fft=True, **kwargs):
588
        """
589
        Convolve map with a kernel.
590
591
        If the kernel is two dimensional, it is applied to all image planes likewise.
592
        If the kernel is higher dimensional it must match the map in the number of
593
        dimensions and the corresponding kernel is selected for every image plane.
594
595
        Parameters
596
        ----------
597
        kernel : `~gammapy.cube.PSFKernel` or `numpy.ndarray`
598
            Convolution kernel.
599
        use_fft : bool
600
            Use `scipy.signal.fftconvolve` or `scipy.ndimage.convolve`.
601
        kwargs : dict
602
            Keyword arguments passed to `scipy.signal.fftconvolve` or
603
            `scipy.ndimage.convolve`.
604
605
        Returns
606
        -------
607
        map : `WcsNDMap`
608
            Convolved map.
609
        """
610
        from gammapy.cube import PSFKernel
611
612
        conv_function = scipy.signal.fftconvolve if use_fft else scipy.ndimage.convolve
613
        convolved_data = np.empty(self.data.shape, dtype=np.float32)
614
        if use_fft:
615
            kwargs.setdefault("mode", "same")
616
617
        if isinstance(kernel, PSFKernel):
618
            kmap = kernel.psf_kernel_map
619
            if not np.allclose(
620
                self.geom.pixel_scales.deg, kmap.geom.pixel_scales.deg, rtol=1e-5
621
            ):
622
                raise ValueError("Pixel size of kernel and map not compatible.")
623
            kernel = kmap.data
624
625
        for img, idx in self.iter_by_image():
626
            idx = Ellipsis if kernel.ndim == 2 else idx
627
            convolved_data[idx] = conv_function(img, kernel[idx], **kwargs)
628
629
        return self._init_copy(data=convolved_data)
630
631
    def apply_edisp(self, edisp):
632
        """Apply energy dispersion to map. Requires energy axis.
633
634
        Parameters
635
        ----------
636
        edisp : `gammapy.irf.EnergyDispersion`
637
            Energy dispersion matrix
638
639
        Returns
640
        -------
641
        map : `WcsNDMap`
642
            Map with energy dispersion applied.
643
        """
644
        loc = self.geom.get_axis_index_by_name("energy")
645
        data = np.rollaxis(self.data, loc, len(self.data.shape))
646
        data = np.dot(data, edisp.pdf_matrix)
647
        data = np.rollaxis(data, -1, loc)
648
649
        e_reco_axis = edisp.e_reco.copy(name="energy")
650
        geom_reco = self.geom.to_image().to_cube(axes=[e_reco_axis])
651
        return self._init_copy(geom=geom_reco, data=data)
652
653
    def cutout(self, position, width, mode="trim"):
654
        """
655
        Create a cutout around a given position.
656
657
        Parameters
658
        ----------
659
        position : `~astropy.coordinates.SkyCoord`
660
            Center position of the cutout region.
661
        width : tuple of `~astropy.coordinates.Angle`
662
            Angular sizes of the region in (lon, lat) in that specific order.
663
            If only one value is passed, a square region is extracted.
664
        mode : {'trim', 'partial', 'strict'}
665
            Mode option for Cutout2D, for details see `~astropy.nddata.utils.Cutout2D`.
666
667
        Returns
668
        -------
669
        cutout : `~gammapy.maps.WcsNDMap`
670
            Cutout map
671
        """
672
        geom_cutout = self.geom.cutout(position=position, width=width, mode=mode)
673
674
        slices = geom_cutout.cutout_info["parent-slices"]
675
        cutout_slices = Ellipsis, slices[0], slices[1]
676
677
        data = self.data[cutout_slices]
678
679
        return self._init_copy(geom=geom_cutout, data=data)
680
681
    def stack(self, other, weights=None):
682
        """Stack cutout into map.
683
684
        Parameters
685
        ----------
686
        other : `WcsNDMap`
687
            Other map to stack
688
        weights : `~numpy.ndarray`
689
            Array to be used as weights.
690
        """
691
        if self.geom == other.geom:
692
            parent_slices, cutout_slices = None, None
693
        elif other.geom.cutout_info is not None and self.geom == other.geom.cutout_info["parent-geom"]:
694
            slices = other.geom.cutout_info["parent-slices"]
695
            parent_slices = Ellipsis, slices[0], slices[1]
696
697
            slices = other.geom.cutout_info["cutout-slices"]
698
            cutout_slices = Ellipsis, slices[0], slices[1]
699
        else:
700
            raise ValueError("Can only stack equivalent maps or cutout of the same map.")
701
702
        data = other.data[cutout_slices]
703
704
        if weights is not None:
705
            data = data * weights
706
707
        self.data[parent_slices] += data
708
709
    def sample_coord(self, n_events, random_state=0):
710
        """Sample position and energy of events.
711
712
        Parameters
713
        ----------
714
        n_events : int
715
            Number of events to sample.
716
        random_state : {int, 'random-seed', 'global-rng', `~numpy.random.RandomState`}
717
            Defines random number generator initialisation.
718
            Passed to `~gammapy.utils.random.get_random_state`.
719
720
        Returns
721
        -------
722
        coords : `~gammapy.maps.MapCoord` object.
723
            Sequence of coordinates and energies of the sampled events.
724
        """
725
726
        random_state = get_random_state(random_state)
727
        sampler = InverseCDFSampler(pdf=self.data, random_state=random_state)
728
729
        coords_pix = sampler.sample(n_events)
730
        coords = self.geom.pix_to_coord(coords_pix[::-1])
731
732
        # TODO: pix_to_coord should return a MapCoord object
733
        axes_names = ["lon", "lat"] + [ax.name for ax in self.geom.axes]
734
        cdict = OrderedDict(zip(axes_names, coords))
735
        cdict["energy"] *= self.geom.get_axis_by_name("energy").unit
736
737
        return MapCoord.create(cdict, coordsys=self.geom.coordsys)
738