gammapy.maps.axes   F
last analyzed

Complexity

Total Complexity 408

Size/Duplication

Total Lines 3121
Duplicated Lines 2.24 %

Importance

Changes 0
Metric Value
eloc 1385
dl 70
loc 3121
rs 0.8
c 0
b 0
f 0
wmc 408

1 Function

Rating   Name   Duplication   Size   Complexity  
A flat_if_equal() 0 5 3

179 Methods

Rating   Name   Duplication   Size   Complexity  
A AxisCoordInterpolator.coord_to_pix() 0 6 1
A AxisCoordInterpolator.__init__() 0 10 2
A AxisCoordInterpolator.pix_to_coord() 0 6 1
A LabelMapAxis.edges_min() 0 4 1
A TimeMapAxis.iter_by_edges() 0 5 2
A TimeMapAxis.slice() 0 19 1
A LabelMapAxis.edges() 0 4 1
A MapAxes.__ne__() 0 2 1
A TimeMapAxis.__hash__() 0 2 1
A LabelMapAxis.pix_to_coord() 0 15 1
A TimeMapAxis.upsample() 0 2 1
A MapAxes.center_coord() 0 4 1
A TimeMapAxis.nbin() 0 4 1
A TimeMapAxis.coord_to_pix() 0 39 3
A TimeMapAxis.interp() 0 4 1
A TimeMapAxis.from_time_bounds() 0 23 1
A TimeMapAxis.copy() 0 14 1
A TimeMapAxis.to_header() 0 30 2
A TimeMapAxis.squash() 0 14 1
A LabelMapAxis.to_header() 0 24 2
A LabelMapAxis.nbin() 0 4 1
A TimeMapAxis.center() 0 4 1
A LabelMapAxis.name() 0 4 1
A TimeMapAxis.time_edges() 0 4 1
A LabelMapAxis.unit() 0 4 1
A TimeMapAxis.format_plot_xaxis() 0 31 2
A TimeMapAxis.as_plot_edges() 0 11 3
A TimeMapAxis.pix_to_idx() 0 3 1
A LabelMapAxis.pix_to_idx() 24 24 2
A TimeMapAxis.from_table() 0 48 4
A TimeMapAxis.downsample() 0 2 1
A LabelMapAxis.as_plot_center() 0 4 1
A LabelMapAxis.__eq__() 0 5 2
A TimeMapAxis.bounds() 0 4 1
A TimeMapAxis.from_gti() 0 25 1
C TimeMapAxis.__init__() 0 45 9
A MapAxes.is_allclose() 0 19 2
A TimeMapAxis.from_time_edges() 0 36 1
A LabelMapAxis.slice() 0 16 1
A TimeMapAxis.edges_min() 0 4 1
A LabelMapAxis.assert_name() 0 11 2
A TimeMapAxis.time_delta() 0 4 1
A LabelMapAxis.bin_width() 0 4 1
A MapAxes.__eq__() 0 5 2
A TimeMapAxis.unit() 0 4 1
A TimeMapAxis.edges() 0 7 2
A TimeMapAxis.as_plot_xerr() 0 15 3
A LabelMapAxis.upsample() 0 3 1
A LabelMapAxis.center() 0 4 1
A LabelMapAxis.edges_max() 0 4 1
A TimeMapAxis.__ne__() 0 2 1
A TimeMapAxis.time_mid() 0 4 1
A LabelMapAxis.downsample() 0 3 1
A LabelMapAxis.format_plot_xaxis() 0 21 1
A LabelMapAxis.coord_to_idx() 0 23 2
A TimeMapAxis._init_copy() 10 10 2
A TimeMapAxis.assert_name() 0 11 2
A TimeMapAxis.reference_time() 0 4 1
A TimeMapAxis.is_contiguous() 0 4 1
A TimeMapAxis.coord_to_idx() 0 28 2
A TimeMapAxis.time_min() 0 4 1
A TimeMapAxis.as_plot_labels() 0 10 2
A LabelMapAxis.is_allclose() 0 19 2
A LabelMapAxis.from_table() 0 26 3
A LabelMapAxis.resample() 0 3 1
A LabelMapAxis.__init__() 0 8 2
A TimeMapAxis.is_allclose() 0 30 3
A TimeMapAxis.edges_max() 0 4 1
A TimeMapAxis.__repr__() 0 12 1
A LabelMapAxis.copy() 0 3 1
A MapAxes.copy() 0 3 1
A LabelMapAxis.as_plot_xerr() 0 4 1
A LabelMapAxis.__repr__() 0 9 1
A TimeMapAxis.time_max() 0 4 1
A TimeMapAxis.to_contiguous() 0 15 1
A TimeMapAxis.name() 0 4 1
A TimeMapAxis.is_aligned() 0 2 1
A LabelMapAxis.coord_to_pix() 0 14 1
A LabelMapAxis.as_plot_labels() 0 4 1
A TimeMapAxis.as_plot_center() 0 9 3
A TimeMapAxis.bin_width() 0 4 1
A LabelMapAxis.as_plot_edges() 0 4 1
A LabelMapAxis.pad() 0 3 1
A TimeMapAxis.time_bounds() 0 5 1
A LabelMapAxis.__ne__() 0 2 1
A TimeMapAxis.__eq__() 0 5 2
B MapAxes.__getitem__() 0 19 8
A MapAxis.from_edges() 0 20 2
A MapAxis._init_copy() 10 10 2
A MapAxis.upsample() 0 31 2
A MapAxis._transform() 0 4 1
A MapAxes.upsample() 0 24 3
A MapAxis.from_stack() 0 23 2
A MapAxes.iter_with_reshape() 0 16 3
A MapAxis.rename() 0 14 1
A MapAxes.__add__() 0 2 1
A MapAxes.names() 0 4 1
F MapAxis.from_table() 0 135 26
A MapAxes.downsample() 0 25 3
A MapAxis.edges() 0 5 1
A MapAxes.shape() 0 4 1
A MapAxes.get_coord() 0 25 4
A MapAxis.slice() 0 26 2
A MapAxis.as_plot_labels() 0 11 2
A MapAxis.from_table_hdu() 0 20 1
A MapAxes.from_default() 0 17 5
A MapAxes.drop() 0 20 3
A MapAxis.unit() 0 4 1
A MapAxes.__init__() 0 12 3
A MapAxis.round() 0 22 2
A MapAxes.slice_by_idx() 0 26 3
A MapAxis.nbin() 0 4 1
A MapAxis.node_type() 0 4 1
B MapAxis.group_table() 0 61 5
B MapAxis.from_bounds() 0 42 6
A MapAxis.edges_min() 0 4 1
A MapAxis.is_energy_axis() 0 3 1
A MapAxes.squash() 0 21 3
A MapAxes.is_flat() 0 5 1
A MapAxes.pix_to_idx() 0 23 2
B MapAxes.to_table_hdu() 0 31 6
A MapAxes.pad() 0 24 3
A MapAxis.to_table_hdu() 0 30 3
A MapAxes.to_header() 0 20 2
A MapAxis.squash() 0 17 1
A MapAxis.center() 0 5 1
A MapAxis.__ne__() 0 2 1
A MapAxes.pix_to_coord() 0 14 1
A MapAxis.use_center_as_plot_labels() 0 7 1
A MapAxes.is_unidimensional() 0 6 1
A MapAxes.resample() 0 38 3
A MapAxis.edges_max() 0 4 1
A MapAxes.index() 0 3 1
A MapAxis.as_plot_scale() 0 6 1
A MapAxes.primary_axis() 0 12 1
F MapAxes.from_table() 0 73 15
A MapAxis.is_allclose() 0 27 4
A MapAxes.reverse() 0 4 1
A MapAxis.is_aligned() 0 23 1
F MapAxis.to_table() 0 73 15
A MapAxis.__hash__() 0 2 1
A MapAxes.assert_names() 0 22 4
A MapAxis.pix_to_idx() 24 24 2
A MapAxis.from_nodes() 0 21 2
A MapAxes.from_table_hdu() 0 20 2
C MapAxes.to_table() 0 54 11
A MapAxis.pad() 0 27 3
A MapAxis.coord_to_idx() 0 30 3
B MapAxis.from_energy_bounds() 0 63 6
A MapAxis.iter_by_edges() 0 5 2
A MapAxes.coord_to_idx() 0 14 1
A MapAxes.bin_volume() 0 14 2
A MapAxis.bounds() 0 8 2
A MapAxes.replace() 0 22 3
A MapAxis.from_energy_edges() 0 34 4
A MapAxis.copy() 0 14 1
A MapAxis.to_node_type() 0 26 3
A MapAxes.__len__() 0 2 1
C MapAxis.to_header() 0 50 9
A MapAxis.format_plot_yaxis() 0 22 1
A MapAxis.nbin_per_decade() 0 13 3
A MapAxis.downsample() 0 39 4
C MapAxis.__init__() 0 40 11
A MapAxis.as_plot_xerr() 0 6 1
A MapAxis.__repr__() 0 13 2
A MapAxis.format_plot_xaxis() 0 24 2
A MapAxis.name() 0 4 1
A MapAxis.interp() 0 4 1
A MapAxis.as_plot_center() 0 4 1
A MapAxis.append() 0 35 5
A MapAxis.bin_width() 0 4 1
A MapAxis.as_plot_edges() 0 4 1
A MapAxes.index_data() 0 15 1
A MapAxis.coord_to_pix() 0 16 1
A MapAxis.assert_name() 0 11 2
A MapAxes.coord_to_pix() 0 14 1
A MapAxis.pix_to_coord() 0 16 1
A MapAxis.__eq__() 0 5 2
A MapAxes.rename_axes() 0 23 4

How to fix   Duplicated Code    Complexity   

Duplicated Code

Duplicate code is one of the most pungent code smells. A rule that is often used is to re-structure code once it is duplicated in three or more places.

Common duplication problems, and corresponding solutions are:

Complexity

 Tip:   Before tackling complexity, make sure that you eliminate any duplication first. This often can reduce the size of classes significantly.

Complex classes like gammapy.maps.axes often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
# Licensed under a 3-clause BSD style license - see LICENSE.rst
2
import copy
3
import inspect
4
from collections.abc import Sequence
5
import numpy as np
6
import scipy
7
import astropy.units as u
8
from astropy.io import fits
9
from astropy.table import Column, Table, hstack
10
from astropy.time import Time
11
from astropy.utils import lazyproperty
12
import matplotlib.pyplot as plt
13
from gammapy.utils.interpolation import interpolation_scale
14
from gammapy.utils.time import time_ref_from_dict, time_ref_to_dict
15
from .utils import INVALID_INDEX, edges_from_lo_hi
16
17
__all__ = ["MapAxes", "MapAxis", "TimeMapAxis", "LabelMapAxis"]
18
19
20
def flat_if_equal(array):
21
    if array.ndim == 2 and np.all(array == array[0]):
22
        return array[0]
23
    else:
24
        return array
25
26
27
class AxisCoordInterpolator:
28
    """Axis coord interpolator"""
29
30
    def __init__(self, edges, interp="lin"):
31
        self.scale = interpolation_scale(interp)
32
        self.x = self.scale(edges)
33
        self.y = np.arange(len(edges), dtype=float)
34
        self.fill_value = "extrapolate"
35
36
        if len(edges) == 1:
37
            self.kind = 0
38
        else:
39
            self.kind = 1
40
41
    def coord_to_pix(self, coord):
42
        """Pix to coord"""
43
        interp_fn = scipy.interpolate.interp1d(
44
            x=self.x, y=self.y, kind=self.kind, fill_value=self.fill_value
45
        )
46
        return interp_fn(self.scale(coord))
47
48
    def pix_to_coord(self, pix):
49
        """Coord to pix"""
50
        interp_fn = scipy.interpolate.interp1d(
51
            x=self.y, y=self.x, kind=self.kind, fill_value=self.fill_value
52
        )
53
        return self.scale.inverse(interp_fn(pix))
54
55
56
PLOT_AXIS_LABEL = {
57
    "energy": "Energy",
58
    "energy_true": "True Energy",
59
    "offset": "FoV Offset",
60
    "rad": "Source Offset",
61
    "migra": "Energy / True Energy",
62
    "fov_lon": "FoV Lon.",
63
    "fov_lat": "FoV Lat.",
64
    "time": "Time",
65
}
66
67
DEFAULT_LABEL_TEMPLATE = "{quantity} [{unit}]"
68
69
70
class MapAxis:
71
    """Class representing an axis of a map.
72
73
    Provides methods for
74
    transforming to/from axis and pixel coordinates.  An axis is
75
    defined by a sequence of node values that lie at the center of
76
    each bin.  The pixel coordinate at each node is equal to its index
77
    in the node array (0, 1, ..).  Bin edges are offset by 0.5 in
78
    pixel coordinates from the nodes such that the lower/upper edge of
79
    the first bin is (-0.5,0.5).
80
81
    Parameters
82
    ----------
83
    nodes : `~numpy.ndarray` or `~astropy.units.Quantity`
84
        Array of node values.  These will be interpreted as either bin
85
        edges or centers according to ``node_type``.
86
    interp : str
87
        Interpolation method used to transform between axis and pixel
88
        coordinates.  Valid options are 'log', 'lin', and 'sqrt'.
89
    name : str
90
        Axis name
91
    node_type : str
92
        Flag indicating whether coordinate nodes correspond to pixel
93
        edges (node_type = 'edges') or pixel centers (node_type =
94
        'center').  'center' should be used where the map values are
95
        defined at a specific coordinate (e.g. differential
96
        quantities). 'edges' should be used where map values are
97
        defined by an integral over coordinate intervals (e.g. a
98
        counts histogram).
99
    unit : str
100
        String specifying the data units.
101
    """
102
103
    # TODO: Cache an interpolation object?
104
    def __init__(self, nodes, interp="lin", name="", node_type="edges", unit=""):
105
106
        if not isinstance(name, str):
107
            raise TypeError(f"Name must be a string, got: {type(name)!r}")
108
109
        if len(nodes) != len(np.unique(nodes)):
110
            raise ValueError("MapAxis: node values must be unique")
111
112
        if ~(np.all(nodes == np.sort(nodes)) or np.all(nodes[::-1] == np.sort(nodes))):
113
            raise ValueError("MapAxis: node values must be sorted")
114
115
        if isinstance(nodes, u.Quantity):
116
            unit = nodes.unit if nodes.unit is not None else ""
