Completed
Push — master ( 7b2f36...894ff4 )
by Axel
35s queued 19s
created

gammapy.maps.axes.TimeMapAxis.from_time_bounds()   A

Complexity

Conditions 1

Size

Total Lines 23
Code Lines 10

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 10
nop 6
dl 0
loc 23
rs 9.9
c 0
b 0
f 0
1
# Licensed under a 3-clause BSD style license - see LICENSE.rst
2
import copy
3
import inspect
4
from collections.abc import Sequence
5
import numpy as np
6
import scipy
7
import astropy.units as u
8
from astropy.io import fits
9
from astropy.table import Column, Table, hstack
10
from astropy.time import Time
11
from astropy.utils import lazyproperty
12
from gammapy.utils.interpolation import interpolation_scale
13
from gammapy.utils.time import time_ref_from_dict, time_ref_to_dict
14
from .utils import INVALID_INDEX, edges_from_lo_hi
15
16
__all__ = ["MapAxes", "MapAxis", "TimeMapAxis", "LabelMapAxis"]
17
18
19
def flat_if_equal(array):
20
    if array.ndim == 2 and np.all(array == array[0]):
21
        return array[0]
22
    else:
23
        return array
24
25
26
def coord_to_pix(edges, coord, interp="lin"):
27
    """Convert axis to pixel coordinates for given interpolation scheme."""
28
    scale = interpolation_scale(interp)
29
30
    interp_fn = scipy.interpolate.interp1d(
31
        scale(edges), np.arange(len(edges), dtype=float), fill_value="extrapolate"
32
    )
33
34
    return interp_fn(scale(coord))
35
36
37
def pix_to_coord(edges, pix, interp="lin"):
38
    """Convert pixel to grid coordinates for given interpolation scheme."""
39
    scale = interpolation_scale(interp)
40
41
    interp_fn = scipy.interpolate.interp1d(
42
        np.arange(len(edges), dtype=float), scale(edges), fill_value="extrapolate"
43
    )
44
45
    return scale.inverse(interp_fn(pix))
46
47
48
class MapAxis:
49
    """Class representing an axis of a map.
50
51
    Provides methods for
52
    transforming to/from axis and pixel coordinates.  An axis is
53
    defined by a sequence of node values that lie at the center of
54
    each bin.  The pixel coordinate at each node is equal to its index
55
    in the node array (0, 1, ..).  Bin edges are offset by 0.5 in
56
    pixel coordinates from the nodes such that the lower/upper edge of
57
    the first bin is (-0.5,0.5).
58
59
    Parameters
60
    ----------
61
    nodes : `~numpy.ndarray` or `~astropy.units.Quantity`
62
        Array of node values.  These will be interpreted as either bin
63
        edges or centers according to ``node_type``.
64
    interp : str
65
        Interpolation method used to transform between axis and pixel
66
        coordinates.  Valid options are 'log', 'lin', and 'sqrt'.
67
    name : str
68
        Axis name
69
    node_type : str
70
        Flag indicating whether coordinate nodes correspond to pixel
71
        edges (node_type = 'edge') or pixel centers (node_type =
72
        'center').  'center' should be used where the map values are
73
        defined at a specific coordinate (e.g. differential
74
        quantities). 'edge' should be used where map values are
75
        defined by an integral over coordinate intervals (e.g. a
76
        counts histogram).
77
    unit : str
78
        String specifying the data units.
79
    """
80
81
    # TODO: Cache an interpolation object?
82
    def __init__(self, nodes, interp="lin", name="", node_type="edges", unit=""):
83
        self._name = name
84
85
        if len(nodes) != len(np.unique(nodes)):
86
            raise ValueError("MapAxis: node values must be unique")
87
88
        if ~(np.all(nodes == np.sort(nodes)) or np.all(nodes[::-1] == np.sort(nodes))):
89
            raise ValueError("MapAxis: node values must be sorted")
90
91
        if len(nodes) == 1 and node_type == "center":
92
            raise ValueError("Single bins can only be used with node-type 'edges'")
93
94
        if isinstance(nodes, u.Quantity):
95
            unit = nodes.unit if nodes.unit is not None else ""
96
            nodes = nodes.value
97
        else:
98
            nodes = np.array(nodes)
99
100
        self._unit = u.Unit(unit)
101
        self._nodes = nodes.astype(float)
102
        self._node_type = node_type
103
        self._interp = interp
104
105
        if (self._nodes < 0).any() and interp != "lin":
106
            raise ValueError(
107
                f"Interpolation scaling {interp!r} only support for positive node values."
108
            )
109
110
        # Set pixel coordinate of first node
111
        if node_type == "edges":
112
            self._pix_offset = -0.5
113
            nbin = len(nodes) - 1
114
        elif node_type == "center":
115
            self._pix_offset = 0.0
116
            nbin = len(nodes)
117
        else:
118
            raise ValueError(f"Invalid node type: {node_type!r}")
119
120
        self._nbin = nbin
121
122
    def assert_name(self, required_name):
123
        """Assert axis name if a specific one is required.
124
125
        Parameters
126
        ----------
127
        required_name : str
128
            Required
129
        """
130
        if self.name != required_name:
131
            raise ValueError(
132
                "Unexpected axis name,"
133
                f' expected "{required_name}", got: "{self.name}"'
134
            )
135
136
    def is_aligned(self, other, atol=2e-2):
137
        """Check if other map axis is aligned.
138
139
        Two axes are aligned if their center coordinate values map to integers
140
        on the other axes as well and if the interpolation modes are equivalent.
141
142
        Parameters
143
        ----------
144
        other : `MapAxis`
145
            Other map axis.
146
        atol : float
147
            Absolute numerical tolerance for the comparison measured in bins.
148
149
        Returns
150
        -------
151
        aligned : bool
152
            Whether the axes are aligned
153
        """
154
        pix = self.coord_to_pix(other.center)
155
        pix_other = other.coord_to_pix(self.center)
156
        pix_all = np.append(pix, pix_other)
157
        aligned = np.allclose(np.round(pix_all) - pix_all, 0, atol=atol)
158
        return aligned and self.interp == other.interp
159
160
    def __eq__(self, other):
161
        if not isinstance(other, self.__class__):
162
            return NotImplemented
163
164
        # TODO: implement an allclose method for MapAxis and call it here
165
        if self.edges.shape != other.edges.shape:
166
            return False
167
        if not self.unit.is_equivalent(other.unit):
168
            return False
169
        return (
170
            np.allclose(
171
                self.edges.to(other.unit).value, other.edges.value, atol=1e-6, rtol=1e-6
172
            )
173
            and self._node_type == other._node_type
174
            and self._interp == other._interp
175
            and self.name.upper() == other.name.upper()
176
        )
177
178
    def __ne__(self, other):
179
        return not self.__eq__(other)
180
181
    def __hash__(self):
182
        return id(self)
183
184
    @property
185
    def is_energy_axis(self):
186
        return self.name in ["energy", "energy_true"]
187
188
    @property
189
    def interp(self):
190
        """Interpolation scale of the axis."""
191
        return self._interp
192
193
    @property
194
    def name(self):
195
        """Name of the axis."""
196
        return self._name
197
198
    @name.setter
199
    def name(self, value):
200
        """Name of the axis."""
201
        self._name = value
202
203
    @lazyproperty
204
    def edges(self):
205
        """Return array of bin edges."""
206
        pix = np.arange(self.nbin + 1, dtype=float) - 0.5
207
        return u.Quantity(self.pix_to_coord(pix), self._unit, copy=False)
208
209
    @property
210
    def edges_min(self):
211
        """Return array of bin edges max values."""
212
        return self.edges[:-1]
213
214
    @property
215
    def edges_max(self):
216
        """Return array of bin edges min values."""
217
        return self.edges[1:]
218
219
    @property
220
    def bounds(self):
221
        """Bounds of the axis (~astropy.units.Quantity)"""
222
        idx = [0, -1]
223
        if self.node_type == "edges":
224
            return self.edges[idx]
225
        else:
226
            return self.center[idx]
227
228
    @property
229
    def as_plot_xerr(self):
230
        """Return tuple of xerr to be used with plt.errorbar()"""
231
        return (
232
            self.center - self.edges_min,
233
            self.edges_max - self.center,
234
        )
235
236
    @property
237
    def as_plot_labels(self):
238
        """Return list of axis plot labels"""
239
        if self.node_type == "edges":
240
            labels = [
241
                f"{val_min:.2e} - {val_max:.2e}"
242
                for val_min, val_max in self.iter_by_edges
243
            ]
244
        else:
245
            labels = [f"{val:.2e}" for val in self.center]
246
247
        return labels
248
249
    @property
250
    def as_plot_edges(self):
251
        """Plot edges"""
252
        return self.edges
253
254
    @property
255
    def as_plot_center(self):
256
        """Plot center"""
257
        return self.center
258
259
    @property
260
    def as_plot_scale(self):
261
        """Plot axis scale"""
262
        mpl_scale = {"lin": "linear", "sqrt": "linear", "log": "log"}
263
264
        return mpl_scale[self.interp]
265
266
    def format_plot_xaxis(self, ax):
267
        """Format plot axis
268
269
        Parameters
270
        ----------
271
        ax : `~matplotlib.pyplot.Axis`
272
            Plot axis to format
273
274
        Returns
275
        -------
276
        ax : `~matplotlib.pyplot.Axis`
277
            Formatted plot axis
278
        """
279
        ax.set_xscale(self.as_plot_scale)
280
281
        xlabel = self.name.capitalize() + f" [{ax.xaxis.units}]"
282
        ax.set_xlabel(xlabel)
283
        ax.set_xlim(self.bounds)
284
        return ax
285
286
    def format_plot_yaxis(self, ax):
287
        """Format plot axis
288
289
        Parameters
290
        ----------
291
        ax : `~matplotlib.pyplot.Axis`
292
            Plot axis to format
293
294
        Returns
295
        -------
296
        ax : `~matplotlib.pyplot.Axis`
297
            Formatted plot axis
298
        """
299
        ax.set_yscale(self.as_plot_scale)
300
301
        ylabel = self.name.capitalize() + f" [{ax.yaxis.units}]"
302
        ax.set_ylabel(ylabel)
303
        ax.set_ylim(self.bounds)
304
        return ax
305
306
    @property
307
    def iter_by_edges(self):
308
        """Iterate by intervals defined by the edges"""
309
        for value_min, value_max in zip(self.edges[:-1], self.edges[1:]):
310
            yield (value_min, value_max)
311
312
    @lazyproperty
313
    def center(self):
314
        """Return array of bin centers."""
315
        pix = np.arange(self.nbin, dtype=float)
316
        return u.Quantity(self.pix_to_coord(pix), self._unit, copy=False)
317
318
    @lazyproperty
319
    def bin_width(self):
320
        """Array of bin widths."""
321
        return np.diff(self.edges)
322
323
    @property
324
    def nbin(self):
325
        """Return number of bins."""
326
        return self._nbin
327
328
    @property
329
    def nbin_per_decade(self):
330
        """Return number of bins."""
331
        if self.interp != "log":
332
            raise ValueError("Bins per decade can only be computed for log-spaced axes")
333
334
        if self.node_type == "edges":
335
            values = self.edges
336
        else:
337
            values = self.center
338
339
        ndecades = np.log10(values.max() / values.min())
340
        return (self._nbin / ndecades).value
341
342
    @property
343
    def node_type(self):
344
        """Return node type ('center' or 'edge')."""
345
        return self._node_type
346
347
    @property
348
    def unit(self):
349
        """Return coordinate axis unit."""
350
        return self._unit
351
352
    @classmethod
353
    def from_bounds(cls, lo_bnd, hi_bnd, nbin, **kwargs):
354
        """Generate an axis object from a lower/upper bound and number of bins.
355
356
        If node_type = 'edge' then bounds correspond to the
357
        lower and upper bound of the first and last bin.  If node_type
358
        = 'center' then bounds correspond to the centers of the first
359
        and last bin.
360
361
        Parameters
362
        ----------
363
        lo_bnd : float
364
            Lower bound of first axis bin.
365
        hi_bnd : float
366
            Upper bound of last axis bin.
367
        nbin : int
368
            Number of bins.
369
        interp : {'lin', 'log', 'sqrt'}
370
            Interpolation method used to transform between axis and pixel
371
            coordinates.  Default: 'lin'.
372
        """
