gammapy.maps.wcs.geom   F
last analyzed

Complexity

Total Complexity 130

Size/Duplication

Total Lines 1223
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 557
dl 0
loc 1223
rs 2
c 0
b 0
f 0
wmc 130

63 Methods

Rating   Name   Duplication   Size   Complexity  
A WcsGeom.region_weights() 0 21 1
A WcsGeom.region_mask() 0 58 3
A WcsGeom.to_even_npix() 0 20 1
A WcsGeom.binary_structure() 0 36 4
A WcsGeom.is_aligned() 0 21 3
A WcsGeom.__repr__() 0 7 1
A WcsGeom.to_odd_npix() 0 31 2
A WcsGeom.npix() 0 4 1
A WcsGeom.width() 0 6 1
A WcsGeom.center_pix() 0 9 1
A WcsGeom.wcs() 0 4 1
A WcsGeom.pixel_scales() 0 13 1
A WcsGeom.projection() 0 4 1
A WcsGeom.shape_axes() 0 4 1
A WcsGeom.ndim() 0 3 1
A WcsGeom.center_coord() 0 9 1
A WcsGeom.pixel_area() 0 5 1
A WcsGeom._shape() 0 4 1
A WcsGeom.center_skydir() 0 9 1
A WcsGeom.is_allsky() 0 9 3
A WcsGeom.frame() 0 7 1
A WcsGeom.__init__() 0 23 3
A WcsGeom.data_shape() 0 4 1
A WcsGeom._shape_edges() 0 4 1
A WcsGeom.cutout_slices() 0 25 1
A WcsGeom.axes() 0 4 1
A WcsGeom.__setstate__() 0 6 3
A WcsGeom.axes_names() 0 4 1
A WcsGeom.data_shape_axes() 0 4 1
A WcsGeom.is_regular() 0 11 2
A WcsGeom._get_pix_all() 0 19 5
A WcsGeom.pix_to_coord() 0 17 2
A WcsGeom.from_header() 0 40 4
A WcsGeom.get_coord() 0 35 3
A WcsGeom._solid_angle() 0 30 1
A WcsGeom.from_aligned() 0 37 1
A WcsGeom.to_binsz() 0 20 1
A WcsGeom._image_geom() 0 5 1
A WcsGeom.contains() 0 3 1
A WcsGeom.__eq__() 0 10 4
A WcsGeom.coord_to_pix() 0 18 3
A WcsGeom.footprint() 0 5 1
A WcsGeom.separation() 0 15 1
C WcsGeom.create() 0 119 9
A WcsGeom.pix_to_idx() 0 23 4
A WcsGeom.to_header() 0 9 2
A WcsGeom.bin_volume() 0 3 1
A WcsGeom.is_allclose() 0 33 3
A WcsGeom._bin_volume() 0 9 2
A WcsGeom.solid_angle() 0 10 1
A WcsGeom.get_pix() 0 19 2
A WcsGeom.upsample() 0 13 3
A WcsGeom.__ne__() 0 2 1
A WcsGeom.__hash__() 0 2 1
A WcsGeom.downsample() 0 19 4
A WcsGeom._pad_spatial() 0 9 2
A WcsGeom._make_bands_cols() 0 33 2
A WcsGeom.crop() 0 9 2
A WcsGeom.boundary_mask() 0 20 1
A WcsGeom.get_idx() 0 5 2
A WcsGeom.to_cube() 0 9 1
A WcsGeom.cutout() 0 40 2
A WcsGeom.to_image() 0 2 1

4 Functions

Rating   Name   Duplication   Size   Complexity  
B cast_to_shape() 0 21 7
A get_resampled_wcs() 0 12 2
A pix2world() 0 30 1
A world2pix() 0 9 1

How to fix   Complexity   

Complexity

Complex classes like gammapy.maps.wcs.geom often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
# Licensed under a 3-clause BSD style license - see LICENSE.rst
2
import copy
3
from functools import lru_cache
4
import numpy as np
5
import astropy.units as u
6
from astropy.convolution import Tophat2DKernel
7
from astropy.coordinates import Angle, SkyCoord
8
from astropy.io import fits
9
from astropy.nddata import Cutout2D
10
from astropy.nddata.utils import overlap_slices
11
from astropy.utils import lazyproperty
12
from astropy.wcs import WCS
13
from astropy.wcs.utils import (
14
    celestial_frame_to_wcs,
15
    proj_plane_pixel_scales,
16
    wcs_to_celestial_frame,
17
)
18
from gammapy.utils.array import round_up_to_even, round_up_to_odd
19
from ..axes import MapAxes
20
from ..coord import MapCoord, skycoord_to_lonlat
21
from ..geom import Geom, get_shape, pix_tuple_to_idx
22
from ..utils import INVALID_INDEX, _check_binsz, _check_width
23
24
__all__ = ["WcsGeom"]
25
26
27
def cast_to_shape(param, shape, dtype):
28
    """Cast a tuple of parameter arrays to a given shape."""
29
    if not isinstance(param, tuple):
30
        param = [param]
31
32
    param = [np.array(p, ndmin=1, dtype=dtype) for p in param]
33
34
    if len(param) == 1:
35
        param = [param[0].copy(), param[0].copy()]
36
37
    for i, p in enumerate(param):
38
39
        if p.size > 1 and p.shape != shape:
40
            raise ValueError
41
42
        if p.shape == shape:
43
            continue
44
45
        param[i] = p * np.ones(shape, dtype=dtype)
46
47
    return tuple(param)
48
49
50
def get_resampled_wcs(wcs, factor, downsampled):
51
    """
52
    Get resampled WCS object.
53
    """
54
    wcs = wcs.deepcopy()
55
56
    if not downsampled:
57
        factor = 1.0 / factor
58
59
    wcs.wcs.cdelt *= factor
60
    wcs.wcs.crpix = (wcs.wcs.crpix - 0.5) / factor + 0.5
61
    return wcs
