Passed
Pull Request — master (#2446)
by Axel
02:42
created

gammapy.maps.wcs.WcsGeom.cutout_info()   A

Complexity

Conditions 1

Size

Total Lines 4
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 3
dl 0
loc 4
rs 10
c 0
b 0
f 0
cc 1
nop 1
1
# Licensed under a 3-clause BSD style license - see LICENSE.rst
2
import copy
3
from functools import lru_cache
4
import numpy as np
5
import astropy.units as u
6
from astropy.coordinates import Angle, SkyCoord
7
from astropy.io import fits
8
from astropy.nddata import Cutout2D
9
from astropy.wcs import WCS
10
from astropy.wcs.utils import proj_plane_pixel_scales
11
from regions import SkyRegion
12
from .geom import (
13
    Geom,
14
    MapCoord,
15
    axes_pix_to_coord,
16
    find_and_read_bands,
17
    get_shape,
18
    make_axes,
19
    pix_tuple_to_idx,
20
    skycoord_to_lonlat,
21
)
22
from .utils import INVALID_INDEX
23
24
__all__ = ["WcsGeom"]
25
26
27
def _check_width(width):
28
    """Check and normalise width argument.
29
30
    Always returns tuple (lon, lat) as float in degrees.
31
    """
32
    if isinstance(width, tuple):
33
        lon = Angle(width[0], "deg").deg
34
        lat = Angle(width[1], "deg").deg
35
        return lon, lat
36
    else:
37
        angle = Angle(width, "deg").deg
38
        if np.isscalar(angle):
39
            return angle, angle
40
        else:
41
            return tuple(angle)
42
43
44
def cast_to_shape(param, shape, dtype):
45
    """Cast a tuple of parameter arrays to a given shape."""
46
    if not isinstance(param, tuple):
47
        param = [param]
48
49
    param = [np.array(p, ndmin=1, dtype=dtype) for p in param]
50
51
    if len(param) == 1:
52
        param = [param[0].copy(), param[0].copy()]
53
54
    for i, p in enumerate(param):
55
56
        if p.size > 1 and p.shape != shape:
57
            raise ValueError
58
59
        if p.shape == shape:
60
            continue
61
62
        param[i] = p * np.ones(shape, dtype=dtype)
63
64
    return tuple(param)
65
66
67
def get_resampled_wcs(wcs, factor, downsampled):
68
    """
69
    Get resampled WCS object.
70
    """
71
    wcs = wcs.deepcopy()
72
73
    if not downsampled:
74
        factor = 1.0 / factor
75
76
    wcs.wcs.cdelt *= factor
77
    wcs.wcs.crpix = (wcs.wcs.crpix - 0.5) / factor + 0.5
78
    return wcs
79
80
81
# TODO: remove this function, move code to the one caller below
82
def _make_image_header(
83
    nxpix=100,
84
    nypix=100,
85
    binsz=0.1,
86
    xref=0,
87
    yref=0,
88
    proj="CAR",
89
    coordsys="GAL",
90
    xrefpix=None,
91
    yrefpix=None,
92
):
93
    """Generate a FITS header from scratch.
94
95
    Uses the same parameter names as the Fermi tool gtbin.
96
97
    If no reference pixel position is given it is assumed ot be
98
    at the center of the image.
99
100
    Parameters
101
    ----------
102
    nxpix : int, optional
103
        Number of pixels in x axis. Default is 100.
104
    nypix : int, optional
105
        Number of pixels in y axis. Default is 100.
106
    binsz : float, optional
107
        Bin size for x and y axes in units of degrees. Default is 0.1.
108
    xref : float, optional
109
        Coordinate system value at reference pixel for x axis. Default is 0.
110
    yref : float, optional
111
        Coordinate system value at reference pixel for y axis. Default is 0.
112
    proj : string, optional
113
        Projection type. Default is 'CAR' (cartesian).
114
    coordsys : {'CEL', 'GAL'}, optional
115
        Coordinate system. Default is 'GAL' (Galactic).
116
    xrefpix : float, optional
117
        Coordinate system reference pixel for x axis. Default is None.
118
    yrefpix: float, optional
119
        Coordinate system reference pixel for y axis. Default is None.
120
121
    Returns
122
    -------
123
    header : `~astropy.io.fits.Header`
124
        Header
125
    """
126
    nxpix = int(nxpix)
127
    nypix = int(nypix)
128
    if not xrefpix:
129
        xrefpix = (nxpix + 1) / 2.0
130
    if not yrefpix:
131
        yrefpix = (nypix + 1) / 2.0
132
133
    if coordsys == "CEL":
134
        ctype1, ctype2 = "RA---", "DEC--"
135
    elif coordsys == "GAL":
136
        ctype1, ctype2 = "GLON-", "GLAT-"
137
    else:
138
        raise ValueError(f"Unsupported coordsys: {coordsys!r}")
139
140
    pars = {
141
        "NAXIS": 2,
142
        "NAXIS1": nxpix,
143
        "NAXIS2": nypix,
144
        "CTYPE1": ctype1 + proj,
145
        "CRVAL1": xref,
146
        "CRPIX1": xrefpix,
147
        "CUNIT1": "deg",
148
        "CDELT1": -binsz,
149
        "CTYPE2": ctype2 + proj,
150
        "CRVAL2": yref,
151
        "CRPIX2": yrefpix,
152
        "CUNIT2": "deg",
153
        "CDELT2": binsz,
154
    }
155
156
    header = fits.Header()
157
    header.update(pars)
158
159
    return header
160
161
162
class WcsGeom(Geom):
163
    """Geometry class for WCS maps.
164
165
    This class encapsulates both the WCS transformation object and the
166
    the image extent (number of pixels in each dimension).  Provides
167
    methods for accessing the properties of the WCS object and
168
    performing transformations between pixel and world coordinates.
169
170
    Parameters
171
    ----------
172
    wcs : `~astropy.wcs.WCS`
173
        WCS projection object
174
    npix : tuple
175
        Number of pixels in each spatial dimension
176
    cdelt : tuple
177
        Pixel size in each image plane.  If none then a constant pixel size will be used.
178
    crpix : tuple
179
        Reference pixel coordinate in each image plane.
180
    axes : list
181
        Axes for non-spatial dimensions
182
    cutout_info : dict
183
        Dict with cutout info, if the `WcsGeom` was created by `WcsGeom.cutout()`
184
    """
185
186
    _slice_spatial_axes = slice(0, 2)
187
    _slice_non_spatial_axes = slice(2, None)
188
    is_hpx = False
189
190
    def __init__(self, wcs, npix, cdelt=None, crpix=None, axes=None, cutout_info=None):
191
        self._wcs = wcs
192
        self._coordsys = get_coordys(wcs)
193
        self._projection = get_projection(wcs)
194
        self._axes = make_axes(axes)
195
196
        if cdelt is None:
197
            cdelt = tuple(np.abs(self.wcs.wcs.cdelt))
198
199
        # Shape to use for WCS transformations
200
        wcs_shape = max([get_shape(t) for t in [npix, cdelt]])
201
        self._npix = cast_to_shape(npix, wcs_shape, int)
202
        self._cdelt = cast_to_shape(cdelt, wcs_shape, float)
203
204
        # By convention CRPIX is indexed from 1
205
        if crpix is None:
206
            crpix = tuple(1.0 + (np.array(self._npix) - 1.0) / 2.0)
207
208
        self._crpix = crpix
209
        self._cutout_info = cutout_info
210
211
    @property
212
    def data_shape(self):
213
        """Shape of the Numpy data array matching this geometry."""
214
        return self._shape[::-1]
215
216
    @property
217
    def _shape(self):
218
        npix_shape = [np.max(self.npix[0]), np.max(self.npix[1])]
219
        ax_shape = [ax.nbin for ax in self.axes]
220
        return tuple(npix_shape + ax_shape)
221
222
    @property
223
    def _shape_edges(self):
224
        npix_shape = [np.max(self.npix[0]) + 1, np.max(self.npix[1]) + 1]
225
        ax_shape = [ax.nbin for ax in self.axes]
226
        return tuple(npix_shape + ax_shape)
227
228
    @property
229
    def shape_axes(self):
230
        """Shape of non-spatial axes."""
231
        return self._shape[self._slice_non_spatial_axes]
232
233
    @property
234
    def wcs(self):
235
        """WCS projection object."""
236
        return self._wcs
237
238
    @property
239
    def coordsys(self):
240
        """Coordinate system of the projection.
241
242
        Galactic ('GAL') or Equatorial ('CEL').
243
        """
244
        return self._coordsys
245
246
    @property
247
    def cutout_info(self):
248
        """Cutout info dict."""
249
        return self._cutout_info
250
251
    @property
252
    def projection(self):
253
        """Map projection."""
254
        return self._projection
255
256
    @property
257
    def is_allsky(self):
258
        """Flag for all-sky maps."""
259
        if np.all(np.isclose(self._npix[0] * self._cdelt[0], 360.0)):
260
            return True
261
        else:
262
            return False
263
264
    @property
265
    def is_regular(self):
266
        """Is this geometry is regular in non-spatial dimensions (bool)?
267
268
        - False for multi-resolution or irregular geometries.
269
        - True if all image planes have the same pixel geometry.
270
        """
271
        if self.npix[0].size > 1:
272
            return False
273
        else:
274
            return True
275
276
    @property
277
    def width(self):
278
        """Tuple with image dimension in deg in longitude and latitude."""
279
        dlon = self._cdelt[0] * self._npix[0]
280
        dlat = self._cdelt[1] * self._npix[1]
281
        return (dlon, dlat) * u.deg
282
283
    @property
284
    def pixel_area(self):
285
        """Pixel area in deg^2."""
286
        # FIXME: Correctly compute solid angle for projection
287
        return self._cdelt[0] * self._cdelt[1]
288
289
    @property
290
    def npix(self):
291
        """Tuple with image dimension in pixels in longitude and latitude."""
292
        return self._npix
293
294
    @property
295
    def axes(self):
296
        """List of non-spatial axes."""
297
        return self._axes
298
299
    @property
300
    def ndim(self):
301
        return len(self.data_shape)
302
303
    @property
304
    def center_coord(self):
305
        """Map coordinate of the center of the geometry.
306
307
        Returns
308
        -------
309
        coord : tuple
310
        """
311
        return self.pix_to_coord(self.center_pix)
312
313
    @property
314
    def center_pix(self):
315
        """Pixel coordinate of the center of the geometry.
316
317
        Returns
318
        -------
319
        pix : tuple
320
        """
321
        return tuple((np.array(self.data_shape) - 1.0) / 2)[::-1]
322
323
    @property
324
    def center_skydir(self):
325
        """Sky coordinate of the center of the geometry.
326
327
        Returns
328
        -------
329
        pix : `~astropy.coordinates.SkyCoord`
330
        """
331
        return SkyCoord.from_pixel(self.center_pix[0], self.center_pix[1], self.wcs)
332
333
    @property
334
    def pixel_scales(self):
335
        """
336
        Pixel scale.
337
338
        Returns angles along each axis of the image at the CRPIX location once
339
        it is projected onto the plane of intermediate world coordinates.
340
341
        Returns
342
        -------
343
        angle: `~astropy.coordinates.Angle`
344
        """
345
        return Angle(proj_plane_pixel_scales(self.wcs), "deg")
346
347
    @classmethod
348
    def create(
349
        cls,
350
        npix=None,
351
        binsz=0.5,
352
        proj="CAR",
353
        coordsys="CEL",
354
        refpix=None,
355
        axes=None,
356
        skydir=None,
357
        width=None,
358
    ):
359
        """Create a WCS geometry object.
360
361
        Pixelization of the map is set with
362
        ``binsz`` and one of either ``npix`` or ``width`` arguments.
363
        For maps with non-spatial dimensions a different pixelization
364
        can be used for each image plane by passing a list or array
365
        argument for any of the pixelization parameters.  If both npix
366
        and width are None then an all-sky geometry will be created.
367
368
        Parameters
369
        ----------
370
        npix : int or tuple or list
371
            Width of the map in pixels. A tuple will be interpreted as
372
            parameters for longitude and latitude axes.  For maps with
373
            non-spatial dimensions, list input can be used to define a
374
            different map width in each image plane.  This option
375
            supersedes width.
376
        width : float or tuple or list
377
            Width of the map in degrees.  A tuple will be interpreted
378
            as parameters for longitude and latitude axes.  For maps
379
            with non-spatial dimensions, list input can be used to
380
            define a different map width in each image plane.
381
        binsz : float or tuple or list
382
            Map pixel size in degrees.  A tuple will be interpreted
383
            as parameters for longitude and latitude axes.  For maps
384
            with non-spatial dimensions, list input can be used to
385
            define a different bin size in each image plane.
386
        skydir : tuple or `~astropy.coordinates.SkyCoord`
387
            Sky position of map center.  Can be either a SkyCoord
388
            object or a tuple of longitude and latitude in deg in the
389
            coordinate system of the map.
390
        coordsys : {'CEL', 'GAL'}, optional
391
            Coordinate system, either Galactic ('GAL') or Equatorial ('CEL').
392
        axes : list
393
            List of non-spatial axes.
394
        proj : string, optional
395
            Any valid WCS projection type. Default is 'CAR' (cartesian).
396
        refpix : tuple
397
            Reference pixel of the projection.  If None this will be
398
            set to the center of the map.
399
400
        Returns
401
        -------
402
        geom : `~WcsGeom`
403
            A WCS geometry object.
404
405
        Examples
406
        --------
407
        >>> from gammapy.maps import WcsGeom
408
        >>> from gammapy.maps import MapAxis
409
        >>> axis = MapAxis.from_bounds(0,1,2)
410
        >>> geom = WcsGeom.create(npix=(100,100), binsz=0.1)
411
        >>> geom = WcsGeom.create(npix=[100,200], binsz=[0.1,0.05], axes=[axis])
412
        >>> geom = WcsGeom.create(width=[5.0,8.0], binsz=[0.1,0.05], axes=[axis])
413
        >>> geom = WcsGeom.create(npix=([100,200],[100,200]), binsz=0.1, axes=[axis])
414
        """
415
        if skydir is None:
416
            xref, yref = (0.0, 0.0)
417
        elif isinstance(skydir, tuple):
418
            xref, yref = skydir
419
        elif isinstance(skydir, SkyCoord):
420
            xref, yref, frame = skycoord_to_lonlat(skydir, coordsys=coordsys)
421
        else:
422
            raise ValueError(f"Invalid type for skydir: {type(skydir)!r}")
423
424
        if width is not None:
425
            width = _check_width(width)
426
427
        shape = max([get_shape(t) for t in [npix, binsz, width]])
428
        binsz = cast_to_shape(binsz, shape, float)
429
430
        # If both npix and width are None then create an all-sky geometry
431
        if npix is None and width is None:
432
            width = (360.0, 180.0)
433
434
        if npix is None:
435
            width = cast_to_shape(width, shape, float)
436
            npix = (
437
                np.rint(width[0] / binsz[0]).astype(int),
438
                np.rint(width[1] / binsz[1]).astype(int),
439
            )
440
        else:
441
            npix = cast_to_shape(npix, shape, int)
442
443
        if refpix is None:
444
            refpix = (None, None)
445
446
        header = _make_image_header(
447
            nxpix=npix[0].flat[0],
448
            nypix=npix[1].flat[0],
449
            binsz=binsz[0].flat[0],
450
            xref=float(xref),
451
            yref=float(yref),
452
            proj=proj,
453
            coordsys=coordsys,
454
            xrefpix=refpix[0],
455
            yrefpix=refpix[1],
456
        )
457
        wcs = WCS(header)
458
        return cls(wcs, npix, cdelt=binsz, axes=axes)
459
460
    @classmethod
461
    def from_header(cls, header, hdu_bands=None):
462
        """Create a WCS geometry object from a FITS header.
463
464
        Parameters
465
        ----------
466
        header : `~astropy.io.fits.Header`
467
            The FITS header
468
        hdu_bands : `~astropy.io.fits.BinTableHDU`
469
            The BANDS table HDU.
470
471
        Returns
472
        -------
473
        wcs : `~WcsGeom`
474
            WCS geometry object.
475
        """
476
        wcs = WCS(header)
477
        naxis = wcs.naxis
478
        for i in range(naxis - 2):
479
            wcs = wcs.dropaxis(2)
480
481
        axes = find_and_read_bands(hdu_bands)
482
        shape = tuple([ax.nbin for ax in axes])
483
484
        if hdu_bands is not None and "NPIX" in hdu_bands.columns.names:
485
            npix = hdu_bands.data.field("NPIX").reshape(shape + (2,))
486
            npix = (npix[..., 0], npix[..., 1])
487
            cdelt = hdu_bands.data.field("CDELT").reshape(shape + (2,))
488
            cdelt = (cdelt[..., 0], cdelt[..., 1])
489
        elif "WCSSHAPE" in header:
490
            wcs_shape = eval(header["WCSSHAPE"])
491
            npix = (wcs_shape[0], wcs_shape[1])
492
            cdelt = None
493
        else:
494
            npix = (header["NAXIS1"], header["NAXIS2"])
495
            cdelt = None
496
497
        return cls(wcs, npix, cdelt=cdelt, axes=axes)
498
499
    def _make_bands_cols(self, hdu=None, conv=None):
500
501
        cols = []
502
        if not self.is_regular:
503
            cols += [
504
                fits.Column(
505
                    "NPIX",
506
                    "2I",
507
                    dim="(2)",
508
                    array=np.vstack((np.ravel(self.npix[0]), np.ravel(self.npix[1]))).T,
509
                )
510
            ]
511
            cols += [
512
                fits.Column(
513
                    "CDELT",
514
                    "2E",
515
                    dim="(2)",
516
                    array=np.vstack(
517
                        (np.ravel(self._cdelt[0]), np.ravel(self._cdelt[1]))
518
                    ).T,
519
                )
520
            ]
521
            cols += [
522
                fits.Column(
523
                    "CRPIX",
524
                    "2E",
525
                    dim="(2)",
526
                    array=np.vstack(
527
                        (np.ravel(self._crpix[0]), np.ravel(self._crpix[1]))
528
                    ).T,
529
                )
530
            ]
531
        return cols
532
533
    def make_header(self):
534
        header = self.wcs.to_header()
535
        self._fill_header_from_axes(header)
536
        shape = "{},{}".format(np.max(self.npix[0]), np.max(self.npix[1]))
537
        for ax in self.axes:
538
            shape += f",{ax.nbin}"
539
        header["WCSSHAPE"] = f"({shape})"
540
        return header
541
542
    def get_image_shape(self, idx):
543
        """Get the shape of the image plane at index ``idx``."""
544
        if self.is_regular:
545
            return int(self.npix[0]), int(self.npix[1])
546
        else:
547
            return int(self.npix[0][idx]), int(self.npix[1][idx])
548
549
    @lru_cache()
550
    def get_idx(self, idx=None, flat=False):
551
        pix = self.get_pix(idx=idx, mode="center")
552
        if flat:
553
            pix = tuple([p[np.isfinite(p)] for p in pix])
554
        return pix_tuple_to_idx(pix)
555
556
    def _get_pix_all(self, idx=None, mode="center"):
557
        """Get idx coordinate array without footprint of the projection applied"""
558
        if mode == "edges":
559
            shape = self._shape_edges
560
        else:
561
            shape = self._shape
562
563
        if idx is None:
564
            pix = [np.arange(n, dtype=float) for n in shape]
565
        else:
566
            pix = [np.arange(n, dtype=float) for n in shape[self._slice_spatial_axes]]
567
            pix += [float(t) for t in idx]
568
569
        if mode == "edges":
570
            for pix_array in pix[self._slice_spatial_axes]:
571
                pix_array -= 0.5
572
573
        pix = np.meshgrid(*pix[::-1], indexing="ij")[::-1]
574
        return pix
575
576
    @lru_cache()
577
    def get_pix(self, idx=None, mode="center"):
578
        """Get map pix coordinates from the geometry.
579
580
        Parameters
581
        ----------
582
        mode : {'center', 'edges'}
583
            Get center or edge pix coordinates for the spatial axes.
584
585
        Returns
586
        -------
587
        coord : tuple
588
            Map pix coordinate tuple.
589
        """
590
        pix = self._get_pix_all(idx=idx, mode=mode)
591
        coords = self.pix_to_coord(pix)
592
        m = np.isfinite(coords[0])
593
        for _ in pix:
594
            _[~m] = INVALID_INDEX.float
595
        return pix
596
597
    @lru_cache()
598
    def get_coord(self, idx=None, flat=False, mode="center", coordsys=None):
599
        """Get map coordinates from the geometry.
600
601
        Parameters
602
        ----------
603
        mode : {'center', 'edges'}
604
            Get center or edge coordinates for the spatial axes.
605
606
        Returns
607
        -------
608
        coord : `~MapCoord`
609
            Map coordinate object.
610
        """
611
        pix = self._get_pix_all(idx=idx, mode=mode)
612
        coords = self.pix_to_coord(pix)
613
614
        if flat:
615
            is_finite = np.isfinite(coords[0])
616
            coords = tuple([c[is_finite] for c in coords])
617
618
        axes_names = ["lon", "lat"] + [ax.name for ax in self.axes]
619
        cdict = dict(zip(axes_names, coords))
620
621
        if coordsys is None:
622
            coordsys = self.coordsys
623
624
        return MapCoord.create(cdict, coordsys=self.coordsys).to_coordsys(coordsys)
625
626
    def coord_to_pix(self, coords):
627
        coords = MapCoord.create(coords, coordsys=self.coordsys)
628
629
        if coords.size == 0:
630
            return tuple([np.array([]) for i in range(coords.ndim)])
631
632
        c = self.coord_to_tuple(coords)
633
        # Variable Bin Size
634
        if not self.is_regular:
635
            idxs = tuple(
636
                [
637
                    np.clip(ax.coord_to_idx(c[i + 2]), 0, ax.nbin - 1)
638
                    for i, ax in enumerate(self.axes)
639
                ]
640
            )
641
            crpix = [t[idxs] for t in self._crpix]
642
            cdelt = [t[idxs] for t in self._cdelt]
643
            pix = world2pix(self.wcs, cdelt, crpix, (coords.lon, coords.lat))
644
            pix = list(pix)
645
        else:
646
            pix = self._wcs.wcs_world2pix(coords.lon, coords.lat, 0)
647
648
        for coord, ax in zip(c[self._slice_non_spatial_axes], self.axes):
649
            pix += [ax.coord_to_pix(coord)]
650
651
        return tuple(pix)
652
653
    def pix_to_coord(self, pix):
654
        # Variable Bin Size
655
        if not self.is_regular:
656
            idxs = pix_tuple_to_idx(pix[self._slice_non_spatial_axes])
657
            crpix = [t[idxs] for t in self._crpix]
658
            cdelt = [t[idxs] for t in self._cdelt]
659
            coords = pix2world(self.wcs, cdelt, crpix, pix[self._slice_spatial_axes])
660
        else:
661
            coords = self._wcs.wcs_pix2world(pix[0], pix[1], 0)
662
663
        coords = [
664
            u.Quantity(coords[0], unit="deg", copy=False),
665
            u.Quantity(coords[1], unit="deg", copy=False),
666
        ]
667
668
        coords += axes_pix_to_coord(self.axes, pix[self._slice_non_spatial_axes])
669
        return tuple(coords)
670
671
    def pix_to_idx(self, pix, clip=False):
672
        # TODO: copy idx to avoid modifying input pix?
673
        # pix_tuple_to_idx seems to always make a copy!?
674
        idxs = pix_tuple_to_idx(pix)
675
        if not self.is_regular:
676
            ibin = pix[self._slice_non_spatial_axes]
677
            ibin = pix_tuple_to_idx(ibin)
678
            for i, ax in enumerate(self.axes):
679
                np.clip(ibin[i], 0, ax.nbin - 1, out=ibin[i])
680
            npix = (self.npix[0][ibin], self.npix[1][ibin])
681
        else:
682
            npix = self.npix
683
684
        for i, idx in enumerate(idxs):
685
            if clip:
686
                if i < 2:
687
                    np.clip(idxs[i], 0, npix[i], out=idxs[i])
688
                else:
689
                    np.clip(idxs[i], 0, self.axes[i - 2].nbin - 1, out=idxs[i])
690
            else:
691
                if i < 2:
692
                    np.putmask(idxs[i], (idx < 0) | (idx >= npix[i]), -1)
693
                else:
694
                    np.putmask(idxs[i], (idx < 0) | (idx >= self.axes[i - 2].nbin), -1)
695
696
        return idxs
697
698
    def contains(self, coords):
699
        idx = self.coord_to_idx(coords)
700
        return np.all(np.stack([t != INVALID_INDEX.int for t in idx]), axis=0)
701
702
    @lru_cache()
703
    def to_image(self):
704
        npix = (np.max(self._npix[0]), np.max(self._npix[1]))
705
        cdelt = (np.max(self._cdelt[0]), np.max(self._cdelt[1]))
706
707
        if self.cutout_info:
708
            cutout_info = self.cutout_info.copy()
709
            cutout_info["parent-geom"] = cutout_info["parent-geom"].to_image()
710
        else:
711
            cutout_info = None
712
        return self.__class__(self._wcs, npix, cdelt=cdelt, cutout_info=cutout_info)
713
714
    def to_cube(self, axes):
715
        npix = (np.max(self._npix[0]), np.max(self._npix[1]))
716
        cdelt = (np.max(self._cdelt[0]), np.max(self._cdelt[1]))
717
        axes = copy.deepcopy(self.axes) + axes
718
719
        if self.cutout_info:
720
            cutout_info = self.cutout_info.copy()
721
            cutout_info["parent-geom"] = cutout_info["parent-geom"].to_cube(axes)
722
        else:
723
            cutout_info = None
724
725
        return self.__class__(self._wcs.deepcopy(), npix, cdelt=cdelt, axes=axes, cutout_info=cutout_info)
726
727
    def pad(self, pad_width):
728
        if np.isscalar(pad_width):
729
            pad_width = (pad_width, pad_width)
730
731
        npix = (self.npix[0] + 2 * pad_width[0], self.npix[1] + 2 * pad_width[1])
732
        wcs = self._wcs.deepcopy()
733
        wcs.wcs.crpix += np.array(pad_width)
734
        cdelt = copy.deepcopy(self._cdelt)
735
        return self.__class__(wcs, npix, cdelt=cdelt, axes=copy.deepcopy(self.axes))
736
737
    def crop(self, crop_width):
738
        if np.isscalar(crop_width):
739
            crop_width = (crop_width, crop_width)
740
741
        npix = (self.npix[0] - 2 * crop_width[0], self.npix[1] - 2 * crop_width[1])
742
        wcs = self._wcs.deepcopy()
743
        wcs.wcs.crpix -= np.array(crop_width)
744
        cdelt = copy.deepcopy(self._cdelt)
745
        return self.__class__(wcs, npix, cdelt=cdelt, axes=copy.deepcopy(self.axes))
746
747
    def downsample(self, factor, axis=None):
748
        if axis is None:
749
            if np.any(np.mod(self.npix, factor) > 0):
750
                raise ValueError(
751
                    f"Spatial shape not divisible by factor {factor!r} in all axes."
752
                    f" You need to pad prior to calling downsample."
753
                )
754
755
            npix = (self.npix[0] / factor, self.npix[1] / factor)
756
            cdelt = (self._cdelt[0] * factor, self._cdelt[1] * factor)
757
            wcs = get_resampled_wcs(self.wcs, factor, True)
758
            return self._init_copy(wcs=wcs, npix=npix, cdelt=cdelt)
759
        else:
760
            if not self.is_regular:
761
                raise NotImplementedError(
762
                    "Upsampling in non-spatial axes not supported for irregular geometries"
763
                )
764
765
            axes = copy.deepcopy(self.axes)
766
            idx = self.get_axis_index_by_name(axis)
767
            axes[idx] = axes[idx].downsample(factor)
768
            return self._init_copy(axes=axes)
769
770
    def upsample(self, factor, axis=None):
771
        if axis is None:
772
            npix = (self.npix[0] * factor, self.npix[1] * factor)
773
            cdelt = (self._cdelt[0] / factor, self._cdelt[1] / factor)
774
            wcs = get_resampled_wcs(self.wcs, factor, False)
775
            return self._init_copy(wcs=wcs, npix=npix, cdelt=cdelt)
776
        else:
777
            if not self.is_regular:
778
                raise NotImplementedError(
779
                    "Upsampling in non-spatial axes not supported for irregular geometries"
780
                )
781
            axes = copy.deepcopy(self.axes)
782
            idx = self.get_axis_index_by_name(axis)
783
            axes[idx] = axes[idx].upsample(factor)
784
            return self._init_copy(axes=axes)
785
786
    def to_binsz(self, binsz):
787
        """Change pixel size of the geometry
788
789
        Parameters
790
        ----------
791
        binsz : float or tuple or list
792
            New pixel size in degree.
793
794
        Returns
795
        -------
796
        geom : `WcsGeom`
797
            Geometry with new pixel size.
798
        """
799
        kwargs = {}
800
        kwargs["skydir"] = self.center_skydir
801
        kwargs["binsz"] = binsz
802
        kwargs["width"] = self.width
803
        kwargs["proj"] = self.projection
804
        kwargs["coordsys"] = self.coordsys
805
        kwargs["axes"] = copy.deepcopy(self.axes)
806
        return self.create(**kwargs)
807
808
    @lru_cache()
809
    def solid_angle(self):
810
        """Solid angle array (`~astropy.units.Quantity` in ``sr``).
811
812
        The array has the same dimension as the WcsGeom object.
813
814
        To return solid angles for the spatial dimensions only use::
815
816
            WcsGeom.to_image().solid_angle()
817
        """
818
        coord = self.get_coord(mode="edges").skycoord
819
820
        # define pixel corners
821
        low_left = coord[..., :-1, :-1]
822
        low_right = coord[..., 1:, :-1]
823
        up_left = coord[..., :-1, 1:]
824
        up_right = coord[..., 1:, 1:]
825
826
        # compute side lengths
827
        low = low_left.separation(low_right)
828
        left = low_left.separation(up_left)
829
        up = up_left.separation(up_right)
830
        right = low_right.separation(up_right)
831
832
        # compute enclosed angles
833
        angle_low_right = low_right.position_angle(up_right) - low_right.position_angle(
834
            low_left
835
        )
836
        angle_up_left = up_left.position_angle(up_right) - low_left.position_angle(
837
            up_left
838
        )
839
840
        # compute area assuming a planar triangle
841
        area_low_right = 0.5 * low * right * np.sin(angle_low_right)
842
        area_up_left = 0.5 * up * left * np.sin(angle_up_left)
843
844
        return u.Quantity(area_low_right + area_up_left, "sr", copy=False)
845
846
    @lru_cache()
847
    def bin_volume(self):
848
        """Bin volume (`~astropy.units.Quantity`)"""
849
        bin_volume = self.to_image().solid_angle()
850
851
        for idx, ax in enumerate(self.axes):
852
            shape = self.ndim * [1]
853
            shape[-(idx + 3)] = -1
854
            bin_volume = bin_volume * ax.bin_width.reshape(tuple(shape))
855
856
        return bin_volume
857
858
    def separation(self, center):
859
        """Compute sky separation wrt a given center.
860
861
        Parameters
862
        ----------
863
        center : `~astropy.coordinates.SkyCoord`
864
            Center position
865
866
        Returns
867
        -------
868
        separation : `~astropy.coordinates.Angle`
869
            Separation angle array (2D)
870
        """
871
        coord = self.to_image().get_coord()
872
        return center.separation(coord.skycoord)
873
874
    def cutout(self, position, width, mode="trim"):
875
        """
876
        Create a cutout around a given position.
877
878
        Parameters
879
        ----------
880
        position : `~astropy.coordinates.SkyCoord`
881
            Center position of the cutout region.
882
        width : tuple of `~astropy.coordinates.Angle`
883
            Angular sizes of the region in (lon, lat) in that specific order.
884
            If only one value is passed, a square region is extracted.
885
        mode : {'trim', 'partial', 'strict'}
886
            Mode option for Cutout2D, for details see `~astropy.nddata.utils.Cutout2D`.
887
888
        Returns
889
        -------
890
        cutout : `~gammapy.maps.WcsNDMap`
891
            Cutout map
892
        """
893
        width = _check_width(width)
894
        dummy_data = np.empty(self.to_image().data_shape)
895
        c2d = Cutout2D(
896
            data=dummy_data,
897
            wcs=self.wcs,
898
            position=position,
899
            # Cutout2D takes size with order (lat, lon)
900
            size=width[::-1] * u.deg,
901
            mode=mode,
902
        )
903
904
        cutout_info = {
905
            "parent-geom": self,
906
            "parent-slices": c2d.slices_original,
907
            "cutout-slices": c2d.slices_cutout
908
        }
909
910
        return self._init_copy(wcs=c2d.wcs, npix=c2d.shape[::-1], cutout_info=cutout_info)
911
912
    def region_mask(self, regions, inside=True):
913
        """Create a mask from a given list of regions
914
915
        Parameters
916
        ----------
917
        regions : list of  `~regions.Region`
918
            Python list of regions (pixel or sky regions accepted)
919
        inside : bool
920
            For ``inside=True``, pixels in the region to True (the default).
921
            For ``inside=False``, pixels in the region are False.
922
923
        Returns
924
        -------
925
        mask_map : `~numpy.ndarray` of boolean type
926
            Boolean region mask
927
928
        Examples
929
        --------
930
        Make an exclusion mask for a circular region::
931
932
            from regions import CircleSkyRegion
933
            from astropy.coordinates import SkyCoord, Angle
934
            from gammapy.maps import WcsNDMap, WcsGeom
935
936
            pos = SkyCoord(0, 0, unit='deg')
937
            geom = WcsGeom.create(skydir=pos, npix=100, binsz=0.1)
938
939
            region = CircleSkyRegion(
940
                SkyCoord(3, 2, unit='deg'),
941
                Angle(1, 'deg'),
942
            )
943
            mask = geom.region_mask([region], inside=False)
944
945
        Note how we made a list with a single region,
946
        since this method expects a list of regions.
947
948
        The return ``mask`` is a boolean Numpy array.
949
        If you want a map object (e.g. for storing in FITS or plotting),
950
        this is how you can make the map::
951
952
            mask_map = WcsNDMap(geom=geom, data=mask)
953
            mask_map.plot()
954
        """
955
        from regions import PixCoord
956
957
        if not self.is_regular:
958
            raise ValueError("Multi-resolution maps not supported yet")
959
960
        idx = self.get_idx()
961
        pixcoord = PixCoord(idx[0], idx[1])
962
963
        mask = np.zeros(self.data_shape, dtype=bool)
964
965
        for region in regions:
966
            if isinstance(region, SkyRegion):
967
                region = region.to_pixel(self.wcs)
968
            mask += region.contains(pixcoord)
969
970
        if inside is False:
971
            np.logical_not(mask, out=mask)
972
973
        return mask
974
975
    def __repr__(self):
976
        axes = ["lon", "lat"] + [_.name for _ in self.axes]
977
        lon = self.center_skydir.data.lon.deg
978
        lat = self.center_skydir.data.lat.deg
979
980
        return (
981
            f"{self.__class__.__name__}\n\n"
982
            f"\taxes       : {axes}\n"
983
            f"\tshape      : {self.data_shape[::-1]}\n"
984
            f"\tndim       : {self.ndim}\n"
985
            f"\tcoordsys   : {self.coordsys}\n"
986
            f"\tprojection : {self.projection}\n"
987
            f"\tcenter     : {lon:.1f} deg, {lat:.1f} deg\n"
988
            f"\twidth      : {self.width[0][0]:.1f} x {self.width[1][0]:.1f}\n"
989
        )
990
991
    def __eq__(self, other):
992
        if not isinstance(other, self.__class__):
993
            return NotImplemented
994
995
        # check overall shape and axes compatibility
996
        if self.data_shape != other.data_shape:
997
            return False
998
999
        for axis, otheraxis in zip(self.axes, other.axes):
1000
            if axis != otheraxis:
1001
                return False
1002
1003
        # check WCS consistency with a priori tolerance of 1e-6
1004
        return self.wcs.wcs.compare(other.wcs.wcs, tolerance=1e-6)
1005
1006
    def __ne__(self, other):
1007
        return not self.__eq__(other)
1008
1009
    def __hash__(self):
1010
        return id(self)
1011
1012
1013
def pix2world(wcs, cdelt, crpix, pix):
1014
    """Perform pixel to world coordinate transformation.
1015
1016
    For a WCS projection with a given pixel size (CDELT) and reference pixel
1017
    (CRPIX). This method can be used to perform WCS transformations
1018
    for projections with different pixelizations but the same
1019
    reference coordinate (CRVAL), projection type, and coordinate system.
1020
1021
    Parameters
1022
    ----------
1023
    wcs : `astropy.wcs.WCS`
1024
        WCS transform object.
1025
    cdelt : tuple
1026
        Tuple of X/Y pixel size in deg.  Each element should have the
1027
        same length as ``pix``.
1028
    crpix : tuple
1029
        Tuple of reference pixel parameters in X and Y dimensions.  Each
1030
        element should have the same length as ``pix``.
1031
    pix : tuple
1032
        Tuple of pixel coordinates.
1033
    """
1034
    pix_ratio = [
1035
        np.abs(wcs.wcs.cdelt[0] / cdelt[0]),
1036
        np.abs(wcs.wcs.cdelt[1] / cdelt[1]),
1037
    ]
1038
    pix = (
1039
        (pix[0] - (crpix[0] - 1.0)) / pix_ratio[0] + wcs.wcs.crpix[0] - 1.0,
1040
        (pix[1] - (crpix[1] - 1.0)) / pix_ratio[1] + wcs.wcs.crpix[1] - 1.0,
1041
    )
1042
    return wcs.wcs_pix2world(pix[0], pix[1], 0)
1043
1044
1045
def world2pix(wcs, cdelt, crpix, coord):
1046
    pix_ratio = [
1047
        np.abs(wcs.wcs.cdelt[0] / cdelt[0]),
1048
        np.abs(wcs.wcs.cdelt[1] / cdelt[1]),
1049
    ]
1050
    pix = wcs.wcs_world2pix(coord[0], coord[1], 0)
1051
    return (
1052
        (pix[0] - (wcs.wcs.crpix[0] - 1.0)) * pix_ratio[0] + crpix[0] - 1.0,
1053
        (pix[1] - (wcs.wcs.crpix[1] - 1.0)) * pix_ratio[1] + crpix[1] - 1.0,
1054
    )
1055
1056
1057
def get_projection(wcs):
1058
    return wcs.wcs.ctype[0][5:]
1059
1060
1061
def get_coordys(wcs):
1062
    if "RA" in wcs.wcs.ctype[0]:
1063
        return "CEL"
1064
    elif "GLON" in wcs.wcs.ctype[0]:
1065
        return "GAL"
1066
    else:
1067
        raise ValueError("Unrecognized WCS coordinate system.")
1068