373
        nbin = int(nbin)
374
        interp = kwargs.setdefault("interp", "lin")
375
        node_type = kwargs.setdefault("node_type", "edges")
376
377
        if node_type == "edges":
378
            nnode = nbin + 1
379
        elif node_type == "center":
380
            nnode = nbin
381
        else:
382
            raise ValueError(f"Invalid node type: {node_type!r}")
383
384
        if interp == "lin":
385
            nodes = np.linspace(lo_bnd, hi_bnd, nnode)
386
        elif interp == "log":
387
            nodes = np.exp(np.linspace(np.log(lo_bnd), np.log(hi_bnd), nnode))
388
        elif interp == "sqrt":
389
            nodes = np.linspace(lo_bnd ** 0.5, hi_bnd ** 0.5, nnode) ** 2.0
390
        else:
391
            raise ValueError(f"Invalid interp: {interp}")
392
393
        return cls(nodes, **kwargs)
394
395
    @classmethod
396
    def from_energy_edges(cls, energy_edges, unit=None, name=None, interp="log"):
397
        """Make an energy axis from adjacent edges.
398
399
        Parameters
400
        ----------
401
        energy_edges : `~astropy.units.Quantity`, float
402
            Energy edges
403
        unit : `~astropy.units.Unit`
404
            Energy unit
405
        name : str
406
            Name of the energy axis, either 'energy' or 'energy_true'
407
        interp: str
408
            interpolation mode. Default is 'log'.
409
410
        Returns
411
        -------
412
        axis : `MapAxis`
413
            Axis with name "energy" and interp "log".
414
        """
415
        energy_edges = u.Quantity(energy_edges, unit)
416
417
        if not energy_edges.unit.is_equivalent("TeV"):
418
            raise ValueError(
419
                f"Please provide a valid energy unit, got {energy_edges.unit} instead."
420
            )
421
422
        if name is None:
423
            name = "energy"
424
425
        if name not in ["energy", "energy_true"]:
426
            raise ValueError("Energy axis can only be named 'energy' or 'energy_true'")
427
428
        return cls.from_edges(energy_edges, unit=unit, interp=interp, name=name)
429
430
    @classmethod
431
    def from_energy_bounds(
432
        cls,
433
        energy_min,
434
        energy_max,
435
        nbin,
436
        unit=None,
437
        per_decade=False,
438
        name=None,
439
        node_type="edges",
440
    ):
441
        """Make an energy axis.
442
443
        Used frequently also to make energy grids, by making
444
        the axis, and then using ``axis.center`` or ``axis.edges``.
445
446
        Parameters
447
        ----------
448
        energy_min, energy_max : `~astropy.units.Quantity`, float
449
            Energy range
450
        nbin : int
451
            Number of bins
452
        unit : `~astropy.units.Unit`
453
            Energy unit
454
        per_decade : bool
455
            Whether `nbin` is given per decade.
456
        name : str
457
            Name of the energy axis, either 'energy' or 'energy_true'
458
459
        Returns
460
        -------
461
        axis : `MapAxis`
462
            Axis with name "energy" and interp "log".
463
        """
464
        energy_min = u.Quantity(energy_min, unit)
465
        energy_max = u.Quantity(energy_max, unit)
466
467
        if unit is None:
468
            unit = energy_max.unit
469
            energy_min = energy_min.to(unit)
470
471
        if not energy_max.unit.is_equivalent("TeV"):
472
            raise ValueError(
473
                f"Please provide a valid energy unit, got {energy_max.unit} instead."
474
            )
475
476
        if per_decade:
477
            nbin = np.ceil(np.log10(energy_max / energy_min).value * nbin)
478
479
        if name is None:
480
            name = "energy"
481
482
        if name not in ["energy", "energy_true"]:
483
            raise ValueError("Energy axis can only be named 'energy' or 'energy_true'")
484
485
        return cls.from_bounds(
486
            energy_min.value,
487
            energy_max.value,
488
            nbin=nbin,
489
            unit=unit,
490
            interp="log",
491
            name=name,
492
            node_type=node_type,
493
        )
494
495
    @classmethod
496
    def from_nodes(cls, nodes, **kwargs):
497
        """Generate an axis object from a sequence of nodes (bin centers).
498
499
        This will create a sequence of bins with edges half-way
500
        between the node values.  This method should be used to
501
        construct an axis where the bin center should lie at a
502
        specific value (e.g. a map of a continuous function).
503
504
        Parameters
505
        ----------
506
        nodes : `~numpy.ndarray`
507
            Axis nodes (bin center).
508
        interp : {'lin', 'log', 'sqrt'}
509
            Interpolation method used to transform between axis and pixel
510
            coordinates.  Default: 'lin'.
511
        """
512
        if len(nodes) < 1:
513
            raise ValueError("Nodes array must have at least one element.")
514
515
        return cls(nodes, node_type="center", **kwargs)
516
517
    @classmethod
518
    def from_edges(cls, edges, **kwargs):
519
        """Generate an axis object from a sequence of bin edges.
520
521
        This method should be used to construct an axis where the bin
522
        edges should lie at specific values (e.g. a histogram).  The
523
        number of bins will be one less than the number of edges.
524
525
        Parameters
526
        ----------
527
        edges : `~numpy.ndarray`
528
            Axis bin edges.
529
        interp : {'lin', 'log', 'sqrt'}
530
            Interpolation method used to transform between axis and pixel
531
            coordinates.  Default: 'lin'.
532
        """
533
        if len(edges) < 2:
534
            raise ValueError("Edges array must have at least two elements.")
535
536
        return cls(edges, node_type="edges", **kwargs)
537
538
    def append(self, axis):
539
        """Append another map axis to this axis
540
541
        Name, interp type and node type must agree between the axes. If the node
542
        type is "edges", the edges must be contiguous and non-overlapping.
543
544
        Parameters
545
        ----------
546
        axis : `MapAxis`
547
            Axis to append.
548
549
        Returns
550
        -------
551
        axis : `MapAxis`
552
            Appended axis
553
        """
554
        if self.node_type != axis.node_type:
555
            raise ValueError(
556
                f"Node type must agree, got {self.node_type} and {axis.node_type}"
557
            )
558
559
        if self.name != axis.name:
560
            raise ValueError(f"Names must agree, got {self.name} and {axis.name} ")
561
562
        if self.interp != axis.interp:
563
            raise ValueError(
564
                f"Interp type must agree, got {self.interp} and {axis.interp}"
565
            )
566
567
        if self.node_type == "edges":
568
            edges = np.append(self.edges, axis.edges[1:])
569
            return self.from_edges(edges=edges, interp=self.interp, name=self.name)
570
        else:
571
            nodes = np.append(self.center, axis.center)
572
            return self.from_nodes(nodes=nodes, interp=self.interp, name=self.name)
573
574
    def pad(self, pad_width):
575
        """Pad axis by a given number of pixels
576
577
        Parameters
578
        ----------
579
        pad_width : int or tuple of int
580
            A single int pads in both direction of the axis, a tuple specifies,
581
            which number of bins to pad at the low and high edge of the axis.
582
583
        Returns
584
        -------
585
        axis : `MapAxis`
586
            Padded axis
587
        """
588
        if isinstance(pad_width, tuple):
589
            pad_low, pad_high = pad_width
590
        else:
591
            pad_low, pad_high = pad_width, pad_width
592
593
        if self.node_type == "edges":
594
            pix = np.arange(-pad_low, self.nbin + pad_high + 1) - 0.5
595
            edges = self.pix_to_coord(pix)
596
            return self.from_edges(edges=edges, interp=self.interp, name=self.name)
597
        else:
598
            pix = np.arange(-pad_low, self.nbin + pad_high)
599
            nodes = self.pix_to_coord(pix)
600
            return self.from_nodes(nodes=nodes, interp=self.interp, name=self.name)
601
602
    @classmethod
603
    def from_stack(cls, axes):
604
        """Create a map axis by merging a list of other map axes.
605
606
        If the node type is "edges" the bin edges in the provided axes must be
607
        contiguous and non-overlapping.
608
609
        Parameters
610
        ----------
611
        axes : list of `MapAxis`
612
            List of map axis to merge.
613
614
        Returns
615
        -------
616
        axis : `MapAxis`
617
            Merged axis
618
        """
619
        ax_stacked = axes[0]
620
621
        for ax in axes[1:]:
622
            ax_stacked = ax_stacked.append(ax)
623
624
        return ax_stacked
625
626
    def pix_to_coord(self, pix):
627
        """Transform from pixel to axis coordinates.
628
629
        Parameters
630
        ----------
631
        pix : `~numpy.ndarray`
632
            Array of pixel coordinate values.
633
634
        Returns
635
        -------
636
        coord : `~numpy.ndarray`
637
            Array of axis coordinate values.
638
        """
639
        pix = pix - self._pix_offset
640
        values = pix_to_coord(self._nodes, pix, interp=self._interp)
641
        return u.Quantity(values, unit=self.unit, copy=False)
642
643 View Code Duplication
    def pix_to_idx(self, pix, clip=False):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
644
        """Convert pix to idx
645
646
        Parameters
647
        ----------
648
        pix : `~numpy.ndarray`
649
            Pixel coordinates.
650
        clip : bool
651
            Choose whether to clip indices to the valid range of the
652
            axis.  If false then indices for coordinates outside
653
            the axi range will be set -1.
654
655
        Returns
656
        -------
657
        idx : `~numpy.ndarray`
658
            Pixel indices.
659
        """
660
        if clip:
661
            idx = np.clip(pix, 0, self.nbin - 1)
662
        else:
663
            condition = (pix < 0) | (pix >= self.nbin)
664
            idx = np.where(condition, -1, pix)
665
666
        return idx
667
668
    def coord_to_pix(self, coord):
669
        """Transform from axis to pixel coordinates.
670
671
        Parameters
672
        ----------
673
        coord : `~numpy.ndarray`
674
            Array of axis coordinate values.
675
676
        Returns
677
        -------
678
        pix : `~numpy.ndarray`
679
            Array of pixel coordinate values.
680
        """
681
        coord = u.Quantity(coord, self.unit, copy=False).value
682
        pix = coord_to_pix(self._nodes, coord, interp=self._interp)
683
        return np.array(pix + self._pix_offset, ndmin=1)
684
685
    def coord_to_idx(self, coord, clip=False):
686
        """Transform from axis coordinate to bin index.
687
688
        Parameters
689
        ----------
690
        coord : `~numpy.ndarray`
691
            Array of axis coordinate values.
692
        clip : bool
693
            Choose whether to clip the index to the valid range of the
694
            axis.  If false then indices for values outside the axis
695
            range will be set -1.
696
697
        Returns
698
        -------
699
        idx : `~numpy.ndarray`
700
            Array of bin indices.
701
        """
702
        coord = u.Quantity(coord, self.unit, copy=False, ndmin=1).value
703
        edges = self.edges.value
704
        idx = np.digitize(coord, edges) - 1
705
706
        if clip:
707
            idx = np.clip(idx, 0, self.nbin - 1)
708
        else:
709
            with np.errstate(invalid="ignore"):
710
                idx[coord > edges[-1]] = INVALID_INDEX.int
711
712
        idx[~np.isfinite(coord)] = INVALID_INDEX.int
713
714
        return idx
715
716
    def slice(self, idx):
717
        """Create a new axis object by extracting a slice from this axis.
718
719
        Parameters
720
        ----------
721
        idx : slice
722
            Slice object selecting a subselection of the axis.
723
724
        Returns
725
        -------
726
        axis : `~MapAxis`
727
            Sliced axis object.
728
        """
729
        center = self.center[idx].value
730
        idx = self.coord_to_idx(center)
731
        # For edge nodes we need to keep N+1 nodes
732
        if self._node_type == "edges":
733
            idx = tuple(list(idx) + [1 + idx[-1]])
734
735
        nodes = self._nodes[(idx,)]
736
        return MapAxis(
737
            nodes,
738
            interp=self._interp,
739
            name=self._name,
740
            node_type=self._node_type,
741
            unit=self._unit,
742
        )
743
744
    def squash(self):