62
63
64
class WcsGeom(Geom):
65
    """Geometry class for WCS maps.
66
67
    This class encapsulates both the WCS transformation object and the
68
    the image extent (number of pixels in each dimension).  Provides
69
    methods for accessing the properties of the WCS object and
70
    performing transformations between pixel and world coordinates.
71
72
    Parameters
73
    ----------
74
    wcs : `~astropy.wcs.WCS`
75
        WCS projection object
76
    npix : tuple
77
        Number of pixels in each spatial dimension
78
    cdelt : tuple
79
        Pixel size in each image plane.  If none then a constant pixel size will be used.
80
    crpix : tuple
81
        Reference pixel coordinate in each image plane.
82
    axes : list
83
        Axes for non-spatial dimensions
84
    """
85
86
    _slice_spatial_axes = slice(0, 2)
87
    _slice_non_spatial_axes = slice(2, None)
88
    is_hpx = False
89
    is_region = False
90
91
    def __init__(self, wcs, npix, cdelt=None, crpix=None, axes=None):
92
        self._wcs = wcs
93
        self._frame = wcs_to_celestial_frame(wcs).name
94
        self._projection = wcs.wcs.ctype[0][5:]
95
        self._axes = MapAxes.from_default(axes, n_spatial_axes=2)
96
97
        if cdelt is None:
98
            cdelt = tuple(np.abs(self.wcs.wcs.cdelt))
99
100
        # Shape to use for WCS transformations
101
        wcs_shape = max([get_shape(t) for t in [npix, cdelt]])
102
        self._npix = cast_to_shape(npix, wcs_shape, int)
103
        self._cdelt = cast_to_shape(cdelt, wcs_shape, float)
104
105
        # By convention CRPIX is indexed from 1
106
        if crpix is None:
107
            crpix = tuple(1.0 + (np.array(self._npix) - 1.0) / 2.0)
108
109
        self._crpix = crpix
110
111
        # define cached methods
112
        self.get_coord = lru_cache()(self.get_coord)
113
        self.get_pix = lru_cache()(self.get_pix)
114
115
    def __setstate__(self, state):
116
        for key, value in state.items():
117
            if key in ["get_coord", "get_pix"]:
118
                state[key] = lru_cache()(value)
119
120
        self.__dict__ = state
121
122
    @property
123
    def data_shape(self):
124
        """Shape of the Numpy data array matching this geometry."""
125
        return self._shape[::-1]
126
127
    @property
128
    def axes_names(self):
129
        """All axes names"""
130
        return ["lon", "lat"] + self.axes.names
131
132
    @property
133
    def data_shape_axes(self):
134
        """Shape of data of the non-spatial axes and unit spatial axes."""
135
        return self.axes.shape[::-1] + (1, 1)
136
137
    @property
138
    def _shape(self):
139
        npix_shape = tuple([np.max(self.npix[0]), np.max(self.npix[1])])
140
        return npix_shape + self.axes.shape
141
142
    @property
143
    def _shape_edges(self):
144
        npix_shape = tuple([np.max(self.npix[0]) + 1, np.max(self.npix[1]) + 1])
145
        return npix_shape + self.axes.shape
146
147
    @property
148
    def shape_axes(self):
149
        """Shape of non-spatial axes."""
150
        return self._shape[self._slice_non_spatial_axes]
151
152
    @property
153
    def wcs(self):
154
        """WCS projection object."""
155
        return self._wcs
156
157
    @property
158
    def frame(self):
159
        """Coordinate system of the projection.
160
161
        Galactic ("galactic") or Equatorial ("icrs").
162
        """
163
        return self._frame
164
165
    def cutout_slices(self, geom, mode="partial"):
166
        """Compute cutout slices.
167
168
        Parameters
169
        ----------
170
        geom : `WcsGeom`
171
            Parent geometry
172
        mode : {"trim", "partial", "strict"}
173
            Cutout slices mode.
174
175
        Returns
176
        -------
177
        slices : dict
178
            Dictionary containing "parent-slices" and "cutout-slices".
179
        """
180
        position = geom.to_image().coord_to_pix(self.center_skydir)
181
        slices = overlap_slices(
182
            large_array_shape=geom.data_shape[-2:],
183
            small_array_shape=self.data_shape[-2:],
184
            position=position[::-1],
185
            mode=mode,
186
        )
187
        return {
188
            "parent-slices": slices[0],
189
            "cutout-slices": slices[1],
190
        }
191
192
    @property
193
    def projection(self):
194
        """Map projection."""
195
        return self._projection
196
197
    @property
198
    def is_allsky(self):
199
        """Flag for all-sky maps."""
200
        if np.all(np.isclose(self._npix[0] * self._cdelt[0], 360.0)) and np.all(
201
            np.isclose(self._npix[1] * self._cdelt[1], 180.0)
202
        ):
203
            return True
204
        else:
205
            return False
206
207
    @property
208
    def is_regular(self):
209
        """Is this geometry is regular in non-spatial dimensions (bool)?
210
211
        - False for multi-resolution or irregular geometries.
212
        - True if all image planes have the same pixel geometry.
213
        """
214
        if self.npix[0].size > 1:
215
            return False
216
        else:
217
            return True
218
219
    @property
220
    def width(self):
221
        """Tuple with image dimension in deg in longitude and latitude."""
222
        dlon = self._cdelt[0] * self._npix[0]
223
        dlat = self._cdelt[1] * self._npix[1]
224
        return (dlon, dlat) * u.deg
225
226
    @property
227
    def pixel_area(self):
228
        """Pixel area in deg^2."""
229
        # FIXME: Correctly compute solid angle for projection
230
        return self._cdelt[0] * self._cdelt[1]
231
232
    @property
233
    def npix(self):
234
        """Tuple with image dimension in pixels in longitude and latitude."""
235
        return self._npix
236
237
    @property
238
    def axes(self):
239
        """List of non-spatial axes."""
240
        return self._axes
241
242
    @property
243
    def ndim(self):
244
        return len(self.data_shape)
245
246
    @property
247
    def center_coord(self):
248
        """Map coordinate of the center of the geometry.
249
250
        Returns
251
        -------
252
        coord : tuple
253
        """
254
        return self.pix_to_coord(self.center_pix)
255
256
    @property
257
    def center_pix(self):
258
        """Pixel coordinate of the center of the geometry.
259
260
        Returns
261
        -------
262
        pix : tuple
263
        """
264
        return tuple((np.array(self.data_shape) - 1.0) / 2)[::-1]
265
266
    @property
