Total Complexity | 408 |
Total Lines | 3121 |
Duplicated Lines | 2.24 % |
Changes | 0 |
Duplicate code is one of the most pungent code smells. A rule that is often used is to re-structure code once it is duplicated in three or more places.
Common duplication problems, and corresponding solutions are:
Complex classes like gammapy.maps.axes often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
1 | # Licensed under a 3-clause BSD style license - see LICENSE.rst |
||
2 | import copy |
||
3 | import inspect |
||
4 | from collections.abc import Sequence |
||
5 | import numpy as np |
||
6 | import scipy |
||
7 | import astropy.units as u |
||
8 | from astropy.io import fits |
||
9 | from astropy.table import Column, Table, hstack |
||
10 | from astropy.time import Time |
||
11 | from astropy.utils import lazyproperty |
||
12 | import matplotlib.pyplot as plt |
||
13 | from gammapy.utils.interpolation import interpolation_scale |
||
14 | from gammapy.utils.time import time_ref_from_dict, time_ref_to_dict |
||
15 | from .utils import INVALID_INDEX, edges_from_lo_hi |
||
16 | |||
17 | __all__ = ["MapAxes", "MapAxis", "TimeMapAxis", "LabelMapAxis"] |
||
18 | |||
19 | |||
20 | def flat_if_equal(array): |
||
21 | if array.ndim == 2 and np.all(array == array[0]): |
||
22 | return array[0] |
||
23 | else: |
||
24 | return array |
||
25 | |||
26 | |||
27 | class AxisCoordInterpolator: |
||
28 | """Axis coord interpolator""" |
||
29 | |||
30 | def __init__(self, edges, interp="lin"): |
||
31 | self.scale = interpolation_scale(interp) |
||
32 | self.x = self.scale(edges) |
||
33 | self.y = np.arange(len(edges), dtype=float) |
||
34 | self.fill_value = "extrapolate" |
||
35 | |||
36 | if len(edges) == 1: |
||
37 | self.kind = 0 |
||
38 | else: |
||
39 | self.kind = 1 |
||
40 | |||
41 | def coord_to_pix(self, coord): |
||
42 | """Pix to coord""" |
||
43 | interp_fn = scipy.interpolate.interp1d( |
||
44 | x=self.x, y=self.y, kind=self.kind, fill_value=self.fill_value |
||
45 | ) |
||
46 | return interp_fn(self.scale(coord)) |
||
47 | |||
48 | def pix_to_coord(self, pix): |
||
49 | """Coord to pix""" |
||
50 | interp_fn = scipy.interpolate.interp1d( |
||
51 | x=self.y, y=self.x, kind=self.kind, fill_value=self.fill_value |
||
52 | ) |
||
53 | return self.scale.inverse(interp_fn(pix)) |
||
54 | |||
55 | |||
56 | PLOT_AXIS_LABEL = { |
||
57 | "energy": "Energy", |
||
58 | "energy_true": "True Energy", |
||
59 | "offset": "FoV Offset", |
||
60 | "rad": "Source Offset", |
||
61 | "migra": "Energy / True Energy", |
||
62 | "fov_lon": "FoV Lon.", |
||
63 | "fov_lat": "FoV Lat.", |
||
64 | "time": "Time", |
||
65 | } |
||
66 | |||
67 | DEFAULT_LABEL_TEMPLATE = "{quantity} [{unit}]" |
||
68 | |||
69 | |||
70 | class MapAxis: |
||
71 | """Class representing an axis of a map. |
||
72 | |||
73 | Provides methods for |
||
74 | transforming to/from axis and pixel coordinates. An axis is |
||
75 | defined by a sequence of node values that lie at the center of |
||
76 | each bin. The pixel coordinate at each node is equal to its index |
||
77 | in the node array (0, 1, ..). Bin edges are offset by 0.5 in |
||
78 | pixel coordinates from the nodes such that the lower/upper edge of |
||
79 | the first bin is (-0.5,0.5). |
||
80 | |||
81 | Parameters |
||
82 | ---------- |
||
83 | nodes : `~numpy.ndarray` or `~astropy.units.Quantity` |
||
84 | Array of node values. These will be interpreted as either bin |
||
85 | edges or centers according to ``node_type``. |
||
86 | interp : str |
||
87 | Interpolation method used to transform between axis and pixel |
||
88 | coordinates. Valid options are 'log', 'lin', and 'sqrt'. |
||
89 | name : str |
||
90 | Axis name |
||
91 | node_type : str |
||
92 | Flag indicating whether coordinate nodes correspond to pixel |
||
93 | edges (node_type = 'edges') or pixel centers (node_type = |
||
94 | 'center'). 'center' should be used where the map values are |
||
95 | defined at a specific coordinate (e.g. differential |
||
96 | quantities). 'edges' should be used where map values are |
||
97 | defined by an integral over coordinate intervals (e.g. a |
||
98 | counts histogram). |
||
99 | unit : str |
||
100 | String specifying the data units. |
||
101 | """ |
||
102 | |||
103 | # TODO: Cache an interpolation object? |
||
104 | def __init__(self, nodes, interp="lin", name="", node_type="edges", unit=""): |
||
105 | |||
106 | if not isinstance(name, str): |
||
107 | raise TypeError(f"Name must be a string, got: {type(name)!r}") |
||
108 | |||
109 | if len(nodes) != len(np.unique(nodes)): |
||
110 | raise ValueError("MapAxis: node values must be unique") |
||
111 | |||
112 | if ~(np.all(nodes == np.sort(nodes)) or np.all(nodes[::-1] == np.sort(nodes))): |
||
113 | raise ValueError("MapAxis: node values must be sorted") |
||
114 | |||
115 | if isinstance(nodes, u.Quantity): |
||
116 | unit = nodes.unit if nodes.unit is not None else "" |
||
117 | nodes = nodes.value |
||
118 | else: |
||
119 | nodes = np.array(nodes) |
||
120 | |||
121 | self._name = name |
||
122 | self._unit = u.Unit(unit) |
||
123 | self._nodes = nodes.astype(float) |
||
124 | self._node_type = node_type |
||
125 | self._interp = interp |
||
126 | |||
127 | if (self._nodes < 0).any() and interp != "lin": |
||
128 | raise ValueError( |
||
129 | f"Interpolation scaling {interp!r} only support for positive node values." |
||
130 | ) |
||
131 | |||
132 | # Set pixel coordinate of first node |
||
133 | if node_type == "edges": |
||
134 | self._pix_offset = -0.5 |
||
135 | nbin = len(nodes) - 1 |
||
136 | elif node_type == "center": |
||
137 | self._pix_offset = 0.0 |
||
138 | nbin = len(nodes) |
||
139 | else: |
||
140 | raise ValueError(f"Invalid node type: {node_type!r}") |
||
141 | |||
142 | self._nbin = nbin |
||
143 | self._use_center_as_plot_labels = None |
||
144 | |||
145 | def assert_name(self, required_name): |
||
146 | """Assert axis name if a specific one is required. |
||
147 | |||
148 | Parameters |
||
149 | ---------- |
||
150 | required_name : str |
||
151 | Required |
||
152 | """ |
||
153 | if self.name != required_name: |
||
154 | raise ValueError( |
||
155 | "Unexpected axis name," |
||
156 | f' expected "{required_name}", got: "{self.name}"' |
||
157 | ) |
||
158 | |||
159 | def is_aligned(self, other, atol=2e-2): |
||
160 | """Check if other map axis is aligned. |
||
161 | |||
162 | Two axes are aligned if their center coordinate values map to integers |
||
163 | on the other axes as well and if the interpolation modes are equivalent. |
||
164 | |||
165 | Parameters |
||
166 | ---------- |
||
167 | other : `MapAxis` |
||
168 | Other map axis. |
||
169 | atol : float |
||
170 | Absolute numerical tolerance for the comparison measured in bins. |
||
171 | |||
172 | Returns |
||
173 | ------- |
||
174 | aligned : bool |
||
175 | Whether the axes are aligned |
||
176 | """ |
||
177 | pix = self.coord_to_pix(other.center) |
||
178 | pix_other = other.coord_to_pix(self.center) |
||
179 | pix_all = np.append(pix, pix_other) |
||
180 | aligned = np.allclose(np.round(pix_all) - pix_all, 0, atol=atol) |
||
181 | return aligned and self.interp == other.interp |
||
182 | |||
183 | def is_allclose(self, other, **kwargs): |
||
184 | """Check if other map axis is all close. |
||
185 | |||
186 | Parameters |
||
187 | ---------- |
||
188 | other : `MapAxis` |
||
189 | Other map axis |
||
190 | **kwargs : dict |
||
191 | Keyword arguments forwarded to `~numpy.allclose` |
||
192 | |||
193 | Returns |
||
194 | ------- |
||
195 | is_allclose : bool |
||
196 | Whether other axis is allclose |
||
197 | """ |
||
198 | if not isinstance(other, self.__class__): |
||
199 | return TypeError(f"Cannot compare {type(self)} and {type(other)}") |
||
200 | |||
201 | if self.edges.shape != other.edges.shape: |
||
202 | return False |
||
203 | if not self.unit.is_equivalent(other.unit): |
||
204 | return False |
||
205 | return ( |
||
206 | np.allclose(self.edges, other.edges, **kwargs) |
||
207 | and self._node_type == other._node_type |
||
208 | and self._interp == other._interp |
||
209 | and self.name.upper() == other.name.upper() |
||
210 | ) |
||
211 | |||
212 | def __eq__(self, other): |
||
213 | if not isinstance(other, self.__class__): |
||
214 | return False |
||
215 | |||
216 | return self.is_allclose(other, rtol=1e-6, atol=1e-6) |
||
217 | |||
218 | def __ne__(self, other): |
||
219 | return not self.__eq__(other) |
||
220 | |||
221 | def __hash__(self): |
||
222 | return id(self) |
||
223 | |||
224 | @lazyproperty |
||
225 | def _transform(self): |
||
226 | """Interpolate coordinates to pixel""" |
||
227 | return AxisCoordInterpolator(edges=self._nodes, interp=self.interp) |
||
228 | |||
229 | @property |
||
230 | def is_energy_axis(self): |
||
231 | return self.name in ["energy", "energy_true"] |
||
232 | |||
233 | @property |
||
234 | def interp(self): |
||
235 | """Interpolation scale of the axis.""" |
||
236 | return self._interp |
||
237 | |||
238 | @property |
||
239 | def name(self): |
||
240 | """Name of the axis.""" |
||
241 | return self._name |
||
242 | |||
243 | @lazyproperty |
||
244 | def edges(self): |
||
245 | """Return array of bin edges.""" |
||
246 | pix = np.arange(self.nbin + 1, dtype=float) - 0.5 |
||
247 | return u.Quantity(self.pix_to_coord(pix), self._unit, copy=False) |
||
248 | |||
249 | @property |
||
250 | def edges_min(self): |
||
251 | """Return array of bin edges max values.""" |
||
252 | return self.edges[:-1] |
||
253 | |||
254 | @property |
||
255 | def edges_max(self): |
||
256 | """Return array of bin edges min values.""" |
||
257 | return self.edges[1:] |
||
258 | |||
259 | @property |
||
260 | def bounds(self): |
||
261 | """Bounds of the axis (~astropy.units.Quantity)""" |
||
262 | idx = [0, -1] |
||
263 | if self.node_type == "edges": |
||
264 | return self.edges[idx] |
||
265 | else: |
||
266 | return self.center[idx] |
||
267 | |||
268 | @property |
||
269 | def as_plot_xerr(self): |
||
270 | """Return tuple of xerr to be used with plt.errorbar()""" |
||
271 | return ( |
||
272 | self.center - self.edges_min, |
||
273 | self.edges_max - self.center, |
||
274 | ) |
||
275 | |||
276 | @property |
||
277 | def use_center_as_plot_labels(self): |
||
278 | """Use center as plot labels""" |
||
279 | if self._use_center_as_plot_labels is not None: |
||
280 | return self._use_center_as_plot_labels |
||
281 | |||
282 | return self.node_type == "center" |
||
283 | |||
284 | @use_center_as_plot_labels.setter |
||
285 | def use_center_as_plot_labels(self, value): |
||
286 | """Use center as plot labels""" |
||
287 | self._use_center_as_plot_labels = bool(value) |
||
288 | |||
289 | @property |
||
290 | def as_plot_labels(self): |
||
291 | """Return list of axis plot labels""" |
||
292 | if self.use_center_as_plot_labels: |
||
293 | labels = [f"{val:.2e}" for val in self.center] |
||
294 | else: |
||
295 | labels = [ |
||
296 | f"{val_min:.2e} - {val_max:.2e}" |
||
297 | for val_min, val_max in self.iter_by_edges |
||
298 | ] |
||
299 | return labels |
||
300 | |||
301 | @property |
||
302 | def as_plot_edges(self): |
||
303 | """Plot edges""" |
||
304 | return self.edges |
||
305 | |||
306 | @property |
||
307 | def as_plot_center(self): |
||
308 | """Plot center""" |
||
309 | return self.center |
||
310 | |||
311 | @property |
||
312 | def as_plot_scale(self): |
||
313 | """Plot axis scale""" |
||
314 | mpl_scale = {"lin": "linear", "sqrt": "linear", "log": "log"} |
||
315 | |||
316 | return mpl_scale[self.interp] |
||
317 | |||
318 | def to_node_type(self, node_type): |
||
319 | """Return MapAxis copy changing its node type to node_type. |
||
320 | |||
321 | Parameters |
||
322 | ---------- |
||
323 | node_type : str 'edges' or 'center' |
||
324 | the target node type |
||
325 | |||
326 | Returns |
||
327 | ------- |
||
328 | axis : `~gammapy.maps.MapAxis` |
||
329 | the new MapAxis |
||
330 | """ |
||
331 | if node_type == self.node_type: |
||
332 | return self |
||
333 | else: |
||
334 | if node_type == "center": |
||
335 | nodes = self.center |
||
336 | else: |
||
337 | nodes = self.edges |
||
338 | return self.__class__( |
||
339 | nodes=nodes, |
||
340 | interp=self.interp, |
||
341 | name=self.name, |
||
342 | node_type=node_type, |
||
343 | unit=self.unit, |
||
344 | ) |
||
345 | |||
346 | def rename(self, new_name): |
||
347 | """Rename the axis. |
||
348 | |||
349 | Parameters |
||
350 | ---------- |
||
351 | new_name : str |
||
352 | The new name for the axis. |
||
353 | |||
354 | Returns |
||
355 | ------- |
||
356 | axis : `~gammapy.maps.MapAxis` |
||
357 | Renamed MapAxis |
||
358 | """ |
||
359 | return self.copy(name=new_name) |
||
360 | |||
361 | def format_plot_xaxis(self, ax): |
||
362 | """Format plot axis |
||
363 | |||
364 | Parameters |
||
365 | ---------- |
||
366 | ax : `~matplotlib.pyplot.Axis` |
||
367 | Plot axis to format |
||
368 | |||
369 | Returns |
||
370 | ------- |
||
371 | ax : `~matplotlib.pyplot.Axis` |
||
372 | Formatted plot axis |
||
373 | """ |
||
374 | ax.set_xscale(self.as_plot_scale) |
||
375 | |||
376 | xlabel = DEFAULT_LABEL_TEMPLATE.format( |
||
377 | quantity=PLOT_AXIS_LABEL.get(self.name, self.name.capitalize()), |
||
378 | unit=ax.xaxis.units, |
||
379 | ) |
||
380 | ax.set_xlabel(xlabel) |
||
381 | xmin, xmax = self.bounds |
||
382 | if not xmin == xmax: |
||
383 | ax.set_xlim(self.bounds) |
||
384 | return ax |
||
385 | |||
386 | def format_plot_yaxis(self, ax): |
||
387 | """Format plot axis |
||
388 | |||
389 | Parameters |
||
390 | ---------- |
||
391 | ax : `~matplotlib.pyplot.Axis` |
||
392 | Plot axis to format |
||
393 | |||
394 | Returns |
||
395 | ------- |
||
396 | ax : `~matplotlib.pyplot.Axis` |
||
397 | Formatted plot axis |
||
398 | """ |
||
399 | ax.set_yscale(self.as_plot_scale) |
||
400 | |||
401 | ylabel = DEFAULT_LABEL_TEMPLATE.format( |
||
402 | quantity=PLOT_AXIS_LABEL.get(self.name, self.name.capitalize()), |
||
403 | unit=ax.yaxis.units, |
||
404 | ) |
||
405 | ax.set_ylabel(ylabel) |
||
406 | ax.set_ylim(self.bounds) |
||
407 | return ax |
||
408 | |||
409 | @property |
||
410 | def iter_by_edges(self): |
||
411 | """Iterate by intervals defined by the edges""" |
||
412 | for value_min, value_max in zip(self.edges[:-1], self.edges[1:]): |
||
413 | yield (value_min, value_max) |
||
414 | |||
415 | @lazyproperty |
||
416 | def center(self): |
||
417 | """Return array of bin centers.""" |
||
418 | pix = np.arange(self.nbin, dtype=float) |
||
419 | return u.Quantity(self.pix_to_coord(pix), self._unit, copy=False) |
||
420 | |||
421 | @lazyproperty |
||
422 | def bin_width(self): |
||
423 | """Array of bin widths.""" |
||
424 | return np.diff(self.edges) |
||
425 | |||
426 | @property |
||
427 | def nbin(self): |
||
428 | """Return number of bins.""" |
||
429 | return self._nbin |
||
430 | |||
431 | @property |
||
432 | def nbin_per_decade(self): |
||
433 | """Return number of bins.""" |
||
434 | if self.interp != "log": |
||
435 | raise ValueError("Bins per decade can only be computed for log-spaced axes") |
||
436 | |||
437 | if self.node_type == "edges": |
||
438 | values = self.edges |
||
439 | else: |
||
440 | values = self.center |
||
441 | |||
442 | ndecades = np.log10(values.max() / values.min()) |
||
443 | return (self._nbin / ndecades).value |
||
444 | |||
445 | @property |
||
446 | def node_type(self): |
||
447 | """Return node type ('center' or 'edges').""" |
||
448 | return self._node_type |
||
449 | |||
450 | @property |
||
451 | def unit(self): |
||
452 | """Return coordinate axis unit.""" |
||
453 | return self._unit |
||
454 | |||
455 | @classmethod |
||
456 | def from_bounds(cls, lo_bnd, hi_bnd, nbin, **kwargs): |
||
457 | """Generate an axis object from a lower/upper bound and number of bins. |
||
458 | |||
459 | If node_type = 'edges' then bounds correspond to the |
||
460 | lower and upper bound of the first and last bin. If node_type |
||
461 | = 'center' then bounds correspond to the centers of the first |
||
462 | and last bin. |
||
463 | |||
464 | Parameters |
||
465 | ---------- |
||
466 | lo_bnd : float |
||
467 | Lower bound of first axis bin. |
||
468 | hi_bnd : float |
||
469 | Upper bound of last axis bin. |
||
470 | nbin : int |
||
471 | Number of bins. |
||
472 | interp : {'lin', 'log', 'sqrt'} |
||
473 | Interpolation method used to transform between axis and pixel |
||
474 | coordinates. Default: 'lin'. |
||
475 | """ |
||
476 | nbin = int(nbin) |
||
477 | interp = kwargs.setdefault("interp", "lin") |
||
478 | node_type = kwargs.setdefault("node_type", "edges") |
||
479 | |||
480 | if node_type == "edges": |
||
481 | nnode = nbin + 1 |
||
482 | elif node_type == "center": |
||
483 | nnode = nbin |
||
484 | else: |
||
485 | raise ValueError(f"Invalid node type: {node_type!r}") |
||
486 | |||
487 | if interp == "lin": |
||
488 | nodes = np.linspace(lo_bnd, hi_bnd, nnode) |
||
489 | elif interp == "log": |
||
490 | nodes = np.exp(np.linspace(np.log(lo_bnd), np.log(hi_bnd), nnode)) |
||
491 | elif interp == "sqrt": |
||
492 | nodes = np.linspace(lo_bnd**0.5, hi_bnd**0.5, nnode) ** 2.0 |
||
493 | else: |
||
494 | raise ValueError(f"Invalid interp: {interp}") |
||
495 | |||
496 | return cls(nodes, **kwargs) |
||
497 | |||
498 | @classmethod |
||
499 | def from_energy_edges(cls, energy_edges, unit=None, name=None, interp="log"): |
||
500 | """Make an energy axis from adjacent edges. |
||
501 | |||
502 | Parameters |
||
503 | ---------- |
||
504 | energy_edges : `~astropy.units.Quantity`, float |
||
505 | Energy edges |
||
506 | unit : `~astropy.units.Unit` |
||
507 | Energy unit |
||
508 | name : str |
||
509 | Name of the energy axis, either 'energy' or 'energy_true' |
||
510 | interp: str |
||
511 | interpolation mode. Default is 'log'. |
||
512 | |||
513 | Returns |
||
514 | ------- |
||
515 | axis : `MapAxis` |
||
516 | Axis with name "energy" and interp "log". |
||
517 | """ |
||
518 | energy_edges = u.Quantity(energy_edges, unit) |
||
519 | |||
520 | if not energy_edges.unit.is_equivalent("TeV"): |
||
521 | raise ValueError( |
||
522 | f"Please provide a valid energy unit, got {energy_edges.unit} instead." |
||
523 | ) |
||
524 | |||
525 | if name is None: |
||
526 | name = "energy" |
||
527 | |||
528 | if name not in ["energy", "energy_true"]: |
||
529 | raise ValueError("Energy axis can only be named 'energy' or 'energy_true'") |
||
530 | |||
531 | return cls.from_edges(energy_edges, unit=unit, interp=interp, name=name) |
||
532 | |||
533 | @classmethod |
||
534 | def from_energy_bounds( |
||
535 | cls, |
||
536 | energy_min, |
||
537 | energy_max, |
||
538 | nbin, |
||
539 | unit=None, |
||
540 | per_decade=False, |
||
541 | name=None, |
||
542 | node_type="edges", |
||
543 | ): |
||
544 | """Make an energy axis. |
||
545 | |||
546 | Used frequently also to make energy grids, by making |
||
547 | the axis, and then using ``axis.center`` or ``axis.edges``. |
||
548 | |||
549 | Parameters |
||
550 | ---------- |
||
551 | energy_min, energy_max : `~astropy.units.Quantity`, float |
||
552 | Energy range |
||
553 | nbin : int |
||
554 | Number of bins |
||
555 | unit : `~astropy.units.Unit` |
||
556 | Energy unit |
||
557 | per_decade : bool |
||
558 | Whether `nbin` is given per decade. |
||
559 | name : str |
||
560 | Name of the energy axis, either 'energy' or 'energy_true' |
||
561 | |||
562 | Returns |
||
563 | ------- |
||
564 | axis : `MapAxis` |
||
565 | Axis with name "energy" and interp "log". |
||
566 | """ |
||
567 | energy_min = u.Quantity(energy_min, unit) |
||
568 | energy_max = u.Quantity(energy_max, unit) |
||
569 | |||
570 | if unit is None: |
||
571 | unit = energy_max.unit |
||
572 | energy_min = energy_min.to(unit) |
||
573 | |||
574 | if not energy_max.unit.is_equivalent("TeV"): |
||
575 | raise ValueError( |
||
576 | f"Please provide a valid energy unit, got {energy_max.unit} instead." |
||
577 | ) |
||
578 | |||
579 | if per_decade: |
||
580 | nbin = np.ceil(np.log10(energy_max / energy_min).value * nbin) |
||
581 | |||
582 | if name is None: |
||
583 | name = "energy" |
||
584 | |||
585 | if name not in ["energy", "energy_true"]: |
||
586 | raise ValueError("Energy axis can only be named 'energy' or 'energy_true'") |
||
587 | |||
588 | return cls.from_bounds( |
||
589 | energy_min.value, |
||
590 | energy_max.value, |
||
591 | nbin=nbin, |
||
592 | unit=unit, |
||
593 | interp="log", |
||
594 | name=name, |
||
595 | node_type=node_type, |
||
596 | ) |
||
597 | |||
598 | @classmethod |
||
599 | def from_nodes(cls, nodes, **kwargs): |
||
600 | """Generate an axis object from a sequence of nodes (bin centers). |
||
601 | |||
602 | This will create a sequence of bins with edges half-way |
||
603 | between the node values. This method should be used to |
||
604 | construct an axis where the bin center should lie at a |
||
605 | specific value (e.g. a map of a continuous function). |
||
606 | |||
607 | Parameters |
||
608 | ---------- |
||
609 | nodes : `~numpy.ndarray` |
||
610 | Axis nodes (bin center). |
||
611 | interp : {'lin', 'log', 'sqrt'} |
||
612 | Interpolation method used to transform between axis and pixel |
||
613 | coordinates. Default: 'lin'. |
||
614 | """ |
||
615 | if len(nodes) < 1: |
||
616 | raise ValueError("Nodes array must have at least one element.") |
||
617 | |||
618 | return cls(nodes, node_type="center", **kwargs) |
||
619 | |||
620 | @classmethod |
||
621 | def from_edges(cls, edges, **kwargs): |
||
622 | """Generate an axis object from a sequence of bin edges. |
||
623 | |||
624 | This method should be used to construct an axis where the bin |
||
625 | edges should lie at specific values (e.g. a histogram). The |
||
626 | number of bins will be one less than the number of edges. |
||
627 | |||
628 | Parameters |
||
629 | ---------- |
||
630 | edges : `~numpy.ndarray` |
||
631 | Axis bin edges. |
||
632 | interp : {'lin', 'log', 'sqrt'} |
||
633 | Interpolation method used to transform between axis and pixel |
||
634 | coordinates. Default: 'lin'. |
||
635 | """ |
||
636 | if len(edges) < 2: |
||
637 | raise ValueError("Edges array must have at least two elements.") |
||
638 | |||
639 | return cls(edges, node_type="edges", **kwargs) |
||
640 | |||
641 | def append(self, axis): |
||
642 | """Append another map axis to this axis |
||
643 | |||
644 | Name, interp type and node type must agree between the axes. If the node |
||
645 | type is "edges", the edges must be contiguous and non-overlapping. |
||
646 | |||
647 | Parameters |
||
648 | ---------- |
||
649 | axis : `MapAxis` |
||
650 | Axis to append. |
||
651 | |||
652 | Returns |
||
653 | ------- |
||
654 | axis : `MapAxis` |
||
655 | Appended axis |
||
656 | """ |
||
657 | if self.node_type != axis.node_type: |
||
658 | raise ValueError( |
||
659 | f"Node type must agree, got {self.node_type} and {axis.node_type}" |
||
660 | ) |
||
661 | |||
662 | if self.name != axis.name: |
||
663 | raise ValueError(f"Names must agree, got {self.name} and {axis.name} ") |
||
664 | |||
665 | if self.interp != axis.interp: |
||
666 | raise ValueError( |
||
667 | f"Interp type must agree, got {self.interp} and {axis.interp}" |
||
668 | ) |
||
669 | |||
670 | if self.node_type == "edges": |
||
671 | edges = np.append(self.edges, axis.edges[1:]) |
||
672 | return self.from_edges(edges=edges, interp=self.interp, name=self.name) |
||
673 | else: |
||
674 | nodes = np.append(self.center, axis.center) |
||
675 | return self.from_nodes(nodes=nodes, interp=self.interp, name=self.name) |
||
676 | |||
677 | def pad(self, pad_width): |
||
678 | """Pad axis by a given number of pixels |
||
679 | |||
680 | Parameters |
||
681 | ---------- |
||
682 | pad_width : int or tuple of int |
||
683 | A single int pads in both direction of the axis, a tuple specifies, |
||
684 | which number of bins to pad at the low and high edge of the axis. |
||
685 | |||
686 | Returns |
||
687 | ------- |
||
688 | axis : `MapAxis` |
||
689 | Padded axis |
||
690 | """ |
||
691 | if isinstance(pad_width, tuple): |
||
692 | pad_low, pad_high = pad_width |
||
693 | else: |
||
694 | pad_low, pad_high = pad_width, pad_width |
||
695 | |||
696 | if self.node_type == "edges": |
||
697 | pix = np.arange(-pad_low, self.nbin + pad_high + 1) - 0.5 |
||
698 | edges = self.pix_to_coord(pix) |
||
699 | return self.from_edges(edges=edges, interp=self.interp, name=self.name) |
||
700 | else: |
||
701 | pix = np.arange(-pad_low, self.nbin + pad_high) |
||
702 | nodes = self.pix_to_coord(pix) |
||
703 | return self.from_nodes(nodes=nodes, interp=self.interp, name=self.name) |
||
704 | |||
705 | @classmethod |
||
706 | def from_stack(cls, axes): |
||
707 | """Create a map axis by merging a list of other map axes. |
||
708 | |||
709 | If the node type is "edges" the bin edges in the provided axes must be |
||
710 | contiguous and non-overlapping. |
||
711 | |||
712 | Parameters |
||
713 | ---------- |
||
714 | axes : list of `MapAxis` |
||
715 | List of map axis to merge. |
||
716 | |||
717 | Returns |
||
718 | ------- |
||
719 | axis : `MapAxis` |
||
720 | Merged axis |
||
721 | """ |
||
722 | ax_stacked = axes[0] |
||
723 | |||
724 | for ax in axes[1:]: |
||
725 | ax_stacked = ax_stacked.append(ax) |
||
726 | |||
727 | return ax_stacked |
||
728 | |||
729 | def pix_to_coord(self, pix): |
||
730 | """Transform from pixel to axis coordinates. |
||
731 | |||
732 | Parameters |
||
733 | ---------- |
||
734 | pix : `~numpy.ndarray` |
||
735 | Array of pixel coordinate values. |
||
736 | |||
737 | Returns |
||
738 | ------- |
||
739 | coord : `~numpy.ndarray` |
||
740 | Array of axis coordinate values. |
||
741 | """ |
||
742 | pix = pix - self._pix_offset |
||
743 | values = self._transform.pix_to_coord(pix=pix) |
||
744 | return u.Quantity(values, unit=self.unit, copy=False) |
||
745 | |||
746 | View Code Duplication | def pix_to_idx(self, pix, clip=False): |
|
|
|||
747 | """Convert pix to idx |
||
748 | |||
749 | Parameters |
||
750 | ---------- |
||
751 | pix : `~numpy.ndarray` |
||
752 | Pixel coordinates. |
||
753 | clip : bool |
||
754 | Choose whether to clip indices to the valid range of the |
||
755 | axis. If false then indices for coordinates outside |
||
756 | the axi range will be set -1. |
||
757 | |||
758 | Returns |
||
759 | ------- |
||
760 | idx : `~numpy.ndarray` |
||
761 | Pixel indices. |
||
762 | """ |
||
763 | if clip: |
||
764 | idx = np.clip(pix, 0, self.nbin - 1) |
||
765 | else: |
||
766 | condition = (pix < 0) | (pix >= self.nbin) |
||
767 | idx = np.where(condition, -1, pix) |
||
768 | |||
769 | return idx |
||
770 | |||
771 | def coord_to_pix(self, coord): |
||
772 | """Transform from axis to pixel coordinates. |
||
773 | |||
774 | Parameters |
||
775 | ---------- |
||
776 | coord : `~numpy.ndarray` |
||
777 | Array of axis coordinate values. |
||
778 | |||
779 | Returns |
||
780 | ------- |
||
781 | pix : `~numpy.ndarray` |
||
782 | Array of pixel coordinate values. |
||
783 | """ |
||
784 | coord = u.Quantity(coord, self.unit, copy=False).value |
||
785 | pix = self._transform.coord_to_pix(coord=coord) |
||
786 | return np.array(pix + self._pix_offset, ndmin=1) |
||
787 | |||
788 | def coord_to_idx(self, coord, clip=False): |
||
789 | """Transform from axis coordinate to bin index. |
||
790 | |||
791 | Parameters |
||
792 | ---------- |
||
793 | coord : `~numpy.ndarray` |
||
794 | Array of axis coordinate values. |
||
795 | clip : bool |
||
796 | Choose whether to clip the index to the valid range of the |
||
797 | axis. If false then indices for values outside the axis |
||
798 | range will be set -1. |
||
799 | |||
800 | Returns |
||
801 | ------- |
||
802 | idx : `~numpy.ndarray` |
||
803 | Array of bin indices. |
||
804 | """ |
||
805 | coord = u.Quantity(coord, self.unit, copy=False, ndmin=1).value |
||
806 | edges = self.edges.value |
||
807 | idx = np.digitize(coord, edges) - 1 |
||
808 | |||
809 | if clip: |
||
810 | idx = np.clip(idx, 0, self.nbin - 1) |
||
811 | else: |
||
812 | with np.errstate(invalid="ignore"): |
||
813 | idx[coord > edges[-1]] = INVALID_INDEX.int |
||
814 | |||
815 | idx[~np.isfinite(coord)] = INVALID_INDEX.int |
||
816 | |||
817 | return idx |
||
818 | |||
819 | def slice(self, idx): |
||
820 | """Create a new axis object by extracting a slice from this axis. |
||
821 | |||
822 | Parameters |
||
823 | ---------- |
||
824 | idx : slice |
||
825 | Slice object selecting a subselection of the axis. |
||
826 | |||
827 | Returns |
||
828 | ------- |
||
829 | axis : `~MapAxis` |
||
830 | Sliced axis object. |
||
831 | """ |
||
832 | center = self.center[idx].value |
||
833 | idx = self.coord_to_idx(center) |
||
834 | # For edge nodes we need to keep N+1 nodes |
||
835 | if self._node_type == "edges": |
||
836 | idx = tuple(list(idx) + [1 + idx[-1]]) |
||
837 | |||
838 | nodes = self._nodes[(idx,)] |
||
839 | return MapAxis( |
||
840 | nodes, |
||
841 | interp=self._interp, |
||
842 | name=self._name, |
||
843 | node_type=self._node_type, |
||
844 | unit=self._unit, |
||
845 | ) |
||
846 | |||
847 | def squash(self): |
||
848 | """Create a new axis object by squashing the axis into one bin. |
||
849 | |||
850 | Returns |
||
851 | ------- |
||
852 | axis : `~MapAxis` |
||
853 | Sliced axis object. |
||
854 | """ |
||
855 | # TODO: Decide on handling node_type=center |
||
856 | # See https://github.com/gammapy/gammapy/issues/1952 |
||
857 | return MapAxis.from_bounds( |
||
858 | lo_bnd=self.edges[0].value, |
||
859 | hi_bnd=self.edges[-1].value, |
||
860 | nbin=1, |
||
861 | interp=self._interp, |
||
862 | name=self._name, |
||
863 | unit=self._unit, |
||
864 | ) |
||
865 | |||
866 | def __repr__(self): |
||
867 | str_ = self.__class__.__name__ |
||
868 | str_ += "\n\n" |
||
869 | fmt = "\t{:<10s} : {:<10s}\n" |
||
870 | str_ += fmt.format("name", self.name) |
||
871 | str_ += fmt.format("unit", "{!r}".format(str(self.unit))) |
||
872 | str_ += fmt.format("nbins", str(self.nbin)) |
||
873 | str_ += fmt.format("node type", self.node_type) |
||
874 | vals = self.edges if self.node_type == "edges" else self.center |
||
875 | str_ += fmt.format(f"{self.node_type} min", "{:.1e}".format(vals.min())) |
||
876 | str_ += fmt.format(f"{self.node_type} max", "{:.1e}".format(vals.max())) |
||
877 | str_ += fmt.format("interp", self._interp) |
||
878 | return str_ |
||
879 | |||
880 | View Code Duplication | def _init_copy(self, **kwargs): |
|
881 | """Init map axis instance by copying missing init arguments from self.""" |
||
882 | argnames = inspect.getfullargspec(self.__init__).args |
||
883 | argnames.remove("self") |
||
884 | |||
885 | for arg in argnames: |
||
886 | value = getattr(self, "_" + arg) |
||
887 | kwargs.setdefault(arg, copy.deepcopy(value)) |
||
888 | |||
889 | return self.__class__(**kwargs) |
||
890 | |||
891 | def copy(self, **kwargs): |
||
892 | """Copy `MapAxis` instance and overwrite given attributes. |
||
893 | |||
894 | Parameters |
||
895 | ---------- |
||
896 | **kwargs : dict |
||
897 | Keyword arguments to overwrite in the map axis constructor. |
||
898 | |||
899 | Returns |
||
900 | ------- |
||
901 | copy : `MapAxis` |
||
902 | Copied map axis. |
||
903 | """ |
||
904 | return self._init_copy(**kwargs) |
||
905 | |||
906 | def round(self, coord, clip=False): |
||
907 | """Round coord to nearest axis edge. |
||
908 | |||
909 | Parameters |
||
910 | ---------- |
||
911 | coord : `~astropy.units.Quantity` |
||
912 | Coordinates |
||
913 | clip : bool |
||
914 | Choose whether to clip indices to the valid range of the axis. |
||
915 | |||
916 | Returns |
||
917 | ------- |
||
918 | coord : `~astropy.units.Quantity` |
||
919 | Rounded coordinates |
||
920 | """ |
||
921 | edges_pix = self.coord_to_pix(coord) |
||
922 | |||
923 | if clip: |
||
924 | edges_pix = np.clip(edges_pix, -0.5, self.nbin - 0.5) |
||
925 | |||
926 | edges_idx = np.round(edges_pix + 0.5) - 0.5 |
||
927 | return self.pix_to_coord(edges_idx) |
||
928 | |||
929 | def group_table(self, edges): |
||
930 | """Compute bin groups table for the map axis, given coarser bin edges. |
||
931 | |||
932 | Parameters |
||
933 | ---------- |
||
934 | edges : `~astropy.units.Quantity` |
||
935 | Group bin edges. |
||
936 | |||
937 | Returns |
||
938 | ------- |
||
939 | groups : `~astropy.table.Table` |
||
940 | Map axis group table. |
||
941 | """ |
||
942 | # TODO: try to simplify this code |
||
943 | if self.node_type != "edges": |
||
944 | raise ValueError("Only edge based map axis can be grouped") |
||
945 | |||
946 | edges_pix = self.coord_to_pix(edges) |
||
947 | edges_pix = np.clip(edges_pix, -0.5, self.nbin - 0.5) |
||
948 | edges_idx = np.round(edges_pix + 0.5) - 0.5 |
||
949 | edges_idx = np.unique(edges_idx) |
||
950 | edges_ref = self.pix_to_coord(edges_idx) |
||
951 | |||
952 | groups = Table() |
||
953 | groups[f"{self.name}_min"] = edges_ref[:-1] |
||
954 | groups[f"{self.name}_max"] = edges_ref[1:] |
||
955 | |||
956 | groups["idx_min"] = (edges_idx[:-1] + 0.5).astype(int) |
||
957 | groups["idx_max"] = (edges_idx[1:] - 0.5).astype(int) |
||
958 | |||
959 | if len(groups) == 0: |
||
960 | raise ValueError("No overlap between reference and target edges.") |
||
961 | |||
962 | groups["bin_type"] = "normal " |
||
963 | |||
964 | edge_idx_start, edge_ref_start = edges_idx[0], edges_ref[0] |
||
965 | if edge_idx_start > 0: |
||
966 | underflow = { |
||
967 | "bin_type": "underflow", |
||
968 | "idx_min": 0, |
||
969 | "idx_max": edge_idx_start, |
||
970 | f"{self.name}_min": self.pix_to_coord(-0.5), |
||
971 | f"{self.name}_max": edge_ref_start, |
||
972 | } |
||
973 | groups.insert_row(0, vals=underflow) |
||
974 | |||
975 | edge_idx_end, edge_ref_end = edges_idx[-1], edges_ref[-1] |
||
976 | |||
977 | if edge_idx_end < (self.nbin - 0.5): |
||
978 | overflow = { |
||
979 | "bin_type": "overflow", |
||
980 | "idx_min": edge_idx_end + 1, |
||
981 | "idx_max": self.nbin - 1, |
||
982 | f"{self.name}_min": edge_ref_end, |
||
983 | f"{self.name}_max": self.pix_to_coord(self.nbin - 0.5), |
||
984 | } |
||
985 | groups.add_row(vals=overflow) |
||
986 | |||
987 | group_idx = Column(np.arange(len(groups))) |
||
988 | groups.add_column(group_idx, name="group_idx", index=0) |
||
989 | return groups |
||
990 | |||
991 | def upsample(self, factor): |
||
992 | """Upsample map axis by a given factor. |
||
993 | |||
994 | When up-sampling for each node specified in the axis, the corresponding |
||
995 | number of sub-nodes are introduced and preserving the initial nodes. For |
||
996 | node type "edges" this results in nbin * factor new bins. For node type |
||
997 | "center" this results in (nbin - 1) * factor + 1 new bins. |
||
998 | |||
999 | Parameters |
||
1000 | ---------- |
||
1001 | factor : int |
||
1002 | Upsampling factor. |
||
1003 | |||
1004 | Returns |
||
1005 | ------- |
||
1006 | axis : `MapAxis` |
||
1007 | Usampled map axis. |
||
1008 | |||
1009 | """ |
||
1010 | if self.node_type == "edges": |
||
1011 | pix = self.coord_to_pix(self.edges) |
||
1012 | nbin = int(self.nbin * factor) + 1 |
||
1013 | pix_new = np.linspace(pix.min(), pix.max(), nbin) |
||
1014 | edges = self.pix_to_coord(pix_new) |
||
1015 | return self.from_edges(edges, name=self.name, interp=self.interp) |
||
1016 | else: |
||
1017 | pix = self.coord_to_pix(self.center) |
||
1018 | nbin = int((self.nbin - 1) * factor) + 1 |
||
1019 | pix_new = np.linspace(pix.min(), pix.max(), nbin) |
||
1020 | nodes = self.pix_to_coord(pix_new) |
||
1021 | return self.from_nodes(nodes, name=self.name, interp=self.interp) |
||
1022 | |||
1023 | def downsample(self, factor): |
||
1024 | """Downsample map axis by a given factor. |
||
1025 | |||
1026 | When down-sampling each n-th (given by the factor) bin is selected from |
||
1027 | the axis while preserving the axis limits. For node type "edges" this |
||
1028 | requires nbin to be dividable by the factor, for node type "center" this |
||
1029 | requires nbin - 1 to be dividable by the factor. |
||
1030 | |||
1031 | Parameters |
||
1032 | ---------- |
||
1033 | factor : int |
||
1034 | Downsampling factor. |
||
1035 | |||
1036 | |||
1037 | Returns |
||
1038 | ------- |
||
1039 | axis : `MapAxis` |
||
1040 | Downsampled map axis. |
||
1041 | """ |
||
1042 | if self.node_type == "edges": |
||
1043 | nbin = self.nbin / factor |
||
1044 | |||
1045 | if np.mod(nbin, 1) > 0: |
||
1046 | raise ValueError( |
||
1047 | f"Number of {self.name} bins is not divisible by {factor}" |
||
1048 | ) |
||
1049 | |||
1050 | edges = self.edges[::factor] |
||
1051 | return self.from_edges(edges, name=self.name, interp=self.interp) |
||
1052 | else: |
||
1053 | nbin = (self.nbin - 1) / factor |
||
1054 | |||
1055 | if np.mod(nbin, 1) > 0: |
||
1056 | raise ValueError( |
||
1057 | f"Number of {self.name} bins - 1 is not divisible by {factor}" |
||
1058 | ) |
||
1059 | |||
1060 | nodes = self.center[::factor] |
||
1061 | return self.from_nodes(nodes, name=self.name, interp=self.interp) |
||
1062 | |||
1063 | def to_header(self, format="ogip", idx=0): |
||
1064 | """Create FITS header |
||
1065 | |||
1066 | Parameters |
||
1067 | ---------- |
||
1068 | format : {"ogip"} |
||
1069 | Format specification |
||
1070 | idx : int |
||
1071 | Column index of the axis. |
||
1072 | |||
1073 | Returns |
||
1074 | ------- |
||
1075 | header : `~astropy.io.fits.Header` |
||
1076 | Header to extend. |
||
1077 | """ |
||
1078 | header = fits.Header() |
||
1079 | |||
1080 | if format in ["ogip", "ogip-sherpa"]: |
||
1081 | header["EXTNAME"] = "EBOUNDS", "Name of this binary table extension" |
||
1082 | header["TELESCOP"] = "DUMMY", "Mission/satellite name" |
||
1083 | header["INSTRUME"] = "DUMMY", "Instrument/detector" |
||
1084 | header["FILTER"] = "None", "Filter information" |
||
1085 | header["CHANTYPE"] = "PHA", "Type of channels (PHA, PI etc)" |
||
1086 | header["DETCHANS"] = self.nbin, "Total number of detector PHA channels" |
||
1087 | header["HDUCLASS"] = "OGIP", "Organisation devising file format" |
||
1088 | header["HDUCLAS1"] = "RESPONSE", "File relates to response of instrument" |
||
1089 | header["HDUCLAS2"] = "EBOUNDS", "This is an EBOUNDS extension" |
||
1090 | header["HDUVERS"] = "1.2.0", "Version of file format" |
||
1091 | elif format in ["gadf", "fgst-ccube", "fgst-template"]: |
||
1092 | key = f"AXCOLS{idx}" |
||
1093 | name = self.name.upper() |
||
1094 | |||
1095 | if self.name == "energy" and self.node_type == "edges": |
||
1096 | header[key] = "E_MIN,E_MAX" |
||
1097 | elif self.name == "energy" and self.node_type == "center": |
||
1098 | header[key] = "ENERGY" |
||
1099 | elif self.node_type == "edges": |
||
1100 | header[key] = f"{name}_MIN,{name}_MAX" |
||
1101 | elif self.node_type == "center": |
||
1102 | header[key] = name |
||
1103 | else: |
||
1104 | raise ValueError(f"Invalid node type {self.node_type!r}") |
||
1105 | |||
1106 | key_interp = f"INTERP{idx}" |
||
1107 | header[key_interp] = self.interp |
||
1108 | |||
1109 | else: |
||
1110 | raise ValueError(f"Unknown format {format}") |
||
1111 | |||
1112 | return header |
||
1113 | |||
1114 | def to_table(self, format="ogip"): |
||
1115 | """Convert `~astropy.units.Quantity` to OGIP ``EBOUNDS`` extension. |
||
1116 | |||
1117 | See https://heasarc.gsfc.nasa.gov/docs/heasarc/caldb/docs/memos/cal_gen_92_002/cal_gen_92_002.html#tth_sEc3.2 # noqa: E501 |
||
1118 | |||
1119 | The 'ogip-sherpa' format is equivalent to 'ogip' but uses keV energy units. |
||
1120 | |||
1121 | Parameters |
||
1122 | ---------- |
||
1123 | format : {"ogip", "ogip-sherpa", "gadf-dl3", "gtpsf"} |
||
1124 | Format specification |
||
1125 | |||
1126 | Returns |
||
1127 | ------- |
||
1128 | table : `~astropy.table.Table` |
||
1129 | Table HDU |
||
1130 | """ |
||
1131 | table = Table() |
||
1132 | edges = self.edges |
||
1133 | |||
1134 | if format in ["ogip", "ogip-sherpa"]: |
||
1135 | self.assert_name("energy") |
||
1136 | |||
1137 | if format == "ogip-sherpa": |
||
1138 | edges = edges.to("keV") |
||
1139 | |||
1140 | table["CHANNEL"] = np.arange(self.nbin, dtype=np.int16) |
||
1141 | table["E_MIN"] = edges[:-1] |
||
1142 | table["E_MAX"] = edges[1:] |
||
1143 | elif format in ["ogip-arf", "ogip-arf-sherpa"]: |
||
1144 | self.assert_name("energy_true") |
||
1145 | |||
1146 | if format == "ogip-arf-sherpa": |
||
1147 | edges = edges.to("keV") |
||
1148 | |||
1149 | table["ENERG_LO"] = edges[:-1] |
||
1150 | table["ENERG_HI"] = edges[1:] |
||
1151 | elif format == "gadf-sed": |
||
1152 | if self.is_energy_axis: |
||
1153 | table["e_ref"] = self.center |
||
1154 | table["e_min"] = self.edges_min |
||
1155 | table["e_max"] = self.edges_max |
||
1156 | elif format == "gadf-dl3": |
||
1157 | from gammapy.irf.io import IRF_DL3_AXES_SPECIFICATION |
||
1158 | |||
1159 | if self.name == "energy": |
||
1160 | column_prefix = "ENERG" |
||
1161 | else: |
||
1162 | for column_prefix, spec in IRF_DL3_AXES_SPECIFICATION.items(): |
||
1163 | if spec["name"] == self.name: |
||
1164 | break |
||
1165 | |||
1166 | if self.node_type == "edges": |
||
1167 | edges_hi, edges_lo = edges[:-1], edges[1:] |
||
1168 | else: |
||
1169 | edges_hi, edges_lo = self.center, self.center |
||
1170 | |||
1171 | table[f"{column_prefix}_LO"] = edges_hi[np.newaxis] |
||
1172 | table[f"{column_prefix}_HI"] = edges_lo[np.newaxis] |
||
1173 | elif format == "gtpsf": |
||
1174 | if self.name == "energy_true": |
||
1175 | table["Energy"] = self.center.to("MeV") |
||
1176 | elif self.name == "rad": |
||
1177 | table["Theta"] = self.center.to("deg") |
||
1178 | else: |
||
1179 | raise ValueError( |
||
1180 | "Can only convert true energy or rad axis to" |
||
1181 | f"'gtpsf' format, got {self.name}" |
||
1182 | ) |
||
1183 | else: |
||
1184 | raise ValueError(f"{format} is not a valid format") |
||
1185 | |||
1186 | return table |
||
1187 | |||
1188 | def to_table_hdu(self, format="ogip"): |
||
1189 | """Convert `~astropy.units.Quantity` to OGIP ``EBOUNDS`` extension. |
||
1190 | |||
1191 | See https://heasarc.gsfc.nasa.gov/docs/heasarc/caldb/docs/memos/cal_gen_92_002/cal_gen_92_002.html#tth_sEc3.2 # noqa: E501 |
||
1192 | |||
1193 | The 'ogip-sherpa' format is equivalent to 'ogip' but uses keV energy units. |
||
1194 | |||
1195 | Parameters |
||
1196 | ---------- |
||
1197 | format : {"ogip", "ogip-sherpa", "gtpsf"} |
||
1198 | Format specification |
||
1199 | |||
1200 | Returns |
||
1201 | ------- |
||
1202 | hdu : `~astropy.io.fits.BinTableHDU` |
||
1203 | Table HDU |
||
1204 | """ |
||
1205 | table = self.to_table(format=format) |
||
1206 | |||
1207 | if format == "gtpsf": |
||
1208 | name = "THETA" |
||
1209 | else: |
||
1210 | name = None |
||
1211 | |||
1212 | hdu = fits.BinTableHDU(table, name=name) |
||
1213 | |||
1214 | if format in ["ogip", "ogip-sherpa"]: |
||
1215 | hdu.header.update(self.to_header(format=format)) |
||
1216 | |||
1217 | return hdu |
||
1218 | |||
1219 | @classmethod |
||
1220 | def from_table(cls, table, format="ogip", idx=0, column_prefix=""): |
||
1221 | """Instantiate MapAxis from table HDU |
||
1222 | |||
1223 | Parameters |
||
1224 | ---------- |
||
1225 | table : `~astropy.table.Table` |
||
1226 | Table |
||
1227 | format : {"ogip", "ogip-arf", "fgst-ccube", "fgst-template", "gadf", "gadf-dl3"} |
||
1228 | Format specification |
||
1229 | idx : int |
||
1230 | Column index of the axis. |
||
1231 | column_prefix : str |
||
1232 | Column name prefix of the axis, used for creating the axis. |
||
1233 | |||
1234 | Returns |
||
1235 | ------- |
||
1236 | axis : `MapAxis` |
||
1237 | Map Axis |
||
1238 | """ |
||
1239 | if format in ["ogip", "fgst-ccube"]: |
||
1240 | energy_min = table["E_MIN"].quantity |
||
1241 | energy_max = table["E_MAX"].quantity |
||
1242 | energy_edges = ( |
||
1243 | np.append(energy_min.value, energy_max.value[-1]) * energy_min.unit |
||
1244 | ) |
||
1245 | axis = cls.from_edges(energy_edges, name="energy", interp="log") |
||
1246 | |||
1247 | elif format == "ogip-arf": |
||
1248 | energy_min = table["ENERG_LO"].quantity |
||
1249 | energy_max = table["ENERG_HI"].quantity |
||
1250 | energy_edges = ( |
||
1251 | np.append(energy_min.value, energy_max.value[-1]) * energy_min.unit |
||
1252 | ) |
||
1253 | axis = cls.from_edges(energy_edges, name="energy_true", interp="log") |
||
1254 | |||
1255 | elif format in ["fgst-template", "fgst-bexpcube"]: |
||
1256 | allowed_names = ["Energy", "ENERGY", "energy"] |
||
1257 | for colname in table.colnames: |
||
1258 | if colname in allowed_names: |
||
1259 | tag = colname |
||
1260 | break |
||
1261 | |||
1262 | nodes = table[tag].data |
||
1263 | axis = cls.from_nodes( |
||
1264 | nodes=nodes, name="energy_true", unit="MeV", interp="log" |
||
1265 | ) |
||
1266 | |||
1267 | elif format == "gadf": |
||
1268 | axcols = table.meta.get("AXCOLS{}".format(idx + 1)) |
||
1269 | colnames = axcols.split(",") |
||
1270 | node_type = "edges" if len(colnames) == 2 else "center" |
||
1271 | |||
1272 | # TODO: check why this extra case is needed |
||
1273 | if colnames[0] == "E_MIN": |
||
1274 | name = "energy" |
||
1275 | else: |
||
1276 | name = colnames[0].replace("_MIN", "").lower() |
||
1277 | # this is need for backward compatibility |
||
1278 | if name == "theta": |
||
1279 | name = "rad" |
||
1280 | |||
1281 | interp = table.meta.get("INTERP{}".format(idx + 1), "lin") |
||
1282 | |||
1283 | if node_type == "center": |
||
1284 | nodes = np.unique(table[colnames[0]].quantity) |
||
1285 | else: |
||
1286 | edges_min = np.unique(table[colnames[0]].quantity) |
||
1287 | edges_max = np.unique(table[colnames[1]].quantity) |
||
1288 | nodes = edges_from_lo_hi(edges_min, edges_max) |
||
1289 | |||
1290 | axis = MapAxis(nodes=nodes, node_type=node_type, interp=interp, name=name) |
||
1291 | |||
1292 | elif format == "gadf-dl3": |
||
1293 | from gammapy.irf.io import IRF_DL3_AXES_SPECIFICATION |
||
1294 | |||
1295 | spec = IRF_DL3_AXES_SPECIFICATION[column_prefix] |
||
1296 | name, interp = spec["name"], spec["interp"] |
||
1297 | |||
1298 | # background models are stored in reconstructed energy |
||
1299 | hduclass = table.meta.get("HDUCLAS2") |
||
1300 | if hduclass in {"BKG", "RAD_MAX"} and column_prefix == "ENERG": |
||
1301 | name = "energy" |
||
1302 | |||
1303 | edges_lo = table[f"{column_prefix}_LO"].quantity[0] |
||
1304 | edges_hi = table[f"{column_prefix}_HI"].quantity[0] |
||
1305 | |||
1306 | if np.allclose(edges_hi, edges_lo): |
||
1307 | axis = MapAxis.from_nodes(edges_hi, interp=interp, name=name) |
||
1308 | else: |
||
1309 | edges = edges_from_lo_hi(edges_lo, edges_hi) |
||
1310 | axis = MapAxis.from_edges(edges, interp=interp, name=name) |
||
1311 | elif format == "gtpsf": |
||
1312 | try: |
||
1313 | energy = table["Energy"].data * u.MeV |
||
1314 | axis = MapAxis.from_nodes(energy, name="energy_true", interp="log") |
||
1315 | except KeyError: |
||
1316 | rad = table["Theta"].data * u.deg |
||
1317 | axis = MapAxis.from_nodes(rad, name="rad") |
||
1318 | elif format == "gadf-sed-energy": |
||
1319 | if "e_min" in table.colnames and "e_max" in table.colnames: |
||
1320 | e_min = flat_if_equal(table["e_min"].quantity) |
||
1321 | e_max = flat_if_equal(table["e_max"].quantity) |
||
1322 | edges = edges_from_lo_hi(e_min, e_max) |
||
1323 | axis = MapAxis.from_energy_edges(edges) |
||
1324 | elif "e_ref" in table.colnames: |
||
1325 | e_ref = flat_if_equal(table["e_ref"].quantity) |
||
1326 | axis = MapAxis.from_nodes(e_ref, name="energy", interp="log") |
||
1327 | else: |
||
1328 | raise ValueError( |
||
1329 | "Either 'e_ref', 'e_min' or 'e_max' column " "names are required" |
||
1330 | ) |
||
1331 | elif format == "gadf-sed-norm": |
||
1332 | # TODO: guess interp here |
||
1333 | nodes = flat_if_equal(table["norm_scan"][0]) |
||
1334 | axis = MapAxis.from_nodes(nodes, name="norm") |
||
1335 | elif format == "gadf-sed-counts": |
||
1336 | if "datasets" in table.colnames: |
||
1337 | labels = np.unique(table["datasets"]) |
||
1338 | axis = LabelMapAxis(labels=labels, name="dataset") |
||
1339 | else: |
||
1340 | shape = table["counts"].shape |
||
1341 | edges = np.arange(shape[-1] + 1) - 0.5 |
||
1342 | axis = MapAxis.from_edges(edges, name="dataset") |
||
1343 | elif format == "profile": |
||
1344 | if "datasets" in table.colnames: |
||
1345 | labels = np.unique(table["datasets"]) |
||
1346 | axis = LabelMapAxis(labels=labels, name="dataset") |
||
1347 | else: |
||
1348 | x_ref = table["x_ref"].quantity |
||
1349 | axis = MapAxis.from_nodes(x_ref, name="projected-distance") |
||
1350 | else: |
||
1351 | raise ValueError(f"Format '{format}' not supported") |
||
1352 | |||
1353 | return axis |
||
1354 | |||
1355 | @classmethod |
||
1356 | def from_table_hdu(cls, hdu, format="ogip", idx=0): |
||
1357 | """Instantiate MapAxis from table HDU |
||
1358 | |||
1359 | Parameters |
||
1360 | ---------- |
||
1361 | hdu : `~astropy.io.fits.BinTableHDU` |
||
1362 | Table HDU |
||
1363 | format : {"ogip", "ogip-arf", "fgst-ccube", "fgst-template"} |
||
1364 | Format specification |
||
1365 | idx : int |
||
1366 | Column index of the axis. |
||
1367 | |||
1368 | Returns |
||
1369 | ------- |
||
1370 | axis : `MapAxis` |
||
1371 | Map Axis |
||
1372 | """ |
||
1373 | table = Table.read(hdu) |
||
1374 | return cls.from_table(table, format=format, idx=idx) |
||
1375 | |||
1376 | |||
1377 | class MapAxes(Sequence): |
||
1378 | """MapAxis container class. |
||
1379 | |||
1380 | Parameters |
||
1381 | ---------- |
||
1382 | axes : list of `MapAxis` |
||
1383 | List of map axis objects. |
||
1384 | """ |
||
1385 | |||
1386 | def __init__(self, axes, n_spatial_axes=None): |
||
1387 | unique_names = [] |
||
1388 | |||
1389 | for ax in axes: |
||
1390 | if ax.name in unique_names: |
||
1391 | raise ( |
||
1392 | ValueError(f"Axis names must be unique, got: '{ax.name}' twice.") |
||
1393 | ) |
||
1394 | unique_names.append(ax.name) |
||
1395 | |||
1396 | self._axes = axes |
||
1397 | self._n_spatial_axes = n_spatial_axes |
||
1398 | |||
1399 | @property |
||
1400 | def primary_axis(self): |
||
1401 | """Primary extra axis, defined as the one longest |
||
1402 | |||
1403 | Returns |
||
1404 | ------- |
||
1405 | axis : `MapAxis` |
||
1406 | Map axis |
||
1407 | """ |
||
1408 | # get longest axis |
||
1409 | idx = np.argmax(self.shape) |
||
1410 | return self[int(idx)] |
||
1411 | |||
1412 | @property |
||
1413 | def is_flat(self): |
||
1414 | """Whether axes is flat""" |
||
1415 | shape = np.array(self.shape) |
||
1416 | return np.all(shape == 1) |
||
1417 | |||
1418 | @property |
||
1419 | def is_unidimensional(self): |
||
1420 | """Whether axes is unidimensional""" |
||
1421 | shape = np.array(self.shape) |
||
1422 | non_zero = np.count_nonzero(shape > 1) |
||
1423 | return self.is_flat or non_zero == 1 |
||
1424 | |||
1425 | @property |
||
1426 | def reverse(self): |
||
1427 | """Reverse axes order""" |
||
1428 | return MapAxes(self[::-1]) |
||
1429 | |||
1430 | @property |
||
1431 | def iter_with_reshape(self): |
||
1432 | """Iterate by shape""" |
||
1433 | for idx, axis in enumerate(self): |
||
1434 | # Extract values for each axis, default: nodes |
||
1435 | shape = [1] * len(self) |
||
1436 | shape[idx] = -1 |
||
1437 | if self._n_spatial_axes: |
||
1438 | shape = ( |
||
1439 | shape[::-1] |
||
1440 | + [ |
||
1441 | 1, |
||
1442 | ] |
||
1443 | * self._n_spatial_axes |
||
1444 | ) |
||
1445 | yield tuple(shape), axis |
||
1446 | |||
1447 | def get_coord(self, mode="center", axis_name=None): |
||
1448 | """Get axes coordinates |
||
1449 | |||
1450 | Parameters |
||
1451 | ---------- |
||
1452 | mode : {"center", "edges"} |
||
1453 | Coordinate center or edges |
||
1454 | axis_name : str |
||
1455 | Axis name for which mode='edges' applies |
||
1456 | |||
1457 | Returns |
||
1458 | ------- |
||
1459 | coords : dict of `~astropy.units.Quanity` |
||
1460 | Map coordinates |
||
1461 | """ |
||
1462 | coords = {} |
||
1463 | |||
1464 | for shape, axis in self.iter_with_reshape: |
||
1465 | if mode == "edges" and axis.name == axis_name: |
||
1466 | coord = axis.edges |
||
1467 | else: |
||
1468 | coord = axis.center |
||
1469 | coords[axis.name] = coord.reshape(shape) |
||
1470 | |||
1471 | return coords |
||
1472 | |||
1473 | def bin_volume(self): |
||
1474 | """Bin axes volume |
||
1475 | |||
1476 | Returns |
||
1477 | ------- |
||
1478 | bin_volume : `~astropy.units.Quantity` |
||
1479 | Bin volume |
||
1480 | """ |
||
1481 | bin_volume = np.array(1) |
||
1482 | |||
1483 | for shape, axis in self.iter_with_reshape: |
||
1484 | bin_volume = bin_volume * axis.bin_width.reshape(shape) |
||
1485 | |||
1486 | return bin_volume |
||
1487 | |||
1488 | @property |
||
1489 | def shape(self): |
||
1490 | """Shape of the axes""" |
||
1491 | return tuple([ax.nbin for ax in self]) |
||
1492 | |||
1493 | @property |
||
1494 | def names(self): |
||
1495 | """Names of the axes""" |
||
1496 | return [ax.name for ax in self] |
||
1497 | |||
1498 | def index(self, axis_name): |
||
1499 | """Get index in list""" |
||
1500 | return self.names.index(axis_name) |
||
1501 | |||
1502 | def index_data(self, axis_name): |
||
1503 | """Get data index of the axes |
||
1504 | |||
1505 | Parameters |
||
1506 | ---------- |
||
1507 | axis_name : str |
||
1508 | Name of the axis. |
||
1509 | |||
1510 | Returns |
||
1511 | ------- |
||
1512 | idx : int |
||
1513 | Data index |
||
1514 | """ |
||
1515 | idx = self.names.index(axis_name) |
||
1516 | return len(self) - idx - 1 |
||
1517 | |||
1518 | def __len__(self): |
||
1519 | return len(self._axes) |
||
1520 | |||
1521 | def __add__(self, other): |
||
1522 | return self.__class__(list(self) + list(other)) |
||
1523 | |||
1524 | def upsample(self, factor, axis_name): |
||
1525 | """Upsample axis by a given factor |
||
1526 | |||
1527 | Parameters |
||
1528 | ---------- |
||
1529 | factor : int |
||
1530 | Upsampling factor. |
||
1531 | axis_name : str |
||
1532 | Axis to upsample. |
||
1533 | |||
1534 | Returns |
||
1535 | ------- |
||
1536 | axes : `MapAxes` |
||
1537 | Map axes |
||
1538 | """ |
||
1539 | axes = [] |
||
1540 | |||
1541 | for ax in self: |
||
1542 | if ax.name == axis_name: |
||
1543 | ax = ax.upsample(factor=factor) |
||
1544 | |||
1545 | axes.append(ax.copy()) |
||
1546 | |||
1547 | return self.__class__(axes=axes) |
||
1548 | |||
1549 | def replace(self, axis): |
||
1550 | """Replace a given axis |
||
1551 | |||
1552 | Parameters |
||
1553 | ---------- |
||
1554 | axis : `MapAxis` |
||
1555 | Map axis |
||
1556 | |||
1557 | Returns |
||
1558 | ------- |
||
1559 | axes : MapAxes |
||
1560 | Map axe |
||
1561 | """ |
||
1562 | axes = [] |
||
1563 | |||
1564 | for ax in self: |
||
1565 | if ax.name == axis.name: |
||
1566 | ax = axis |
||
1567 | |||
1568 | axes.append(ax) |
||
1569 | |||
1570 | return self.__class__(axes=axes) |
||
1571 | |||
1572 | def resample(self, axis): |
||
1573 | """Resample axis binning. |
||
1574 | |||
1575 | This method groups the existing bins into a new binning. |
||
1576 | |||
1577 | Parameters |
||
1578 | ---------- |
||
1579 | axis : `MapAxis` |
||
1580 | New map axis. |
||
1581 | |||
1582 | Returns |
||
1583 | ------- |
||
1584 | axes : `MapAxes` |
||
1585 | Axes object with resampled axis. |
||
1586 | """ |
||
1587 | axis_self = self[axis.name] |
||
1588 | groups = axis_self.group_table(axis.edges) |
||
1589 | |||
1590 | # Keep only normal bins |
||
1591 | groups = groups[groups["bin_type"] == "normal "] |
||
1592 | |||
1593 | edges = edges_from_lo_hi( |
||
1594 | groups[axis.name + "_min"].quantity, |
||
1595 | groups[axis.name + "_max"].quantity, |
||
1596 | ) |
||
1597 | |||
1598 | axis_resampled = MapAxis.from_edges( |
||
1599 | edges=edges, interp=axis.interp, name=axis.name |
||
1600 | ) |
||
1601 | |||
1602 | axes = [] |
||
1603 | for ax in self: |
||
1604 | if ax.name == axis.name: |
||
1605 | axes.append(axis_resampled) |
||
1606 | else: |
||
1607 | axes.append(ax.copy()) |
||
1608 | |||
1609 | return self.__class__(axes=axes) |
||
1610 | |||
1611 | def downsample(self, factor, axis_name): |
||
1612 | """Downsample axis by a given factor |
||
1613 | |||
1614 | Parameters |
||
1615 | ---------- |
||
1616 | factor : int |
||
1617 | Upsampling factor. |
||
1618 | axis_name : str |
||
1619 | Axis to upsample. |
||
1620 | |||
1621 | Returns |
||
1622 | ------- |
||
1623 | axes : `MapAxes` |
||
1624 | Map axes |
||
1625 | |||
1626 | """ |
||
1627 | axes = [] |
||
1628 | |||
1629 | for ax in self: |
||
1630 | if ax.name == axis_name: |
||
1631 | ax = ax.downsample(factor=factor) |
||
1632 | |||
1633 | axes.append(ax.copy()) |
||
1634 | |||
1635 | return self.__class__(axes=axes) |
||
1636 | |||
1637 | def squash(self, axis_name): |
||
1638 | """Squash axis. |
||
1639 | |||
1640 | Parameters |
||
1641 | ---------- |
||
1642 | axis_name : str |
||
1643 | Axis to squash. |
||
1644 | |||
1645 | Returns |
||
1646 | ------- |
||
1647 | axes : `MapAxes` |
||
1648 | Axes with squashed axis. |
||
1649 | """ |
||
1650 | axes = [] |
||
1651 | |||
1652 | for ax in self: |
||
1653 | if ax.name == axis_name: |
||
1654 | ax = ax.squash() |
||
1655 | axes.append(ax.copy()) |
||
1656 | |||
1657 | return self.__class__(axes=axes) |
||
1658 | |||
1659 | def pad(self, axis_name, pad_width): |
||
1660 | """Pad axes |
||
1661 | |||
1662 | Parameters |
||
1663 | ---------- |
||
1664 | axis_name : str |
||
1665 | Name of the axis to pad. |
||
1666 | pad_width : int or tuple of int |
||
1667 | Pad width |
||
1668 | |||
1669 | Returns |
||
1670 | ------- |
||
1671 | axes : `MapAxes` |
||
1672 | Axes with squashed axis. |
||
1673 | |||
1674 | """ |
||
1675 | axes = [] |
||
1676 | |||
1677 | for ax in self: |
||
1678 | if ax.name == axis_name: |
||
1679 | ax = ax.pad(pad_width=pad_width) |
||
1680 | axes.append(ax) |
||
1681 | |||
1682 | return self.__class__(axes=axes) |
||
1683 | |||
1684 | def drop(self, axis_name): |
||
1685 | """Drop an axis. |
||
1686 | |||
1687 | Parameters |
||
1688 | ---------- |
||
1689 | axis_name : str |
||
1690 | Name of the axis to remove. |
||
1691 | |||
1692 | Returns |
||
1693 | ------- |
||
1694 | axes : `MapAxes` |
||
1695 | Axes with squashed axis. |
||
1696 | """ |
||
1697 | axes = [] |
||
1698 | for ax in self: |
||
1699 | if ax.name == axis_name: |
||
1700 | continue |
||
1701 | axes.append(ax.copy()) |
||
1702 | |||
1703 | return self.__class__(axes=axes) |
||
1704 | |||
1705 | def __getitem__(self, idx): |
||
1706 | if isinstance(idx, int): |
||
1707 | return self._axes[idx] |
||
1708 | elif isinstance(idx, str): |
||
1709 | for ax in self._axes: |
||
1710 | if ax.name == idx: |
||
1711 | return ax |
||
1712 | raise KeyError(f"No axes: {idx!r}") |
||
1713 | elif isinstance(idx, slice): |
||
1714 | axes = self._axes[idx] |
||
1715 | return self.__class__(axes=axes) |
||
1716 | elif isinstance(idx, list): |
||
1717 | axes = [] |
||
1718 | for name in idx: |
||
1719 | axes.append(self[name]) |
||
1720 | |||
1721 | return self.__class__(axes=axes) |
||
1722 | else: |
||
1723 | raise TypeError(f"Invalid type: {type(idx)!r}") |
||
1724 | |||
1725 | def coord_to_idx(self, coord, clip=True): |
||
1726 | """Transform from axis to pixel indices. |
||
1727 | |||
1728 | Parameters |
||
1729 | ---------- |
||
1730 | coord : dict of `~numpy.ndarray` or `MapCoord` |
||
1731 | Array of axis coordinate values. |
||
1732 | |||
1733 | Returns |
||
1734 | ------- |
||
1735 | pix : tuple of `~numpy.ndarray` |
||
1736 | Array of pixel indices values. |
||
1737 | """ |
||
1738 | return tuple([ax.coord_to_idx(coord[ax.name], clip=clip) for ax in self]) |
||
1739 | |||
1740 | def coord_to_pix(self, coord): |
||
1741 | """Transform from axis to pixel coordinates. |
||
1742 | |||
1743 | Parameters |
||
1744 | ---------- |
||
1745 | coord : dict of `~numpy.ndarray` |
||
1746 | Array of axis coordinate values. |
||
1747 | |||
1748 | Returns |
||
1749 | ------- |
||
1750 | pix : tuple of `~numpy.ndarray` |
||
1751 | Array of pixel coordinate values. |
||
1752 | """ |
||
1753 | return tuple([ax.coord_to_pix(coord[ax.name]) for ax in self]) |
||
1754 | |||
1755 | def pix_to_coord(self, pix): |
||
1756 | """Convert pixel coordinates to map coordinates. |
||
1757 | |||
1758 | Parameters |
||
1759 | ---------- |
||
1760 | pix : tuple |
||
1761 | Tuple of pixel coordinates. |
||
1762 | |||
1763 | Returns |
||
1764 | ------- |
||
1765 | coords : tuple |
||
1766 | Tuple of map coordinates. |
||
1767 | """ |
||
1768 | return tuple([ax.pix_to_coord(p) for ax, p in zip(self, pix)]) |
||
1769 | |||
1770 | def pix_to_idx(self, pix, clip=False): |
||
1771 | """Convert pix to idx |
||
1772 | |||
1773 | Parameters |
||
1774 | ---------- |
||
1775 | pix : tuple of `~numpy.ndarray` |
||
1776 | Pixel coordinates. |
||
1777 | clip : bool |
||
1778 | Choose whether to clip indices to the valid range of the |
||
1779 | axis. If false then indices for coordinates outside |
||
1780 | the axi range will be set -1. |
||
1781 | |||
1782 | Returns |
||
1783 | ------- |
||
1784 | idx : tuple `~numpy.ndarray` |
||
1785 | Pixel indices. |
||
1786 | """ |
||
1787 | idx = [] |
||
1788 | |||
1789 | for pix_array, ax in zip(pix, self): |
||
1790 | idx.append(ax.pix_to_idx(pix_array, clip=clip)) |
||
1791 | |||
1792 | return tuple(idx) |
||
1793 | |||
1794 | def slice_by_idx(self, slices): |
||
1795 | """Create a new geometry by slicing the non-spatial axes. |
||
1796 | |||
1797 | Parameters |
||
1798 | ---------- |
||
1799 | slices : dict |
||
1800 | Dict of axes names and integers or `slice` object pairs. Contains one |
||
1801 | element for each non-spatial dimension. For integer indexing the |
||
1802 | corresponding axes is dropped from the map. Axes not specified in the |
||
1803 | dict are kept unchanged. |
||
1804 | |||
1805 | Returns |
||
1806 | ------- |
||
1807 | geom : `~Geom` |
||
1808 | Sliced geometry. |
||
1809 | """ |
||
1810 | axes = [] |
||
1811 | for ax in self: |
||
1812 | ax_slice = slices.get(ax.name, slice(None)) |
||
1813 | |||
1814 | # in the case where isinstance(ax_slice, int) the axes is dropped |
||
1815 | if isinstance(ax_slice, slice): |
||
1816 | ax_sliced = ax.slice(ax_slice) |
||
1817 | axes.append(ax_sliced.copy()) |
||
1818 | |||
1819 | return self.__class__(axes=axes) |
||
1820 | |||
1821 | def to_header(self, format="gadf"): |
||
1822 | """Convert axes to FITS header |
||
1823 | |||
1824 | Parameters |
||
1825 | ---------- |
||
1826 | format : {"gadf"} |
||
1827 | Header format |
||
1828 | |||
1829 | Returns |
||
1830 | ------- |
||
1831 | header : `~astropy.io.fits.Header` |
||
1832 | FITS header. |
||
1833 | """ |
||
1834 | header = fits.Header() |
||
1835 | |||
1836 | for idx, ax in enumerate(self, start=1): |
||
1837 | header_ax = ax.to_header(format=format, idx=idx) |
||
1838 | header.update(header_ax) |
||
1839 | |||
1840 | return header |
||
1841 | |||
1842 | def to_table(self, format="gadf"): |
||
1843 | """Convert axes to table |
||
1844 | |||
1845 | Parameters |
||
1846 | ---------- |
||
1847 | format : {"gadf", "gadf-dl3", "fgst-ccube", "fgst-template", "ogip", "ogip-sherpa", "ogip-arf", "ogip-arf-sherpa"} # noqa E501 |
||
1848 | Format to use. |
||
1849 | |||
1850 | Returns |
||
1851 | ------- |
||
1852 | table : `~astropy.table.Table` |
||
1853 | Table with axis data |
||
1854 | """ |
||
1855 | if format == "gadf-dl3": |
||
1856 | tables = [] |
||
1857 | |||
1858 | for ax in self: |
||
1859 | tables.append(ax.to_table(format=format)) |
||
1860 | |||
1861 | table = hstack(tables) |
||
1862 | elif format in ["gadf", "fgst-ccube", "fgst-template"]: |
||
1863 | table = Table() |
||
1864 | table["CHANNEL"] = np.arange(np.prod(self.shape)) |
||
1865 | |||
1866 | axes_ctr = np.meshgrid(*[ax.center for ax in self]) |
||
1867 | axes_min = np.meshgrid(*[ax.edges_min for ax in self]) |
||
1868 | axes_max = np.meshgrid(*[ax.edges_max for ax in self]) |
||
1869 | |||
1870 | for idx, ax in enumerate(self): |
||
1871 | name = ax.name.upper() |
||
1872 | |||
1873 | if name == "ENERGY": |
||
1874 | colnames = ["ENERGY", "E_MIN", "E_MAX"] |
||
1875 | else: |
||
1876 | colnames = [name, name + "_MIN", name + "_MAX"] |
||
1877 | |||
1878 | for colname, v in zip(colnames, [axes_ctr, axes_min, axes_max]): |
||
1879 | # do not store edges for label axis |
||
1880 | if ax.node_type == "label" and colname != name: |
||
1881 | continue |
||
1882 | |||
1883 | table[colname] = np.ravel(v[idx]) |
||
1884 | |||
1885 | if isinstance(ax, TimeMapAxis): |
||
1886 | ref_dict = time_ref_to_dict(ax.reference_time) |
||
1887 | table.meta.update(ref_dict) |
||
1888 | |||
1889 | elif format in ["ogip", "ogip-sherpa", "ogip", "ogip-arf"]: |
||
1890 | energy_axis = self["energy"] |
||
1891 | table = energy_axis.to_table(format=format) |
||
1892 | else: |
||
1893 | raise ValueError(f"Unsupported format: '{format}'") |
||
1894 | |||
1895 | return table |
||
1896 | |||
1897 | def to_table_hdu(self, format="gadf", hdu_bands=None): |
||
1898 | """Make FITS table columns for map axes. |
||
1899 | |||
1900 | Parameters |
||
1901 | ---------- |
||
1902 | format : {"gadf", "fgst-ccube", "fgst-template"} |
||
1903 | Format to use. |
||
1904 | hdu_bands : str |
||
1905 | Name of the bands HDU to use. |
||
1906 | |||
1907 | Returns |
||
1908 | ------- |
||
1909 | hdu : `~astropy.io.fits.BinTableHDU` |
||
1910 | Bin table HDU. |
||
1911 | """ |
||
1912 | # FIXME: Check whether convention is compatible with |
||
1913 | # dimensionality of geometry and simplify!!! |
||
1914 | |||
1915 | if format in ["fgst-ccube", "ogip", "ogip-sherpa"]: |
||
1916 | hdu_bands = "EBOUNDS" |
||
1917 | elif format == "fgst-template": |
||
1918 | hdu_bands = "ENERGIES" |
||
1919 | elif format == "gadf" or format is None: |
||
1920 | if hdu_bands is None: |
||
1921 | hdu_bands = "BANDS" |
||
1922 | else: |
||
1923 | raise ValueError(f"Unknown format {format}") |
||
1924 | |||
1925 | table = self.to_table(format=format) |
||
1926 | header = self.to_header(format=format) |
||
1927 | return fits.BinTableHDU(table, name=hdu_bands, header=header) |
||
1928 | |||
1929 | @classmethod |
||
1930 | def from_table_hdu(cls, hdu, format="gadf"): |
||
1931 | """Create MapAxes from BinTableHDU |
||
1932 | |||
1933 | Parameters |
||
1934 | ---------- |
||
1935 | hdu : `~astropy.io.fits.BinTableHDU` |
||
1936 | Bin table HDU |
||
1937 | |||
1938 | |||
1939 | Returns |
||
1940 | ------- |
||
1941 | axes : `MapAxes` |
||
1942 | Map axes object |
||
1943 | """ |
||
1944 | if hdu is None: |
||
1945 | return cls([]) |
||
1946 | |||
1947 | table = Table.read(hdu) |
||
1948 | return cls.from_table(table, format=format) |
||
1949 | |||
1950 | @classmethod |
||
1951 | def from_table(cls, table, format="gadf"): |
||
1952 | """Create MapAxes from table |
||
1953 | |||
1954 | Parameters |
||
1955 | ---------- |
||
1956 | table : `~astropy.table.Table` |
||
1957 | Bin table HDU |
||
1958 | format : {"gadf", "gadf-dl3", "fgst-ccube", "fgst-template", "fgst-bexcube", "ogip-arf"} |
||
1959 | Format to use. |
||
1960 | |||
1961 | Returns |
||
1962 | ------- |
||
1963 | axes : `MapAxes` |
||
1964 | Map axes object |
||
1965 | """ |
||
1966 | from gammapy.irf.io import IRF_DL3_AXES_SPECIFICATION |
||
1967 | |||
1968 | axes = [] |
||
1969 | |||
1970 | # Formats that support only one energy axis |
||
1971 | if format in [ |
||
1972 | "fgst-ccube", |
||
1973 | "fgst-template", |
||
1974 | "fgst-bexpcube", |
||
1975 | "ogip", |
||
1976 | "ogip-arf", |
||
1977 | ]: |
||
1978 | axes.append(MapAxis.from_table(table, format=format)) |
||
1979 | elif format == "gadf": |
||
1980 | # This limits the max number of axes to 5 |
||
1981 | for idx in range(5): |
||
1982 | axcols = table.meta.get("AXCOLS{}".format(idx + 1)) |
||
1983 | if axcols is None: |
||
1984 | break |
||
1985 | |||
1986 | # TODO: what is good way to check whether it is a given axis type? |
||
1987 | try: |
||
1988 | axis = LabelMapAxis.from_table(table, format=format, idx=idx) |
||
1989 | except (KeyError, TypeError): |
||
1990 | try: |
||
1991 | axis = TimeMapAxis.from_table(table, format=format, idx=idx) |
||
1992 | except (KeyError, ValueError): |
||
1993 | axis = MapAxis.from_table(table, format=format, idx=idx) |
||
1994 | |||
1995 | axes.append(axis) |
||
1996 | elif format == "gadf-dl3": |
||
1997 | for column_prefix in IRF_DL3_AXES_SPECIFICATION: |
||
1998 | try: |
||
1999 | axis = MapAxis.from_table( |
||
2000 | table, format=format, column_prefix=column_prefix |
||
2001 | ) |
||
2002 | except KeyError: |
||
2003 | continue |
||
2004 | |||
2005 | axes.append(axis) |
||
2006 | elif format == "gadf-sed": |
||
2007 | for axis_format in ["gadf-sed-norm", "gadf-sed-energy", "gadf-sed-counts"]: |
||
2008 | try: |
||
2009 | axis = MapAxis.from_table(table=table, format=axis_format) |
||
2010 | except KeyError: |
||
2011 | continue |
||
2012 | axes.append(axis) |
||
2013 | elif format == "lightcurve": |
||
2014 | axes.extend(cls.from_table(table=table, format="gadf-sed")) |
||
2015 | axes.append(TimeMapAxis.from_table(table, format="lightcurve")) |
||
2016 | elif format == "profile": |
||
2017 | axes.extend(cls.from_table(table=table, format="gadf-sed")) |
||
2018 | axes.append(MapAxis.from_table(table, format="profile")) |
||
2019 | else: |
||
2020 | raise ValueError(f"Unsupported format: '{format}'") |
||
2021 | |||
2022 | return cls(axes) |
||
2023 | |||
2024 | @classmethod |
||
2025 | def from_default(cls, axes, n_spatial_axes=None): |
||
2026 | """Make a sequence of `~MapAxis` objects.""" |
||
2027 | if axes is None: |
||
2028 | return cls([]) |
||
2029 | |||
2030 | axes_out = [] |
||
2031 | for idx, ax in enumerate(axes): |
||
2032 | if isinstance(ax, np.ndarray): |
||
2033 | ax = MapAxis(ax) |
||
2034 | |||
2035 | if ax.name == "": |
||
2036 | ax._name = f"axis{idx}" |
||
2037 | |||
2038 | axes_out.append(ax) |
||
2039 | |||
2040 | return cls(axes_out, n_spatial_axes=n_spatial_axes) |
||
2041 | |||
2042 | def assert_names(self, required_names): |
||
2043 | """Assert required axis names and order |
||
2044 | |||
2045 | Parameters |
||
2046 | ---------- |
||
2047 | required_names : list of str |
||
2048 | Required |
||
2049 | """ |
||
2050 | message = ( |
||
2051 | "Incorrect axis order or names. Expected axis " |
||
2052 | f"order: {required_names}, got: {self.names}." |
||
2053 | ) |
||
2054 | |||
2055 | if not len(self) == len(required_names): |
||
2056 | raise ValueError(message) |
||
2057 | |||
2058 | try: |
||
2059 | for ax, required_name in zip(self, required_names): |
||
2060 | ax.assert_name(required_name) |
||
2061 | |||
2062 | except ValueError: |
||
2063 | raise ValueError(message) |
||
2064 | |||
2065 | def rename_axes(self, names, new_names): |
||
2066 | """Rename the axes. |
||
2067 | |||
2068 | Parameters |
||
2069 | ---------- |
||
2070 | names : list or str |
||
2071 | Names of the axes |
||
2072 | new_names : list or str |
||
2073 | New names of the axes (list must be of same length than `names`). |
||
2074 | |||
2075 | Returns |
||
2076 | ------- |
||
2077 | axes : `MapAxes` |
||
2078 | Renamed Map axes object |
||
2079 | """ |
||
2080 | axes = self.copy() |
||
2081 | if isinstance(names, str): |
||
2082 | names = [names] |
||
2083 | if isinstance(new_names, str): |
||
2084 | new_names = [new_names] |
||
2085 | for name, new_name in zip(names, new_names): |
||
2086 | axes[name]._name = new_name |
||
2087 | return axes |
||
2088 | |||
2089 | @property |
||
2090 | def center_coord(self): |
||
2091 | """Center coordinates""" |
||
2092 | return tuple([ax.pix_to_coord((float(ax.nbin) - 1.0) / 2.0) for ax in self]) |
||
2093 | |||
2094 | def is_allclose(self, other, **kwargs): |
||
2095 | """Check if other map axes are all close. |
||
2096 | |||
2097 | Parameters |
||
2098 | ---------- |
||
2099 | other : `MapAxes` |
||
2100 | Other map axes |
||
2101 | **kwargs : dict |
||
2102 | Keyword arguments forwarded to `~MapAxis.is_allclose` |
||
2103 | |||
2104 | Returns |
||
2105 | ------- |
||
2106 | is_allclose : bool |
||
2107 | Whether other axes are all close |
||
2108 | """ |
||
2109 | if not isinstance(other, self.__class__): |
||
2110 | return TypeError(f"Cannot compare {type(self)} and {type(other)}") |
||
2111 | |||
2112 | return np.all([ax0.is_allclose(ax1, **kwargs) for ax0, ax1 in zip(other, self)]) |
||
2113 | |||
2114 | def __eq__(self, other): |
||
2115 | if not isinstance(other, self.__class__): |
||
2116 | return False |
||
2117 | |||
2118 | return self.is_allclose(other, rtol=1e-6, atol=1e-6) |
||
2119 | |||
2120 | def __ne__(self, other): |
||
2121 | return not self.__eq__(other) |
||
2122 | |||
2123 | def copy(self): |
||
2124 | """Init map axes instance by copying each axis.""" |
||
2125 | return self.__class__([_.copy() for _ in self]) |
||
2126 | |||
2127 | |||
2128 | class TimeMapAxis: |
||
2129 | """Class representing a time axis. |
||
2130 | |||
2131 | Provides methods for transforming to/from axis and pixel coordinates. |
||
2132 | A time axis can represent non-contiguous sequences of non-overlapping time intervals. |
||
2133 | |||
2134 | Time intervals must be provided in increasing order. |
||
2135 | |||
2136 | Parameters |
||
2137 | ---------- |
||
2138 | edges_min : `~astropy.units.Quantity` |
||
2139 | Array of edge time values. This the time delta w.r.t. to the reference time. |
||
2140 | edges_max : `~astropy.units.Quantity` |
||
2141 | Array of edge time values. This the time delta w.r.t. to the reference time. |
||
2142 | reference_time : `~astropy.time.Time` |
||
2143 | Reference time to use. |
||
2144 | name : str |
||
2145 | Axis name |
||
2146 | interp : str |
||
2147 | Interpolation method used to transform between axis and pixel |
||
2148 | coordinates. For now only 'lin' is supported. |
||
2149 | """ |
||
2150 | |||
2151 | node_type = "intervals" |
||
2152 | time_format = "iso" |
||
2153 | |||
2154 | def __init__(self, edges_min, edges_max, reference_time, name="time", interp="lin"): |
||
2155 | self._name = name |
||
2156 | |||
2157 | edges_min = u.Quantity(edges_min, ndmin=1) |
||
2158 | edges_max = u.Quantity(edges_max, ndmin=1) |
||
2159 | |||
2160 | if not edges_min.unit.is_equivalent("s"): |
||
2161 | raise ValueError( |
||
2162 | f"Time edges min must have a valid time unit, got {edges_min.unit}" |
||
2163 | ) |
||
2164 | |||
2165 | if not edges_max.unit.is_equivalent("s"): |
||
2166 | raise ValueError( |
||
2167 | f"Time edges max must have a valid time unit, got {edges_max.unit}" |
||
2168 | ) |
||
2169 | |||
2170 | if not edges_min.shape == edges_max.shape: |
||
2171 | raise ValueError( |
||
2172 | "Edges min and edges max must have the same shape," |
||
2173 | f" got {edges_min.shape} and {edges_max.shape}." |
||
2174 | ) |
||
2175 | |||
2176 | if not np.all(edges_max > edges_min): |
||
2177 | raise ValueError("Edges max must all be larger than edge min") |
||
2178 | |||
2179 | if not np.all(edges_min == np.sort(edges_min)): |
||
2180 | raise ValueError("Time edges min values must be sorted") |
||
2181 | |||
2182 | if not np.all(edges_max == np.sort(edges_max)): |
||
2183 | raise ValueError("Time edges max values must be sorted") |
||
2184 | |||
2185 | if interp != "lin": |
||
2186 | raise NotImplementedError( |
||
2187 | f"Non-linear scaling scheme are not supported yet, got {interp}" |
||
2188 | ) |
||
2189 | |||
2190 | self._edges_min = edges_min |
||
2191 | self._edges_max = edges_max |
||
2192 | self._reference_time = Time(reference_time) |
||
2193 | self._pix_offset = -0.5 |
||
2194 | self._interp = interp |
||
2195 | |||
2196 | delta = edges_min[1:] - edges_max[:-1] |
||
2197 | if np.any(delta < 0 * u.s): |
||
2198 | raise ValueError("Time intervals must not overlap.") |
||
2199 | |||
2200 | @property |
||
2201 | def is_contiguous(self): |
||
2202 | """Whether the axis is contiguous""" |
||
2203 | return np.all(self.edges_min[1:] == self.edges_max[:-1]) |
||
2204 | |||
2205 | def to_contiguous(self): |
||
2206 | """Make the time axis contiguous |
||
2207 | |||
2208 | Returns |
||
2209 | ------- |
||
2210 | axis : `TimeMapAxis` |
||
2211 | Contiguous time axis |
||
2212 | """ |
||
2213 | edges = np.unique(np.stack([self.edges_min, self.edges_max])) |
||
2214 | return self.__class__( |
||
2215 | edges_min=edges[:-1], |
||
2216 | edges_max=edges[1:], |
||
2217 | reference_time=self.reference_time, |
||
2218 | name=self.name, |
||
2219 | interp=self.interp, |
||
2220 | ) |
||
2221 | |||
2222 | @property |
||
2223 | def unit(self): |
||
2224 | """Axes unit""" |
||
2225 | return self.edges_max.unit |
||
2226 | |||
2227 | @property |
||
2228 | def interp(self): |
||
2229 | """Interp""" |
||
2230 | return self._interp |
||
2231 | |||
2232 | @property |
||
2233 | def reference_time(self): |
||
2234 | """Return reference time used for the axis.""" |
||
2235 | return self._reference_time |
||
2236 | |||
2237 | @property |
||
2238 | def name(self): |
||
2239 | """Return axis name.""" |
||
2240 | return self._name |
||
2241 | |||
2242 | @property |
||
2243 | def nbin(self): |
||
2244 | """Return number of bins in the axis.""" |
||
2245 | return len(self.edges_min.flatten()) |
||
2246 | |||
2247 | @property |
||
2248 | def edges_min(self): |
||
2249 | """Return array of bin edges max values.""" |
||
2250 | return self._edges_min |
||
2251 | |||
2252 | @property |
||
2253 | def edges_max(self): |
||
2254 | """Return array of bin edges min values.""" |
||
2255 | return self._edges_max |
||
2256 | |||
2257 | @property |
||
2258 | def edges(self): |
||
2259 | """Return array of bin edges values.""" |
||
2260 | if not self.is_contiguous: |
||
2261 | raise ValueError("Time axis is not contiguous") |
||
2262 | |||
2263 | return edges_from_lo_hi(self.edges_min, self.edges_max) |
||
2264 | |||
2265 | @property |
||
2266 | def bounds(self): |
||
2267 | """Bounds of the axis (~astropy.units.Quantity)""" |
||
2268 | return self.edges_min[0], self.edges_max[-1] |
||
2269 | |||
2270 | @property |
||
2271 | def time_bounds(self): |
||
2272 | """Bounds of the axis (~astropy.units.Quantity)""" |
||
2273 | t_min, t_max = self.bounds |
||
2274 | return t_min + self.reference_time, t_max + self.reference_time |
||
2275 | |||
2276 | @property |
||
2277 | def time_min(self): |
||
2278 | """Return axis lower edges as Time objects.""" |
||
2279 | return self._edges_min + self.reference_time |
||
2280 | |||
2281 | @property |
||
2282 | def time_max(self): |
||
2283 | """Return axis upper edges as Time objects.""" |
||
2284 | return self._edges_max + self.reference_time |
||
2285 | |||
2286 | @property |
||
2287 | def time_delta(self): |
||
2288 | """Return axis time bin width (`~astropy.time.TimeDelta`).""" |
||
2289 | return self.time_max - self.time_min |
||
2290 | |||
2291 | @property |
||
2292 | def time_mid(self): |
||
2293 | """Return time bin center (`~astropy.time.Time`).""" |
||
2294 | return self.time_min + 0.5 * self.time_delta |
||
2295 | |||
2296 | @property |
||
2297 | def time_edges(self): |
||
2298 | """Time edges""" |
||
2299 | return self.reference_time + self.edges |
||
2300 | |||
2301 | @property |
||
2302 | def as_plot_xerr(self): |
||
2303 | """Plot x error""" |
||
2304 | xn, xp = self.time_mid - self.time_min, self.time_max - self.time_mid |
||
2305 | |||
2306 | if self.time_format == "iso": |
||
2307 | x_errn = xn.to_datetime() |
||
2308 | x_errp = xp.to_datetime() |
||
2309 | elif self.time_format == "mjd": |
||
2310 | x_errn = xn.to("day") |
||
2311 | x_errp = xp.to("day") |
||
2312 | else: |
||
2313 | raise ValueError(f"Invalid time_format: {self.time_format}") |
||
2314 | |||
2315 | return x_errn, x_errp |
||
2316 | |||
2317 | @property |
||
2318 | def as_plot_labels(self): |
||
2319 | """Plot labels""" |
||
2320 | labels = [] |
||
2321 | |||
2322 | for t_min, t_max in self.iter_by_edges: |
||
2323 | label = f"{getattr(t_min, self.time_format)} - {getattr(t_max, self.time_format)}" |
||
2324 | labels.append(label) |
||
2325 | |||
2326 | return labels |
||
2327 | |||
2328 | @property |
||
2329 | def as_plot_edges(self): |
||
2330 | """Plot edges""" |
||
2331 | if self.time_format == "iso": |
||
2332 | edges = self.time_edges.to_datetime() |
||
2333 | elif self.time_format == "mjd": |
||
2334 | edges = self.time_edges.mjd * u.day |
||
2335 | else: |
||
2336 | raise ValueError(f"Invalid time_format: {self.time_format}") |
||
2337 | |||
2338 | return edges |
||
2339 | |||
2340 | @property |
||
2341 | def as_plot_center(self): |
||
2342 | """Plot center""" |
||
2343 | if self.time_format == "iso": |
||
2344 | center = self.time_mid.datetime |
||
2345 | elif self.time_format == "mjd": |
||
2346 | center = self.time_mid.mjd * u.day |
||
2347 | |||
2348 | return center |
||
2349 | |||
2350 | def format_plot_xaxis(self, ax): |
||
2351 | """Format plot axis |
||
2352 | |||
2353 | Parameters |
||
2354 | ---------- |
||
2355 | ax : `~matplotlib.pyplot.Axis` |
||
2356 | Plot axis to format |
||
2357 | |||
2358 | Returns |
||
2359 | ------- |
||
2360 | ax : `~matplotlib.pyplot.Axis` |
||
2361 | Formatted plot axis |
||
2362 | """ |
||
2363 | from matplotlib.dates import DateFormatter |
||
2364 | |||
2365 | xlabel = DEFAULT_LABEL_TEMPLATE.format( |
||
2366 | quantity=PLOT_AXIS_LABEL.get(self.name, self.name.capitalize()), |
||
2367 | unit=self.time_format, |
||
2368 | ) |
||
2369 | ax.set_xlabel(xlabel) |
||
2370 | |||
2371 | if self.time_format == "iso": |
||
2372 | ax.xaxis.set_major_formatter(DateFormatter("%Y-%m-%d %H:%M:%S")) |
||
2373 | plt.setp( |
||
2374 | ax.xaxis.get_majorticklabels(), |
||
2375 | rotation=30, |
||
2376 | ha="right", |
||
2377 | rotation_mode="anchor", |
||
2378 | ) |
||
2379 | |||
2380 | return ax |
||
2381 | |||
2382 | def assert_name(self, required_name): |
||
2383 | """Assert axis name if a specific one is required. |
||
2384 | |||
2385 | Parameters |
||
2386 | ---------- |
||
2387 | required_name : str |
||
2388 | Required |
||
2389 | """ |
||
2390 | if self.name != required_name: |
||
2391 | raise ValueError( |
||
2392 | "Unexpected axis name," |
||
2393 | f' expected "{required_name}", got: "{self.name}"' |
||
2394 | ) |
||
2395 | |||
2396 | def is_allclose(self, other, **kwargs): |
||
2397 | """Check if other map axis is all close. |
||
2398 | |||
2399 | Parameters |
||
2400 | ---------- |
||
2401 | other : `TimeMapAxis` |
||
2402 | Other map axis |
||
2403 | **kwargs : dict |
||
2404 | Keyword arguments forwarded to `~numpy.allclose` |
||
2405 | |||
2406 | Returns |
||
2407 | ------- |
||
2408 | is_allclose : bool |
||
2409 | Whether other axis is allclose |
||
2410 | """ |
||
2411 | if not isinstance(other, self.__class__): |
||
2412 | return TypeError(f"Cannot compare {type(self)} and {type(other)}") |
||
2413 | |||
2414 | if self._edges_min.shape != other._edges_min.shape: |
||
2415 | return False |
||
2416 | |||
2417 | # This will test equality at microsec level. |
||
2418 | delta_min = self.time_min - other.time_min |
||
2419 | delta_max = self.time_max - other.time_max |
||
2420 | |||
2421 | return ( |
||
2422 | np.allclose(delta_min.to_value("s"), 0.0, **kwargs) |
||
2423 | and np.allclose(delta_max.to_value("s"), 0.0, **kwargs) |
||
2424 | and self._interp == other._interp |
||
2425 | and self.name.upper() == other.name.upper() |
||
2426 | ) |
||
2427 | |||
2428 | def __eq__(self, other): |
||
2429 | if not isinstance(other, self.__class__): |
||
2430 | return False |
||
2431 | |||
2432 | return self.is_allclose(other=other, atol=1e-6) |
||
2433 | |||
2434 | def __ne__(self, other): |
||
2435 | return not self.__eq__(other) |
||
2436 | |||
2437 | def __hash__(self): |
||
2438 | return id(self) |
||
2439 | |||
2440 | def is_aligned(self, other, atol=2e-2): |
||
2441 | raise NotImplementedError |
||
2442 | |||
2443 | @property |
||
2444 | def iter_by_edges(self): |
||
2445 | """Iterate by intervals defined by the edges""" |
||
2446 | for time_min, time_max in zip(self.time_min, self.time_max): |
||
2447 | yield (time_min, time_max) |
||
2448 | |||
2449 | def coord_to_idx(self, coord, **kwargs): |
||
2450 | """Transform from axis time coordinate to bin index. |
||
2451 | |||
2452 | Indices of time values falling outside time bins will be |
||
2453 | set to -1. |
||
2454 | |||
2455 | Parameters |
||
2456 | ---------- |
||
2457 | coord : `~astropy.time.Time` or `~astropy.units.Quantity` |
||
2458 | Array of axis coordinate values. The quantity is assumed |
||
2459 | to be relative to the reference time. |
||
2460 | |||
2461 | Returns |
||
2462 | ------- |
||
2463 | idx : `~numpy.ndarray` |
||
2464 | Array of bin indices. |
||
2465 | """ |
||
2466 | if isinstance(coord, u.Quantity): |
||
2467 | coord = self.reference_time + coord |
||
2468 | |||
2469 | time = Time(coord[..., np.newaxis]) |
||
2470 | delta_plus = (time - self.time_min).value > 0.0 |
||
2471 | delta_minus = (time - self.time_max).value <= 0.0 |
||
2472 | mask = np.logical_and(delta_plus, delta_minus) |
||
2473 | |||
2474 | idx = np.asanyarray(np.argmax(mask, axis=-1)) |
||
2475 | idx[~np.any(mask, axis=-1)] = INVALID_INDEX.int |
||
2476 | return idx |
||
2477 | |||
2478 | def coord_to_pix(self, coord, **kwargs): |
||
2479 | """Transform from time to coordinate to pixel position. |
||
2480 | |||
2481 | Pixels of time values falling outside time bins will be |
||
2482 | set to -1. |
||
2483 | |||
2484 | Parameters |
||
2485 | ---------- |
||
2486 | coord : `~astropy.time.Time` |
||
2487 | Array of axis coordinate values. |
||
2488 | |||
2489 | Returns |
||
2490 | ------- |
||
2491 | pix : `~numpy.ndarray` |
||
2492 | Array of pixel positions. |
||
2493 | """ |
||
2494 | if isinstance(coord, u.Quantity): |
||
2495 | coord = self.reference_time + coord |
||
2496 | |||
2497 | idx = np.atleast_1d(self.coord_to_idx(coord)) |
||
2498 | |||
2499 | valid_pix = idx != INVALID_INDEX.int |
||
2500 | pix = np.atleast_1d(idx).astype("float") |
||
2501 | |||
2502 | # TODO: is there the equivalent of np.atleast1d for astropy.time.Time? |
||
2503 | if coord.shape == (): |
||
2504 | coord = coord.reshape((1,)) |
||
2505 | |||
2506 | relative_time = coord[valid_pix] - self.reference_time |
||
2507 | |||
2508 | scale = interpolation_scale(self._interp) |
||
2509 | valid_idx = idx[valid_pix] |
||
2510 | s_min = scale(self._edges_min[valid_idx]) |
||
2511 | s_max = scale(self._edges_max[valid_idx]) |
||
2512 | s_coord = scale(relative_time.to(self._edges_min.unit)) |
||
2513 | |||
2514 | pix[valid_pix] += (s_coord - s_min) / (s_max - s_min) |
||
2515 | pix[~valid_pix] = INVALID_INDEX.float |
||
2516 | return pix - 0.5 |
||
2517 | |||
2518 | @staticmethod |
||
2519 | def pix_to_idx(pix, clip=False): |
||
2520 | return pix |
||
2521 | |||
2522 | @property |
||
2523 | def center(self): |
||
2524 | """Return `~astropy.units.Quantity` at interval centers.""" |
||
2525 | return self.edges_min + 0.5 * self.bin_width |
||
2526 | |||
2527 | @property |
||
2528 | def bin_width(self): |
||
2529 | """Return time interval width (`~astropy.units.Quantity`).""" |
||
2530 | return self.time_delta.to("h") |
||
2531 | |||
2532 | def __repr__(self): |
||
2533 | str_ = self.__class__.__name__ + "\n" |
||
2534 | str_ += "-" * len(self.__class__.__name__) + "\n\n" |
||
2535 | fmt = "\t{:<14s} : {:<10s}\n" |
||
2536 | str_ += fmt.format("name", self.name) |
||
2537 | str_ += fmt.format("nbins", str(self.nbin)) |
||
2538 | str_ += fmt.format("reference time", self.reference_time.iso) |
||
2539 | str_ += fmt.format("scale", self.reference_time.scale) |
||
2540 | str_ += fmt.format("time min.", self.time_min.min().iso) |
||
2541 | str_ += fmt.format("time max.", self.time_max.max().iso) |
||
2542 | str_ += fmt.format("total time", np.sum(self.bin_width)) |
||
2543 | return str_.expandtabs(tabsize=2) |
||
2544 | |||
2545 | def upsample(self): |
||
2546 | raise NotImplementedError |
||
2547 | |||
2548 | def downsample(self): |
||
2549 | raise NotImplementedError |
||
2550 | |||
2551 | View Code Duplication | def _init_copy(self, **kwargs): |
|
2552 | """Init map axis instance by copying missing init arguments from self.""" |
||
2553 | argnames = inspect.getfullargspec(self.__init__).args |
||
2554 | argnames.remove("self") |
||
2555 | |||
2556 | for arg in argnames: |
||
2557 | value = getattr(self, "_" + arg) |
||
2558 | kwargs.setdefault(arg, copy.deepcopy(value)) |
||
2559 | |||
2560 | return self.__class__(**kwargs) |
||
2561 | |||
2562 | def copy(self, **kwargs): |
||
2563 | """Copy `MapAxis` instance and overwrite given attributes. |
||
2564 | |||
2565 | Parameters |
||
2566 | ---------- |
||
2567 | **kwargs : dict |
||
2568 | Keyword arguments to overwrite in the map axis constructor. |
||
2569 | |||
2570 | Returns |
||
2571 | ------- |
||
2572 | copy : `MapAxis` |
||
2573 | Copied map axis. |
||
2574 | """ |
||
2575 | return self._init_copy(**kwargs) |
||
2576 | |||
2577 | def slice(self, idx): |
||
2578 | """Create a new axis object by extracting a slice from this axis. |
||
2579 | |||
2580 | Parameters |
||
2581 | ---------- |
||
2582 | idx : slice |
||
2583 | Slice object selecting a subselection of the axis. |
||
2584 | |||
2585 | Returns |
||
2586 | ------- |
||
2587 | axis : `~TimeMapAxis` |
||
2588 | Sliced axis object. |
||
2589 | """ |
||
2590 | return TimeMapAxis( |
||
2591 | self._edges_min[idx].copy(), |
||
2592 | self._edges_max[idx].copy(), |
||
2593 | self.reference_time, |
||
2594 | interp=self._interp, |
||
2595 | name=self.name, |
||
2596 | ) |
||
2597 | |||
2598 | def squash(self): |
||
2599 | """Create a new axis object by squashing the axis into one bin. |
||
2600 | |||
2601 | Returns |
||
2602 | ------- |
||
2603 | axis : `~MapAxis` |
||
2604 | Sliced axis object. |
||
2605 | """ |
||
2606 | return TimeMapAxis( |
||
2607 | self._edges_min[0], |
||
2608 | self._edges_max[-1], |
||
2609 | self.reference_time, |
||
2610 | interp=self._interp, |
||
2611 | name=self._name, |
||
2612 | ) |
||
2613 | |||
2614 | # TODO: if we are to allow log or sqrt bins the reference time should always |
||
2615 | # be strictly lower than all times |
||
2616 | # Should we define a mechanism to ensure this is always correct? |
||
2617 | @classmethod |
||
2618 | def from_time_edges(cls, time_min, time_max, unit="d", interp="lin", name="time"): |
||
2619 | """Create TimeMapAxis from the time interval edges defined as `~astropy.time.Time`. |
||
2620 | |||
2621 | The reference time is defined as the lower edge of the first interval. |
||
2622 | |||
2623 | Parameters |
||
2624 | ---------- |
||
2625 | time_min : `~astropy.time.Time` |
||
2626 | Array of lower edge times. |
||
2627 | time_max : `~astropy.time.Time` |
||
2628 | Array of lower edge times. |
||
2629 | unit : `~astropy.units.Unit` or str |
||
2630 | The unit to convert the edges to. Default is 'd' (day). |
||
2631 | interp : str |
||
2632 | Interpolation method used to transform between axis and pixel |
||
2633 | coordinates. Valid options are 'log', 'lin', and 'sqrt'. |
||
2634 | name : str |
||
2635 | Axis name |
||
2636 | |||
2637 | Returns |
||
2638 | ------- |
||
2639 | axis : `TimeMapAxis` |
||
2640 | Time map axis. |
||
2641 | """ |
||
2642 | unit = u.Unit(unit) |
||
2643 | reference_time = time_min[0] |
||
2644 | edges_min = time_min - reference_time |
||
2645 | edges_max = time_max - reference_time |
||
2646 | |||
2647 | return cls( |
||
2648 | edges_min.to(unit), |
||
2649 | edges_max.to(unit), |
||
2650 | reference_time, |
||
2651 | interp=interp, |
||
2652 | name=name, |
||
2653 | ) |
||
2654 | |||
2655 | # TODO: how configurable should that be? column names? |
||
2656 | @classmethod |
||
2657 | def from_table(cls, table, format="gadf", idx=0): |
||
2658 | """Create time map axis from table |
||
2659 | |||
2660 | Parameters |
||
2661 | ---------- |
||
2662 | table : `~astropy.table.Table` |
||
2663 | Bin table HDU |
||
2664 | format : {"gadf", "fermi-fgl", "lightcurve"} |
||
2665 | Format to use. |
||
2666 | |||
2667 | Returns |
||
2668 | ------- |
||
2669 | axis : `TimeMapAxis` |
||
2670 | Time map axis. |
||
2671 | """ |
||
2672 | if format == "gadf": |
||
2673 | axcols = table.meta.get("AXCOLS{}".format(idx + 1)) |
||
2674 | colnames = axcols.split(",") |
||
2675 | name = colnames[0].replace("_MIN", "").lower() |
||
2676 | reference_time = time_ref_from_dict(table.meta) |
||
2677 | edges_min = np.unique(table[colnames[0]].quantity) |
||
2678 | edges_max = np.unique(table[colnames[1]].quantity) |
||
2679 | elif format == "fermi-fgl": |
||
2680 | meta = table.meta.copy() |
||
2681 | meta["MJDREFF"] = str(meta["MJDREFF"]).replace("D-4", "e-4") |
||
2682 | reference_time = time_ref_from_dict(meta=meta) |
||
2683 | name = "time" |
||
2684 | edges_min = table["Hist_Start"][:-1] |
||
2685 | edges_max = table["Hist_Start"][1:] |
||
2686 | elif format == "lightcurve": |
||
2687 | # TODO: is this a good format? It just supports mjd... |
||
2688 | name = "time" |
||
2689 | scale = table.meta.get("TIMESYS", "utc") |
||
2690 | time_min = Time(table["time_min"].data, format="mjd", scale=scale) |
||
2691 | time_max = Time(table["time_max"].data, format="mjd", scale=scale) |
||
2692 | reference_time = Time("2001-01-01T00:00:00") |
||
2693 | reference_time.format = "mjd" |
||
2694 | edges_min = (time_min - reference_time).to("s") |
||
2695 | edges_max = (time_max - reference_time).to("s") |
||
2696 | else: |
||
2697 | raise ValueError(f"Not a supported format: {format}") |
||
2698 | |||
2699 | return cls( |
||
2700 | edges_min=edges_min, |
||
2701 | edges_max=edges_max, |
||
2702 | reference_time=reference_time, |
||
2703 | name=name, |
||
2704 | ) |
||
2705 | |||
2706 | @classmethod |
||
2707 | def from_gti(cls, gti, name="time"): |
||
2708 | """Create a time axis from an input GTI. |
||
2709 | |||
2710 | Parameters |
||
2711 | ---------- |
||
2712 | gti : `GTI` |
||
2713 | GTI table |
||
2714 | name : str |
||
2715 | Axis name |
||
2716 | |||
2717 | Returns |
||
2718 | ------- |
||
2719 | axis : `TimeMapAxis` |
||
2720 | Time map axis. |
||
2721 | |||
2722 | """ |
||
2723 | tmin = gti.time_start - gti.time_ref |
||
2724 | tmax = gti.time_stop - gti.time_ref |
||
2725 | |||
2726 | return cls( |
||
2727 | edges_min=tmin.to("s"), |
||
2728 | edges_max=tmax.to("s"), |
||
2729 | reference_time=gti.time_ref, |
||
2730 | name=name, |
||
2731 | ) |
||
2732 | |||
2733 | @classmethod |
||
2734 | def from_time_bounds(cls, time_min, time_max, nbin, unit="d", name="time"): |
||
2735 | """Create linearly spaced time axis from bounds |
||
2736 | |||
2737 | Parameters |
||
2738 | ---------- |
||
2739 | time_min : `~astropy.time.Time` |
||
2740 | Lower bound |
||
2741 | time_max : `~astropy.time.Time` |
||
2742 | Upper bound |
||
2743 | nbin : int |
||
2744 | Number of bins |
||
2745 | name : str |
||
2746 | Name of the axis. |
||
2747 | """ |
||
2748 | delta = time_max - time_min |
||
2749 | time_edges = time_min + delta * np.linspace(0, 1, nbin + 1) |
||
2750 | return cls.from_time_edges( |
||
2751 | time_min=time_edges[:-1], |
||
2752 | time_max=time_edges[1:], |
||
2753 | interp="lin", |
||
2754 | unit=unit, |
||
2755 | name=name, |
||
2756 | ) |
||
2757 | |||
2758 | def to_header(self, format="gadf", idx=0): |
||
2759 | """Create FITS header |
||
2760 | |||
2761 | Parameters |
||
2762 | ---------- |
||
2763 | format : {"ogip"} |
||
2764 | Format specification |
||
2765 | idx : int |
||
2766 | Column index of the axis. |
||
2767 | |||
2768 | Returns |
||
2769 | ------- |
||
2770 | header : `~astropy.io.fits.Header` |
||
2771 | Header to extend. |
||
2772 | """ |
||
2773 | header = fits.Header() |
||
2774 | |||
2775 | if format == "gadf": |
||
2776 | key = f"AXCOLS{idx}" |
||
2777 | name = self.name.upper() |
||
2778 | header[key] = f"{name}_MIN,{name}_MAX" |
||
2779 | key_interp = f"INTERP{idx}" |
||
2780 | header[key_interp] = self.interp |
||
2781 | |||
2782 | ref_dict = time_ref_to_dict(self.reference_time) |
||
2783 | header.update(ref_dict) |
||
2784 | else: |
||
2785 | raise ValueError(f"Unknown format {format}") |
||
2786 | |||
2787 | return header |
||
2788 | |||
2789 | |||
2790 | class LabelMapAxis: |
||
2791 | """Map axis using labels |
||
2792 | |||
2793 | Parameters |
||
2794 | ---------- |
||
2795 | labels : list of str |
||
2796 | Labels to be used for the axis nodes. |
||
2797 | name : str |
||
2798 | Name of the axis. |
||
2799 | |||
2800 | """ |
||
2801 | |||
2802 | node_type = "label" |
||
2803 | |||
2804 | def __init__(self, labels, name=""): |
||
2805 | unique_labels = np.unique(labels) |
||
2806 | |||
2807 | if not len(unique_labels) == len(labels): |
||
2808 | raise ValueError("Node labels must be unique") |
||
2809 | |||
2810 | self._labels = unique_labels |
||
2811 | self._name = name |
||
2812 | |||
2813 | @property |
||
2814 | def unit(self): |
||
2815 | """Unit""" |
||
2816 | return u.Unit("") |
||
2817 | |||
2818 | @property |
||
2819 | def name(self): |
||
2820 | """Name of the axis""" |
||
2821 | return self._name |
||
2822 | |||
2823 | def assert_name(self, required_name): |
||
2824 | """Assert axis name if a specific one is required. |
||
2825 | |||
2826 | Parameters |
||
2827 | ---------- |
||
2828 | required_name : str |
||
2829 | Required |
||
2830 | """ |
||
2831 | if self.name != required_name: |
||
2832 | raise ValueError( |
||
2833 | "Unexpected axis name," |
||
2834 | f' expected "{required_name}", got: "{self.name}"' |
||
2835 | ) |
||
2836 | |||
2837 | @property |
||
2838 | def nbin(self): |
||
2839 | """Number of bins""" |
||
2840 | return len(self._labels) |
||
2841 | |||
2842 | def pix_to_coord(self, pix): |
||
2843 | """Transform from pixel to axis coordinates. |
||
2844 | |||
2845 | Parameters |
||
2846 | ---------- |
||
2847 | pix : `~numpy.ndarray` |
||
2848 | Array of pixel coordinate values. |
||
2849 | |||
2850 | Returns |
||
2851 | ------- |
||
2852 | coord : `~numpy.ndarray` |
||
2853 | Array of axis coordinate values. |
||
2854 | """ |
||
2855 | idx = np.round(pix).astype(int) |
||
2856 | return self._labels[idx] |
||
2857 | |||
2858 | def coord_to_idx(self, coord, **kwargs): |
||
2859 | """Transform labels to indices |
||
2860 | |||
2861 | If the label is not present an error is raised. |
||
2862 | |||
2863 | Parameters |
||
2864 | ---------- |
||
2865 | coord : `~astropy.time.Time` |
||
2866 | Array of axis coordinate values. |
||
2867 | |||
2868 | Returns |
||
2869 | ------- |
||
2870 | idx : `~numpy.ndarray` |
||
2871 | Array of bin indices. |
||
2872 | """ |
||
2873 | coord = np.array(coord)[..., np.newaxis] |
||
2874 | is_equal = coord == self._labels |
||
2875 | |||
2876 | if not np.all(np.any(is_equal, axis=-1)): |
||
2877 | label = coord[~np.any(is_equal, axis=-1)] |
||
2878 | raise ValueError(f"Not a valid label: {label}") |
||
2879 | |||
2880 | return np.argmax(is_equal, axis=-1) |
||
2881 | |||
2882 | def coord_to_pix(self, coord): |
||
2883 | """Transform from axis labels to pixel coordinates. |
||
2884 | |||
2885 | Parameters |
||
2886 | ---------- |
||
2887 | coord : `~numpy.ndarray` |
||
2888 | Array of axis label values. |
||
2889 | |||
2890 | Returns |
||
2891 | ------- |
||
2892 | pix : `~numpy.ndarray` |
||
2893 | Array of pixel coordinate values. |
||
2894 | """ |
||
2895 | return self.coord_to_idx(coord).astype("float") |
||
2896 | |||
2897 | View Code Duplication | def pix_to_idx(self, pix, clip=False): |
|
2898 | """Convert pix to idx |
||
2899 | |||
2900 | Parameters |
||
2901 | ---------- |
||
2902 | pix : tuple of `~numpy.ndarray` |
||
2903 | Pixel coordinates. |
||
2904 | clip : bool |
||
2905 | Choose whether to clip indices to the valid range of the |
||
2906 | axis. If false then indices for coordinates outside |
||
2907 | the axi range will be set -1. |
||
2908 | |||
2909 | Returns |
||
2910 | ------- |
||
2911 | idx : tuple `~numpy.ndarray` |
||
2912 | Pixel indices. |
||
2913 | """ |
||
2914 | if clip: |
||
2915 | idx = np.clip(pix, 0, self.nbin - 1) |
||
2916 | else: |
||
2917 | condition = (pix < 0) | (pix >= self.nbin) |
||
2918 | idx = np.where(condition, -1, pix) |
||
2919 | |||
2920 | return idx |
||
2921 | |||
2922 | @property |
||
2923 | def center(self): |
||
2924 | """Center of the label axis""" |
||
2925 | return self._labels |
||
2926 | |||
2927 | @property |
||
2928 | def edges(self): |
||
2929 | """Edges of the label axis""" |
||
2930 | raise ValueError("A LabelMapAxis does not define edges") |
||
2931 | |||
2932 | @property |
||
2933 | def edges_min(self): |
||
2934 | """Edges of the label axis""" |
||
2935 | return self._labels |
||
2936 | |||
2937 | @property |
||
2938 | def edges_max(self): |
||
2939 | """Edges of the label axis""" |
||
2940 | return self._labels |
||
2941 | |||
2942 | @property |
||
2943 | def bin_width(self): |
||
2944 | """Bin width is unity""" |
||
2945 | return np.ones(self.nbin) |
||
2946 | |||
2947 | @property |
||
2948 | def as_plot_xerr(self): |
||
2949 | """Plot labels""" |
||
2950 | return 0.5 * np.ones(self.nbin) |
||
2951 | |||
2952 | @property |
||
2953 | def as_plot_labels(self): |
||
2954 | """Plot labels""" |
||
2955 | return self._labels.tolist() |
||
2956 | |||
2957 | @property |
||
2958 | def as_plot_center(self): |
||
2959 | """Plot labels""" |
||
2960 | return np.arange(self.nbin) |
||
2961 | |||
2962 | @property |
||
2963 | def as_plot_edges(self): |
||
2964 | """Plot labels""" |
||
2965 | return np.arange(self.nbin + 1) - 0.5 |
||
2966 | |||
2967 | def format_plot_xaxis(self, ax): |
||
2968 | """Format plot axis. |
||
2969 | |||
2970 | Parameters |
||
2971 | ---------- |
||
2972 | ax : `~matplotlib.pyplot.Axis` |
||
2973 | Plot axis to format. |
||
2974 | |||
2975 | Returns |
||
2976 | ------- |
||
2977 | ax : `~matplotlib.pyplot.Axis` |
||
2978 | Formatted plot axis. |
||
2979 | """ |
||
2980 | ax.set_xticks(self.as_plot_center) |
||
2981 | ax.set_xticklabels( |
||
2982 | self.as_plot_labels, |
||
2983 | rotation=30, |
||
2984 | ha="right", |
||
2985 | rotation_mode="anchor", |
||
2986 | ) |
||
2987 | return ax |
||
2988 | |||
2989 | def to_header(self, format="gadf", idx=0): |
||
2990 | """Create FITS header |
||
2991 | |||
2992 | Parameters |
||
2993 | ---------- |
||
2994 | format : {"ogip"} |
||
2995 | Format specification |
||
2996 | idx : int |
||
2997 | Column index of the axis. |
||
2998 | |||
2999 | Returns |
||
3000 | ------- |
||
3001 | header : `~astropy.io.fits.Header` |
||
3002 | Header to extend. |
||
3003 | """ |
||
3004 | header = fits.Header() |
||
3005 | |||
3006 | if format == "gadf": |
||
3007 | key = f"AXCOLS{idx}" |
||
3008 | header[key] = self.name.upper() |
||
3009 | else: |
||
3010 | raise ValueError(f"Unknown format {format}") |
||
3011 | |||
3012 | return header |
||
3013 | |||
3014 | # TODO: how configurable should that be? column names? |
||
3015 | @classmethod |
||
3016 | def from_table(cls, table, format="gadf", idx=0): |
||
3017 | """Create time map axis from table |
||
3018 | |||
3019 | Parameters |
||
3020 | ---------- |
||
3021 | table : `~astropy.table.Table` |
||
3022 | Bin table HDU |
||
3023 | format : {"gadf"} |
||
3024 | Format to use. |
||
3025 | |||
3026 | Returns |
||
3027 | ------- |
||
3028 | axis : `TimeMapAxis` |
||
3029 | Time map axis. |
||
3030 | """ |
||
3031 | if format == "gadf": |
||
3032 | colname = table.meta.get("AXCOLS{}".format(idx + 1)) |
||
3033 | column = table[colname] |
||
3034 | if not np.issubdtype(column.dtype, np.str_): |
||
3035 | raise TypeError(f"Not a valid dtype for label axis: '{column.dtype}'") |
||
3036 | labels = np.unique(column.data) |
||
3037 | else: |
||
3038 | raise ValueError(f"Not a supported format: {format}") |
||
3039 | |||
3040 | return cls(labels=labels, name=colname.lower()) |
||
3041 | |||
3042 | def __repr__(self): |
||
3043 | str_ = self.__class__.__name__ + "\n" |
||
3044 | str_ += "-" * len(self.__class__.__name__) + "\n\n" |
||
3045 | fmt = "\t{:<10s} : {:<10s}\n" |
||
3046 | str_ += fmt.format("name", self.name) |
||
3047 | str_ += fmt.format("nbins", str(self.nbin)) |
||
3048 | str_ += fmt.format("node type", self.node_type) |
||
3049 | str_ += fmt.format("labels", "{0}".format(list(self._labels))) |
||
3050 | return str_.expandtabs(tabsize=2) |
||
3051 | |||
3052 | def is_allclose(self, other, **kwargs): |
||
3053 | """Check if other map axis is all close. |
||
3054 | |||
3055 | Parameters |
||
3056 | ---------- |
||
3057 | other : `LabelMapAxis` |
||
3058 | Other map axis |
||
3059 | |||
3060 | Returns |
||
3061 | ------- |
||
3062 | is_allclose : bool |
||
3063 | Whether other axis is allclose |
||
3064 | """ |
||
3065 | if not isinstance(other, self.__class__): |
||
3066 | return TypeError(f"Cannot compare {type(self)} and {type(other)}") |
||
3067 | |||
3068 | name_equal = self.name.upper() == other.name.upper() |
||
3069 | labels_equal = np.all(self.center == other.center) |
||
3070 | return name_equal & labels_equal |
||
3071 | |||
3072 | def __eq__(self, other): |
||
3073 | if not isinstance(other, self.__class__): |
||
3074 | return False |
||
3075 | |||
3076 | return self.is_allclose(other=other) |
||
3077 | |||
3078 | def __ne__(self, other): |
||
3079 | return not self.__eq__(other) |
||
3080 | |||
3081 | # TODO: could create sub-labels here using dashes like "label-1-a", etc. |
||
3082 | def upsample(self, *args, **kwargs): |
||
3083 | """Upsample axis""" |
||
3084 | raise NotImplementedError("Upsampling a LabelMapAxis is not supported") |
||
3085 | |||
3086 | # TODO: could merge labels here like "label-1-label2", etc. |
||
3087 | def downsample(self, *args, **kwargs): |
||
3088 | """Downsample axis""" |
||
3089 | raise NotImplementedError("Downsampling a LabelMapAxis is not supported") |
||
3090 | |||
3091 | # TODO: could merge labels here like "label-1-label2", etc. |
||
3092 | def resample(self, *args, **kwargs): |
||
3093 | """Resample axis""" |
||
3094 | raise NotImplementedError("Resampling a LabelMapAxis is not supported") |
||
3095 | |||
3096 | # TODO: could create new labels here like "label-10-a" |
||
3097 | def pad(self, *args, **kwargs): |
||
3098 | """Resample axis""" |
||
3099 | raise NotImplementedError("Padding a LabelMapAxis is not supported") |
||
3100 | |||
3101 | def copy(self): |
||
3102 | """Copy axis""" |
||
3103 | return copy.deepcopy(self) |
||
3104 | |||
3105 | def slice(self, idx): |
||
3106 | """Create a new axis object by extracting a slice from this axis. |
||
3107 | |||
3108 | Parameters |
||
3109 | ---------- |
||
3110 | idx : slice |
||
3111 | Slice object selecting a subselection of the axis. |
||
3112 | |||
3113 | Returns |
||
3114 | ------- |
||
3115 | axis : `~LabelMapAxis` |
||
3116 | Sliced axis object. |
||
3117 | """ |
||
3118 | return self.__class__( |
||
3119 | labels=self._labels[idx], |
||
3120 | name=self.name, |
||
3121 | ) |
||
3122 |