745
        """Create a new axis object by squashing the axis into one bin.
746
747
        Returns
748
        -------
749
        axis : `~MapAxis`
750
            Sliced axis object.
751
        """
752
        # TODO: Decide on handling node_type=center
753
        # See https://github.com/gammapy/gammapy/issues/1952
754
        return MapAxis.from_bounds(
755
            lo_bnd=self.edges[0].value,
756
            hi_bnd=self.edges[-1].value,
757
            nbin=1,
758
            interp=self._interp,
759
            name=self._name,
760
            unit=self._unit,
761
        )
762
763
    def __repr__(self):
764
        str_ = self.__class__.__name__
765
        str_ += "\n\n"
766
        fmt = "\t{:<10s} : {:<10s}\n"
767
        str_ += fmt.format("name", self.name)
768
        str_ += fmt.format("unit", "{!r}".format(str(self.unit)))
769
        str_ += fmt.format("nbins", str(self.nbin))
770
        str_ += fmt.format("node type", self.node_type)
771
        vals = self.edges if self.node_type == "edges" else self.center
772
        str_ += fmt.format(f"{self.node_type} min", "{:.1e}".format(vals.min()))
773
        str_ += fmt.format(f"{self.node_type} max", "{:.1e}".format(vals.max()))
774
        str_ += fmt.format("interp", self._interp)
775
        return str_
776
777 View Code Duplication
    def _init_copy(self, **kwargs):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
778
        """Init map axis instance by copying missing init arguments from self."""
779
        argnames = inspect.getfullargspec(self.__init__).args
780
        argnames.remove("self")
781
782
        for arg in argnames:
783
            value = getattr(self, "_" + arg)
784
            kwargs.setdefault(arg, copy.deepcopy(value))
785
786
        return self.__class__(**kwargs)
787
788
    def copy(self, **kwargs):
789
        """Copy `MapAxis` instance and overwrite given attributes.
790
791
        Parameters
792
        ----------
793
        **kwargs : dict
794
            Keyword arguments to overwrite in the map axis constructor.
795
796
        Returns
797
        -------
798
        copy : `MapAxis`
799
            Copied map axis.
800
        """
801
        return self._init_copy(**kwargs)
802
803
    def round(self, coord, clip=False):
804
        """Round coord to nearest axis edge.
805
806
        Parameters
807
        ----------
808
        coord : `~astropy.units.Quantity`
809
            Coordinates
810
        clip : bool
811
            Choose whether to clip indices to the valid range of the axis.
812
813
        Returns
814
        -------
815
        coord : `~astropy.units.Quantity`
816
            Rounded coordinates
817
        """
818
        edges_pix = self.coord_to_pix(coord)
819
820
        if clip:
821
            edges_pix = np.clip(edges_pix, -0.5, self.nbin - 0.5)
822
823
        edges_idx = np.round(edges_pix + 0.5) - 0.5
824
        return self.pix_to_coord(edges_idx)
825
826
    def group_table(self, edges):
827
        """Compute bin groups table for the map axis, given coarser bin edges.
828
829
        Parameters
830
        ----------
831
        edges : `~astropy.units.Quantity`
832
            Group bin edges.
833
834
        Returns
835
        -------
836
        groups : `~astropy.table.Table`
837
            Map axis group table.
838
        """
839
        # TODO: try to simplify this code
840
        if self.node_type != "edges":
841
            raise ValueError("Only edge based map axis can be grouped")
842
843
        edges_pix = self.coord_to_pix(edges)
844
        edges_pix = np.clip(edges_pix, -0.5, self.nbin - 0.5)
845
        edges_idx = np.round(edges_pix + 0.5) - 0.5
846
        edges_idx = np.unique(edges_idx)
847
        edges_ref = self.pix_to_coord(edges_idx)
848
849
        groups = Table()
850
        groups[f"{self.name}_min"] = edges_ref[:-1]
851
        groups[f"{self.name}_max"] = edges_ref[1:]
852
853
        groups["idx_min"] = (edges_idx[:-1] + 0.5).astype(int)
854
        groups["idx_max"] = (edges_idx[1:] - 0.5).astype(int)
855
856
        if len(groups) == 0:
857
            raise ValueError("No overlap between reference and target edges.")
858
859
        groups["bin_type"] = "normal   "
860
861
        edge_idx_start, edge_ref_start = edges_idx[0], edges_ref[0]
862
        if edge_idx_start > 0:
863
            underflow = {
864
                "bin_type": "underflow",
865
                "idx_min": 0,
866
                "idx_max": edge_idx_start,
867
                f"{self.name}_min": self.pix_to_coord(-0.5),
868
                f"{self.name}_max": edge_ref_start,
869
            }
870
            groups.insert_row(0, vals=underflow)
871
872
        edge_idx_end, edge_ref_end = edges_idx[-1], edges_ref[-1]
873
874
        if edge_idx_end < (self.nbin - 0.5):
875
            overflow = {
876
                "bin_type": "overflow",
877
                "idx_min": edge_idx_end + 1,
878
                "idx_max": self.nbin - 1,
879
                f"{self.name}_min": edge_ref_end,
880
                f"{self.name}_max": self.pix_to_coord(self.nbin - 0.5),
881
            }
882
            groups.add_row(vals=overflow)
883
884
        group_idx = Column(np.arange(len(groups)))
885
        groups.add_column(group_idx, name="group_idx", index=0)
886
        return groups
887
888
    def upsample(self, factor):
889
        """Upsample map axis by a given factor.
890
891
        When up-sampling for each node specified in the axis, the corresponding
892
        number of sub-nodes are introduced and preserving the initial nodes. For
893
        node type "edges" this results in nbin * factor new bins. For node type
894
        "center" this results in (nbin - 1) * factor + 1 new bins.
895
896
        Parameters
897
        ----------
898
        factor : int
899
            Upsampling factor.
900
901
        Returns
902
        -------
903
        axis : `MapAxis`
904
            Usampled map axis.
905
906
        """
907
        if self.node_type == "edges":
908
            pix = self.coord_to_pix(self.edges)
909
            nbin = int(self.nbin * factor) + 1
910
            pix_new = np.linspace(pix.min(), pix.max(), nbin)
911
            edges = self.pix_to_coord(pix_new)
912
            return self.from_edges(edges, name=self.name, interp=self.interp)
913
        else:
914
            pix = self.coord_to_pix(self.center)
915
            nbin = int((self.nbin - 1) * factor) + 1
916
            pix_new = np.linspace(pix.min(), pix.max(), nbin)
917
            nodes = self.pix_to_coord(pix_new)
918
            return self.from_nodes(nodes, name=self.name, interp=self.interp)
919
920
    def downsample(self, factor):
921
        """Downsample map axis by a given factor.
922
923
        When down-sampling each n-th (given by the factor) bin is selected from
924
        the axis while preserving the axis limits. For node type "edges" this
925
        requires nbin to be dividable by the factor, for node type "center" this
926
        requires nbin - 1 to be dividable by the factor.
927
928
        Parameters
929
        ----------
930
        factor : int
931
            Downsampling factor.
932
933
934
        Returns
935
        -------
936
        axis : `MapAxis`
937
            Downsampled map axis.
938
        """
939
        if self.node_type == "edges":
940
            nbin = self.nbin / factor
941
942
            if np.mod(nbin, 1) > 0:
943
                raise ValueError(
944
                    f"Number of {self.name} bins is not divisible by {factor}"
945
                )
946
947
            edges = self.edges[::factor]
948
            return self.from_edges(edges, name=self.name, interp=self.interp)
949
        else:
950
            nbin = (self.nbin - 1) / factor
951
952
            if np.mod(nbin, 1) > 0:
953
                raise ValueError(
954
                    f"Number of {self.name} bins - 1 is not divisible by {factor}"
955
                )
956
957
            nodes = self.center[::factor]
958
            return self.from_nodes(nodes, name=self.name, interp=self.interp)
959
960
    def to_header(self, format="ogip", idx=0):
961
        """Create FITS header
962
963
        Parameters
964
        ----------
965
        format : {"ogip"}
966
            Format specification
967
        idx : int
968
            Column index of the axis.
969
970
        Returns
971
        -------
972
        header : `~astropy.io.fits.Header`
973
            Header to extend.
974
        """
975
        header = fits.Header()
976
977
        if format in ["ogip", "ogip-sherpa"]:
978
            header["EXTNAME"] = "EBOUNDS", "Name of this binary table extension"
979
            header["TELESCOP"] = "DUMMY", "Mission/satellite name"
980
            header["INSTRUME"] = "DUMMY", "Instrument/detector"
981
            header["FILTER"] = "None", "Filter information"
982
            header["CHANTYPE"] = "PHA", "Type of channels (PHA, PI etc)"
983
            header["DETCHANS"] = self.nbin, "Total number of detector PHA channels"
984
            header["HDUCLASS"] = "OGIP", "Organisation devising file format"
985
            header["HDUCLAS1"] = "RESPONSE", "File relates to response of instrument"
986
            header["HDUCLAS2"] = "EBOUNDS", "This is an EBOUNDS extension"
987
            header["HDUVERS"] = "1.2.0", "Version of file format"
988
        elif format in ["gadf", "fgst-ccube", "fgst-template"]:
989
            key = f"AXCOLS{idx}"
990
            name = self.name.upper()
991
992
            if self.name == "energy" and self.node_type == "edges":
993
                header[key] = "E_MIN,E_MAX"
994
            elif self.name == "energy" and self.node_type == "center":
995
                header[key] = "ENERGY"
996
            elif self.node_type == "edges":
997
                header[key] = f"{name}_MIN,{name}_MAX"
998
            elif self.node_type == "center":
999
                header[key] = name
1000
            else:
1001
                raise ValueError(f"Invalid node type {self.node_type!r}")
1002
1003
            key_interp = f"INTERP{idx}"
1004
            header[key_interp] = self.interp
1005
1006
        else:
1007
            raise ValueError(f"Unknown format {format}")
1008
1009
        return header
1010
1011
    def to_table(self, format="ogip"):
1012
        """Convert `~astropy.units.Quantity` to OGIP ``EBOUNDS`` extension.
1013
1014
        See https://heasarc.gsfc.nasa.gov/docs/heasarc/caldb/docs/memos/cal_gen_92_002/cal_gen_92_002.html#tth_sEc3.2
1015
1016
        The 'ogip-sherpa' format is equivalent to 'ogip' but uses keV energy units.
1017
1018
        Parameters
1019
        ----------
1020
        format : {"ogip", "ogip-sherpa", "gadf-dl3", "gtpsf"}
1021
            Format specification
1022
1023
        Returns
1024
        -------
1025
        table : `~astropy.table.Table`
1026
            Table HDU
1027
        """
1028
        table = Table()
1029
        edges = self.edges
1030
1031
        if format in ["ogip", "ogip-sherpa"]:
1032
            self.assert_name("energy")
1033
1034
            if format == "ogip-sherpa":
1035
                edges = edges.to("keV")
1036
1037
            table["CHANNEL"] = np.arange(self.nbin, dtype=np.int16)
1038
            table["E_MIN"] = edges[:-1]
1039
            table["E_MAX"] = edges[1:]
1040
        elif format in ["ogip-arf", "ogip-arf-sherpa"]:
1041
            self.assert_name("energy_true")
1042
1043
            if format == "ogip-arf-sherpa":
1044
                edges = edges.to("keV")
1045
1046
            table["ENERG_LO"] = edges[:-1]
1047
            table["ENERG_HI"] = edges[1:]
1048
        elif format == "gadf-sed":
1049
            if self.is_energy_axis:
1050
                table["e_ref"] = self.center
1051
                table["e_min"] = self.edges_min
1052
                table["e_max"] = self.edges_max
1053
        elif format == "gadf-dl3":
1054
            from gammapy.irf.io import IRF_DL3_AXES_SPECIFICATION
1055
1056
            if self.name == "energy":
1057
                column_prefix = "ENERG"
1058
            else:
1059
                for column_prefix, spec in IRF_DL3_AXES_SPECIFICATION.items():
1060
                    if spec["name"] == self.name:
1061
                        break
1062
1063
            if self.node_type == "edges":
1064
                edges_hi, edges_lo = edges[:-1], edges[1:]
1065
            else:
1066
                edges_hi, edges_lo = self.center, self.center
1067
1068
            table[f"{column_prefix}_LO"] = edges_hi[np.newaxis]
