1
|
|
|
# Licensed under a 3-clause BSD style license - see LICENSE.rst |
2
|
|
|
import logging |
3
|
|
|
from collections import OrderedDict |
4
|
|
|
import numpy as np |
5
|
|
|
import scipy.interpolate |
6
|
|
|
import scipy.ndimage |
7
|
|
|
import scipy.signal |
8
|
|
|
import astropy.units as u |
9
|
|
|
from astropy.convolution import Tophat2DKernel |
10
|
|
|
from astropy.io import fits |
11
|
|
|
from astropy.nddata import Cutout2D |
12
|
|
|
from gammapy.extern.skimage import block_reduce |
13
|
|
|
from gammapy.utils.interpolation import ScaledRegularGridInterpolator |
14
|
|
|
from gammapy.utils.random import InverseCDFSampler, get_random_state |
15
|
|
|
from gammapy.utils.units import unit_from_fits_image_hdu |
16
|
|
|
from .geom import MapCoord, pix_tuple_to_idx |
17
|
|
|
from .reproject import reproject_car_to_hpx, reproject_car_to_wcs |
18
|
|
|
from .utils import INVALID_INDEX, interp_to_order |
19
|
|
|
from .wcs import _check_width |
20
|
|
|
from .wcsmap import WcsGeom, WcsMap |
21
|
|
|
|
22
|
|
|
__all__ = ["WcsNDMap"] |
23
|
|
|
|
24
|
|
|
log = logging.getLogger(__name__) |
25
|
|
|
|
26
|
|
|
|
27
|
|
|
class WcsNDMap(WcsMap): |
28
|
|
|
"""HEALPix map with any number of non-spatial dimensions. |
29
|
|
|
|
30
|
|
|
This class uses an ND numpy array to store map values. For maps with |
31
|
|
|
non-spatial dimensions and variable pixel size it will allocate an |
32
|
|
|
array with dimensions commensurate with the largest image plane. |
33
|
|
|
|
34
|
|
|
Parameters |
35
|
|
|
---------- |
36
|
|
|
geom : `~gammapy.maps.WcsGeom` |
37
|
|
|
WCS geometry object. |
38
|
|
|
data : `~numpy.ndarray` |
39
|
|
|
Data array. If none then an empty array will be allocated. |
40
|
|
|
dtype : str, optional |
41
|
|
|
Data type, default is float32 |
42
|
|
|
meta : `dict` |
43
|
|
|
Dictionary to store meta data. |
44
|
|
|
unit : str or `~astropy.units.Unit` |
45
|
|
|
The map unit |
46
|
|
|
""" |
47
|
|
|
|
48
|
|
|
def __init__(self, geom, data=None, dtype="float32", meta=None, unit=""): |
49
|
|
|
# TODO: Figure out how to mask pixels for integer data types |
50
|
|
|
|
51
|
|
|
data_shape = geom.data_shape |
52
|
|
|
|
53
|
|
|
if data is None: |
54
|
|
|
data = self._make_default_data(geom, data_shape, dtype) |
55
|
|
|
|
56
|
|
|
super().__init__(geom, data, meta, unit) |
57
|
|
|
|
58
|
|
|
@staticmethod |
59
|
|
|
def _make_default_data(geom, shape_np, dtype): |
60
|
|
|
# Check whether corners of each image plane are valid |
61
|
|
|
|
62
|
|
|
data = np.zeros(shape_np, dtype=dtype) |
63
|
|
|
|
64
|
|
|
if not geom.is_regular or geom.is_allsky: |
65
|
|
|
coords = geom.get_coord() |
66
|
|
|
is_nan = np.isnan(coords.lon) |
67
|
|
|
data[is_nan] = np.nan |
68
|
|
|
|
69
|
|
|
return data |
70
|
|
|
|
71
|
|
|
@classmethod |
72
|
|
|
def from_hdu(cls, hdu, hdu_bands=None): |
73
|
|
|
"""Make a WcsNDMap object from a FITS HDU. |
74
|
|
|
|
75
|
|
|
Parameters |
76
|
|
|
---------- |
77
|
|
|
hdu : `~astropy.io.fits.BinTableHDU` or `~astropy.io.fits.ImageHDU` |
78
|
|
|
The map FITS HDU. |
79
|
|
|
hdu_bands : `~astropy.io.fits.BinTableHDU` |
80
|
|
|
The BANDS table HDU. |
81
|
|
|
""" |
82
|
|
|
geom = WcsGeom.from_header(hdu.header, hdu_bands) |
83
|
|
|
shape = tuple([ax.nbin for ax in geom.axes]) |
84
|
|
|
shape_wcs = tuple([np.max(geom.npix[0]), np.max(geom.npix[1])]) |
85
|
|
|
|
86
|
|
|
meta = cls._get_meta_from_header(hdu.header) |
87
|
|
|
unit = unit_from_fits_image_hdu(hdu.header) |
88
|
|
|
map_out = cls(geom, meta=meta, unit=unit) |
89
|
|
|
|
90
|
|
|
# TODO: Should we support extracting slices? |
91
|
|
|
if isinstance(hdu, fits.BinTableHDU): |
92
|
|
|
pix = hdu.data.field("PIX") |
93
|
|
|
pix = np.unravel_index(pix, shape_wcs[::-1]) |
94
|
|
|
vals = hdu.data.field("VALUE") |
95
|
|
|
if "CHANNEL" in hdu.data.columns.names and shape: |
96
|
|
|
chan = hdu.data.field("CHANNEL") |
97
|
|
|
chan = np.unravel_index(chan, shape[::-1]) |
98
|
|
|
idx = chan + pix |
99
|
|
|
else: |
100
|
|
|
idx = pix |
101
|
|
|
|
102
|
|
|
map_out.set_by_idx(idx[::-1], vals) |
103
|
|
|
else: |
104
|
|
|
map_out.data = hdu.data |
105
|
|
|
|
106
|
|
|
return map_out |
107
|
|
|
|
108
|
|
|
def get_by_idx(self, idx): |
109
|
|
|
idx = pix_tuple_to_idx(idx) |
110
|
|
|
return self.data.T[idx] |
111
|
|
|
|
112
|
|
|
def interp_by_coord(self, coords, interp=None, fill_value=None): |
113
|
|
|
|
114
|
|
|
if self.geom.is_regular: |
115
|
|
|
pix = self.geom.coord_to_pix(coords) |
116
|
|
|
return self.interp_by_pix(pix, interp=interp, fill_value=fill_value) |
117
|
|
|
else: |
118
|
|
|
return self._interp_by_coord_griddata(coords, interp=interp) |
119
|
|
|
|
120
|
|
|
def interp_by_pix(self, pix, interp=None, fill_value=None): |
121
|
|
|
"""Interpolate map values at the given pixel coordinates. |
122
|
|
|
""" |
123
|
|
|
if not self.geom.is_regular: |
124
|
|
|
raise ValueError("interp_by_pix only supported for regular geom.") |
125
|
|
|
|
126
|
|
|
order = interp_to_order(interp) |
127
|
|
|
if order == 0 or order == 1: |
128
|
|
|
return self._interp_by_pix_linear_grid( |
129
|
|
|
pix, order=order, fill_value=fill_value |
130
|
|
|
) |
131
|
|
|
elif order == 2 or order == 3: |
132
|
|
|
return self._interp_by_pix_map_coordinates(pix, order=order) |
133
|
|
|
else: |
134
|
|
|
raise ValueError(f"Invalid interpolation order: {order!r}") |
135
|
|
|
|
136
|
|
|
def _interp_by_pix_linear_grid(self, pix, order=1, fill_value=None): |
137
|
|
|
# TODO: Cache interpolator |
138
|
|
|
method_lookup = {0: "nearest", 1: "linear"} |
139
|
|
|
try: |
140
|
|
|
method = method_lookup[order] |
141
|
|
|
except KeyError: |
142
|
|
|
raise ValueError(f"Invalid interpolation order: {order!r}") |
143
|
|
|
|
144
|
|
|
grid_pix = [np.arange(n, dtype=float) for n in self.data.shape[::-1]] |
145
|
|
|
|
146
|
|
|
if np.any(np.isfinite(self.data)): |
147
|
|
|
data = self.data.copy().T |
148
|
|
|
data[~np.isfinite(data)] = 0.0 |
149
|
|
|
else: |
150
|
|
|
data = self.data.T |
151
|
|
|
|
152
|
|
|
fn = ScaledRegularGridInterpolator( |
153
|
|
|
grid_pix, data, fill_value=fill_value, bounds_error=False, method=method |
154
|
|
|
) |
155
|
|
|
return fn(tuple(pix), clip=False) |
156
|
|
|
|
157
|
|
|
def _interp_by_pix_map_coordinates(self, pix, order=1): |
158
|
|
|
pix = tuple( |
159
|
|
|
[ |
160
|
|
|
np.array(x, ndmin=1) |
161
|
|
|
if not isinstance(x, np.ndarray) or x.ndim == 0 |
162
|
|
|
else x |
163
|
|
|
for x in pix |
164
|
|
|
] |
165
|
|
|
) |
166
|
|
|
return scipy.ndimage.map_coordinates( |
167
|
|
|
self.data.T, pix, order=order, mode="nearest" |
168
|
|
|
) |
169
|
|
|
|
170
|
|
|
def _interp_by_coord_griddata(self, coords, interp=None): |
171
|
|
|
order = interp_to_order(interp) |
172
|
|
|
method_lookup = {0: "nearest", 1: "linear", 3: "cubic"} |
173
|
|
|
method = method_lookup.get(order, None) |
174
|
|
|
if method is None: |
175
|
|
|
raise ValueError(f"Invalid interp: {interp!r}") |
176
|
|
|
|
177
|
|
|
grid_coords = tuple(self.geom.get_coord(flat=True)) |
178
|
|
|
data = self.data[np.isfinite(self.data)] |
179
|
|
|
vals = scipy.interpolate.griddata( |
180
|
|
|
grid_coords, data, tuple(coords), method=method |
181
|
|
|
) |
182
|
|
|
|
183
|
|
|
m = ~np.isfinite(vals) |
184
|
|
|
if np.any(m): |
185
|
|
|
vals_fill = scipy.interpolate.griddata( |
186
|
|
|
grid_coords, data, tuple([c[m] for c in coords]), method="nearest" |
187
|
|
|
) |
188
|
|
|
vals[m] = vals_fill |
189
|
|
|
|
190
|
|
|
return vals |
191
|
|
|
|
192
|
|
|
def fill_by_idx(self, idx, weights=None): |
193
|
|
|
idx = pix_tuple_to_idx(idx) |
194
|
|
|
msk = np.all(np.stack([t != INVALID_INDEX.int for t in idx]), axis=0) |
195
|
|
|
idx = [t[msk] for t in idx] |
196
|
|
|
|
197
|
|
|
if weights is not None: |
198
|
|
|
if isinstance(weights, u.Quantity): |
199
|
|
|
weights = weights.to_value(self.unit) |
200
|
|
|
weights = weights[msk] |
201
|
|
|
|
202
|
|
|
idx = np.ravel_multi_index(idx, self.data.T.shape) |
203
|
|
|
idx, idx_inv = np.unique(idx, return_inverse=True) |
204
|
|
|
weights = np.bincount(idx_inv, weights=weights).astype(self.data.dtype) |
205
|
|
|
self.data.T.flat[idx] += weights |
206
|
|
|
|
207
|
|
|
def set_by_idx(self, idx, vals): |
208
|
|
|
idx = pix_tuple_to_idx(idx) |
209
|
|
|
self.data.T[idx] = vals |
210
|
|
|
|
211
|
|
|
def sum_over_axes(self, keepdims=False): |
212
|
|
|
"""To sum map values over all non-spatial axes. |
213
|
|
|
|
214
|
|
|
Parameters |
215
|
|
|
---------- |
216
|
|
|
keepdims : bool, optional |
217
|
|
|
If this is set to true, the axes which are summed over are left in |
218
|
|
|
the map with a single bin |
219
|
|
|
|
220
|
|
|
Returns |
221
|
|
|
------- |
222
|
|
|
map_out : WcsNDMap |
223
|
|
|
Map with non-spatial axes summed over |
224
|
|
|
""" |
225
|
|
|
axis = tuple(range(self.data.ndim - 2)) |
226
|
|
|
geom = self.geom.to_image() |
227
|
|
|
if keepdims: |
228
|
|
|
for ax in self.geom.axes: |
229
|
|
|
geom = geom.to_cube([ax.squash()]) |
230
|
|
|
data = np.nansum(self.data, axis=axis, keepdims=keepdims) |
231
|
|
|
# TODO: summing over the axis can change the unit, handle this correctly |
232
|
|
|
return self._init_copy(geom=geom, data=data) |
233
|
|
|
|
234
|
|
|
def _reproject_to_wcs(self, geom, mode="interp", order=1): |
235
|
|
|
from reproject import reproject_interp, reproject_exact |
236
|
|
|
|
237
|
|
|
data = np.empty(geom.data_shape) |
238
|
|
|
|
239
|
|
|
for img, idx in self.iter_by_image(): |
240
|
|
|
# TODO: Create WCS object for image plane if |
241
|
|
|
# multi-resolution geom |
242
|
|
|
shape_out = geom.get_image_shape(idx)[::-1] |
243
|
|
|
|
244
|
|
|
if self.geom.projection == "CAR" and self.geom.is_allsky: |
245
|
|
|
vals, footprint = reproject_car_to_wcs( |
246
|
|
|
(img, self.geom.wcs), geom.wcs, shape_out=shape_out |
247
|
|
|
) |
248
|
|
|
elif mode == "interp": |
249
|
|
|
vals, footprint = reproject_interp( |
250
|
|
|
(img, self.geom.wcs), geom.wcs, shape_out=shape_out |
251
|
|
|
) |
252
|
|
|
elif mode == "exact": |
253
|
|
|
vals, footprint = reproject_exact( |
254
|
|
|
(img, self.geom.wcs), geom.wcs, shape_out=shape_out |
255
|
|
|
) |
256
|
|
|
else: |
257
|
|
|
raise TypeError(f"mode must be 'interp' or 'exact'. Got: {mode!r}") |
258
|
|
|
|
259
|
|
|
data[idx] = vals |
260
|
|
|
|
261
|
|
|
return self._init_copy(geom=geom, data=data) |
262
|
|
|
|
263
|
|
|
def _reproject_to_hpx(self, geom, mode="interp", order=1): |
264
|
|
|
from reproject import reproject_to_healpix |
265
|
|
|
|
266
|
|
|
data = np.empty(geom.data_shape) |
267
|
|
|
coordsys = "galactic" if geom.coordsys == "GAL" else "icrs" |
268
|
|
|
|
269
|
|
|
for img, idx in self.iter_by_image(): |
270
|
|
|
# TODO: For partial-sky HPX we need to map from full- to |
271
|
|
|
# partial-sky indices |
272
|
|
|
if self.geom.projection == "CAR" and self.geom.is_allsky: |
273
|
|
|
vals, footprint = reproject_car_to_hpx( |
274
|
|
|
(img, self.geom.wcs), |
275
|
|
|
coordsys, |
276
|
|
|
nside=geom.nside, |
277
|
|
|
nested=geom.nest, |
278
|
|
|
order=order, |
279
|
|
|
) |
280
|
|
|
else: |
281
|
|
|
vals, footprint = reproject_to_healpix( |
282
|
|
|
(img, self.geom.wcs), |
283
|
|
|
coordsys, |
284
|
|
|
nside=geom.nside, |
285
|
|
|
nested=geom.nest, |
286
|
|
|
order=order, |
287
|
|
|
) |
288
|
|
|
data[idx] = vals |
289
|
|
|
|
290
|
|
|
return self._init_copy(geom=geom, data=data) |
291
|
|
|
|
292
|
|
|
def pad(self, pad_width, mode="constant", cval=0, order=1): |
293
|
|
|
if np.isscalar(pad_width): |
294
|
|
|
pad_width = (pad_width, pad_width) |
295
|
|
|
pad_width += (0,) * (self.geom.ndim - 2) |
296
|
|
|
|
297
|
|
|
geom = self.geom.pad(pad_width[:2]) |
298
|
|
|
if self.geom.is_regular and mode != "interp": |
299
|
|
|
return self._pad_np(geom, pad_width, mode, cval) |
300
|
|
|
else: |
301
|
|
|
return self._pad_coadd(geom, pad_width, mode, cval, order) |
302
|
|
|
|
303
|
|
|
def _pad_np(self, geom, pad_width, mode, cval): |
304
|
|
|
"""Pad a map using ``numpy.pad``. |
305
|
|
|
|
306
|
|
|
This method only works for regular geometries but should be more |
307
|
|
|
efficient when working with large maps. |
308
|
|
|
""" |
309
|
|
|
kwargs = {} |
310
|
|
|
if mode == "constant": |
311
|
|
|
kwargs["constant_values"] = cval |
312
|
|
|
|
313
|
|
|
pad_width = [(t, t) for t in pad_width] |
314
|
|
|
data = np.pad(self.data, pad_width[::-1], mode) |
315
|
|
|
return self._init_copy(geom=geom, data=data) |
316
|
|
|
|
317
|
|
|
def _pad_coadd(self, geom, pad_width, mode, cval, order): |
318
|
|
|
"""Pad a map manually by coadding the original map with the new map.""" |
319
|
|
|
idx_in = self.geom.get_idx(flat=True) |
320
|
|
|
idx_in = tuple([t + w for t, w in zip(idx_in, pad_width)])[::-1] |
321
|
|
|
idx_out = geom.get_idx(flat=True)[::-1] |
322
|
|
|
map_out = self._init_copy(geom=geom, data=None) |
323
|
|
|
map_out.coadd(self) |
324
|
|
|
|
325
|
|
|
if mode == "constant": |
326
|
|
|
pad_msk = np.zeros_like(map_out.data, dtype=bool) |
327
|
|
|
pad_msk[idx_out] = True |
328
|
|
|
pad_msk[idx_in] = False |
329
|
|
|
map_out.data[pad_msk] = cval |
330
|
|
|
elif mode == "interp": |
331
|
|
|
coords = geom.pix_to_coord(idx_out[::-1]) |
332
|
|
|
m = self.geom.contains(coords) |
333
|
|
|
coords = tuple([c[~m] for c in coords]) |
334
|
|
|
vals = self.interp_by_coord(coords, interp=order) |
335
|
|
|
map_out.set_by_coord(coords, vals) |
336
|
|
|
else: |
337
|
|
|
raise ValueError(f"Invalid mode: {mode!r}") |
338
|
|
|
|
339
|
|
|
return map_out |
340
|
|
|
|
341
|
|
|
def crop(self, crop_width): |
342
|
|
|
if np.isscalar(crop_width): |
343
|
|
|
crop_width = (crop_width, crop_width) |
344
|
|
|
|
345
|
|
|
geom = self.geom.crop(crop_width) |
346
|
|
|
if self.geom.is_regular: |
347
|
|
|
slices = [slice(None)] * len(self.geom.axes) |
348
|
|
|
slices += [ |
349
|
|
|
slice(crop_width[1], int(self.geom.npix[1] - crop_width[1])), |
350
|
|
|
slice(crop_width[0], int(self.geom.npix[0] - crop_width[0])), |
351
|
|
|
] |
352
|
|
|
data = self.data[tuple(slices)] |
353
|
|
|
map_out = self._init_copy(geom=geom, data=data) |
354
|
|
|
else: |
355
|
|
|
# FIXME: This could be done more efficiently by |
356
|
|
|
# constructing the appropriate slices for each image plane |
357
|
|
|
map_out = self._init_copy(geom=geom, data=None) |
358
|
|
|
map_out.coadd(self) |
359
|
|
|
|
360
|
|
|
return map_out |
361
|
|
|
|
362
|
|
|
def upsample(self, factor, order=0, preserve_counts=True, axis=None): |
363
|
|
|
geom = self.geom.upsample(factor, axis=axis) |
364
|
|
|
idx = geom.get_idx() |
365
|
|
|
|
366
|
|
|
if axis is None: |
367
|
|
|
pix = ( |
368
|
|
|
(idx[0] - 0.5 * (factor - 1)) / factor, |
369
|
|
|
(idx[1] - 0.5 * (factor - 1)) / factor, |
370
|
|
|
) + idx[2:] |
371
|
|
|
else: |
372
|
|
|
pix = list(idx) |
373
|
|
|
idx_ax = self.geom.get_axis_index_by_name(axis) |
374
|
|
|
pix[idx_ax] = (pix[idx_ax] - 0.5 * (factor - 1)) / factor |
375
|
|
|
|
376
|
|
|
data = scipy.ndimage.map_coordinates( |
377
|
|
|
self.data.T, tuple(pix), order=order, mode="nearest" |
378
|
|
|
) |
379
|
|
|
|
380
|
|
|
if preserve_counts: |
381
|
|
|
if axis is None: |
382
|
|
|
data /= factor ** 2 |
383
|
|
|
else: |
384
|
|
|
data /= factor |
385
|
|
|
|
386
|
|
|
return self._init_copy(geom=geom, data=data) |
387
|
|
|
|
388
|
|
|
def downsample(self, factor, preserve_counts=True, axis=None): |
389
|
|
|
geom = self.geom.downsample(factor, axis=axis) |
390
|
|
|
if axis is None: |
391
|
|
|
block_size = (factor, factor) + (1,) * len(self.geom.axes) |
392
|
|
|
else: |
393
|
|
|
block_size = [1] * self.data.ndim |
394
|
|
|
idx = self.geom.get_axis_index_by_name(axis) |
395
|
|
|
block_size[-(idx + 1)] = factor |
396
|
|
|
|
397
|
|
|
func = np.nansum if preserve_counts else np.nanmean |
398
|
|
|
data = block_reduce(self.data, tuple(block_size[::-1]), func=func) |
399
|
|
|
|
400
|
|
|
return self._init_copy(geom=geom, data=data) |
401
|
|
|
|
402
|
|
|
def plot(self, ax=None, fig=None, add_cbar=False, stretch="linear", **kwargs): |
403
|
|
|
""" |
404
|
|
|
Plot image on matplotlib WCS axes. |
405
|
|
|
|
406
|
|
|
Parameters |
407
|
|
|
---------- |
408
|
|
|
ax : `~astropy.visualization.wcsaxes.WCSAxes`, optional |
409
|
|
|
WCS axis object to plot on. |
410
|
|
|
fig : `~matplotlib.figure.Figure` |
411
|
|
|
Figure object. |
412
|
|
|
add_cbar : bool |
413
|
|
|
Add color bar? |
414
|
|
|
stretch : str |
415
|
|
|
Passed to `astropy.visualization.simple_norm`. |
416
|
|
|
**kwargs : dict |
417
|
|
|
Keyword arguments passed to `~matplotlib.pyplot.imshow`. |
418
|
|
|
|
419
|
|
|
Returns |
420
|
|
|
------- |
421
|
|
|
fig : `~matplotlib.figure.Figure` |
422
|
|
|
Figure object. |
423
|
|
|
ax : `~astropy.visualization.wcsaxes.WCSAxes` |
424
|
|
|
WCS axis object |
425
|
|
|
cbar : `~matplotlib.colorbar.Colorbar` or None |
426
|
|
|
Colorbar object. |
427
|
|
|
""" |
428
|
|
|
import matplotlib.pyplot as plt |
429
|
|
|
from astropy.visualization import simple_norm |
430
|
|
|
from astropy.visualization.wcsaxes.frame import EllipticalFrame |
431
|
|
|
|
432
|
|
|
if not self.geom.is_image: |
433
|
|
|
raise TypeError("Use .plot_interactive() for Map dimension > 2") |
434
|
|
|
|
435
|
|
|
if fig is None: |
436
|
|
|
fig = plt.gcf() |
437
|
|
|
|
438
|
|
|
if ax is None: |
439
|
|
|
if self.geom.is_allsky: |
440
|
|
|
ax = fig.add_subplot( |
441
|
|
|
1, 1, 1, projection=self.geom.wcs, frame_class=EllipticalFrame |
442
|
|
|
) |
443
|
|
|
else: |
444
|
|
|
ax = fig.add_subplot(1, 1, 1, projection=self.geom.wcs) |
445
|
|
|
|
446
|
|
|
data = self.data.astype(float) |
447
|
|
|
|
448
|
|
|
kwargs.setdefault("interpolation", "nearest") |
449
|
|
|
kwargs.setdefault("origin", "lower") |
450
|
|
|
kwargs.setdefault("cmap", "afmhot") |
451
|
|
|
|
452
|
|
|
norm = simple_norm(data[np.isfinite(data)], stretch) |
453
|
|
|
kwargs.setdefault("norm", norm) |
454
|
|
|
|
455
|
|
|
caxes = ax.imshow(data, **kwargs) |
456
|
|
|
cbar = fig.colorbar(caxes, ax=ax, label=str(self.unit)) if add_cbar else None |
457
|
|
|
|
458
|
|
|
if self.geom.is_allsky: |
459
|
|
|
ax = self._plot_format_allsky(ax) |
460
|
|
|
else: |
461
|
|
|
ax = self._plot_format(ax) |
462
|
|
|
|
463
|
|
|
# without this the axis limits are changed when calling scatter |
464
|
|
|
ax.autoscale(enable=False) |
465
|
|
|
return fig, ax, cbar |
466
|
|
|
|
467
|
|
|
def _plot_format(self, ax): |
468
|
|
|
try: |
469
|
|
|
ax.coords["glon"].set_axislabel("Galactic Longitude") |
470
|
|
|
ax.coords["glat"].set_axislabel("Galactic Latitude") |
471
|
|
|
except KeyError: |
472
|
|
|
ax.coords["ra"].set_axislabel("Right Ascension") |
473
|
|
|
ax.coords["dec"].set_axislabel("Declination") |
474
|
|
|
except AttributeError: |
475
|
|
|
log.info("Can't set coordinate axes. No WCS information available.") |
476
|
|
|
return ax |
477
|
|
|
|
478
|
|
|
def _plot_format_allsky(self, ax): |
479
|
|
|
# Remove frame |
480
|
|
|
ax.coords.frame.set_linewidth(0) |
481
|
|
|
|
482
|
|
|
# Set plot axis limits |
483
|
|
|
ymax, xmax = self.data.shape |
484
|
|
|
xmargin, _ = self.geom.coord_to_pix({"lon": 180, "lat": 0}) |
485
|
|
|
_, ymargin = self.geom.coord_to_pix({"lon": 0, "lat": -90}) |
486
|
|
|
|
487
|
|
|
ax.set_xlim(xmargin, xmax - xmargin) |
488
|
|
|
ax.set_ylim(ymargin, ymax - ymargin) |
489
|
|
|
|
490
|
|
|
ax.text(0, ymax, self.geom.coordsys + " coords") |
491
|
|
|
|
492
|
|
|
# Grid and ticks |
493
|
|
|
glon_spacing, glat_spacing = 45, 15 |
494
|
|
|
lon, lat = ax.coords |
495
|
|
|
lon.set_ticks(spacing=glon_spacing * u.deg, color="w", alpha=0.8) |
496
|
|
|
lat.set_ticks(spacing=glat_spacing * u.deg) |
497
|
|
|
lon.set_ticks_visible(False) |
498
|
|
|
|
499
|
|
|
lon.set_ticklabel(color="w", alpha=0.8) |
500
|
|
|
lon.grid(alpha=0.2, linestyle="solid", color="w") |
501
|
|
|
lat.grid(alpha=0.2, linestyle="solid", color="w") |
502
|
|
|
return ax |
503
|
|
|
|
504
|
|
|
def smooth(self, width, kernel="gauss", **kwargs): |
505
|
|
|
"""Smooth the map. |
506
|
|
|
|
507
|
|
|
Iterates over 2D image planes, processing one at a time. |
508
|
|
|
|
509
|
|
|
Parameters |
510
|
|
|
---------- |
511
|
|
|
width : `~astropy.units.Quantity`, str or float |
512
|
|
|
Smoothing width given as quantity or float. If a float is given it |
513
|
|
|
interpreted as smoothing width in pixels. If an (angular) quantity |
514
|
|
|
is given it converted to pixels using ``geom.wcs.wcs.cdelt``. |
515
|
|
|
It corresponds to the standard deviation in case of a Gaussian kernel, |
516
|
|
|
the radius in case of a disk kernel, and the side length in case |
517
|
|
|
of a box kernel. |
518
|
|
|
kernel : {'gauss', 'disk', 'box'} |
519
|
|
|
Kernel shape |
520
|
|
|
kwargs : dict |
521
|
|
|
Keyword arguments passed to `~scipy.ndimage.uniform_filter` |
522
|
|
|
('box'), `~scipy.ndimage.gaussian_filter` ('gauss') or |
523
|
|
|
`~scipy.ndimage.convolve` ('disk'). |
524
|
|
|
|
525
|
|
|
Returns |
526
|
|
|
------- |
527
|
|
|
image : `WcsNDMap` |
528
|
|
|
Smoothed image (a copy, the original object is unchanged). |
529
|
|
|
""" |
530
|
|
|
if isinstance(width, (u.Quantity, str)): |
531
|
|
|
width = u.Quantity(width) / self.geom.pixel_scales.mean() |
532
|
|
|
width = width.to_value("") |
533
|
|
|
|
534
|
|
|
smoothed_data = np.empty(self.data.shape, dtype=float) |
535
|
|
|
|
536
|
|
|
for img, idx in self.iter_by_image(): |
537
|
|
|
img = img.astype(float) |
538
|
|
|
if kernel == "gauss": |
539
|
|
|
data = scipy.ndimage.gaussian_filter(img, width, **kwargs) |
540
|
|
|
elif kernel == "disk": |
541
|
|
|
disk = Tophat2DKernel(width) |
542
|
|
|
disk.normalize("integral") |
543
|
|
|
data = scipy.ndimage.convolve(img, disk.array, **kwargs) |
544
|
|
|
elif kernel == "box": |
545
|
|
|
data = scipy.ndimage.uniform_filter(img, width, **kwargs) |
546
|
|
|
else: |
547
|
|
|
raise ValueError(f"Invalid kernel: {kernel!r}") |
548
|
|
|
smoothed_data[idx] = data |
549
|
|
|
|
550
|
|
|
return self._init_copy(data=smoothed_data) |
551
|
|
|
|
552
|
|
|
def get_spectrum(self, region=None, func=np.nansum): |
553
|
|
|
"""Extract spectrum in a given region. |
554
|
|
|
|
555
|
|
|
The spectrum can be computed by summing (or, more generally, applying ``func``) |
556
|
|
|
along the spatial axes in each energy bin. This occurs only inside the ``region``, |
557
|
|
|
which by default is assumed to be the whole spatial extension of the map. |
558
|
|
|
|
559
|
|
|
Parameters |
560
|
|
|
---------- |
561
|
|
|
region: `~regions.Region` |
562
|
|
|
Region (pixel or sky regions accepted). |
563
|
|
|
func : numpy.ufunc |
564
|
|
|
Function to reduce the data. |
565
|
|
|
|
566
|
|
|
Returns |
567
|
|
|
------- |
568
|
|
|
spectrum : `~gammapy.spectrum.CountsSpectrum` |
569
|
|
|
Spectrum in the given region. |
570
|
|
|
""" |
571
|
|
|
from gammapy.spectrum import CountsSpectrum |
572
|
|
|
|
573
|
|
|
energy_axis = self.geom.get_axis_by_name("energy") |
574
|
|
|
|
575
|
|
|
if region: |
576
|
|
|
mask = self.geom.region_mask([region]) |
577
|
|
|
data = self.data[mask].reshape(energy_axis.nbin, -1) |
578
|
|
|
data = func(data, axis=1) |
579
|
|
|
else: |
580
|
|
|
data = func(self.data, axis=(1, 2)) |
581
|
|
|
|
582
|
|
|
edges = energy_axis.edges |
583
|
|
|
return CountsSpectrum( |
584
|
|
|
data=data, energy_lo=edges[:-1], energy_hi=edges[1:], unit=self.unit |
585
|
|
|
) |
586
|
|
|
|
587
|
|
|
def convolve(self, kernel, use_fft=True, **kwargs): |
588
|
|
|
""" |
589
|
|
|
Convolve map with a kernel. |
590
|
|
|
|
591
|
|
|
If the kernel is two dimensional, it is applied to all image planes likewise. |
592
|
|
|
If the kernel is higher dimensional it must match the map in the number of |
593
|
|
|
dimensions and the corresponding kernel is selected for every image plane. |
594
|
|
|
|
595
|
|
|
Parameters |
596
|
|
|
---------- |
597
|
|
|
kernel : `~gammapy.cube.PSFKernel` or `numpy.ndarray` |
598
|
|
|
Convolution kernel. |
599
|
|
|
use_fft : bool |
600
|
|
|
Use `scipy.signal.fftconvolve` or `scipy.ndimage.convolve`. |
601
|
|
|
kwargs : dict |
602
|
|
|
Keyword arguments passed to `scipy.signal.fftconvolve` or |
603
|
|
|
`scipy.ndimage.convolve`. |
604
|
|
|
|
605
|
|
|
Returns |
606
|
|
|
------- |
607
|
|
|
map : `WcsNDMap` |
608
|
|
|
Convolved map. |
609
|
|
|
""" |
610
|
|
|
from gammapy.cube import PSFKernel |
611
|
|
|
|
612
|
|
|
conv_function = scipy.signal.fftconvolve if use_fft else scipy.ndimage.convolve |
613
|
|
|
convolved_data = np.empty(self.data.shape, dtype=np.float32) |
614
|
|
|
if use_fft: |
615
|
|
|
kwargs.setdefault("mode", "same") |
616
|
|
|
|
617
|
|
|
if isinstance(kernel, PSFKernel): |
618
|
|
|
kmap = kernel.psf_kernel_map |
619
|
|
|
if not np.allclose( |
620
|
|
|
self.geom.pixel_scales.deg, kmap.geom.pixel_scales.deg, rtol=1e-5 |
621
|
|
|
): |
622
|
|
|
raise ValueError("Pixel size of kernel and map not compatible.") |
623
|
|
|
kernel = kmap.data |
624
|
|
|
|
625
|
|
|
for img, idx in self.iter_by_image(): |
626
|
|
|
idx = Ellipsis if kernel.ndim == 2 else idx |
627
|
|
|
convolved_data[idx] = conv_function(img, kernel[idx], **kwargs) |
628
|
|
|
|
629
|
|
|
return self._init_copy(data=convolved_data) |
630
|
|
|
|
631
|
|
|
def apply_edisp(self, edisp): |
632
|
|
|
"""Apply energy dispersion to map. Requires energy axis. |
633
|
|
|
|
634
|
|
|
Parameters |
635
|
|
|
---------- |
636
|
|
|
edisp : `gammapy.irf.EnergyDispersion` |
637
|
|
|
Energy dispersion matrix |
638
|
|
|
|
639
|
|
|
Returns |
640
|
|
|
------- |
641
|
|
|
map : `WcsNDMap` |
642
|
|
|
Map with energy dispersion applied. |
643
|
|
|
""" |
644
|
|
|
loc = self.geom.get_axis_index_by_name("energy") |
645
|
|
|
data = np.rollaxis(self.data, loc, len(self.data.shape)) |
646
|
|
|
data = np.dot(data, edisp.pdf_matrix) |
647
|
|
|
data = np.rollaxis(data, -1, loc) |
648
|
|
|
|
649
|
|
|
e_reco_axis = edisp.e_reco.copy(name="energy") |
650
|
|
|
geom_reco = self.geom.to_image().to_cube(axes=[e_reco_axis]) |
651
|
|
|
return self._init_copy(geom=geom_reco, data=data) |
652
|
|
|
|
653
|
|
|
def cutout(self, position, width, mode="trim"): |
654
|
|
|
""" |
655
|
|
|
Create a cutout around a given position. |
656
|
|
|
|
657
|
|
|
Parameters |
658
|
|
|
---------- |
659
|
|
|
position : `~astropy.coordinates.SkyCoord` |
660
|
|
|
Center position of the cutout region. |
661
|
|
|
width : tuple of `~astropy.coordinates.Angle` |
662
|
|
|
Angular sizes of the region in (lon, lat) in that specific order. |
663
|
|
|
If only one value is passed, a square region is extracted. |
664
|
|
|
mode : {'trim', 'partial', 'strict'} |
665
|
|
|
Mode option for Cutout2D, for details see `~astropy.nddata.utils.Cutout2D`. |
666
|
|
|
|
667
|
|
|
Returns |
668
|
|
|
------- |
669
|
|
|
cutout : `~gammapy.maps.WcsNDMap` |
670
|
|
|
Cutout map |
671
|
|
|
""" |
672
|
|
|
geom_cutout = self.geom.cutout(position=position, width=width, mode=mode) |
673
|
|
|
|
674
|
|
|
slices = geom_cutout.cutout_info["parent-slices"] |
675
|
|
|
cutout_slices = Ellipsis, slices[0], slices[1] |
676
|
|
|
|
677
|
|
|
data = self.data[cutout_slices] |
678
|
|
|
|
679
|
|
|
return self._init_copy(geom=geom_cutout, data=data) |
680
|
|
|
|
681
|
|
|
def stack(self, other, weights=None): |
682
|
|
|
"""Stack cutout into map. |
683
|
|
|
|
684
|
|
|
Parameters |
685
|
|
|
---------- |
686
|
|
|
other : `WcsNDMap` |
687
|
|
|
Other map to stack |
688
|
|
|
weights : `~numpy.ndarray` |
689
|
|
|
Array to be used as weights. |
690
|
|
|
""" |
691
|
|
|
if self.geom == other.geom: |
692
|
|
|
parent_slices, cutout_slices = None, None |
693
|
|
|
elif other.geom.cutout_info is not None and self.geom == other.geom.cutout_info["parent-geom"]: |
694
|
|
|
slices = other.geom.cutout_info["parent-slices"] |
695
|
|
|
parent_slices = Ellipsis, slices[0], slices[1] |
696
|
|
|
|
697
|
|
|
slices = other.geom.cutout_info["cutout-slices"] |
698
|
|
|
cutout_slices = Ellipsis, slices[0], slices[1] |
699
|
|
|
else: |
700
|
|
|
raise ValueError("Can only stack equivalent maps or cutout of the same map.") |
701
|
|
|
|
702
|
|
|
data = other.data[cutout_slices] |
703
|
|
|
|
704
|
|
|
if weights is not None: |
705
|
|
|
data = data * weights |
706
|
|
|
|
707
|
|
|
self.data[parent_slices] += data |
708
|
|
|
|
709
|
|
|
def sample_coord(self, n_events, random_state=0): |
710
|
|
|
"""Sample position and energy of events. |
711
|
|
|
|
712
|
|
|
Parameters |
713
|
|
|
---------- |
714
|
|
|
n_events : int |
715
|
|
|
Number of events to sample. |
716
|
|
|
random_state : {int, 'random-seed', 'global-rng', `~numpy.random.RandomState`} |
717
|
|
|
Defines random number generator initialisation. |
718
|
|
|
Passed to `~gammapy.utils.random.get_random_state`. |
719
|
|
|
|
720
|
|
|
Returns |
721
|
|
|
------- |
722
|
|
|
coords : `~gammapy.maps.MapCoord` object. |
723
|
|
|
Sequence of coordinates and energies of the sampled events. |
724
|
|
|
""" |
725
|
|
|
|
726
|
|
|
random_state = get_random_state(random_state) |
727
|
|
|
sampler = InverseCDFSampler(pdf=self.data, random_state=random_state) |
728
|
|
|
|
729
|
|
|
coords_pix = sampler.sample(n_events) |
730
|
|
|
coords = self.geom.pix_to_coord(coords_pix[::-1]) |
731
|
|
|
|
732
|
|
|
# TODO: pix_to_coord should return a MapCoord object |
733
|
|
|
axes_names = ["lon", "lat"] + [ax.name for ax in self.geom.axes] |
734
|
|
|
cdict = OrderedDict(zip(axes_names, coords)) |
735
|
|
|
cdict["energy"] *= self.geom.get_axis_by_name("energy").unit |
736
|
|
|
|
737
|
|
|
return MapCoord.create(cdict, coordsys=self.geom.coordsys) |
738
|
|
|
|