267
    def center_skydir(self):
268
        """Sky coordinate of the center of the geometry.
269
270
        Returns
271
        -------
272
        pix : `~astropy.coordinates.SkyCoord`
273
        """
274
        return SkyCoord.from_pixel(self.center_pix[0], self.center_pix[1], self.wcs)
275
276
    @property
277
    def pixel_scales(self):
278
        """
279
        Pixel scale.
280
281
        Returns angles along each axis of the image at the CRPIX location once
282
        it is projected onto the plane of intermediate world coordinates.
283
284
        Returns
285
        -------
286
        angle: `~astropy.coordinates.Angle`
287
        """
288
        return Angle(proj_plane_pixel_scales(self.wcs), "deg")
289
290
    @classmethod
291
    def create(
292
        cls,
293
        npix=None,
294
        binsz=0.5,
295
        proj="CAR",
296
        frame="icrs",
297
        refpix=None,
298
        axes=None,
299
        skydir=None,
300
        width=None,
301
    ):
302
        """Create a WCS geometry object.
303
304
        Pixelization of the map is set with
305
        ``binsz`` and one of either ``npix`` or ``width`` arguments.
306
        For maps with non-spatial dimensions a different pixelization
307
        can be used for each image plane by passing a list or array
308
        argument for any of the pixelization parameters.  If both npix
309
        and width are None then an all-sky geometry will be created.
310
311
        Parameters
312
        ----------
313
        npix : int or tuple or list
314
            Width of the map in pixels. A tuple will be interpreted as
315
            parameters for longitude and latitude axes.  For maps with
316
            non-spatial dimensions, list input can be used to define a
317
            different map width in each image plane.  This option
318
            supersedes width.
319
        width : float or tuple or list or string
320
            Width of the map in degrees.  A tuple will be interpreted
321
            as parameters for longitude and latitude axes.  For maps
322
            with non-spatial dimensions, list input can be used to
323
            define a different map width in each image plane.
324
        binsz : float or tuple or list
325
            Map pixel size in degrees.  A tuple will be interpreted
326
            as parameters for longitude and latitude axes.  For maps
327
            with non-spatial dimensions, list input can be used to
328
            define a different bin size in each image plane.
329
        skydir : tuple or `~astropy.coordinates.SkyCoord`
330
            Sky position of map center.  Can be either a SkyCoord
331
            object or a tuple of longitude and latitude in deg in the
332
            coordinate system of the map.
333
        frame : {"icrs", "galactic"}, optional
334
            Coordinate system, either Galactic ("galactic") or Equatorial ("icrs").
335
        axes : list
336
            List of non-spatial axes.
337
        proj : string, optional
338
            Any valid WCS projection type. Default is 'CAR' (Plate-Carrée projection).
339
            See `WCS supported projections <https://docs.astropy.org/en/stable/wcs/supported_projections.html>`__  # noqa: E501
340
        refpix : tuple
341
            Reference pixel of the projection.  If None this will be
342
            set to the center of the map.
343
344
        Returns
345
        -------
346
        geom : `~WcsGeom`
347
            A WCS geometry object.
348
349
        Examples
350
        --------
351
        >>> from gammapy.maps import WcsGeom
352
        >>> from gammapy.maps import MapAxis
353
        >>> axis = MapAxis.from_bounds(0,1,2)
354
        >>> geom = WcsGeom.create(npix=(100,100), binsz=0.1)
355
        >>> geom = WcsGeom.create(npix=(100,100), binsz="0.1deg")
356
        >>> geom = WcsGeom.create(npix=[100,200], binsz=[0.1,0.05], axes=[axis])
357
        >>> geom = WcsGeom.create(npix=[100,200], binsz=["0.1deg","0.05deg"], axes=[axis])
358
        >>> geom = WcsGeom.create(width=[5.0,8.0], binsz=[0.1,0.05], axes=[axis])
359
        >>> geom = WcsGeom.create(npix=([100,200],[100,200]), binsz=0.1, axes=[axis])
360
        """
361
        if skydir is None:
362
            crval = (0.0, 0.0)
363
        elif isinstance(skydir, tuple):
364
            crval = skydir
365
        elif isinstance(skydir, SkyCoord):
366
            xref, yref, frame = skycoord_to_lonlat(skydir, frame=frame)
367
            crval = (xref, yref)
368
        else:
369
            raise ValueError(f"Invalid type for skydir: {type(skydir)!r}")
370
371
        if width is not None:
372
            width = _check_width(width)
373
374
        binsz = _check_binsz(binsz)
375
376
        shape = max([get_shape(t) for t in [npix, binsz, width]])
377
        binsz = cast_to_shape(binsz, shape, float)
378
379
        # If both npix and width are None then create an all-sky geometry
380
        if npix is None and width is None:
381
            width = (360.0, 180.0)
382
383
        if npix is None:
384
            width = cast_to_shape(width, shape, float)
385
            npix = (
386
                np.rint(width[0] / binsz[0]).astype(int),
387
                np.rint(width[1] / binsz[1]).astype(int),
388
            )
389
        else:
390
            npix = cast_to_shape(npix, shape, int)
391
392
        if refpix is None:
393
            nxpix = int(npix[0].flat[0])
394
            nypix = int(npix[1].flat[0])
395
            refpix = ((nxpix + 1) / 2.0, (nypix + 1) / 2.0)
396
397
        # get frame class
398
        frame = SkyCoord(np.nan, np.nan, frame=frame, unit="deg").frame
399
        wcs = celestial_frame_to_wcs(frame, projection=proj)
400
        wcs.wcs.crpix = refpix
401
        wcs.wcs.crval = crval
402
403
        cdelt = (-binsz[0].flat[0], binsz[1].flat[0])
404
        wcs.wcs.cdelt = cdelt
405
406
        wcs.array_shape = npix[0].flat[0], npix[1].flat[0]
407
        wcs.wcs.datfix()
408
        return cls(wcs, npix, cdelt=binsz, axes=axes)
409
410
    @property
411
    def footprint(self):
412
        """Footprint of the geometry"""
413
        coords = self.wcs.calc_footprint()
414
        return SkyCoord(coords, frame=self.frame, unit="deg")
415
416
    @classmethod