1069
            table[f"{column_prefix}_HI"] = edges_lo[np.newaxis]
1070
        elif format == "gtpsf":
1071
            if self.name == "energy_true":
1072
                table["Energy"] = self.center.to("MeV")
1073
            elif self.name == "rad":
1074
                table["Theta"] = self.center.to("deg")
1075
            else:
1076
                raise ValueError(
1077
                    "Can only convert true energy or rad axis to"
1078
                    f"'gtpsf' format, got {self.name}"
1079
                )
1080
        else:
1081
            raise ValueError(f"{format} is not a valid format")
1082
1083
        return table
1084
1085
    def to_table_hdu(self, format="ogip"):
1086
        """Convert `~astropy.units.Quantity` to OGIP ``EBOUNDS`` extension.
1087
1088
        See https://heasarc.gsfc.nasa.gov/docs/heasarc/caldb/docs/memos/cal_gen_92_002/cal_gen_92_002.html#tth_sEc3.2
1089
1090
        The 'ogip-sherpa' format is equivalent to 'ogip' but uses keV energy units.
1091
1092
        Parameters
1093
        ----------
1094
        format : {"ogip", "ogip-sherpa", "gtpsf"}
1095
            Format specification
1096
1097
        Returns
1098
        -------
1099
        hdu : `~astropy.io.fits.BinTableHDU`
1100
            Table HDU
1101
        """
1102
        table = self.to_table(format=format)
1103
1104
        if format == "gtpsf":
1105
            name = "THETA"
1106
        else:
1107
            name = None
1108
1109
        hdu = fits.BinTableHDU(table, name=name)
1110
1111
        if format in ["ogip", "ogip-sherpa"]:
1112
            hdu.header.update(self.to_header(format=format))
1113
1114
        return hdu
1115
1116
    @classmethod
1117
    def from_table(cls, table, format="ogip", idx=0, column_prefix=""):
1118
        """Instantiate MapAxis from table HDU
1119
1120
        Parameters
1121
        ----------
1122
        table : `~astropy.table.Table`
1123
            Table
1124
        format : {"ogip", "ogip-arf", "fgst-ccube", "fgst-template", "gadf", "gadf-dl3"}
1125
            Format specification
1126
        idx : int
1127
            Column index of the axis.
1128
        column_prefix : str
1129
            Column name prefix of the axis, used for creating the axis.
1130
1131
        Returns
1132
        -------
1133
        axis : `MapAxis`
1134
            Map Axis
1135
        """
1136
        if format in ["ogip", "fgst-ccube"]:
1137
            energy_min = table["E_MIN"].quantity
1138
            energy_max = table["E_MAX"].quantity
1139
            energy_edges = (
1140
                np.append(energy_min.value, energy_max.value[-1]) * energy_min.unit
1141
            )
1142
            axis = cls.from_edges(energy_edges, name="energy", interp="log")
1143
1144
        elif format == "ogip-arf":
1145
            energy_min = table["ENERG_LO"].quantity
1146
            energy_max = table["ENERG_HI"].quantity
1147
            energy_edges = (
1148
                np.append(energy_min.value, energy_max.value[-1]) * energy_min.unit
1149
            )
1150
            axis = cls.from_edges(energy_edges, name="energy_true", interp="log")
1151
1152
        elif format in ["fgst-template", "fgst-bexpcube"]:
1153
            allowed_names = ["Energy", "ENERGY", "energy"]
1154
            for colname in table.colnames:
1155
                if colname in allowed_names:
1156
                    tag = colname
1157
                    break
1158
1159
            nodes = table[tag].data
0 ignored issues
show
introduced by
The variable tag does not seem to be defined for all execution paths.
Loading history...
1160
            axis = cls.from_nodes(
1161
                nodes=nodes, name="energy_true", unit="MeV", interp="log"
1162
            )
1163
1164
        elif format == "gadf":
1165
            axcols = table.meta.get("AXCOLS{}".format(idx + 1))
1166
            colnames = axcols.split(",")
1167
            node_type = "edges" if len(colnames) == 2 else "center"
1168
1169
            # TODO: check why this extra case is needed
1170
            if colnames[0] == "E_MIN":
1171
                name = "energy"
1172
            else:
1173
                name = colnames[0].replace("_MIN", "").lower()
1174
                # this is need for backward compatibility
1175
                if name == "theta":
1176
                    name = "rad"
1177
1178
            interp = table.meta.get("INTERP{}".format(idx + 1), "lin")
1179
1180
            if node_type == "center":
1181
                nodes = np.unique(table[colnames[0]].quantity)
1182
            else:
1183
                edges_min = np.unique(table[colnames[0]].quantity)
1184
                edges_max = np.unique(table[colnames[1]].quantity)
1185
                nodes = edges_from_lo_hi(edges_min, edges_max)
1186
1187
            axis = MapAxis(nodes=nodes, node_type=node_type, interp=interp, name=name)
1188
1189
        elif format == "gadf-dl3":
1190
            from gammapy.irf.io import IRF_DL3_AXES_SPECIFICATION
1191
1192
            spec = IRF_DL3_AXES_SPECIFICATION[column_prefix]
1193
            name, interp = spec["name"], spec["interp"]
1194
1195
            # background models are stored in reconstructed energy
1196
            hduclass = table.meta.get("HDUCLAS2")
1197
            if hduclass in {"BKG", "RAD_MAX"} and column_prefix == "ENERG":
1198
                name = "energy"
1199
1200
            edges_lo = table[f"{column_prefix}_LO"].quantity[0]
1201
            edges_hi = table[f"{column_prefix}_HI"].quantity[0]
1202
1203
            if np.allclose(edges_hi, edges_lo):
1204
                axis = MapAxis.from_nodes(edges_hi, interp=interp, name=name)
1205
            else:
1206
                edges = edges_from_lo_hi(edges_lo, edges_hi)
1207
                axis = MapAxis.from_edges(edges, interp=interp, name=name)
1208
        elif format == "gtpsf":
1209
            try:
1210
                energy = table["Energy"].data * u.MeV
1211
                axis = MapAxis.from_nodes(energy, name="energy_true", interp="log")
1212
            except KeyError:
1213
                rad = table["Theta"].data * u.deg
1214
                axis = MapAxis.from_nodes(rad, name="rad")
1215
        elif format == "gadf-sed-energy":
1216
            if "e_min" in table.colnames and "e_max" in table.colnames:
1217
                e_min = flat_if_equal(table["e_min"].quantity)
1218
                e_max = flat_if_equal(table["e_max"].quantity)
1219
                edges = edges_from_lo_hi(e_min, e_max)
1220
                axis = MapAxis.from_energy_edges(edges)
1221
            elif "e_ref" in table.colnames:
1222
                e_ref = flat_if_equal(table["e_ref"].quantity)
1223
                axis = MapAxis.from_nodes(e_ref, name="energy", interp="log")
1224
            else:
1225
                raise ValueError(
1226
                    "Either 'e_ref', 'e_min' or 'e_max' column " "names are required"
1227
                )
1228
        elif format == "gadf-sed-norm":
1229
            # TODO: guess interp here
1230
            nodes = flat_if_equal(table["norm_scan"][0])
1231
            axis = MapAxis.from_nodes(nodes, name="norm")
1232
        elif format == "gadf-sed-counts":
1233
            if "datasets" in table.colnames:
1234
                labels = np.unique(table["datasets"])
1235
                axis = LabelMapAxis(labels=labels, name="dataset")
1236
            else:
1237
                shape = table["counts"].shape
1238
                edges = np.arange(shape[-1] + 1) - 0.5
1239
                axis = MapAxis.from_edges(edges, name="dataset")
1240
        else:
1241
            raise ValueError(f"Format '{format}' not supported")
1242
1243
        return axis
1244
1245
    @classmethod
1246
    def from_table_hdu(cls, hdu, format="ogip", idx=0):
1247
        """Instantiate MapAxis from table HDU
1248
1249
        Parameters
1250
        ----------
1251
        hdu : `~astropy.io.fits.BinTableHDU`
1252
            Table HDU
1253
        format : {"ogip", "ogip-arf", "fgst-ccube", "fgst-template"}
1254
            Format specification
1255
        idx : int
1256
            Column index of the axis.
1257
1258
        Returns
1259
        -------
1260
        axis : `MapAxis`
1261
            Map Axis
1262
        """
1263
        table = Table.read(hdu)
1264
        return cls.from_table(table, format=format, idx=idx)
1265
1266
1267
class MapAxes(Sequence):
1268
    """MapAxis container class.
1269
1270
    Parameters
1271
    ----------
1272
    axes : list of `MapAxis`
1273
        List of map axis objects.
1274
    """
1275
1276
    def __init__(self, axes, n_spatial_axes=None):
1277
        unique_names = []
1278
1279
        for ax in axes:
1280
            if ax.name in unique_names:
1281
                raise (
1282
                    ValueError(f"Axis names must be unique, got: '{ax.name}' twice.")
1283
                )
1284
            unique_names.append(ax.name)
1285
1286
        self._axes = axes
1287
        self._n_spatial_axes = n_spatial_axes
1288
1289
    @property
1290
    def primary_axis(self):
1291
        """Primary extra axis, defined as the one longest
1292
1293
        Returns
1294
        -------
1295
        axis : `MapAxis`
1296
            Map axis
1297
        """
1298
        # get longest axis
1299
        idx = np.argmax(self.shape)
1300
        return self[int(idx)]
1301
1302
    @property
1303
    def is_flat(self):
1304
        """Whether axes is flat"""
1305
        return np.all(self.shape == 1)
1306
1307
    @property
1308
    def is_unidimensional(self):
1309
        """Whether axes is unidimensional"""
1310
        value = (np.array(self.shape) > 1).sum()
1311
        return value == 1
1312
1313
    @property
1314
    def reverse(self):
1315
        """Reverse axes order"""
1316
        return MapAxes(self[::-1])
1317
1318
    @property
1319
    def iter_with_reshape(self):
1320
        """Iterate by shape"""
1321
        for idx, axis in enumerate(self):
1322
            # Extract values for each axis, default: nodes
1323
            shape = [1] * len(self)
1324
            shape[idx] = -1
1325
            if self._n_spatial_axes:
1326
                shape = (
1327
                    shape[::-1]
1328
                    + [
1329
                        1,
1330
                    ]
1331
                    * self._n_spatial_axes
1332
                )
1333
            yield tuple(shape), axis
1334
1335
    def get_coord(self, mode="center", axis_name=None):
1336
        """Get axes coordinates
1337
1338
        Parameters
1339
        ----------
1340
        mode : {"center", "edges"}
1341
            Coordinate center or edges
1342
        axis_name : str
1343
            Axis name for which mode='edges' applies
1344
1345
        Returns
1346
        -------
1347
        coords : dict of `~astropy.units.Quanity`
1348
            Map coordinates
1349
        """
1350
        coords = {}
1351
1352
        for shape, axis in self.iter_with_reshape:
1353
            if mode == "edges" and axis.name == axis_name:
1354
                coord = axis.edges
1355
            else:
1356
                coord = axis.center
1357
            coords[axis.name] = coord.reshape(shape)
1358
1359
        return coords
1360
1361
    def bin_volume(self):
1362
        """Bin axes volume
1363
1364
        Returns
1365
        -------
1366
        bin_volume : `~astropy.units.Quantity`
1367
            Bin volume
1368
        """
1369
        bin_volume = np.array(1)
1370
1371
        for shape, axis in self.iter_with_reshape:
1372
            bin_volume = bin_volume * axis.bin_width.reshape(shape)
1373
1374
        return bin_volume
1375
1376
    @property
1377
    def shape(self):
1378
        """Shape of the axes"""
1379
        return tuple([ax.nbin for ax in self])
1380
1381
    @property
1382
    def names(self):
1383
        """Names of the axes"""
1384
        return [ax.name for ax in self]
1385
1386
    def index(self, axis_name):
1387
        """Get index in list"""
1388
        return self.names.index(axis_name)
1389
1390
    def index_data(self, axis_name):
1391
        """Get data index of the axes
1392
1393
        Parameters
1394
        ----------
1395
        axis_name : str
1396
            Name of the axis.
1397
1398
        Returns
1399
        -------
1400
        idx : int
1401
            Data index
1402
        """
