Total Complexity | 205 |
Total Lines | 1908 |
Duplicated Lines | 1.47 % |
Changes | 0 |
Duplicate code is one of the most pungent code smells. A rule that is often used is to re-structure code once it is duplicated in three or more places.
Common duplication problems, and corresponding solutions are:
Complex classes like gammapy.maps.core often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
1 | # Licensed under a 3-clause BSD style license - see LICENSE.rst |
||
2 | import abc |
||
3 | import copy |
||
4 | import inspect |
||
5 | import json |
||
6 | from collections import OrderedDict |
||
7 | import numpy as np |
||
8 | from astropy import units as u |
||
9 | from astropy.io import fits |
||
10 | import matplotlib.pyplot as plt |
||
11 | from gammapy.utils.random import InverseCDFSampler, get_random_state |
||
12 | from gammapy.utils.scripts import make_path |
||
13 | from gammapy.utils.units import energy_unit_format |
||
14 | from .axes import MapAxis |
||
15 | from .coord import MapCoord |
||
16 | from .geom import pix_tuple_to_idx |
||
17 | from .io import JsonQuantityDecoder |
||
18 | |||
19 | __all__ = ["Map"] |
||
20 | |||
21 | |||
22 | class Map(abc.ABC): |
||
23 | """Abstract map class. |
||
24 | |||
25 | This can represent WCS- or HEALPIX-based maps |
||
26 | with 2 spatial dimensions and N non-spatial dimensions. |
||
27 | |||
28 | Parameters |
||
29 | ---------- |
||
30 | geom : `~gammapy.maps.Geom` |
||
31 | Geometry |
||
32 | data : `~numpy.ndarray` or `~astropy.units.Quantity` |
||
33 | Data array |
||
34 | meta : `dict` |
||
35 | Dictionary to store meta data |
||
36 | unit : str or `~astropy.units.Unit` |
||
37 | Data unit, ignored if data is a Quantity. |
||
38 | """ |
||
39 | |||
40 | tag = "map" |
||
41 | |||
42 | def __init__(self, geom, data, meta=None, unit=""): |
||
43 | self._geom = geom |
||
44 | |||
45 | if isinstance(data, u.Quantity): |
||
46 | self._unit = u.Unit(unit) |
||
47 | self.quantity = data |
||
48 | else: |
||
49 | self.data = data |
||
50 | self._unit = u.Unit(unit) |
||
51 | |||
52 | if meta is None: |
||
53 | self.meta = {} |
||
54 | else: |
||
55 | self.meta = meta |
||
56 | |||
57 | def _init_copy(self, **kwargs): |
||
58 | """Init map instance by copying missing init arguments from self.""" |
||
59 | argnames = inspect.getfullargspec(self.__init__).args |
||
60 | argnames.remove("self") |
||
61 | argnames.remove("dtype") |
||
62 | |||
63 | for arg in argnames: |
||
64 | value = getattr(self, "_" + arg) |
||
65 | kwargs.setdefault(arg, copy.deepcopy(value)) |
||
66 | |||
67 | return self.from_geom(**kwargs) |
||
68 | |||
69 | @property |
||
70 | def is_mask(self): |
||
71 | """Whether map is mask with bool dtype""" |
||
72 | return self.data.dtype == bool |
||
73 | |||
74 | @property |
||
75 | def geom(self): |
||
76 | """Map geometry (`~gammapy.maps.Geom`)""" |
||
77 | return self._geom |
||
78 | |||
79 | @property |
||
80 | def data(self): |
||
81 | """Data array (`~numpy.ndarray`)""" |
||
82 | return self._data |
||
83 | |||
84 | @data.setter |
||
85 | def data(self, value): |
||
86 | """Set data |
||
87 | |||
88 | Parameters |
||
89 | ---------- |
||
90 | value : array-like |
||
91 | Data array |
||
92 | """ |
||
93 | if np.isscalar(value): |
||
94 | value = value * np.ones(self.geom.data_shape, dtype=type(value)) |
||
95 | |||
96 | if isinstance(value, u.Quantity): |
||
97 | raise TypeError("Map data must be a Numpy array. Set unit separately") |
||
98 | |||
99 | if not value.shape == self.geom.data_shape: |
||
100 | value = value.reshape(self.geom.data_shape) |
||
101 | |||
102 | self._data = value |
||
103 | |||
104 | @property |
||
105 | def unit(self): |
||
106 | """Map unit (`~astropy.units.Unit`)""" |
||
107 | return self._unit |
||
108 | |||
109 | @property |
||
110 | def meta(self): |
||
111 | """Map meta (`dict`)""" |
||
112 | return self._meta |
||
113 | |||
114 | @meta.setter |
||
115 | def meta(self, val): |
||
116 | self._meta = val |
||
117 | |||
118 | @property |
||
119 | def quantity(self): |
||
120 | """Map data times unit (`~astropy.units.Quantity`)""" |
||
121 | return u.Quantity(self.data, self.unit, copy=False) |
||
122 | |||
123 | @quantity.setter |
||
124 | def quantity(self, val): |
||
125 | """Set data and unit |
||
126 | |||
127 | Parameters |
||
128 | ---------- |
||
129 | value : `~astropy.units.Quantity` |
||
130 | Quantity |
||
131 | """ |
||
132 | val = u.Quantity(val, copy=False) |
||
133 | |||
134 | self.data = val.value |
||
135 | self._unit = val.unit |
||
136 | |||
137 | def rename_axes(self, names, new_names): |
||
138 | """Rename the Map axes. |
||
139 | |||
140 | Parameters |
||
141 | ---------- |
||
142 | names : list or str |
||
143 | Names of the axes. |
||
144 | new_names : list or str |
||
145 | New names of the axes (list must be of same length than `names`). |
||
146 | |||
147 | Returns |
||
148 | ------- |
||
149 | geom : `~Map` |
||
150 | Renamed Map. |
||
151 | """ |
||
152 | geom = self.geom.rename_axes(names=names, new_names=new_names) |
||
153 | return self._init_copy(geom=geom) |
||
154 | |||
155 | @staticmethod |
||
156 | def create(**kwargs): |
||
157 | """Create an empty map object. |
||
158 | |||
159 | This method accepts generic options listed below, as well as options |
||
160 | for `HpxMap` and `WcsMap` objects. For WCS-specific options, see |
||
161 | `WcsMap.create` and for HPX-specific options, see `HpxMap.create`. |
||
162 | |||
163 | Parameters |
||
164 | ---------- |
||
165 | frame : str |
||
166 | Coordinate system, either Galactic ("galactic") or Equatorial |
||
167 | ("icrs"). |
||
168 | map_type : {'wcs', 'wcs-sparse', 'hpx', 'hpx-sparse', 'region'} |
||
169 | Map type. Selects the class that will be used to |
||
170 | instantiate the map. |
||
171 | binsz : float or `~numpy.ndarray` |
||
172 | Pixel size in degrees. |
||
173 | skydir : `~astropy.coordinates.SkyCoord` |
||
174 | Coordinate of map center. |
||
175 | axes : list |
||
176 | List of `~MapAxis` objects for each non-spatial dimension. |
||
177 | If None then the map will be a 2D image. |
||
178 | dtype : str |
||
179 | Data type, default is 'float32' |
||
180 | unit : str or `~astropy.units.Unit` |
||
181 | Data unit. |
||
182 | meta : `dict` |
||
183 | Dictionary to store meta data. |
||
184 | region : `~regions.SkyRegion` |
||
185 | Sky region used for the region map. |
||
186 | |||
187 | Returns |
||
188 | ------- |
||
189 | map : `Map` |
||
190 | Empty map object. |
||
191 | """ |
||
192 | from .hpx import HpxMap |
||
193 | from .region import RegionNDMap |
||
194 | from .wcs import WcsMap |
||
195 | |||
196 | map_type = kwargs.setdefault("map_type", "wcs") |
||
197 | if "wcs" in map_type.lower(): |
||
198 | return WcsMap.create(**kwargs) |
||
199 | elif "hpx" in map_type.lower(): |
||
200 | return HpxMap.create(**kwargs) |
||
201 | elif map_type == "region": |
||
202 | _ = kwargs.pop("map_type") |
||
203 | return RegionNDMap.create(**kwargs) |
||
204 | else: |
||
205 | raise ValueError(f"Unrecognized map type: {map_type!r}") |
||
206 | |||
207 | @staticmethod |
||
208 | def read( |
||
209 | filename, hdu=None, hdu_bands=None, map_type="auto", format=None, colname=None |
||
210 | ): |
||
211 | """Read a map from a FITS file. |
||
212 | |||
213 | Parameters |
||
214 | ---------- |
||
215 | filename : str or `~pathlib.Path` |
||
216 | Name of the FITS file. |
||
217 | hdu : str |
||
218 | Name or index of the HDU with the map data. |
||
219 | hdu_bands : str |
||
220 | Name or index of the HDU with the BANDS table. If not |
||
221 | defined this will be inferred from the FITS header of the |
||
222 | map HDU. |
||
223 | map_type : {'wcs', 'wcs-sparse', 'hpx', 'hpx-sparse', 'auto', 'region'} |
||
224 | Map type. Selects the class that will be used to |
||
225 | instantiate the map. The map type should be consistent |
||
226 | with the format of the input file. If map_type is 'auto' |
||
227 | then an appropriate map type will be inferred from the |
||
228 | input file. |
||
229 | colname : str, optional |
||
230 | data column name to be used of healix map. |
||
231 | |||
232 | Returns |
||
233 | ------- |
||
234 | map_out : `Map` |
||
235 | Map object |
||
236 | """ |
||
237 | with fits.open(str(make_path(filename)), memmap=False) as hdulist: |
||
238 | return Map.from_hdulist( |
||
239 | hdulist, hdu, hdu_bands, map_type, format=format, colname=colname |
||
240 | ) |
||
241 | |||
242 | @staticmethod |
||
243 | def from_geom(geom, meta=None, data=None, unit="", dtype="float32"): |
||
244 | """Generate an empty map from a `Geom` instance. |
||
245 | |||
246 | Parameters |
||
247 | ---------- |
||
248 | geom : `Geom` |
||
249 | Map geometry. |
||
250 | data : `numpy.ndarray` |
||
251 | data array |
||
252 | meta : `dict` |
||
253 | Dictionary to store meta data. |
||
254 | unit : str or `~astropy.units.Unit` |
||
255 | Data unit. |
||
256 | |||
257 | Returns |
||
258 | ------- |
||
259 | map_out : `Map` |
||
260 | Map object |
||
261 | |||
262 | """ |
||
263 | from .hpx import HpxGeom |
||
264 | from .region import RegionGeom |
||
265 | from .wcs import WcsGeom |
||
266 | |||
267 | if isinstance(geom, HpxGeom): |
||
268 | map_type = "hpx" |
||
269 | elif isinstance(geom, WcsGeom): |
||
270 | map_type = "wcs" |
||
271 | elif isinstance(geom, RegionGeom): |
||
272 | map_type = "region" |
||
273 | else: |
||
274 | raise ValueError("Unrecognized geom type.") |
||
275 | |||
276 | cls_out = Map._get_map_cls(map_type) |
||
277 | return cls_out(geom, data=data, meta=meta, unit=unit, dtype=dtype) |
||
278 | |||
279 | @staticmethod |
||
280 | def from_hdulist( |
||
281 | hdulist, hdu=None, hdu_bands=None, map_type="auto", format=None, colname=None |
||
282 | ): |
||
283 | """Create from `astropy.io.fits.HDUList`. |
||
284 | |||
285 | Parameters |
||
286 | ---------- |
||
287 | hdulist : `~astropy.io.fits.HDUList` |
||
288 | HDU list containing HDUs for map data and bands. |
||
289 | hdu : str |
||
290 | Name or index of the HDU with the map data. |
||
291 | hdu_bands : str |
||
292 | Name or index of the HDU with the BANDS table. |
||
293 | map_type : {"auto", "wcs", "hpx", "region"} |
||
294 | Map type. |
||
295 | format : {'gadf', 'fgst-ccube', 'fgst-template'} |
||
296 | FITS format convention. |
||
297 | colname : str, optional |
||
298 | Data column name to be used for the HEALPix map. |
||
299 | |||
300 | Returns |
||
301 | ------- |
||
302 | map_out : `Map` |
||
303 | Map object |
||
304 | """ |
||
305 | if map_type == "auto": |
||
306 | map_type = Map._get_map_type(hdulist, hdu) |
||
307 | cls_out = Map._get_map_cls(map_type) |
||
308 | if map_type == "hpx": |
||
309 | return cls_out.from_hdulist( |
||
310 | hdulist, hdu=hdu, hdu_bands=hdu_bands, format=format, colname=colname |
||
311 | ) |
||
312 | else: |
||
313 | return cls_out.from_hdulist( |
||
314 | hdulist, hdu=hdu, hdu_bands=hdu_bands, format=format |
||
315 | ) |
||
316 | |||
317 | @staticmethod |
||
318 | def _get_meta_from_header(header): |
||
319 | """Load meta data from a FITS header.""" |
||
320 | if "META" in header: |
||
321 | return json.loads(header["META"], cls=JsonQuantityDecoder) |
||
322 | else: |
||
323 | return {} |
||
324 | |||
325 | @staticmethod |
||
326 | def _get_map_type(hdu_list, hdu_name): |
||
327 | """Infer map type from a FITS HDU. |
||
328 | |||
329 | Only read header, never data, to have good performance. |
||
330 | """ |
||
331 | if hdu_name is None: |
||
332 | # Find the header of the first non-empty HDU |
||
333 | header = hdu_list[0].header |
||
334 | if header["NAXIS"] == 0: |
||
335 | header = hdu_list[1].header |
||
336 | else: |
||
337 | header = hdu_list[hdu_name].header |
||
338 | |||
339 | if ("PIXTYPE" in header) and (header["PIXTYPE"] == "HEALPIX"): |
||
340 | return "hpx" |
||
341 | elif "CTYPE1" in header: |
||
342 | return "wcs" |
||
343 | else: |
||
344 | return "region" |
||
345 | |||
346 | @staticmethod |
||
347 | def _get_map_cls(map_type): |
||
348 | """Get map class for given `map_type` string. |
||
349 | |||
350 | This should probably be a registry dict so that users |
||
351 | can add supported map types to the `gammapy.maps` I/O |
||
352 | (see e.g. the Astropy table format I/O registry), |
||
353 | but that's non-trivial to implement without avoiding circular imports. |
||
354 | """ |
||
355 | if map_type == "wcs": |
||
356 | from .wcs import WcsNDMap |
||
357 | |||
358 | return WcsNDMap |
||
359 | elif map_type == "wcs-sparse": |
||
360 | raise NotImplementedError() |
||
361 | elif map_type == "hpx": |
||
362 | from .hpx import HpxNDMap |
||
363 | |||
364 | return HpxNDMap |
||
365 | elif map_type == "hpx-sparse": |
||
366 | raise NotImplementedError() |
||
367 | elif map_type == "region": |
||
368 | from .region import RegionNDMap |
||
369 | |||
370 | return RegionNDMap |
||
371 | else: |
||
372 | raise ValueError(f"Unrecognized map type: {map_type!r}") |
||
373 | |||
374 | def write(self, filename, overwrite=False, **kwargs): |
||
375 | """Write to a FITS file. |
||
376 | |||
377 | Parameters |
||
378 | ---------- |
||
379 | filename : str |
||
380 | Output file name. |
||
381 | overwrite : bool |
||
382 | Overwrite existing file? |
||
383 | hdu : str |
||
384 | Set the name of the image extension. By default this will |
||
385 | be set to SKYMAP (for BINTABLE HDU) or PRIMARY (for IMAGE |
||
386 | HDU). |
||
387 | hdu_bands : str |
||
388 | Set the name of the bands table extension. By default this will |
||
389 | be set to BANDS. |
||
390 | format : str, optional |
||
391 | FITS format convention. By default files will be written |
||
392 | to the gamma-astro-data-formats (GADF) format. This |
||
393 | option can be used to write files that are compliant with |
||
394 | format conventions required by specific software (e.g. the |
||
395 | Fermi Science Tools). The following formats are supported: |
||
396 | |||
397 | - "gadf" (default) |
||
398 | - "fgst-ccube" |
||
399 | - "fgst-ltcube" |
||
400 | - "fgst-bexpcube" |
||
401 | - "fgst-srcmap" |
||
402 | - "fgst-template" |
||
403 | - "fgst-srcmap-sparse" |
||
404 | - "galprop" |
||
405 | - "galprop2" |
||
406 | |||
407 | sparse : bool |
||
408 | Sparsify the map by dropping pixels with zero amplitude. |
||
409 | This option is only compatible with the 'gadf' format. |
||
410 | """ |
||
411 | hdulist = self.to_hdulist(**kwargs) |
||
412 | hdulist.writeto(str(make_path(filename)), overwrite=overwrite) |
||
413 | |||
414 | def iter_by_axis(self, axis_name, keepdims=False): |
||
415 | """ "Iterate over a given axis |
||
416 | |||
417 | Yields |
||
418 | ------ |
||
419 | map : `Map` |
||
420 | Map iteration. |
||
421 | |||
422 | See also |
||
423 | -------- |
||
424 | iter_by_image : iterate by image returning a map |
||
425 | """ |
||
426 | axis = self.geom.axes[axis_name] |
||
427 | for idx in range(axis.nbin): |
||
428 | idx_axis = slice(idx, idx + 1) if keepdims else idx |
||
429 | slices = {axis_name: idx_axis} |
||
430 | yield self.slice_by_idx(slices=slices) |
||
431 | |||
432 | def iter_by_image(self, keepdims=False): |
||
433 | """Iterate over image planes of a map. |
||
434 | |||
435 | Parameters |
||
436 | ---------- |
||
437 | keepdims : bool |
||
438 | Keep dimensions. |
||
439 | |||
440 | Yields |
||
441 | ------ |
||
442 | map : `Map` |
||
443 | Map iteration. |
||
444 | |||
445 | See also |
||
446 | -------- |
||
447 | iter_by_image_data : iterate by image returning data and index |
||
448 | """ |
||
449 | for idx in np.ndindex(self.geom.shape_axes): |
||
450 | if keepdims: |
||
451 | names = self.geom.axes.names |
||
452 | slices = {name: slice(_, _ + 1) for name, _ in zip(names, idx)} |
||
453 | yield self.slice_by_idx(slices=slices) |
||
454 | else: |
||
455 | yield self.get_image_by_idx(idx=idx) |
||
456 | |||
457 | def iter_by_image_data(self): |
||
458 | """Iterate over image planes of the map. |
||
459 | |||
460 | The image plane index is in data order, so that the data array can be |
||
461 | indexed directly. |
||
462 | |||
463 | Yields |
||
464 | ------ |
||
465 | (data, idx) : tuple |
||
466 | Where ``data`` is a `numpy.ndarray` view of the image plane data, |
||
467 | and ``idx`` is a tuple of int, the index of the image plane. |
||
468 | |||
469 | See also |
||
470 | -------- |
||
471 | iter_by_image : iterate by image returning a map |
||
472 | """ |
||
473 | for idx in np.ndindex(self.geom.shape_axes): |
||
474 | yield self.data[idx[::-1]], idx[::-1] |
||
475 | |||
476 | def coadd(self, map_in, weights=None): |
||
477 | """Add the contents of ``map_in`` to this map. |
||
478 | |||
479 | This method can be used to combine maps containing integral quantities (e.g. counts) |
||
480 | or differential quantities if the maps have the same binning. |
||
481 | |||
482 | Parameters |
||
483 | ---------- |
||
484 | map_in : `Map` |
||
485 | Input map. |
||
486 | weights: `Map` or `~numpy.ndarray` |
||
487 | The weight factors while adding |
||
488 | """ |
||
489 | if not self.unit.is_equivalent(map_in.unit): |
||
490 | raise ValueError("Incompatible units") |
||
491 | |||
492 | # TODO: Check whether geometries are aligned and if so sum the |
||
493 | # data vectors directly |
||
494 | if weights is not None: |
||
495 | map_in = map_in * weights |
||
496 | idx = map_in.geom.get_idx() |
||
497 | coords = map_in.geom.get_coord() |
||
498 | vals = u.Quantity(map_in.get_by_idx(idx), map_in.unit) |
||
499 | self.fill_by_coord(coords, vals) |
||
500 | |||
501 | def pad(self, pad_width, axis_name=None, mode="constant", cval=0, method="linear"): |
||
502 | """Pad the spatial dimensions of the map. |
||
503 | |||
504 | Parameters |
||
505 | ---------- |
||
506 | pad_width : {sequence, array_like, int} |
||
507 | Number of pixels padded to the edges of each axis. |
||
508 | axis_name : str |
||
509 | Which axis to downsample. By default spatial axes are padded. |
||
510 | mode : {'edge', 'constant', 'interp'} |
||
511 | Padding mode. 'edge' pads with the closest edge value. |
||
512 | 'constant' pads with a constant value. 'interp' pads with |
||
513 | an extrapolated value. |
||
514 | cval : float |
||
515 | Padding value when mode='consant'. |
||
516 | |||
517 | Returns |
||
518 | ------- |
||
519 | map : `Map` |
||
520 | Padded map. |
||
521 | |||
522 | """ |
||
523 | if axis_name: |
||
524 | if np.isscalar(pad_width): |
||
525 | pad_width = (pad_width, pad_width) |
||
526 | |||
527 | geom = self.geom.pad(pad_width=pad_width, axis_name=axis_name) |
||
528 | idx = self.geom.axes.index_data(axis_name) |
||
529 | pad_width_np = [(0, 0)] * self.data.ndim |
||
530 | pad_width_np[idx] = pad_width |
||
531 | |||
532 | kwargs = {} |
||
533 | if mode == "constant": |
||
534 | kwargs["constant_values"] = cval |
||
535 | |||
536 | data = np.pad(self.data, pad_width=pad_width_np, mode=mode, **kwargs) |
||
537 | return self.__class__( |
||
538 | geom=geom, data=data, unit=self.unit, meta=self.meta.copy() |
||
539 | ) |
||
540 | |||
541 | return self._pad_spatial(pad_width, mode="constant", cval=cval) |
||
542 | |||
543 | @abc.abstractmethod |
||
544 | def _pad_spatial(self, pad_width, mode="constant", cval=0, order=1): |
||
545 | pass |
||
546 | |||
547 | @abc.abstractmethod |
||
548 | def crop(self, crop_width): |
||
549 | """Crop the spatial dimensions of the map. |
||
550 | |||
551 | Parameters |
||
552 | ---------- |
||
553 | crop_width : {sequence, array_like, int} |
||
554 | Number of pixels cropped from the edges of each axis. |
||
555 | Defined analogously to ``pad_with`` from `numpy.pad`. |
||
556 | |||
557 | Returns |
||
558 | ------- |
||
559 | map : `Map` |
||
560 | Cropped map. |
||
561 | """ |
||
562 | pass |
||
563 | |||
564 | @abc.abstractmethod |
||
565 | def downsample(self, factor, preserve_counts=True, axis_name=None): |
||
566 | """Downsample the spatial dimension by a given factor. |
||
567 | |||
568 | Parameters |
||
569 | ---------- |
||
570 | factor : int |
||
571 | Downsampling factor. |
||
572 | preserve_counts : bool |
||
573 | Preserve the integral over each bin. This should be true |
||
574 | if the map is an integral quantity (e.g. counts) and false if |
||
575 | the map is a differential quantity (e.g. intensity). |
||
576 | axis_name : str |
||
577 | Which axis to downsample. By default spatial axes are downsampled. |
||
578 | |||
579 | Returns |
||
580 | ------- |
||
581 | map : `Map` |
||
582 | Downsampled map. |
||
583 | """ |
||
584 | pass |
||
585 | |||
586 | @abc.abstractmethod |
||
587 | def upsample(self, factor, order=0, preserve_counts=True, axis_name=None): |
||
588 | """Upsample the spatial dimension by a given factor. |
||
589 | |||
590 | Parameters |
||
591 | ---------- |
||
592 | factor : int |
||
593 | Upsampling factor. |
||
594 | order : int |
||
595 | Order of the interpolation used for upsampling. |
||
596 | preserve_counts : bool |
||
597 | Preserve the integral over each bin. This should be true |
||
598 | if the map is an integral quantity (e.g. counts) and false if |
||
599 | the map is a differential quantity (e.g. intensity). |
||
600 | axis_name : str |
||
601 | Which axis to upsample. By default spatial axes are upsampled. |
||
602 | |||
603 | |||
604 | Returns |
||
605 | ------- |
||
606 | map : `Map` |
||
607 | Upsampled map. |
||
608 | """ |
||
609 | pass |
||
610 | |||
611 | def resample(self, geom, weights=None, preserve_counts=True): |
||
612 | """Resample pixels to ``geom`` with given ``weights``. |
||
613 | |||
614 | Parameters |
||
615 | ---------- |
||
616 | geom : `~gammapy.maps.Geom` |
||
617 | Target Map geometry |
||
618 | weights : `~numpy.ndarray` |
||
619 | Weights vector. Default is weight of one. Must have same shape as |
||
620 | the data of the map. |
||
621 | preserve_counts : bool |
||
622 | Preserve the integral over each bin. This should be true |
||
623 | if the map is an integral quantity (e.g. counts) and false if |
||
624 | the map is a differential quantity (e.g. intensity) |
||
625 | |||
626 | Returns |
||
627 | ------- |
||
628 | resampled_map : `Map` |
||
629 | Resampled map |
||
630 | """ |
||
631 | coords = self.geom.get_coord() |
||
632 | idx = geom.coord_to_idx(coords) |
||
633 | |||
634 | weights = 1 if weights is None else weights |
||
635 | |||
636 | resampled = self.from_geom(geom=geom) |
||
637 | resampled._resample_by_idx( |
||
638 | idx, weights=self.data * weights, preserve_counts=preserve_counts |
||
639 | ) |
||
640 | return resampled |
||
641 | |||
642 | @abc.abstractmethod |
||
643 | def _resample_by_idx(self, idx, weights=None, preserve_counts=False): |
||
644 | """Resample pixels at ``idx`` with given ``weights``. |
||
645 | |||
646 | Parameters |
||
647 | ---------- |
||
648 | idx : tuple |
||
649 | Tuple of pixel index arrays for each dimension of the map. |
||
650 | Tuple should be ordered as (I_lon, I_lat, I_0, ..., I_n) |
||
651 | for WCS maps and (I_hpx, I_0, ..., I_n) for HEALPix maps. |
||
652 | weights : `~numpy.ndarray` |
||
653 | Weights vector. Default is weight of one. |
||
654 | preserve_counts : bool |
||
655 | Preserve the integral over each bin. This should be true |
||
656 | if the map is an integral quantity (e.g. counts) and false if |
||
657 | the map is a differential quantity (e.g. intensity) |
||
658 | """ |
||
659 | pass |
||
660 | |||
661 | def resample_axis(self, axis, weights=None, ufunc=np.add): |
||
662 | """Resample map to a new axis by grouping and reducing smaller bins by a given ufunc |
||
663 | |||
664 | By default, the map content are summed over the smaller bins. Other numpy ufunc can be |
||
665 | used, e.g. `numpy.logical_and` or `numpy.logical_or`. |
||
666 | |||
667 | Parameters |
||
668 | ---------- |
||
669 | axis : `MapAxis` |
||
670 | New map axis. |
||
671 | weights : `Map` |
||
672 | Array to be used as weights. The spatial geometry must be equivalent |
||
673 | to `other` and additional axes must be broadcastable. |
||
674 | ufunc : `~numpy.ufunc` |
||
675 | ufunc to use to resample the axis. Default is numpy.add. |
||
676 | |||
677 | |||
678 | Returns |
||
679 | ------- |
||
680 | map : `Map` |
||
681 | Map with resampled axis. |
||
682 | """ |
||
683 | from .hpx import HpxGeom |
||
684 | |||
685 | geom = self.geom.resample_axis(axis) |
||
686 | |||
687 | axis_self = self.geom.axes[axis.name] |
||
688 | axis_resampled = geom.axes[axis.name] |
||
689 | |||
690 | # We don't use MapAxis.coord_to_idx because is does not behave as needed with boundaries |
||
691 | coord = axis_resampled.edges.value |
||
692 | edges = axis_self.edges.value |
||
693 | indices = np.digitize(coord, edges) - 1 |
||
694 | |||
695 | idx = self.geom.axes.index_data(axis.name) |
||
696 | |||
697 | weights = 1 if weights is None else weights.data |
||
698 | |||
699 | if not isinstance(self.geom, HpxGeom): |
||
700 | shape = self.geom._shape[:2] |
||
701 | else: |
||
702 | shape = (self.geom.data_shape[-1],) |
||
703 | shape += tuple([ax.nbin if ax != axis else 1 for ax in self.geom.axes]) |
||
704 | |||
705 | padded_array = np.append(self.data * weights, np.zeros(shape[::-1]), axis=idx) |
||
706 | |||
707 | slices = tuple([slice(0, _) for _ in geom.data_shape]) |
||
708 | data = ufunc.reduceat(padded_array, indices=indices, axis=idx)[slices] |
||
709 | |||
710 | return self._init_copy(data=data, geom=geom) |
||
711 | |||
712 | def slice_by_idx( |
||
713 | self, |
||
714 | slices, |
||
715 | ): |
||
716 | """Slice sub map from map object. |
||
717 | |||
718 | Parameters |
||
719 | ---------- |
||
720 | slices : dict |
||
721 | Dict of axes names and integers or `slice` object pairs. Contains one |
||
722 | element for each non-spatial dimension. For integer indexing the |
||
723 | corresponding axes is dropped from the map. Axes not specified in the |
||
724 | dict are kept unchanged. |
||
725 | |||
726 | Returns |
||
727 | ------- |
||
728 | map_out : `Map` |
||
729 | Sliced map object. |
||
730 | """ |
||
731 | geom = self.geom.slice_by_idx(slices) |
||
732 | slices = tuple([slices.get(ax.name, slice(None)) for ax in self.geom.axes]) |
||
733 | data = self.data[slices[::-1]] |
||
734 | return self.__class__(geom=geom, data=data, unit=self.unit, meta=self.meta) |
||
735 | |||
736 | def get_image_by_coord(self, coords): |
||
737 | """Return spatial map at the given axis coordinates. |
||
738 | |||
739 | Parameters |
||
740 | ---------- |
||
741 | coords : tuple or dict |
||
742 | Tuple should be ordered as (x_0, ..., x_n) where x_i are coordinates |
||
743 | for non-spatial dimensions of the map. Dict should specify the axis |
||
744 | names of the non-spatial axes such as {'axes0': x_0, ..., 'axesn': x_n}. |
||
745 | |||
746 | Returns |
||
747 | ------- |
||
748 | map_out : `Map` |
||
749 | Map with spatial dimensions only. |
||
750 | |||
751 | See Also |
||
752 | -------- |
||
753 | get_image_by_idx, get_image_by_pix |
||
754 | |||
755 | Examples |
||
756 | -------- |
||
757 | :: |
||
758 | |||
759 | import numpy as np |
||
760 | from gammapy.maps import Map, MapAxis |
||
761 | from astropy.coordinates import SkyCoord |
||
762 | from astropy import units as u |
||
763 | |||
764 | # Define map axes |
||
765 | energy_axis = MapAxis.from_edges( |
||
766 | np.logspace(-1., 1., 4), unit='TeV', name='energy', |
||
767 | ) |
||
768 | |||
769 | time_axis = MapAxis.from_edges( |
||
770 | np.linspace(0., 10, 20), unit='h', name='time', |
||
771 | ) |
||
772 | |||
773 | # Define map center |
||
774 | skydir = SkyCoord(0, 0, frame='galactic', unit='deg') |
||
775 | |||
776 | # Create map |
||
777 | m_wcs = Map.create( |
||
778 | map_type='wcs', |
||
779 | binsz=0.02, |
||
780 | skydir=skydir, |
||
781 | width=10.0, |
||
782 | axes=[energy_axis, time_axis], |
||
783 | ) |
||
784 | |||
785 | # Get image by coord tuple |
||
786 | image = m_wcs.get_image_by_coord(('500 GeV', '1 h')) |
||
787 | |||
788 | # Get image by coord dict with strings |
||
789 | image = m_wcs.get_image_by_coord({'energy': '500 GeV', 'time': '1 h'}) |
||
790 | |||
791 | # Get image by coord dict with quantities |
||
792 | image = m_wcs.get_image_by_coord({'energy': 0.5 * u.TeV, 'time': 1 * u.h}) |
||
793 | """ |
||
794 | if isinstance(coords, tuple): |
||
795 | coords = dict(zip(self.geom.axes.names, coords)) |
||
796 | |||
797 | idx = self.geom.axes.coord_to_idx(coords) |
||
798 | return self.get_image_by_idx(idx) |
||
799 | |||
800 | def get_image_by_pix(self, pix): |
||
801 | """Return spatial map at the given axis pixel coordinates |
||
802 | |||
803 | Parameters |
||
804 | ---------- |
||
805 | pix : tuple |
||
806 | Tuple of scalar pixel coordinates for each non-spatial dimension of |
||
807 | the map. Tuple should be ordered as (I_0, ..., I_n). Pixel coordinates |
||
808 | can be either float or integer type. |
||
809 | |||
810 | See Also |
||
811 | -------- |
||
812 | get_image_by_coord, get_image_by_idx |
||
813 | |||
814 | Returns |
||
815 | ------- |
||
816 | map_out : `Map` |
||
817 | Map with spatial dimensions only. |
||
818 | """ |
||
819 | idx = self.geom.pix_to_idx(pix) |
||
820 | return self.get_image_by_idx(idx) |
||
821 | |||
822 | def get_image_by_idx(self, idx): |
||
823 | """Return spatial map at the given axis pixel indices. |
||
824 | |||
825 | Parameters |
||
826 | ---------- |
||
827 | idx : tuple |
||
828 | Tuple of scalar indices for each non spatial dimension of the map. |
||
829 | Tuple should be ordered as (I_0, ..., I_n). |
||
830 | |||
831 | See Also |
||
832 | -------- |
||
833 | get_image_by_coord, get_image_by_pix |
||
834 | |||
835 | Returns |
||
836 | ------- |
||
837 | map_out : `Map` |
||
838 | Map with spatial dimensions only. |
||
839 | """ |
||
840 | if len(idx) != len(self.geom.axes): |
||
841 | raise ValueError("Tuple length must equal number of non-spatial dimensions") |
||
842 | |||
843 | # Only support scalar indices per axis |
||
844 | idx = tuple([int(_) for _ in idx]) |
||
845 | |||
846 | geom = self.geom.to_image() |
||
847 | data = self.data[idx[::-1]] |
||
848 | return self.__class__(geom=geom, data=data, unit=self.unit, meta=self.meta) |
||
849 | |||
850 | def get_by_coord(self, coords, fill_value=np.nan): |
||
851 | """Return map values at the given map coordinates. |
||
852 | |||
853 | Parameters |
||
854 | ---------- |
||
855 | coords : tuple or `~gammapy.maps.MapCoord` |
||
856 | Coordinate arrays for each dimension of the map. Tuple |
||
857 | should be ordered as (lon, lat, x_0, ..., x_n) where x_i |
||
858 | are coordinates for non-spatial dimensions of the map. |
||
859 | fill_value : float |
||
860 | Value which is returned if the position is outside of the projection |
||
861 | footprint |
||
862 | |||
863 | Returns |
||
864 | ------- |
||
865 | vals : `~numpy.ndarray` |
||
866 | Values of pixels in the map. np.nan used to flag coords |
||
867 | outside of map. |
||
868 | """ |
||
869 | pix = self.geom.coord_to_pix(coords=coords) |
||
870 | vals = self.get_by_pix(pix, fill_value=fill_value) |
||
871 | return vals |
||
872 | |||
873 | def get_by_pix(self, pix, fill_value=np.nan): |
||
874 | """Return map values at the given pixel coordinates. |
||
875 | |||
876 | Parameters |
||
877 | ---------- |
||
878 | pix : tuple |
||
879 | Tuple of pixel index arrays for each dimension of the map. |
||
880 | Tuple should be ordered as (I_lon, I_lat, I_0, ..., I_n) |
||
881 | for WCS maps and (I_hpx, I_0, ..., I_n) for HEALPix maps. |
||
882 | Pixel indices can be either float or integer type. |
||
883 | fill_value : float |
||
884 | Value which is returned if the position is outside of the projection |
||
885 | footprint |
||
886 | |||
887 | Returns |
||
888 | ------- |
||
889 | vals : `~numpy.ndarray` |
||
890 | Array of pixel values. np.nan used to flag coordinates |
||
891 | outside of map |
||
892 | """ |
||
893 | # FIXME: Support local indexing here? |
||
894 | # FIXME: Support slicing? |
||
895 | pix = np.broadcast_arrays(*pix) |
||
896 | idx = self.geom.pix_to_idx(pix) |
||
897 | vals = self.get_by_idx(idx) |
||
898 | mask = self.geom.contains_pix(pix) |
||
899 | |||
900 | if not mask.all(): |
||
901 | vals = vals.astype(type(fill_value)) |
||
902 | vals[~mask] = fill_value |
||
903 | |||
904 | return vals |
||
905 | |||
906 | @abc.abstractmethod |
||
907 | def get_by_idx(self, idx): |
||
908 | """Return map values at the given pixel indices. |
||
909 | |||
910 | Parameters |
||
911 | ---------- |
||
912 | idx : tuple |
||
913 | Tuple of pixel index arrays for each dimension of the map. |
||
914 | Tuple should be ordered as (I_lon, I_lat, I_0, ..., I_n) |
||
915 | for WCS maps and (I_hpx, I_0, ..., I_n) for HEALPix maps. |
||
916 | |||
917 | Returns |
||
918 | ------- |
||
919 | vals : `~numpy.ndarray` |
||
920 | Array of pixel values. |
||
921 | np.nan used to flag coordinate outside of map |
||
922 | """ |
||
923 | pass |
||
924 | |||
925 | @abc.abstractmethod |
||
926 | def interp_by_coord(self, coords, method="linear", fill_value=None): |
||
927 | """Interpolate map values at the given map coordinates. |
||
928 | |||
929 | Parameters |
||
930 | ---------- |
||
931 | coords : tuple or `~gammapy.maps.MapCoord` |
||
932 | Coordinate arrays for each dimension of the map. Tuple |
||
933 | should be ordered as (lon, lat, x_0, ..., x_n) where x_i |
||
934 | are coordinates for non-spatial dimensions of the map. |
||
935 | method : {"linear", "nearest"} |
||
936 | Method to interpolate data values. By default linear |
||
937 | interpolation is performed. |
||
938 | fill_value : None or float value |
||
939 | The value to use for points outside of the interpolation domain. |
||
940 | If None, values outside the domain are extrapolated. |
||
941 | |||
942 | Returns |
||
943 | ------- |
||
944 | vals : `~numpy.ndarray` |
||
945 | Interpolated pixel values. |
||
946 | """ |
||
947 | pass |
||
948 | |||
949 | @abc.abstractmethod |
||
950 | def interp_by_pix(self, pix, method="linear", fill_value=None): |
||
951 | """Interpolate map values at the given pixel coordinates. |
||
952 | |||
953 | Parameters |
||
954 | ---------- |
||
955 | pix : tuple |
||
956 | Tuple of pixel coordinate arrays for each dimension of the |
||
957 | map. Tuple should be ordered as (p_lon, p_lat, p_0, ..., |
||
958 | p_n) where p_i are pixel coordinates for non-spatial |
||
959 | dimensions of the map. |
||
960 | method : {"linear", "nearest"} |
||
961 | Method to interpolate data values. By default linear |
||
962 | interpolation is performed. |
||
963 | fill_value : None or float value |
||
964 | The value to use for points outside of the interpolation domain. |
||
965 | If None, values outside the domain are extrapolated. |
||
966 | |||
967 | Returns |
||
968 | ------- |
||
969 | vals : `~numpy.ndarray` |
||
970 | Interpolated pixel values. |
||
971 | """ |
||
972 | pass |
||
973 | |||
974 | def interp_to_geom(self, geom, preserve_counts=False, fill_value=0, **kwargs): |
||
975 | """Interpolate map to input geometry. |
||
976 | |||
977 | Parameters |
||
978 | ---------- |
||
979 | geom : `~gammapy.maps.Geom` |
||
980 | Target Map geometry |
||
981 | preserve_counts : bool |
||
982 | Preserve the integral over each bin. This should be true |
||
983 | if the map is an integral quantity (e.g. counts) and false if |
||
984 | the map is a differential quantity (e.g. intensity) |
||
985 | **kwargs : dict |
||
986 | Keyword arguments passed to `Map.interp_by_coord` |
||
987 | |||
988 | Returns |
||
989 | ------- |
||
990 | interp_map : `Map` |
||
991 | Interpolated Map |
||
992 | """ |
||
993 | coords = geom.get_coord() |
||
994 | map_copy = self.copy() |
||
995 | |||
996 | if preserve_counts: |
||
997 | if geom.ndim > 2 and geom.axes[0] != self.geom.axes[0]: |
||
998 | raise ValueError( |
||
999 | f"Energy axis do not match: expected {self.geom.axes[0]}," |
||
1000 | " but got {geom.axes[0]}." |
||
1001 | ) |
||
1002 | map_copy.data /= map_copy.geom.solid_angle().to_value("deg2") |
||
1003 | |||
1004 | if map_copy.is_mask: |
||
1005 | # TODO: check this NaN handling is needed |
||
1006 | data = map_copy.get_by_coord(coords) |
||
1007 | data = np.nan_to_num(data, nan=fill_value).astype(bool) |
||
1008 | else: |
||
1009 | data = map_copy.interp_by_coord(coords, fill_value=fill_value, **kwargs) |
||
1010 | |||
1011 | if preserve_counts: |
||
1012 | data *= geom.solid_angle().to_value("deg2") |
||
1013 | |||
1014 | return Map.from_geom(geom, data=data, unit=self.unit) |
||
1015 | |||
1016 | def reproject_to_geom(self, geom, preserve_counts=False, precision_factor=10): |
||
1017 | """Reproject map to input geometry. |
||
1018 | |||
1019 | Parameters |
||
1020 | ---------- |
||
1021 | geom : `~gammapy.maps.Geom` |
||
1022 | Target Map geometry |
||
1023 | preserve_counts : bool |
||
1024 | Preserve the integral over each bin. This should be true |
||
1025 | if the map is an integral quantity (e.g. counts) and false if |
||
1026 | the map is a differential quantity (e.g. intensity) |
||
1027 | precision_factor : int |
||
1028 | Minimal factor between the bin size of the output map and the oversampled base map. |
||
1029 | Used only for the oversampling method. |
||
1030 | |||
1031 | Returns |
||
1032 | ------- |
||
1033 | output_map : `Map` |
||
1034 | Reprojected Map |
||
1035 | """ |
||
1036 | from .hpx import HpxGeom |
||
1037 | from .region import RegionGeom |
||
1038 | |||
1039 | axes = [ax.copy() for ax in self.geom.axes] |
||
1040 | geom3d = geom.copy(axes=axes) |
||
1041 | |||
1042 | if not geom.is_image: |
||
1043 | if geom.axes.names != geom3d.axes.names: |
||
1044 | raise ValueError("Axis names and order should be the same.") |
||
1045 | if geom.axes != geom3d.axes and ( |
||
1046 | isinstance(geom3d, HpxGeom) or isinstance(self.geom, HpxGeom) |
||
1047 | ): |
||
1048 | raise TypeError( |
||
1049 | "Reprojection to 3d geom with non-identical axes is not supported for HpxGeom. " |
||
1050 | "Reproject to 2d geom first and then use inter_to_geom method." |
||
1051 | ) |
||
1052 | if isinstance(geom3d, RegionGeom): |
||
1053 | base_factor = ( |
||
1054 | geom3d.to_wcs_geom().pixel_scales.min() / self.geom.pixel_scales.min() |
||
1055 | ) |
||
1056 | elif isinstance(self.geom, RegionGeom): |
||
1057 | base_factor = ( |
||
1058 | geom3d.pixel_scales.min() / self.geom.to_wcs_geom().pixel_scales.min() |
||
1059 | ) |
||
1060 | else: |
||
1061 | base_factor = geom3d.pixel_scales.min() / self.geom.pixel_scales.min() |
||
1062 | |||
1063 | if base_factor >= precision_factor: |
||
1064 | input_map = self |
||
1065 | else: |
||
1066 | factor = precision_factor / base_factor |
||
1067 | if isinstance(self.geom, HpxGeom): |
||
1068 | factor = int(2 ** np.ceil(np.log(factor) / np.log(2))) |
||
1069 | else: |
||
1070 | factor = int(np.ceil(factor)) |
||
1071 | input_map = self.upsample(factor=factor, preserve_counts=preserve_counts) |
||
1072 | |||
1073 | output_map = input_map.resample(geom3d, preserve_counts=preserve_counts) |
||
1074 | |||
1075 | if not geom.is_image and geom.axes != geom3d.axes: |
||
1076 | for base_ax, target_ax in zip(geom3d.axes, geom.axes): |
||
1077 | base_factor = base_ax.bin_width.min() / target_ax.bin_width.min() |
||
1078 | if not base_factor >= precision_factor: |
||
1079 | factor = precision_factor / base_factor |
||
1080 | factor = int(np.ceil(factor)) |
||
1081 | output_map = output_map.upsample( |
||
1082 | factor=factor, |
||
1083 | preserve_counts=preserve_counts, |
||
1084 | axis_name=base_ax.name, |
||
1085 | ) |
||
1086 | output_map = output_map.resample(geom, preserve_counts=preserve_counts) |
||
1087 | return output_map |
||
1088 | |||
1089 | def fill_events(self, events): |
||
1090 | """Fill event coordinates (`~gammapy.data.EventList`).""" |
||
1091 | self.fill_by_coord(events.map_coord(self.geom)) |
||
1092 | |||
1093 | def fill_by_coord(self, coords, weights=None): |
||
1094 | """Fill pixels at ``coords`` with given ``weights``. |
||
1095 | |||
1096 | Parameters |
||
1097 | ---------- |
||
1098 | coords : tuple or `~gammapy.maps.MapCoord` |
||
1099 | Coordinate arrays for each dimension of the map. Tuple |
||
1100 | should be ordered as (lon, lat, x_0, ..., x_n) where x_i |
||
1101 | are coordinates for non-spatial dimensions of the map. |
||
1102 | weights : `~numpy.ndarray` |
||
1103 | Weights vector. Default is weight of one. |
||
1104 | """ |
||
1105 | idx = self.geom.coord_to_idx(coords) |
||
1106 | self.fill_by_idx(idx, weights=weights) |
||
1107 | |||
1108 | def fill_by_pix(self, pix, weights=None): |
||
1109 | """Fill pixels at ``pix`` with given ``weights``. |
||
1110 | |||
1111 | Parameters |
||
1112 | ---------- |
||
1113 | pix : tuple |
||
1114 | Tuple of pixel index arrays for each dimension of the map. |
||
1115 | Tuple should be ordered as (I_lon, I_lat, I_0, ..., I_n) |
||
1116 | for WCS maps and (I_hpx, I_0, ..., I_n) for HEALPix maps. |
||
1117 | Pixel indices can be either float or integer type. Float |
||
1118 | indices will be rounded to the nearest integer. |
||
1119 | weights : `~numpy.ndarray` |
||
1120 | Weights vector. Default is weight of one. |
||
1121 | """ |
||
1122 | idx = pix_tuple_to_idx(pix) |
||
1123 | return self.fill_by_idx(idx, weights=weights) |
||
1124 | |||
1125 | @abc.abstractmethod |
||
1126 | def fill_by_idx(self, idx, weights=None): |
||
1127 | """Fill pixels at ``idx`` with given ``weights``. |
||
1128 | |||
1129 | Parameters |
||
1130 | ---------- |
||
1131 | idx : tuple |
||
1132 | Tuple of pixel index arrays for each dimension of the map. |
||
1133 | Tuple should be ordered as (I_lon, I_lat, I_0, ..., I_n) |
||
1134 | for WCS maps and (I_hpx, I_0, ..., I_n) for HEALPix maps. |
||
1135 | weights : `~numpy.ndarray` |
||
1136 | Weights vector. Default is weight of one. |
||
1137 | """ |
||
1138 | pass |
||
1139 | |||
1140 | def set_by_coord(self, coords, vals): |
||
1141 | """Set pixels at ``coords`` with given ``vals``. |
||
1142 | |||
1143 | Parameters |
||
1144 | ---------- |
||
1145 | coords : tuple or `~gammapy.maps.MapCoord` |
||
1146 | Coordinate arrays for each dimension of the map. Tuple |
||
1147 | should be ordered as (lon, lat, x_0, ..., x_n) where x_i |
||
1148 | are coordinates for non-spatial dimensions of the map. |
||
1149 | vals : `~numpy.ndarray` |
||
1150 | Values vector. |
||
1151 | """ |
||
1152 | idx = self.geom.coord_to_pix(coords) |
||
1153 | self.set_by_pix(idx, vals) |
||
1154 | |||
1155 | def set_by_pix(self, pix, vals): |
||
1156 | """Set pixels at ``pix`` with given ``vals``. |
||
1157 | |||
1158 | Parameters |
||
1159 | ---------- |
||
1160 | pix : tuple |
||
1161 | Tuple of pixel index arrays for each dimension of the map. |
||
1162 | Tuple should be ordered as (I_lon, I_lat, I_0, ..., I_n) |
||
1163 | for WCS maps and (I_hpx, I_0, ..., I_n) for HEALPix maps. |
||
1164 | Pixel indices can be either float or integer type. Float |
||
1165 | indices will be rounded to the nearest integer. |
||
1166 | vals : `~numpy.ndarray` |
||
1167 | Values vector. |
||
1168 | """ |
||
1169 | idx = pix_tuple_to_idx(pix) |
||
1170 | return self.set_by_idx(idx, vals) |
||
1171 | |||
1172 | @abc.abstractmethod |
||
1173 | def set_by_idx(self, idx, vals): |
||
1174 | """Set pixels at ``idx`` with given ``vals``. |
||
1175 | |||
1176 | Parameters |
||
1177 | ---------- |
||
1178 | idx : tuple |
||
1179 | Tuple of pixel index arrays for each dimension of the map. |
||
1180 | Tuple should be ordered as (I_lon, I_lat, I_0, ..., I_n) |
||
1181 | for WCS maps and (I_hpx, I_0, ..., I_n) for HEALPix maps. |
||
1182 | vals : `~numpy.ndarray` |
||
1183 | Values vector. |
||
1184 | """ |
||
1185 | pass |
||
1186 | |||
1187 | def plot_grid(self, figsize=None, ncols=3, **kwargs): |
||
1188 | """Plot map as a grid of subplots for non-spatial axes |
||
1189 | |||
1190 | Parameters |
||
1191 | ---------- |
||
1192 | figsize : tuple of int |
||
1193 | Figsize to plot on |
||
1194 | ncols : int |
||
1195 | Number of columns to plot |
||
1196 | **kwargs : dict |
||
1197 | Keyword arguments passed to `Map.plot`. |
||
1198 | |||
1199 | Returns |
||
1200 | ------- |
||
1201 | axes : `~numpy.ndarray` of `~matplotlib.pyplot.Axes` |
||
1202 | Axes grid |
||
1203 | """ |
||
1204 | if len(self.geom.axes) > 1: |
||
1205 | raise ValueError("Grid plotting is only supported for one non spatial axis") |
||
1206 | |||
1207 | axis = self.geom.axes[0] |
||
1208 | |||
1209 | cols = min(ncols, axis.nbin) |
||
1210 | rows = 1 + (axis.nbin - 1) // cols |
||
1211 | |||
1212 | if figsize is None: |
||
1213 | width = 12 |
||
1214 | figsize = (width, width * rows / cols) |
||
1215 | |||
1216 | if self.geom.is_hpx: |
||
1217 | wcs = self.geom.to_wcs_geom().wcs |
||
1218 | else: |
||
1219 | wcs = self.geom.wcs |
||
1220 | |||
1221 | fig, axes = plt.subplots( |
||
1222 | ncols=cols, |
||
1223 | nrows=rows, |
||
1224 | subplot_kw={"projection": wcs}, |
||
1225 | figsize=figsize, |
||
1226 | gridspec_kw={"hspace": 0.1, "wspace": 0.1}, |
||
1227 | ) |
||
1228 | |||
1229 | for idx in range(cols * rows): |
||
1230 | ax = axes.flat[idx] |
||
1231 | |||
1232 | try: |
||
1233 | image = self.get_image_by_idx((idx,)) |
||
1234 | except IndexError: |
||
1235 | ax.set_visible(False) |
||
1236 | continue |
||
1237 | |||
1238 | if image.geom.is_hpx: |
||
1239 | image_wcs = image.to_wcs( |
||
1240 | normalize=False, |
||
1241 | proj="AIT", |
||
1242 | oversample=2, |
||
1243 | ) |
||
1244 | else: |
||
1245 | image_wcs = image |
||
1246 | |||
1247 | image_wcs.plot(ax=ax, **kwargs) |
||
1248 | |||
1249 | if axis.node_type == "center": |
||
1250 | if axis.name == "energy" or axis.name == "energy_true": |
||
1251 | info = energy_unit_format(axis.center[idx]) |
||
1252 | else: |
||
1253 | info = f"{axis.center[idx]:.1f}" |
||
1254 | else: |
||
1255 | if axis.name == "energy" or axis.name == "energy_true": |
||
1256 | info = ( |
||
1257 | f"{energy_unit_format(axis.edges[idx])} - " |
||
1258 | f"{energy_unit_format(axis.edges[idx+1])}" |
||
1259 | ) |
||
1260 | else: |
||
1261 | info = f"{axis.edges[idx]:.1f} - {axis.edges[idx + 1]:.1f} " |
||
1262 | ax.set_title(f"{axis.name.capitalize()} " + info) |
||
1263 | lon, lat = ax.coords[0], ax.coords[1] |
||
1264 | lon.set_ticks_position("b") |
||
1265 | lat.set_ticks_position("l") |
||
1266 | |||
1267 | row, col = np.unravel_index(idx, shape=(rows, cols)) |
||
1268 | |||
1269 | if col > 0: |
||
1270 | lat.set_ticklabel_visible(False) |
||
1271 | lat.set_axislabel("") |
||
1272 | |||
1273 | if row < (rows - 1): |
||
1274 | lon.set_ticklabel_visible(False) |
||
1275 | lon.set_axislabel("") |
||
1276 | |||
1277 | return axes |
||
1278 | |||
1279 | def plot_interactive(self, rc_params=None, **kwargs): |
||
1280 | """ |
||
1281 | Plot map with interactive widgets to explore the non spatial axes. |
||
1282 | |||
1283 | Parameters |
||
1284 | ---------- |
||
1285 | rc_params : dict |
||
1286 | Passed to ``matplotlib.rc_context(rc=rc_params)`` to style the plot. |
||
1287 | **kwargs : dict |
||
1288 | Keyword arguments passed to `WcsNDMap.plot`. |
||
1289 | |||
1290 | Examples |
||
1291 | -------- |
||
1292 | You can try this out e.g. using a Fermi-LAT diffuse model cube with an energy axis:: |
||
1293 | |||
1294 | from gammapy.maps import Map |
||
1295 | |||
1296 | m = Map.read("$GAMMAPY_DATA/fermi_3fhl/gll_iem_v06_cutout.fits") |
||
1297 | m.plot_interactive(add_cbar=True, stretch="sqrt") |
||
1298 | |||
1299 | If you would like to adjust the figure size you can use the ``rc_params`` argument:: |
||
1300 | |||
1301 | rc_params = {'figure.figsize': (12, 6), 'font.size': 12} |
||
1302 | m.plot_interactive(rc_params=rc_params) |
||
1303 | """ |
||
1304 | import matplotlib as mpl |
||
1305 | from ipywidgets import RadioButtons, SelectionSlider |
||
1306 | from ipywidgets.widgets.interaction import fixed, interact |
||
1307 | |||
1308 | if self.geom.is_image: |
||
1309 | raise TypeError("Use .plot() for 2D Maps") |
||
1310 | |||
1311 | kwargs.setdefault("interpolation", "nearest") |
||
1312 | kwargs.setdefault("origin", "lower") |
||
1313 | kwargs.setdefault("cmap", "afmhot") |
||
1314 | |||
1315 | rc_params = rc_params or {} |
||
1316 | stretch = kwargs.pop("stretch", "sqrt") |
||
1317 | |||
1318 | interact_kwargs = {} |
||
1319 | |||
1320 | for axis in self.geom.axes: |
||
1321 | if axis.node_type == "center": |
||
1322 | if axis.name == "energy" or axis.name == "energy_true": |
||
1323 | options = energy_unit_format(axis.center) |
||
1324 | else: |
||
1325 | options = axis.as_plot_labels |
||
1326 | elif axis.name == "energy" or axis.name == "energy_true": |
||
1327 | E = energy_unit_format(axis.edges) |
||
1328 | options = [f"{E[i]} - {E[i+1]}" for i in range(len(E) - 1)] |
||
1329 | else: |
||
1330 | options = axis.as_plot_labels |
||
1331 | interact_kwargs[axis.name] = SelectionSlider( |
||
1332 | options=options, |
||
1333 | description=f"Select {axis.name}:", |
||
1334 | continuous_update=False, |
||
1335 | style={"description_width": "initial"}, |
||
1336 | layout={"width": "50%"}, |
||
1337 | ) |
||
1338 | interact_kwargs[axis.name + "_options"] = fixed(options) |
||
1339 | |||
1340 | interact_kwargs["stretch"] = RadioButtons( |
||
1341 | options=["linear", "sqrt", "log"], |
||
1342 | value=stretch, |
||
1343 | description="Select stretch:", |
||
1344 | style={"description_width": "initial"}, |
||
1345 | ) |
||
1346 | |||
1347 | @interact(**interact_kwargs) |
||
1348 | def _plot_interactive(**ikwargs): |
||
1349 | idx = [ |
||
1350 | ikwargs[ax.name + "_options"].index(ikwargs[ax.name]) |
||
1351 | for ax in self.geom.axes |
||
1352 | ] |
||
1353 | img = self.get_image_by_idx(idx) |
||
1354 | stretch = ikwargs["stretch"] |
||
1355 | with mpl.rc_context(rc=rc_params): |
||
1356 | img.plot(stretch=stretch, **kwargs) |
||
1357 | plt.show() |
||
1358 | |||
1359 | def copy(self, **kwargs): |
||
1360 | """Copy map instance and overwrite given attributes, except for geometry. |
||
1361 | |||
1362 | Parameters |
||
1363 | ---------- |
||
1364 | **kwargs : dict |
||
1365 | Keyword arguments to overwrite in the map constructor. |
||
1366 | |||
1367 | Returns |
||
1368 | ------- |
||
1369 | copy : `Map` |
||
1370 | Copied Map. |
||
1371 | """ |
||
1372 | if "geom" in kwargs: |
||
1373 | geom = kwargs["geom"] |
||
1374 | if not geom.data_shape == self.geom.data_shape: |
||
1375 | raise ValueError( |
||
1376 | "Can't copy and change data size of the map. " |
||
1377 | f" Current shape {self.geom.data_shape}," |
||
1378 | f" requested shape {geom.data_shape}" |
||
1379 | ) |
||
1380 | |||
1381 | return self._init_copy(**kwargs) |
||
1382 | |||
1383 | def apply_edisp(self, edisp): |
||
1384 | """Apply energy dispersion to map. Requires energy axis. |
||
1385 | |||
1386 | Parameters |
||
1387 | ---------- |
||
1388 | edisp : `gammapy.irf.EDispKernel` |
||
1389 | Energy dispersion matrix |
||
1390 | |||
1391 | Returns |
||
1392 | ------- |
||
1393 | map : `WcsNDMap` |
||
1394 | Map with energy dispersion applied. |
||
1395 | """ |
||
1396 | # TODO: either use sparse matrix mutiplication or something like edisp.is_diagonal |
||
1397 | if edisp is not None: |
||
1398 | loc = self.geom.axes.index("energy_true") |
||
1399 | data = np.rollaxis(self.data, loc, len(self.data.shape)) |
||
1400 | data = np.dot(data, edisp.pdf_matrix) |
||
1401 | data = np.rollaxis(data, -1, loc) |
||
1402 | energy_axis = edisp.axes["energy"].copy(name="energy") |
||
1403 | else: |
||
1404 | data = self.data |
||
1405 | energy_axis = self.geom.axes["energy_true"].copy(name="energy") |
||
1406 | |||
1407 | geom = self.geom.to_image().to_cube(axes=[energy_axis]) |
||
1408 | return self._init_copy(geom=geom, data=data) |
||
1409 | |||
1410 | def mask_nearest_position(self, position): |
||
1411 | """Given a sky coordinate return nearest valid position in the mask |
||
1412 | |||
1413 | If the mask contains additional axes, the mask is reduced over those. |
||
1414 | |||
1415 | Parameters |
||
1416 | ---------- |
||
1417 | position : `~astropy.coordinates.SkyCoord` |
||
1418 | Test position |
||
1419 | |||
1420 | Returns |
||
1421 | ------- |
||
1422 | position : `~astropy.coordinates.SkyCoord` |
||
1423 | Nearest position in the mask |
||
1424 | """ |
||
1425 | if not self.geom.is_image: |
||
1426 | raise ValueError("Method only supported for 2D images") |
||
1427 | |||
1428 | coords = self.geom.to_image().get_coord().skycoord |
||
1429 | separation = coords.separation(position) |
||
1430 | separation[~self.data] = np.inf |
||
1431 | idx = np.argmin(separation) |
||
1432 | return coords.flatten()[idx] |
||
1433 | |||
1434 | def sum_over_axes(self, axes_names=None, keepdims=True, weights=None): |
||
1435 | """To sum map values over all non-spatial axes. |
||
1436 | |||
1437 | Parameters |
||
1438 | ---------- |
||
1439 | keepdims : bool, optional |
||
1440 | If this is set to true, the axes which are summed over are left in |
||
1441 | the map with a single bin |
||
1442 | axes_names: list of str |
||
1443 | Names of MapAxis to reduce over. If None, all will summed over |
||
1444 | weights : `Map` |
||
1445 | Weights to be applied. The Map should have the same geometry. |
||
1446 | |||
1447 | Returns |
||
1448 | ------- |
||
1449 | map_out : `~Map` |
||
1450 | Map with non-spatial axes summed over |
||
1451 | """ |
||
1452 | return self.reduce_over_axes( |
||
1453 | func=np.add, axes_names=axes_names, keepdims=keepdims, weights=weights |
||
1454 | ) |
||
1455 | |||
1456 | def reduce_over_axes( |
||
1457 | self, func=np.add, keepdims=False, axes_names=None, weights=None |
||
1458 | ): |
||
1459 | """Reduce map over non-spatial axes |
||
1460 | |||
1461 | Parameters |
||
1462 | ---------- |
||
1463 | func : `~numpy.ufunc` |
||
1464 | Function to use for reducing the data. |
||
1465 | keepdims : bool, optional |
||
1466 | If this is set to true, the axes which are summed over are left in |
||
1467 | the map with a single bin |
||
1468 | axes_names: list |
||
1469 | Names of MapAxis to reduce over |
||
1470 | If None, all will reduced |
||
1471 | weights : `Map` |
||
1472 | Weights to be applied. |
||
1473 | |||
1474 | Returns |
||
1475 | ------- |
||
1476 | map_out : `~Map` |
||
1477 | Map with non-spatial axes reduced |
||
1478 | """ |
||
1479 | if axes_names is None: |
||
1480 | axes_names = self.geom.axes.names |
||
1481 | |||
1482 | map_out = self.copy() |
||
1483 | for axis_name in axes_names: |
||
1484 | map_out = map_out.reduce( |
||
1485 | axis_name, func=func, keepdims=keepdims, weights=weights |
||
1486 | ) |
||
1487 | |||
1488 | return map_out |
||
1489 | |||
1490 | def reduce(self, axis_name, func=np.add, keepdims=False, weights=None): |
||
1491 | """Reduce map over a single non-spatial axis |
||
1492 | |||
1493 | Parameters |
||
1494 | ---------- |
||
1495 | axis_name: str |
||
1496 | The name of the axis to reduce over |
||
1497 | func : `~numpy.ufunc` |
||
1498 | Function to use for reducing the data. |
||
1499 | keepdims : bool, optional |
||
1500 | If this is set to true, the axes which are summed over are left in |
||
1501 | the map with a single bin |
||
1502 | weights : `Map` |
||
1503 | Weights to be applied. |
||
1504 | |||
1505 | Returns |
||
1506 | ------- |
||
1507 | map_out : `~Map` |
||
1508 | Map with the given non-spatial axes reduced |
||
1509 | """ |
||
1510 | if keepdims: |
||
1511 | geom = self.geom.squash(axis_name=axis_name) |
||
1512 | else: |
||
1513 | geom = self.geom.drop(axis_name=axis_name) |
||
1514 | |||
1515 | idx = self.geom.axes.index_data(axis_name) |
||
1516 | |||
1517 | data = self.data |
||
1518 | |||
1519 | if weights is not None: |
||
1520 | data = data * weights |
||
1521 | |||
1522 | data = func.reduce(data, axis=idx, keepdims=keepdims, where=~np.isnan(data)) |
||
1523 | return self._init_copy(geom=geom, data=data) |
||
1524 | |||
1525 | def cumsum(self, axis_name): |
||
1526 | """Compute cumulative sum along a given axis |
||
1527 | |||
1528 | Parameters |
||
1529 | ---------- |
||
1530 | axis_name : str |
||
1531 | Along which axis to integrate. |
||
1532 | |||
1533 | Returns |
||
1534 | ------- |
||
1535 | cumsum : `Map` |
||
1536 | Map with cumulative sum |
||
1537 | """ |
||
1538 | axis = self.geom.axes[axis_name] |
||
1539 | axis_idx = self.geom.axes.index_data(axis_name) |
||
1540 | |||
1541 | # TODO: the broadcasting should be done by axis.center, axis.bin_width etc. |
||
1542 | shape = [1] * len(self.geom.data_shape) |
||
1543 | shape[axis_idx] = -1 |
||
1544 | |||
1545 | values = self.quantity * axis.bin_width.reshape(shape) |
||
1546 | |||
1547 | if axis_name == "rad": |
||
1548 | # take Jacobian into account |
||
1549 | values = 2 * np.pi * axis.center.reshape(shape) * values |
||
1550 | |||
1551 | data = np.insert(values.cumsum(axis=axis_idx), 0, 0, axis=axis_idx) |
||
1552 | |||
1553 | axis_shifted = MapAxis.from_nodes( |
||
1554 | axis.edges, name=axis.name, interp=axis.interp |
||
1555 | ) |
||
1556 | axes = self.geom.axes.replace(axis_shifted) |
||
1557 | geom = self.geom.to_image().to_cube(axes) |
||
1558 | return self.__class__(geom=geom, data=data.value, unit=data.unit) |
||
1559 | |||
1560 | def integral(self, axis_name, coords, **kwargs): |
||
1561 | """Compute integral along a given axis |
||
1562 | |||
1563 | This method uses interpolation of the cumulative sum. |
||
1564 | |||
1565 | Parameters |
||
1566 | ---------- |
||
1567 | axis_name : str |
||
1568 | Along which axis to integrate. |
||
1569 | coords : dict or `MapCoord` |
||
1570 | Map coordinates |
||
1571 | |||
1572 | **kwargs : dict |
||
1573 | Coordinates at which to evaluate the IRF |
||
1574 | |||
1575 | Returns |
||
1576 | ------- |
||
1577 | array : `~astropy.units.Quantity` |
||
1578 | Returns 2D array with axes offset |
||
1579 | """ |
||
1580 | cumsum = self.cumsum(axis_name=axis_name) |
||
1581 | cumsum = cumsum.pad(pad_width=1, axis_name=axis_name, mode="edge") |
||
1582 | return u.Quantity( |
||
1583 | cumsum.interp_by_coord(coords, **kwargs), cumsum.unit, copy=False |
||
1584 | ) |
||
1585 | |||
1586 | def normalize(self, axis_name=None): |
||
1587 | """Normalise data in place along a given axis. |
||
1588 | |||
1589 | Parameters |
||
1590 | ---------- |
||
1591 | axis_name : str |
||
1592 | Along which axis to normalize. |
||
1593 | |||
1594 | """ |
||
1595 | cumsum = self.cumsum(axis_name=axis_name).quantity |
||
1596 | |||
1597 | with np.errstate(invalid="ignore", divide="ignore"): |
||
1598 | axis = self.geom.axes.index_data(axis_name=axis_name) |
||
1599 | normed = self.quantity / cumsum.max(axis=axis, keepdims=True) |
||
1600 | |||
1601 | self.quantity = np.nan_to_num(normed) |
||
1602 | |||
1603 | @classmethod |
||
1604 | def from_stack(cls, maps, axis=None, axis_name=None): |
||
1605 | """Create Map from list of images and a non-spatial axis. |
||
1606 | |||
1607 | The image geometries must be aligned, except for the axis that is stacked. |
||
1608 | |||
1609 | Parameters |
||
1610 | ---------- |
||
1611 | maps : list of `Map` objects |
||
1612 | List of maps |
||
1613 | axis : `MapAxis` |
||
1614 | If a `MapAxis` is provided the maps are stacked along the last data |
||
1615 | axis and the new axis is introduced. |
||
1616 | axis_name : str |
||
1617 | If an axis name is as string the given the maps are stacked along |
||
1618 | the given axis name. |
||
1619 | |||
1620 | Returns |
||
1621 | ------- |
||
1622 | map : `Map` |
||
1623 | Map with additional non-spatial axis. |
||
1624 | """ |
||
1625 | geom = maps[0].geom |
||
1626 | |||
1627 | if axis_name is None and axis is None: |
||
1628 | axis_name = geom.axes.names[-1] |
||
1629 | |||
1630 | if axis_name: |
||
1631 | axis = MapAxis.from_stack(axes=[m.geom.axes[axis_name] for m in maps]) |
||
1632 | geom = geom.drop(axis_name=axis_name) |
||
1633 | |||
1634 | data = [] |
||
1635 | |||
1636 | for m in maps: |
||
1637 | if axis_name: |
||
1638 | m_geom = m.geom.drop(axis_name=axis_name) |
||
1639 | else: |
||
1640 | m_geom = m.geom |
||
1641 | |||
1642 | if not m_geom == geom: |
||
1643 | raise ValueError(f"Image geometries not aligned: {m.geom} and {geom}") |
||
1644 | |||
1645 | data.append(m.quantity.to_value(maps[0].unit)) |
||
1646 | |||
1647 | return cls.from_geom( |
||
1648 | data=np.stack(data), geom=geom.to_cube(axes=[axis]), unit=maps[0].unit |
||
1649 | ) |
||
1650 | |||
1651 | def split_by_axis(self, axis_name): |
||
1652 | """Split a Map along an axis into multiple maps. |
||
1653 | |||
1654 | Parameters |
||
1655 | ---------- |
||
1656 | axis_name : str |
||
1657 | Name of the axis to split |
||
1658 | |||
1659 | Returns |
||
1660 | ------- |
||
1661 | maps : list |
||
1662 | A list of `~gammapy.maps.Map` |
||
1663 | """ |
||
1664 | maps = [] |
||
1665 | axis = self.geom.axes[axis_name] |
||
1666 | for idx in range(axis.nbin): |
||
1667 | m = self.slice_by_idx({axis_name: idx}) |
||
1668 | maps.append(m) |
||
1669 | return maps |
||
1670 | |||
1671 | def to_cube(self, axes): |
||
1672 | """Append non-spatial axes to create a higher-dimensional Map. |
||
1673 | |||
1674 | This will result in a Map with a new geometry with |
||
1675 | N+M dimensions where N is the number of current dimensions and |
||
1676 | M is the number of axes in the list. The data is reshaped onto the |
||
1677 | new geometry |
||
1678 | |||
1679 | Parameters |
||
1680 | ---------- |
||
1681 | axes : list |
||
1682 | Axes that will be appended to this Map. |
||
1683 | The axes should have only one bin |
||
1684 | |||
1685 | Returns |
||
1686 | ------- |
||
1687 | map : `~gammapy.maps.WcsNDMap` |
||
1688 | new map |
||
1689 | """ |
||
1690 | for ax in axes: |
||
1691 | if ax.nbin > 1: |
||
1692 | raise ValueError(ax.name, "should have only one bin") |
||
1693 | geom = self.geom.to_cube(axes) |
||
1694 | data = self.data.reshape((1,) * len(axes) + self.data.shape) |
||
1695 | return self.from_geom(data=data, geom=geom, unit=self.unit) |
||
1696 | |||
1697 | def get_spectrum(self, region=None, func=np.nansum, weights=None): |
||
1698 | """Extract spectrum in a given region. |
||
1699 | |||
1700 | The spectrum can be computed by summing (or, more generally, applying ``func``) |
||
1701 | along the spatial axes in each energy bin. This occurs only inside the ``region``, |
||
1702 | which by default is assumed to be the whole spatial extension of the map. |
||
1703 | |||
1704 | Parameters |
||
1705 | ---------- |
||
1706 | region: `~regions.Region` |
||
1707 | Region (pixel or sky regions accepted). |
||
1708 | func : numpy.func |
||
1709 | Function to reduce the data. Default is np.nansum. |
||
1710 | For a boolean Map, use np.any or np.all. |
||
1711 | weights : `WcsNDMap` |
||
1712 | Array to be used as weights. The geometry must be equivalent. |
||
1713 | |||
1714 | Returns |
||
1715 | ------- |
||
1716 | spectrum : `~gammapy.maps.RegionNDMap` |
||
1717 | Spectrum in the given region. |
||
1718 | """ |
||
1719 | if not self.geom.has_energy_axis: |
||
1720 | raise ValueError("Energy axis required") |
||
1721 | |||
1722 | return self.to_region_nd_map(region=region, func=func, weights=weights) |
||
1723 | |||
1724 | def to_unit(self, unit): |
||
1725 | """Convert map to different unit |
||
1726 | |||
1727 | Parameters |
||
1728 | ---------- |
||
1729 | unit : `~astropy.unit.Unit` or str |
||
1730 | New unit |
||
1731 | |||
1732 | Returns |
||
1733 | ------- |
||
1734 | map : `Map` |
||
1735 | Map with new unit and converted data |
||
1736 | """ |
||
1737 | data = self.quantity.to_value(unit) |
||
1738 | return self.from_geom(self.geom, data=data, unit=unit) |
||
1739 | |||
1740 | View Code Duplication | def is_allclose(self, other, rtol_axes=1e-3, atol_axes=1e-6, **kwargs): |
|
|
|||
1741 | """Compare two Maps for close equivalency |
||
1742 | |||
1743 | Parameters |
||
1744 | ---------- |
||
1745 | other : `gammapy.maps.Map` |
||
1746 | The Map to compare against |
||
1747 | rtol_axes : float |
||
1748 | Relative tolerance for the axes comparison. |
||
1749 | atol_axes : float |
||
1750 | Relative tolerance for the axes comparison. |
||
1751 | **kwargs : dict |
||
1752 | keywords passed to `numpy.allclose` |
||
1753 | |||
1754 | Returns |
||
1755 | ------- |
||
1756 | is_allclose : bool |
||
1757 | Whether the Map is all close. |
||
1758 | """ |
||
1759 | if not isinstance(other, self.__class__): |
||
1760 | return TypeError(f"Cannot compare {type(self)} and {type(other)}") |
||
1761 | |||
1762 | if self.data.shape != other.data.shape: |
||
1763 | return False |
||
1764 | |||
1765 | axes_eq = self.axes.is_allclose(other.axes, rtol=rtol_axes, atol=atol_axes) |
||
1766 | data_eq = np.allclose(self.quantity, other.quantity, **kwargs) |
||
1767 | return axes_eq and data_eq |
||
1768 | |||
1769 | def __repr__(self): |
||
1770 | geom = self.geom.__class__.__name__ |
||
1771 | axes = ["skycoord"] if self.geom.is_hpx else ["lon", "lat"] |
||
1772 | axes = axes + [_.name for _ in self.geom.axes] |
||
1773 | |||
1774 | return ( |
||
1775 | f"{self.__class__.__name__}\n\n" |
||
1776 | f"\tgeom : {geom} \n " |
||
1777 | f"\taxes : {axes}\n" |
||
1778 | f"\tshape : {self.geom.data_shape[::-1]}\n" |
||
1779 | f"\tndim : {self.geom.ndim}\n" |
||
1780 | f"\tunit : {self.unit}\n" |
||
1781 | f"\tdtype : {self.data.dtype}\n" |
||
1782 | ) |
||
1783 | |||
1784 | def _arithmetics(self, operator, other, copy): |
||
1785 | """Perform arithmetic on maps after checking geometry consistency.""" |
||
1786 | if isinstance(other, Map): |
||
1787 | if self.geom == other.geom: |
||
1788 | q = other.quantity |
||
1789 | else: |
||
1790 | raise ValueError("Map Arithmetic: Inconsistent geometries.") |
||
1791 | else: |
||
1792 | q = u.Quantity(other, copy=False) |
||
1793 | |||
1794 | out = self.copy() if copy else self |
||
1795 | out.quantity = operator(out.quantity, q) |
||
1796 | return out |
||
1797 | |||
1798 | def _boolean_arithmetics(self, operator, other, copy): |
||
1799 | """Perform arithmetic on maps after checking geometry consistency.""" |
||
1800 | if operator == np.logical_not: |
||
1801 | out = self.copy() |
||
1802 | out.data = operator(out.data) |
||
1803 | return out |
||
1804 | |||
1805 | if isinstance(other, Map): |
||
1806 | if self.geom == other.geom: |
||
1807 | other = other.data |
||
1808 | else: |
||
1809 | raise ValueError("Map Arithmetic: Inconsistent geometries.") |
||
1810 | |||
1811 | out = self.copy() if copy else self |
||
1812 | out.data = operator(out.data, other) |
||
1813 | return out |
||
1814 | |||
1815 | def __add__(self, other): |
||
1816 | return self._arithmetics(np.add, other, copy=True) |
||
1817 | |||
1818 | def __iadd__(self, other): |
||
1819 | return self._arithmetics(np.add, other, copy=False) |
||
1820 | |||
1821 | def __sub__(self, other): |
||
1822 | return self._arithmetics(np.subtract, other, copy=True) |
||
1823 | |||
1824 | def __isub__(self, other): |
||
1825 | return self._arithmetics(np.subtract, other, copy=False) |
||
1826 | |||
1827 | def __mul__(self, other): |
||
1828 | return self._arithmetics(np.multiply, other, copy=True) |
||
1829 | |||
1830 | def __imul__(self, other): |
||
1831 | return self._arithmetics(np.multiply, other, copy=False) |
||
1832 | |||
1833 | def __truediv__(self, other): |
||
1834 | return self._arithmetics(np.true_divide, other, copy=True) |
||
1835 | |||
1836 | def __itruediv__(self, other): |
||
1837 | return self._arithmetics(np.true_divide, other, copy=False) |
||
1838 | |||
1839 | def __le__(self, other): |
||
1840 | return self._arithmetics(np.less_equal, other, copy=True) |
||
1841 | |||
1842 | def __lt__(self, other): |
||
1843 | return self._arithmetics(np.less, other, copy=True) |
||
1844 | |||
1845 | def __ge__(self, other): |
||
1846 | return self._arithmetics(np.greater_equal, other, copy=True) |
||
1847 | |||
1848 | def __gt__(self, other): |
||
1849 | return self._arithmetics(np.greater, other, copy=True) |
||
1850 | |||
1851 | def __eq__(self, other): |
||
1852 | return self._arithmetics(np.equal, other, copy=True) |
||
1853 | |||
1854 | def __ne__(self, other): |
||
1855 | return self._arithmetics(np.not_equal, other, copy=True) |
||
1856 | |||
1857 | def __and__(self, other): |
||
1858 | return self._boolean_arithmetics(np.logical_and, other, copy=True) |
||
1859 | |||
1860 | def __or__(self, other): |
||
1861 | return self._boolean_arithmetics(np.logical_or, other, copy=True) |
||
1862 | |||
1863 | def __invert__(self): |
||
1864 | return self._boolean_arithmetics(np.logical_not, None, copy=True) |
||
1865 | |||
1866 | def __xor__(self, other): |
||
1867 | return self._boolean_arithmetics(np.logical_xor, other, copy=True) |
||
1868 | |||
1869 | def __iand__(self, other): |
||
1870 | return self._boolean_arithmetics(np.logical_and, other, copy=False) |
||
1871 | |||
1872 | def __ior__(self, other): |
||
1873 | return self._boolean_arithmetics(np.logical_or, other, copy=False) |
||
1874 | |||
1875 | def __ixor__(self, other): |
||
1876 | return self._boolean_arithmetics(np.logical_xor, other, copy=False) |
||
1877 | |||
1878 | def __array__(self): |
||
1879 | return self.data |
||
1880 | |||
1881 | def sample_coord(self, n_events, random_state=0): |
||
1882 | """Sample position and energy of events. |
||
1883 | |||
1884 | Parameters |
||
1885 | ---------- |
||
1886 | n_events : int |
||
1887 | Number of events to sample. |
||
1888 | random_state : {int, 'random-seed', 'global-rng', `~numpy.random.RandomState`} |
||
1889 | Defines random number generator initialisation. |
||
1890 | Passed to `~gammapy.utils.random.get_random_state`. |
||
1891 | |||
1892 | Returns |
||
1893 | ------- |
||
1894 | coords : `~gammapy.maps.MapCoord` object. |
||
1895 | Sequence of coordinates and energies of the sampled events. |
||
1896 | """ |
||
1897 | |||
1898 | random_state = get_random_state(random_state) |
||
1899 | sampler = InverseCDFSampler(pdf=self.data, random_state=random_state) |
||
1900 | |||
1901 | coords_pix = sampler.sample(n_events) |
||
1902 | coords = self.geom.pix_to_coord(coords_pix[::-1]) |
||
1903 | |||
1904 | # TODO: pix_to_coord should return a MapCoord object |
||
1905 | cdict = OrderedDict(zip(self.geom.axes_names, coords)) |
||
1906 | |||
1907 | return MapCoord.create(cdict, frame=self.geom.frame) |
||
1908 |