417
    def from_aligned(cls, geom, skydir, width):
418
        """Create an aligned geometry from an existing one
419
420
        Parameters
421
        ----------
422
        geom : `~WcsGeom`
423
            A reference WCS geometry object.
424
        skydir : tuple or `~astropy.coordinates.SkyCoord`
425
            Sky position of map center.  Can be either a SkyCoord
426
            object or a tuple of longitude and latitude in deg in the
427
            coordinate system of the map.
428
        width : float or tuple or list or string
429
            Width of the map in degrees.  A tuple will be interpreted
430
            as parameters for longitude and latitude axes.  For maps
431
            with non-spatial dimensions, list input can be used to
432
            define a different map width in each image plane.
433
434
        Returns
435
        -------
436
        geom : `~WcsGeom`
437
            An aligned WCS geometry object with specified size and center.
438
439
        """
440
        width = _check_width(width) * u.deg
441
        npix = tuple(np.round(width / geom.pixel_scales).astype(int))
442
        xref, yref = geom.to_image().coord_to_pix(skydir)
443
        xref = int(np.floor(-xref + npix[0] / 2.0)) + geom.wcs.wcs.crpix[0]
444
        yref = int(np.floor(-yref + npix[1] / 2.0)) + geom.wcs.wcs.crpix[1]
445
        return cls.create(
446
            skydir=tuple(geom.wcs.wcs.crval),
447
            npix=npix,
448
            refpix=(xref, yref),
449
            frame=geom.frame,
450
            binsz=tuple(geom.pixel_scales.deg),
451
            axes=geom.axes,
452
            proj=geom.projection,
453
        )
454
455
    @classmethod
456
    def from_header(cls, header, hdu_bands=None, format="gadf"):
457
        """Create a WCS geometry object from a FITS header.
458
459
        Parameters
460
        ----------
461
        header : `~astropy.io.fits.Header`
462
            The FITS header
463
        hdu_bands : `~astropy.io.fits.BinTableHDU`
464
            The BANDS table HDU.
465
        format : {'gadf', 'fgst-ccube','fgst-template'}
466
            FITS format convention.
467
468
        Returns
469
        -------
470
        wcs : `~WcsGeom`
471
            WCS geometry object.
472
        """
473
        wcs = WCS(header, naxis=2)
474
        # TODO: see https://github.com/astropy/astropy/issues/9259
475
        wcs._naxis = wcs._naxis[:2]
476
477
        axes = MapAxes.from_table_hdu(hdu_bands, format=format)
478
        shape = axes.shape
479
480
        if hdu_bands is not None and "NPIX" in hdu_bands.columns.names:
481
            npix = hdu_bands.data.field("NPIX").reshape(shape + (2,))
482
            npix = (npix[..., 0], npix[..., 1])
483
            cdelt = hdu_bands.data.field("CDELT").reshape(shape + (2,))
484
            cdelt = (cdelt[..., 0], cdelt[..., 1])
485
        elif "WCSSHAPE" in header:
486
            wcs_shape = eval(header["WCSSHAPE"])
487
            npix = (wcs_shape[0], wcs_shape[1])
488
            cdelt = None
489
            wcs.array_shape = npix
490
        else:
491
            npix = (header["NAXIS1"], header["NAXIS2"])
492
            cdelt = None
493
494
        return cls(wcs, npix, cdelt=cdelt, axes=axes)
495
496
    def _make_bands_cols(self):
497
498
        cols = []
499
        if not self.is_regular:
500
            cols += [
501
                fits.Column(
502
                    "NPIX",
503
                    "2I",
504
                    dim="(2)",
505
                    array=np.vstack((np.ravel(self.npix[0]), np.ravel(self.npix[1]))).T,
506
                )
507
            ]
508
            cols += [
509
                fits.Column(
510
                    "CDELT",
511
                    "2E",
512
                    dim="(2)",
513
                    array=np.vstack(
514
                        (np.ravel(self._cdelt[0]), np.ravel(self._cdelt[1]))
515
                    ).T,
516
                )
517
            ]
518
            cols += [
519
                fits.Column(
520
                    "CRPIX",
521
                    "2E",
522
                    dim="(2)",
523
                    array=np.vstack(
524
                        (np.ravel(self._crpix[0]), np.ravel(self._crpix[1]))
525
                    ).T,
526
                )
527
            ]
528
        return cols
529
530
    def to_header(self):
531
        header = self.wcs.to_header()
532
        header.update(self.axes.to_header())
533
        shape = "{},{}".format(np.max(self.npix[0]), np.max(self.npix[1]))
534
        for ax in self.axes:
535
            shape += f",{ax.nbin}"
536
537
        header["WCSSHAPE"] = f"({shape})"
538
        return header
539
540
    def get_idx(self, idx=None, flat=False):
541
        pix = self.get_pix(idx=idx, mode="center")
542
        if flat:
543
            pix = tuple([p[np.isfinite(p)] for p in pix])
544
        return pix_tuple_to_idx(pix)
545
546
    def _get_pix_all(
547
        self, idx=None, mode="center", sparse=False, axis_name=("lon", "lat")
548
    ):
549
        """Get idx coordinate array without footprint of the projection applied"""
550
        pix_all = []
551
552
        for name, nbin in zip(self.axes_names, self._shape):
553
            if mode == "edges" and name in axis_name:
554
                pix = np.arange(-0.5, nbin, dtype=float)
555
            else:
556
                pix = np.arange(nbin, dtype=float)
557
558
            pix_all.append(pix)
559
560
        # TODO: improve varying bin size coordinate handling
561
        if idx is not None:
562
            pix_all = pix_all[self._slice_spatial_axes] + [float(t) for t in idx]
563
564
        return np.meshgrid(*pix_all[::-1], indexing="ij", sparse=sparse)[::-1]
565
566
    def get_pix(self, idx=None, mode="center"):
567
        """Get map pix coordinates from the geometry.
568
569
        Parameters
570
        ----------
571
        mode : {'center', 'edges'}
572
            Get center or edge pix coordinates for the spatial axes.
573
574
        Returns
575
        -------
576
        coord : tuple
577
            Map pix coordinate tuple.
578
        """
579
        pix = self._get_pix_all(idx=idx, mode=mode)