1403
        idx = self.names.index(axis_name)
1404
        return len(self) - idx - 1
1405
1406
    def __len__(self):
1407
        return len(self._axes)
1408
1409
    def __add__(self, other):
1410
        return self.__class__(list(self) + list(other))
1411
1412
    def upsample(self, factor, axis_name):
1413
        """Upsample axis by a given factor
1414
1415
        Parameters
1416
        ----------
1417
        factor : int
1418
            Upsampling factor.
1419
        axis_name : str
1420
            Axis to upsample.
1421
1422
        Returns
1423
        -------
1424
        axes : `MapAxes`
1425
            Map axes
1426
        """
1427
        axes = []
1428
1429
        for ax in self:
1430
            if ax.name == axis_name:
1431
                ax = ax.upsample(factor=factor)
1432
1433
            axes.append(ax.copy())
1434
1435
        return self.__class__(axes=axes)
1436
1437
    def replace(self, axis):
1438
        """Replace a given axis
1439
1440
        Parameters
1441
        ----------
1442
        axis : `MapAxis`
1443
            Map axis
1444
1445
        Returns
1446
        -------
1447
        axes : MapAxes
1448
            Map axe
1449
        """
1450
        axes = []
1451
1452
        for ax in self:
1453
            if ax.name == axis.name:
1454
                ax = axis
1455
1456
            axes.append(ax)
1457
1458
        return self.__class__(axes=axes)
1459
1460
    def resample(self, axis):
1461
        """Resample axis binning.
1462
1463
        This method groups the existing bins into a new binning.
1464
1465
        Parameters
1466
        ----------
1467
        axis : `MapAxis`
1468
            New map axis.
1469
1470
        Returns
1471
        -------
1472
        axes : `MapAxes`
1473
            Axes object with resampled axis.
1474
        """
1475
        axis_self = self[axis.name]
1476
        groups = axis_self.group_table(axis.edges)
1477
1478
        # Keep only normal bins
1479
        groups = groups[groups["bin_type"] == "normal   "]
1480
1481
        edges = edges_from_lo_hi(
1482
            groups[axis.name + "_min"].quantity,
1483
            groups[axis.name + "_max"].quantity,
1484
        )
1485
1486
        axis_resampled = MapAxis.from_edges(
1487
            edges=edges, interp=axis.interp, name=axis.name
1488
        )
1489
1490
        axes = []
1491
        for ax in self:
1492
            if ax.name == axis.name:
1493
                axes.append(axis_resampled)
1494
            else:
1495
                axes.append(ax.copy())
1496
1497
        return self.__class__(axes=axes)
1498
1499
    def downsample(self, factor, axis_name):
1500
        """Downsample axis by a given factor
1501
1502
        Parameters
1503
        ----------
1504
        factor : int
1505
            Upsampling factor.
1506
        axis_name : str
1507
            Axis to upsample.
1508
1509
        Returns
1510
        -------
1511
        axes : `MapAxes`
1512
            Map axes
1513
1514
        """
1515
        axes = []
1516
1517
        for ax in self:
1518
            if ax.name == axis_name:
1519
                ax = ax.downsample(factor=factor)
1520
1521
            axes.append(ax.copy())
1522
1523
        return self.__class__(axes=axes)
1524
1525
    def squash(self, axis_name):
1526
        """Squash axis.
1527
1528
        Parameters
1529
        ----------
1530
        axis_name : str
1531
            Axis to squash.
1532
1533
        Returns
1534
        -------
1535
        axes : `MapAxes`
1536
            Axes with squashed axis.
1537
        """
1538
        axes = []
1539
1540
        for ax in self:
1541
            if ax.name == axis_name:
1542
                ax = ax.squash()
1543
            axes.append(ax.copy())
1544
1545
        return self.__class__(axes=axes)
1546
1547
    def pad(self, axis_name, pad_width):
1548
        """Pad axes
1549
1550
        Parameters
1551
        ----------
1552
        axis_name : str
1553
            Name of the axis to pad.
1554
        pad_width : int or tuple of int
1555
            Pad width
1556
1557
        Returns
1558
        -------
1559
        axes : `MapAxes`
1560
            Axes with squashed axis.
1561
1562
        """
1563
        axes = []
1564
1565
        for ax in self:
1566
            if ax.name == axis_name:
1567
                ax = ax.pad(pad_width=pad_width)
1568
            axes.append(ax)
1569
1570
        return self.__class__(axes=axes)
1571
1572
    def drop(self, axis_name):
1573
        """Drop an axis.
1574
1575
        Parameters
1576
        ----------
1577
        axis_name : str
1578
            Name of the axis to remove.
1579
1580
        Returns
1581
        -------
1582
        axes : `MapAxes`
1583
            Axes with squashed axis.
1584
        """
1585
        axes = []
1586
        for ax in self:
1587
            if ax.name == axis_name:
1588
                continue
1589
            axes.append(ax.copy())
1590
1591
        return self.__class__(axes=axes)
1592
1593
    def __getitem__(self, idx):
1594
        if isinstance(idx, (int, slice)):
1595
            return self._axes[idx]
1596
        elif isinstance(idx, str):
1597
            for ax in self._axes:
1598
                if ax.name == idx:
1599
                    return ax
1600
            raise KeyError(f"No axes: {idx!r}")
1601
        elif isinstance(idx, list):
1602
            axes = []
1603
            for name in idx:
1604
                axes.append(self[name])
1605
1606
            return self.__class__(axes=axes)
1607
        else:
1608
            raise TypeError(f"Invalid type: {type(idx)!r}")
1609
1610
    def coord_to_idx(self, coord, clip=True):
1611
        """Transform from axis to pixel indices.
1612
1613
        Parameters
1614
        ----------
1615
        coord : dict of `~numpy.ndarray` or `MapCoord`
1616
            Array of axis coordinate values.
1617
1618
        Returns
1619
        -------
1620
        pix : tuple of `~numpy.ndarray`
1621
            Array of pixel indices values.
1622
        """
1623
        return tuple([ax.coord_to_idx(coord[ax.name], clip=clip) for ax in self])
1624
1625
    def coord_to_pix(self, coord):
1626
        """Transform from axis to pixel coordinates.
1627
1628
        Parameters
1629
        ----------
1630
        coord : dict of `~numpy.ndarray`
1631
            Array of axis coordinate values.
1632
1633
        Returns
1634
        -------
1635
        pix : tuple of `~numpy.ndarray`
1636
            Array of pixel coordinate values.
1637
        """
1638
        return tuple([ax.coord_to_pix(coord[ax.name]) for ax in self])
1639
1640
    def pix_to_coord(self, pix):
1641
        """Convert pixel coordinates to map coordinates.
1642
1643
        Parameters
1644
        ----------
1645
        pix : tuple
1646
            Tuple of pixel coordinates.
1647
1648
        Returns
1649
        -------
1650
        coords : tuple
1651
            Tuple of map coordinates.
1652
        """
1653
        return tuple([ax.pix_to_coord(p) for ax, p in zip(self, pix)])
1654
1655
    def pix_to_idx(self, pix, clip=False):
1656
        """Convert pix to idx
1657
1658
        Parameters
1659
        ----------
1660
        pix : tuple of `~numpy.ndarray`
1661
            Pixel coordinates.
1662
        clip : bool
1663
            Choose whether to clip indices to the valid range of the
1664
            axis.  If false then indices for coordinates outside
1665
            the axi range will be set -1.
1666
1667
        Returns
1668
        -------
1669
        idx : tuple `~numpy.ndarray`
1670
            Pixel indices.
1671
        """
1672
        idx = []
1673
1674
        for pix_array, ax in zip(pix, self):
1675
            idx.append(ax.pix_to_idx(pix_array, clip=clip))
1676
1677
        return tuple(idx)
1678
1679
    def slice_by_idx(self, slices):
1680
        """Create a new geometry by slicing the non-spatial axes.
1681
1682
        Parameters
1683
        ----------
1684
        slices : dict
1685
            Dict of axes names and integers or `slice` object pairs. Contains one
1686
            element for each non-spatial dimension. For integer indexing the
1687
            corresponding axes is dropped from the map. Axes not specified in the
1688
            dict are kept unchanged.
1689
1690
        Returns
1691
        -------
1692
        geom : `~Geom`
1693
            Sliced geometry.
1694
        """
1695
        axes = []
1696
        for ax in self:
1697
            ax_slice = slices.get(ax.name, slice(None))
1698
1699
            # in the case where isinstance(ax_slice, int) the axes is dropped
1700
            if isinstance(ax_slice, slice):
1701
                ax_sliced = ax.slice(ax_slice)
1702
                axes.append(ax_sliced.copy())
1703
1704
        return self.__class__(axes=axes)
1705
1706
    def to_header(self, format="gadf"):
1707
        """Convert axes to FITS header
1708
1709
        Parameters
1710
        ----------
1711
        format : {"gadf"}
1712
            Header format
1713
1714
        Returns
1715
        -------
1716
        header : `~astropy.io.fits.Header`
1717
            FITS header.
1718
        """
1719
        header = fits.Header()
1720
1721
        for idx, ax in enumerate(self, start=1):
1722
            header_ax = ax.to_header(format=format, idx=idx)
1723
            header.update(header_ax)
1724
1725
        return header
1726
1727
    def to_table(self, format="gadf"):
1728
        """Convert axes to table
1729
1730
        Parameters
1731
        ----------
1732
        format : {"gadf", "gadf-dl3", "fgst-ccube", "fgst-template", "ogip", "ogip-sherpa", "ogip-arf", "ogip-arf-sherpa"}
1733
            Format to use.
1734
1735
        Returns
1736
        -------
1737
        table : `~astropy.table.Table`
1738
            Table with axis data
1739
        """
1740
        if format == "gadf-dl3":
1741
            tables = []
1742
1743
            for ax in self:
1744
                tables.append(ax.to_table(format=format))
1745
1746
            table = hstack(tables)
1747
        elif format in ["gadf", "fgst-ccube", "fgst-template"]:
1748
            table = Table()
1749
            table["CHANNEL"] = np.arange(np.prod(self.shape))
1750
1751
            axes_ctr = np.meshgrid(*[ax.center for ax in self])
1752
            axes_min = np.meshgrid(*[ax.edges_min for ax in self])
1753
            axes_max = np.meshgrid(*[ax.edges_max for ax in self])
1754
1755
            for idx, ax in enumerate(self):
1756
                name = ax.name.upper()
1757
1758
                if name == "ENERGY":
1759
                    colnames = ["ENERGY", "E_MIN", "E_MAX"]
1760
                else:
1761
                    colnames = [name, name + "_MIN", name + "_MAX"]
1762
1763
                for colname, v in zip(colnames, [axes_ctr, axes_min, axes_max]):
1764
                    # do not store edges for label axis
1765
                    if ax.node_type == "label" and colname != name:
1766
                        continue
1767
1768
                    table[colname] = np.ravel(v[idx])
1769
1770
                if isinstance(ax, TimeMapAxis):
1771
                    ref_dict = time_ref_to_dict(ax.reference_time)
1772
                    table.meta.update(ref_dict)
1773
1774
        elif format in ["ogip", "ogip-sherpa", "ogip", "ogip-arf"]:
1775
            energy_axis = self["energy"]
1776
            table = energy_axis.to_table(format=format)
1777
        else:
1778
            raise ValueError(f"Unsupported format: '{format}'")
1779
1780
        return table
1781
1782
    def to_table_hdu(self, format="gadf", hdu_bands=None):
1783
        """Make FITS table columns for map axes.
1784
1785
        Parameters
1786
        ----------
1787
        format : {"gadf", "fgst-ccube", "fgst-template"}
1788
            Format to use.
1789
        hdu_bands : str
1790
            Name of the bands HDU to use.
1791
1792
        Returns
1793
        -------
1794
        hdu : `~astropy.io.fits.BinTableHDU`
1795
            Bin table HDU.
1796
        """
1797
        # FIXME: Check whether convention is compatible with
1798
        #  dimensionality of geometry and simplify!!!
1799
1800
        if format in ["fgst-ccube", "ogip", "ogip-sherpa"]:
1801
            hdu_bands = "EBOUNDS"
1802
        elif format == "fgst-template":
1803
            hdu_bands = "ENERGIES"
