Passed
Pull Request — master (#1898)
by
unknown
04:26
created

gammapy/maps/wcs.py (1 issue)

1
# Licensed under a 3-clause BSD style license - see LICENSE.rst
0 ignored issues
show
Too many lines in module (1104/1000)
Loading history...
2
from __future__ import absolute_import, division, print_function, unicode_literals
3
import copy
4
import numpy as np
5
from collections import OrderedDict
6
from astropy.wcs import WCS
7
from astropy.io import fits
8
from astropy.coordinates import SkyCoord, Angle
9
from astropy.coordinates.angle_utilities import angular_separation
10
from astropy.wcs.utils import proj_plane_pixel_scales
11
import astropy.units as u
12
from regions import SkyRegion
13
from ..utils.wcs import get_resampled_wcs
14
from .geom import MapGeom, MapCoord, pix_tuple_to_idx, skycoord_to_lonlat
15
from .geom import get_shape, make_axes, find_and_read_bands
16
17
__all__ = ["WcsGeom"]
18
19
20
def _check_width(width):
21
    """Check and normalise width argument.
22
23
    Always returns tuple (lon, lat) as float in degrees.
24
    """
25
    if isinstance(width, tuple):
26
        lon = Angle(width[0], "deg").deg
27
        lat = Angle(width[1], "deg").deg
28
        return lon, lat
29
    else:
30
        angle = Angle(width, "deg").deg
31
        if np.isscalar(angle):
32
            return angle, angle
33
        else:
34
            return tuple(angle)
35
36
37
def cast_to_shape(param, shape, dtype):
38
    """Cast a tuple of parameter arrays to a given shape."""
39
    if not isinstance(param, tuple):
40
        param = [param]
41
42
    param = [np.array(p, ndmin=1, dtype=dtype) for p in param]
43
44
    if len(param) == 1:
45
        param = [param[0].copy(), param[0].copy()]
46
47
    for i, p in enumerate(param):
48
49
        if p.size > 1 and p.shape != shape:
50
            raise ValueError
51
52
        if p.shape == shape:
53
            continue
54
55
        param[i] = p * np.ones(shape, dtype=dtype)
56
57
    return tuple(param)
58
59
60
# TODO: remove this function, move code to the one caller below
61
def _make_image_header(
62
    nxpix=100,
63
    nypix=100,
64
    binsz=0.1,
65
    xref=0,
66
    yref=0,
67
    proj="CAR",
68
    coordsys="GAL",
69
    xrefpix=None,
70
    yrefpix=None,
71
):
72
    """Generate a FITS header from scratch.
73
74
    Uses the same parameter names as the Fermi tool gtbin.
75
76
    If no reference pixel position is given it is assumed ot be
77
    at the center of the image.
78
79
    Parameters
80
    ----------
81
    nxpix : int, optional
82
        Number of pixels in x axis. Default is 100.
83
    nypix : int, optional
84
        Number of pixels in y axis. Default is 100.
85
    binsz : float, optional
86
        Bin size for x and y axes in units of degrees. Default is 0.1.
87
    xref : float, optional
88
        Coordinate system value at reference pixel for x axis. Default is 0.
89
    yref : float, optional
90
        Coordinate system value at reference pixel for y axis. Default is 0.
91
    proj : string, optional
92
        Projection type. Default is 'CAR' (cartesian).
93
    coordsys : {'CEL', 'GAL'}, optional
94
        Coordinate system. Default is 'GAL' (Galactic).
95
    xrefpix : float, optional
96
        Coordinate system reference pixel for x axis. Default is None.
97
    yrefpix: float, optional
98
        Coordinate system reference pixel for y axis. Default is None.
99
100
    Returns
101
    -------
102
    header : `~astropy.io.fits.Header`
103
        Header
104
    """
105
    nxpix = int(nxpix)
106
    nypix = int(nypix)
107
    if not xrefpix:
108
        xrefpix = (nxpix + 1) / 2.0
109
    if not yrefpix:
110
        yrefpix = (nypix + 1) / 2.0
111
112
    if coordsys == "CEL":
113
        ctype1, ctype2 = "RA---", "DEC--"
114
    elif coordsys == "GAL":
115
        ctype1, ctype2 = "GLON-", "GLAT-"
116
    else:
117
        raise ValueError("Unsupported coordsys: {!r}".format(coordsys))
118
119
    pars = {
120
        "NAXIS": 2,
121
        "NAXIS1": nxpix,
122
        "NAXIS2": nypix,
123
        "CTYPE1": ctype1 + proj,
124
        "CRVAL1": xref,
125
        "CRPIX1": xrefpix,
126
        "CUNIT1": "deg",
127
        "CDELT1": -binsz,
128
        "CTYPE2": ctype2 + proj,
129
        "CRVAL2": yref,
130
        "CRPIX2": yrefpix,
131
        "CUNIT2": "deg",
132
        "CDELT2": binsz,
133
    }
134
135
    header = fits.Header()
136
    header.update(pars)
137
138
    return header
139
140
141
class WcsGeom(MapGeom):
142
    """Geometry class for WCS maps.
143
144
    This class encapsulates both the WCS transformation object and the
145
    the image extent (number of pixels in each dimension).  Provides
146
    methods for accessing the properties of the WCS object and
147
    performing transformations between pixel and world coordinates.
148
149
    Parameters
150
    ----------
151
    wcs : `~astropy.wcs.WCS`
152
        WCS projection object
153
    npix : tuple
154
        Number of pixels in each spatial dimension
155
    cdelt : tuple
156
        Pixel size in each image plane.  If none then a constant pixel size will be used.
157
    crpix : tuple
158
        Reference pixel coordinate in each image plane.
159
    axes : list
160
        Axes for non-spatial dimensions
161
    conv : {'gadf', 'fgst-ccube', 'fgst-template'}
162
        Serialization format convention.  This sets the default format
163
        that will be used when writing this geometry to a file.
164
    """
165
166
    _slice_spatial_axes = slice(0, 2)
167
    _slice_non_spatial_axes = slice(2, -1)
168
    is_hpx = False
169
170
    def __init__(self, wcs, npix, cdelt=None, crpix=None, axes=None, conv="gadf"):
171
        self._wcs = wcs
172
        self._coordsys = get_coordys(wcs)
173
        self._projection = get_projection(wcs)
174
        self._conv = conv
175
        self._axes = make_axes(axes, conv)
176
177
        if cdelt is None:
178
            cdelt = tuple(np.abs(self.wcs.wcs.cdelt))
179
180
        # Shape to use for WCS transformations
181
        wcs_shape = max([get_shape(t) for t in [npix, cdelt]])
182
        if np.sum(wcs_shape) > 1 and wcs_shape != self.shape:
183
            raise ValueError()
184
185
        self._npix = cast_to_shape(npix, wcs_shape, int)
186
        self._cdelt = cast_to_shape(cdelt, wcs_shape, float)
187
188
        # By convention CRPIX is indexed from 1
189
        if crpix is None:
190
            crpix = tuple(1.0 + (np.array(self._npix) - 1.0) / 2.0)
191
192
        self._crpix = crpix
193
194
    @property
195
    def data_shape(self):
196
        """Shape of the Numpy data array matching this geometry."""
197
        npix_shape = [np.max(self.npix[0]), np.max(self.npix[1])]
198
        ax_shape = [ax.nbin for ax in self.axes]
199
        return tuple(npix_shape + ax_shape)[::-1]
200
201
    @property
202
    def wcs(self):
203
        """WCS projection object."""
204
        return self._wcs
205
206
    @property
207
    def coordsys(self):
208
        """Coordinate system of the projection, either Galactic ('GAL') or
209
        Equatorial ('CEL')."""
210
        return self._coordsys
211
212
    @property
213
    def projection(self):
214
        """Map projection."""
215
        return self._projection
216
217
    @property
218
    def is_allsky(self):
219
        """Flag for all-sky maps."""
220
        if np.all(np.isclose(self._npix[0] * self._cdelt[0], 360.0)):
221
            return True
222
        else:
223
            return False
224
225
    @property
226
    def is_regular(self):
227
        """Flag identifying whether this geometry is regular in non-spatial
228
        dimensions.  False for multi-resolution or irregular
229
        geometries.  If true all image planes have the same pixel
230
        geometry.
231
        """
232
        if self.npix[0].size > 1:
233
            return False
234
        else:
235
            return True
236
237
    @property
238
    def width(self):
239
        """Tuple with image dimension in deg in longitude and latitude."""
240
        return (self._cdelt[0] * self._npix[0], self._cdelt[1] * self._npix[1])
241
242
    @property
243
    def pixel_area(self):
244
        """Pixel area in deg^2."""
245
        # FIXME: Correctly compute solid angle for projection
246
        return self._cdelt[0] * self._cdelt[1]
247
248
    @property
249
    def npix(self):
250
        """Tuple with image dimension in pixels in longitude and latitude."""
251
        return self._npix
252
253
    @property
254
    def conv(self):
255
        """Name of default FITS convention associated with this geometry."""
256
        return self._conv
257
258
    @property
259
    def axes(self):
260
        """List of non-spatial axes."""
261
        return self._axes
262
263
    @property
264
    def shape(self):
265
        """Shape of non-spatial axes."""
266
        return tuple([ax.nbin for ax in self._axes])
267
268
    @property
269
    def ndim(self):
270
        return len(self.data_shape)
271
272
    @property
273
    def center_coord(self):
274
        """Map coordinate of the center of the geometry.
275
276
        Returns
277
        -------
278
        coord : tuple
279
        """
280
        return self.pix_to_coord(self.center_pix)
281
282
    @property
283
    def center_pix(self):
284
        """Pixel coordinate of the center of the geometry.
285
286
        Returns
287
        -------
288
        pix : tuple
289
        """
290
        return tuple((np.array(self.data_shape) - 1.0) / 2)[::-1]
291
292
    @property
293
    def center_skydir(self):
294
        """Sky coordinate of the center of the geometry.
295
296
        Returns
297
        -------
298
        pix : `~astropy.coordinates.SkyCoord`
299
        """
300
        return SkyCoord.from_pixel(self.center_pix[0], self.center_pix[1], self.wcs)
301
302
    @property
303
    def pixel_scales(self):
304
        """
305
        Pixel scale.
306
307
        Returns angles along each axis of the image at the CRPIX location once
308
        it is projected onto the plane of intermediate world coordinates.
309
310
        Returns
311
        -------
312
        angle: `~astropy.coordinates.Angle`
313
        """
314
        return Angle(proj_plane_pixel_scales(self.wcs), "deg")
315
316
    @classmethod
317
    def create(
318
        cls,
319
        npix=None,
320
        binsz=0.5,
321
        proj="CAR",
322
        coordsys="CEL",
323
        refpix=None,
324
        axes=None,
325
        skydir=None,
326
        width=None,
327
        conv="gadf",
328
    ):
329
        """Create a WCS geometry object.
330
331
        Pixelization of the map is set with
332
        ``binsz`` and one of either ``npix`` or ``width`` arguments.
333
        For maps with non-spatial dimensions a different pixelization
334
        can be used for each image plane by passing a list or array
335
        argument for any of the pixelization parameters.  If both npix
336
        and width are None then an all-sky geometry will be created.
337
338
        Parameters
339
        ----------
340
        npix : int or tuple or list
341
            Width of the map in pixels. A tuple will be interpreted as
342
            parameters for longitude and latitude axes.  For maps with
343
            non-spatial dimensions, list input can be used to define a
344
            different map width in each image plane.  This option
345
            supersedes width.
346
        width : float or tuple or list
347
            Width of the map in degrees.  A tuple will be interpreted
348
            as parameters for longitude and latitude axes.  For maps
349
            with non-spatial dimensions, list input can be used to
350
            define a different map width in each image plane.
351
        binsz : float or tuple or list
352
            Map pixel size in degrees.  A tuple will be interpreted
353
            as parameters for longitude and latitude axes.  For maps
354
            with non-spatial dimensions, list input can be used to
355
            define a different bin size in each image plane.
356
        skydir : tuple or `~astropy.coordinates.SkyCoord`
357
            Sky position of map center.  Can be either a SkyCoord
358
            object or a tuple of longitude and latitude in deg in the
359
            coordinate system of the map.
360
        coordsys : {'CEL', 'GAL'}, optional
361
            Coordinate system, either Galactic ('GAL') or Equatorial ('CEL').
362
        axes : list
363
            List of non-spatial axes.
364
        proj : string, optional
365
            Any valid WCS projection type. Default is 'CAR' (cartesian).
366
        refpix : tuple
367
            Reference pixel of the projection.  If None this will be
368
            set to the center of the map.
369
        conv : string, optional
370
            FITS format convention ('fgst-ccube', 'fgst-template',
371
            'gadf').  Default is 'gadf'.
372
373
        Returns
374
        -------
375
        geom : `~WcsGeom`
376
            A WCS geometry object.
377
378
        Examples
379
        --------
380
        >>> from gammapy.maps import WcsGeom
381
        >>> from gammapy.maps import MapAxis
382
        >>> axis = MapAxis.from_bounds(0,1,2)
383
        >>> geom = WcsGeom.create(npix=(100,100), binsz=0.1)
384
        >>> geom = WcsGeom.create(npix=[100,200], binsz=[0.1,0.05], axes=[axis])
385
        >>> geom = WcsGeom.create(width=[5.0,8.0], binsz=[0.1,0.05], axes=[axis])
386
        >>> geom = WcsGeom.create(npix=([100,200],[100,200]), binsz=0.1, axes=[axis])
387
        """
388
        if skydir is None:
389
            xref, yref = (0.0, 0.0)
390
        elif isinstance(skydir, tuple):
391
            xref, yref = skydir
392
        elif isinstance(skydir, SkyCoord):
393
            xref, yref, frame = skycoord_to_lonlat(skydir, coordsys=coordsys)
394
        else:
395
            raise ValueError("Invalid type for skydir: {!r}".format(type(skydir)))
396
397
        if width is not None:
398
            width = _check_width(width)
399
400
        shape = max([get_shape(t) for t in [npix, binsz, width]])
401
        binsz = cast_to_shape(binsz, shape, float)
402
403
        # If both npix and width are None then create an all-sky geometry
404
        if npix is None and width is None:
405
            width = (360.0, 180.0)
406
407
        if npix is None:
408
            width = cast_to_shape(width, shape, float)
409
            npix = (
410
                np.rint(width[0] / binsz[0]).astype(int),
411
                np.rint(width[1] / binsz[1]).astype(int),
412
            )
413
        else:
414
            npix = cast_to_shape(npix, shape, int)
415
416
        if refpix is None:
417
            refpix = (None, None)
418
419
        header = _make_image_header(
420
            nxpix=npix[0].flat[0],
421
            nypix=npix[1].flat[0],
422
            binsz=binsz[0].flat[0],
423
            xref=float(xref),
424
            yref=float(yref),
425
            proj=proj,
426
            coordsys=coordsys,
427
            xrefpix=refpix[0],
428
            yrefpix=refpix[1],
429
        )
430
        wcs = WCS(header)
431
        return cls(wcs, npix, cdelt=binsz, axes=axes, conv=conv)
432
433
    @classmethod
434
    def from_header(cls, header, hdu_bands=None):
435
        """Create a WCS geometry object from a FITS header.
436
437
        Parameters
438
        ----------
439
        header : `~astropy.io.fits.Header`
440
            The FITS header
441
        hdu_bands : `~astropy.io.fits.BinTableHDU`
442
            The BANDS table HDU.
443
444
        Returns
445
        -------
446
        wcs : `~WcsGeom`
447
            WCS geometry object.
448
        """
449
        wcs = WCS(header)
450
        naxis = wcs.naxis
451
        for i in range(naxis - 2):
452
            wcs = wcs.dropaxis(2)
453
454
        axes = find_and_read_bands(hdu_bands, header)
455
        shape = tuple([ax.nbin for ax in axes])
456
        conv = "gadf"
457
458
        # Discover FITS convention
459
        if hdu_bands is not None:
460
            if hdu_bands.name == "EBOUNDS":
461
                conv = "fgst-ccube"
462
            elif hdu_bands.name == "ENERGIES":
463
                conv = "fgst-template"
464
465
        if hdu_bands is not None and "NPIX" in hdu_bands.columns.names:
466
            npix = hdu_bands.data.field("NPIX").reshape(shape + (2,))
467
            npix = (npix[..., 0], npix[..., 1])
468
            cdelt = hdu_bands.data.field("CDELT").reshape(shape + (2,))
469
            cdelt = (cdelt[..., 0], cdelt[..., 1])
470
        elif "WCSSHAPE" in header:
471
            wcs_shape = eval(header["WCSSHAPE"])
472
            npix = (wcs_shape[0], wcs_shape[1])
473
            cdelt = None
474
        else:
475
            npix = (header["NAXIS1"], header["NAXIS2"])
476
            cdelt = None
477
478
        return cls(wcs, npix, cdelt=cdelt, axes=axes, conv=conv)
479
480
    def _make_bands_cols(self, hdu=None, conv=None):
481
482
        cols = []
483
        if not self.is_regular:
484
            cols += [
485
                fits.Column(
486
                    "NPIX",
487
                    "2I",
488
                    dim="(2)",
489
                    array=np.vstack((np.ravel(self.npix[0]), np.ravel(self.npix[1]))).T,
490
                )
491
            ]
492
            cols += [
493
                fits.Column(
494
                    "CDELT",
495
                    "2E",
496
                    dim="(2)",
497
                    array=np.vstack(
498
                        (np.ravel(self._cdelt[0]), np.ravel(self._cdelt[1]))
499
                    ).T,
500
                )
501
            ]
502
            cols += [
503
                fits.Column(
504
                    "CRPIX",
505
                    "2E",
506
                    dim="(2)",
507
                    array=np.vstack(
508
                        (np.ravel(self._crpix[0]), np.ravel(self._crpix[1]))
509
                    ).T,
510
                )
511
            ]
512
        return cols
513
514
    def make_header(self):
515
        header = self.wcs.to_header()
516
        self._fill_header_from_axes(header)
517
        shape = "{},{}".format(np.max(self.npix[0]), np.max(self.npix[1]))
518
        for ax in self.axes:
519
            shape += ",{}".format(ax.nbin)
520
        header["WCSSHAPE"] = "({})".format(shape)
521
        return header
522
523
    def get_image_shape(self, idx):
524
        """Get the shape of the image plane at index ``idx``."""
525
526
        if self.is_regular:
527
            return int(self.npix[0]), int(self.npix[1])
528
        else:
529
            return int(self.npix[0][idx]), int(self.npix[1][idx])
530
531
    def get_idx(self, idx=None, flat=False):
532
        pix = self.get_pix(idx=idx, mode="center")
533
        if flat:
534
            pix = tuple([p[np.isfinite(p)] for p in pix])
535
        return pix_tuple_to_idx(pix)
536
537
    def get_pix(self, idx=None, mode="center"):
538
        """Get map pix coordinates from the geometry.
539
540
        Parameters
541
        ----------
542
        mode : {'center', 'edges'}
543
            Get center or edge pix coordinates for the spatial axes.
544
545
        Returns
546
        -------
547
        coord : tuple
548
            Map pix coordinate tuple.
549
        """
550
        # FIXME: Figure out if there is some way to employ open/sparse
551
        # vectors
552
553
        # FIXME: It would be more efficient to split this into one
554
        # method that computes indices and a method that casts
555
        # those to floats and adds the appropriate offset
556
557
        npix = copy.deepcopy(self.npix)
558
559
        if mode == "edges":
560
            for pix_num in npix[:2]:
561
                pix_num += 1
562
563
        if self.axes and not self.is_regular:
564
565
            shape = (np.max(self._npix[0]), np.max(self._npix[1]))
566
567
            if idx is None:
568
                shape = shape + self.shape
569
            else:
570
                shape = shape + (1,) * len(self.axes)
571
572
            pix2 = [
573
                np.full(shape, np.nan, dtype=float) for i in range(2 + len(self.axes))
574
            ]
575
            for idx_img in np.ndindex(self.shape):
576
577
                if idx is not None and idx_img != idx:
578
                    continue
579
580
                npix0, npix1 = npix[0][idx_img], npix[1][idx_img]
581
                pix_img = np.meshgrid(
582
                    np.arange(npix0), np.arange(npix1), indexing="ij", sparse=False
583
                )
584
585
                if idx is None:
586
                    s_img = (slice(0, npix0), slice(0, npix1)) + idx_img
587
                else:
588
                    s_img = (slice(0, npix0), slice(0, npix1)) + (0,) * len(self.axes)
589
590
                pix2[0][s_img] = pix_img[0]
591
                pix2[1][s_img] = pix_img[1]
592
                for j in range(len(self.axes)):
593
                    pix2[j + 2][s_img] = idx_img[j]
594
            pix = [t.T for t in pix2]
595
        else:
596
            pix = [np.arange(npix[0], dtype=float), np.arange(npix[1], dtype=float)]
597
598
            if idx is None:
599
                pix += [np.arange(ax.nbin, dtype=float) for ax in self.axes]
600
            else:
601
                pix += [float(t) for t in idx]
602
603
            pix = np.meshgrid(*pix[::-1], indexing="ij", sparse=False)[::-1]
604
605
        if mode == "edges":
606
            for pix_array in pix[self._slice_spatial_axes]:
607
                pix_array -= 0.5
608
609
        coords = self.pix_to_coord(pix)
610
        m = np.isfinite(coords[0])
611
        for i in range(len(pix)):
612
            pix[i][~m] = np.nan
613
        return pix
614
615
    def get_coord(self, idx=None, flat=False, mode="center"):
616
        """Get map coordinates from the geometry.
617
618
        Parameters
619
        ----------
620
        mode : {'center', 'edges'}
621
            Get center or edge coordinates for the spatial axes.
622
623
        Returns
624
        -------
625
        coord : `~MapCoord`
626
            Map coordinate object.
627
        """
628
        pix = self.get_pix(idx=idx, mode=mode)
629
        coords = self.pix_to_coord(pix)
630
631
        if flat:
632
            coords = tuple([c[np.isfinite(c)] for c in coords])
633
634
        axes_names = ["lon", "lat"] + [ax.name for ax in self.axes]
635
        cdict = OrderedDict(zip(axes_names, coords))
636
637
        return MapCoord.create(cdict, coordsys=self.coordsys)
638
639
    def coord_to_pix(self, coords):
640
        coords = MapCoord.create(coords, coordsys=self.coordsys)
641
        if coords.size == 0:
642
            return tuple([np.array([]) for i in range(coords.ndim)])
643
644
        c = self.coord_to_tuple(coords)
645
        # Variable Bin Size
646
        if not self.is_regular:
647
            bins = [ax.coord_to_pix(c[i + 2]) for i, ax in enumerate(self.axes)]
648
            idxs = tuple(
649
                [
650
                    np.clip(ax.coord_to_idx(c[i + 2]), 0, ax.nbin - 1)
651
                    for i, ax in enumerate(self.axes)
652
                ]
653
            )
654
            crpix = [t[idxs] for t in self._crpix]
655
            cdelt = [t[idxs] for t in self._cdelt]
656
            pix = world2pix(self.wcs, cdelt, crpix, (coords.lon, coords.lat))
657
            pix = list(pix) + bins
658
        else:
659
            pix = self._wcs.wcs_world2pix(coords.lon, coords.lat, 0)
660
            for i, ax in enumerate(self.axes):
661
                pix += [ax.coord_to_pix(c[i + 2])]
662
663
        return tuple(pix)
664
665
    def pix_to_coord(self, pix):
666
        # Variable Bin Size
667
        if not self.is_regular:
668
            idxs = pix_tuple_to_idx([pix[2 + i] for i, ax in enumerate(self.axes)])
669
            vals = [ax.pix_to_coord(pix[2 + i]) for i, ax in enumerate(self.axes)]
670
            crpix = [t[idxs] for t in self._crpix]
671
            cdelt = [t[idxs] for t in self._cdelt]
672
            coords = pix2world(self.wcs, cdelt, crpix, pix[:2])
673
            coords += vals
674
        else:
675
            coords = self._wcs.wcs_pix2world(pix[0], pix[1], 0)
676
            for i, ax in enumerate(self.axes):
677
                coords += [ax.pix_to_coord(pix[i + 2])]
678
679
        return tuple(coords)
680
681
    def pix_to_idx(self, pix, clip=False):
682
        # TODO: copy idx to avoid modifying input pix?
683
        # pix_tuple_to_idx seems to always make a copy!?
684
        idxs = pix_tuple_to_idx(pix)
685
        if not self.is_regular:
686
            ibin = [pix[2 + i] for i, ax in enumerate(self.axes)]
687
            ibin = pix_tuple_to_idx(ibin)
688
            for i, ax in enumerate(self.axes):
689
                np.clip(ibin[i], 0, ax.nbin - 1, out=ibin[i])
690
            npix = (self.npix[0][ibin], self.npix[1][ibin])
691
        else:
692
            npix = self.npix
693
694
        for i, idx in enumerate(idxs):
695
            if clip:
696
                if i < 2:
697
                    np.clip(idxs[i], 0, npix[i], out=idxs[i])
698
                else:
699
                    np.clip(idxs[i], 0, self.axes[i - 2].nbin - 1, out=idxs[i])
700
            else:
701
                if i < 2:
702
                    np.putmask(idxs[i], (idx < 0) | (idx >= npix[i]), -1)
703
                else:
704
                    np.putmask(idxs[i], (idx < 0) | (idx >= self.axes[i - 2].nbin), -1)
705
706
        return idxs
707
708
    def contains(self, coords):
709
        idx = self.coord_to_idx(coords)
710
        return np.all(np.stack([t != -1 for t in idx]), axis=0)
711
712
    def to_image(self):
713
        npix = (np.max(self._npix[0]), np.max(self._npix[1]))
714
        cdelt = (np.max(self._cdelt[0]), np.max(self._cdelt[1]))
715
        return self.__class__(self._wcs, npix, cdelt=cdelt)
716
717
    def to_cube(self, axes):
718
        npix = (np.max(self._npix[0]), np.max(self._npix[1]))
719
        cdelt = (np.max(self._cdelt[0]), np.max(self._cdelt[1]))
720
        axes = copy.deepcopy(self.axes) + axes
721
        return self.__class__(self._wcs.deepcopy(), npix, cdelt=cdelt, axes=axes)
722
723
    def pad(self, pad_width):
724
        if np.isscalar(pad_width):
725
            pad_width = (pad_width, pad_width)
726
727
        npix = (self.npix[0] + 2 * pad_width[0], self.npix[1] + 2 * pad_width[1])
728
        wcs = self._wcs.deepcopy()
729
        wcs.wcs.crpix += np.array(pad_width)
730
        cdelt = copy.deepcopy(self._cdelt)
731
        return self.__class__(wcs, npix, cdelt=cdelt, axes=copy.deepcopy(self.axes))
732
733
    def crop(self, crop_width):
734
        if np.isscalar(crop_width):
735
            crop_width = (crop_width, crop_width)
736
737
        npix = (self.npix[0] - 2 * crop_width[0], self.npix[1] - 2 * crop_width[1])
738
        wcs = self._wcs.deepcopy()
739
        wcs.wcs.crpix -= np.array(crop_width)
740
        cdelt = copy.deepcopy(self._cdelt)
741
        return self.__class__(wcs, npix, cdelt=cdelt, axes=copy.deepcopy(self.axes))
742
743
    def downsample(self, factor):
744
745
        if not np.all(np.mod(self.npix[0], factor) == 0) or not np.all(
746
            np.mod(self.npix[1], factor) == 0
747
        ):
748
            raise ValueError(
749
                "Data shape not divisible by factor {!r} in all axes."
750
                " You need to pad prior to calling downsample.".format(factor)
751
            )
752
753
        npix = (self.npix[0] / factor, self.npix[1] / factor)
754
        cdelt = (self._cdelt[0] * factor, self._cdelt[1] * factor)
755
        wcs = get_resampled_wcs(self.wcs, factor, True)
756
        return self.__class__(wcs, npix, cdelt=cdelt, axes=copy.deepcopy(self.axes))
757
758
    def upsample(self, factor):
759
        npix = (self.npix[0] * factor, self.npix[1] * factor)
760
        cdelt = (self._cdelt[0] / factor, self._cdelt[1] / factor)
761
        wcs = get_resampled_wcs(self.wcs, factor, False)
762
        return self.__class__(wcs, npix, cdelt=cdelt, axes=copy.deepcopy(self.axes))
763
764
    def solid_angle(self):
765
        """Solid angle array (`~astropy.units.Quantity` in ``sr``).
766
767
        The array has the same dimension as the WcsGeom object.
768
769
        To return solid angles for the spatial dimensions only use::
770
771
            WcsGeom.to_image().solid_angle()
772
        """
773
        coord = self.get_coord(mode="edges")
774
        lon = coord.lon * np.pi / 180.0
775
        lat = coord.lat * np.pi / 180.0
776
777
        # Compute solid angle using the approximation that it's
778
        # the product between angular separation of pixel corners.
779
        # First index is "y", second index is "x"
780
        ylo_xlo = lon[..., :-1, :-1], lat[..., :-1, :-1]
781
        ylo_xhi = lon[..., :-1, 1:], lat[..., :-1, 1:]
782
        yhi_xlo = lon[..., 1:, :-1], lat[..., 1:, :-1]
783
784
        dx = angular_separation(*(ylo_xlo + ylo_xhi))
785
        dy = angular_separation(*(ylo_xlo + yhi_xlo))
786
787
        return u.Quantity(dx * dy, "sr", copy=False)
788
789
    def separation(self, center):
790
        """Compute sky separation wrt a given center.
791
792
        Parameters
793
        ----------
794
        center : `~astropy.coordinates.SkyCoord`
795
            Center position
796
797
        Returns
798
        -------
799
        separation : `~astropy.coordinates.Angle`
800
            Separation angle array (2D)
801
        """
802
        coord = self.to_image().get_coord()
803
        return center.separation(coord.skycoord)
804
805
    def region_mask(self, regions, inside=True):
806
        """Create a mask from a given list of regions
807
808
        Parameters
809
        ----------
810
        regions : list of  `~regions.Region`
811
            Python list of regions (pixel or sky regions accepted)
812
        inside : bool
813
            For ``inside=True``, pixels in the region to True (the default).
814
            For ``inside=False``, pixels in the region are False.
815
816
        Returns
817
        -------
818
        mask_map : `~numpy.ndarray` of boolean type
819
            Boolean region mask
820
821
        Examples
822
        --------
823
        Make an exclusion mask for a circular region::
824
825
            from regions import CircleSkyRegion
826
            from astropy.coordinates import SkyCoord, Angle
827
            from gammapy.maps import WcsNDMap, WcsGeom
828
829
            pos = SkyCoord(0, 0, unit='deg')
830
            geom = WcsGeom.create(skydir=pos, npix=100, binsz=0.1)
831
832
            region = CircleSkyRegion(
833
                SkyCoord(3, 2, unit='deg'),
834
                Angle(1, 'deg'),
835
            )
836
            mask = geom.region_mask([region], inside=False)
837
838
        Note how we made a list with a single region,
839
        since this method expects a list of regions.
840
841
        The return ``mask`` is a boolean Numpy array.
842
        If you want a map object (e.g. for storing in FITS or plotting),
843
        this is how you can make the map::
844
845
            mask_map = WcsNDMap(geom=geom, data=mask)
846
            mask_map.plot()
847
        """
848
        from regions import PixCoord
849
850
        if not self.is_regular:
851
            raise ValueError("Multi-resolution maps not supported yet")
852
853
        idx = self.get_idx()
854
        pixcoord = PixCoord(idx[0], idx[1])
855
856
        mask = np.zeros(self.data_shape, dtype=bool)
857
858
        for region in regions:
859
            if isinstance(region, SkyRegion):
860
                region = region.to_pixel(self.wcs)
861
            mask += region.contains(pixcoord)
862
863
        if inside is False:
864
            np.logical_not(mask, out=mask)
865
866
        return mask
867
868
    def __repr__(self):
869
        str_ = self.__class__.__name__
870
        str_ += "\n\n"
871
        axes = ["lon", "lat"] + [_.name for _ in self.axes]
872
        str_ += "\taxes       : {}\n".format(", ".join(axes))
873
        str_ += "\tshape      : {}\n".format(self.data_shape[::-1])
874
        str_ += "\tndim       : {}\n".format(self.ndim)
875
        str_ += "\tcoordsys   : {}\n".format(self.coordsys)
876
        str_ += "\tprojection : {}\n".format(self.projection)
877
        lon = self.center_skydir.data.lon.deg
878
        lat = self.center_skydir.data.lat.deg
879
        str_ += "\tcenter     : {:.1f} deg, {:.1f} deg\n".format(lon, lat)
880
        str_ += "\twidth      : {width[0][0]:.1f} x {width[1][0]:.1f} " "deg\n".format(
881
            width=self.width
882
        )
883
        return str_
884
885
    def __eq__(self, other):
886
        if not isinstance(other, self.__class__):
887
            return NotImplemented
888
889
        # check overall shape and axes compatibility
890
        if self.data_shape != other.data_shape:
891
            return False
892
893
        for axis, otheraxis in zip(self.axes, other.axes):
894
            if axis != otheraxis:
895
                return False
896
897
        # check WCS consistency with a priori tolerance of 1e-6
898
        return self.wcs.wcs.compare(other.wcs.wcs, tolerance=1e-6)
899
900
    def __ne__(self, other):
901
        return not self.__eq__(other)
902
903
def create_wcs(
904
    skydir, coordsys="CEL", projection="AIT", cdelt=1.0, crpix=1.0, axes=None
905
):
906
    """Create a WCS object.
907
908
    Parameters
909
    ----------
910
    skydir : `~astropy.coordinates.SkyCoord`
911
        Sky coordinate of the WCS reference point
912
    coordsys : str
913
        TODO
914
    projection : str
915
        TODO
916
    cdelt : float
917
        TODO
918
    crpix : float or (float,float)
919
        In the first case the same value is used for x and y axes
920
    axes : list
921
        List of non-spatial axes
922
    """
923
    naxis = 2
924
    if axes is not None:
925
        naxis += len(axes)
926
927
    w = WCS(naxis=naxis)
928
929
    if coordsys == "CEL":
930
        w.wcs.ctype[0] = "RA---{}".format(projection)
931
        w.wcs.ctype[1] = "DEC--{}".format(projection)
932
        w.wcs.crval[0] = skydir.icrs.ra.deg
933
        w.wcs.crval[1] = skydir.icrs.dec.deg
934
    elif coordsys == "GAL":
935
        w.wcs.ctype[0] = "GLON-{}".format(projection)
936
        w.wcs.ctype[1] = "GLAT-{}".format(projection)
937
        w.wcs.crval[0] = skydir.galactic.l.deg
938
        w.wcs.crval[1] = skydir.galactic.b.deg
939
    else:
940
        raise ValueError("Invalid coordsys: {!r}".format(coordsys))
941
942
    if isinstance(crpix, tuple):
943
        w.wcs.crpix[0] = crpix[0]
944
        w.wcs.crpix[1] = crpix[1]
945
    else:
946
        w.wcs.crpix[0] = crpix
947
        w.wcs.crpix[1] = crpix
948
949
    w.wcs.cdelt[0] = -cdelt
950
    w.wcs.cdelt[1] = cdelt
951
952
    w = WCS(w.to_header())
953
    # FIXME: Figure out what to do here
954
    # if naxis == 3 and energies is not None:
955
    #    w.wcs.crpix[2] = 1
956
    #    w.wcs.crval[2] = energies[0]
957
    #    w.wcs.cdelt[2] = energies[1] - energies[0]
958
    #    w.wcs.ctype[2] = 'Energy'
959
    #    w.wcs.cunit[2] = 'MeV'
960
961
    return w
962
963
964
def wcs_add_energy_axis(wcs, energies):
965
    """Copy a WCS object, and add on the energy axis.
966
967
    Parameters
968
    ----------
969
    wcs : `~astropy.wcs.WCS`
970
        WCS
971
    energies : array-like
972
        Array of energies
973
    """
974
    if wcs.naxis != 2:
975
        raise ValueError("WCS naxis must be 2. Got: {}".format(wcs.naxis))
976
977
    w = WCS(naxis=3)
978
    w.wcs.crpix[0] = wcs.wcs.crpix[0]
979
    w.wcs.crpix[1] = wcs.wcs.crpix[1]
980
    w.wcs.ctype[0] = wcs.wcs.ctype[0]
981
    w.wcs.ctype[1] = wcs.wcs.ctype[1]
982
    w.wcs.crval[0] = wcs.wcs.crval[0]
983
    w.wcs.crval[1] = wcs.wcs.crval[1]
984
    w.wcs.cdelt[0] = wcs.wcs.cdelt[0]
985
    w.wcs.cdelt[1] = wcs.wcs.cdelt[1]
986
987
    w = WCS(w.to_header())
988
    w.wcs.crpix[2] = 1
989
    w.wcs.crval[2] = energies[0]
990
    w.wcs.cdelt[2] = energies[1] - energies[0]
991
    w.wcs.ctype[2] = "Energy"
992
993
    return w
994
995
996
def offset_to_sky(skydir, offset_lon, offset_lat, coordsys="CEL", projection="AIT"):
997
    """Convert a cartesian offset (X,Y) in the given projection into
998
    a pair of spherical coordinates."""
999
    offset_lon = np.array(offset_lon, ndmin=1)
1000
    offset_lat = np.array(offset_lat, ndmin=1)
1001
1002
    w = create_wcs(skydir, coordsys, projection)
1003
    pixcrd = np.vstack((offset_lon, offset_lat)).T
1004
1005
    return w.wcs_pix2world(pixcrd, 0)
1006
1007
1008
def sky_to_offset(skydir, lon, lat, coordsys="CEL", projection="AIT"):
1009
    """Convert sky coordinates to a projected offset.
1010
1011
    This function is the inverse of offset_to_sky.
1012
    """
1013
    w = create_wcs(skydir, coordsys, projection)
1014
    skycrd = np.vstack((lon, lat)).T
1015
1016
    if len(skycrd) == 0:
1017
        return skycrd
1018
1019
    return w.wcs_world2pix(skycrd, 0)
1020
1021
1022
def offset_to_skydir(skydir, offset_lon, offset_lat, coordsys="CEL", projection="AIT"):
1023
    """Convert a cartesian offset (X,Y) in the given projection into
1024
    a SkyCoord."""
1025
    offset_lon = np.array(offset_lon, ndmin=1)
1026
    offset_lat = np.array(offset_lat, ndmin=1)
1027
1028
    w = create_wcs(skydir, coordsys, projection)
1029
    return SkyCoord.from_pixel(offset_lon, offset_lat, w, 0)
1030
1031
1032
def pix2world(wcs, cdelt, crpix, pix):
1033
    """Perform pixel to world coordinate transformation for a WCS
1034
    projection with a given pixel size (CDELT) and reference pixel
1035
    (CRPIX).  This method can be used to perform WCS transformations
1036
    for projections with different pixelizations but the same
1037
    reference coordinate (CRVAL), projection type, and coordinate
1038
    system.
1039
1040
    Parameters
1041
    ----------
1042
    wcs : `astropy.wcs.WCS`
1043
        WCS transform object.
1044
    cdelt : tuple
1045
        Tuple of X/Y pixel size in deg.  Each element should have the
1046
        same length as ``pix``.
1047
    crpix : tuple
1048
        Tuple of reference pixel parameters in X and Y dimensions.  Each
1049
        element should have the same length as ``pix``.
1050
    pix : tuple
1051
        Tuple of pixel coordinates.
1052
    """
1053
    pix_ratio = [
1054
        np.abs(wcs.wcs.cdelt[0] / cdelt[0]),
1055
        np.abs(wcs.wcs.cdelt[1] / cdelt[1]),
1056
    ]
1057
    pix = (
1058
        (pix[0] - (crpix[0] - 1.0)) / pix_ratio[0] + wcs.wcs.crpix[0] - 1.0,
1059
        (pix[1] - (crpix[1] - 1.0)) / pix_ratio[1] + wcs.wcs.crpix[1] - 1.0,
1060
    )
1061
    return wcs.wcs_pix2world(pix[0], pix[1], 0)
1062
1063
1064
def world2pix(wcs, cdelt, crpix, coord):
1065
    pix_ratio = [
1066
        np.abs(wcs.wcs.cdelt[0] / cdelt[0]),
1067
        np.abs(wcs.wcs.cdelt[1] / cdelt[1]),
1068
    ]
1069
    pix = wcs.wcs_world2pix(coord[0], coord[1], 0)
1070
    return (
1071
        (pix[0] - (wcs.wcs.crpix[0] - 1.0)) * pix_ratio[0] + crpix[0] - 1.0,
1072
        (pix[1] - (wcs.wcs.crpix[1] - 1.0)) * pix_ratio[1] + crpix[1] - 1.0,
1073
    )
1074
1075
1076
def get_projection(wcs):
1077
    return wcs.wcs.ctype[0][5:]
1078
1079
1080
def get_coordys(wcs):
1081
    if "RA" in wcs.wcs.ctype[0]:
1082
        return "CEL"
1083
    elif "GLON" in wcs.wcs.ctype[0]:
1084
        return "GAL"
1085
    else:
1086
        raise ValueError("Unrecognized WCS coordinate system.")
1087
1088
1089
def wcs_to_axes(w, npix):
1090
    """Generate a sequence of bin edge vectors corresponding to the
1091
    axes of a WCS object."""
1092
    npix = npix[::-1]
1093
1094
    cdelt0 = np.abs(w.wcs.cdelt[0])
1095
    x = np.linspace(-(npix[0]) / 2.0, (npix[0]) / 2.0, npix[0] + 1) * cdelt0
1096
1097
    cdelt1 = np.abs(w.wcs.cdelt[1])
1098
    y = np.linspace(-(npix[1]) / 2.0, (npix[1]) / 2.0, npix[1] + 1) * cdelt1
1099
1100
    cdelt2 = np.log10((w.wcs.cdelt[2] + w.wcs.crval[2]) / w.wcs.crval[2])
1101
    z = np.linspace(0, npix[2], npix[2] + 1) * cdelt2
1102
    z += np.log10(w.wcs.crval[2])
1103
1104
    return x, y, z
1105