Passed
Pull Request — master (#1898)
by
unknown
04:26
created

gammapy/maps/geom.py (1 issue)

1
# Licensed under a 3-clause BSD style license - see LICENSE.rst
0 ignored issues
show
Too many lines in module (1417/1000)
Loading history...
2
from __future__ import absolute_import, division, print_function, unicode_literals
3
import abc
4
import copy
5
import inspect
6
import re
7
from collections import OrderedDict
8
import numpy as np
9
from ..extern import six
10
from astropy.utils.misc import InheritDocstrings
11
from astropy.io import fits
12
from astropy import units as u
13
from astropy.coordinates import SkyCoord
14
from .utils import find_hdu, find_bands_hdu
15
from ..utils.interpolation import interpolation_scale
16
17
__all__ = ["MapCoord", "MapGeom", "MapAxis"]
18
19
20
def make_axes(axes_in, conv):
21
    """Make a sequence of `~MapAxis` objects."""
22
    if axes_in is None:
23
        return []
24
25
    axes_out = []
26
    for i, ax in enumerate(axes_in):
27
        if isinstance(ax, np.ndarray):
28
            ax = MapAxis(ax)
29
30
        if conv in ["fgst-ccube", "fgst-template"]:
31
            ax.name = "energy"
32
        elif ax.name == "":
33
            ax.name = "axis%i" % i
34
35
        axes_out += [ax]
36
37
    return axes_out
38
39
40
def make_axes_cols(axes, axis_names=None):
41
    """Make FITS table columns for map axes.
42
43
    Parameters
44
    ----------
45
    axes : list
46
        Python list of `MapAxis` objects
47
48
    Returns
49
    -------
50
    cols : list
51
        Python list of `~astropy.io.fits.Column`
52
    """
53
    size = np.prod([ax.nbin for ax in axes])
54
    chan = np.arange(0, size)
55
    cols = [fits.Column("CHANNEL", "I", array=chan)]
56
57
    if axis_names is None:
58
        axis_names = [ax.name for ax in axes]
59
    axis_names = [_.upper() for _ in axis_names]
60
61
    axes_ctr = np.meshgrid(*[ax.center for ax in axes])
62
    axes_min = np.meshgrid(*[ax.edges[:-1] for ax in axes])
63
    axes_max = np.meshgrid(*[ax.edges[1:] for ax in axes])
64
65
    for i, (ax, name) in enumerate(zip(axes, axis_names)):
66
67
        if name == "ENERGY":
68
            colnames = ["ENERGY", "E_MIN", "E_MAX"]
69
        elif name == "TIME":
70
            colnames = ["TIME", "T_MIN", "T_MAX"]
71
        else:
72
            s = "AXIS%i" % i if name == "" else name
73
            colnames = [s, s + "_MIN", s + "_MAX"]
74
75
        for colname, v in zip(colnames, [axes_ctr, axes_min, axes_max]):
76
            array = np.ravel(v[i])
77
            unit = ax.unit.to_string("fits")
78
            cols.append(fits.Column(colname, "E", array=array, unit=unit))
79
80
    return cols
81
82
83
def find_and_read_bands(hdu, header=None):
84
    """Read and returns the map axes from a BANDS table.
85
86
    Parameters
87
    ----------
88
    hdu : `~astropy.io.fits.BinTableHDU`
89
        The BANDS table HDU.
90
    header : `~astropy.io.fits.Header`
91
        Header
92
93
    Returns
94
    -------
95
    axes : list of `~MapAxis`
96
        List of axis objects.
97
    """
98
    if hdu is None:
99
        return []
100
101
    axes = []
102
    axis_cols = []
103
    if hdu.name == "ENERGIES":
104
        axis_cols = [["ENERGY"]]
105
    elif hdu.name == "EBOUNDS":
106
        axis_cols = [["E_MIN", "E_MAX"]]
107
    else:
108
        for i in range(5):
109
            if "AXCOLS%i" % i in hdu.header:
110
                axis_cols += [hdu.header["AXCOLS%i" % i].split(",")]
111
112
    interp = "lin"
113
    for i, cols in enumerate(axis_cols):
114
115
        if "ENERGY" in cols or "E_MIN" in cols:
116
            name = "energy"
117
            interp = "log"
118
        elif re.search("(.+)_MIN", cols[0]):
119
            name = re.search("(.+)_MIN", cols[0]).group(1)
120
        else:
121
            name = cols[0]
122
123
        unit = hdu.data.columns[cols[0]].unit
124
        if unit is None and header is not None:
125
            unit = header.get("CUNIT%i" % (3 + i), "")
126
        if unit is None:
127
            unit = ""
128
        if len(cols) == 2:
129
            xmin = np.unique(hdu.data.field(cols[0]))
130
            xmax = np.unique(hdu.data.field(cols[1]))
131
            nodes = np.append(xmin, xmax[-1])
132
            axes.append(MapAxis(nodes, name=name, unit=unit, interp=interp))
133
        else:
134
            nodes = np.unique(hdu.data.field(cols[0]))
135
            axes.append(MapAxis.from_nodes(nodes, name=name, unit=unit, interp=interp))
136
137
    return axes
138
139
140
def get_shape(param):
141
    if param is None:
142
        return tuple()
143
144
    if not isinstance(param, tuple):
145
        param = [param]
146
147
    return max([np.array(p, ndmin=1).shape for p in param])
148
149
150
def coordsys_to_frame(coordsys):
151
    if coordsys in ["CEL", "C"]:
152
        return "icrs"
153
    elif coordsys in ["GAL", "G"]:
154
        return "galactic"
155
    else:
156
        raise ValueError("Unrecognized coordinate system: {!r}".format(coordsys))
157
158
159
def skycoord_to_lonlat(skycoord, coordsys=None):
160
    """
161
162
    Returns
163
    -------
164
    lon : `~numpy.ndarray`
165
        Longitude in degrees.
166
167
    lat : `~numpy.ndarray`
168
        Latitude in degrees.
169
170
    frame : str
171
        Name of coordinate frame.
172
    """
173
174
    if coordsys in ["CEL", "C"]:
175
        skycoord = skycoord.transform_to("icrs")
176
    elif coordsys in ["GAL", "G"]:
177
        skycoord = skycoord.transform_to("galactic")
178
179
    frame = skycoord.frame.name
180
    if frame in ["icrs", "fk5"]:
181
        return skycoord.ra.deg, skycoord.dec.deg, frame
182
    elif frame in ["galactic"]:
183
        return skycoord.l.deg, skycoord.b.deg, frame
184
    else:
185
        raise ValueError("Unrecognized SkyCoord frame: {!r}".format(frame))
186
187
188
def lonlat_to_skycoord(lon, lat, coordsys):
189
    return SkyCoord(lon, lat, frame=coordsys_to_frame(coordsys), unit="deg")
190
191
192
def pix_tuple_to_idx(pix):
193
    """Convert a tuple of pixel coordinate arrays to a tuple of pixel
194
    indices.
195
196
    Pixel coordinates are rounded to the closest integer value.
197
198
    Parameters
199
    ----------
200
    pix : tuple
201
        Tuple of pixel coordinates with one element for each dimension.
202
203
    Returns
204
    -------
205
    idx : `~numpy.ndarray`
206
        Array of pixel indices.
207
    """
208
    idx = []
209
    for p in pix:
210
        p = np.array(p, ndmin=1)
211
        if np.issubdtype(p.dtype, np.integer):
212
            idx += [p]
213
        else:
214
            p_idx = np.rint(p).astype(int)
215
            p_idx[~np.isfinite(p)] = -1
216
            idx += [p_idx]
217
218
    return tuple(idx)
219
220
221
def axes_pix_to_coord(axes, pix):
222
    """Perform pixel to axis coordinates for a list of `~MapAxis`
223
    objects.
224
225
    Parameters
226
    ----------
227
    axes : list
228
        List of `~MapAxis`.
229
230
    pix : tuple
231
        Tuple of pixel coordinates.
232
    """
233
    coords = []
234
    for ax, t in zip(axes, pix):
235
        coords += [ax.pix_to_coord(t)]
236
237
    return coords
238
239
240
def coord_to_idx(edges, x, clip=False):
241
    """Convert axis coordinates ``x`` to bin indices.
242
243
    Returns -1 for values below/above the lower/upper edge.
244
    """
245
    x = np.array(x, ndmin=1)
246
    ibin = np.digitize(x, edges) - 1
247
248
    if clip:
249
        ibin[x < edges[0]] = 0
250
        ibin[x > edges[-1]] = len(edges) - 1
251
    else:
252
        with np.errstate(invalid="ignore"):
253
            ibin[x > edges[-1]] = -1
254
255
    ibin[~np.isfinite(x)] = -1
256
    return ibin
257
258
259
def bin_to_val(edges, bins):
260
    ctr = 0.5 * (edges[1:] + edges[:-1])
261
    return ctr[bins]
262
263
264
def coord_to_pix(edges, coord, interp="lin"):
265
    """Convert axis coordinates to pixel coordinates using the chosen
266
    interpolation scheme."""
267
    from scipy.interpolate import interp1d
268
269
    scale = interpolation_scale(interp)
270
271
    interp_fn = interp1d(
272
        scale(edges), np.arange(len(edges), dtype=float), fill_value="extrapolate"
273
    )
274
275
    return interp_fn(scale(coord))
276
277
278
def pix_to_coord(edges, pix, interp="lin"):
279
    """Convert pixel coordinates to grid coordinates using the chosen
280
    interpolation scheme."""
281
    from scipy.interpolate import interp1d
282
283
    scale = interpolation_scale(interp)
284
285
    interp_fn = interp1d(
286
        np.arange(len(edges), dtype=float), scale(edges), fill_value="extrapolate"
287
    )
288
289
    return scale.inverse(interp_fn(pix))
290
291
292
class MapAxis(object):
293
    """Class representing an axis of a map.
294
295
    Provides methods for
296
    transforming to/from axis and pixel coordinates.  An axis is
297
    defined by a sequence of node values that lie at the center of
298
    each bin.  The pixel coordinate at each node is equal to its index
299
    in the node array (0, 1, ..).  Bin edges are offset by 0.5 in
300
    pixel coordinates from the nodes such that the lower/upper edge of
301
    the first bin is (-0.5,0.5).
302
303
    Parameters
304
    ----------
305
    nodes : `~numpy.ndarray`
306
        Array of node values.  These will be interpreted as either bin
307
        edges or centers according to ``node_type``.
308
    interp : str
309
        Interpolation method used to transform between axis and pixel
310
        coordinates.  Valid options are 'log', 'lin', and 'sqrt'.
311
    name : str
312
        Axis name
313
    node_type : str
314
        Flag indicating whether coordinate nodes correspond to pixel
315
        edges (node_type = 'edge') or pixel centers (node_type =
316
        'center').  'center' should be used where the map values are
317
        defined at a specific coordinate (e.g. differential
318
        quantities). 'edge' should be used where map values are
319
        defined by an integral over coordinate intervals (e.g. a
320
        counts histogram).
321
    unit : str
322
        String specifying the data units.
323
    """
324
325
    __slots__ = [
326
        "_name",
327
        "_nodes",
328
        "_node_type",
329
        "_interp",
330
        "_pix_offset",
331
        "_nbin",
332
        "_unit",
333
    ]
334
335
    # TODO: Add methods to faciliate FITS I/O.
336
    # TODO: Cache an interpolation object?
337
338
    def __init__(self, nodes, interp="lin", name="", node_type="edges", unit=""):
339
        self.name = name
340
        self.unit = unit
341
        self._nodes = nodes
342
        self._node_type = node_type
343
        self._interp = interp
344
345
        # Set pixel coordinate of first node
346
        if node_type == "edges":
347
            self._pix_offset = -0.5
348
            nbin = len(nodes) - 1
349
        elif node_type == "center":
350
            self._pix_offset = 0.0
351
            nbin = len(nodes)
352
        else:
353
            raise ValueError("Invalid node type: {!r}".format(node_type))
354
355
        self._nbin = nbin
356
357
    def __eq__(self, other):
358
        """Test axis equality. Absolute and relative tolerances of 1e-6 are used"""
359
        if not isinstance(other, self.__class__):
360
            return NotImplemented
361
362
        # TODO: implement an allclose method for MapAxis and call it here
363
        if self.edges.shape != other.edges.shape:
364
            return False
365
366
        return (
367
            np.allclose(self.edges, other.edges, atol=1e-6, rtol=1e-6)
368
            and self.unit == other.unit
369
            and self._node_type == other._node_type
370
            and self._interp == other._interp
371
            and self.name.upper() == other.name.upper()
372
        )
373
374
    def __ne__(self, other):
375
        return not self.__eq__(other)
376
377
    @property
378
    def name(self):
379
        """Name of the axis."""
380
        return self._name
381
382
    @name.setter
383
    def name(self, val):
384
        self._name = val
385
386
    @property
387
    def edges(self):
388
        """Return array of bin edges."""
389
        pix = np.arange(self.nbin + 1, dtype=float) - 0.5
390
        return self.pix_to_coord(pix)
391
392
    @property
393
    def center(self):
394
        """Return array of bin centers."""
395
        pix = np.arange(self.nbin, dtype=float)
396
        return self.pix_to_coord(pix)
397
398
    @property
399
    def nbin(self):
400
        """Return number of bins."""
401
        return self._nbin
402
403
    @property
404
    def node_type(self):
405
        """Return node type ('center' or 'edge')."""
406
        return self._node_type
407
408
    @property
409
    def unit(self):
410
        """Return coordinate axis unit."""
411
        return self._unit
412
413
    @unit.setter
414
    def unit(self, val):
415
        self._unit = u.Unit(val)
416
417
    @classmethod
418
    def from_bounds(cls, lo_bnd, hi_bnd, nbin, **kwargs):
419
        """Generate an axis object from a lower/upper bound and number of bins.
420
421
        If node_type = 'edge' then bounds correspond to the
422
        lower and upper bound of the first and last bin.  If node_type
423
        = 'center' then bounds correspond to the centers of the first
424
        and last bin.
425
426
        Parameters
427
        ----------
428
        lo_bnd : float
429
            Lower bound of first axis bin.
430
        hi_bnd : float
431
            Upper bound of last axis bin.
432
        nbin : int
433
            Number of bins.
434
        interp : {'lin', 'log', 'sqrt'}
435
            Interpolation method used to transform between axis and pixel
436
            coordinates.  Default: 'lin'.
437
        """
438
        interp = kwargs.setdefault("interp", "lin")
439
        node_type = kwargs.setdefault("node_type", "edges")
440
441
        if node_type == "edges":
442
            nnode = nbin + 1
443
        elif node_type == "center":
444
            nnode = nbin
445
        else:
446
            raise ValueError("Invalid node type: {!r}".format(node_type))
447
448
        if interp == "lin":
449
            nodes = np.linspace(lo_bnd, hi_bnd, nnode)
450
        elif interp == "log":
451
            nodes = np.exp(np.linspace(np.log(lo_bnd), np.log(hi_bnd), nnode))
452
        elif interp == "sqrt":
453
            nodes = np.linspace(lo_bnd ** 0.5, hi_bnd ** 0.5, nnode) ** 2.0
454
        else:
455
            raise ValueError("Invalid interp: {}".format(interp))
456
457
        return cls(nodes, **kwargs)
458
459
    @classmethod
460
    def from_nodes(cls, nodes, **kwargs):
461
        """Generate an axis object from a sequence of nodes (bin centers).
462
463
        This will create a sequence of bins with edges half-way
464
        between the node values.  This method should be used to
465
        construct an axis where the bin center should lie at a
466
        specific value (e.g. a map of a continuous function).
467
468
        Parameters
469
        ----------
470
        nodes : `~numpy.ndarray`
471
            Axis nodes (bin center).
472
        interp : {'lin', 'log', 'sqrt'}
473
            Interpolation method used to transform between axis and pixel
474
            coordinates.  Default: 'lin'.
475
        """
476
        nodes = np.array(nodes, ndmin=1)
477
        if len(nodes) < 1:
478
            raise ValueError("Nodes array must have at least one element.")
479
480
        return cls(nodes, node_type="center", **kwargs)
481
482
    @classmethod
483
    def from_edges(cls, edges, **kwargs):
484
        """Generate an axis object from a sequence of bin edges.
485
486
        This method should be used to construct an axis where the bin
487
        edges should lie at specific values (e.g. a histogram).  The
488
        number of bins will be one less than the number of edges.
489
490
        Parameters
491
        ----------
492
        edges : `~numpy.ndarray`
493
            Axis bin edges.
494
        interp : {'lin', 'log', 'sqrt'}
495
            Interpolation method used to transform between axis and pixel
496
            coordinates.  Default: 'lin'.
497
        """
498
        if len(edges) < 2:
499
            raise ValueError("Edges array must have at least two elements.")
500
501
        return cls(edges, node_type="edges", **kwargs)
502
503
    def pix_to_coord(self, pix):
504
        """Transform from pixel to axis coordinates.
505
506
        Parameters
507
        ----------
508
        pix : `~numpy.ndarray`
509
            Array of pixel coordinate values.
510
511
        Returns
512
        -------
513
        coord : `~numpy.ndarray`
514
            Array of axis coordinate values.
515
        """
516
        pix = pix - self._pix_offset
517
        return pix_to_coord(self._nodes, pix, interp=self._interp)
518
519
    def coord_to_pix(self, coord):
520
        """Transform from axis to pixel coordinates.
521
522
        Parameters
523
        ----------
524
        coord : `~numpy.ndarray`
525
            Array of axis coordinate values.
526
527
        Returns
528
        -------
529
        pix : `~numpy.ndarray`
530
            Array of pixel coordinate values.
531
        """
532
        coord = u.Quantity(coord, self.unit, copy=False).value
533
        pix = coord_to_pix(self._nodes, coord, interp=self._interp)
534
        return np.array(pix + self._pix_offset, ndmin=1)
535
536
    def coord_to_idx(self, coord, clip=False):
537
        """Transform from axis coordinate to bin index.
538
539
        Parameters
540
        ----------
541
        coord : `~numpy.ndarray`
542
            Array of axis coordinate values.
543
        clip : bool
544
            Choose whether to clip the index to the valid range of the
545
            axis.  If false then indices for values outside the axis
546
            range will be set -1.
547
548
        Returns
549
        -------
550
        idx : `~numpy.ndarray`
551
            Array of bin indices.
552
        """
553
        coord = u.Quantity(coord, self.unit, copy=False).value
554
        return coord_to_idx(self.edges, coord, clip)
555
556
    def slice(self, idx):
557
        """Create a new axis object by extracting a slice from this axis.
558
559
        Parameters
560
        ----------
561
        idx : slice
562
            Slice object selecting a subselection of the axis.
563
564
        Returns
565
        -------
566
        axis : `~MapAxis`
567
            Sliced axis objected.
568
        """
569
        center = self.center[idx]
570
        idx = self.coord_to_idx(center)
571
        # For edge nodes we need to keep N+1 nodes
572
        if self._node_type == "edges":
573
            idx = tuple(list(idx) + [1 + idx[-1]])
574
575
        nodes = self._nodes[(idx,)]
576
        return MapAxis(
577
            nodes,
578
            interp=self._interp,
579
            name=self._name,
580
            node_type=self._node_type,
581
            unit=self._unit,
582
        )
583
584
    def __repr__(self):
585
        str_ = self.__class__.__name__
586
        str_ += "\n\n"
587
        fmt = "\t{:<10s} : {:<10s}\n"
588
        str_ += fmt.format("name", self.name)
589
        str_ += fmt.format("unit", "{!r}".format(str(self.unit)))
590
        str_ += fmt.format("nbins", str(self.nbin))
591
        str_ += fmt.format("node type", self.node_type)
592
        vals = self.edges if self.node_type == "edges" else self.center
593
        str_ += fmt.format(
594
            "{} min".format(self.node_type),
595
            "{:.1e} {}".format(vals.min(), str(self.unit)),
596
        )
597
        str_ += fmt.format(
598
            "{} max".format(self.node_type),
599
            "{:.1e} {}".format(vals.max(), str(self.unit)),
600
        )
601
        str_ += fmt.format("interp", self._interp)
602
        return str_
603
604
    def copy(self):
605
        """Copy `MapAxis` object"""
606
        return copy.deepcopy(self)
607
608
609
class MapCoord(object):
610
    """Represents a sequence of n-dimensional map coordinates.
611
612
    Contains coordinates for 2 spatial dimensions and an arbitrary
613
    number of additional non-spatial dimensions.
614
615
    For further information see :ref:`mapcoord`.
616
617
    Parameters
618
    ----------
619
    data : `~collections.OrderedDict` of `~numpy.ndarray`
620
        Dictionary of coordinate arrays.
621
    coordsys : {'CEL', 'GAL', None}
622
        Spatial coordinate system.  If None then the coordinate system
623
        will be set to the native coordinate system of the geometry.
624
    match_by_name : bool
625
        Match coordinates to axes by name?
626
        If false coordinates will be matched by index.
627
    """
628
629
    def __init__(self, data, coordsys=None, match_by_name=True):
630
631
        if "lon" not in data or "lat" not in data:
632
            raise ValueError("data dictionary must contain axes named 'lon' and 'lat'.")
633
634
        if issubclass(data["lon"].__class__, u.Quantity):
635
            raise ValueError("Quantity not supported for 'lon'")
636
        if issubclass(data["lat"].__class__, u.Quantity):
637
            raise ValueError("Quantity not supported for 'lat'")
638
639
        data = OrderedDict(
640
            [(k, np.atleast_1d(np.asanyarray(v))) for k, v in data.items()]
641
        )
642
        vals = np.broadcast_arrays(*data.values(), subok=True)
643
        self._data = OrderedDict(zip(data.keys(), vals))
644
        self._coordsys = coordsys
645
        self._match_by_name = match_by_name
646
647
    def __getitem__(self, key):
648
        if isinstance(key, six.string_types):
649
            return self._data[key]
650
        else:
651
            return list(self._data.values())[key]
652
653
    def __iter__(self):
654
        return iter(self._data.values())
655
656
    @property
657
    def ndim(self):
658
        """Number of dimensions."""
659
        return len(self._data)
660
661
    @property
662
    def shape(self):
663
        """Coordinate array shape."""
664
        return self[0].shape
665
666
    @property
667
    def size(self):
668
        return self[0].size
669
670
    @property
671
    def lon(self):
672
        """Longitude coordinate in degrees."""
673
        return self._data["lon"]
674
675
    @property
676
    def lat(self):
677
        """Latitude coordinate in degrees."""
678
        return self._data["lat"]
679
680
    @property
681
    def theta(self):
682
        """Theta co-latitude angle in radians"""
683
        return np.pi / 2.0 - np.radians(self.lat)
684
685
    @property
686
    def phi(self):
687
        """Phi longitude angle in radians"""
688
        return np.radians(self.lon)
689
690
    @property
691
    def coordsys(self):
692
        """Coordinate system (str)"""
693
        return self._coordsys
694
695
    @property
696
    def match_by_name(self):
697
        """Boolean flag indicating whether axis lookup should be performed by
698
        name (True) or index (False).
699
        """
700
        return self._match_by_name
701
702
    @property
703
    def skycoord(self):
704
        return SkyCoord(
705
            self.lon, self.lat, unit="deg", frame=coordsys_to_frame(self.coordsys)
706
        )
707
708
    @classmethod
709
    def _from_lonlat(cls, coords, coordsys=None):
710
        """Create a `~MapCoord` from a tuple of coordinate vectors.
711
712
        The first two elements of the tuple should be longitude and latitude in degrees.
713
714
        Parameters
715
        ----------
716
        coords : tuple
717
            Tuple of `~numpy.ndarray`.
718
719
        Returns
720
        -------
721
        coord : `~MapCoord`
722
            A coordinates object.
723
        """
724
        if isinstance(coords, (list, tuple)):
725
            coords_dict = OrderedDict([("lon", coords[0]), ("lat", coords[1])])
726
            for i, c in enumerate(coords[2:]):
727
                coords_dict["axis{}".format(i)] = c
728
        else:
729
            raise ValueError("Unrecognized input type.")
730
731
        return cls(coords_dict, coordsys=coordsys, match_by_name=False)
732
733
    @classmethod
734
    def _from_skycoord(cls, coords, coordsys=None):
735
        """Create from vector of `~astropy.coordinates.SkyCoord`.
736
737
        Parameters
738
        ----------
739
        coords : tuple
740
            Coordinate tuple with first element of type
741
            `~astropy.coordinates.SkyCoord`.
742
        coordsys : {'CEL', 'GAL', None}
743
            Spatial coordinate system of output `~MapCoord` object.
744
            If None the coordinate system will be set to the frame of
745
            the `~astropy.coordinates.SkyCoord` object.
746
        """
747
        skycoord = coords[0]
748
        name = skycoord.frame.name
749
        if name in ["icrs", "fk5"]:
750
            coords = (skycoord.ra.deg, skycoord.dec.deg) + coords[1:]
751
            coords = cls._from_lonlat(coords, coordsys="CEL")
752
        elif name in ["galactic"]:
753
            coords = (skycoord.l.deg, skycoord.b.deg) + coords[1:]
754
            coords = cls._from_lonlat(coords, coordsys="GAL")
755
        else:
756
            raise ValueError("Unrecognized coordinate frame: {!r}".format(name))
757
758
        if coordsys is None:
759
            return coords
760
        else:
761
            return coords.to_coordsys(coordsys)
762
763
    @classmethod
764
    def _from_tuple(cls, coords, coordsys=None):
765
        """Create from tuple of coordinate vectors."""
766
        if isinstance(coords[0], (list, np.ndarray)) or np.isscalar(coords[0]):
767
            return cls._from_lonlat(coords, coordsys=coordsys)
768
        elif isinstance(coords[0], SkyCoord):
769
            return cls._from_skycoord(coords, coordsys=coordsys)
770
        else:
771
            raise TypeError("Type not supported: {!r}".format(type(coords)))
772
773
    @classmethod
774
    def _from_dict(cls, coords, coordsys=None):
775
        """Create from a dictionary of coordinate vectors."""
776
        if "lon" in coords and "lat" in coords:
777
            return cls(coords, coordsys=coordsys)
778
        elif "skycoord" in coords:
779
            coords_dict = OrderedDict()
780
            lon, lat, frame = skycoord_to_lonlat(coords["skycoord"], coordsys=coordsys)
781
            coords_dict["lon"] = lon
782
            coords_dict["lat"] = lat
783
            for k, v in coords.items():
784
                if k == "skycoord":
785
                    continue
786
                coords_dict[k] = v
787
            return cls(coords_dict, coordsys=coordsys)
788
        else:
789
            raise ValueError("coords dict must contain 'lon'/'lat' or 'skycoord'.")
790
791
    @classmethod
792
    def create(cls, data, coordsys=None):
793
        """Create a new `~MapCoord` object.
794
795
        This method can be used to create either unnamed (with tuple input)
796
        or named (via dict input) axes.
797
798
        Parameters
799
        ----------
800
        data : tuple, dict, `MapCoord` or `~astropy.coordinates.SkyCoord`
801
            Object containing coordinate arrays.
802
        coordsys : {'CEL', 'GAL', None}, optional
803
            Set the coordinate system for longitude and latitude. If
804
            None longitude and latitude will be assumed to be in
805
            the coordinate system native to a given map geometry.
806
807
        Examples
808
        --------
809
        >>> from astropy.coordinates import SkyCoord
810
        >>> from gammapy.maps import MapCoord
811
812
        >>> lon, lat = [1, 2], [2, 3]
813
        >>> skycoord = SkyCoord(lon, lat, unit='deg')
814
        >>> energy = [1000]
815
        >>> c = MapCoord.create((lon,lat))
816
        >>> c = MapCoord.create((skycoord,))
817
        >>> c = MapCoord.create((lon,lat,energy))
818
        >>> c = MapCoord.create(dict(lon=lon,lat=lat))
819
        >>> c = MapCoord.create(dict(lon=lon,lat=lat,energy=energy))
820
        >>> c = MapCoord.create(dict(skycoord=skycoord,energy=energy))
821
        """
822
        if isinstance(data, cls):
823
            if data.coordsys is None or coordsys == data.coordsys:
824
                return data
825
            else:
826
                return data.to_coordsys(coordsys)
827
        elif isinstance(data, dict):
828
            return cls._from_dict(data, coordsys=coordsys)
829
        elif isinstance(data, (list, tuple)):
830
            return cls._from_tuple(data, coordsys=coordsys)
831
        elif isinstance(data, SkyCoord):
832
            return cls._from_skycoord((data,), coordsys=coordsys)
833
        else:
834
            raise TypeError("Unsupported input type: {!r}".format(type(data)))
835
836
    def to_coordsys(self, coordsys):
837
        """Convert to a different coordinate frame.
838
839
        Parameters
840
        ----------
841
        coordsys : {'CEL', 'GAL'}
842
            Coordinate system, either Galactic ('GAL') or Equatorial ('CEL').
843
844
        Returns
845
        -------
846
        coords : `~MapCoord`
847
            A coordinates object.
848
        """
849
        if coordsys == self.coordsys:
850
            return copy.deepcopy(self)
851
        else:
852
            skycoord = lonlat_to_skycoord(self.lon, self.lat, self.coordsys)
853
            lon, lat, frame = skycoord_to_lonlat(skycoord, coordsys=coordsys)
854
            data = copy.deepcopy(self._data)
855
            data["lon"] = lon
856
            data["lat"] = lat
857
            return self.__class__(data, coordsys, self._match_by_name)
858
859
    def apply_mask(self, mask):
860
        """Return a masked copy of this coordinate object.
861
862
        Parameters
863
        ----------
864
        mask : `~numpy.ndarray`
865
            Boolean mask.
866
867
        Returns
868
        -------
869
        coords : `~MapCoord`
870
            A coordinates object.
871
        """
872
        data = OrderedDict([(k, v[mask]) for k, v in self._data.items()])
873
        return self.__class__(data, self.coordsys, self._match_by_name)
874
875
    def copy(self):
876
        """Copy `MapCoord` object."""
877
        return copy.deepcopy(self)
878
879
    def __repr__(self):
880
        str_ = self.__class__.__name__
881
        str_ += "\n\n"
882
        str_ += "\taxes     : {}\n".format(", ".join(self._data.keys()))
883
        str_ += "\tshape    : {}\n".format(self.shape[::-1])
884
        str_ += "\tndim     : {}\n".format(self.ndim)
885
        str_ += "\tcoordsys : {}\n".format(self.coordsys)
886
        return str_
887
888
    # TODO: this is a temporary solution until we have decided how to handle
889
    # quantities uniformly. This should be called after any `MapCoord.create()`
890
    # to support that users can pass quantities in any Map.xxx_by_coord() method.
891
    def match_axes_units(self, geom):
892
        """Match the units of the non-spatial axes to a given map geometry.
893
894
        Parameters
895
        ----------
896
        geom : `MapGeom`
897
            Map geometry with specified units per axis.
898
899
        Returns
900
        -------
901
        coords : `MapCoord`
902
            Map coord object with matched units
903
        """
904
        coords = OrderedDict()
905
906
        for name, coord in self._data.items():
907
            if name in ["lon", "lat"]:
908
                coords[name] = coord
909
            else:
910
                ax = geom.get_axis_by_name(name)
911
                coords[name] = u.Quantity(coord, ax.unit, copy=False).value
912
913
        return self.__class__(coords, coordsys=self.coordsys)
914
915
916
class MapGeomMeta(InheritDocstrings, abc.ABCMeta):
917
    pass
918
919
920
@six.add_metaclass(MapGeomMeta)
921
class MapGeom(object):
922
    """Base class for WCS and HEALPix geometries."""
923
924
    @property
925
    @abc.abstractmethod
926
    def data_shape(self):
927
        """Shape of the Numpy data array matching this geometry."""
928
        pass
929
930
    @property
931
    @abc.abstractmethod
932
    def is_allsky(self):
933
        pass
934
935
    @property
936
    @abc.abstractmethod
937
    def center_coord(self):
938
        pass
939
940
    @property
941
    @abc.abstractmethod
942
    def center_pix(self):
943
        pass
944
945
    @property
946
    @abc.abstractmethod
947
    def center_skydir(self):
948
        pass
949
950
    @classmethod
951
    def from_hdulist(cls, hdulist, hdu=None, hdu_bands=None):
952
        """Load a geometry object from a FITS HDUList.
953
954
        Parameters
955
        ----------
956
        hdulist :  `~astropy.io.fits.HDUList`
957
            HDU list containing HDUs for map data and bands.
958
        hdu : str
959
            Name or index of the HDU with the map data.
960
        hdu_bands : str
961
            Name or index of the HDU with the BANDS table.  If not
962
            defined this will be inferred from the FITS header of the
963
            map HDU.
964
965
        Returns
966
        -------
967
        geom : `~MapGeom`
968
            Geometry object.
969
        """
970
        if hdu is None:
971
            hdu = find_hdu(hdulist)
972
        else:
973
            hdu = hdulist[hdu]
974
975
        if hdu_bands is None:
976
            hdu_bands = find_bands_hdu(hdulist, hdu)
977
978
        if hdu_bands is not None:
979
            hdu_bands = hdulist[hdu_bands]
980
981
        return cls.from_header(hdu.header, hdu_bands)
982
983
    def make_bands_hdu(self, hdu=None, hdu_skymap=None, conv=None):
984
        conv = self.conv if conv is None else conv
985
        header = fits.Header()
986
        self._fill_header_from_axes(header)
987
        axis_names = None
988
989
        # FIXME: Check whether convention is compatible with
990
        # dimensionality of geometry
991
992
        if conv == "fgst-ccube":
993
            hdu = "EBOUNDS"
994
            axis_names = ["energy"]
995
        elif conv == "fgst-template":
996
            hdu = "ENERGIES"
997
            axis_names = ["energy"]
998
        elif conv == "gadf" and hdu is None:
999
            if hdu_skymap:
1000
                hdu = "{}_{}".format(hdu_skymap, "BANDS")
1001
            else:
1002
                hdu = "BANDS"
1003
        # else:
1004
        #     raise ValueError('Unknown conv: {}'.format(conv))
1005
1006
        cols = make_axes_cols(self.axes, axis_names)
1007
        cols += self._make_bands_cols()
1008
        return fits.BinTableHDU.from_columns(cols, header, name=hdu)
1009
1010
    @abc.abstractmethod
1011
    def _make_bands_cols(self):
1012
        pass
1013
1014
    @abc.abstractmethod
1015
    def get_idx(self, idx=None, local=False, flat=False):
1016
        """Get tuple of pixel indices for this geometry.
1017
1018
        Returns all pixels in the geometry by default. Pixel indices
1019
        for a single image plane can be accessed by setting ``idx``
1020
        to the index tuple of a plane.
1021
1022
        Parameters
1023
        ----------
1024
        idx : tuple, optional
1025
            A tuple of indices with one index for each non-spatial
1026
            dimension.  If defined only pixels for the image plane with
1027
            this index will be returned.  If none then all pixels
1028
            will be returned.
1029
        local : bool
1030
            Flag to return local or global pixel indices.  Local
1031
            indices run from 0 to the number of pixels in a given
1032
            image plane.
1033
        flat : bool, optional
1034
            Return a flattened array containing only indices for
1035
            pixels contained in the geometry.
1036
1037
        Returns
1038
        -------
1039
        idx : tuple
1040
            Tuple of pixel index vectors with one vector for each
1041
            dimension.
1042
        """
1043
        pass
1044
1045
    @abc.abstractmethod
1046
    def get_coord(self, idx=None, flat=False):
1047
        """Get the coordinate array for this geometry.
1048
1049
        Returns a coordinate array with the same shape as the data
1050
        array.  Pixels outside the geometry are set to NaN.
1051
        Coordinates for a single image plane can be accessed by
1052
        setting ``idx`` to the index tuple of a plane.
1053
1054
        Parameters
1055
        ----------
1056
        idx : tuple, optional
1057
            A tuple of indices with one index for each non-spatial
1058
            dimension.  If defined only coordinates for the image
1059
            plane with this index will be returned.  If none then
1060
            coordinates for all pixels will be returned.
1061
        flat : bool, optional
1062
            Return a flattened array containing only coordinates for
1063
            pixels contained in the geometry.
1064
1065
        Returns
1066
        -------
1067
        coords : tuple
1068
            Tuple of coordinate vectors with one vector for each
1069
            dimension.
1070
        """
1071
        pass
1072
1073
    @abc.abstractmethod
1074
    def coord_to_pix(self, coords):
1075
        """Convert map coordinates to pixel coordinates.
1076
1077
        Parameters
1078
        ----------
1079
        coords : tuple
1080
            Coordinate values in each dimension of the map.  This can
1081
            either be a tuple of numpy arrays or a MapCoord object.
1082
            If passed as a tuple then the ordering should be
1083
            (longitude, latitude, c_0, ..., c_N) where c_i is the
1084
            coordinate vector for axis i.
1085
1086
        Returns
1087
        -------
1088
        pix : tuple
1089
            Tuple of pixel coordinates in image and band dimensions.
1090
        """
1091
        pass
1092
1093
    def coord_to_idx(self, coords, clip=False):
1094
        """Convert map coordinates to pixel indices.
1095
1096
        Parameters
1097
        ----------
1098
        coords : tuple or `~MapCoord`
1099
            Coordinate values in each dimension of the map.  This can
1100
            either be a tuple of numpy arrays or a MapCoord object.
1101
            If passed as a tuple then the ordering should be
1102
            (longitude, latitude, c_0, ..., c_N) where c_i is the
1103
            coordinate vector for axis i.
1104
        clip : bool
1105
            Choose whether to clip indices to the valid range of the
1106
            geometry.  If false then indices for coordinates outside
1107
            the geometry range will be set -1.
1108
1109
        Returns
1110
        -------
1111
        pix : tuple
1112
            Tuple of pixel indices in image and band dimensions.
1113
            Elements set to -1 correspond to coordinates outside the
1114
            map.
1115
        """
1116
        pix = self.coord_to_pix(coords)
1117
        return self.pix_to_idx(pix, clip=clip)
1118
1119
    @abc.abstractmethod
1120
    def pix_to_coord(self, pix):
1121
        """Convert pixel coordinates to map coordinates.
1122
1123
        Parameters
1124
        ----------
1125
        pix : tuple
1126
            Tuple of pixel coordinates.
1127
1128
        Returns
1129
        -------
1130
        coords : tuple
1131
            Tuple of map coordinates.
1132
        """
1133
        pass
1134
1135
    @abc.abstractmethod
1136
    def pix_to_idx(self, pix, clip=False):
1137
        """Convert pixel coordinates to pixel indices.  Returns -1 for pixel
1138
        coordinates that lie outside of the map.
1139
1140
        Parameters
1141
        ----------
1142
        pix : tuple
1143
            Tuple of pixel coordinates.
1144
        clip : bool
1145
            Choose whether to clip indices to the valid range of the
1146
            geometry.  If false then indices for coordinates outside
1147
            the geometry range will be set -1.
1148
1149
        Returns
1150
        -------
1151
        idx : tuple
1152
            Tuple of pixel indices.
1153
        """
1154
        pass
1155
1156
    @abc.abstractmethod
1157
    def contains(self, coords):
1158
        """Check if a given map coordinate is contained in the geometry.
1159
1160
        Parameters
1161
        ----------
1162
        coords : tuple or `~gammapy.maps.MapCoord`
1163
            Tuple of map coordinates.
1164
1165
        Returns
1166
        -------
1167
        containment : `~numpy.ndarray`
1168
            Bool array.
1169
        """
1170
        pass
1171
1172
    def contains_pix(self, pix):
1173
        """Check if a given pixel coordinate is contained in the geometry.
1174
1175
        Parameters
1176
        ----------
1177
        pix : tuple
1178
            Tuple of pixel coordinates.
1179
1180
        Returns
1181
        -------
1182
        containment : `~numpy.ndarray`
1183
            Bool array.
1184
        """
1185
        idx = self.pix_to_idx(pix)
1186
        return np.all(np.stack([t != -1 for t in idx]), axis=0)
1187
1188
    def slice_by_idx(self, slices):
1189
        """Create a new geometry by cutting in the non-spatial dimensions of
1190
        this geometry.
1191
1192
        Parameters
1193
        ----------
1194
        slices : dict
1195
            Dict of axes names and integers or `slice` object pairs. Contains one
1196
            element for each non-spatial dimension. For integer indexing the
1197
            correspoding axes is dropped from the map. Axes not specified in the
1198
            dict are kept unchanged.
1199
1200
        Returns
1201
        -------
1202
        geom : `~MapGeom`
1203
            Sliced geometry.
1204
        """
1205
        axes = []
1206
        for ax in self.axes:
1207
            ax_slice = slices.get(ax.name, slice(None))
1208
            if isinstance(ax_slice, slice):
1209
                ax_sliced = ax.slice(ax_slice)
1210
                axes.append(ax_sliced)
1211
                # in the case where isinstance(ax_slice, int) the axes is dropped
1212
1213
        return self._init_copy(axes=axes)
1214
1215
    @abc.abstractmethod
1216
    def to_image(self):
1217
        """Create a 2D geometry by dropping all non-spatial dimensions of this
1218
        geometry.
1219
1220
        Returns
1221
        -------
1222
        geom : `~MapGeom`
1223
            Image geometry.
1224
        """
1225
        pass
1226
1227
    @abc.abstractmethod
1228
    def to_cube(self, axes):
1229
        """Create a new geometry by appending a list of non-spatial axes to
1230
        the present geometry.  This will result in a new geometry with
1231
        N+M dimensions where N is the number of current dimensions and
1232
        M is the number of axes in the list.
1233
1234
        Parameters
1235
        ----------
1236
        axes : list
1237
            Axes that will be appended to this geometry.
1238
1239
        Returns
1240
        -------
1241
        geom : `~MapGeom`
1242
            Map geometry.
1243
        """
1244
        pass
1245
1246
    def coord_to_tuple(self, coord):
1247
        """Generate a coordinate tuple compatible with this geometry.
1248
1249
        Parameters
1250
        ----------
1251
        coord : `~MapCoord`
1252
        """
1253
        if self.ndim != coord.ndim:
1254
            raise ValueError("ndim mismatch")
1255
1256
        if not coord.match_by_name:
1257
            return tuple(coord._data.values())
1258
1259
        coord_tuple = [coord.lon, coord.lat]
1260
        for ax in self.axes:
1261
            coord_tuple += [coord[ax.name]]
1262
1263
        return coord_tuple
1264
1265
    @abc.abstractmethod
1266
    def pad(self, pad_width):
1267
        """
1268
        Pad the geometry at the edges.
1269
1270
        Parameters
1271
        ----------
1272
        pad_width : {sequence, array_like, int}
1273
            Number of values padded to the edges of each axis.
1274
1275
        Returns
1276
        -------
1277
        geom : `~MapGeom`
1278
            Padded geometry.
1279
        """
1280
        pass
1281
1282
    @abc.abstractmethod
1283
    def crop(self, crop_width):
1284
        """
1285
        Crop the geometry at the edges.
1286
1287
        Parameters
1288
        ----------
1289
        crop_width : {sequence, array_like, int}
1290
            Number of values cropped from the edges of each axis.
1291
1292
        Returns
1293
        -------
1294
        geom : `~MapGeom`
1295
            Cropped geometry.
1296
        """
1297
        pass
1298
1299
    @abc.abstractmethod
1300
    def downsample(self, factor):
1301
        """Downsample the spatial dimension of the geometry by a given factor.
1302
1303
        Parameters
1304
        ----------
1305
        factor : int
1306
            Downsampling factor.
1307
1308
        Returns
1309
        -------
1310
        geom : `~MapGeom`
1311
            Downsampled geometry.
1312
1313
        """
1314
        pass
1315
1316
    @abc.abstractmethod
1317
    def upsample(self, factor):
1318
        """Upsample the spatial dimension of the geometry by a given factor.
1319
1320
        Parameters
1321
        ----------
1322
        factor : int
1323
            Upsampling factor.
1324
1325
        Returns
1326
        -------
1327
        geom : `~MapGeom`
1328
            Upsampled geometry.
1329
1330
        """
1331
        pass
1332
1333
    @abc.abstractmethod
1334
    def solid_angle(self):
1335
        """Solid angle (`~astropy.units.Quantity` in ``sr``)."""
1336
        pass
1337
1338
    def _fill_header_from_axes(self, header):
1339
        for idx, ax in enumerate(self.axes, start=1):
1340
            key = "AXCOLS%i" % idx
1341
            name = ax.name.upper()
1342
            if ax.name == "energy" and ax.node_type == "edges":
1343
                header[key] = "E_MIN,E_MAX"
1344
            elif ax.name == "energy" and ax.node_type == "center":
1345
                header[key] = "ENERGY"
1346
            elif ax.node_type == "edges":
1347
                header[key] = "{}_MIN,{}_MAX".format(name, name)
1348
            elif ax.node_type == "center":
1349
                header[key] = name
1350
            else:
1351
                raise ValueError("Invalid node type {!r}".format(ax.node_type))
1352
1353
    @property
1354
    def is_image(self):
1355
        """Whether the geom is equivalent to an image without extra dimensions."""
1356
        if self.axes is None:
1357
            return True
1358
        return len(self.axes) == 0
1359
1360
    def get_axis_by_name(self, name):
1361
        """Get an axis by name (case in-sensitive).
1362
1363
        Parameters
1364
        ----------
1365
        name : str
1366
           Name of the requested axis
1367
1368
        Returns
1369
        -------
1370
        axis : `~gammapy.maps.MapAxis`
1371
            Axis
1372
        """
1373
        axes = {axis.name.upper(): axis for axis in self.axes}
1374
        return axes[name.upper()]
1375
1376
    def get_axis_index_by_name(self, name):
1377
        """Get an axis index by name (case in-sensitive).
1378
1379
        Parameters
1380
        ----------
1381
        name : str
1382
           Axis name
1383
1384
        Returns
1385
        -------
1386
        index : int
1387
            Axis index
1388
        """
1389
        names = [axis.name.upper() for axis in self.axes]
1390
        return names.index(name.upper())
1391
1392
    def _init_copy(self, **kwargs):
1393
        """Init map instance by copying missing init arguments from self.
1394
        """
1395
        argnames = inspect.getargspec(self.__init__).args
1396
        argnames.remove("self")
1397
1398
        for arg in argnames:
1399
            value = getattr(self, "_" + arg)
1400
            kwargs.setdefault(arg, copy.deepcopy(value))
1401
1402
        return self.__class__(**kwargs)
1403
1404
    def copy(self, **kwargs):
1405
        """Copy `MapGeom` instance and overwrite given attributes.
1406
1407
        Parameters
1408
        ----------
1409
        **kwargs : dict
1410
            Keyword arguments to overwrite in the map geometry constructor.
1411
1412
        Returns
1413
        --------
1414
        copy : `MapGeom`
1415
            Copied map geometry.
1416
        """
1417
        return self._init_copy(**kwargs)
1418