1804
        elif format == "gadf" or format is None:
1805
            if hdu_bands is None:
1806
                hdu_bands = "BANDS"
1807
        else:
1808
            raise ValueError(f"Unknown format {format}")
1809
1810
        table = self.to_table(format=format)
1811
        header = self.to_header(format=format)
1812
        return fits.BinTableHDU(table, name=hdu_bands, header=header)
1813
1814
    @classmethod
1815
    def from_table_hdu(cls, hdu, format="gadf"):
1816
        """Create MapAxes from BinTableHDU
1817
1818
        Parameters
1819
        ----------
1820
        hdu : `~astropy.io.fits.BinTableHDU`
1821
            Bin table HDU
1822
1823
1824
        Returns
1825
        -------
1826
        axes : `MapAxes`
1827
            Map axes object
1828
        """
1829
        if hdu is None:
1830
            return cls([])
1831
1832
        table = Table.read(hdu)
1833
        return cls.from_table(table, format=format)
1834
1835
    @classmethod
1836
    def from_table(cls, table, format="gadf"):
1837
        """Create MapAxes from BinTableHDU
1838
1839
        Parameters
1840
        ----------
1841
        table : `~astropy.table.Table`
1842
            Bin table HDU
1843
        format : {"gadf", "gadf-dl3", "fgst-ccube", "fgst-template", "fgst-bexcube", "ogip-arf"}
1844
            Format to use.
1845
1846
        Returns
1847
        -------
1848
        axes : `MapAxes`
1849
            Map axes object
1850
        """
1851
        from gammapy.irf.io import IRF_DL3_AXES_SPECIFICATION
1852
1853
        axes = []
1854
1855
        # Formats that support only one energy axis
1856
        if format in [
1857
            "fgst-ccube",
1858
            "fgst-template",
1859
            "fgst-bexpcube",
1860
            "ogip",
1861
            "ogip-arf",
1862
        ]:
1863
            axes.append(MapAxis.from_table(table, format=format))
1864
        elif format == "gadf":
1865
            # This limits the max number of axes to 5
1866
            for idx in range(5):
1867
                axcols = table.meta.get("AXCOLS{}".format(idx + 1))
1868
                if axcols is None:
1869
                    break
1870
1871
                # TODO: what is good way to check whether it is a given axis type?
1872
                try:
1873
                    axis = LabelMapAxis.from_table(table, format=format, idx=idx)
1874
                except (KeyError, TypeError):
1875
                    try:
1876
                        axis = TimeMapAxis.from_table(table, format=format, idx=idx)
1877
                    except (KeyError, ValueError):
1878
                        axis = MapAxis.from_table(table, format=format, idx=idx)
1879
1880
                axes.append(axis)
1881
        elif format == "gadf-dl3":
1882
            for column_prefix in IRF_DL3_AXES_SPECIFICATION:
1883
                try:
1884
                    axis = MapAxis.from_table(
1885
                        table, format=format, column_prefix=column_prefix
1886
                    )
1887
                except KeyError:
1888
                    continue
1889
1890
                axes.append(axis)
1891
        elif format == "gadf-sed":
1892
            for axis_format in ["gadf-sed-norm", "gadf-sed-energy", "gadf-sed-counts"]:
1893
                try:
1894
                    axis = MapAxis.from_table(table=table, format=axis_format)
1895
                except KeyError:
1896
                    continue
1897
                axes.append(axis)
1898
        elif format == "lightcurve":
1899
            axes.extend(cls.from_table(table=table, format="gadf-sed"))
1900
            axes.append(TimeMapAxis.from_table(table, format="lightcurve"))
1901
        else:
1902
            raise ValueError(f"Unsupported format: '{format}'")
1903
1904
        return cls(axes)
1905
1906
    @classmethod
1907
    def from_default(cls, axes, n_spatial_axes=None):
1908
        """Make a sequence of `~MapAxis` objects."""
1909
        if axes is None:
1910
            return cls([])
1911
1912
        axes_out = []
1913
        for idx, ax in enumerate(axes):
1914
            if isinstance(ax, np.ndarray):
1915
                ax = MapAxis(ax)
1916
1917
            if ax.name == "":
1918
                ax.name = f"axis{idx}"
1919
1920
            axes_out.append(ax)
1921
1922
        return cls(axes_out, n_spatial_axes=n_spatial_axes)
1923
1924
    def assert_names(self, required_names):
1925
        """Assert required axis names and order
1926
1927
        Parameters
1928
        ----------
1929
        required_names : list of str
1930
            Required
1931
        """
1932
        message = (
1933
            "Incorrect axis order or names. Expected axis "
1934
            f"order: {required_names}, got: {self.names}."
1935
        )
1936
1937
        if not len(self) == len(required_names):
1938
            raise ValueError(message)
1939
1940
        try:
1941
            for ax, required_name in zip(self, required_names):
1942
                ax.assert_name(required_name)
1943
1944
        except ValueError:
1945
            raise ValueError(message)
1946
1947
    @property
1948
    def center_coord(self):
1949
        """Center coordinates"""
1950
        return tuple([ax.pix_to_coord((float(ax.nbin) - 1.0) / 2.0) for ax in self])
1951
1952
1953
class TimeMapAxis:
1954
    """Class representing a time axis.
1955
1956
    Provides methods for transforming to/from axis and pixel coordinates.
1957
    A time axis can represent non-contiguous sequences of non-overlapping time intervals.
1958
1959
    Time intervals must be provided in increasing order.
1960
1961
    Parameters
1962
    ----------
1963
    edges_min : `~astropy.units.Quantity`
1964
        Array of edge time values. This the time delta w.r.t. to the reference time.
1965
    edges_max : `~astropy.units.Quantity`
1966
        Array of edge time values. This the time delta w.r.t. to the reference time.
1967
    reference_time : `~astropy.time.Time`
1968
        Reference time to use.
1969
    name : str
1970
        Axis name
1971
    interp : str
1972
        Interpolation method used to transform between axis and pixel
1973
        coordinates.  For now only 'lin' is supported.
1974
    """
1975
1976
    node_type = "intervals"
1977
    time_format = "iso"
1978
1979
    def __init__(self, edges_min, edges_max, reference_time, name="time", interp="lin"):
1980
        self._name = name
1981
1982
        edges_min = u.Quantity(edges_min, ndmin=1)
1983
        edges_max = u.Quantity(edges_max, ndmin=1)
1984
1985
        if not edges_min.unit.is_equivalent("s"):
1986
            raise ValueError(
1987
                f"Time edges min must have a valid time unit, got {edges_min.unit}"
1988
            )
1989
1990
        if not edges_max.unit.is_equivalent("s"):
1991
            raise ValueError(
1992
                f"Time edges max must have a valid time unit, got {edges_max.unit}"
1993
            )
1994
1995
        if not edges_min.shape == edges_max.shape:
1996
            raise ValueError(
1997
                "Edges min and edges max must have the same shape,"
1998
                f" got {edges_min.shape} and {edges_max.shape}."
1999
            )
2000
2001
        if not np.all(edges_max > edges_min):
2002
            raise ValueError("Edges max must all be larger than edge min")
2003
2004
        if not np.all(edges_min == np.sort(edges_min)):
2005
            raise ValueError("Time edges min values must be sorted")
2006
2007
        if not np.all(edges_max == np.sort(edges_max)):
2008
            raise ValueError("Time edges max values must be sorted")
2009
2010
        if interp != "lin":
2011
            raise NotImplementedError(
2012
                f"Non-linear scaling scheme are not supported yet, got {interp}"
2013
            )
2014
2015
        self._edges_min = edges_min
2016
        self._edges_max = edges_max
2017
        self._reference_time = Time(reference_time)
2018
        self._pix_offset = -0.5
2019
        self._interp = interp
2020
2021
        delta = edges_min[1:] - edges_max[:-1]
2022
        if np.any(delta < 0 * u.s):
2023
            raise ValueError("Time intervals must not overlap.")
2024
2025
    @property
2026
    def is_contiguous(self):
2027
        """Whether the axis is contiguous"""
2028
        return np.all(self.edges_min[1:] == self.edges_max[:-1])
2029
2030
    def to_contiguous(self):
2031
        """Make the time axis contiguous
2032
2033
        Returns
2034
        -------
2035
        axis : `TimeMapAxis`
2036
            Contiguous time axis
2037
        """
2038
        edges = np.unique(np.stack([self.edges_min, self.edges_max]))
2039
        return self.__class__(
2040
            edges_min=edges[:-1],
2041
            edges_max=edges[1:],
2042
            reference_time=self.reference_time,
2043
            name=self.name,
2044
            interp=self.interp,
2045
        )
2046
2047
    @property
2048
    def unit(self):
2049
        """Axes unit"""
2050
        return self.edges_max.unit
2051
2052
    @property
2053
    def interp(self):
2054
        """Interp"""
2055
        return self._interp
2056
2057
    @property
2058
    def reference_time(self):
2059
        """Return reference time used for the axis."""
2060
        return self._reference_time
2061
2062
    @property
2063
    def name(self):
2064
        """Return axis name."""
2065
        return self._name
2066
2067
    @property
2068
    def nbin(self):
2069
        """Return number of bins in the axis."""
2070
        return len(self.edges_min.flatten())
2071
2072
    @property
2073
    def edges_min(self):
2074
        """Return array of bin edges max values."""
2075
        return self._edges_min
2076
2077
    @property
2078
    def edges_max(self):
2079
        """Return array of bin edges min values."""
2080
        return self._edges_max
2081
2082
    @property
2083
    def edges(self):
2084
        """Return array of bin edges values."""
2085
        if not self.is_contiguous:
2086
            raise ValueError("Time axis is not contiguous")
2087
2088
        return edges_from_lo_hi(self.edges_min, self.edges_max)
2089
2090
    @property
2091
    def time_min(self):
2092
        """Return axis lower edges as Time objects."""
2093
        return self._edges_min + self.reference_time
2094
2095
    @property
2096
    def time_max(self):
2097
        """Return axis upper edges as Time objects."""
2098
        return self._edges_max + self.reference_time
2099
2100
    @property
2101
    def time_delta(self):
2102
        """Return axis time bin width (`~astropy.time.TimeDelta`)."""
2103
        return self.time_max - self.time_min
2104
2105
    @property
2106
    def time_mid(self):
2107
        """Return time bin center (`~astropy.time.Time`)."""
2108
        return self.time_min + 0.5 * self.time_delta
2109
2110
    @property
2111
    def time_edges(self):
2112
        """Time edges"""
2113
        return self.reference_time + self.edges
2114
2115
    @property
2116
    def as_plot_xerr(self):
2117
        """Plot x error"""
2118
        xn, xp = self.time_mid - self.time_min, self.time_max - self.time_mid
2119
2120
        if self.time_format == "iso":
2121
            x_errn = xn.to_datetime()
2122
            x_errp = xp.to_datetime()
2123
        elif self.time_format == "mjd":
2124
            x_errn = xn.to("day")
2125
            x_errp = xp.to("day")
2126
        else:
2127
            raise ValueError(f"Invalid time_format: {self.time_format}")
2128
2129
        return x_errn, x_errp
2130
2131
    @property
2132
    def as_plot_labels(self):
2133
        """Plot labels"""
2134
        labels = []
2135
2136
        for t_min, t_max in self.iter_by_edges:
2137
            label = f"{getattr(t_min, self.time_format)} - {getattr(t_max, self.time_format)}"
2138
            labels.append(label)
2139
2140
        return labels
2141
2142
    @property
2143
    def as_plot_edges(self):
2144
        """Plot edges"""
2145
        if self.time_format == "iso":
2146
            edges = self.time_edges.to_datetime()
2147
        elif self.time_format == "mjd":
2148
            edges = self.time_edges.mjd * u.day
2149
        else:
2150
            raise ValueError(f"Invalid time_format: {self.time_format}")
2151
2152
        return edges
2153
2154
    @property
2155
    def as_plot_center(self):
2156
        """Plot center"""
2157
        if self.time_format == "iso":
2158
            center = self.time_mid.datetime
2159
        elif self.time_format == "mjd":
2160
            center = self.time_mid.mjd * u.day
2161
2162
        return center
0 ignored issues
show
introduced by
The variable center does not seem to be defined for all execution paths.
Loading history...
2163
2164
    def format_plot_xaxis(self, ax):
