Passed
Pull Request — master (#2343)
by Axel
03:10
created

gammapy.maps.wcsnd.WcsNDMap.sample_coord()   A

Complexity

Conditions 1

Size

Total Lines 29
Code Lines 9

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 9
dl 0
loc 29
rs 9.95
c 0
b 0
f 0
cc 1
nop 3
1
# Licensed under a 3-clause BSD style license - see LICENSE.rst
2
from collections import OrderedDict
3
import logging
4
import numpy as np
5
import scipy.interpolate
6
import scipy.ndimage
7
import scipy.signal
8
from astropy.table import Table
9
import astropy.units as u
10
from astropy.convolution import Tophat2DKernel
11
from astropy.io import fits
12
from astropy.nddata import Cutout2D
13
from gammapy.extern.skimage import block_reduce
14
from gammapy.utils.interpolation import ScaledRegularGridInterpolator
15
from gammapy.utils.units import unit_from_fits_image_hdu
16
from .geom import pix_tuple_to_idx, MapCoord
17
from .reproject import reproject_car_to_hpx, reproject_car_to_wcs
18
from .utils import INVALID_INDEX, interp_to_order
19
from .utils import get_random_state
20
from gammapy.utils.random import InverseCDFSampler, get_random_state
21
from .wcs import _check_width
22
from .wcsmap import WcsGeom, WcsMap
23
24
__all__ = ["WcsNDMap"]
25
26
log = logging.getLogger(__name__)
27
28
29
class WcsNDMap(WcsMap):
30
    """HEALPix map with any number of non-spatial dimensions.
31
32
    This class uses an ND numpy array to store map values. For maps with
33
    non-spatial dimensions and variable pixel size it will allocate an
34
    array with dimensions commensurate with the largest image plane.
35
36
    Parameters
37
    ----------
38
    geom : `~gammapy.maps.WcsGeom`
39
        WCS geometry object.
40
    data : `~numpy.ndarray`
41
        Data array. If none then an empty array will be allocated.
42
    dtype : str, optional
43
        Data type, default is float32
44
    meta : `dict`
45
        Dictionary to store meta data.
46
    unit : str or `~astropy.units.Unit`
47
        The map unit
48
    """
49
50
    def __init__(self, geom, data=None, dtype="float32", meta=None, unit=""):
51
        # TODO: Figure out how to mask pixels for integer data types
52
53
        data_shape = geom.data_shape
54
55
        if data is None:
56
            data = self._make_default_data(geom, data_shape, dtype)
57
58
        super().__init__(geom, data, meta, unit)
59
60
    @staticmethod
61
    def _make_default_data(geom, shape_np, dtype):
62
        # Check whether corners of each image plane are valid
63
        coords = []
64
        if not geom.is_regular:
65
            for idx in np.ndindex(geom.shape_axes):
66
                pix = (
67
                    np.array([0.0, float(geom.npix[0][idx] - 1)]),
68
                    np.array([0.0, float(geom.npix[1][idx] - 1)]),
69
                )
70
                pix += tuple([np.array(2 * [t]) for t in idx])
71
                coords += geom.pix_to_coord(pix)
72
73
        else:
74
            pix = (
75
                np.array([0.0, float(geom.npix[0] - 1)]),
76
                np.array([0.0, float(geom.npix[1] - 1)]),
77
            )
78
            pix += tuple([np.array(2 * [0.0]) for i in range(geom.ndim - 2)])
79
            coords += geom.pix_to_coord(pix)
80
81
        if np.all(np.isfinite(np.vstack(coords))):
82
            if geom.is_regular:
83
                data = np.zeros(shape_np, dtype=dtype)
84
            else:
85
                data = np.full(shape_np, np.nan, dtype=dtype)
86
                for idx in np.ndindex(geom.shape_axes):
87
                    data[idx, slice(geom.npix[0][idx]), slice(geom.npix[1][idx])] = 0.0
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable slice does not seem to be defined.
Loading history...
88
        else:
89
            data = np.full(shape_np, np.nan, dtype=dtype)
90
            idx = geom.get_idx()
91
            m = np.all(np.stack([t != INVALID_INDEX.int for t in idx]), axis=0)
92
            data[m] = 0.0
93
94
        return data
95
96
    @classmethod
97
    def from_hdu(cls, hdu, hdu_bands=None):
98
        """Make a WcsNDMap object from a FITS HDU.
99
100
        Parameters
101
        ----------
102
        hdu : `~astropy.io.fits.BinTableHDU` or `~astropy.io.fits.ImageHDU`
103
            The map FITS HDU.
104
        hdu_bands : `~astropy.io.fits.BinTableHDU`
105
            The BANDS table HDU.
106
        """
107
        geom = WcsGeom.from_header(hdu.header, hdu_bands)
108
        shape = tuple([ax.nbin for ax in geom.axes])
109
        shape_wcs = tuple([np.max(geom.npix[0]), np.max(geom.npix[1])])
110
111
        meta = cls._get_meta_from_header(hdu.header)
112
        unit = unit_from_fits_image_hdu(hdu.header)
113
        map_out = cls(geom, meta=meta, unit=unit)
114
115
        # TODO: Should we support extracting slices?
116
        if isinstance(hdu, fits.BinTableHDU):
117
            pix = hdu.data.field("PIX")
118
            pix = np.unravel_index(pix, shape_wcs[::-1])
119
            vals = hdu.data.field("VALUE")
120
            if "CHANNEL" in hdu.data.columns.names and shape:
121
                chan = hdu.data.field("CHANNEL")
122
                chan = np.unravel_index(chan, shape[::-1])
123
                idx = chan + pix
124
            else:
125
                idx = pix
126
127
            map_out.set_by_idx(idx[::-1], vals)
128
        else:
129
            map_out.data = hdu.data
130
131
        return map_out
132
133
    def get_by_idx(self, idx):
134
        idx = pix_tuple_to_idx(idx)
135
        return self.data.T[idx]
136
137
    def interp_by_coord(self, coords, interp=None, fill_value=None):
138
139
        if self.geom.is_regular:
140
            pix = self.geom.coord_to_pix(coords)
141
            return self.interp_by_pix(pix, interp=interp, fill_value=fill_value)
142
        else:
143
            return self._interp_by_coord_griddata(coords, interp=interp)
144
145
    def interp_by_pix(self, pix, interp=None, fill_value=None):
146
        """Interpolate map values at the given pixel coordinates.
147
        """
148
        if not self.geom.is_regular:
149
            raise ValueError("interp_by_pix only supported for regular geom.")
150
151
        order = interp_to_order(interp)
152
        if order == 0 or order == 1:
153
            return self._interp_by_pix_linear_grid(
154
                pix, order=order, fill_value=fill_value
155
            )
156
        elif order == 2 or order == 3:
157
            return self._interp_by_pix_map_coordinates(pix, order=order)
158
        else:
159
            raise ValueError("Invalid interpolation order: {!r}".format(order))
160
161
    def _interp_by_pix_linear_grid(self, pix, order=1, fill_value=None):
162
        # TODO: Cache interpolator
163
        method_lookup = {0: "nearest", 1: "linear"}
164
        try:
165
            method = method_lookup[order]
166
        except KeyError:
167
            raise ValueError("Invalid interpolation order: {!r}".format(order))
168
169
        grid_pix = [np.arange(n, dtype=float) for n in self.data.shape[::-1]]
170
171
        if np.any(np.isfinite(self.data)):
172
            data = self.data.copy().T
173
            data[~np.isfinite(data)] = 0.0
174
        else:
175
            data = self.data.T
176
177
        fn = ScaledRegularGridInterpolator(
178
            grid_pix, data, fill_value=fill_value, bounds_error=False, method=method
179
        )
180
        return fn(tuple(pix), clip=False)
181
182
    def _interp_by_pix_map_coordinates(self, pix, order=1):
183
        pix = tuple(
184
            [
185
                np.array(x, ndmin=1)
186
                if not isinstance(x, np.ndarray) or x.ndim == 0
187
                else x
188
                for x in pix
189
            ]
190
        )
191
        return scipy.ndimage.map_coordinates(
192
            self.data.T, pix, order=order, mode="nearest"
193
        )
194
195
    def _interp_by_coord_griddata(self, coords, interp=None):
196
        order = interp_to_order(interp)
197
        method_lookup = {0: "nearest", 1: "linear", 3: "cubic"}
198
        method = method_lookup.get(order, None)
199
        if method is None:
200
            raise ValueError("Invalid interp: {!r}".format(interp))
201
202
        grid_coords = tuple(self.geom.get_coord(flat=True))
203
        data = self.data[np.isfinite(self.data)]
204
        vals = scipy.interpolate.griddata(
205
            grid_coords, data, tuple(coords), method=method
206
        )
207
208
        m = ~np.isfinite(vals)
209
        if np.any(m):
210
            vals_fill = scipy.interpolate.griddata(
211
                grid_coords, data, tuple([c[m] for c in coords]), method="nearest"
212
            )
213
            vals[m] = vals_fill
214
215
        return vals
216
217
    def fill_by_idx(self, idx, weights=None):
218
        idx = pix_tuple_to_idx(idx)
219
        msk = np.all(np.stack([t != INVALID_INDEX.int for t in idx]), axis=0)
220
        idx = [t[msk] for t in idx]
221
222
        if weights is not None:
223
            if isinstance(weights, u.Quantity):
224
                weights = weights.to_value(self.unit)
225
            weights = weights[msk]
226
227
        idx = np.ravel_multi_index(idx, self.data.T.shape)
228
        idx, idx_inv = np.unique(idx, return_inverse=True)
229
        weights = np.bincount(idx_inv, weights=weights).astype(self.data.dtype)
230
        self.data.T.flat[idx] += weights
231
232
    def set_by_idx(self, idx, vals):
233
        idx = pix_tuple_to_idx(idx)
234
        self.data.T[idx] = vals
235
236
    def sum_over_axes(self, keepdims=False):
237
        """To sum map values over all non-spatial axes.
238
239
        Parameters
240
        ----------
241
        keepdims : bool, optional
242
            If this is set to true, the axes which are summed over are left in
243
            the map with a single bin
244
245
        Returns
246
        -------
247
        map_out : WcsNDMap
248
            Map with non-spatial axes summed over
249
        """
250
        axis = tuple(range(self.data.ndim - 2))
251
        geom = self.geom.to_image()
252
        if keepdims:
253
            for ax in self.geom.axes:
254
                geom = geom.to_cube([ax.squash()])
255
        data = np.nansum(self.data, axis=axis, keepdims=keepdims)
256
        # TODO: summing over the axis can change the unit, handle this correctly
257
        return self._init_copy(geom=geom, data=data)
258
259
    def _reproject_to_wcs(self, geom, mode="interp", order=1):
260
        from reproject import reproject_interp, reproject_exact
261
262
        data = np.empty(geom.data_shape)
263
264
        for img, idx in self.iter_by_image():
265
            # TODO: Create WCS object for image plane if
266
            # multi-resolution geom
267
            shape_out = geom.get_image_shape(idx)[::-1]
268
269
            if self.geom.projection == "CAR" and self.geom.is_allsky:
270
                vals, footprint = reproject_car_to_wcs(
271
                    (img, self.geom.wcs), geom.wcs, shape_out=shape_out
272
                )
273
            elif mode == "interp":
274
                vals, footprint = reproject_interp(
275
                    (img, self.geom.wcs), geom.wcs, shape_out=shape_out
276
                )
277
            elif mode == "exact":
278
                vals, footprint = reproject_exact(
279
                    (img, self.geom.wcs), geom.wcs, shape_out=shape_out
280
                )
281
            else:
282
                raise TypeError(
283
                    "mode must be 'interp' or 'exact'. Got: {!r}".format(mode)
284
                )
285
286
            data[idx] = vals
287
288
        return self._init_copy(geom=geom, data=data)
289
290
    def _reproject_to_hpx(self, geom, mode="interp", order=1):
291
        from reproject import reproject_to_healpix
292
293
        data = np.empty(geom.data_shape)
294
        coordsys = "galactic" if geom.coordsys == "GAL" else "icrs"
295
296
        for img, idx in self.iter_by_image():
297
            # TODO: For partial-sky HPX we need to map from full- to
298
            # partial-sky indices
299
            if self.geom.projection == "CAR" and self.geom.is_allsky:
300
                vals, footprint = reproject_car_to_hpx(
301
                    (img, self.geom.wcs),
302
                    coordsys,
303
                    nside=geom.nside,
304
                    nested=geom.nest,
305
                    order=order,
306
                )
307
            else:
308
                vals, footprint = reproject_to_healpix(
309
                    (img, self.geom.wcs),
310
                    coordsys,
311
                    nside=geom.nside,
312
                    nested=geom.nest,
313
                    order=order,
314
                )
315
            data[idx] = vals
316
317
        return self._init_copy(geom=geom, data=data)
318
319
    def pad(self, pad_width, mode="constant", cval=0, order=1):
320
        if np.isscalar(pad_width):
321
            pad_width = (pad_width, pad_width)
322
            pad_width += (0,) * (self.geom.ndim - 2)
323
324
        geom = self.geom.pad(pad_width[:2])
325
        if self.geom.is_regular and mode != "interp":
326
            return self._pad_np(geom, pad_width, mode, cval)
327
        else:
328
            return self._pad_coadd(geom, pad_width, mode, cval, order)
329
330
    def _pad_np(self, geom, pad_width, mode, cval):
331
        """Pad a map using ``numpy.pad``.
332
333
        This method only works for regular geometries but should be more
334
        efficient when working with large maps.
335
        """
336
        kwargs = {}
337
        if mode == "constant":
338
            kwargs["constant_values"] = cval
339
340
        pad_width = [(t, t) for t in pad_width]
341
        data = np.pad(self.data, pad_width[::-1], mode)
342
        return self._init_copy(geom=geom, data=data)
343
344
    def _pad_coadd(self, geom, pad_width, mode, cval, order):
345
        """Pad a map manually by coadding the original map with the new map."""
346
        idx_in = self.geom.get_idx(flat=True)
347
        idx_in = tuple([t + w for t, w in zip(idx_in, pad_width)])[::-1]
348
        idx_out = geom.get_idx(flat=True)[::-1]
349
        map_out = self._init_copy(geom=geom, data=None)
350
        map_out.coadd(self)
351
352
        if mode == "constant":
353
            pad_msk = np.zeros_like(map_out.data, dtype=bool)
354
            pad_msk[idx_out] = True
355
            pad_msk[idx_in] = False
356
            map_out.data[pad_msk] = cval
357
        elif mode == "interp":
358
            coords = geom.pix_to_coord(idx_out[::-1])
359
            m = self.geom.contains(coords)
360
            coords = tuple([c[~m] for c in coords])
361
            vals = self.interp_by_coord(coords, interp=order)
362
            map_out.set_by_coord(coords, vals)
363
        else:
364
            raise ValueError("Invalid mode: {!r}".format(mode))
365
366
        return map_out
367
368
    def crop(self, crop_width):
369
        if np.isscalar(crop_width):
370
            crop_width = (crop_width, crop_width)
371
372
        geom = self.geom.crop(crop_width)
373
        if self.geom.is_regular:
374
            slices = [slice(None)] * len(self.geom.axes)
375
            slices += [
376
                slice(crop_width[1], int(self.geom.npix[1] - crop_width[1])),
377
                slice(crop_width[0], int(self.geom.npix[0] - crop_width[0])),
378
            ]
379
            data = self.data[tuple(slices)]
380
            map_out = self._init_copy(geom=geom, data=data)
381
        else:
382
            # FIXME: This could be done more efficiently by
383
            # constructing the appropriate slices for each image plane
384
            map_out = self._init_copy(geom=geom, data=None)
385
            map_out.coadd(self)
386
387
        return map_out
388
389
    def upsample(self, factor, order=0, preserve_counts=True, axis=None):
390
        geom = self.geom.upsample(factor, axis=axis)
391
        idx = geom.get_idx()
392
393
        if axis is None:
394
            pix = (
395
                (idx[0] - 0.5 * (factor - 1)) / factor,
396
                (idx[1] - 0.5 * (factor - 1)) / factor,
397
            ) + idx[2:]
398
        else:
399
            pix = list(idx)
400
            idx_ax = self.geom.get_axis_index_by_name(axis)
401
            pix[idx_ax] = (pix[idx_ax] - 0.5 * (factor - 1)) / factor
402
403
        data = scipy.ndimage.map_coordinates(
404
            self.data.T, tuple(pix), order=order, mode="nearest"
405
        )
406
407
        if preserve_counts:
408
            if axis is None:
409
                data /= factor ** 2
410
            else:
411
                data /= factor
412
413
        return self._init_copy(geom=geom, data=data)
414
415
    def downsample(self, factor, preserve_counts=True, axis=None):
416
        geom = self.geom.downsample(factor, axis=axis)
417
        if axis is None:
418
            block_size = (factor, factor) + (1,) * len(self.geom.axes)
419
        else:
420
            block_size = [1] * self.data.ndim
421
            idx = self.geom.get_axis_index_by_name(axis)
422
            block_size[-(idx + 1)] = factor
423
424
        func = np.nansum if preserve_counts else np.nanmean
425
        data = block_reduce(self.data, tuple(block_size[::-1]), func=func)
426
427
        return self._init_copy(geom=geom, data=data)
428
429
    def plot(self, ax=None, fig=None, add_cbar=False, stretch="linear", **kwargs):
430
        """
431
        Plot image on matplotlib WCS axes.
432
433
        Parameters
434
        ----------
435
        ax : `~astropy.visualization.wcsaxes.WCSAxes`, optional
436
            WCS axis object to plot on.
437
        fig : `~matplotlib.figure.Figure`
438
            Figure object.
439
        add_cbar : bool
440
            Add color bar?
441
        stretch : str
442
            Passed to `astropy.visualization.simple_norm`.
443
        **kwargs : dict
444
            Keyword arguments passed to `~matplotlib.pyplot.imshow`.
445
446
        Returns
447
        -------
448
        fig : `~matplotlib.figure.Figure`
449
            Figure object.
450
        ax : `~astropy.visualization.wcsaxes.WCSAxes`
451
            WCS axis object
452
        cbar : `~matplotlib.colorbar.Colorbar` or None
453
            Colorbar object.
454
        """
455
        import matplotlib.pyplot as plt
456
        from astropy.visualization import simple_norm
457
        from astropy.visualization.wcsaxes.frame import EllipticalFrame
458
459
        if not self.geom.is_image:
460
            raise TypeError("Use .plot_interactive() for Map dimension > 2")
461
462
        if fig is None:
463
            fig = plt.gcf()
464
465
        if ax is None:
466
            if self.geom.is_allsky:
467
                ax = fig.add_subplot(
468
                    1, 1, 1, projection=self.geom.wcs, frame_class=EllipticalFrame
469
                )
470
            else:
471
                ax = fig.add_subplot(1, 1, 1, projection=self.geom.wcs)
472
473
        data = self.data.astype(float)
474
475
        kwargs.setdefault("interpolation", "nearest")
476
        kwargs.setdefault("origin", "lower")
477
        kwargs.setdefault("cmap", "afmhot")
478
479
        norm = simple_norm(data[np.isfinite(data)], stretch)
480
        kwargs.setdefault("norm", norm)
481
482
        caxes = ax.imshow(data, **kwargs)
483
        cbar = fig.colorbar(caxes, ax=ax, label=str(self.unit)) if add_cbar else None
484
485
        if self.geom.is_allsky:
486
            ax = self._plot_format_allsky(ax)
487
        else:
488
            ax = self._plot_format(ax)
489
490
        # without this the axis limits are changed when calling scatter
491
        ax.autoscale(enable=False)
492
        return fig, ax, cbar
493
494
    def _plot_format(self, ax):
495
        try:
496
            ax.coords["glon"].set_axislabel("Galactic Longitude")
497
            ax.coords["glat"].set_axislabel("Galactic Latitude")
498
        except KeyError:
499
            ax.coords["ra"].set_axislabel("Right Ascension")
500
            ax.coords["dec"].set_axislabel("Declination")
501
        except AttributeError:
502
            log.info("Can't set coordinate axes. No WCS information available.")
503
        return ax
504
505
    def _plot_format_allsky(self, ax):
506
        # Remove frame
507
        ax.coords.frame.set_linewidth(0)
508
509
        # Set plot axis limits
510
        ymax, xmax = self.data.shape
511
        xmargin, _ = self.geom.coord_to_pix({"lon": 180, "lat": 0})
512
        _, ymargin = self.geom.coord_to_pix({"lon": 0, "lat": -90})
513
514
        ax.set_xlim(xmargin, xmax - xmargin)
515
        ax.set_ylim(ymargin, ymax - ymargin)
516
517
        ax.text(0, ymax, self.geom.coordsys + " coords")
518
519
        # Grid and ticks
520
        glon_spacing, glat_spacing = 45, 15
521
        lon, lat = ax.coords
522
        lon.set_ticks(spacing=glon_spacing * u.deg, color="w", alpha=0.8)
523
        lat.set_ticks(spacing=glat_spacing * u.deg)
524
        lon.set_ticks_visible(False)
525
526
        lon.set_ticklabel(color="w", alpha=0.8)
527
        lon.grid(alpha=0.2, linestyle="solid", color="w")
528
        lat.grid(alpha=0.2, linestyle="solid", color="w")
529
        return ax
530
531
    def smooth(self, width, kernel="gauss", **kwargs):
532
        """Smooth the map.
533
534
        Iterates over 2D image planes, processing one at a time.
535
536
        Parameters
537
        ----------
538
        width : `~astropy.units.Quantity`, str or float
539
            Smoothing width given as quantity or float. If a float is given it
540
            interpreted as smoothing width in pixels. If an (angular) quantity
541
            is given it converted to pixels using ``geom.wcs.wcs.cdelt``.
542
            It corresponds to the standard deviation in case of a Gaussian kernel,
543
            the radius in case of a disk kernel, and the side length in case
544
            of a box kernel.
545
        kernel : {'gauss', 'disk', 'box'}
546
            Kernel shape
547
        kwargs : dict
548
            Keyword arguments passed to `~scipy.ndimage.uniform_filter`
549
            ('box'), `~scipy.ndimage.gaussian_filter` ('gauss') or
550
            `~scipy.ndimage.convolve` ('disk').
551
552
        Returns
553
        -------
554
        image : `WcsNDMap`
555
            Smoothed image (a copy, the original object is unchanged).
556
        """
557
        if isinstance(width, (u.Quantity, str)):
558
            width = u.Quantity(width) / self.geom.pixel_scales.mean()
559
            width = width.to_value("")
560
561
        smoothed_data = np.empty(self.data.shape, dtype=float)
562
563
        for img, idx in self.iter_by_image():
564
            img = img.astype(float)
565
            if kernel == "gauss":
566
                data = scipy.ndimage.gaussian_filter(img, width, **kwargs)
567
            elif kernel == "disk":
568
                disk = Tophat2DKernel(width)
569
                disk.normalize("integral")
570
                data = scipy.ndimage.convolve(img, disk.array, **kwargs)
571
            elif kernel == "box":
572
                data = scipy.ndimage.uniform_filter(img, width, **kwargs)
573
            else:
574
                raise ValueError("Invalid kernel: {!r}".format(kernel))
575
            smoothed_data[idx] = data
576
577
        return self._init_copy(data=smoothed_data)
578
579
    def get_spectrum(self, region=None, func=np.nansum):
580
        """Extract spectrum in a given region.
581
582
        The spectrum can be computed by summing (or, more generally, applying `func`)
583
        along the spatial axes in each energy bin. This occurs only inside the `region`,
584
        which by default is assumed to be the whole spatial extension of the map.
585
586
        Parameters
587
        ----------
588
        region: `~regions.Region`
589
             Region (pixel or sky regions accepted).
590
        func : `~numpy.ufunc`
591
            Function to reduce the data.
592
593
        Returns
594
        -------
595
        spectrum : `CountsSpectrum`
596
            Spectrum in the given region.
597
        """
598
        from gammapy.spectrum import CountsSpectrum
599
600
        energy_axis = self.geom.get_axis_by_name("energy")
601
602
        if region:
603
            mask = self.geom.region_mask([region])
604
            data = self.data[mask].reshape(energy_axis.nbin, -1)
605
            data = func(data, axis=1)
606
        else:
607
            data = func(self.data, axis=(1, 2))
608
609
        edges = energy_axis.edges
610
        return CountsSpectrum(
611
            data=data, energy_lo=edges[:-1], energy_hi=edges[1:], unit=self.unit
612
        )
613
614
    def convolve(self, kernel, use_fft=True, **kwargs):
615
        """
616
        Convolve map with a kernel.
617
618
        If the kernel is two dimensional, it is applied to all image planes likewise.
619
        If the kernel is higher dimensional it must match the map in the number of
620
        dimensions and the corresponding kernel is selected for every image plane.
621
622
        Parameters
623
        ----------
624
        kernel : `~gammapy.cube.PSFKernel` or `numpy.ndarray`
625
            Convolution kernel.
626
        use_fft : bool
627
            Use `scipy.signal.fftconvolve` or `scipy.ndimage.convolve`.
628
        kwargs : dict
629
            Keyword arguments passed to `scipy.signal.fftconvolve` or
630
            `scipy.ndimage.convolve`.
631
632
        Returns
633
        -------
634
        map : `WcsNDMap`
635
            Convolved map.
636
        """
637
        from gammapy.cube import PSFKernel
638
639
        conv_function = scipy.signal.fftconvolve if use_fft else scipy.ndimage.convolve
640
        convolved_data = np.empty(self.data.shape, dtype=np.float32)
641
        if use_fft:
642
            kwargs.setdefault("mode", "same")
643
644
        if isinstance(kernel, PSFKernel):
645
            kmap = kernel.psf_kernel_map
646
            if not np.allclose(
647
                self.geom.pixel_scales.deg, kmap.geom.pixel_scales.deg, rtol=1e-5
648
            ):
649
                raise ValueError("Pixel size of kernel and map not compatible.")
650
            kernel = kmap.data
651
652
        for img, idx in self.iter_by_image():
653
            idx = Ellipsis if kernel.ndim == 2 else idx
654
            convolved_data[idx] = conv_function(img, kernel[idx], **kwargs)
655
656
        return self._init_copy(data=convolved_data)
657
658
    def cutout(self, position, width, mode="trim"):
659
        """
660
        Create a cutout around a given position.
661
662
        Parameters
663
        ----------
664
        position : `~astropy.coordinates.SkyCoord`
665
            Center position of the cutout region.
666
        width : tuple of `~astropy.coordinates.Angle`
667
            Angular sizes of the region in (lon, lat) in that specific order.
668
            If only one value is passed, a square region is extracted.
669
        mode : {'trim', 'partial', 'strict'}
670
            Mode option for Cutout2D, for details see `~astropy.nddata.utils.Cutout2D`.
671
672
        Returns
673
        -------
674
        cutout : `~gammapy.maps.WcsNDMap`
675
            Cutout map
676
        """
677
        width = _check_width(width)
678
        idx = (0,) * len(self.geom.axes)
679
        c2d = Cutout2D(
680
            data=self.data[idx],
681
            wcs=self.geom.wcs,
682
            position=position,
683
            # Cutout2D takes size with order (lat, lon)
684
            size=width[::-1] * u.deg,
685
            mode=mode,
686
        )
687
688
        # Create the slices with the non-spatial axis
689
        cutout_slices = Ellipsis, c2d.slices_original[0], c2d.slices_original[1]
690
691
        geom = WcsGeom(c2d.wcs, c2d.shape[::-1], axes=self.geom.axes)
692
        data = self.data[cutout_slices]
693
694
        return self._init_copy(geom=geom, data=data)
695
696
    def sample_coord(self, n_events, random_state=0):
697
        """Sample position and energy of events.
698
699
        Parameters
700
        ----------
701
        n_events : int
702
            Number of events to sample.
703
        random_state : {int, 'random-seed', 'global-rng', `~numpy.random.RandomState`}
704
            Defines random number generator initialisation.
705
            Passed to `~gammapy.utils.random.get_random_state`.
706
707
        Returns
708
        -------
709
        coords : `~gammapy.maps.MapCoord` object.
710
            Sequence of coordinates and energies of the sampled events.
711
        """
712
713
        random_state = get_random_state(random_state)
714
        sampler = InverseCDFSampler(pdf=self.data, random_state=random_state)
715
716
        coords_pix = sampler.sample(n_events)
717
        coords = self.geom.pix_to_coord(coords_pix[::-1])
718
719
        # TODO: pix_to_coord should return a MapCoord object
720
        axes_names = ["lon", "lat"] + [ax.name for ax in self.geom.axes]
721
        cdict = OrderedDict(zip(axes_names, coords))
722
        cdict["energy"] *= self.geom.get_axis_by_name("energy").unit
723
724
        return MapCoord.create(cdict, coordsys=self.geom.coordsys)
725