117
            nodes = nodes.value
118
        else:
119
            nodes = np.array(nodes)
120
121
        self._name = name
122
        self._unit = u.Unit(unit)
123
        self._nodes = nodes.astype(float)
124
        self._node_type = node_type
125
        self._interp = interp
126
127
        if (self._nodes < 0).any() and interp != "lin":
128
            raise ValueError(
129
                f"Interpolation scaling {interp!r} only support for positive node values."
130
            )
131
132
        # Set pixel coordinate of first node
133
        if node_type == "edges":
134
            self._pix_offset = -0.5
135
            nbin = len(nodes) - 1
136
        elif node_type == "center":
137
            self._pix_offset = 0.0
138
            nbin = len(nodes)
139
        else:
140
            raise ValueError(f"Invalid node type: {node_type!r}")
141
142
        self._nbin = nbin
143
        self._use_center_as_plot_labels = None
144
145
    def assert_name(self, required_name):
146
        """Assert axis name if a specific one is required.
147
148
        Parameters
149
        ----------
150
        required_name : str
151
            Required
152
        """
153
        if self.name != required_name:
154
            raise ValueError(
155
                "Unexpected axis name,"
156
                f' expected "{required_name}", got: "{self.name}"'
157
            )
158
159
    def is_aligned(self, other, atol=2e-2):
160
        """Check if other map axis is aligned.
161
162
        Two axes are aligned if their center coordinate values map to integers
163
        on the other axes as well and if the interpolation modes are equivalent.
164
165
        Parameters
166
        ----------
167
        other : `MapAxis`
168
            Other map axis.
169
        atol : float
170
            Absolute numerical tolerance for the comparison measured in bins.
171
172
        Returns
173
        -------
174
        aligned : bool
175
            Whether the axes are aligned
176
        """
177
        pix = self.coord_to_pix(other.center)
178
        pix_other = other.coord_to_pix(self.center)
179
        pix_all = np.append(pix, pix_other)
180
        aligned = np.allclose(np.round(pix_all) - pix_all, 0, atol=atol)
181
        return aligned and self.interp == other.interp
182
183
    def is_allclose(self, other, **kwargs):
184
        """Check if other map axis is all close.
185
186
        Parameters
187
        ----------
188
        other : `MapAxis`
189
            Other map axis
190
        **kwargs : dict
191
            Keyword arguments forwarded to `~numpy.allclose`
192
193
        Returns
194
        -------
195
        is_allclose : bool
196
            Whether other axis is allclose
197
        """
198
        if not isinstance(other, self.__class__):
199
            return TypeError(f"Cannot compare {type(self)} and {type(other)}")
200
201
        if self.edges.shape != other.edges.shape:
202
            return False
203
        if not self.unit.is_equivalent(other.unit):
204
            return False
205
        return (
206
            np.allclose(self.edges, other.edges, **kwargs)
207
            and self._node_type == other._node_type
208
            and self._interp == other._interp
209
            and self.name.upper() == other.name.upper()
210
        )
211
212
    def __eq__(self, other):
213
        if not isinstance(other, self.__class__):
214
            return False
215
216
        return self.is_allclose(other, rtol=1e-6, atol=1e-6)
217
218
    def __ne__(self, other):
219
        return not self.__eq__(other)
220
221
    def __hash__(self):
222
        return id(self)
223
224
    @lazyproperty
225
    def _transform(self):
226
        """Interpolate coordinates to pixel"""
227
        return AxisCoordInterpolator(edges=self._nodes, interp=self.interp)
228
229
    @property
230
    def is_energy_axis(self):
231
        return self.name in ["energy", "energy_true"]
232
233
    @property
234
    def interp(self):
235
        """Interpolation scale of the axis."""
236
        return self._interp
237
238
    @property
239
    def name(self):
240
        """Name of the axis."""
241
        return self._name
242
243
    @lazyproperty
244
    def edges(self):
245
        """Return array of bin edges."""
246
        pix = np.arange(self.nbin + 1, dtype=float) - 0.5
247
        return u.Quantity(self.pix_to_coord(pix), self._unit, copy=False)
248
249
    @property
250
    def edges_min(self):
251
        """Return array of bin edges max values."""
252
        return self.edges[:-1]
253
254
    @property
255
    def edges_max(self):
256
        """Return array of bin edges min values."""
257
        return self.edges[1:]
258
259
    @property
260
    def bounds(self):
261
        """Bounds of the axis (~astropy.units.Quantity)"""
262
        idx = [0, -1]
263
        if self.node_type == "edges":
264
            return self.edges[idx]
265
        else:
266
            return self.center[idx]
267
268
    @property
269
    def as_plot_xerr(self):
270
        """Return tuple of xerr to be used with plt.errorbar()"""
271
        return (
272
            self.center - self.edges_min,
273
            self.edges_max - self.center,
274
        )
275
276
    @property
277
    def use_center_as_plot_labels(self):
278
        """Use center as plot labels"""
279
        if self._use_center_as_plot_labels is not None:
280
            return self._use_center_as_plot_labels
281
282
        return self.node_type == "center"
283
284
    @use_center_as_plot_labels.setter
285
    def use_center_as_plot_labels(self, value):
286
        """Use center as plot labels"""
287
        self._use_center_as_plot_labels = bool(value)
288
289
    @property
290
    def as_plot_labels(self):
291
        """Return list of axis plot labels"""
292
        if self.use_center_as_plot_labels:
293
            labels = [f"{val:.2e}" for val in self.center]
294
        else:
295
            labels = [
296
                f"{val_min:.2e} - {val_max:.2e}"
297
                for val_min, val_max in self.iter_by_edges
298
            ]
299
        return labels
300
301
    @property
302
    def as_plot_edges(self):
303
        """Plot edges"""
304
        return self.edges
305
306
    @property
307
    def as_plot_center(self):
308
        """Plot center"""
309
        return self.center
310
311
    @property
312
    def as_plot_scale(self):
313
        """Plot axis scale"""
314
        mpl_scale = {"lin": "linear", "sqrt": "linear", "log": "log"}
315
316
        return mpl_scale[self.interp]
317
318
    def to_node_type(self, node_type):
319
        """Return MapAxis copy changing its node type to node_type.
320
321
        Parameters
322
        ----------
323
        node_type : str 'edges' or 'center'
324
            the target node type
325
326
        Returns
327
        -------
328
        axis : `~gammapy.maps.MapAxis`
329
            the new MapAxis
330
        """
331
        if node_type == self.node_type:
332
            return self
333
        else:
334
            if node_type == "center":
335
                nodes = self.center
336
            else:
337
                nodes = self.edges
338
            return self.__class__(
339
                nodes=nodes,
340
                interp=self.interp,
341
                name=self.name,
342
                node_type=node_type,
343
                unit=self.unit,
344
            )
345
346
    def rename(self, new_name):
347
        """Rename the axis.
348
349
        Parameters
350
        ----------
351
        new_name : str
352
            The new name for the axis.
353
354
        Returns
355
        -------
356
        axis : `~gammapy.maps.MapAxis`
357
            Renamed MapAxis
358
        """
359
        return self.copy(name=new_name)
360
361
    def format_plot_xaxis(self, ax):
362
        """Format plot axis
363
364
        Parameters
365
        ----------
366
        ax : `~matplotlib.pyplot.Axis`
367
            Plot axis to format
368
369
        Returns
370
        -------
371
        ax : `~matplotlib.pyplot.Axis`
372
            Formatted plot axis
373
        """
374
        ax.set_xscale(self.as_plot_scale)
375
376
        xlabel = DEFAULT_LABEL_TEMPLATE.format(
377
            quantity=PLOT_AXIS_LABEL.get(self.name, self.name.capitalize()),
378
            unit=ax.xaxis.units,
379
        )
380
        ax.set_xlabel(xlabel)
381
        xmin, xmax = self.bounds
382
        if not xmin == xmax:
383
            ax.set_xlim(self.bounds)
384
        return ax
385
386
    def format_plot_yaxis(self, ax):
387
        """Format plot axis
388
389
        Parameters
390
        ----------
391
        ax : `~matplotlib.pyplot.Axis`
392
            Plot axis to format
393
394
        Returns
395
        -------
396
        ax : `~matplotlib.pyplot.Axis`
397
            Formatted plot axis
398
        """
399
        ax.set_yscale(self.as_plot_scale)
400
401
        ylabel = DEFAULT_LABEL_TEMPLATE.format(
402
            quantity=PLOT_AXIS_LABEL.get(self.name, self.name.capitalize()),
403
            unit=ax.yaxis.units,
404
        )
405
        ax.set_ylabel(ylabel)
406
        ax.set_ylim(self.bounds)
407
        return ax
408
409
    @property
410
    def iter_by_edges(self):
411
        """Iterate by intervals defined by the edges"""
412
        for value_min, value_max in zip(self.edges[:-1], self.edges[1:]):
413
            yield (value_min, value_max)
414
415
    @lazyproperty
416
    def center(self):
417
        """Return array of bin centers."""
418
        pix = np.arange(self.nbin, dtype=float)
419
        return u.Quantity(self.pix_to_coord(pix), self._unit, copy=False)
420
421
    @lazyproperty
422
    def bin_width(self):
423
        """Array of bin widths."""
424
        return np.diff(self.edges)
425
426
    @property
427
    def nbin(self):
428
        """Return number of bins."""
429
        return self._nbin
430
431
    @property
432
    def nbin_per_decade(self):
433
        """Return number of bins."""
434
        if self.interp != "log":
435
            raise ValueError("Bins per decade can only be computed for log-spaced axes")
436
437
        if self.node_type == "edges":
438
            values = self.edges
439
        else:
440
            values = self.center
441
442
        ndecades = np.log10(values.max() / values.min())
443
        return (self._nbin / ndecades).value
444
445
    @property
446
    def node_type(self):
447
        """Return node type ('center' or 'edges')."""
448
        return self._node_type
449
450
    @property
451
    def unit(self):
452
        """Return coordinate axis unit."""
453
        return self._unit
454
455
    @classmethod
456
    def from_bounds(cls, lo_bnd, hi_bnd, nbin, **kwargs):
457
        """Generate an axis object from a lower/upper bound and number of bins.
458
459
        If node_type = 'edges' then bounds correspond to the
460
        lower and upper bound of the first and last bin.  If node_type
461
        = 'center' then bounds correspond to the centers of the first
462
        and last bin.
463
464
        Parameters
465
        ----------
466
        lo_bnd : float
467
            Lower bound of first axis bin.
468
        hi_bnd : float
469
            Upper bound of last axis bin.
470
        nbin : int
471
            Number of bins.
472
        interp : {'lin', 'log', 'sqrt'}
473
            Interpolation method used to transform between axis and pixel
474
            coordinates.  Default: 'lin'.
475
        """
476
        nbin = int(nbin)
477
        interp = kwargs.setdefault("interp", "lin")
478
        node_type = kwargs.setdefault("node_type", "edges")
479
480
        if node_type == "edges":
481
            nnode = nbin + 1
482
        elif node_type == "center":
483
            nnode = nbin
484
        else:
485
            raise ValueError(f"Invalid node type: {node_type!r}")
486
487
        if interp == "lin":
488
            nodes = np.linspace(lo_bnd, hi_bnd, nnode)
489
        elif interp == "log":
490
            nodes = np.exp(np.linspace(np.log(lo_bnd), np.log(hi_bnd), nnode))
491
        elif interp == "sqrt":
492
            nodes = np.linspace(lo_bnd**0.5, hi_bnd**0.5, nnode) ** 2.0
493
        else:
494
            raise ValueError(f"Invalid interp: {interp}")
495
496
        return cls(nodes, **kwargs)
497
498
    @classmethod
499
    def from_energy_edges(cls, energy_edges, unit=None, name=None, interp="log"):
500
        """Make an energy axis from adjacent edges.
501
502
        Parameters
503
        ----------
504
        energy_edges : `~astropy.units.Quantity`, float
505
            Energy edges
506
        unit : `~astropy.units.Unit`
507
            Energy unit
508
        name : str
509
            Name of the energy axis, either 'energy' or 'energy_true'
510
        interp: str
511
            interpolation mode. Default is 'log'.
512
513
        Returns
514
        -------
515
        axis : `MapAxis`
516
            Axis with name "energy" and interp "log".
517
        """
518
        energy_edges = u.Quantity(energy_edges, unit)
519
520
        if not energy_edges.unit.is_equivalent("TeV"):
521
            raise ValueError(
522
                f"Please provide a valid energy unit, got {energy_edges.unit} instead."
523
            )
524
525
        if name is None:
526
            name = "energy"
527
528
        if name not in ["energy", "energy_true"]:
529
            raise ValueError("Energy axis can only be named 'energy' or 'energy_true'")
530
531
        return cls.from_edges(energy_edges, unit=unit, interp=interp, name=name)
532
533
    @classmethod
534
    def from_energy_bounds(
535
        cls,
536
        energy_min,
537
        energy_max,
538
        nbin,
539
        unit=None,
540
        per_decade=False,
541
        name=None,
542
        node_type="edges",
543
    ):
544
        """Make an energy axis.
545
546
        Used frequently also to make energy grids, by making
547
        the axis, and then using ``axis.center`` or ``axis.edges``.
548
549
        Parameters
550
        ----------
551
        energy_min, energy_max : `~astropy.units.Quantity`, float
552
            Energy range
553
        nbin : int
554
            Number of bins
555
        unit : `~astropy.units.Unit`
556
            Energy unit
557
        per_decade : bool
558
            Whether `nbin` is given per decade.
559
        name : str
560
            Name of the energy axis, either 'energy' or 'energy_true'
561
562
        Returns
563
        -------
564
        axis : `MapAxis`
565
            Axis with name "energy" and interp "log".
566
        """
567
        energy_min = u.Quantity(energy_min, unit)
568
        energy_max = u.Quantity(energy_max, unit)
569
570
        if unit is None:
571
            unit = energy_max.unit
572
            energy_min = energy_min.to(unit)
573
574
        if not energy_max.unit.is_equivalent("TeV"):
575
            raise ValueError(
576
                f"Please provide a valid energy unit, got {energy_max.unit} instead."
577
            )
578
579
        if per_decade:
580
            nbin = np.ceil(np.log10(energy_max / energy_min).value * nbin)
581
582
        if name is None:
583
            name = "energy"
584
585
        if name not in ["energy", "energy_true"]:
586
            raise ValueError("Energy axis can only be named 'energy' or 'energy_true'")
587
588
        return cls.from_bounds(
589
            energy_min.value,
590
            energy_max.value,
591
            nbin=nbin,
592
            unit=unit,
593
            interp="log",
594
            name=name,
595
            node_type=node_type,
596
        )
597
598
    @classmethod
599
    def from_nodes(cls, nodes, **kwargs):
600
        """Generate an axis object from a sequence of nodes (bin centers).
601
602
        This will create a sequence of bins with edges half-way
603
        between the node values.  This method should be used to
604
        construct an axis where the bin center should lie at a
605
        specific value (e.g. a map of a continuous function).
606
607
        Parameters
608
        ----------
609
        nodes : `~numpy.ndarray`
610
            Axis nodes (bin center).
611
        interp : {'lin', 'log', 'sqrt'}
612
            Interpolation method used to transform between axis and pixel
613
            coordinates.  Default: 'lin'.
614
        """
615
        if len(nodes) < 1:
616
            raise ValueError("Nodes array must have at least one element.")
617
618
        return cls(nodes, node_type="center", **kwargs)
619
620
    @classmethod
621
    def from_edges(cls, edges, **kwargs):
622
        """Generate an axis object from a sequence of bin edges.
623
624
        This method should be used to construct an axis where the bin
625
        edges should lie at specific values (e.g. a histogram).  The
626
        number of bins will be one less than the number of edges.
627
628
        Parameters
629
        ----------
630
        edges : `~numpy.ndarray`
631
            Axis bin edges.
632
        interp : {'lin', 'log', 'sqrt'}
633
            Interpolation method used to transform between axis and pixel
634
            coordinates.  Default: 'lin'.
635
        """
636
        if len(edges) < 2:
637
            raise ValueError("Edges array must have at least two elements.")
638
639
        return cls(edges, node_type="edges", **kwargs)
640
641
    def append(self, axis):
642
        """Append another map axis to this axis
643
644
        Name, interp type and node type must agree between the axes. If the node
645
        type is "edges", the edges must be contiguous and non-overlapping.
646
647
        Parameters
648
        ----------
649
        axis : `MapAxis`
650
            Axis to append.
651
652
        Returns
653
        -------
654
        axis : `MapAxis`
655
            Appended axis
656
        """
657
        if self.node_type != axis.node_type:
658
            raise ValueError(
659
                f"Node type must agree, got {self.node_type} and {axis.node_type}"
660
            )
661
662
        if self.name != axis.name:
663
            raise ValueError(f"Names must agree, got {self.name} and {axis.name} ")
664
665
        if self.interp != axis.interp:
666
            raise ValueError(
667
                f"Interp type must agree, got {self.interp} and {axis.interp}"
668
            )
669
670
        if self.node_type == "edges":
671
            edges = np.append(self.edges, axis.edges[1:])
672
            return self.from_edges(edges=edges, interp=self.interp, name=self.name)
673
        else:
674
            nodes = np.append(self.center, axis.center)
675
            return self.from_nodes(nodes=nodes, interp=self.interp, name=self.name)
676
677
    def pad(self, pad_width):
678
        """Pad axis by a given number of pixels
679
680
        Parameters
681
        ----------
682
        pad_width : int or tuple of int
683
            A single int pads in both direction of the axis, a tuple specifies,
684
            which number of bins to pad at the low and high edge of the axis.
685
686
        Returns
687
        -------
688
        axis : `MapAxis`
689
            Padded axis
690
        """
691
        if isinstance(pad_width, tuple):
692
            pad_low, pad_high = pad_width
693
        else:
694
            pad_low, pad_high = pad_width, pad_width
695
696
        if self.node_type == "edges":
697
            pix = np.arange(-pad_low, self.nbin + pad_high + 1) - 0.5
698
            edges = self.pix_to_coord(pix)
699
            return self.from_edges(edges=edges, interp=self.interp, name=self.name)
700
        else:
701
            pix = np.arange(-pad_low, self.nbin + pad_high)
702
            nodes = self.pix_to_coord(pix)
703
            return self.from_nodes(nodes=nodes, interp=self.interp, name=self.name)
704
705
    @classmethod
706
    def from_stack(cls, axes):
707
        """Create a map axis by merging a list of other map axes.
708
709
        If the node type is "edges" the bin edges in the provided axes must be
710
        contiguous and non-overlapping.
711
712
        Parameters
713
        ----------
714
        axes : list of `MapAxis`
715
            List of map axis to merge.
716
717
        Returns
718
        -------
719
        axis : `MapAxis`
720
            Merged axis
721
        """
722
        ax_stacked = axes[0]
723
724
        for ax in axes[1:]:
725
            ax_stacked = ax_stacked.append(ax)
726
727
        return ax_stacked
728
729
    def pix_to_coord(self, pix):
730
        """Transform from pixel to axis coordinates.
731
732
        Parameters
733
        ----------
734
        pix : `~numpy.ndarray`
735
            Array of pixel coordinate values.
736
737
        Returns
738
        -------
739
        coord : `~numpy.ndarray`
740
            Array of axis coordinate values.
741
        """
742
        pix = pix - self._pix_offset
743
        values = self._transform.pix_to_coord(pix=pix)
744
        return u.Quantity(values, unit=self.unit, copy=False)
745
746 View Code Duplication
    def pix_to_idx(self, pix, clip=False):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
747
        """Convert pix to idx
748
749
        Parameters
750
        ----------
751
        pix : `~numpy.ndarray`
752
            Pixel coordinates.
753
        clip : bool
754
            Choose whether to clip indices to the valid range of the
755
            axis.  If false then indices for coordinates outside
756
            the axi range will be set -1.
757
758
        Returns
759
        -------
760
        idx : `~numpy.ndarray`
761
            Pixel indices.
762
        """
763
        if clip:
764
            idx = np.clip(pix, 0, self.nbin - 1)
765
        else:
766
            condition = (pix < 0) | (pix >= self.nbin)
767
            idx = np.where(condition, -1, pix)
768
769
        return idx
770
771
    def coord_to_pix(self, coord):
772
        """Transform from axis to pixel coordinates.
773
774
        Parameters
775
        ----------
776
        coord : `~numpy.ndarray`
777
            Array of axis coordinate values.
778
779
        Returns
780
        -------
781
        pix : `~numpy.ndarray`
782
            Array of pixel coordinate values.
783
        """
784
        coord = u.Quantity(coord, self.unit, copy=False).value
785
        pix = self._transform.coord_to_pix(coord=coord)
786
        return np.array(pix + self._pix_offset, ndmin=1)
787
788
    def coord_to_idx(self, coord, clip=False):
789
        """Transform from axis coordinate to bin index.
790
791
        Parameters
792
        ----------
793
        coord : `~numpy.ndarray`
794
            Array of axis coordinate values.
795
        clip : bool
796
            Choose whether to clip the index to the valid range of the
797
            axis.  If false then indices for values outside the axis
798
            range will be set -1.
799
800
        Returns
801
        -------
802
        idx : `~numpy.ndarray`
803
            Array of bin indices.
804
        """
805
        coord = u.Quantity(coord, self.unit, copy=False, ndmin=1).value
806
        edges = self.edges.value
807
        idx = np.digitize(coord, edges) - 1
808
809
        if clip:
810
            idx = np.clip(idx, 0, self.nbin - 1)
811
        else:
812
            with np.errstate(invalid="ignore"):
813
                idx[coord > edges[-1]] = INVALID_INDEX.int
814
815
        idx[~np.isfinite(coord)] = INVALID_INDEX.int
816
817
        return idx
818
819
    def slice(self, idx):
820
        """Create a new axis object by extracting a slice from this axis.
821
822
        Parameters
823
        ----------
824
        idx : slice
825
            Slice object selecting a subselection of the axis.
826
827
        Returns
828
        -------
829
        axis : `~MapAxis`
830
            Sliced axis object.
831
        """
832
        center = self.center[idx].value
833
        idx = self.coord_to_idx(center)
834
        # For edge nodes we need to keep N+1 nodes
835
        if self._node_type == "edges":
836
            idx = tuple(list(idx) + [1 + idx[-1]])
837
838
        nodes = self._nodes[(idx,)]
839
        return MapAxis(
840
            nodes,
841
            interp=self._interp,
842
            name=self._name,
843
            node_type=self._node_type,
844
            unit=self._unit,
845
        )
846
847
    def squash(self):
848
        """Create a new axis object by squashing the axis into one bin.
849
850
        Returns
851
        -------
852
        axis : `~MapAxis`
853
            Sliced axis object.
854
        """
855
        # TODO: Decide on handling node_type=center
856
        # See https://github.com/gammapy/gammapy/issues/1952
857
        return MapAxis.from_bounds(
858
            lo_bnd=self.edges[0].value,
859
            hi_bnd=self.edges[-1].value,
860
            nbin=1,
861
            interp=self._interp,
862
            name=self._name,
863
            unit=self._unit,
864
        )
865
866
    def __repr__(self):
867
        str_ = self.__class__.__name__
868
        str_ += "\n\n"
869
        fmt = "\t{:<10s} : {:<10s}\n"
870
        str_ += fmt.format("name", self.name)
871
        str_ += fmt.format("unit", "{!r}".format(str(self.unit)))
872
        str_ += fmt.format("nbins", str(self.nbin))
873
        str_ += fmt.format("node type", self.node_type)
874
        vals = self.edges if self.node_type == "edges" else self.center
875
        str_ += fmt.format(f"{self.node_type} min", "{:.1e}".format(vals.min()))
876
        str_ += fmt.format(f"{self.node_type} max", "{:.1e}".format(vals.max()))
877
        str_ += fmt.format("interp", self._interp)
878
        return str_
879
880 View Code Duplication
    def _init_copy(self, **kwargs):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
881
        """Init map axis instance by copying missing init arguments from self."""
882
        argnames = inspect.getfullargspec(self.__init__).args
883
        argnames.remove("self")
884
885
        for arg in argnames:
886
            value = getattr(self, "_" + arg)
887
            kwargs.setdefault(arg, copy.deepcopy(value))
888
889
        return self.__class__(**kwargs)
890
891
    def copy(self, **kwargs):
892
        """Copy `MapAxis` instance and overwrite given attributes.
893
894
        Parameters
895
        ----------
896
        **kwargs : dict
897
            Keyword arguments to overwrite in the map axis constructor.
898
899
        Returns
900
        -------
901
        copy : `MapAxis`
902
            Copied map axis.
903
        """
904
        return self._init_copy(**kwargs)
905
906
    def round(self, coord, clip=False):
907
        """Round coord to nearest axis edge.
908
909
        Parameters
910
        ----------
911
        coord : `~astropy.units.Quantity`
912
            Coordinates
913
        clip : bool
914
            Choose whether to clip indices to the valid range of the axis.
915
916
        Returns
917
        -------
918
        coord : `~astropy.units.Quantity`
919
            Rounded coordinates
920
        """
921
        edges_pix = self.coord_to_pix(coord)
922
923
        if clip:
924
            edges_pix = np.clip(edges_pix, -0.5, self.nbin - 0.5)
925
926
        edges_idx = np.round(edges_pix + 0.5) - 0.5
927
        return self.pix_to_coord(edges_idx)
928
929
    def group_table(self, edges):
930
        """Compute bin groups table for the map axis, given coarser bin edges.
931
932
        Parameters
933
        ----------
934
        edges : `~astropy.units.Quantity`
935
            Group bin edges.
936
937
        Returns
938
        -------
939
        groups : `~astropy.table.Table`
940
            Map axis group table.
941
        """
942
        # TODO: try to simplify this code
943
        if self.node_type != "edges":
944
            raise ValueError("Only edge based map axis can be grouped")
945
946
        edges_pix = self.coord_to_pix(edges)
947
        edges_pix = np.clip(edges_pix, -0.5, self.nbin - 0.5)
948
        edges_idx = np.round(edges_pix + 0.5) - 0.5
949
        edges_idx = np.unique(edges_idx)
950
        edges_ref = self.pix_to_coord(edges_idx)
951
952
        groups = Table()
953
        groups[f"{self.name}_min"] = edges_ref[:-1]
954
        groups[f"{self.name}_max"] = edges_ref[1:]
955
956
        groups["idx_min"] = (edges_idx[:-1] + 0.5).astype(int)
957
        groups["idx_max"] = (edges_idx[1:] - 0.5).astype(int)
958
959
        if len(groups) == 0:
960
            raise ValueError("No overlap between reference and target edges.")
961
962
        groups["bin_type"] = "normal   "
963
964
        edge_idx_start, edge_ref_start = edges_idx[0], edges_ref[0]
965
        if edge_idx_start > 0:
966
            underflow = {
967
                "bin_type": "underflow",
968
                "idx_min": 0,
969
                "idx_max": edge_idx_start,
970
                f"{self.name}_min": self.pix_to_coord(-0.5),
971
                f"{self.name}_max": edge_ref_start,
972
            }
973
            groups.insert_row(0, vals=underflow)
974
975
        edge_idx_end, edge_ref_end = edges_idx[-1], edges_ref[-1]
976
977
        if edge_idx_end < (self.nbin - 0.5):
978
            overflow = {
979
                "bin_type": "overflow",
980
                "idx_min": edge_idx_end + 1,
981
                "idx_max": self.nbin - 1,
982
                f"{self.name}_min": edge_ref_end,
983
                f"{self.name}_max": self.pix_to_coord(self.nbin - 0.5),
984
            }
985
            groups.add_row(vals=overflow)
986
987
        group_idx = Column(np.arange(len(groups)))
988
        groups.add_column(group_idx, name="group_idx", index=0)
989
        return groups
990
991
    def upsample(self, factor):
992
        """Upsample map axis by a given factor.
993
994
        When up-sampling for each node specified in the axis, the corresponding
995
        number of sub-nodes are introduced and preserving the initial nodes. For
996
        node type "edges" this results in nbin * factor new bins. For node type
997
        "center" this results in (nbin - 1) * factor + 1 new bins.
998
999
        Parameters
1000
        ----------
1001
        factor : int
1002
            Upsampling factor.
1003
1004
        Returns
1005
        -------
1006
        axis : `MapAxis`
1007
            Usampled map axis.
1008
1009
        """
1010
        if self.node_type == "edges":
1011
            pix = self.coord_to_pix(self.edges)
1012
            nbin = int(self.nbin * factor) + 1