2165
        """Format plot axis
2166
2167
        Parameters
2168
        ----------
2169
        ax : `~matplotlib.pyplot.Axis`
2170
            Plot axis to format
2171
2172
        Returns
2173
        -------
2174
        ax : `~matplotlib.pyplot.Axis`
2175
            Formatted plot axis
2176
        """
2177
        import matplotlib.pyplot as plt
2178
        from matplotlib.dates import DateFormatter
2179
2180
        xlabel = self.name.capitalize() + f" [{self.time_format}]"
2181
2182
        ax.set_xlabel(xlabel)
2183
2184
        if self.time_format == "iso":
2185
            ax.xaxis.set_major_formatter(DateFormatter("%Y-%m-%d %H:%M:%S"))
2186
            plt.setp(
2187
                ax.xaxis.get_majorticklabels(),
2188
                rotation=30,
2189
                ha="right",
2190
                rotation_mode="anchor",
2191
            )
2192
2193
        return ax
2194
2195
    def assert_name(self, required_name):
2196
        """Assert axis name if a specific one is required.
2197
2198
        Parameters
2199
        ----------
2200
        required_name : str
2201
            Required
2202
        """
2203
        if self.name != required_name:
2204
            raise ValueError(
2205
                "Unexpected axis name,"
2206
                f' expected "{required_name}", got: "{self.name}"'
2207
            )
2208
2209
    def __eq__(self, other):
2210
        if not isinstance(other, self.__class__):
2211
            return NotImplemented
2212
2213
        if self._edges_min.shape != other._edges_min.shape:
2214
            return False
2215
2216
        # This will test equality at microsec level.
2217
        delta_min = self.time_min - other.time_min
2218
        delta_max = self.time_max - other.time_max
2219
2220
        return (
2221
            np.allclose(delta_min.to_value("s"), 0.0, atol=1e-6)
2222
            and np.allclose(delta_max.to_value("s"), 0.0, atol=1e-6)
2223
            and self._interp == other._interp
2224
            and self.name.upper() == other.name.upper()
2225
        )
2226
2227
    def __ne__(self, other):
2228
        return not self.__eq__(other)
2229
2230
    def __hash__(self):
2231
        return id(self)
2232
2233
    def is_aligned(self, other, atol=2e-2):
2234
        raise NotImplementedError
2235
2236
    @property
2237
    def iter_by_edges(self):
2238
        """Iterate by intervals defined by the edges"""
2239
        for time_min, time_max in zip(self.time_min, self.time_max):
2240
            yield (time_min, time_max)
2241
2242
    def coord_to_idx(self, coord, **kwargs):
2243
        """Transform from axis time coordinate to bin index.
2244
2245
        Indices of time values falling outside time bins will be
2246
        set to -1.
2247
2248
        Parameters
2249
        ----------
2250
        coord : `~astropy.time.Time` or `~astropy.units.Quantity`
2251
            Array of axis coordinate values. The quantity is assumed
2252
            to be relative to the reference time.
2253
2254
        Returns
2255
        -------
2256
        idx : `~numpy.ndarray`
2257
            Array of bin indices.
2258
        """
2259
        if isinstance(coord, u.Quantity):
2260
            coord = self.reference_time + coord
2261
2262
        time = Time(coord[..., np.newaxis])
2263
        delta_plus = (time - self.time_min).value > 0.0
2264
        delta_minus = (time - self.time_max).value <= 0.0
2265
        mask = np.logical_and(delta_plus, delta_minus)
2266
2267
        idx = np.asanyarray(np.argmax(mask, axis=-1))
2268
        idx[~np.any(mask, axis=-1)] = INVALID_INDEX.int
2269
        return idx
2270
2271
    def coord_to_pix(self, coord, **kwargs):
2272
        """Transform from time to coordinate to pixel position.
2273
2274
        Pixels of time values falling outside time bins will be
2275
        set to -1.
2276
2277
        Parameters
2278
        ----------
2279
        coord : `~astropy.time.Time`
2280
            Array of axis coordinate values.
2281
2282
        Returns
2283
        -------
2284
        pix : `~numpy.ndarray`
2285
            Array of pixel positions.
2286
        """
2287
        if isinstance(coord, u.Quantity):
2288
            coord = self.reference_time + coord
2289
2290
        idx = np.atleast_1d(self.coord_to_idx(coord))
2291
2292
        valid_pix = idx != INVALID_INDEX.int
2293
        pix = np.atleast_1d(idx).astype("float")
2294
2295
        # TODO: is there the equivalent of np.atleast1d for astropy.time.Time?
2296
        if coord.shape == ():
2297
            coord = coord.reshape((1,))
2298
2299
        relative_time = coord[valid_pix] - self.reference_time
2300
2301
        scale = interpolation_scale(self._interp)
2302
        valid_idx = idx[valid_pix]
2303
        s_min = scale(self._edges_min[valid_idx])
2304
        s_max = scale(self._edges_max[valid_idx])
2305
        s_coord = scale(relative_time.to(self._edges_min.unit))
2306
2307
        pix[valid_pix] += (s_coord - s_min) / (s_max - s_min)
2308
        pix[~valid_pix] = INVALID_INDEX.float
2309
        return pix - 0.5
2310
2311
    @staticmethod
2312
    def pix_to_idx(pix, clip=False):
2313
        return pix
2314
2315
    @property
2316
    def center(self):
2317
        """Return `~astropy.units.Quantity` at interval centers."""
2318
        return self.edges_min + 0.5 * self.bin_width
2319
2320
    @property
2321
    def bin_width(self):
2322
        """Return time interval width (`~astropy.units.Quantity`)."""
2323
        return self.time_delta.to("h")
2324
2325
    def __repr__(self):
2326
        str_ = self.__class__.__name__ + "\n"
2327
        str_ += "-" * len(self.__class__.__name__) + "\n\n"
2328
        fmt = "\t{:<14s} : {:<10s}\n"
2329
        str_ += fmt.format("name", self.name)
2330
        str_ += fmt.format("nbins", str(self.nbin))
2331
        str_ += fmt.format("reference time", self.reference_time.iso)
2332
        str_ += fmt.format("scale", self.reference_time.scale)
2333
        str_ += fmt.format("time min.", self.time_min.min().iso)
2334
        str_ += fmt.format("time max.", self.time_max.max().iso)
2335
        str_ += fmt.format("total time", np.sum(self.bin_width))
2336
        return str_.expandtabs(tabsize=2)
2337
2338
    def upsample(self):
2339
        raise NotImplementedError
2340
2341
    def downsample(self):
2342
        raise NotImplementedError
2343
2344 View Code Duplication
    def _init_copy(self, **kwargs):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
2345
        """Init map axis instance by copying missing init arguments from self."""
2346
        argnames = inspect.getfullargspec(self.__init__).args
2347
        argnames.remove("self")
2348
2349
        for arg in argnames:
2350
            value = getattr(self, "_" + arg)
2351
            kwargs.setdefault(arg, copy.deepcopy(value))
2352
2353
        return self.__class__(**kwargs)
2354
2355
    def copy(self, **kwargs):
2356
        """Copy `MapAxis` instance and overwrite given attributes.
2357
2358
        Parameters
2359
        ----------
2360
        **kwargs : dict
2361
            Keyword arguments to overwrite in the map axis constructor.
2362
2363
        Returns
2364
        -------
2365
        copy : `MapAxis`
2366
            Copied map axis.
2367
        """
2368
        return self._init_copy(**kwargs)
2369
2370
    def slice(self, idx):
2371
        """Create a new axis object by extracting a slice from this axis.
2372
2373
        Parameters
2374
        ----------
2375
        idx : slice
2376
            Slice object selecting a subselection of the axis.
2377
2378
        Returns
2379
        -------
2380
        axis : `~TimeMapAxis`
2381
            Sliced axis object.
2382
        """
2383
        return TimeMapAxis(
2384
            self._edges_min[idx].copy(),
2385
            self._edges_max[idx].copy(),
2386
            self.reference_time,
2387
            interp=self._interp,
2388
            name=self.name,
2389
        )
2390
2391
    def squash(self):
2392
        """Create a new axis object by squashing the axis into one bin.
2393
2394
        Returns
2395
        -------
2396
        axis : `~MapAxis`
2397
            Sliced axis object.
2398
        """
2399
        return TimeMapAxis(
2400
            self._edges_min[0],
2401
            self._edges_max[-1],
2402
            self.reference_time,
2403
            interp=self._interp,
2404
            name=self._name,
2405
        )
2406
2407
    # TODO: if we are to allow log or sqrt bins the reference time should always
2408
    #  be strictly lower than all times
2409
    #  Should we define a mechanism to ensure this is always correct?
2410
    @classmethod
2411
    def from_time_edges(cls, time_min, time_max, unit="d", interp="lin", name="time"):
2412
        """Create TimeMapAxis from the time interval edges defined as `~astropy.time.Time`.
2413
2414
        The reference time is defined as the lower edge of the first interval.
2415
2416
        Parameters
2417
        ----------
2418
        time_min : `~astropy.time.Time`
2419
            Array of lower edge times.
2420
        time_max : `~astropy.time.Time`
2421
            Array of lower edge times.
2422
        unit : `~astropy.units.Unit` or str
2423
            The unit to convert the edges to. Default is 'd' (day).
2424
        interp : str
2425
            Interpolation method used to transform between axis and pixel
2426
            coordinates.  Valid options are 'log', 'lin', and 'sqrt'.
2427
        name : str
2428
            Axis name
2429
2430
        Returns
2431
        -------
2432
        axis : `TimeMapAxis`
2433
            Time map axis.
2434
        """
2435
        unit = u.Unit(unit)
2436
        reference_time = time_min[0]
2437
        edges_min = time_min - reference_time
2438
        edges_max = time_max - reference_time
2439
2440
        return cls(
2441
            edges_min.to(unit),
2442
            edges_max.to(unit),
2443
            reference_time,
2444
            interp=interp,
2445
            name=name,
2446
        )
2447
2448
    # TODO: how configurable should that be? column names?
2449
    @classmethod
2450
    def from_table(cls, table, format="gadf", idx=0):
2451
        """Create time map axis from table
2452
2453
        Parameters
2454
        ----------
2455
        table : `~astropy.table.Table`
2456
            Bin table HDU
2457
        format : {"gadf", "fermi-fgl", "lightcurve"}
2458
            Format to use.
2459
2460
        Returns
2461
        -------
2462
        axis : `TimeMapAxis`
2463
            Time map axis.
2464
        """
2465
        if format == "gadf":
2466
            axcols = table.meta.get("AXCOLS{}".format(idx + 1))
2467
            colnames = axcols.split(",")
2468
            name = colnames[0].replace("_MIN", "").lower()
2469
            reference_time = time_ref_from_dict(table.meta)
2470
            edges_min = np.unique(table[colnames[0]].quantity)
2471
            edges_max = np.unique(table[colnames[1]].quantity)
2472
        elif format == "fermi-fgl":
2473
            meta = table.meta.copy()
2474
            meta["MJDREFF"] = str(meta["MJDREFF"]).replace("D-4", "e-4")
2475
            reference_time = time_ref_from_dict(meta=meta)
2476
            name = "time"
2477
            edges_min = table["Hist_Start"][:-1]
2478
            edges_max = table["Hist_Start"][1:]
2479
        elif format == "lightcurve":
2480
            # TODO: is this a good format? It just supports mjd...
2481
            name = "time"
2482
            scale = table.meta.get("TIMESYS", "utc")
2483
            time_min = Time(table["time_min"].data, format="mjd", scale=scale)
2484
            time_max = Time(table["time_max"].data, format="mjd", scale=scale)
2485
            reference_time = Time("2001-01-01T00:00:00")
2486
            reference_time.format = "mjd"
2487
            edges_min = (time_min - reference_time).to("s")
2488
            edges_max = (time_max - reference_time).to("s")
2489
        else:
2490
            raise ValueError(f"Not a supported format: {format}")
2491
2492
        return cls(
2493
            edges_min=edges_min,
2494
            edges_max=edges_max,
2495
            reference_time=reference_time,
2496
            name=name,
2497
        )
2498
2499
    @classmethod
2500
    def from_gti(cls, gti, name="time"):