580
        coords = self.pix_to_coord(pix)
581
        m = np.isfinite(coords[0])
582
        for _ in pix:
583
            _[~m] = INVALID_INDEX.float
584
        return pix
585
586
    def get_coord(
587
        self, idx=None, mode="center", frame=None, sparse=False, axis_name=None
588
    ):
589
        """Get map coordinates from the geometry.
590
591
        Parameters
592
        ----------
593
        mode : {'center', 'edges'}
594
            Get center or edge coordinates for the spatial axes.
595
        frame : str or `~astropy.coordinates.Frame`
596
            Coordinate frame
597
        sparse : bool
598
            Compute sparse coordinates
599
        axis_name : str
600
            If mode = "edges", the edges will be returned for this axis.
601
602
        Returns
603
        -------
604
        coord : `~MapCoord`
605
            Map coordinate object.
606
        """
607
        if axis_name is None:
608
            axis_name = ("lon", "lat")
609
610
        if frame is None:
611
            frame = self.frame
612
613
        pix = self._get_pix_all(idx=idx, mode=mode, sparse=sparse, axis_name=axis_name)
614
615
        data = self.pix_to_coord(pix)
616
617
        coords = MapCoord.create(
618
            data=data, frame=self.frame, axis_names=self.axes.names
619
        )
620
        return coords.to_frame(frame)
621
622
    def coord_to_pix(self, coords):
623
        coords = MapCoord.create(coords, frame=self.frame, axis_names=self.axes.names)
624
625
        if coords.size == 0:
626
            return tuple([np.array([]) for i in range(coords.ndim)])
627
628
        # Variable Bin Size
629
        if not self.is_regular:
630
            idxs = self.axes.coord_to_idx(coords, clip=True)
631
            crpix = [t[idxs] for t in self._crpix]
632
            cdelt = [t[idxs] for t in self._cdelt]
633
            pix = world2pix(self.wcs, cdelt, crpix, (coords.lon, coords.lat))
634
            pix = list(pix)
635
        else:
636
            pix = self._wcs.wcs_world2pix(coords.lon, coords.lat, 0)
637
638
        pix += self.axes.coord_to_pix(coords)
639
        return tuple(pix)
640
641
    def pix_to_coord(self, pix):
642
        # Variable Bin Size
643
        if not self.is_regular:
644
            idxs = pix_tuple_to_idx(pix[self._slice_non_spatial_axes])
645
            crpix = [t[idxs] for t in self._crpix]
646
            cdelt = [t[idxs] for t in self._cdelt]
647
            coords = pix2world(self.wcs, cdelt, crpix, pix[self._slice_spatial_axes])
648
        else:
649
            coords = self._wcs.wcs_pix2world(pix[0], pix[1], 0)
650
651
        coords = (
652
            u.Quantity(coords[0], unit="deg", copy=False),
653
            u.Quantity(coords[1], unit="deg", copy=False),
654
        )
655
656
        coords += self.axes.pix_to_coord(pix[self._slice_non_spatial_axes])
657
        return coords
658
659
    def pix_to_idx(self, pix, clip=False):
660
        pix = pix_tuple_to_idx(pix)
661
662
        idx_non_spatial = self.axes.pix_to_idx(
663
            pix[self._slice_non_spatial_axes], clip=clip
664
        )
665
666
        if not self.is_regular:
667
            npix = (self.npix[0][idx_non_spatial], self.npix[1][idx_non_spatial])
668
        else:
669
            npix = self.npix
670
671
        idx_spatial = []
672
673
        for idx, npix_ in zip(pix[self._slice_spatial_axes], npix):
674
            if clip:
675
                idx = np.clip(idx, 0, npix_)
676
            else:
677
                idx = np.where((idx < 0) | (idx >= npix_), -1, idx)
678
679
            idx_spatial.append(idx)
680
681
        return tuple(idx_spatial) + idx_non_spatial
682
683
    def contains(self, coords):
684
        idx = self.coord_to_idx(coords)
685
        return np.all(np.stack([t != INVALID_INDEX.int for t in idx]), axis=0)
686
687
    def to_image(self):
688
        return self._image_geom
689
690
    @lazyproperty
691
    def _image_geom(self):
692
        npix = (np.max(self._npix[0]), np.max(self._npix[1]))
693
        cdelt = (np.max(self._cdelt[0]), np.max(self._cdelt[1]))
694
        return self.__class__(self._wcs, npix, cdelt=cdelt)
695
696
    def to_cube(self, axes):
697
        npix = (np.max(self._npix[0]), np.max(self._npix[1]))
698
        cdelt = (np.max(self._cdelt[0]), np.max(self._cdelt[1]))
699
        axes = copy.deepcopy(self.axes) + axes
700
        return self.__class__(
701
            self._wcs.deepcopy(),
702
            npix,
703
            cdelt=cdelt,
704
            axes=axes,
705
        )
706
707
    def _pad_spatial(self, pad_width):
708
        if np.isscalar(pad_width):
709
            pad_width = (pad_width, pad_width)
710
711
        npix = (self.npix[0] + 2 * pad_width[0], self.npix[1] + 2 * pad_width[1])
712
        wcs = self._wcs.deepcopy()
713
        wcs.wcs.crpix += np.array(pad_width)
714
        cdelt = copy.deepcopy(self._cdelt)
715
        return self.__class__(wcs, npix, cdelt=cdelt, axes=copy.deepcopy(self.axes))
716
717
    def crop(self, crop_width):
718
        if np.isscalar(crop_width):
719
            crop_width = (crop_width, crop_width)
720
721
        npix = (self.npix[0] - 2 * crop_width[0], self.npix[1] - 2 * crop_width[1])
722
        wcs = self._wcs.deepcopy()
723
        wcs.wcs.crpix -= np.array(crop_width)
724
        cdelt = copy.deepcopy(self._cdelt)
725
        return self.__class__(wcs, npix, cdelt=cdelt, axes=copy.deepcopy(self.axes))
726
727
    def downsample(self, factor, axis_name=None):
728
        if axis_name is None:
729
            if np.any(np.mod(self.npix, factor) > 0):
730
                raise ValueError(
731
                    f"Spatial shape not divisible by factor {factor!r} in all axes."
732
                    f" You need to pad prior to calling downsample."
733
                )