1013
            pix_new = np.linspace(pix.min(), pix.max(), nbin)
1014
            edges = self.pix_to_coord(pix_new)
1015
            return self.from_edges(edges, name=self.name, interp=self.interp)
1016
        else:
1017
            pix = self.coord_to_pix(self.center)
1018
            nbin = int((self.nbin - 1) * factor) + 1
1019
            pix_new = np.linspace(pix.min(), pix.max(), nbin)
1020
            nodes = self.pix_to_coord(pix_new)
1021
            return self.from_nodes(nodes, name=self.name, interp=self.interp)
1022
1023
    def downsample(self, factor):
1024
        """Downsample map axis by a given factor.
1025
1026
        When down-sampling each n-th (given by the factor) bin is selected from
1027
        the axis while preserving the axis limits. For node type "edges" this
1028
        requires nbin to be dividable by the factor, for node type "center" this
1029
        requires nbin - 1 to be dividable by the factor.
1030
1031
        Parameters
1032
        ----------
1033
        factor : int
1034
            Downsampling factor.
1035
1036
1037
        Returns
1038
        -------
1039
        axis : `MapAxis`
1040
            Downsampled map axis.
1041
        """
1042
        if self.node_type == "edges":
1043
            nbin = self.nbin / factor
1044
1045
            if np.mod(nbin, 1) > 0:
1046
                raise ValueError(
1047
                    f"Number of {self.name} bins is not divisible by {factor}"
1048
                )
1049
1050
            edges = self.edges[::factor]
1051
            return self.from_edges(edges, name=self.name, interp=self.interp)
1052
        else:
1053
            nbin = (self.nbin - 1) / factor
1054
1055
            if np.mod(nbin, 1) > 0:
1056
                raise ValueError(
1057
                    f"Number of {self.name} bins - 1 is not divisible by {factor}"
1058
                )
1059
1060
            nodes = self.center[::factor]
1061
            return self.from_nodes(nodes, name=self.name, interp=self.interp)
1062
1063
    def to_header(self, format="ogip", idx=0):
1064
        """Create FITS header
1065
1066
        Parameters
1067
        ----------
1068
        format : {"ogip"}
1069
            Format specification
1070
        idx : int
1071
            Column index of the axis.
1072
1073
        Returns
1074
        -------
1075
        header : `~astropy.io.fits.Header`
1076
            Header to extend.
1077
        """
1078
        header = fits.Header()
1079
1080
        if format in ["ogip", "ogip-sherpa"]:
1081
            header["EXTNAME"] = "EBOUNDS", "Name of this binary table extension"
1082
            header["TELESCOP"] = "DUMMY", "Mission/satellite name"
1083
            header["INSTRUME"] = "DUMMY", "Instrument/detector"
1084
            header["FILTER"] = "None", "Filter information"
1085
            header["CHANTYPE"] = "PHA", "Type of channels (PHA, PI etc)"
1086
            header["DETCHANS"] = self.nbin, "Total number of detector PHA channels"
1087
            header["HDUCLASS"] = "OGIP", "Organisation devising file format"
1088
            header["HDUCLAS1"] = "RESPONSE", "File relates to response of instrument"
1089
            header["HDUCLAS2"] = "EBOUNDS", "This is an EBOUNDS extension"
1090
            header["HDUVERS"] = "1.2.0", "Version of file format"
1091
        elif format in ["gadf", "fgst-ccube", "fgst-template"]:
1092
            key = f"AXCOLS{idx}"
1093
            name = self.name.upper()
1094
1095
            if self.name == "energy" and self.node_type == "edges":
1096
                header[key] = "E_MIN,E_MAX"
1097
            elif self.name == "energy" and self.node_type == "center":
1098
                header[key] = "ENERGY"
1099
            elif self.node_type == "edges":
1100
                header[key] = f"{name}_MIN,{name}_MAX"
1101
            elif self.node_type == "center":
1102
                header[key] = name
1103
            else:
1104
                raise ValueError(f"Invalid node type {self.node_type!r}")
1105
1106
            key_interp = f"INTERP{idx}"
1107
            header[key_interp] = self.interp
1108
1109
        else:
1110
            raise ValueError(f"Unknown format {format}")
1111
1112
        return header
1113
1114
    def to_table(self, format="ogip"):
1115
        """Convert `~astropy.units.Quantity` to OGIP ``EBOUNDS`` extension.
1116
1117
        See https://heasarc.gsfc.nasa.gov/docs/heasarc/caldb/docs/memos/cal_gen_92_002/cal_gen_92_002.html#tth_sEc3.2  # noqa: E501
1118
1119
        The 'ogip-sherpa' format is equivalent to 'ogip' but uses keV energy units.
1120
1121
        Parameters
1122
        ----------
1123
        format : {"ogip", "ogip-sherpa", "gadf-dl3", "gtpsf"}
1124
            Format specification
1125
1126
        Returns
1127
        -------
1128
        table : `~astropy.table.Table`
1129
            Table HDU
1130
        """
1131
        table = Table()
1132
        edges = self.edges
1133
1134
        if format in ["ogip", "ogip-sherpa"]:
1135
            self.assert_name("energy")
1136
1137
            if format == "ogip-sherpa":
1138
                edges = edges.to("keV")
1139
1140
            table["CHANNEL"] = np.arange(self.nbin, dtype=np.int16)
1141
            table["E_MIN"] = edges[:-1]
1142
            table["E_MAX"] = edges[1:]
1143
        elif format in ["ogip-arf", "ogip-arf-sherpa"]:
1144
            self.assert_name("energy_true")
1145
1146
            if format == "ogip-arf-sherpa":
1147
                edges = edges.to("keV")
1148
1149
            table["ENERG_LO"] = edges[:-1]
1150
            table["ENERG_HI"] = edges[1:]
1151
        elif format == "gadf-sed":
1152
            if self.is_energy_axis:
1153
                table["e_ref"] = self.center
1154
                table["e_min"] = self.edges_min
1155
                table["e_max"] = self.edges_max
1156
        elif format == "gadf-dl3":
1157
            from gammapy.irf.io import IRF_DL3_AXES_SPECIFICATION
1158
1159
            if self.name == "energy":
1160
                column_prefix = "ENERG"
1161
            else:
1162
                for column_prefix, spec in IRF_DL3_AXES_SPECIFICATION.items():
1163
                    if spec["name"] == self.name:
1164
                        break
1165
1166
            if self.node_type == "edges":
1167
                edges_hi, edges_lo = edges[:-1], edges[1:]
1168
            else:
1169
                edges_hi, edges_lo = self.center, self.center
1170
1171
            table[f"{column_prefix}_LO"] = edges_hi[np.newaxis]
1172
            table[f"{column_prefix}_HI"] = edges_lo[np.newaxis]
1173
        elif format == "gtpsf":
1174
            if self.name == "energy_true":
1175
                table["Energy"] = self.center.to("MeV")
1176
            elif self.name == "rad":
1177
                table["Theta"] = self.center.to("deg")
1178
            else:
1179
                raise ValueError(
1180
                    "Can only convert true energy or rad axis to"
1181
                    f"'gtpsf' format, got {self.name}"
1182
                )
1183
        else:
1184
            raise ValueError(f"{format} is not a valid format")
1185
1186
        return table
1187
1188
    def to_table_hdu(self, format="ogip"):
1189
        """Convert `~astropy.units.Quantity` to OGIP ``EBOUNDS`` extension.
1190
1191
        See https://heasarc.gsfc.nasa.gov/docs/heasarc/caldb/docs/memos/cal_gen_92_002/cal_gen_92_002.html#tth_sEc3.2  # noqa: E501
1192
1193
        The 'ogip-sherpa' format is equivalent to 'ogip' but uses keV energy units.
1194
1195
        Parameters
1196
        ----------
1197
        format : {"ogip", "ogip-sherpa", "gtpsf"}
1198
            Format specification
1199
1200
        Returns
1201
        -------
1202
        hdu : `~astropy.io.fits.BinTableHDU`
1203
            Table HDU
1204
        """
1205
        table = self.to_table(format=format)
1206
1207
        if format == "gtpsf":
1208
            name = "THETA"
1209
        else:
1210
            name = None
1211
1212
        hdu = fits.BinTableHDU(table, name=name)
1213
1214
        if format in ["ogip", "ogip-sherpa"]:
1215
            hdu.header.update(self.to_header(format=format))
1216
1217
        return hdu
1218
1219
    @classmethod
1220
    def from_table(cls, table, format="ogip", idx=0, column_prefix=""):
1221
        """Instantiate MapAxis from table HDU
1222
1223
        Parameters
1224
        ----------
1225
        table : `~astropy.table.Table`
1226
            Table
1227
        format : {"ogip", "ogip-arf", "fgst-ccube", "fgst-template", "gadf", "gadf-dl3"}
1228
            Format specification
1229
        idx : int
1230
            Column index of the axis.
1231
        column_prefix : str
1232
            Column name prefix of the axis, used for creating the axis.
1233
1234
        Returns
1235
        -------
1236
        axis : `MapAxis`
1237
            Map Axis
1238
        """
1239
        if format in ["ogip", "fgst-ccube"]:
1240
            energy_min = table["E_MIN"].quantity
1241
            energy_max = table["E_MAX"].quantity
1242
            energy_edges = (
1243
                np.append(energy_min.value, energy_max.value[-1]) * energy_min.unit
1244
            )
1245
            axis = cls.from_edges(energy_edges, name="energy", interp="log")
1246
1247
        elif format == "ogip-arf":
1248
            energy_min = table["ENERG_LO"].quantity
1249
            energy_max = table["ENERG_HI"].quantity
1250
            energy_edges = (
1251
                np.append(energy_min.value, energy_max.value[-1]) * energy_min.unit
1252
            )
1253
            axis = cls.from_edges(energy_edges, name="energy_true", interp="log")
1254
1255
        elif format in ["fgst-template", "fgst-bexpcube"]:
1256
            allowed_names = ["Energy", "ENERGY", "energy"]
1257
            for colname in table.colnames:
1258
                if colname in allowed_names:
1259
                    tag = colname
1260
                    break
1261
1262
            nodes = table[tag].data
0 ignored issues
show
introduced by
The variable tag does not seem to be defined for all execution paths.
Loading history...
1263
            axis = cls.from_nodes(
1264
                nodes=nodes, name="energy_true", unit="MeV", interp="log"
1265
            )
1266
1267
        elif format == "gadf":
1268
            axcols = table.meta.get("AXCOLS{}".format(idx + 1))
1269
            colnames = axcols.split(",")
1270
            node_type = "edges" if len(colnames) == 2 else "center"
1271
1272
            # TODO: check why this extra case is needed
1273
            if colnames[0] == "E_MIN":
1274
                name = "energy"
1275
            else:
1276
                name = colnames[0].replace("_MIN", "").lower()
1277
                # this is need for backward compatibility
1278
                if name == "theta":
1279
                    name = "rad"
1280
1281
            interp = table.meta.get("INTERP{}".format(idx + 1), "lin")
1282
1283
            if node_type == "center":
1284
                nodes = np.unique(table[colnames[0]].quantity)
1285
            else:
1286
                edges_min = np.unique(table[colnames[0]].quantity)
1287
                edges_max = np.unique(table[colnames[1]].quantity)
1288
                nodes = edges_from_lo_hi(edges_min, edges_max)
1289
1290
            axis = MapAxis(nodes=nodes, node_type=node_type, interp=interp, name=name)
1291
1292
        elif format == "gadf-dl3":
1293
            from gammapy.irf.io import IRF_DL3_AXES_SPECIFICATION
1294
1295
            spec = IRF_DL3_AXES_SPECIFICATION[column_prefix]
1296
            name, interp = spec["name"], spec["interp"]
1297
1298
            # background models are stored in reconstructed energy
1299
            hduclass = table.meta.get("HDUCLAS2")
1300
            if hduclass in {"BKG", "RAD_MAX"} and column_prefix == "ENERG":
1301
                name = "energy"
1302
1303
            edges_lo = table[f"{column_prefix}_LO"].quantity[0]
1304
            edges_hi = table[f"{column_prefix}_HI"].quantity[0]
1305
1306
            if np.allclose(edges_hi, edges_lo):
1307
                axis = MapAxis.from_nodes(edges_hi, interp=interp, name=name)
1308
            else:
1309
                edges = edges_from_lo_hi(edges_lo, edges_hi)
1310
                axis = MapAxis.from_edges(edges, interp=interp, name=name)
1311
        elif format == "gtpsf":
1312
            try:
1313
                energy = table["Energy"].data * u.MeV
1314
                axis = MapAxis.from_nodes(energy, name="energy_true", interp="log")
1315
            except KeyError:
1316
                rad = table["Theta"].data * u.deg
1317
                axis = MapAxis.from_nodes(rad, name="rad")
1318
        elif format == "gadf-sed-energy":
1319
            if "e_min" in table.colnames and "e_max" in table.colnames:
1320
                e_min = flat_if_equal(table["e_min"].quantity)
1321
                e_max = flat_if_equal(table["e_max"].quantity)
1322
                edges = edges_from_lo_hi(e_min, e_max)
1323
                axis = MapAxis.from_energy_edges(edges)
1324
            elif "e_ref" in table.colnames:
1325
                e_ref = flat_if_equal(table["e_ref"].quantity)
1326
                axis = MapAxis.from_nodes(e_ref, name="energy", interp="log")
1327
            else:
1328
                raise ValueError(
1329
                    "Either 'e_ref', 'e_min' or 'e_max' column " "names are required"
1330
                )
1331
        elif format == "gadf-sed-norm":
1332
            # TODO: guess interp here
1333
            nodes = flat_if_equal(table["norm_scan"][0])
1334
            axis = MapAxis.from_nodes(nodes, name="norm")
1335
        elif format == "gadf-sed-counts":
1336
            if "datasets" in table.colnames:
1337
                labels = np.unique(table["datasets"])
1338
                axis = LabelMapAxis(labels=labels, name="dataset")
1339
            else:
1340
                shape = table["counts"].shape
1341
                edges = np.arange(shape[-1] + 1) - 0.5
1342
                axis = MapAxis.from_edges(edges, name="dataset")
1343
        elif format == "profile":
1344
            if "datasets" in table.colnames:
1345
                labels = np.unique(table["datasets"])
1346
                axis = LabelMapAxis(labels=labels, name="dataset")
1347
            else:
1348
                x_ref = table["x_ref"].quantity
1349
                axis = MapAxis.from_nodes(x_ref, name="projected-distance")
1350
        else:
1351
            raise ValueError(f"Format '{format}' not supported")
1352
1353
        return axis
1354
1355
    @classmethod
1356
    def from_table_hdu(cls, hdu, format="ogip", idx=0):
1357
        """Instantiate MapAxis from table HDU
1358
1359
        Parameters
1360
        ----------
1361
        hdu : `~astropy.io.fits.BinTableHDU`
1362
            Table HDU
1363
        format : {"ogip", "ogip-arf", "fgst-ccube", "fgst-template"}
1364
            Format specification
1365
        idx : int
1366
            Column index of the axis.
1367
1368
        Returns
1369
        -------
1370
        axis : `MapAxis`
1371
            Map Axis
1372
        """
1373
        table = Table.read(hdu)
1374
        return cls.from_table(table, format=format, idx=idx)
1375
1376
1377
class MapAxes(Sequence):
1378
    """MapAxis container class.
1379
1380
    Parameters
1381
    ----------
1382
    axes : list of `MapAxis`
1383
        List of map axis objects.
1384
    """
1385
1386
    def __init__(self, axes, n_spatial_axes=None):
1387
        unique_names = []
1388
1389
        for ax in axes:
1390
            if ax.name in unique_names:
1391
                raise (
1392
                    ValueError(f"Axis names must be unique, got: '{ax.name}' twice.")
1393
                )
1394
            unique_names.append(ax.name)
1395
1396
        self._axes = axes
1397
        self._n_spatial_axes = n_spatial_axes
1398
1399
    @property
1400
    def primary_axis(self):
1401
        """Primary extra axis, defined as the one longest
1402
1403
        Returns
1404
        -------
1405
        axis : `MapAxis`
1406
            Map axis
1407
        """
1408
        # get longest axis
1409
        idx = np.argmax(self.shape)
1410
        return self[int(idx)]
1411
1412
    @property
1413
    def is_flat(self):
1414
        """Whether axes is flat"""
1415
        shape = np.array(self.shape)
1416
        return np.all(shape == 1)
1417
1418
    @property
1419
    def is_unidimensional(self):
1420
        """Whether axes is unidimensional"""
1421
        shape = np.array(self.shape)
1422
        non_zero = np.count_nonzero(shape > 1)
1423
        return self.is_flat or non_zero == 1
1424
1425
    @property
1426
    def reverse(self):
1427
        """Reverse axes order"""
1428
        return MapAxes(self[::-1])
1429
1430
    @property
1431
    def iter_with_reshape(self):
1432
        """Iterate by shape"""
1433
        for idx, axis in enumerate(self):
1434
            # Extract values for each axis, default: nodes
1435
            shape = [1] * len(self)
1436
            shape[idx] = -1
1437
            if self._n_spatial_axes:
1438
                shape = (
1439
                    shape[::-1]
1440
                    + [
1441
                        1,
1442
                    ]
1443
                    * self._n_spatial_axes
1444
                )
1445
            yield tuple(shape), axis
1446
1447
    def get_coord(self, mode="center", axis_name=None):
1448
        """Get axes coordinates
1449
1450
        Parameters
1451
        ----------
1452
        mode : {"center", "edges"}
1453
            Coordinate center or edges
1454
        axis_name : str
1455
            Axis name for which mode='edges' applies
1456
1457
        Returns
1458
        -------
1459
        coords : dict of `~astropy.units.Quanity`
1460
            Map coordinates
1461
        """
1462
        coords = {}
1463
1464
        for shape, axis in self.iter_with_reshape:
1465
            if mode == "edges" and axis.name == axis_name:
1466
                coord = axis.edges
1467
            else:
1468
                coord = axis.center
1469
            coords[axis.name] = coord.reshape(shape)
1470
1471
        return coords
1472
1473
    def bin_volume(self):
1474
        """Bin axes volume
1475
1476
        Returns
1477
        -------
1478
        bin_volume : `~astropy.units.Quantity`
1479
            Bin volume
1480
        """
1481
        bin_volume = np.array(1)
1482
1483
        for shape, axis in self.iter_with_reshape:
1484
            bin_volume = bin_volume * axis.bin_width.reshape(shape)
1485
1486
        return bin_volume
1487
1488
    @property
1489
    def shape(self):
1490
        """Shape of the axes"""
1491
        return tuple([ax.nbin for ax in self])
1492
1493
    @property
1494
    def names(self):
1495
        """Names of the axes"""
1496
        return [ax.name for ax in self]
1497
1498
    def index(self, axis_name):
1499
        """Get index in list"""
1500
        return self.names.index(axis_name)
1501
1502
    def index_data(self, axis_name):
1503
        """Get data index of the axes
1504
1505
        Parameters
1506
        ----------
1507
        axis_name : str
1508
            Name of the axis.
1509
1510
        Returns
1511
        -------
1512
        idx : int
1513
            Data index
1514
        """
1515
        idx = self.names.index(axis_name)
1516
        return len(self) - idx - 1
1517
1518
    def __len__(self):
1519
        return len(self._axes)
1520
1521
    def __add__(self, other):
1522
        return self.__class__(list(self) + list(other))
1523
1524
    def upsample(self, factor, axis_name):
1525
        """Upsample axis by a given factor
1526
1527
        Parameters
1528
        ----------
1529
        factor : int
1530
            Upsampling factor.
1531
        axis_name : str
1532
            Axis to upsample.
1533
1534
        Returns
1535
        -------
1536
        axes : `MapAxes`
1537
            Map axes
1538
        """
1539
        axes = []
1540
1541
        for ax in self:
1542
            if ax.name == axis_name:
1543
                ax = ax.upsample(factor=factor)
1544
1545
            axes.append(ax.copy())
1546
1547
        return self.__class__(axes=axes)
1548
1549
    def replace(self, axis):
1550
        """Replace a given axis
1551
1552
        Parameters
1553
        ----------
1554
        axis : `MapAxis`
1555
            Map axis
1556
1557
        Returns
1558
        -------
1559
        axes : MapAxes
1560
            Map axe
1561
        """
1562
        axes = []
1563
1564
        for ax in self:
1565
            if ax.name == axis.name:
1566
                ax = axis
1567
1568
            axes.append(ax)
1569
1570
        return self.__class__(axes=axes)
1571
1572
    def resample(self, axis):
1573
        """Resample axis binning.
1574
1575
        This method groups the existing bins into a new binning.
1576
1577
        Parameters
1578
        ----------
1579
        axis : `MapAxis`
1580
            New map axis.
1581
1582
        Returns
1583
        -------
1584
        axes : `MapAxes`
1585
            Axes object with resampled axis.
1586
        """
1587
        axis_self = self[axis.name]
1588
        groups = axis_self.group_table(axis.edges)
1589
1590
        # Keep only normal bins
1591
        groups = groups[groups["bin_type"] == "normal   "]
1592
1593
        edges = edges_from_lo_hi(
1594
            groups[axis.name + "_min"].quantity,
1595
            groups[axis.name + "_max"].quantity,
1596
        )
1597
1598
        axis_resampled = MapAxis.from_edges(
1599
            edges=edges, interp=axis.interp, name=axis.name
1600
        )
1601
1602
        axes = []
1603
        for ax in self:
1604
            if ax.name == axis.name:
1605
                axes.append(axis_resampled)
1606
            else:
1607
                axes.append(ax.copy())
1608
1609
        return self.__class__(axes=axes)
1610
1611
    def downsample(self, factor, axis_name):
1612
        """Downsample axis by a given factor
1613
1614
        Parameters
1615
        ----------
1616
        factor : int
1617
            Upsampling factor.
1618
        axis_name : str
1619
            Axis to upsample.
1620
1621
        Returns
1622
        -------
1623
        axes : `MapAxes`
1624
            Map axes
1625
1626
        """
1627
        axes = []
1628
1629
        for ax in self:
1630
            if ax.name == axis_name:
1631
                ax = ax.downsample(factor=factor)
1632
1633
            axes.append(ax.copy())
1634
1635
        return self.__class__(axes=axes)
1636
1637
    def squash(self, axis_name):
1638
        """Squash axis.
1639
1640
        Parameters
1641
        ----------
1642
        axis_name : str
1643
            Axis to squash.
1644
1645
        Returns
1646
        -------
1647
        axes : `MapAxes`
1648
            Axes with squashed axis.
1649
        """
1650
        axes = []
1651
1652
        for ax in self:
1653
            if ax.name == axis_name:
1654
                ax = ax.squash()
1655
            axes.append(ax.copy())
1656
1657
        return self.__class__(axes=axes)
1658
1659
    def pad(self, axis_name, pad_width):
1660
        """Pad axes
1661
1662
        Parameters
1663
        ----------
1664
        axis_name : str
1665
            Name of the axis to pad.
1666
        pad_width : int or tuple of int
1667
            Pad width
1668
1669
        Returns
1670
        -------
1671
        axes : `MapAxes`
1672
            Axes with squashed axis.
1673
1674
        """
1675
        axes = []
1676
1677
        for ax in self:
1678
            if ax.name == axis_name:
1679
                ax = ax.pad(pad_width=pad_width)
1680
            axes.append(ax)
1681
1682
        return self.__class__(axes=axes)
1683
1684
    def drop(self, axis_name):
1685
        """Drop an axis.
1686
1687
        Parameters
1688
        ----------
1689
        axis_name : str
1690
            Name of the axis to remove.
1691
1692
        Returns
1693
        -------
1694
        axes : `MapAxes`
1695
            Axes with squashed axis.
1696
        """
1697
        axes = []
1698
        for ax in self:
1699
            if ax.name == axis_name:
1700
                continue
1701
            axes.append(ax.copy())
1702
1703
        return self.__class__(axes=axes)
1704
1705
    def __getitem__(self, idx):
1706
        if isinstance(idx, int):
1707
            return self._axes[idx]
1708
        elif isinstance(idx, str):
1709
            for ax in self._axes:
1710
                if ax.name == idx:
1711
                    return ax
1712
            raise KeyError(f"No axes: {idx!r}")
1713
        elif isinstance(idx, slice):
1714
            axes = self._axes[idx]
1715
            return self.__class__(axes=axes)
1716
        elif isinstance(idx, list):
1717
            axes = []
1718
            for name in idx:
1719
                axes.append(self[name])
1720
1721
            return self.__class__(axes=axes)
1722
        else:
1723
            raise TypeError(f"Invalid type: {type(idx)!r}")
1724
1725
    def coord_to_idx(self, coord, clip=True):
1726
        """Transform from axis to pixel indices.
1727
1728
        Parameters
1729
        ----------
1730
        coord : dict of `~numpy.ndarray` or `MapCoord`
1731
            Array of axis coordinate values.
1732
1733
        Returns
1734
        -------
1735
        pix : tuple of `~numpy.ndarray`
1736
            Array of pixel indices values.
1737
        """
1738
        return tuple([ax.coord_to_idx(coord[ax.name], clip=clip) for ax in self])
1739
1740
    def coord_to_pix(self, coord):
1741
        """Transform from axis to pixel coordinates.
1742
1743
        Parameters
1744
        ----------
1745
        coord : dict of `~numpy.ndarray`
1746
            Array of axis coordinate values.
1747
1748
        Returns
1749
        -------
1750
        pix : tuple of `~numpy.ndarray`
1751
            Array of pixel coordinate values.
1752
        """
1753
        return tuple([ax.coord_to_pix(coord[ax.name]) for ax in self])
1754
1755
    def pix_to_coord(self, pix):
1756
        """Convert pixel coordinates to map coordinates.
1757
1758
        Parameters
1759
        ----------
1760
        pix : tuple
1761
            Tuple of pixel coordinates.
1762
1763
        Returns
1764
        -------
1765
        coords : tuple
1766
            Tuple of map coordinates.
1767
        """
1768
        return tuple([ax.pix_to_coord(p) for ax, p in zip(self, pix)])
1769
1770
    def pix_to_idx(self, pix, clip=False):
1771
        """Convert pix to idx
1772
1773
        Parameters
1774
        ----------
1775
        pix : tuple of `~numpy.ndarray`
1776
            Pixel coordinates.
1777
        clip : bool
1778
            Choose whether to clip indices to the valid range of the
1779
            axis.  If false then indices for coordinates outside
1780
            the axi range will be set -1.
1781
1782
        Returns
1783
        -------
1784
        idx : tuple `~numpy.ndarray`
1785
            Pixel indices.
1786
        """
1787
        idx = []
1788
1789
        for pix_array, ax in zip(pix, self):
1790
            idx.append(ax.pix_to_idx(pix_array, clip=clip))
1791
1792
        return tuple(idx)
1793
1794
    def slice_by_idx(self, slices):
1795
        """Create a new geometry by slicing the non-spatial axes.
1796
1797
        Parameters
1798
        ----------
1799
        slices : dict
1800
            Dict of axes names and integers or `slice` object pairs. Contains one
1801
            element for each non-spatial dimension. For integer indexing the
1802
            corresponding axes is dropped from the map. Axes not specified in the
1803
            dict are kept unchanged.
1804
1805
        Returns
1806
        -------
1807
        geom : `~Geom`
1808
            Sliced geometry.
1809
        """
1810
        axes = []
1811
        for ax in self:
1812
            ax_slice = slices.get(ax.name, slice(None))
1813
1814
            # in the case where isinstance(ax_slice, int) the axes is dropped
1815
            if isinstance(ax_slice, slice):
1816
                ax_sliced = ax.slice(ax_slice)
1817
                axes.append(ax_sliced.copy())
1818
1819
        return self.__class__(axes=axes)
1820
1821
    def to_header(self, format="gadf"):
1822
        """Convert axes to FITS header
1823
1824
        Parameters
1825
        ----------
1826
        format : {"gadf"}
1827
            Header format
1828
1829
        Returns
1830
        -------
1831
        header : `~astropy.io.fits.Header`
1832
            FITS header.
1833
        """
1834
        header = fits.Header()
1835
1836
        for idx, ax in enumerate(self, start=1):
1837
            header_ax = ax.to_header(format=format, idx=idx)
1838
            header.update(header_ax)
1839
1840
        return header
1841
1842
    def to_table(self, format="gadf"):