2501
        """Create a time axis from an input GTI.
2502
2503
        Parameters
2504
        ----------
2505
        gti : `GTI`
2506
            GTI table
2507
        name : str
2508
            Axis name
2509
2510
        Returns
2511
        -------
2512
        axis : `TimeMapAxis`
2513
            Time map axis.
2514
2515
        """
2516
        tmin = gti.time_start - gti.time_ref
2517
        tmax = gti.time_stop - gti.time_ref
2518
2519
        return cls(
2520
            edges_min=tmin.to("s"),
2521
            edges_max=tmax.to("s"),
2522
            reference_time=gti.time_ref,
2523
            name=name,
2524
        )
2525
2526
    @classmethod
2527
    def from_time_bounds(cls, time_min, time_max, nbin, unit="d", name="time"):
2528
        """Create linearily spaced time axis from bounds
2529
2530
        Parameters
2531
        ----------
2532
        time_min : `~astropy.time.Time`
2533
            Lower bound
2534
        time_max : `~astropy.time.Time`
2535
            Upper bound
2536
        nbin : int
2537
            Number of bins
2538
        name : str
2539
            Name of the axis.
2540
        """
2541
        delta = time_max - time_min
2542
        time_edges = time_min + delta * np.linspace(0, 1, nbin + 1)
2543
        return cls.from_time_edges(
2544
            time_min=time_edges[:-1],
2545
            time_max=time_edges[1:],
2546
            interp="lin",
2547
            unit=unit,
2548
            name=name
2549
        )
2550
2551
    def to_header(self, format="gadf", idx=0):
2552
        """Create FITS header
2553
2554
        Parameters
2555
        ----------
2556
        format : {"ogip"}
2557
            Format specification
2558
        idx : int
2559
            Column index of the axis.
2560
2561
        Returns
2562
        -------
2563
        header : `~astropy.io.fits.Header`
2564
            Header to extend.
2565
        """
2566
        header = fits.Header()
2567
2568
        if format == "gadf":
2569
            key = f"AXCOLS{idx}"
2570
            name = self.name.upper()
2571
            header[key] = f"{name}_MIN,{name}_MAX"
2572
            key_interp = f"INTERP{idx}"
2573
            header[key_interp] = self.interp
2574
2575
            ref_dict = time_ref_to_dict(self.reference_time)
2576
            header.update(ref_dict)
2577
        else:
2578
            raise ValueError(f"Unknown format {format}")
2579
2580
        return header
2581
2582
2583
class LabelMapAxis:
2584
    """Map axis using labels
2585
2586
    Parameters
2587
    ----------
2588
    labels : list of str
2589
        Labels to be used for the axis nodes.
2590
    name : str
2591
        Name of the axis.
2592
2593
    """
2594
2595
    node_type = "label"
2596
2597
    def __init__(self, labels, name=""):
2598
        unique_labels = set(labels)
2599
2600
        if not len(unique_labels) == len(labels):
2601
            raise ValueError("Node labels must be unique")
2602
2603
        self._labels = np.array(labels)
2604
        self._name = name
2605
2606
    @property
2607
    def unit(self):
2608
        """Unit"""
2609
        return u.Unit("")
2610
2611
    @property
2612
    def name(self):
2613
        """Name of the axis"""
2614
        return self._name
2615
2616
    def assert_name(self, required_name):
2617
        """Assert axis name if a specific one is required.
2618
2619
        Parameters
2620
        ----------
2621
        required_name : str
2622
            Required
2623
        """
2624
        if self.name != required_name:
2625
            raise ValueError(
2626
                "Unexpected axis name,"
2627
                f' expected "{required_name}", got: "{self.name}"'
2628
            )
2629
2630
    @property
2631
    def nbin(self):
2632
        """Number of bins"""
2633
        return len(self._labels)
2634
2635
    def pix_to_coord(self, pix):
2636
        """Transform from pixel to axis coordinates.
2637
2638
        Parameters
2639
        ----------
2640
        pix : `~numpy.ndarray`
2641
            Array of pixel coordinate values.
2642
2643
        Returns
2644
        -------
2645
        coord : `~numpy.ndarray`
2646
            Array of axis coordinate values.
2647
        """
2648
        idx = np.round(pix).astype(int)
2649
        return self._labels[idx]
2650
2651
    def coord_to_idx(self, coord, **kwargs):
2652
        """Transform labels to indices
2653
2654
        If the label is not present an error is raised.
2655
2656
        Parameters
2657
        ----------
2658
        coord : `~astropy.time.Time`
2659
            Array of axis coordinate values.
2660
2661
        Returns
2662
        -------
2663
        idx : `~numpy.ndarray`
2664
            Array of bin indices.
2665
        """
2666
        coord = np.array(coord)[..., np.newaxis]
2667
        is_equal = coord == self._labels
2668
2669
        if not np.all(np.any(is_equal, axis=-1)):
2670
            label = coord[~np.any(is_equal, axis=-1)]
2671
            raise ValueError(f"Not a valid label: {label}")
2672
2673
        return np.argmax(is_equal, axis=-1)
2674
2675
    def coord_to_pix(self, coord):
2676
        """Transform from axis labels to pixel coordinates.
2677
2678
        Parameters
2679
        ----------
2680
        coord : `~numpy.ndarray`
2681
            Array of axis label values.
2682
2683
        Returns
2684
        -------
2685
        pix : `~numpy.ndarray`
2686
            Array of pixel coordinate values.
2687
        """
2688
        return self.coord_to_idx(coord).astype("float")
2689
2690 View Code Duplication
    def pix_to_idx(self, pix, clip=False):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
2691
        """Convert pix to idx
2692
2693
        Parameters
2694
        ----------
2695
        pix : tuple of `~numpy.ndarray`
2696
            Pixel coordinates.
2697
        clip : bool
2698
            Choose whether to clip indices to the valid range of the
2699
            axis.  If false then indices for coordinates outside
2700
            the axi range will be set -1.
2701
2702
        Returns
2703
        -------
2704
        idx : tuple `~numpy.ndarray`
2705
            Pixel indices.
2706
        """
2707
        if clip:
2708
            idx = np.clip(pix, 0, self.nbin - 1)
2709
        else:
2710
            condition = (pix < 0) | (pix >= self.nbin)
2711
            idx = np.where(condition, -1, pix)
2712
2713
        return idx
2714
2715
    @property
2716
    def center(self):
2717
        """Center of the label axis"""
2718
        return self._labels
2719
2720
    @property
2721
    def edges(self):
2722
        """Edges of the label axis"""
2723
        raise ValueError("A LabelMapAxis does not define edges")
2724
2725
    @property
2726
    def edges_min(self):
2727
        """Edges of the label axis"""
2728
        return self._labels
2729
2730
    @property
2731
    def edges_max(self):
2732
        """Edges of the label axis"""
2733
        return self._labels
2734
2735
    @property
2736
    def bin_width(self):
2737
        """Bin width is unity"""
2738
        return np.ones(self.nbin)
2739
2740
    @property
2741
    def as_plot_xerr(self):
2742
        """Plot labels"""
2743
        return 0.5 * np.ones(self.nbin)
2744
2745
    @property
2746
    def as_plot_labels(self):
2747
        """Plot labels"""
2748
        return self._labels.tolist()
2749
2750
    @property
2751
    def as_plot_center(self):
2752
        """Plot labels"""
2753
        return np.arange(self.nbin)
2754
2755
    @property
2756
    def as_plot_edges(self):
2757
        """Plot labels"""
2758
        return np.arange(self.nbin + 1) - 0.5
2759
2760
    def format_plot_xaxis(self, ax):
2761
        """Format plot axis.
2762
2763
        Parameters
2764
        ----------
2765
        ax : `~matplotlib.pyplot.Axis`
2766
            Plot axis to format.
2767
2768
        Returns
2769
        -------
2770
        ax : `~matplotlib.pyplot.Axis`
2771
            Formatted plot axis.
2772
        """
2773
        ax.set_xticks(self.as_plot_center)
2774
        ax.set_xticklabels(
2775
            self.as_plot_labels,
2776
            rotation=30,
2777
            ha="right",
2778
            rotation_mode="anchor",
2779
        )
2780
        return ax
2781
2782
    def to_header(self, format="gadf", idx=0):
2783
        """Create FITS header
2784
2785
        Parameters
2786
        ----------
2787
        format : {"ogip"}
2788
            Format specification
2789
        idx : int
2790
            Column index of the axis.
2791
2792
        Returns
2793
        -------
2794
        header : `~astropy.io.fits.Header`
2795
            Header to extend.
2796
        """
2797
        header = fits.Header()
2798
2799
        if format == "gadf":
2800
            key = f"AXCOLS{idx}"
2801
            header[key] = self.name.upper()
2802
        else:
2803
            raise ValueError(f"Unknown format {format}")
2804
2805
        return header
2806
2807
    # TODO: how configurable should that be? column names?
2808
    @classmethod
2809
    def from_table(cls, table, format="gadf", idx=0):
2810
        """Create time map axis from table
2811
2812
        Parameters
2813
        ----------
2814
        table : `~astropy.table.Table`
2815
            Bin table HDU
2816
        format : {"gadf"}
2817
            Format to use.
2818
2819
        Returns
2820
        -------
2821
        axis : `TimeMapAxis`
2822
            Time map axis.
2823
        """
2824
        if format == "gadf":
2825
            colname = table.meta.get("AXCOLS{}".format(idx + 1))
2826
            column = table[colname]
2827
            if not np.issubdtype(column.dtype, np.str_):
2828
                raise TypeError(f"Not a valid dtype for label axis: '{column.dtype}'")
2829
            labels = np.unique(column.data)
2830
        else:
2831
            raise ValueError(f"Not a supported format: {format}")
2832
2833
        return cls(labels=labels, name=colname.lower())
2834
2835
    def __repr__(self):
2836
        str_ = self.__class__.__name__ + "\n"
2837
        str_ += "-" * len(self.__class__.__name__) + "\n\n"
2838
        fmt = "\t{:<10s} : {:<10s}\n"
2839
        str_ += fmt.format("name", self.name)
2840
        str_ += fmt.format("nbins", str(self.nbin))
2841
        str_ += fmt.format("node type", self.node_type)
2842
        str_ += fmt.format("labels", "{0}".format(list(self._labels)))
2843
        return str_.expandtabs(tabsize=2)
2844
2845
    def __eq__(self, other):
2846
        if not isinstance(other, self.__class__):
2847
            return NotImplemented
2848
2849
        name_equal = self.name.upper() == other.name.upper()
2850
        labels_equal = np.all(self.center == other.center)
2851
        return name_equal & labels_equal
2852
2853
    def __ne__(self, other):
2854
        return not self.__eq__(other)
2855
2856
    # TODO: could create sub-labels here using dashes like "label-1-a", etc.
2857
    def upsample(self, *args, **kwargs):
2858
        """Upsample axis"""
2859
        raise NotImplementedError("Upsampling a LabelMapAxis is not supported")
2860
2861
    # TODO: could merge labels here like "label-1-label2", etc.
2862
    def downsample(self, *args, **kwargs):
2863
        """Downsample axis"""
2864
        raise NotImplementedError("Downsampling a LabelMapAxis is not supported")
2865
2866
    # TODO: could merge labels here like "label-1-label2", etc.
2867
    def resample(self, *args, **kwargs):
2868
        """Resample axis"""
2869
        raise NotImplementedError("Resampling a LabelMapAxis is not supported")
2870
2871
    # TODO: could create new labels here like "label-10-a"
2872
    def pad(self, *args, **kwargs):
2873
        """Resample axis"""
2874
        raise NotImplementedError("Padding a LabelMapAxis is not supported")
2875
2876
    def copy(self):
2877
        """Copy axis"""
2878
        return copy.deepcopy(self)
2879
2880
    def slice(self, idx):
2881
        """Create a new axis object by extracting a slice from this axis.
2882
2883
        Parameters
2884
        ----------
2885
        idx : slice
2886
            Slice object selecting a subselection of the axis.
2887
2888
        Returns
2889
        -------
2890
        axis : `~LabelMapAxis`
2891
            Sliced axis object.
2892
        """
2893
        return self.__class__(
2894
            labels=self._labels[idx],
2895
            name=self.name,
2896
        )
2897