734
735
            npix = (self.npix[0] / factor, self.npix[1] / factor)
736
            cdelt = (self._cdelt[0] * factor, self._cdelt[1] * factor)
737
            wcs = get_resampled_wcs(self.wcs, factor, True)
738
            return self._init_copy(wcs=wcs, npix=npix, cdelt=cdelt)
739
        else:
740
            if not self.is_regular:
741
                raise NotImplementedError(
742
                    "Upsampling in non-spatial axes not supported for irregular geometries"
743
                )
744
            axes = self.axes.downsample(factor=factor, axis_name=axis_name)
745
            return self._init_copy(axes=axes)
746
747
    def upsample(self, factor, axis_name=None):
748
        if axis_name is None:
749
            npix = (self.npix[0] * factor, self.npix[1] * factor)
750
            cdelt = (self._cdelt[0] / factor, self._cdelt[1] / factor)
751
            wcs = get_resampled_wcs(self.wcs, factor, False)
752
            return self._init_copy(wcs=wcs, npix=npix, cdelt=cdelt)
753
        else:
754
            if not self.is_regular:
755
                raise NotImplementedError(
756
                    "Upsampling in non-spatial axes not supported for irregular geometries"
757
                )
758
            axes = self.axes.upsample(factor=factor, axis_name=axis_name)
759
            return self._init_copy(axes=axes)
760
761
    def to_binsz(self, binsz):
762
        """Change pixel size of the geometry.
763
764
        Parameters
765
        ----------
766
        binsz : float or tuple or list
767
            New pixel size in degree.
768
769
        Returns
770
        -------
771
        geom : `WcsGeom`
772
            Geometry with new pixel size.
773
        """
774
        return self.create(
775
            skydir=self.center_skydir,
776
            binsz=binsz,
777
            width=self.width,
778
            proj=self.projection,
779
            frame=self.frame,
780
            axes=copy.deepcopy(self.axes),
781
        )
782
783
    def solid_angle(self):
784
        """Solid angle array (`~astropy.units.Quantity` in ``sr``).
785
786
        The array has the same dimension as the WcsGeom object.
787
788
        To return solid angles for the spatial dimensions only use::
789
790
             WcsGeom.to_image().solid_angle()
791
        """
792
        return self._solid_angle
793
794
    @lazyproperty
795
    def _solid_angle(self):
796
        coord = self.get_coord(mode="edges").skycoord
797
798
        # define pixel corners
799
        low_left = coord[..., :-1, :-1]
800
        low_right = coord[..., 1:, :-1]
801
        up_left = coord[..., :-1, 1:]
802
        up_right = coord[..., 1:, 1:]
803
804
        # compute side lengths
805
        low = low_left.separation(low_right)
806
        left = low_left.separation(up_left)
807
        up = up_left.separation(up_right)
808
        right = low_right.separation(up_right)
809
810
        # compute enclosed angles
811
        angle_low_right = low_right.position_angle(up_right) - low_right.position_angle(
812
            low_left
813
        )
814
        angle_up_left = up_left.position_angle(up_right) - low_left.position_angle(
815
            up_left
816
        )
817
818
        # compute area assuming a planar triangle
819
        area_low_right = 0.5 * low * right * np.sin(angle_low_right)
820
        area_up_left = 0.5 * up * left * np.sin(angle_up_left)
821
        # TODO: for non-negative cdelt a negative solid angle is returned
822
        #  find out why and fix properly
823
        return np.abs(u.Quantity(area_low_right + area_up_left, "sr", copy=False))
824
825
    def bin_volume(self):
826
        """Bin volume (`~astropy.units.Quantity`)"""
827
        return self._bin_volume
828
829
    @lazyproperty
830
    def _bin_volume(self):
831
        """Cached property of bin volume"""
832
        value = self.to_image().solid_angle()
833
834
        if not self.is_image:
835
            value = value * self.axes.bin_volume()
836
837
        return value
838
839
    def separation(self, center):
840
        """Compute sky separation wrt a given center.
841
842
        Parameters
843
        ----------
844
        center : `~astropy.coordinates.SkyCoord`
845
            Center position
846
847
        Returns
848
        -------
849
        separation : `~astropy.coordinates.Angle`
850
            Separation angle array (2D)
851
        """
852
        coord = self.to_image().get_coord()
853
        return center.separation(coord.skycoord)
854
855
    def cutout(self, position, width, mode="trim", odd_npix=False):
856
        """
857
        Create a cutout around a given position.
858
859
        Parameters
860
        ----------
861
        position : `~astropy.coordinates.SkyCoord`
862
            Center position of the cutout region.
863
        width : tuple of `~astropy.coordinates.Angle`
864
            Angular sizes of the region in (lon, lat) in that specific order.
865
            If only one value is passed, a square region is extracted.
866
        mode : {'trim', 'partial', 'strict'}
867
            Mode option for Cutout2D, for details see `~astropy.nddata.utils.Cutout2D`.
868
        odd_npix : bool
869
            Force width to odd number of pixels.
870
871
        Returns
872
        -------
873
        cutout : `~gammapy.maps.WcsNDMap`
874
            Cutout map
875
        """
876
        width = _check_width(width) * u.deg
877
878
        binsz = self.pixel_scales
879
        width_npix = np.clip((width / binsz).to_value(""), 1, None)
880
        width = width_npix * binsz
881
882
        if odd_npix:
883
            width = round_up_to_odd(width_npix)
884
885
        dummy_data = np.empty(self.to_image().data_shape, dtype=bool)
886
        c2d = Cutout2D(
887
            data=dummy_data,
888
            wcs=self.wcs,
889
            position=position,
890
            # Cutout2D takes size with order (lat, lon)
891
            size=width[::-1],
892
            mode=mode,
893
        )
894
        return self._init_copy(wcs=c2d.wcs, npix=c2d.shape[::-1])
895
896
    def boundary_mask(self, width):
897
        """Create a mask applying binary erosion with a given width from geom edges
898
899
        Parameters
900
        ----------
901
        width : tuple of `~astropy.units.Quantity`
902
            Angular sizes of the margin in (lon, lat) in that specific order.
903
            If only one value is passed, the same margin is applied in (lon, lat).
904
905
        Returns
906
        -------
907
        mask_map : `~gammapy.maps.WcsNDMap` of boolean type
908
            Boundary mask
909
910
        """