1843
        """Convert axes to table
1844
1845
        Parameters
1846
        ----------
1847
        format : {"gadf", "gadf-dl3", "fgst-ccube", "fgst-template", "ogip", "ogip-sherpa", "ogip-arf", "ogip-arf-sherpa"}  # noqa E501
1848
            Format to use.
1849
1850
        Returns
1851
        -------
1852
        table : `~astropy.table.Table`
1853
            Table with axis data
1854
        """
1855
        if format == "gadf-dl3":
1856
            tables = []
1857
1858
            for ax in self:
1859
                tables.append(ax.to_table(format=format))
1860
1861
            table = hstack(tables)
1862
        elif format in ["gadf", "fgst-ccube", "fgst-template"]:
1863
            table = Table()
1864
            table["CHANNEL"] = np.arange(np.prod(self.shape))
1865
1866
            axes_ctr = np.meshgrid(*[ax.center for ax in self])
1867
            axes_min = np.meshgrid(*[ax.edges_min for ax in self])
1868
            axes_max = np.meshgrid(*[ax.edges_max for ax in self])
1869
1870
            for idx, ax in enumerate(self):
1871
                name = ax.name.upper()
1872
1873
                if name == "ENERGY":
1874
                    colnames = ["ENERGY", "E_MIN", "E_MAX"]
1875
                else:
1876
                    colnames = [name, name + "_MIN", name + "_MAX"]
1877
1878
                for colname, v in zip(colnames, [axes_ctr, axes_min, axes_max]):
1879
                    # do not store edges for label axis
1880
                    if ax.node_type == "label" and colname != name:
1881
                        continue
1882
1883
                    table[colname] = np.ravel(v[idx])
1884
1885
                if isinstance(ax, TimeMapAxis):
1886
                    ref_dict = time_ref_to_dict(ax.reference_time)
1887
                    table.meta.update(ref_dict)
1888
1889
        elif format in ["ogip", "ogip-sherpa", "ogip", "ogip-arf"]:
1890
            energy_axis = self["energy"]
1891
            table = energy_axis.to_table(format=format)
1892
        else:
1893
            raise ValueError(f"Unsupported format: '{format}'")
1894
1895
        return table
1896
1897
    def to_table_hdu(self, format="gadf", hdu_bands=None):
1898
        """Make FITS table columns for map axes.
1899
1900
        Parameters
1901
        ----------
1902
        format : {"gadf", "fgst-ccube", "fgst-template"}
1903
            Format to use.
1904
        hdu_bands : str
1905
            Name of the bands HDU to use.
1906
1907
        Returns
1908
        -------
1909
        hdu : `~astropy.io.fits.BinTableHDU`
1910
            Bin table HDU.
1911
        """
1912
        # FIXME: Check whether convention is compatible with
1913
        #  dimensionality of geometry and simplify!!!
1914
1915
        if format in ["fgst-ccube", "ogip", "ogip-sherpa"]:
1916
            hdu_bands = "EBOUNDS"
1917
        elif format == "fgst-template":
1918
            hdu_bands = "ENERGIES"
1919
        elif format == "gadf" or format is None:
1920
            if hdu_bands is None:
1921
                hdu_bands = "BANDS"
1922
        else:
1923
            raise ValueError(f"Unknown format {format}")
1924
1925
        table = self.to_table(format=format)
1926
        header = self.to_header(format=format)
1927
        return fits.BinTableHDU(table, name=hdu_bands, header=header)
1928
1929
    @classmethod
1930
    def from_table_hdu(cls, hdu, format="gadf"):
1931
        """Create MapAxes from BinTableHDU
1932
1933
        Parameters
1934
        ----------
1935
        hdu : `~astropy.io.fits.BinTableHDU`
1936
            Bin table HDU
1937
1938
1939
        Returns
1940
        -------
1941
        axes : `MapAxes`
1942
            Map axes object
1943
        """
1944
        if hdu is None:
1945
            return cls([])
1946
1947
        table = Table.read(hdu)
1948
        return cls.from_table(table, format=format)
1949
1950
    @classmethod
1951
    def from_table(cls, table, format="gadf"):
1952
        """Create MapAxes from table
1953
1954
        Parameters
1955
        ----------
1956
        table : `~astropy.table.Table`
1957
            Bin table HDU
1958
        format : {"gadf", "gadf-dl3", "fgst-ccube", "fgst-template", "fgst-bexcube", "ogip-arf"}
1959
            Format to use.
1960
1961
        Returns
1962
        -------
1963
        axes : `MapAxes`
1964
            Map axes object
1965
        """
1966
        from gammapy.irf.io import IRF_DL3_AXES_SPECIFICATION
1967
1968
        axes = []
1969
1970
        # Formats that support only one energy axis
1971
        if format in [
1972
            "fgst-ccube",
1973
            "fgst-template",
1974
            "fgst-bexpcube",
1975
            "ogip",
1976
            "ogip-arf",
1977
        ]:
1978
            axes.append(MapAxis.from_table(table, format=format))
1979
        elif format == "gadf":
1980
            # This limits the max number of axes to 5
1981
            for idx in range(5):
1982
                axcols = table.meta.get("AXCOLS{}".format(idx + 1))
1983
                if axcols is None:
1984
                    break
1985
1986
                # TODO: what is good way to check whether it is a given axis type?
1987
                try:
1988
                    axis = LabelMapAxis.from_table(table, format=format, idx=idx)
1989
                except (KeyError, TypeError):
1990
                    try:
1991
                        axis = TimeMapAxis.from_table(table, format=format, idx=idx)
1992
                    except (KeyError, ValueError):
1993
                        axis = MapAxis.from_table(table, format=format, idx=idx)
1994
1995
                axes.append(axis)
1996
        elif format == "gadf-dl3":
1997
            for column_prefix in IRF_DL3_AXES_SPECIFICATION:
1998
                try:
1999
                    axis = MapAxis.from_table(
2000
                        table, format=format, column_prefix=column_prefix
2001
                    )
2002
                except KeyError:
2003
                    continue
2004
2005
                axes.append(axis)
2006
        elif format == "gadf-sed":
2007
            for axis_format in ["gadf-sed-norm", "gadf-sed-energy", "gadf-sed-counts"]:
2008
                try:
2009
                    axis = MapAxis.from_table(table=table, format=axis_format)
2010
                except KeyError:
2011
                    continue
2012
                axes.append(axis)
2013
        elif format == "lightcurve":
2014
            axes.extend(cls.from_table(table=table, format="gadf-sed"))
2015
            axes.append(TimeMapAxis.from_table(table, format="lightcurve"))
2016
        elif format == "profile":
2017
            axes.extend(cls.from_table(table=table, format="gadf-sed"))
2018
            axes.append(MapAxis.from_table(table, format="profile"))
2019
        else:
2020
            raise ValueError(f"Unsupported format: '{format}'")
2021
2022
        return cls(axes)
2023
2024
    @classmethod
2025
    def from_default(cls, axes, n_spatial_axes=None):
2026
        """Make a sequence of `~MapAxis` objects."""
2027
        if axes is None:
2028
            return cls([])
2029
2030
        axes_out = []
2031
        for idx, ax in enumerate(axes):
2032
            if isinstance(ax, np.ndarray):
2033
                ax = MapAxis(ax)
2034
2035
            if ax.name == "":
2036
                ax._name = f"axis{idx}"
2037
2038
            axes_out.append(ax)
2039
2040
        return cls(axes_out, n_spatial_axes=n_spatial_axes)
2041
2042
    def assert_names(self, required_names):
2043
        """Assert required axis names and order
2044
2045
        Parameters
2046
        ----------
2047
        required_names : list of str
2048
            Required
2049
        """
2050
        message = (
2051
            "Incorrect axis order or names. Expected axis "
2052
            f"order: {required_names}, got: {self.names}."
2053
        )
2054
2055
        if not len(self) == len(required_names):
2056
            raise ValueError(message)
2057
2058
        try:
2059
            for ax, required_name in zip(self, required_names):
2060
                ax.assert_name(required_name)
2061
2062
        except ValueError:
2063
            raise ValueError(message)
2064
2065
    def rename_axes(self, names, new_names):
2066
        """Rename the axes.
2067
2068
        Parameters
2069
        ----------
2070
        names : list or str
2071
            Names of the axes
2072
        new_names : list or str
2073
            New names of the axes (list must be of same length than `names`).
2074
2075
        Returns
2076
        -------
2077
        axes : `MapAxes`
2078
            Renamed Map axes object
2079
        """
2080
        axes = self.copy()
2081
        if isinstance(names, str):
2082
            names = [names]
2083
        if isinstance(new_names, str):
2084
            new_names = [new_names]
2085
        for name, new_name in zip(names, new_names):
2086
            axes[name]._name = new_name
2087
        return axes
2088
2089
    @property
2090
    def center_coord(self):
2091
        """Center coordinates"""
2092
        return tuple([ax.pix_to_coord((float(ax.nbin) - 1.0) / 2.0) for ax in self])
2093
2094
    def is_allclose(self, other, **kwargs):
2095
        """Check if other map axes are all close.
2096
2097
        Parameters
2098
        ----------
2099
        other : `MapAxes`
2100
            Other map axes
2101
        **kwargs : dict
2102
            Keyword arguments forwarded to `~MapAxis.is_allclose`
2103
2104
        Returns
2105
        -------
2106
        is_allclose : bool
2107
            Whether other axes are all close
2108
        """
2109
        if not isinstance(other, self.__class__):
2110
            return TypeError(f"Cannot compare {type(self)} and {type(other)}")
2111
2112
        return np.all([ax0.is_allclose(ax1, **kwargs) for ax0, ax1 in zip(other, self)])
2113
2114
    def __eq__(self, other):
2115
        if not isinstance(other, self.__class__):
2116
            return False
2117
2118
        return self.is_allclose(other, rtol=1e-6, atol=1e-6)
2119
2120
    def __ne__(self, other):
2121
        return not self.__eq__(other)
2122
2123
    def copy(self):
2124
        """Init map axes instance by copying each axis."""
2125
        return self.__class__([_.copy() for _ in self])
2126
2127
2128
class TimeMapAxis:
2129
    """Class representing a time axis.
2130
2131
    Provides methods for transforming to/from axis and pixel coordinates.
2132
    A time axis can represent non-contiguous sequences of non-overlapping time intervals.
2133
2134
    Time intervals must be provided in increasing order.
2135
2136
    Parameters
2137
    ----------
2138
    edges_min : `~astropy.units.Quantity`
2139
        Array of edge time values. This the time delta w.r.t. to the reference time.
2140
    edges_max : `~astropy.units.Quantity`
2141
        Array of edge time values. This the time delta w.r.t. to the reference time.
2142
    reference_time : `~astropy.time.Time`
2143
        Reference time to use.
2144
    name : str
2145
        Axis name
2146
    interp : str
2147
        Interpolation method used to transform between axis and pixel
2148
        coordinates.  For now only 'lin' is supported.
2149
    """
2150
2151
    node_type = "intervals"
2152
    time_format = "iso"
2153
2154
    def __init__(self, edges_min, edges_max, reference_time, name="time", interp="lin"):
2155
        self._name = name
2156
2157
        edges_min = u.Quantity(edges_min, ndmin=1)
2158
        edges_max = u.Quantity(edges_max, ndmin=1)
2159
2160
        if not edges_min.unit.is_equivalent("s"):
2161
            raise ValueError(
2162
                f"Time edges min must have a valid time unit, got {edges_min.unit}"
2163
            )
2164
2165
        if not edges_max.unit.is_equivalent("s"):
2166
            raise ValueError(
2167
                f"Time edges max must have a valid time unit, got {edges_max.unit}"
2168
            )
2169
2170
        if not edges_min.shape == edges_max.shape:
2171
            raise ValueError(
2172
                "Edges min and edges max must have the same shape,"
2173
                f" got {edges_min.shape} and {edges_max.shape}."
2174
            )
2175
2176
        if not np.all(edges_max > edges_min):
2177
            raise ValueError("Edges max must all be larger than edge min")
2178
2179
        if not np.all(edges_min == np.sort(edges_min)):
2180
            raise ValueError("Time edges min values must be sorted")
2181
2182
        if not np.all(edges_max == np.sort(edges_max)):
2183
            raise ValueError("Time edges max values must be sorted")
2184
2185
        if interp != "lin":
2186
            raise NotImplementedError(
2187
                f"Non-linear scaling scheme are not supported yet, got {interp}"
2188
            )
2189
2190
        self._edges_min = edges_min
2191
        self._edges_max = edges_max
2192
        self._reference_time = Time(reference_time)
2193
        self._pix_offset = -0.5
2194
        self._interp = interp
2195
2196
        delta = edges_min[1:] - edges_max[:-1]
2197
        if np.any(delta < 0 * u.s):
2198
            raise ValueError("Time intervals must not overlap.")
2199
2200
    @property
2201
    def is_contiguous(self):
2202
        """Whether the axis is contiguous"""
2203
        return np.all(self.edges_min[1:] == self.edges_max[:-1])
2204
2205
    def to_contiguous(self):
2206
        """Make the time axis contiguous
2207
2208
        Returns
2209
        -------
2210
        axis : `TimeMapAxis`
2211
            Contiguous time axis
2212
        """
2213
        edges = np.unique(np.stack([self.edges_min, self.edges_max]))
2214
        return self.__class__(
2215
            edges_min=edges[:-1],
2216
            edges_max=edges[1:],
2217
            reference_time=self.reference_time,
2218
            name=self.name,
2219
            interp=self.interp,
2220
        )
2221
2222
    @property
2223
    def unit(self):
2224
        """Axes unit"""
2225
        return self.edges_max.unit
2226
2227
    @property
2228
    def interp(self):
2229
        """Interp"""
2230
        return self._interp
2231
2232
    @property
2233
    def reference_time(self):
2234
        """Return reference time used for the axis."""
2235
        return self._reference_time
2236
2237
    @property
2238
    def name(self):
2239
        """Return axis name."""
2240
        return self._name
2241
2242
    @property
2243
    def nbin(self):
2244
        """Return number of bins in the axis."""
2245
        return len(self.edges_min.flatten())
2246
2247
    @property
2248
    def edges_min(self):
2249
        """Return array of bin edges max values."""
2250
        return self._edges_min
2251
2252
    @property
2253
    def edges_max(self):
2254
        """Return array of bin edges min values."""
2255
        return self._edges_max
2256
2257
    @property
2258
    def edges(self):
2259
        """Return array of bin edges values."""
2260
        if not self.is_contiguous:
2261
            raise ValueError("Time axis is not contiguous")
2262
2263
        return edges_from_lo_hi(self.edges_min, self.edges_max)
2264
2265
    @property
2266
    def bounds(self):
2267
        """Bounds of the axis (~astropy.units.Quantity)"""
2268
        return self.edges_min[0], self.edges_max[-1]
2269
2270
    @property
2271
    def time_bounds(self):
2272
        """Bounds of the axis (~astropy.units.Quantity)"""
2273
        t_min, t_max = self.bounds
2274
        return t_min + self.reference_time, t_max + self.reference_time
2275
2276
    @property
2277
    def time_min(self):
2278
        """Return axis lower edges as Time objects."""
2279
        return self._edges_min + self.reference_time
2280
2281
    @property
2282
    def time_max(self):
2283
        """Return axis upper edges as Time objects."""
2284
        return self._edges_max + self.reference_time
2285
2286
    @property
2287
    def time_delta(self):
2288
        """Return axis time bin width (`~astropy.time.TimeDelta`)."""
2289
        return self.time_max - self.time_min
2290
2291
    @property
2292
    def time_mid(self):
2293
        """Return time bin center (`~astropy.time.Time`)."""
2294
        return self.time_min + 0.5 * self.time_delta
2295
2296
    @property
2297
    def time_edges(self):
2298
        """Time edges"""
2299
        return self.reference_time + self.edges
2300
2301
    @property
2302
    def as_plot_xerr(self):
2303
        """Plot x error"""
2304
        xn, xp = self.time_mid - self.time_min, self.time_max - self.time_mid
2305
2306
        if self.time_format == "iso":
2307
            x_errn = xn.to_datetime()
2308
            x_errp = xp.to_datetime()
2309
        elif self.time_format == "mjd":
2310
            x_errn = xn.to("day")
2311
            x_errp = xp.to("day")
2312
        else:
2313
            raise ValueError(f"Invalid time_format: {self.time_format}")
2314
2315
        return x_errn, x_errp
2316
2317
    @property
2318
    def as_plot_labels(self):
2319
        """Plot labels"""
2320
        labels = []
2321
2322
        for t_min, t_max in self.iter_by_edges:
2323
            label = f"{getattr(t_min, self.time_format)} - {getattr(t_max, self.time_format)}"
2324
            labels.append(label)
2325
2326
        return labels
2327
2328
    @property
2329
    def as_plot_edges(self):
2330
        """Plot edges"""
2331
        if self.time_format == "iso":
2332
            edges = self.time_edges.to_datetime()
2333
        elif self.time_format == "mjd":
2334
            edges = self.time_edges.mjd * u.day
2335
        else:
2336
            raise ValueError(f"Invalid time_format: {self.time_format}")
2337
2338
        return edges
2339
2340
    @property
2341
    def as_plot_center(self):
2342
        """Plot center"""
2343
        if self.time_format == "iso":
2344
            center = self.time_mid.datetime
2345
        elif self.time_format == "mjd":
2346
            center = self.time_mid.mjd * u.day
2347
2348
        return center
0 ignored issues
show
introduced by
The variable center does not seem to be defined for all execution paths.
Loading history...
2349
2350
    def format_plot_xaxis(self, ax):
2351
        """Format plot axis
2352
2353
        Parameters
2354
        ----------
2355
        ax : `~matplotlib.pyplot.Axis`
2356
            Plot axis to format
2357
2358
        Returns
2359
        -------
2360
        ax : `~matplotlib.pyplot.Axis`
2361
            Formatted plot axis
2362
        """
2363
        from matplotlib.dates import DateFormatter
2364
2365
        xlabel = DEFAULT_LABEL_TEMPLATE.format(
2366
            quantity=PLOT_AXIS_LABEL.get(self.name, self.name.capitalize()),
2367
            unit=self.time_format,
2368
        )
2369
        ax.set_xlabel(xlabel)
2370
2371
        if self.time_format == "iso":
2372
            ax.xaxis.set_major_formatter(DateFormatter("%Y-%m-%d %H:%M:%S"))
2373
            plt.setp(
2374
                ax.xaxis.get_majorticklabels(),
2375
                rotation=30,
2376
                ha="right",
2377
                rotation_mode="anchor",
2378
            )
2379
2380
        return ax
2381
2382
    def assert_name(self, required_name):
2383
        """Assert axis name if a specific one is required.
2384
2385
        Parameters
2386
        ----------
2387
        required_name : str
2388
            Required
2389
        """
2390
        if self.name != required_name:
2391
            raise ValueError(
2392
                "Unexpected axis name,"
2393
                f' expected "{required_name}", got: "{self.name}"'
2394
            )
2395
2396
    def is_allclose(self, other, **kwargs):
2397
        """Check if other map axis is all close.
2398
2399
        Parameters
2400
        ----------
2401
        other : `TimeMapAxis`
2402
            Other map axis
2403
        **kwargs : dict
2404
            Keyword arguments forwarded to `~numpy.allclose`
2405
2406
        Returns
2407
        -------
2408
        is_allclose : bool
2409
            Whether other axis is allclose
2410
        """
2411
        if not isinstance(other, self.__class__):
2412
            return TypeError(f"Cannot compare {type(self)} and {type(other)}")
2413
2414
        if self._edges_min.shape != other._edges_min.shape:
2415
            return False
2416
2417
        # This will test equality at microsec level.
2418
        delta_min = self.time_min - other.time_min
2419
        delta_max = self.time_max - other.time_max
2420
2421
        return (
2422
            np.allclose(delta_min.to_value("s"), 0.0, **kwargs)
2423
            and np.allclose(delta_max.to_value("s"), 0.0, **kwargs)
2424
            and self._interp == other._interp
2425
            and self.name.upper() == other.name.upper()
2426
        )
2427
2428
    def __eq__(self, other):
2429
        if not isinstance(other, self.__class__):
2430
            return False
2431
2432
        return self.is_allclose(other=other, atol=1e-6)
2433
2434
    def __ne__(self, other):
2435
        return not self.__eq__(other)
2436
2437
    def __hash__(self):
2438
        return id(self)
2439
2440
    def is_aligned(self, other, atol=2e-2):
2441
        raise NotImplementedError
2442
2443
    @property
2444
    def iter_by_edges(self):
2445
        """Iterate by intervals defined by the edges"""
2446
        for time_min, time_max in zip(self.time_min, self.time_max):
2447
            yield (time_min, time_max)
2448
2449
    def coord_to_idx(self, coord, **kwargs):
2450
        """Transform from axis time coordinate to bin index.
2451
2452
        Indices of time values falling outside time bins will be
2453
        set to -1.
2454
2455
        Parameters
2456
        ----------
2457
        coord : `~astropy.time.Time` or `~astropy.units.Quantity`
2458
            Array of axis coordinate values. The quantity is assumed
2459
            to be relative to the reference time.
2460
2461
        Returns
2462
        -------
2463
        idx : `~numpy.ndarray`
2464
            Array of bin indices.
2465
        """
2466
        if isinstance(coord, u.Quantity):
2467
            coord = self.reference_time + coord
2468
2469
        time = Time(coord[..., np.newaxis])
2470
        delta_plus = (time - self.time_min).value > 0.0
2471
        delta_minus = (time - self.time_max).value <= 0.0
2472
        mask = np.logical_and(delta_plus, delta_minus)
2473
2474
        idx = np.asanyarray(np.argmax(mask, axis=-1))
2475
        idx[~np.any(mask, axis=-1)] = INVALID_INDEX.int
2476
        return idx
2477
2478
    def coord_to_pix(self, coord, **kwargs):
2479
        """Transform from time to coordinate to pixel position.
2480
2481
        Pixels of time values falling outside time bins will be
2482
        set to -1.
2483
2484
        Parameters
2485
        ----------
2486
        coord : `~astropy.time.Time`
2487
            Array of axis coordinate values.
2488
2489
        Returns
2490
        -------
2491
        pix : `~numpy.ndarray`
2492
            Array of pixel positions.
2493
        """
2494
        if isinstance(coord, u.Quantity):
2495
            coord = self.reference_time + coord
2496
2497
        idx = np.atleast_1d(self.coord_to_idx(coord))
2498
2499
        valid_pix = idx != INVALID_INDEX.int
2500
        pix = np.atleast_1d(idx).astype("float")
2501
2502
        # TODO: is there the equivalent of np.atleast1d for astropy.time.Time?
2503
        if coord.shape == ():
2504
            coord = coord.reshape((1,))
2505
2506
        relative_time = coord[valid_pix] - self.reference_time
2507
2508
        scale = interpolation_scale(self._interp)
2509
        valid_idx = idx[valid_pix]
2510
        s_min = scale(self._edges_min[valid_idx])
2511
        s_max = scale(self._edges_max[valid_idx])
2512
        s_coord = scale(relative_time.to(self._edges_min.unit))
2513
2514
        pix[valid_pix] += (s_coord - s_min) / (s_max - s_min)
2515
        pix[~valid_pix] = INVALID_INDEX.float
2516
        return pix - 0.5
2517
2518
    @staticmethod
2519
    def pix_to_idx(pix, clip=False):
2520
        return pix
2521
2522
    @property
2523
    def center(self):
2524
        """Return `~astropy.units.Quantity` at interval centers."""
2525
        return self.edges_min + 0.5 * self.bin_width
2526
2527
    @property
2528
    def bin_width(self):
2529
        """Return time interval width (`~astropy.units.Quantity`)."""
2530
        return self.time_delta.to("h")
2531
2532
    def __repr__(self):
2533
        str_ = self.__class__.__name__ + "\n"
2534
        str_ += "-" * len(self.__class__.__name__) + "\n\n"
2535
        fmt = "\t{:<14s} : {:<10s}\n"
2536
        str_ += fmt.format("name", self.name)
2537
        str_ += fmt.format("nbins", str(self.nbin))
2538
        str_ += fmt.format("reference time", self.reference_time.iso)
2539
        str_ += fmt.format("scale", self.reference_time.scale)
2540
        str_ += fmt.format("time min.", self.time_min.min().iso)
2541
        str_ += fmt.format("time max.", self.time_max.max().iso)
2542
        str_ += fmt.format("total time", np.sum(self.bin_width))
2543
        return str_.expandtabs(tabsize=2)
2544
2545
    def upsample(self):
2546
        raise NotImplementedError
2547
2548
    def downsample(self):
2549
        raise NotImplementedError
2550
2551 View Code Duplication
    def _init_copy(self, **kwargs):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
2552
        """Init map axis instance by copying missing init arguments from self."""
2553
        argnames = inspect.getfullargspec(self.__init__).args
2554
        argnames.remove("self")
2555
2556
        for arg in argnames:
2557
            value = getattr(self, "_" + arg)
2558
            kwargs.setdefault(arg, copy.deepcopy(value))
2559
2560
        return self.__class__(**kwargs)
2561
2562
    def copy(self, **kwargs):
2563
        """Copy `MapAxis` instance and overwrite given attributes.
2564
2565
        Parameters
2566
        ----------
2567
        **kwargs : dict
2568
            Keyword arguments to overwrite in the map axis constructor.
2569
2570
        Returns
2571
        -------
2572
        copy : `MapAxis`
2573
            Copied map axis.
2574
        """
2575
        return self._init_copy(**kwargs)
2576
2577
    def slice(self, idx):
2578
        """Create a new axis object by extracting a slice from this axis.
2579
2580
        Parameters
2581
        ----------
2582
        idx : slice
2583
            Slice object selecting a subselection of the axis.
2584
2585
        Returns
2586
        -------
2587
        axis : `~TimeMapAxis`
2588
            Sliced axis object.
2589
        """
2590
        return TimeMapAxis(
2591
            self._edges_min[idx].copy(),
2592
            self._edges_max[idx].copy(),
2593
            self.reference_time,
2594
            interp=self._interp,
2595
            name=self.name,
2596
        )
2597
2598
    def squash(self):
2599
        """Create a new axis object by squashing the axis into one bin.
2600
2601
        Returns
2602
        -------
2603
        axis : `~MapAxis`
2604
            Sliced axis object.
2605
        """
2606
        return TimeMapAxis(
2607
            self._edges_min[0],
2608
            self._edges_max[-1],
2609
            self.reference_time,
2610
            interp=self._interp,
2611
            name=self._name,
2612
        )
2613
2614
    # TODO: if we are to allow log or sqrt bins the reference time should always
2615
    #  be strictly lower than all times
2616
    #  Should we define a mechanism to ensure this is always correct?
2617
    @classmethod
2618
    def from_time_edges(cls, time_min, time_max, unit="d", interp="lin", name="time"):
2619
        """Create TimeMapAxis from the time interval edges defined as `~astropy.time.Time`.
2620
2621
        The reference time is defined as the lower edge of the first interval.
2622
2623
        Parameters
2624
        ----------
2625
        time_min : `~astropy.time.Time`
2626
            Array of lower edge times.
2627
        time_max : `~astropy.time.Time`
2628
            Array of lower edge times.
2629
        unit : `~astropy.units.Unit` or str
2630
            The unit to convert the edges to. Default is 'd' (day).
2631
        interp : str
2632
            Interpolation method used to transform between axis and pixel
2633
            coordinates.  Valid options are 'log', 'lin', and 'sqrt'.
2634
        name : str
2635
            Axis name
2636
2637
        Returns
2638
        -------
2639
        axis : `TimeMapAxis`
2640
            Time map axis.
2641
        """
2642
        unit = u.Unit(unit)
2643
        reference_time = time_min[0]
2644
        edges_min = time_min - reference_time
2645
        edges_max = time_max - reference_time
2646
2647
        return cls(
2648
            edges_min.to(unit),
2649
            edges_max.to(unit),
2650
            reference_time,
2651
            interp=interp,
2652
            name=name,
2653
        )
2654
2655
    # TODO: how configurable should that be? column names?
2656
    @classmethod
2657
    def from_table(cls, table, format="gadf", idx=0):
2658
        """Create time map axis from table
2659
2660
        Parameters
2661
        ----------
2662
        table : `~astropy.table.Table`
2663
            Bin table HDU
2664
        format : {"gadf", "fermi-fgl", "lightcurve"}
2665
            Format to use.
2666
2667
        Returns
2668
        -------
2669
        axis : `TimeMapAxis`
2670
            Time map axis.
2671
        """
2672
        if format == "gadf":
2673
            axcols = table.meta.get("AXCOLS{}".format(idx + 1))