911
        from .ndmap import WcsNDMap
912
913
        data = np.ones(self.data_shape, dtype=bool)
914
        return WcsNDMap.from_geom(self, data=data).binary_erode(
915
            width=2 * u.Quantity(width), kernel="box"
916
        )
917
918
    def region_mask(self, regions, inside=True):
919
        """Create a mask from a given list of regions
920
921
        The mask is filled such that a pixel inside the region is filled with
922
        "True". To invert the mask, e.g. to create a mask with exclusion regions
923
        the tilde (~) operator can be used (see example below).
924
925
        Parameters
926
        ----------
927
        regions : str, `~regions.Region` or list of `~regions.Region`
928
            Region or list of regions (pixel or sky regions accepted).
929
            A region can be defined as a string ind DS9 format as well.
930
            See http://ds9.si.edu/doc/ref/region.html for details.
931
        inside : bool
932
            For ``inside=True``, pixels in the region to True (the default).
933
            For ``inside=False``, pixels in the region are False.
934
935
        Returns
936
        -------
937
        mask_map : `~gammapy.maps.WcsNDMap` of boolean type
938
            Boolean region mask
939
940
941
        Examples
942
        --------
943
        Make an exclusion mask for a circular region::
944
945
            from regions import CircleSkyRegion
946
            from astropy.coordinates import SkyCoord, Angle
947
            from gammapy.maps import WcsNDMap, WcsGeom
948
949
            pos = SkyCoord(0, 0, unit='deg')
950
            geom = WcsGeom.create(skydir=pos, npix=100, binsz=0.1)
951
952
            region = CircleSkyRegion(
953
                SkyCoord(3, 2, unit='deg'),
954
                Angle(1, 'deg'),
955
            )
956
957
            # the Gammapy convention for exclusion regions is to take the inverse
958
            mask = ~geom.region_mask([region])
959
960
        Note how we made a list with a single region,
961
        since this method expects a list of regions.
962
        """
963
        from gammapy.maps import Map, RegionGeom
964
965
        if not self.is_regular:
966
            raise ValueError("Multi-resolution maps not supported yet")
967
968
        geom = RegionGeom.from_regions(regions, wcs=self.wcs)
969
        idx = self.get_idx()
970
        mask = geom.contains_wcs_pix(idx)
971
972
        if not inside:
973
            np.logical_not(mask, out=mask)
974
975
        return Map.from_geom(self, data=mask)
976
977
    def region_weights(self, regions, oversampling_factor=10):
978
        """Compute regions weights
979
980
        Parameters
981
        ----------
982
        regions : str, `~regions.Region` or list of `~regions.Region`
983
            Region or list of regions (pixel or sky regions accepted).
984
            A region can be defined as a string ind DS9 format as well.
985
            See http://ds9.si.edu/doc/ref/region.html for details.
986
        oversampling_factor : int
987
            Over-sampling factor to compute the region weights
988
989
        Returns
990
        -------
991
        map : `~gammapy.maps.WcsNDMap` of boolean type
992
            Weights region mask
993
        """
994
        geom = self.upsample(factor=oversampling_factor)
995
        m = geom.region_mask(regions=regions)
996
        m.data = m.data.astype(float)
997
        return m.downsample(factor=oversampling_factor, preserve_counts=False)
998
999
    def binary_structure(self, width, kernel="disk"):
1000
        """Get binary structure
1001
1002
        Parameters
1003
        ----------
1004
        width : `~astropy.units.Quantity`, str or float
1005
            If a float is given it interpreted as width in pixels. If an (angular)
1006
            quantity is given it converted to pixels using ``geom.wcs.wcs.cdelt``.
1007
            The width corresponds to radius in case of a disk kernel, and
1008
            the side length in case of a box kernel.
1009
        kernel : {'disk', 'box'}
1010
            Kernel shape
1011
1012
        Returns
1013
        -------
1014
        structure : `~numoy.ndarray`
1015
            Binary structure
1016
        """
1017
        width = u.Quantity(width)
1018
1019
        if width.unit.is_equivalent("deg"):
1020
            width = width / self.pixel_scales
1021
1022
        width = round_up_to_odd(width.to_value(""))
1023
1024
        if kernel == "disk":
1025
            disk = Tophat2DKernel(width[0])
1026
            disk.normalize("peak")
1027
            structure = disk.array
1028
        elif kernel == "box":
1029
            structure = np.ones(width)
1030
        else:
1031
            raise ValueError(f"Invalid kernel: {kernel!r}")
1032
1033
        shape = (1,) * len(self.axes) + structure.shape
1034
        return structure.reshape(shape)
1035
1036
    def __repr__(self):
1037
        lon = self.center_skydir.data.lon.deg
1038
        lat = self.center_skydir.data.lat.deg
1039
        lon_ref, lat_ref = self.wcs.wcs.crval
1040
1041
        return (
1042
            f"{self.__class__.__name__}\n\n"
1043
            f"\taxes       : {self.axes_names}\n"
1044
            f"\tshape      : {self.data_shape[::-1]}\n"
1045
            f"\tndim       : {self.ndim}\n"
1046
            f"\tframe      : {self.frame}\n"
1047
            f"\tprojection : {self.projection}\n"
1048
            f"\tcenter     : {lon:.1f} deg, {lat:.1f} deg\n"
1049
            f"\twidth      : {self.width[0][0]:.1f} x {self.width[1][0]:.1f}\n"
1050
            f"\twcs ref    : {lon_ref:.1f} deg, {lat_ref:.1f} deg\n"
1051
        )
1052
1053
    def to_odd_npix(self, max_radius=None):
1054
        """Create a new geom object with an odd number of pixel and a maximum size.
1055
1056
        This is useful for PSF kernel creation.
1057
1058
        Parameters
1059
        ----------
1060
        max_radius : `~astropy.units.Quantity`
1061
            Max. radius of the geometry (half the width)
1062
1063
        Returns
1064
        -------
1065
        geom : `WcsGeom`
1066
            Geom with odd number of pixels
1067
        """
1068
        if max_radius is None:
1069
            width = self.width.max()
1070
        else:
1071
            width = 2 * u.Quantity(max_radius)
1072
1073
        binsz = self.pixel_scales.max()
1074
1075
        width_npix = (width / binsz).to_value("")
1076
        npix = round_up_to_odd(width_npix)
1077
        return WcsGeom.create(
1078
            skydir=self.center_skydir,
1079
            binsz=binsz,
1080
            npix=npix,
1081
            proj=self.projection,
1082
            frame=self.frame,
1083
            axes=self.axes,
1084
        )
1085
1086
    def to_even_npix(self):
1087
        """Create a new geom object with an even number of pixel and a maximum size.
1088
1089
        Returns
1090
        -------
1091
        geom : `WcsGeom`
1092
            Geom with odd number of pixels
1093
        """
1094
        width = self.width.max()
1095
        binsz = self.pixel_scales.max()
1096
1097
        width_npix = (width / binsz).to_value("")
1098
        npix = round_up_to_even(width_npix)
1099
        return WcsGeom.create(
1100
            skydir=self.center_skydir,
1101
            binsz=binsz,
1102
            npix=npix,
1103
            proj=self.projection,
1104
            frame=self.frame,
1105
            axes=self.axes,
1106
        )
1107
1108
    def is_aligned(self, other, tolerance=1e-6):
1109
        """Check if WCS and extra axes are aligned.
1110
1111
        Parameters
1112
        ----------
1113
        other : `WcsGeom`
1114
            Other geom.
1115
        tolerance : float
1116
            Tolerance for the comparison.
1117
1118
        Returns
1119
        -------
1120
        aligned : bool
1121
            Whether geometries are aligned
1122
        """
1123
        for axis, otheraxis in zip(self.axes, other.axes):
1124
            if axis != otheraxis:
1125
                return False
1126
1127
        # check WCS consistency with a priori tolerance of 1e-6
1128
        return self.wcs.wcs.compare(other.wcs.wcs, cmp=2, tolerance=tolerance)
1129
1130
    def is_allclose(self, other, rtol_axes=1e-6, atol_axes=1e-6, rtol_wcs=1e-6):
1131
        """Compare two data IRFs for equivalency
1132
1133
        Parameters
1134
        ----------
1135
        other :  `WcsGeom`
1136
            Geom to compare against
1137
        rtol_axes : float
1138
            Relative tolerance for the axes comparison.
1139
        atol_axes : float
1140
            Relative tolerance for the axes comparison.
1141
        rtol_wcs : float
1142
            Relative tolerance for the wcs comparison.
1143
1144
        Returns
1145
        -------
1146
        is_allclose : bool
1147
            Whether the geometry is all close.
1148
        """
1149
        if not isinstance(other, self.__class__):
1150
            return TypeError(f"Cannot compare {type(self)} and {type(other)}")
1151
1152
        if self.data_shape != other.data_shape:
1153
            return False
1154
1155
        axes_eq = self.axes.is_allclose(other.axes, rtol=rtol_axes, atol=atol_axes)
1156
1157
        # check WCS consistency with a priori tolerance of 1e-6
1158
        # cmp=1 parameter ensures no comparison with ancillary information
1159
        # see https://github.com/astropy/astropy/pull/4522/files
1160
        wcs_eq = self.wcs.wcs.compare(other.wcs.wcs, cmp=1, tolerance=rtol_wcs)
1161
1162
        return axes_eq and wcs_eq
1163
1164
    def __eq__(self, other):
1165
        if not isinstance(other, self.__class__):
1166
            return False
1167
1168
        if not (self.is_regular and other.is_regular):
1169
            raise NotImplementedError(
1170
                "Geom comparison is not possible for irregular geometries."
1171
            )
1172
1173
        return self.is_allclose(other=other, rtol_wcs=1e-6, rtol_axes=1e-6)
1174
1175
    def __ne__(self, other):
1176
        return not self.__eq__(other)
1177
1178
    def __hash__(self):
1179
        return id(self)
1180
1181
1182
def pix2world(wcs, cdelt, crpix, pix):
1183
    """Perform pixel to world coordinate transformation.
1184
1185
    For a WCS projection with a given pixel size (CDELT) and reference pixel
1186
    (CRPIX). This method can be used to perform WCS transformations
1187
    for projections with different pixelizations but the same
1188
    reference coordinate (CRVAL), projection type, and coordinate system.
1189
1190
    Parameters
1191
    ----------
1192
    wcs : `astropy.wcs.WCS`
1193
        WCS transform object.
1194
    cdelt : tuple
1195
        Tuple of X/Y pixel size in deg.  Each element should have the
1196
        same length as ``pix``.
1197
    crpix : tuple
1198
        Tuple of reference pixel parameters in X and Y dimensions.  Each
1199
        element should have the same length as ``pix``.
1200
    pix : tuple
1201
        Tuple of pixel coordinates.
1202
    """
1203
    pix_ratio = [
1204
        np.abs(wcs.wcs.cdelt[0] / cdelt[0]),
1205
        np.abs(wcs.wcs.cdelt[1] / cdelt[1]),
1206
    ]
1207
    pix = (
1208
        (pix[0] - (crpix[0] - 1.0)) / pix_ratio[0] + wcs.wcs.crpix[0] - 1.0,
1209
        (pix[1] - (crpix[1] - 1.0)) / pix_ratio[1] + wcs.wcs.crpix[1] - 1.0,
1210
    )
1211
    return wcs.wcs_pix2world(pix[0], pix[1], 0)
1212
1213
1214
def world2pix(wcs, cdelt, crpix, coord):
1215
    pix_ratio = [
1216
        np.abs(wcs.wcs.cdelt[0] / cdelt[0]),
1217
        np.abs(wcs.wcs.cdelt[1] / cdelt[1]),
1218
    ]
1219
    pix = wcs.wcs_world2pix(coord[0], coord[1], 0)
1220
    return (
1221
        (pix[0] - (wcs.wcs.crpix[0] - 1.0)) * pix_ratio[0] + crpix[0] - 1.0,
1222
        (pix[1] - (wcs.wcs.crpix[1] - 1.0)) * pix_ratio[1] + crpix[1] - 1.0,
1223
    )
1224