2674
            colnames = axcols.split(",")
2675
            name = colnames[0].replace("_MIN", "").lower()
2676
            reference_time = time_ref_from_dict(table.meta)
2677
            edges_min = np.unique(table[colnames[0]].quantity)
2678
            edges_max = np.unique(table[colnames[1]].quantity)
2679
        elif format == "fermi-fgl":
2680
            meta = table.meta.copy()
2681
            meta["MJDREFF"] = str(meta["MJDREFF"]).replace("D-4", "e-4")
2682
            reference_time = time_ref_from_dict(meta=meta)
2683
            name = "time"
2684
            edges_min = table["Hist_Start"][:-1]
2685
            edges_max = table["Hist_Start"][1:]
2686
        elif format == "lightcurve":
2687
            # TODO: is this a good format? It just supports mjd...
2688
            name = "time"
2689
            scale = table.meta.get("TIMESYS", "utc")
2690
            time_min = Time(table["time_min"].data, format="mjd", scale=scale)
2691
            time_max = Time(table["time_max"].data, format="mjd", scale=scale)
2692
            reference_time = Time("2001-01-01T00:00:00")
2693
            reference_time.format = "mjd"
2694
            edges_min = (time_min - reference_time).to("s")
2695
            edges_max = (time_max - reference_time).to("s")
2696
        else:
2697
            raise ValueError(f"Not a supported format: {format}")
2698
2699
        return cls(
2700
            edges_min=edges_min,
2701
            edges_max=edges_max,
2702
            reference_time=reference_time,
2703
            name=name,
2704
        )
2705
2706
    @classmethod
2707
    def from_gti(cls, gti, name="time"):
2708
        """Create a time axis from an input GTI.
2709
2710
        Parameters
2711
        ----------
2712
        gti : `GTI`
2713
            GTI table
2714
        name : str
2715
            Axis name
2716
2717
        Returns
2718
        -------
2719
        axis : `TimeMapAxis`
2720
            Time map axis.
2721
2722
        """
2723
        tmin = gti.time_start - gti.time_ref
2724
        tmax = gti.time_stop - gti.time_ref
2725
2726
        return cls(
2727
            edges_min=tmin.to("s"),
2728
            edges_max=tmax.to("s"),
2729
            reference_time=gti.time_ref,
2730
            name=name,
2731
        )
2732
2733
    @classmethod
2734
    def from_time_bounds(cls, time_min, time_max, nbin, unit="d", name="time"):
2735
        """Create linearly spaced time axis from bounds
2736
2737
        Parameters
2738
        ----------
2739
        time_min : `~astropy.time.Time`
2740
            Lower bound
2741
        time_max : `~astropy.time.Time`
2742
            Upper bound
2743
        nbin : int
2744
            Number of bins
2745
        name : str
2746
            Name of the axis.
2747
        """
2748
        delta = time_max - time_min
2749
        time_edges = time_min + delta * np.linspace(0, 1, nbin + 1)
2750
        return cls.from_time_edges(
2751
            time_min=time_edges[:-1],
2752
            time_max=time_edges[1:],
2753
            interp="lin",
2754
            unit=unit,
2755
            name=name,
2756
        )
2757
2758
    def to_header(self, format="gadf", idx=0):
2759
        """Create FITS header
2760
2761
        Parameters
2762
        ----------
2763
        format : {"ogip"}
2764
            Format specification
2765
        idx : int
2766
            Column index of the axis.
2767
2768
        Returns
2769
        -------
2770
        header : `~astropy.io.fits.Header`
2771
            Header to extend.
2772
        """
2773
        header = fits.Header()
2774
2775
        if format == "gadf":
2776
            key = f"AXCOLS{idx}"
2777
            name = self.name.upper()
2778
            header[key] = f"{name}_MIN,{name}_MAX"
2779
            key_interp = f"INTERP{idx}"
2780
            header[key_interp] = self.interp
2781
2782
            ref_dict = time_ref_to_dict(self.reference_time)
2783
            header.update(ref_dict)
2784
        else:
2785
            raise ValueError(f"Unknown format {format}")
2786
2787
        return header
2788
2789
2790
class LabelMapAxis:
2791
    """Map axis using labels
2792
2793
    Parameters
2794
    ----------
2795
    labels : list of str
2796
        Labels to be used for the axis nodes.
2797
    name : str
2798
        Name of the axis.
2799
2800
    """
2801
2802
    node_type = "label"
2803
2804
    def __init__(self, labels, name=""):
2805
        unique_labels = np.unique(labels)
2806
2807
        if not len(unique_labels) == len(labels):
2808
            raise ValueError("Node labels must be unique")
2809
2810
        self._labels = unique_labels
2811
        self._name = name
2812
2813
    @property
2814
    def unit(self):
2815
        """Unit"""
2816
        return u.Unit("")
2817
2818
    @property
2819
    def name(self):
2820
        """Name of the axis"""
2821
        return self._name
2822
2823
    def assert_name(self, required_name):
2824
        """Assert axis name if a specific one is required.
2825
2826
        Parameters
2827
        ----------
2828
        required_name : str
2829
            Required
2830
        """
2831
        if self.name != required_name:
2832
            raise ValueError(
2833
                "Unexpected axis name,"
2834
                f' expected "{required_name}", got: "{self.name}"'
2835
            )
2836
2837
    @property
2838
    def nbin(self):
2839
        """Number of bins"""
2840
        return len(self._labels)
2841
2842
    def pix_to_coord(self, pix):
2843
        """Transform from pixel to axis coordinates.
2844
2845
        Parameters
2846
        ----------
2847
        pix : `~numpy.ndarray`
2848
            Array of pixel coordinate values.
2849
2850
        Returns
2851
        -------
2852
        coord : `~numpy.ndarray`
2853
            Array of axis coordinate values.
2854
        """
2855
        idx = np.round(pix).astype(int)
2856
        return self._labels[idx]
2857
2858
    def coord_to_idx(self, coord, **kwargs):
2859
        """Transform labels to indices
2860
2861
        If the label is not present an error is raised.
2862
2863
        Parameters
2864
        ----------
2865
        coord : `~astropy.time.Time`
2866
            Array of axis coordinate values.
2867
2868
        Returns
2869
        -------
2870
        idx : `~numpy.ndarray`
2871
            Array of bin indices.
2872
        """
2873
        coord = np.array(coord)[..., np.newaxis]
2874
        is_equal = coord == self._labels
2875
2876
        if not np.all(np.any(is_equal, axis=-1)):
2877
            label = coord[~np.any(is_equal, axis=-1)]
2878
            raise ValueError(f"Not a valid label: {label}")
2879
2880
        return np.argmax(is_equal, axis=-1)
2881
2882
    def coord_to_pix(self, coord):
2883
        """Transform from axis labels to pixel coordinates.
2884
2885
        Parameters
2886
        ----------
2887
        coord : `~numpy.ndarray`
2888
            Array of axis label values.
2889
2890
        Returns
2891
        -------
2892
        pix : `~numpy.ndarray`
2893
            Array of pixel coordinate values.
2894
        """
2895
        return self.coord_to_idx(coord).astype("float")
2896
2897 View Code Duplication
    def pix_to_idx(self, pix, clip=False):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
2898
        """Convert pix to idx
2899
2900
        Parameters
2901
        ----------
2902
        pix : tuple of `~numpy.ndarray`
2903
            Pixel coordinates.
2904
        clip : bool
2905
            Choose whether to clip indices to the valid range of the
2906
            axis.  If false then indices for coordinates outside
2907
            the axi range will be set -1.
2908
2909
        Returns
2910
        -------
2911
        idx : tuple `~numpy.ndarray`
2912
            Pixel indices.
2913
        """
2914
        if clip:
2915
            idx = np.clip(pix, 0, self.nbin - 1)
2916
        else:
2917
            condition = (pix < 0) | (pix >= self.nbin)
2918
            idx = np.where(condition, -1, pix)
2919
2920
        return idx
2921
2922
    @property
2923
    def center(self):
2924
        """Center of the label axis"""
2925
        return self._labels
2926
2927
    @property
2928
    def edges(self):
2929
        """Edges of the label axis"""
2930
        raise ValueError("A LabelMapAxis does not define edges")
2931
2932
    @property
2933
    def edges_min(self):
2934
        """Edges of the label axis"""
2935
        return self._labels
2936
2937
    @property
2938
    def edges_max(self):
2939
        """Edges of the label axis"""
2940
        return self._labels
2941
2942
    @property
2943
    def bin_width(self):
2944
        """Bin width is unity"""
2945
        return np.ones(self.nbin)
2946
2947
    @property
2948
    def as_plot_xerr(self):
2949
        """Plot labels"""
2950
        return 0.5 * np.ones(self.nbin)
2951
2952
    @property
2953
    def as_plot_labels(self):
2954
        """Plot labels"""
2955
        return self._labels.tolist()
2956
2957
    @property
2958
    def as_plot_center(self):
2959
        """Plot labels"""
2960
        return np.arange(self.nbin)
2961
2962
    @property
2963
    def as_plot_edges(self):
2964
        """Plot labels"""
2965
        return np.arange(self.nbin + 1) - 0.5
2966
2967
    def format_plot_xaxis(self, ax):
2968
        """Format plot axis.
2969
2970
        Parameters
2971
        ----------
2972
        ax : `~matplotlib.pyplot.Axis`
2973
            Plot axis to format.
2974
2975
        Returns
2976
        -------
2977
        ax : `~matplotlib.pyplot.Axis`
2978
            Formatted plot axis.
2979
        """
2980
        ax.set_xticks(self.as_plot_center)
2981
        ax.set_xticklabels(
2982
            self.as_plot_labels,
2983
            rotation=30,
2984
            ha="right",
2985
            rotation_mode="anchor",
2986
        )
2987
        return ax
2988
2989
    def to_header(self, format="gadf", idx=0):
2990
        """Create FITS header
2991
2992
        Parameters
2993
        ----------
2994
        format : {"ogip"}
2995
            Format specification
2996
        idx : int
2997
            Column index of the axis.
2998
2999
        Returns
3000
        -------
3001
        header : `~astropy.io.fits.Header`
3002
            Header to extend.
3003
        """
3004
        header = fits.Header()
3005
3006
        if format == "gadf":
3007
            key = f"AXCOLS{idx}"
3008
            header[key] = self.name.upper()
3009
        else:
3010
            raise ValueError(f"Unknown format {format}")
3011
3012
        return header
3013
3014
    # TODO: how configurable should that be? column names?
3015
    @classmethod
3016
    def from_table(cls, table, format="gadf", idx=0):
3017
        """Create time map axis from table
3018
3019
        Parameters
3020
        ----------
3021
        table : `~astropy.table.Table`
3022
            Bin table HDU
3023
        format : {"gadf"}
3024
            Format to use.
3025
3026
        Returns
3027
        -------
3028
        axis : `TimeMapAxis`
3029
            Time map axis.
3030
        """
3031
        if format == "gadf":
3032
            colname = table.meta.get("AXCOLS{}".format(idx + 1))
3033
            column = table[colname]
3034
            if not np.issubdtype(column.dtype, np.str_):
3035
                raise TypeError(f"Not a valid dtype for label axis: '{column.dtype}'")
3036
            labels = np.unique(column.data)
3037
        else:
3038
            raise ValueError(f"Not a supported format: {format}")
3039
3040
        return cls(labels=labels, name=colname.lower())
3041
3042
    def __repr__(self):
3043
        str_ = self.__class__.__name__ + "\n"
3044
        str_ += "-" * len(self.__class__.__name__) + "\n\n"
3045
        fmt = "\t{:<10s} : {:<10s}\n"
3046
        str_ += fmt.format("name", self.name)
3047
        str_ += fmt.format("nbins", str(self.nbin))
3048
        str_ += fmt.format("node type", self.node_type)
3049
        str_ += fmt.format("labels", "{0}".format(list(self._labels)))
3050
        return str_.expandtabs(tabsize=2)
3051
3052
    def is_allclose(self, other, **kwargs):
3053
        """Check if other map axis is all close.
3054
3055
        Parameters
3056
        ----------
3057
        other : `LabelMapAxis`
3058
            Other map axis
3059
3060
        Returns
3061
        -------
3062
        is_allclose : bool
3063
            Whether other axis is allclose
3064
        """
3065
        if not isinstance(other, self.__class__):
3066
            return TypeError(f"Cannot compare {type(self)} and {type(other)}")
3067
3068
        name_equal = self.name.upper() == other.name.upper()
3069
        labels_equal = np.all(self.center == other.center)
3070
        return name_equal & labels_equal
3071
3072
    def __eq__(self, other):
3073
        if not isinstance(other, self.__class__):
3074
            return False
3075
3076
        return self.is_allclose(other=other)
3077
3078
    def __ne__(self, other):
3079
        return not self.__eq__(other)
3080
3081
    # TODO: could create sub-labels here using dashes like "label-1-a", etc.
3082
    def upsample(self, *args, **kwargs):
3083
        """Upsample axis"""
3084
        raise NotImplementedError("Upsampling a LabelMapAxis is not supported")
3085
3086
    # TODO: could merge labels here like "label-1-label2", etc.
3087
    def downsample(self, *args, **kwargs):
3088
        """Downsample axis"""
3089
        raise NotImplementedError("Downsampling a LabelMapAxis is not supported")
3090
3091
    # TODO: could merge labels here like "label-1-label2", etc.
3092
    def resample(self, *args, **kwargs):
3093
        """Resample axis"""
3094
        raise NotImplementedError("Resampling a LabelMapAxis is not supported")
3095
3096
    # TODO: could create new labels here like "label-10-a"
3097
    def pad(self, *args, **kwargs):
3098
        """Resample axis"""
3099
        raise NotImplementedError("Padding a LabelMapAxis is not supported")
3100
3101
    def copy(self):
3102
        """Copy axis"""
3103
        return copy.deepcopy(self)
3104
3105
    def slice(self, idx):
3106
        """Create a new axis object by extracting a slice from this axis.
3107
3108
        Parameters
3109
        ----------
3110
        idx : slice
3111
            Slice object selecting a subselection of the axis.
3112
3113
        Returns
3114
        -------
3115
        axis : `~LabelMapAxis`
3116
            Sliced axis object.
3117
        """
3118
        return self.__class__(
3119
            labels=self._labels[idx],
3120
            name=self.name,
3121
        )
3122